mirror of https://github.com/langgenius/dify.git
Merge branch 'fix/secret_variable' of github.com:PankajKaushal/dify into fix/secret_variable
This commit is contained in:
commit
bb53760e6c
|
|
@ -1,4 +1,4 @@
|
|||
FROM mcr.microsoft.com/devcontainers/python:3.12-bullseye
|
||||
FROM mcr.microsoft.com/devcontainers/python:3.12-bookworm
|
||||
|
||||
RUN apt-get update && export DEBIAN_FRONTEND=noninteractive \
|
||||
&& apt-get -y install libgmp-dev libmpfr-dev libmpc-dev
|
||||
|
|
|
|||
|
|
@ -30,6 +30,8 @@ jobs:
|
|||
run: |
|
||||
uvx --from ast-grep-cli sg --pattern 'db.session.query($WHATEVER).filter($HERE)' --rewrite 'db.session.query($WHATEVER).where($HERE)' -l py --update-all
|
||||
uvx --from ast-grep-cli sg --pattern 'session.query($WHATEVER).filter($HERE)' --rewrite 'session.query($WHATEVER).where($HERE)' -l py --update-all
|
||||
uvx --from ast-grep-cli sg -p '$A = db.Column($$$B)' -r '$A = mapped_column($$$B)' -l py --update-all
|
||||
uvx --from ast-grep-cli sg -p '$A : $T = db.Column($$$B)' -r '$A : $T = mapped_column($$$B)' -l py --update-all
|
||||
# Convert Optional[T] to T | None (ignoring quoted types)
|
||||
cat > /tmp/optional-rule.yml << 'EOF'
|
||||
id: convert-optional-to-union
|
||||
|
|
|
|||
|
|
@ -4,8 +4,7 @@ on:
|
|||
push:
|
||||
branches:
|
||||
- "main"
|
||||
- "deploy/dev"
|
||||
- "deploy/enterprise"
|
||||
- "deploy/**"
|
||||
- "build/**"
|
||||
- "release/e-*"
|
||||
- "hotfix/**"
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ jobs:
|
|||
- name: Deploy to server
|
||||
uses: appleboy/ssh-action@v0.1.8
|
||||
with:
|
||||
host: ${{ secrets.RAG_SSH_HOST }}
|
||||
host: ${{ secrets.SSH_HOST }}
|
||||
username: ${{ secrets.SSH_USER }}
|
||||
key: ${{ secrets.SSH_PRIVATE_KEY }}
|
||||
script: |
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
name: Deploy RAG Dev
|
||||
name: Deploy Trigger Dev
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
|
@ -7,7 +7,7 @@ on:
|
|||
workflow_run:
|
||||
workflows: ["Build and Push API & Web"]
|
||||
branches:
|
||||
- "deploy/rag-dev"
|
||||
- "deploy/trigger-dev"
|
||||
types:
|
||||
- completed
|
||||
|
||||
|
|
@ -16,12 +16,12 @@ jobs:
|
|||
runs-on: ubuntu-latest
|
||||
if: |
|
||||
github.event.workflow_run.conclusion == 'success' &&
|
||||
github.event.workflow_run.head_branch == 'deploy/rag-dev'
|
||||
github.event.workflow_run.head_branch == 'deploy/trigger-dev'
|
||||
steps:
|
||||
- name: Deploy to server
|
||||
uses: appleboy/ssh-action@v0.1.8
|
||||
with:
|
||||
host: ${{ secrets.RAG_SSH_HOST }}
|
||||
host: ${{ secrets.TRIGGER_SSH_HOST }}
|
||||
username: ${{ secrets.SSH_USER }}
|
||||
key: ${{ secrets.SSH_PRIVATE_KEY }}
|
||||
script: |
|
||||
24
README.md
24
README.md
|
|
@ -40,18 +40,18 @@
|
|||
|
||||
<p align="center">
|
||||
<a href="./README.md"><img alt="README in English" src="https://img.shields.io/badge/English-d9d9d9"></a>
|
||||
<a href="./README/README_TW.md"><img alt="繁體中文文件" src="https://img.shields.io/badge/繁體中文-d9d9d9"></a>
|
||||
<a href="./README/README_CN.md"><img alt="简体中文版自述文件" src="https://img.shields.io/badge/简体中文-d9d9d9"></a>
|
||||
<a href="./README/README_JA.md"><img alt="日本語のREADME" src="https://img.shields.io/badge/日本語-d9d9d9"></a>
|
||||
<a href="./README/README_ES.md"><img alt="README en Español" src="https://img.shields.io/badge/Español-d9d9d9"></a>
|
||||
<a href="./README/README_FR.md"><img alt="README en Français" src="https://img.shields.io/badge/Français-d9d9d9"></a>
|
||||
<a href="./README/README_KL.md"><img alt="README tlhIngan Hol" src="https://img.shields.io/badge/Klingon-d9d9d9"></a>
|
||||
<a href="./README/README_KR.md"><img alt="README in Korean" src="https://img.shields.io/badge/한국어-d9d9d9"></a>
|
||||
<a href="./README/README_AR.md"><img alt="README بالعربية" src="https://img.shields.io/badge/العربية-d9d9d9"></a>
|
||||
<a href="./README/README_TR.md"><img alt="Türkçe README" src="https://img.shields.io/badge/Türkçe-d9d9d9"></a>
|
||||
<a href="./README/README_VI.md"><img alt="README Tiếng Việt" src="https://img.shields.io/badge/Ti%E1%BA%BFng%20Vi%E1%BB%87t-d9d9d9"></a>
|
||||
<a href="./README/README_DE.md"><img alt="README in Deutsch" src="https://img.shields.io/badge/German-d9d9d9"></a>
|
||||
<a href="./README/README_BN.md"><img alt="README in বাংলা" src="https://img.shields.io/badge/বাংলা-d9d9d9"></a>
|
||||
<a href="./docs/zh-TW/README.md"><img alt="繁體中文文件" src="https://img.shields.io/badge/繁體中文-d9d9d9"></a>
|
||||
<a href="./docs/zh-CN/README.md"><img alt="简体中文文件" src="https://img.shields.io/badge/简体中文-d9d9d9"></a>
|
||||
<a href="./docs/ja-JP/README.md"><img alt="日本語のREADME" src="https://img.shields.io/badge/日本語-d9d9d9"></a>
|
||||
<a href="./docs/es-ES/README.md"><img alt="README en Español" src="https://img.shields.io/badge/Español-d9d9d9"></a>
|
||||
<a href="./docs/fr-FR/README.md"><img alt="README en Français" src="https://img.shields.io/badge/Français-d9d9d9"></a>
|
||||
<a href="./docs/tlh/README.md"><img alt="README tlhIngan Hol" src="https://img.shields.io/badge/Klingon-d9d9d9"></a>
|
||||
<a href="./docs/ko-KR/README.md"><img alt="README in Korean" src="https://img.shields.io/badge/한국어-d9d9d9"></a>
|
||||
<a href="./docs/ar-SA/README.md"><img alt="README بالعربية" src="https://img.shields.io/badge/العربية-d9d9d9"></a>
|
||||
<a href="./docs/tr-TR/README.md"><img alt="Türkçe README" src="https://img.shields.io/badge/Türkçe-d9d9d9"></a>
|
||||
<a href="./docs/vi-VN/README.md"><img alt="README Tiếng Việt" src="https://img.shields.io/badge/Ti%E1%BA%BFng%20Vi%E1%BB%87t-d9d9d9"></a>
|
||||
<a href="./docs/de-DE/README.md"><img alt="README in Deutsch" src="https://img.shields.io/badge/German-d9d9d9"></a>
|
||||
<a href="./docs/bn-BD/README.md"><img alt="README in বাংলা" src="https://img.shields.io/badge/বাংলা-d9d9d9"></a>
|
||||
</p>
|
||||
|
||||
Dify is an open-source platform for developing LLM applications. Its intuitive interface combines agentic AI workflows, RAG pipelines, agent capabilities, model management, observability features, and more—allowing you to quickly move from prototype to production.
|
||||
|
|
|
|||
|
|
@ -343,6 +343,15 @@ OCEANBASE_VECTOR_DATABASE=test
|
|||
OCEANBASE_MEMORY_LIMIT=6G
|
||||
OCEANBASE_ENABLE_HYBRID_SEARCH=false
|
||||
|
||||
# AlibabaCloud MySQL Vector configuration
|
||||
ALIBABACLOUD_MYSQL_HOST=127.0.0.1
|
||||
ALIBABACLOUD_MYSQL_PORT=3306
|
||||
ALIBABACLOUD_MYSQL_USER=root
|
||||
ALIBABACLOUD_MYSQL_PASSWORD=root
|
||||
ALIBABACLOUD_MYSQL_DATABASE=dify
|
||||
ALIBABACLOUD_MYSQL_MAX_CONNECTION=5
|
||||
ALIBABACLOUD_MYSQL_HNSW_M=6
|
||||
|
||||
# openGauss configuration
|
||||
OPENGAUSS_HOST=127.0.0.1
|
||||
OPENGAUSS_PORT=6600
|
||||
|
|
@ -427,8 +436,8 @@ CODE_EXECUTION_POOL_MAX_KEEPALIVE_CONNECTIONS=20
|
|||
CODE_EXECUTION_POOL_KEEPALIVE_EXPIRY=5.0
|
||||
CODE_MAX_NUMBER=9223372036854775807
|
||||
CODE_MIN_NUMBER=-9223372036854775808
|
||||
CODE_MAX_STRING_LENGTH=80000
|
||||
TEMPLATE_TRANSFORM_MAX_LENGTH=80000
|
||||
CODE_MAX_STRING_LENGTH=400000
|
||||
TEMPLATE_TRANSFORM_MAX_LENGTH=400000
|
||||
CODE_MAX_STRING_ARRAY_LENGTH=30
|
||||
CODE_MAX_OBJECT_ARRAY_LENGTH=30
|
||||
CODE_MAX_NUMBER_ARRAY_LENGTH=1000
|
||||
|
|
|
|||
|
|
@ -81,7 +81,6 @@ ignore = [
|
|||
"SIM113", # enumerate-for-loop
|
||||
"SIM117", # multiple-with-statements
|
||||
"SIM210", # if-expr-with-true-false
|
||||
"UP038", # deprecated and not recommended by Ruff, https://docs.astral.sh/ruff/rules/non-pep604-isinstance/
|
||||
]
|
||||
|
||||
[lint.per-file-ignores]
|
||||
|
|
|
|||
|
|
@ -1521,6 +1521,14 @@ def transform_datasource_credentials():
|
|||
auth_count = 0
|
||||
for firecrawl_tenant_credential in firecrawl_tenant_credentials:
|
||||
auth_count += 1
|
||||
if not firecrawl_tenant_credential.credentials:
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Skipping firecrawl credential for tenant {tenant_id} due to missing credentials.",
|
||||
fg="yellow",
|
||||
)
|
||||
)
|
||||
continue
|
||||
# get credential api key
|
||||
credentials_json = json.loads(firecrawl_tenant_credential.credentials)
|
||||
api_key = credentials_json.get("config", {}).get("api_key")
|
||||
|
|
@ -1576,6 +1584,14 @@ def transform_datasource_credentials():
|
|||
auth_count = 0
|
||||
for jina_tenant_credential in jina_tenant_credentials:
|
||||
auth_count += 1
|
||||
if not jina_tenant_credential.credentials:
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Skipping jina credential for tenant {tenant_id} due to missing credentials.",
|
||||
fg="yellow",
|
||||
)
|
||||
)
|
||||
continue
|
||||
# get credential api key
|
||||
credentials_json = json.loads(jina_tenant_credential.credentials)
|
||||
api_key = credentials_json.get("config", {}).get("api_key")
|
||||
|
|
|
|||
|
|
@ -150,7 +150,7 @@ class CodeExecutionSandboxConfig(BaseSettings):
|
|||
|
||||
CODE_MAX_STRING_LENGTH: PositiveInt = Field(
|
||||
description="Maximum allowed length for strings in code execution",
|
||||
default=80000,
|
||||
default=400_000,
|
||||
)
|
||||
|
||||
CODE_MAX_STRING_ARRAY_LENGTH: PositiveInt = Field(
|
||||
|
|
@ -362,11 +362,11 @@ class HttpConfig(BaseSettings):
|
|||
)
|
||||
|
||||
HTTP_REQUEST_MAX_READ_TIMEOUT: int = Field(
|
||||
ge=1, description="Maximum read timeout in seconds for HTTP requests", default=60
|
||||
ge=1, description="Maximum read timeout in seconds for HTTP requests", default=600
|
||||
)
|
||||
|
||||
HTTP_REQUEST_MAX_WRITE_TIMEOUT: int = Field(
|
||||
ge=1, description="Maximum write timeout in seconds for HTTP requests", default=20
|
||||
ge=1, description="Maximum write timeout in seconds for HTTP requests", default=600
|
||||
)
|
||||
|
||||
HTTP_REQUEST_NODE_MAX_BINARY_SIZE: PositiveInt = Field(
|
||||
|
|
@ -582,6 +582,11 @@ class WorkflowConfig(BaseSettings):
|
|||
default=200 * 1024,
|
||||
)
|
||||
|
||||
TEMPLATE_TRANSFORM_MAX_LENGTH: PositiveInt = Field(
|
||||
description="Maximum number of characters allowed in Template Transform node output",
|
||||
default=400_000,
|
||||
)
|
||||
|
||||
# GraphEngine Worker Pool Configuration
|
||||
GRAPH_ENGINE_MIN_WORKERS: PositiveInt = Field(
|
||||
description="Minimum number of workers per GraphEngine instance",
|
||||
|
|
@ -766,7 +771,7 @@ class MailConfig(BaseSettings):
|
|||
|
||||
MAIL_TEMPLATING_TIMEOUT: int = Field(
|
||||
description="""
|
||||
Timeout for email templating in seconds. Used to prevent infinite loops in malicious templates.
|
||||
Timeout for email templating in seconds. Used to prevent infinite loops in malicious templates.
|
||||
Only available in sandbox mode.""",
|
||||
default=3,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ from .storage.opendal_storage_config import OpenDALStorageConfig
|
|||
from .storage.supabase_storage_config import SupabaseStorageConfig
|
||||
from .storage.tencent_cos_storage_config import TencentCloudCOSStorageConfig
|
||||
from .storage.volcengine_tos_storage_config import VolcengineTOSStorageConfig
|
||||
from .vdb.alibabacloud_mysql_config import AlibabaCloudMySQLConfig
|
||||
from .vdb.analyticdb_config import AnalyticdbConfig
|
||||
from .vdb.baidu_vector_config import BaiduVectorDBConfig
|
||||
from .vdb.chroma_config import ChromaConfig
|
||||
|
|
@ -330,6 +331,7 @@ class MiddlewareConfig(
|
|||
ClickzettaConfig,
|
||||
HuaweiCloudConfig,
|
||||
MilvusConfig,
|
||||
AlibabaCloudMySQLConfig,
|
||||
MyScaleConfig,
|
||||
OpenSearchConfig,
|
||||
OracleConfig,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,54 @@
|
|||
from pydantic import Field, PositiveInt
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class AlibabaCloudMySQLConfig(BaseSettings):
|
||||
"""
|
||||
Configuration settings for AlibabaCloud MySQL vector database
|
||||
"""
|
||||
|
||||
ALIBABACLOUD_MYSQL_HOST: str = Field(
|
||||
description="Hostname or IP address of the AlibabaCloud MySQL server (e.g., 'localhost' or 'mysql.aliyun.com')",
|
||||
default="localhost",
|
||||
)
|
||||
|
||||
ALIBABACLOUD_MYSQL_PORT: PositiveInt = Field(
|
||||
description="Port number on which the AlibabaCloud MySQL server is listening (default is 3306)",
|
||||
default=3306,
|
||||
)
|
||||
|
||||
ALIBABACLOUD_MYSQL_USER: str = Field(
|
||||
description="Username for authenticating with AlibabaCloud MySQL (default is 'root')",
|
||||
default="root",
|
||||
)
|
||||
|
||||
ALIBABACLOUD_MYSQL_PASSWORD: str = Field(
|
||||
description="Password for authenticating with AlibabaCloud MySQL (default is an empty string)",
|
||||
default="",
|
||||
)
|
||||
|
||||
ALIBABACLOUD_MYSQL_DATABASE: str = Field(
|
||||
description="Name of the AlibabaCloud MySQL database to connect to (default is 'dify')",
|
||||
default="dify",
|
||||
)
|
||||
|
||||
ALIBABACLOUD_MYSQL_MAX_CONNECTION: PositiveInt = Field(
|
||||
description="Maximum number of connections in the connection pool",
|
||||
default=5,
|
||||
)
|
||||
|
||||
ALIBABACLOUD_MYSQL_CHARSET: str = Field(
|
||||
description="Character set for AlibabaCloud MySQL connection (default is 'utf8mb4')",
|
||||
default="utf8mb4",
|
||||
)
|
||||
|
||||
ALIBABACLOUD_MYSQL_DISTANCE_FUNCTION: str = Field(
|
||||
description="Distance function used for vector similarity search in AlibabaCloud MySQL "
|
||||
"(e.g., 'cosine', 'euclidean')",
|
||||
default="cosine",
|
||||
)
|
||||
|
||||
ALIBABACLOUD_MYSQL_HNSW_M: PositiveInt = Field(
|
||||
description="Maximum number of connections per layer for HNSW vector index (default is 6, range: 3-200)",
|
||||
default=6,
|
||||
)
|
||||
|
|
@ -1,23 +1,24 @@
|
|||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import Field, PositiveInt
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class AuthMethod(StrEnum):
|
||||
"""
|
||||
Authentication method for OpenSearch
|
||||
"""
|
||||
|
||||
BASIC = "basic"
|
||||
AWS_MANAGED_IAM = "aws_managed_iam"
|
||||
|
||||
|
||||
class OpenSearchConfig(BaseSettings):
|
||||
"""
|
||||
Configuration settings for OpenSearch
|
||||
"""
|
||||
|
||||
class AuthMethod(Enum):
|
||||
"""
|
||||
Authentication method for OpenSearch
|
||||
"""
|
||||
|
||||
BASIC = "basic"
|
||||
AWS_MANAGED_IAM = "aws_managed_iam"
|
||||
|
||||
OPENSEARCH_HOST: str | None = Field(
|
||||
description="Hostname or IP address of the OpenSearch server (e.g., 'localhost' or 'opensearch.example.com')",
|
||||
default=None,
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
from configs import dify_config
|
||||
from libs.collection_utils import convert_to_lower_and_upper_set
|
||||
|
||||
HIDDEN_VALUE = "[__HIDDEN__]"
|
||||
UNKNOWN_VALUE = "[__UNKNOWN__]"
|
||||
|
|
@ -6,24 +7,39 @@ UUID_NIL = "00000000-0000-0000-0000-000000000000"
|
|||
|
||||
DEFAULT_FILE_NUMBER_LIMITS = 3
|
||||
|
||||
IMAGE_EXTENSIONS = ["jpg", "jpeg", "png", "webp", "gif", "svg"]
|
||||
IMAGE_EXTENSIONS.extend([ext.upper() for ext in IMAGE_EXTENSIONS])
|
||||
IMAGE_EXTENSIONS = convert_to_lower_and_upper_set({"jpg", "jpeg", "png", "webp", "gif", "svg"})
|
||||
|
||||
VIDEO_EXTENSIONS = ["mp4", "mov", "mpeg", "webm"]
|
||||
VIDEO_EXTENSIONS.extend([ext.upper() for ext in VIDEO_EXTENSIONS])
|
||||
VIDEO_EXTENSIONS = convert_to_lower_and_upper_set({"mp4", "mov", "mpeg", "webm"})
|
||||
|
||||
AUDIO_EXTENSIONS = ["mp3", "m4a", "wav", "amr", "mpga"]
|
||||
AUDIO_EXTENSIONS.extend([ext.upper() for ext in AUDIO_EXTENSIONS])
|
||||
AUDIO_EXTENSIONS = convert_to_lower_and_upper_set({"mp3", "m4a", "wav", "amr", "mpga"})
|
||||
|
||||
|
||||
_doc_extensions: list[str]
|
||||
_doc_extensions: set[str]
|
||||
if dify_config.ETL_TYPE == "Unstructured":
|
||||
_doc_extensions = ["txt", "markdown", "md", "mdx", "pdf", "html", "htm", "xlsx", "xls", "vtt", "properties"]
|
||||
_doc_extensions.extend(("doc", "docx", "csv", "eml", "msg", "pptx", "xml", "epub"))
|
||||
_doc_extensions = {
|
||||
"txt",
|
||||
"markdown",
|
||||
"md",
|
||||
"mdx",
|
||||
"pdf",
|
||||
"html",
|
||||
"htm",
|
||||
"xlsx",
|
||||
"xls",
|
||||
"vtt",
|
||||
"properties",
|
||||
"doc",
|
||||
"docx",
|
||||
"csv",
|
||||
"eml",
|
||||
"msg",
|
||||
"pptx",
|
||||
"xml",
|
||||
"epub",
|
||||
}
|
||||
if dify_config.UNSTRUCTURED_API_URL:
|
||||
_doc_extensions.append("ppt")
|
||||
_doc_extensions.add("ppt")
|
||||
else:
|
||||
_doc_extensions = [
|
||||
_doc_extensions = {
|
||||
"txt",
|
||||
"markdown",
|
||||
"md",
|
||||
|
|
@ -37,5 +53,5 @@ else:
|
|||
"csv",
|
||||
"vtt",
|
||||
"properties",
|
||||
]
|
||||
DOCUMENT_EXTENSIONS = _doc_extensions + [ext.upper() for ext in _doc_extensions]
|
||||
}
|
||||
DOCUMENT_EXTENSIONS: set[str] = convert_to_lower_and_upper_set(_doc_extensions)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
import flask_restx
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
from flask_restx._http import HTTPStatus
|
||||
from sqlalchemy import select
|
||||
|
|
@ -8,7 +7,8 @@ from werkzeug.exceptions import Forbidden
|
|||
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import TimestampField
|
||||
from libs.login import login_required
|
||||
from libs.login import current_user, login_required
|
||||
from models.account import Account
|
||||
from models.dataset import Dataset
|
||||
from models.model import ApiToken, App
|
||||
|
||||
|
|
@ -57,6 +57,8 @@ class BaseApiKeyListResource(Resource):
|
|||
def get(self, resource_id):
|
||||
assert self.resource_id_field is not None, "resource_id_field must be set"
|
||||
resource_id = str(resource_id)
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
_get_resource(resource_id, current_user.current_tenant_id, self.resource_model)
|
||||
keys = db.session.scalars(
|
||||
select(ApiToken).where(
|
||||
|
|
@ -69,8 +71,10 @@ class BaseApiKeyListResource(Resource):
|
|||
def post(self, resource_id):
|
||||
assert self.resource_id_field is not None, "resource_id_field must be set"
|
||||
resource_id = str(resource_id)
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
_get_resource(resource_id, current_user.current_tenant_id, self.resource_model)
|
||||
if not current_user.is_editor:
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
current_key_count = (
|
||||
|
|
@ -108,6 +112,8 @@ class BaseApiKeyResource(Resource):
|
|||
assert self.resource_id_field is not None, "resource_id_field must be set"
|
||||
resource_id = str(resource_id)
|
||||
api_key_id = str(api_key_id)
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
_get_resource(resource_id, current_user.current_tenant_id, self.resource_model)
|
||||
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ from core.ops.ops_trace_manager import OpsTraceManager
|
|||
from extensions.ext_database import db
|
||||
from fields.app_fields import app_detail_fields, app_detail_fields_with_site, app_pagination_fields
|
||||
from libs.login import login_required
|
||||
from libs.validators import validate_description_length
|
||||
from models import Account, App
|
||||
from services.app_dsl_service import AppDslService, ImportMode
|
||||
from services.app_service import AppService
|
||||
|
|
@ -28,12 +29,6 @@ from services.feature_service import FeatureService
|
|||
ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "completion"]
|
||||
|
||||
|
||||
def _validate_description_length(description):
|
||||
if description and len(description) > 400:
|
||||
raise ValueError("Description cannot exceed 400 characters.")
|
||||
return description
|
||||
|
||||
|
||||
@console_ns.route("/apps")
|
||||
class AppListApi(Resource):
|
||||
@api.doc("list_apps")
|
||||
|
|
@ -138,7 +133,7 @@ class AppListApi(Resource):
|
|||
"""Create app"""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("name", type=str, required=True, location="json")
|
||||
parser.add_argument("description", type=_validate_description_length, location="json")
|
||||
parser.add_argument("description", type=validate_description_length, location="json")
|
||||
parser.add_argument("mode", type=str, choices=ALLOW_CREATE_APP_MODES, location="json")
|
||||
parser.add_argument("icon_type", type=str, location="json")
|
||||
parser.add_argument("icon", type=str, location="json")
|
||||
|
|
@ -219,7 +214,7 @@ class AppApi(Resource):
|
|||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("name", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("description", type=_validate_description_length, location="json")
|
||||
parser.add_argument("description", type=validate_description_length, location="json")
|
||||
parser.add_argument("icon_type", type=str, location="json")
|
||||
parser.add_argument("icon", type=str, location="json")
|
||||
parser.add_argument("icon_background", type=str, location="json")
|
||||
|
|
@ -297,7 +292,7 @@ class AppCopyApi(Resource):
|
|||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("name", type=str, location="json")
|
||||
parser.add_argument("description", type=_validate_description_length, location="json")
|
||||
parser.add_argument("description", type=validate_description_length, location="json")
|
||||
parser.add_argument("icon_type", type=str, location="json")
|
||||
parser.add_argument("icon", type=str, location="json")
|
||||
parser.add_argument("icon_background", type=str, location="json")
|
||||
|
|
@ -309,7 +304,7 @@ class AppCopyApi(Resource):
|
|||
account = cast(Account, current_user)
|
||||
result = import_service.import_app(
|
||||
account=account,
|
||||
import_mode=ImportMode.YAML_CONTENT.value,
|
||||
import_mode=ImportMode.YAML_CONTENT,
|
||||
yaml_content=yaml_content,
|
||||
name=args.get("name"),
|
||||
description=args.get("description"),
|
||||
|
|
|
|||
|
|
@ -70,9 +70,9 @@ class AppImportApi(Resource):
|
|||
EnterpriseService.WebAppAuth.update_app_access_mode(result.app_id, "private")
|
||||
# Return appropriate status code based on result
|
||||
status = result.status
|
||||
if status == ImportStatus.FAILED.value:
|
||||
if status == ImportStatus.FAILED:
|
||||
return result.model_dump(mode="json"), 400
|
||||
elif status == ImportStatus.PENDING.value:
|
||||
elif status == ImportStatus.PENDING:
|
||||
return result.model_dump(mode="json"), 202
|
||||
return result.model_dump(mode="json"), 200
|
||||
|
||||
|
|
@ -97,7 +97,7 @@ class AppImportConfirmApi(Resource):
|
|||
session.commit()
|
||||
|
||||
# Return appropriate status code based on result
|
||||
if result.status == ImportStatus.FAILED.value:
|
||||
if result.status == ImportStatus.FAILED:
|
||||
return result.model_dump(mode="json"), 400
|
||||
return result.model_dump(mode="json"), 200
|
||||
|
||||
|
|
|
|||
|
|
@ -309,7 +309,7 @@ class ChatConversationApi(Resource):
|
|||
)
|
||||
|
||||
if app_model.mode == AppMode.ADVANCED_CHAT:
|
||||
query = query.where(Conversation.invoke_from != InvokeFrom.DEBUGGER.value)
|
||||
query = query.where(Conversation.invoke_from != InvokeFrom.DEBUGGER)
|
||||
|
||||
match args["sort_by"]:
|
||||
case "created_at":
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ from core.tools.tool_manager import ToolManager
|
|||
from core.tools.utils.configuration import ToolParameterConfigurationManager
|
||||
from events.app_event import app_model_config_was_updated
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.login import login_required
|
||||
from models.account import Account
|
||||
from models.model import AppMode, AppModelConfig
|
||||
|
|
@ -90,7 +91,7 @@ class ModelConfigResource(Resource):
|
|||
if not isinstance(tool, dict) or len(tool.keys()) <= 3:
|
||||
continue
|
||||
|
||||
agent_tool_entity = AgentToolEntity(**tool)
|
||||
agent_tool_entity = AgentToolEntity.model_validate(tool)
|
||||
# get tool
|
||||
try:
|
||||
tool_runtime = ToolManager.get_agent_tool_runtime(
|
||||
|
|
@ -124,7 +125,7 @@ class ModelConfigResource(Resource):
|
|||
# encrypt agent tool parameters if it's secret-input
|
||||
agent_mode = new_app_model_config.agent_mode_dict
|
||||
for tool in agent_mode.get("tools") or []:
|
||||
agent_tool_entity = AgentToolEntity(**tool)
|
||||
agent_tool_entity = AgentToolEntity.model_validate(tool)
|
||||
|
||||
# get tool
|
||||
key = f"{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}"
|
||||
|
|
@ -172,6 +173,8 @@ class ModelConfigResource(Resource):
|
|||
db.session.flush()
|
||||
|
||||
app_model.app_model_config_id = new_app_model_config.id
|
||||
app_model.updated_by = current_user.id
|
||||
app_model.updated_at = naive_utc_now()
|
||||
db.session.commit()
|
||||
|
||||
app_model_config_was_updated.send(app_model, app_model_config=new_app_model_config)
|
||||
|
|
|
|||
|
|
@ -52,7 +52,7 @@ FROM
|
|||
WHERE
|
||||
app_id = :app_id
|
||||
AND invoke_from != :invoke_from"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value}
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
|
@ -127,7 +127,7 @@ class DailyConversationStatistic(Resource):
|
|||
sa.func.count(sa.distinct(Message.conversation_id)).label("conversation_count"),
|
||||
)
|
||||
.select_from(Message)
|
||||
.where(Message.app_id == app_model.id, Message.invoke_from != InvokeFrom.DEBUGGER.value)
|
||||
.where(Message.app_id == app_model.id, Message.invoke_from != InvokeFrom.DEBUGGER)
|
||||
)
|
||||
|
||||
if args["start"]:
|
||||
|
|
@ -190,7 +190,7 @@ FROM
|
|||
WHERE
|
||||
app_id = :app_id
|
||||
AND invoke_from != :invoke_from"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value}
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
|
@ -263,7 +263,7 @@ FROM
|
|||
WHERE
|
||||
app_id = :app_id
|
||||
AND invoke_from != :invoke_from"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value}
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
|
@ -345,7 +345,7 @@ FROM
|
|||
WHERE
|
||||
c.app_id = :app_id
|
||||
AND m.invoke_from != :invoke_from"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value}
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
|
@ -432,7 +432,7 @@ LEFT JOIN
|
|||
WHERE
|
||||
m.app_id = :app_id
|
||||
AND m.invoke_from != :invoke_from"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value}
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
|
@ -509,7 +509,7 @@ FROM
|
|||
WHERE
|
||||
app_id = :app_id
|
||||
AND invoke_from != :invoke_from"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value}
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
|
@ -584,7 +584,7 @@ FROM
|
|||
WHERE
|
||||
app_id = :app_id
|
||||
AND invoke_from != :invoke_from"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value}
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
|
|
|||
|
|
@ -25,6 +25,7 @@ from factories import file_factory, variable_factory
|
|||
from fields.workflow_fields import workflow_fields, workflow_pagination_fields
|
||||
from fields.workflow_run_fields import workflow_run_node_execution_fields
|
||||
from libs import helper
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.helper import TimestampField, uuid_value
|
||||
from libs.login import current_user, login_required
|
||||
from models import App
|
||||
|
|
@ -674,8 +675,12 @@ class PublishedWorkflowApi(Resource):
|
|||
marked_comment=args.marked_comment or "",
|
||||
)
|
||||
|
||||
app_model.workflow_id = workflow.id
|
||||
db.session.commit() # NOTE: this is necessary for update app_model.workflow_id
|
||||
# Update app_model within the same session to ensure atomicity
|
||||
app_model_in_session = session.get(App, app_model.id)
|
||||
if app_model_in_session:
|
||||
app_model_in_session.workflow_id = workflow.id
|
||||
app_model_in_session.updated_by = current_user.id
|
||||
app_model_in_session.updated_at = naive_utc_now()
|
||||
|
||||
workflow_created_at = TimestampField().format(workflow.created_at)
|
||||
|
||||
|
|
|
|||
|
|
@ -47,7 +47,7 @@ WHERE
|
|||
arg_dict = {
|
||||
"tz": account.timezone,
|
||||
"app_id": app_model.id,
|
||||
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN.value,
|
||||
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN,
|
||||
}
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
|
|
@ -115,7 +115,7 @@ WHERE
|
|||
arg_dict = {
|
||||
"tz": account.timezone,
|
||||
"app_id": app_model.id,
|
||||
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN.value,
|
||||
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN,
|
||||
}
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
|
|
@ -183,7 +183,7 @@ WHERE
|
|||
arg_dict = {
|
||||
"tz": account.timezone,
|
||||
"app_id": app_model.id,
|
||||
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN.value,
|
||||
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN,
|
||||
}
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
|
|
@ -269,7 +269,7 @@ GROUP BY
|
|||
arg_dict = {
|
||||
"tz": account.timezone,
|
||||
"app_id": app_model.id,
|
||||
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN.value,
|
||||
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN,
|
||||
}
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
|
|
|
|||
|
|
@ -103,7 +103,7 @@ class ActivateApi(Resource):
|
|||
account.interface_language = args["interface_language"]
|
||||
account.timezone = args["timezone"]
|
||||
account.interface_theme = "light"
|
||||
account.status = AccountStatus.ACTIVE.value
|
||||
account.status = AccountStatus.ACTIVE
|
||||
account.initialized_at = naive_utc_now()
|
||||
db.session.commit()
|
||||
|
||||
|
|
|
|||
|
|
@ -130,11 +130,11 @@ class OAuthCallback(Resource):
|
|||
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message={e.description}")
|
||||
|
||||
# Check account status
|
||||
if account.status == AccountStatus.BANNED.value:
|
||||
if account.status == AccountStatus.BANNED:
|
||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Account is banned.")
|
||||
|
||||
if account.status == AccountStatus.PENDING.value:
|
||||
account.status = AccountStatus.ACTIVE.value
|
||||
if account.status == AccountStatus.PENDING:
|
||||
account.status = AccountStatus.ACTIVE
|
||||
account.initialized_at = naive_utc_now()
|
||||
db.session.commit()
|
||||
|
||||
|
|
|
|||
|
|
@ -1,9 +1,9 @@
|
|||
from flask import request
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, reqparse
|
||||
|
||||
from libs.helper import extract_remote_ip
|
||||
from libs.login import login_required
|
||||
from libs.login import current_user, login_required
|
||||
from models.account import Account
|
||||
from services.billing_service import BillingService
|
||||
|
||||
from .. import console_ns
|
||||
|
|
@ -17,6 +17,8 @@ class ComplianceApi(Resource):
|
|||
@account_initialization_required
|
||||
@only_edition_cloud
|
||||
def get(self):
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("doc_name", type=str, required=True, location="args")
|
||||
args = parser.parse_args()
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ from core.datasource.entities.datasource_entities import DatasourceProviderType,
|
|||
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
|
||||
from core.indexing_runner import IndexingRunner
|
||||
from core.rag.extractor.entity.datasource_type import DatasourceType
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo
|
||||
from core.rag.extractor.notion_extractor import NotionExtractor
|
||||
from extensions.ext_database import db
|
||||
from fields.data_source_fields import integrate_list_fields, integrate_notion_info_list_fields
|
||||
|
|
@ -256,14 +256,16 @@ class DataSourceNotionApi(Resource):
|
|||
credential_id = notion_info.get("credential_id")
|
||||
for page in notion_info["pages"]:
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type=DatasourceType.NOTION.value,
|
||||
notion_info={
|
||||
"credential_id": credential_id,
|
||||
"notion_workspace_id": workspace_id,
|
||||
"notion_obj_id": page["page_id"],
|
||||
"notion_page_type": page["type"],
|
||||
"tenant_id": current_user.current_tenant_id,
|
||||
},
|
||||
datasource_type=DatasourceType.NOTION,
|
||||
notion_info=NotionInfo.model_validate(
|
||||
{
|
||||
"credential_id": credential_id,
|
||||
"notion_workspace_id": workspace_id,
|
||||
"notion_obj_id": page["page_id"],
|
||||
"notion_page_type": page["type"],
|
||||
"tenant_id": current_user.current_tenant_id,
|
||||
}
|
||||
),
|
||||
document_model=args["doc_form"],
|
||||
)
|
||||
extract_settings.append(extract_setting)
|
||||
|
|
|
|||
|
|
@ -24,13 +24,14 @@ from core.model_runtime.entities.model_entities import ModelType
|
|||
from core.provider_manager import ProviderManager
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
from core.rag.extractor.entity.datasource_type import DatasourceType
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo, WebsiteInfo
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from extensions.ext_database import db
|
||||
from fields.app_fields import related_app_list
|
||||
from fields.dataset_fields import dataset_detail_fields, dataset_query_detail_fields
|
||||
from fields.document_fields import document_status_fields
|
||||
from libs.login import login_required
|
||||
from libs.validators import validate_description_length
|
||||
from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile
|
||||
from models.account import Account
|
||||
from models.dataset import DatasetPermissionEnum
|
||||
|
|
@ -44,10 +45,77 @@ def _validate_name(name: str) -> str:
|
|||
return name
|
||||
|
||||
|
||||
def _validate_description_length(description):
|
||||
if description and len(description) > 400:
|
||||
raise ValueError("Description cannot exceed 400 characters.")
|
||||
return description
|
||||
def _get_retrieval_methods_by_vector_type(vector_type: str | None, is_mock: bool = False) -> dict[str, list[str]]:
|
||||
"""
|
||||
Get supported retrieval methods based on vector database type.
|
||||
|
||||
Args:
|
||||
vector_type: Vector database type, can be None
|
||||
is_mock: Whether this is a Mock API, affects MILVUS handling
|
||||
|
||||
Returns:
|
||||
Dictionary containing supported retrieval methods
|
||||
|
||||
Raises:
|
||||
ValueError: If vector_type is None or unsupported
|
||||
"""
|
||||
if vector_type is None:
|
||||
raise ValueError("Vector store type is not configured.")
|
||||
|
||||
# Define vector database types that only support semantic search
|
||||
semantic_only_types = {
|
||||
VectorType.RELYT,
|
||||
VectorType.TIDB_VECTOR,
|
||||
VectorType.CHROMA,
|
||||
VectorType.PGVECTO_RS,
|
||||
VectorType.VIKINGDB,
|
||||
VectorType.UPSTASH,
|
||||
}
|
||||
|
||||
# Define vector database types that support all retrieval methods
|
||||
full_search_types = {
|
||||
VectorType.QDRANT,
|
||||
VectorType.WEAVIATE,
|
||||
VectorType.OPENSEARCH,
|
||||
VectorType.ANALYTICDB,
|
||||
VectorType.MYSCALE,
|
||||
VectorType.ORACLE,
|
||||
VectorType.ELASTICSEARCH,
|
||||
VectorType.ELASTICSEARCH_JA,
|
||||
VectorType.PGVECTOR,
|
||||
VectorType.VASTBASE,
|
||||
VectorType.TIDB_ON_QDRANT,
|
||||
VectorType.LINDORM,
|
||||
VectorType.COUCHBASE,
|
||||
VectorType.OPENGAUSS,
|
||||
VectorType.OCEANBASE,
|
||||
VectorType.TABLESTORE,
|
||||
VectorType.HUAWEI_CLOUD,
|
||||
VectorType.TENCENT,
|
||||
VectorType.MATRIXONE,
|
||||
VectorType.CLICKZETTA,
|
||||
VectorType.BAIDU,
|
||||
VectorType.ALIBABACLOUD_MYSQL,
|
||||
}
|
||||
|
||||
semantic_methods = {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}
|
||||
full_methods = {
|
||||
"retrieval_method": [
|
||||
RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||
RetrievalMethod.FULL_TEXT_SEARCH.value,
|
||||
RetrievalMethod.HYBRID_SEARCH.value,
|
||||
]
|
||||
}
|
||||
|
||||
if vector_type == VectorType.MILVUS:
|
||||
return semantic_methods if is_mock else full_methods
|
||||
|
||||
if vector_type in semantic_only_types:
|
||||
return semantic_methods
|
||||
elif vector_type in full_search_types:
|
||||
return full_methods
|
||||
else:
|
||||
raise ValueError(f"Unsupported vector db type {vector_type}.")
|
||||
|
||||
|
||||
@console_ns.route("/datasets")
|
||||
|
|
@ -149,7 +217,7 @@ class DatasetListApi(Resource):
|
|||
)
|
||||
parser.add_argument(
|
||||
"description",
|
||||
type=_validate_description_length,
|
||||
type=validate_description_length,
|
||||
nullable=True,
|
||||
required=False,
|
||||
default="",
|
||||
|
|
@ -290,7 +358,7 @@ class DatasetApi(Resource):
|
|||
help="type is required. Name must be between 1 to 40 characters.",
|
||||
type=_validate_name,
|
||||
)
|
||||
parser.add_argument("description", location="json", store_missing=False, type=_validate_description_length)
|
||||
parser.add_argument("description", location="json", store_missing=False, type=validate_description_length)
|
||||
parser.add_argument(
|
||||
"indexing_technique",
|
||||
type=str,
|
||||
|
|
@ -505,7 +573,7 @@ class DatasetIndexingEstimateApi(Resource):
|
|||
if file_details:
|
||||
for file_detail in file_details:
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type=DatasourceType.FILE.value,
|
||||
datasource_type=DatasourceType.FILE,
|
||||
upload_file=file_detail,
|
||||
document_model=args["doc_form"],
|
||||
)
|
||||
|
|
@ -517,14 +585,16 @@ class DatasetIndexingEstimateApi(Resource):
|
|||
credential_id = notion_info.get("credential_id")
|
||||
for page in notion_info["pages"]:
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type=DatasourceType.NOTION.value,
|
||||
notion_info={
|
||||
"credential_id": credential_id,
|
||||
"notion_workspace_id": workspace_id,
|
||||
"notion_obj_id": page["page_id"],
|
||||
"notion_page_type": page["type"],
|
||||
"tenant_id": current_user.current_tenant_id,
|
||||
},
|
||||
datasource_type=DatasourceType.NOTION,
|
||||
notion_info=NotionInfo.model_validate(
|
||||
{
|
||||
"credential_id": credential_id,
|
||||
"notion_workspace_id": workspace_id,
|
||||
"notion_obj_id": page["page_id"],
|
||||
"notion_page_type": page["type"],
|
||||
"tenant_id": current_user.current_tenant_id,
|
||||
}
|
||||
),
|
||||
document_model=args["doc_form"],
|
||||
)
|
||||
extract_settings.append(extract_setting)
|
||||
|
|
@ -532,15 +602,17 @@ class DatasetIndexingEstimateApi(Resource):
|
|||
website_info_list = args["info_list"]["website_info_list"]
|
||||
for url in website_info_list["urls"]:
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type=DatasourceType.WEBSITE.value,
|
||||
website_info={
|
||||
"provider": website_info_list["provider"],
|
||||
"job_id": website_info_list["job_id"],
|
||||
"url": url,
|
||||
"tenant_id": current_user.current_tenant_id,
|
||||
"mode": "crawl",
|
||||
"only_main_content": website_info_list["only_main_content"],
|
||||
},
|
||||
datasource_type=DatasourceType.WEBSITE,
|
||||
website_info=WebsiteInfo.model_validate(
|
||||
{
|
||||
"provider": website_info_list["provider"],
|
||||
"job_id": website_info_list["job_id"],
|
||||
"url": url,
|
||||
"tenant_id": current_user.current_tenant_id,
|
||||
"mode": "crawl",
|
||||
"only_main_content": website_info_list["only_main_content"],
|
||||
}
|
||||
),
|
||||
document_model=args["doc_form"],
|
||||
)
|
||||
extract_settings.append(extract_setting)
|
||||
|
|
@ -778,49 +850,7 @@ class DatasetRetrievalSettingApi(Resource):
|
|||
@account_initialization_required
|
||||
def get(self):
|
||||
vector_type = dify_config.VECTOR_STORE
|
||||
match vector_type:
|
||||
case (
|
||||
VectorType.RELYT
|
||||
| VectorType.TIDB_VECTOR
|
||||
| VectorType.CHROMA
|
||||
| VectorType.PGVECTO_RS
|
||||
| VectorType.VIKINGDB
|
||||
| VectorType.UPSTASH
|
||||
):
|
||||
return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}
|
||||
case (
|
||||
VectorType.QDRANT
|
||||
| VectorType.WEAVIATE
|
||||
| VectorType.OPENSEARCH
|
||||
| VectorType.ANALYTICDB
|
||||
| VectorType.MYSCALE
|
||||
| VectorType.ORACLE
|
||||
| VectorType.ELASTICSEARCH
|
||||
| VectorType.ELASTICSEARCH_JA
|
||||
| VectorType.PGVECTOR
|
||||
| VectorType.VASTBASE
|
||||
| VectorType.TIDB_ON_QDRANT
|
||||
| VectorType.LINDORM
|
||||
| VectorType.COUCHBASE
|
||||
| VectorType.MILVUS
|
||||
| VectorType.OPENGAUSS
|
||||
| VectorType.OCEANBASE
|
||||
| VectorType.TABLESTORE
|
||||
| VectorType.HUAWEI_CLOUD
|
||||
| VectorType.TENCENT
|
||||
| VectorType.MATRIXONE
|
||||
| VectorType.CLICKZETTA
|
||||
| VectorType.BAIDU
|
||||
):
|
||||
return {
|
||||
"retrieval_method": [
|
||||
RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||
RetrievalMethod.FULL_TEXT_SEARCH.value,
|
||||
RetrievalMethod.HYBRID_SEARCH.value,
|
||||
]
|
||||
}
|
||||
case _:
|
||||
raise ValueError(f"Unsupported vector db type {vector_type}.")
|
||||
return _get_retrieval_methods_by_vector_type(vector_type, is_mock=False)
|
||||
|
||||
|
||||
@console_ns.route("/datasets/retrieval-setting/<string:vector_type>")
|
||||
|
|
@ -833,48 +863,7 @@ class DatasetRetrievalSettingMockApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, vector_type):
|
||||
match vector_type:
|
||||
case (
|
||||
VectorType.MILVUS
|
||||
| VectorType.RELYT
|
||||
| VectorType.TIDB_VECTOR
|
||||
| VectorType.CHROMA
|
||||
| VectorType.PGVECTO_RS
|
||||
| VectorType.VIKINGDB
|
||||
| VectorType.UPSTASH
|
||||
):
|
||||
return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}
|
||||
case (
|
||||
VectorType.QDRANT
|
||||
| VectorType.WEAVIATE
|
||||
| VectorType.OPENSEARCH
|
||||
| VectorType.ANALYTICDB
|
||||
| VectorType.MYSCALE
|
||||
| VectorType.ORACLE
|
||||
| VectorType.ELASTICSEARCH
|
||||
| VectorType.ELASTICSEARCH_JA
|
||||
| VectorType.COUCHBASE
|
||||
| VectorType.PGVECTOR
|
||||
| VectorType.VASTBASE
|
||||
| VectorType.LINDORM
|
||||
| VectorType.OPENGAUSS
|
||||
| VectorType.OCEANBASE
|
||||
| VectorType.TABLESTORE
|
||||
| VectorType.TENCENT
|
||||
| VectorType.HUAWEI_CLOUD
|
||||
| VectorType.MATRIXONE
|
||||
| VectorType.CLICKZETTA
|
||||
| VectorType.BAIDU
|
||||
):
|
||||
return {
|
||||
"retrieval_method": [
|
||||
RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||
RetrievalMethod.FULL_TEXT_SEARCH.value,
|
||||
RetrievalMethod.HYBRID_SEARCH.value,
|
||||
]
|
||||
}
|
||||
case _:
|
||||
raise ValueError(f"Unsupported vector db type {vector_type}.")
|
||||
return _get_retrieval_methods_by_vector_type(vector_type, is_mock=True)
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/error-docs")
|
||||
|
|
|
|||
|
|
@ -44,7 +44,7 @@ from core.model_runtime.entities.model_entities import ModelType
|
|||
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
from core.plugin.impl.exc import PluginDaemonClientSideError
|
||||
from core.rag.extractor.entity.datasource_type import DatasourceType
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo, WebsiteInfo
|
||||
from extensions.ext_database import db
|
||||
from fields.document_fields import (
|
||||
dataset_and_document_fields,
|
||||
|
|
@ -305,7 +305,7 @@ class DatasetDocumentListApi(Resource):
|
|||
"doc_language", type=str, default="English", required=False, nullable=False, location="json"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
knowledge_config = KnowledgeConfig(**args)
|
||||
knowledge_config = KnowledgeConfig.model_validate(args)
|
||||
|
||||
if not dataset.indexing_technique and not knowledge_config.indexing_technique:
|
||||
raise ValueError("indexing_technique is required.")
|
||||
|
|
@ -395,7 +395,7 @@ class DatasetInitApi(Resource):
|
|||
parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
knowledge_config = KnowledgeConfig(**args)
|
||||
knowledge_config = KnowledgeConfig.model_validate(args)
|
||||
if knowledge_config.indexing_technique == "high_quality":
|
||||
if knowledge_config.embedding_model is None or knowledge_config.embedding_model_provider is None:
|
||||
raise ValueError("embedding model and embedding model provider are required for high quality indexing.")
|
||||
|
|
@ -475,7 +475,7 @@ class DocumentIndexingEstimateApi(DocumentResource):
|
|||
raise NotFound("File not found.")
|
||||
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type=DatasourceType.FILE.value, upload_file=file, document_model=document.doc_form
|
||||
datasource_type=DatasourceType.FILE, upload_file=file, document_model=document.doc_form
|
||||
)
|
||||
|
||||
indexing_runner = IndexingRunner()
|
||||
|
|
@ -538,7 +538,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
|||
raise NotFound("File not found.")
|
||||
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type=DatasourceType.FILE.value, upload_file=file_detail, document_model=document.doc_form
|
||||
datasource_type=DatasourceType.FILE, upload_file=file_detail, document_model=document.doc_form
|
||||
)
|
||||
extract_settings.append(extract_setting)
|
||||
|
||||
|
|
@ -546,14 +546,16 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
|||
if not data_source_info:
|
||||
continue
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type=DatasourceType.NOTION.value,
|
||||
notion_info={
|
||||
"credential_id": data_source_info["credential_id"],
|
||||
"notion_workspace_id": data_source_info["notion_workspace_id"],
|
||||
"notion_obj_id": data_source_info["notion_page_id"],
|
||||
"notion_page_type": data_source_info["type"],
|
||||
"tenant_id": current_user.current_tenant_id,
|
||||
},
|
||||
datasource_type=DatasourceType.NOTION,
|
||||
notion_info=NotionInfo.model_validate(
|
||||
{
|
||||
"credential_id": data_source_info["credential_id"],
|
||||
"notion_workspace_id": data_source_info["notion_workspace_id"],
|
||||
"notion_obj_id": data_source_info["notion_page_id"],
|
||||
"notion_page_type": data_source_info["type"],
|
||||
"tenant_id": current_user.current_tenant_id,
|
||||
}
|
||||
),
|
||||
document_model=document.doc_form,
|
||||
)
|
||||
extract_settings.append(extract_setting)
|
||||
|
|
@ -561,15 +563,17 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
|||
if not data_source_info:
|
||||
continue
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type=DatasourceType.WEBSITE.value,
|
||||
website_info={
|
||||
"provider": data_source_info["provider"],
|
||||
"job_id": data_source_info["job_id"],
|
||||
"url": data_source_info["url"],
|
||||
"tenant_id": current_user.current_tenant_id,
|
||||
"mode": data_source_info["mode"],
|
||||
"only_main_content": data_source_info["only_main_content"],
|
||||
},
|
||||
datasource_type=DatasourceType.WEBSITE,
|
||||
website_info=WebsiteInfo.model_validate(
|
||||
{
|
||||
"provider": data_source_info["provider"],
|
||||
"job_id": data_source_info["job_id"],
|
||||
"url": data_source_info["url"],
|
||||
"tenant_id": current_user.current_tenant_id,
|
||||
"mode": data_source_info["mode"],
|
||||
"only_main_content": data_source_info["only_main_content"],
|
||||
}
|
||||
),
|
||||
document_model=document.doc_form,
|
||||
)
|
||||
extract_settings.append(extract_setting)
|
||||
|
|
|
|||
|
|
@ -309,7 +309,7 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
|||
)
|
||||
args = parser.parse_args()
|
||||
SegmentService.segment_create_args_validate(args, document)
|
||||
segment = SegmentService.update_segment(SegmentUpdateArgs(**args), segment, document, dataset)
|
||||
segment = SegmentService.update_segment(SegmentUpdateArgs.model_validate(args), segment, document, dataset)
|
||||
return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200
|
||||
|
||||
@setup_required
|
||||
|
|
@ -564,7 +564,7 @@ class ChildChunkAddApi(Resource):
|
|||
args = parser.parse_args()
|
||||
try:
|
||||
chunks_data = args["chunks"]
|
||||
chunks = [ChildChunkUpdateArgs(**chunk) for chunk in chunks_data]
|
||||
chunks = [ChildChunkUpdateArgs.model_validate(chunk) for chunk in chunks_data]
|
||||
child_chunks = SegmentService.update_child_chunks(chunks, segment, document, dataset)
|
||||
except ChildChunkIndexingServiceError as e:
|
||||
raise ChildChunkIndexingError(str(e))
|
||||
|
|
|
|||
|
|
@ -1,7 +1,5 @@
|
|||
import logging
|
||||
from typing import cast
|
||||
|
||||
from flask_login import current_user
|
||||
from flask_restx import marshal, reqparse
|
||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||
|
||||
|
|
@ -21,6 +19,7 @@ from core.errors.error import (
|
|||
)
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from fields.hit_testing_fields import hit_testing_record_fields
|
||||
from libs.login import current_user
|
||||
from models.account import Account
|
||||
from services.dataset_service import DatasetService
|
||||
from services.hit_testing_service import HitTestingService
|
||||
|
|
@ -31,6 +30,7 @@ logger = logging.getLogger(__name__)
|
|||
class DatasetsHitTestingBase:
|
||||
@staticmethod
|
||||
def get_and_validate_dataset(dataset_id: str):
|
||||
assert isinstance(current_user, Account)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
|
@ -57,11 +57,12 @@ class DatasetsHitTestingBase:
|
|||
|
||||
@staticmethod
|
||||
def perform_hit_testing(dataset, args):
|
||||
assert isinstance(current_user, Account)
|
||||
try:
|
||||
response = HitTestingService.retrieve(
|
||||
dataset=dataset,
|
||||
query=args["query"],
|
||||
account=cast(Account, current_user),
|
||||
account=current_user,
|
||||
retrieval_model=args["retrieval_model"],
|
||||
external_retrieval_model=args["external_retrieval_model"],
|
||||
limit=10,
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ class DatasetMetadataCreateApi(Resource):
|
|||
parser.add_argument("type", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("name", type=str, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
metadata_args = MetadataArgs(**args)
|
||||
metadata_args = MetadataArgs.model_validate(args)
|
||||
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
|
|
@ -137,7 +137,7 @@ class DocumentMetadataEditApi(Resource):
|
|||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("operation_data", type=list, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
metadata_args = MetadataOperationData(**args)
|
||||
metadata_args = MetadataOperationData.model_validate(args)
|
||||
|
||||
MetadataService.update_documents_metadata(dataset, metadata_args)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
from fastapi.encoders import jsonable_encoder
|
||||
from flask import make_response, redirect, request
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, reqparse
|
||||
|
|
@ -11,6 +10,7 @@ from controllers.console.wraps import (
|
|||
setup_required,
|
||||
)
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.plugin.impl.oauth import OAuthHandler
|
||||
from libs.helper import StrLen
|
||||
from libs.login import login_required
|
||||
|
|
|
|||
|
|
@ -88,7 +88,7 @@ class CustomizedPipelineTemplateApi(Resource):
|
|||
nullable=True,
|
||||
)
|
||||
args = parser.parse_args()
|
||||
pipeline_template_info = PipelineTemplateInfoEntity(**args)
|
||||
pipeline_template_info = PipelineTemplateInfoEntity.model_validate(args)
|
||||
RagPipelineService.update_customized_pipeline_template(template_id, pipeline_template_info)
|
||||
return 200
|
||||
|
||||
|
|
|
|||
|
|
@ -60,9 +60,9 @@ class RagPipelineImportApi(Resource):
|
|||
|
||||
# Return appropriate status code based on result
|
||||
status = result.status
|
||||
if status == ImportStatus.FAILED.value:
|
||||
if status == ImportStatus.FAILED:
|
||||
return result.model_dump(mode="json"), 400
|
||||
elif status == ImportStatus.PENDING.value:
|
||||
elif status == ImportStatus.PENDING:
|
||||
return result.model_dump(mode="json"), 202
|
||||
return result.model_dump(mode="json"), 200
|
||||
|
||||
|
|
@ -87,7 +87,7 @@ class RagPipelineImportConfirmApi(Resource):
|
|||
session.commit()
|
||||
|
||||
# Return appropriate status code based on result
|
||||
if result.status == ImportStatus.FAILED.value:
|
||||
if result.status == ImportStatus.FAILED:
|
||||
return result.model_dump(mode="json"), 400
|
||||
return result.model_dump(mode="json"), 200
|
||||
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from flask_restx import Resource, inputs, marshal_with, reqparse
|
|||
from sqlalchemy import and_, select
|
||||
from werkzeug.exceptions import BadRequest, Forbidden, NotFound
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.explore.wraps import InstalledAppResource
|
||||
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
|
||||
from extensions.ext_database import db
|
||||
|
|
@ -22,6 +22,7 @@ from services.feature_service import FeatureService
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@console_ns.route("/installed-apps")
|
||||
class InstalledAppsListApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -154,6 +155,7 @@ class InstalledAppsListApi(Resource):
|
|||
return {"message": "App installed successfully"}
|
||||
|
||||
|
||||
@console_ns.route("/installed-apps/<uuid:installed_app_id>")
|
||||
class InstalledAppApi(InstalledAppResource):
|
||||
"""
|
||||
update and delete an installed app
|
||||
|
|
@ -185,7 +187,3 @@ class InstalledAppApi(InstalledAppResource):
|
|||
db.session.commit()
|
||||
|
||||
return {"result": "success", "message": "App info updated successfully"}
|
||||
|
||||
|
||||
api.add_resource(InstalledAppsListApi, "/installed-apps")
|
||||
api.add_resource(InstalledAppApi, "/installed-apps/<uuid:installed_app_id>")
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from flask_restx import marshal_with
|
||||
|
||||
from controllers.common import fields
|
||||
from controllers.console import api
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import AppUnavailableError
|
||||
from controllers.console.explore.wraps import InstalledAppResource
|
||||
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
|
||||
|
|
@ -9,6 +9,7 @@ from models.model import AppMode, InstalledApp
|
|||
from services.app_service import AppService
|
||||
|
||||
|
||||
@console_ns.route("/installed-apps/<uuid:installed_app_id>/parameters", endpoint="installed_app_parameters")
|
||||
class AppParameterApi(InstalledAppResource):
|
||||
"""Resource for app variables."""
|
||||
|
||||
|
|
@ -39,6 +40,7 @@ class AppParameterApi(InstalledAppResource):
|
|||
return get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
|
||||
|
||||
|
||||
@console_ns.route("/installed-apps/<uuid:installed_app_id>/meta", endpoint="installed_app_meta")
|
||||
class ExploreAppMetaApi(InstalledAppResource):
|
||||
def get(self, installed_app: InstalledApp):
|
||||
"""Get app meta"""
|
||||
|
|
@ -46,9 +48,3 @@ class ExploreAppMetaApi(InstalledAppResource):
|
|||
if not app_model:
|
||||
raise ValueError("App not found")
|
||||
return AppService().get_app_meta(app_model)
|
||||
|
||||
|
||||
api.add_resource(
|
||||
AppParameterApi, "/installed-apps/<uuid:installed_app_id>/parameters", endpoint="installed_app_parameters"
|
||||
)
|
||||
api.add_resource(ExploreAppMetaApi, "/installed-apps/<uuid:installed_app_id>/meta", endpoint="installed_app_meta")
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from flask_restx import Resource, fields, marshal_with, reqparse
|
||||
|
||||
from constants.languages import languages
|
||||
from controllers.console import api
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from libs.helper import AppIconUrlField
|
||||
from libs.login import current_user, login_required
|
||||
|
|
@ -35,6 +35,7 @@ recommended_app_list_fields = {
|
|||
}
|
||||
|
||||
|
||||
@console_ns.route("/explore/apps")
|
||||
class RecommendedAppListApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -56,13 +57,10 @@ class RecommendedAppListApi(Resource):
|
|||
return RecommendedAppService.get_recommended_apps_and_categories(language_prefix)
|
||||
|
||||
|
||||
@console_ns.route("/explore/apps/<uuid:app_id>")
|
||||
class RecommendedAppApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_id):
|
||||
app_id = str(app_id)
|
||||
return RecommendedAppService.get_recommend_app_detail(app_id)
|
||||
|
||||
|
||||
api.add_resource(RecommendedAppListApi, "/explore/apps")
|
||||
api.add_resource(RecommendedAppApi, "/explore/apps/<uuid:app_id>")
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ from flask_restx import fields, marshal_with, reqparse
|
|||
from flask_restx.inputs import int_range
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.explore.error import NotCompletionAppError
|
||||
from controllers.console.explore.wraps import InstalledAppResource
|
||||
from fields.conversation_fields import message_file_fields
|
||||
|
|
@ -25,6 +25,7 @@ message_fields = {
|
|||
}
|
||||
|
||||
|
||||
@console_ns.route("/installed-apps/<uuid:installed_app_id>/saved-messages", endpoint="installed_app_saved_messages")
|
||||
class SavedMessageListApi(InstalledAppResource):
|
||||
saved_message_infinite_scroll_pagination_fields = {
|
||||
"limit": fields.Integer,
|
||||
|
|
@ -66,6 +67,9 @@ class SavedMessageListApi(InstalledAppResource):
|
|||
return {"result": "success"}
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
"/installed-apps/<uuid:installed_app_id>/saved-messages/<uuid:message_id>", endpoint="installed_app_saved_message"
|
||||
)
|
||||
class SavedMessageApi(InstalledAppResource):
|
||||
def delete(self, installed_app, message_id):
|
||||
app_model = installed_app.app
|
||||
|
|
@ -80,15 +84,3 @@ class SavedMessageApi(InstalledAppResource):
|
|||
SavedMessageService.delete(app_model, current_user, message_id)
|
||||
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
api.add_resource(
|
||||
SavedMessageListApi,
|
||||
"/installed-apps/<uuid:installed_app_id>/saved-messages",
|
||||
endpoint="installed_app_saved_messages",
|
||||
)
|
||||
api.add_resource(
|
||||
SavedMessageApi,
|
||||
"/installed-apps/<uuid:installed_app_id>/saved-messages/<uuid:message_id>",
|
||||
endpoint="installed_app_saved_message",
|
||||
)
|
||||
|
|
|
|||
|
|
@ -2,15 +2,15 @@ from collections.abc import Callable
|
|||
from functools import wraps
|
||||
from typing import Concatenate, ParamSpec, TypeVar
|
||||
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.console.explore.error import AppAccessDeniedError
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from extensions.ext_database import db
|
||||
from libs.login import login_required
|
||||
from libs.login import current_user, login_required
|
||||
from models import InstalledApp
|
||||
from models.account import Account
|
||||
from services.app_service import AppService
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
from services.feature_service import FeatureService
|
||||
|
|
@ -24,6 +24,8 @@ def installed_app_required(view: Callable[Concatenate[InstalledApp, P], R] | Non
|
|||
def decorator(view: Callable[Concatenate[InstalledApp, P], R]):
|
||||
@wraps(view)
|
||||
def decorated(installed_app_id: str, *args: P.args, **kwargs: P.kwargs):
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
installed_app = (
|
||||
db.session.query(InstalledApp)
|
||||
.where(
|
||||
|
|
@ -56,6 +58,7 @@ def user_allowed_to_access_app(view: Callable[Concatenate[InstalledApp, P], R] |
|
|||
def decorated(installed_app: InstalledApp, *args: P.args, **kwargs: P.kwargs):
|
||||
feature = FeatureService.get_system_features()
|
||||
if feature.webapp_auth.enabled:
|
||||
assert isinstance(current_user, Account)
|
||||
app_id = installed_app.app_id
|
||||
app_code = AppService.get_app_code_by_id(app_id)
|
||||
res = EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(
|
||||
|
|
|
|||
|
|
@ -1,11 +1,11 @@
|
|||
from flask_login import current_user
|
||||
from flask_restx import Resource, fields, marshal_with, reqparse
|
||||
|
||||
from constants import HIDDEN_VALUE
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from fields.api_based_extension_fields import api_based_extension_fields
|
||||
from libs.login import login_required
|
||||
from libs.login import current_user, login_required
|
||||
from models.account import Account
|
||||
from models.api_based_extension import APIBasedExtension
|
||||
from services.api_based_extension_service import APIBasedExtensionService
|
||||
from services.code_based_extension_service import CodeBasedExtensionService
|
||||
|
|
@ -47,6 +47,8 @@ class APIBasedExtensionAPI(Resource):
|
|||
@account_initialization_required
|
||||
@marshal_with(api_based_extension_fields)
|
||||
def get(self):
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
tenant_id = current_user.current_tenant_id
|
||||
return APIBasedExtensionService.get_all_by_tenant_id(tenant_id)
|
||||
|
||||
|
|
@ -68,6 +70,8 @@ class APIBasedExtensionAPI(Resource):
|
|||
@account_initialization_required
|
||||
@marshal_with(api_based_extension_fields)
|
||||
def post(self):
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("name", type=str, required=True, location="json")
|
||||
parser.add_argument("api_endpoint", type=str, required=True, location="json")
|
||||
|
|
@ -95,6 +99,8 @@ class APIBasedExtensionDetailAPI(Resource):
|
|||
@account_initialization_required
|
||||
@marshal_with(api_based_extension_fields)
|
||||
def get(self, id):
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
api_based_extension_id = str(id)
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
|
|
@ -119,6 +125,8 @@ class APIBasedExtensionDetailAPI(Resource):
|
|||
@account_initialization_required
|
||||
@marshal_with(api_based_extension_fields)
|
||||
def post(self, id):
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
api_based_extension_id = str(id)
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
|
|
@ -146,6 +154,8 @@ class APIBasedExtensionDetailAPI(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self, id):
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
api_based_extension_id = str(id)
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from flask_login import current_user
|
||||
from flask_restx import Resource, fields
|
||||
|
||||
from libs.login import login_required
|
||||
from libs.login import current_user, login_required
|
||||
from models.account import Account
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
from . import api, console_ns
|
||||
|
|
@ -23,6 +23,8 @@ class FeatureApi(Resource):
|
|||
@cloud_utm_record
|
||||
def get(self):
|
||||
"""Get feature configuration for current tenant"""
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
return FeatureService.get_features(current_user.current_tenant_id).model_dump()
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,8 +1,6 @@
|
|||
import urllib.parse
|
||||
from typing import cast
|
||||
|
||||
import httpx
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, marshal_with, reqparse
|
||||
|
||||
import services
|
||||
|
|
@ -16,6 +14,7 @@ from core.file import helpers as file_helpers
|
|||
from core.helper import ssrf_proxy
|
||||
from extensions.ext_database import db
|
||||
from fields.file_fields import file_fields_with_signed_url, remote_file_info_fields
|
||||
from libs.login import current_user
|
||||
from models.account import Account
|
||||
from services.file_service import FileService
|
||||
|
||||
|
|
@ -65,7 +64,8 @@ class RemoteFileUploadApi(Resource):
|
|||
content = resp.content if resp.request.method == "GET" else ssrf_proxy.get(url).content
|
||||
|
||||
try:
|
||||
user = cast(Account, current_user)
|
||||
assert isinstance(current_user, Account)
|
||||
user = current_user
|
||||
upload_file = FileService(db.engine).upload_file(
|
||||
filename=file_info.filename,
|
||||
content=content,
|
||||
|
|
|
|||
|
|
@ -1,12 +1,12 @@
|
|||
from flask import request
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, marshal_with, reqparse
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from fields.tag_fields import dataset_tag_fields
|
||||
from libs.login import login_required
|
||||
from libs.login import current_user, login_required
|
||||
from models.account import Account
|
||||
from models.model import Tag
|
||||
from services.tag_service import TagService
|
||||
|
||||
|
|
@ -24,6 +24,8 @@ class TagListApi(Resource):
|
|||
@account_initialization_required
|
||||
@marshal_with(dataset_tag_fields)
|
||||
def get(self):
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
tag_type = request.args.get("type", type=str, default="")
|
||||
keyword = request.args.get("keyword", default=None, type=str)
|
||||
tags = TagService.get_tags(tag_type, current_user.current_tenant_id, keyword)
|
||||
|
|
@ -34,8 +36,10 @@ class TagListApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not (current_user.is_editor or current_user.is_dataset_editor):
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
|
|
@ -59,9 +63,11 @@ class TagUpdateDeleteApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def patch(self, tag_id):
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
tag_id = str(tag_id)
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not (current_user.is_editor or current_user.is_dataset_editor):
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
|
|
@ -81,9 +87,11 @@ class TagUpdateDeleteApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self, tag_id):
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
tag_id = str(tag_id)
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
TagService.delete_tag(tag_id)
|
||||
|
|
@ -97,8 +105,10 @@ class TagBindingCreateApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||
if not (current_user.is_editor or current_user.is_dataset_editor):
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
|
|
@ -123,8 +133,10 @@ class TagBindingDeleteApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||
if not (current_user.is_editor or current_user.is_dataset_editor):
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from sqlalchemy.orm import Session
|
|||
|
||||
from configs import dify_config
|
||||
from constants.languages import supported_language
|
||||
from controllers.console import api
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.auth.error import (
|
||||
EmailAlreadyInUseError,
|
||||
EmailChangeLimitError,
|
||||
|
|
@ -45,6 +45,7 @@ from services.billing_service import BillingService
|
|||
from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError
|
||||
|
||||
|
||||
@console_ns.route("/account/init")
|
||||
class AccountInitApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -97,6 +98,7 @@ class AccountInitApi(Resource):
|
|||
return {"result": "success"}
|
||||
|
||||
|
||||
@console_ns.route("/account/profile")
|
||||
class AccountProfileApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -109,6 +111,7 @@ class AccountProfileApi(Resource):
|
|||
return current_user
|
||||
|
||||
|
||||
@console_ns.route("/account/name")
|
||||
class AccountNameApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -130,6 +133,7 @@ class AccountNameApi(Resource):
|
|||
return updated_account
|
||||
|
||||
|
||||
@console_ns.route("/account/avatar")
|
||||
class AccountAvatarApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -147,6 +151,7 @@ class AccountAvatarApi(Resource):
|
|||
return updated_account
|
||||
|
||||
|
||||
@console_ns.route("/account/interface-language")
|
||||
class AccountInterfaceLanguageApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -164,6 +169,7 @@ class AccountInterfaceLanguageApi(Resource):
|
|||
return updated_account
|
||||
|
||||
|
||||
@console_ns.route("/account/interface-theme")
|
||||
class AccountInterfaceThemeApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -181,6 +187,7 @@ class AccountInterfaceThemeApi(Resource):
|
|||
return updated_account
|
||||
|
||||
|
||||
@console_ns.route("/account/timezone")
|
||||
class AccountTimezoneApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -202,6 +209,7 @@ class AccountTimezoneApi(Resource):
|
|||
return updated_account
|
||||
|
||||
|
||||
@console_ns.route("/account/password")
|
||||
class AccountPasswordApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -227,6 +235,7 @@ class AccountPasswordApi(Resource):
|
|||
return {"result": "success"}
|
||||
|
||||
|
||||
@console_ns.route("/account/integrates")
|
||||
class AccountIntegrateApi(Resource):
|
||||
integrate_fields = {
|
||||
"provider": fields.String,
|
||||
|
|
@ -283,6 +292,7 @@ class AccountIntegrateApi(Resource):
|
|||
return {"data": integrate_data}
|
||||
|
||||
|
||||
@console_ns.route("/account/delete/verify")
|
||||
class AccountDeleteVerifyApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -298,6 +308,7 @@ class AccountDeleteVerifyApi(Resource):
|
|||
return {"result": "success", "data": token}
|
||||
|
||||
|
||||
@console_ns.route("/account/delete")
|
||||
class AccountDeleteApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -320,6 +331,7 @@ class AccountDeleteApi(Resource):
|
|||
return {"result": "success"}
|
||||
|
||||
|
||||
@console_ns.route("/account/delete/feedback")
|
||||
class AccountDeleteUpdateFeedbackApi(Resource):
|
||||
@setup_required
|
||||
def post(self):
|
||||
|
|
@ -333,6 +345,7 @@ class AccountDeleteUpdateFeedbackApi(Resource):
|
|||
return {"result": "success"}
|
||||
|
||||
|
||||
@console_ns.route("/account/education/verify")
|
||||
class EducationVerifyApi(Resource):
|
||||
verify_fields = {
|
||||
"token": fields.String,
|
||||
|
|
@ -352,6 +365,7 @@ class EducationVerifyApi(Resource):
|
|||
return BillingService.EducationIdentity.verify(account.id, account.email)
|
||||
|
||||
|
||||
@console_ns.route("/account/education")
|
||||
class EducationApi(Resource):
|
||||
status_fields = {
|
||||
"result": fields.Boolean,
|
||||
|
|
@ -396,6 +410,7 @@ class EducationApi(Resource):
|
|||
return res
|
||||
|
||||
|
||||
@console_ns.route("/account/education/autocomplete")
|
||||
class EducationAutoCompleteApi(Resource):
|
||||
data_fields = {
|
||||
"data": fields.List(fields.String),
|
||||
|
|
@ -419,6 +434,7 @@ class EducationAutoCompleteApi(Resource):
|
|||
return BillingService.EducationIdentity.autocomplete(args["keywords"], args["page"], args["limit"])
|
||||
|
||||
|
||||
@console_ns.route("/account/change-email")
|
||||
class ChangeEmailSendEmailApi(Resource):
|
||||
@enable_change_email
|
||||
@setup_required
|
||||
|
|
@ -467,6 +483,7 @@ class ChangeEmailSendEmailApi(Resource):
|
|||
return {"result": "success", "data": token}
|
||||
|
||||
|
||||
@console_ns.route("/account/change-email/validity")
|
||||
class ChangeEmailCheckApi(Resource):
|
||||
@enable_change_email
|
||||
@setup_required
|
||||
|
|
@ -508,6 +525,7 @@ class ChangeEmailCheckApi(Resource):
|
|||
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
|
||||
|
||||
|
||||
@console_ns.route("/account/change-email/reset")
|
||||
class ChangeEmailResetApi(Resource):
|
||||
@enable_change_email
|
||||
@setup_required
|
||||
|
|
@ -547,6 +565,7 @@ class ChangeEmailResetApi(Resource):
|
|||
return updated_account
|
||||
|
||||
|
||||
@console_ns.route("/account/change-email/check-email-unique")
|
||||
class CheckEmailUnique(Resource):
|
||||
@setup_required
|
||||
def post(self):
|
||||
|
|
@ -558,28 +577,3 @@ class CheckEmailUnique(Resource):
|
|||
if not AccountService.check_email_unique(args["email"]):
|
||||
raise EmailAlreadyInUseError()
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
# Register API resources
|
||||
api.add_resource(AccountInitApi, "/account/init")
|
||||
api.add_resource(AccountProfileApi, "/account/profile")
|
||||
api.add_resource(AccountNameApi, "/account/name")
|
||||
api.add_resource(AccountAvatarApi, "/account/avatar")
|
||||
api.add_resource(AccountInterfaceLanguageApi, "/account/interface-language")
|
||||
api.add_resource(AccountInterfaceThemeApi, "/account/interface-theme")
|
||||
api.add_resource(AccountTimezoneApi, "/account/timezone")
|
||||
api.add_resource(AccountPasswordApi, "/account/password")
|
||||
api.add_resource(AccountIntegrateApi, "/account/integrates")
|
||||
api.add_resource(AccountDeleteVerifyApi, "/account/delete/verify")
|
||||
api.add_resource(AccountDeleteApi, "/account/delete")
|
||||
api.add_resource(AccountDeleteUpdateFeedbackApi, "/account/delete/feedback")
|
||||
api.add_resource(EducationVerifyApi, "/account/education/verify")
|
||||
api.add_resource(EducationApi, "/account/education")
|
||||
api.add_resource(EducationAutoCompleteApi, "/account/education/autocomplete")
|
||||
# Change email
|
||||
api.add_resource(ChangeEmailSendEmailApi, "/account/change-email")
|
||||
api.add_resource(ChangeEmailCheckApi, "/account/change-email/validity")
|
||||
api.add_resource(ChangeEmailResetApi, "/account/change-email/reset")
|
||||
api.add_resource(CheckEmailUnique, "/account/change-email/check-email-unique")
|
||||
# api.add_resource(AccountEmailApi, '/account/email')
|
||||
# api.add_resource(AccountEmailVerifyApi, '/account/email-verify')
|
||||
|
|
|
|||
|
|
@ -1,10 +1,10 @@
|
|||
from flask_login import current_user
|
||||
from flask_restx import Resource, fields
|
||||
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from libs.login import login_required
|
||||
from libs.login import current_user, login_required
|
||||
from models.account import Account
|
||||
from services.agent_service import AgentService
|
||||
|
||||
|
||||
|
|
@ -21,7 +21,9 @@ class AgentProviderListApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
assert isinstance(current_user, Account)
|
||||
user = current_user
|
||||
assert user.current_tenant_id is not None
|
||||
|
||||
user_id = user.id
|
||||
tenant_id = user.current_tenant_id
|
||||
|
|
@ -43,7 +45,9 @@ class AgentProviderApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider_name: str):
|
||||
assert isinstance(current_user, Account)
|
||||
user = current_user
|
||||
assert user.current_tenant_id is not None
|
||||
user_id = user.id
|
||||
tenant_id = user.current_tenant_id
|
||||
return jsonable_encoder(AgentService.get_agent_provider(user_id, tenant_id, provider_name))
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
from flask_login import current_user
|
||||
from flask_restx import Resource, fields, reqparse
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
|
|
@ -6,10 +5,18 @@ from controllers.console import api, console_ns
|
|||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.plugin.impl.exc import PluginPermissionDeniedError
|
||||
from libs.login import login_required
|
||||
from libs.login import current_user, login_required
|
||||
from models.account import Account
|
||||
from services.plugin.endpoint_service import EndpointService
|
||||
|
||||
|
||||
def _current_account_with_tenant() -> tuple[Account, str]:
|
||||
assert isinstance(current_user, Account)
|
||||
tenant_id = current_user.current_tenant_id
|
||||
assert tenant_id is not None
|
||||
return current_user, tenant_id
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/endpoints/create")
|
||||
class EndpointCreateApi(Resource):
|
||||
@api.doc("create_endpoint")
|
||||
|
|
@ -34,7 +41,7 @@ class EndpointCreateApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user = current_user
|
||||
user, tenant_id = _current_account_with_tenant()
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
|
|
@ -51,7 +58,7 @@ class EndpointCreateApi(Resource):
|
|||
try:
|
||||
return {
|
||||
"success": EndpointService.create_endpoint(
|
||||
tenant_id=user.current_tenant_id,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user.id,
|
||||
plugin_unique_identifier=plugin_unique_identifier,
|
||||
name=name,
|
||||
|
|
@ -80,7 +87,7 @@ class EndpointListApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
user = current_user
|
||||
user, tenant_id = _current_account_with_tenant()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("page", type=int, required=True, location="args")
|
||||
|
|
@ -93,7 +100,7 @@ class EndpointListApi(Resource):
|
|||
return jsonable_encoder(
|
||||
{
|
||||
"endpoints": EndpointService.list_endpoints(
|
||||
tenant_id=user.current_tenant_id,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user.id,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
|
|
@ -123,7 +130,7 @@ class EndpointListForSinglePluginApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
user = current_user
|
||||
user, tenant_id = _current_account_with_tenant()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("page", type=int, required=True, location="args")
|
||||
|
|
@ -138,7 +145,7 @@ class EndpointListForSinglePluginApi(Resource):
|
|||
return jsonable_encoder(
|
||||
{
|
||||
"endpoints": EndpointService.list_endpoints_for_single_plugin(
|
||||
tenant_id=user.current_tenant_id,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user.id,
|
||||
plugin_id=plugin_id,
|
||||
page=page,
|
||||
|
|
@ -165,7 +172,7 @@ class EndpointDeleteApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user = current_user
|
||||
user, tenant_id = _current_account_with_tenant()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("endpoint_id", type=str, required=True)
|
||||
|
|
@ -177,9 +184,7 @@ class EndpointDeleteApi(Resource):
|
|||
endpoint_id = args["endpoint_id"]
|
||||
|
||||
return {
|
||||
"success": EndpointService.delete_endpoint(
|
||||
tenant_id=user.current_tenant_id, user_id=user.id, endpoint_id=endpoint_id
|
||||
)
|
||||
"success": EndpointService.delete_endpoint(tenant_id=tenant_id, user_id=user.id, endpoint_id=endpoint_id)
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -207,7 +212,7 @@ class EndpointUpdateApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user = current_user
|
||||
user, tenant_id = _current_account_with_tenant()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("endpoint_id", type=str, required=True)
|
||||
|
|
@ -224,7 +229,7 @@ class EndpointUpdateApi(Resource):
|
|||
|
||||
return {
|
||||
"success": EndpointService.update_endpoint(
|
||||
tenant_id=user.current_tenant_id,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user.id,
|
||||
endpoint_id=endpoint_id,
|
||||
name=name,
|
||||
|
|
@ -250,7 +255,7 @@ class EndpointEnableApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user = current_user
|
||||
user, tenant_id = _current_account_with_tenant()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("endpoint_id", type=str, required=True)
|
||||
|
|
@ -262,9 +267,7 @@ class EndpointEnableApi(Resource):
|
|||
raise Forbidden()
|
||||
|
||||
return {
|
||||
"success": EndpointService.enable_endpoint(
|
||||
tenant_id=user.current_tenant_id, user_id=user.id, endpoint_id=endpoint_id
|
||||
)
|
||||
"success": EndpointService.enable_endpoint(tenant_id=tenant_id, user_id=user.id, endpoint_id=endpoint_id)
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -285,7 +288,7 @@ class EndpointDisableApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user = current_user
|
||||
user, tenant_id = _current_account_with_tenant()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("endpoint_id", type=str, required=True)
|
||||
|
|
@ -297,7 +300,5 @@ class EndpointDisableApi(Resource):
|
|||
raise Forbidden()
|
||||
|
||||
return {
|
||||
"success": EndpointService.disable_endpoint(
|
||||
tenant_id=user.current_tenant_id, user_id=user.id, endpoint_id=endpoint_id
|
||||
)
|
||||
"success": EndpointService.disable_endpoint(tenant_id=tenant_id, user_id=user.id, endpoint_id=endpoint_id)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from flask_restx import Resource, reqparse
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
|
|
@ -10,6 +10,9 @@ from models.account import Account, TenantAccountRole
|
|||
from services.model_load_balancing_service import ModelLoadBalancingService
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
"/workspaces/current/model-providers/<path:provider>/models/load-balancing-configs/credentials-validate"
|
||||
)
|
||||
class LoadBalancingCredentialsValidateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -61,6 +64,9 @@ class LoadBalancingCredentialsValidateApi(Resource):
|
|||
return response
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
"/workspaces/current/model-providers/<path:provider>/models/load-balancing-configs/<string:config_id>/credentials-validate"
|
||||
)
|
||||
class LoadBalancingConfigCredentialsValidateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -111,15 +117,3 @@ class LoadBalancingConfigCredentialsValidateApi(Resource):
|
|||
response["error"] = error
|
||||
|
||||
return response
|
||||
|
||||
|
||||
# Load Balancing Config
|
||||
api.add_resource(
|
||||
LoadBalancingCredentialsValidateApi,
|
||||
"/workspaces/current/model-providers/<path:provider>/models/load-balancing-configs/credentials-validate",
|
||||
)
|
||||
|
||||
api.add_resource(
|
||||
LoadBalancingConfigCredentialsValidateApi,
|
||||
"/workspaces/current/model-providers/<path:provider>/models/load-balancing-configs/<string:config_id>/credentials-validate",
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,12 +1,11 @@
|
|||
from urllib import parse
|
||||
|
||||
from flask import abort, request
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, marshal_with, reqparse
|
||||
|
||||
import services
|
||||
from configs import dify_config
|
||||
from controllers.console import api
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.auth.error import (
|
||||
CannotTransferOwnerToSelfError,
|
||||
EmailCodeError,
|
||||
|
|
@ -26,13 +25,14 @@ from controllers.console.wraps import (
|
|||
from extensions.ext_database import db
|
||||
from fields.member_fields import account_with_role_list_fields
|
||||
from libs.helper import extract_remote_ip
|
||||
from libs.login import login_required
|
||||
from libs.login import current_user, login_required
|
||||
from models.account import Account, TenantAccountRole
|
||||
from services.account_service import AccountService, RegisterService, TenantService
|
||||
from services.errors.account import AccountAlreadyInTenantError
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/members")
|
||||
class MemberListApi(Resource):
|
||||
"""List all members of current tenant."""
|
||||
|
||||
|
|
@ -49,6 +49,7 @@ class MemberListApi(Resource):
|
|||
return {"result": "success", "accounts": members}, 200
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/members/invite-email")
|
||||
class MemberInviteEmailApi(Resource):
|
||||
"""Invite a new member by email."""
|
||||
|
||||
|
|
@ -111,6 +112,7 @@ class MemberInviteEmailApi(Resource):
|
|||
}, 201
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/members/<uuid:member_id>")
|
||||
class MemberCancelInviteApi(Resource):
|
||||
"""Cancel an invitation by member id."""
|
||||
|
||||
|
|
@ -143,6 +145,7 @@ class MemberCancelInviteApi(Resource):
|
|||
}, 200
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/members/<uuid:member_id>/update-role")
|
||||
class MemberUpdateRoleApi(Resource):
|
||||
"""Update member role."""
|
||||
|
||||
|
|
@ -177,6 +180,7 @@ class MemberUpdateRoleApi(Resource):
|
|||
return {"result": "success"}
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/dataset-operators")
|
||||
class DatasetOperatorMemberListApi(Resource):
|
||||
"""List all members of current tenant."""
|
||||
|
||||
|
|
@ -193,6 +197,7 @@ class DatasetOperatorMemberListApi(Resource):
|
|||
return {"result": "success", "accounts": members}, 200
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/members/send-owner-transfer-confirm-email")
|
||||
class SendOwnerTransferEmailApi(Resource):
|
||||
"""Send owner transfer email."""
|
||||
|
||||
|
|
@ -233,6 +238,7 @@ class SendOwnerTransferEmailApi(Resource):
|
|||
return {"result": "success", "data": token}
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/members/owner-transfer-check")
|
||||
class OwnerTransferCheckApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -278,6 +284,7 @@ class OwnerTransferCheckApi(Resource):
|
|||
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/members/<uuid:member_id>/owner-transfer")
|
||||
class OwnerTransfer(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -339,14 +346,3 @@ class OwnerTransfer(Resource):
|
|||
raise ValueError(str(e))
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
api.add_resource(MemberListApi, "/workspaces/current/members")
|
||||
api.add_resource(MemberInviteEmailApi, "/workspaces/current/members/invite-email")
|
||||
api.add_resource(MemberCancelInviteApi, "/workspaces/current/members/<uuid:member_id>")
|
||||
api.add_resource(MemberUpdateRoleApi, "/workspaces/current/members/<uuid:member_id>/update-role")
|
||||
api.add_resource(DatasetOperatorMemberListApi, "/workspaces/current/dataset-operators")
|
||||
# owner transfer
|
||||
api.add_resource(SendOwnerTransferEmailApi, "/workspaces/current/members/send-owner-transfer-confirm-email")
|
||||
api.add_resource(OwnerTransferCheckApi, "/workspaces/current/members/owner-transfer-check")
|
||||
api.add_resource(OwnerTransfer, "/workspaces/current/members/<uuid:member_id>/owner-transfer")
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from flask_login import current_user
|
|||
from flask_restx import Resource, reqparse
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
|
|
@ -17,6 +17,7 @@ from services.billing_service import BillingService
|
|||
from services.model_provider_service import ModelProviderService
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/model-providers")
|
||||
class ModelProviderListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -45,6 +46,7 @@ class ModelProviderListApi(Resource):
|
|||
return jsonable_encoder({"data": provider_list})
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/model-providers/<path:provider>/credentials")
|
||||
class ModelProviderCredentialApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -151,6 +153,7 @@ class ModelProviderCredentialApi(Resource):
|
|||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/model-providers/<path:provider>/credentials/switch")
|
||||
class ModelProviderCredentialSwitchApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -175,6 +178,7 @@ class ModelProviderCredentialSwitchApi(Resource):
|
|||
return {"result": "success"}
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/model-providers/<path:provider>/credentials/validate")
|
||||
class ModelProviderValidateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -211,6 +215,7 @@ class ModelProviderValidateApi(Resource):
|
|||
return response
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/<string:tenant_id>/model-providers/<path:provider>/<string:icon_type>/<string:lang>")
|
||||
class ModelProviderIconApi(Resource):
|
||||
"""
|
||||
Get model provider icon
|
||||
|
|
@ -229,6 +234,7 @@ class ModelProviderIconApi(Resource):
|
|||
return send_file(io.BytesIO(icon), mimetype=mimetype)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/model-providers/<path:provider>/preferred-provider-type")
|
||||
class PreferredProviderTypeUpdateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -262,6 +268,7 @@ class PreferredProviderTypeUpdateApi(Resource):
|
|||
return {"result": "success"}
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/model-providers/<path:provider>/checkout-url")
|
||||
class ModelProviderPaymentCheckoutUrlApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -281,21 +288,3 @@ class ModelProviderPaymentCheckoutUrlApi(Resource):
|
|||
prefilled_email=current_user.email,
|
||||
)
|
||||
return data
|
||||
|
||||
|
||||
api.add_resource(ModelProviderListApi, "/workspaces/current/model-providers")
|
||||
|
||||
api.add_resource(ModelProviderCredentialApi, "/workspaces/current/model-providers/<path:provider>/credentials")
|
||||
api.add_resource(
|
||||
ModelProviderCredentialSwitchApi, "/workspaces/current/model-providers/<path:provider>/credentials/switch"
|
||||
)
|
||||
api.add_resource(ModelProviderValidateApi, "/workspaces/current/model-providers/<path:provider>/credentials/validate")
|
||||
|
||||
api.add_resource(
|
||||
PreferredProviderTypeUpdateApi, "/workspaces/current/model-providers/<path:provider>/preferred-provider-type"
|
||||
)
|
||||
api.add_resource(ModelProviderPaymentCheckoutUrlApi, "/workspaces/current/model-providers/<path:provider>/checkout-url")
|
||||
api.add_resource(
|
||||
ModelProviderIconApi,
|
||||
"/workspaces/<string:tenant_id>/model-providers/<path:provider>/<string:icon_type>/<string:lang>",
|
||||
)
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from flask_login import current_user
|
|||
from flask_restx import Resource, reqparse
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
|
|
@ -17,6 +17,7 @@ from services.model_provider_service import ModelProviderService
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/default-model")
|
||||
class DefaultModelApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -85,6 +86,7 @@ class DefaultModelApi(Resource):
|
|||
return {"result": "success"}
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/model-providers/<path:provider>/models")
|
||||
class ModelProviderModelApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -187,6 +189,7 @@ class ModelProviderModelApi(Resource):
|
|||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/model-providers/<path:provider>/models/credentials")
|
||||
class ModelProviderModelCredentialApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -364,6 +367,7 @@ class ModelProviderModelCredentialApi(Resource):
|
|||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/model-providers/<path:provider>/models/credentials/switch")
|
||||
class ModelProviderModelCredentialSwitchApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -395,6 +399,9 @@ class ModelProviderModelCredentialSwitchApi(Resource):
|
|||
return {"result": "success"}
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
"/workspaces/current/model-providers/<path:provider>/models/enable", endpoint="model-provider-model-enable"
|
||||
)
|
||||
class ModelProviderModelEnableApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -422,6 +429,9 @@ class ModelProviderModelEnableApi(Resource):
|
|||
return {"result": "success"}
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
"/workspaces/current/model-providers/<path:provider>/models/disable", endpoint="model-provider-model-disable"
|
||||
)
|
||||
class ModelProviderModelDisableApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -449,6 +459,7 @@ class ModelProviderModelDisableApi(Resource):
|
|||
return {"result": "success"}
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/model-providers/<path:provider>/models/credentials/validate")
|
||||
class ModelProviderModelValidateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -494,6 +505,7 @@ class ModelProviderModelValidateApi(Resource):
|
|||
return response
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/model-providers/<path:provider>/models/parameter-rules")
|
||||
class ModelProviderModelParameterRuleApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -513,6 +525,7 @@ class ModelProviderModelParameterRuleApi(Resource):
|
|||
return jsonable_encoder({"data": parameter_rules})
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/models/model-types/<string:model_type>")
|
||||
class ModelProviderAvailableModelApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -524,32 +537,3 @@ class ModelProviderAvailableModelApi(Resource):
|
|||
models = model_provider_service.get_models_by_model_type(tenant_id=tenant_id, model_type=model_type)
|
||||
|
||||
return jsonable_encoder({"data": models})
|
||||
|
||||
|
||||
api.add_resource(ModelProviderModelApi, "/workspaces/current/model-providers/<path:provider>/models")
|
||||
api.add_resource(
|
||||
ModelProviderModelEnableApi,
|
||||
"/workspaces/current/model-providers/<path:provider>/models/enable",
|
||||
endpoint="model-provider-model-enable",
|
||||
)
|
||||
api.add_resource(
|
||||
ModelProviderModelDisableApi,
|
||||
"/workspaces/current/model-providers/<path:provider>/models/disable",
|
||||
endpoint="model-provider-model-disable",
|
||||
)
|
||||
api.add_resource(
|
||||
ModelProviderModelCredentialApi, "/workspaces/current/model-providers/<path:provider>/models/credentials"
|
||||
)
|
||||
api.add_resource(
|
||||
ModelProviderModelCredentialSwitchApi,
|
||||
"/workspaces/current/model-providers/<path:provider>/models/credentials/switch",
|
||||
)
|
||||
api.add_resource(
|
||||
ModelProviderModelValidateApi, "/workspaces/current/model-providers/<path:provider>/models/credentials/validate"
|
||||
)
|
||||
|
||||
api.add_resource(
|
||||
ModelProviderModelParameterRuleApi, "/workspaces/current/model-providers/<path:provider>/models/parameter-rules"
|
||||
)
|
||||
api.add_resource(ModelProviderAvailableModelApi, "/workspaces/current/models/model-types/<string:model_type>")
|
||||
api.add_resource(DefaultModelApi, "/workspaces/current/default-model")
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from flask_restx import Resource, reqparse
|
|||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.console import api
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.workspace import plugin_permission_required
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
|
|
@ -19,6 +19,7 @@ from services.plugin.plugin_permission_service import PluginPermissionService
|
|||
from services.plugin.plugin_service import PluginService
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/debugging-key")
|
||||
class PluginDebuggingKeyApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -37,6 +38,7 @@ class PluginDebuggingKeyApi(Resource):
|
|||
raise ValueError(e)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/list")
|
||||
class PluginListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -55,6 +57,7 @@ class PluginListApi(Resource):
|
|||
return jsonable_encoder({"plugins": plugins_with_total.list, "total": plugins_with_total.total})
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/list/latest-versions")
|
||||
class PluginListLatestVersionsApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -72,6 +75,7 @@ class PluginListLatestVersionsApi(Resource):
|
|||
return jsonable_encoder({"versions": versions})
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/list/installations/ids")
|
||||
class PluginListInstallationsFromIdsApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -91,6 +95,7 @@ class PluginListInstallationsFromIdsApi(Resource):
|
|||
return jsonable_encoder({"plugins": plugins})
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/icon")
|
||||
class PluginIconApi(Resource):
|
||||
@setup_required
|
||||
def get(self):
|
||||
|
|
@ -108,6 +113,7 @@ class PluginIconApi(Resource):
|
|||
return send_file(io.BytesIO(icon_bytes), mimetype=mimetype, max_age=icon_cache_max_age)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/upload/pkg")
|
||||
class PluginUploadFromPkgApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -131,6 +137,7 @@ class PluginUploadFromPkgApi(Resource):
|
|||
return jsonable_encoder(response)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/upload/github")
|
||||
class PluginUploadFromGithubApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -153,6 +160,7 @@ class PluginUploadFromGithubApi(Resource):
|
|||
return jsonable_encoder(response)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/upload/bundle")
|
||||
class PluginUploadFromBundleApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -176,6 +184,7 @@ class PluginUploadFromBundleApi(Resource):
|
|||
return jsonable_encoder(response)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/install/pkg")
|
||||
class PluginInstallFromPkgApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -201,6 +210,7 @@ class PluginInstallFromPkgApi(Resource):
|
|||
return jsonable_encoder(response)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/install/github")
|
||||
class PluginInstallFromGithubApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -230,6 +240,7 @@ class PluginInstallFromGithubApi(Resource):
|
|||
return jsonable_encoder(response)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/install/marketplace")
|
||||
class PluginInstallFromMarketplaceApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -255,6 +266,7 @@ class PluginInstallFromMarketplaceApi(Resource):
|
|||
return jsonable_encoder(response)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/marketplace/pkg")
|
||||
class PluginFetchMarketplacePkgApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -280,6 +292,7 @@ class PluginFetchMarketplacePkgApi(Resource):
|
|||
raise ValueError(e)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/fetch-manifest")
|
||||
class PluginFetchManifestApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -304,6 +317,7 @@ class PluginFetchManifestApi(Resource):
|
|||
raise ValueError(e)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/tasks")
|
||||
class PluginFetchInstallTasksApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -325,6 +339,7 @@ class PluginFetchInstallTasksApi(Resource):
|
|||
raise ValueError(e)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/tasks/<task_id>")
|
||||
class PluginFetchInstallTaskApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -339,6 +354,7 @@ class PluginFetchInstallTaskApi(Resource):
|
|||
raise ValueError(e)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/tasks/<task_id>/delete")
|
||||
class PluginDeleteInstallTaskApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -353,6 +369,7 @@ class PluginDeleteInstallTaskApi(Resource):
|
|||
raise ValueError(e)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/tasks/delete_all")
|
||||
class PluginDeleteAllInstallTaskItemsApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -367,6 +384,7 @@ class PluginDeleteAllInstallTaskItemsApi(Resource):
|
|||
raise ValueError(e)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/tasks/<task_id>/delete/<path:identifier>")
|
||||
class PluginDeleteInstallTaskItemApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -381,6 +399,7 @@ class PluginDeleteInstallTaskItemApi(Resource):
|
|||
raise ValueError(e)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/upgrade/marketplace")
|
||||
class PluginUpgradeFromMarketplaceApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -404,6 +423,7 @@ class PluginUpgradeFromMarketplaceApi(Resource):
|
|||
raise ValueError(e)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/upgrade/github")
|
||||
class PluginUpgradeFromGithubApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -435,6 +455,7 @@ class PluginUpgradeFromGithubApi(Resource):
|
|||
raise ValueError(e)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/uninstall")
|
||||
class PluginUninstallApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -453,6 +474,7 @@ class PluginUninstallApi(Resource):
|
|||
raise ValueError(e)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/permission/change")
|
||||
class PluginChangePermissionApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -475,6 +497,7 @@ class PluginChangePermissionApi(Resource):
|
|||
return {"success": PluginPermissionService.change_permission(tenant_id, install_permission, debug_permission)}
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/permission/fetch")
|
||||
class PluginFetchPermissionApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -499,6 +522,7 @@ class PluginFetchPermissionApi(Resource):
|
|||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/parameters/dynamic-options")
|
||||
class PluginFetchDynamicSelectOptionsApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -535,6 +559,7 @@ class PluginFetchDynamicSelectOptionsApi(Resource):
|
|||
return jsonable_encoder({"options": options})
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/preferences/change")
|
||||
class PluginChangePreferencesApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -590,6 +615,7 @@ class PluginChangePreferencesApi(Resource):
|
|||
return jsonable_encoder({"success": True})
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/preferences/fetch")
|
||||
class PluginFetchPreferencesApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -628,6 +654,7 @@ class PluginFetchPreferencesApi(Resource):
|
|||
return jsonable_encoder({"permission": permission_dict, "auto_upgrade": auto_upgrade_dict})
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/preferences/autoupgrade/exclude")
|
||||
class PluginAutoUpgradeExcludePluginApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -641,35 +668,3 @@ class PluginAutoUpgradeExcludePluginApi(Resource):
|
|||
args = req.parse_args()
|
||||
|
||||
return jsonable_encoder({"success": PluginAutoUpgradeService.exclude_plugin(tenant_id, args["plugin_id"])})
|
||||
|
||||
|
||||
api.add_resource(PluginDebuggingKeyApi, "/workspaces/current/plugin/debugging-key")
|
||||
api.add_resource(PluginListApi, "/workspaces/current/plugin/list")
|
||||
api.add_resource(PluginListLatestVersionsApi, "/workspaces/current/plugin/list/latest-versions")
|
||||
api.add_resource(PluginListInstallationsFromIdsApi, "/workspaces/current/plugin/list/installations/ids")
|
||||
api.add_resource(PluginIconApi, "/workspaces/current/plugin/icon")
|
||||
api.add_resource(PluginUploadFromPkgApi, "/workspaces/current/plugin/upload/pkg")
|
||||
api.add_resource(PluginUploadFromGithubApi, "/workspaces/current/plugin/upload/github")
|
||||
api.add_resource(PluginUploadFromBundleApi, "/workspaces/current/plugin/upload/bundle")
|
||||
api.add_resource(PluginInstallFromPkgApi, "/workspaces/current/plugin/install/pkg")
|
||||
api.add_resource(PluginInstallFromGithubApi, "/workspaces/current/plugin/install/github")
|
||||
api.add_resource(PluginUpgradeFromMarketplaceApi, "/workspaces/current/plugin/upgrade/marketplace")
|
||||
api.add_resource(PluginUpgradeFromGithubApi, "/workspaces/current/plugin/upgrade/github")
|
||||
api.add_resource(PluginInstallFromMarketplaceApi, "/workspaces/current/plugin/install/marketplace")
|
||||
api.add_resource(PluginFetchManifestApi, "/workspaces/current/plugin/fetch-manifest")
|
||||
api.add_resource(PluginFetchInstallTasksApi, "/workspaces/current/plugin/tasks")
|
||||
api.add_resource(PluginFetchInstallTaskApi, "/workspaces/current/plugin/tasks/<task_id>")
|
||||
api.add_resource(PluginDeleteInstallTaskApi, "/workspaces/current/plugin/tasks/<task_id>/delete")
|
||||
api.add_resource(PluginDeleteAllInstallTaskItemsApi, "/workspaces/current/plugin/tasks/delete_all")
|
||||
api.add_resource(PluginDeleteInstallTaskItemApi, "/workspaces/current/plugin/tasks/<task_id>/delete/<path:identifier>")
|
||||
api.add_resource(PluginUninstallApi, "/workspaces/current/plugin/uninstall")
|
||||
api.add_resource(PluginFetchMarketplacePkgApi, "/workspaces/current/plugin/marketplace/pkg")
|
||||
|
||||
api.add_resource(PluginChangePermissionApi, "/workspaces/current/plugin/permission/change")
|
||||
api.add_resource(PluginFetchPermissionApi, "/workspaces/current/plugin/permission/fetch")
|
||||
|
||||
api.add_resource(PluginFetchDynamicSelectOptionsApi, "/workspaces/current/plugin/parameters/dynamic-options")
|
||||
|
||||
api.add_resource(PluginFetchPreferencesApi, "/workspaces/current/plugin/preferences/fetch")
|
||||
api.add_resource(PluginChangePreferencesApi, "/workspaces/current/plugin/preferences/change")
|
||||
api.add_resource(PluginAutoUpgradeExcludePluginApi, "/workspaces/current/plugin/preferences/autoupgrade/exclude")
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ from flask_restx import (
|
|||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.console import api
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
enterprise_license_required,
|
||||
|
|
@ -47,6 +47,7 @@ def is_valid_url(url: str) -> bool:
|
|||
return False
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/tool-providers")
|
||||
class ToolProviderListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -71,6 +72,7 @@ class ToolProviderListApi(Resource):
|
|||
return ToolCommonService.list_tool_providers(user_id, tenant_id, args.get("type", None))
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/tools")
|
||||
class ToolBuiltinProviderListToolsApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -88,6 +90,7 @@ class ToolBuiltinProviderListToolsApi(Resource):
|
|||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/info")
|
||||
class ToolBuiltinProviderInfoApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -100,6 +103,7 @@ class ToolBuiltinProviderInfoApi(Resource):
|
|||
return jsonable_encoder(BuiltinToolManageService.get_builtin_tool_provider_info(tenant_id, provider))
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/delete")
|
||||
class ToolBuiltinProviderDeleteApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -121,6 +125,7 @@ class ToolBuiltinProviderDeleteApi(Resource):
|
|||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/add")
|
||||
class ToolBuiltinProviderAddApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -150,6 +155,7 @@ class ToolBuiltinProviderAddApi(Resource):
|
|||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/update")
|
||||
class ToolBuiltinProviderUpdateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -181,6 +187,7 @@ class ToolBuiltinProviderUpdateApi(Resource):
|
|||
return result
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/credentials")
|
||||
class ToolBuiltinProviderGetCredentialsApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -196,6 +203,7 @@ class ToolBuiltinProviderGetCredentialsApi(Resource):
|
|||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/icon")
|
||||
class ToolBuiltinProviderIconApi(Resource):
|
||||
@setup_required
|
||||
def get(self, provider):
|
||||
|
|
@ -204,6 +212,7 @@ class ToolBuiltinProviderIconApi(Resource):
|
|||
return send_file(io.BytesIO(icon_bytes), mimetype=mimetype, max_age=icon_cache_max_age)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/tool-provider/api/add")
|
||||
class ToolApiProviderAddApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -243,6 +252,7 @@ class ToolApiProviderAddApi(Resource):
|
|||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/tool-provider/api/remote")
|
||||
class ToolApiProviderGetRemoteSchemaApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -266,6 +276,7 @@ class ToolApiProviderGetRemoteSchemaApi(Resource):
|
|||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/tool-provider/api/tools")
|
||||
class ToolApiProviderListToolsApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -291,6 +302,7 @@ class ToolApiProviderListToolsApi(Resource):
|
|||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/tool-provider/api/update")
|
||||
class ToolApiProviderUpdateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -332,6 +344,7 @@ class ToolApiProviderUpdateApi(Resource):
|
|||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/tool-provider/api/delete")
|
||||
class ToolApiProviderDeleteApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -358,6 +371,7 @@ class ToolApiProviderDeleteApi(Resource):
|
|||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/tool-provider/api/get")
|
||||
class ToolApiProviderGetApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -381,6 +395,7 @@ class ToolApiProviderGetApi(Resource):
|
|||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/credential/schema/<path:credential_type>")
|
||||
class ToolBuiltinProviderCredentialsSchemaApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -396,6 +411,7 @@ class ToolBuiltinProviderCredentialsSchemaApi(Resource):
|
|||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/tool-provider/api/schema")
|
||||
class ToolApiProviderSchemaApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -412,6 +428,7 @@ class ToolApiProviderSchemaApi(Resource):
|
|||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/tool-provider/api/test/pre")
|
||||
class ToolApiProviderPreviousTestApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -439,6 +456,7 @@ class ToolApiProviderPreviousTestApi(Resource):
|
|||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/tool-provider/workflow/create")
|
||||
class ToolWorkflowProviderCreateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -478,6 +496,7 @@ class ToolWorkflowProviderCreateApi(Resource):
|
|||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/tool-provider/workflow/update")
|
||||
class ToolWorkflowProviderUpdateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -520,6 +539,7 @@ class ToolWorkflowProviderUpdateApi(Resource):
|
|||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/tool-provider/workflow/delete")
|
||||
class ToolWorkflowProviderDeleteApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -545,6 +565,7 @@ class ToolWorkflowProviderDeleteApi(Resource):
|
|||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/tool-provider/workflow/get")
|
||||
class ToolWorkflowProviderGetApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -579,6 +600,7 @@ class ToolWorkflowProviderGetApi(Resource):
|
|||
return jsonable_encoder(tool)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/tool-provider/workflow/tools")
|
||||
class ToolWorkflowProviderListToolApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -603,6 +625,7 @@ class ToolWorkflowProviderListToolApi(Resource):
|
|||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/tools/builtin")
|
||||
class ToolBuiltinListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -624,6 +647,7 @@ class ToolBuiltinListApi(Resource):
|
|||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/tools/api")
|
||||
class ToolApiListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -642,6 +666,7 @@ class ToolApiListApi(Resource):
|
|||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/tools/workflow")
|
||||
class ToolWorkflowListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -663,6 +688,7 @@ class ToolWorkflowListApi(Resource):
|
|||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/tool-labels")
|
||||
class ToolLabelsApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -672,6 +698,7 @@ class ToolLabelsApi(Resource):
|
|||
return jsonable_encoder(ToolLabelsService.list_tool_labels())
|
||||
|
||||
|
||||
@console_ns.route("/oauth/plugin/<path:provider>/tool/authorization-url")
|
||||
class ToolPluginOAuthApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -716,6 +743,7 @@ class ToolPluginOAuthApi(Resource):
|
|||
return response
|
||||
|
||||
|
||||
@console_ns.route("/oauth/plugin/<path:provider>/tool/callback")
|
||||
class ToolOAuthCallback(Resource):
|
||||
@setup_required
|
||||
def get(self, provider):
|
||||
|
|
@ -766,6 +794,7 @@ class ToolOAuthCallback(Resource):
|
|||
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/default-credential")
|
||||
class ToolBuiltinProviderSetDefaultApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -779,6 +808,7 @@ class ToolBuiltinProviderSetDefaultApi(Resource):
|
|||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/oauth/custom-client")
|
||||
class ToolOAuthCustomClient(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -822,6 +852,7 @@ class ToolOAuthCustomClient(Resource):
|
|||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/oauth/client-schema")
|
||||
class ToolBuiltinProviderGetOauthClientSchemaApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -834,6 +865,7 @@ class ToolBuiltinProviderGetOauthClientSchemaApi(Resource):
|
|||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/credential/info")
|
||||
class ToolBuiltinProviderGetCredentialInfoApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -849,6 +881,7 @@ class ToolBuiltinProviderGetCredentialInfoApi(Resource):
|
|||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/tool-provider/mcp")
|
||||
class ToolProviderMCPApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -933,6 +966,7 @@ class ToolProviderMCPApi(Resource):
|
|||
return {"result": "success"}
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/tool-provider/mcp/auth")
|
||||
class ToolMCPAuthApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -978,6 +1012,7 @@ class ToolMCPAuthApi(Resource):
|
|||
raise ValueError(f"Failed to connect to MCP server: {e}") from e
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/tool-provider/mcp/tools/<path:provider_id>")
|
||||
class ToolMCPDetailApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -988,6 +1023,7 @@ class ToolMCPDetailApi(Resource):
|
|||
return jsonable_encoder(ToolTransformService.mcp_provider_to_user_provider(provider, for_list=True))
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/tools/mcp")
|
||||
class ToolMCPListAllApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -1001,6 +1037,7 @@ class ToolMCPListAllApi(Resource):
|
|||
return [tool.to_dict() for tool in tools]
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/tool-provider/mcp/update/<path:provider_id>")
|
||||
class ToolMCPUpdateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -1014,6 +1051,7 @@ class ToolMCPUpdateApi(Resource):
|
|||
return jsonable_encoder(tools)
|
||||
|
||||
|
||||
@console_ns.route("/mcp/oauth/callback")
|
||||
class ToolMCPCallbackApi(Resource):
|
||||
def get(self):
|
||||
parser = reqparse.RequestParser()
|
||||
|
|
@ -1024,67 +1062,3 @@ class ToolMCPCallbackApi(Resource):
|
|||
authorization_code = args["code"]
|
||||
handle_callback(state_key, authorization_code)
|
||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")
|
||||
|
||||
|
||||
# tool provider
|
||||
api.add_resource(ToolProviderListApi, "/workspaces/current/tool-providers")
|
||||
|
||||
# tool oauth
|
||||
api.add_resource(ToolPluginOAuthApi, "/oauth/plugin/<path:provider>/tool/authorization-url")
|
||||
api.add_resource(ToolOAuthCallback, "/oauth/plugin/<path:provider>/tool/callback")
|
||||
api.add_resource(ToolOAuthCustomClient, "/workspaces/current/tool-provider/builtin/<path:provider>/oauth/custom-client")
|
||||
|
||||
# builtin tool provider
|
||||
api.add_resource(ToolBuiltinProviderListToolsApi, "/workspaces/current/tool-provider/builtin/<path:provider>/tools")
|
||||
api.add_resource(ToolBuiltinProviderInfoApi, "/workspaces/current/tool-provider/builtin/<path:provider>/info")
|
||||
api.add_resource(ToolBuiltinProviderAddApi, "/workspaces/current/tool-provider/builtin/<path:provider>/add")
|
||||
api.add_resource(ToolBuiltinProviderDeleteApi, "/workspaces/current/tool-provider/builtin/<path:provider>/delete")
|
||||
api.add_resource(ToolBuiltinProviderUpdateApi, "/workspaces/current/tool-provider/builtin/<path:provider>/update")
|
||||
api.add_resource(
|
||||
ToolBuiltinProviderSetDefaultApi, "/workspaces/current/tool-provider/builtin/<path:provider>/default-credential"
|
||||
)
|
||||
api.add_resource(
|
||||
ToolBuiltinProviderGetCredentialInfoApi, "/workspaces/current/tool-provider/builtin/<path:provider>/credential/info"
|
||||
)
|
||||
api.add_resource(
|
||||
ToolBuiltinProviderGetCredentialsApi, "/workspaces/current/tool-provider/builtin/<path:provider>/credentials"
|
||||
)
|
||||
api.add_resource(
|
||||
ToolBuiltinProviderCredentialsSchemaApi,
|
||||
"/workspaces/current/tool-provider/builtin/<path:provider>/credential/schema/<path:credential_type>",
|
||||
)
|
||||
api.add_resource(
|
||||
ToolBuiltinProviderGetOauthClientSchemaApi,
|
||||
"/workspaces/current/tool-provider/builtin/<path:provider>/oauth/client-schema",
|
||||
)
|
||||
api.add_resource(ToolBuiltinProviderIconApi, "/workspaces/current/tool-provider/builtin/<path:provider>/icon")
|
||||
|
||||
# api tool provider
|
||||
api.add_resource(ToolApiProviderAddApi, "/workspaces/current/tool-provider/api/add")
|
||||
api.add_resource(ToolApiProviderGetRemoteSchemaApi, "/workspaces/current/tool-provider/api/remote")
|
||||
api.add_resource(ToolApiProviderListToolsApi, "/workspaces/current/tool-provider/api/tools")
|
||||
api.add_resource(ToolApiProviderUpdateApi, "/workspaces/current/tool-provider/api/update")
|
||||
api.add_resource(ToolApiProviderDeleteApi, "/workspaces/current/tool-provider/api/delete")
|
||||
api.add_resource(ToolApiProviderGetApi, "/workspaces/current/tool-provider/api/get")
|
||||
api.add_resource(ToolApiProviderSchemaApi, "/workspaces/current/tool-provider/api/schema")
|
||||
api.add_resource(ToolApiProviderPreviousTestApi, "/workspaces/current/tool-provider/api/test/pre")
|
||||
|
||||
# workflow tool provider
|
||||
api.add_resource(ToolWorkflowProviderCreateApi, "/workspaces/current/tool-provider/workflow/create")
|
||||
api.add_resource(ToolWorkflowProviderUpdateApi, "/workspaces/current/tool-provider/workflow/update")
|
||||
api.add_resource(ToolWorkflowProviderDeleteApi, "/workspaces/current/tool-provider/workflow/delete")
|
||||
api.add_resource(ToolWorkflowProviderGetApi, "/workspaces/current/tool-provider/workflow/get")
|
||||
api.add_resource(ToolWorkflowProviderListToolApi, "/workspaces/current/tool-provider/workflow/tools")
|
||||
|
||||
# mcp tool provider
|
||||
api.add_resource(ToolMCPDetailApi, "/workspaces/current/tool-provider/mcp/tools/<path:provider_id>")
|
||||
api.add_resource(ToolProviderMCPApi, "/workspaces/current/tool-provider/mcp")
|
||||
api.add_resource(ToolMCPUpdateApi, "/workspaces/current/tool-provider/mcp/update/<path:provider_id>")
|
||||
api.add_resource(ToolMCPAuthApi, "/workspaces/current/tool-provider/mcp/auth")
|
||||
api.add_resource(ToolMCPCallbackApi, "/mcp/oauth/callback")
|
||||
|
||||
api.add_resource(ToolBuiltinListApi, "/workspaces/current/tools/builtin")
|
||||
api.add_resource(ToolApiListApi, "/workspaces/current/tools/api")
|
||||
api.add_resource(ToolMCPListAllApi, "/workspaces/current/tools/mcp")
|
||||
api.add_resource(ToolWorkflowListApi, "/workspaces/current/tools/workflow")
|
||||
api.add_resource(ToolLabelsApi, "/workspaces/current/tool-labels")
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
import logging
|
||||
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse
|
||||
from sqlalchemy import select
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
|
@ -14,7 +13,7 @@ from controllers.common.errors import (
|
|||
TooManyFilesError,
|
||||
UnsupportedFileTypeError,
|
||||
)
|
||||
from controllers.console import api
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.admin import admin_required
|
||||
from controllers.console.error import AccountNotLinkTenantError
|
||||
from controllers.console.wraps import (
|
||||
|
|
@ -24,7 +23,7 @@ from controllers.console.wraps import (
|
|||
)
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import TimestampField
|
||||
from libs.login import login_required
|
||||
from libs.login import current_user, login_required
|
||||
from models.account import Account, Tenant, TenantStatus
|
||||
from services.account_service import TenantService
|
||||
from services.feature_service import FeatureService
|
||||
|
|
@ -65,6 +64,7 @@ tenants_fields = {
|
|||
workspace_fields = {"id": fields.String, "name": fields.String, "status": fields.String, "created_at": TimestampField}
|
||||
|
||||
|
||||
@console_ns.route("/workspaces")
|
||||
class TenantListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -93,6 +93,7 @@ class TenantListApi(Resource):
|
|||
return {"workspaces": marshal(tenant_dicts, tenants_fields)}, 200
|
||||
|
||||
|
||||
@console_ns.route("/all-workspaces")
|
||||
class WorkspaceListApi(Resource):
|
||||
@setup_required
|
||||
@admin_required
|
||||
|
|
@ -118,6 +119,8 @@ class WorkspaceListApi(Resource):
|
|||
}, 200
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current", endpoint="workspaces_current")
|
||||
@console_ns.route("/info", endpoint="info") # Deprecated
|
||||
class TenantApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -143,11 +146,10 @@ class TenantApi(Resource):
|
|||
else:
|
||||
raise Unauthorized("workspace is archived")
|
||||
|
||||
if not tenant:
|
||||
raise ValueError("No tenant available")
|
||||
return WorkspaceService.get_tenant_info(tenant), 200
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/switch")
|
||||
class SwitchWorkspaceApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -172,6 +174,7 @@ class SwitchWorkspaceApi(Resource):
|
|||
return {"result": "success", "new_tenant": marshal(WorkspaceService.get_tenant_info(new_tenant), tenant_fields)}
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/custom-config")
|
||||
class CustomConfigWorkspaceApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -202,6 +205,7 @@ class CustomConfigWorkspaceApi(Resource):
|
|||
return {"result": "success", "tenant": marshal(WorkspaceService.get_tenant_info(tenant), tenant_fields)}
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/custom-config/webapp-logo/upload")
|
||||
class WebappLogoWorkspaceApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -242,6 +246,7 @@ class WebappLogoWorkspaceApi(Resource):
|
|||
return {"id": upload_file.id}, 201
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/info")
|
||||
class WorkspaceInfoApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -261,13 +266,3 @@ class WorkspaceInfoApi(Resource):
|
|||
db.session.commit()
|
||||
|
||||
return {"result": "success", "tenant": marshal(WorkspaceService.get_tenant_info(tenant), tenant_fields)}
|
||||
|
||||
|
||||
api.add_resource(TenantListApi, "/workspaces") # GET for getting all tenants
|
||||
api.add_resource(WorkspaceListApi, "/all-workspaces") # GET for getting all tenants
|
||||
api.add_resource(TenantApi, "/workspaces/current", endpoint="workspaces_current") # GET for getting current tenant info
|
||||
api.add_resource(TenantApi, "/info", endpoint="info") # Deprecated
|
||||
api.add_resource(SwitchWorkspaceApi, "/workspaces/switch") # POST for switching tenant
|
||||
api.add_resource(CustomConfigWorkspaceApi, "/workspaces/custom-config")
|
||||
api.add_resource(WebappLogoWorkspaceApi, "/workspaces/custom-config/webapp-logo/upload")
|
||||
api.add_resource(WorkspaceInfoApi, "/workspaces/info") # POST for changing workspace info
|
||||
|
|
|
|||
|
|
@ -7,13 +7,13 @@ from functools import wraps
|
|||
from typing import ParamSpec, TypeVar
|
||||
|
||||
from flask import abort, request
|
||||
from flask_login import current_user
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.console.workspace.error import AccountNotInitializedError
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.account import AccountStatus
|
||||
from libs.login import current_user
|
||||
from models.account import Account, AccountStatus
|
||||
from models.dataset import RateLimitLog
|
||||
from models.model import DifySetup
|
||||
from services.feature_service import FeatureService, LicenseStatus
|
||||
|
|
@ -25,11 +25,16 @@ P = ParamSpec("P")
|
|||
R = TypeVar("R")
|
||||
|
||||
|
||||
def _current_account() -> Account:
|
||||
assert isinstance(current_user, Account)
|
||||
return current_user
|
||||
|
||||
|
||||
def account_initialization_required(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
# check account initialization
|
||||
account = current_user
|
||||
account = _current_account()
|
||||
|
||||
if account.status == AccountStatus.UNINITIALIZED:
|
||||
raise AccountNotInitializedError()
|
||||
|
|
@ -75,7 +80,9 @@ def only_edition_self_hosted(view: Callable[P, R]):
|
|||
def cloud_edition_billing_enabled(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
features = FeatureService.get_features(current_user.current_tenant_id)
|
||||
account = _current_account()
|
||||
assert account.current_tenant_id is not None
|
||||
features = FeatureService.get_features(account.current_tenant_id)
|
||||
if not features.billing.enabled:
|
||||
abort(403, "Billing feature is not enabled.")
|
||||
return view(*args, **kwargs)
|
||||
|
|
@ -87,7 +94,10 @@ def cloud_edition_billing_resource_check(resource: str):
|
|||
def interceptor(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
features = FeatureService.get_features(current_user.current_tenant_id)
|
||||
account = _current_account()
|
||||
assert account.current_tenant_id is not None
|
||||
tenant_id = account.current_tenant_id
|
||||
features = FeatureService.get_features(tenant_id)
|
||||
if features.billing.enabled:
|
||||
members = features.members
|
||||
apps = features.apps
|
||||
|
|
@ -128,7 +138,9 @@ def cloud_edition_billing_knowledge_limit_check(resource: str):
|
|||
def interceptor(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
features = FeatureService.get_features(current_user.current_tenant_id)
|
||||
account = _current_account()
|
||||
assert account.current_tenant_id is not None
|
||||
features = FeatureService.get_features(account.current_tenant_id)
|
||||
if features.billing.enabled:
|
||||
if resource == "add_segment":
|
||||
if features.billing.subscription.plan == "sandbox":
|
||||
|
|
@ -151,10 +163,13 @@ def cloud_edition_billing_rate_limit_check(resource: str):
|
|||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
if resource == "knowledge":
|
||||
knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(current_user.current_tenant_id)
|
||||
account = _current_account()
|
||||
assert account.current_tenant_id is not None
|
||||
tenant_id = account.current_tenant_id
|
||||
knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(tenant_id)
|
||||
if knowledge_rate_limit.enabled:
|
||||
current_time = int(time.time() * 1000)
|
||||
key = f"rate_limit_{current_user.current_tenant_id}"
|
||||
key = f"rate_limit_{tenant_id}"
|
||||
|
||||
redis_client.zadd(key, {current_time: current_time})
|
||||
|
||||
|
|
@ -165,7 +180,7 @@ def cloud_edition_billing_rate_limit_check(resource: str):
|
|||
if request_count > knowledge_rate_limit.limit:
|
||||
# add ratelimit record
|
||||
rate_limit_log = RateLimitLog(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=tenant_id,
|
||||
subscription_plan=knowledge_rate_limit.subscription_plan,
|
||||
operation="knowledge",
|
||||
)
|
||||
|
|
@ -185,14 +200,17 @@ def cloud_utm_record(view: Callable[P, R]):
|
|||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
with contextlib.suppress(Exception):
|
||||
features = FeatureService.get_features(current_user.current_tenant_id)
|
||||
account = _current_account()
|
||||
assert account.current_tenant_id is not None
|
||||
tenant_id = account.current_tenant_id
|
||||
features = FeatureService.get_features(tenant_id)
|
||||
|
||||
if features.billing.enabled:
|
||||
utm_info = request.cookies.get("utm_info")
|
||||
|
||||
if utm_info:
|
||||
utm_info_dict: dict = json.loads(utm_info)
|
||||
OperationService.record_utm(current_user.current_tenant_id, utm_info_dict)
|
||||
OperationService.record_utm(tenant_id, utm_info_dict)
|
||||
|
||||
return view(*args, **kwargs)
|
||||
|
||||
|
|
@ -271,7 +289,9 @@ def enable_change_email(view: Callable[P, R]):
|
|||
def is_allow_transfer_owner(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
features = FeatureService.get_features(current_user.current_tenant_id)
|
||||
account = _current_account()
|
||||
assert account.current_tenant_id is not None
|
||||
features = FeatureService.get_features(account.current_tenant_id)
|
||||
if features.is_allow_transfer_workspace:
|
||||
return view(*args, **kwargs)
|
||||
|
||||
|
|
@ -284,7 +304,9 @@ def is_allow_transfer_owner(view: Callable[P, R]):
|
|||
def knowledge_pipeline_publish_enabled(view):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
features = FeatureService.get_features(current_user.current_tenant_id)
|
||||
account = _current_account()
|
||||
assert account.current_tenant_id is not None
|
||||
features = FeatureService.get_features(account.current_tenant_id)
|
||||
if features.knowledge_pipeline.publish_enabled:
|
||||
return view(*args, **kwargs)
|
||||
abort(403)
|
||||
|
|
|
|||
|
|
@ -25,8 +25,8 @@ def get_user(tenant_id: str, user_id: str | None) -> EndUser:
|
|||
As a result, it could only be considered as an end user id.
|
||||
"""
|
||||
if not user_id:
|
||||
user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID.value
|
||||
is_anonymous = user_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID.value
|
||||
user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID
|
||||
is_anonymous = user_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID
|
||||
try:
|
||||
with Session(db.engine) as session:
|
||||
user_model = None
|
||||
|
|
@ -85,7 +85,7 @@ def get_user_tenant(view: Callable[P, R] | None = None):
|
|||
raise ValueError("tenant_id is required")
|
||||
|
||||
if not user_id:
|
||||
user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID.value
|
||||
user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID
|
||||
|
||||
try:
|
||||
tenant_model = (
|
||||
|
|
@ -128,7 +128,7 @@ def plugin_data(view: Callable[P, R] | None = None, *, payload_type: type[BaseMo
|
|||
raise ValueError("invalid json")
|
||||
|
||||
try:
|
||||
payload = payload_type(**data)
|
||||
payload = payload_type.model_validate(data)
|
||||
except Exception as e:
|
||||
raise ValueError(f"invalid payload: {str(e)}")
|
||||
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ from core.provider_manager import ProviderManager
|
|||
from fields.dataset_fields import dataset_detail_fields
|
||||
from fields.tag_fields import build_dataset_tag_fields
|
||||
from libs.login import current_user
|
||||
from libs.validators import validate_description_length
|
||||
from models.account import Account
|
||||
from models.dataset import Dataset, DatasetPermissionEnum
|
||||
from models.provider_ids import ModelProviderID
|
||||
|
|
@ -31,12 +32,6 @@ def _validate_name(name):
|
|||
return name
|
||||
|
||||
|
||||
def _validate_description_length(description):
|
||||
if description and len(description) > 400:
|
||||
raise ValueError("Description cannot exceed 400 characters.")
|
||||
return description
|
||||
|
||||
|
||||
# Define parsers for dataset operations
|
||||
dataset_create_parser = reqparse.RequestParser()
|
||||
dataset_create_parser.add_argument(
|
||||
|
|
@ -48,7 +43,7 @@ dataset_create_parser.add_argument(
|
|||
)
|
||||
dataset_create_parser.add_argument(
|
||||
"description",
|
||||
type=_validate_description_length,
|
||||
type=validate_description_length,
|
||||
nullable=True,
|
||||
required=False,
|
||||
default="",
|
||||
|
|
@ -101,7 +96,7 @@ dataset_update_parser.add_argument(
|
|||
type=_validate_name,
|
||||
)
|
||||
dataset_update_parser.add_argument(
|
||||
"description", location="json", store_missing=False, type=_validate_description_length
|
||||
"description", location="json", store_missing=False, type=validate_description_length
|
||||
)
|
||||
dataset_update_parser.add_argument(
|
||||
"indexing_technique",
|
||||
|
|
@ -285,7 +280,7 @@ class DatasetListApi(DatasetApiResource):
|
|||
external_knowledge_id=args["external_knowledge_id"],
|
||||
embedding_model_provider=args["embedding_model_provider"],
|
||||
embedding_model_name=args["embedding_model"],
|
||||
retrieval_model=RetrievalModel(**args["retrieval_model"])
|
||||
retrieval_model=RetrievalModel.model_validate(args["retrieval_model"])
|
||||
if args["retrieval_model"] is not None
|
||||
else None,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -136,7 +136,7 @@ class DocumentAddByTextApi(DatasetApiResource):
|
|||
"info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}},
|
||||
}
|
||||
args["data_source"] = data_source
|
||||
knowledge_config = KnowledgeConfig(**args)
|
||||
knowledge_config = KnowledgeConfig.model_validate(args)
|
||||
# validate args
|
||||
DocumentService.document_create_args_validate(knowledge_config)
|
||||
|
||||
|
|
@ -221,7 +221,7 @@ class DocumentUpdateByTextApi(DatasetApiResource):
|
|||
args["data_source"] = data_source
|
||||
# validate args
|
||||
args["original_document_id"] = str(document_id)
|
||||
knowledge_config = KnowledgeConfig(**args)
|
||||
knowledge_config = KnowledgeConfig.model_validate(args)
|
||||
DocumentService.document_create_args_validate(knowledge_config)
|
||||
|
||||
try:
|
||||
|
|
@ -328,7 +328,7 @@ class DocumentAddByFileApi(DatasetApiResource):
|
|||
}
|
||||
args["data_source"] = data_source
|
||||
# validate args
|
||||
knowledge_config = KnowledgeConfig(**args)
|
||||
knowledge_config = KnowledgeConfig.model_validate(args)
|
||||
DocumentService.document_create_args_validate(knowledge_config)
|
||||
|
||||
dataset_process_rule = dataset.latest_process_rule if "process_rule" not in args else None
|
||||
|
|
@ -426,7 +426,7 @@ class DocumentUpdateByFileApi(DatasetApiResource):
|
|||
# validate args
|
||||
args["original_document_id"] = str(document_id)
|
||||
|
||||
knowledge_config = KnowledgeConfig(**args)
|
||||
knowledge_config = KnowledgeConfig.model_validate(args)
|
||||
DocumentService.document_create_args_validate(knowledge_config)
|
||||
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -51,7 +51,7 @@ class DatasetMetadataCreateServiceApi(DatasetApiResource):
|
|||
def post(self, tenant_id, dataset_id):
|
||||
"""Create metadata for a dataset."""
|
||||
args = metadata_create_parser.parse_args()
|
||||
metadata_args = MetadataArgs(**args)
|
||||
metadata_args = MetadataArgs.model_validate(args)
|
||||
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
|
|
@ -200,7 +200,7 @@ class DocumentMetadataEditServiceApi(DatasetApiResource):
|
|||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
|
||||
args = document_metadata_parser.parse_args()
|
||||
metadata_args = MetadataOperationData(**args)
|
||||
metadata_args = MetadataOperationData.model_validate(args)
|
||||
|
||||
MetadataService.update_documents_metadata(dataset, metadata_args)
|
||||
|
||||
|
|
|
|||
|
|
@ -98,7 +98,7 @@ class DatasourceNodeRunApi(DatasetApiResource):
|
|||
parser.add_argument("is_published", type=bool, required=True, location="json")
|
||||
args: ParseResult = parser.parse_args()
|
||||
|
||||
datasource_node_run_api_entity: DatasourceNodeRunApiEntity = DatasourceNodeRunApiEntity(**args)
|
||||
datasource_node_run_api_entity = DatasourceNodeRunApiEntity.model_validate(args)
|
||||
assert isinstance(current_user, Account)
|
||||
rag_pipeline_service: RagPipelineService = RagPipelineService()
|
||||
pipeline: Pipeline = rag_pipeline_service.get_pipeline(tenant_id=tenant_id, dataset_id=dataset_id)
|
||||
|
|
|
|||
|
|
@ -252,7 +252,7 @@ class DatasetSegmentApi(DatasetApiResource):
|
|||
args = segment_update_parser.parse_args()
|
||||
|
||||
updated_segment = SegmentService.update_segment(
|
||||
SegmentUpdateArgs(**args["segment"]), segment, document, dataset
|
||||
SegmentUpdateArgs.model_validate(args["segment"]), segment, document, dataset
|
||||
)
|
||||
return {"data": marshal(updated_segment, segment_fields), "doc_form": document.doc_form}, 200
|
||||
|
||||
|
|
|
|||
|
|
@ -313,7 +313,7 @@ def create_or_update_end_user_for_user_id(app_model: App, user_id: str | None =
|
|||
Create or update session terminal based on user ID.
|
||||
"""
|
||||
if not user_id:
|
||||
user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID.value
|
||||
user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
end_user = (
|
||||
|
|
@ -332,7 +332,7 @@ def create_or_update_end_user_for_user_id(app_model: App, user_id: str | None =
|
|||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
type="service_api",
|
||||
is_anonymous=user_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID.value,
|
||||
is_anonymous=user_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID,
|
||||
session_id=user_id,
|
||||
)
|
||||
session.add(end_user)
|
||||
|
|
|
|||
|
|
@ -126,6 +126,8 @@ def exchange_token_for_existing_web_user(app_code: str, enterprise_user_decoded:
|
|||
end_user_id = enterprise_user_decoded.get("end_user_id")
|
||||
session_id = enterprise_user_decoded.get("session_id")
|
||||
user_auth_type = enterprise_user_decoded.get("auth_type")
|
||||
exchanged_token_expires_unix = enterprise_user_decoded.get("exp")
|
||||
|
||||
if not user_auth_type:
|
||||
raise Unauthorized("Missing auth_type in the token.")
|
||||
|
||||
|
|
@ -169,8 +171,11 @@ def exchange_token_for_existing_web_user(app_code: str, enterprise_user_decoded:
|
|||
)
|
||||
db.session.add(end_user)
|
||||
db.session.commit()
|
||||
exp_dt = datetime.now(UTC) + timedelta(minutes=dify_config.ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
exp = int(exp_dt.timestamp())
|
||||
|
||||
exp = int((datetime.now(UTC) + timedelta(minutes=dify_config.ACCESS_TOKEN_EXPIRE_MINUTES)).timestamp())
|
||||
if exchanged_token_expires_unix:
|
||||
exp = int(exchanged_token_expires_unix)
|
||||
|
||||
payload = {
|
||||
"iss": site.id,
|
||||
"sub": "Web API Passport",
|
||||
|
|
|
|||
|
|
@ -40,7 +40,7 @@ class AgentConfigManager:
|
|||
"credential_id": tool.get("credential_id", None),
|
||||
}
|
||||
|
||||
agent_tools.append(AgentToolEntity(**agent_tool_properties))
|
||||
agent_tools.append(AgentToolEntity.model_validate(agent_tool_properties))
|
||||
|
||||
if "strategy" in config["agent_mode"] and config["agent_mode"]["strategy"] not in {
|
||||
"react_router",
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import uuid
|
||||
from typing import Literal, cast
|
||||
|
||||
from core.app.app_config.entities import (
|
||||
DatasetEntity,
|
||||
|
|
@ -74,6 +75,9 @@ class DatasetConfigManager:
|
|||
return None
|
||||
query_variable = config.get("dataset_query_variable")
|
||||
|
||||
metadata_model_config_dict = dataset_configs.get("metadata_model_config")
|
||||
metadata_filtering_conditions_dict = dataset_configs.get("metadata_filtering_conditions")
|
||||
|
||||
if dataset_configs["retrieval_model"] == "single":
|
||||
return DatasetEntity(
|
||||
dataset_ids=dataset_ids,
|
||||
|
|
@ -82,18 +86,23 @@ class DatasetConfigManager:
|
|||
retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of(
|
||||
dataset_configs["retrieval_model"]
|
||||
),
|
||||
metadata_filtering_mode=dataset_configs.get("metadata_filtering_mode", "disabled"),
|
||||
metadata_model_config=ModelConfig(**dataset_configs.get("metadata_model_config"))
|
||||
if dataset_configs.get("metadata_model_config")
|
||||
metadata_filtering_mode=cast(
|
||||
Literal["disabled", "automatic", "manual"],
|
||||
dataset_configs.get("metadata_filtering_mode", "disabled"),
|
||||
),
|
||||
metadata_model_config=ModelConfig(**metadata_model_config_dict)
|
||||
if isinstance(metadata_model_config_dict, dict)
|
||||
else None,
|
||||
metadata_filtering_conditions=MetadataFilteringCondition(
|
||||
**dataset_configs.get("metadata_filtering_conditions", {})
|
||||
)
|
||||
if dataset_configs.get("metadata_filtering_conditions")
|
||||
metadata_filtering_conditions=MetadataFilteringCondition(**metadata_filtering_conditions_dict)
|
||||
if isinstance(metadata_filtering_conditions_dict, dict)
|
||||
else None,
|
||||
),
|
||||
)
|
||||
else:
|
||||
score_threshold_val = dataset_configs.get("score_threshold")
|
||||
reranking_model_val = dataset_configs.get("reranking_model")
|
||||
weights_val = dataset_configs.get("weights")
|
||||
|
||||
return DatasetEntity(
|
||||
dataset_ids=dataset_ids,
|
||||
retrieve_config=DatasetRetrieveConfigEntity(
|
||||
|
|
@ -101,22 +110,23 @@ class DatasetConfigManager:
|
|||
retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of(
|
||||
dataset_configs["retrieval_model"]
|
||||
),
|
||||
top_k=dataset_configs.get("top_k", 4),
|
||||
score_threshold=dataset_configs.get("score_threshold")
|
||||
if dataset_configs.get("score_threshold_enabled", False)
|
||||
top_k=int(dataset_configs.get("top_k", 4)),
|
||||
score_threshold=float(score_threshold_val)
|
||||
if dataset_configs.get("score_threshold_enabled", False) and score_threshold_val is not None
|
||||
else None,
|
||||
reranking_model=dataset_configs.get("reranking_model"),
|
||||
weights=dataset_configs.get("weights"),
|
||||
reranking_enabled=dataset_configs.get("reranking_enabled", True),
|
||||
reranking_model=reranking_model_val if isinstance(reranking_model_val, dict) else None,
|
||||
weights=weights_val if isinstance(weights_val, dict) else None,
|
||||
reranking_enabled=bool(dataset_configs.get("reranking_enabled", True)),
|
||||
rerank_mode=dataset_configs.get("reranking_mode", "reranking_model"),
|
||||
metadata_filtering_mode=dataset_configs.get("metadata_filtering_mode", "disabled"),
|
||||
metadata_model_config=ModelConfig(**dataset_configs.get("metadata_model_config"))
|
||||
if dataset_configs.get("metadata_model_config")
|
||||
metadata_filtering_mode=cast(
|
||||
Literal["disabled", "automatic", "manual"],
|
||||
dataset_configs.get("metadata_filtering_mode", "disabled"),
|
||||
),
|
||||
metadata_model_config=ModelConfig(**metadata_model_config_dict)
|
||||
if isinstance(metadata_model_config_dict, dict)
|
||||
else None,
|
||||
metadata_filtering_conditions=MetadataFilteringCondition(
|
||||
**dataset_configs.get("metadata_filtering_conditions", {})
|
||||
)
|
||||
if dataset_configs.get("metadata_filtering_conditions")
|
||||
metadata_filtering_conditions=MetadataFilteringCondition(**metadata_filtering_conditions_dict)
|
||||
if isinstance(metadata_filtering_conditions_dict, dict)
|
||||
else None,
|
||||
),
|
||||
)
|
||||
|
|
@ -134,18 +144,17 @@ class DatasetConfigManager:
|
|||
config = cls.extract_dataset_config_for_legacy_compatibility(tenant_id, app_mode, config)
|
||||
|
||||
# dataset_configs
|
||||
if not config.get("dataset_configs"):
|
||||
config["dataset_configs"] = {"retrieval_model": "single"}
|
||||
if "dataset_configs" not in config or not config.get("dataset_configs"):
|
||||
config["dataset_configs"] = {}
|
||||
config["dataset_configs"]["retrieval_model"] = config["dataset_configs"].get("retrieval_model", "single")
|
||||
|
||||
if not isinstance(config["dataset_configs"], dict):
|
||||
raise ValueError("dataset_configs must be of object type")
|
||||
|
||||
if not config["dataset_configs"].get("datasets"):
|
||||
if "datasets" not in config["dataset_configs"] or not config["dataset_configs"].get("datasets"):
|
||||
config["dataset_configs"]["datasets"] = {"strategy": "router", "datasets": []}
|
||||
|
||||
need_manual_query_datasets = config.get("dataset_configs") and config["dataset_configs"].get(
|
||||
"datasets", {}
|
||||
).get("datasets")
|
||||
need_manual_query_datasets = config.get("dataset_configs", {}).get("datasets", {}).get("datasets")
|
||||
|
||||
if need_manual_query_datasets and app_mode == AppMode.COMPLETION:
|
||||
# Only check when mode is completion
|
||||
|
|
@ -166,8 +175,8 @@ class DatasetConfigManager:
|
|||
:param config: app model config args
|
||||
"""
|
||||
# Extract dataset config for legacy compatibility
|
||||
if not config.get("agent_mode"):
|
||||
config["agent_mode"] = {"enabled": False, "tools": []}
|
||||
if "agent_mode" not in config or not config.get("agent_mode"):
|
||||
config["agent_mode"] = {}
|
||||
|
||||
if not isinstance(config["agent_mode"], dict):
|
||||
raise ValueError("agent_mode must be of object type")
|
||||
|
|
@ -180,19 +189,22 @@ class DatasetConfigManager:
|
|||
raise ValueError("enabled in agent_mode must be of boolean type")
|
||||
|
||||
# tools
|
||||
if not config["agent_mode"].get("tools"):
|
||||
if "tools" not in config["agent_mode"] or not config["agent_mode"].get("tools"):
|
||||
config["agent_mode"]["tools"] = []
|
||||
|
||||
if not isinstance(config["agent_mode"]["tools"], list):
|
||||
raise ValueError("tools in agent_mode must be a list of objects")
|
||||
|
||||
# strategy
|
||||
if not config["agent_mode"].get("strategy"):
|
||||
config["agent_mode"]["strategy"] = PlanningStrategy.ROUTER.value
|
||||
if "strategy" not in config["agent_mode"] or not config["agent_mode"].get("strategy"):
|
||||
config["agent_mode"]["strategy"] = PlanningStrategy.ROUTER
|
||||
|
||||
has_datasets = False
|
||||
if config["agent_mode"]["strategy"] in {PlanningStrategy.ROUTER.value, PlanningStrategy.REACT_ROUTER.value}:
|
||||
for tool in config["agent_mode"]["tools"]:
|
||||
if config.get("agent_mode", {}).get("strategy") in {
|
||||
PlanningStrategy.ROUTER,
|
||||
PlanningStrategy.REACT_ROUTER,
|
||||
}:
|
||||
for tool in config.get("agent_mode", {}).get("tools", []):
|
||||
key = list(tool.keys())[0]
|
||||
if key == "dataset":
|
||||
# old style, use tool name as key
|
||||
|
|
@ -217,7 +229,7 @@ class DatasetConfigManager:
|
|||
|
||||
has_datasets = True
|
||||
|
||||
need_manual_query_datasets = has_datasets and config["agent_mode"]["enabled"]
|
||||
need_manual_query_datasets = has_datasets and config.get("agent_mode", {}).get("enabled")
|
||||
|
||||
if need_manual_query_datasets and app_mode == AppMode.COMPLETION:
|
||||
# Only check when mode is completion
|
||||
|
|
|
|||
|
|
@ -68,9 +68,13 @@ class ModelConfigConverter:
|
|||
# get model mode
|
||||
model_mode = model_config.mode
|
||||
if not model_mode:
|
||||
model_mode = LLMMode.CHAT.value
|
||||
model_mode = LLMMode.CHAT
|
||||
if model_schema and model_schema.model_properties.get(ModelPropertyKey.MODE):
|
||||
model_mode = LLMMode(model_schema.model_properties[ModelPropertyKey.MODE]).value
|
||||
try:
|
||||
model_mode = LLMMode(model_schema.model_properties[ModelPropertyKey.MODE])
|
||||
except ValueError:
|
||||
# Fall back to CHAT mode if the stored value is invalid
|
||||
model_mode = LLMMode.CHAT
|
||||
|
||||
if not model_schema:
|
||||
raise ValueError(f"Model {model_name} not exist.")
|
||||
|
|
|
|||
|
|
@ -100,7 +100,7 @@ class PromptTemplateConfigManager:
|
|||
if config["model"]["mode"] not in model_mode_vals:
|
||||
raise ValueError(f"model.mode must be in {model_mode_vals} when prompt_type is advanced")
|
||||
|
||||
if app_mode == AppMode.CHAT and config["model"]["mode"] == ModelMode.COMPLETION.value:
|
||||
if app_mode == AppMode.CHAT and config["model"]["mode"] == ModelMode.COMPLETION:
|
||||
user_prefix = config["completion_prompt_config"]["conversation_histories_role"]["user_prefix"]
|
||||
assistant_prefix = config["completion_prompt_config"]["conversation_histories_role"]["assistant_prefix"]
|
||||
|
||||
|
|
@ -110,7 +110,7 @@ class PromptTemplateConfigManager:
|
|||
if not assistant_prefix:
|
||||
config["completion_prompt_config"]["conversation_histories_role"]["assistant_prefix"] = "Assistant"
|
||||
|
||||
if config["model"]["mode"] == ModelMode.CHAT.value:
|
||||
if config["model"]["mode"] == ModelMode.CHAT:
|
||||
prompt_list = config["chat_prompt_config"]["prompt"]
|
||||
|
||||
if len(prompt_list) > 10:
|
||||
|
|
|
|||
|
|
@ -186,7 +186,7 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
|
|||
raise ValueError("enabled in agent_mode must be of boolean type")
|
||||
|
||||
if not agent_mode.get("strategy"):
|
||||
agent_mode["strategy"] = PlanningStrategy.ROUTER.value
|
||||
agent_mode["strategy"] = PlanningStrategy.ROUTER
|
||||
|
||||
if agent_mode["strategy"] not in [member.value for member in list(PlanningStrategy.__members__.values())]:
|
||||
raise ValueError("strategy in agent_mode must be in the specified strategy list")
|
||||
|
|
|
|||
|
|
@ -198,9 +198,9 @@ class AgentChatAppRunner(AppRunner):
|
|||
# start agent runner
|
||||
if agent_entity.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT:
|
||||
# check LLM mode
|
||||
if model_schema.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT.value:
|
||||
if model_schema.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT:
|
||||
runner_cls = CotChatAgentRunner
|
||||
elif model_schema.model_properties.get(ModelPropertyKey.MODE) == LLMMode.COMPLETION.value:
|
||||
elif model_schema.model_properties.get(ModelPropertyKey.MODE) == LLMMode.COMPLETION:
|
||||
runner_cls = CotCompletionAgentRunner
|
||||
else:
|
||||
raise ValueError(f"Invalid LLM mode: {model_schema.model_properties.get(ModelPropertyKey.MODE)}")
|
||||
|
|
|
|||
|
|
@ -61,9 +61,6 @@ class AppRunner:
|
|||
if model_context_tokens is None:
|
||||
return -1
|
||||
|
||||
if max_tokens is None:
|
||||
max_tokens = 0
|
||||
|
||||
prompt_tokens = model_instance.get_llm_num_tokens(prompt_messages)
|
||||
|
||||
if prompt_tokens + max_tokens > model_context_tokens:
|
||||
|
|
|
|||
|
|
@ -116,7 +116,7 @@ class PipelineRunner(WorkflowBasedAppRunner):
|
|||
rag_pipeline_variables = []
|
||||
if workflow.rag_pipeline_variables:
|
||||
for v in workflow.rag_pipeline_variables:
|
||||
rag_pipeline_variable = RAGPipelineVariable(**v)
|
||||
rag_pipeline_variable = RAGPipelineVariable.model_validate(v)
|
||||
if (
|
||||
rag_pipeline_variable.belong_to_node_id
|
||||
in (self.application_generate_entity.start_node_id, "shared")
|
||||
|
|
@ -229,8 +229,8 @@ class PipelineRunner(WorkflowBasedAppRunner):
|
|||
workflow_id=workflow.id,
|
||||
graph_config=graph_config,
|
||||
user_id=self.application_generate_entity.user_id,
|
||||
user_from=UserFrom.ACCOUNT.value,
|
||||
invoke_from=InvokeFrom.SERVICE_API.value,
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -100,8 +100,8 @@ class WorkflowBasedAppRunner:
|
|||
workflow_id=workflow_id,
|
||||
graph_config=graph_config,
|
||||
user_id=user_id,
|
||||
user_from=UserFrom.ACCOUNT.value,
|
||||
invoke_from=InvokeFrom.SERVICE_API.value,
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
|
|
@ -244,8 +244,8 @@ class WorkflowBasedAppRunner:
|
|||
workflow_id=workflow.id,
|
||||
graph_config=graph_config,
|
||||
user_id="",
|
||||
user_from=UserFrom.ACCOUNT.value,
|
||||
invoke_from=InvokeFrom.SERVICE_API.value,
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -107,7 +107,6 @@ class MessageCycleManager:
|
|||
if dify_config.DEBUG:
|
||||
logger.exception("generate conversation name failed, conversation_id: %s", conversation_id)
|
||||
|
||||
db.session.merge(conversation)
|
||||
db.session.commit()
|
||||
db.session.close()
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from openai import BaseModel
|
||||
from pydantic import Field
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
# Import InvokeFrom locally to avoid circular import
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
|
|
|
|||
|
|
@ -49,7 +49,7 @@ class DatasourceProviderApiEntity(BaseModel):
|
|||
for datasource in datasources:
|
||||
if datasource.get("parameters"):
|
||||
for parameter in datasource.get("parameters"):
|
||||
if parameter.get("type") == DatasourceParameter.DatasourceParameterType.SYSTEM_FILES.value:
|
||||
if parameter.get("type") == DatasourceParameter.DatasourceParameterType.SYSTEM_FILES:
|
||||
parameter["type"] = "files"
|
||||
# -------------
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
|
||||
class I18nObject(BaseModel):
|
||||
|
|
@ -11,11 +11,12 @@ class I18nObject(BaseModel):
|
|||
pt_BR: str | None = Field(default=None)
|
||||
ja_JP: str | None = Field(default=None)
|
||||
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
@model_validator(mode="after")
|
||||
def _(self):
|
||||
self.zh_Hans = self.zh_Hans or self.en_US
|
||||
self.pt_BR = self.pt_BR or self.en_US
|
||||
self.ja_JP = self.ja_JP or self.en_US
|
||||
return self
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {"zh_Hans": self.zh_Hans, "en_US": self.en_US, "pt_BR": self.pt_BR, "ja_JP": self.ja_JP}
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import enum
|
||||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field, ValidationInfo, field_validator
|
||||
|
|
@ -54,16 +54,16 @@ class DatasourceParameter(PluginParameter):
|
|||
removes TOOLS_SELECTOR from PluginParameterType
|
||||
"""
|
||||
|
||||
STRING = PluginParameterType.STRING.value
|
||||
NUMBER = PluginParameterType.NUMBER.value
|
||||
BOOLEAN = PluginParameterType.BOOLEAN.value
|
||||
SELECT = PluginParameterType.SELECT.value
|
||||
SECRET_INPUT = PluginParameterType.SECRET_INPUT.value
|
||||
FILE = PluginParameterType.FILE.value
|
||||
FILES = PluginParameterType.FILES.value
|
||||
STRING = PluginParameterType.STRING
|
||||
NUMBER = PluginParameterType.NUMBER
|
||||
BOOLEAN = PluginParameterType.BOOLEAN
|
||||
SELECT = PluginParameterType.SELECT
|
||||
SECRET_INPUT = PluginParameterType.SECRET_INPUT
|
||||
FILE = PluginParameterType.FILE
|
||||
FILES = PluginParameterType.FILES
|
||||
|
||||
# deprecated, should not use.
|
||||
SYSTEM_FILES = PluginParameterType.SYSTEM_FILES.value
|
||||
SYSTEM_FILES = PluginParameterType.SYSTEM_FILES
|
||||
|
||||
def as_normal_type(self):
|
||||
return as_normal_type(self)
|
||||
|
|
@ -218,7 +218,7 @@ class DatasourceLabel(BaseModel):
|
|||
icon: str = Field(..., description="The icon of the tool")
|
||||
|
||||
|
||||
class DatasourceInvokeFrom(Enum):
|
||||
class DatasourceInvokeFrom(StrEnum):
|
||||
"""
|
||||
Enum class for datasource invoke
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from collections import defaultdict
|
|||
from collections.abc import Iterator, Sequence
|
||||
from json import JSONDecodeError
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
|
|
@ -73,9 +73,8 @@ class ProviderConfiguration(BaseModel):
|
|||
# pydantic configs
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _(self):
|
||||
if self.provider.provider not in original_provider_configurate_methods:
|
||||
original_provider_configurate_methods[self.provider.provider] = []
|
||||
for configurate_method in self.provider.configurate_methods:
|
||||
|
|
@ -90,6 +89,7 @@ class ProviderConfiguration(BaseModel):
|
|||
and ConfigurateMethod.PREDEFINED_MODEL not in self.provider.configurate_methods
|
||||
):
|
||||
self.provider.configurate_methods.append(ConfigurateMethod.PREDEFINED_MODEL)
|
||||
return self
|
||||
|
||||
def get_current_credentials(self, model_type: ModelType, model: str) -> dict | None:
|
||||
"""
|
||||
|
|
@ -207,7 +207,7 @@ class ProviderConfiguration(BaseModel):
|
|||
"""
|
||||
stmt = select(Provider).where(
|
||||
Provider.tenant_id == self.tenant_id,
|
||||
Provider.provider_type == ProviderType.CUSTOM.value,
|
||||
Provider.provider_type == ProviderType.CUSTOM,
|
||||
Provider.provider_name.in_(self._get_provider_names()),
|
||||
)
|
||||
|
||||
|
|
@ -458,7 +458,7 @@ class ProviderConfiguration(BaseModel):
|
|||
provider_record = Provider(
|
||||
tenant_id=self.tenant_id,
|
||||
provider_name=self.provider.provider,
|
||||
provider_type=ProviderType.CUSTOM.value,
|
||||
provider_type=ProviderType.CUSTOM,
|
||||
is_valid=True,
|
||||
credential_id=new_record.id,
|
||||
)
|
||||
|
|
@ -1414,7 +1414,7 @@ class ProviderConfiguration(BaseModel):
|
|||
"""
|
||||
secret_input_form_variables = []
|
||||
for credential_form_schema in credential_form_schemas:
|
||||
if credential_form_schema.type.value == FormType.SECRET_INPUT.value:
|
||||
if credential_form_schema.type == FormType.SECRET_INPUT:
|
||||
secret_input_form_variables.append(credential_form_schema.variable)
|
||||
|
||||
return secret_input_form_variables
|
||||
|
|
|
|||
|
|
@ -1,13 +1,13 @@
|
|||
from typing import cast
|
||||
|
||||
import requests
|
||||
import httpx
|
||||
|
||||
from configs import dify_config
|
||||
from models.api_based_extension import APIBasedExtensionPoint
|
||||
|
||||
|
||||
class APIBasedExtensionRequestor:
|
||||
timeout: tuple[int, int] = (5, 60)
|
||||
timeout: httpx.Timeout = httpx.Timeout(60.0, connect=5.0)
|
||||
"""timeout for request connect and read"""
|
||||
|
||||
def __init__(self, api_endpoint: str, api_key: str):
|
||||
|
|
@ -27,25 +27,23 @@ class APIBasedExtensionRequestor:
|
|||
url = self.api_endpoint
|
||||
|
||||
try:
|
||||
# proxy support for security
|
||||
proxies = None
|
||||
mounts: dict[str, httpx.BaseTransport] | None = None
|
||||
if dify_config.SSRF_PROXY_HTTP_URL and dify_config.SSRF_PROXY_HTTPS_URL:
|
||||
proxies = {
|
||||
"http": dify_config.SSRF_PROXY_HTTP_URL,
|
||||
"https": dify_config.SSRF_PROXY_HTTPS_URL,
|
||||
mounts = {
|
||||
"http://": httpx.HTTPTransport(proxy=dify_config.SSRF_PROXY_HTTP_URL),
|
||||
"https://": httpx.HTTPTransport(proxy=dify_config.SSRF_PROXY_HTTPS_URL),
|
||||
}
|
||||
|
||||
response = requests.request(
|
||||
method="POST",
|
||||
url=url,
|
||||
json={"point": point.value, "params": params},
|
||||
headers=headers,
|
||||
timeout=self.timeout,
|
||||
proxies=proxies,
|
||||
)
|
||||
except requests.Timeout:
|
||||
with httpx.Client(mounts=mounts, timeout=self.timeout) as client:
|
||||
response = client.request(
|
||||
method="POST",
|
||||
url=url,
|
||||
json={"point": point.value, "params": params},
|
||||
headers=headers,
|
||||
)
|
||||
except httpx.TimeoutException:
|
||||
raise ValueError("request timeout")
|
||||
except requests.ConnectionError:
|
||||
except httpx.RequestError:
|
||||
raise ValueError("request connection error")
|
||||
|
||||
if response.status_code != 200:
|
||||
|
|
|
|||
|
|
@ -131,7 +131,7 @@ class CodeExecutor:
|
|||
if (code := response_data.get("code")) != 0:
|
||||
raise CodeExecutionError(f"Got error code: {code}. Got error msg: {response_data.get('message')}")
|
||||
|
||||
response_code = CodeExecutionResponse(**response_data)
|
||||
response_code = CodeExecutionResponse.model_validate(response_data)
|
||||
|
||||
if response_code.data.error:
|
||||
raise CodeExecutionError(response_code.data.error)
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ def batch_fetch_plugin_manifests(plugin_ids: list[str]) -> Sequence[MarketplaceP
|
|||
response = httpx.post(url, json={"plugin_ids": plugin_ids}, headers={"X-Dify-Version": dify_config.project.version})
|
||||
response.raise_for_status()
|
||||
|
||||
return [MarketplacePluginDeclaration(**plugin) for plugin in response.json()["data"]["plugins"]]
|
||||
return [MarketplacePluginDeclaration.model_validate(plugin) for plugin in response.json()["data"]["plugins"]]
|
||||
|
||||
|
||||
def batch_fetch_plugin_manifests_ignore_deserialization_error(
|
||||
|
|
@ -41,7 +41,7 @@ def batch_fetch_plugin_manifests_ignore_deserialization_error(
|
|||
result: list[MarketplacePluginDeclaration] = []
|
||||
for plugin in response.json()["data"]["plugins"]:
|
||||
try:
|
||||
result.append(MarketplacePluginDeclaration(**plugin))
|
||||
result.append(MarketplacePluginDeclaration.model_validate(plugin))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ from core.rag.cleaner.clean_processor import CleanProcessor
|
|||
from core.rag.datasource.keyword.keyword_factory import Keyword
|
||||
from core.rag.docstore.dataset_docstore import DatasetDocumentStore
|
||||
from core.rag.extractor.entity.datasource_type import DatasourceType
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo, WebsiteInfo
|
||||
from core.rag.index_processor.constant.index_type import IndexType
|
||||
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
|
||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
|
|
@ -343,7 +343,7 @@ class IndexingRunner:
|
|||
|
||||
if file_detail:
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type=DatasourceType.FILE.value,
|
||||
datasource_type=DatasourceType.FILE,
|
||||
upload_file=file_detail,
|
||||
document_model=dataset_document.doc_form,
|
||||
)
|
||||
|
|
@ -356,15 +356,17 @@ class IndexingRunner:
|
|||
):
|
||||
raise ValueError("no notion import info found")
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type=DatasourceType.NOTION.value,
|
||||
notion_info={
|
||||
"credential_id": data_source_info["credential_id"],
|
||||
"notion_workspace_id": data_source_info["notion_workspace_id"],
|
||||
"notion_obj_id": data_source_info["notion_page_id"],
|
||||
"notion_page_type": data_source_info["type"],
|
||||
"document": dataset_document,
|
||||
"tenant_id": dataset_document.tenant_id,
|
||||
},
|
||||
datasource_type=DatasourceType.NOTION,
|
||||
notion_info=NotionInfo.model_validate(
|
||||
{
|
||||
"credential_id": data_source_info["credential_id"],
|
||||
"notion_workspace_id": data_source_info["notion_workspace_id"],
|
||||
"notion_obj_id": data_source_info["notion_page_id"],
|
||||
"notion_page_type": data_source_info["type"],
|
||||
"document": dataset_document,
|
||||
"tenant_id": dataset_document.tenant_id,
|
||||
}
|
||||
),
|
||||
document_model=dataset_document.doc_form,
|
||||
)
|
||||
text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"])
|
||||
|
|
@ -377,15 +379,17 @@ class IndexingRunner:
|
|||
):
|
||||
raise ValueError("no website import info found")
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type=DatasourceType.WEBSITE.value,
|
||||
website_info={
|
||||
"provider": data_source_info["provider"],
|
||||
"job_id": data_source_info["job_id"],
|
||||
"tenant_id": dataset_document.tenant_id,
|
||||
"url": data_source_info["url"],
|
||||
"mode": data_source_info["mode"],
|
||||
"only_main_content": data_source_info["only_main_content"],
|
||||
},
|
||||
datasource_type=DatasourceType.WEBSITE,
|
||||
website_info=WebsiteInfo.model_validate(
|
||||
{
|
||||
"provider": data_source_info["provider"],
|
||||
"job_id": data_source_info["job_id"],
|
||||
"tenant_id": dataset_document.tenant_id,
|
||||
"url": data_source_info["url"],
|
||||
"mode": data_source_info["mode"],
|
||||
"only_main_content": data_source_info["only_main_content"],
|
||||
}
|
||||
),
|
||||
document_model=dataset_document.doc_form,
|
||||
)
|
||||
text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"])
|
||||
|
|
|
|||
|
|
@ -224,8 +224,8 @@ def _handle_native_json_schema(
|
|||
|
||||
# Set appropriate response format if required by the model
|
||||
for rule in rules:
|
||||
if rule.name == "response_format" and ResponseFormat.JSON_SCHEMA.value in rule.options:
|
||||
model_parameters["response_format"] = ResponseFormat.JSON_SCHEMA.value
|
||||
if rule.name == "response_format" and ResponseFormat.JSON_SCHEMA in rule.options:
|
||||
model_parameters["response_format"] = ResponseFormat.JSON_SCHEMA
|
||||
|
||||
return model_parameters
|
||||
|
||||
|
|
@ -239,10 +239,10 @@ def _set_response_format(model_parameters: dict, rules: list):
|
|||
"""
|
||||
for rule in rules:
|
||||
if rule.name == "response_format":
|
||||
if ResponseFormat.JSON.value in rule.options:
|
||||
model_parameters["response_format"] = ResponseFormat.JSON.value
|
||||
elif ResponseFormat.JSON_OBJECT.value in rule.options:
|
||||
model_parameters["response_format"] = ResponseFormat.JSON_OBJECT.value
|
||||
if ResponseFormat.JSON in rule.options:
|
||||
model_parameters["response_format"] = ResponseFormat.JSON
|
||||
elif ResponseFormat.JSON_OBJECT in rule.options:
|
||||
model_parameters["response_format"] = ResponseFormat.JSON_OBJECT
|
||||
|
||||
|
||||
def _handle_prompt_based_schema(
|
||||
|
|
|
|||
|
|
@ -294,7 +294,7 @@ class ClientSession(
|
|||
method="completion/complete",
|
||||
params=types.CompleteRequestParams(
|
||||
ref=ref,
|
||||
argument=types.CompletionArgument(**argument),
|
||||
argument=types.CompletionArgument.model_validate(argument),
|
||||
),
|
||||
)
|
||||
),
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
|
||||
class I18nObject(BaseModel):
|
||||
|
|
@ -9,7 +9,8 @@ class I18nObject(BaseModel):
|
|||
zh_Hans: str | None = None
|
||||
en_US: str
|
||||
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
@model_validator(mode="after")
|
||||
def _(self):
|
||||
if not self.zh_Hans:
|
||||
self.zh_Hans = self.en_US
|
||||
return self
|
||||
|
|
|
|||
|
|
@ -1,13 +1,13 @@
|
|||
from collections.abc import Sequence
|
||||
from enum import Enum, StrEnum, auto
|
||||
from enum import StrEnum, auto
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
|
||||
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
|
||||
|
||||
|
||||
class ConfigurateMethod(Enum):
|
||||
class ConfigurateMethod(StrEnum):
|
||||
"""
|
||||
Enum class for configurate method of provider model.
|
||||
"""
|
||||
|
|
@ -46,10 +46,11 @@ class FormOption(BaseModel):
|
|||
value: str
|
||||
show_on: list[FormShowOnObject] = []
|
||||
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
@model_validator(mode="after")
|
||||
def _(self):
|
||||
if not self.label:
|
||||
self.label = I18nObject(en_US=self.value)
|
||||
return self
|
||||
|
||||
|
||||
class CredentialFormSchema(BaseModel):
|
||||
|
|
|
|||
|
|
@ -269,17 +269,17 @@ class ModelProviderFactory:
|
|||
}
|
||||
|
||||
if model_type == ModelType.LLM:
|
||||
return LargeLanguageModel(**init_params) # type: ignore
|
||||
return LargeLanguageModel.model_validate(init_params)
|
||||
elif model_type == ModelType.TEXT_EMBEDDING:
|
||||
return TextEmbeddingModel(**init_params) # type: ignore
|
||||
return TextEmbeddingModel.model_validate(init_params)
|
||||
elif model_type == ModelType.RERANK:
|
||||
return RerankModel(**init_params) # type: ignore
|
||||
return RerankModel.model_validate(init_params)
|
||||
elif model_type == ModelType.SPEECH2TEXT:
|
||||
return Speech2TextModel(**init_params) # type: ignore
|
||||
return Speech2TextModel.model_validate(init_params)
|
||||
elif model_type == ModelType.MODERATION:
|
||||
return ModerationModel(**init_params) # type: ignore
|
||||
return ModerationModel.model_validate(init_params)
|
||||
elif model_type == ModelType.TTS:
|
||||
return TTSModel(**init_params) # type: ignore
|
||||
return TTSModel.model_validate(init_params)
|
||||
|
||||
def get_provider_icon(self, provider: str, icon_type: str, lang: str) -> tuple[bytes, str]:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -51,7 +51,7 @@ class ApiModeration(Moderation):
|
|||
params = ModerationInputParams(app_id=self.app_id, inputs=inputs, query=query)
|
||||
|
||||
result = self._get_config_by_requestor(APIBasedExtensionPoint.APP_MODERATION_INPUT, params.model_dump())
|
||||
return ModerationInputsResult(**result)
|
||||
return ModerationInputsResult.model_validate(result)
|
||||
|
||||
return ModerationInputsResult(
|
||||
flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response
|
||||
|
|
@ -67,7 +67,7 @@ class ApiModeration(Moderation):
|
|||
params = ModerationOutputParams(app_id=self.app_id, text=text)
|
||||
|
||||
result = self._get_config_by_requestor(APIBasedExtensionPoint.APP_MODERATION_OUTPUT, params.model_dump())
|
||||
return ModerationOutputsResult(**result)
|
||||
return ModerationOutputsResult.model_validate(result)
|
||||
|
||||
return ModerationOutputsResult(
|
||||
flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response
|
||||
|
|
|
|||
|
|
@ -213,9 +213,9 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
|||
node_metadata.update(json.loads(node_execution.execution_metadata))
|
||||
|
||||
# Determine the correct span kind based on node type
|
||||
span_kind = OpenInferenceSpanKindValues.CHAIN.value
|
||||
span_kind = OpenInferenceSpanKindValues.CHAIN
|
||||
if node_execution.node_type == "llm":
|
||||
span_kind = OpenInferenceSpanKindValues.LLM.value
|
||||
span_kind = OpenInferenceSpanKindValues.LLM
|
||||
provider = process_data.get("model_provider")
|
||||
model = process_data.get("model_name")
|
||||
if provider:
|
||||
|
|
@ -230,18 +230,18 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
|||
node_metadata["prompt_tokens"] = usage_data.get("prompt_tokens", 0)
|
||||
node_metadata["completion_tokens"] = usage_data.get("completion_tokens", 0)
|
||||
elif node_execution.node_type == "dataset_retrieval":
|
||||
span_kind = OpenInferenceSpanKindValues.RETRIEVER.value
|
||||
span_kind = OpenInferenceSpanKindValues.RETRIEVER
|
||||
elif node_execution.node_type == "tool":
|
||||
span_kind = OpenInferenceSpanKindValues.TOOL.value
|
||||
span_kind = OpenInferenceSpanKindValues.TOOL
|
||||
else:
|
||||
span_kind = OpenInferenceSpanKindValues.CHAIN.value
|
||||
span_kind = OpenInferenceSpanKindValues.CHAIN
|
||||
|
||||
node_span = self.tracer.start_span(
|
||||
name=node_execution.node_type,
|
||||
attributes={
|
||||
SpanAttributes.INPUT_VALUE: node_execution.inputs or "{}",
|
||||
SpanAttributes.OUTPUT_VALUE: node_execution.outputs or "{}",
|
||||
SpanAttributes.OPENINFERENCE_SPAN_KIND: span_kind,
|
||||
SpanAttributes.OPENINFERENCE_SPAN_KIND: span_kind.value,
|
||||
SpanAttributes.METADATA: json.dumps(node_metadata, ensure_ascii=False),
|
||||
SpanAttributes.SESSION_ID: trace_info.conversation_id or "",
|
||||
},
|
||||
|
|
|
|||
|
|
@ -73,7 +73,7 @@ class LangFuseDataTrace(BaseTraceInstance):
|
|||
|
||||
if trace_info.message_id:
|
||||
trace_id = trace_info.trace_id or trace_info.message_id
|
||||
name = TraceTaskName.MESSAGE_TRACE.value
|
||||
name = TraceTaskName.MESSAGE_TRACE
|
||||
trace_data = LangfuseTrace(
|
||||
id=trace_id,
|
||||
user_id=user_id,
|
||||
|
|
@ -88,7 +88,7 @@ class LangFuseDataTrace(BaseTraceInstance):
|
|||
self.add_trace(langfuse_trace_data=trace_data)
|
||||
workflow_span_data = LangfuseSpan(
|
||||
id=trace_info.workflow_run_id,
|
||||
name=TraceTaskName.WORKFLOW_TRACE.value,
|
||||
name=TraceTaskName.WORKFLOW_TRACE,
|
||||
input=dict(trace_info.workflow_run_inputs),
|
||||
output=dict(trace_info.workflow_run_outputs),
|
||||
trace_id=trace_id,
|
||||
|
|
@ -103,7 +103,7 @@ class LangFuseDataTrace(BaseTraceInstance):
|
|||
trace_data = LangfuseTrace(
|
||||
id=trace_id,
|
||||
user_id=user_id,
|
||||
name=TraceTaskName.WORKFLOW_TRACE.value,
|
||||
name=TraceTaskName.WORKFLOW_TRACE,
|
||||
input=dict(trace_info.workflow_run_inputs),
|
||||
output=dict(trace_info.workflow_run_outputs),
|
||||
metadata=metadata,
|
||||
|
|
@ -253,7 +253,7 @@ class LangFuseDataTrace(BaseTraceInstance):
|
|||
trace_data = LangfuseTrace(
|
||||
id=trace_id,
|
||||
user_id=user_id,
|
||||
name=TraceTaskName.MESSAGE_TRACE.value,
|
||||
name=TraceTaskName.MESSAGE_TRACE,
|
||||
input={
|
||||
"message": trace_info.inputs,
|
||||
"files": file_list,
|
||||
|
|
@ -303,7 +303,7 @@ class LangFuseDataTrace(BaseTraceInstance):
|
|||
if trace_info.message_data is None:
|
||||
return
|
||||
span_data = LangfuseSpan(
|
||||
name=TraceTaskName.MODERATION_TRACE.value,
|
||||
name=TraceTaskName.MODERATION_TRACE,
|
||||
input=trace_info.inputs,
|
||||
output={
|
||||
"action": trace_info.action,
|
||||
|
|
@ -331,7 +331,7 @@ class LangFuseDataTrace(BaseTraceInstance):
|
|||
)
|
||||
|
||||
generation_data = LangfuseGeneration(
|
||||
name=TraceTaskName.SUGGESTED_QUESTION_TRACE.value,
|
||||
name=TraceTaskName.SUGGESTED_QUESTION_TRACE,
|
||||
input=trace_info.inputs,
|
||||
output=str(trace_info.suggested_question),
|
||||
trace_id=trace_info.trace_id or trace_info.message_id,
|
||||
|
|
@ -349,7 +349,7 @@ class LangFuseDataTrace(BaseTraceInstance):
|
|||
if trace_info.message_data is None:
|
||||
return
|
||||
dataset_retrieval_span_data = LangfuseSpan(
|
||||
name=TraceTaskName.DATASET_RETRIEVAL_TRACE.value,
|
||||
name=TraceTaskName.DATASET_RETRIEVAL_TRACE,
|
||||
input=trace_info.inputs,
|
||||
output={"documents": trace_info.documents},
|
||||
trace_id=trace_info.trace_id or trace_info.message_id,
|
||||
|
|
@ -377,7 +377,7 @@ class LangFuseDataTrace(BaseTraceInstance):
|
|||
|
||||
def generate_name_trace(self, trace_info: GenerateNameTraceInfo):
|
||||
name_generation_trace_data = LangfuseTrace(
|
||||
name=TraceTaskName.GENERATE_NAME_TRACE.value,
|
||||
name=TraceTaskName.GENERATE_NAME_TRACE,
|
||||
input=trace_info.inputs,
|
||||
output=trace_info.outputs,
|
||||
user_id=trace_info.tenant_id,
|
||||
|
|
@ -388,7 +388,7 @@ class LangFuseDataTrace(BaseTraceInstance):
|
|||
self.add_trace(langfuse_trace_data=name_generation_trace_data)
|
||||
|
||||
name_generation_span_data = LangfuseSpan(
|
||||
name=TraceTaskName.GENERATE_NAME_TRACE.value,
|
||||
name=TraceTaskName.GENERATE_NAME_TRACE,
|
||||
input=trace_info.inputs,
|
||||
output=trace_info.outputs,
|
||||
trace_id=trace_info.conversation_id,
|
||||
|
|
|
|||
|
|
@ -81,7 +81,7 @@ class LangSmithDataTrace(BaseTraceInstance):
|
|||
if trace_info.message_id:
|
||||
message_run = LangSmithRunModel(
|
||||
id=trace_info.message_id,
|
||||
name=TraceTaskName.MESSAGE_TRACE.value,
|
||||
name=TraceTaskName.MESSAGE_TRACE,
|
||||
inputs=dict(trace_info.workflow_run_inputs),
|
||||
outputs=dict(trace_info.workflow_run_outputs),
|
||||
run_type=LangSmithRunType.chain,
|
||||
|
|
@ -110,7 +110,7 @@ class LangSmithDataTrace(BaseTraceInstance):
|
|||
file_list=trace_info.file_list,
|
||||
total_tokens=trace_info.total_tokens,
|
||||
id=trace_info.workflow_run_id,
|
||||
name=TraceTaskName.WORKFLOW_TRACE.value,
|
||||
name=TraceTaskName.WORKFLOW_TRACE,
|
||||
inputs=dict(trace_info.workflow_run_inputs),
|
||||
run_type=LangSmithRunType.tool,
|
||||
start_time=trace_info.workflow_data.created_at,
|
||||
|
|
@ -271,7 +271,7 @@ class LangSmithDataTrace(BaseTraceInstance):
|
|||
output_tokens=trace_info.answer_tokens,
|
||||
total_tokens=trace_info.total_tokens,
|
||||
id=message_id,
|
||||
name=TraceTaskName.MESSAGE_TRACE.value,
|
||||
name=TraceTaskName.MESSAGE_TRACE,
|
||||
inputs=trace_info.inputs,
|
||||
run_type=LangSmithRunType.chain,
|
||||
start_time=trace_info.start_time,
|
||||
|
|
@ -327,7 +327,7 @@ class LangSmithDataTrace(BaseTraceInstance):
|
|||
if trace_info.message_data is None:
|
||||
return
|
||||
langsmith_run = LangSmithRunModel(
|
||||
name=TraceTaskName.MODERATION_TRACE.value,
|
||||
name=TraceTaskName.MODERATION_TRACE,
|
||||
inputs=trace_info.inputs,
|
||||
outputs={
|
||||
"action": trace_info.action,
|
||||
|
|
@ -362,7 +362,7 @@ class LangSmithDataTrace(BaseTraceInstance):
|
|||
if message_data is None:
|
||||
return
|
||||
suggested_question_run = LangSmithRunModel(
|
||||
name=TraceTaskName.SUGGESTED_QUESTION_TRACE.value,
|
||||
name=TraceTaskName.SUGGESTED_QUESTION_TRACE,
|
||||
inputs=trace_info.inputs,
|
||||
outputs=trace_info.suggested_question,
|
||||
run_type=LangSmithRunType.tool,
|
||||
|
|
@ -391,7 +391,7 @@ class LangSmithDataTrace(BaseTraceInstance):
|
|||
if trace_info.message_data is None:
|
||||
return
|
||||
dataset_retrieval_run = LangSmithRunModel(
|
||||
name=TraceTaskName.DATASET_RETRIEVAL_TRACE.value,
|
||||
name=TraceTaskName.DATASET_RETRIEVAL_TRACE,
|
||||
inputs=trace_info.inputs,
|
||||
outputs={"documents": trace_info.documents},
|
||||
run_type=LangSmithRunType.retriever,
|
||||
|
|
@ -447,7 +447,7 @@ class LangSmithDataTrace(BaseTraceInstance):
|
|||
|
||||
def generate_name_trace(self, trace_info: GenerateNameTraceInfo):
|
||||
name_run = LangSmithRunModel(
|
||||
name=TraceTaskName.GENERATE_NAME_TRACE.value,
|
||||
name=TraceTaskName.GENERATE_NAME_TRACE,
|
||||
inputs=trace_info.inputs,
|
||||
outputs=trace_info.outputs,
|
||||
run_type=LangSmithRunType.tool,
|
||||
|
|
|
|||
|
|
@ -108,7 +108,7 @@ class OpikDataTrace(BaseTraceInstance):
|
|||
|
||||
trace_data = {
|
||||
"id": opik_trace_id,
|
||||
"name": TraceTaskName.MESSAGE_TRACE.value,
|
||||
"name": TraceTaskName.MESSAGE_TRACE,
|
||||
"start_time": trace_info.start_time,
|
||||
"end_time": trace_info.end_time,
|
||||
"metadata": workflow_metadata,
|
||||
|
|
@ -125,7 +125,7 @@ class OpikDataTrace(BaseTraceInstance):
|
|||
"id": root_span_id,
|
||||
"parent_span_id": None,
|
||||
"trace_id": opik_trace_id,
|
||||
"name": TraceTaskName.WORKFLOW_TRACE.value,
|
||||
"name": TraceTaskName.WORKFLOW_TRACE,
|
||||
"input": wrap_dict("input", trace_info.workflow_run_inputs),
|
||||
"output": wrap_dict("output", trace_info.workflow_run_outputs),
|
||||
"start_time": trace_info.start_time,
|
||||
|
|
@ -138,7 +138,7 @@ class OpikDataTrace(BaseTraceInstance):
|
|||
else:
|
||||
trace_data = {
|
||||
"id": opik_trace_id,
|
||||
"name": TraceTaskName.MESSAGE_TRACE.value,
|
||||
"name": TraceTaskName.MESSAGE_TRACE,
|
||||
"start_time": trace_info.start_time,
|
||||
"end_time": trace_info.end_time,
|
||||
"metadata": workflow_metadata,
|
||||
|
|
@ -290,7 +290,7 @@ class OpikDataTrace(BaseTraceInstance):
|
|||
|
||||
trace_data = {
|
||||
"id": prepare_opik_uuid(trace_info.start_time, dify_trace_id),
|
||||
"name": TraceTaskName.MESSAGE_TRACE.value,
|
||||
"name": TraceTaskName.MESSAGE_TRACE,
|
||||
"start_time": trace_info.start_time,
|
||||
"end_time": trace_info.end_time,
|
||||
"metadata": wrap_metadata(metadata),
|
||||
|
|
@ -329,7 +329,7 @@ class OpikDataTrace(BaseTraceInstance):
|
|||
|
||||
span_data = {
|
||||
"trace_id": prepare_opik_uuid(start_time, trace_info.trace_id or trace_info.message_id),
|
||||
"name": TraceTaskName.MODERATION_TRACE.value,
|
||||
"name": TraceTaskName.MODERATION_TRACE,
|
||||
"type": "tool",
|
||||
"start_time": start_time,
|
||||
"end_time": trace_info.end_time or trace_info.message_data.updated_at,
|
||||
|
|
@ -355,7 +355,7 @@ class OpikDataTrace(BaseTraceInstance):
|
|||
|
||||
span_data = {
|
||||
"trace_id": prepare_opik_uuid(start_time, trace_info.trace_id or trace_info.message_id),
|
||||
"name": TraceTaskName.SUGGESTED_QUESTION_TRACE.value,
|
||||
"name": TraceTaskName.SUGGESTED_QUESTION_TRACE,
|
||||
"type": "tool",
|
||||
"start_time": start_time,
|
||||
"end_time": trace_info.end_time or message_data.updated_at,
|
||||
|
|
@ -375,7 +375,7 @@ class OpikDataTrace(BaseTraceInstance):
|
|||
|
||||
span_data = {
|
||||
"trace_id": prepare_opik_uuid(start_time, trace_info.trace_id or trace_info.message_id),
|
||||
"name": TraceTaskName.DATASET_RETRIEVAL_TRACE.value,
|
||||
"name": TraceTaskName.DATASET_RETRIEVAL_TRACE,
|
||||
"type": "tool",
|
||||
"start_time": start_time,
|
||||
"end_time": trace_info.end_time or trace_info.message_data.updated_at,
|
||||
|
|
@ -405,7 +405,7 @@ class OpikDataTrace(BaseTraceInstance):
|
|||
def generate_name_trace(self, trace_info: GenerateNameTraceInfo):
|
||||
trace_data = {
|
||||
"id": prepare_opik_uuid(trace_info.start_time, trace_info.trace_id or trace_info.message_id),
|
||||
"name": TraceTaskName.GENERATE_NAME_TRACE.value,
|
||||
"name": TraceTaskName.GENERATE_NAME_TRACE,
|
||||
"start_time": trace_info.start_time,
|
||||
"end_time": trace_info.end_time,
|
||||
"metadata": wrap_metadata(trace_info.metadata),
|
||||
|
|
@ -420,7 +420,7 @@ class OpikDataTrace(BaseTraceInstance):
|
|||
|
||||
span_data = {
|
||||
"trace_id": trace.id,
|
||||
"name": TraceTaskName.GENERATE_NAME_TRACE.value,
|
||||
"name": TraceTaskName.GENERATE_NAME_TRACE,
|
||||
"start_time": trace_info.start_time,
|
||||
"end_time": trace_info.end_time,
|
||||
"metadata": wrap_metadata(trace_info.metadata),
|
||||
|
|
|
|||
|
|
@ -104,7 +104,7 @@ class WeaveDataTrace(BaseTraceInstance):
|
|||
|
||||
message_run = WeaveTraceModel(
|
||||
id=trace_info.message_id,
|
||||
op=str(TraceTaskName.MESSAGE_TRACE.value),
|
||||
op=str(TraceTaskName.MESSAGE_TRACE),
|
||||
inputs=dict(trace_info.workflow_run_inputs),
|
||||
outputs=dict(trace_info.workflow_run_outputs),
|
||||
total_tokens=trace_info.total_tokens,
|
||||
|
|
@ -126,7 +126,7 @@ class WeaveDataTrace(BaseTraceInstance):
|
|||
file_list=trace_info.file_list,
|
||||
total_tokens=trace_info.total_tokens,
|
||||
id=trace_info.workflow_run_id,
|
||||
op=str(TraceTaskName.WORKFLOW_TRACE.value),
|
||||
op=str(TraceTaskName.WORKFLOW_TRACE),
|
||||
inputs=dict(trace_info.workflow_run_inputs),
|
||||
outputs=dict(trace_info.workflow_run_outputs),
|
||||
attributes=workflow_attributes,
|
||||
|
|
@ -253,7 +253,7 @@ class WeaveDataTrace(BaseTraceInstance):
|
|||
|
||||
message_run = WeaveTraceModel(
|
||||
id=trace_id,
|
||||
op=str(TraceTaskName.MESSAGE_TRACE.value),
|
||||
op=str(TraceTaskName.MESSAGE_TRACE),
|
||||
input_tokens=trace_info.message_tokens,
|
||||
output_tokens=trace_info.answer_tokens,
|
||||
total_tokens=trace_info.total_tokens,
|
||||
|
|
@ -300,7 +300,7 @@ class WeaveDataTrace(BaseTraceInstance):
|
|||
|
||||
moderation_run = WeaveTraceModel(
|
||||
id=str(uuid.uuid4()),
|
||||
op=str(TraceTaskName.MODERATION_TRACE.value),
|
||||
op=str(TraceTaskName.MODERATION_TRACE),
|
||||
inputs=trace_info.inputs,
|
||||
outputs={
|
||||
"action": trace_info.action,
|
||||
|
|
@ -330,7 +330,7 @@ class WeaveDataTrace(BaseTraceInstance):
|
|||
|
||||
suggested_question_run = WeaveTraceModel(
|
||||
id=str(uuid.uuid4()),
|
||||
op=str(TraceTaskName.SUGGESTED_QUESTION_TRACE.value),
|
||||
op=str(TraceTaskName.SUGGESTED_QUESTION_TRACE),
|
||||
inputs=trace_info.inputs,
|
||||
outputs=trace_info.suggested_question,
|
||||
attributes=attributes,
|
||||
|
|
@ -355,7 +355,7 @@ class WeaveDataTrace(BaseTraceInstance):
|
|||
|
||||
dataset_retrieval_run = WeaveTraceModel(
|
||||
id=str(uuid.uuid4()),
|
||||
op=str(TraceTaskName.DATASET_RETRIEVAL_TRACE.value),
|
||||
op=str(TraceTaskName.DATASET_RETRIEVAL_TRACE),
|
||||
inputs=trace_info.inputs,
|
||||
outputs={"documents": trace_info.documents},
|
||||
attributes=attributes,
|
||||
|
|
@ -397,7 +397,7 @@ class WeaveDataTrace(BaseTraceInstance):
|
|||
|
||||
name_run = WeaveTraceModel(
|
||||
id=str(uuid.uuid4()),
|
||||
op=str(TraceTaskName.GENERATE_NAME_TRACE.value),
|
||||
op=str(TraceTaskName.GENERATE_NAME_TRACE),
|
||||
inputs=trace_info.inputs,
|
||||
outputs=trace_info.outputs,
|
||||
attributes=attributes,
|
||||
|
|
|
|||
|
|
@ -52,7 +52,7 @@ class PluginNodeBackwardsInvocation(BaseBackwardsInvocation):
|
|||
instruction=instruction, # instruct with variables are not supported
|
||||
)
|
||||
node_data_dict = node_data.model_dump()
|
||||
node_data_dict["type"] = NodeType.PARAMETER_EXTRACTOR.value
|
||||
node_data_dict["type"] = NodeType.PARAMETER_EXTRACTOR
|
||||
execution = workflow_service.run_free_workflow_node(
|
||||
node_data_dict,
|
||||
tenant_id=tenant_id,
|
||||
|
|
|
|||
|
|
@ -83,16 +83,16 @@ class RequestInvokeLLM(BaseRequestInvokeModel):
|
|||
raise ValueError("prompt_messages must be a list")
|
||||
|
||||
for i in range(len(v)):
|
||||
if v[i]["role"] == PromptMessageRole.USER.value:
|
||||
v[i] = UserPromptMessage(**v[i])
|
||||
elif v[i]["role"] == PromptMessageRole.ASSISTANT.value:
|
||||
v[i] = AssistantPromptMessage(**v[i])
|
||||
elif v[i]["role"] == PromptMessageRole.SYSTEM.value:
|
||||
v[i] = SystemPromptMessage(**v[i])
|
||||
elif v[i]["role"] == PromptMessageRole.TOOL.value:
|
||||
v[i] = ToolPromptMessage(**v[i])
|
||||
if v[i]["role"] == PromptMessageRole.USER:
|
||||
v[i] = UserPromptMessage.model_validate(v[i])
|
||||
elif v[i]["role"] == PromptMessageRole.ASSISTANT:
|
||||
v[i] = AssistantPromptMessage.model_validate(v[i])
|
||||
elif v[i]["role"] == PromptMessageRole.SYSTEM:
|
||||
v[i] = SystemPromptMessage.model_validate(v[i])
|
||||
elif v[i]["role"] == PromptMessageRole.TOOL:
|
||||
v[i] = ToolPromptMessage.model_validate(v[i])
|
||||
else:
|
||||
v[i] = PromptMessage(**v[i])
|
||||
v[i] = PromptMessage.model_validate(v[i])
|
||||
|
||||
return v
|
||||
|
||||
|
|
|
|||
|
|
@ -2,11 +2,10 @@ import inspect
|
|||
import json
|
||||
import logging
|
||||
from collections.abc import Callable, Generator
|
||||
from typing import TypeVar
|
||||
from typing import Any, TypeVar
|
||||
|
||||
import requests
|
||||
import httpx
|
||||
from pydantic import BaseModel
|
||||
from requests.exceptions import HTTPError
|
||||
from yarl import URL
|
||||
|
||||
from configs import dify_config
|
||||
|
|
@ -47,29 +46,56 @@ class BasePluginClient:
|
|||
data: bytes | dict | str | None = None,
|
||||
params: dict | None = None,
|
||||
files: dict | None = None,
|
||||
stream: bool = False,
|
||||
) -> requests.Response:
|
||||
) -> httpx.Response:
|
||||
"""
|
||||
Make a request to the plugin daemon inner API.
|
||||
"""
|
||||
url = plugin_daemon_inner_api_baseurl / path
|
||||
headers = headers or {}
|
||||
headers["X-Api-Key"] = dify_config.PLUGIN_DAEMON_KEY
|
||||
headers["Accept-Encoding"] = "gzip, deflate, br"
|
||||
url, headers, prepared_data, params, files = self._prepare_request(path, headers, data, params, files)
|
||||
|
||||
if headers.get("Content-Type") == "application/json" and isinstance(data, dict):
|
||||
data = json.dumps(data)
|
||||
request_kwargs: dict[str, Any] = {
|
||||
"method": method,
|
||||
"url": url,
|
||||
"headers": headers,
|
||||
"params": params,
|
||||
"files": files,
|
||||
}
|
||||
if isinstance(prepared_data, dict):
|
||||
request_kwargs["data"] = prepared_data
|
||||
elif prepared_data is not None:
|
||||
request_kwargs["content"] = prepared_data
|
||||
|
||||
try:
|
||||
response = requests.request(
|
||||
method=method, url=str(url), headers=headers, data=data, params=params, stream=stream, files=files
|
||||
)
|
||||
except requests.ConnectionError:
|
||||
response = httpx.request(**request_kwargs)
|
||||
except httpx.RequestError:
|
||||
logger.exception("Request to Plugin Daemon Service failed")
|
||||
raise PluginDaemonInnerError(code=-500, message="Request to Plugin Daemon Service failed")
|
||||
|
||||
return response
|
||||
|
||||
def _prepare_request(
|
||||
self,
|
||||
path: str,
|
||||
headers: dict | None,
|
||||
data: bytes | dict | str | None,
|
||||
params: dict | None,
|
||||
files: dict | None,
|
||||
) -> tuple[str, dict, bytes | dict | str | None, dict | None, dict | None]:
|
||||
url = plugin_daemon_inner_api_baseurl / path
|
||||
prepared_headers = dict(headers or {})
|
||||
prepared_headers["X-Api-Key"] = dify_config.PLUGIN_DAEMON_KEY
|
||||
prepared_headers.setdefault("Accept-Encoding", "gzip, deflate, br")
|
||||
|
||||
prepared_data: bytes | dict | str | None = (
|
||||
data if isinstance(data, (bytes, str, dict)) or data is None else None
|
||||
)
|
||||
if isinstance(data, dict):
|
||||
if prepared_headers.get("Content-Type") == "application/json":
|
||||
prepared_data = json.dumps(data)
|
||||
else:
|
||||
prepared_data = data
|
||||
|
||||
return str(url), prepared_headers, prepared_data, params, files
|
||||
|
||||
def _stream_request(
|
||||
self,
|
||||
method: str,
|
||||
|
|
@ -78,23 +104,44 @@ class BasePluginClient:
|
|||
headers: dict | None = None,
|
||||
data: bytes | dict | None = None,
|
||||
files: dict | None = None,
|
||||
) -> Generator[bytes, None, None]:
|
||||
) -> Generator[str, None, None]:
|
||||
"""
|
||||
Make a stream request to the plugin daemon inner API
|
||||
"""
|
||||
response = self._request(method, path, headers, data, params, files, stream=True)
|
||||
for line in response.iter_lines(chunk_size=1024 * 8):
|
||||
line = line.decode("utf-8").strip()
|
||||
if line.startswith("data:"):
|
||||
line = line[5:].strip()
|
||||
if line:
|
||||
yield line
|
||||
url, headers, prepared_data, params, files = self._prepare_request(path, headers, data, params, files)
|
||||
|
||||
stream_kwargs: dict[str, Any] = {
|
||||
"method": method,
|
||||
"url": url,
|
||||
"headers": headers,
|
||||
"params": params,
|
||||
"files": files,
|
||||
}
|
||||
if isinstance(prepared_data, dict):
|
||||
stream_kwargs["data"] = prepared_data
|
||||
elif prepared_data is not None:
|
||||
stream_kwargs["content"] = prepared_data
|
||||
|
||||
try:
|
||||
with httpx.stream(**stream_kwargs) as response:
|
||||
for raw_line in response.iter_lines():
|
||||
if raw_line is None:
|
||||
continue
|
||||
line = raw_line.decode("utf-8") if isinstance(raw_line, bytes) else raw_line
|
||||
line = line.strip()
|
||||
if line.startswith("data:"):
|
||||
line = line[5:].strip()
|
||||
if line:
|
||||
yield line
|
||||
except httpx.RequestError:
|
||||
logger.exception("Stream request to Plugin Daemon Service failed")
|
||||
raise PluginDaemonInnerError(code=-500, message="Request to Plugin Daemon Service failed")
|
||||
|
||||
def _stream_request_with_model(
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
type: type[T],
|
||||
type_: type[T],
|
||||
headers: dict | None = None,
|
||||
data: bytes | dict | None = None,
|
||||
params: dict | None = None,
|
||||
|
|
@ -104,13 +151,13 @@ class BasePluginClient:
|
|||
Make a stream request to the plugin daemon inner API and yield the response as a model.
|
||||
"""
|
||||
for line in self._stream_request(method, path, params, headers, data, files):
|
||||
yield type(**json.loads(line)) # type: ignore
|
||||
yield type_(**json.loads(line)) # type: ignore
|
||||
|
||||
def _request_with_model(
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
type: type[T],
|
||||
type_: type[T],
|
||||
headers: dict | None = None,
|
||||
data: bytes | None = None,
|
||||
params: dict | None = None,
|
||||
|
|
@ -120,13 +167,13 @@ class BasePluginClient:
|
|||
Make a request to the plugin daemon inner API and return the response as a model.
|
||||
"""
|
||||
response = self._request(method, path, headers, data, params, files)
|
||||
return type(**response.json()) # type: ignore
|
||||
return type_(**response.json()) # type: ignore
|
||||
|
||||
def _request_with_plugin_daemon_response(
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
type: type[T],
|
||||
type_: type[T],
|
||||
headers: dict | None = None,
|
||||
data: bytes | dict | None = None,
|
||||
params: dict | None = None,
|
||||
|
|
@ -139,23 +186,23 @@ class BasePluginClient:
|
|||
try:
|
||||
response = self._request(method, path, headers, data, params, files)
|
||||
response.raise_for_status()
|
||||
except HTTPError as e:
|
||||
msg = f"Failed to request plugin daemon, status: {e.response.status_code}, url: {path}"
|
||||
logger.exception(msg)
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.exception("Failed to request plugin daemon, status: %s, url: %s", e.response.status_code, path)
|
||||
raise e
|
||||
except Exception as e:
|
||||
msg = f"Failed to request plugin daemon, url: {path}"
|
||||
logger.exception(msg)
|
||||
logger.exception("Failed to request plugin daemon, url: %s", path)
|
||||
raise ValueError(msg) from e
|
||||
|
||||
try:
|
||||
json_response = response.json()
|
||||
if transformer:
|
||||
json_response = transformer(json_response)
|
||||
rep = PluginDaemonBasicResponse[type](**json_response) # type: ignore
|
||||
# https://stackoverflow.com/questions/59634937/variable-foo-class-is-not-valid-as-type-but-why
|
||||
rep = PluginDaemonBasicResponse[type_].model_validate(json_response) # type: ignore
|
||||
except Exception:
|
||||
msg = (
|
||||
f"Failed to parse response from plugin daemon to PluginDaemonBasicResponse [{str(type.__name__)}],"
|
||||
f"Failed to parse response from plugin daemon to PluginDaemonBasicResponse [{str(type_.__name__)}],"
|
||||
f" url: {path}"
|
||||
)
|
||||
logger.exception(msg)
|
||||
|
|
@ -163,7 +210,7 @@ class BasePluginClient:
|
|||
|
||||
if rep.code != 0:
|
||||
try:
|
||||
error = PluginDaemonError(**json.loads(rep.message))
|
||||
error = PluginDaemonError.model_validate(json.loads(rep.message))
|
||||
except Exception:
|
||||
raise ValueError(f"{rep.message}, code: {rep.code}")
|
||||
|
||||
|
|
@ -178,7 +225,7 @@ class BasePluginClient:
|
|||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
type: type[T],
|
||||
type_: type[T],
|
||||
headers: dict | None = None,
|
||||
data: bytes | dict | None = None,
|
||||
params: dict | None = None,
|
||||
|
|
@ -189,7 +236,7 @@ class BasePluginClient:
|
|||
"""
|
||||
for line in self._stream_request(method, path, params, headers, data, files):
|
||||
try:
|
||||
rep = PluginDaemonBasicResponse[type].model_validate_json(line) # type: ignore
|
||||
rep = PluginDaemonBasicResponse[type_].model_validate_json(line) # type: ignore
|
||||
except (ValueError, TypeError):
|
||||
# TODO modify this when line_data has code and message
|
||||
try:
|
||||
|
|
@ -204,11 +251,11 @@ class BasePluginClient:
|
|||
if rep.code != 0:
|
||||
if rep.code == -500:
|
||||
try:
|
||||
error = PluginDaemonError(**json.loads(rep.message))
|
||||
error = PluginDaemonError.model_validate(json.loads(rep.message))
|
||||
except Exception:
|
||||
raise PluginDaemonInnerError(code=rep.code, message=rep.message)
|
||||
|
||||
logger.error("Error in stream reponse for plugin %s", rep.__dict__)
|
||||
logger.error("Error in stream response for plugin %s", rep.__dict__)
|
||||
self._handle_plugin_daemon_error(error.error_type, error.message)
|
||||
raise ValueError(f"plugin daemon: {rep.message}, code: {rep.code}")
|
||||
if rep.data is None:
|
||||
|
|
|
|||
|
|
@ -46,7 +46,9 @@ class PluginDatasourceManager(BasePluginClient):
|
|||
params={"page": 1, "page_size": 256},
|
||||
transformer=transformer,
|
||||
)
|
||||
local_file_datasource_provider = PluginDatasourceProviderEntity(**self._get_local_file_datasource_provider())
|
||||
local_file_datasource_provider = PluginDatasourceProviderEntity.model_validate(
|
||||
self._get_local_file_datasource_provider()
|
||||
)
|
||||
|
||||
for provider in response:
|
||||
ToolTransformService.repack_provider(tenant_id=tenant_id, provider=provider)
|
||||
|
|
@ -104,7 +106,7 @@ class PluginDatasourceManager(BasePluginClient):
|
|||
Fetch datasource provider for the given tenant and plugin.
|
||||
"""
|
||||
if provider_id == "langgenius/file/file":
|
||||
return PluginDatasourceProviderEntity(**self._get_local_file_datasource_provider())
|
||||
return PluginDatasourceProviderEntity.model_validate(self._get_local_file_datasource_provider())
|
||||
|
||||
tool_provider_id = DatasourceProviderID(provider_id)
|
||||
|
||||
|
|
|
|||
|
|
@ -162,7 +162,7 @@ class PluginModelClient(BasePluginClient):
|
|||
response = self._request_with_plugin_daemon_response_stream(
|
||||
method="POST",
|
||||
path=f"plugin/{tenant_id}/dispatch/llm/invoke",
|
||||
type=LLMResultChunk,
|
||||
type_=LLMResultChunk,
|
||||
data=jsonable_encoder(
|
||||
{
|
||||
"user_id": user_id,
|
||||
|
|
@ -208,7 +208,7 @@ class PluginModelClient(BasePluginClient):
|
|||
response = self._request_with_plugin_daemon_response_stream(
|
||||
method="POST",
|
||||
path=f"plugin/{tenant_id}/dispatch/llm/num_tokens",
|
||||
type=PluginLLMNumTokensResponse,
|
||||
type_=PluginLLMNumTokensResponse,
|
||||
data=jsonable_encoder(
|
||||
{
|
||||
"user_id": user_id,
|
||||
|
|
@ -250,7 +250,7 @@ class PluginModelClient(BasePluginClient):
|
|||
response = self._request_with_plugin_daemon_response_stream(
|
||||
method="POST",
|
||||
path=f"plugin/{tenant_id}/dispatch/text_embedding/invoke",
|
||||
type=TextEmbeddingResult,
|
||||
type_=TextEmbeddingResult,
|
||||
data=jsonable_encoder(
|
||||
{
|
||||
"user_id": user_id,
|
||||
|
|
@ -291,7 +291,7 @@ class PluginModelClient(BasePluginClient):
|
|||
response = self._request_with_plugin_daemon_response_stream(
|
||||
method="POST",
|
||||
path=f"plugin/{tenant_id}/dispatch/text_embedding/num_tokens",
|
||||
type=PluginTextEmbeddingNumTokensResponse,
|
||||
type_=PluginTextEmbeddingNumTokensResponse,
|
||||
data=jsonable_encoder(
|
||||
{
|
||||
"user_id": user_id,
|
||||
|
|
@ -334,7 +334,7 @@ class PluginModelClient(BasePluginClient):
|
|||
response = self._request_with_plugin_daemon_response_stream(
|
||||
method="POST",
|
||||
path=f"plugin/{tenant_id}/dispatch/rerank/invoke",
|
||||
type=RerankResult,
|
||||
type_=RerankResult,
|
||||
data=jsonable_encoder(
|
||||
{
|
||||
"user_id": user_id,
|
||||
|
|
@ -378,7 +378,7 @@ class PluginModelClient(BasePluginClient):
|
|||
response = self._request_with_plugin_daemon_response_stream(
|
||||
method="POST",
|
||||
path=f"plugin/{tenant_id}/dispatch/tts/invoke",
|
||||
type=PluginStringResultResponse,
|
||||
type_=PluginStringResultResponse,
|
||||
data=jsonable_encoder(
|
||||
{
|
||||
"user_id": user_id,
|
||||
|
|
@ -422,7 +422,7 @@ class PluginModelClient(BasePluginClient):
|
|||
response = self._request_with_plugin_daemon_response_stream(
|
||||
method="POST",
|
||||
path=f"plugin/{tenant_id}/dispatch/tts/model/voices",
|
||||
type=PluginVoicesResponse,
|
||||
type_=PluginVoicesResponse,
|
||||
data=jsonable_encoder(
|
||||
{
|
||||
"user_id": user_id,
|
||||
|
|
@ -466,7 +466,7 @@ class PluginModelClient(BasePluginClient):
|
|||
response = self._request_with_plugin_daemon_response_stream(
|
||||
method="POST",
|
||||
path=f"plugin/{tenant_id}/dispatch/speech2text/invoke",
|
||||
type=PluginStringResultResponse,
|
||||
type_=PluginStringResultResponse,
|
||||
data=jsonable_encoder(
|
||||
{
|
||||
"user_id": user_id,
|
||||
|
|
@ -506,7 +506,7 @@ class PluginModelClient(BasePluginClient):
|
|||
response = self._request_with_plugin_daemon_response_stream(
|
||||
method="POST",
|
||||
path=f"plugin/{tenant_id}/dispatch/moderation/invoke",
|
||||
type=PluginBasicBooleanResponse,
|
||||
type_=PluginBasicBooleanResponse,
|
||||
data=jsonable_encoder(
|
||||
{
|
||||
"user_id": user_id,
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from collections.abc import Generator
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TypeVar, Union, cast
|
||||
from typing import TypeVar, Union
|
||||
|
||||
from core.agent.entities import AgentInvokeMessage
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
|
|
@ -87,7 +87,8 @@ def merge_blob_chunks(
|
|||
),
|
||||
meta=resp.meta,
|
||||
)
|
||||
yield cast(MessageType, merged_message)
|
||||
assert isinstance(merged_message, (ToolInvokeMessage, AgentInvokeMessage))
|
||||
yield merged_message # type: ignore
|
||||
# Clean up the buffer
|
||||
del files[chunk_id]
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -610,7 +610,7 @@ class ProviderManager:
|
|||
|
||||
provider_quota_to_provider_record_dict = {}
|
||||
for provider_record in provider_records:
|
||||
if provider_record.provider_type != ProviderType.SYSTEM.value:
|
||||
if provider_record.provider_type != ProviderType.SYSTEM:
|
||||
continue
|
||||
|
||||
provider_quota_to_provider_record_dict[ProviderQuotaType.value_of(provider_record.quota_type)] = (
|
||||
|
|
@ -627,8 +627,8 @@ class ProviderManager:
|
|||
tenant_id=tenant_id,
|
||||
# TODO: Use provider name with prefix after the data migration.
|
||||
provider_name=ModelProviderID(provider_name).provider_name,
|
||||
provider_type=ProviderType.SYSTEM.value,
|
||||
quota_type=ProviderQuotaType.TRIAL.value,
|
||||
provider_type=ProviderType.SYSTEM,
|
||||
quota_type=ProviderQuotaType.TRIAL,
|
||||
quota_limit=quota.quota_limit, # type: ignore
|
||||
quota_used=0,
|
||||
is_valid=True,
|
||||
|
|
@ -641,8 +641,8 @@ class ProviderManager:
|
|||
stmt = select(Provider).where(
|
||||
Provider.tenant_id == tenant_id,
|
||||
Provider.provider_name == ModelProviderID(provider_name).provider_name,
|
||||
Provider.provider_type == ProviderType.SYSTEM.value,
|
||||
Provider.quota_type == ProviderQuotaType.TRIAL.value,
|
||||
Provider.provider_type == ProviderType.SYSTEM,
|
||||
Provider.quota_type == ProviderQuotaType.TRIAL,
|
||||
)
|
||||
existed_provider_record = db.session.scalar(stmt)
|
||||
if not existed_provider_record:
|
||||
|
|
@ -702,7 +702,7 @@ class ProviderManager:
|
|||
"""Get custom provider configuration."""
|
||||
# Find custom provider record (non-system)
|
||||
custom_provider_record = next(
|
||||
(record for record in provider_records if record.provider_type != ProviderType.SYSTEM.value), None
|
||||
(record for record in provider_records if record.provider_type != ProviderType.SYSTEM), None
|
||||
)
|
||||
|
||||
if not custom_provider_record:
|
||||
|
|
@ -905,7 +905,7 @@ class ProviderManager:
|
|||
# Convert provider_records to dict
|
||||
quota_type_to_provider_records_dict: dict[ProviderQuotaType, Provider] = {}
|
||||
for provider_record in provider_records:
|
||||
if provider_record.provider_type != ProviderType.SYSTEM.value:
|
||||
if provider_record.provider_type != ProviderType.SYSTEM:
|
||||
continue
|
||||
|
||||
quota_type_to_provider_records_dict[ProviderQuotaType.value_of(provider_record.quota_type)] = (
|
||||
|
|
@ -1046,7 +1046,7 @@ class ProviderManager:
|
|||
"""
|
||||
secret_input_form_variables = []
|
||||
for credential_form_schema in credential_form_schemas:
|
||||
if credential_form_schema.type.value == FormType.SECRET_INPUT.value:
|
||||
if credential_form_schema.type == FormType.SECRET_INPUT:
|
||||
secret_input_form_variables.append(credential_form_schema.variable)
|
||||
|
||||
return secret_input_form_variables
|
||||
|
|
|
|||
|
|
@ -46,7 +46,7 @@ class DataPostProcessor:
|
|||
reranking_model: dict | None = None,
|
||||
weights: dict | None = None,
|
||||
) -> BaseRerankRunner | None:
|
||||
if reranking_mode == RerankMode.WEIGHTED_SCORE.value and weights:
|
||||
if reranking_mode == RerankMode.WEIGHTED_SCORE and weights:
|
||||
runner = RerankRunnerFactory.create_rerank_runner(
|
||||
runner_type=reranking_mode,
|
||||
tenant_id=tenant_id,
|
||||
|
|
@ -62,7 +62,7 @@ class DataPostProcessor:
|
|||
),
|
||||
)
|
||||
return runner
|
||||
elif reranking_mode == RerankMode.RERANKING_MODEL.value:
|
||||
elif reranking_mode == RerankMode.RERANKING_MODEL:
|
||||
rerank_model_instance = self._get_rerank_model_instance(tenant_id, reranking_model)
|
||||
if rerank_model_instance is None:
|
||||
return None
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue