Merge branch 'feat/end-user-oauth' into deploy/end-user-oauth
|
|
@ -0,0 +1,6 @@
|
|||
# Cursor Rules for Dify Project
|
||||
|
||||
## Automated Test Generation
|
||||
|
||||
- Use `web/testing/testing.md` as the canonical instruction set for generating frontend automated tests.
|
||||
- When proposing or saving tests, re-read that document and follow every requirement.
|
||||
|
|
@ -0,0 +1,226 @@
|
|||
# CODEOWNERS
|
||||
# This file defines code ownership for the Dify project.
|
||||
# Each line is a file pattern followed by one or more owners.
|
||||
# Owners can be @username, @org/team-name, or email addresses.
|
||||
# For more information, see: https://docs.github.com/en/repositories/managing-your-repositorys-settings-and-features/customizing-your-repository/about-code-owners
|
||||
|
||||
* @crazywoola @laipz8200 @Yeuoly
|
||||
|
||||
# Backend (default owner, more specific rules below will override)
|
||||
api/ @QuantumGhost
|
||||
|
||||
# Backend - Workflow - Engine (Core graph execution engine)
|
||||
api/core/workflow/graph_engine/ @laipz8200 @QuantumGhost
|
||||
api/core/workflow/runtime/ @laipz8200 @QuantumGhost
|
||||
api/core/workflow/graph/ @laipz8200 @QuantumGhost
|
||||
api/core/workflow/graph_events/ @laipz8200 @QuantumGhost
|
||||
api/core/workflow/node_events/ @laipz8200 @QuantumGhost
|
||||
api/core/model_runtime/ @laipz8200 @QuantumGhost
|
||||
|
||||
# Backend - Workflow - Nodes (Agent, Iteration, Loop, LLM)
|
||||
api/core/workflow/nodes/agent/ @Nov1c444
|
||||
api/core/workflow/nodes/iteration/ @Nov1c444
|
||||
api/core/workflow/nodes/loop/ @Nov1c444
|
||||
api/core/workflow/nodes/llm/ @Nov1c444
|
||||
|
||||
# Backend - RAG (Retrieval Augmented Generation)
|
||||
api/core/rag/ @JohnJyong
|
||||
api/services/rag_pipeline/ @JohnJyong
|
||||
api/services/dataset_service.py @JohnJyong
|
||||
api/services/knowledge_service.py @JohnJyong
|
||||
api/services/external_knowledge_service.py @JohnJyong
|
||||
api/services/hit_testing_service.py @JohnJyong
|
||||
api/services/metadata_service.py @JohnJyong
|
||||
api/services/vector_service.py @JohnJyong
|
||||
api/services/entities/knowledge_entities/ @JohnJyong
|
||||
api/services/entities/external_knowledge_entities/ @JohnJyong
|
||||
api/controllers/console/datasets/ @JohnJyong
|
||||
api/controllers/service_api/dataset/ @JohnJyong
|
||||
api/models/dataset.py @JohnJyong
|
||||
api/tasks/rag_pipeline/ @JohnJyong
|
||||
api/tasks/add_document_to_index_task.py @JohnJyong
|
||||
api/tasks/batch_clean_document_task.py @JohnJyong
|
||||
api/tasks/clean_document_task.py @JohnJyong
|
||||
api/tasks/clean_notion_document_task.py @JohnJyong
|
||||
api/tasks/document_indexing_task.py @JohnJyong
|
||||
api/tasks/document_indexing_sync_task.py @JohnJyong
|
||||
api/tasks/document_indexing_update_task.py @JohnJyong
|
||||
api/tasks/duplicate_document_indexing_task.py @JohnJyong
|
||||
api/tasks/recover_document_indexing_task.py @JohnJyong
|
||||
api/tasks/remove_document_from_index_task.py @JohnJyong
|
||||
api/tasks/retry_document_indexing_task.py @JohnJyong
|
||||
api/tasks/sync_website_document_indexing_task.py @JohnJyong
|
||||
api/tasks/batch_create_segment_to_index_task.py @JohnJyong
|
||||
api/tasks/create_segment_to_index_task.py @JohnJyong
|
||||
api/tasks/delete_segment_from_index_task.py @JohnJyong
|
||||
api/tasks/disable_segment_from_index_task.py @JohnJyong
|
||||
api/tasks/disable_segments_from_index_task.py @JohnJyong
|
||||
api/tasks/enable_segment_to_index_task.py @JohnJyong
|
||||
api/tasks/enable_segments_to_index_task.py @JohnJyong
|
||||
api/tasks/clean_dataset_task.py @JohnJyong
|
||||
api/tasks/deal_dataset_index_update_task.py @JohnJyong
|
||||
api/tasks/deal_dataset_vector_index_task.py @JohnJyong
|
||||
|
||||
# Backend - Plugins
|
||||
api/core/plugin/ @Mairuis @Yeuoly @Stream29
|
||||
api/services/plugin/ @Mairuis @Yeuoly @Stream29
|
||||
api/controllers/console/workspace/plugin.py @Mairuis @Yeuoly @Stream29
|
||||
api/controllers/inner_api/plugin/ @Mairuis @Yeuoly @Stream29
|
||||
api/tasks/process_tenant_plugin_autoupgrade_check_task.py @Mairuis @Yeuoly @Stream29
|
||||
|
||||
# Backend - Trigger/Schedule/Webhook
|
||||
api/controllers/trigger/ @Mairuis @Yeuoly
|
||||
api/controllers/console/app/workflow_trigger.py @Mairuis @Yeuoly
|
||||
api/controllers/console/workspace/trigger_providers.py @Mairuis @Yeuoly
|
||||
api/core/trigger/ @Mairuis @Yeuoly
|
||||
api/core/app/layers/trigger_post_layer.py @Mairuis @Yeuoly
|
||||
api/services/trigger/ @Mairuis @Yeuoly
|
||||
api/models/trigger.py @Mairuis @Yeuoly
|
||||
api/fields/workflow_trigger_fields.py @Mairuis @Yeuoly
|
||||
api/repositories/workflow_trigger_log_repository.py @Mairuis @Yeuoly
|
||||
api/repositories/sqlalchemy_workflow_trigger_log_repository.py @Mairuis @Yeuoly
|
||||
api/libs/schedule_utils.py @Mairuis @Yeuoly
|
||||
api/services/workflow/scheduler.py @Mairuis @Yeuoly
|
||||
api/schedule/trigger_provider_refresh_task.py @Mairuis @Yeuoly
|
||||
api/schedule/workflow_schedule_task.py @Mairuis @Yeuoly
|
||||
api/tasks/trigger_processing_tasks.py @Mairuis @Yeuoly
|
||||
api/tasks/trigger_subscription_refresh_tasks.py @Mairuis @Yeuoly
|
||||
api/tasks/workflow_schedule_tasks.py @Mairuis @Yeuoly
|
||||
api/tasks/workflow_cfs_scheduler/ @Mairuis @Yeuoly
|
||||
api/events/event_handlers/sync_plugin_trigger_when_app_created.py @Mairuis @Yeuoly
|
||||
api/events/event_handlers/update_app_triggers_when_app_published_workflow_updated.py @Mairuis @Yeuoly
|
||||
api/events/event_handlers/sync_workflow_schedule_when_app_published.py @Mairuis @Yeuoly
|
||||
api/events/event_handlers/sync_webhook_when_app_created.py @Mairuis @Yeuoly
|
||||
|
||||
# Backend - Async Workflow
|
||||
api/services/async_workflow_service.py @Mairuis @Yeuoly
|
||||
api/tasks/async_workflow_tasks.py @Mairuis @Yeuoly
|
||||
|
||||
# Backend - Billing
|
||||
api/services/billing_service.py @hj24 @zyssyz123
|
||||
api/controllers/console/billing/ @hj24 @zyssyz123
|
||||
|
||||
# Backend - Enterprise
|
||||
api/configs/enterprise/ @GarfieldDai @GareArc
|
||||
api/services/enterprise/ @GarfieldDai @GareArc
|
||||
api/services/feature_service.py @GarfieldDai @GareArc
|
||||
api/controllers/console/feature.py @GarfieldDai @GareArc
|
||||
api/controllers/web/feature.py @GarfieldDai @GareArc
|
||||
|
||||
# Backend - Database Migrations
|
||||
api/migrations/ @snakevash @laipz8200
|
||||
|
||||
# Frontend
|
||||
web/ @iamjoel
|
||||
|
||||
# Frontend - App - Orchestration
|
||||
web/app/components/workflow/ @iamjoel @zxhlyh
|
||||
web/app/components/workflow-app/ @iamjoel @zxhlyh
|
||||
web/app/components/app/configuration/ @iamjoel @zxhlyh
|
||||
web/app/components/app/app-publisher/ @iamjoel @zxhlyh
|
||||
|
||||
# Frontend - WebApp - Chat
|
||||
web/app/components/base/chat/ @iamjoel @zxhlyh
|
||||
|
||||
# Frontend - WebApp - Completion
|
||||
web/app/components/share/text-generation/ @iamjoel @zxhlyh
|
||||
|
||||
# Frontend - App - List and Creation
|
||||
web/app/components/apps/ @JzoNgKVO @iamjoel
|
||||
web/app/components/app/create-app-dialog/ @JzoNgKVO @iamjoel
|
||||
web/app/components/app/create-app-modal/ @JzoNgKVO @iamjoel
|
||||
web/app/components/app/create-from-dsl-modal/ @JzoNgKVO @iamjoel
|
||||
|
||||
# Frontend - App - API Documentation
|
||||
web/app/components/develop/ @JzoNgKVO @iamjoel
|
||||
|
||||
# Frontend - App - Logs and Annotations
|
||||
web/app/components/app/workflow-log/ @JzoNgKVO @iamjoel
|
||||
web/app/components/app/log/ @JzoNgKVO @iamjoel
|
||||
web/app/components/app/log-annotation/ @JzoNgKVO @iamjoel
|
||||
web/app/components/app/annotation/ @JzoNgKVO @iamjoel
|
||||
|
||||
# Frontend - App - Monitoring
|
||||
web/app/(commonLayout)/app/(appDetailLayout)/\[appId\]/overview/ @JzoNgKVO @iamjoel
|
||||
web/app/components/app/overview/ @JzoNgKVO @iamjoel
|
||||
|
||||
# Frontend - App - Settings
|
||||
web/app/components/app-sidebar/ @JzoNgKVO @iamjoel
|
||||
|
||||
# Frontend - RAG - Hit Testing
|
||||
web/app/components/datasets/hit-testing/ @JzoNgKVO @iamjoel
|
||||
|
||||
# Frontend - RAG - List and Creation
|
||||
web/app/components/datasets/list/ @iamjoel @WTW0313
|
||||
web/app/components/datasets/create/ @iamjoel @WTW0313
|
||||
web/app/components/datasets/create-from-pipeline/ @iamjoel @WTW0313
|
||||
web/app/components/datasets/external-knowledge-base/ @iamjoel @WTW0313
|
||||
|
||||
# Frontend - RAG - Orchestration (general rule first, specific rules below override)
|
||||
web/app/components/rag-pipeline/ @iamjoel @WTW0313
|
||||
web/app/components/rag-pipeline/components/rag-pipeline-main.tsx @iamjoel @zxhlyh
|
||||
web/app/components/rag-pipeline/store/ @iamjoel @zxhlyh
|
||||
|
||||
# Frontend - RAG - Documents List
|
||||
web/app/components/datasets/documents/list.tsx @iamjoel @WTW0313
|
||||
web/app/components/datasets/documents/create-from-pipeline/ @iamjoel @WTW0313
|
||||
|
||||
# Frontend - RAG - Segments List
|
||||
web/app/components/datasets/documents/detail/ @iamjoel @WTW0313
|
||||
|
||||
# Frontend - RAG - Settings
|
||||
web/app/components/datasets/settings/ @iamjoel @WTW0313
|
||||
|
||||
# Frontend - Ecosystem - Plugins
|
||||
web/app/components/plugins/ @iamjoel @zhsama
|
||||
|
||||
# Frontend - Ecosystem - Tools
|
||||
web/app/components/tools/ @iamjoel @Yessenia-d
|
||||
|
||||
# Frontend - Ecosystem - MarketPlace
|
||||
web/app/components/plugins/marketplace/ @iamjoel @Yessenia-d
|
||||
|
||||
# Frontend - Login and Registration
|
||||
web/app/signin/ @douxc @iamjoel
|
||||
web/app/signup/ @douxc @iamjoel
|
||||
web/app/reset-password/ @douxc @iamjoel
|
||||
web/app/install/ @douxc @iamjoel
|
||||
web/app/init/ @douxc @iamjoel
|
||||
web/app/forgot-password/ @douxc @iamjoel
|
||||
web/app/account/ @douxc @iamjoel
|
||||
|
||||
# Frontend - Service Authentication
|
||||
web/service/base.ts @douxc @iamjoel
|
||||
|
||||
# Frontend - WebApp Authentication and Access Control
|
||||
web/app/(shareLayout)/components/ @douxc @iamjoel
|
||||
web/app/(shareLayout)/webapp-signin/ @douxc @iamjoel
|
||||
web/app/(shareLayout)/webapp-reset-password/ @douxc @iamjoel
|
||||
web/app/components/app/app-access-control/ @douxc @iamjoel
|
||||
|
||||
# Frontend - Explore Page
|
||||
web/app/components/explore/ @CodingOnStar @iamjoel
|
||||
|
||||
# Frontend - Personal Settings
|
||||
web/app/components/header/account-setting/ @CodingOnStar @iamjoel
|
||||
web/app/components/header/account-dropdown/ @CodingOnStar @iamjoel
|
||||
|
||||
# Frontend - Analytics
|
||||
web/app/components/base/ga/ @CodingOnStar @iamjoel
|
||||
|
||||
# Frontend - Base Components
|
||||
web/app/components/base/ @iamjoel @zxhlyh
|
||||
|
||||
# Frontend - Utils and Hooks
|
||||
web/utils/classnames.ts @iamjoel @zxhlyh
|
||||
web/utils/time.ts @iamjoel @zxhlyh
|
||||
web/utils/format.ts @iamjoel @zxhlyh
|
||||
web/utils/clipboard.ts @iamjoel @zxhlyh
|
||||
web/hooks/use-document-title.ts @iamjoel @zxhlyh
|
||||
|
||||
# Frontend - Billing and Education
|
||||
web/app/components/billing/ @iamjoel @zxhlyh
|
||||
web/app/education-apply/ @iamjoel @zxhlyh
|
||||
|
||||
# Frontend - Workspace
|
||||
web/app/components/header/account-dropdown/workplace-selector/ @iamjoel @zxhlyh
|
||||
|
|
@ -0,0 +1,12 @@
|
|||
# Copilot Instructions
|
||||
|
||||
GitHub Copilot must follow the unified frontend testing requirements documented in `web/testing/testing.md`.
|
||||
|
||||
Key reminders:
|
||||
|
||||
- Generate tests using the mandated tech stack, naming, and code style (AAA pattern, `fireEvent`, descriptive test names, cleans up mocks).
|
||||
- Cover rendering, prop combinations, and edge cases by default; extend coverage for hooks, routing, async flows, and domain-specific components when applicable.
|
||||
- Target >95% line and branch coverage and 100% function/statement coverage.
|
||||
- Apply the project's mocking conventions for i18n, toast notifications, and Next.js utilities.
|
||||
|
||||
Any suggestions from Copilot that conflict with `web/testing/testing.md` should be revised before acceptance.
|
||||
|
|
@ -77,12 +77,15 @@ jobs:
|
|||
uses: peter-evans/create-pull-request@v6
|
||||
with:
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
commit-message: Update i18n files and type definitions based on en-US changes
|
||||
title: 'chore: translate i18n files and update type definitions'
|
||||
commit-message: 'chore(i18n): update translations based on en-US changes'
|
||||
title: 'chore(i18n): translate i18n files and update type definitions'
|
||||
body: |
|
||||
This PR was automatically created to update i18n files and TypeScript type definitions based on changes in en-US locale.
|
||||
|
||||
|
||||
**Triggered by:** ${{ github.sha }}
|
||||
|
||||
**Changes included:**
|
||||
- Updated translation files for all locales
|
||||
- Regenerated TypeScript type definitions for type safety
|
||||
branch: chore/automated-i18n-updates
|
||||
branch: chore/automated-i18n-updates-${{ github.sha }}
|
||||
delete-branch: true
|
||||
|
|
|
|||
|
|
@ -0,0 +1,5 @@
|
|||
# Windsurf Testing Rules
|
||||
|
||||
- Use `web/testing/testing.md` as the single source of truth for frontend automated testing.
|
||||
- Honor every requirement in that document when generating or accepting tests.
|
||||
- When proposing or saving tests, re-read that document and follow every requirement.
|
||||
|
|
@ -77,6 +77,8 @@ How we prioritize:
|
|||
|
||||
For setting up the frontend service, please refer to our comprehensive [guide](https://github.com/langgenius/dify/blob/main/web/README.md) in the `web/README.md` file. This document provides detailed instructions to help you set up the frontend environment properly.
|
||||
|
||||
**Testing**: All React components must have comprehensive test coverage. See [web/testing/testing.md](https://github.com/langgenius/dify/blob/main/web/testing/testing.md) for the canonical frontend testing guidelines and follow every requirement described there.
|
||||
|
||||
#### Backend
|
||||
|
||||
For setting up the backend service, kindly refer to our detailed [instructions](https://github.com/langgenius/dify/blob/main/api/README.md) in the `api/README.md` file. This document contains step-by-step guidance to help you get the backend up and running smoothly.
|
||||
|
|
|
|||
|
|
@ -36,6 +36,12 @@
|
|||
<img alt="Issues closed" src="https://img.shields.io/github/issues-search?query=repo%3Alanggenius%2Fdify%20is%3Aclosed&label=issues%20closed&labelColor=%20%237d89b0&color=%20%235d6b98"></a>
|
||||
<a href="https://github.com/langgenius/dify/discussions/" target="_blank">
|
||||
<img alt="Discussion posts" src="https://img.shields.io/github/discussions/langgenius/dify?labelColor=%20%239b8afb&color=%20%237a5af8"></a>
|
||||
<a href="https://insights.linuxfoundation.org/project/langgenius-dify" target="_blank">
|
||||
<img alt="LFX Health Score" src="https://insights.linuxfoundation.org/api/badge/health-score?project=langgenius-dify"></a>
|
||||
<a href="https://insights.linuxfoundation.org/project/langgenius-dify" target="_blank">
|
||||
<img alt="LFX Contributors" src="https://insights.linuxfoundation.org/api/badge/contributors?project=langgenius-dify"></a>
|
||||
<a href="https://insights.linuxfoundation.org/project/langgenius-dify" target="_blank">
|
||||
<img alt="LFX Active Contributors" src="https://insights.linuxfoundation.org/api/badge/active-contributors?project=langgenius-dify"></a>
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
|
|
|
|||
|
|
@ -176,6 +176,7 @@ WEAVIATE_ENDPOINT=http://localhost:8080
|
|||
WEAVIATE_API_KEY=WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih
|
||||
WEAVIATE_GRPC_ENABLED=false
|
||||
WEAVIATE_BATCH_SIZE=100
|
||||
WEAVIATE_TOKENIZATION=word
|
||||
|
||||
# OceanBase Vector configuration
|
||||
OCEANBASE_VECTOR_HOST=127.0.0.1
|
||||
|
|
@ -539,6 +540,7 @@ WORKFLOW_LOG_CLEANUP_BATCH_SIZE=100
|
|||
|
||||
# App configuration
|
||||
APP_MAX_EXECUTION_TIME=1200
|
||||
APP_DEFAULT_ACTIVE_REQUESTS=0
|
||||
APP_MAX_ACTIVE_REQUESTS=0
|
||||
|
||||
# Celery beat configuration
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ layers =
|
|||
graph
|
||||
nodes
|
||||
node_events
|
||||
runtime
|
||||
entities
|
||||
containers =
|
||||
core.workflow
|
||||
|
|
|
|||
|
|
@ -48,6 +48,12 @@ ENV PYTHONIOENCODING=utf-8
|
|||
|
||||
WORKDIR /app/api
|
||||
|
||||
# Create non-root user
|
||||
ARG dify_uid=1001
|
||||
RUN groupadd -r -g ${dify_uid} dify && \
|
||||
useradd -r -u ${dify_uid} -g ${dify_uid} -s /bin/bash dify && \
|
||||
chown -R dify:dify /app
|
||||
|
||||
RUN \
|
||||
apt-get update \
|
||||
# Install dependencies
|
||||
|
|
@ -57,7 +63,7 @@ RUN \
|
|||
# for gmpy2 \
|
||||
libgmp-dev libmpfr-dev libmpc-dev \
|
||||
# For Security
|
||||
expat libldap-2.5-0 perl libsqlite3-0 zlib1g \
|
||||
expat libldap-2.5-0=2.5.13+dfsg-5 perl libsqlite3-0=3.40.1-2+deb12u2 zlib1g=1:1.2.13.dfsg-1 \
|
||||
# install fonts to support the use of tools like pypdfium2
|
||||
fonts-noto-cjk \
|
||||
# install a package to improve the accuracy of guessing mime type and file extension
|
||||
|
|
@ -69,24 +75,29 @@ RUN \
|
|||
|
||||
# Copy Python environment and packages
|
||||
ENV VIRTUAL_ENV=/app/api/.venv
|
||||
COPY --from=packages ${VIRTUAL_ENV} ${VIRTUAL_ENV}
|
||||
COPY --from=packages --chown=dify:dify ${VIRTUAL_ENV} ${VIRTUAL_ENV}
|
||||
ENV PATH="${VIRTUAL_ENV}/bin:${PATH}"
|
||||
|
||||
# Download nltk data
|
||||
RUN python -c "import nltk; nltk.download('punkt'); nltk.download('averaged_perceptron_tagger')"
|
||||
RUN mkdir -p /usr/local/share/nltk_data && NLTK_DATA=/usr/local/share/nltk_data python -c "import nltk; nltk.download('punkt'); nltk.download('averaged_perceptron_tagger'); nltk.download('stopwords')" \
|
||||
&& chmod -R 755 /usr/local/share/nltk_data
|
||||
|
||||
ENV TIKTOKEN_CACHE_DIR=/app/api/.tiktoken_cache
|
||||
|
||||
RUN python -c "import tiktoken; tiktoken.encoding_for_model('gpt2')"
|
||||
RUN python -c "import tiktoken; tiktoken.encoding_for_model('gpt2')" \
|
||||
&& chown -R dify:dify ${TIKTOKEN_CACHE_DIR}
|
||||
|
||||
# Copy source code
|
||||
COPY . /app/api/
|
||||
COPY --chown=dify:dify . /app/api/
|
||||
|
||||
# Prepare entrypoint script
|
||||
COPY --chown=dify:dify --chmod=755 docker/entrypoint.sh /entrypoint.sh
|
||||
|
||||
# Copy entrypoint
|
||||
COPY docker/entrypoint.sh /entrypoint.sh
|
||||
RUN chmod +x /entrypoint.sh
|
||||
|
||||
ARG COMMIT_SHA
|
||||
ENV COMMIT_SHA=${COMMIT_SHA}
|
||||
ENV NLTK_DATA=/usr/local/share/nltk_data
|
||||
|
||||
USER dify
|
||||
|
||||
ENTRYPOINT ["/bin/bash", "/entrypoint.sh"]
|
||||
|
|
|
|||
|
|
@ -73,6 +73,10 @@ class AppExecutionConfig(BaseSettings):
|
|||
description="Maximum allowed execution time for the application in seconds",
|
||||
default=1200,
|
||||
)
|
||||
APP_DEFAULT_ACTIVE_REQUESTS: NonNegativeInt = Field(
|
||||
description="Default number of concurrent active requests per app (0 for unlimited)",
|
||||
default=0,
|
||||
)
|
||||
APP_MAX_ACTIVE_REQUESTS: NonNegativeInt = Field(
|
||||
description="Maximum number of concurrent active requests per app (0 for unlimited)",
|
||||
default=0,
|
||||
|
|
|
|||
|
|
@ -31,3 +31,8 @@ class WeaviateConfig(BaseSettings):
|
|||
description="Number of objects to be processed in a single batch operation (default is 100)",
|
||||
default=100,
|
||||
)
|
||||
|
||||
WEAVIATE_TOKENIZATION: str | None = Field(
|
||||
description="Tokenization for Weaviate (default is word)",
|
||||
default="word",
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,16 +1,23 @@
|
|||
from flask_restx import Resource, fields, reqparse
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from libs.login import login_required
|
||||
from services.advanced_prompt_template_service import AdvancedPromptTemplateService
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("app_mode", type=str, required=True, location="args", help="Application mode")
|
||||
.add_argument("model_mode", type=str, required=True, location="args", help="Model mode")
|
||||
.add_argument("has_context", type=str, required=False, default="true", location="args", help="Whether has context")
|
||||
.add_argument("model_name", type=str, required=True, location="args", help="Model name")
|
||||
|
||||
class AdvancedPromptTemplateQuery(BaseModel):
|
||||
app_mode: str = Field(..., description="Application mode")
|
||||
model_mode: str = Field(..., description="Model mode")
|
||||
has_context: str = Field(default="true", description="Whether has context")
|
||||
model_name: str = Field(..., description="Model name")
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
AdvancedPromptTemplateQuery.__name__,
|
||||
AdvancedPromptTemplateQuery.model_json_schema(ref_template="#/definitions/{model}"),
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -18,7 +25,7 @@ parser = (
|
|||
class AdvancedPromptTemplateList(Resource):
|
||||
@console_ns.doc("get_advanced_prompt_templates")
|
||||
@console_ns.doc(description="Get advanced prompt templates based on app mode and model configuration")
|
||||
@console_ns.expect(parser)
|
||||
@console_ns.expect(console_ns.models[AdvancedPromptTemplateQuery.__name__])
|
||||
@console_ns.response(
|
||||
200, "Prompt templates retrieved successfully", fields.List(fields.Raw(description="Prompt template data"))
|
||||
)
|
||||
|
|
@ -27,6 +34,6 @@ class AdvancedPromptTemplateList(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
args = parser.parse_args()
|
||||
args = AdvancedPromptTemplateQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
return AdvancedPromptTemplateService.get_prompt(args)
|
||||
return AdvancedPromptTemplateService.get_prompt(args.model_dump())
|
||||
|
|
|
|||
|
|
@ -1,9 +1,12 @@
|
|||
import uuid
|
||||
from typing import Literal
|
||||
|
||||
from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields, marshal, marshal_with
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import BadRequest, abort
|
||||
from werkzeug.exceptions import BadRequest
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
|
|
@ -36,6 +39,130 @@ from services.enterprise.enterprise_service import EnterpriseService
|
|||
from services.feature_service import FeatureService
|
||||
|
||||
ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "completion"]
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class AppListQuery(BaseModel):
|
||||
page: int = Field(default=1, ge=1, le=99999, description="Page number (1-99999)")
|
||||
limit: int = Field(default=20, ge=1, le=100, description="Page size (1-100)")
|
||||
mode: Literal["completion", "chat", "advanced-chat", "workflow", "agent-chat", "channel", "all"] = Field(
|
||||
default="all", description="App mode filter"
|
||||
)
|
||||
name: str | None = Field(default=None, description="Filter by app name")
|
||||
tag_ids: list[str] | None = Field(default=None, description="Comma-separated tag IDs")
|
||||
is_created_by_me: bool | None = Field(default=None, description="Filter by creator")
|
||||
|
||||
@field_validator("tag_ids", mode="before")
|
||||
@classmethod
|
||||
def validate_tag_ids(cls, value: str | list[str] | None) -> list[str] | None:
|
||||
if not value:
|
||||
return None
|
||||
|
||||
if isinstance(value, str):
|
||||
items = [item.strip() for item in value.split(",") if item.strip()]
|
||||
elif isinstance(value, list):
|
||||
items = [str(item).strip() for item in value if item and str(item).strip()]
|
||||
else:
|
||||
raise TypeError("Unsupported tag_ids type.")
|
||||
|
||||
if not items:
|
||||
return None
|
||||
|
||||
try:
|
||||
return [str(uuid.UUID(item)) for item in items]
|
||||
except ValueError as exc:
|
||||
raise ValueError("Invalid UUID format in tag_ids.") from exc
|
||||
|
||||
|
||||
class CreateAppPayload(BaseModel):
|
||||
name: str = Field(..., min_length=1, description="App name")
|
||||
description: str | None = Field(default=None, description="App description (max 400 chars)")
|
||||
mode: Literal["chat", "agent-chat", "advanced-chat", "workflow", "completion"] = Field(..., description="App mode")
|
||||
icon_type: str | None = Field(default=None, description="Icon type")
|
||||
icon: str | None = Field(default=None, description="Icon")
|
||||
icon_background: str | None = Field(default=None, description="Icon background color")
|
||||
|
||||
@field_validator("description")
|
||||
@classmethod
|
||||
def validate_description(cls, value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return value
|
||||
return validate_description_length(value)
|
||||
|
||||
|
||||
class UpdateAppPayload(BaseModel):
|
||||
name: str = Field(..., min_length=1, description="App name")
|
||||
description: str | None = Field(default=None, description="App description (max 400 chars)")
|
||||
icon_type: str | None = Field(default=None, description="Icon type")
|
||||
icon: str | None = Field(default=None, description="Icon")
|
||||
icon_background: str | None = Field(default=None, description="Icon background color")
|
||||
use_icon_as_answer_icon: bool | None = Field(default=None, description="Use icon as answer icon")
|
||||
max_active_requests: int | None = Field(default=None, description="Maximum active requests")
|
||||
|
||||
@field_validator("description")
|
||||
@classmethod
|
||||
def validate_description(cls, value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return value
|
||||
return validate_description_length(value)
|
||||
|
||||
|
||||
class CopyAppPayload(BaseModel):
|
||||
name: str | None = Field(default=None, description="Name for the copied app")
|
||||
description: str | None = Field(default=None, description="Description for the copied app")
|
||||
icon_type: str | None = Field(default=None, description="Icon type")
|
||||
icon: str | None = Field(default=None, description="Icon")
|
||||
icon_background: str | None = Field(default=None, description="Icon background color")
|
||||
|
||||
@field_validator("description")
|
||||
@classmethod
|
||||
def validate_description(cls, value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return value
|
||||
return validate_description_length(value)
|
||||
|
||||
|
||||
class AppExportQuery(BaseModel):
|
||||
include_secret: bool = Field(default=False, description="Include secrets in export")
|
||||
workflow_id: str | None = Field(default=None, description="Specific workflow ID to export")
|
||||
|
||||
|
||||
class AppNamePayload(BaseModel):
|
||||
name: str = Field(..., min_length=1, description="Name to check")
|
||||
|
||||
|
||||
class AppIconPayload(BaseModel):
|
||||
icon: str | None = Field(default=None, description="Icon data")
|
||||
icon_background: str | None = Field(default=None, description="Icon background color")
|
||||
|
||||
|
||||
class AppSiteStatusPayload(BaseModel):
|
||||
enable_site: bool = Field(..., description="Enable or disable site")
|
||||
|
||||
|
||||
class AppApiStatusPayload(BaseModel):
|
||||
enable_api: bool = Field(..., description="Enable or disable API")
|
||||
|
||||
|
||||
class AppTracePayload(BaseModel):
|
||||
enabled: bool = Field(..., description="Enable or disable tracing")
|
||||
tracing_provider: str = Field(..., description="Tracing provider")
|
||||
|
||||
|
||||
def reg(cls: type[BaseModel]):
|
||||
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
|
||||
reg(AppListQuery)
|
||||
reg(CreateAppPayload)
|
||||
reg(UpdateAppPayload)
|
||||
reg(CopyAppPayload)
|
||||
reg(AppExportQuery)
|
||||
reg(AppNamePayload)
|
||||
reg(AppIconPayload)
|
||||
reg(AppSiteStatusPayload)
|
||||
reg(AppApiStatusPayload)
|
||||
reg(AppTracePayload)
|
||||
|
||||
# Register models for flask_restx to avoid dict type issues in Swagger
|
||||
# Register base models first
|
||||
|
|
@ -147,22 +274,7 @@ app_pagination_model = console_ns.model(
|
|||
class AppListApi(Resource):
|
||||
@console_ns.doc("list_apps")
|
||||
@console_ns.doc(description="Get list of applications with pagination and filtering")
|
||||
@console_ns.expect(
|
||||
console_ns.parser()
|
||||
.add_argument("page", type=int, location="args", help="Page number (1-99999)", default=1)
|
||||
.add_argument("limit", type=int, location="args", help="Page size (1-100)", default=20)
|
||||
.add_argument(
|
||||
"mode",
|
||||
type=str,
|
||||
location="args",
|
||||
choices=["completion", "chat", "advanced-chat", "workflow", "agent-chat", "channel", "all"],
|
||||
default="all",
|
||||
help="App mode filter",
|
||||
)
|
||||
.add_argument("name", type=str, location="args", help="Filter by app name")
|
||||
.add_argument("tag_ids", type=str, location="args", help="Comma-separated tag IDs")
|
||||
.add_argument("is_created_by_me", type=bool, location="args", help="Filter by creator")
|
||||
)
|
||||
@console_ns.expect(console_ns.models[AppListQuery.__name__])
|
||||
@console_ns.response(200, "Success", app_pagination_model)
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -172,42 +284,12 @@ class AppListApi(Resource):
|
|||
"""Get app list"""
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
def uuid_list(value):
|
||||
try:
|
||||
return [str(uuid.UUID(v)) for v in value.split(",")]
|
||||
except ValueError:
|
||||
abort(400, message="Invalid UUID format in tag_ids.")
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args")
|
||||
.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
|
||||
.add_argument(
|
||||
"mode",
|
||||
type=str,
|
||||
choices=[
|
||||
"completion",
|
||||
"chat",
|
||||
"advanced-chat",
|
||||
"workflow",
|
||||
"agent-chat",
|
||||
"channel",
|
||||
"all",
|
||||
],
|
||||
default="all",
|
||||
location="args",
|
||||
required=False,
|
||||
)
|
||||
.add_argument("name", type=str, location="args", required=False)
|
||||
.add_argument("tag_ids", type=uuid_list, location="args", required=False)
|
||||
.add_argument("is_created_by_me", type=inputs.boolean, location="args", required=False)
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
args = AppListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args_dict = args.model_dump()
|
||||
|
||||
# get app list
|
||||
app_service = AppService()
|
||||
app_pagination = app_service.get_paginate_apps(current_user.id, current_tenant_id, args)
|
||||
app_pagination = app_service.get_paginate_apps(current_user.id, current_tenant_id, args_dict)
|
||||
if not app_pagination:
|
||||
return {"data": [], "total": 0, "page": 1, "limit": 20, "has_more": False}
|
||||
|
||||
|
|
@ -254,19 +336,7 @@ class AppListApi(Resource):
|
|||
|
||||
@console_ns.doc("create_app")
|
||||
@console_ns.doc(description="Create a new application")
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"CreateAppRequest",
|
||||
{
|
||||
"name": fields.String(required=True, description="App name"),
|
||||
"description": fields.String(description="App description (max 400 chars)"),
|
||||
"mode": fields.String(required=True, enum=ALLOW_CREATE_APP_MODES, description="App mode"),
|
||||
"icon_type": fields.String(description="Icon type"),
|
||||
"icon": fields.String(description="Icon"),
|
||||
"icon_background": fields.String(description="Icon background color"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[CreateAppPayload.__name__])
|
||||
@console_ns.response(201, "App created successfully", app_detail_model)
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@console_ns.response(400, "Invalid request parameters")
|
||||
|
|
@ -279,22 +349,10 @@ class AppListApi(Resource):
|
|||
def post(self):
|
||||
"""Create app"""
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("name", type=str, required=True, location="json")
|
||||
.add_argument("description", type=validate_description_length, location="json")
|
||||
.add_argument("mode", type=str, choices=ALLOW_CREATE_APP_MODES, location="json")
|
||||
.add_argument("icon_type", type=str, location="json")
|
||||
.add_argument("icon", type=str, location="json")
|
||||
.add_argument("icon_background", type=str, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if "mode" not in args or args["mode"] is None:
|
||||
raise BadRequest("mode is required")
|
||||
args = CreateAppPayload.model_validate(console_ns.payload)
|
||||
|
||||
app_service = AppService()
|
||||
app = app_service.create_app(current_tenant_id, args, current_user)
|
||||
app = app_service.create_app(current_tenant_id, args.model_dump(), current_user)
|
||||
|
||||
return app, 201
|
||||
|
||||
|
|
@ -326,20 +384,7 @@ class AppApi(Resource):
|
|||
@console_ns.doc("update_app")
|
||||
@console_ns.doc(description="Update application details")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"UpdateAppRequest",
|
||||
{
|
||||
"name": fields.String(required=True, description="App name"),
|
||||
"description": fields.String(description="App description (max 400 chars)"),
|
||||
"icon_type": fields.String(description="Icon type"),
|
||||
"icon": fields.String(description="Icon"),
|
||||
"icon_background": fields.String(description="Icon background color"),
|
||||
"use_icon_as_answer_icon": fields.Boolean(description="Use icon as answer icon"),
|
||||
"max_active_requests": fields.Integer(description="Maximum active requests"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[UpdateAppPayload.__name__])
|
||||
@console_ns.response(200, "App updated successfully", app_detail_with_site_model)
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@console_ns.response(400, "Invalid request parameters")
|
||||
|
|
@ -351,28 +396,18 @@ class AppApi(Resource):
|
|||
@marshal_with(app_detail_with_site_model)
|
||||
def put(self, app_model):
|
||||
"""Update app"""
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("name", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("description", type=validate_description_length, location="json")
|
||||
.add_argument("icon_type", type=str, location="json")
|
||||
.add_argument("icon", type=str, location="json")
|
||||
.add_argument("icon_background", type=str, location="json")
|
||||
.add_argument("use_icon_as_answer_icon", type=bool, location="json")
|
||||
.add_argument("max_active_requests", type=int, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = UpdateAppPayload.model_validate(console_ns.payload)
|
||||
|
||||
app_service = AppService()
|
||||
|
||||
args_dict: AppService.ArgsDict = {
|
||||
"name": args["name"],
|
||||
"description": args.get("description", ""),
|
||||
"icon_type": args.get("icon_type", ""),
|
||||
"icon": args.get("icon", ""),
|
||||
"icon_background": args.get("icon_background", ""),
|
||||
"use_icon_as_answer_icon": args.get("use_icon_as_answer_icon", False),
|
||||
"max_active_requests": args.get("max_active_requests", 0),
|
||||
"name": args.name,
|
||||
"description": args.description or "",
|
||||
"icon_type": args.icon_type or "",
|
||||
"icon": args.icon or "",
|
||||
"icon_background": args.icon_background or "",
|
||||
"use_icon_as_answer_icon": args.use_icon_as_answer_icon or False,
|
||||
"max_active_requests": args.max_active_requests or 0,
|
||||
}
|
||||
app_model = app_service.update_app(app_model, args_dict)
|
||||
|
||||
|
|
@ -401,18 +436,7 @@ class AppCopyApi(Resource):
|
|||
@console_ns.doc("copy_app")
|
||||
@console_ns.doc(description="Create a copy of an existing application")
|
||||
@console_ns.doc(params={"app_id": "Application ID to copy"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"CopyAppRequest",
|
||||
{
|
||||
"name": fields.String(description="Name for the copied app"),
|
||||
"description": fields.String(description="Description for the copied app"),
|
||||
"icon_type": fields.String(description="Icon type"),
|
||||
"icon": fields.String(description="Icon"),
|
||||
"icon_background": fields.String(description="Icon background color"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[CopyAppPayload.__name__])
|
||||
@console_ns.response(201, "App copied successfully", app_detail_with_site_model)
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
|
|
@ -426,15 +450,7 @@ class AppCopyApi(Resource):
|
|||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("name", type=str, location="json")
|
||||
.add_argument("description", type=validate_description_length, location="json")
|
||||
.add_argument("icon_type", type=str, location="json")
|
||||
.add_argument("icon", type=str, location="json")
|
||||
.add_argument("icon_background", type=str, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = CopyAppPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
with Session(db.engine) as session:
|
||||
import_service = AppDslService(session)
|
||||
|
|
@ -443,11 +459,11 @@ class AppCopyApi(Resource):
|
|||
account=current_user,
|
||||
import_mode=ImportMode.YAML_CONTENT,
|
||||
yaml_content=yaml_content,
|
||||
name=args.get("name"),
|
||||
description=args.get("description"),
|
||||
icon_type=args.get("icon_type"),
|
||||
icon=args.get("icon"),
|
||||
icon_background=args.get("icon_background"),
|
||||
name=args.name,
|
||||
description=args.description,
|
||||
icon_type=args.icon_type,
|
||||
icon=args.icon,
|
||||
icon_background=args.icon_background,
|
||||
)
|
||||
session.commit()
|
||||
|
||||
|
|
@ -462,11 +478,7 @@ class AppExportApi(Resource):
|
|||
@console_ns.doc("export_app")
|
||||
@console_ns.doc(description="Export application configuration as DSL")
|
||||
@console_ns.doc(params={"app_id": "Application ID to export"})
|
||||
@console_ns.expect(
|
||||
console_ns.parser()
|
||||
.add_argument("include_secret", type=bool, location="args", default=False, help="Include secrets in export")
|
||||
.add_argument("workflow_id", type=str, location="args", help="Specific workflow ID to export")
|
||||
)
|
||||
@console_ns.expect(console_ns.models[AppExportQuery.__name__])
|
||||
@console_ns.response(
|
||||
200,
|
||||
"App exported successfully",
|
||||
|
|
@ -480,30 +492,23 @@ class AppExportApi(Resource):
|
|||
@edit_permission_required
|
||||
def get(self, app_model):
|
||||
"""Export app"""
|
||||
# Add include_secret params
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("include_secret", type=inputs.boolean, default=False, location="args")
|
||||
.add_argument("workflow_id", type=str, location="args")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = AppExportQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
return {
|
||||
"data": AppDslService.export_dsl(
|
||||
app_model=app_model, include_secret=args["include_secret"], workflow_id=args.get("workflow_id")
|
||||
app_model=app_model,
|
||||
include_secret=args.include_secret,
|
||||
workflow_id=args.workflow_id,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
parser = reqparse.RequestParser().add_argument("name", type=str, required=True, location="json", help="Name to check")
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/name")
|
||||
class AppNameApi(Resource):
|
||||
@console_ns.doc("check_app_name")
|
||||
@console_ns.doc(description="Check if app name is available")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(parser)
|
||||
@console_ns.expect(console_ns.models[AppNamePayload.__name__])
|
||||
@console_ns.response(200, "Name availability checked")
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -512,10 +517,10 @@ class AppNameApi(Resource):
|
|||
@marshal_with(app_detail_model)
|
||||
@edit_permission_required
|
||||
def post(self, app_model):
|
||||
args = parser.parse_args()
|
||||
args = AppNamePayload.model_validate(console_ns.payload)
|
||||
|
||||
app_service = AppService()
|
||||
app_model = app_service.update_app_name(app_model, args["name"])
|
||||
app_model = app_service.update_app_name(app_model, args.name)
|
||||
|
||||
return app_model
|
||||
|
||||
|
|
@ -525,16 +530,7 @@ class AppIconApi(Resource):
|
|||
@console_ns.doc("update_app_icon")
|
||||
@console_ns.doc(description="Update application icon")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"AppIconRequest",
|
||||
{
|
||||
"icon": fields.String(required=True, description="Icon data"),
|
||||
"icon_type": fields.String(description="Icon type"),
|
||||
"icon_background": fields.String(description="Icon background color"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[AppIconPayload.__name__])
|
||||
@console_ns.response(200, "Icon updated successfully")
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
|
|
@ -544,15 +540,10 @@ class AppIconApi(Resource):
|
|||
@marshal_with(app_detail_model)
|
||||
@edit_permission_required
|
||||
def post(self, app_model):
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("icon", type=str, location="json")
|
||||
.add_argument("icon_background", type=str, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = AppIconPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
app_service = AppService()
|
||||
app_model = app_service.update_app_icon(app_model, args.get("icon") or "", args.get("icon_background") or "")
|
||||
app_model = app_service.update_app_icon(app_model, args.icon or "", args.icon_background or "")
|
||||
|
||||
return app_model
|
||||
|
||||
|
|
@ -562,11 +553,7 @@ class AppSiteStatus(Resource):
|
|||
@console_ns.doc("update_app_site_status")
|
||||
@console_ns.doc(description="Enable or disable app site")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"AppSiteStatusRequest", {"enable_site": fields.Boolean(required=True, description="Enable or disable site")}
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[AppSiteStatusPayload.__name__])
|
||||
@console_ns.response(200, "Site status updated successfully", app_detail_model)
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
|
|
@ -576,11 +563,10 @@ class AppSiteStatus(Resource):
|
|||
@marshal_with(app_detail_model)
|
||||
@edit_permission_required
|
||||
def post(self, app_model):
|
||||
parser = reqparse.RequestParser().add_argument("enable_site", type=bool, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
args = AppSiteStatusPayload.model_validate(console_ns.payload)
|
||||
|
||||
app_service = AppService()
|
||||
app_model = app_service.update_app_site_status(app_model, args["enable_site"])
|
||||
app_model = app_service.update_app_site_status(app_model, args.enable_site)
|
||||
|
||||
return app_model
|
||||
|
||||
|
|
@ -590,11 +576,7 @@ class AppApiStatus(Resource):
|
|||
@console_ns.doc("update_app_api_status")
|
||||
@console_ns.doc(description="Enable or disable app API")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"AppApiStatusRequest", {"enable_api": fields.Boolean(required=True, description="Enable or disable API")}
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[AppApiStatusPayload.__name__])
|
||||
@console_ns.response(200, "API status updated successfully", app_detail_model)
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
|
|
@ -604,11 +586,10 @@ class AppApiStatus(Resource):
|
|||
@get_app_model
|
||||
@marshal_with(app_detail_model)
|
||||
def post(self, app_model):
|
||||
parser = reqparse.RequestParser().add_argument("enable_api", type=bool, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
args = AppApiStatusPayload.model_validate(console_ns.payload)
|
||||
|
||||
app_service = AppService()
|
||||
app_model = app_service.update_app_api_status(app_model, args["enable_api"])
|
||||
app_model = app_service.update_app_api_status(app_model, args.enable_api)
|
||||
|
||||
return app_model
|
||||
|
||||
|
|
@ -631,15 +612,7 @@ class AppTraceApi(Resource):
|
|||
@console_ns.doc("update_app_trace")
|
||||
@console_ns.doc(description="Update app tracing configuration")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"AppTraceRequest",
|
||||
{
|
||||
"enabled": fields.Boolean(required=True, description="Enable or disable tracing"),
|
||||
"tracing_provider": fields.String(required=True, description="Tracing provider"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[AppTracePayload.__name__])
|
||||
@console_ns.response(200, "Trace configuration updated successfully")
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
|
|
@ -648,17 +621,12 @@ class AppTraceApi(Resource):
|
|||
@edit_permission_required
|
||||
def post(self, app_id):
|
||||
# add app trace
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("enabled", type=bool, required=True, location="json")
|
||||
.add_argument("tracing_provider", type=str, required=True, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = AppTracePayload.model_validate(console_ns.payload)
|
||||
|
||||
OpsTraceManager.update_app_tracing_config(
|
||||
app_id=app_id,
|
||||
enabled=args["enabled"],
|
||||
tracing_provider=args["tracing_provider"],
|
||||
enabled=args.enabled,
|
||||
tracing_provider=args.tracing_provider,
|
||||
)
|
||||
|
||||
return {"result": "success"}
|
||||
|
|
|
|||
|
|
@ -1,7 +1,9 @@
|
|||
import logging
|
||||
from typing import Any, Literal
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields, reqparse
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
|
|
@ -17,7 +19,6 @@ from controllers.console.app.error import (
|
|||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.errors.error import (
|
||||
ModelCurrentlyNotSupportError,
|
||||
|
|
@ -32,9 +33,45 @@ from libs.login import current_user, login_required
|
|||
from models import Account
|
||||
from models.model import AppMode
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.app_task_service import AppTaskService
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class BaseMessagePayload(BaseModel):
|
||||
inputs: dict[str, Any]
|
||||
model_config_data: dict[str, Any] = Field(..., alias="model_config")
|
||||
files: list[Any] | None = Field(default=None, description="Uploaded files")
|
||||
response_mode: Literal["blocking", "streaming"] = Field(default="blocking", description="Response mode")
|
||||
retriever_from: str = Field(default="dev", description="Retriever source")
|
||||
|
||||
|
||||
class CompletionMessagePayload(BaseMessagePayload):
|
||||
query: str = Field(default="", description="Query text")
|
||||
|
||||
|
||||
class ChatMessagePayload(BaseMessagePayload):
|
||||
query: str = Field(..., description="User query")
|
||||
conversation_id: str | None = Field(default=None, description="Conversation ID")
|
||||
parent_message_id: str | None = Field(default=None, description="Parent message ID")
|
||||
|
||||
@field_validator("conversation_id", "parent_message_id")
|
||||
@classmethod
|
||||
def validate_uuid(cls, value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return value
|
||||
return uuid_value(value)
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
CompletionMessagePayload.__name__,
|
||||
CompletionMessagePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
console_ns.schema_model(
|
||||
ChatMessagePayload.__name__, ChatMessagePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
|
||||
|
||||
# define completion message api for user
|
||||
|
|
@ -43,19 +80,7 @@ class CompletionMessageApi(Resource):
|
|||
@console_ns.doc("create_completion_message")
|
||||
@console_ns.doc(description="Generate completion message for debugging")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"CompletionMessageRequest",
|
||||
{
|
||||
"inputs": fields.Raw(required=True, description="Input variables"),
|
||||
"query": fields.String(description="Query text", default=""),
|
||||
"files": fields.List(fields.Raw(), description="Uploaded files"),
|
||||
"model_config": fields.Raw(required=True, description="Model configuration"),
|
||||
"response_mode": fields.String(enum=["blocking", "streaming"], description="Response mode"),
|
||||
"retriever_from": fields.String(default="dev", description="Retriever source"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[CompletionMessagePayload.__name__])
|
||||
@console_ns.response(200, "Completion generated successfully")
|
||||
@console_ns.response(400, "Invalid request parameters")
|
||||
@console_ns.response(404, "App not found")
|
||||
|
|
@ -64,18 +89,10 @@ class CompletionMessageApi(Resource):
|
|||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.COMPLETION)
|
||||
def post(self, app_model):
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("inputs", type=dict, required=True, location="json")
|
||||
.add_argument("query", type=str, location="json", default="")
|
||||
.add_argument("files", type=list, required=False, location="json")
|
||||
.add_argument("model_config", type=dict, required=True, location="json")
|
||||
.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
|
||||
.add_argument("retriever_from", type=str, required=False, default="dev", location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args_model = CompletionMessagePayload.model_validate(console_ns.payload)
|
||||
args = args_model.model_dump(exclude_none=True, by_alias=True)
|
||||
|
||||
streaming = args["response_mode"] != "blocking"
|
||||
streaming = args_model.response_mode != "blocking"
|
||||
args["auto_generate_name"] = False
|
||||
|
||||
try:
|
||||
|
|
@ -121,7 +138,13 @@ class CompletionMessageStopApi(Resource):
|
|||
def post(self, app_model, task_id):
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, current_user.id)
|
||||
|
||||
AppTaskService.stop_task(
|
||||
task_id=task_id,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
user_id=current_user.id,
|
||||
app_mode=AppMode.value_of(app_model.mode),
|
||||
)
|
||||
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
|
@ -131,21 +154,7 @@ class ChatMessageApi(Resource):
|
|||
@console_ns.doc("create_chat_message")
|
||||
@console_ns.doc(description="Generate chat message for debugging")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"ChatMessageRequest",
|
||||
{
|
||||
"inputs": fields.Raw(required=True, description="Input variables"),
|
||||
"query": fields.String(required=True, description="User query"),
|
||||
"files": fields.List(fields.Raw(), description="Uploaded files"),
|
||||
"model_config": fields.Raw(required=True, description="Model configuration"),
|
||||
"conversation_id": fields.String(description="Conversation ID"),
|
||||
"parent_message_id": fields.String(description="Parent message ID"),
|
||||
"response_mode": fields.String(enum=["blocking", "streaming"], description="Response mode"),
|
||||
"retriever_from": fields.String(default="dev", description="Retriever source"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[ChatMessagePayload.__name__])
|
||||
@console_ns.response(200, "Chat message generated successfully")
|
||||
@console_ns.response(400, "Invalid request parameters")
|
||||
@console_ns.response(404, "App or conversation not found")
|
||||
|
|
@ -155,20 +164,10 @@ class ChatMessageApi(Resource):
|
|||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT])
|
||||
@edit_permission_required
|
||||
def post(self, app_model):
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("inputs", type=dict, required=True, location="json")
|
||||
.add_argument("query", type=str, required=True, location="json")
|
||||
.add_argument("files", type=list, required=False, location="json")
|
||||
.add_argument("model_config", type=dict, required=True, location="json")
|
||||
.add_argument("conversation_id", type=uuid_value, location="json")
|
||||
.add_argument("parent_message_id", type=uuid_value, required=False, location="json")
|
||||
.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
|
||||
.add_argument("retriever_from", type=str, required=False, default="dev", location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args_model = ChatMessagePayload.model_validate(console_ns.payload)
|
||||
args = args_model.model_dump(exclude_none=True, by_alias=True)
|
||||
|
||||
streaming = args["response_mode"] != "blocking"
|
||||
streaming = args_model.response_mode != "blocking"
|
||||
args["auto_generate_name"] = False
|
||||
|
||||
external_trace_id = get_external_trace_id(request)
|
||||
|
|
@ -220,6 +219,12 @@ class ChatMessageStopApi(Resource):
|
|||
def post(self, app_model, task_id):
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, current_user.id)
|
||||
|
||||
AppTaskService.stop_task(
|
||||
task_id=task_id,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
user_id=current_user.id,
|
||||
app_mode=AppMode.value_of(app_model.mode),
|
||||
)
|
||||
|
||||
return {"result": "success"}, 200
|
||||
|
|
|
|||
|
|
@ -1,7 +1,9 @@
|
|||
from typing import Literal
|
||||
|
||||
import sqlalchemy as sa
|
||||
from flask import abort
|
||||
from flask_restx import Resource, fields, marshal_with, reqparse
|
||||
from flask_restx.inputs import int_range
|
||||
from flask import abort, request
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy import func, or_
|
||||
from sqlalchemy.orm import joinedload
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
|
@ -14,13 +16,54 @@ from extensions.ext_database import db
|
|||
from fields.conversation_fields import MessageTextField
|
||||
from fields.raws import FilesContainedField
|
||||
from libs.datetime_utils import naive_utc_now, parse_time_range
|
||||
from libs.helper import DatetimeString, TimestampField
|
||||
from libs.helper import TimestampField
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import Conversation, EndUser, Message, MessageAnnotation
|
||||
from models.model import AppMode
|
||||
from services.conversation_service import ConversationService
|
||||
from services.errors.conversation import ConversationNotExistsError
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class BaseConversationQuery(BaseModel):
|
||||
keyword: str | None = Field(default=None, description="Search keyword")
|
||||
start: str | None = Field(default=None, description="Start date (YYYY-MM-DD HH:MM)")
|
||||
end: str | None = Field(default=None, description="End date (YYYY-MM-DD HH:MM)")
|
||||
annotation_status: Literal["annotated", "not_annotated", "all"] = Field(
|
||||
default="all", description="Annotation status filter"
|
||||
)
|
||||
page: int = Field(default=1, ge=1, le=99999, description="Page number")
|
||||
limit: int = Field(default=20, ge=1, le=100, description="Page size (1-100)")
|
||||
|
||||
@field_validator("start", "end", mode="before")
|
||||
@classmethod
|
||||
def blank_to_none(cls, value: str | None) -> str | None:
|
||||
if value == "":
|
||||
return None
|
||||
return value
|
||||
|
||||
|
||||
class CompletionConversationQuery(BaseConversationQuery):
|
||||
pass
|
||||
|
||||
|
||||
class ChatConversationQuery(BaseConversationQuery):
|
||||
message_count_gte: int | None = Field(default=None, ge=1, description="Minimum message count")
|
||||
sort_by: Literal["created_at", "-created_at", "updated_at", "-updated_at"] = Field(
|
||||
default="-updated_at", description="Sort field and direction"
|
||||
)
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
CompletionConversationQuery.__name__,
|
||||
CompletionConversationQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
console_ns.schema_model(
|
||||
ChatConversationQuery.__name__,
|
||||
ChatConversationQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
# Register models for flask_restx to avoid dict type issues in Swagger
|
||||
# Register in dependency order: base models first, then dependent models
|
||||
|
||||
|
|
@ -283,22 +326,7 @@ class CompletionConversationApi(Resource):
|
|||
@console_ns.doc("list_completion_conversations")
|
||||
@console_ns.doc(description="Get completion conversations with pagination and filtering")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.parser()
|
||||
.add_argument("keyword", type=str, location="args", help="Search keyword")
|
||||
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
|
||||
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
|
||||
.add_argument(
|
||||
"annotation_status",
|
||||
type=str,
|
||||
location="args",
|
||||
choices=["annotated", "not_annotated", "all"],
|
||||
default="all",
|
||||
help="Annotation status filter",
|
||||
)
|
||||
.add_argument("page", type=int, location="args", default=1, help="Page number")
|
||||
.add_argument("limit", type=int, location="args", default=20, help="Page size (1-100)")
|
||||
)
|
||||
@console_ns.expect(console_ns.models[CompletionConversationQuery.__name__])
|
||||
@console_ns.response(200, "Success", conversation_pagination_model)
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
|
|
@ -309,32 +337,17 @@ class CompletionConversationApi(Resource):
|
|||
@edit_permission_required
|
||||
def get(self, app_model):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("keyword", type=str, location="args")
|
||||
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
.add_argument(
|
||||
"annotation_status",
|
||||
type=str,
|
||||
choices=["annotated", "not_annotated", "all"],
|
||||
default="all",
|
||||
location="args",
|
||||
)
|
||||
.add_argument("page", type=int_range(1, 99999), default=1, location="args")
|
||||
.add_argument("limit", type=int_range(1, 100), default=20, location="args")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = CompletionConversationQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
query = sa.select(Conversation).where(
|
||||
Conversation.app_id == app_model.id, Conversation.mode == "completion", Conversation.is_deleted.is_(False)
|
||||
)
|
||||
|
||||
if args["keyword"]:
|
||||
if args.keyword:
|
||||
query = query.join(Message, Message.conversation_id == Conversation.id).where(
|
||||
or_(
|
||||
Message.query.ilike(f"%{args['keyword']}%"),
|
||||
Message.answer.ilike(f"%{args['keyword']}%"),
|
||||
Message.query.ilike(f"%{args.keyword}%"),
|
||||
Message.answer.ilike(f"%{args.keyword}%"),
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -342,7 +355,7 @@ class CompletionConversationApi(Resource):
|
|||
assert account.timezone is not None
|
||||
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
|
|
@ -354,11 +367,11 @@ class CompletionConversationApi(Resource):
|
|||
query = query.where(Conversation.created_at < end_datetime_utc)
|
||||
|
||||
# FIXME, the type ignore in this file
|
||||
if args["annotation_status"] == "annotated":
|
||||
if args.annotation_status == "annotated":
|
||||
query = query.options(joinedload(Conversation.message_annotations)).join( # type: ignore
|
||||
MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
|
||||
)
|
||||
elif args["annotation_status"] == "not_annotated":
|
||||
elif args.annotation_status == "not_annotated":
|
||||
query = (
|
||||
query.outerjoin(MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id)
|
||||
.group_by(Conversation.id)
|
||||
|
|
@ -367,7 +380,7 @@ class CompletionConversationApi(Resource):
|
|||
|
||||
query = query.order_by(Conversation.created_at.desc())
|
||||
|
||||
conversations = db.paginate(query, page=args["page"], per_page=args["limit"], error_out=False)
|
||||
conversations = db.paginate(query, page=args.page, per_page=args.limit, error_out=False)
|
||||
|
||||
return conversations
|
||||
|
||||
|
|
@ -419,31 +432,7 @@ class ChatConversationApi(Resource):
|
|||
@console_ns.doc("list_chat_conversations")
|
||||
@console_ns.doc(description="Get chat conversations with pagination, filtering and summary")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.parser()
|
||||
.add_argument("keyword", type=str, location="args", help="Search keyword")
|
||||
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
|
||||
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
|
||||
.add_argument(
|
||||
"annotation_status",
|
||||
type=str,
|
||||
location="args",
|
||||
choices=["annotated", "not_annotated", "all"],
|
||||
default="all",
|
||||
help="Annotation status filter",
|
||||
)
|
||||
.add_argument("message_count_gte", type=int, location="args", help="Minimum message count")
|
||||
.add_argument("page", type=int, location="args", default=1, help="Page number")
|
||||
.add_argument("limit", type=int, location="args", default=20, help="Page size (1-100)")
|
||||
.add_argument(
|
||||
"sort_by",
|
||||
type=str,
|
||||
location="args",
|
||||
choices=["created_at", "-created_at", "updated_at", "-updated_at"],
|
||||
default="-updated_at",
|
||||
help="Sort field and direction",
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[ChatConversationQuery.__name__])
|
||||
@console_ns.response(200, "Success", conversation_with_summary_pagination_model)
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
|
|
@ -454,31 +443,7 @@ class ChatConversationApi(Resource):
|
|||
@edit_permission_required
|
||||
def get(self, app_model):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("keyword", type=str, location="args")
|
||||
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
.add_argument(
|
||||
"annotation_status",
|
||||
type=str,
|
||||
choices=["annotated", "not_annotated", "all"],
|
||||
default="all",
|
||||
location="args",
|
||||
)
|
||||
.add_argument("message_count_gte", type=int_range(1, 99999), required=False, location="args")
|
||||
.add_argument("page", type=int_range(1, 99999), required=False, default=1, location="args")
|
||||
.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
|
||||
.add_argument(
|
||||
"sort_by",
|
||||
type=str,
|
||||
choices=["created_at", "-created_at", "updated_at", "-updated_at"],
|
||||
required=False,
|
||||
default="-updated_at",
|
||||
location="args",
|
||||
)
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = ChatConversationQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
subquery = (
|
||||
db.session.query(
|
||||
|
|
@ -490,8 +455,8 @@ class ChatConversationApi(Resource):
|
|||
|
||||
query = sa.select(Conversation).where(Conversation.app_id == app_model.id, Conversation.is_deleted.is_(False))
|
||||
|
||||
if args["keyword"]:
|
||||
keyword_filter = f"%{args['keyword']}%"
|
||||
if args.keyword:
|
||||
keyword_filter = f"%{args.keyword}%"
|
||||
query = (
|
||||
query.join(
|
||||
Message,
|
||||
|
|
@ -514,12 +479,12 @@ class ChatConversationApi(Resource):
|
|||
assert account.timezone is not None
|
||||
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
if start_datetime_utc:
|
||||
match args["sort_by"]:
|
||||
match args.sort_by:
|
||||
case "updated_at" | "-updated_at":
|
||||
query = query.where(Conversation.updated_at >= start_datetime_utc)
|
||||
case "created_at" | "-created_at" | _:
|
||||
|
|
@ -527,35 +492,35 @@ class ChatConversationApi(Resource):
|
|||
|
||||
if end_datetime_utc:
|
||||
end_datetime_utc = end_datetime_utc.replace(second=59)
|
||||
match args["sort_by"]:
|
||||
match args.sort_by:
|
||||
case "updated_at" | "-updated_at":
|
||||
query = query.where(Conversation.updated_at <= end_datetime_utc)
|
||||
case "created_at" | "-created_at" | _:
|
||||
query = query.where(Conversation.created_at <= end_datetime_utc)
|
||||
|
||||
if args["annotation_status"] == "annotated":
|
||||
if args.annotation_status == "annotated":
|
||||
query = query.options(joinedload(Conversation.message_annotations)).join( # type: ignore
|
||||
MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
|
||||
)
|
||||
elif args["annotation_status"] == "not_annotated":
|
||||
elif args.annotation_status == "not_annotated":
|
||||
query = (
|
||||
query.outerjoin(MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id)
|
||||
.group_by(Conversation.id)
|
||||
.having(func.count(MessageAnnotation.id) == 0)
|
||||
)
|
||||
|
||||
if args["message_count_gte"] and args["message_count_gte"] >= 1:
|
||||
if args.message_count_gte and args.message_count_gte >= 1:
|
||||
query = (
|
||||
query.options(joinedload(Conversation.messages)) # type: ignore
|
||||
.join(Message, Message.conversation_id == Conversation.id)
|
||||
.group_by(Conversation.id)
|
||||
.having(func.count(Message.id) >= args["message_count_gte"])
|
||||
.having(func.count(Message.id) >= args.message_count_gte)
|
||||
)
|
||||
|
||||
if app_model.mode == AppMode.ADVANCED_CHAT:
|
||||
query = query.where(Conversation.invoke_from != InvokeFrom.DEBUGGER)
|
||||
|
||||
match args["sort_by"]:
|
||||
match args.sort_by:
|
||||
case "created_at":
|
||||
query = query.order_by(Conversation.created_at.asc())
|
||||
case "-created_at":
|
||||
|
|
@ -567,7 +532,7 @@ class ChatConversationApi(Resource):
|
|||
case _:
|
||||
query = query.order_by(Conversation.created_at.desc())
|
||||
|
||||
conversations = db.paginate(query, page=args["page"], per_page=args["limit"], error_out=False)
|
||||
conversations = db.paginate(query, page=args.page, per_page=args.limit, error_out=False)
|
||||
|
||||
return conversations
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,6 @@
|
|||
from flask_restx import Resource, fields, marshal_with, reqparse
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
|
|
@ -14,6 +16,18 @@ from libs.login import login_required
|
|||
from models import ConversationVariable
|
||||
from models.model import AppMode
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class ConversationVariablesQuery(BaseModel):
|
||||
conversation_id: str = Field(..., description="Conversation ID to filter variables")
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
ConversationVariablesQuery.__name__,
|
||||
ConversationVariablesQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
# Register models for flask_restx to avoid dict type issues in Swagger
|
||||
# Register base model first
|
||||
conversation_variable_model = console_ns.model("ConversationVariable", conversation_variable_fields)
|
||||
|
|
@ -33,11 +47,7 @@ class ConversationVariablesApi(Resource):
|
|||
@console_ns.doc("get_conversation_variables")
|
||||
@console_ns.doc(description="Get conversation variables for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.parser().add_argument(
|
||||
"conversation_id", type=str, location="args", help="Conversation ID to filter variables"
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[ConversationVariablesQuery.__name__])
|
||||
@console_ns.response(200, "Conversation variables retrieved successfully", paginated_conversation_variable_model)
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -45,18 +55,14 @@ class ConversationVariablesApi(Resource):
|
|||
@get_app_model(mode=AppMode.ADVANCED_CHAT)
|
||||
@marshal_with(paginated_conversation_variable_model)
|
||||
def get(self, app_model):
|
||||
parser = reqparse.RequestParser().add_argument("conversation_id", type=str, location="args")
|
||||
args = parser.parse_args()
|
||||
args = ConversationVariablesQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
stmt = (
|
||||
select(ConversationVariable)
|
||||
.where(ConversationVariable.app_id == app_model.id)
|
||||
.order_by(ConversationVariable.created_at)
|
||||
)
|
||||
if args["conversation_id"]:
|
||||
stmt = stmt.where(ConversationVariable.conversation_id == args["conversation_id"])
|
||||
else:
|
||||
raise ValueError("conversation_id is required")
|
||||
stmt = stmt.where(ConversationVariable.conversation_id == args.conversation_id)
|
||||
|
||||
# NOTE: This is a temporary solution to avoid performance issues.
|
||||
page = 1
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from flask_restx import Resource, fields, reqparse
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import (
|
||||
|
|
@ -21,21 +23,54 @@ from libs.login import current_account_with_tenant, login_required
|
|||
from models import App
|
||||
from services.workflow_service import WorkflowService
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class RuleGeneratePayload(BaseModel):
|
||||
instruction: str = Field(..., description="Rule generation instruction")
|
||||
model_config_data: dict[str, Any] = Field(..., alias="model_config", description="Model configuration")
|
||||
no_variable: bool = Field(default=False, description="Whether to exclude variables")
|
||||
|
||||
|
||||
class RuleCodeGeneratePayload(RuleGeneratePayload):
|
||||
code_language: str = Field(default="javascript", description="Programming language for code generation")
|
||||
|
||||
|
||||
class RuleStructuredOutputPayload(BaseModel):
|
||||
instruction: str = Field(..., description="Structured output generation instruction")
|
||||
model_config_data: dict[str, Any] = Field(..., alias="model_config", description="Model configuration")
|
||||
|
||||
|
||||
class InstructionGeneratePayload(BaseModel):
|
||||
flow_id: str = Field(..., description="Workflow/Flow ID")
|
||||
node_id: str = Field(default="", description="Node ID for workflow context")
|
||||
current: str = Field(default="", description="Current instruction text")
|
||||
language: str = Field(default="javascript", description="Programming language (javascript/python)")
|
||||
instruction: str = Field(..., description="Instruction for generation")
|
||||
model_config_data: dict[str, Any] = Field(..., alias="model_config", description="Model configuration")
|
||||
ideal_output: str = Field(default="", description="Expected ideal output")
|
||||
|
||||
|
||||
class InstructionTemplatePayload(BaseModel):
|
||||
type: str = Field(..., description="Instruction template type")
|
||||
|
||||
|
||||
def reg(cls: type[BaseModel]):
|
||||
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
|
||||
reg(RuleGeneratePayload)
|
||||
reg(RuleCodeGeneratePayload)
|
||||
reg(RuleStructuredOutputPayload)
|
||||
reg(InstructionGeneratePayload)
|
||||
reg(InstructionTemplatePayload)
|
||||
|
||||
|
||||
@console_ns.route("/rule-generate")
|
||||
class RuleGenerateApi(Resource):
|
||||
@console_ns.doc("generate_rule_config")
|
||||
@console_ns.doc(description="Generate rule configuration using LLM")
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"RuleGenerateRequest",
|
||||
{
|
||||
"instruction": fields.String(required=True, description="Rule generation instruction"),
|
||||
"model_config": fields.Raw(required=True, description="Model configuration"),
|
||||
"no_variable": fields.Boolean(required=True, default=False, description="Whether to exclude variables"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[RuleGeneratePayload.__name__])
|
||||
@console_ns.response(200, "Rule configuration generated successfully")
|
||||
@console_ns.response(400, "Invalid request parameters")
|
||||
@console_ns.response(402, "Provider quota exceeded")
|
||||
|
|
@ -43,21 +78,15 @@ class RuleGenerateApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("instruction", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("no_variable", type=bool, required=True, default=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = RuleGeneratePayload.model_validate(console_ns.payload)
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
try:
|
||||
rules = LLMGenerator.generate_rule_config(
|
||||
tenant_id=current_tenant_id,
|
||||
instruction=args["instruction"],
|
||||
model_config=args["model_config"],
|
||||
no_variable=args["no_variable"],
|
||||
instruction=args.instruction,
|
||||
model_config=args.model_config_data,
|
||||
no_variable=args.no_variable,
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
|
|
@ -75,19 +104,7 @@ class RuleGenerateApi(Resource):
|
|||
class RuleCodeGenerateApi(Resource):
|
||||
@console_ns.doc("generate_rule_code")
|
||||
@console_ns.doc(description="Generate code rules using LLM")
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"RuleCodeGenerateRequest",
|
||||
{
|
||||
"instruction": fields.String(required=True, description="Code generation instruction"),
|
||||
"model_config": fields.Raw(required=True, description="Model configuration"),
|
||||
"no_variable": fields.Boolean(required=True, default=False, description="Whether to exclude variables"),
|
||||
"code_language": fields.String(
|
||||
default="javascript", description="Programming language for code generation"
|
||||
),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[RuleCodeGeneratePayload.__name__])
|
||||
@console_ns.response(200, "Code rules generated successfully")
|
||||
@console_ns.response(400, "Invalid request parameters")
|
||||
@console_ns.response(402, "Provider quota exceeded")
|
||||
|
|
@ -95,22 +112,15 @@ class RuleCodeGenerateApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("instruction", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("no_variable", type=bool, required=True, default=False, location="json")
|
||||
.add_argument("code_language", type=str, required=False, default="javascript", location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = RuleCodeGeneratePayload.model_validate(console_ns.payload)
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
try:
|
||||
code_result = LLMGenerator.generate_code(
|
||||
tenant_id=current_tenant_id,
|
||||
instruction=args["instruction"],
|
||||
model_config=args["model_config"],
|
||||
code_language=args["code_language"],
|
||||
instruction=args.instruction,
|
||||
model_config=args.model_config_data,
|
||||
code_language=args.code_language,
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
|
|
@ -128,15 +138,7 @@ class RuleCodeGenerateApi(Resource):
|
|||
class RuleStructuredOutputGenerateApi(Resource):
|
||||
@console_ns.doc("generate_structured_output")
|
||||
@console_ns.doc(description="Generate structured output rules using LLM")
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"StructuredOutputGenerateRequest",
|
||||
{
|
||||
"instruction": fields.String(required=True, description="Structured output generation instruction"),
|
||||
"model_config": fields.Raw(required=True, description="Model configuration"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[RuleStructuredOutputPayload.__name__])
|
||||
@console_ns.response(200, "Structured output generated successfully")
|
||||
@console_ns.response(400, "Invalid request parameters")
|
||||
@console_ns.response(402, "Provider quota exceeded")
|
||||
|
|
@ -144,19 +146,14 @@ class RuleStructuredOutputGenerateApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("instruction", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = RuleStructuredOutputPayload.model_validate(console_ns.payload)
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
try:
|
||||
structured_output = LLMGenerator.generate_structured_output(
|
||||
tenant_id=current_tenant_id,
|
||||
instruction=args["instruction"],
|
||||
model_config=args["model_config"],
|
||||
instruction=args.instruction,
|
||||
model_config=args.model_config_data,
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
|
|
@ -174,20 +171,7 @@ class RuleStructuredOutputGenerateApi(Resource):
|
|||
class InstructionGenerateApi(Resource):
|
||||
@console_ns.doc("generate_instruction")
|
||||
@console_ns.doc(description="Generate instruction for workflow nodes or general use")
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"InstructionGenerateRequest",
|
||||
{
|
||||
"flow_id": fields.String(required=True, description="Workflow/Flow ID"),
|
||||
"node_id": fields.String(description="Node ID for workflow context"),
|
||||
"current": fields.String(description="Current instruction text"),
|
||||
"language": fields.String(default="javascript", description="Programming language (javascript/python)"),
|
||||
"instruction": fields.String(required=True, description="Instruction for generation"),
|
||||
"model_config": fields.Raw(required=True, description="Model configuration"),
|
||||
"ideal_output": fields.String(description="Expected ideal output"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[InstructionGeneratePayload.__name__])
|
||||
@console_ns.response(200, "Instruction generated successfully")
|
||||
@console_ns.response(400, "Invalid request parameters or flow/workflow not found")
|
||||
@console_ns.response(402, "Provider quota exceeded")
|
||||
|
|
@ -195,79 +179,69 @@ class InstructionGenerateApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("flow_id", type=str, required=True, default="", location="json")
|
||||
.add_argument("node_id", type=str, required=False, default="", location="json")
|
||||
.add_argument("current", type=str, required=False, default="", location="json")
|
||||
.add_argument("language", type=str, required=False, default="javascript", location="json")
|
||||
.add_argument("instruction", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("ideal_output", type=str, required=False, default="", location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = InstructionGeneratePayload.model_validate(console_ns.payload)
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
providers: list[type[CodeNodeProvider]] = [Python3CodeProvider, JavascriptCodeProvider]
|
||||
code_provider: type[CodeNodeProvider] | None = next(
|
||||
(p for p in providers if p.is_accept_language(args["language"])), None
|
||||
(p for p in providers if p.is_accept_language(args.language)), None
|
||||
)
|
||||
code_template = code_provider.get_default_code() if code_provider else ""
|
||||
try:
|
||||
# Generate from nothing for a workflow node
|
||||
if (args["current"] == code_template or args["current"] == "") and args["node_id"] != "":
|
||||
app = db.session.query(App).where(App.id == args["flow_id"]).first()
|
||||
if (args.current in (code_template, "")) and args.node_id != "":
|
||||
app = db.session.query(App).where(App.id == args.flow_id).first()
|
||||
if not app:
|
||||
return {"error": f"app {args['flow_id']} not found"}, 400
|
||||
return {"error": f"app {args.flow_id} not found"}, 400
|
||||
workflow = WorkflowService().get_draft_workflow(app_model=app)
|
||||
if not workflow:
|
||||
return {"error": f"workflow {args['flow_id']} not found"}, 400
|
||||
return {"error": f"workflow {args.flow_id} not found"}, 400
|
||||
nodes: Sequence = workflow.graph_dict["nodes"]
|
||||
node = [node for node in nodes if node["id"] == args["node_id"]]
|
||||
node = [node for node in nodes if node["id"] == args.node_id]
|
||||
if len(node) == 0:
|
||||
return {"error": f"node {args['node_id']} not found"}, 400
|
||||
return {"error": f"node {args.node_id} not found"}, 400
|
||||
node_type = node[0]["data"]["type"]
|
||||
match node_type:
|
||||
case "llm":
|
||||
return LLMGenerator.generate_rule_config(
|
||||
current_tenant_id,
|
||||
instruction=args["instruction"],
|
||||
model_config=args["model_config"],
|
||||
instruction=args.instruction,
|
||||
model_config=args.model_config_data,
|
||||
no_variable=True,
|
||||
)
|
||||
case "agent":
|
||||
return LLMGenerator.generate_rule_config(
|
||||
current_tenant_id,
|
||||
instruction=args["instruction"],
|
||||
model_config=args["model_config"],
|
||||
instruction=args.instruction,
|
||||
model_config=args.model_config_data,
|
||||
no_variable=True,
|
||||
)
|
||||
case "code":
|
||||
return LLMGenerator.generate_code(
|
||||
tenant_id=current_tenant_id,
|
||||
instruction=args["instruction"],
|
||||
model_config=args["model_config"],
|
||||
code_language=args["language"],
|
||||
instruction=args.instruction,
|
||||
model_config=args.model_config_data,
|
||||
code_language=args.language,
|
||||
)
|
||||
case _:
|
||||
return {"error": f"invalid node type: {node_type}"}
|
||||
if args["node_id"] == "" and args["current"] != "": # For legacy app without a workflow
|
||||
if args.node_id == "" and args.current != "": # For legacy app without a workflow
|
||||
return LLMGenerator.instruction_modify_legacy(
|
||||
tenant_id=current_tenant_id,
|
||||
flow_id=args["flow_id"],
|
||||
current=args["current"],
|
||||
instruction=args["instruction"],
|
||||
model_config=args["model_config"],
|
||||
ideal_output=args["ideal_output"],
|
||||
flow_id=args.flow_id,
|
||||
current=args.current,
|
||||
instruction=args.instruction,
|
||||
model_config=args.model_config_data,
|
||||
ideal_output=args.ideal_output,
|
||||
)
|
||||
if args["node_id"] != "" and args["current"] != "": # For workflow node
|
||||
if args.node_id != "" and args.current != "": # For workflow node
|
||||
return LLMGenerator.instruction_modify_workflow(
|
||||
tenant_id=current_tenant_id,
|
||||
flow_id=args["flow_id"],
|
||||
node_id=args["node_id"],
|
||||
current=args["current"],
|
||||
instruction=args["instruction"],
|
||||
model_config=args["model_config"],
|
||||
ideal_output=args["ideal_output"],
|
||||
flow_id=args.flow_id,
|
||||
node_id=args.node_id,
|
||||
current=args.current,
|
||||
instruction=args.instruction,
|
||||
model_config=args.model_config_data,
|
||||
ideal_output=args.ideal_output,
|
||||
workflow_service=WorkflowService(),
|
||||
)
|
||||
return {"error": "incompatible parameters"}, 400
|
||||
|
|
@ -285,24 +259,15 @@ class InstructionGenerateApi(Resource):
|
|||
class InstructionGenerationTemplateApi(Resource):
|
||||
@console_ns.doc("get_instruction_template")
|
||||
@console_ns.doc(description="Get instruction generation template")
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"InstructionTemplateRequest",
|
||||
{
|
||||
"instruction": fields.String(required=True, description="Template instruction"),
|
||||
"ideal_output": fields.String(description="Expected ideal output"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[InstructionTemplatePayload.__name__])
|
||||
@console_ns.response(200, "Template retrieved successfully")
|
||||
@console_ns.response(400, "Invalid request parameters")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser().add_argument("type", type=str, required=True, default=False, location="json")
|
||||
args = parser.parse_args()
|
||||
match args["type"]:
|
||||
args = InstructionTemplatePayload.model_validate(console_ns.payload)
|
||||
match args.type:
|
||||
case "prompt":
|
||||
from core.llm_generator.prompts import INSTRUCTION_GENERATE_TEMPLATE_PROMPT
|
||||
|
||||
|
|
@ -312,4 +277,4 @@ class InstructionGenerationTemplateApi(Resource):
|
|||
|
||||
return {"data": INSTRUCTION_GENERATE_TEMPLATE_CODE}
|
||||
case _:
|
||||
raise ValueError(f"Invalid type: {args['type']}")
|
||||
raise ValueError(f"Invalid type: {args.type}")
|
||||
|
|
|
|||
|
|
@ -1,7 +1,9 @@
|
|||
import logging
|
||||
from typing import Literal
|
||||
|
||||
from flask_restx import Resource, fields, marshal_with, reqparse
|
||||
from flask_restx.inputs import int_range
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy import exists, select
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
|
||||
|
|
@ -33,6 +35,67 @@ from services.errors.message import MessageNotExistsError, SuggestedQuestionsAft
|
|||
from services.message_service import MessageService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class ChatMessagesQuery(BaseModel):
|
||||
conversation_id: str = Field(..., description="Conversation ID")
|
||||
first_id: str | None = Field(default=None, description="First message ID for pagination")
|
||||
limit: int = Field(default=20, ge=1, le=100, description="Number of messages to return (1-100)")
|
||||
|
||||
@field_validator("first_id", mode="before")
|
||||
@classmethod
|
||||
def empty_to_none(cls, value: str | None) -> str | None:
|
||||
if value == "":
|
||||
return None
|
||||
return value
|
||||
|
||||
@field_validator("conversation_id", "first_id")
|
||||
@classmethod
|
||||
def validate_uuid(cls, value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return value
|
||||
return uuid_value(value)
|
||||
|
||||
|
||||
class MessageFeedbackPayload(BaseModel):
|
||||
message_id: str = Field(..., description="Message ID")
|
||||
rating: Literal["like", "dislike"] | None = Field(default=None, description="Feedback rating")
|
||||
|
||||
@field_validator("message_id")
|
||||
@classmethod
|
||||
def validate_message_id(cls, value: str) -> str:
|
||||
return uuid_value(value)
|
||||
|
||||
|
||||
class FeedbackExportQuery(BaseModel):
|
||||
from_source: Literal["user", "admin"] | None = Field(default=None, description="Filter by feedback source")
|
||||
rating: Literal["like", "dislike"] | None = Field(default=None, description="Filter by rating")
|
||||
has_comment: bool | None = Field(default=None, description="Only include feedback with comments")
|
||||
start_date: str | None = Field(default=None, description="Start date (YYYY-MM-DD)")
|
||||
end_date: str | None = Field(default=None, description="End date (YYYY-MM-DD)")
|
||||
format: Literal["csv", "json"] = Field(default="csv", description="Export format")
|
||||
|
||||
@field_validator("has_comment", mode="before")
|
||||
@classmethod
|
||||
def parse_bool(cls, value: bool | str | None) -> bool | None:
|
||||
if isinstance(value, bool) or value is None:
|
||||
return value
|
||||
lowered = value.lower()
|
||||
if lowered in {"true", "1", "yes", "on"}:
|
||||
return True
|
||||
if lowered in {"false", "0", "no", "off"}:
|
||||
return False
|
||||
raise ValueError("has_comment must be a boolean value")
|
||||
|
||||
|
||||
def reg(cls: type[BaseModel]):
|
||||
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
|
||||
reg(ChatMessagesQuery)
|
||||
reg(MessageFeedbackPayload)
|
||||
reg(FeedbackExportQuery)
|
||||
|
||||
# Register models for flask_restx to avoid dict type issues in Swagger
|
||||
# Register in dependency order: base models first, then dependent models
|
||||
|
|
@ -157,12 +220,7 @@ class ChatMessageListApi(Resource):
|
|||
@console_ns.doc("list_chat_messages")
|
||||
@console_ns.doc(description="Get chat messages for a conversation with pagination")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.parser()
|
||||
.add_argument("conversation_id", type=str, required=True, location="args", help="Conversation ID")
|
||||
.add_argument("first_id", type=str, location="args", help="First message ID for pagination")
|
||||
.add_argument("limit", type=int, location="args", default=20, help="Number of messages to return (1-100)")
|
||||
)
|
||||
@console_ns.expect(console_ns.models[ChatMessagesQuery.__name__])
|
||||
@console_ns.response(200, "Success", message_infinite_scroll_pagination_model)
|
||||
@console_ns.response(404, "Conversation not found")
|
||||
@login_required
|
||||
|
|
@ -172,27 +230,21 @@ class ChatMessageListApi(Resource):
|
|||
@marshal_with(message_infinite_scroll_pagination_model)
|
||||
@edit_permission_required
|
||||
def get(self, app_model):
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("conversation_id", required=True, type=uuid_value, location="args")
|
||||
.add_argument("first_id", type=uuid_value, location="args")
|
||||
.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = ChatMessagesQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
conversation = (
|
||||
db.session.query(Conversation)
|
||||
.where(Conversation.id == args["conversation_id"], Conversation.app_id == app_model.id)
|
||||
.where(Conversation.id == args.conversation_id, Conversation.app_id == app_model.id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not conversation:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
||||
if args["first_id"]:
|
||||
if args.first_id:
|
||||
first_message = (
|
||||
db.session.query(Message)
|
||||
.where(Message.conversation_id == conversation.id, Message.id == args["first_id"])
|
||||
.where(Message.conversation_id == conversation.id, Message.id == args.first_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
|
|
@ -207,7 +259,7 @@ class ChatMessageListApi(Resource):
|
|||
Message.id != first_message.id,
|
||||
)
|
||||
.order_by(Message.created_at.desc())
|
||||
.limit(args["limit"])
|
||||
.limit(args.limit)
|
||||
.all()
|
||||
)
|
||||
else:
|
||||
|
|
@ -215,12 +267,12 @@ class ChatMessageListApi(Resource):
|
|||
db.session.query(Message)
|
||||
.where(Message.conversation_id == conversation.id)
|
||||
.order_by(Message.created_at.desc())
|
||||
.limit(args["limit"])
|
||||
.limit(args.limit)
|
||||
.all()
|
||||
)
|
||||
|
||||
# Initialize has_more based on whether we have a full page
|
||||
if len(history_messages) == args["limit"]:
|
||||
if len(history_messages) == args.limit:
|
||||
current_page_first_message = history_messages[-1]
|
||||
# Check if there are more messages before the current page
|
||||
has_more = db.session.scalar(
|
||||
|
|
@ -238,7 +290,7 @@ class ChatMessageListApi(Resource):
|
|||
|
||||
history_messages = list(reversed(history_messages))
|
||||
|
||||
return InfiniteScrollPagination(data=history_messages, limit=args["limit"], has_more=has_more)
|
||||
return InfiniteScrollPagination(data=history_messages, limit=args.limit, has_more=has_more)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/feedbacks")
|
||||
|
|
@ -246,15 +298,7 @@ class MessageFeedbackApi(Resource):
|
|||
@console_ns.doc("create_message_feedback")
|
||||
@console_ns.doc(description="Create or update message feedback (like/dislike)")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"MessageFeedbackRequest",
|
||||
{
|
||||
"message_id": fields.String(required=True, description="Message ID"),
|
||||
"rating": fields.String(enum=["like", "dislike"], description="Feedback rating"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[MessageFeedbackPayload.__name__])
|
||||
@console_ns.response(200, "Feedback updated successfully")
|
||||
@console_ns.response(404, "Message not found")
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
|
|
@ -265,14 +309,9 @@ class MessageFeedbackApi(Resource):
|
|||
def post(self, app_model):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("message_id", required=True, type=uuid_value, location="json")
|
||||
.add_argument("rating", type=str, choices=["like", "dislike", None], location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = MessageFeedbackPayload.model_validate(console_ns.payload)
|
||||
|
||||
message_id = str(args["message_id"])
|
||||
message_id = str(args.message_id)
|
||||
|
||||
message = db.session.query(Message).where(Message.id == message_id, Message.app_id == app_model.id).first()
|
||||
|
||||
|
|
@ -281,18 +320,21 @@ class MessageFeedbackApi(Resource):
|
|||
|
||||
feedback = message.admin_feedback
|
||||
|
||||
if not args["rating"] and feedback:
|
||||
if not args.rating and feedback:
|
||||
db.session.delete(feedback)
|
||||
elif args["rating"] and feedback:
|
||||
feedback.rating = args["rating"]
|
||||
elif not args["rating"] and not feedback:
|
||||
elif args.rating and feedback:
|
||||
feedback.rating = args.rating
|
||||
elif not args.rating and not feedback:
|
||||
raise ValueError("rating cannot be None when feedback not exists")
|
||||
else:
|
||||
rating_value = args.rating
|
||||
if rating_value is None:
|
||||
raise ValueError("rating is required to create feedback")
|
||||
feedback = MessageFeedback(
|
||||
app_id=app_model.id,
|
||||
conversation_id=message.conversation_id,
|
||||
message_id=message.id,
|
||||
rating=args["rating"],
|
||||
rating=rating_value,
|
||||
from_source="admin",
|
||||
from_account_id=current_user.id,
|
||||
)
|
||||
|
|
@ -369,6 +411,46 @@ class MessageSuggestedQuestionApi(Resource):
|
|||
return {"data": questions}
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/feedbacks/export")
|
||||
class MessageFeedbackExportApi(Resource):
|
||||
@console_ns.doc("export_feedbacks")
|
||||
@console_ns.doc(description="Export user feedback data for Google Sheets")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[FeedbackExportQuery.__name__])
|
||||
@console_ns.response(200, "Feedback data exported successfully")
|
||||
@console_ns.response(400, "Invalid parameters")
|
||||
@console_ns.response(500, "Internal server error")
|
||||
@get_app_model
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_model):
|
||||
args = FeedbackExportQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
# Import the service function
|
||||
from services.feedback_service import FeedbackService
|
||||
|
||||
try:
|
||||
export_data = FeedbackService.export_feedbacks(
|
||||
app_id=app_model.id,
|
||||
from_source=args.from_source,
|
||||
rating=args.rating,
|
||||
has_comment=args.has_comment,
|
||||
start_date=args.start_date,
|
||||
end_date=args.end_date,
|
||||
format_type=args.format,
|
||||
)
|
||||
|
||||
return export_data
|
||||
|
||||
except ValueError as e:
|
||||
logger.exception("Parameter validation error in feedback export")
|
||||
return {"error": f"Parameter validation error: {str(e)}"}, 400
|
||||
except Exception as e:
|
||||
logger.exception("Error exporting feedback data")
|
||||
raise InternalServerError(str(e))
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/messages/<uuid:message_id>")
|
||||
class MessageApi(Resource):
|
||||
@console_ns.doc("get_message")
|
||||
|
|
|
|||
|
|
@ -1,8 +1,9 @@
|
|||
from decimal import Decimal
|
||||
|
||||
import sqlalchemy as sa
|
||||
from flask import abort, jsonify
|
||||
from flask_restx import Resource, fields, reqparse
|
||||
from flask import abort, jsonify, request
|
||||
from flask_restx import Resource, fields
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
|
|
@ -10,21 +11,37 @@ from controllers.console.wraps import account_initialization_required, setup_req
|
|||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import parse_time_range
|
||||
from libs.helper import DatetimeString, convert_datetime_to_date
|
||||
from libs.helper import convert_datetime_to_date
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import AppMode
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class StatisticTimeRangeQuery(BaseModel):
|
||||
start: str | None = Field(default=None, description="Start date (YYYY-MM-DD HH:MM)")
|
||||
end: str | None = Field(default=None, description="End date (YYYY-MM-DD HH:MM)")
|
||||
|
||||
@field_validator("start", "end", mode="before")
|
||||
@classmethod
|
||||
def empty_string_to_none(cls, value: str | None) -> str | None:
|
||||
if value == "":
|
||||
return None
|
||||
return value
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
StatisticTimeRangeQuery.__name__,
|
||||
StatisticTimeRangeQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/statistics/daily-messages")
|
||||
class DailyMessageStatistic(Resource):
|
||||
@console_ns.doc("get_daily_message_statistics")
|
||||
@console_ns.doc(description="Get daily message statistics for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.parser()
|
||||
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
|
||||
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
|
||||
)
|
||||
@console_ns.expect(console_ns.models[StatisticTimeRangeQuery.__name__])
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Daily message statistics retrieved successfully",
|
||||
|
|
@ -37,12 +54,7 @@ class DailyMessageStatistic(Resource):
|
|||
def get(self, app_model):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
converted_created_at = convert_datetime_to_date("created_at")
|
||||
sql_query = f"""SELECT
|
||||
|
|
@ -57,7 +69,7 @@ WHERE
|
|||
assert account.timezone is not None
|
||||
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
|
|
@ -81,19 +93,12 @@ WHERE
|
|||
return jsonify({"data": response_data})
|
||||
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args", help="Start date (YYYY-MM-DD HH:MM)")
|
||||
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args", help="End date (YYYY-MM-DD HH:MM)")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/statistics/daily-conversations")
|
||||
class DailyConversationStatistic(Resource):
|
||||
@console_ns.doc("get_daily_conversation_statistics")
|
||||
@console_ns.doc(description="Get daily conversation statistics for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(parser)
|
||||
@console_ns.expect(console_ns.models[StatisticTimeRangeQuery.__name__])
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Daily conversation statistics retrieved successfully",
|
||||
|
|
@ -106,7 +111,7 @@ class DailyConversationStatistic(Resource):
|
|||
def get(self, app_model):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
args = parser.parse_args()
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
converted_created_at = convert_datetime_to_date("created_at")
|
||||
sql_query = f"""SELECT
|
||||
|
|
@ -121,7 +126,7 @@ WHERE
|
|||
assert account.timezone is not None
|
||||
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
|
|
@ -149,7 +154,7 @@ class DailyTerminalsStatistic(Resource):
|
|||
@console_ns.doc("get_daily_terminals_statistics")
|
||||
@console_ns.doc(description="Get daily terminal/end-user statistics for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(parser)
|
||||
@console_ns.expect(console_ns.models[StatisticTimeRangeQuery.__name__])
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Daily terminal statistics retrieved successfully",
|
||||
|
|
@ -162,7 +167,7 @@ class DailyTerminalsStatistic(Resource):
|
|||
def get(self, app_model):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
args = parser.parse_args()
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
converted_created_at = convert_datetime_to_date("created_at")
|
||||
sql_query = f"""SELECT
|
||||
|
|
@ -177,7 +182,7 @@ WHERE
|
|||
assert account.timezone is not None
|
||||
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
|
|
@ -206,7 +211,7 @@ class DailyTokenCostStatistic(Resource):
|
|||
@console_ns.doc("get_daily_token_cost_statistics")
|
||||
@console_ns.doc(description="Get daily token cost statistics for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(parser)
|
||||
@console_ns.expect(console_ns.models[StatisticTimeRangeQuery.__name__])
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Daily token cost statistics retrieved successfully",
|
||||
|
|
@ -219,7 +224,7 @@ class DailyTokenCostStatistic(Resource):
|
|||
def get(self, app_model):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
args = parser.parse_args()
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
converted_created_at = convert_datetime_to_date("created_at")
|
||||
sql_query = f"""SELECT
|
||||
|
|
@ -235,7 +240,7 @@ WHERE
|
|||
assert account.timezone is not None
|
||||
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
|
|
@ -266,7 +271,7 @@ class AverageSessionInteractionStatistic(Resource):
|
|||
@console_ns.doc("get_average_session_interaction_statistics")
|
||||
@console_ns.doc(description="Get average session interaction statistics for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(parser)
|
||||
@console_ns.expect(console_ns.models[StatisticTimeRangeQuery.__name__])
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Average session interaction statistics retrieved successfully",
|
||||
|
|
@ -279,7 +284,7 @@ class AverageSessionInteractionStatistic(Resource):
|
|||
def get(self, app_model):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
args = parser.parse_args()
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
converted_created_at = convert_datetime_to_date("c.created_at")
|
||||
sql_query = f"""SELECT
|
||||
|
|
@ -302,7 +307,7 @@ FROM
|
|||
assert account.timezone is not None
|
||||
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
|
|
@ -342,7 +347,7 @@ class UserSatisfactionRateStatistic(Resource):
|
|||
@console_ns.doc("get_user_satisfaction_rate_statistics")
|
||||
@console_ns.doc(description="Get user satisfaction rate statistics for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(parser)
|
||||
@console_ns.expect(console_ns.models[StatisticTimeRangeQuery.__name__])
|
||||
@console_ns.response(
|
||||
200,
|
||||
"User satisfaction rate statistics retrieved successfully",
|
||||
|
|
@ -355,7 +360,7 @@ class UserSatisfactionRateStatistic(Resource):
|
|||
def get(self, app_model):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
args = parser.parse_args()
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
converted_created_at = convert_datetime_to_date("m.created_at")
|
||||
sql_query = f"""SELECT
|
||||
|
|
@ -374,7 +379,7 @@ WHERE
|
|||
assert account.timezone is not None
|
||||
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
|
|
@ -408,7 +413,7 @@ class AverageResponseTimeStatistic(Resource):
|
|||
@console_ns.doc("get_average_response_time_statistics")
|
||||
@console_ns.doc(description="Get average response time statistics for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(parser)
|
||||
@console_ns.expect(console_ns.models[StatisticTimeRangeQuery.__name__])
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Average response time statistics retrieved successfully",
|
||||
|
|
@ -421,7 +426,7 @@ class AverageResponseTimeStatistic(Resource):
|
|||
def get(self, app_model):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
args = parser.parse_args()
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
converted_created_at = convert_datetime_to_date("created_at")
|
||||
sql_query = f"""SELECT
|
||||
|
|
@ -436,7 +441,7 @@ WHERE
|
|||
assert account.timezone is not None
|
||||
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
|
|
@ -465,7 +470,7 @@ class TokensPerSecondStatistic(Resource):
|
|||
@console_ns.doc("get_tokens_per_second_statistics")
|
||||
@console_ns.doc(description="Get tokens per second statistics for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(parser)
|
||||
@console_ns.expect(console_ns.models[StatisticTimeRangeQuery.__name__])
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Tokens per second statistics retrieved successfully",
|
||||
|
|
@ -477,7 +482,7 @@ class TokensPerSecondStatistic(Resource):
|
|||
@account_initialization_required
|
||||
def get(self, app_model):
|
||||
account, _ = current_account_with_tenant()
|
||||
args = parser.parse_args()
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
converted_created_at = convert_datetime_to_date("created_at")
|
||||
sql_query = f"""SELECT
|
||||
|
|
@ -495,7 +500,7 @@ WHERE
|
|||
assert account.timezone is not None
|
||||
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
|
|
|
|||
|
|
@ -1,10 +1,11 @@
|
|||
import json
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
from typing import cast
|
||||
from typing import Any
|
||||
|
||||
from flask import abort, request
|
||||
from flask_restx import Resource, fields, inputs, marshal_with, reqparse
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||
|
||||
|
|
@ -49,6 +50,7 @@ from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseE
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
LISTENING_RETRY_IN = 2000
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
# Register models for flask_restx to avoid dict type issues in Swagger
|
||||
# Register in dependency order: base models first, then dependent models
|
||||
|
|
@ -90,17 +92,121 @@ workflow_pagination_model = console_ns.model("WorkflowPagination", workflow_pagi
|
|||
# Otherwise register it here
|
||||
from fields.end_user_fields import simple_end_user_fields
|
||||
|
||||
simple_end_user_model = None
|
||||
try:
|
||||
simple_end_user_model = console_ns.models.get("SimpleEndUser")
|
||||
except (KeyError, AttributeError):
|
||||
except AttributeError:
|
||||
pass
|
||||
if simple_end_user_model is None:
|
||||
simple_end_user_model = console_ns.model("SimpleEndUser", simple_end_user_fields)
|
||||
|
||||
workflow_run_node_execution_model = None
|
||||
try:
|
||||
workflow_run_node_execution_model = console_ns.models.get("WorkflowRunNodeExecution")
|
||||
except (KeyError, AttributeError):
|
||||
except AttributeError:
|
||||
pass
|
||||
if workflow_run_node_execution_model is None:
|
||||
workflow_run_node_execution_model = console_ns.model("WorkflowRunNodeExecution", workflow_run_node_execution_fields)
|
||||
|
||||
|
||||
class SyncDraftWorkflowPayload(BaseModel):
|
||||
graph: dict[str, Any]
|
||||
features: dict[str, Any]
|
||||
hash: str | None = None
|
||||
environment_variables: list[dict[str, Any]] = Field(default_factory=list)
|
||||
conversation_variables: list[dict[str, Any]] = Field(default_factory=list)
|
||||
|
||||
|
||||
class BaseWorkflowRunPayload(BaseModel):
|
||||
files: list[dict[str, Any]] | None = None
|
||||
|
||||
|
||||
class AdvancedChatWorkflowRunPayload(BaseWorkflowRunPayload):
|
||||
inputs: dict[str, Any] | None = None
|
||||
query: str = ""
|
||||
conversation_id: str | None = None
|
||||
parent_message_id: str | None = None
|
||||
|
||||
@field_validator("conversation_id", "parent_message_id")
|
||||
@classmethod
|
||||
def validate_uuid(cls, value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return value
|
||||
return uuid_value(value)
|
||||
|
||||
|
||||
class IterationNodeRunPayload(BaseModel):
|
||||
inputs: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class LoopNodeRunPayload(BaseModel):
|
||||
inputs: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class DraftWorkflowRunPayload(BaseWorkflowRunPayload):
|
||||
inputs: dict[str, Any]
|
||||
|
||||
|
||||
class DraftWorkflowNodeRunPayload(BaseWorkflowRunPayload):
|
||||
inputs: dict[str, Any]
|
||||
query: str = ""
|
||||
|
||||
|
||||
class PublishWorkflowPayload(BaseModel):
|
||||
marked_name: str | None = Field(default=None, max_length=20)
|
||||
marked_comment: str | None = Field(default=None, max_length=100)
|
||||
|
||||
|
||||
class DefaultBlockConfigQuery(BaseModel):
|
||||
q: str | None = None
|
||||
|
||||
|
||||
class ConvertToWorkflowPayload(BaseModel):
|
||||
name: str | None = None
|
||||
icon_type: str | None = None
|
||||
icon: str | None = None
|
||||
icon_background: str | None = None
|
||||
|
||||
|
||||
class WorkflowListQuery(BaseModel):
|
||||
page: int = Field(default=1, ge=1, le=99999)
|
||||
limit: int = Field(default=10, ge=1, le=100)
|
||||
user_id: str | None = None
|
||||
named_only: bool = False
|
||||
|
||||
|
||||
class WorkflowUpdatePayload(BaseModel):
|
||||
marked_name: str | None = Field(default=None, max_length=20)
|
||||
marked_comment: str | None = Field(default=None, max_length=100)
|
||||
|
||||
|
||||
class DraftWorkflowTriggerRunPayload(BaseModel):
|
||||
node_id: str
|
||||
|
||||
|
||||
class DraftWorkflowTriggerRunAllPayload(BaseModel):
|
||||
node_ids: list[str]
|
||||
|
||||
|
||||
def reg(cls: type[BaseModel]):
|
||||
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
|
||||
reg(SyncDraftWorkflowPayload)
|
||||
reg(AdvancedChatWorkflowRunPayload)
|
||||
reg(IterationNodeRunPayload)
|
||||
reg(LoopNodeRunPayload)
|
||||
reg(DraftWorkflowRunPayload)
|
||||
reg(DraftWorkflowNodeRunPayload)
|
||||
reg(PublishWorkflowPayload)
|
||||
reg(DefaultBlockConfigQuery)
|
||||
reg(ConvertToWorkflowPayload)
|
||||
reg(WorkflowListQuery)
|
||||
reg(WorkflowUpdatePayload)
|
||||
reg(DraftWorkflowTriggerRunPayload)
|
||||
reg(DraftWorkflowTriggerRunAllPayload)
|
||||
|
||||
|
||||
# TODO(QuantumGhost): Refactor existing node run API to handle file parameter parsing
|
||||
# at the controller level rather than in the workflow logic. This would improve separation
|
||||
# of concerns and make the code more maintainable.
|
||||
|
|
@ -152,18 +258,7 @@ class DraftWorkflowApi(Resource):
|
|||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@console_ns.doc("sync_draft_workflow")
|
||||
@console_ns.doc(description="Sync draft workflow configuration")
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"SyncDraftWorkflowRequest",
|
||||
{
|
||||
"graph": fields.Raw(required=True, description="Workflow graph configuration"),
|
||||
"features": fields.Raw(required=True, description="Workflow features configuration"),
|
||||
"hash": fields.String(description="Workflow hash for validation"),
|
||||
"environment_variables": fields.List(fields.Raw, required=True, description="Environment variables"),
|
||||
"conversation_variables": fields.List(fields.Raw, description="Conversation variables"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[SyncDraftWorkflowPayload.__name__])
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Draft workflow synced successfully",
|
||||
|
|
@ -187,36 +282,23 @@ class DraftWorkflowApi(Resource):
|
|||
|
||||
content_type = request.headers.get("Content-Type", "")
|
||||
|
||||
payload_data: dict[str, Any] | None = None
|
||||
if "application/json" in content_type:
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("graph", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("features", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("hash", type=str, required=False, location="json")
|
||||
.add_argument("environment_variables", type=list, required=True, location="json")
|
||||
.add_argument("conversation_variables", type=list, required=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
payload_data = request.get_json(silent=True)
|
||||
if not isinstance(payload_data, dict):
|
||||
return {"message": "Invalid JSON data"}, 400
|
||||
elif "text/plain" in content_type:
|
||||
try:
|
||||
data = json.loads(request.data.decode("utf-8"))
|
||||
if "graph" not in data or "features" not in data:
|
||||
raise ValueError("graph or features not found in data")
|
||||
|
||||
if not isinstance(data.get("graph"), dict) or not isinstance(data.get("features"), dict):
|
||||
raise ValueError("graph or features is not a dict")
|
||||
|
||||
args = {
|
||||
"graph": data.get("graph"),
|
||||
"features": data.get("features"),
|
||||
"hash": data.get("hash"),
|
||||
"environment_variables": data.get("environment_variables"),
|
||||
"conversation_variables": data.get("conversation_variables"),
|
||||
}
|
||||
payload_data = json.loads(request.data.decode("utf-8"))
|
||||
except json.JSONDecodeError:
|
||||
return {"message": "Invalid JSON data"}, 400
|
||||
if not isinstance(payload_data, dict):
|
||||
return {"message": "Invalid JSON data"}, 400
|
||||
else:
|
||||
abort(415)
|
||||
|
||||
args_model = SyncDraftWorkflowPayload.model_validate(payload_data)
|
||||
args = args_model.model_dump()
|
||||
workflow_service = WorkflowService()
|
||||
|
||||
try:
|
||||
|
|
@ -252,17 +334,7 @@ class AdvancedChatDraftWorkflowRunApi(Resource):
|
|||
@console_ns.doc("run_advanced_chat_draft_workflow")
|
||||
@console_ns.doc(description="Run draft workflow for advanced chat application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"AdvancedChatWorkflowRunRequest",
|
||||
{
|
||||
"query": fields.String(required=True, description="User query"),
|
||||
"inputs": fields.Raw(description="Input variables"),
|
||||
"files": fields.List(fields.Raw, description="File uploads"),
|
||||
"conversation_id": fields.String(description="Conversation ID"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[AdvancedChatWorkflowRunPayload.__name__])
|
||||
@console_ns.response(200, "Workflow run started successfully")
|
||||
@console_ns.response(400, "Invalid request parameters")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
|
|
@ -277,16 +349,8 @@ class AdvancedChatDraftWorkflowRunApi(Resource):
|
|||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("inputs", type=dict, location="json")
|
||||
.add_argument("query", type=str, required=True, location="json", default="")
|
||||
.add_argument("files", type=list, location="json")
|
||||
.add_argument("conversation_id", type=uuid_value, location="json")
|
||||
.add_argument("parent_message_id", type=uuid_value, required=False, location="json")
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
args_model = AdvancedChatWorkflowRunPayload.model_validate(console_ns.payload or {})
|
||||
args = args_model.model_dump(exclude_none=True)
|
||||
|
||||
external_trace_id = get_external_trace_id(request)
|
||||
if external_trace_id:
|
||||
|
|
@ -316,15 +380,7 @@ class AdvancedChatDraftRunIterationNodeApi(Resource):
|
|||
@console_ns.doc("run_advanced_chat_draft_iteration_node")
|
||||
@console_ns.doc(description="Run draft workflow iteration node for advanced chat")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"IterationNodeRunRequest",
|
||||
{
|
||||
"task_id": fields.String(required=True, description="Task ID"),
|
||||
"inputs": fields.Raw(description="Input variables"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[IterationNodeRunPayload.__name__])
|
||||
@console_ns.response(200, "Iteration node run started successfully")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@console_ns.response(404, "Node not found")
|
||||
|
|
@ -338,8 +394,7 @@ class AdvancedChatDraftRunIterationNodeApi(Resource):
|
|||
Run draft workflow iteration node
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
parser = reqparse.RequestParser().add_argument("inputs", type=dict, location="json")
|
||||
args = parser.parse_args()
|
||||
args = IterationNodeRunPayload.model_validate(console_ns.payload or {}).model_dump(exclude_none=True)
|
||||
|
||||
try:
|
||||
response = AppGenerateService.generate_single_iteration(
|
||||
|
|
@ -363,15 +418,7 @@ class WorkflowDraftRunIterationNodeApi(Resource):
|
|||
@console_ns.doc("run_workflow_draft_iteration_node")
|
||||
@console_ns.doc(description="Run draft workflow iteration node")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"WorkflowIterationNodeRunRequest",
|
||||
{
|
||||
"task_id": fields.String(required=True, description="Task ID"),
|
||||
"inputs": fields.Raw(description="Input variables"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[IterationNodeRunPayload.__name__])
|
||||
@console_ns.response(200, "Workflow iteration node run started successfully")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@console_ns.response(404, "Node not found")
|
||||
|
|
@ -385,8 +432,7 @@ class WorkflowDraftRunIterationNodeApi(Resource):
|
|||
Run draft workflow iteration node
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
parser = reqparse.RequestParser().add_argument("inputs", type=dict, location="json")
|
||||
args = parser.parse_args()
|
||||
args = IterationNodeRunPayload.model_validate(console_ns.payload or {}).model_dump(exclude_none=True)
|
||||
|
||||
try:
|
||||
response = AppGenerateService.generate_single_iteration(
|
||||
|
|
@ -410,15 +456,7 @@ class AdvancedChatDraftRunLoopNodeApi(Resource):
|
|||
@console_ns.doc("run_advanced_chat_draft_loop_node")
|
||||
@console_ns.doc(description="Run draft workflow loop node for advanced chat")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"LoopNodeRunRequest",
|
||||
{
|
||||
"task_id": fields.String(required=True, description="Task ID"),
|
||||
"inputs": fields.Raw(description="Input variables"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[LoopNodeRunPayload.__name__])
|
||||
@console_ns.response(200, "Loop node run started successfully")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@console_ns.response(404, "Node not found")
|
||||
|
|
@ -432,8 +470,7 @@ class AdvancedChatDraftRunLoopNodeApi(Resource):
|
|||
Run draft workflow loop node
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
parser = reqparse.RequestParser().add_argument("inputs", type=dict, location="json")
|
||||
args = parser.parse_args()
|
||||
args = LoopNodeRunPayload.model_validate(console_ns.payload or {}).model_dump(exclude_none=True)
|
||||
|
||||
try:
|
||||
response = AppGenerateService.generate_single_loop(
|
||||
|
|
@ -457,15 +494,7 @@ class WorkflowDraftRunLoopNodeApi(Resource):
|
|||
@console_ns.doc("run_workflow_draft_loop_node")
|
||||
@console_ns.doc(description="Run draft workflow loop node")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"WorkflowLoopNodeRunRequest",
|
||||
{
|
||||
"task_id": fields.String(required=True, description="Task ID"),
|
||||
"inputs": fields.Raw(description="Input variables"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[LoopNodeRunPayload.__name__])
|
||||
@console_ns.response(200, "Workflow loop node run started successfully")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@console_ns.response(404, "Node not found")
|
||||
|
|
@ -479,8 +508,7 @@ class WorkflowDraftRunLoopNodeApi(Resource):
|
|||
Run draft workflow loop node
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
parser = reqparse.RequestParser().add_argument("inputs", type=dict, location="json")
|
||||
args = parser.parse_args()
|
||||
args = LoopNodeRunPayload.model_validate(console_ns.payload or {}).model_dump(exclude_none=True)
|
||||
|
||||
try:
|
||||
response = AppGenerateService.generate_single_loop(
|
||||
|
|
@ -504,15 +532,7 @@ class DraftWorkflowRunApi(Resource):
|
|||
@console_ns.doc("run_draft_workflow")
|
||||
@console_ns.doc(description="Run draft workflow")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"DraftWorkflowRunRequest",
|
||||
{
|
||||
"inputs": fields.Raw(required=True, description="Input variables"),
|
||||
"files": fields.List(fields.Raw, description="File uploads"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[DraftWorkflowRunPayload.__name__])
|
||||
@console_ns.response(200, "Draft workflow run started successfully")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@setup_required
|
||||
|
|
@ -525,12 +545,7 @@ class DraftWorkflowRunApi(Resource):
|
|||
Run draft workflow
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("files", type=list, required=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = DraftWorkflowRunPayload.model_validate(console_ns.payload or {}).model_dump(exclude_none=True)
|
||||
|
||||
external_trace_id = get_external_trace_id(request)
|
||||
if external_trace_id:
|
||||
|
|
@ -582,14 +597,7 @@ class DraftWorkflowNodeRunApi(Resource):
|
|||
@console_ns.doc("run_draft_workflow_node")
|
||||
@console_ns.doc(description="Run draft workflow node")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"DraftWorkflowNodeRunRequest",
|
||||
{
|
||||
"inputs": fields.Raw(description="Input variables"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[DraftWorkflowNodeRunPayload.__name__])
|
||||
@console_ns.response(200, "Node run started successfully", workflow_run_node_execution_model)
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@console_ns.response(404, "Node not found")
|
||||
|
|
@ -604,15 +612,10 @@ class DraftWorkflowNodeRunApi(Resource):
|
|||
Run draft workflow node
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("query", type=str, required=False, location="json", default="")
|
||||
.add_argument("files", type=list, location="json", default=[])
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args_model = DraftWorkflowNodeRunPayload.model_validate(console_ns.payload or {})
|
||||
args = args_model.model_dump(exclude_none=True)
|
||||
|
||||
user_inputs = args.get("inputs")
|
||||
user_inputs = args_model.inputs
|
||||
if user_inputs is None:
|
||||
raise ValueError("missing inputs")
|
||||
|
||||
|
|
@ -637,13 +640,6 @@ class DraftWorkflowNodeRunApi(Resource):
|
|||
return workflow_node_execution
|
||||
|
||||
|
||||
parser_publish = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("marked_name", type=str, required=False, default="", location="json")
|
||||
.add_argument("marked_comment", type=str, required=False, default="", location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/publish")
|
||||
class PublishedWorkflowApi(Resource):
|
||||
@console_ns.doc("get_published_workflow")
|
||||
|
|
@ -668,7 +664,7 @@ class PublishedWorkflowApi(Resource):
|
|||
# return workflow, if not found, return None
|
||||
return workflow
|
||||
|
||||
@console_ns.expect(parser_publish)
|
||||
@console_ns.expect(console_ns.models[PublishWorkflowPayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -680,13 +676,7 @@ class PublishedWorkflowApi(Resource):
|
|||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
args = parser_publish.parse_args()
|
||||
|
||||
# Validate name and comment length
|
||||
if args.marked_name and len(args.marked_name) > 20:
|
||||
raise ValueError("Marked name cannot exceed 20 characters")
|
||||
if args.marked_comment and len(args.marked_comment) > 100:
|
||||
raise ValueError("Marked comment cannot exceed 100 characters")
|
||||
args = PublishWorkflowPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
with Session(db.engine) as session:
|
||||
|
|
@ -735,9 +725,6 @@ class DefaultBlockConfigsApi(Resource):
|
|||
return workflow_service.get_default_block_configs()
|
||||
|
||||
|
||||
parser_block = reqparse.RequestParser().add_argument("q", type=str, location="args")
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/default-workflow-block-configs/<string:block_type>")
|
||||
class DefaultBlockConfigApi(Resource):
|
||||
@console_ns.doc("get_default_block_config")
|
||||
|
|
@ -745,7 +732,7 @@ class DefaultBlockConfigApi(Resource):
|
|||
@console_ns.doc(params={"app_id": "Application ID", "block_type": "Block type"})
|
||||
@console_ns.response(200, "Default block configuration retrieved successfully")
|
||||
@console_ns.response(404, "Block type not found")
|
||||
@console_ns.expect(parser_block)
|
||||
@console_ns.expect(console_ns.models[DefaultBlockConfigQuery.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -755,14 +742,12 @@ class DefaultBlockConfigApi(Resource):
|
|||
"""
|
||||
Get default block config
|
||||
"""
|
||||
args = parser_block.parse_args()
|
||||
|
||||
q = args.get("q")
|
||||
args = DefaultBlockConfigQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
filters = None
|
||||
if q:
|
||||
if args.q:
|
||||
try:
|
||||
filters = json.loads(args.get("q", ""))
|
||||
filters = json.loads(args.q)
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError("Invalid filters")
|
||||
|
||||
|
|
@ -771,18 +756,9 @@ class DefaultBlockConfigApi(Resource):
|
|||
return workflow_service.get_default_block_config(node_type=block_type, filters=filters)
|
||||
|
||||
|
||||
parser_convert = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("name", type=str, required=False, nullable=True, location="json")
|
||||
.add_argument("icon_type", type=str, required=False, nullable=True, location="json")
|
||||
.add_argument("icon", type=str, required=False, nullable=True, location="json")
|
||||
.add_argument("icon_background", type=str, required=False, nullable=True, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/convert-to-workflow")
|
||||
class ConvertToWorkflowApi(Resource):
|
||||
@console_ns.expect(parser_convert)
|
||||
@console_ns.expect(console_ns.models[ConvertToWorkflowPayload.__name__])
|
||||
@console_ns.doc("convert_to_workflow")
|
||||
@console_ns.doc(description="Convert application to workflow mode")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
|
|
@ -802,10 +778,8 @@ class ConvertToWorkflowApi(Resource):
|
|||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
if request.data:
|
||||
args = parser_convert.parse_args()
|
||||
else:
|
||||
args = {}
|
||||
payload = console_ns.payload or {}
|
||||
args = ConvertToWorkflowPayload.model_validate(payload).model_dump(exclude_none=True)
|
||||
|
||||
# convert to workflow mode
|
||||
workflow_service = WorkflowService()
|
||||
|
|
@ -817,18 +791,9 @@ class ConvertToWorkflowApi(Resource):
|
|||
}
|
||||
|
||||
|
||||
parser_workflows = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args")
|
||||
.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=10, location="args")
|
||||
.add_argument("user_id", type=str, required=False, location="args")
|
||||
.add_argument("named_only", type=inputs.boolean, required=False, default=False, location="args")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows")
|
||||
class PublishedAllWorkflowApi(Resource):
|
||||
@console_ns.expect(parser_workflows)
|
||||
@console_ns.expect(console_ns.models[WorkflowListQuery.__name__])
|
||||
@console_ns.doc("get_all_published_workflows")
|
||||
@console_ns.doc(description="Get all published workflows for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
|
|
@ -845,16 +810,15 @@ class PublishedAllWorkflowApi(Resource):
|
|||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
args = parser_workflows.parse_args()
|
||||
page = args["page"]
|
||||
limit = args["limit"]
|
||||
user_id = args.get("user_id")
|
||||
named_only = args.get("named_only", False)
|
||||
args = WorkflowListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
page = args.page
|
||||
limit = args.limit
|
||||
user_id = args.user_id
|
||||
named_only = args.named_only
|
||||
|
||||
if user_id:
|
||||
if user_id != current_user.id:
|
||||
raise Forbidden()
|
||||
user_id = cast(str, user_id)
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
with Session(db.engine) as session:
|
||||
|
|
@ -880,15 +844,7 @@ class WorkflowByIdApi(Resource):
|
|||
@console_ns.doc("update_workflow_by_id")
|
||||
@console_ns.doc(description="Update workflow by ID")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "workflow_id": "Workflow ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"UpdateWorkflowRequest",
|
||||
{
|
||||
"environment_variables": fields.List(fields.Raw, description="Environment variables"),
|
||||
"conversation_variables": fields.List(fields.Raw, description="Conversation variables"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[WorkflowUpdatePayload.__name__])
|
||||
@console_ns.response(200, "Workflow updated successfully", workflow_model)
|
||||
@console_ns.response(404, "Workflow not found")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
|
|
@ -903,25 +859,14 @@ class WorkflowByIdApi(Resource):
|
|||
Update workflow attributes
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("marked_name", type=str, required=False, location="json")
|
||||
.add_argument("marked_comment", type=str, required=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Validate name and comment length
|
||||
if args.marked_name and len(args.marked_name) > 20:
|
||||
raise ValueError("Marked name cannot exceed 20 characters")
|
||||
if args.marked_comment and len(args.marked_comment) > 100:
|
||||
raise ValueError("Marked comment cannot exceed 100 characters")
|
||||
args = WorkflowUpdatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
# Prepare update data
|
||||
update_data = {}
|
||||
if args.get("marked_name") is not None:
|
||||
update_data["marked_name"] = args["marked_name"]
|
||||
if args.get("marked_comment") is not None:
|
||||
update_data["marked_comment"] = args["marked_comment"]
|
||||
if args.marked_name is not None:
|
||||
update_data["marked_name"] = args.marked_name
|
||||
if args.marked_comment is not None:
|
||||
update_data["marked_comment"] = args.marked_comment
|
||||
|
||||
if not update_data:
|
||||
return {"message": "No valid fields to update"}, 400
|
||||
|
|
@ -1034,11 +979,8 @@ class DraftWorkflowTriggerRunApi(Resource):
|
|||
Poll for trigger events and execute full workflow when event arrives
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
parser = reqparse.RequestParser().add_argument(
|
||||
"node_id", type=str, required=True, location="json", nullable=False
|
||||
)
|
||||
args = parser.parse_args()
|
||||
node_id = args["node_id"]
|
||||
args = DraftWorkflowTriggerRunPayload.model_validate(console_ns.payload or {})
|
||||
node_id = args.node_id
|
||||
workflow_service = WorkflowService()
|
||||
draft_workflow = workflow_service.get_draft_workflow(app_model)
|
||||
if not draft_workflow:
|
||||
|
|
@ -1166,14 +1108,7 @@ class DraftWorkflowTriggerRunAllApi(Resource):
|
|||
@console_ns.doc("draft_workflow_trigger_run_all")
|
||||
@console_ns.doc(description="Full workflow debug when the start node is a trigger")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"DraftWorkflowTriggerRunAllRequest",
|
||||
{
|
||||
"node_ids": fields.List(fields.String, required=True, description="Node IDs"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[DraftWorkflowTriggerRunAllPayload.__name__])
|
||||
@console_ns.response(200, "Workflow executed successfully")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@console_ns.response(500, "Internal server error")
|
||||
|
|
@ -1188,11 +1123,8 @@ class DraftWorkflowTriggerRunAllApi(Resource):
|
|||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
parser = reqparse.RequestParser().add_argument(
|
||||
"node_ids", type=list, required=True, location="json", nullable=False
|
||||
)
|
||||
args = parser.parse_args()
|
||||
node_ids = args["node_ids"]
|
||||
args = DraftWorkflowTriggerRunAllPayload.model_validate(console_ns.payload or {})
|
||||
node_ids = args.node_ids
|
||||
workflow_service = WorkflowService()
|
||||
draft_workflow = workflow_service.get_draft_workflow(app_model)
|
||||
if not draft_workflow:
|
||||
|
|
|
|||
|
|
@ -1,6 +1,9 @@
|
|||
from datetime import datetime
|
||||
|
||||
from dateutil.parser import isoparse
|
||||
from flask_restx import Resource, marshal_with, reqparse
|
||||
from flask_restx.inputs import int_range
|
||||
from flask import request
|
||||
from flask_restx import Resource, marshal_with
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from controllers.console import console_ns
|
||||
|
|
@ -14,6 +17,48 @@ from models import App
|
|||
from models.model import AppMode
|
||||
from services.workflow_app_service import WorkflowAppService
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class WorkflowAppLogQuery(BaseModel):
|
||||
keyword: str | None = Field(default=None, description="Search keyword for filtering logs")
|
||||
status: WorkflowExecutionStatus | None = Field(
|
||||
default=None, description="Execution status filter (succeeded, failed, stopped, partial-succeeded)"
|
||||
)
|
||||
created_at__before: datetime | None = Field(default=None, description="Filter logs created before this timestamp")
|
||||
created_at__after: datetime | None = Field(default=None, description="Filter logs created after this timestamp")
|
||||
created_by_end_user_session_id: str | None = Field(default=None, description="Filter by end user session ID")
|
||||
created_by_account: str | None = Field(default=None, description="Filter by account")
|
||||
detail: bool = Field(default=False, description="Whether to return detailed logs")
|
||||
page: int = Field(default=1, ge=1, le=99999, description="Page number (1-99999)")
|
||||
limit: int = Field(default=20, ge=1, le=100, description="Number of items per page (1-100)")
|
||||
|
||||
@field_validator("created_at__before", "created_at__after", mode="before")
|
||||
@classmethod
|
||||
def parse_datetime(cls, value: str | None) -> datetime | None:
|
||||
if value in (None, ""):
|
||||
return None
|
||||
return isoparse(value) # type: ignore
|
||||
|
||||
@field_validator("detail", mode="before")
|
||||
@classmethod
|
||||
def parse_bool(cls, value: bool | str | None) -> bool:
|
||||
if isinstance(value, bool):
|
||||
return value
|
||||
if value is None:
|
||||
return False
|
||||
lowered = value.lower()
|
||||
if lowered in {"1", "true", "yes", "on"}:
|
||||
return True
|
||||
if lowered in {"0", "false", "no", "off"}:
|
||||
return False
|
||||
raise ValueError("Invalid boolean value for detail")
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
WorkflowAppLogQuery.__name__, WorkflowAppLogQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
|
||||
# Register model for flask_restx to avoid dict type issues in Swagger
|
||||
workflow_app_log_pagination_model = build_workflow_app_log_pagination_model(console_ns)
|
||||
|
||||
|
|
@ -23,19 +68,7 @@ class WorkflowAppLogApi(Resource):
|
|||
@console_ns.doc("get_workflow_app_logs")
|
||||
@console_ns.doc(description="Get workflow application execution logs")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.doc(
|
||||
params={
|
||||
"keyword": "Search keyword for filtering logs",
|
||||
"status": "Filter by execution status (succeeded, failed, stopped, partial-succeeded)",
|
||||
"created_at__before": "Filter logs created before this timestamp",
|
||||
"created_at__after": "Filter logs created after this timestamp",
|
||||
"created_by_end_user_session_id": "Filter by end user session ID",
|
||||
"created_by_account": "Filter by account",
|
||||
"detail": "Whether to return detailed logs",
|
||||
"page": "Page number (1-99999)",
|
||||
"limit": "Number of items per page (1-100)",
|
||||
}
|
||||
)
|
||||
@console_ns.expect(console_ns.models[WorkflowAppLogQuery.__name__])
|
||||
@console_ns.response(200, "Workflow app logs retrieved successfully", workflow_app_log_pagination_model)
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -46,44 +79,7 @@ class WorkflowAppLogApi(Resource):
|
|||
"""
|
||||
Get workflow app logs
|
||||
"""
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("keyword", type=str, location="args")
|
||||
.add_argument(
|
||||
"status", type=str, choices=["succeeded", "failed", "stopped", "partial-succeeded"], location="args"
|
||||
)
|
||||
.add_argument(
|
||||
"created_at__before", type=str, location="args", help="Filter logs created before this timestamp"
|
||||
)
|
||||
.add_argument(
|
||||
"created_at__after", type=str, location="args", help="Filter logs created after this timestamp"
|
||||
)
|
||||
.add_argument(
|
||||
"created_by_end_user_session_id",
|
||||
type=str,
|
||||
location="args",
|
||||
required=False,
|
||||
default=None,
|
||||
)
|
||||
.add_argument(
|
||||
"created_by_account",
|
||||
type=str,
|
||||
location="args",
|
||||
required=False,
|
||||
default=None,
|
||||
)
|
||||
.add_argument("detail", type=bool, location="args", required=False, default=False)
|
||||
.add_argument("page", type=int_range(1, 99999), default=1, location="args")
|
||||
.add_argument("limit", type=int_range(1, 100), default=20, location="args")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
args.status = WorkflowExecutionStatus(args.status) if args.status else None
|
||||
if args.created_at__before:
|
||||
args.created_at__before = isoparse(args.created_at__before)
|
||||
|
||||
if args.created_at__after:
|
||||
args.created_at__after = isoparse(args.created_at__after)
|
||||
args = WorkflowAppLogQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
# get paginate workflow app logs
|
||||
workflow_app_service = WorkflowAppService()
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
from typing import cast
|
||||
from typing import Literal, cast
|
||||
|
||||
from flask_restx import Resource, fields, marshal_with, reqparse
|
||||
from flask_restx.inputs import int_range
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
|
|
@ -92,70 +93,51 @@ workflow_run_node_execution_list_model = console_ns.model(
|
|||
"WorkflowRunNodeExecutionList", workflow_run_node_execution_list_fields_copy
|
||||
)
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
def _parse_workflow_run_list_args():
|
||||
"""
|
||||
Parse common arguments for workflow run list endpoints.
|
||||
|
||||
Returns:
|
||||
Parsed arguments containing last_id, limit, status, and triggered_from filters
|
||||
"""
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("last_id", type=uuid_value, location="args")
|
||||
.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
|
||||
.add_argument(
|
||||
"status",
|
||||
type=str,
|
||||
choices=WORKFLOW_RUN_STATUS_CHOICES,
|
||||
location="args",
|
||||
required=False,
|
||||
)
|
||||
.add_argument(
|
||||
"triggered_from",
|
||||
type=str,
|
||||
choices=["debugging", "app-run"],
|
||||
location="args",
|
||||
required=False,
|
||||
help="Filter by trigger source: debugging or app-run",
|
||||
)
|
||||
class WorkflowRunListQuery(BaseModel):
|
||||
last_id: str | None = Field(default=None, description="Last run ID for pagination")
|
||||
limit: int = Field(default=20, ge=1, le=100, description="Number of items per page (1-100)")
|
||||
status: Literal["running", "succeeded", "failed", "stopped", "partial-succeeded"] | None = Field(
|
||||
default=None, description="Workflow run status filter"
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def _parse_workflow_run_count_args():
|
||||
"""
|
||||
Parse common arguments for workflow run count endpoints.
|
||||
|
||||
Returns:
|
||||
Parsed arguments containing status, time_range, and triggered_from filters
|
||||
"""
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument(
|
||||
"status",
|
||||
type=str,
|
||||
choices=WORKFLOW_RUN_STATUS_CHOICES,
|
||||
location="args",
|
||||
required=False,
|
||||
)
|
||||
.add_argument(
|
||||
"time_range",
|
||||
type=time_duration,
|
||||
location="args",
|
||||
required=False,
|
||||
help="Time range filter (e.g., 7d, 4h, 30m, 30s)",
|
||||
)
|
||||
.add_argument(
|
||||
"triggered_from",
|
||||
type=str,
|
||||
choices=["debugging", "app-run"],
|
||||
location="args",
|
||||
required=False,
|
||||
help="Filter by trigger source: debugging or app-run",
|
||||
)
|
||||
triggered_from: Literal["debugging", "app-run"] | None = Field(
|
||||
default=None, description="Filter by trigger source: debugging or app-run"
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
@field_validator("last_id")
|
||||
@classmethod
|
||||
def validate_last_id(cls, value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return value
|
||||
return uuid_value(value)
|
||||
|
||||
|
||||
class WorkflowRunCountQuery(BaseModel):
|
||||
status: Literal["running", "succeeded", "failed", "stopped", "partial-succeeded"] | None = Field(
|
||||
default=None, description="Workflow run status filter"
|
||||
)
|
||||
time_range: str | None = Field(default=None, description="Time range filter (e.g., 7d, 4h, 30m, 30s)")
|
||||
triggered_from: Literal["debugging", "app-run"] | None = Field(
|
||||
default=None, description="Filter by trigger source: debugging or app-run"
|
||||
)
|
||||
|
||||
@field_validator("time_range")
|
||||
@classmethod
|
||||
def validate_time_range(cls, value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return value
|
||||
return time_duration(value)
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
WorkflowRunListQuery.__name__, WorkflowRunListQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
console_ns.schema_model(
|
||||
WorkflowRunCountQuery.__name__,
|
||||
WorkflowRunCountQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/advanced-chat/workflow-runs")
|
||||
|
|
@ -170,6 +152,7 @@ class AdvancedChatAppWorkflowRunListApi(Resource):
|
|||
@console_ns.doc(
|
||||
params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"}
|
||||
)
|
||||
@console_ns.expect(console_ns.models[WorkflowRunListQuery.__name__])
|
||||
@console_ns.response(200, "Workflow runs retrieved successfully", advanced_chat_workflow_run_pagination_model)
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -180,12 +163,13 @@ class AdvancedChatAppWorkflowRunListApi(Resource):
|
|||
"""
|
||||
Get advanced chat app workflow run list
|
||||
"""
|
||||
args = _parse_workflow_run_list_args()
|
||||
args_model = WorkflowRunListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = args_model.model_dump(exclude_none=True)
|
||||
|
||||
# Default to DEBUGGING if not specified
|
||||
triggered_from = (
|
||||
WorkflowRunTriggeredFrom(args.get("triggered_from"))
|
||||
if args.get("triggered_from")
|
||||
WorkflowRunTriggeredFrom(args_model.triggered_from)
|
||||
if args_model.triggered_from
|
||||
else WorkflowRunTriggeredFrom.DEBUGGING
|
||||
)
|
||||
|
||||
|
|
@ -217,6 +201,7 @@ class AdvancedChatAppWorkflowRunCountApi(Resource):
|
|||
params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"}
|
||||
)
|
||||
@console_ns.response(200, "Workflow runs count retrieved successfully", workflow_run_count_model)
|
||||
@console_ns.expect(console_ns.models[WorkflowRunCountQuery.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -226,12 +211,13 @@ class AdvancedChatAppWorkflowRunCountApi(Resource):
|
|||
"""
|
||||
Get advanced chat workflow runs count statistics
|
||||
"""
|
||||
args = _parse_workflow_run_count_args()
|
||||
args_model = WorkflowRunCountQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = args_model.model_dump(exclude_none=True)
|
||||
|
||||
# Default to DEBUGGING if not specified
|
||||
triggered_from = (
|
||||
WorkflowRunTriggeredFrom(args.get("triggered_from"))
|
||||
if args.get("triggered_from")
|
||||
WorkflowRunTriggeredFrom(args_model.triggered_from)
|
||||
if args_model.triggered_from
|
||||
else WorkflowRunTriggeredFrom.DEBUGGING
|
||||
)
|
||||
|
||||
|
|
@ -259,6 +245,7 @@ class WorkflowRunListApi(Resource):
|
|||
params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"}
|
||||
)
|
||||
@console_ns.response(200, "Workflow runs retrieved successfully", workflow_run_pagination_model)
|
||||
@console_ns.expect(console_ns.models[WorkflowRunListQuery.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -268,12 +255,13 @@ class WorkflowRunListApi(Resource):
|
|||
"""
|
||||
Get workflow run list
|
||||
"""
|
||||
args = _parse_workflow_run_list_args()
|
||||
args_model = WorkflowRunListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = args_model.model_dump(exclude_none=True)
|
||||
|
||||
# Default to DEBUGGING for workflow if not specified (backward compatibility)
|
||||
triggered_from = (
|
||||
WorkflowRunTriggeredFrom(args.get("triggered_from"))
|
||||
if args.get("triggered_from")
|
||||
WorkflowRunTriggeredFrom(args_model.triggered_from)
|
||||
if args_model.triggered_from
|
||||
else WorkflowRunTriggeredFrom.DEBUGGING
|
||||
)
|
||||
|
||||
|
|
@ -305,6 +293,7 @@ class WorkflowRunCountApi(Resource):
|
|||
params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"}
|
||||
)
|
||||
@console_ns.response(200, "Workflow runs count retrieved successfully", workflow_run_count_model)
|
||||
@console_ns.expect(console_ns.models[WorkflowRunCountQuery.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -314,12 +303,13 @@ class WorkflowRunCountApi(Resource):
|
|||
"""
|
||||
Get workflow runs count statistics
|
||||
"""
|
||||
args = _parse_workflow_run_count_args()
|
||||
args_model = WorkflowRunCountQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = args_model.model_dump(exclude_none=True)
|
||||
|
||||
# Default to DEBUGGING for workflow if not specified (backward compatibility)
|
||||
triggered_from = (
|
||||
WorkflowRunTriggeredFrom(args.get("triggered_from"))
|
||||
if args.get("triggered_from")
|
||||
WorkflowRunTriggeredFrom(args_model.triggered_from)
|
||||
if args_model.triggered_from
|
||||
else WorkflowRunTriggeredFrom.DEBUGGING
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
from flask import abort, jsonify
|
||||
from flask_restx import Resource, reqparse
|
||||
from flask import abort, jsonify, request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from controllers.console import console_ns
|
||||
|
|
@ -7,12 +8,31 @@ from controllers.console.app.wraps import get_app_model
|
|||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import parse_time_range
|
||||
from libs.helper import DatetimeString
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
from models.model import AppMode
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class WorkflowStatisticQuery(BaseModel):
|
||||
start: str | None = Field(default=None, description="Start date and time (YYYY-MM-DD HH:MM)")
|
||||
end: str | None = Field(default=None, description="End date and time (YYYY-MM-DD HH:MM)")
|
||||
|
||||
@field_validator("start", "end", mode="before")
|
||||
@classmethod
|
||||
def blank_to_none(cls, value: str | None) -> str | None:
|
||||
if value == "":
|
||||
return None
|
||||
return value
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
WorkflowStatisticQuery.__name__,
|
||||
WorkflowStatisticQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow/statistics/daily-conversations")
|
||||
class WorkflowDailyRunsStatistic(Resource):
|
||||
|
|
@ -24,9 +44,7 @@ class WorkflowDailyRunsStatistic(Resource):
|
|||
@console_ns.doc("get_workflow_daily_runs_statistic")
|
||||
@console_ns.doc(description="Get workflow daily runs statistics")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.doc(
|
||||
params={"start": "Start date and time (YYYY-MM-DD HH:MM)", "end": "End date and time (YYYY-MM-DD HH:MM)"}
|
||||
)
|
||||
@console_ns.expect(console_ns.models[WorkflowStatisticQuery.__name__])
|
||||
@console_ns.response(200, "Daily runs statistics retrieved successfully")
|
||||
@get_app_model
|
||||
@setup_required
|
||||
|
|
@ -35,17 +53,12 @@ class WorkflowDailyRunsStatistic(Resource):
|
|||
def get(self, app_model):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
assert account.timezone is not None
|
||||
|
||||
try:
|
||||
start_date, end_date = parse_time_range(args["start"], args["end"], account.timezone)
|
||||
start_date, end_date = parse_time_range(args.start, args.end, account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
|
|
@ -71,9 +84,7 @@ class WorkflowDailyTerminalsStatistic(Resource):
|
|||
@console_ns.doc("get_workflow_daily_terminals_statistic")
|
||||
@console_ns.doc(description="Get workflow daily terminals statistics")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.doc(
|
||||
params={"start": "Start date and time (YYYY-MM-DD HH:MM)", "end": "End date and time (YYYY-MM-DD HH:MM)"}
|
||||
)
|
||||
@console_ns.expect(console_ns.models[WorkflowStatisticQuery.__name__])
|
||||
@console_ns.response(200, "Daily terminals statistics retrieved successfully")
|
||||
@get_app_model
|
||||
@setup_required
|
||||
|
|
@ -82,17 +93,12 @@ class WorkflowDailyTerminalsStatistic(Resource):
|
|||
def get(self, app_model):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
assert account.timezone is not None
|
||||
|
||||
try:
|
||||
start_date, end_date = parse_time_range(args["start"], args["end"], account.timezone)
|
||||
start_date, end_date = parse_time_range(args.start, args.end, account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
|
|
@ -118,9 +124,7 @@ class WorkflowDailyTokenCostStatistic(Resource):
|
|||
@console_ns.doc("get_workflow_daily_token_cost_statistic")
|
||||
@console_ns.doc(description="Get workflow daily token cost statistics")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.doc(
|
||||
params={"start": "Start date and time (YYYY-MM-DD HH:MM)", "end": "End date and time (YYYY-MM-DD HH:MM)"}
|
||||
)
|
||||
@console_ns.expect(console_ns.models[WorkflowStatisticQuery.__name__])
|
||||
@console_ns.response(200, "Daily token cost statistics retrieved successfully")
|
||||
@get_app_model
|
||||
@setup_required
|
||||
|
|
@ -129,17 +133,12 @@ class WorkflowDailyTokenCostStatistic(Resource):
|
|||
def get(self, app_model):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
assert account.timezone is not None
|
||||
|
||||
try:
|
||||
start_date, end_date = parse_time_range(args["start"], args["end"], account.timezone)
|
||||
start_date, end_date = parse_time_range(args.start, args.end, account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
|
|
@ -165,9 +164,7 @@ class WorkflowAverageAppInteractionStatistic(Resource):
|
|||
@console_ns.doc("get_workflow_average_app_interaction_statistic")
|
||||
@console_ns.doc(description="Get workflow average app interaction statistics")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.doc(
|
||||
params={"start": "Start date and time (YYYY-MM-DD HH:MM)", "end": "End date and time (YYYY-MM-DD HH:MM)"}
|
||||
)
|
||||
@console_ns.expect(console_ns.models[WorkflowStatisticQuery.__name__])
|
||||
@console_ns.response(200, "Average app interaction statistics retrieved successfully")
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -176,17 +173,12 @@ class WorkflowAverageAppInteractionStatistic(Resource):
|
|||
def get(self, app_model):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
assert account.timezone is not None
|
||||
|
||||
try:
|
||||
start_date, end_date = parse_time_range(args["start"], args["end"], account.timezone)
|
||||
start_date, end_date = parse_time_range(args.start, args.end, account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
|
|
|
|||
|
|
@ -1,14 +1,13 @@
|
|||
import logging
|
||||
|
||||
from flask_restx import Resource, marshal_with, reqparse
|
||||
from flask import request
|
||||
from flask_restx import Resource, marshal_with
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from extensions.ext_database import db
|
||||
from fields.workflow_trigger_fields import trigger_fields, triggers_list_fields, webhook_trigger_fields
|
||||
from libs.login import current_user, login_required
|
||||
|
|
@ -16,12 +15,35 @@ from models.enums import AppTriggerStatus
|
|||
from models.model import Account, App, AppMode
|
||||
from models.trigger import AppTrigger, WorkflowWebhookTrigger
|
||||
|
||||
from .. import console_ns
|
||||
from ..app.wraps import get_app_model
|
||||
from ..wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class Parser(BaseModel):
|
||||
node_id: str
|
||||
|
||||
|
||||
class ParserEnable(BaseModel):
|
||||
trigger_id: str
|
||||
enable_trigger: bool
|
||||
|
||||
|
||||
console_ns.schema_model(Parser.__name__, Parser.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserEnable.__name__, ParserEnable.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/triggers/webhook")
|
||||
class WebhookTriggerApi(Resource):
|
||||
"""Webhook Trigger API"""
|
||||
|
||||
@console_ns.expect(console_ns.models[Parser.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -29,10 +51,9 @@ class WebhookTriggerApi(Resource):
|
|||
@marshal_with(webhook_trigger_fields)
|
||||
def get(self, app_model: App):
|
||||
"""Get webhook trigger for a node"""
|
||||
parser = reqparse.RequestParser().add_argument("node_id", type=str, required=True, help="Node ID is required")
|
||||
args = parser.parse_args()
|
||||
args = Parser.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
node_id = str(args["node_id"])
|
||||
node_id = args.node_id
|
||||
|
||||
with Session(db.engine) as session:
|
||||
# Get webhook trigger for this app and node
|
||||
|
|
@ -51,6 +72,7 @@ class WebhookTriggerApi(Resource):
|
|||
return webhook_trigger
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/triggers")
|
||||
class AppTriggersApi(Resource):
|
||||
"""App Triggers list API"""
|
||||
|
||||
|
|
@ -90,7 +112,9 @@ class AppTriggersApi(Resource):
|
|||
return {"data": triggers}
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/trigger-enable")
|
||||
class AppTriggerEnableApi(Resource):
|
||||
@console_ns.expect(console_ns.models[ParserEnable.__name__], validate=True)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -99,17 +123,11 @@ class AppTriggerEnableApi(Resource):
|
|||
@marshal_with(trigger_fields)
|
||||
def post(self, app_model: App):
|
||||
"""Update app trigger (enable/disable)"""
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("trigger_id", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("enable_trigger", type=bool, required=True, nullable=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = ParserEnable.model_validate(console_ns.payload)
|
||||
|
||||
assert current_user.current_tenant_id is not None
|
||||
|
||||
trigger_id = args["trigger_id"]
|
||||
|
||||
trigger_id = args.trigger_id
|
||||
with Session(db.engine) as session:
|
||||
# Find the trigger using select
|
||||
trigger = session.execute(
|
||||
|
|
@ -124,7 +142,7 @@ class AppTriggerEnableApi(Resource):
|
|||
raise NotFound("Trigger not found")
|
||||
|
||||
# Update status based on enable_trigger boolean
|
||||
trigger.status = AppTriggerStatus.ENABLED if args["enable_trigger"] else AppTriggerStatus.DISABLED
|
||||
trigger.status = AppTriggerStatus.ENABLED if args.enable_trigger else AppTriggerStatus.DISABLED
|
||||
|
||||
session.commit()
|
||||
session.refresh(trigger)
|
||||
|
|
@ -137,8 +155,3 @@ class AppTriggerEnableApi(Resource):
|
|||
trigger.icon = "" # type: ignore
|
||||
|
||||
return trigger
|
||||
|
||||
|
||||
console_ns.add_resource(WebhookTriggerApi, "/apps/<uuid:app_id>/workflows/triggers/webhook")
|
||||
console_ns.add_resource(AppTriggersApi, "/apps/<uuid:app_id>/triggers")
|
||||
console_ns.add_resource(AppTriggerEnableApi, "/apps/<uuid:app_id>/trigger-enable")
|
||||
|
|
|
|||
|
|
@ -15,7 +15,6 @@ from controllers.console.app.error import (
|
|||
from controllers.console.explore.error import NotChatAppError, NotCompletionAppError
|
||||
from controllers.console.explore.wraps import InstalledAppResource
|
||||
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.errors.error import (
|
||||
ModelCurrentlyNotSupportError,
|
||||
|
|
@ -31,6 +30,7 @@ from libs.login import current_user
|
|||
from models import Account
|
||||
from models.model import AppMode
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.app_task_service import AppTaskService
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
|
||||
from .. import console_ns
|
||||
|
|
@ -46,7 +46,7 @@ logger = logging.getLogger(__name__)
|
|||
class CompletionApi(InstalledAppResource):
|
||||
def post(self, installed_app):
|
||||
app_model = installed_app.app
|
||||
if app_model.mode != "completion":
|
||||
if app_model.mode != AppMode.COMPLETION:
|
||||
raise NotCompletionAppError()
|
||||
|
||||
parser = (
|
||||
|
|
@ -102,12 +102,18 @@ class CompletionApi(InstalledAppResource):
|
|||
class CompletionStopApi(InstalledAppResource):
|
||||
def post(self, installed_app, task_id):
|
||||
app_model = installed_app.app
|
||||
if app_model.mode != "completion":
|
||||
if app_model.mode != AppMode.COMPLETION:
|
||||
raise NotCompletionAppError()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id)
|
||||
|
||||
AppTaskService.stop_task(
|
||||
task_id=task_id,
|
||||
invoke_from=InvokeFrom.EXPLORE,
|
||||
user_id=current_user.id,
|
||||
app_mode=AppMode.value_of(app_model.mode),
|
||||
)
|
||||
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
|
@ -184,6 +190,12 @@ class ChatStopApi(InstalledAppResource):
|
|||
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id)
|
||||
|
||||
AppTaskService.stop_task(
|
||||
task_id=task_id,
|
||||
invoke_from=InvokeFrom.EXPLORE,
|
||||
user_id=current_user.id,
|
||||
app_mode=app_mode,
|
||||
)
|
||||
|
||||
return {"result": "success"}, 200
|
||||
|
|
|
|||
|
|
@ -58,7 +58,7 @@ class VersionApi(Resource):
|
|||
response = httpx.get(
|
||||
check_update_url,
|
||||
params={"current_version": args["current_version"]},
|
||||
timeout=httpx.Timeout(connect=3, read=10),
|
||||
timeout=httpx.Timeout(timeout=10.0, connect=3.0),
|
||||
)
|
||||
except Exception as error:
|
||||
logger.warning("Check update version error: %s.", str(error))
|
||||
|
|
|
|||
|
|
@ -1,8 +1,10 @@
|
|||
from datetime import datetime
|
||||
from typing import Literal
|
||||
|
||||
import pytz
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields, marshal_with, reqparse
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
|
|
@ -42,20 +44,160 @@ from services.account_service import AccountService
|
|||
from services.billing_service import BillingService
|
||||
from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
def _init_parser():
|
||||
parser = reqparse.RequestParser()
|
||||
if dify_config.EDITION == "CLOUD":
|
||||
parser.add_argument("invitation_code", type=str, location="json")
|
||||
parser.add_argument("interface_language", type=supported_language, required=True, location="json").add_argument(
|
||||
"timezone", type=timezone, required=True, location="json"
|
||||
)
|
||||
return parser
|
||||
|
||||
class AccountInitPayload(BaseModel):
|
||||
interface_language: str
|
||||
timezone: str
|
||||
invitation_code: str | None = None
|
||||
|
||||
@field_validator("interface_language")
|
||||
@classmethod
|
||||
def validate_language(cls, value: str) -> str:
|
||||
return supported_language(value)
|
||||
|
||||
@field_validator("timezone")
|
||||
@classmethod
|
||||
def validate_timezone(cls, value: str) -> str:
|
||||
return timezone(value)
|
||||
|
||||
|
||||
class AccountNamePayload(BaseModel):
|
||||
name: str = Field(min_length=3, max_length=30)
|
||||
|
||||
|
||||
class AccountAvatarPayload(BaseModel):
|
||||
avatar: str
|
||||
|
||||
|
||||
class AccountInterfaceLanguagePayload(BaseModel):
|
||||
interface_language: str
|
||||
|
||||
@field_validator("interface_language")
|
||||
@classmethod
|
||||
def validate_language(cls, value: str) -> str:
|
||||
return supported_language(value)
|
||||
|
||||
|
||||
class AccountInterfaceThemePayload(BaseModel):
|
||||
interface_theme: Literal["light", "dark"]
|
||||
|
||||
|
||||
class AccountTimezonePayload(BaseModel):
|
||||
timezone: str
|
||||
|
||||
@field_validator("timezone")
|
||||
@classmethod
|
||||
def validate_timezone(cls, value: str) -> str:
|
||||
return timezone(value)
|
||||
|
||||
|
||||
class AccountPasswordPayload(BaseModel):
|
||||
password: str | None = None
|
||||
new_password: str
|
||||
repeat_new_password: str
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_passwords_match(self) -> "AccountPasswordPayload":
|
||||
if self.new_password != self.repeat_new_password:
|
||||
raise RepeatPasswordNotMatchError()
|
||||
return self
|
||||
|
||||
|
||||
class AccountDeletePayload(BaseModel):
|
||||
token: str
|
||||
code: str
|
||||
|
||||
|
||||
class AccountDeletionFeedbackPayload(BaseModel):
|
||||
email: str
|
||||
feedback: str
|
||||
|
||||
@field_validator("email")
|
||||
@classmethod
|
||||
def validate_email(cls, value: str) -> str:
|
||||
return email(value)
|
||||
|
||||
|
||||
class EducationActivatePayload(BaseModel):
|
||||
token: str
|
||||
institution: str
|
||||
role: str
|
||||
|
||||
|
||||
class EducationAutocompleteQuery(BaseModel):
|
||||
keywords: str
|
||||
page: int = 0
|
||||
limit: int = 20
|
||||
|
||||
|
||||
class ChangeEmailSendPayload(BaseModel):
|
||||
email: str
|
||||
language: str | None = None
|
||||
phase: str | None = None
|
||||
token: str | None = None
|
||||
|
||||
@field_validator("email")
|
||||
@classmethod
|
||||
def validate_email(cls, value: str) -> str:
|
||||
return email(value)
|
||||
|
||||
|
||||
class ChangeEmailValidityPayload(BaseModel):
|
||||
email: str
|
||||
code: str
|
||||
token: str
|
||||
|
||||
@field_validator("email")
|
||||
@classmethod
|
||||
def validate_email(cls, value: str) -> str:
|
||||
return email(value)
|
||||
|
||||
|
||||
class ChangeEmailResetPayload(BaseModel):
|
||||
new_email: str
|
||||
token: str
|
||||
|
||||
@field_validator("new_email")
|
||||
@classmethod
|
||||
def validate_email(cls, value: str) -> str:
|
||||
return email(value)
|
||||
|
||||
|
||||
class CheckEmailUniquePayload(BaseModel):
|
||||
email: str
|
||||
|
||||
@field_validator("email")
|
||||
@classmethod
|
||||
def validate_email(cls, value: str) -> str:
|
||||
return email(value)
|
||||
|
||||
|
||||
def reg(cls: type[BaseModel]):
|
||||
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
|
||||
reg(AccountInitPayload)
|
||||
reg(AccountNamePayload)
|
||||
reg(AccountAvatarPayload)
|
||||
reg(AccountInterfaceLanguagePayload)
|
||||
reg(AccountInterfaceThemePayload)
|
||||
reg(AccountTimezonePayload)
|
||||
reg(AccountPasswordPayload)
|
||||
reg(AccountDeletePayload)
|
||||
reg(AccountDeletionFeedbackPayload)
|
||||
reg(EducationActivatePayload)
|
||||
reg(EducationAutocompleteQuery)
|
||||
reg(ChangeEmailSendPayload)
|
||||
reg(ChangeEmailValidityPayload)
|
||||
reg(ChangeEmailResetPayload)
|
||||
reg(CheckEmailUniquePayload)
|
||||
|
||||
|
||||
@console_ns.route("/account/init")
|
||||
class AccountInitApi(Resource):
|
||||
@console_ns.expect(_init_parser())
|
||||
@console_ns.expect(console_ns.models[AccountInitPayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
def post(self):
|
||||
|
|
@ -64,17 +206,18 @@ class AccountInitApi(Resource):
|
|||
if account.status == "active":
|
||||
raise AccountAlreadyInitedError()
|
||||
|
||||
args = _init_parser().parse_args()
|
||||
payload = console_ns.payload or {}
|
||||
args = AccountInitPayload.model_validate(payload)
|
||||
|
||||
if dify_config.EDITION == "CLOUD":
|
||||
if not args["invitation_code"]:
|
||||
if not args.invitation_code:
|
||||
raise ValueError("invitation_code is required")
|
||||
|
||||
# check invitation code
|
||||
invitation_code = (
|
||||
db.session.query(InvitationCode)
|
||||
.where(
|
||||
InvitationCode.code == args["invitation_code"],
|
||||
InvitationCode.code == args.invitation_code,
|
||||
InvitationCode.status == "unused",
|
||||
)
|
||||
.first()
|
||||
|
|
@ -88,8 +231,8 @@ class AccountInitApi(Resource):
|
|||
invitation_code.used_by_tenant_id = account.current_tenant_id
|
||||
invitation_code.used_by_account_id = account.id
|
||||
|
||||
account.interface_language = args["interface_language"]
|
||||
account.timezone = args["timezone"]
|
||||
account.interface_language = args.interface_language
|
||||
account.timezone = args.timezone
|
||||
account.interface_theme = "light"
|
||||
account.status = "active"
|
||||
account.initialized_at = naive_utc_now()
|
||||
|
|
@ -110,137 +253,104 @@ class AccountProfileApi(Resource):
|
|||
return current_user
|
||||
|
||||
|
||||
parser_name = reqparse.RequestParser().add_argument("name", type=str, required=True, location="json")
|
||||
|
||||
|
||||
@console_ns.route("/account/name")
|
||||
class AccountNameApi(Resource):
|
||||
@console_ns.expect(parser_name)
|
||||
@console_ns.expect(console_ns.models[AccountNamePayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(account_fields)
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
args = parser_name.parse_args()
|
||||
|
||||
# Validate account name length
|
||||
if len(args["name"]) < 3 or len(args["name"]) > 30:
|
||||
raise ValueError("Account name must be between 3 and 30 characters.")
|
||||
|
||||
updated_account = AccountService.update_account(current_user, name=args["name"])
|
||||
payload = console_ns.payload or {}
|
||||
args = AccountNamePayload.model_validate(payload)
|
||||
updated_account = AccountService.update_account(current_user, name=args.name)
|
||||
|
||||
return updated_account
|
||||
|
||||
|
||||
parser_avatar = reqparse.RequestParser().add_argument("avatar", type=str, required=True, location="json")
|
||||
|
||||
|
||||
@console_ns.route("/account/avatar")
|
||||
class AccountAvatarApi(Resource):
|
||||
@console_ns.expect(parser_avatar)
|
||||
@console_ns.expect(console_ns.models[AccountAvatarPayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(account_fields)
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
args = parser_avatar.parse_args()
|
||||
payload = console_ns.payload or {}
|
||||
args = AccountAvatarPayload.model_validate(payload)
|
||||
|
||||
updated_account = AccountService.update_account(current_user, avatar=args["avatar"])
|
||||
updated_account = AccountService.update_account(current_user, avatar=args.avatar)
|
||||
|
||||
return updated_account
|
||||
|
||||
|
||||
parser_interface = reqparse.RequestParser().add_argument(
|
||||
"interface_language", type=supported_language, required=True, location="json"
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/account/interface-language")
|
||||
class AccountInterfaceLanguageApi(Resource):
|
||||
@console_ns.expect(parser_interface)
|
||||
@console_ns.expect(console_ns.models[AccountInterfaceLanguagePayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(account_fields)
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
args = parser_interface.parse_args()
|
||||
payload = console_ns.payload or {}
|
||||
args = AccountInterfaceLanguagePayload.model_validate(payload)
|
||||
|
||||
updated_account = AccountService.update_account(current_user, interface_language=args["interface_language"])
|
||||
updated_account = AccountService.update_account(current_user, interface_language=args.interface_language)
|
||||
|
||||
return updated_account
|
||||
|
||||
|
||||
parser_theme = reqparse.RequestParser().add_argument(
|
||||
"interface_theme", type=str, choices=["light", "dark"], required=True, location="json"
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/account/interface-theme")
|
||||
class AccountInterfaceThemeApi(Resource):
|
||||
@console_ns.expect(parser_theme)
|
||||
@console_ns.expect(console_ns.models[AccountInterfaceThemePayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(account_fields)
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
args = parser_theme.parse_args()
|
||||
payload = console_ns.payload or {}
|
||||
args = AccountInterfaceThemePayload.model_validate(payload)
|
||||
|
||||
updated_account = AccountService.update_account(current_user, interface_theme=args["interface_theme"])
|
||||
updated_account = AccountService.update_account(current_user, interface_theme=args.interface_theme)
|
||||
|
||||
return updated_account
|
||||
|
||||
|
||||
parser_timezone = reqparse.RequestParser().add_argument("timezone", type=str, required=True, location="json")
|
||||
|
||||
|
||||
@console_ns.route("/account/timezone")
|
||||
class AccountTimezoneApi(Resource):
|
||||
@console_ns.expect(parser_timezone)
|
||||
@console_ns.expect(console_ns.models[AccountTimezonePayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(account_fields)
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
args = parser_timezone.parse_args()
|
||||
payload = console_ns.payload or {}
|
||||
args = AccountTimezonePayload.model_validate(payload)
|
||||
|
||||
# Validate timezone string, e.g. America/New_York, Asia/Shanghai
|
||||
if args["timezone"] not in pytz.all_timezones:
|
||||
raise ValueError("Invalid timezone string.")
|
||||
|
||||
updated_account = AccountService.update_account(current_user, timezone=args["timezone"])
|
||||
updated_account = AccountService.update_account(current_user, timezone=args.timezone)
|
||||
|
||||
return updated_account
|
||||
|
||||
|
||||
parser_pw = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("password", type=str, required=False, location="json")
|
||||
.add_argument("new_password", type=str, required=True, location="json")
|
||||
.add_argument("repeat_new_password", type=str, required=True, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/account/password")
|
||||
class AccountPasswordApi(Resource):
|
||||
@console_ns.expect(parser_pw)
|
||||
@console_ns.expect(console_ns.models[AccountPasswordPayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(account_fields)
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
args = parser_pw.parse_args()
|
||||
|
||||
if args["new_password"] != args["repeat_new_password"]:
|
||||
raise RepeatPasswordNotMatchError()
|
||||
payload = console_ns.payload or {}
|
||||
args = AccountPasswordPayload.model_validate(payload)
|
||||
|
||||
try:
|
||||
AccountService.update_account_password(current_user, args["password"], args["new_password"])
|
||||
AccountService.update_account_password(current_user, args.password, args.new_password)
|
||||
except ServiceCurrentPasswordIncorrectError:
|
||||
raise CurrentPasswordIncorrectError()
|
||||
|
||||
|
|
@ -316,25 +426,19 @@ class AccountDeleteVerifyApi(Resource):
|
|||
return {"result": "success", "data": token}
|
||||
|
||||
|
||||
parser_delete = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("token", type=str, required=True, location="json")
|
||||
.add_argument("code", type=str, required=True, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/account/delete")
|
||||
class AccountDeleteApi(Resource):
|
||||
@console_ns.expect(parser_delete)
|
||||
@console_ns.expect(console_ns.models[AccountDeletePayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
args = parser_delete.parse_args()
|
||||
payload = console_ns.payload or {}
|
||||
args = AccountDeletePayload.model_validate(payload)
|
||||
|
||||
if not AccountService.verify_account_deletion_code(args["token"], args["code"]):
|
||||
if not AccountService.verify_account_deletion_code(args.token, args.code):
|
||||
raise InvalidAccountDeletionCodeError()
|
||||
|
||||
AccountService.delete_account(account)
|
||||
|
|
@ -342,21 +446,15 @@ class AccountDeleteApi(Resource):
|
|||
return {"result": "success"}
|
||||
|
||||
|
||||
parser_feedback = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("email", type=str, required=True, location="json")
|
||||
.add_argument("feedback", type=str, required=True, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/account/delete/feedback")
|
||||
class AccountDeleteUpdateFeedbackApi(Resource):
|
||||
@console_ns.expect(parser_feedback)
|
||||
@console_ns.expect(console_ns.models[AccountDeletionFeedbackPayload.__name__])
|
||||
@setup_required
|
||||
def post(self):
|
||||
args = parser_feedback.parse_args()
|
||||
payload = console_ns.payload or {}
|
||||
args = AccountDeletionFeedbackPayload.model_validate(payload)
|
||||
|
||||
BillingService.update_account_deletion_feedback(args["email"], args["feedback"])
|
||||
BillingService.update_account_deletion_feedback(args.email, args.feedback)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
|
|
@ -379,14 +477,6 @@ class EducationVerifyApi(Resource):
|
|||
return BillingService.EducationIdentity.verify(account.id, account.email)
|
||||
|
||||
|
||||
parser_edu = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("token", type=str, required=True, location="json")
|
||||
.add_argument("institution", type=str, required=True, location="json")
|
||||
.add_argument("role", type=str, required=True, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/account/education")
|
||||
class EducationApi(Resource):
|
||||
status_fields = {
|
||||
|
|
@ -396,7 +486,7 @@ class EducationApi(Resource):
|
|||
"allow_refresh": fields.Boolean,
|
||||
}
|
||||
|
||||
@console_ns.expect(parser_edu)
|
||||
@console_ns.expect(console_ns.models[EducationActivatePayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -405,9 +495,10 @@ class EducationApi(Resource):
|
|||
def post(self):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
args = parser_edu.parse_args()
|
||||
payload = console_ns.payload or {}
|
||||
args = EducationActivatePayload.model_validate(payload)
|
||||
|
||||
return BillingService.EducationIdentity.activate(account, args["token"], args["institution"], args["role"])
|
||||
return BillingService.EducationIdentity.activate(account, args.token, args.institution, args.role)
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -425,14 +516,6 @@ class EducationApi(Resource):
|
|||
return res
|
||||
|
||||
|
||||
parser_autocomplete = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("keywords", type=str, required=True, location="args")
|
||||
.add_argument("page", type=int, required=False, location="args", default=0)
|
||||
.add_argument("limit", type=int, required=False, location="args", default=20)
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/account/education/autocomplete")
|
||||
class EducationAutoCompleteApi(Resource):
|
||||
data_fields = {
|
||||
|
|
@ -441,7 +524,7 @@ class EducationAutoCompleteApi(Resource):
|
|||
"has_next": fields.Boolean,
|
||||
}
|
||||
|
||||
@console_ns.expect(parser_autocomplete)
|
||||
@console_ns.expect(console_ns.models[EducationAutocompleteQuery.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -449,46 +532,39 @@ class EducationAutoCompleteApi(Resource):
|
|||
@cloud_edition_billing_enabled
|
||||
@marshal_with(data_fields)
|
||||
def get(self):
|
||||
args = parser_autocomplete.parse_args()
|
||||
payload = request.args.to_dict(flat=True) # type: ignore
|
||||
args = EducationAutocompleteQuery.model_validate(payload)
|
||||
|
||||
return BillingService.EducationIdentity.autocomplete(args["keywords"], args["page"], args["limit"])
|
||||
|
||||
|
||||
parser_change_email = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("email", type=email, required=True, location="json")
|
||||
.add_argument("language", type=str, required=False, location="json")
|
||||
.add_argument("phase", type=str, required=False, location="json")
|
||||
.add_argument("token", type=str, required=False, location="json")
|
||||
)
|
||||
return BillingService.EducationIdentity.autocomplete(args.keywords, args.page, args.limit)
|
||||
|
||||
|
||||
@console_ns.route("/account/change-email")
|
||||
class ChangeEmailSendEmailApi(Resource):
|
||||
@console_ns.expect(parser_change_email)
|
||||
@console_ns.expect(console_ns.models[ChangeEmailSendPayload.__name__])
|
||||
@enable_change_email
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
args = parser_change_email.parse_args()
|
||||
payload = console_ns.payload or {}
|
||||
args = ChangeEmailSendPayload.model_validate(payload)
|
||||
|
||||
ip_address = extract_remote_ip(request)
|
||||
if AccountService.is_email_send_ip_limit(ip_address):
|
||||
raise EmailSendIpLimitError()
|
||||
|
||||
if args["language"] is not None and args["language"] == "zh-Hans":
|
||||
if args.language is not None and args.language == "zh-Hans":
|
||||
language = "zh-Hans"
|
||||
else:
|
||||
language = "en-US"
|
||||
account = None
|
||||
user_email = args["email"]
|
||||
if args["phase"] is not None and args["phase"] == "new_email":
|
||||
if args["token"] is None:
|
||||
user_email = args.email
|
||||
if args.phase is not None and args.phase == "new_email":
|
||||
if args.token is None:
|
||||
raise InvalidTokenError()
|
||||
|
||||
reset_data = AccountService.get_change_email_data(args["token"])
|
||||
reset_data = AccountService.get_change_email_data(args.token)
|
||||
if reset_data is None:
|
||||
raise InvalidTokenError()
|
||||
user_email = reset_data.get("email", "")
|
||||
|
|
@ -497,118 +573,103 @@ class ChangeEmailSendEmailApi(Resource):
|
|||
raise InvalidEmailError()
|
||||
else:
|
||||
with Session(db.engine) as session:
|
||||
account = session.execute(select(Account).filter_by(email=args["email"])).scalar_one_or_none()
|
||||
account = session.execute(select(Account).filter_by(email=args.email)).scalar_one_or_none()
|
||||
if account is None:
|
||||
raise AccountNotFound()
|
||||
|
||||
token = AccountService.send_change_email_email(
|
||||
account=account, email=args["email"], old_email=user_email, language=language, phase=args["phase"]
|
||||
account=account, email=args.email, old_email=user_email, language=language, phase=args.phase
|
||||
)
|
||||
return {"result": "success", "data": token}
|
||||
|
||||
|
||||
parser_validity = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("email", type=email, required=True, location="json")
|
||||
.add_argument("code", type=str, required=True, location="json")
|
||||
.add_argument("token", type=str, required=True, nullable=False, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/account/change-email/validity")
|
||||
class ChangeEmailCheckApi(Resource):
|
||||
@console_ns.expect(parser_validity)
|
||||
@console_ns.expect(console_ns.models[ChangeEmailValidityPayload.__name__])
|
||||
@enable_change_email
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
args = parser_validity.parse_args()
|
||||
payload = console_ns.payload or {}
|
||||
args = ChangeEmailValidityPayload.model_validate(payload)
|
||||
|
||||
user_email = args["email"]
|
||||
user_email = args.email
|
||||
|
||||
is_change_email_error_rate_limit = AccountService.is_change_email_error_rate_limit(args["email"])
|
||||
is_change_email_error_rate_limit = AccountService.is_change_email_error_rate_limit(args.email)
|
||||
if is_change_email_error_rate_limit:
|
||||
raise EmailChangeLimitError()
|
||||
|
||||
token_data = AccountService.get_change_email_data(args["token"])
|
||||
token_data = AccountService.get_change_email_data(args.token)
|
||||
if token_data is None:
|
||||
raise InvalidTokenError()
|
||||
|
||||
if user_email != token_data.get("email"):
|
||||
raise InvalidEmailError()
|
||||
|
||||
if args["code"] != token_data.get("code"):
|
||||
AccountService.add_change_email_error_rate_limit(args["email"])
|
||||
if args.code != token_data.get("code"):
|
||||
AccountService.add_change_email_error_rate_limit(args.email)
|
||||
raise EmailCodeError()
|
||||
|
||||
# Verified, revoke the first token
|
||||
AccountService.revoke_change_email_token(args["token"])
|
||||
AccountService.revoke_change_email_token(args.token)
|
||||
|
||||
# Refresh token data by generating a new token
|
||||
_, new_token = AccountService.generate_change_email_token(
|
||||
user_email, code=args["code"], old_email=token_data.get("old_email"), additional_data={}
|
||||
user_email, code=args.code, old_email=token_data.get("old_email"), additional_data={}
|
||||
)
|
||||
|
||||
AccountService.reset_change_email_error_rate_limit(args["email"])
|
||||
AccountService.reset_change_email_error_rate_limit(args.email)
|
||||
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
|
||||
|
||||
|
||||
parser_reset = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("new_email", type=email, required=True, location="json")
|
||||
.add_argument("token", type=str, required=True, nullable=False, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/account/change-email/reset")
|
||||
class ChangeEmailResetApi(Resource):
|
||||
@console_ns.expect(parser_reset)
|
||||
@console_ns.expect(console_ns.models[ChangeEmailResetPayload.__name__])
|
||||
@enable_change_email
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(account_fields)
|
||||
def post(self):
|
||||
args = parser_reset.parse_args()
|
||||
payload = console_ns.payload or {}
|
||||
args = ChangeEmailResetPayload.model_validate(payload)
|
||||
|
||||
if AccountService.is_account_in_freeze(args["new_email"]):
|
||||
if AccountService.is_account_in_freeze(args.new_email):
|
||||
raise AccountInFreezeError()
|
||||
|
||||
if not AccountService.check_email_unique(args["new_email"]):
|
||||
if not AccountService.check_email_unique(args.new_email):
|
||||
raise EmailAlreadyInUseError()
|
||||
|
||||
reset_data = AccountService.get_change_email_data(args["token"])
|
||||
reset_data = AccountService.get_change_email_data(args.token)
|
||||
if not reset_data:
|
||||
raise InvalidTokenError()
|
||||
|
||||
AccountService.revoke_change_email_token(args["token"])
|
||||
AccountService.revoke_change_email_token(args.token)
|
||||
|
||||
old_email = reset_data.get("old_email", "")
|
||||
current_user, _ = current_account_with_tenant()
|
||||
if current_user.email != old_email:
|
||||
raise AccountNotFound()
|
||||
|
||||
updated_account = AccountService.update_account_email(current_user, email=args["new_email"])
|
||||
updated_account = AccountService.update_account_email(current_user, email=args.new_email)
|
||||
|
||||
AccountService.send_change_email_completed_notify_email(
|
||||
email=args["new_email"],
|
||||
email=args.new_email,
|
||||
)
|
||||
|
||||
return updated_account
|
||||
|
||||
|
||||
parser_check = reqparse.RequestParser().add_argument("email", type=email, required=True, location="json")
|
||||
|
||||
|
||||
@console_ns.route("/account/change-email/check-email-unique")
|
||||
class CheckEmailUnique(Resource):
|
||||
@console_ns.expect(parser_check)
|
||||
@console_ns.expect(console_ns.models[CheckEmailUniquePayload.__name__])
|
||||
@setup_required
|
||||
def post(self):
|
||||
args = parser_check.parse_args()
|
||||
if AccountService.is_account_in_freeze(args["email"]):
|
||||
payload = console_ns.payload or {}
|
||||
args = CheckEmailUniquePayload.model_validate(payload)
|
||||
if AccountService.is_account_in_freeze(args.email):
|
||||
raise AccountInFreezeError()
|
||||
if not AccountService.check_email_unique(args["email"]):
|
||||
if not AccountService.check_email_unique(args.email):
|
||||
raise EmailAlreadyInUseError()
|
||||
return {"result": "success"}
|
||||
|
|
|
|||
|
|
@ -1,4 +1,8 @@
|
|||
from flask_restx import Resource, fields, reqparse
|
||||
from typing import Any
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required
|
||||
|
|
@ -7,21 +11,49 @@ from core.plugin.impl.exc import PluginPermissionDeniedError
|
|||
from libs.login import current_account_with_tenant, login_required
|
||||
from services.plugin.endpoint_service import EndpointService
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class EndpointCreatePayload(BaseModel):
|
||||
plugin_unique_identifier: str
|
||||
settings: dict[str, Any]
|
||||
name: str = Field(min_length=1)
|
||||
|
||||
|
||||
class EndpointIdPayload(BaseModel):
|
||||
endpoint_id: str
|
||||
|
||||
|
||||
class EndpointUpdatePayload(EndpointIdPayload):
|
||||
settings: dict[str, Any]
|
||||
name: str = Field(min_length=1)
|
||||
|
||||
|
||||
class EndpointListQuery(BaseModel):
|
||||
page: int = Field(ge=1)
|
||||
page_size: int = Field(gt=0)
|
||||
|
||||
|
||||
class EndpointListForPluginQuery(EndpointListQuery):
|
||||
plugin_id: str
|
||||
|
||||
|
||||
def reg(cls: type[BaseModel]):
|
||||
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
|
||||
reg(EndpointCreatePayload)
|
||||
reg(EndpointIdPayload)
|
||||
reg(EndpointUpdatePayload)
|
||||
reg(EndpointListQuery)
|
||||
reg(EndpointListForPluginQuery)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/endpoints/create")
|
||||
class EndpointCreateApi(Resource):
|
||||
@console_ns.doc("create_endpoint")
|
||||
@console_ns.doc(description="Create a new plugin endpoint")
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"EndpointCreateRequest",
|
||||
{
|
||||
"plugin_unique_identifier": fields.String(required=True, description="Plugin unique identifier"),
|
||||
"settings": fields.Raw(required=True, description="Endpoint settings"),
|
||||
"name": fields.String(required=True, description="Endpoint name"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[EndpointCreatePayload.__name__])
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Endpoint created successfully",
|
||||
|
|
@ -35,26 +67,16 @@ class EndpointCreateApi(Resource):
|
|||
def post(self):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("plugin_unique_identifier", type=str, required=True)
|
||||
.add_argument("settings", type=dict, required=True)
|
||||
.add_argument("name", type=str, required=True)
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
plugin_unique_identifier = args["plugin_unique_identifier"]
|
||||
settings = args["settings"]
|
||||
name = args["name"]
|
||||
args = EndpointCreatePayload.model_validate(console_ns.payload)
|
||||
|
||||
try:
|
||||
return {
|
||||
"success": EndpointService.create_endpoint(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user.id,
|
||||
plugin_unique_identifier=plugin_unique_identifier,
|
||||
name=name,
|
||||
settings=settings,
|
||||
plugin_unique_identifier=args.plugin_unique_identifier,
|
||||
name=args.name,
|
||||
settings=args.settings,
|
||||
)
|
||||
}
|
||||
except PluginPermissionDeniedError as e:
|
||||
|
|
@ -65,11 +87,7 @@ class EndpointCreateApi(Resource):
|
|||
class EndpointListApi(Resource):
|
||||
@console_ns.doc("list_endpoints")
|
||||
@console_ns.doc(description="List plugin endpoints with pagination")
|
||||
@console_ns.expect(
|
||||
console_ns.parser()
|
||||
.add_argument("page", type=int, required=True, location="args", help="Page number")
|
||||
.add_argument("page_size", type=int, required=True, location="args", help="Page size")
|
||||
)
|
||||
@console_ns.expect(console_ns.models[EndpointListQuery.__name__])
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Success",
|
||||
|
|
@ -83,15 +101,10 @@ class EndpointListApi(Resource):
|
|||
def get(self):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("page", type=int, required=True, location="args")
|
||||
.add_argument("page_size", type=int, required=True, location="args")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = EndpointListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
page = args["page"]
|
||||
page_size = args["page_size"]
|
||||
page = args.page
|
||||
page_size = args.page_size
|
||||
|
||||
return jsonable_encoder(
|
||||
{
|
||||
|
|
@ -109,12 +122,7 @@ class EndpointListApi(Resource):
|
|||
class EndpointListForSinglePluginApi(Resource):
|
||||
@console_ns.doc("list_plugin_endpoints")
|
||||
@console_ns.doc(description="List endpoints for a specific plugin")
|
||||
@console_ns.expect(
|
||||
console_ns.parser()
|
||||
.add_argument("page", type=int, required=True, location="args", help="Page number")
|
||||
.add_argument("page_size", type=int, required=True, location="args", help="Page size")
|
||||
.add_argument("plugin_id", type=str, required=True, location="args", help="Plugin ID")
|
||||
)
|
||||
@console_ns.expect(console_ns.models[EndpointListForPluginQuery.__name__])
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Success",
|
||||
|
|
@ -128,17 +136,11 @@ class EndpointListForSinglePluginApi(Resource):
|
|||
def get(self):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("page", type=int, required=True, location="args")
|
||||
.add_argument("page_size", type=int, required=True, location="args")
|
||||
.add_argument("plugin_id", type=str, required=True, location="args")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = EndpointListForPluginQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
page = args["page"]
|
||||
page_size = args["page_size"]
|
||||
plugin_id = args["plugin_id"]
|
||||
page = args.page
|
||||
page_size = args.page_size
|
||||
plugin_id = args.plugin_id
|
||||
|
||||
return jsonable_encoder(
|
||||
{
|
||||
|
|
@ -157,11 +159,7 @@ class EndpointListForSinglePluginApi(Resource):
|
|||
class EndpointDeleteApi(Resource):
|
||||
@console_ns.doc("delete_endpoint")
|
||||
@console_ns.doc(description="Delete a plugin endpoint")
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"EndpointDeleteRequest", {"endpoint_id": fields.String(required=True, description="Endpoint ID")}
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[EndpointIdPayload.__name__])
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Endpoint deleted successfully",
|
||||
|
|
@ -175,13 +173,12 @@ class EndpointDeleteApi(Resource):
|
|||
def post(self):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
parser = reqparse.RequestParser().add_argument("endpoint_id", type=str, required=True)
|
||||
args = parser.parse_args()
|
||||
|
||||
endpoint_id = args["endpoint_id"]
|
||||
args = EndpointIdPayload.model_validate(console_ns.payload)
|
||||
|
||||
return {
|
||||
"success": EndpointService.delete_endpoint(tenant_id=tenant_id, user_id=user.id, endpoint_id=endpoint_id)
|
||||
"success": EndpointService.delete_endpoint(
|
||||
tenant_id=tenant_id, user_id=user.id, endpoint_id=args.endpoint_id
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -189,16 +186,7 @@ class EndpointDeleteApi(Resource):
|
|||
class EndpointUpdateApi(Resource):
|
||||
@console_ns.doc("update_endpoint")
|
||||
@console_ns.doc(description="Update a plugin endpoint")
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"EndpointUpdateRequest",
|
||||
{
|
||||
"endpoint_id": fields.String(required=True, description="Endpoint ID"),
|
||||
"settings": fields.Raw(required=True, description="Updated settings"),
|
||||
"name": fields.String(required=True, description="Updated name"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[EndpointUpdatePayload.__name__])
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Endpoint updated successfully",
|
||||
|
|
@ -212,25 +200,15 @@ class EndpointUpdateApi(Resource):
|
|||
def post(self):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("endpoint_id", type=str, required=True)
|
||||
.add_argument("settings", type=dict, required=True)
|
||||
.add_argument("name", type=str, required=True)
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
endpoint_id = args["endpoint_id"]
|
||||
settings = args["settings"]
|
||||
name = args["name"]
|
||||
args = EndpointUpdatePayload.model_validate(console_ns.payload)
|
||||
|
||||
return {
|
||||
"success": EndpointService.update_endpoint(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user.id,
|
||||
endpoint_id=endpoint_id,
|
||||
name=name,
|
||||
settings=settings,
|
||||
endpoint_id=args.endpoint_id,
|
||||
name=args.name,
|
||||
settings=args.settings,
|
||||
)
|
||||
}
|
||||
|
||||
|
|
@ -239,11 +217,7 @@ class EndpointUpdateApi(Resource):
|
|||
class EndpointEnableApi(Resource):
|
||||
@console_ns.doc("enable_endpoint")
|
||||
@console_ns.doc(description="Enable a plugin endpoint")
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"EndpointEnableRequest", {"endpoint_id": fields.String(required=True, description="Endpoint ID")}
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[EndpointIdPayload.__name__])
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Endpoint enabled successfully",
|
||||
|
|
@ -257,13 +231,12 @@ class EndpointEnableApi(Resource):
|
|||
def post(self):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
parser = reqparse.RequestParser().add_argument("endpoint_id", type=str, required=True)
|
||||
args = parser.parse_args()
|
||||
|
||||
endpoint_id = args["endpoint_id"]
|
||||
args = EndpointIdPayload.model_validate(console_ns.payload)
|
||||
|
||||
return {
|
||||
"success": EndpointService.enable_endpoint(tenant_id=tenant_id, user_id=user.id, endpoint_id=endpoint_id)
|
||||
"success": EndpointService.enable_endpoint(
|
||||
tenant_id=tenant_id, user_id=user.id, endpoint_id=args.endpoint_id
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -271,11 +244,7 @@ class EndpointEnableApi(Resource):
|
|||
class EndpointDisableApi(Resource):
|
||||
@console_ns.doc("disable_endpoint")
|
||||
@console_ns.doc(description="Disable a plugin endpoint")
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"EndpointDisableRequest", {"endpoint_id": fields.String(required=True, description="Endpoint ID")}
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[EndpointIdPayload.__name__])
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Endpoint disabled successfully",
|
||||
|
|
@ -289,11 +258,10 @@ class EndpointDisableApi(Resource):
|
|||
def post(self):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
parser = reqparse.RequestParser().add_argument("endpoint_id", type=str, required=True)
|
||||
args = parser.parse_args()
|
||||
|
||||
endpoint_id = args["endpoint_id"]
|
||||
args = EndpointIdPayload.model_validate(console_ns.payload)
|
||||
|
||||
return {
|
||||
"success": EndpointService.disable_endpoint(tenant_id=tenant_id, user_id=user.id, endpoint_id=endpoint_id)
|
||||
"success": EndpointService.disable_endpoint(
|
||||
tenant_id=tenant_id, user_id=user.id, endpoint_id=args.endpoint_id
|
||||
)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
from urllib import parse
|
||||
|
||||
from flask import abort, request
|
||||
from flask_restx import Resource, marshal_with, reqparse
|
||||
from flask_restx import Resource, marshal_with
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
import services
|
||||
from configs import dify_config
|
||||
|
|
@ -31,6 +32,42 @@ from services.account_service import AccountService, RegisterService, TenantServ
|
|||
from services.errors.account import AccountAlreadyInTenantError
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class MemberInvitePayload(BaseModel):
|
||||
emails: list[str] = Field(default_factory=list)
|
||||
role: TenantAccountRole
|
||||
language: str | None = None
|
||||
|
||||
|
||||
class MemberRoleUpdatePayload(BaseModel):
|
||||
role: str
|
||||
|
||||
|
||||
class OwnerTransferEmailPayload(BaseModel):
|
||||
language: str | None = None
|
||||
|
||||
|
||||
class OwnerTransferCheckPayload(BaseModel):
|
||||
code: str
|
||||
token: str
|
||||
|
||||
|
||||
class OwnerTransferPayload(BaseModel):
|
||||
token: str
|
||||
|
||||
|
||||
def reg(cls: type[BaseModel]):
|
||||
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
|
||||
reg(MemberInvitePayload)
|
||||
reg(MemberRoleUpdatePayload)
|
||||
reg(OwnerTransferEmailPayload)
|
||||
reg(OwnerTransferCheckPayload)
|
||||
reg(OwnerTransferPayload)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/members")
|
||||
class MemberListApi(Resource):
|
||||
|
|
@ -48,29 +85,22 @@ class MemberListApi(Resource):
|
|||
return {"result": "success", "accounts": members}, 200
|
||||
|
||||
|
||||
parser_invite = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("emails", type=list, required=True, location="json")
|
||||
.add_argument("role", type=str, required=True, default="admin", location="json")
|
||||
.add_argument("language", type=str, required=False, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/members/invite-email")
|
||||
class MemberInviteEmailApi(Resource):
|
||||
"""Invite a new member by email."""
|
||||
|
||||
@console_ns.expect(parser_invite)
|
||||
@console_ns.expect(console_ns.models[MemberInvitePayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("members")
|
||||
def post(self):
|
||||
args = parser_invite.parse_args()
|
||||
payload = console_ns.payload or {}
|
||||
args = MemberInvitePayload.model_validate(payload)
|
||||
|
||||
invitee_emails = args["emails"]
|
||||
invitee_role = args["role"]
|
||||
interface_language = args["language"]
|
||||
invitee_emails = args.emails
|
||||
invitee_role = args.role
|
||||
interface_language = args.language
|
||||
if not TenantAccountRole.is_non_owner_role(invitee_role):
|
||||
return {"code": "invalid-role", "message": "Invalid role"}, 400
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
|
@ -146,20 +176,18 @@ class MemberCancelInviteApi(Resource):
|
|||
}, 200
|
||||
|
||||
|
||||
parser_update = reqparse.RequestParser().add_argument("role", type=str, required=True, location="json")
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/members/<uuid:member_id>/update-role")
|
||||
class MemberUpdateRoleApi(Resource):
|
||||
"""Update member role."""
|
||||
|
||||
@console_ns.expect(parser_update)
|
||||
@console_ns.expect(console_ns.models[MemberRoleUpdatePayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def put(self, member_id):
|
||||
args = parser_update.parse_args()
|
||||
new_role = args["role"]
|
||||
payload = console_ns.payload or {}
|
||||
args = MemberRoleUpdatePayload.model_validate(payload)
|
||||
new_role = args.role
|
||||
|
||||
if not TenantAccountRole.is_valid_role(new_role):
|
||||
return {"code": "invalid-role", "message": "Invalid role"}, 400
|
||||
|
|
@ -197,20 +225,18 @@ class DatasetOperatorMemberListApi(Resource):
|
|||
return {"result": "success", "accounts": members}, 200
|
||||
|
||||
|
||||
parser_send = reqparse.RequestParser().add_argument("language", type=str, required=False, location="json")
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/members/send-owner-transfer-confirm-email")
|
||||
class SendOwnerTransferEmailApi(Resource):
|
||||
"""Send owner transfer email."""
|
||||
|
||||
@console_ns.expect(parser_send)
|
||||
@console_ns.expect(console_ns.models[OwnerTransferEmailPayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@is_allow_transfer_owner
|
||||
def post(self):
|
||||
args = parser_send.parse_args()
|
||||
payload = console_ns.payload or {}
|
||||
args = OwnerTransferEmailPayload.model_validate(payload)
|
||||
ip_address = extract_remote_ip(request)
|
||||
if AccountService.is_email_send_ip_limit(ip_address):
|
||||
raise EmailSendIpLimitError()
|
||||
|
|
@ -221,7 +247,7 @@ class SendOwnerTransferEmailApi(Resource):
|
|||
if not TenantService.is_owner(current_user, current_user.current_tenant):
|
||||
raise NotOwnerError()
|
||||
|
||||
if args["language"] is not None and args["language"] == "zh-Hans":
|
||||
if args.language is not None and args.language == "zh-Hans":
|
||||
language = "zh-Hans"
|
||||
else:
|
||||
language = "en-US"
|
||||
|
|
@ -238,22 +264,16 @@ class SendOwnerTransferEmailApi(Resource):
|
|||
return {"result": "success", "data": token}
|
||||
|
||||
|
||||
parser_owner = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("code", type=str, required=True, location="json")
|
||||
.add_argument("token", type=str, required=True, nullable=False, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/members/owner-transfer-check")
|
||||
class OwnerTransferCheckApi(Resource):
|
||||
@console_ns.expect(parser_owner)
|
||||
@console_ns.expect(console_ns.models[OwnerTransferCheckPayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@is_allow_transfer_owner
|
||||
def post(self):
|
||||
args = parser_owner.parse_args()
|
||||
payload = console_ns.payload or {}
|
||||
args = OwnerTransferCheckPayload.model_validate(payload)
|
||||
# check if the current user is the owner of the workspace
|
||||
current_user, _ = current_account_with_tenant()
|
||||
if not current_user.current_tenant:
|
||||
|
|
@ -267,41 +287,37 @@ class OwnerTransferCheckApi(Resource):
|
|||
if is_owner_transfer_error_rate_limit:
|
||||
raise OwnerTransferLimitError()
|
||||
|
||||
token_data = AccountService.get_owner_transfer_data(args["token"])
|
||||
token_data = AccountService.get_owner_transfer_data(args.token)
|
||||
if token_data is None:
|
||||
raise InvalidTokenError()
|
||||
|
||||
if user_email != token_data.get("email"):
|
||||
raise InvalidEmailError()
|
||||
|
||||
if args["code"] != token_data.get("code"):
|
||||
if args.code != token_data.get("code"):
|
||||
AccountService.add_owner_transfer_error_rate_limit(user_email)
|
||||
raise EmailCodeError()
|
||||
|
||||
# Verified, revoke the first token
|
||||
AccountService.revoke_owner_transfer_token(args["token"])
|
||||
AccountService.revoke_owner_transfer_token(args.token)
|
||||
|
||||
# Refresh token data by generating a new token
|
||||
_, new_token = AccountService.generate_owner_transfer_token(user_email, code=args["code"], additional_data={})
|
||||
_, new_token = AccountService.generate_owner_transfer_token(user_email, code=args.code, additional_data={})
|
||||
|
||||
AccountService.reset_owner_transfer_error_rate_limit(user_email)
|
||||
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
|
||||
|
||||
|
||||
parser_owner_transfer = reqparse.RequestParser().add_argument(
|
||||
"token", type=str, required=True, nullable=False, location="json"
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/members/<uuid:member_id>/owner-transfer")
|
||||
class OwnerTransfer(Resource):
|
||||
@console_ns.expect(parser_owner_transfer)
|
||||
@console_ns.expect(console_ns.models[OwnerTransferPayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@is_allow_transfer_owner
|
||||
def post(self, member_id):
|
||||
args = parser_owner_transfer.parse_args()
|
||||
payload = console_ns.payload or {}
|
||||
args = OwnerTransferPayload.model_validate(payload)
|
||||
|
||||
# check if the current user is the owner of the workspace
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
|
@ -313,14 +329,14 @@ class OwnerTransfer(Resource):
|
|||
if current_user.id == str(member_id):
|
||||
raise CannotTransferOwnerToSelfError()
|
||||
|
||||
transfer_token_data = AccountService.get_owner_transfer_data(args["token"])
|
||||
transfer_token_data = AccountService.get_owner_transfer_data(args.token)
|
||||
if not transfer_token_data:
|
||||
raise InvalidTokenError()
|
||||
|
||||
if transfer_token_data.get("email") != current_user.email:
|
||||
raise InvalidEmailError()
|
||||
|
||||
AccountService.revoke_owner_transfer_token(args["token"])
|
||||
AccountService.revoke_owner_transfer_token(args.token)
|
||||
|
||||
member = db.session.get(Account, str(member_id))
|
||||
if not member:
|
||||
|
|
|
|||
|
|
@ -1,31 +1,97 @@
|
|||
import io
|
||||
from typing import Any, Literal
|
||||
|
||||
from flask import send_file
|
||||
from flask_restx import Resource, reqparse
|
||||
from flask import request, send_file
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from libs.helper import StrLen, uuid_value
|
||||
from libs.helper import uuid_value
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from services.billing_service import BillingService
|
||||
from services.model_provider_service import ModelProviderService
|
||||
|
||||
parser_model = reqparse.RequestParser().add_argument(
|
||||
"model_type",
|
||||
type=str,
|
||||
required=False,
|
||||
nullable=True,
|
||||
choices=[mt.value for mt in ModelType],
|
||||
location="args",
|
||||
)
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class ParserModelList(BaseModel):
|
||||
model_type: ModelType | None = None
|
||||
|
||||
|
||||
class ParserCredentialId(BaseModel):
|
||||
credential_id: str | None = None
|
||||
|
||||
@field_validator("credential_id")
|
||||
@classmethod
|
||||
def validate_optional_credential_id(cls, value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return value
|
||||
return uuid_value(value)
|
||||
|
||||
|
||||
class ParserCredentialCreate(BaseModel):
|
||||
credentials: dict[str, Any]
|
||||
name: str | None = Field(default=None, max_length=30)
|
||||
|
||||
|
||||
class ParserCredentialUpdate(BaseModel):
|
||||
credential_id: str
|
||||
credentials: dict[str, Any]
|
||||
name: str | None = Field(default=None, max_length=30)
|
||||
|
||||
@field_validator("credential_id")
|
||||
@classmethod
|
||||
def validate_update_credential_id(cls, value: str) -> str:
|
||||
return uuid_value(value)
|
||||
|
||||
|
||||
class ParserCredentialDelete(BaseModel):
|
||||
credential_id: str
|
||||
|
||||
@field_validator("credential_id")
|
||||
@classmethod
|
||||
def validate_delete_credential_id(cls, value: str) -> str:
|
||||
return uuid_value(value)
|
||||
|
||||
|
||||
class ParserCredentialSwitch(BaseModel):
|
||||
credential_id: str
|
||||
|
||||
@field_validator("credential_id")
|
||||
@classmethod
|
||||
def validate_switch_credential_id(cls, value: str) -> str:
|
||||
return uuid_value(value)
|
||||
|
||||
|
||||
class ParserCredentialValidate(BaseModel):
|
||||
credentials: dict[str, Any]
|
||||
|
||||
|
||||
class ParserPreferredProviderType(BaseModel):
|
||||
preferred_provider_type: Literal["system", "custom"]
|
||||
|
||||
|
||||
def reg(cls: type[BaseModel]):
|
||||
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
|
||||
reg(ParserModelList)
|
||||
reg(ParserCredentialId)
|
||||
reg(ParserCredentialCreate)
|
||||
reg(ParserCredentialUpdate)
|
||||
reg(ParserCredentialDelete)
|
||||
reg(ParserCredentialSwitch)
|
||||
reg(ParserCredentialValidate)
|
||||
reg(ParserPreferredProviderType)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/model-providers")
|
||||
class ModelProviderListApi(Resource):
|
||||
@console_ns.expect(parser_model)
|
||||
@console_ns.expect(console_ns.models[ParserModelList.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -33,38 +99,18 @@ class ModelProviderListApi(Resource):
|
|||
_, current_tenant_id = current_account_with_tenant()
|
||||
tenant_id = current_tenant_id
|
||||
|
||||
args = parser_model.parse_args()
|
||||
payload = request.args.to_dict(flat=True) # type: ignore
|
||||
args = ParserModelList.model_validate(payload)
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
provider_list = model_provider_service.get_provider_list(tenant_id=tenant_id, model_type=args.get("model_type"))
|
||||
provider_list = model_provider_service.get_provider_list(tenant_id=tenant_id, model_type=args.model_type)
|
||||
|
||||
return jsonable_encoder({"data": provider_list})
|
||||
|
||||
|
||||
parser_cred = reqparse.RequestParser().add_argument(
|
||||
"credential_id", type=uuid_value, required=False, nullable=True, location="args"
|
||||
)
|
||||
parser_post_cred = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
|
||||
)
|
||||
|
||||
parser_put_cred = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
|
||||
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
|
||||
)
|
||||
|
||||
parser_delete_cred = reqparse.RequestParser().add_argument(
|
||||
"credential_id", type=uuid_value, required=True, nullable=False, location="json"
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/model-providers/<path:provider>/credentials")
|
||||
class ModelProviderCredentialApi(Resource):
|
||||
@console_ns.expect(parser_cred)
|
||||
@console_ns.expect(console_ns.models[ParserCredentialId.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -72,23 +118,25 @@ class ModelProviderCredentialApi(Resource):
|
|||
_, current_tenant_id = current_account_with_tenant()
|
||||
tenant_id = current_tenant_id
|
||||
# if credential_id is not provided, return current used credential
|
||||
args = parser_cred.parse_args()
|
||||
payload = request.args.to_dict(flat=True) # type: ignore
|
||||
args = ParserCredentialId.model_validate(payload)
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
credentials = model_provider_service.get_provider_credential(
|
||||
tenant_id=tenant_id, provider=provider, credential_id=args.get("credential_id")
|
||||
tenant_id=tenant_id, provider=provider, credential_id=args.credential_id
|
||||
)
|
||||
|
||||
return {"credentials": credentials}
|
||||
|
||||
@console_ns.expect(parser_post_cred)
|
||||
@console_ns.expect(console_ns.models[ParserCredentialCreate.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
args = parser_post_cred.parse_args()
|
||||
payload = console_ns.payload or {}
|
||||
args = ParserCredentialCreate.model_validate(payload)
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
|
||||
|
|
@ -96,15 +144,15 @@ class ModelProviderCredentialApi(Resource):
|
|||
model_provider_service.create_provider_credential(
|
||||
tenant_id=current_tenant_id,
|
||||
provider=provider,
|
||||
credentials=args["credentials"],
|
||||
credential_name=args["name"],
|
||||
credentials=args.credentials,
|
||||
credential_name=args.name,
|
||||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
raise ValueError(str(ex))
|
||||
|
||||
return {"result": "success"}, 201
|
||||
|
||||
@console_ns.expect(parser_put_cred)
|
||||
@console_ns.expect(console_ns.models[ParserCredentialUpdate.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
|
|
@ -112,7 +160,8 @@ class ModelProviderCredentialApi(Resource):
|
|||
def put(self, provider: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
args = parser_put_cred.parse_args()
|
||||
payload = console_ns.payload or {}
|
||||
args = ParserCredentialUpdate.model_validate(payload)
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
|
||||
|
|
@ -120,71 +169,64 @@ class ModelProviderCredentialApi(Resource):
|
|||
model_provider_service.update_provider_credential(
|
||||
tenant_id=current_tenant_id,
|
||||
provider=provider,
|
||||
credentials=args["credentials"],
|
||||
credential_id=args["credential_id"],
|
||||
credential_name=args["name"],
|
||||
credentials=args.credentials,
|
||||
credential_id=args.credential_id,
|
||||
credential_name=args.name,
|
||||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
raise ValueError(str(ex))
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
@console_ns.expect(parser_delete_cred)
|
||||
@console_ns.expect(console_ns.models[ParserCredentialDelete.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def delete(self, provider: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
args = parser_delete_cred.parse_args()
|
||||
payload = console_ns.payload or {}
|
||||
args = ParserCredentialDelete.model_validate(payload)
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
model_provider_service.remove_provider_credential(
|
||||
tenant_id=current_tenant_id, provider=provider, credential_id=args["credential_id"]
|
||||
tenant_id=current_tenant_id, provider=provider, credential_id=args.credential_id
|
||||
)
|
||||
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
parser_switch = reqparse.RequestParser().add_argument(
|
||||
"credential_id", type=str, required=True, nullable=False, location="json"
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/model-providers/<path:provider>/credentials/switch")
|
||||
class ModelProviderCredentialSwitchApi(Resource):
|
||||
@console_ns.expect(parser_switch)
|
||||
@console_ns.expect(console_ns.models[ParserCredentialSwitch.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
args = parser_switch.parse_args()
|
||||
payload = console_ns.payload or {}
|
||||
args = ParserCredentialSwitch.model_validate(payload)
|
||||
|
||||
service = ModelProviderService()
|
||||
service.switch_active_provider_credential(
|
||||
tenant_id=current_tenant_id,
|
||||
provider=provider,
|
||||
credential_id=args["credential_id"],
|
||||
credential_id=args.credential_id,
|
||||
)
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
parser_validate = reqparse.RequestParser().add_argument(
|
||||
"credentials", type=dict, required=True, nullable=False, location="json"
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/model-providers/<path:provider>/credentials/validate")
|
||||
class ModelProviderValidateApi(Resource):
|
||||
@console_ns.expect(parser_validate)
|
||||
@console_ns.expect(console_ns.models[ParserCredentialValidate.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
args = parser_validate.parse_args()
|
||||
payload = console_ns.payload or {}
|
||||
args = ParserCredentialValidate.model_validate(payload)
|
||||
|
||||
tenant_id = current_tenant_id
|
||||
|
||||
|
|
@ -195,7 +237,7 @@ class ModelProviderValidateApi(Resource):
|
|||
|
||||
try:
|
||||
model_provider_service.validate_provider_credentials(
|
||||
tenant_id=tenant_id, provider=provider, credentials=args["credentials"]
|
||||
tenant_id=tenant_id, provider=provider, credentials=args.credentials
|
||||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
result = False
|
||||
|
|
@ -228,19 +270,9 @@ class ModelProviderIconApi(Resource):
|
|||
return send_file(io.BytesIO(icon), mimetype=mimetype)
|
||||
|
||||
|
||||
parser_preferred = reqparse.RequestParser().add_argument(
|
||||
"preferred_provider_type",
|
||||
type=str,
|
||||
required=True,
|
||||
nullable=False,
|
||||
choices=["system", "custom"],
|
||||
location="json",
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/model-providers/<path:provider>/preferred-provider-type")
|
||||
class PreferredProviderTypeUpdateApi(Resource):
|
||||
@console_ns.expect(parser_preferred)
|
||||
@console_ns.expect(console_ns.models[ParserPreferredProviderType.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
|
|
@ -250,11 +282,12 @@ class PreferredProviderTypeUpdateApi(Resource):
|
|||
|
||||
tenant_id = current_tenant_id
|
||||
|
||||
args = parser_preferred.parse_args()
|
||||
payload = console_ns.payload or {}
|
||||
args = ParserPreferredProviderType.model_validate(payload)
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
model_provider_service.switch_preferred_provider(
|
||||
tenant_id=tenant_id, provider=provider, preferred_provider_type=args["preferred_provider_type"]
|
||||
tenant_id=tenant_id, provider=provider, preferred_provider_type=args.preferred_provider_type
|
||||
)
|
||||
|
||||
return {"result": "success"}
|
||||
|
|
|
|||
|
|
@ -1,52 +1,144 @@
|
|||
import logging
|
||||
from typing import Any, cast
|
||||
|
||||
from flask_restx import Resource, reqparse
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from libs.helper import StrLen, uuid_value
|
||||
from libs.helper import uuid_value
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from services.model_load_balancing_service import ModelLoadBalancingService
|
||||
from services.model_provider_service import ModelProviderService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
parser_get_default = reqparse.RequestParser().add_argument(
|
||||
"model_type",
|
||||
type=str,
|
||||
required=True,
|
||||
nullable=False,
|
||||
choices=[mt.value for mt in ModelType],
|
||||
location="args",
|
||||
)
|
||||
parser_post_default = reqparse.RequestParser().add_argument(
|
||||
"model_settings", type=list, required=True, nullable=False, location="json"
|
||||
)
|
||||
class ParserGetDefault(BaseModel):
|
||||
model_type: ModelType
|
||||
|
||||
|
||||
class ParserPostDefault(BaseModel):
|
||||
class Inner(BaseModel):
|
||||
model_type: ModelType
|
||||
model: str | None = None
|
||||
provider: str | None = None
|
||||
|
||||
model_settings: list[Inner]
|
||||
|
||||
|
||||
class ParserDeleteModels(BaseModel):
|
||||
model: str
|
||||
model_type: ModelType
|
||||
|
||||
|
||||
class LoadBalancingPayload(BaseModel):
|
||||
configs: list[dict[str, Any]] | None = None
|
||||
enabled: bool | None = None
|
||||
|
||||
|
||||
class ParserPostModels(BaseModel):
|
||||
model: str
|
||||
model_type: ModelType
|
||||
load_balancing: LoadBalancingPayload | None = None
|
||||
config_from: str | None = None
|
||||
credential_id: str | None = None
|
||||
|
||||
@field_validator("credential_id")
|
||||
@classmethod
|
||||
def validate_credential_id(cls, value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return value
|
||||
return uuid_value(value)
|
||||
|
||||
|
||||
class ParserGetCredentials(BaseModel):
|
||||
model: str
|
||||
model_type: ModelType
|
||||
config_from: str | None = None
|
||||
credential_id: str | None = None
|
||||
|
||||
@field_validator("credential_id")
|
||||
@classmethod
|
||||
def validate_get_credential_id(cls, value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return value
|
||||
return uuid_value(value)
|
||||
|
||||
|
||||
class ParserCredentialBase(BaseModel):
|
||||
model: str
|
||||
model_type: ModelType
|
||||
|
||||
|
||||
class ParserCreateCredential(ParserCredentialBase):
|
||||
name: str | None = Field(default=None, max_length=30)
|
||||
credentials: dict[str, Any]
|
||||
|
||||
|
||||
class ParserUpdateCredential(ParserCredentialBase):
|
||||
credential_id: str
|
||||
credentials: dict[str, Any]
|
||||
name: str | None = Field(default=None, max_length=30)
|
||||
|
||||
@field_validator("credential_id")
|
||||
@classmethod
|
||||
def validate_update_credential_id(cls, value: str) -> str:
|
||||
return uuid_value(value)
|
||||
|
||||
|
||||
class ParserDeleteCredential(ParserCredentialBase):
|
||||
credential_id: str
|
||||
|
||||
@field_validator("credential_id")
|
||||
@classmethod
|
||||
def validate_delete_credential_id(cls, value: str) -> str:
|
||||
return uuid_value(value)
|
||||
|
||||
|
||||
class ParserParameter(BaseModel):
|
||||
model: str
|
||||
|
||||
|
||||
def reg(cls: type[BaseModel]):
|
||||
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
|
||||
reg(ParserGetDefault)
|
||||
reg(ParserPostDefault)
|
||||
reg(ParserDeleteModels)
|
||||
reg(ParserPostModels)
|
||||
reg(ParserGetCredentials)
|
||||
reg(ParserCreateCredential)
|
||||
reg(ParserUpdateCredential)
|
||||
reg(ParserDeleteCredential)
|
||||
reg(ParserParameter)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/default-model")
|
||||
class DefaultModelApi(Resource):
|
||||
@console_ns.expect(parser_get_default)
|
||||
@console_ns.expect(console_ns.models[ParserGetDefault.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
args = parser_get_default.parse_args()
|
||||
args = ParserGetDefault.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
default_model_entity = model_provider_service.get_default_model_of_model_type(
|
||||
tenant_id=tenant_id, model_type=args["model_type"]
|
||||
tenant_id=tenant_id, model_type=args.model_type
|
||||
)
|
||||
|
||||
return jsonable_encoder({"data": default_model_entity})
|
||||
|
||||
@console_ns.expect(parser_post_default)
|
||||
@console_ns.expect(console_ns.models[ParserPostDefault.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
|
|
@ -54,66 +146,31 @@ class DefaultModelApi(Resource):
|
|||
def post(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
args = parser_post_default.parse_args()
|
||||
args = ParserPostDefault.model_validate(console_ns.payload)
|
||||
model_provider_service = ModelProviderService()
|
||||
model_settings = args["model_settings"]
|
||||
model_settings = args.model_settings
|
||||
for model_setting in model_settings:
|
||||
if "model_type" not in model_setting or model_setting["model_type"] not in [mt.value for mt in ModelType]:
|
||||
raise ValueError("invalid model type")
|
||||
|
||||
if "provider" not in model_setting:
|
||||
if model_setting.provider is None:
|
||||
continue
|
||||
|
||||
if "model" not in model_setting:
|
||||
raise ValueError("invalid model")
|
||||
|
||||
try:
|
||||
model_provider_service.update_default_model_of_model_type(
|
||||
tenant_id=tenant_id,
|
||||
model_type=model_setting["model_type"],
|
||||
provider=model_setting["provider"],
|
||||
model=model_setting["model"],
|
||||
model_type=model_setting.model_type,
|
||||
provider=model_setting.provider,
|
||||
model=cast(str, model_setting.model),
|
||||
)
|
||||
except Exception as ex:
|
||||
logger.exception(
|
||||
"Failed to update default model, model type: %s, model: %s",
|
||||
model_setting["model_type"],
|
||||
model_setting.get("model"),
|
||||
model_setting.model_type,
|
||||
model_setting.model,
|
||||
)
|
||||
raise ex
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
parser_post_models = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("model", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument(
|
||||
"model_type",
|
||||
type=str,
|
||||
required=True,
|
||||
nullable=False,
|
||||
choices=[mt.value for mt in ModelType],
|
||||
location="json",
|
||||
)
|
||||
.add_argument("load_balancing", type=dict, required=False, nullable=True, location="json")
|
||||
.add_argument("config_from", type=str, required=False, nullable=True, location="json")
|
||||
.add_argument("credential_id", type=uuid_value, required=False, nullable=True, location="json")
|
||||
)
|
||||
parser_delete_models = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("model", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument(
|
||||
"model_type",
|
||||
type=str,
|
||||
required=True,
|
||||
nullable=False,
|
||||
choices=[mt.value for mt in ModelType],
|
||||
location="json",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/model-providers/<path:provider>/models")
|
||||
class ModelProviderModelApi(Resource):
|
||||
@setup_required
|
||||
|
|
@ -127,7 +184,7 @@ class ModelProviderModelApi(Resource):
|
|||
|
||||
return jsonable_encoder({"data": models})
|
||||
|
||||
@console_ns.expect(parser_post_models)
|
||||
@console_ns.expect(console_ns.models[ParserPostModels.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
|
|
@ -135,45 +192,45 @@ class ModelProviderModelApi(Resource):
|
|||
def post(self, provider: str):
|
||||
# To save the model's load balance configs
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
args = parser_post_models.parse_args()
|
||||
args = ParserPostModels.model_validate(console_ns.payload)
|
||||
|
||||
if args.get("config_from", "") == "custom-model":
|
||||
if not args.get("credential_id"):
|
||||
if args.config_from == "custom-model":
|
||||
if not args.credential_id:
|
||||
raise ValueError("credential_id is required when configuring a custom-model")
|
||||
service = ModelProviderService()
|
||||
service.switch_active_custom_model_credential(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
model_type=args["model_type"],
|
||||
model=args["model"],
|
||||
credential_id=args["credential_id"],
|
||||
model_type=args.model_type,
|
||||
model=args.model,
|
||||
credential_id=args.credential_id,
|
||||
)
|
||||
|
||||
model_load_balancing_service = ModelLoadBalancingService()
|
||||
|
||||
if "load_balancing" in args and args["load_balancing"] and "configs" in args["load_balancing"]:
|
||||
if args.load_balancing and args.load_balancing.configs:
|
||||
# save load balancing configs
|
||||
model_load_balancing_service.update_load_balancing_configs(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
model=args["model"],
|
||||
model_type=args["model_type"],
|
||||
configs=args["load_balancing"]["configs"],
|
||||
config_from=args.get("config_from", ""),
|
||||
model=args.model,
|
||||
model_type=args.model_type,
|
||||
configs=args.load_balancing.configs,
|
||||
config_from=args.config_from or "",
|
||||
)
|
||||
|
||||
if args.get("load_balancing", {}).get("enabled"):
|
||||
if args.load_balancing.enabled:
|
||||
model_load_balancing_service.enable_model_load_balancing(
|
||||
tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"]
|
||||
tenant_id=tenant_id, provider=provider, model=args.model, model_type=args.model_type
|
||||
)
|
||||
else:
|
||||
model_load_balancing_service.disable_model_load_balancing(
|
||||
tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"]
|
||||
tenant_id=tenant_id, provider=provider, model=args.model, model_type=args.model_type
|
||||
)
|
||||
|
||||
return {"result": "success"}, 200
|
||||
|
||||
@console_ns.expect(parser_delete_models)
|
||||
@console_ns.expect(console_ns.models[ParserDeleteModels.__name__], validate=True)
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
|
|
@ -181,113 +238,53 @@ class ModelProviderModelApi(Resource):
|
|||
def delete(self, provider: str):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
args = parser_delete_models.parse_args()
|
||||
args = ParserDeleteModels.model_validate(console_ns.payload)
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
model_provider_service.remove_model(
|
||||
tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"]
|
||||
tenant_id=tenant_id, provider=provider, model=args.model, model_type=args.model_type
|
||||
)
|
||||
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
parser_get_credentials = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("model", type=str, required=True, nullable=False, location="args")
|
||||
.add_argument(
|
||||
"model_type",
|
||||
type=str,
|
||||
required=True,
|
||||
nullable=False,
|
||||
choices=[mt.value for mt in ModelType],
|
||||
location="args",
|
||||
)
|
||||
.add_argument("config_from", type=str, required=False, nullable=True, location="args")
|
||||
.add_argument("credential_id", type=uuid_value, required=False, nullable=True, location="args")
|
||||
)
|
||||
|
||||
|
||||
parser_post_cred = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("model", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument(
|
||||
"model_type",
|
||||
type=str,
|
||||
required=True,
|
||||
nullable=False,
|
||||
choices=[mt.value for mt in ModelType],
|
||||
location="json",
|
||||
)
|
||||
.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
|
||||
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
)
|
||||
parser_put_cred = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("model", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument(
|
||||
"model_type",
|
||||
type=str,
|
||||
required=True,
|
||||
nullable=False,
|
||||
choices=[mt.value for mt in ModelType],
|
||||
location="json",
|
||||
)
|
||||
.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
|
||||
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
|
||||
)
|
||||
parser_delete_cred = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("model", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument(
|
||||
"model_type",
|
||||
type=str,
|
||||
required=True,
|
||||
nullable=False,
|
||||
choices=[mt.value for mt in ModelType],
|
||||
location="json",
|
||||
)
|
||||
.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/model-providers/<path:provider>/models/credentials")
|
||||
class ModelProviderModelCredentialApi(Resource):
|
||||
@console_ns.expect(parser_get_credentials)
|
||||
@console_ns.expect(console_ns.models[ParserGetCredentials.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider: str):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
args = parser_get_credentials.parse_args()
|
||||
args = ParserGetCredentials.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
current_credential = model_provider_service.get_model_credential(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
model_type=args["model_type"],
|
||||
model=args["model"],
|
||||
credential_id=args.get("credential_id"),
|
||||
model_type=args.model_type,
|
||||
model=args.model,
|
||||
credential_id=args.credential_id,
|
||||
)
|
||||
|
||||
model_load_balancing_service = ModelLoadBalancingService()
|
||||
is_load_balancing_enabled, load_balancing_configs = model_load_balancing_service.get_load_balancing_configs(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
model=args["model"],
|
||||
model_type=args["model_type"],
|
||||
config_from=args.get("config_from", ""),
|
||||
model=args.model,
|
||||
model_type=args.model_type,
|
||||
config_from=args.config_from or "",
|
||||
)
|
||||
|
||||
if args.get("config_from", "") == "predefined-model":
|
||||
if args.config_from == "predefined-model":
|
||||
available_credentials = model_provider_service.provider_manager.get_provider_available_credentials(
|
||||
tenant_id=tenant_id, provider_name=provider
|
||||
)
|
||||
else:
|
||||
model_type = ModelType.value_of(args["model_type"]).to_origin_model_type()
|
||||
model_type = args.model_type
|
||||
available_credentials = model_provider_service.provider_manager.get_provider_model_available_credentials(
|
||||
tenant_id=tenant_id, provider_name=provider, model_type=model_type, model_name=args["model"]
|
||||
tenant_id=tenant_id, provider_name=provider, model_type=model_type, model_name=args.model
|
||||
)
|
||||
|
||||
return jsonable_encoder(
|
||||
|
|
@ -304,7 +301,7 @@ class ModelProviderModelCredentialApi(Resource):
|
|||
}
|
||||
)
|
||||
|
||||
@console_ns.expect(parser_post_cred)
|
||||
@console_ns.expect(console_ns.models[ParserCreateCredential.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
|
|
@ -312,7 +309,7 @@ class ModelProviderModelCredentialApi(Resource):
|
|||
def post(self, provider: str):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
args = parser_post_cred.parse_args()
|
||||
args = ParserCreateCredential.model_validate(console_ns.payload)
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
|
||||
|
|
@ -320,30 +317,30 @@ class ModelProviderModelCredentialApi(Resource):
|
|||
model_provider_service.create_model_credential(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
model=args["model"],
|
||||
model_type=args["model_type"],
|
||||
credentials=args["credentials"],
|
||||
credential_name=args["name"],
|
||||
model=args.model,
|
||||
model_type=args.model_type,
|
||||
credentials=args.credentials,
|
||||
credential_name=args.name,
|
||||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
logger.exception(
|
||||
"Failed to save model credentials, tenant_id: %s, model: %s, model_type: %s",
|
||||
tenant_id,
|
||||
args.get("model"),
|
||||
args.get("model_type"),
|
||||
args.model,
|
||||
args.model_type,
|
||||
)
|
||||
raise ValueError(str(ex))
|
||||
|
||||
return {"result": "success"}, 201
|
||||
|
||||
@console_ns.expect(parser_put_cred)
|
||||
@console_ns.expect(console_ns.models[ParserUpdateCredential.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def put(self, provider: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
args = parser_put_cred.parse_args()
|
||||
args = ParserUpdateCredential.model_validate(console_ns.payload)
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
|
||||
|
|
@ -351,106 +348,87 @@ class ModelProviderModelCredentialApi(Resource):
|
|||
model_provider_service.update_model_credential(
|
||||
tenant_id=current_tenant_id,
|
||||
provider=provider,
|
||||
model_type=args["model_type"],
|
||||
model=args["model"],
|
||||
credentials=args["credentials"],
|
||||
credential_id=args["credential_id"],
|
||||
credential_name=args["name"],
|
||||
model_type=args.model_type,
|
||||
model=args.model,
|
||||
credentials=args.credentials,
|
||||
credential_id=args.credential_id,
|
||||
credential_name=args.name,
|
||||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
raise ValueError(str(ex))
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
@console_ns.expect(parser_delete_cred)
|
||||
@console_ns.expect(console_ns.models[ParserDeleteCredential.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def delete(self, provider: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
args = parser_delete_cred.parse_args()
|
||||
args = ParserDeleteCredential.model_validate(console_ns.payload)
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
model_provider_service.remove_model_credential(
|
||||
tenant_id=current_tenant_id,
|
||||
provider=provider,
|
||||
model_type=args["model_type"],
|
||||
model=args["model"],
|
||||
credential_id=args["credential_id"],
|
||||
model_type=args.model_type,
|
||||
model=args.model,
|
||||
credential_id=args.credential_id,
|
||||
)
|
||||
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
parser_switch = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("model", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument(
|
||||
"model_type",
|
||||
type=str,
|
||||
required=True,
|
||||
nullable=False,
|
||||
choices=[mt.value for mt in ModelType],
|
||||
location="json",
|
||||
)
|
||||
.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
|
||||
class ParserSwitch(BaseModel):
|
||||
model: str
|
||||
model_type: ModelType
|
||||
credential_id: str
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserSwitch.__name__, ParserSwitch.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/model-providers/<path:provider>/models/credentials/switch")
|
||||
class ModelProviderModelCredentialSwitchApi(Resource):
|
||||
@console_ns.expect(parser_switch)
|
||||
@console_ns.expect(console_ns.models[ParserSwitch.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
args = parser_switch.parse_args()
|
||||
args = ParserSwitch.model_validate(console_ns.payload)
|
||||
|
||||
service = ModelProviderService()
|
||||
service.add_model_credential_to_model_list(
|
||||
tenant_id=current_tenant_id,
|
||||
provider=provider,
|
||||
model_type=args["model_type"],
|
||||
model=args["model"],
|
||||
credential_id=args["credential_id"],
|
||||
model_type=args.model_type,
|
||||
model=args.model,
|
||||
credential_id=args.credential_id,
|
||||
)
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
parser_model_enable_disable = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("model", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument(
|
||||
"model_type",
|
||||
type=str,
|
||||
required=True,
|
||||
nullable=False,
|
||||
choices=[mt.value for mt in ModelType],
|
||||
location="json",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
"/workspaces/current/model-providers/<path:provider>/models/enable", endpoint="model-provider-model-enable"
|
||||
)
|
||||
class ModelProviderModelEnableApi(Resource):
|
||||
@console_ns.expect(parser_model_enable_disable)
|
||||
@console_ns.expect(console_ns.models[ParserDeleteModels.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def patch(self, provider: str):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
args = parser_model_enable_disable.parse_args()
|
||||
args = ParserDeleteModels.model_validate(console_ns.payload)
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
model_provider_service.enable_model(
|
||||
tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"]
|
||||
tenant_id=tenant_id, provider=provider, model=args.model, model_type=args.model_type
|
||||
)
|
||||
|
||||
return {"result": "success"}
|
||||
|
|
@ -460,48 +438,43 @@ class ModelProviderModelEnableApi(Resource):
|
|||
"/workspaces/current/model-providers/<path:provider>/models/disable", endpoint="model-provider-model-disable"
|
||||
)
|
||||
class ModelProviderModelDisableApi(Resource):
|
||||
@console_ns.expect(parser_model_enable_disable)
|
||||
@console_ns.expect(console_ns.models[ParserDeleteModels.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def patch(self, provider: str):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
args = parser_model_enable_disable.parse_args()
|
||||
args = ParserDeleteModels.model_validate(console_ns.payload)
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
model_provider_service.disable_model(
|
||||
tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"]
|
||||
tenant_id=tenant_id, provider=provider, model=args.model, model_type=args.model_type
|
||||
)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
parser_validate = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("model", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument(
|
||||
"model_type",
|
||||
type=str,
|
||||
required=True,
|
||||
nullable=False,
|
||||
choices=[mt.value for mt in ModelType],
|
||||
location="json",
|
||||
)
|
||||
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
class ParserValidate(BaseModel):
|
||||
model: str
|
||||
model_type: ModelType
|
||||
credentials: dict
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserValidate.__name__, ParserValidate.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/model-providers/<path:provider>/models/credentials/validate")
|
||||
class ModelProviderModelValidateApi(Resource):
|
||||
@console_ns.expect(parser_validate)
|
||||
@console_ns.expect(console_ns.models[ParserValidate.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
args = parser_validate.parse_args()
|
||||
args = ParserValidate.model_validate(console_ns.payload)
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
|
||||
|
|
@ -512,9 +485,9 @@ class ModelProviderModelValidateApi(Resource):
|
|||
model_provider_service.validate_model_credentials(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
model=args["model"],
|
||||
model_type=args["model_type"],
|
||||
credentials=args["credentials"],
|
||||
model=args.model,
|
||||
model_type=args.model_type,
|
||||
credentials=args.credentials,
|
||||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
result = False
|
||||
|
|
@ -528,24 +501,19 @@ class ModelProviderModelValidateApi(Resource):
|
|||
return response
|
||||
|
||||
|
||||
parser_parameter = reqparse.RequestParser().add_argument(
|
||||
"model", type=str, required=True, nullable=False, location="args"
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/model-providers/<path:provider>/models/parameter-rules")
|
||||
class ModelProviderModelParameterRuleApi(Resource):
|
||||
@console_ns.expect(parser_parameter)
|
||||
@console_ns.expect(console_ns.models[ParserParameter.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider: str):
|
||||
args = parser_parameter.parse_args()
|
||||
args = ParserParameter.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
parameter_rules = model_provider_service.get_model_parameter_rules(
|
||||
tenant_id=tenant_id, provider=provider, model=args["model"]
|
||||
tenant_id=tenant_id, provider=provider, model=args.model
|
||||
)
|
||||
|
||||
return jsonable_encoder({"data": parameter_rules})
|
||||
|
|
|
|||
|
|
@ -1,7 +1,9 @@
|
|||
import io
|
||||
from typing import Literal
|
||||
|
||||
from flask import request, send_file
|
||||
from flask_restx import Resource, reqparse
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from configs import dify_config
|
||||
|
|
@ -17,6 +19,12 @@ from services.plugin.plugin_parameter_service import PluginParameterService
|
|||
from services.plugin.plugin_permission_service import PluginPermissionService
|
||||
from services.plugin.plugin_service import PluginService
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
def reg(cls: type[BaseModel]):
|
||||
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/debugging-key")
|
||||
class PluginDebuggingKeyApi(Resource):
|
||||
|
|
@ -37,88 +45,194 @@ class PluginDebuggingKeyApi(Resource):
|
|||
raise ValueError(e)
|
||||
|
||||
|
||||
parser_list = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("page", type=int, required=False, location="args", default=1)
|
||||
.add_argument("page_size", type=int, required=False, location="args", default=256)
|
||||
)
|
||||
class ParserList(BaseModel):
|
||||
page: int = Field(default=1)
|
||||
page_size: int = Field(default=256)
|
||||
|
||||
|
||||
reg(ParserList)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/list")
|
||||
class PluginListApi(Resource):
|
||||
@console_ns.expect(parser_list)
|
||||
@console_ns.expect(console_ns.models[ParserList.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
args = parser_list.parse_args()
|
||||
args = ParserList.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
try:
|
||||
plugins_with_total = PluginService.list_with_total(tenant_id, args["page"], args["page_size"])
|
||||
plugins_with_total = PluginService.list_with_total(tenant_id, args.page, args.page_size)
|
||||
except PluginDaemonClientSideError as e:
|
||||
raise ValueError(e)
|
||||
|
||||
return jsonable_encoder({"plugins": plugins_with_total.list, "total": plugins_with_total.total})
|
||||
|
||||
|
||||
parser_latest = reqparse.RequestParser().add_argument("plugin_ids", type=list, required=True, location="json")
|
||||
class ParserLatest(BaseModel):
|
||||
plugin_ids: list[str]
|
||||
|
||||
|
||||
class ParserIcon(BaseModel):
|
||||
tenant_id: str
|
||||
filename: str
|
||||
|
||||
|
||||
class ParserAsset(BaseModel):
|
||||
plugin_unique_identifier: str
|
||||
file_name: str
|
||||
|
||||
|
||||
class ParserGithubUpload(BaseModel):
|
||||
repo: str
|
||||
version: str
|
||||
package: str
|
||||
|
||||
|
||||
class ParserPluginIdentifiers(BaseModel):
|
||||
plugin_unique_identifiers: list[str]
|
||||
|
||||
|
||||
class ParserGithubInstall(BaseModel):
|
||||
plugin_unique_identifier: str
|
||||
repo: str
|
||||
version: str
|
||||
package: str
|
||||
|
||||
|
||||
class ParserPluginIdentifierQuery(BaseModel):
|
||||
plugin_unique_identifier: str
|
||||
|
||||
|
||||
class ParserTasks(BaseModel):
|
||||
page: int
|
||||
page_size: int
|
||||
|
||||
|
||||
class ParserMarketplaceUpgrade(BaseModel):
|
||||
original_plugin_unique_identifier: str
|
||||
new_plugin_unique_identifier: str
|
||||
|
||||
|
||||
class ParserGithubUpgrade(BaseModel):
|
||||
original_plugin_unique_identifier: str
|
||||
new_plugin_unique_identifier: str
|
||||
repo: str
|
||||
version: str
|
||||
package: str
|
||||
|
||||
|
||||
class ParserUninstall(BaseModel):
|
||||
plugin_installation_id: str
|
||||
|
||||
|
||||
class ParserPermissionChange(BaseModel):
|
||||
install_permission: TenantPluginPermission.InstallPermission
|
||||
debug_permission: TenantPluginPermission.DebugPermission
|
||||
|
||||
|
||||
class ParserDynamicOptions(BaseModel):
|
||||
plugin_id: str
|
||||
provider: str
|
||||
action: str
|
||||
parameter: str
|
||||
credential_id: str | None = None
|
||||
provider_type: Literal["tool", "trigger"]
|
||||
|
||||
|
||||
class PluginPermissionSettingsPayload(BaseModel):
|
||||
install_permission: TenantPluginPermission.InstallPermission = TenantPluginPermission.InstallPermission.EVERYONE
|
||||
debug_permission: TenantPluginPermission.DebugPermission = TenantPluginPermission.DebugPermission.EVERYONE
|
||||
|
||||
|
||||
class PluginAutoUpgradeSettingsPayload(BaseModel):
|
||||
strategy_setting: TenantPluginAutoUpgradeStrategy.StrategySetting = (
|
||||
TenantPluginAutoUpgradeStrategy.StrategySetting.FIX_ONLY
|
||||
)
|
||||
upgrade_time_of_day: int = 0
|
||||
upgrade_mode: TenantPluginAutoUpgradeStrategy.UpgradeMode = TenantPluginAutoUpgradeStrategy.UpgradeMode.EXCLUDE
|
||||
exclude_plugins: list[str] = Field(default_factory=list)
|
||||
include_plugins: list[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ParserPreferencesChange(BaseModel):
|
||||
permission: PluginPermissionSettingsPayload
|
||||
auto_upgrade: PluginAutoUpgradeSettingsPayload
|
||||
|
||||
|
||||
class ParserExcludePlugin(BaseModel):
|
||||
plugin_id: str
|
||||
|
||||
|
||||
class ParserReadme(BaseModel):
|
||||
plugin_unique_identifier: str
|
||||
language: str = Field(default="en-US")
|
||||
|
||||
|
||||
reg(ParserLatest)
|
||||
reg(ParserIcon)
|
||||
reg(ParserAsset)
|
||||
reg(ParserGithubUpload)
|
||||
reg(ParserPluginIdentifiers)
|
||||
reg(ParserGithubInstall)
|
||||
reg(ParserPluginIdentifierQuery)
|
||||
reg(ParserTasks)
|
||||
reg(ParserMarketplaceUpgrade)
|
||||
reg(ParserGithubUpgrade)
|
||||
reg(ParserUninstall)
|
||||
reg(ParserPermissionChange)
|
||||
reg(ParserDynamicOptions)
|
||||
reg(ParserPreferencesChange)
|
||||
reg(ParserExcludePlugin)
|
||||
reg(ParserReadme)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/list/latest-versions")
|
||||
class PluginListLatestVersionsApi(Resource):
|
||||
@console_ns.expect(parser_latest)
|
||||
@console_ns.expect(console_ns.models[ParserLatest.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
args = parser_latest.parse_args()
|
||||
args = ParserLatest.model_validate(console_ns.payload)
|
||||
|
||||
try:
|
||||
versions = PluginService.list_latest_versions(args["plugin_ids"])
|
||||
versions = PluginService.list_latest_versions(args.plugin_ids)
|
||||
except PluginDaemonClientSideError as e:
|
||||
raise ValueError(e)
|
||||
|
||||
return jsonable_encoder({"versions": versions})
|
||||
|
||||
|
||||
parser_ids = reqparse.RequestParser().add_argument("plugin_ids", type=list, required=True, location="json")
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/list/installations/ids")
|
||||
class PluginListInstallationsFromIdsApi(Resource):
|
||||
@console_ns.expect(parser_ids)
|
||||
@console_ns.expect(console_ns.models[ParserLatest.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
args = parser_ids.parse_args()
|
||||
args = ParserLatest.model_validate(console_ns.payload)
|
||||
|
||||
try:
|
||||
plugins = PluginService.list_installations_from_ids(tenant_id, args["plugin_ids"])
|
||||
plugins = PluginService.list_installations_from_ids(tenant_id, args.plugin_ids)
|
||||
except PluginDaemonClientSideError as e:
|
||||
raise ValueError(e)
|
||||
|
||||
return jsonable_encoder({"plugins": plugins})
|
||||
|
||||
|
||||
parser_icon = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("tenant_id", type=str, required=True, location="args")
|
||||
.add_argument("filename", type=str, required=True, location="args")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/icon")
|
||||
class PluginIconApi(Resource):
|
||||
@console_ns.expect(parser_icon)
|
||||
@console_ns.expect(console_ns.models[ParserIcon.__name__])
|
||||
@setup_required
|
||||
def get(self):
|
||||
args = parser_icon.parse_args()
|
||||
args = ParserIcon.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
try:
|
||||
icon_bytes, mimetype = PluginService.get_asset(args["tenant_id"], args["filename"])
|
||||
icon_bytes, mimetype = PluginService.get_asset(args.tenant_id, args.filename)
|
||||
except PluginDaemonClientSideError as e:
|
||||
raise ValueError(e)
|
||||
|
||||
|
|
@ -128,20 +242,16 @@ class PluginIconApi(Resource):
|
|||
|
||||
@console_ns.route("/workspaces/current/plugin/asset")
|
||||
class PluginAssetApi(Resource):
|
||||
@console_ns.expect(console_ns.models[ParserAsset.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
req = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("plugin_unique_identifier", type=str, required=True, location="args")
|
||||
.add_argument("file_name", type=str, required=True, location="args")
|
||||
)
|
||||
args = req.parse_args()
|
||||
args = ParserAsset.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
try:
|
||||
binary = PluginService.extract_asset(tenant_id, args["plugin_unique_identifier"], args["file_name"])
|
||||
binary = PluginService.extract_asset(tenant_id, args.plugin_unique_identifier, args.file_name)
|
||||
return send_file(io.BytesIO(binary), mimetype="application/octet-stream")
|
||||
except PluginDaemonClientSideError as e:
|
||||
raise ValueError(e)
|
||||
|
|
@ -171,17 +281,9 @@ class PluginUploadFromPkgApi(Resource):
|
|||
return jsonable_encoder(response)
|
||||
|
||||
|
||||
parser_github = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("repo", type=str, required=True, location="json")
|
||||
.add_argument("version", type=str, required=True, location="json")
|
||||
.add_argument("package", type=str, required=True, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/upload/github")
|
||||
class PluginUploadFromGithubApi(Resource):
|
||||
@console_ns.expect(parser_github)
|
||||
@console_ns.expect(console_ns.models[ParserGithubUpload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -189,10 +291,10 @@ class PluginUploadFromGithubApi(Resource):
|
|||
def post(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
args = parser_github.parse_args()
|
||||
args = ParserGithubUpload.model_validate(console_ns.payload)
|
||||
|
||||
try:
|
||||
response = PluginService.upload_pkg_from_github(tenant_id, args["repo"], args["version"], args["package"])
|
||||
response = PluginService.upload_pkg_from_github(tenant_id, args.repo, args.version, args.package)
|
||||
except PluginDaemonClientSideError as e:
|
||||
raise ValueError(e)
|
||||
|
||||
|
|
@ -223,47 +325,28 @@ class PluginUploadFromBundleApi(Resource):
|
|||
return jsonable_encoder(response)
|
||||
|
||||
|
||||
parser_pkg = reqparse.RequestParser().add_argument(
|
||||
"plugin_unique_identifiers", type=list, required=True, location="json"
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/install/pkg")
|
||||
class PluginInstallFromPkgApi(Resource):
|
||||
@console_ns.expect(parser_pkg)
|
||||
@console_ns.expect(console_ns.models[ParserPluginIdentifiers.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@plugin_permission_required(install_required=True)
|
||||
def post(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
args = parser_pkg.parse_args()
|
||||
|
||||
# check if all plugin_unique_identifiers are valid string
|
||||
for plugin_unique_identifier in args["plugin_unique_identifiers"]:
|
||||
if not isinstance(plugin_unique_identifier, str):
|
||||
raise ValueError("Invalid plugin unique identifier")
|
||||
args = ParserPluginIdentifiers.model_validate(console_ns.payload)
|
||||
|
||||
try:
|
||||
response = PluginService.install_from_local_pkg(tenant_id, args["plugin_unique_identifiers"])
|
||||
response = PluginService.install_from_local_pkg(tenant_id, args.plugin_unique_identifiers)
|
||||
except PluginDaemonClientSideError as e:
|
||||
raise ValueError(e)
|
||||
|
||||
return jsonable_encoder(response)
|
||||
|
||||
|
||||
parser_githubapi = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("repo", type=str, required=True, location="json")
|
||||
.add_argument("version", type=str, required=True, location="json")
|
||||
.add_argument("package", type=str, required=True, location="json")
|
||||
.add_argument("plugin_unique_identifier", type=str, required=True, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/install/github")
|
||||
class PluginInstallFromGithubApi(Resource):
|
||||
@console_ns.expect(parser_githubapi)
|
||||
@console_ns.expect(console_ns.models[ParserGithubInstall.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -271,15 +354,15 @@ class PluginInstallFromGithubApi(Resource):
|
|||
def post(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
args = parser_githubapi.parse_args()
|
||||
args = ParserGithubInstall.model_validate(console_ns.payload)
|
||||
|
||||
try:
|
||||
response = PluginService.install_from_github(
|
||||
tenant_id,
|
||||
args["plugin_unique_identifier"],
|
||||
args["repo"],
|
||||
args["version"],
|
||||
args["package"],
|
||||
args.plugin_unique_identifier,
|
||||
args.repo,
|
||||
args.version,
|
||||
args.package,
|
||||
)
|
||||
except PluginDaemonClientSideError as e:
|
||||
raise ValueError(e)
|
||||
|
|
@ -287,14 +370,9 @@ class PluginInstallFromGithubApi(Resource):
|
|||
return jsonable_encoder(response)
|
||||
|
||||
|
||||
parser_marketplace = reqparse.RequestParser().add_argument(
|
||||
"plugin_unique_identifiers", type=list, required=True, location="json"
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/install/marketplace")
|
||||
class PluginInstallFromMarketplaceApi(Resource):
|
||||
@console_ns.expect(parser_marketplace)
|
||||
@console_ns.expect(console_ns.models[ParserPluginIdentifiers.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -302,43 +380,33 @@ class PluginInstallFromMarketplaceApi(Resource):
|
|||
def post(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
args = parser_marketplace.parse_args()
|
||||
|
||||
# check if all plugin_unique_identifiers are valid string
|
||||
for plugin_unique_identifier in args["plugin_unique_identifiers"]:
|
||||
if not isinstance(plugin_unique_identifier, str):
|
||||
raise ValueError("Invalid plugin unique identifier")
|
||||
args = ParserPluginIdentifiers.model_validate(console_ns.payload)
|
||||
|
||||
try:
|
||||
response = PluginService.install_from_marketplace_pkg(tenant_id, args["plugin_unique_identifiers"])
|
||||
response = PluginService.install_from_marketplace_pkg(tenant_id, args.plugin_unique_identifiers)
|
||||
except PluginDaemonClientSideError as e:
|
||||
raise ValueError(e)
|
||||
|
||||
return jsonable_encoder(response)
|
||||
|
||||
|
||||
parser_pkgapi = reqparse.RequestParser().add_argument(
|
||||
"plugin_unique_identifier", type=str, required=True, location="args"
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/marketplace/pkg")
|
||||
class PluginFetchMarketplacePkgApi(Resource):
|
||||
@console_ns.expect(parser_pkgapi)
|
||||
@console_ns.expect(console_ns.models[ParserPluginIdentifierQuery.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@plugin_permission_required(install_required=True)
|
||||
def get(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
args = parser_pkgapi.parse_args()
|
||||
args = ParserPluginIdentifierQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
try:
|
||||
return jsonable_encoder(
|
||||
{
|
||||
"manifest": PluginService.fetch_marketplace_pkg(
|
||||
tenant_id,
|
||||
args["plugin_unique_identifier"],
|
||||
args.plugin_unique_identifier,
|
||||
)
|
||||
}
|
||||
)
|
||||
|
|
@ -346,14 +414,9 @@ class PluginFetchMarketplacePkgApi(Resource):
|
|||
raise ValueError(e)
|
||||
|
||||
|
||||
parser_fetch = reqparse.RequestParser().add_argument(
|
||||
"plugin_unique_identifier", type=str, required=True, location="args"
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/fetch-manifest")
|
||||
class PluginFetchManifestApi(Resource):
|
||||
@console_ns.expect(parser_fetch)
|
||||
@console_ns.expect(console_ns.models[ParserPluginIdentifierQuery.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -361,30 +424,19 @@ class PluginFetchManifestApi(Resource):
|
|||
def get(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
args = parser_fetch.parse_args()
|
||||
args = ParserPluginIdentifierQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
try:
|
||||
return jsonable_encoder(
|
||||
{
|
||||
"manifest": PluginService.fetch_plugin_manifest(
|
||||
tenant_id, args["plugin_unique_identifier"]
|
||||
).model_dump()
|
||||
}
|
||||
{"manifest": PluginService.fetch_plugin_manifest(tenant_id, args.plugin_unique_identifier).model_dump()}
|
||||
)
|
||||
except PluginDaemonClientSideError as e:
|
||||
raise ValueError(e)
|
||||
|
||||
|
||||
parser_tasks = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("page", type=int, required=True, location="args")
|
||||
.add_argument("page_size", type=int, required=True, location="args")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/tasks")
|
||||
class PluginFetchInstallTasksApi(Resource):
|
||||
@console_ns.expect(parser_tasks)
|
||||
@console_ns.expect(console_ns.models[ParserTasks.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -392,12 +444,10 @@ class PluginFetchInstallTasksApi(Resource):
|
|||
def get(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
args = parser_tasks.parse_args()
|
||||
args = ParserTasks.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
try:
|
||||
return jsonable_encoder(
|
||||
{"tasks": PluginService.fetch_install_tasks(tenant_id, args["page"], args["page_size"])}
|
||||
)
|
||||
return jsonable_encoder({"tasks": PluginService.fetch_install_tasks(tenant_id, args.page, args.page_size)})
|
||||
except PluginDaemonClientSideError as e:
|
||||
raise ValueError(e)
|
||||
|
||||
|
|
@ -462,16 +512,9 @@ class PluginDeleteInstallTaskItemApi(Resource):
|
|||
raise ValueError(e)
|
||||
|
||||
|
||||
parser_marketplace_api = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("original_plugin_unique_identifier", type=str, required=True, location="json")
|
||||
.add_argument("new_plugin_unique_identifier", type=str, required=True, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/upgrade/marketplace")
|
||||
class PluginUpgradeFromMarketplaceApi(Resource):
|
||||
@console_ns.expect(parser_marketplace_api)
|
||||
@console_ns.expect(console_ns.models[ParserMarketplaceUpgrade.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -479,31 +522,21 @@ class PluginUpgradeFromMarketplaceApi(Resource):
|
|||
def post(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
args = parser_marketplace_api.parse_args()
|
||||
args = ParserMarketplaceUpgrade.model_validate(console_ns.payload)
|
||||
|
||||
try:
|
||||
return jsonable_encoder(
|
||||
PluginService.upgrade_plugin_with_marketplace(
|
||||
tenant_id, args["original_plugin_unique_identifier"], args["new_plugin_unique_identifier"]
|
||||
tenant_id, args.original_plugin_unique_identifier, args.new_plugin_unique_identifier
|
||||
)
|
||||
)
|
||||
except PluginDaemonClientSideError as e:
|
||||
raise ValueError(e)
|
||||
|
||||
|
||||
parser_github_post = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("original_plugin_unique_identifier", type=str, required=True, location="json")
|
||||
.add_argument("new_plugin_unique_identifier", type=str, required=True, location="json")
|
||||
.add_argument("repo", type=str, required=True, location="json")
|
||||
.add_argument("version", type=str, required=True, location="json")
|
||||
.add_argument("package", type=str, required=True, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/upgrade/github")
|
||||
class PluginUpgradeFromGithubApi(Resource):
|
||||
@console_ns.expect(parser_github_post)
|
||||
@console_ns.expect(console_ns.models[ParserGithubUpgrade.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -511,56 +544,44 @@ class PluginUpgradeFromGithubApi(Resource):
|
|||
def post(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
args = parser_github_post.parse_args()
|
||||
args = ParserGithubUpgrade.model_validate(console_ns.payload)
|
||||
|
||||
try:
|
||||
return jsonable_encoder(
|
||||
PluginService.upgrade_plugin_with_github(
|
||||
tenant_id,
|
||||
args["original_plugin_unique_identifier"],
|
||||
args["new_plugin_unique_identifier"],
|
||||
args["repo"],
|
||||
args["version"],
|
||||
args["package"],
|
||||
args.original_plugin_unique_identifier,
|
||||
args.new_plugin_unique_identifier,
|
||||
args.repo,
|
||||
args.version,
|
||||
args.package,
|
||||
)
|
||||
)
|
||||
except PluginDaemonClientSideError as e:
|
||||
raise ValueError(e)
|
||||
|
||||
|
||||
parser_uninstall = reqparse.RequestParser().add_argument(
|
||||
"plugin_installation_id", type=str, required=True, location="json"
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/uninstall")
|
||||
class PluginUninstallApi(Resource):
|
||||
@console_ns.expect(parser_uninstall)
|
||||
@console_ns.expect(console_ns.models[ParserUninstall.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@plugin_permission_required(install_required=True)
|
||||
def post(self):
|
||||
args = parser_uninstall.parse_args()
|
||||
args = ParserUninstall.model_validate(console_ns.payload)
|
||||
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
try:
|
||||
return {"success": PluginService.uninstall(tenant_id, args["plugin_installation_id"])}
|
||||
return {"success": PluginService.uninstall(tenant_id, args.plugin_installation_id)}
|
||||
except PluginDaemonClientSideError as e:
|
||||
raise ValueError(e)
|
||||
|
||||
|
||||
parser_change_post = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("install_permission", type=str, required=True, location="json")
|
||||
.add_argument("debug_permission", type=str, required=True, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/permission/change")
|
||||
class PluginChangePermissionApi(Resource):
|
||||
@console_ns.expect(parser_change_post)
|
||||
@console_ns.expect(console_ns.models[ParserPermissionChange.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -570,14 +591,15 @@ class PluginChangePermissionApi(Resource):
|
|||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
args = parser_change_post.parse_args()
|
||||
|
||||
install_permission = TenantPluginPermission.InstallPermission(args["install_permission"])
|
||||
debug_permission = TenantPluginPermission.DebugPermission(args["debug_permission"])
|
||||
args = ParserPermissionChange.model_validate(console_ns.payload)
|
||||
|
||||
tenant_id = current_tenant_id
|
||||
|
||||
return {"success": PluginPermissionService.change_permission(tenant_id, install_permission, debug_permission)}
|
||||
return {
|
||||
"success": PluginPermissionService.change_permission(
|
||||
tenant_id, args.install_permission, args.debug_permission
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/permission/fetch")
|
||||
|
|
@ -605,20 +627,9 @@ class PluginFetchPermissionApi(Resource):
|
|||
)
|
||||
|
||||
|
||||
parser_dynamic = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("plugin_id", type=str, required=True, location="args")
|
||||
.add_argument("provider", type=str, required=True, location="args")
|
||||
.add_argument("action", type=str, required=True, location="args")
|
||||
.add_argument("parameter", type=str, required=True, location="args")
|
||||
.add_argument("credential_id", type=str, required=False, location="args")
|
||||
.add_argument("provider_type", type=str, required=True, location="args")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/parameters/dynamic-options")
|
||||
class PluginFetchDynamicSelectOptionsApi(Resource):
|
||||
@console_ns.expect(parser_dynamic)
|
||||
@console_ns.expect(console_ns.models[ParserDynamicOptions.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
|
|
@ -627,18 +638,18 @@ class PluginFetchDynamicSelectOptionsApi(Resource):
|
|||
current_user, tenant_id = current_account_with_tenant()
|
||||
user_id = current_user.id
|
||||
|
||||
args = parser_dynamic.parse_args()
|
||||
args = ParserDynamicOptions.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
try:
|
||||
options = PluginParameterService.get_dynamic_select_options(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
plugin_id=args["plugin_id"],
|
||||
provider=args["provider"],
|
||||
action=args["action"],
|
||||
parameter=args["parameter"],
|
||||
credential_id=args["credential_id"],
|
||||
provider_type=args["provider_type"],
|
||||
plugin_id=args.plugin_id,
|
||||
provider=args.provider,
|
||||
action=args.action,
|
||||
parameter=args.parameter,
|
||||
credential_id=args.credential_id,
|
||||
provider_type=args.provider_type,
|
||||
)
|
||||
except PluginDaemonClientSideError as e:
|
||||
raise ValueError(e)
|
||||
|
|
@ -646,16 +657,9 @@ class PluginFetchDynamicSelectOptionsApi(Resource):
|
|||
return jsonable_encoder({"options": options})
|
||||
|
||||
|
||||
parser_change = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("permission", type=dict, required=True, location="json")
|
||||
.add_argument("auto_upgrade", type=dict, required=True, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/preferences/change")
|
||||
class PluginChangePreferencesApi(Resource):
|
||||
@console_ns.expect(parser_change)
|
||||
@console_ns.expect(console_ns.models[ParserPreferencesChange.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -664,22 +668,20 @@ class PluginChangePreferencesApi(Resource):
|
|||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
args = parser_change.parse_args()
|
||||
args = ParserPreferencesChange.model_validate(console_ns.payload)
|
||||
|
||||
permission = args["permission"]
|
||||
permission = args.permission
|
||||
|
||||
install_permission = TenantPluginPermission.InstallPermission(permission.get("install_permission", "everyone"))
|
||||
debug_permission = TenantPluginPermission.DebugPermission(permission.get("debug_permission", "everyone"))
|
||||
install_permission = permission.install_permission
|
||||
debug_permission = permission.debug_permission
|
||||
|
||||
auto_upgrade = args["auto_upgrade"]
|
||||
auto_upgrade = args.auto_upgrade
|
||||
|
||||
strategy_setting = TenantPluginAutoUpgradeStrategy.StrategySetting(
|
||||
auto_upgrade.get("strategy_setting", "fix_only")
|
||||
)
|
||||
upgrade_time_of_day = auto_upgrade.get("upgrade_time_of_day", 0)
|
||||
upgrade_mode = TenantPluginAutoUpgradeStrategy.UpgradeMode(auto_upgrade.get("upgrade_mode", "exclude"))
|
||||
exclude_plugins = auto_upgrade.get("exclude_plugins", [])
|
||||
include_plugins = auto_upgrade.get("include_plugins", [])
|
||||
strategy_setting = auto_upgrade.strategy_setting
|
||||
upgrade_time_of_day = auto_upgrade.upgrade_time_of_day
|
||||
upgrade_mode = auto_upgrade.upgrade_mode
|
||||
exclude_plugins = auto_upgrade.exclude_plugins
|
||||
include_plugins = auto_upgrade.include_plugins
|
||||
|
||||
# set permission
|
||||
set_permission_result = PluginPermissionService.change_permission(
|
||||
|
|
@ -744,12 +746,9 @@ class PluginFetchPreferencesApi(Resource):
|
|||
return jsonable_encoder({"permission": permission_dict, "auto_upgrade": auto_upgrade_dict})
|
||||
|
||||
|
||||
parser_exclude = reqparse.RequestParser().add_argument("plugin_id", type=str, required=True, location="json")
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/preferences/autoupgrade/exclude")
|
||||
class PluginAutoUpgradeExcludePluginApi(Resource):
|
||||
@console_ns.expect(parser_exclude)
|
||||
@console_ns.expect(console_ns.models[ParserExcludePlugin.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -757,28 +756,20 @@ class PluginAutoUpgradeExcludePluginApi(Resource):
|
|||
# exclude one single plugin
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
args = parser_exclude.parse_args()
|
||||
args = ParserExcludePlugin.model_validate(console_ns.payload)
|
||||
|
||||
return jsonable_encoder({"success": PluginAutoUpgradeService.exclude_plugin(tenant_id, args["plugin_id"])})
|
||||
return jsonable_encoder({"success": PluginAutoUpgradeService.exclude_plugin(tenant_id, args.plugin_id)})
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/readme")
|
||||
class PluginReadmeApi(Resource):
|
||||
@console_ns.expect(console_ns.models[ParserReadme.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("plugin_unique_identifier", type=str, required=True, location="args")
|
||||
.add_argument("language", type=str, required=False, location="args")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = ParserReadme.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
return jsonable_encoder(
|
||||
{
|
||||
"readme": PluginService.fetch_plugin_readme(
|
||||
tenant_id, args["plugin_unique_identifier"], args.get("language", "en-US")
|
||||
)
|
||||
}
|
||||
{"readme": PluginService.fetch_plugin_readme(tenant_id, args.plugin_unique_identifier, args.language)}
|
||||
)
|
||||
|
|
|
|||
|
|
@ -6,8 +6,6 @@ from sqlalchemy.orm import Session
|
|||
from werkzeug.exceptions import BadRequest, Forbidden
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required
|
||||
from controllers.web.error import NotFoundError
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
|
|
@ -23,9 +21,13 @@ from services.trigger.trigger_provider_service import TriggerProviderService
|
|||
from services.trigger.trigger_subscription_builder_service import TriggerSubscriptionBuilderService
|
||||
from services.trigger.trigger_subscription_operator_service import TriggerSubscriptionOperatorService
|
||||
|
||||
from .. import console_ns
|
||||
from ..wraps import account_initialization_required, is_admin_or_owner_required, setup_required
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/trigger-provider/<path:provider>/icon")
|
||||
class TriggerProviderIconApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -38,6 +40,7 @@ class TriggerProviderIconApi(Resource):
|
|||
return TriggerManager.get_trigger_plugin_icon(tenant_id=user.current_tenant_id, provider_id=provider)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/triggers")
|
||||
class TriggerProviderListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -50,6 +53,7 @@ class TriggerProviderListApi(Resource):
|
|||
return jsonable_encoder(TriggerProviderService.list_trigger_providers(user.current_tenant_id))
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/trigger-provider/<path:provider>/info")
|
||||
class TriggerProviderInfoApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -64,6 +68,7 @@ class TriggerProviderInfoApi(Resource):
|
|||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/trigger-provider/<path:provider>/subscriptions/list")
|
||||
class TriggerSubscriptionListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -87,7 +92,16 @@ class TriggerSubscriptionListApi(Resource):
|
|||
raise
|
||||
|
||||
|
||||
parser = reqparse.RequestParser().add_argument(
|
||||
"credential_type", type=str, required=False, nullable=True, location="json"
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/create",
|
||||
)
|
||||
class TriggerSubscriptionBuilderCreateApi(Resource):
|
||||
@console_ns.expect(parser)
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
|
|
@ -97,9 +111,6 @@ class TriggerSubscriptionBuilderCreateApi(Resource):
|
|||
user = current_user
|
||||
assert user.current_tenant_id is not None
|
||||
|
||||
parser = reqparse.RequestParser().add_argument(
|
||||
"credential_type", type=str, required=False, nullable=True, location="json"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
|
|
@ -116,6 +127,9 @@ class TriggerSubscriptionBuilderCreateApi(Resource):
|
|||
raise
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/<path:subscription_builder_id>",
|
||||
)
|
||||
class TriggerSubscriptionBuilderGetApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -127,7 +141,18 @@ class TriggerSubscriptionBuilderGetApi(Resource):
|
|||
)
|
||||
|
||||
|
||||
parser_api = (
|
||||
reqparse.RequestParser()
|
||||
# The credentials of the subscription builder
|
||||
.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/verify/<path:subscription_builder_id>",
|
||||
)
|
||||
class TriggerSubscriptionBuilderVerifyApi(Resource):
|
||||
@console_ns.expect(parser_api)
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
|
|
@ -136,12 +161,8 @@ class TriggerSubscriptionBuilderVerifyApi(Resource):
|
|||
"""Verify a subscription instance for a trigger provider"""
|
||||
user = current_user
|
||||
assert user.current_tenant_id is not None
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
# The credentials of the subscription builder
|
||||
.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
args = parser_api.parse_args()
|
||||
|
||||
try:
|
||||
# Use atomic update_and_verify to prevent race conditions
|
||||
|
|
@ -159,7 +180,24 @@ class TriggerSubscriptionBuilderVerifyApi(Resource):
|
|||
raise ValueError(str(e)) from e
|
||||
|
||||
|
||||
parser_update_api = (
|
||||
reqparse.RequestParser()
|
||||
# The name of the subscription builder
|
||||
.add_argument("name", type=str, required=False, nullable=True, location="json")
|
||||
# The parameters of the subscription builder
|
||||
.add_argument("parameters", type=dict, required=False, nullable=True, location="json")
|
||||
# The properties of the subscription builder
|
||||
.add_argument("properties", type=dict, required=False, nullable=True, location="json")
|
||||
# The credentials of the subscription builder
|
||||
.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/update/<path:subscription_builder_id>",
|
||||
)
|
||||
class TriggerSubscriptionBuilderUpdateApi(Resource):
|
||||
@console_ns.expect(parser_update_api)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -169,18 +207,7 @@ class TriggerSubscriptionBuilderUpdateApi(Resource):
|
|||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
# The name of the subscription builder
|
||||
.add_argument("name", type=str, required=False, nullable=True, location="json")
|
||||
# The parameters of the subscription builder
|
||||
.add_argument("parameters", type=dict, required=False, nullable=True, location="json")
|
||||
# The properties of the subscription builder
|
||||
.add_argument("properties", type=dict, required=False, nullable=True, location="json")
|
||||
# The credentials of the subscription builder
|
||||
.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_update_api.parse_args()
|
||||
try:
|
||||
return jsonable_encoder(
|
||||
TriggerSubscriptionBuilderService.update_trigger_subscription_builder(
|
||||
|
|
@ -200,6 +227,9 @@ class TriggerSubscriptionBuilderUpdateApi(Resource):
|
|||
raise
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/logs/<path:subscription_builder_id>",
|
||||
)
|
||||
class TriggerSubscriptionBuilderLogsApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -218,7 +248,11 @@ class TriggerSubscriptionBuilderLogsApi(Resource):
|
|||
raise
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/build/<path:subscription_builder_id>",
|
||||
)
|
||||
class TriggerSubscriptionBuilderBuildApi(Resource):
|
||||
@console_ns.expect(parser_update_api)
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
|
|
@ -227,18 +261,7 @@ class TriggerSubscriptionBuilderBuildApi(Resource):
|
|||
"""Build a subscription instance for a trigger provider"""
|
||||
user = current_user
|
||||
assert user.current_tenant_id is not None
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
# The name of the subscription builder
|
||||
.add_argument("name", type=str, required=False, nullable=True, location="json")
|
||||
# The parameters of the subscription builder
|
||||
.add_argument("parameters", type=dict, required=False, nullable=True, location="json")
|
||||
# The properties of the subscription builder
|
||||
.add_argument("properties", type=dict, required=False, nullable=True, location="json")
|
||||
# The credentials of the subscription builder
|
||||
.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_update_api.parse_args()
|
||||
try:
|
||||
# Use atomic update_and_build to prevent race conditions
|
||||
TriggerSubscriptionBuilderService.update_and_build_builder(
|
||||
|
|
@ -258,6 +281,9 @@ class TriggerSubscriptionBuilderBuildApi(Resource):
|
|||
raise ValueError(str(e)) from e
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
"/workspaces/current/trigger-provider/<path:subscription_id>/subscriptions/delete",
|
||||
)
|
||||
class TriggerSubscriptionDeleteApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -291,6 +317,7 @@ class TriggerSubscriptionDeleteApi(Resource):
|
|||
raise
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/trigger-provider/<path:provider>/subscriptions/oauth/authorize")
|
||||
class TriggerOAuthAuthorizeApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -374,6 +401,7 @@ class TriggerOAuthAuthorizeApi(Resource):
|
|||
raise
|
||||
|
||||
|
||||
@console_ns.route("/oauth/plugin/<path:provider>/trigger/callback")
|
||||
class TriggerOAuthCallbackApi(Resource):
|
||||
@setup_required
|
||||
def get(self, provider):
|
||||
|
|
@ -438,6 +466,14 @@ class TriggerOAuthCallbackApi(Resource):
|
|||
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")
|
||||
|
||||
|
||||
parser_oauth_client = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("client_params", type=dict, required=False, nullable=True, location="json")
|
||||
.add_argument("enabled", type=bool, required=False, nullable=True, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/trigger-provider/<path:provider>/oauth/client")
|
||||
class TriggerOAuthClientManageApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -484,6 +520,7 @@ class TriggerOAuthClientManageApi(Resource):
|
|||
logger.exception("Error getting OAuth client", exc_info=e)
|
||||
raise
|
||||
|
||||
@console_ns.expect(parser_oauth_client)
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
|
|
@ -493,12 +530,7 @@ class TriggerOAuthClientManageApi(Resource):
|
|||
user = current_user
|
||||
assert user.current_tenant_id is not None
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("client_params", type=dict, required=False, nullable=True, location="json")
|
||||
.add_argument("enabled", type=bool, required=False, nullable=True, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_oauth_client.parse_args()
|
||||
|
||||
try:
|
||||
provider_id = TriggerProviderID(provider)
|
||||
|
|
@ -536,52 +568,3 @@ class TriggerOAuthClientManageApi(Resource):
|
|||
except Exception as e:
|
||||
logger.exception("Error removing OAuth client", exc_info=e)
|
||||
raise
|
||||
|
||||
|
||||
# Trigger Subscription
|
||||
console_ns.add_resource(TriggerProviderIconApi, "/workspaces/current/trigger-provider/<path:provider>/icon")
|
||||
console_ns.add_resource(TriggerProviderListApi, "/workspaces/current/triggers")
|
||||
console_ns.add_resource(TriggerProviderInfoApi, "/workspaces/current/trigger-provider/<path:provider>/info")
|
||||
console_ns.add_resource(
|
||||
TriggerSubscriptionListApi, "/workspaces/current/trigger-provider/<path:provider>/subscriptions/list"
|
||||
)
|
||||
console_ns.add_resource(
|
||||
TriggerSubscriptionDeleteApi,
|
||||
"/workspaces/current/trigger-provider/<path:subscription_id>/subscriptions/delete",
|
||||
)
|
||||
|
||||
# Trigger Subscription Builder
|
||||
console_ns.add_resource(
|
||||
TriggerSubscriptionBuilderCreateApi,
|
||||
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/create",
|
||||
)
|
||||
console_ns.add_resource(
|
||||
TriggerSubscriptionBuilderGetApi,
|
||||
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/<path:subscription_builder_id>",
|
||||
)
|
||||
console_ns.add_resource(
|
||||
TriggerSubscriptionBuilderUpdateApi,
|
||||
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/update/<path:subscription_builder_id>",
|
||||
)
|
||||
console_ns.add_resource(
|
||||
TriggerSubscriptionBuilderVerifyApi,
|
||||
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/verify/<path:subscription_builder_id>",
|
||||
)
|
||||
console_ns.add_resource(
|
||||
TriggerSubscriptionBuilderBuildApi,
|
||||
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/build/<path:subscription_builder_id>",
|
||||
)
|
||||
console_ns.add_resource(
|
||||
TriggerSubscriptionBuilderLogsApi,
|
||||
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/logs/<path:subscription_builder_id>",
|
||||
)
|
||||
|
||||
|
||||
# OAuth
|
||||
console_ns.add_resource(
|
||||
TriggerOAuthAuthorizeApi, "/workspaces/current/trigger-provider/<path:provider>/subscriptions/oauth/authorize"
|
||||
)
|
||||
console_ns.add_resource(TriggerOAuthCallbackApi, "/oauth/plugin/<path:provider>/trigger/callback")
|
||||
console_ns.add_resource(
|
||||
TriggerOAuthClientManageApi, "/workspaces/current/trigger-provider/<path:provider>/oauth/client"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
import logging
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse
|
||||
from flask_restx import Resource, fields, marshal, marshal_with
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
|
|
@ -32,8 +33,36 @@ from services.file_service import FileService
|
|||
from services.workspace_service import WorkspaceService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class WorkspaceListQuery(BaseModel):
|
||||
page: int = Field(default=1, ge=1, le=99999)
|
||||
limit: int = Field(default=20, ge=1, le=100)
|
||||
|
||||
|
||||
class SwitchWorkspacePayload(BaseModel):
|
||||
tenant_id: str
|
||||
|
||||
|
||||
class WorkspaceCustomConfigPayload(BaseModel):
|
||||
remove_webapp_brand: bool | None = None
|
||||
replace_webapp_logo: str | None = None
|
||||
|
||||
|
||||
class WorkspaceInfoPayload(BaseModel):
|
||||
name: str
|
||||
|
||||
|
||||
def reg(cls: type[BaseModel]):
|
||||
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
|
||||
reg(WorkspaceListQuery)
|
||||
reg(SwitchWorkspacePayload)
|
||||
reg(WorkspaceCustomConfigPayload)
|
||||
reg(WorkspaceInfoPayload)
|
||||
|
||||
provider_fields = {
|
||||
"provider_name": fields.String,
|
||||
"provider_type": fields.String,
|
||||
|
|
@ -95,18 +124,15 @@ class TenantListApi(Resource):
|
|||
|
||||
@console_ns.route("/all-workspaces")
|
||||
class WorkspaceListApi(Resource):
|
||||
@console_ns.expect(console_ns.models[WorkspaceListQuery.__name__])
|
||||
@setup_required
|
||||
@admin_required
|
||||
def get(self):
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args")
|
||||
.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
payload = request.args.to_dict(flat=True) # type: ignore
|
||||
args = WorkspaceListQuery.model_validate(payload)
|
||||
|
||||
stmt = select(Tenant).order_by(Tenant.created_at.desc())
|
||||
tenants = db.paginate(select=stmt, page=args["page"], per_page=args["limit"], error_out=False)
|
||||
tenants = db.paginate(select=stmt, page=args.page, per_page=args.limit, error_out=False)
|
||||
has_more = False
|
||||
|
||||
if tenants.has_next:
|
||||
|
|
@ -115,8 +141,8 @@ class WorkspaceListApi(Resource):
|
|||
return {
|
||||
"data": marshal(tenants.items, workspace_fields),
|
||||
"has_more": has_more,
|
||||
"limit": args["limit"],
|
||||
"page": args["page"],
|
||||
"limit": args.limit,
|
||||
"page": args.page,
|
||||
"total": tenants.total,
|
||||
}, 200
|
||||
|
||||
|
|
@ -150,26 +176,24 @@ class TenantApi(Resource):
|
|||
return WorkspaceService.get_tenant_info(tenant), 200
|
||||
|
||||
|
||||
parser_switch = reqparse.RequestParser().add_argument("tenant_id", type=str, required=True, location="json")
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/switch")
|
||||
class SwitchWorkspaceApi(Resource):
|
||||
@console_ns.expect(parser_switch)
|
||||
@console_ns.expect(console_ns.models[SwitchWorkspacePayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
args = parser_switch.parse_args()
|
||||
payload = console_ns.payload or {}
|
||||
args = SwitchWorkspacePayload.model_validate(payload)
|
||||
|
||||
# check if tenant_id is valid, 403 if not
|
||||
try:
|
||||
TenantService.switch_tenant(current_user, args["tenant_id"])
|
||||
TenantService.switch_tenant(current_user, args.tenant_id)
|
||||
except Exception:
|
||||
raise AccountNotLinkTenantError("Account not link tenant")
|
||||
|
||||
new_tenant = db.session.query(Tenant).get(args["tenant_id"]) # Get new tenant
|
||||
new_tenant = db.session.query(Tenant).get(args.tenant_id) # Get new tenant
|
||||
if new_tenant is None:
|
||||
raise ValueError("Tenant not found")
|
||||
|
||||
|
|
@ -178,24 +202,21 @@ class SwitchWorkspaceApi(Resource):
|
|||
|
||||
@console_ns.route("/workspaces/custom-config")
|
||||
class CustomConfigWorkspaceApi(Resource):
|
||||
@console_ns.expect(console_ns.models[WorkspaceCustomConfigPayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("workspace_custom")
|
||||
def post(self):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("remove_webapp_brand", type=bool, location="json")
|
||||
.add_argument("replace_webapp_logo", type=str, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
payload = console_ns.payload or {}
|
||||
args = WorkspaceCustomConfigPayload.model_validate(payload)
|
||||
tenant = db.get_or_404(Tenant, current_tenant_id)
|
||||
|
||||
custom_config_dict = {
|
||||
"remove_webapp_brand": args["remove_webapp_brand"],
|
||||
"replace_webapp_logo": args["replace_webapp_logo"]
|
||||
if args["replace_webapp_logo"] is not None
|
||||
"remove_webapp_brand": args.remove_webapp_brand,
|
||||
"replace_webapp_logo": args.replace_webapp_logo
|
||||
if args.replace_webapp_logo is not None
|
||||
else tenant.custom_config_dict.get("replace_webapp_logo"),
|
||||
}
|
||||
|
||||
|
|
@ -245,24 +266,22 @@ class WebappLogoWorkspaceApi(Resource):
|
|||
return {"id": upload_file.id}, 201
|
||||
|
||||
|
||||
parser_info = reqparse.RequestParser().add_argument("name", type=str, required=True, location="json")
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/info")
|
||||
class WorkspaceInfoApi(Resource):
|
||||
@console_ns.expect(parser_info)
|
||||
@console_ns.expect(console_ns.models[WorkspaceInfoPayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
# Change workspace name
|
||||
def post(self):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
args = parser_info.parse_args()
|
||||
payload = console_ns.payload or {}
|
||||
args = WorkspaceInfoPayload.model_validate(payload)
|
||||
|
||||
if not current_tenant_id:
|
||||
raise ValueError("No current tenant")
|
||||
tenant = db.get_or_404(Tenant, current_tenant_id)
|
||||
tenant.name = args["name"]
|
||||
tenant.name = args.name
|
||||
db.session.commit()
|
||||
|
||||
return {"result": "success", "tenant": marshal(WorkspaceService.get_tenant_info(tenant), tenant_fields)}
|
||||
|
|
|
|||
|
|
@ -17,7 +17,6 @@ from controllers.service_api.app.error import (
|
|||
)
|
||||
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
||||
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.errors.error import (
|
||||
ModelCurrentlyNotSupportError,
|
||||
|
|
@ -30,6 +29,7 @@ from libs import helper
|
|||
from libs.helper import uuid_value
|
||||
from models.model import App, AppMode, EndUser
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.app_task_service import AppTaskService
|
||||
from services.errors.app import IsDraftWorkflowError, WorkflowIdFormatError, WorkflowNotFoundError
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
|
||||
|
|
@ -88,7 +88,7 @@ class CompletionApi(Resource):
|
|||
This endpoint generates a completion based on the provided inputs and query.
|
||||
Supports both blocking and streaming response modes.
|
||||
"""
|
||||
if app_model.mode != "completion":
|
||||
if app_model.mode != AppMode.COMPLETION:
|
||||
raise AppUnavailableError()
|
||||
|
||||
args = completion_parser.parse_args()
|
||||
|
|
@ -147,10 +147,15 @@ class CompletionStopApi(Resource):
|
|||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
|
||||
def post(self, app_model: App, end_user: EndUser, task_id: str):
|
||||
"""Stop a running completion task."""
|
||||
if app_model.mode != "completion":
|
||||
if app_model.mode != AppMode.COMPLETION:
|
||||
raise AppUnavailableError()
|
||||
|
||||
AppQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id)
|
||||
AppTaskService.stop_task(
|
||||
task_id=task_id,
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
user_id=end_user.id,
|
||||
app_mode=AppMode.value_of(app_model.mode),
|
||||
)
|
||||
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
|
@ -244,6 +249,11 @@ class ChatStopApi(Resource):
|
|||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
AppQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id)
|
||||
AppTaskService.stop_task(
|
||||
task_id=task_id,
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
user_id=end_user.id,
|
||||
app_mode=app_mode,
|
||||
)
|
||||
|
||||
return {"result": "success"}, 200
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import logging
|
||||
import time
|
||||
|
||||
from flask import jsonify
|
||||
from flask import jsonify, request
|
||||
from werkzeug.exceptions import NotFound, RequestEntityTooLarge
|
||||
|
||||
from controllers.trigger import bp
|
||||
|
|
@ -28,8 +28,14 @@ def _prepare_webhook_execution(webhook_id: str, is_debug: bool = False):
|
|||
webhook_data = WebhookService.extract_and_validate_webhook_data(webhook_trigger, node_config)
|
||||
return webhook_trigger, workflow, node_config, webhook_data, None
|
||||
except ValueError as e:
|
||||
# Fall back to raw extraction for error reporting
|
||||
webhook_data = WebhookService.extract_webhook_data(webhook_trigger)
|
||||
# Provide minimal context for error reporting without risking another parse failure
|
||||
webhook_data = {
|
||||
"method": request.method,
|
||||
"headers": dict(request.headers),
|
||||
"query_params": dict(request.args),
|
||||
"body": {},
|
||||
"files": {},
|
||||
}
|
||||
return webhook_trigger, workflow, node_config, webhook_data, str(e)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -17,7 +17,6 @@ from controllers.web.error import (
|
|||
)
|
||||
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
|
||||
from controllers.web.wraps import WebApiResource
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.errors.error import (
|
||||
ModelCurrentlyNotSupportError,
|
||||
|
|
@ -29,6 +28,7 @@ from libs import helper
|
|||
from libs.helper import uuid_value
|
||||
from models.model import AppMode
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.app_task_service import AppTaskService
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -64,7 +64,7 @@ class CompletionApi(WebApiResource):
|
|||
}
|
||||
)
|
||||
def post(self, app_model, end_user):
|
||||
if app_model.mode != "completion":
|
||||
if app_model.mode != AppMode.COMPLETION:
|
||||
raise NotCompletionAppError()
|
||||
|
||||
parser = (
|
||||
|
|
@ -125,10 +125,15 @@ class CompletionStopApi(WebApiResource):
|
|||
}
|
||||
)
|
||||
def post(self, app_model, end_user, task_id):
|
||||
if app_model.mode != "completion":
|
||||
if app_model.mode != AppMode.COMPLETION:
|
||||
raise NotCompletionAppError()
|
||||
|
||||
AppQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id)
|
||||
AppTaskService.stop_task(
|
||||
task_id=task_id,
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
user_id=end_user.id,
|
||||
app_mode=AppMode.value_of(app_model.mode),
|
||||
)
|
||||
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
|
@ -234,6 +239,11 @@ class ChatStopApi(WebApiResource):
|
|||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
AppQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id)
|
||||
AppTaskService.stop_task(
|
||||
task_id=task_id,
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
user_id=end_user.id,
|
||||
app_mode=app_mode,
|
||||
)
|
||||
|
||||
return {"result": "success"}, 200
|
||||
|
|
|
|||
|
|
@ -62,7 +62,8 @@ from core.app.task_pipeline.message_cycle_manager import MessageCycleManager
|
|||
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.ops.entities.trace_entity import TraceTaskName
|
||||
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
|
||||
from core.workflow.enums import WorkflowExecutionStatus
|
||||
from core.workflow.nodes import NodeType
|
||||
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
|
||||
|
|
@ -72,7 +73,7 @@ from extensions.ext_database import db
|
|||
from libs.datetime_utils import naive_utc_now
|
||||
from models import Account, Conversation, EndUser, Message, MessageFile
|
||||
from models.enums import CreatorUserRole
|
||||
from models.workflow import Workflow
|
||||
from models.workflow import Workflow, WorkflowNodeExecutionModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -580,7 +581,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
|||
|
||||
with self._database_session() as session:
|
||||
# Save message
|
||||
self._save_message(session=session, graph_runtime_state=resolved_state)
|
||||
self._save_message(session=session, graph_runtime_state=resolved_state, trace_manager=trace_manager)
|
||||
|
||||
yield workflow_finish_resp
|
||||
elif event.stopped_by in (
|
||||
|
|
@ -590,7 +591,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
|||
# When hitting input-moderation or annotation-reply, the workflow will not start
|
||||
with self._database_session() as session:
|
||||
# Save message
|
||||
self._save_message(session=session)
|
||||
self._save_message(session=session, trace_manager=trace_manager)
|
||||
|
||||
yield self._message_end_to_stream_response()
|
||||
|
||||
|
|
@ -599,6 +600,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
|||
event: QueueAdvancedChatMessageEndEvent,
|
||||
*,
|
||||
graph_runtime_state: GraphRuntimeState | None = None,
|
||||
trace_manager: TraceQueueManager | None = None,
|
||||
**kwargs,
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle advanced chat message end events."""
|
||||
|
|
@ -616,7 +618,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
|||
|
||||
# Save message
|
||||
with self._database_session() as session:
|
||||
self._save_message(session=session, graph_runtime_state=resolved_state)
|
||||
self._save_message(session=session, graph_runtime_state=resolved_state, trace_manager=trace_manager)
|
||||
|
||||
yield self._message_end_to_stream_response()
|
||||
|
||||
|
|
@ -770,7 +772,13 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
|||
if self._conversation_name_generate_thread:
|
||||
self._conversation_name_generate_thread.join()
|
||||
|
||||
def _save_message(self, *, session: Session, graph_runtime_state: GraphRuntimeState | None = None):
|
||||
def _save_message(
|
||||
self,
|
||||
*,
|
||||
session: Session,
|
||||
graph_runtime_state: GraphRuntimeState | None = None,
|
||||
trace_manager: TraceQueueManager | None = None,
|
||||
):
|
||||
message = self._get_message(session=session)
|
||||
|
||||
# If there are assistant files, remove markdown image links from answer
|
||||
|
|
@ -809,6 +817,14 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
|||
|
||||
metadata = self._task_state.metadata.model_dump()
|
||||
message.message_metadata = json.dumps(jsonable_encoder(metadata))
|
||||
|
||||
# Extract model provider and model_id from workflow node executions for tracing
|
||||
if message.workflow_run_id:
|
||||
model_info = self._extract_model_info_from_workflow(session, message.workflow_run_id)
|
||||
if model_info:
|
||||
message.model_provider = model_info.get("provider")
|
||||
message.model_id = model_info.get("model")
|
||||
|
||||
message_files = [
|
||||
MessageFile(
|
||||
message_id=message.id,
|
||||
|
|
@ -826,6 +842,68 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
|||
]
|
||||
session.add_all(message_files)
|
||||
|
||||
# Trigger MESSAGE_TRACE for tracing integrations
|
||||
if trace_manager:
|
||||
trace_manager.add_trace_task(
|
||||
TraceTask(
|
||||
TraceTaskName.MESSAGE_TRACE, conversation_id=self._conversation_id, message_id=self._message_id
|
||||
)
|
||||
)
|
||||
|
||||
def _extract_model_info_from_workflow(self, session: Session, workflow_run_id: str) -> dict[str, str] | None:
|
||||
"""
|
||||
Extract model provider and model_id from workflow node executions.
|
||||
Returns dict with 'provider' and 'model' keys, or None if not found.
|
||||
"""
|
||||
try:
|
||||
# Query workflow node executions for LLM or Agent nodes
|
||||
stmt = (
|
||||
select(WorkflowNodeExecutionModel)
|
||||
.where(WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id)
|
||||
.where(WorkflowNodeExecutionModel.node_type.in_(["llm", "agent"]))
|
||||
.order_by(WorkflowNodeExecutionModel.created_at.desc())
|
||||
.limit(1)
|
||||
)
|
||||
node_execution = session.scalar(stmt)
|
||||
|
||||
if not node_execution:
|
||||
return None
|
||||
|
||||
# Try to extract from execution_metadata for agent nodes
|
||||
if node_execution.execution_metadata:
|
||||
try:
|
||||
metadata = json.loads(node_execution.execution_metadata)
|
||||
agent_log = metadata.get("agent_log", [])
|
||||
# Look for the first agent thought with provider info
|
||||
for log_entry in agent_log:
|
||||
entry_metadata = log_entry.get("metadata", {})
|
||||
provider_str = entry_metadata.get("provider")
|
||||
if provider_str:
|
||||
# Parse format like "langgenius/deepseek/deepseek"
|
||||
parts = provider_str.split("/")
|
||||
if len(parts) >= 3:
|
||||
return {"provider": parts[1], "model": parts[2]}
|
||||
elif len(parts) == 2:
|
||||
return {"provider": parts[0], "model": parts[1]}
|
||||
except (json.JSONDecodeError, KeyError, AttributeError) as e:
|
||||
logger.debug("Failed to parse execution_metadata: %s", e)
|
||||
|
||||
# Try to extract from process_data for llm nodes
|
||||
if node_execution.process_data:
|
||||
try:
|
||||
process_data = json.loads(node_execution.process_data)
|
||||
provider = process_data.get("model_provider")
|
||||
model = process_data.get("model_name")
|
||||
if provider and model:
|
||||
return {"provider": provider, "model": model}
|
||||
except (json.JSONDecodeError, KeyError) as e:
|
||||
logger.debug("Failed to parse process_data: %s", e)
|
||||
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.warning("Failed to extract model info from workflow: %s", e)
|
||||
return None
|
||||
|
||||
def _seed_graph_runtime_state_from_queue_manager(self) -> None:
|
||||
"""Bootstrap the cached runtime state from the queue manager when present."""
|
||||
candidate = self._base_task_pipeline.queue_manager.graph_runtime_state
|
||||
|
|
|
|||
|
|
@ -155,8 +155,17 @@ class BaseAppGenerator:
|
|||
f"{variable_entity.variable} in input form must be less than {variable_entity.max_length} files"
|
||||
)
|
||||
case VariableEntityType.CHECKBOX:
|
||||
if not isinstance(value, bool):
|
||||
raise ValueError(f"{variable_entity.variable} in input form must be a valid boolean value")
|
||||
if isinstance(value, str):
|
||||
normalized_value = value.strip().lower()
|
||||
if normalized_value in {"true", "1", "yes", "on"}:
|
||||
value = True
|
||||
elif normalized_value in {"false", "0", "no", "off"}:
|
||||
value = False
|
||||
elif isinstance(value, (int, float)):
|
||||
if value == 1:
|
||||
value = True
|
||||
elif value == 0:
|
||||
value = False
|
||||
case _:
|
||||
raise AssertionError("this statement should be unreachable.")
|
||||
|
||||
|
|
|
|||
|
|
@ -258,6 +258,10 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
|||
|
||||
run_id = self._extract_workflow_run_id(runtime_state)
|
||||
self._workflow_execution_id = run_id
|
||||
|
||||
with self._database_session() as session:
|
||||
self._save_workflow_app_log(session=session, workflow_run_id=self._workflow_execution_id)
|
||||
|
||||
start_resp = self._workflow_response_converter.workflow_start_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run_id=run_id,
|
||||
|
|
@ -414,9 +418,6 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
|||
graph_runtime_state=validated_state,
|
||||
)
|
||||
|
||||
with self._database_session() as session:
|
||||
self._save_workflow_app_log(session=session, workflow_run_id=self._workflow_execution_id)
|
||||
|
||||
yield workflow_finish_resp
|
||||
|
||||
def _handle_workflow_partial_success_event(
|
||||
|
|
@ -437,10 +438,6 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
|||
graph_runtime_state=validated_state,
|
||||
exceptions_count=event.exceptions_count,
|
||||
)
|
||||
|
||||
with self._database_session() as session:
|
||||
self._save_workflow_app_log(session=session, workflow_run_id=self._workflow_execution_id)
|
||||
|
||||
yield workflow_finish_resp
|
||||
|
||||
def _handle_workflow_failed_and_stop_events(
|
||||
|
|
@ -471,10 +468,6 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
|||
error=error,
|
||||
exceptions_count=exceptions_count,
|
||||
)
|
||||
|
||||
with self._database_session() as session:
|
||||
self._save_workflow_app_log(session=session, workflow_run_id=self._workflow_execution_id)
|
||||
|
||||
yield workflow_finish_resp
|
||||
|
||||
def _handle_text_chunk_event(
|
||||
|
|
@ -655,7 +648,6 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
|||
)
|
||||
|
||||
session.add(workflow_app_log)
|
||||
session.commit()
|
||||
|
||||
def _text_chunk_to_stream_response(
|
||||
self, text: str, from_variable_selector: list[str] | None = None
|
||||
|
|
|
|||
|
|
@ -40,6 +40,9 @@ class EasyUITaskState(TaskState):
|
|||
"""
|
||||
|
||||
llm_result: LLMResult
|
||||
first_token_time: float | None = None
|
||||
last_token_time: float | None = None
|
||||
is_streaming_response: bool = False
|
||||
|
||||
|
||||
class WorkflowTaskState(TaskState):
|
||||
|
|
|
|||
|
|
@ -118,6 +118,7 @@ class PauseStatePersistenceLayer(GraphEngineLayer):
|
|||
workflow_run_id=workflow_run_id,
|
||||
state_owner_user_id=self._state_owner_user_id,
|
||||
state=state.dumps(),
|
||||
pause_reasons=event.reasons,
|
||||
)
|
||||
|
||||
def on_graph_end(self, error: Exception | None) -> None:
|
||||
|
|
|
|||
|
|
@ -332,6 +332,12 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
|||
if not self._task_state.llm_result.prompt_messages:
|
||||
self._task_state.llm_result.prompt_messages = chunk.prompt_messages
|
||||
|
||||
# Track streaming response times
|
||||
if self._task_state.first_token_time is None:
|
||||
self._task_state.first_token_time = time.perf_counter()
|
||||
self._task_state.is_streaming_response = True
|
||||
self._task_state.last_token_time = time.perf_counter()
|
||||
|
||||
# handle output moderation chunk
|
||||
should_direct_answer = self._handle_output_moderation_chunk(cast(str, delta_text))
|
||||
if should_direct_answer:
|
||||
|
|
@ -398,6 +404,18 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
|||
message.total_price = usage.total_price
|
||||
message.currency = usage.currency
|
||||
self._task_state.llm_result.usage.latency = message.provider_response_latency
|
||||
|
||||
# Add streaming metrics to usage if available
|
||||
if self._task_state.is_streaming_response and self._task_state.first_token_time:
|
||||
start_time = self.start_at
|
||||
first_token_time = self._task_state.first_token_time
|
||||
last_token_time = self._task_state.last_token_time or first_token_time
|
||||
usage.time_to_first_token = round(first_token_time - start_time, 3)
|
||||
usage.time_to_generate = round(last_token_time - first_token_time, 3)
|
||||
|
||||
# Update metadata with the complete usage info
|
||||
self._task_state.metadata.usage = usage
|
||||
|
||||
message.message_metadata = self._task_state.metadata.model_dump_json()
|
||||
|
||||
if trace_manager:
|
||||
|
|
|
|||
|
|
@ -152,10 +152,5 @@ class CodeExecutor:
|
|||
raise CodeExecutionError(f"Unsupported language {language}")
|
||||
|
||||
runner, preload = template_transformer.transform_caller(code, inputs)
|
||||
|
||||
try:
|
||||
response = cls.execute_code(language, preload, runner)
|
||||
except CodeExecutionError as e:
|
||||
raise e
|
||||
|
||||
response = cls.execute_code(language, preload, runner)
|
||||
return template_transformer.transform_response(response)
|
||||
|
|
|
|||
|
|
@ -1,308 +0,0 @@
|
|||
## Custom Integration of Pre-defined Models
|
||||
|
||||
### Introduction
|
||||
|
||||
After completing the vendors integration, the next step is to connect the vendor's models. To illustrate the entire connection process, we will use Xinference as an example to demonstrate a complete vendor integration.
|
||||
|
||||
It is important to note that for custom models, each model connection requires a complete vendor credential.
|
||||
|
||||
Unlike pre-defined models, a custom vendor integration always includes the following two parameters, which do not need to be defined in the vendor YAML file.
|
||||
|
||||

|
||||
|
||||
As mentioned earlier, vendors do not need to implement validate_provider_credential. The runtime will automatically call the corresponding model layer's validate_credentials to validate the credentials based on the model type and name selected by the user.
|
||||
|
||||
### Writing the Vendor YAML
|
||||
|
||||
First, we need to identify the types of models supported by the vendor we are integrating.
|
||||
|
||||
Currently supported model types are as follows:
|
||||
|
||||
- `llm` Text Generation Models
|
||||
|
||||
- `text_embedding` Text Embedding Models
|
||||
|
||||
- `rerank` Rerank Models
|
||||
|
||||
- `speech2text` Speech-to-Text
|
||||
|
||||
- `tts` Text-to-Speech
|
||||
|
||||
- `moderation` Moderation
|
||||
|
||||
Xinference supports LLM, Text Embedding, and Rerank. So we will start by writing xinference.yaml.
|
||||
|
||||
```yaml
|
||||
provider: xinference #Define the vendor identifier
|
||||
label: # Vendor display name, supports both en_US (English) and zh_Hans (Simplified Chinese). If zh_Hans is not set, it will use en_US by default.
|
||||
en_US: Xorbits Inference
|
||||
icon_small: # Small icon, refer to other vendors' icons stored in the _assets directory within the vendor implementation directory; follows the same language policy as the label
|
||||
en_US: icon_s_en.svg
|
||||
icon_large: # Large icon
|
||||
en_US: icon_l_en.svg
|
||||
help: # Help information
|
||||
title:
|
||||
en_US: How to deploy Xinference
|
||||
zh_Hans: 如何部署 Xinference
|
||||
url:
|
||||
en_US: https://github.com/xorbitsai/inference
|
||||
supported_model_types: # Supported model types. Xinference supports LLM, Text Embedding, and Rerank
|
||||
- llm
|
||||
- text-embedding
|
||||
- rerank
|
||||
configurate_methods: # Since Xinference is a locally deployed vendor with no predefined models, users need to deploy whatever models they need according to Xinference documentation. Thus, it only supports custom models.
|
||||
- customizable-model
|
||||
provider_credential_schema:
|
||||
credential_form_schemas:
|
||||
```
|
||||
|
||||
Then, we need to determine what credentials are required to define a model in Xinference.
|
||||
|
||||
- Since it supports three different types of models, we need to specify the model_type to denote the model type. Here is how we can define it:
|
||||
|
||||
```yaml
|
||||
provider_credential_schema:
|
||||
credential_form_schemas:
|
||||
- variable: model_type
|
||||
type: select
|
||||
label:
|
||||
en_US: Model type
|
||||
zh_Hans: 模型类型
|
||||
required: true
|
||||
options:
|
||||
- value: text-generation
|
||||
label:
|
||||
en_US: Language Model
|
||||
zh_Hans: 语言模型
|
||||
- value: embeddings
|
||||
label:
|
||||
en_US: Text Embedding
|
||||
- value: reranking
|
||||
label:
|
||||
en_US: Rerank
|
||||
```
|
||||
|
||||
- Next, each model has its own model_name, so we need to define that here:
|
||||
|
||||
```yaml
|
||||
- variable: model_name
|
||||
type: text-input
|
||||
label:
|
||||
en_US: Model name
|
||||
zh_Hans: 模型名称
|
||||
required: true
|
||||
placeholder:
|
||||
zh_Hans: 填写模型名称
|
||||
en_US: Input model name
|
||||
```
|
||||
|
||||
- Specify the Xinference local deployment address:
|
||||
|
||||
```yaml
|
||||
- variable: server_url
|
||||
label:
|
||||
zh_Hans: 服务器 URL
|
||||
en_US: Server url
|
||||
type: text-input
|
||||
required: true
|
||||
placeholder:
|
||||
zh_Hans: 在此输入 Xinference 的服务器地址,如 https://example.com/xxx
|
||||
en_US: Enter the url of your Xinference, for example https://example.com/xxx
|
||||
```
|
||||
|
||||
- Each model has a unique model_uid, so we also need to define that here:
|
||||
|
||||
```yaml
|
||||
- variable: model_uid
|
||||
label:
|
||||
zh_Hans: 模型 UID
|
||||
en_US: Model uid
|
||||
type: text-input
|
||||
required: true
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的 Model UID
|
||||
en_US: Enter the model uid
|
||||
```
|
||||
|
||||
Now, we have completed the basic definition of the vendor.
|
||||
|
||||
### Writing the Model Code
|
||||
|
||||
Next, let's take the `llm` type as an example and write `xinference.llm.llm.py`.
|
||||
|
||||
In `llm.py`, create a Xinference LLM class, we name it `XinferenceAILargeLanguageModel` (this can be arbitrary), inheriting from the `__base.large_language_model.LargeLanguageModel` base class, and implement the following methods:
|
||||
|
||||
- LLM Invocation
|
||||
|
||||
Implement the core method for LLM invocation, supporting both stream and synchronous responses.
|
||||
|
||||
```python
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None) \
|
||||
-> Union[LLMResult, Generator]:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param model_parameters: model parameters
|
||||
:param tools: tools for tool usage
|
||||
:param stop: stop words
|
||||
:param stream: is the response a stream
|
||||
:param user: unique user id
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
```
|
||||
|
||||
When implementing, ensure to use two functions to return data separately for synchronous and stream responses. This is important because Python treats functions containing the `yield` keyword as generator functions, mandating them to return `Generator` types. Here’s an example (note that the example uses simplified parameters; in real implementation, use the parameter list as defined above):
|
||||
|
||||
```python
|
||||
def _invoke(self, stream: bool, **kwargs) \
|
||||
-> Union[LLMResult, Generator]:
|
||||
if stream:
|
||||
return self._handle_stream_response(**kwargs)
|
||||
return self._handle_sync_response(**kwargs)
|
||||
|
||||
def _handle_stream_response(self, **kwargs) -> Generator:
|
||||
for chunk in response:
|
||||
yield chunk
|
||||
def _handle_sync_response(self, **kwargs) -> LLMResult:
|
||||
return LLMResult(**response)
|
||||
```
|
||||
|
||||
- Pre-compute Input Tokens
|
||||
|
||||
If the model does not provide an interface for pre-computing tokens, you can return 0 directly.
|
||||
|
||||
```python
|
||||
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],tools: Optional[list[PromptMessageTool]] = None) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param tools: tools for tool usage
|
||||
:return: token count
|
||||
"""
|
||||
```
|
||||
|
||||
Sometimes, you might not want to return 0 directly. In such cases, you can use `self._get_num_tokens_by_gpt2(text: str)` to get pre-computed tokens and ensure environment variable `PLUGIN_BASED_TOKEN_COUNTING_ENABLED` is set to `true`, This method is provided by the `AIModel` base class, and it uses GPT2's Tokenizer for calculation. However, it should be noted that this is only a substitute and may not be fully accurate.
|
||||
|
||||
- Model Credentials Validation
|
||||
|
||||
Similar to vendor credentials validation, this method validates individual model credentials.
|
||||
|
||||
```python
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Validate model credentials
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return: None
|
||||
"""
|
||||
```
|
||||
|
||||
- Model Parameter Schema
|
||||
|
||||
Unlike custom types, since the YAML file does not define which parameters a model supports, we need to dynamically generate the model parameter schema.
|
||||
|
||||
For instance, Xinference supports `max_tokens`, `temperature`, and `top_p` parameters.
|
||||
|
||||
However, some vendors may support different parameters for different models. For example, the `OpenLLM` vendor supports `top_k`, but not all models provided by this vendor support `top_k`. Let's say model A supports `top_k` but model B does not. In such cases, we need to dynamically generate the model parameter schema, as illustrated below:
|
||||
|
||||
```python
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
||||
"""
|
||||
used to define customizable model schema
|
||||
"""
|
||||
rules = [
|
||||
ParameterRule(
|
||||
name='temperature', type=ParameterType.FLOAT,
|
||||
use_template='temperature',
|
||||
label=I18nObject(
|
||||
zh_Hans='温度', en_US='Temperature'
|
||||
)
|
||||
),
|
||||
ParameterRule(
|
||||
name='top_p', type=ParameterType.FLOAT,
|
||||
use_template='top_p',
|
||||
label=I18nObject(
|
||||
zh_Hans='Top P', en_US='Top P'
|
||||
)
|
||||
),
|
||||
ParameterRule(
|
||||
name='max_tokens', type=ParameterType.INT,
|
||||
use_template='max_tokens',
|
||||
min=1,
|
||||
default=512,
|
||||
label=I18nObject(
|
||||
zh_Hans='最大生成长度', en_US='Max Tokens'
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
# if model is A, add top_k to rules
|
||||
if model == 'A':
|
||||
rules.append(
|
||||
ParameterRule(
|
||||
name='top_k', type=ParameterType.INT,
|
||||
use_template='top_k',
|
||||
min=1,
|
||||
default=50,
|
||||
label=I18nObject(
|
||||
zh_Hans='Top K', en_US='Top K'
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
"""
|
||||
some NOT IMPORTANT code here
|
||||
"""
|
||||
|
||||
entity = AIModelEntity(
|
||||
model=model,
|
||||
label=I18nObject(
|
||||
en_US=model
|
||||
),
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_type=model_type,
|
||||
model_properties={
|
||||
ModelPropertyKey.MODE: ModelType.LLM,
|
||||
},
|
||||
parameter_rules=rules
|
||||
)
|
||||
|
||||
return entity
|
||||
```
|
||||
|
||||
- Exception Error Mapping
|
||||
|
||||
When a model invocation error occurs, it should be mapped to the runtime's specified `InvokeError` type, enabling Dify to handle different errors appropriately.
|
||||
|
||||
Runtime Errors:
|
||||
|
||||
- `InvokeConnectionError` Connection error during invocation
|
||||
- `InvokeServerUnavailableError` Service provider unavailable
|
||||
- `InvokeRateLimitError` Rate limit reached
|
||||
- `InvokeAuthorizationError` Authorization failure
|
||||
- `InvokeBadRequestError` Invalid request parameters
|
||||
|
||||
```python
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
The key is the error type thrown to the caller
|
||||
The value is the error type thrown by the model,
|
||||
which needs to be converted into a unified error type for the caller.
|
||||
|
||||
:return: Invoke error mapping
|
||||
"""
|
||||
```
|
||||
|
||||
For interface method details, see: [Interfaces](./interfaces.md). For specific implementations, refer to: [llm.py](https://github.com/langgenius/dify-runtime/blob/main/lib/model_providers/anthropic/llm/llm.py).
|
||||
|
Before Width: | Height: | Size: 230 KiB |
|
Before Width: | Height: | Size: 205 KiB |
|
Before Width: | Height: | Size: 370 KiB |
|
Before Width: | Height: | Size: 113 KiB |
|
Before Width: | Height: | Size: 109 KiB |
|
Before Width: | Height: | Size: 70 KiB |
|
Before Width: | Height: | Size: 75 KiB |
|
Before Width: | Height: | Size: 541 KiB |
|
Before Width: | Height: | Size: 44 KiB |
|
Before Width: | Height: | Size: 262 KiB |
|
|
@ -1,701 +0,0 @@
|
|||
# Interface Methods
|
||||
|
||||
This section describes the interface methods and parameter explanations that need to be implemented by providers and various model types.
|
||||
|
||||
## Provider
|
||||
|
||||
Inherit the `__base.model_provider.ModelProvider` base class and implement the following interfaces:
|
||||
|
||||
```python
|
||||
def validate_provider_credentials(self, credentials: dict) -> None:
|
||||
"""
|
||||
Validate provider credentials
|
||||
You can choose any validate_credentials method of model type or implement validate method by yourself,
|
||||
such as: get model list api
|
||||
|
||||
if validate failed, raise exception
|
||||
|
||||
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
|
||||
"""
|
||||
```
|
||||
|
||||
- `credentials` (object) Credential information
|
||||
|
||||
The parameters of credential information are defined by the `provider_credential_schema` in the provider's YAML configuration file. Inputs such as `api_key` are included.
|
||||
|
||||
If verification fails, throw the `errors.validate.CredentialsValidateFailedError` error.
|
||||
|
||||
## Model
|
||||
|
||||
Models are divided into 5 different types, each inheriting from different base classes and requiring the implementation of different methods.
|
||||
|
||||
All models need to uniformly implement the following 2 methods:
|
||||
|
||||
- Model Credential Verification
|
||||
|
||||
Similar to provider credential verification, this step involves verification for an individual model.
|
||||
|
||||
```python
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Validate model credentials
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return:
|
||||
"""
|
||||
```
|
||||
|
||||
Parameters:
|
||||
|
||||
- `model` (string) Model name
|
||||
|
||||
- `credentials` (object) Credential information
|
||||
|
||||
The parameters of credential information are defined by either the `provider_credential_schema` or `model_credential_schema` in the provider's YAML configuration file. Inputs such as `api_key` are included.
|
||||
|
||||
If verification fails, throw the `errors.validate.CredentialsValidateFailedError` error.
|
||||
|
||||
- Invocation Error Mapping Table
|
||||
|
||||
When there is an exception in model invocation, it needs to be mapped to the `InvokeError` type specified by Runtime. This facilitates Dify's ability to handle different errors with appropriate follow-up actions.
|
||||
|
||||
Runtime Errors:
|
||||
|
||||
- `InvokeConnectionError` Invocation connection error
|
||||
- `InvokeServerUnavailableError` Invocation service provider unavailable
|
||||
- `InvokeRateLimitError` Invocation reached rate limit
|
||||
- `InvokeAuthorizationError` Invocation authorization failure
|
||||
- `InvokeBadRequestError` Invocation parameter error
|
||||
|
||||
```python
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
The key is the error type thrown to the caller
|
||||
The value is the error type thrown by the model,
|
||||
which needs to be converted into a unified error type for the caller.
|
||||
|
||||
:return: Invoke error mapping
|
||||
"""
|
||||
```
|
||||
|
||||
You can refer to OpenAI's `_invoke_error_mapping` for an example.
|
||||
|
||||
### LLM
|
||||
|
||||
Inherit the `__base.large_language_model.LargeLanguageModel` base class and implement the following interfaces:
|
||||
|
||||
- LLM Invocation
|
||||
|
||||
Implement the core method for LLM invocation, which can support both streaming and synchronous returns.
|
||||
|
||||
```python
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None) \
|
||||
-> Union[LLMResult, Generator]:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param model_parameters: model parameters
|
||||
:param tools: tools for tool calling
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
```
|
||||
|
||||
- Parameters:
|
||||
|
||||
- `model` (string) Model name
|
||||
|
||||
- `credentials` (object) Credential information
|
||||
|
||||
The parameters of credential information are defined by either the `provider_credential_schema` or `model_credential_schema` in the provider's YAML configuration file. Inputs such as `api_key` are included.
|
||||
|
||||
- `prompt_messages` (array\[[PromptMessage](#PromptMessage)\]) List of prompts
|
||||
|
||||
If the model is of the `Completion` type, the list only needs to include one [UserPromptMessage](#UserPromptMessage) element;
|
||||
|
||||
If the model is of the `Chat` type, it requires a list of elements such as [SystemPromptMessage](#SystemPromptMessage), [UserPromptMessage](#UserPromptMessage), [AssistantPromptMessage](#AssistantPromptMessage), [ToolPromptMessage](#ToolPromptMessage) depending on the message.
|
||||
|
||||
- `model_parameters` (object) Model parameters
|
||||
|
||||
The model parameters are defined by the `parameter_rules` in the model's YAML configuration.
|
||||
|
||||
- `tools` (array\[[PromptMessageTool](#PromptMessageTool)\]) [optional] List of tools, equivalent to the `function` in `function calling`.
|
||||
|
||||
That is, the tool list for tool calling.
|
||||
|
||||
- `stop` (array[string]) [optional] Stop sequences
|
||||
|
||||
The model output will stop before the string defined by the stop sequence.
|
||||
|
||||
- `stream` (bool) Whether to output in a streaming manner, default is True
|
||||
|
||||
Streaming output returns Generator\[[LLMResultChunk](#LLMResultChunk)\], non-streaming output returns [LLMResult](#LLMResult).
|
||||
|
||||
- `user` (string) [optional] Unique identifier of the user
|
||||
|
||||
This can help the provider monitor and detect abusive behavior.
|
||||
|
||||
- Returns
|
||||
|
||||
Streaming output returns Generator\[[LLMResultChunk](#LLMResultChunk)\], non-streaming output returns [LLMResult](#LLMResult).
|
||||
|
||||
- Pre-calculating Input Tokens
|
||||
|
||||
If the model does not provide a pre-calculated tokens interface, you can directly return 0.
|
||||
|
||||
```python
|
||||
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param tools: tools for tool calling
|
||||
:return:
|
||||
"""
|
||||
```
|
||||
|
||||
For parameter explanations, refer to the above section on `LLM Invocation`.
|
||||
|
||||
- Fetch Custom Model Schema [Optional]
|
||||
|
||||
```python
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
||||
"""
|
||||
Get customizable model schema
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return: model schema
|
||||
"""
|
||||
```
|
||||
|
||||
When the provider supports adding custom LLMs, this method can be implemented to allow custom models to fetch model schema. The default return null.
|
||||
|
||||
### TextEmbedding
|
||||
|
||||
Inherit the `__base.text_embedding_model.TextEmbeddingModel` base class and implement the following interfaces:
|
||||
|
||||
- Embedding Invocation
|
||||
|
||||
```python
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
texts: list[str], user: Optional[str] = None) \
|
||||
-> TextEmbeddingResult:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:param user: unique user id
|
||||
:return: embeddings result
|
||||
"""
|
||||
```
|
||||
|
||||
- Parameters:
|
||||
|
||||
- `model` (string) Model name
|
||||
|
||||
- `credentials` (object) Credential information
|
||||
|
||||
The parameters of credential information are defined by either the `provider_credential_schema` or `model_credential_schema` in the provider's YAML configuration file. Inputs such as `api_key` are included.
|
||||
|
||||
- `texts` (array[string]) List of texts, capable of batch processing
|
||||
|
||||
- `user` (string) [optional] Unique identifier of the user
|
||||
|
||||
This can help the provider monitor and detect abusive behavior.
|
||||
|
||||
- Returns:
|
||||
|
||||
[TextEmbeddingResult](#TextEmbeddingResult) entity.
|
||||
|
||||
- Pre-calculating Tokens
|
||||
|
||||
```python
|
||||
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:return:
|
||||
"""
|
||||
```
|
||||
|
||||
For parameter explanations, refer to the above section on `Embedding Invocation`.
|
||||
|
||||
### Rerank
|
||||
|
||||
Inherit the `__base.rerank_model.RerankModel` base class and implement the following interfaces:
|
||||
|
||||
- Rerank Invocation
|
||||
|
||||
```python
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None,
|
||||
user: Optional[str] = None) \
|
||||
-> RerankResult:
|
||||
"""
|
||||
Invoke rerank model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param query: search query
|
||||
:param docs: docs for reranking
|
||||
:param score_threshold: score threshold
|
||||
:param top_n: top n
|
||||
:param user: unique user id
|
||||
:return: rerank result
|
||||
"""
|
||||
```
|
||||
|
||||
- Parameters:
|
||||
|
||||
- `model` (string) Model name
|
||||
|
||||
- `credentials` (object) Credential information
|
||||
|
||||
The parameters of credential information are defined by either the `provider_credential_schema` or `model_credential_schema` in the provider's YAML configuration file. Inputs such as `api_key` are included.
|
||||
|
||||
- `query` (string) Query request content
|
||||
|
||||
- `docs` (array[string]) List of segments to be reranked
|
||||
|
||||
- `score_threshold` (float) [optional] Score threshold
|
||||
|
||||
- `top_n` (int) [optional] Select the top n segments
|
||||
|
||||
- `user` (string) [optional] Unique identifier of the user
|
||||
|
||||
This can help the provider monitor and detect abusive behavior.
|
||||
|
||||
- Returns:
|
||||
|
||||
[RerankResult](#RerankResult) entity.
|
||||
|
||||
### Speech2text
|
||||
|
||||
Inherit the `__base.speech2text_model.Speech2TextModel` base class and implement the following interfaces:
|
||||
|
||||
- Invoke Invocation
|
||||
|
||||
```python
|
||||
def _invoke(self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None) -> str:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param file: audio file
|
||||
:param user: unique user id
|
||||
:return: text for given audio file
|
||||
"""
|
||||
```
|
||||
|
||||
- Parameters:
|
||||
|
||||
- `model` (string) Model name
|
||||
|
||||
- `credentials` (object) Credential information
|
||||
|
||||
The parameters of credential information are defined by either the `provider_credential_schema` or `model_credential_schema` in the provider's YAML configuration file. Inputs such as `api_key` are included.
|
||||
|
||||
- `file` (File) File stream
|
||||
|
||||
- `user` (string) [optional] Unique identifier of the user
|
||||
|
||||
This can help the provider monitor and detect abusive behavior.
|
||||
|
||||
- Returns:
|
||||
|
||||
The string after speech-to-text conversion.
|
||||
|
||||
### Text2speech
|
||||
|
||||
Inherit the `__base.text2speech_model.Text2SpeechModel` base class and implement the following interfaces:
|
||||
|
||||
- Invoke Invocation
|
||||
|
||||
```python
|
||||
def _invoke(self, model: str, credentials: dict, content_text: str, streaming: bool, user: Optional[str] = None):
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param content_text: text content to be translated
|
||||
:param streaming: output is streaming
|
||||
:param user: unique user id
|
||||
:return: translated audio file
|
||||
"""
|
||||
```
|
||||
|
||||
- Parameters:
|
||||
|
||||
- `model` (string) Model name
|
||||
|
||||
- `credentials` (object) Credential information
|
||||
|
||||
The parameters of credential information are defined by either the `provider_credential_schema` or `model_credential_schema` in the provider's YAML configuration file. Inputs such as `api_key` are included.
|
||||
|
||||
- `content_text` (string) The text content that needs to be converted
|
||||
|
||||
- `streaming` (bool) Whether to stream output
|
||||
|
||||
- `user` (string) [optional] Unique identifier of the user
|
||||
|
||||
This can help the provider monitor and detect abusive behavior.
|
||||
|
||||
- Returns:
|
||||
|
||||
Text converted speech stream.
|
||||
|
||||
### Moderation
|
||||
|
||||
Inherit the `__base.moderation_model.ModerationModel` base class and implement the following interfaces:
|
||||
|
||||
- Invoke Invocation
|
||||
|
||||
```python
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
text: str, user: Optional[str] = None) \
|
||||
-> bool:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param text: text to moderate
|
||||
:param user: unique user id
|
||||
:return: false if text is safe, true otherwise
|
||||
"""
|
||||
```
|
||||
|
||||
- Parameters:
|
||||
|
||||
- `model` (string) Model name
|
||||
|
||||
- `credentials` (object) Credential information
|
||||
|
||||
The parameters of credential information are defined by either the `provider_credential_schema` or `model_credential_schema` in the provider's YAML configuration file. Inputs such as `api_key` are included.
|
||||
|
||||
- `text` (string) Text content
|
||||
|
||||
- `user` (string) [optional] Unique identifier of the user
|
||||
|
||||
This can help the provider monitor and detect abusive behavior.
|
||||
|
||||
- Returns:
|
||||
|
||||
False indicates that the input text is safe, True indicates otherwise.
|
||||
|
||||
## Entities
|
||||
|
||||
### PromptMessageRole
|
||||
|
||||
Message role
|
||||
|
||||
```python
|
||||
class PromptMessageRole(Enum):
|
||||
"""
|
||||
Enum class for prompt message.
|
||||
"""
|
||||
SYSTEM = "system"
|
||||
USER = "user"
|
||||
ASSISTANT = "assistant"
|
||||
TOOL = "tool"
|
||||
```
|
||||
|
||||
### PromptMessageContentType
|
||||
|
||||
Message content types, divided into text and image.
|
||||
|
||||
```python
|
||||
class PromptMessageContentType(Enum):
|
||||
"""
|
||||
Enum class for prompt message content type.
|
||||
"""
|
||||
TEXT = 'text'
|
||||
IMAGE = 'image'
|
||||
```
|
||||
|
||||
### PromptMessageContent
|
||||
|
||||
Message content base class, used only for parameter declaration and cannot be initialized.
|
||||
|
||||
```python
|
||||
class PromptMessageContent(BaseModel):
|
||||
"""
|
||||
Model class for prompt message content.
|
||||
"""
|
||||
type: PromptMessageContentType
|
||||
data: str
|
||||
```
|
||||
|
||||
Currently, two types are supported: text and image. It's possible to simultaneously input text and multiple images.
|
||||
|
||||
You need to initialize `TextPromptMessageContent` and `ImagePromptMessageContent` separately for input.
|
||||
|
||||
### TextPromptMessageContent
|
||||
|
||||
```python
|
||||
class TextPromptMessageContent(PromptMessageContent):
|
||||
"""
|
||||
Model class for text prompt message content.
|
||||
"""
|
||||
type: PromptMessageContentType = PromptMessageContentType.TEXT
|
||||
```
|
||||
|
||||
If inputting a combination of text and images, the text needs to be constructed into this entity as part of the `content` list.
|
||||
|
||||
### ImagePromptMessageContent
|
||||
|
||||
```python
|
||||
class ImagePromptMessageContent(PromptMessageContent):
|
||||
"""
|
||||
Model class for image prompt message content.
|
||||
"""
|
||||
class DETAIL(Enum):
|
||||
LOW = 'low'
|
||||
HIGH = 'high'
|
||||
|
||||
type: PromptMessageContentType = PromptMessageContentType.IMAGE
|
||||
detail: DETAIL = DETAIL.LOW # Resolution
|
||||
```
|
||||
|
||||
If inputting a combination of text and images, the images need to be constructed into this entity as part of the `content` list.
|
||||
|
||||
`data` can be either a `url` or a `base64` encoded string of the image.
|
||||
|
||||
### PromptMessage
|
||||
|
||||
The base class for all Role message bodies, used only for parameter declaration and cannot be initialized.
|
||||
|
||||
```python
|
||||
class PromptMessage(BaseModel):
|
||||
"""
|
||||
Model class for prompt message.
|
||||
"""
|
||||
role: PromptMessageRole
|
||||
content: Optional[str | list[PromptMessageContent]] = None # Supports two types: string and content list. The content list is designed to meet the needs of multimodal inputs. For more details, see the PromptMessageContent explanation.
|
||||
name: Optional[str] = None
|
||||
```
|
||||
|
||||
### UserPromptMessage
|
||||
|
||||
UserMessage message body, representing a user's message.
|
||||
|
||||
```python
|
||||
class UserPromptMessage(PromptMessage):
|
||||
"""
|
||||
Model class for user prompt message.
|
||||
"""
|
||||
role: PromptMessageRole = PromptMessageRole.USER
|
||||
```
|
||||
|
||||
### AssistantPromptMessage
|
||||
|
||||
Represents a message returned by the model, typically used for `few-shots` or inputting chat history.
|
||||
|
||||
```python
|
||||
class AssistantPromptMessage(PromptMessage):
|
||||
"""
|
||||
Model class for assistant prompt message.
|
||||
"""
|
||||
class ToolCall(BaseModel):
|
||||
"""
|
||||
Model class for assistant prompt message tool call.
|
||||
"""
|
||||
class ToolCallFunction(BaseModel):
|
||||
"""
|
||||
Model class for assistant prompt message tool call function.
|
||||
"""
|
||||
name: str # tool name
|
||||
arguments: str # tool arguments
|
||||
|
||||
id: str # Tool ID, effective only in OpenAI tool calls. It's the unique ID for tool invocation and the same tool can be called multiple times.
|
||||
type: str # default: function
|
||||
function: ToolCallFunction # tool call information
|
||||
|
||||
role: PromptMessageRole = PromptMessageRole.ASSISTANT
|
||||
tool_calls: list[ToolCall] = [] # The result of tool invocation in response from the model (returned only when tools are input and the model deems it necessary to invoke a tool).
|
||||
```
|
||||
|
||||
Where `tool_calls` are the list of `tool calls` returned by the model after invoking the model with the `tools` input.
|
||||
|
||||
### SystemPromptMessage
|
||||
|
||||
Represents system messages, usually used for setting system commands given to the model.
|
||||
|
||||
```python
|
||||
class SystemPromptMessage(PromptMessage):
|
||||
"""
|
||||
Model class for system prompt message.
|
||||
"""
|
||||
role: PromptMessageRole = PromptMessageRole.SYSTEM
|
||||
```
|
||||
|
||||
### ToolPromptMessage
|
||||
|
||||
Represents tool messages, used for conveying the results of a tool execution to the model for the next step of processing.
|
||||
|
||||
```python
|
||||
class ToolPromptMessage(PromptMessage):
|
||||
"""
|
||||
Model class for tool prompt message.
|
||||
"""
|
||||
role: PromptMessageRole = PromptMessageRole.TOOL
|
||||
tool_call_id: str # Tool invocation ID. If OpenAI tool call is not supported, the name of the tool can also be inputted.
|
||||
```
|
||||
|
||||
The base class's `content` takes in the results of tool execution.
|
||||
|
||||
### PromptMessageTool
|
||||
|
||||
```python
|
||||
class PromptMessageTool(BaseModel):
|
||||
"""
|
||||
Model class for prompt message tool.
|
||||
"""
|
||||
name: str
|
||||
description: str
|
||||
parameters: dict
|
||||
```
|
||||
|
||||
______________________________________________________________________
|
||||
|
||||
### LLMResult
|
||||
|
||||
```python
|
||||
class LLMResult(BaseModel):
|
||||
"""
|
||||
Model class for llm result.
|
||||
"""
|
||||
model: str # Actual used modele
|
||||
prompt_messages: list[PromptMessage] # prompt messages
|
||||
message: AssistantPromptMessage # response message
|
||||
usage: LLMUsage # usage info
|
||||
system_fingerprint: Optional[str] = None # request fingerprint, refer to OpenAI definition
|
||||
```
|
||||
|
||||
### LLMResultChunkDelta
|
||||
|
||||
In streaming returns, each iteration contains the `delta` entity.
|
||||
|
||||
```python
|
||||
class LLMResultChunkDelta(BaseModel):
|
||||
"""
|
||||
Model class for llm result chunk delta.
|
||||
"""
|
||||
index: int
|
||||
message: AssistantPromptMessage # response message
|
||||
usage: Optional[LLMUsage] = None # usage info
|
||||
finish_reason: Optional[str] = None # finish reason, only the last one returns
|
||||
```
|
||||
|
||||
### LLMResultChunk
|
||||
|
||||
Each iteration entity in streaming returns.
|
||||
|
||||
```python
|
||||
class LLMResultChunk(BaseModel):
|
||||
"""
|
||||
Model class for llm result chunk.
|
||||
"""
|
||||
model: str # Actual used modele
|
||||
prompt_messages: list[PromptMessage] # prompt messages
|
||||
system_fingerprint: Optional[str] = None # request fingerprint, refer to OpenAI definition
|
||||
delta: LLMResultChunkDelta
|
||||
```
|
||||
|
||||
### LLMUsage
|
||||
|
||||
```python
|
||||
class LLMUsage(ModelUsage):
|
||||
"""
|
||||
Model class for LLM usage.
|
||||
"""
|
||||
prompt_tokens: int # Tokens used for prompt
|
||||
prompt_unit_price: Decimal # Unit price for prompt
|
||||
prompt_price_unit: Decimal # Price unit for prompt, i.e., the unit price based on how many tokens
|
||||
prompt_price: Decimal # Cost for prompt
|
||||
completion_tokens: int # Tokens used for response
|
||||
completion_unit_price: Decimal # Unit price for response
|
||||
completion_price_unit: Decimal # Price unit for response, i.e., the unit price based on how many tokens
|
||||
completion_price: Decimal # Cost for response
|
||||
total_tokens: int # Total number of tokens used
|
||||
total_price: Decimal # Total cost
|
||||
currency: str # Currency unit
|
||||
latency: float # Request latency (s)
|
||||
```
|
||||
|
||||
______________________________________________________________________
|
||||
|
||||
### TextEmbeddingResult
|
||||
|
||||
```python
|
||||
class TextEmbeddingResult(BaseModel):
|
||||
"""
|
||||
Model class for text embedding result.
|
||||
"""
|
||||
model: str # Actual model used
|
||||
embeddings: list[list[float]] # List of embedding vectors, corresponding to the input texts list
|
||||
usage: EmbeddingUsage # Usage information
|
||||
```
|
||||
|
||||
### EmbeddingUsage
|
||||
|
||||
```python
|
||||
class EmbeddingUsage(ModelUsage):
|
||||
"""
|
||||
Model class for embedding usage.
|
||||
"""
|
||||
tokens: int # Number of tokens used
|
||||
total_tokens: int # Total number of tokens used
|
||||
unit_price: Decimal # Unit price
|
||||
price_unit: Decimal # Price unit, i.e., the unit price based on how many tokens
|
||||
total_price: Decimal # Total cost
|
||||
currency: str # Currency unit
|
||||
latency: float # Request latency (s)
|
||||
```
|
||||
|
||||
______________________________________________________________________
|
||||
|
||||
### RerankResult
|
||||
|
||||
```python
|
||||
class RerankResult(BaseModel):
|
||||
"""
|
||||
Model class for rerank result.
|
||||
"""
|
||||
model: str # Actual model used
|
||||
docs: list[RerankDocument] # Reranked document list
|
||||
```
|
||||
|
||||
### RerankDocument
|
||||
|
||||
```python
|
||||
class RerankDocument(BaseModel):
|
||||
"""
|
||||
Model class for rerank document.
|
||||
"""
|
||||
index: int # original index
|
||||
text: str
|
||||
score: float
|
||||
```
|
||||
|
|
@ -1,176 +0,0 @@
|
|||
## Predefined Model Integration
|
||||
|
||||
After completing the vendor integration, the next step is to integrate the models from the vendor.
|
||||
|
||||
First, we need to determine the type of model to be integrated and create the corresponding model type `module` under the respective vendor's directory.
|
||||
|
||||
Currently supported model types are:
|
||||
|
||||
- `llm` Text Generation Model
|
||||
- `text_embedding` Text Embedding Model
|
||||
- `rerank` Rerank Model
|
||||
- `speech2text` Speech-to-Text
|
||||
- `tts` Text-to-Speech
|
||||
- `moderation` Moderation
|
||||
|
||||
Continuing with `Anthropic` as an example, `Anthropic` only supports LLM, so create a `module` named `llm` under `model_providers.anthropic`.
|
||||
|
||||
For predefined models, we first need to create a YAML file named after the model under the `llm` `module`, such as `claude-2.1.yaml`.
|
||||
|
||||
### Prepare Model YAML
|
||||
|
||||
```yaml
|
||||
model: claude-2.1 # Model identifier
|
||||
# Display name of the model, which can be set to en_US English or zh_Hans Chinese. If zh_Hans is not set, it will default to en_US.
|
||||
# This can also be omitted, in which case the model identifier will be used as the label
|
||||
label:
|
||||
en_US: claude-2.1
|
||||
model_type: llm # Model type, claude-2.1 is an LLM
|
||||
features: # Supported features, agent-thought supports Agent reasoning, vision supports image understanding
|
||||
- agent-thought
|
||||
model_properties: # Model properties
|
||||
mode: chat # LLM mode, complete for text completion models, chat for conversation models
|
||||
context_size: 200000 # Maximum context size
|
||||
parameter_rules: # Parameter rules for the model call; only LLM requires this
|
||||
- name: temperature # Parameter variable name
|
||||
# Five default configuration templates are provided: temperature/top_p/max_tokens/presence_penalty/frequency_penalty
|
||||
# The template variable name can be set directly in use_template, which will use the default configuration in entities.defaults.PARAMETER_RULE_TEMPLATE
|
||||
# Additional configuration parameters will override the default configuration if set
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label: # Display name of the parameter
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int # Parameter type, supports float/int/string/boolean
|
||||
help: # Help information, describing the parameter's function
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false # Whether the parameter is mandatory; can be omitted
|
||||
- name: max_tokens_to_sample
|
||||
use_template: max_tokens
|
||||
default: 4096 # Default value of the parameter
|
||||
min: 1 # Minimum value of the parameter, applicable to float/int only
|
||||
max: 4096 # Maximum value of the parameter, applicable to float/int only
|
||||
pricing: # Pricing information
|
||||
input: '8.00' # Input unit price, i.e., prompt price
|
||||
output: '24.00' # Output unit price, i.e., response content price
|
||||
unit: '0.000001' # Price unit, meaning the above prices are per 100K
|
||||
currency: USD # Price currency
|
||||
```
|
||||
|
||||
It is recommended to prepare all model configurations before starting the implementation of the model code.
|
||||
|
||||
You can also refer to the YAML configuration information under the corresponding model type directories of other vendors in the `model_providers` directory. For the complete YAML rules, refer to: [Schema](schema.md#aimodelentity).
|
||||
|
||||
### Implement the Model Call Code
|
||||
|
||||
Next, create a Python file named `llm.py` under the `llm` `module` to write the implementation code.
|
||||
|
||||
Create an Anthropic LLM class named `AnthropicLargeLanguageModel` (or any other name), inheriting from the `__base.large_language_model.LargeLanguageModel` base class, and implement the following methods:
|
||||
|
||||
- LLM Call
|
||||
|
||||
Implement the core method for calling the LLM, supporting both streaming and synchronous responses.
|
||||
|
||||
```python
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None) \
|
||||
-> Union[LLMResult, Generator]:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param model_parameters: model parameters
|
||||
:param tools: tools for tool calling
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
```
|
||||
|
||||
Ensure to use two functions for returning data, one for synchronous returns and the other for streaming returns, because Python identifies functions containing the `yield` keyword as generator functions, fixing the return type to `Generator`. Thus, synchronous and streaming returns need to be implemented separately, as shown below (note that the example uses simplified parameters, for actual implementation follow the above parameter list):
|
||||
|
||||
```python
|
||||
def _invoke(self, stream: bool, **kwargs) \
|
||||
-> Union[LLMResult, Generator]:
|
||||
if stream:
|
||||
return self._handle_stream_response(**kwargs)
|
||||
return self._handle_sync_response(**kwargs)
|
||||
|
||||
def _handle_stream_response(self, **kwargs) -> Generator:
|
||||
for chunk in response:
|
||||
yield chunk
|
||||
def _handle_sync_response(self, **kwargs) -> LLMResult:
|
||||
return LLMResult(**response)
|
||||
```
|
||||
|
||||
- Pre-compute Input Tokens
|
||||
|
||||
If the model does not provide an interface to precompute tokens, return 0 directly.
|
||||
|
||||
```python
|
||||
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param tools: tools for tool calling
|
||||
:return:
|
||||
"""
|
||||
```
|
||||
|
||||
- Validate Model Credentials
|
||||
|
||||
Similar to vendor credential validation, but specific to a single model.
|
||||
|
||||
```python
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Validate model credentials
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return:
|
||||
"""
|
||||
```
|
||||
|
||||
- Map Invoke Errors
|
||||
|
||||
When a model call fails, map it to a specific `InvokeError` type as required by Runtime, allowing Dify to handle different errors accordingly.
|
||||
|
||||
Runtime Errors:
|
||||
|
||||
- `InvokeConnectionError` Connection error
|
||||
|
||||
- `InvokeServerUnavailableError` Service provider unavailable
|
||||
|
||||
- `InvokeRateLimitError` Rate limit reached
|
||||
|
||||
- `InvokeAuthorizationError` Authorization failed
|
||||
|
||||
- `InvokeBadRequestError` Parameter error
|
||||
|
||||
```python
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
The key is the error type thrown to the caller
|
||||
The value is the error type thrown by the model,
|
||||
which needs to be converted into a unified error type for the caller.
|
||||
|
||||
:return: Invoke error mapping
|
||||
"""
|
||||
```
|
||||
|
||||
For interface method explanations, see: [Interfaces](./interfaces.md). For detailed implementation, refer to: [llm.py](https://github.com/langgenius/dify-runtime/blob/main/lib/model_providers/anthropic/llm/llm.py).
|
||||
|
|
@ -1,266 +0,0 @@
|
|||
## Adding a New Provider
|
||||
|
||||
Providers support three types of model configuration methods:
|
||||
|
||||
- `predefined-model` Predefined model
|
||||
|
||||
This indicates that users only need to configure the unified provider credentials to use the predefined models under the provider.
|
||||
|
||||
- `customizable-model` Customizable model
|
||||
|
||||
Users need to add credential configurations for each model.
|
||||
|
||||
- `fetch-from-remote` Fetch from remote
|
||||
|
||||
This is consistent with the `predefined-model` configuration method. Only unified provider credentials need to be configured, and models are obtained from the provider through credential information.
|
||||
|
||||
These three configuration methods **can coexist**, meaning a provider can support `predefined-model` + `customizable-model` or `predefined-model` + `fetch-from-remote`, etc. In other words, configuring the unified provider credentials allows the use of predefined and remotely fetched models, and if new models are added, they can be used in addition to the custom models.
|
||||
|
||||
## Getting Started
|
||||
|
||||
Adding a new provider starts with determining the English identifier of the provider, such as `anthropic`, and using this identifier to create a `module` in `model_providers`.
|
||||
|
||||
Under this `module`, we first need to prepare the provider's YAML configuration.
|
||||
|
||||
### Preparing Provider YAML
|
||||
|
||||
Here, using `Anthropic` as an example, we preset the provider's basic information, supported model types, configuration methods, and credential rules.
|
||||
|
||||
```YAML
|
||||
provider: anthropic # Provider identifier
|
||||
label: # Provider display name, can be set in en_US English and zh_Hans Chinese, zh_Hans will default to en_US if not set.
|
||||
en_US: Anthropic
|
||||
icon_small: # Small provider icon, stored in the _assets directory under the corresponding provider implementation directory, same language strategy as label
|
||||
en_US: icon_s_en.png
|
||||
icon_large: # Large provider icon, stored in the _assets directory under the corresponding provider implementation directory, same language strategy as label
|
||||
en_US: icon_l_en.png
|
||||
supported_model_types: # Supported model types, Anthropic only supports LLM
|
||||
- llm
|
||||
configurate_methods: # Supported configuration methods, Anthropic only supports predefined models
|
||||
- predefined-model
|
||||
provider_credential_schema: # Provider credential rules, as Anthropic only supports predefined models, unified provider credential rules need to be defined
|
||||
credential_form_schemas: # List of credential form items
|
||||
- variable: anthropic_api_key # Credential parameter variable name
|
||||
label: # Display name
|
||||
en_US: API Key
|
||||
type: secret-input # Form type, here secret-input represents an encrypted information input box, showing masked information when editing.
|
||||
required: true # Whether required
|
||||
placeholder: # Placeholder information
|
||||
zh_Hans: Enter your API Key here
|
||||
en_US: Enter your API Key
|
||||
- variable: anthropic_api_url
|
||||
label:
|
||||
en_US: API URL
|
||||
type: text-input # Form type, here text-input represents a text input box
|
||||
required: false
|
||||
placeholder:
|
||||
zh_Hans: Enter your API URL here
|
||||
en_US: Enter your API URL
|
||||
```
|
||||
|
||||
You can also refer to the YAML configuration information under other provider directories in `model_providers`. The complete YAML rules are available at: [Schema](schema.md#provider).
|
||||
|
||||
### Implementing Provider Code
|
||||
|
||||
Providers need to inherit the `__base.model_provider.ModelProvider` base class and implement the `validate_provider_credentials` method for unified provider credential verification. For reference, see [AnthropicProvider](https://github.com/langgenius/dify-runtime/blob/main/lib/model_providers/anthropic/anthropic.py).
|
||||
|
||||
> If the provider is the type of `customizable-model`, there is no need to implement the `validate_provider_credentials` method.
|
||||
|
||||
```python
|
||||
def validate_provider_credentials(self, credentials: dict) -> None:
|
||||
"""
|
||||
Validate provider credentials
|
||||
You can choose any validate_credentials method of model type or implement validate method by yourself,
|
||||
such as: get model list api
|
||||
|
||||
if validate failed, raise exception
|
||||
|
||||
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
|
||||
"""
|
||||
```
|
||||
|
||||
Of course, you can also preliminarily reserve the implementation of `validate_provider_credentials` and directly reuse it after the model credential verification method is implemented.
|
||||
|
||||
______________________________________________________________________
|
||||
|
||||
### Adding Models
|
||||
|
||||
After the provider integration is complete, the next step is to integrate models under the provider.
|
||||
|
||||
First, we need to determine the type of the model to be integrated and create a `module` for the corresponding model type in the provider's directory.
|
||||
|
||||
The currently supported model types are as follows:
|
||||
|
||||
- `llm` Text generation model
|
||||
- `text_embedding` Text Embedding model
|
||||
- `rerank` Rerank model
|
||||
- `speech2text` Speech to text
|
||||
- `tts` Text to speech
|
||||
- `moderation` Moderation
|
||||
|
||||
Continuing with `Anthropic` as an example, since `Anthropic` only supports LLM, we create a `module` named `llm` in `model_providers.anthropic`.
|
||||
|
||||
For predefined models, we first need to create a YAML file named after the model, such as `claude-2.1.yaml`, under the `llm` `module`.
|
||||
|
||||
#### Preparing Model YAML
|
||||
|
||||
```yaml
|
||||
model: claude-2.1 # Model identifier
|
||||
# Model display name, can be set in en_US English and zh_Hans Chinese, zh_Hans will default to en_US if not set.
|
||||
# Alternatively, if the label is not set, use the model identifier content.
|
||||
label:
|
||||
en_US: claude-2.1
|
||||
model_type: llm # Model type, claude-2.1 is an LLM
|
||||
features: # Supported features, agent-thought for Agent reasoning, vision for image understanding
|
||||
- agent-thought
|
||||
model_properties: # Model properties
|
||||
mode: chat # LLM mode, complete for text completion model, chat for dialogue model
|
||||
context_size: 200000 # Maximum supported context size
|
||||
parameter_rules: # Model invocation parameter rules, only required for LLM
|
||||
- name: temperature # Invocation parameter variable name
|
||||
# Default preset with 5 variable content configuration templates: temperature/top_p/max_tokens/presence_penalty/frequency_penalty
|
||||
# Directly set the template variable name in use_template, which will use the default configuration in entities.defaults.PARAMETER_RULE_TEMPLATE
|
||||
# If additional configuration parameters are set, they will override the default configuration
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label: # Invocation parameter display name
|
||||
zh_Hans: Sampling quantity
|
||||
en_US: Top k
|
||||
type: int # Parameter type, supports float/int/string/boolean
|
||||
help: # Help information, describing the role of the parameter
|
||||
zh_Hans: Only sample from the top K options for each subsequent token.
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false # Whether required, can be left unset
|
||||
- name: max_tokens_to_sample
|
||||
use_template: max_tokens
|
||||
default: 4096 # Default parameter value
|
||||
min: 1 # Minimum parameter value, only applicable for float/int
|
||||
max: 4096 # Maximum parameter value, only applicable for float/int
|
||||
pricing: # Pricing information
|
||||
input: '8.00' # Input price, i.e., Prompt price
|
||||
output: '24.00' # Output price, i.e., returned content price
|
||||
unit: '0.000001' # Pricing unit, i.e., the above prices are per 100K
|
||||
currency: USD # Currency
|
||||
```
|
||||
|
||||
It is recommended to prepare all model configurations before starting the implementation of the model code.
|
||||
|
||||
Similarly, you can also refer to the YAML configuration information for corresponding model types of other providers in the `model_providers` directory. The complete YAML rules can be found at: [Schema](schema.md#AIModel).
|
||||
|
||||
#### Implementing Model Invocation Code
|
||||
|
||||
Next, you need to create a python file named `llm.py` under the `llm` `module` to write the implementation code.
|
||||
|
||||
In `llm.py`, create an Anthropic LLM class, which we name `AnthropicLargeLanguageModel` (arbitrarily), inheriting the `__base.large_language_model.LargeLanguageModel` base class, and implement the following methods:
|
||||
|
||||
- LLM Invocation
|
||||
|
||||
Implement the core method for LLM invocation, which can support both streaming and synchronous returns.
|
||||
|
||||
```python
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None) \
|
||||
-> Union[LLMResult, Generator]:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param model_parameters: model parameters
|
||||
:param tools: tools for tool calling
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
```
|
||||
|
||||
- Pre-calculating Input Tokens
|
||||
|
||||
If the model does not provide a pre-calculated tokens interface, you can directly return 0.
|
||||
|
||||
```python
|
||||
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param tools: tools for tool calling
|
||||
:return:
|
||||
"""
|
||||
```
|
||||
|
||||
- Model Credential Verification
|
||||
|
||||
Similar to provider credential verification, this step involves verification for an individual model.
|
||||
|
||||
```python
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Validate model credentials
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return:
|
||||
"""
|
||||
```
|
||||
|
||||
- Invocation Error Mapping Table
|
||||
|
||||
When there is an exception in model invocation, it needs to be mapped to the `InvokeError` type specified by Runtime. This facilitates Dify's ability to handle different errors with appropriate follow-up actions.
|
||||
|
||||
Runtime Errors:
|
||||
|
||||
- `InvokeConnectionError` Invocation connection error
|
||||
- `InvokeServerUnavailableError` Invocation service provider unavailable
|
||||
- `InvokeRateLimitError` Invocation reached rate limit
|
||||
- `InvokeAuthorizationError` Invocation authorization failure
|
||||
- `InvokeBadRequestError` Invocation parameter error
|
||||
|
||||
```python
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
The key is the error type thrown to the caller
|
||||
The value is the error type thrown by the model,
|
||||
which needs to be converted into a unified error type for the caller.
|
||||
|
||||
:return: Invoke error mapping
|
||||
"""
|
||||
```
|
||||
|
||||
For details on the interface methods, see: [Interfaces](interfaces.md). For specific implementations, refer to: [llm.py](https://github.com/langgenius/dify-runtime/blob/main/lib/model_providers/anthropic/llm/llm.py).
|
||||
|
||||
### Testing
|
||||
|
||||
To ensure the availability of integrated providers/models, each method written needs corresponding integration test code in the `tests` directory.
|
||||
|
||||
Continuing with `Anthropic` as an example:
|
||||
|
||||
Before writing test code, you need to first add the necessary credential environment variables for the test provider in `.env.example`, such as: `ANTHROPIC_API_KEY`.
|
||||
|
||||
Before execution, copy `.env.example` to `.env` and then execute.
|
||||
|
||||
#### Writing Test Code
|
||||
|
||||
Create a `module` with the same name as the provider in the `tests` directory: `anthropic`, and continue to create `test_provider.py` and test py files for the corresponding model types within this module, as shown below:
|
||||
|
||||
```shell
|
||||
.
|
||||
├── __init__.py
|
||||
├── anthropic
|
||||
│ ├── __init__.py
|
||||
│ ├── test_llm.py # LLM Testing
|
||||
│ └── test_provider.py # Provider Testing
|
||||
```
|
||||
|
||||
Write test code for all the various cases implemented above and submit the code after passing the tests.
|
||||
|
|
@ -1,208 +0,0 @@
|
|||
# Configuration Rules
|
||||
|
||||
- Provider rules are based on the [Provider](#Provider) entity.
|
||||
- Model rules are based on the [AIModelEntity](#AIModelEntity) entity.
|
||||
|
||||
> All entities mentioned below are based on `Pydantic BaseModel` and can be found in the `entities` module.
|
||||
|
||||
### Provider
|
||||
|
||||
- `provider` (string) Provider identifier, e.g., `openai`
|
||||
- `label` (object) Provider display name, i18n, with `en_US` English and `zh_Hans` Chinese language settings
|
||||
- `zh_Hans` (string) [optional] Chinese label name, if `zh_Hans` is not set, `en_US` will be used by default.
|
||||
- `en_US` (string) English label name
|
||||
- `description` (object) Provider description, i18n
|
||||
- `zh_Hans` (string) [optional] Chinese description
|
||||
- `en_US` (string) English description
|
||||
- `icon_small` (string) [optional] Small provider ICON, stored in the `_assets` directory under the corresponding provider implementation directory, with the same language strategy as `label`
|
||||
- `zh_Hans` (string) Chinese ICON
|
||||
- `en_US` (string) English ICON
|
||||
- `icon_large` (string) [optional] Large provider ICON, stored in the `_assets` directory under the corresponding provider implementation directory, with the same language strategy as `label`
|
||||
- `zh_Hans` (string) Chinese ICON
|
||||
- `en_US` (string) English ICON
|
||||
- `background` (string) [optional] Background color value, e.g., #FFFFFF, if empty, the default frontend color value will be displayed.
|
||||
- `help` (object) [optional] help information
|
||||
- `title` (object) help title, i18n
|
||||
- `zh_Hans` (string) [optional] Chinese title
|
||||
- `en_US` (string) English title
|
||||
- `url` (object) help link, i18n
|
||||
- `zh_Hans` (string) [optional] Chinese link
|
||||
- `en_US` (string) English link
|
||||
- `supported_model_types` (array\[[ModelType](#ModelType)\]) Supported model types
|
||||
- `configurate_methods` (array\[[ConfigurateMethod](#ConfigurateMethod)\]) Configuration methods
|
||||
- `provider_credential_schema` ([ProviderCredentialSchema](#ProviderCredentialSchema)) Provider credential specification
|
||||
- `model_credential_schema` ([ModelCredentialSchema](#ModelCredentialSchema)) Model credential specification
|
||||
|
||||
### AIModelEntity
|
||||
|
||||
- `model` (string) Model identifier, e.g., `gpt-3.5-turbo`
|
||||
- `label` (object) [optional] Model display name, i18n, with `en_US` English and `zh_Hans` Chinese language settings
|
||||
- `zh_Hans` (string) [optional] Chinese label name
|
||||
- `en_US` (string) English label name
|
||||
- `model_type` ([ModelType](#ModelType)) Model type
|
||||
- `features` (array\[[ModelFeature](#ModelFeature)\]) [optional] Supported feature list
|
||||
- `model_properties` (object) Model properties
|
||||
- `mode` ([LLMMode](#LLMMode)) Mode (available for model type `llm`)
|
||||
- `context_size` (int) Context size (available for model types `llm`, `text-embedding`)
|
||||
- `max_chunks` (int) Maximum number of chunks (available for model types `text-embedding`, `moderation`)
|
||||
- `file_upload_limit` (int) Maximum file upload limit, in MB (available for model type `speech2text`)
|
||||
- `supported_file_extensions` (string) Supported file extension formats, e.g., mp3, mp4 (available for model type `speech2text`)
|
||||
- `default_voice` (string) default voice, e.g.:alloy,echo,fable,onyx,nova,shimmer(available for model type `tts`)
|
||||
- `voices` (list) List of available voice.(available for model type `tts`)
|
||||
- `mode` (string) voice model.(available for model type `tts`)
|
||||
- `name` (string) voice model display name.(available for model type `tts`)
|
||||
- `language` (string) the voice model supports languages.(available for model type `tts`)
|
||||
- `word_limit` (int) Single conversion word limit, paragraph-wise by default(available for model type `tts`)
|
||||
- `audio_type` (string) Support audio file extension format, e.g.:mp3,wav(available for model type `tts`)
|
||||
- `max_workers` (int) Number of concurrent workers supporting text and audio conversion(available for model type`tts`)
|
||||
- `max_characters_per_chunk` (int) Maximum characters per chunk (available for model type `moderation`)
|
||||
- `parameter_rules` (array\[[ParameterRule](#ParameterRule)\]) [optional] Model invocation parameter rules
|
||||
- `pricing` ([PriceConfig](#PriceConfig)) [optional] Pricing information
|
||||
- `deprecated` (bool) Whether deprecated. If deprecated, the model will no longer be displayed in the list, but those already configured can continue to be used. Default False.
|
||||
|
||||
### ModelType
|
||||
|
||||
- `llm` Text generation model
|
||||
- `text-embedding` Text Embedding model
|
||||
- `rerank` Rerank model
|
||||
- `speech2text` Speech to text
|
||||
- `tts` Text to speech
|
||||
- `moderation` Moderation
|
||||
|
||||
### ConfigurateMethod
|
||||
|
||||
- `predefined-model` Predefined model
|
||||
|
||||
Indicates that users can use the predefined models under the provider by configuring the unified provider credentials.
|
||||
|
||||
- `customizable-model` Customizable model
|
||||
|
||||
Users need to add credential configuration for each model.
|
||||
|
||||
- `fetch-from-remote` Fetch from remote
|
||||
|
||||
Consistent with the `predefined-model` configuration method, only unified provider credentials need to be configured, and models are obtained from the provider through credential information.
|
||||
|
||||
### ModelFeature
|
||||
|
||||
- `agent-thought` Agent reasoning, generally over 70B with thought chain capability.
|
||||
- `vision` Vision, i.e., image understanding.
|
||||
- `tool-call`
|
||||
- `multi-tool-call`
|
||||
- `stream-tool-call`
|
||||
|
||||
### FetchFrom
|
||||
|
||||
- `predefined-model` Predefined model
|
||||
- `fetch-from-remote` Remote model
|
||||
|
||||
### LLMMode
|
||||
|
||||
- `complete` Text completion
|
||||
- `chat` Dialogue
|
||||
|
||||
### ParameterRule
|
||||
|
||||
- `name` (string) Actual model invocation parameter name
|
||||
|
||||
- `use_template` (string) [optional] Using template
|
||||
|
||||
By default, 5 variable content configuration templates are preset:
|
||||
|
||||
- `temperature`
|
||||
- `top_p`
|
||||
- `frequency_penalty`
|
||||
- `presence_penalty`
|
||||
- `max_tokens`
|
||||
|
||||
In use_template, you can directly set the template variable name, which will use the default configuration in entities.defaults.PARAMETER_RULE_TEMPLATE
|
||||
No need to set any parameters other than `name` and `use_template`. If additional configuration parameters are set, they will override the default configuration.
|
||||
Refer to `openai/llm/gpt-3.5-turbo.yaml`.
|
||||
|
||||
- `label` (object) [optional] Label, i18n
|
||||
|
||||
- `zh_Hans`(string) [optional] Chinese label name
|
||||
- `en_US` (string) English label name
|
||||
|
||||
- `type`(string) [optional] Parameter type
|
||||
|
||||
- `int` Integer
|
||||
- `float` Float
|
||||
- `string` String
|
||||
- `boolean` Boolean
|
||||
|
||||
- `help` (string) [optional] Help information
|
||||
|
||||
- `zh_Hans` (string) [optional] Chinese help information
|
||||
- `en_US` (string) English help information
|
||||
|
||||
- `required` (bool) Required, default False.
|
||||
|
||||
- `default`(int/float/string/bool) [optional] Default value
|
||||
|
||||
- `min`(int/float) [optional] Minimum value, applicable only to numeric types
|
||||
|
||||
- `max`(int/float) [optional] Maximum value, applicable only to numeric types
|
||||
|
||||
- `precision`(int) [optional] Precision, number of decimal places to keep, applicable only to numeric types
|
||||
|
||||
- `options` (array[string]) [optional] Dropdown option values, applicable only when `type` is `string`, if not set or null, option values are not restricted
|
||||
|
||||
### PriceConfig
|
||||
|
||||
- `input` (float) Input price, i.e., Prompt price
|
||||
- `output` (float) Output price, i.e., returned content price
|
||||
- `unit` (float) Pricing unit, e.g., if the price is measured in 1M tokens, the corresponding token amount for the unit price is `0.000001`.
|
||||
- `currency` (string) Currency unit
|
||||
|
||||
### ProviderCredentialSchema
|
||||
|
||||
- `credential_form_schemas` (array\[[CredentialFormSchema](#CredentialFormSchema)\]) Credential form standard
|
||||
|
||||
### ModelCredentialSchema
|
||||
|
||||
- `model` (object) Model identifier, variable name defaults to `model`
|
||||
- `label` (object) Model form item display name
|
||||
- `en_US` (string) English
|
||||
- `zh_Hans`(string) [optional] Chinese
|
||||
- `placeholder` (object) Model prompt content
|
||||
- `en_US`(string) English
|
||||
- `zh_Hans`(string) [optional] Chinese
|
||||
- `credential_form_schemas` (array\[[CredentialFormSchema](#CredentialFormSchema)\]) Credential form standard
|
||||
|
||||
### CredentialFormSchema
|
||||
|
||||
- `variable` (string) Form item variable name
|
||||
- `label` (object) Form item label name
|
||||
- `en_US`(string) English
|
||||
- `zh_Hans` (string) [optional] Chinese
|
||||
- `type` ([FormType](#FormType)) Form item type
|
||||
- `required` (bool) Whether required
|
||||
- `default`(string) Default value
|
||||
- `options` (array\[[FormOption](#FormOption)\]) Specific property of form items of type `select` or `radio`, defining dropdown content
|
||||
- `placeholder`(object) Specific property of form items of type `text-input`, placeholder content
|
||||
- `en_US`(string) English
|
||||
- `zh_Hans` (string) [optional] Chinese
|
||||
- `max_length` (int) Specific property of form items of type `text-input`, defining maximum input length, 0 for no limit.
|
||||
- `show_on` (array\[[FormShowOnObject](#FormShowOnObject)\]) Displayed when other form item values meet certain conditions, displayed always if empty.
|
||||
|
||||
### FormType
|
||||
|
||||
- `text-input` Text input component
|
||||
- `secret-input` Password input component
|
||||
- `select` Single-choice dropdown
|
||||
- `radio` Radio component
|
||||
- `switch` Switch component, only supports `true` and `false` values
|
||||
|
||||
### FormOption
|
||||
|
||||
- `label` (object) Label
|
||||
- `en_US`(string) English
|
||||
- `zh_Hans`(string) [optional] Chinese
|
||||
- `value` (string) Dropdown option value
|
||||
- `show_on` (array\[[FormShowOnObject](#FormShowOnObject)\]) Displayed when other form item values meet certain conditions, displayed always if empty.
|
||||
|
||||
### FormShowOnObject
|
||||
|
||||
- `variable` (string) Variable name of other form items
|
||||
- `value` (string) Variable value of other form items
|
||||
|
|
@ -1,304 +0,0 @@
|
|||
## 自定义预定义模型接入
|
||||
|
||||
### 介绍
|
||||
|
||||
供应商集成完成后,接下来为供应商下模型的接入,为了帮助理解整个接入过程,我们以`Xinference`为例,逐步完成一个完整的供应商接入。
|
||||
|
||||
需要注意的是,对于自定义模型,每一个模型的接入都需要填写一个完整的供应商凭据。
|
||||
|
||||
而不同于预定义模型,自定义供应商接入时永远会拥有如下两个参数,不需要在供应商 yaml 中定义。
|
||||
|
||||

|
||||
|
||||
在前文中,我们已经知道了供应商无需实现`validate_provider_credential`,Runtime 会自行根据用户在此选择的模型类型和模型名称调用对应的模型层的`validate_credentials`来进行验证。
|
||||
|
||||
### 编写供应商 yaml
|
||||
|
||||
我们首先要确定,接入的这个供应商支持哪些类型的模型。
|
||||
|
||||
当前支持模型类型如下:
|
||||
|
||||
- `llm` 文本生成模型
|
||||
- `text_embedding` 文本 Embedding 模型
|
||||
- `rerank` Rerank 模型
|
||||
- `speech2text` 语音转文字
|
||||
- `tts` 文字转语音
|
||||
- `moderation` 审查
|
||||
|
||||
`Xinference`支持`LLM`和`Text Embedding`和 Rerank,那么我们开始编写`xinference.yaml`。
|
||||
|
||||
```yaml
|
||||
provider: xinference #确定供应商标识
|
||||
label: # 供应商展示名称,可设置 en_US 英文、zh_Hans 中文两种语言,zh_Hans 不设置将默认使用 en_US。
|
||||
en_US: Xorbits Inference
|
||||
icon_small: # 小图标,可以参考其他供应商的图标,存储在对应供应商实现目录下的 _assets 目录,中英文策略同 label
|
||||
en_US: icon_s_en.svg
|
||||
icon_large: # 大图标
|
||||
en_US: icon_l_en.svg
|
||||
help: # 帮助
|
||||
title:
|
||||
en_US: How to deploy Xinference
|
||||
zh_Hans: 如何部署 Xinference
|
||||
url:
|
||||
en_US: https://github.com/xorbitsai/inference
|
||||
supported_model_types: # 支持的模型类型,Xinference 同时支持 LLM/Text Embedding/Rerank
|
||||
- llm
|
||||
- text-embedding
|
||||
- rerank
|
||||
configurate_methods: # 因为 Xinference 为本地部署的供应商,并且没有预定义模型,需要用什么模型需要根据 Xinference 的文档自己部署,所以这里只支持自定义模型
|
||||
- customizable-model
|
||||
provider_credential_schema:
|
||||
credential_form_schemas:
|
||||
```
|
||||
|
||||
随后,我们需要思考在 Xinference 中定义一个模型需要哪些凭据
|
||||
|
||||
- 它支持三种不同的模型,因此,我们需要有`model_type`来指定这个模型的类型,它有三种类型,所以我们这么编写
|
||||
|
||||
```yaml
|
||||
provider_credential_schema:
|
||||
credential_form_schemas:
|
||||
- variable: model_type
|
||||
type: select
|
||||
label:
|
||||
en_US: Model type
|
||||
zh_Hans: 模型类型
|
||||
required: true
|
||||
options:
|
||||
- value: text-generation
|
||||
label:
|
||||
en_US: Language Model
|
||||
zh_Hans: 语言模型
|
||||
- value: embeddings
|
||||
label:
|
||||
en_US: Text Embedding
|
||||
- value: reranking
|
||||
label:
|
||||
en_US: Rerank
|
||||
```
|
||||
|
||||
- 每一个模型都有自己的名称`model_name`,因此需要在这里定义
|
||||
|
||||
```yaml
|
||||
- variable: model_name
|
||||
type: text-input
|
||||
label:
|
||||
en_US: Model name
|
||||
zh_Hans: 模型名称
|
||||
required: true
|
||||
placeholder:
|
||||
zh_Hans: 填写模型名称
|
||||
en_US: Input model name
|
||||
```
|
||||
|
||||
- 填写 Xinference 本地部署的地址
|
||||
|
||||
```yaml
|
||||
- variable: server_url
|
||||
label:
|
||||
zh_Hans: 服务器 URL
|
||||
en_US: Server url
|
||||
type: text-input
|
||||
required: true
|
||||
placeholder:
|
||||
zh_Hans: 在此输入 Xinference 的服务器地址,如 https://example.com/xxx
|
||||
en_US: Enter the url of your Xinference, for example https://example.com/xxx
|
||||
```
|
||||
|
||||
- 每个模型都有唯一的 model_uid,因此需要在这里定义
|
||||
|
||||
```yaml
|
||||
- variable: model_uid
|
||||
label:
|
||||
zh_Hans: 模型 UID
|
||||
en_US: Model uid
|
||||
type: text-input
|
||||
required: true
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的 Model UID
|
||||
en_US: Enter the model uid
|
||||
```
|
||||
|
||||
现在,我们就完成了供应商的基础定义。
|
||||
|
||||
### 编写模型代码
|
||||
|
||||
然后我们以`llm`类型为例,编写`xinference.llm.llm.py`
|
||||
|
||||
在 `llm.py` 中创建一个 Xinference LLM 类,我们取名为 `XinferenceAILargeLanguageModel`(随意),继承 `__base.large_language_model.LargeLanguageModel` 基类,实现以下几个方法:
|
||||
|
||||
- LLM 调用
|
||||
|
||||
实现 LLM 调用的核心方法,可同时支持流式和同步返回。
|
||||
|
||||
```python
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None) \
|
||||
-> Union[LLMResult, Generator]:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param model_parameters: model parameters
|
||||
:param tools: tools for tool calling
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
```
|
||||
|
||||
在实现时,需要注意使用两个函数来返回数据,分别用于处理同步返回和流式返回,因为 Python 会将函数中包含 `yield` 关键字的函数识别为生成器函数,返回的数据类型固定为 `Generator`,因此同步和流式返回需要分别实现,就像下面这样(注意下面例子使用了简化参数,实际实现时需要按照上面的参数列表进行实现):
|
||||
|
||||
```python
|
||||
def _invoke(self, stream: bool, **kwargs) \
|
||||
-> Union[LLMResult, Generator]:
|
||||
if stream:
|
||||
return self._handle_stream_response(**kwargs)
|
||||
return self._handle_sync_response(**kwargs)
|
||||
|
||||
def _handle_stream_response(self, **kwargs) -> Generator:
|
||||
for chunk in response:
|
||||
yield chunk
|
||||
def _handle_sync_response(self, **kwargs) -> LLMResult:
|
||||
return LLMResult(**response)
|
||||
```
|
||||
|
||||
- 预计算输入 tokens
|
||||
|
||||
若模型未提供预计算 tokens 接口,可直接返回 0。
|
||||
|
||||
```python
|
||||
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param tools: tools for tool calling
|
||||
:return:
|
||||
"""
|
||||
```
|
||||
|
||||
有时候,也许你不需要直接返回 0,所以你可以使用`self._get_num_tokens_by_gpt2(text: str)`来获取预计算的 tokens,并确保环境变量`PLUGIN_BASED_TOKEN_COUNTING_ENABLED`设置为`true`,这个方法位于`AIModel`基类中,它会使用 GPT2 的 Tokenizer 进行计算,但是只能作为替代方法,并不完全准确。
|
||||
|
||||
- 模型凭据校验
|
||||
|
||||
与供应商凭据校验类似,这里针对单个模型进行校验。
|
||||
|
||||
```python
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Validate model credentials
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return:
|
||||
"""
|
||||
```
|
||||
|
||||
- 模型参数 Schema
|
||||
|
||||
与自定义类型不同,由于没有在 yaml 文件中定义一个模型支持哪些参数,因此,我们需要动态时间模型参数的 Schema。
|
||||
|
||||
如 Xinference 支持`max_tokens` `temperature` `top_p` 这三个模型参数。
|
||||
|
||||
但是有的供应商根据不同的模型支持不同的参数,如供应商`OpenLLM`支持`top_k`,但是并不是这个供应商提供的所有模型都支持`top_k`,我们这里举例 A 模型支持`top_k`,B 模型不支持`top_k`,那么我们需要在这里动态生成模型参数的 Schema,如下所示:
|
||||
|
||||
```python
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
||||
"""
|
||||
used to define customizable model schema
|
||||
"""
|
||||
rules = [
|
||||
ParameterRule(
|
||||
name='temperature', type=ParameterType.FLOAT,
|
||||
use_template='temperature',
|
||||
label=I18nObject(
|
||||
zh_Hans='温度', en_US='Temperature'
|
||||
)
|
||||
),
|
||||
ParameterRule(
|
||||
name='top_p', type=ParameterType.FLOAT,
|
||||
use_template='top_p',
|
||||
label=I18nObject(
|
||||
zh_Hans='Top P', en_US='Top P'
|
||||
)
|
||||
),
|
||||
ParameterRule(
|
||||
name='max_tokens', type=ParameterType.INT,
|
||||
use_template='max_tokens',
|
||||
min=1,
|
||||
default=512,
|
||||
label=I18nObject(
|
||||
zh_Hans='最大生成长度', en_US='Max Tokens'
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
# if model is A, add top_k to rules
|
||||
if model == 'A':
|
||||
rules.append(
|
||||
ParameterRule(
|
||||
name='top_k', type=ParameterType.INT,
|
||||
use_template='top_k',
|
||||
min=1,
|
||||
default=50,
|
||||
label=I18nObject(
|
||||
zh_Hans='Top K', en_US='Top K'
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
"""
|
||||
some NOT IMPORTANT code here
|
||||
"""
|
||||
|
||||
entity = AIModelEntity(
|
||||
model=model,
|
||||
label=I18nObject(
|
||||
en_US=model
|
||||
),
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_type=model_type,
|
||||
model_properties={
|
||||
ModelPropertyKey.MODE: ModelType.LLM,
|
||||
},
|
||||
parameter_rules=rules
|
||||
)
|
||||
|
||||
return entity
|
||||
```
|
||||
|
||||
- 调用异常错误映射表
|
||||
|
||||
当模型调用异常时需要映射到 Runtime 指定的 `InvokeError` 类型,方便 Dify 针对不同错误做不同后续处理。
|
||||
|
||||
Runtime Errors:
|
||||
|
||||
- `InvokeConnectionError` 调用连接错误
|
||||
- `InvokeServerUnavailableError ` 调用服务方不可用
|
||||
- `InvokeRateLimitError ` 调用达到限额
|
||||
- `InvokeAuthorizationError` 调用鉴权失败
|
||||
- `InvokeBadRequestError ` 调用传参有误
|
||||
|
||||
```python
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
The key is the error type thrown to the caller
|
||||
The value is the error type thrown by the model,
|
||||
which needs to be converted into a unified error type for the caller.
|
||||
|
||||
:return: Invoke error mapping
|
||||
"""
|
||||
```
|
||||
|
||||
接口方法说明见:[Interfaces](./interfaces.md),具体实现可参考:[llm.py](https://github.com/langgenius/dify-runtime/blob/main/lib/model_providers/anthropic/llm/llm.py)。
|
||||
|
Before Width: | Height: | Size: 230 KiB |
|
Before Width: | Height: | Size: 205 KiB |
|
Before Width: | Height: | Size: 385 KiB |
|
Before Width: | Height: | Size: 113 KiB |
|
Before Width: | Height: | Size: 109 KiB |
|
Before Width: | Height: | Size: 70 KiB |
|
Before Width: | Height: | Size: 75 KiB |
|
Before Width: | Height: | Size: 541 KiB |
|
Before Width: | Height: | Size: 44 KiB |
|
Before Width: | Height: | Size: 262 KiB |
|
|
@ -1,744 +0,0 @@
|
|||
# 接口方法
|
||||
|
||||
这里介绍供应商和各模型类型需要实现的接口方法和参数说明。
|
||||
|
||||
## 供应商
|
||||
|
||||
继承 `__base.model_provider.ModelProvider` 基类,实现以下接口:
|
||||
|
||||
```python
|
||||
def validate_provider_credentials(self, credentials: dict) -> None:
|
||||
"""
|
||||
Validate provider credentials
|
||||
You can choose any validate_credentials method of model type or implement validate method by yourself,
|
||||
such as: get model list api
|
||||
|
||||
if validate failed, raise exception
|
||||
|
||||
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
|
||||
"""
|
||||
```
|
||||
|
||||
- `credentials` (object) 凭据信息
|
||||
|
||||
凭据信息的参数由供应商 YAML 配置文件的 `provider_credential_schema` 定义,传入如:`api_key` 等。
|
||||
|
||||
验证失败请抛出 `errors.validate.CredentialsValidateFailedError` 错误。
|
||||
|
||||
**注:预定义模型需完整实现该接口,自定义模型供应商只需要如下简单实现即可**
|
||||
|
||||
```python
|
||||
class XinferenceProvider(Provider):
|
||||
def validate_provider_credentials(self, credentials: dict) -> None:
|
||||
pass
|
||||
```
|
||||
|
||||
## 模型
|
||||
|
||||
模型分为 5 种不同的模型类型,不同模型类型继承的基类不同,需要实现的方法也不同。
|
||||
|
||||
### 通用接口
|
||||
|
||||
所有模型均需要统一实现下面 2 个方法:
|
||||
|
||||
- 模型凭据校验
|
||||
|
||||
与供应商凭据校验类似,这里针对单个模型进行校验。
|
||||
|
||||
```python
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Validate model credentials
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return:
|
||||
"""
|
||||
```
|
||||
|
||||
参数:
|
||||
|
||||
- `model` (string) 模型名称
|
||||
|
||||
- `credentials` (object) 凭据信息
|
||||
|
||||
凭据信息的参数由供应商 YAML 配置文件的 `provider_credential_schema` 或 `model_credential_schema` 定义,传入如:`api_key` 等。
|
||||
|
||||
验证失败请抛出 `errors.validate.CredentialsValidateFailedError` 错误。
|
||||
|
||||
- 调用异常错误映射表
|
||||
|
||||
当模型调用异常时需要映射到 Runtime 指定的 `InvokeError` 类型,方便 Dify 针对不同错误做不同后续处理。
|
||||
|
||||
Runtime Errors:
|
||||
|
||||
- `InvokeConnectionError` 调用连接错误
|
||||
- `InvokeServerUnavailableError ` 调用服务方不可用
|
||||
- `InvokeRateLimitError ` 调用达到限额
|
||||
- `InvokeAuthorizationError` 调用鉴权失败
|
||||
- `InvokeBadRequestError ` 调用传参有误
|
||||
|
||||
```python
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
The key is the error type thrown to the caller
|
||||
The value is the error type thrown by the model,
|
||||
which needs to be converted into a unified error type for the caller.
|
||||
|
||||
:return: Invoke error mapping
|
||||
"""
|
||||
```
|
||||
|
||||
也可以直接抛出对应 Errors,并做如下定义,这样在之后的调用中可以直接抛出`InvokeConnectionError`等异常。
|
||||
|
||||
```python
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
return {
|
||||
InvokeConnectionError: [
|
||||
InvokeConnectionError
|
||||
],
|
||||
InvokeServerUnavailableError: [
|
||||
InvokeServerUnavailableError
|
||||
],
|
||||
InvokeRateLimitError: [
|
||||
InvokeRateLimitError
|
||||
],
|
||||
InvokeAuthorizationError: [
|
||||
InvokeAuthorizationError
|
||||
],
|
||||
InvokeBadRequestError: [
|
||||
InvokeBadRequestError
|
||||
],
|
||||
}
|
||||
```
|
||||
|
||||
可参考 OpenAI `_invoke_error_mapping`。
|
||||
|
||||
### LLM
|
||||
|
||||
继承 `__base.large_language_model.LargeLanguageModel` 基类,实现以下接口:
|
||||
|
||||
- LLM 调用
|
||||
|
||||
实现 LLM 调用的核心方法,可同时支持流式和同步返回。
|
||||
|
||||
```python
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None) \
|
||||
-> Union[LLMResult, Generator]:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param model_parameters: model parameters
|
||||
:param tools: tools for tool calling
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
```
|
||||
|
||||
- 参数:
|
||||
|
||||
- `model` (string) 模型名称
|
||||
|
||||
- `credentials` (object) 凭据信息
|
||||
|
||||
凭据信息的参数由供应商 YAML 配置文件的 `provider_credential_schema` 或 `model_credential_schema` 定义,传入如:`api_key` 等。
|
||||
|
||||
- `prompt_messages` (array\[[PromptMessage](#PromptMessage)\]) Prompt 列表
|
||||
|
||||
若模型为 `Completion` 类型,则列表只需要传入一个 [UserPromptMessage](#UserPromptMessage) 元素即可;
|
||||
|
||||
若模型为 `Chat` 类型,需要根据消息不同传入 [SystemPromptMessage](#SystemPromptMessage), [UserPromptMessage](#UserPromptMessage), [AssistantPromptMessage](#AssistantPromptMessage), [ToolPromptMessage](#ToolPromptMessage) 元素列表
|
||||
|
||||
- `model_parameters` (object) 模型参数
|
||||
|
||||
模型参数由模型 YAML 配置的 `parameter_rules` 定义。
|
||||
|
||||
- `tools` (array\[[PromptMessageTool](#PromptMessageTool)\]) [optional] 工具列表,等同于 `function calling` 中的 `function`。
|
||||
|
||||
即传入 tool calling 的工具列表。
|
||||
|
||||
- `stop` (array[string]) [optional] 停止序列
|
||||
|
||||
模型返回将在停止序列定义的字符串之前停止输出。
|
||||
|
||||
- `stream` (bool) 是否流式输出,默认 True
|
||||
|
||||
流式输出返回 Generator\[[LLMResultChunk](#LLMResultChunk)\],非流式输出返回 [LLMResult](#LLMResult)。
|
||||
|
||||
- `user` (string) [optional] 用户的唯一标识符
|
||||
|
||||
可以帮助供应商监控和检测滥用行为。
|
||||
|
||||
- 返回
|
||||
|
||||
流式输出返回 Generator\[[LLMResultChunk](#LLMResultChunk)\],非流式输出返回 [LLMResult](#LLMResult)。
|
||||
|
||||
- 预计算输入 tokens
|
||||
|
||||
若模型未提供预计算 tokens 接口,可直接返回 0。
|
||||
|
||||
```python
|
||||
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param tools: tools for tool calling
|
||||
:return:
|
||||
"""
|
||||
```
|
||||
|
||||
参数说明见上述 `LLM 调用`。
|
||||
|
||||
该接口需要根据对应`model`选择合适的`tokenizer`进行计算,如果对应模型没有提供`tokenizer`,可以使用`AIModel`基类中的`_get_num_tokens_by_gpt2(text: str)`方法进行计算。
|
||||
|
||||
- 获取自定义模型规则 [可选]
|
||||
|
||||
```python
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
||||
"""
|
||||
Get customizable model schema
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return: model schema
|
||||
"""
|
||||
```
|
||||
|
||||
当供应商支持增加自定义 LLM 时,可实现此方法让自定义模型可获取模型规则,默认返回 None。
|
||||
|
||||
对于`OpenAI`供应商下的大部分微调模型,可以通过其微调模型名称获取到其基类模型,如`gpt-3.5-turbo-1106`,然后返回基类模型的预定义参数规则,参考[openai](https://github.com/langgenius/dify/blob/feat/model-runtime/api/core/model_runtime/model_providers/openai/llm/llm.py#L801)
|
||||
的具体实现
|
||||
|
||||
### TextEmbedding
|
||||
|
||||
继承 `__base.text_embedding_model.TextEmbeddingModel` 基类,实现以下接口:
|
||||
|
||||
- Embedding 调用
|
||||
|
||||
```python
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
texts: list[str], user: Optional[str] = None) \
|
||||
-> TextEmbeddingResult:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:param user: unique user id
|
||||
:return: embeddings result
|
||||
"""
|
||||
```
|
||||
|
||||
- 参数:
|
||||
|
||||
- `model` (string) 模型名称
|
||||
|
||||
- `credentials` (object) 凭据信息
|
||||
|
||||
凭据信息的参数由供应商 YAML 配置文件的 `provider_credential_schema` 或 `model_credential_schema` 定义,传入如:`api_key` 等。
|
||||
|
||||
- `texts` (array[string]) 文本列表,可批量处理
|
||||
|
||||
- `user` (string) [optional] 用户的唯一标识符
|
||||
|
||||
可以帮助供应商监控和检测滥用行为。
|
||||
|
||||
- 返回:
|
||||
|
||||
[TextEmbeddingResult](#TextEmbeddingResult) 实体。
|
||||
|
||||
- 预计算 tokens
|
||||
|
||||
```python
|
||||
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:return:
|
||||
"""
|
||||
```
|
||||
|
||||
参数说明见上述 `Embedding 调用`。
|
||||
|
||||
同上述`LargeLanguageModel`,该接口需要根据对应`model`选择合适的`tokenizer`进行计算,如果对应模型没有提供`tokenizer`,可以使用`AIModel`基类中的`_get_num_tokens_by_gpt2(text: str)`方法进行计算。
|
||||
|
||||
### Rerank
|
||||
|
||||
继承 `__base.rerank_model.RerankModel` 基类,实现以下接口:
|
||||
|
||||
- rerank 调用
|
||||
|
||||
```python
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None,
|
||||
user: Optional[str] = None) \
|
||||
-> RerankResult:
|
||||
"""
|
||||
Invoke rerank model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param query: search query
|
||||
:param docs: docs for reranking
|
||||
:param score_threshold: score threshold
|
||||
:param top_n: top n
|
||||
:param user: unique user id
|
||||
:return: rerank result
|
||||
"""
|
||||
```
|
||||
|
||||
- 参数:
|
||||
|
||||
- `model` (string) 模型名称
|
||||
|
||||
- `credentials` (object) 凭据信息
|
||||
|
||||
凭据信息的参数由供应商 YAML 配置文件的 `provider_credential_schema` 或 `model_credential_schema` 定义,传入如:`api_key` 等。
|
||||
|
||||
- `query` (string) 查询请求内容
|
||||
|
||||
- `docs` (array[string]) 需要重排的分段列表
|
||||
|
||||
- `score_threshold` (float) [optional] Score 阈值
|
||||
|
||||
- `top_n` (int) [optional] 取前 n 个分段
|
||||
|
||||
- `user` (string) [optional] 用户的唯一标识符
|
||||
|
||||
可以帮助供应商监控和检测滥用行为。
|
||||
|
||||
- 返回:
|
||||
|
||||
[RerankResult](#RerankResult) 实体。
|
||||
|
||||
### Speech2text
|
||||
|
||||
继承 `__base.speech2text_model.Speech2TextModel` 基类,实现以下接口:
|
||||
|
||||
- Invoke 调用
|
||||
|
||||
```python
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
file: IO[bytes], user: Optional[str] = None) \
|
||||
-> str:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param file: audio file
|
||||
:param user: unique user id
|
||||
:return: text for given audio file
|
||||
"""
|
||||
```
|
||||
|
||||
- 参数:
|
||||
|
||||
- `model` (string) 模型名称
|
||||
|
||||
- `credentials` (object) 凭据信息
|
||||
|
||||
凭据信息的参数由供应商 YAML 配置文件的 `provider_credential_schema` 或 `model_credential_schema` 定义,传入如:`api_key` 等。
|
||||
|
||||
- `file` (File) 文件流
|
||||
|
||||
- `user` (string) [optional] 用户的唯一标识符
|
||||
|
||||
可以帮助供应商监控和检测滥用行为。
|
||||
|
||||
- 返回:
|
||||
|
||||
语音转换后的字符串。
|
||||
|
||||
### Text2speech
|
||||
|
||||
继承 `__base.text2speech_model.Text2SpeechModel` 基类,实现以下接口:
|
||||
|
||||
- Invoke 调用
|
||||
|
||||
```python
|
||||
def _invoke(self, model: str, credentials: dict, content_text: str, streaming: bool, user: Optional[str] = None):
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param content_text: text content to be translated
|
||||
:param streaming: output is streaming
|
||||
:param user: unique user id
|
||||
:return: translated audio file
|
||||
"""
|
||||
```
|
||||
|
||||
- 参数:
|
||||
|
||||
- `model` (string) 模型名称
|
||||
|
||||
- `credentials` (object) 凭据信息
|
||||
|
||||
凭据信息的参数由供应商 YAML 配置文件的 `provider_credential_schema` 或 `model_credential_schema` 定义,传入如:`api_key` 等。
|
||||
|
||||
- `content_text` (string) 需要转换的文本内容
|
||||
|
||||
- `streaming` (bool) 是否进行流式输出
|
||||
|
||||
- `user` (string) [optional] 用户的唯一标识符
|
||||
|
||||
可以帮助供应商监控和检测滥用行为。
|
||||
|
||||
- 返回:
|
||||
|
||||
文本转换后的语音流。
|
||||
|
||||
### Moderation
|
||||
|
||||
继承 `__base.moderation_model.ModerationModel` 基类,实现以下接口:
|
||||
|
||||
- Invoke 调用
|
||||
|
||||
```python
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
text: str, user: Optional[str] = None) \
|
||||
-> bool:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param text: text to moderate
|
||||
:param user: unique user id
|
||||
:return: false if text is safe, true otherwise
|
||||
"""
|
||||
```
|
||||
|
||||
- 参数:
|
||||
|
||||
- `model` (string) 模型名称
|
||||
|
||||
- `credentials` (object) 凭据信息
|
||||
|
||||
凭据信息的参数由供应商 YAML 配置文件的 `provider_credential_schema` 或 `model_credential_schema` 定义,传入如:`api_key` 等。
|
||||
|
||||
- `text` (string) 文本内容
|
||||
|
||||
- `user` (string) [optional] 用户的唯一标识符
|
||||
|
||||
可以帮助供应商监控和检测滥用行为。
|
||||
|
||||
- 返回:
|
||||
|
||||
False 代表传入的文本安全,True 则反之。
|
||||
|
||||
## 实体
|
||||
|
||||
### PromptMessageRole
|
||||
|
||||
消息角色
|
||||
|
||||
```python
|
||||
class PromptMessageRole(Enum):
|
||||
"""
|
||||
Enum class for prompt message.
|
||||
"""
|
||||
SYSTEM = "system"
|
||||
USER = "user"
|
||||
ASSISTANT = "assistant"
|
||||
TOOL = "tool"
|
||||
```
|
||||
|
||||
### PromptMessageContentType
|
||||
|
||||
消息内容类型,分为纯文本和图片。
|
||||
|
||||
```python
|
||||
class PromptMessageContentType(Enum):
|
||||
"""
|
||||
Enum class for prompt message content type.
|
||||
"""
|
||||
TEXT = 'text'
|
||||
IMAGE = 'image'
|
||||
```
|
||||
|
||||
### PromptMessageContent
|
||||
|
||||
消息内容基类,仅作为参数声明用,不可初始化。
|
||||
|
||||
```python
|
||||
class PromptMessageContent(BaseModel):
|
||||
"""
|
||||
Model class for prompt message content.
|
||||
"""
|
||||
type: PromptMessageContentType
|
||||
data: str # 内容数据
|
||||
```
|
||||
|
||||
当前支持文本和图片两种类型,可支持同时传入文本和多图。
|
||||
|
||||
需要分别初始化 `TextPromptMessageContent` 和 `ImagePromptMessageContent` 传入。
|
||||
|
||||
### TextPromptMessageContent
|
||||
|
||||
```python
|
||||
class TextPromptMessageContent(PromptMessageContent):
|
||||
"""
|
||||
Model class for text prompt message content.
|
||||
"""
|
||||
type: PromptMessageContentType = PromptMessageContentType.TEXT
|
||||
```
|
||||
|
||||
若传入图文,其中文字需要构造此实体作为 `content` 列表中的一部分。
|
||||
|
||||
### ImagePromptMessageContent
|
||||
|
||||
```python
|
||||
class ImagePromptMessageContent(PromptMessageContent):
|
||||
"""
|
||||
Model class for image prompt message content.
|
||||
"""
|
||||
class DETAIL(Enum):
|
||||
LOW = 'low'
|
||||
HIGH = 'high'
|
||||
|
||||
type: PromptMessageContentType = PromptMessageContentType.IMAGE
|
||||
detail: DETAIL = DETAIL.LOW # 分辨率
|
||||
```
|
||||
|
||||
若传入图文,其中图片需要构造此实体作为 `content` 列表中的一部分
|
||||
|
||||
`data` 可以为 `url` 或者图片 `base64` 加密后的字符串。
|
||||
|
||||
### PromptMessage
|
||||
|
||||
所有 Role 消息体的基类,仅作为参数声明用,不可初始化。
|
||||
|
||||
```python
|
||||
class PromptMessage(BaseModel):
|
||||
"""
|
||||
Model class for prompt message.
|
||||
"""
|
||||
role: PromptMessageRole # 消息角色
|
||||
content: Optional[str | list[PromptMessageContent]] = None # 支持两种类型,字符串和内容列表,内容列表是为了满足多模态的需要,可详见 PromptMessageContent 说明。
|
||||
name: Optional[str] = None # 名称,可选。
|
||||
```
|
||||
|
||||
### UserPromptMessage
|
||||
|
||||
UserMessage 消息体,代表用户消息。
|
||||
|
||||
```python
|
||||
class UserPromptMessage(PromptMessage):
|
||||
"""
|
||||
Model class for user prompt message.
|
||||
"""
|
||||
role: PromptMessageRole = PromptMessageRole.USER
|
||||
```
|
||||
|
||||
### AssistantPromptMessage
|
||||
|
||||
代表模型返回消息,通常用于 `few-shots` 或聊天历史传入。
|
||||
|
||||
```python
|
||||
class AssistantPromptMessage(PromptMessage):
|
||||
"""
|
||||
Model class for assistant prompt message.
|
||||
"""
|
||||
class ToolCall(BaseModel):
|
||||
"""
|
||||
Model class for assistant prompt message tool call.
|
||||
"""
|
||||
class ToolCallFunction(BaseModel):
|
||||
"""
|
||||
Model class for assistant prompt message tool call function.
|
||||
"""
|
||||
name: str # 工具名称
|
||||
arguments: str # 工具参数
|
||||
|
||||
id: str # 工具 ID,仅在 OpenAI tool call 生效,为工具调用的唯一 ID,同一个工具可以调用多次
|
||||
type: str # 默认 function
|
||||
function: ToolCallFunction # 工具调用信息
|
||||
|
||||
role: PromptMessageRole = PromptMessageRole.ASSISTANT
|
||||
tool_calls: list[ToolCall] = [] # 模型回复的工具调用结果(仅当传入 tools,并且模型认为需要调用工具时返回)
|
||||
```
|
||||
|
||||
其中 `tool_calls` 为调用模型传入 `tools` 后,由模型返回的 `tool call` 列表。
|
||||
|
||||
### SystemPromptMessage
|
||||
|
||||
代表系统消息,通常用于设定给模型的系统指令。
|
||||
|
||||
```python
|
||||
class SystemPromptMessage(PromptMessage):
|
||||
"""
|
||||
Model class for system prompt message.
|
||||
"""
|
||||
role: PromptMessageRole = PromptMessageRole.SYSTEM
|
||||
```
|
||||
|
||||
### ToolPromptMessage
|
||||
|
||||
代表工具消息,用于工具执行后将结果交给模型进行下一步计划。
|
||||
|
||||
```python
|
||||
class ToolPromptMessage(PromptMessage):
|
||||
"""
|
||||
Model class for tool prompt message.
|
||||
"""
|
||||
role: PromptMessageRole = PromptMessageRole.TOOL
|
||||
tool_call_id: str # 工具调用 ID,若不支持 OpenAI tool call,也可传入工具名称
|
||||
```
|
||||
|
||||
基类的 `content` 传入工具执行结果。
|
||||
|
||||
### PromptMessageTool
|
||||
|
||||
```python
|
||||
class PromptMessageTool(BaseModel):
|
||||
"""
|
||||
Model class for prompt message tool.
|
||||
"""
|
||||
name: str # 工具名称
|
||||
description: str # 工具描述
|
||||
parameters: dict # 工具参数 dict
|
||||
```
|
||||
|
||||
______________________________________________________________________
|
||||
|
||||
### LLMResult
|
||||
|
||||
```python
|
||||
class LLMResult(BaseModel):
|
||||
"""
|
||||
Model class for llm result.
|
||||
"""
|
||||
model: str # 实际使用模型
|
||||
prompt_messages: list[PromptMessage] # prompt 消息列表
|
||||
message: AssistantPromptMessage # 回复消息
|
||||
usage: LLMUsage # 使用的 tokens 及费用信息
|
||||
system_fingerprint: Optional[str] = None # 请求指纹,可参考 OpenAI 该参数定义
|
||||
```
|
||||
|
||||
### LLMResultChunkDelta
|
||||
|
||||
流式返回中每个迭代内部 `delta` 实体
|
||||
|
||||
```python
|
||||
class LLMResultChunkDelta(BaseModel):
|
||||
"""
|
||||
Model class for llm result chunk delta.
|
||||
"""
|
||||
index: int # 序号
|
||||
message: AssistantPromptMessage # 回复消息
|
||||
usage: Optional[LLMUsage] = None # 使用的 tokens 及费用信息,仅最后一条返回
|
||||
finish_reason: Optional[str] = None # 结束原因,仅最后一条返回
|
||||
```
|
||||
|
||||
### LLMResultChunk
|
||||
|
||||
流式返回中每个迭代实体
|
||||
|
||||
```python
|
||||
class LLMResultChunk(BaseModel):
|
||||
"""
|
||||
Model class for llm result chunk.
|
||||
"""
|
||||
model: str # 实际使用模型
|
||||
prompt_messages: list[PromptMessage] # prompt 消息列表
|
||||
system_fingerprint: Optional[str] = None # 请求指纹,可参考 OpenAI 该参数定义
|
||||
delta: LLMResultChunkDelta # 每个迭代存在变化的内容
|
||||
```
|
||||
|
||||
### LLMUsage
|
||||
|
||||
```python
|
||||
class LLMUsage(ModelUsage):
|
||||
"""
|
||||
Model class for llm usage.
|
||||
"""
|
||||
prompt_tokens: int # prompt 使用 tokens
|
||||
prompt_unit_price: Decimal # prompt 单价
|
||||
prompt_price_unit: Decimal # prompt 价格单位,即单价基于多少 tokens
|
||||
prompt_price: Decimal # prompt 费用
|
||||
completion_tokens: int # 回复使用 tokens
|
||||
completion_unit_price: Decimal # 回复单价
|
||||
completion_price_unit: Decimal # 回复价格单位,即单价基于多少 tokens
|
||||
completion_price: Decimal # 回复费用
|
||||
total_tokens: int # 总使用 token 数
|
||||
total_price: Decimal # 总费用
|
||||
currency: str # 货币单位
|
||||
latency: float # 请求耗时 (s)
|
||||
```
|
||||
|
||||
______________________________________________________________________
|
||||
|
||||
### TextEmbeddingResult
|
||||
|
||||
```python
|
||||
class TextEmbeddingResult(BaseModel):
|
||||
"""
|
||||
Model class for text embedding result.
|
||||
"""
|
||||
model: str # 实际使用模型
|
||||
embeddings: list[list[float]] # embedding 向量列表,对应传入的 texts 列表
|
||||
usage: EmbeddingUsage # 使用信息
|
||||
```
|
||||
|
||||
### EmbeddingUsage
|
||||
|
||||
```python
|
||||
class EmbeddingUsage(ModelUsage):
|
||||
"""
|
||||
Model class for embedding usage.
|
||||
"""
|
||||
tokens: int # 使用 token 数
|
||||
total_tokens: int # 总使用 token 数
|
||||
unit_price: Decimal # 单价
|
||||
price_unit: Decimal # 价格单位,即单价基于多少 tokens
|
||||
total_price: Decimal # 总费用
|
||||
currency: str # 货币单位
|
||||
latency: float # 请求耗时 (s)
|
||||
```
|
||||
|
||||
______________________________________________________________________
|
||||
|
||||
### RerankResult
|
||||
|
||||
```python
|
||||
class RerankResult(BaseModel):
|
||||
"""
|
||||
Model class for rerank result.
|
||||
"""
|
||||
model: str # 实际使用模型
|
||||
docs: list[RerankDocument] # 重排后的分段列表
|
||||
```
|
||||
|
||||
### RerankDocument
|
||||
|
||||
```python
|
||||
class RerankDocument(BaseModel):
|
||||
"""
|
||||
Model class for rerank document.
|
||||
"""
|
||||
index: int # 原序号
|
||||
text: str # 分段文本内容
|
||||
score: float # 分数
|
||||
```
|
||||
|
|
@ -1,172 +0,0 @@
|
|||
## 预定义模型接入
|
||||
|
||||
供应商集成完成后,接下来为供应商下模型的接入。
|
||||
|
||||
我们首先需要确定接入模型的类型,并在对应供应商的目录下创建对应模型类型的 `module`。
|
||||
|
||||
当前支持模型类型如下:
|
||||
|
||||
- `llm` 文本生成模型
|
||||
- `text_embedding` 文本 Embedding 模型
|
||||
- `rerank` Rerank 模型
|
||||
- `speech2text` 语音转文字
|
||||
- `tts` 文字转语音
|
||||
- `moderation` 审查
|
||||
|
||||
依旧以 `Anthropic` 为例,`Anthropic` 仅支持 LLM,因此在 `model_providers.anthropic` 创建一个 `llm` 为名称的 `module`。
|
||||
|
||||
对于预定义的模型,我们首先需要在 `llm` `module` 下创建以模型名为文件名称的 YAML 文件,如:`claude-2.1.yaml`。
|
||||
|
||||
### 准备模型 YAML
|
||||
|
||||
```yaml
|
||||
model: claude-2.1 # 模型标识
|
||||
# 模型展示名称,可设置 en_US 英文、zh_Hans 中文两种语言,zh_Hans 不设置将默认使用 en_US。
|
||||
# 也可不设置 label,则使用 model 标识内容。
|
||||
label:
|
||||
en_US: claude-2.1
|
||||
model_type: llm # 模型类型,claude-2.1 为 LLM
|
||||
features: # 支持功能,agent-thought 为支持 Agent 推理,vision 为支持图片理解
|
||||
- agent-thought
|
||||
model_properties: # 模型属性
|
||||
mode: chat # LLM 模式,complete 文本补全模型,chat 对话模型
|
||||
context_size: 200000 # 支持最大上下文大小
|
||||
parameter_rules: # 模型调用参数规则,仅 LLM 需要提供
|
||||
- name: temperature # 调用参数变量名
|
||||
# 默认预置了 5 种变量内容配置模板,temperature/top_p/max_tokens/presence_penalty/frequency_penalty
|
||||
# 可在 use_template 中直接设置模板变量名,将会使用 entities.defaults.PARAMETER_RULE_TEMPLATE 中的默认配置
|
||||
# 若设置了额外的配置参数,将覆盖默认配置
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label: # 调用参数展示名称
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int # 参数类型,支持 float/int/string/boolean
|
||||
help: # 帮助信息,描述参数作用
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false # 是否必填,可不设置
|
||||
- name: max_tokens_to_sample
|
||||
use_template: max_tokens
|
||||
default: 4096 # 参数默认值
|
||||
min: 1 # 参数最小值,仅 float/int 可用
|
||||
max: 4096 # 参数最大值,仅 float/int 可用
|
||||
pricing: # 价格信息
|
||||
input: '8.00' # 输入单价,即 Prompt 单价
|
||||
output: '24.00' # 输出单价,即返回内容单价
|
||||
unit: '0.000001' # 价格单位,即上述价格为每 100K 的单价
|
||||
currency: USD # 价格货币
|
||||
```
|
||||
|
||||
建议将所有模型配置都准备完毕后再开始模型代码的实现。
|
||||
|
||||
同样,也可以参考 `model_providers` 目录下其他供应商对应模型类型目录下的 YAML 配置信息,完整的 YAML 规则见:[Schema](schema.md#aimodelentity)。
|
||||
|
||||
### 实现模型调用代码
|
||||
|
||||
接下来需要在 `llm` `module` 下创建一个同名的 python 文件 `llm.py` 来编写代码实现。
|
||||
|
||||
在 `llm.py` 中创建一个 Anthropic LLM 类,我们取名为 `AnthropicLargeLanguageModel`(随意),继承 `__base.large_language_model.LargeLanguageModel` 基类,实现以下几个方法:
|
||||
|
||||
- LLM 调用
|
||||
|
||||
实现 LLM 调用的核心方法,可同时支持流式和同步返回。
|
||||
|
||||
```python
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None) \
|
||||
-> Union[LLMResult, Generator]:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param model_parameters: model parameters
|
||||
:param tools: tools for tool calling
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
```
|
||||
|
||||
在实现时,需要注意使用两个函数来返回数据,分别用于处理同步返回和流式返回,因为 Python 会将函数中包含 `yield` 关键字的函数识别为生成器函数,返回的数据类型固定为 `Generator`,因此同步和流式返回需要分别实现,就像下面这样(注意下面例子使用了简化参数,实际实现时需要按照上面的参数列表进行实现):
|
||||
|
||||
```python
|
||||
def _invoke(self, stream: bool, **kwargs) \
|
||||
-> Union[LLMResult, Generator]:
|
||||
if stream:
|
||||
return self._handle_stream_response(**kwargs)
|
||||
return self._handle_sync_response(**kwargs)
|
||||
|
||||
def _handle_stream_response(self, **kwargs) -> Generator:
|
||||
for chunk in response:
|
||||
yield chunk
|
||||
def _handle_sync_response(self, **kwargs) -> LLMResult:
|
||||
return LLMResult(**response)
|
||||
```
|
||||
|
||||
- 预计算输入 tokens
|
||||
|
||||
若模型未提供预计算 tokens 接口,可直接返回 0。
|
||||
|
||||
```python
|
||||
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param tools: tools for tool calling
|
||||
:return:
|
||||
"""
|
||||
```
|
||||
|
||||
- 模型凭据校验
|
||||
|
||||
与供应商凭据校验类似,这里针对单个模型进行校验。
|
||||
|
||||
```python
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Validate model credentials
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return:
|
||||
"""
|
||||
```
|
||||
|
||||
- 调用异常错误映射表
|
||||
|
||||
当模型调用异常时需要映射到 Runtime 指定的 `InvokeError` 类型,方便 Dify 针对不同错误做不同后续处理。
|
||||
|
||||
Runtime Errors:
|
||||
|
||||
- `InvokeConnectionError` 调用连接错误
|
||||
- `InvokeServerUnavailableError ` 调用服务方不可用
|
||||
- `InvokeRateLimitError ` 调用达到限额
|
||||
- `InvokeAuthorizationError` 调用鉴权失败
|
||||
- `InvokeBadRequestError ` 调用传参有误
|
||||
|
||||
```python
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
The key is the error type thrown to the caller
|
||||
The value is the error type thrown by the model,
|
||||
which needs to be converted into a unified error type for the caller.
|
||||
|
||||
:return: Invoke error mapping
|
||||
"""
|
||||
```
|
||||
|
||||
接口方法说明见:[Interfaces](./interfaces.md),具体实现可参考:[llm.py](https://github.com/langgenius/dify-runtime/blob/main/lib/model_providers/anthropic/llm/llm.py)。
|
||||
|
|
@ -1,192 +0,0 @@
|
|||
## 增加新供应商
|
||||
|
||||
供应商支持三种模型配置方式:
|
||||
|
||||
- `predefined-model ` 预定义模型
|
||||
|
||||
表示用户只需要配置统一的供应商凭据即可使用供应商下的预定义模型。
|
||||
|
||||
- `customizable-model` 自定义模型
|
||||
|
||||
用户需要新增每个模型的凭据配置,如 Xinference,它同时支持 LLM 和 Text Embedding,但是每个模型都有唯一的**model_uid**,如果想要将两者同时接入,就需要为每个模型配置一个**model_uid**。
|
||||
|
||||
- `fetch-from-remote` 从远程获取
|
||||
|
||||
与 `predefined-model` 配置方式一致,只需要配置统一的供应商凭据即可,模型通过凭据信息从供应商获取。
|
||||
|
||||
如 OpenAI,我们可以基于 gpt-turbo-3.5 来 Fine Tune 多个模型,而他们都位于同一个**api_key**下,当配置为 `fetch-from-remote` 时,开发者只需要配置统一的**api_key**即可让 DifyRuntime 获取到开发者所有的微调模型并接入 Dify。
|
||||
|
||||
这三种配置方式**支持共存**,即存在供应商支持 `predefined-model` + `customizable-model` 或 `predefined-model` + `fetch-from-remote` 等,也就是配置了供应商统一凭据可以使用预定义模型和从远程获取的模型,若新增了模型,则可以在此基础上额外使用自定义的模型。
|
||||
|
||||
## 开始
|
||||
|
||||
### 介绍
|
||||
|
||||
#### 名词解释
|
||||
|
||||
- `module`: 一个`module`即为一个 Python Package,或者通俗一点,称为一个文件夹,里面包含了一个`__init__.py`文件,以及其他的`.py`文件。
|
||||
|
||||
#### 步骤
|
||||
|
||||
新增一个供应商主要分为几步,这里简单列出,帮助大家有一个大概的认识,具体的步骤会在下面详细介绍。
|
||||
|
||||
- 创建供应商 yaml 文件,根据[ProviderSchema](./schema.md#provider)编写
|
||||
- 创建供应商代码,实现一个`class`。
|
||||
- 根据模型类型,在供应商`module`下创建对应的模型类型 `module`,如`llm`或`text_embedding`。
|
||||
- 根据模型类型,在对应的模型`module`下创建同名的代码文件,如`llm.py`,并实现一个`class`。
|
||||
- 如果有预定义模型,根据模型名称创建同名的 yaml 文件在模型`module`下,如`claude-2.1.yaml`,根据[AIModelEntity](./schema.md#aimodelentity)编写。
|
||||
- 编写测试代码,确保功能可用。
|
||||
|
||||
### 开始吧
|
||||
|
||||
增加一个新的供应商需要先确定供应商的英文标识,如 `anthropic`,使用该标识在 `model_providers` 创建以此为名称的 `module`。
|
||||
|
||||
在此 `module` 下,我们需要先准备供应商的 YAML 配置。
|
||||
|
||||
#### 准备供应商 YAML
|
||||
|
||||
此处以 `Anthropic` 为例,预设了供应商基础信息、支持的模型类型、配置方式、凭据规则。
|
||||
|
||||
```YAML
|
||||
provider: anthropic # 供应商标识
|
||||
label: # 供应商展示名称,可设置 en_US 英文、zh_Hans 中文两种语言,zh_Hans 不设置将默认使用 en_US。
|
||||
en_US: Anthropic
|
||||
icon_small: # 供应商小图标,存储在对应供应商实现目录下的 _assets 目录,中英文策略同 label
|
||||
en_US: icon_s_en.png
|
||||
icon_large: # 供应商大图标,存储在对应供应商实现目录下的 _assets 目录,中英文策略同 label
|
||||
en_US: icon_l_en.png
|
||||
supported_model_types: # 支持的模型类型,Anthropic 仅支持 LLM
|
||||
- llm
|
||||
configurate_methods: # 支持的配置方式,Anthropic 仅支持预定义模型
|
||||
- predefined-model
|
||||
provider_credential_schema: # 供应商凭据规则,由于 Anthropic 仅支持预定义模型,则需要定义统一供应商凭据规则
|
||||
credential_form_schemas: # 凭据表单项列表
|
||||
- variable: anthropic_api_key # 凭据参数变量名
|
||||
label: # 展示名称
|
||||
en_US: API Key
|
||||
type: secret-input # 表单类型,此处 secret-input 代表加密信息输入框,编辑时只展示屏蔽后的信息。
|
||||
required: true # 是否必填
|
||||
placeholder: # PlaceHolder 信息
|
||||
zh_Hans: 在此输入您的 API Key
|
||||
en_US: Enter your API Key
|
||||
- variable: anthropic_api_url
|
||||
label:
|
||||
en_US: API URL
|
||||
type: text-input # 表单类型,此处 text-input 代表文本输入框
|
||||
required: false
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的 API URL
|
||||
en_US: Enter your API URL
|
||||
```
|
||||
|
||||
如果接入的供应商提供自定义模型,比如`OpenAI`提供微调模型,那么我们就需要添加[`model_credential_schema`](./schema.md#modelcredentialschema),以`OpenAI`为例:
|
||||
|
||||
```yaml
|
||||
model_credential_schema:
|
||||
model: # 微调模型名称
|
||||
label:
|
||||
en_US: Model Name
|
||||
zh_Hans: 模型名称
|
||||
placeholder:
|
||||
en_US: Enter your model name
|
||||
zh_Hans: 输入模型名称
|
||||
credential_form_schemas:
|
||||
- variable: openai_api_key
|
||||
label:
|
||||
en_US: API Key
|
||||
type: secret-input
|
||||
required: true
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的 API Key
|
||||
en_US: Enter your API Key
|
||||
- variable: openai_organization
|
||||
label:
|
||||
zh_Hans: 组织 ID
|
||||
en_US: Organization
|
||||
type: text-input
|
||||
required: false
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的组织 ID
|
||||
en_US: Enter your Organization ID
|
||||
- variable: openai_api_base
|
||||
label:
|
||||
zh_Hans: API Base
|
||||
en_US: API Base
|
||||
type: text-input
|
||||
required: false
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的 API Base
|
||||
en_US: Enter your API Base
|
||||
```
|
||||
|
||||
也可以参考 `model_providers` 目录下其他供应商目录下的 YAML 配置信息,完整的 YAML 规则见:[Schema](schema.md#provider)。
|
||||
|
||||
#### 实现供应商代码
|
||||
|
||||
我们需要在`model_providers`下创建一个同名的 python 文件,如`anthropic.py`,并实现一个`class`,继承`__base.provider.Provider`基类,如`AnthropicProvider`。
|
||||
|
||||
##### 自定义模型供应商
|
||||
|
||||
当供应商为 Xinference 等自定义模型供应商时,可跳过该步骤,仅创建一个空的`XinferenceProvider`类即可,并实现一个空的`validate_provider_credentials`方法,该方法并不会被实际使用,仅用作避免抽象类无法实例化。
|
||||
|
||||
```python
|
||||
class XinferenceProvider(Provider):
|
||||
def validate_provider_credentials(self, credentials: dict) -> None:
|
||||
pass
|
||||
```
|
||||
|
||||
##### 预定义模型供应商
|
||||
|
||||
供应商需要继承 `__base.model_provider.ModelProvider` 基类,实现 `validate_provider_credentials` 供应商统一凭据校验方法即可,可参考 [AnthropicProvider](https://github.com/langgenius/dify-runtime/blob/main/lib/model_providers/anthropic/anthropic.py)。
|
||||
|
||||
```python
|
||||
def validate_provider_credentials(self, credentials: dict) -> None:
|
||||
"""
|
||||
Validate provider credentials
|
||||
You can choose any validate_credentials method of model type or implement validate method by yourself,
|
||||
such as: get model list api
|
||||
|
||||
if validate failed, raise exception
|
||||
|
||||
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
|
||||
"""
|
||||
```
|
||||
|
||||
当然也可以先预留 `validate_provider_credentials` 实现,在模型凭据校验方法实现后直接复用。
|
||||
|
||||
#### 增加模型
|
||||
|
||||
#### [增加预定义模型 👈🏻](./predefined_model_scale_out.md)
|
||||
|
||||
对于预定义模型,我们可以通过简单定义一个 yaml,并通过实现调用代码来接入。
|
||||
|
||||
#### [增加自定义模型 👈🏻](./customizable_model_scale_out.md)
|
||||
|
||||
对于自定义模型,我们只需要实现调用代码即可接入,但是它需要处理的参数可能会更加复杂。
|
||||
|
||||
______________________________________________________________________
|
||||
|
||||
### 测试
|
||||
|
||||
为了保证接入供应商/模型的可用性,编写后的每个方法均需要在 `tests` 目录中编写对应的集成测试代码。
|
||||
|
||||
依旧以 `Anthropic` 为例。
|
||||
|
||||
在编写测试代码前,需要先在 `.env.example` 新增测试供应商所需要的凭据环境变量,如:`ANTHROPIC_API_KEY`。
|
||||
|
||||
在执行前需要将 `.env.example` 复制为 `.env` 再执行。
|
||||
|
||||
#### 编写测试代码
|
||||
|
||||
在 `tests` 目录下创建供应商同名的 `module`: `anthropic`,继续在此模块中创建 `test_provider.py` 以及对应模型类型的 test py 文件,如下所示:
|
||||
|
||||
```shell
|
||||
.
|
||||
├── __init__.py
|
||||
├── anthropic
|
||||
│ ├── __init__.py
|
||||
│ ├── test_llm.py # LLM 测试
|
||||
│ └── test_provider.py # 供应商测试
|
||||
```
|
||||
|
||||
针对上面实现的代码的各种情况进行测试代码编写,并测试通过后提交代码。
|
||||
|
|
@ -1,209 +0,0 @@
|
|||
# 配置规则
|
||||
|
||||
- 供应商规则基于 [Provider](#Provider) 实体。
|
||||
|
||||
- 模型规则基于 [AIModelEntity](#AIModelEntity) 实体。
|
||||
|
||||
> 以下所有实体均基于 `Pydantic BaseModel`,可在 `entities` 模块中找到对应实体。
|
||||
|
||||
### Provider
|
||||
|
||||
- `provider` (string) 供应商标识,如:`openai`
|
||||
- `label` (object) 供应商展示名称,i18n,可设置 `en_US` 英文、`zh_Hans` 中文两种语言
|
||||
- `zh_Hans ` (string) [optional] 中文标签名,`zh_Hans` 不设置将默认使用 `en_US`。
|
||||
- `en_US` (string) 英文标签名
|
||||
- `description` (object) [optional] 供应商描述,i18n
|
||||
- `zh_Hans` (string) [optional] 中文描述
|
||||
- `en_US` (string) 英文描述
|
||||
- `icon_small` (string) [optional] 供应商小 ICON,存储在对应供应商实现目录下的 `_assets` 目录,中英文策略同 `label`
|
||||
- `zh_Hans` (string) [optional] 中文 ICON
|
||||
- `en_US` (string) 英文 ICON
|
||||
- `icon_large` (string) [optional] 供应商大 ICON,存储在对应供应商实现目录下的 \_assets 目录,中英文策略同 label
|
||||
- `zh_Hans `(string) [optional] 中文 ICON
|
||||
- `en_US` (string) 英文 ICON
|
||||
- `background` (string) [optional] 背景颜色色值,例:#FFFFFF,为空则展示前端默认色值。
|
||||
- `help` (object) [optional] 帮助信息
|
||||
- `title` (object) 帮助标题,i18n
|
||||
- `zh_Hans` (string) [optional] 中文标题
|
||||
- `en_US` (string) 英文标题
|
||||
- `url` (object) 帮助链接,i18n
|
||||
- `zh_Hans` (string) [optional] 中文链接
|
||||
- `en_US` (string) 英文链接
|
||||
- `supported_model_types` (array\[[ModelType](#ModelType)\]) 支持的模型类型
|
||||
- `configurate_methods` (array\[[ConfigurateMethod](#ConfigurateMethod)\]) 配置方式
|
||||
- `provider_credential_schema` ([ProviderCredentialSchema](#ProviderCredentialSchema)) 供应商凭据规格
|
||||
- `model_credential_schema` ([ModelCredentialSchema](#ModelCredentialSchema)) 模型凭据规格
|
||||
|
||||
### AIModelEntity
|
||||
|
||||
- `model` (string) 模型标识,如:`gpt-3.5-turbo`
|
||||
- `label` (object) [optional] 模型展示名称,i18n,可设置 `en_US` 英文、`zh_Hans` 中文两种语言
|
||||
- `zh_Hans `(string) [optional] 中文标签名
|
||||
- `en_US` (string) 英文标签名
|
||||
- `model_type` ([ModelType](#ModelType)) 模型类型
|
||||
- `features` (array\[[ModelFeature](#ModelFeature)\]) [optional] 支持功能列表
|
||||
- `model_properties` (object) 模型属性
|
||||
- `mode` ([LLMMode](#LLMMode)) 模式 (模型类型 `llm` 可用)
|
||||
- `context_size` (int) 上下文大小 (模型类型 `llm` `text-embedding` 可用)
|
||||
- `max_chunks` (int) 最大分块数量 (模型类型 `text-embedding ` `moderation` 可用)
|
||||
- `file_upload_limit` (int) 文件最大上传限制,单位:MB。(模型类型 `speech2text` 可用)
|
||||
- `supported_file_extensions` (string) 支持文件扩展格式,如:mp3,mp4(模型类型 `speech2text` 可用)
|
||||
- `default_voice` (string) 缺省音色,必选:alloy,echo,fable,onyx,nova,shimmer(模型类型 `tts` 可用)
|
||||
- `voices` (list) 可选音色列表。
|
||||
- `mode` (string) 音色模型。(模型类型 `tts` 可用)
|
||||
- `name` (string) 音色模型显示名称。(模型类型 `tts` 可用)
|
||||
- `language` (string) 音色模型支持语言。(模型类型 `tts` 可用)
|
||||
- `word_limit` (int) 单次转换字数限制,默认按段落分段(模型类型 `tts` 可用)
|
||||
- `audio_type` (string) 支持音频文件扩展格式,如:mp3,wav(模型类型 `tts` 可用)
|
||||
- `max_workers` (int) 支持文字音频转换并发任务数(模型类型 `tts` 可用)
|
||||
- `max_characters_per_chunk` (int) 每块最大字符数 (模型类型 `moderation` 可用)
|
||||
- `parameter_rules` (array\[[ParameterRule](#ParameterRule)\]) [optional] 模型调用参数规则
|
||||
- `pricing` ([PriceConfig](#PriceConfig)) [optional] 价格信息
|
||||
- `deprecated` (bool) 是否废弃。若废弃,模型列表将不再展示,但已经配置的可以继续使用,默认 False。
|
||||
|
||||
### ModelType
|
||||
|
||||
- `llm` 文本生成模型
|
||||
- `text-embedding` 文本 Embedding 模型
|
||||
- `rerank` Rerank 模型
|
||||
- `speech2text` 语音转文字
|
||||
- `tts` 文字转语音
|
||||
- `moderation` 审查
|
||||
|
||||
### ConfigurateMethod
|
||||
|
||||
- `predefined-model ` 预定义模型
|
||||
|
||||
表示用户只需要配置统一的供应商凭据即可使用供应商下的预定义模型。
|
||||
|
||||
- `customizable-model` 自定义模型
|
||||
|
||||
用户需要新增每个模型的凭据配置。
|
||||
|
||||
- `fetch-from-remote` 从远程获取
|
||||
|
||||
与 `predefined-model` 配置方式一致,只需要配置统一的供应商凭据即可,模型通过凭据信息从供应商获取。
|
||||
|
||||
### ModelFeature
|
||||
|
||||
- `agent-thought` Agent 推理,一般超过 70B 有思维链能力。
|
||||
- `vision` 视觉,即:图像理解。
|
||||
- `tool-call` 工具调用
|
||||
- `multi-tool-call` 多工具调用
|
||||
- `stream-tool-call` 流式工具调用
|
||||
|
||||
### FetchFrom
|
||||
|
||||
- `predefined-model` 预定义模型
|
||||
- `fetch-from-remote` 远程模型
|
||||
|
||||
### LLMMode
|
||||
|
||||
- `completion` 文本补全
|
||||
- `chat` 对话
|
||||
|
||||
### ParameterRule
|
||||
|
||||
- `name` (string) 调用模型实际参数名
|
||||
|
||||
- `use_template` (string) [optional] 使用模板
|
||||
|
||||
默认预置了 5 种变量内容配置模板:
|
||||
|
||||
- `temperature`
|
||||
- `top_p`
|
||||
- `frequency_penalty`
|
||||
- `presence_penalty`
|
||||
- `max_tokens`
|
||||
|
||||
可在 use_template 中直接设置模板变量名,将会使用 entities.defaults.PARAMETER_RULE_TEMPLATE 中的默认配置
|
||||
不用设置除 `name` 和 `use_template` 之外的所有参数,若设置了额外的配置参数,将覆盖默认配置。
|
||||
可参考 `openai/llm/gpt-3.5-turbo.yaml`。
|
||||
|
||||
- `label` (object) [optional] 标签,i18n
|
||||
|
||||
- `zh_Hans`(string) [optional] 中文标签名
|
||||
- `en_US` (string) 英文标签名
|
||||
|
||||
- `type`(string) [optional] 参数类型
|
||||
|
||||
- `int` 整数
|
||||
- `float` 浮点数
|
||||
- `string` 字符串
|
||||
- `boolean` 布尔型
|
||||
|
||||
- `help` (string) [optional] 帮助信息
|
||||
|
||||
- `zh_Hans` (string) [optional] 中文帮助信息
|
||||
- `en_US` (string) 英文帮助信息
|
||||
|
||||
- `required` (bool) 是否必填,默认 False。
|
||||
|
||||
- `default`(int/float/string/bool) [optional] 默认值
|
||||
|
||||
- `min`(int/float) [optional] 最小值,仅数字类型适用
|
||||
|
||||
- `max`(int/float) [optional] 最大值,仅数字类型适用
|
||||
|
||||
- `precision`(int) [optional] 精度,保留小数位数,仅数字类型适用
|
||||
|
||||
- `options` (array[string]) [optional] 下拉选项值,仅当 `type` 为 `string` 时适用,若不设置或为 null 则不限制选项值
|
||||
|
||||
### PriceConfig
|
||||
|
||||
- `input` (float) 输入单价,即 Prompt 单价
|
||||
- `output` (float) 输出单价,即返回内容单价
|
||||
- `unit` (float) 价格单位,如以 1M tokens 计价,则单价对应的单位 token 数为 `0.000001`
|
||||
- `currency` (string) 货币单位
|
||||
|
||||
### ProviderCredentialSchema
|
||||
|
||||
- `credential_form_schemas` (array\[[CredentialFormSchema](#CredentialFormSchema)\]) 凭据表单规范
|
||||
|
||||
### ModelCredentialSchema
|
||||
|
||||
- `model` (object) 模型标识,变量名默认 `model`
|
||||
- `label` (object) 模型表单项展示名称
|
||||
- `en_US` (string) 英文
|
||||
- `zh_Hans`(string) [optional] 中文
|
||||
- `placeholder` (object) 模型提示内容
|
||||
- `en_US`(string) 英文
|
||||
- `zh_Hans`(string) [optional] 中文
|
||||
- `credential_form_schemas` (array\[[CredentialFormSchema](#CredentialFormSchema)\]) 凭据表单规范
|
||||
|
||||
### CredentialFormSchema
|
||||
|
||||
- `variable` (string) 表单项变量名
|
||||
- `label` (object) 表单项标签名
|
||||
- `en_US`(string) 英文
|
||||
- `zh_Hans` (string) [optional] 中文
|
||||
- `type` ([FormType](#FormType)) 表单项类型
|
||||
- `required` (bool) 是否必填
|
||||
- `default`(string) 默认值
|
||||
- `options` (array\[[FormOption](#FormOption)\]) 表单项为 `select` 或 `radio` 专有属性,定义下拉内容
|
||||
- `placeholder`(object) 表单项为 `text-input `专有属性,表单项 PlaceHolder
|
||||
- `en_US`(string) 英文
|
||||
- `zh_Hans` (string) [optional] 中文
|
||||
- `max_length` (int) 表单项为`text-input`专有属性,定义输入最大长度,0 为不限制。
|
||||
- `show_on` (array\[[FormShowOnObject](#FormShowOnObject)\]) 当其他表单项值符合条件时显示,为空则始终显示。
|
||||
|
||||
### FormType
|
||||
|
||||
- `text-input` 文本输入组件
|
||||
- `secret-input` 密码输入组件
|
||||
- `select` 单选下拉
|
||||
- `radio` Radio 组件
|
||||
- `switch` 开关组件,仅支持 `true` 和 `false`
|
||||
|
||||
### FormOption
|
||||
|
||||
- `label` (object) 标签
|
||||
- `en_US`(string) 英文
|
||||
- `zh_Hans`(string) [optional] 中文
|
||||
- `value` (string) 下拉选项值
|
||||
- `show_on` (array\[[FormShowOnObject](#FormShowOnObject)\]) 当其他表单项值符合条件时显示,为空则始终显示。
|
||||
|
||||
### FormShowOnObject
|
||||
|
||||
- `variable` (string) 其他表单项变量名
|
||||
- `value` (string) 其他表单项变量值
|
||||
|
|
@ -222,6 +222,59 @@ class TencentSpanBuilder:
|
|||
links=links,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def build_message_llm_span(
|
||||
trace_info: MessageTraceInfo, trace_id: int, parent_span_id: int, user_id: str
|
||||
) -> SpanData:
|
||||
"""Build LLM span for message traces with detailed LLM attributes."""
|
||||
status = Status(StatusCode.OK)
|
||||
if trace_info.error:
|
||||
status = Status(StatusCode.ERROR, trace_info.error)
|
||||
|
||||
# Extract model information from `metadata`` or `message_data`
|
||||
trace_metadata = trace_info.metadata or {}
|
||||
message_data = trace_info.message_data or {}
|
||||
|
||||
model_provider = trace_metadata.get("ls_provider") or (
|
||||
message_data.get("model_provider", "") if isinstance(message_data, dict) else ""
|
||||
)
|
||||
model_name = trace_metadata.get("ls_model_name") or (
|
||||
message_data.get("model_id", "") if isinstance(message_data, dict) else ""
|
||||
)
|
||||
|
||||
inputs_str = str(trace_info.inputs or "")
|
||||
outputs_str = str(trace_info.outputs or "")
|
||||
|
||||
attributes = {
|
||||
GEN_AI_SESSION_ID: trace_metadata.get("conversation_id", ""),
|
||||
GEN_AI_USER_ID: str(user_id),
|
||||
GEN_AI_SPAN_KIND: GenAISpanKind.GENERATION.value,
|
||||
GEN_AI_FRAMEWORK: "dify",
|
||||
GEN_AI_MODEL_NAME: str(model_name),
|
||||
GEN_AI_PROVIDER: str(model_provider),
|
||||
GEN_AI_USAGE_INPUT_TOKENS: str(trace_info.message_tokens or 0),
|
||||
GEN_AI_USAGE_OUTPUT_TOKENS: str(trace_info.answer_tokens or 0),
|
||||
GEN_AI_USAGE_TOTAL_TOKENS: str(trace_info.total_tokens or 0),
|
||||
GEN_AI_PROMPT: inputs_str,
|
||||
GEN_AI_COMPLETION: outputs_str,
|
||||
INPUT_VALUE: inputs_str,
|
||||
OUTPUT_VALUE: outputs_str,
|
||||
}
|
||||
|
||||
if trace_info.is_streaming_request:
|
||||
attributes[GEN_AI_IS_STREAMING_REQUEST] = "true"
|
||||
|
||||
return SpanData(
|
||||
trace_id=trace_id,
|
||||
parent_span_id=parent_span_id,
|
||||
span_id=TencentTraceUtils.convert_to_span_id(trace_info.message_id, "llm"),
|
||||
name="GENERATION",
|
||||
start_time=TencentSpanBuilder._get_time_nanoseconds(trace_info.start_time),
|
||||
end_time=TencentSpanBuilder._get_time_nanoseconds(trace_info.end_time),
|
||||
attributes=attributes,
|
||||
status=status,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def build_tool_span(trace_info: ToolTraceInfo, trace_id: int, parent_span_id: int) -> SpanData:
|
||||
"""Build tool span."""
|
||||
|
|
|
|||
|
|
@ -107,9 +107,13 @@ class TencentDataTrace(BaseTraceInstance):
|
|||
links.append(TencentTraceUtils.create_link(trace_info.trace_id))
|
||||
|
||||
message_span = TencentSpanBuilder.build_message_span(trace_info, trace_id, str(user_id), links)
|
||||
|
||||
self.trace_client.add_span(message_span)
|
||||
|
||||
# Add LLM child span with detailed attributes
|
||||
parent_span_id = TencentTraceUtils.convert_to_span_id(trace_info.message_id, "message")
|
||||
llm_span = TencentSpanBuilder.build_message_llm_span(trace_info, trace_id, parent_span_id, str(user_id))
|
||||
self.trace_client.add_span(llm_span)
|
||||
|
||||
self._record_message_llm_metrics(trace_info)
|
||||
|
||||
# Record trace duration for entry span
|
||||
|
|
|
|||
|
|
@ -1,20 +1,110 @@
|
|||
import re
|
||||
from operator import itemgetter
|
||||
from typing import cast
|
||||
|
||||
|
||||
class JiebaKeywordTableHandler:
|
||||
def __init__(self):
|
||||
from core.rag.datasource.keyword.jieba.stopwords import STOPWORDS
|
||||
|
||||
tfidf = self._load_tfidf_extractor()
|
||||
tfidf.stop_words = STOPWORDS # type: ignore[attr-defined]
|
||||
self._tfidf = tfidf
|
||||
|
||||
def _load_tfidf_extractor(self):
|
||||
"""
|
||||
Load jieba TFIDF extractor with fallback strategy.
|
||||
|
||||
Loading Flow:
|
||||
┌─────────────────────────────────────────────────────────────────────┐
|
||||
│ jieba.analyse.default_tfidf │
|
||||
│ exists? │
|
||||
└─────────────────────────────────────────────────────────────────────┘
|
||||
│ │
|
||||
YES NO
|
||||
│ │
|
||||
▼ ▼
|
||||
┌──────────────────┐ ┌──────────────────────────────────┐
|
||||
│ Return default │ │ jieba.analyse.TFIDF exists? │
|
||||
│ TFIDF │ └──────────────────────────────────┘
|
||||
└──────────────────┘ │ │
|
||||
YES NO
|
||||
│ │
|
||||
│ ▼
|
||||
│ ┌────────────────────────────┐
|
||||
│ │ Try import from │
|
||||
│ │ jieba.analyse.tfidf.TFIDF │
|
||||
│ └────────────────────────────┘
|
||||
│ │ │
|
||||
│ SUCCESS FAILED
|
||||
│ │ │
|
||||
▼ ▼ ▼
|
||||
┌────────────────────────┐ ┌─────────────────┐
|
||||
│ Instantiate TFIDF() │ │ Build fallback │
|
||||
│ & cache to default │ │ _SimpleTFIDF │
|
||||
└────────────────────────┘ └─────────────────┘
|
||||
"""
|
||||
import jieba.analyse # type: ignore
|
||||
|
||||
tfidf = getattr(jieba.analyse, "default_tfidf", None)
|
||||
if tfidf is not None:
|
||||
return tfidf
|
||||
|
||||
tfidf_class = getattr(jieba.analyse, "TFIDF", None)
|
||||
if tfidf_class is None:
|
||||
try:
|
||||
from jieba.analyse.tfidf import TFIDF # type: ignore
|
||||
|
||||
tfidf_class = TFIDF
|
||||
except Exception:
|
||||
tfidf_class = None
|
||||
|
||||
if tfidf_class is not None:
|
||||
tfidf = tfidf_class()
|
||||
jieba.analyse.default_tfidf = tfidf # type: ignore[attr-defined]
|
||||
return tfidf
|
||||
|
||||
return self._build_fallback_tfidf()
|
||||
|
||||
@staticmethod
|
||||
def _build_fallback_tfidf():
|
||||
"""Fallback lightweight TFIDF for environments missing jieba's TFIDF."""
|
||||
import jieba # type: ignore
|
||||
|
||||
from core.rag.datasource.keyword.jieba.stopwords import STOPWORDS
|
||||
|
||||
jieba.analyse.default_tfidf.stop_words = STOPWORDS # type: ignore
|
||||
class _SimpleTFIDF:
|
||||
def __init__(self):
|
||||
self.stop_words = STOPWORDS
|
||||
self._lcut = getattr(jieba, "lcut", None)
|
||||
|
||||
def extract_tags(self, sentence: str, top_k: int | None = 20, **kwargs):
|
||||
# Basic frequency-based keyword extraction as a fallback when TF-IDF is unavailable.
|
||||
top_k = kwargs.pop("topK", top_k)
|
||||
cut = getattr(jieba, "cut", None)
|
||||
if self._lcut:
|
||||
tokens = self._lcut(sentence)
|
||||
elif callable(cut):
|
||||
tokens = list(cut(sentence))
|
||||
else:
|
||||
tokens = re.findall(r"\w+", sentence)
|
||||
|
||||
words = [w for w in tokens if w and w not in self.stop_words]
|
||||
freq: dict[str, int] = {}
|
||||
for w in words:
|
||||
freq[w] = freq.get(w, 0) + 1
|
||||
|
||||
sorted_words = sorted(freq.items(), key=itemgetter(1), reverse=True)
|
||||
if top_k is not None:
|
||||
sorted_words = sorted_words[:top_k]
|
||||
|
||||
return [item[0] for item in sorted_words]
|
||||
|
||||
return _SimpleTFIDF()
|
||||
|
||||
def extract_keywords(self, text: str, max_keywords_per_chunk: int | None = 10) -> set[str]:
|
||||
"""Extract keywords with JIEBA tfidf."""
|
||||
import jieba.analyse # type: ignore
|
||||
|
||||
keywords = jieba.analyse.extract_tags(
|
||||
keywords = self._tfidf.extract_tags(
|
||||
sentence=text,
|
||||
topK=max_keywords_per_chunk,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -58,11 +58,39 @@ class OceanBaseVector(BaseVector):
|
|||
password=self._config.password,
|
||||
db_name=self._config.database,
|
||||
)
|
||||
self._fields: list[str] = [] # List of fields in the collection
|
||||
if self._client.check_table_exists(collection_name):
|
||||
self._load_collection_fields()
|
||||
self._hybrid_search_enabled = self._check_hybrid_search_support() # Check if hybrid search is supported
|
||||
|
||||
def get_type(self) -> str:
|
||||
return VectorType.OCEANBASE
|
||||
|
||||
def _load_collection_fields(self):
|
||||
"""
|
||||
Load collection fields from the database table.
|
||||
This method populates the _fields list with column names from the table.
|
||||
"""
|
||||
try:
|
||||
if self._collection_name in self._client.metadata_obj.tables:
|
||||
table = self._client.metadata_obj.tables[self._collection_name]
|
||||
# Store all column names except 'id' (primary key)
|
||||
self._fields = [column.name for column in table.columns if column.name != "id"]
|
||||
logger.debug("Loaded fields for collection '%s': %s", self._collection_name, self._fields)
|
||||
else:
|
||||
logger.warning("Collection '%s' not found in metadata", self._collection_name)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to load collection fields for '%s': %s", self._collection_name, str(e))
|
||||
|
||||
def field_exists(self, field: str) -> bool:
|
||||
"""
|
||||
Check if a field exists in the collection.
|
||||
|
||||
:param field: Field name to check
|
||||
:return: True if field exists, False otherwise
|
||||
"""
|
||||
return field in self._fields
|
||||
|
||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
self._vec_dim = len(embeddings[0])
|
||||
self._create_collection()
|
||||
|
|
@ -151,6 +179,7 @@ class OceanBaseVector(BaseVector):
|
|||
logger.debug("DEBUG: Hybrid search is NOT enabled for '%s'", self._collection_name)
|
||||
|
||||
self._client.refresh_metadata([self._collection_name])
|
||||
self._load_collection_fields()
|
||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||
|
||||
def _check_hybrid_search_support(self) -> bool:
|
||||
|
|
@ -177,42 +206,134 @@ class OceanBaseVector(BaseVector):
|
|||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
ids = self._get_uuids(documents)
|
||||
for id, doc, emb in zip(ids, documents, embeddings):
|
||||
self._client.insert(
|
||||
table_name=self._collection_name,
|
||||
data={
|
||||
"id": id,
|
||||
"vector": emb,
|
||||
"text": doc.page_content,
|
||||
"metadata": doc.metadata,
|
||||
},
|
||||
)
|
||||
try:
|
||||
self._client.insert(
|
||||
table_name=self._collection_name,
|
||||
data={
|
||||
"id": id,
|
||||
"vector": emb,
|
||||
"text": doc.page_content,
|
||||
"metadata": doc.metadata,
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"Failed to insert document with id '%s' in collection '%s'",
|
||||
id,
|
||||
self._collection_name,
|
||||
)
|
||||
raise Exception(f"Failed to insert document with id '{id}'") from e
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
cur = self._client.get(table_name=self._collection_name, ids=id)
|
||||
return bool(cur.rowcount != 0)
|
||||
try:
|
||||
cur = self._client.get(table_name=self._collection_name, ids=id)
|
||||
return bool(cur.rowcount != 0)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"Failed to check if text exists with id '%s' in collection '%s'",
|
||||
id,
|
||||
self._collection_name,
|
||||
)
|
||||
raise Exception(f"Failed to check text existence for id '{id}'") from e
|
||||
|
||||
def delete_by_ids(self, ids: list[str]):
|
||||
if not ids:
|
||||
return
|
||||
self._client.delete(table_name=self._collection_name, ids=ids)
|
||||
try:
|
||||
self._client.delete(table_name=self._collection_name, ids=ids)
|
||||
logger.debug("Deleted %d documents from collection '%s'", len(ids), self._collection_name)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"Failed to delete %d documents from collection '%s'",
|
||||
len(ids),
|
||||
self._collection_name,
|
||||
)
|
||||
raise Exception(f"Failed to delete documents from collection '{self._collection_name}'") from e
|
||||
|
||||
def get_ids_by_metadata_field(self, key: str, value: str) -> list[str]:
|
||||
from sqlalchemy import text
|
||||
try:
|
||||
import re
|
||||
|
||||
cur = self._client.get(
|
||||
table_name=self._collection_name,
|
||||
ids=None,
|
||||
where_clause=[text(f"metadata->>'$.{key}' = '{value}'")],
|
||||
output_column_name=["id"],
|
||||
)
|
||||
return [row[0] for row in cur]
|
||||
from sqlalchemy import text
|
||||
|
||||
# Validate key to prevent injection in JSON path
|
||||
if not re.match(r"^[a-zA-Z0-9_.]+$", key):
|
||||
raise ValueError(f"Invalid characters in metadata key: {key}")
|
||||
|
||||
# Use parameterized query to prevent SQL injection
|
||||
sql = text(f"SELECT id FROM `{self._collection_name}` WHERE metadata->>'$.{key}' = :value")
|
||||
|
||||
with self._client.engine.connect() as conn:
|
||||
result = conn.execute(sql, {"value": value})
|
||||
ids = [row[0] for row in result]
|
||||
|
||||
logger.debug(
|
||||
"Found %d documents with metadata field '%s'='%s' in collection '%s'",
|
||||
len(ids),
|
||||
key,
|
||||
value,
|
||||
self._collection_name,
|
||||
)
|
||||
return ids
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"Failed to get IDs by metadata field '%s'='%s' in collection '%s'",
|
||||
key,
|
||||
value,
|
||||
self._collection_name,
|
||||
)
|
||||
raise Exception(f"Failed to query documents by metadata field '{key}'") from e
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str):
|
||||
ids = self.get_ids_by_metadata_field(key, value)
|
||||
self.delete_by_ids(ids)
|
||||
if ids:
|
||||
self.delete_by_ids(ids)
|
||||
else:
|
||||
logger.debug("No documents found to delete with metadata field '%s'='%s'", key, value)
|
||||
|
||||
def _process_search_results(
|
||||
self, results: list[tuple], score_threshold: float = 0.0, score_key: str = "score"
|
||||
) -> list[Document]:
|
||||
"""
|
||||
Common method to process search results
|
||||
|
||||
:param results: Search results as list of tuples (text, metadata, score)
|
||||
:param score_threshold: Score threshold for filtering
|
||||
:param score_key: Key name for score in metadata
|
||||
:return: List of documents
|
||||
"""
|
||||
docs = []
|
||||
for row in results:
|
||||
text, metadata_str, score = row[0], row[1], row[2]
|
||||
|
||||
# Parse metadata JSON
|
||||
try:
|
||||
metadata = json.loads(metadata_str) if isinstance(metadata_str, str) else metadata_str
|
||||
except json.JSONDecodeError:
|
||||
logger.warning("Invalid JSON metadata: %s", metadata_str)
|
||||
metadata = {}
|
||||
|
||||
# Add score to metadata
|
||||
metadata[score_key] = score
|
||||
|
||||
# Filter by score threshold
|
||||
if score >= score_threshold:
|
||||
docs.append(Document(page_content=text, metadata=metadata))
|
||||
|
||||
return docs
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
if not self._hybrid_search_enabled:
|
||||
logger.warning(
|
||||
"Full-text search is disabled: set OCEANBASE_ENABLE_HYBRID_SEARCH=true (requires OceanBase >= 4.3.5.1)."
|
||||
)
|
||||
return []
|
||||
if not self.field_exists("text"):
|
||||
logger.warning(
|
||||
"Full-text search unavailable: collection '%s' missing 'text' field; "
|
||||
"recreate the collection after enabling OCEANBASE_ENABLE_HYBRID_SEARCH to add fulltext index.",
|
||||
self._collection_name,
|
||||
)
|
||||
return []
|
||||
|
||||
try:
|
||||
|
|
@ -220,13 +341,24 @@ class OceanBaseVector(BaseVector):
|
|||
if not isinstance(top_k, int) or top_k <= 0:
|
||||
raise ValueError("top_k must be a positive integer")
|
||||
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
where_clause = ""
|
||||
if document_ids_filter:
|
||||
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
||||
where_clause = f" AND metadata->>'$.document_id' IN ({document_ids})"
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
|
||||
full_sql = f"""SELECT metadata, text, MATCH (text) AGAINST (:query) AS score
|
||||
# Build parameterized query to prevent SQL injection
|
||||
from sqlalchemy import text
|
||||
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
params = {"query": query}
|
||||
where_clause = ""
|
||||
|
||||
if document_ids_filter:
|
||||
# Create parameterized placeholders for document IDs
|
||||
placeholders = ", ".join(f":doc_id_{i}" for i in range(len(document_ids_filter)))
|
||||
where_clause = f" AND metadata->>'$.document_id' IN ({placeholders})"
|
||||
# Add document IDs to parameters
|
||||
for i, doc_id in enumerate(document_ids_filter):
|
||||
params[f"doc_id_{i}"] = doc_id
|
||||
|
||||
full_sql = f"""SELECT text, metadata, MATCH (text) AGAINST (:query) AS score
|
||||
FROM {self._collection_name}
|
||||
WHERE MATCH (text) AGAINST (:query) > 0
|
||||
{where_clause}
|
||||
|
|
@ -235,41 +367,45 @@ class OceanBaseVector(BaseVector):
|
|||
|
||||
with self._client.engine.connect() as conn:
|
||||
with conn.begin():
|
||||
from sqlalchemy import text
|
||||
|
||||
result = conn.execute(text(full_sql), {"query": query})
|
||||
result = conn.execute(text(full_sql), params)
|
||||
rows = result.fetchall()
|
||||
|
||||
docs = []
|
||||
for row in rows:
|
||||
metadata_str, _text, score = row
|
||||
try:
|
||||
metadata = json.loads(metadata_str)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning("Invalid JSON metadata: %s", metadata_str)
|
||||
metadata = {}
|
||||
metadata["score"] = score
|
||||
docs.append(Document(page_content=_text, metadata=metadata))
|
||||
|
||||
return docs
|
||||
return self._process_search_results(rows, score_threshold=score_threshold)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to fulltext search: %s.", str(e))
|
||||
return []
|
||||
logger.exception(
|
||||
"Failed to perform full-text search on collection '%s' with query '%s'",
|
||||
self._collection_name,
|
||||
query,
|
||||
)
|
||||
raise Exception(f"Full-text search failed for collection '{self._collection_name}'") from e
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
from sqlalchemy import text
|
||||
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
_where_clause = None
|
||||
if document_ids_filter:
|
||||
# Validate document IDs to prevent SQL injection
|
||||
# Document IDs should be alphanumeric with hyphens and underscores
|
||||
import re
|
||||
|
||||
for doc_id in document_ids_filter:
|
||||
if not isinstance(doc_id, str) or not re.match(r"^[a-zA-Z0-9_-]+$", doc_id):
|
||||
raise ValueError(f"Invalid document ID format: {doc_id}")
|
||||
|
||||
# Safe to use in query after validation
|
||||
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
||||
where_clause = f"metadata->>'$.document_id' in ({document_ids})"
|
||||
from sqlalchemy import text
|
||||
|
||||
_where_clause = [text(where_clause)]
|
||||
ef_search = kwargs.get("ef_search", self._hnsw_ef_search)
|
||||
if ef_search != self._hnsw_ef_search:
|
||||
self._client.set_ob_hnsw_ef_search(ef_search)
|
||||
self._hnsw_ef_search = ef_search
|
||||
topk = kwargs.get("top_k", 10)
|
||||
try:
|
||||
score_threshold = float(val) if (val := kwargs.get("score_threshold")) is not None else 0.0
|
||||
except (ValueError, TypeError) as e:
|
||||
raise ValueError(f"Invalid score_threshold parameter: {e}") from e
|
||||
try:
|
||||
cur = self._client.ann_search(
|
||||
table_name=self._collection_name,
|
||||
|
|
@ -282,21 +418,27 @@ class OceanBaseVector(BaseVector):
|
|||
where_clause=_where_clause,
|
||||
)
|
||||
except Exception as e:
|
||||
raise Exception("Failed to search by vector. ", e)
|
||||
docs = []
|
||||
for _text, metadata, distance in cur:
|
||||
metadata = json.loads(metadata)
|
||||
metadata["score"] = 1 - distance / math.sqrt(2)
|
||||
docs.append(
|
||||
Document(
|
||||
page_content=_text,
|
||||
metadata=metadata,
|
||||
)
|
||||
logger.exception(
|
||||
"Failed to perform vector search on collection '%s'",
|
||||
self._collection_name,
|
||||
)
|
||||
return docs
|
||||
raise Exception(f"Vector search failed for collection '{self._collection_name}'") from e
|
||||
|
||||
# Convert distance to score and prepare results for processing
|
||||
results = []
|
||||
for _text, metadata_str, distance in cur:
|
||||
score = 1 - distance / math.sqrt(2)
|
||||
results.append((_text, metadata_str, score))
|
||||
|
||||
return self._process_search_results(results, score_threshold=score_threshold)
|
||||
|
||||
def delete(self):
|
||||
self._client.drop_table_if_exist(self._collection_name)
|
||||
try:
|
||||
self._client.drop_table_if_exist(self._collection_name)
|
||||
logger.debug("Dropped collection '%s'", self._collection_name)
|
||||
except Exception as e:
|
||||
logger.exception("Failed to delete collection '%s'", self._collection_name)
|
||||
raise Exception(f"Failed to delete collection '{self._collection_name}'") from e
|
||||
|
||||
|
||||
class OceanBaseVectorFactory(AbstractVectorFactory):
|
||||
|
|
|
|||
|
|
@ -302,8 +302,7 @@ class OracleVector(BaseVector):
|
|||
nltk.data.find("tokenizers/punkt")
|
||||
nltk.data.find("corpora/stopwords")
|
||||
except LookupError:
|
||||
nltk.download("punkt")
|
||||
nltk.download("stopwords")
|
||||
raise LookupError("Unable to find the required NLTK data package: punkt and stopwords")
|
||||
e_str = re.sub(r"[^\w ]", "", query)
|
||||
all_tokens = nltk.word_tokenize(e_str)
|
||||
stop_words = stopwords.words("english")
|
||||
|
|
|
|||
|
|
@ -167,13 +167,18 @@ class WeaviateVector(BaseVector):
|
|||
|
||||
try:
|
||||
if not self._client.collections.exists(self._collection_name):
|
||||
tokenization = (
|
||||
wc.Tokenization(dify_config.WEAVIATE_TOKENIZATION)
|
||||
if dify_config.WEAVIATE_TOKENIZATION
|
||||
else wc.Tokenization.WORD
|
||||
)
|
||||
self._client.collections.create(
|
||||
name=self._collection_name,
|
||||
properties=[
|
||||
wc.Property(
|
||||
name=Field.TEXT_KEY.value,
|
||||
data_type=wc.DataType.TEXT,
|
||||
tokenization=wc.Tokenization.WORD,
|
||||
tokenization=tokenization,
|
||||
),
|
||||
wc.Property(name="document_id", data_type=wc.DataType.TEXT),
|
||||
wc.Property(name="doc_id", data_type=wc.DataType.TEXT),
|
||||
|
|
|
|||
|
|
@ -54,6 +54,8 @@ class ToolProviderApiEntity(BaseModel):
|
|||
configuration: MCPConfiguration | None = Field(
|
||||
default=None, description="The timeout and sse_read_timeout of the MCP tool"
|
||||
)
|
||||
# Workflow
|
||||
workflow_app_id: str | None = Field(default=None, description="The app id of the workflow tool")
|
||||
|
||||
@field_validator("tools", mode="before")
|
||||
@classmethod
|
||||
|
|
@ -87,6 +89,8 @@ class ToolProviderApiEntity(BaseModel):
|
|||
optional_fields.update(self.optional_field("is_dynamic_registration", self.is_dynamic_registration))
|
||||
optional_fields.update(self.optional_field("masked_headers", self.masked_headers))
|
||||
optional_fields.update(self.optional_field("original_headers", self.original_headers))
|
||||
elif self.type == ToolProviderType.WORKFLOW:
|
||||
optional_fields.update(self.optional_field("workflow_app_id", self.workflow_app_id))
|
||||
return {
|
||||
"id": self.id,
|
||||
"author": self.author,
|
||||
|
|
|
|||
|
|
@ -1,4 +1,6 @@
|
|||
from pydantic import BaseModel
|
||||
from collections.abc import Mapping
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.tools.entities.tool_entities import ToolParameter
|
||||
|
||||
|
|
@ -25,3 +27,5 @@ class ApiToolBundle(BaseModel):
|
|||
icon: str | None = None
|
||||
# openapi operation
|
||||
openapi: dict
|
||||
# output schema
|
||||
output_schema: Mapping[str, object] = Field(default_factory=dict)
|
||||
|
|
|
|||
|
|
@ -123,6 +123,16 @@ class ApiProviderAuthType(StrEnum):
|
|||
raise ValueError(f"invalid mode value '{value}', expected one of: {valid}")
|
||||
|
||||
|
||||
class ToolAuthType(StrEnum):
|
||||
"""
|
||||
Enum class for tool authentication type.
|
||||
Determines whether OAuth credentials are workspace-level or end-user-level.
|
||||
"""
|
||||
|
||||
WORKSPACE = "workspace"
|
||||
END_USER = "end_user"
|
||||
|
||||
|
||||
class ToolInvokeMessage(BaseModel):
|
||||
class TextMessage(BaseModel):
|
||||
text: str
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ from typing import Any
|
|||
|
||||
from core.app.app_config.entities import VariableEntity
|
||||
from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration
|
||||
from core.workflow.nodes.base.entities import OutputVariableEntity
|
||||
|
||||
|
||||
class WorkflowToolConfigurationUtils:
|
||||
|
|
@ -24,6 +25,31 @@ class WorkflowToolConfigurationUtils:
|
|||
|
||||
return [VariableEntity.model_validate(variable) for variable in start_node.get("data", {}).get("variables", [])]
|
||||
|
||||
@classmethod
|
||||
def get_workflow_graph_output(cls, graph: Mapping[str, Any]) -> Sequence[OutputVariableEntity]:
|
||||
"""
|
||||
get workflow graph output
|
||||
"""
|
||||
nodes = graph.get("nodes", [])
|
||||
outputs_by_variable: dict[str, OutputVariableEntity] = {}
|
||||
variable_order: list[str] = []
|
||||
|
||||
for node in nodes:
|
||||
if node.get("data", {}).get("type") != "end":
|
||||
continue
|
||||
|
||||
for output in node.get("data", {}).get("outputs", []):
|
||||
entity = OutputVariableEntity.model_validate(output)
|
||||
variable = entity.variable
|
||||
|
||||
if variable not in variable_order:
|
||||
variable_order.append(variable)
|
||||
|
||||
# Later end nodes override duplicated variable definitions.
|
||||
outputs_by_variable[variable] = entity
|
||||
|
||||
return [outputs_by_variable[variable] for variable in variable_order]
|
||||
|
||||
@classmethod
|
||||
def check_is_synced(
|
||||
cls, variables: list[VariableEntity], tool_configurations: list[WorkflowToolParameterConfiguration]
|
||||
|
|
|
|||
|
|
@ -141,6 +141,7 @@ class WorkflowToolProviderController(ToolProviderController):
|
|||
form=parameter.form,
|
||||
llm_description=parameter.description,
|
||||
required=variable.required,
|
||||
default=variable.default,
|
||||
options=options,
|
||||
placeholder=I18nObject(en_US="", zh_Hans=""),
|
||||
)
|
||||
|
|
@ -161,6 +162,20 @@ class WorkflowToolProviderController(ToolProviderController):
|
|||
else:
|
||||
raise ValueError("variable not found")
|
||||
|
||||
# get output schema from workflow
|
||||
outputs = WorkflowToolConfigurationUtils.get_workflow_graph_output(graph)
|
||||
|
||||
reserved_keys = {"json", "text", "files"}
|
||||
|
||||
properties = {}
|
||||
for output in outputs:
|
||||
if output.variable not in reserved_keys:
|
||||
properties[output.variable] = {
|
||||
"type": output.value_type,
|
||||
"description": "",
|
||||
}
|
||||
output_schema = {"type": "object", "properties": properties}
|
||||
|
||||
return WorkflowTool(
|
||||
workflow_as_tool_id=db_provider.id,
|
||||
entity=ToolEntity(
|
||||
|
|
@ -176,6 +191,7 @@ class WorkflowToolProviderController(ToolProviderController):
|
|||
llm=db_provider.description,
|
||||
),
|
||||
parameters=workflow_tool_parameters,
|
||||
output_schema=output_schema,
|
||||
),
|
||||
runtime=ToolRuntime(
|
||||
tenant_id=db_provider.tenant_id,
|
||||
|
|
|
|||
|
|
@ -114,6 +114,11 @@ class WorkflowTool(Tool):
|
|||
for file in files:
|
||||
yield self.create_file_message(file) # type: ignore
|
||||
|
||||
# traverse `outputs` field and create variable messages
|
||||
for key, value in outputs.items():
|
||||
if key not in {"text", "json", "files"}:
|
||||
yield self.create_variable_message(variable_name=key, variable_value=value)
|
||||
|
||||
self._latest_usage = self._derive_usage_from_result(data)
|
||||
|
||||
yield self.create_text_message(json.dumps(outputs, ensure_ascii=False))
|
||||
|
|
|
|||
|
|
@ -71,6 +71,11 @@ class TriggerProviderIdentity(BaseModel):
|
|||
icon_dark: str | None = Field(default=None, description="The dark icon of the trigger provider")
|
||||
tags: list[str] = Field(default_factory=list, description="The tags of the trigger provider")
|
||||
|
||||
@field_validator("tags", mode="before")
|
||||
@classmethod
|
||||
def validate_tags(cls, v: list[str] | None) -> list[str]:
|
||||
return v or []
|
||||
|
||||
|
||||
class EventIdentity(BaseModel):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -1,17 +1,11 @@
|
|||
from ..runtime.graph_runtime_state import GraphRuntimeState
|
||||
from ..runtime.variable_pool import VariablePool
|
||||
from .agent import AgentNodeStrategyInit
|
||||
from .graph_init_params import GraphInitParams
|
||||
from .workflow_execution import WorkflowExecution
|
||||
from .workflow_node_execution import WorkflowNodeExecution
|
||||
from .workflow_pause import WorkflowPauseEntity
|
||||
|
||||
__all__ = [
|
||||
"AgentNodeStrategyInit",
|
||||
"GraphInitParams",
|
||||
"GraphRuntimeState",
|
||||
"VariablePool",
|
||||
"WorkflowExecution",
|
||||
"WorkflowNodeExecution",
|
||||
"WorkflowPauseEntity",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -1,49 +1,26 @@
|
|||
from enum import StrEnum, auto
|
||||
from typing import Annotated, Any, ClassVar, TypeAlias
|
||||
from typing import Annotated, Literal, TypeAlias
|
||||
|
||||
from pydantic import BaseModel, Discriminator, Tag
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class _PauseReasonType(StrEnum):
|
||||
class PauseReasonType(StrEnum):
|
||||
HUMAN_INPUT_REQUIRED = auto()
|
||||
SCHEDULED_PAUSE = auto()
|
||||
|
||||
|
||||
class _PauseReasonBase(BaseModel):
|
||||
TYPE: ClassVar[_PauseReasonType]
|
||||
class HumanInputRequired(BaseModel):
|
||||
TYPE: Literal[PauseReasonType.HUMAN_INPUT_REQUIRED] = PauseReasonType.HUMAN_INPUT_REQUIRED
|
||||
|
||||
form_id: str
|
||||
# The identifier of the human input node causing the pause.
|
||||
node_id: str
|
||||
|
||||
|
||||
class HumanInputRequired(_PauseReasonBase):
|
||||
TYPE = _PauseReasonType.HUMAN_INPUT_REQUIRED
|
||||
|
||||
|
||||
class SchedulingPause(_PauseReasonBase):
|
||||
TYPE = _PauseReasonType.SCHEDULED_PAUSE
|
||||
class SchedulingPause(BaseModel):
|
||||
TYPE: Literal[PauseReasonType.SCHEDULED_PAUSE] = PauseReasonType.SCHEDULED_PAUSE
|
||||
|
||||
message: str
|
||||
|
||||
|
||||
def _get_pause_reason_discriminator(v: Any) -> _PauseReasonType | None:
|
||||
if isinstance(v, _PauseReasonBase):
|
||||
return v.TYPE
|
||||
elif isinstance(v, dict):
|
||||
reason_type_str = v.get("TYPE")
|
||||
if reason_type_str is None:
|
||||
return None
|
||||
try:
|
||||
reason_type = _PauseReasonType(reason_type_str)
|
||||
except ValueError:
|
||||
return None
|
||||
return reason_type
|
||||
else:
|
||||
# return None if the discriminator value isn't found
|
||||
return None
|
||||
|
||||
|
||||
PauseReason: TypeAlias = Annotated[
|
||||
(
|
||||
Annotated[HumanInputRequired, Tag(_PauseReasonType.HUMAN_INPUT_REQUIRED)]
|
||||
| Annotated[SchedulingPause, Tag(_PauseReasonType.SCHEDULED_PAUSE)]
|
||||
),
|
||||
Discriminator(_get_pause_reason_discriminator),
|
||||
]
|
||||
PauseReason: TypeAlias = Annotated[HumanInputRequired | SchedulingPause, Field(discriminator="TYPE")]
|
||||
|
|
|
|||
|
|
@ -42,7 +42,7 @@ class GraphExecutionState(BaseModel):
|
|||
completed: bool = Field(default=False)
|
||||
aborted: bool = Field(default=False)
|
||||
paused: bool = Field(default=False)
|
||||
pause_reason: PauseReason | None = Field(default=None)
|
||||
pause_reasons: list[PauseReason] = Field(default_factory=list)
|
||||
error: GraphExecutionErrorState | None = Field(default=None)
|
||||
exceptions_count: int = Field(default=0)
|
||||
node_executions: list[NodeExecutionState] = Field(default_factory=list[NodeExecutionState])
|
||||
|
|
@ -107,7 +107,7 @@ class GraphExecution:
|
|||
completed: bool = False
|
||||
aborted: bool = False
|
||||
paused: bool = False
|
||||
pause_reason: PauseReason | None = None
|
||||
pause_reasons: list[PauseReason] = field(default_factory=list)
|
||||
error: Exception | None = None
|
||||
node_executions: dict[str, NodeExecution] = field(default_factory=dict[str, NodeExecution])
|
||||
exceptions_count: int = 0
|
||||
|
|
@ -137,10 +137,8 @@ class GraphExecution:
|
|||
raise RuntimeError("Cannot pause execution that has completed")
|
||||
if self.aborted:
|
||||
raise RuntimeError("Cannot pause execution that has been aborted")
|
||||
if self.paused:
|
||||
return
|
||||
self.paused = True
|
||||
self.pause_reason = reason
|
||||
self.pause_reasons.append(reason)
|
||||
|
||||
def fail(self, error: Exception) -> None:
|
||||
"""Mark the graph execution as failed."""
|
||||
|
|
@ -195,7 +193,7 @@ class GraphExecution:
|
|||
completed=self.completed,
|
||||
aborted=self.aborted,
|
||||
paused=self.paused,
|
||||
pause_reason=self.pause_reason,
|
||||
pause_reasons=self.pause_reasons,
|
||||
error=_serialize_error(self.error),
|
||||
exceptions_count=self.exceptions_count,
|
||||
node_executions=node_states,
|
||||
|
|
@ -221,7 +219,7 @@ class GraphExecution:
|
|||
self.completed = state.completed
|
||||
self.aborted = state.aborted
|
||||
self.paused = state.paused
|
||||
self.pause_reason = state.pause_reason
|
||||
self.pause_reasons = state.pause_reasons
|
||||
self.error = _deserialize_error(state.error)
|
||||
self.exceptions_count = state.exceptions_count
|
||||
self.node_executions = {
|
||||
|
|
|
|||
|
|
@ -110,7 +110,13 @@ class EventManager:
|
|||
"""
|
||||
with self._lock.write_lock():
|
||||
self._events.append(event)
|
||||
self._notify_layers(event)
|
||||
|
||||
# NOTE: `_notify_layers` is intentionally called outside the critical section
|
||||
# to minimize lock contention and avoid blocking other readers or writers.
|
||||
#
|
||||
# The public `notify_layers` method also does not use a write lock,
|
||||
# so protecting `_notify_layers` with a lock here is unnecessary.
|
||||
self._notify_layers(event)
|
||||
|
||||
def _get_new_events(self, start_index: int) -> list[GraphEngineEvent]:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -232,7 +232,7 @@ class GraphEngine:
|
|||
self._graph_execution.start()
|
||||
else:
|
||||
self._graph_execution.paused = False
|
||||
self._graph_execution.pause_reason = None
|
||||
self._graph_execution.pause_reasons = []
|
||||
|
||||
start_event = GraphRunStartedEvent()
|
||||
self._event_manager.notify_layers(start_event)
|
||||
|
|
@ -246,11 +246,11 @@ class GraphEngine:
|
|||
|
||||
# Handle completion
|
||||
if self._graph_execution.is_paused:
|
||||
pause_reason = self._graph_execution.pause_reason
|
||||
assert pause_reason is not None, "pause_reason should not be None when execution is paused."
|
||||
pause_reasons = self._graph_execution.pause_reasons
|
||||
assert pause_reasons, "pause_reasons should not be empty when execution is paused."
|
||||
# Ensure we have a valid PauseReason for the event
|
||||
paused_event = GraphRunPausedEvent(
|
||||
reason=pause_reason,
|
||||
reasons=pause_reasons,
|
||||
outputs=self._graph_runtime_state.outputs,
|
||||
)
|
||||
self._event_manager.notify_layers(paused_event)
|
||||
|
|
|
|||
|
|
@ -45,8 +45,7 @@ class GraphRunAbortedEvent(BaseGraphEvent):
|
|||
class GraphRunPausedEvent(BaseGraphEvent):
|
||||
"""Event emitted when a graph run is paused by user command."""
|
||||
|
||||
# reason: str | None = Field(default=None, description="reason for pause")
|
||||
reason: PauseReason = Field(..., description="reason for pause")
|
||||
reasons: list[PauseReason] = Field(description="reason for pause", default_factory=list)
|
||||
outputs: dict[str, object] = Field(
|
||||
default_factory=dict,
|
||||
description="Outputs available to the client while the run is paused.",
|
||||
|
|
|
|||
|
|
@ -26,7 +26,6 @@ from core.tools.tool_manager import ToolManager
|
|||
from core.tools.utils.message_transformer import ToolFileMessageTransformer
|
||||
from core.variables.segments import ArrayFileSegment, StringSegment
|
||||
from core.workflow.enums import (
|
||||
ErrorStrategy,
|
||||
NodeType,
|
||||
SystemVariableKey,
|
||||
WorkflowNodeExecutionMetadataKey,
|
||||
|
|
@ -40,7 +39,6 @@ from core.workflow.node_events import (
|
|||
StreamCompletedEvent,
|
||||
)
|
||||
from core.workflow.nodes.agent.entities import AgentNodeData, AgentOldVersionModelFeatures, ParamsAutoGenerated
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
|
||||
from core.workflow.runtime import VariablePool
|
||||
|
|
@ -66,34 +64,12 @@ if TYPE_CHECKING:
|
|||
from core.plugin.entities.request import InvokeCredentials
|
||||
|
||||
|
||||
class AgentNode(Node):
|
||||
class AgentNode(Node[AgentNodeData]):
|
||||
"""
|
||||
Agent Node
|
||||
"""
|
||||
|
||||
node_type = NodeType.AGENT
|
||||
_node_data: AgentNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = AgentNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
return self._node_data.retry_config
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
return self._node_data.default_value_dict
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
|
|
@ -105,8 +81,8 @@ class AgentNode(Node):
|
|||
try:
|
||||
strategy = get_plugin_agent_strategy(
|
||||
tenant_id=self.tenant_id,
|
||||
agent_strategy_provider_name=self._node_data.agent_strategy_provider_name,
|
||||
agent_strategy_name=self._node_data.agent_strategy_name,
|
||||
agent_strategy_provider_name=self.node_data.agent_strategy_provider_name,
|
||||
agent_strategy_name=self.node_data.agent_strategy_name,
|
||||
)
|
||||
except Exception as e:
|
||||
yield StreamCompletedEvent(
|
||||
|
|
@ -124,13 +100,13 @@ class AgentNode(Node):
|
|||
parameters = self._generate_agent_parameters(
|
||||
agent_parameters=agent_parameters,
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
node_data=self._node_data,
|
||||
node_data=self.node_data,
|
||||
strategy=strategy,
|
||||
)
|
||||
parameters_for_log = self._generate_agent_parameters(
|
||||
agent_parameters=agent_parameters,
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
node_data=self._node_data,
|
||||
node_data=self.node_data,
|
||||
for_log=True,
|
||||
strategy=strategy,
|
||||
)
|
||||
|
|
@ -163,7 +139,7 @@ class AgentNode(Node):
|
|||
messages=message_stream,
|
||||
tool_info={
|
||||
"icon": self.agent_strategy_icon,
|
||||
"agent_strategy": self._node_data.agent_strategy_name,
|
||||
"agent_strategy": self.node_data.agent_strategy_name,
|
||||
},
|
||||
parameters_for_log=parameters_for_log,
|
||||
user_id=self.user_id,
|
||||
|
|
@ -410,7 +386,7 @@ class AgentNode(Node):
|
|||
current_plugin = next(
|
||||
plugin
|
||||
for plugin in plugins
|
||||
if f"{plugin.plugin_id}/{plugin.name}" == self._node_data.agent_strategy_provider_name
|
||||
if f"{plugin.plugin_id}/{plugin.name}" == self.node_data.agent_strategy_provider_name
|
||||
)
|
||||
icon = current_plugin.declaration.icon
|
||||
except StopIteration:
|
||||
|
|
|
|||
|
|
@ -2,48 +2,24 @@ from collections.abc import Mapping, Sequence
|
|||
from typing import Any
|
||||
|
||||
from core.variables import ArrayFileSegment, FileSegment, Segment
|
||||
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.enums import NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.answer.entities import AnswerNodeData
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.base.template import Template
|
||||
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
|
||||
|
||||
|
||||
class AnswerNode(Node):
|
||||
class AnswerNode(Node[AnswerNodeData]):
|
||||
node_type = NodeType.ANSWER
|
||||
execution_type = NodeExecutionType.RESPONSE
|
||||
|
||||
_node_data: AnswerNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = AnswerNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
return self._node_data.retry_config
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
return self._node_data.default_value_dict
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
segments = self.graph_runtime_state.variable_pool.convert_template(self._node_data.answer)
|
||||
segments = self.graph_runtime_state.variable_pool.convert_template(self.node_data.answer)
|
||||
files = self._extract_files_from_segments(segments.value)
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
|
|
@ -93,4 +69,4 @@ class AnswerNode(Node):
|
|||
Returns:
|
||||
Template instance for this Answer node
|
||||
"""
|
||||
return Template.from_answer_template(self._node_data.answer)
|
||||
return Template.from_answer_template(self.node_data.answer)
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from collections.abc import Sequence
|
|||
from enum import StrEnum
|
||||
from typing import Any, Union
|
||||
|
||||
from pydantic import BaseModel, model_validator
|
||||
from pydantic import BaseModel, field_validator, model_validator
|
||||
|
||||
from core.workflow.enums import ErrorStrategy
|
||||
|
||||
|
|
@ -35,6 +35,45 @@ class VariableSelector(BaseModel):
|
|||
value_selector: Sequence[str]
|
||||
|
||||
|
||||
class OutputVariableType(StrEnum):
|
||||
STRING = "string"
|
||||
NUMBER = "number"
|
||||
INTEGER = "integer"
|
||||
SECRET = "secret"
|
||||
BOOLEAN = "boolean"
|
||||
OBJECT = "object"
|
||||
FILE = "file"
|
||||
ARRAY = "array"
|
||||
ARRAY_STRING = "array[string]"
|
||||
ARRAY_NUMBER = "array[number]"
|
||||
ARRAY_OBJECT = "array[object]"
|
||||
ARRAY_BOOLEAN = "array[boolean]"
|
||||
ARRAY_FILE = "array[file]"
|
||||
ANY = "any"
|
||||
ARRAY_ANY = "array[any]"
|
||||
|
||||
|
||||
class OutputVariableEntity(BaseModel):
|
||||
"""
|
||||
Output Variable Entity.
|
||||
"""
|
||||
|
||||
variable: str
|
||||
value_type: OutputVariableType
|
||||
value_selector: Sequence[str]
|
||||
|
||||
@field_validator("value_type", mode="before")
|
||||
@classmethod
|
||||
def normalize_value_type(cls, v: Any) -> Any:
|
||||
"""
|
||||
Normalize value_type to handle case-insensitive array types.
|
||||
Converts 'Array[...]' to 'array[...]' for backward compatibility.
|
||||
"""
|
||||
if isinstance(v, str) and v.startswith("Array["):
|
||||
return v.lower()
|
||||
return v
|
||||
|
||||
|
||||
class DefaultValueType(StrEnum):
|
||||
STRING = "string"
|
||||
NUMBER = "number"
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import logging
|
|||
from abc import abstractmethod
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from functools import singledispatchmethod
|
||||
from typing import Any, ClassVar
|
||||
from typing import Any, ClassVar, Generic, TypeVar, cast, get_args, get_origin
|
||||
from uuid import uuid4
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
|
|
@ -49,12 +49,121 @@ from models.enums import UserFrom
|
|||
|
||||
from .entities import BaseNodeData, RetryConfig
|
||||
|
||||
NodeDataT = TypeVar("NodeDataT", bound=BaseNodeData)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Node:
|
||||
class Node(Generic[NodeDataT]):
|
||||
node_type: ClassVar["NodeType"]
|
||||
execution_type: NodeExecutionType = NodeExecutionType.EXECUTABLE
|
||||
_node_data_type: ClassVar[type[BaseNodeData]] = BaseNodeData
|
||||
|
||||
def __init_subclass__(cls, **kwargs: Any) -> None:
|
||||
"""
|
||||
Automatically extract and validate the node data type from the generic parameter.
|
||||
|
||||
When a subclass is defined as `class MyNode(Node[MyNodeData])`, this method:
|
||||
1. Inspects `__orig_bases__` to find the `Node[T]` parameterization
|
||||
2. Extracts `T` (e.g., `MyNodeData`) from the generic argument
|
||||
3. Validates that `T` is a proper `BaseNodeData` subclass
|
||||
4. Stores it in `_node_data_type` for automatic hydration in `__init__`
|
||||
|
||||
This eliminates the need for subclasses to manually implement boilerplate
|
||||
accessor methods like `_get_title()`, `_get_error_strategy()`, etc.
|
||||
|
||||
How it works:
|
||||
::
|
||||
|
||||
class CodeNode(Node[CodeNodeData]):
|
||||
│ │
|
||||
│ └─────────────────────────────────┐
|
||||
│ │
|
||||
▼ ▼
|
||||
┌─────────────────────────────┐ ┌─────────────────────────────────┐
|
||||
│ __orig_bases__ = ( │ │ CodeNodeData(BaseNodeData) │
|
||||
│ Node[CodeNodeData], │ │ title: str │
|
||||
│ ) │ │ desc: str | None │
|
||||
└──────────────┬──────────────┘ │ ... │
|
||||
│ └─────────────────────────────────┘
|
||||
▼ ▲
|
||||
┌─────────────────────────────┐ │
|
||||
│ get_origin(base) -> Node │ │
|
||||
│ get_args(base) -> ( │ │
|
||||
│ CodeNodeData, │ ──────────────────────┘
|
||||
│ ) │
|
||||
└──────────────┬──────────────┘
|
||||
│
|
||||
▼
|
||||
┌─────────────────────────────┐
|
||||
│ Validate: │
|
||||
│ - Is it a type? │
|
||||
│ - Is it a BaseNodeData │
|
||||
│ subclass? │
|
||||
└──────────────┬──────────────┘
|
||||
│
|
||||
▼
|
||||
┌─────────────────────────────┐
|
||||
│ cls._node_data_type = │
|
||||
│ CodeNodeData │
|
||||
└─────────────────────────────┘
|
||||
|
||||
Later, in __init__:
|
||||
::
|
||||
|
||||
config["data"] ──► _hydrate_node_data() ──► _node_data_type.model_validate()
|
||||
│
|
||||
▼
|
||||
CodeNodeData instance
|
||||
(stored in self._node_data)
|
||||
|
||||
Example:
|
||||
class CodeNode(Node[CodeNodeData]): # CodeNodeData is auto-extracted
|
||||
node_type = NodeType.CODE
|
||||
# No need to implement _get_title, _get_error_strategy, etc.
|
||||
"""
|
||||
super().__init_subclass__(**kwargs)
|
||||
|
||||
if cls is Node:
|
||||
return
|
||||
|
||||
node_data_type = cls._extract_node_data_type_from_generic()
|
||||
|
||||
if node_data_type is None:
|
||||
raise TypeError(f"{cls.__name__} must inherit from Node[T] with a BaseNodeData subtype")
|
||||
|
||||
cls._node_data_type = node_data_type
|
||||
|
||||
@classmethod
|
||||
def _extract_node_data_type_from_generic(cls) -> type[BaseNodeData] | None:
|
||||
"""
|
||||
Extract the node data type from the generic parameter `Node[T]`.
|
||||
|
||||
Inspects `__orig_bases__` to find the `Node[T]` parameterization and extracts `T`.
|
||||
|
||||
Returns:
|
||||
The extracted BaseNodeData subtype, or None if not found.
|
||||
|
||||
Raises:
|
||||
TypeError: If the generic argument is invalid (not exactly one argument,
|
||||
or not a BaseNodeData subtype).
|
||||
"""
|
||||
# __orig_bases__ contains the original generic bases before type erasure.
|
||||
# For `class CodeNode(Node[CodeNodeData])`, this would be `(Node[CodeNodeData],)`.
|
||||
for base in getattr(cls, "__orig_bases__", ()): # type: ignore[attr-defined]
|
||||
origin = get_origin(base) # Returns `Node` for `Node[CodeNodeData]`
|
||||
if origin is Node:
|
||||
args = get_args(base) # Returns `(CodeNodeData,)` for `Node[CodeNodeData]`
|
||||
if len(args) != 1:
|
||||
raise TypeError(f"{cls.__name__} must specify exactly one node data generic argument")
|
||||
|
||||
candidate = args[0]
|
||||
if not isinstance(candidate, type) or not issubclass(candidate, BaseNodeData):
|
||||
raise TypeError(f"{cls.__name__} must parameterize Node with a BaseNodeData subtype")
|
||||
|
||||
return candidate
|
||||
|
||||
return None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -63,6 +172,7 @@ class Node:
|
|||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
) -> None:
|
||||
self._graph_init_params = graph_init_params
|
||||
self.id = id
|
||||
self.tenant_id = graph_init_params.tenant_id
|
||||
self.app_id = graph_init_params.app_id
|
||||
|
|
@ -83,8 +193,24 @@ class Node:
|
|||
self._node_execution_id: str = ""
|
||||
self._start_at = naive_utc_now()
|
||||
|
||||
@abstractmethod
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None: ...
|
||||
raw_node_data = config.get("data") or {}
|
||||
if not isinstance(raw_node_data, Mapping):
|
||||
raise ValueError("Node config data must be a mapping.")
|
||||
|
||||
self._node_data: NodeDataT = self._hydrate_node_data(raw_node_data)
|
||||
|
||||
self.post_init()
|
||||
|
||||
def post_init(self) -> None:
|
||||
"""Optional hook for subclasses requiring extra initialization."""
|
||||
return
|
||||
|
||||
@property
|
||||
def graph_init_params(self) -> "GraphInitParams":
|
||||
return self._graph_init_params
|
||||
|
||||
def _hydrate_node_data(self, data: Mapping[str, Any]) -> NodeDataT:
|
||||
return cast(NodeDataT, self._node_data_type.model_validate(data))
|
||||
|
||||
@abstractmethod
|
||||
def _run(self) -> NodeRunResult | Generator[NodeEventBase, None, None]:
|
||||
|
|
@ -114,23 +240,23 @@ class Node:
|
|||
from core.workflow.nodes.tool.tool_node import ToolNode
|
||||
|
||||
if isinstance(self, ToolNode):
|
||||
start_event.provider_id = getattr(self.get_base_node_data(), "provider_id", "")
|
||||
start_event.provider_type = getattr(self.get_base_node_data(), "provider_type", "")
|
||||
start_event.provider_id = getattr(self.node_data, "provider_id", "")
|
||||
start_event.provider_type = getattr(self.node_data, "provider_type", "")
|
||||
|
||||
from core.workflow.nodes.datasource.datasource_node import DatasourceNode
|
||||
|
||||
if isinstance(self, DatasourceNode):
|
||||
plugin_id = getattr(self.get_base_node_data(), "plugin_id", "")
|
||||
provider_name = getattr(self.get_base_node_data(), "provider_name", "")
|
||||
plugin_id = getattr(self.node_data, "plugin_id", "")
|
||||
provider_name = getattr(self.node_data, "provider_name", "")
|
||||
|
||||
start_event.provider_id = f"{plugin_id}/{provider_name}"
|
||||
start_event.provider_type = getattr(self.get_base_node_data(), "provider_type", "")
|
||||
start_event.provider_type = getattr(self.node_data, "provider_type", "")
|
||||
|
||||
from core.workflow.nodes.trigger_plugin.trigger_event_node import TriggerEventNode
|
||||
|
||||
if isinstance(self, TriggerEventNode):
|
||||
start_event.provider_id = getattr(self.get_base_node_data(), "provider_id", "")
|
||||
start_event.provider_type = getattr(self.get_base_node_data(), "provider_type", "")
|
||||
start_event.provider_id = getattr(self.node_data, "provider_id", "")
|
||||
start_event.provider_type = getattr(self.node_data, "provider_type", "")
|
||||
|
||||
from typing import cast
|
||||
|
||||
|
|
@ -139,7 +265,7 @@ class Node:
|
|||
|
||||
if isinstance(self, AgentNode):
|
||||
start_event.agent_strategy = AgentNodeStrategyInit(
|
||||
name=cast(AgentNodeData, self.get_base_node_data()).agent_strategy_name,
|
||||
name=cast(AgentNodeData, self.node_data).agent_strategy_name,
|
||||
icon=self.agent_strategy_icon,
|
||||
)
|
||||
|
||||
|
|
@ -273,38 +399,25 @@ class Node:
|
|||
def retry(self) -> bool:
|
||||
return False
|
||||
|
||||
# Abstract methods that subclasses must implement to provide access
|
||||
# to BaseNodeData properties in a type-safe way
|
||||
|
||||
@abstractmethod
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
"""Get the error strategy for this node."""
|
||||
...
|
||||
return self._node_data.error_strategy
|
||||
|
||||
@abstractmethod
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
"""Get the retry configuration for this node."""
|
||||
...
|
||||
return self._node_data.retry_config
|
||||
|
||||
@abstractmethod
|
||||
def _get_title(self) -> str:
|
||||
"""Get the node title."""
|
||||
...
|
||||
return self._node_data.title
|
||||
|
||||
@abstractmethod
|
||||
def _get_description(self) -> str | None:
|
||||
"""Get the node description."""
|
||||
...
|
||||
return self._node_data.desc
|
||||
|
||||
@abstractmethod
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
"""Get the default values dictionary for this node."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
"""Get the BaseNodeData object for this node."""
|
||||
...
|
||||
return self._node_data.default_value_dict
|
||||
|
||||
# Public interface properties that delegate to abstract methods
|
||||
@property
|
||||
|
|
@ -332,6 +445,11 @@ class Node:
|
|||
"""Get the default values dictionary for this node."""
|
||||
return self._get_default_value_dict()
|
||||
|
||||
@property
|
||||
def node_data(self) -> NodeDataT:
|
||||
"""Typed access to this node's configuration data."""
|
||||
return self._node_data
|
||||
|
||||
def _convert_node_run_result_to_graph_node_event(self, result: NodeRunResult) -> GraphNodeEventBase:
|
||||
match result.status:
|
||||
case WorkflowNodeExecutionStatus.FAILED:
|
||||
|
|
@ -426,7 +544,7 @@ class Node:
|
|||
id=self._node_execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
node_title=self.get_base_node_data().title,
|
||||
node_title=self.node_data.title,
|
||||
start_at=event.start_at,
|
||||
inputs=event.inputs,
|
||||
metadata=event.metadata,
|
||||
|
|
@ -439,7 +557,7 @@ class Node:
|
|||
id=self._node_execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
node_title=self.get_base_node_data().title,
|
||||
node_title=self.node_data.title,
|
||||
index=event.index,
|
||||
pre_loop_output=event.pre_loop_output,
|
||||
)
|
||||
|
|
@ -450,7 +568,7 @@ class Node:
|
|||
id=self._node_execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
node_title=self.get_base_node_data().title,
|
||||
node_title=self.node_data.title,
|
||||
start_at=event.start_at,
|
||||
inputs=event.inputs,
|
||||
outputs=event.outputs,
|
||||
|
|
@ -464,7 +582,7 @@ class Node:
|
|||
id=self._node_execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
node_title=self.get_base_node_data().title,
|
||||
node_title=self.node_data.title,
|
||||
start_at=event.start_at,
|
||||
inputs=event.inputs,
|
||||
outputs=event.outputs,
|
||||
|
|
@ -479,7 +597,7 @@ class Node:
|
|||
id=self._node_execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
node_title=self.get_base_node_data().title,
|
||||
node_title=self.node_data.title,
|
||||
start_at=event.start_at,
|
||||
inputs=event.inputs,
|
||||
metadata=event.metadata,
|
||||
|
|
@ -492,7 +610,7 @@ class Node:
|
|||
id=self._node_execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
node_title=self.get_base_node_data().title,
|
||||
node_title=self.node_data.title,
|
||||
index=event.index,
|
||||
pre_iteration_output=event.pre_iteration_output,
|
||||
)
|
||||
|
|
@ -503,7 +621,7 @@ class Node:
|
|||
id=self._node_execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
node_title=self.get_base_node_data().title,
|
||||
node_title=self.node_data.title,
|
||||
start_at=event.start_at,
|
||||
inputs=event.inputs,
|
||||
outputs=event.outputs,
|
||||
|
|
@ -517,7 +635,7 @@ class Node:
|
|||
id=self._node_execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
node_title=self.get_base_node_data().title,
|
||||
node_title=self.node_data.title,
|
||||
start_at=event.start_at,
|
||||
inputs=event.inputs,
|
||||
outputs=event.outputs,
|
||||
|
|
|
|||
|
|
@ -9,9 +9,8 @@ from core.helper.code_executor.javascript.javascript_code_provider import Javasc
|
|||
from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider
|
||||
from core.variables.segments import ArrayFileSegment
|
||||
from core.variables.types import SegmentType
|
||||
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.code.entities import CodeNodeData
|
||||
|
||||
|
|
@ -22,32 +21,9 @@ from .exc import (
|
|||
)
|
||||
|
||||
|
||||
class CodeNode(Node):
|
||||
class CodeNode(Node[CodeNodeData]):
|
||||
node_type = NodeType.CODE
|
||||
|
||||
_node_data: CodeNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = CodeNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
return self._node_data.retry_config
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
return self._node_data.default_value_dict
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
|
||||
"""
|
||||
|
|
@ -70,12 +46,12 @@ class CodeNode(Node):
|
|||
|
||||
def _run(self) -> NodeRunResult:
|
||||
# Get code language
|
||||
code_language = self._node_data.code_language
|
||||
code = self._node_data.code
|
||||
code_language = self.node_data.code_language
|
||||
code = self.node_data.code
|
||||
|
||||
# Get variables
|
||||
variables = {}
|
||||
for variable_selector in self._node_data.variables:
|
||||
for variable_selector in self.node_data.variables:
|
||||
variable_name = variable_selector.variable
|
||||
variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector)
|
||||
if isinstance(variable, ArrayFileSegment):
|
||||
|
|
@ -91,7 +67,7 @@ class CodeNode(Node):
|
|||
)
|
||||
|
||||
# Transform result
|
||||
result = self._transform_result(result=result, output_schema=self._node_data.outputs)
|
||||
result = self._transform_result(result=result, output_schema=self.node_data.outputs)
|
||||
except (CodeExecutionError, CodeNodeError) as e:
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e), error_type=type(e).__name__
|
||||
|
|
@ -428,7 +404,7 @@ class CodeNode(Node):
|
|||
|
||||
@property
|
||||
def retry(self) -> bool:
|
||||
return self._node_data.retry_config.retry_enabled
|
||||
return self.node_data.retry_config.retry_enabled
|
||||
|
||||
@staticmethod
|
||||
def _convert_boolean_to_int(value: bool | int | float | None) -> int | float | None:
|
||||
|
|
|
|||
|
|
@ -20,9 +20,8 @@ from core.plugin.impl.exc import PluginDaemonClientSideError
|
|||
from core.variables.segments import ArrayAnySegment
|
||||
from core.variables.variables import ArrayAnyVariable
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, SystemVariableKey
|
||||
from core.workflow.enums import NodeExecutionType, NodeType, SystemVariableKey
|
||||
from core.workflow.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
|
||||
from core.workflow.nodes.tool.exc import ToolFileError
|
||||
|
|
@ -38,42 +37,20 @@ from .entities import DatasourceNodeData
|
|||
from .exc import DatasourceNodeError, DatasourceParameterError
|
||||
|
||||
|
||||
class DatasourceNode(Node):
|
||||
class DatasourceNode(Node[DatasourceNodeData]):
|
||||
"""
|
||||
Datasource Node
|
||||
"""
|
||||
|
||||
_node_data: DatasourceNodeData
|
||||
node_type = NodeType.DATASOURCE
|
||||
execution_type = NodeExecutionType.ROOT
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
self._node_data = DatasourceNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
return self._node_data.retry_config
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
return self._node_data.default_value_dict
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
def _run(self) -> Generator:
|
||||
"""
|
||||
Run the datasource node
|
||||
"""
|
||||
|
||||
node_data = self._node_data
|
||||
node_data = self.node_data
|
||||
variable_pool = self.graph_runtime_state.variable_pool
|
||||
datasource_type_segement = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_TYPE])
|
||||
if not datasource_type_segement:
|
||||
|
|
|
|||