Merge branch 'fix/secret_variable' of github.com:PankajKaushal/dify into fix/secret_variable

This commit is contained in:
Pankaj Kaushal 2025-10-13 16:27:54 +05:30
commit bb53760e6c
546 changed files with 12238 additions and 5279 deletions

View File

@ -1,4 +1,4 @@
FROM mcr.microsoft.com/devcontainers/python:3.12-bullseye
FROM mcr.microsoft.com/devcontainers/python:3.12-bookworm
RUN apt-get update && export DEBIAN_FRONTEND=noninteractive \
&& apt-get -y install libgmp-dev libmpfr-dev libmpc-dev

View File

@ -30,6 +30,8 @@ jobs:
run: |
uvx --from ast-grep-cli sg --pattern 'db.session.query($WHATEVER).filter($HERE)' --rewrite 'db.session.query($WHATEVER).where($HERE)' -l py --update-all
uvx --from ast-grep-cli sg --pattern 'session.query($WHATEVER).filter($HERE)' --rewrite 'session.query($WHATEVER).where($HERE)' -l py --update-all
uvx --from ast-grep-cli sg -p '$A = db.Column($$$B)' -r '$A = mapped_column($$$B)' -l py --update-all
uvx --from ast-grep-cli sg -p '$A : $T = db.Column($$$B)' -r '$A : $T = mapped_column($$$B)' -l py --update-all
# Convert Optional[T] to T | None (ignoring quoted types)
cat > /tmp/optional-rule.yml << 'EOF'
id: convert-optional-to-union

View File

@ -4,8 +4,7 @@ on:
push:
branches:
- "main"
- "deploy/dev"
- "deploy/enterprise"
- "deploy/**"
- "build/**"
- "release/e-*"
- "hotfix/**"

View File

@ -18,7 +18,7 @@ jobs:
- name: Deploy to server
uses: appleboy/ssh-action@v0.1.8
with:
host: ${{ secrets.RAG_SSH_HOST }}
host: ${{ secrets.SSH_HOST }}
username: ${{ secrets.SSH_USER }}
key: ${{ secrets.SSH_PRIVATE_KEY }}
script: |

View File

@ -1,4 +1,4 @@
name: Deploy RAG Dev
name: Deploy Trigger Dev
permissions:
contents: read
@ -7,7 +7,7 @@ on:
workflow_run:
workflows: ["Build and Push API & Web"]
branches:
- "deploy/rag-dev"
- "deploy/trigger-dev"
types:
- completed
@ -16,12 +16,12 @@ jobs:
runs-on: ubuntu-latest
if: |
github.event.workflow_run.conclusion == 'success' &&
github.event.workflow_run.head_branch == 'deploy/rag-dev'
github.event.workflow_run.head_branch == 'deploy/trigger-dev'
steps:
- name: Deploy to server
uses: appleboy/ssh-action@v0.1.8
with:
host: ${{ secrets.RAG_SSH_HOST }}
host: ${{ secrets.TRIGGER_SSH_HOST }}
username: ${{ secrets.SSH_USER }}
key: ${{ secrets.SSH_PRIVATE_KEY }}
script: |

View File

@ -40,18 +40,18 @@
<p align="center">
<a href="./README.md"><img alt="README in English" src="https://img.shields.io/badge/English-d9d9d9"></a>
<a href="./README/README_TW.md"><img alt="繁體中文文件" src="https://img.shields.io/badge/繁體中文-d9d9d9"></a>
<a href="./README/README_CN.md"><img alt="简体中文版自述文件" src="https://img.shields.io/badge/简体中文-d9d9d9"></a>
<a href="./README/README_JA.md"><img alt="日本語のREADME" src="https://img.shields.io/badge/日本語-d9d9d9"></a>
<a href="./README/README_ES.md"><img alt="README en Español" src="https://img.shields.io/badge/Español-d9d9d9"></a>
<a href="./README/README_FR.md"><img alt="README en Français" src="https://img.shields.io/badge/Français-d9d9d9"></a>
<a href="./README/README_KL.md"><img alt="README tlhIngan Hol" src="https://img.shields.io/badge/Klingon-d9d9d9"></a>
<a href="./README/README_KR.md"><img alt="README in Korean" src="https://img.shields.io/badge/한국어-d9d9d9"></a>
<a href="./README/README_AR.md"><img alt="README بالعربية" src="https://img.shields.io/badge/العربية-d9d9d9"></a>
<a href="./README/README_TR.md"><img alt="Türkçe README" src="https://img.shields.io/badge/Türkçe-d9d9d9"></a>
<a href="./README/README_VI.md"><img alt="README Tiếng Việt" src="https://img.shields.io/badge/Ti%E1%BA%BFng%20Vi%E1%BB%87t-d9d9d9"></a>
<a href="./README/README_DE.md"><img alt="README in Deutsch" src="https://img.shields.io/badge/German-d9d9d9"></a>
<a href="./README/README_BN.md"><img alt="README in বাংলা" src="https://img.shields.io/badge/বাংলা-d9d9d9"></a>
<a href="./docs/zh-TW/README.md"><img alt="繁體中文文件" src="https://img.shields.io/badge/繁體中文-d9d9d9"></a>
<a href="./docs/zh-CN/README.md"><img alt="简体中文文件" src="https://img.shields.io/badge/简体中文-d9d9d9"></a>
<a href="./docs/ja-JP/README.md"><img alt="日本語のREADME" src="https://img.shields.io/badge/日本語-d9d9d9"></a>
<a href="./docs/es-ES/README.md"><img alt="README en Español" src="https://img.shields.io/badge/Español-d9d9d9"></a>
<a href="./docs/fr-FR/README.md"><img alt="README en Français" src="https://img.shields.io/badge/Français-d9d9d9"></a>
<a href="./docs/tlh/README.md"><img alt="README tlhIngan Hol" src="https://img.shields.io/badge/Klingon-d9d9d9"></a>
<a href="./docs/ko-KR/README.md"><img alt="README in Korean" src="https://img.shields.io/badge/한국어-d9d9d9"></a>
<a href="./docs/ar-SA/README.md"><img alt="README بالعربية" src="https://img.shields.io/badge/العربية-d9d9d9"></a>
<a href="./docs/tr-TR/README.md"><img alt="Türkçe README" src="https://img.shields.io/badge/Türkçe-d9d9d9"></a>
<a href="./docs/vi-VN/README.md"><img alt="README Tiếng Việt" src="https://img.shields.io/badge/Ti%E1%BA%BFng%20Vi%E1%BB%87t-d9d9d9"></a>
<a href="./docs/de-DE/README.md"><img alt="README in Deutsch" src="https://img.shields.io/badge/German-d9d9d9"></a>
<a href="./docs/bn-BD/README.md"><img alt="README in বাংলা" src="https://img.shields.io/badge/বাংলা-d9d9d9"></a>
</p>
Dify is an open-source platform for developing LLM applications. Its intuitive interface combines agentic AI workflows, RAG pipelines, agent capabilities, model management, observability features, and more—allowing you to quickly move from prototype to production.

View File

@ -343,6 +343,15 @@ OCEANBASE_VECTOR_DATABASE=test
OCEANBASE_MEMORY_LIMIT=6G
OCEANBASE_ENABLE_HYBRID_SEARCH=false
# AlibabaCloud MySQL Vector configuration
ALIBABACLOUD_MYSQL_HOST=127.0.0.1
ALIBABACLOUD_MYSQL_PORT=3306
ALIBABACLOUD_MYSQL_USER=root
ALIBABACLOUD_MYSQL_PASSWORD=root
ALIBABACLOUD_MYSQL_DATABASE=dify
ALIBABACLOUD_MYSQL_MAX_CONNECTION=5
ALIBABACLOUD_MYSQL_HNSW_M=6
# openGauss configuration
OPENGAUSS_HOST=127.0.0.1
OPENGAUSS_PORT=6600
@ -427,8 +436,8 @@ CODE_EXECUTION_POOL_MAX_KEEPALIVE_CONNECTIONS=20
CODE_EXECUTION_POOL_KEEPALIVE_EXPIRY=5.0
CODE_MAX_NUMBER=9223372036854775807
CODE_MIN_NUMBER=-9223372036854775808
CODE_MAX_STRING_LENGTH=80000
TEMPLATE_TRANSFORM_MAX_LENGTH=80000
CODE_MAX_STRING_LENGTH=400000
TEMPLATE_TRANSFORM_MAX_LENGTH=400000
CODE_MAX_STRING_ARRAY_LENGTH=30
CODE_MAX_OBJECT_ARRAY_LENGTH=30
CODE_MAX_NUMBER_ARRAY_LENGTH=1000

View File

@ -81,7 +81,6 @@ ignore = [
"SIM113", # enumerate-for-loop
"SIM117", # multiple-with-statements
"SIM210", # if-expr-with-true-false
"UP038", # deprecated and not recommended by Ruff, https://docs.astral.sh/ruff/rules/non-pep604-isinstance/
]
[lint.per-file-ignores]

View File

@ -1521,6 +1521,14 @@ def transform_datasource_credentials():
auth_count = 0
for firecrawl_tenant_credential in firecrawl_tenant_credentials:
auth_count += 1
if not firecrawl_tenant_credential.credentials:
click.echo(
click.style(
f"Skipping firecrawl credential for tenant {tenant_id} due to missing credentials.",
fg="yellow",
)
)
continue
# get credential api key
credentials_json = json.loads(firecrawl_tenant_credential.credentials)
api_key = credentials_json.get("config", {}).get("api_key")
@ -1576,6 +1584,14 @@ def transform_datasource_credentials():
auth_count = 0
for jina_tenant_credential in jina_tenant_credentials:
auth_count += 1
if not jina_tenant_credential.credentials:
click.echo(
click.style(
f"Skipping jina credential for tenant {tenant_id} due to missing credentials.",
fg="yellow",
)
)
continue
# get credential api key
credentials_json = json.loads(jina_tenant_credential.credentials)
api_key = credentials_json.get("config", {}).get("api_key")

View File

@ -150,7 +150,7 @@ class CodeExecutionSandboxConfig(BaseSettings):
CODE_MAX_STRING_LENGTH: PositiveInt = Field(
description="Maximum allowed length for strings in code execution",
default=80000,
default=400_000,
)
CODE_MAX_STRING_ARRAY_LENGTH: PositiveInt = Field(
@ -362,11 +362,11 @@ class HttpConfig(BaseSettings):
)
HTTP_REQUEST_MAX_READ_TIMEOUT: int = Field(
ge=1, description="Maximum read timeout in seconds for HTTP requests", default=60
ge=1, description="Maximum read timeout in seconds for HTTP requests", default=600
)
HTTP_REQUEST_MAX_WRITE_TIMEOUT: int = Field(
ge=1, description="Maximum write timeout in seconds for HTTP requests", default=20
ge=1, description="Maximum write timeout in seconds for HTTP requests", default=600
)
HTTP_REQUEST_NODE_MAX_BINARY_SIZE: PositiveInt = Field(
@ -582,6 +582,11 @@ class WorkflowConfig(BaseSettings):
default=200 * 1024,
)
TEMPLATE_TRANSFORM_MAX_LENGTH: PositiveInt = Field(
description="Maximum number of characters allowed in Template Transform node output",
default=400_000,
)
# GraphEngine Worker Pool Configuration
GRAPH_ENGINE_MIN_WORKERS: PositiveInt = Field(
description="Minimum number of workers per GraphEngine instance",
@ -766,7 +771,7 @@ class MailConfig(BaseSettings):
MAIL_TEMPLATING_TIMEOUT: int = Field(
description="""
Timeout for email templating in seconds. Used to prevent infinite loops in malicious templates.
Timeout for email templating in seconds. Used to prevent infinite loops in malicious templates.
Only available in sandbox mode.""",
default=3,
)

View File

@ -18,6 +18,7 @@ from .storage.opendal_storage_config import OpenDALStorageConfig
from .storage.supabase_storage_config import SupabaseStorageConfig
from .storage.tencent_cos_storage_config import TencentCloudCOSStorageConfig
from .storage.volcengine_tos_storage_config import VolcengineTOSStorageConfig
from .vdb.alibabacloud_mysql_config import AlibabaCloudMySQLConfig
from .vdb.analyticdb_config import AnalyticdbConfig
from .vdb.baidu_vector_config import BaiduVectorDBConfig
from .vdb.chroma_config import ChromaConfig
@ -330,6 +331,7 @@ class MiddlewareConfig(
ClickzettaConfig,
HuaweiCloudConfig,
MilvusConfig,
AlibabaCloudMySQLConfig,
MyScaleConfig,
OpenSearchConfig,
OracleConfig,

View File

@ -0,0 +1,54 @@
from pydantic import Field, PositiveInt
from pydantic_settings import BaseSettings
class AlibabaCloudMySQLConfig(BaseSettings):
"""
Configuration settings for AlibabaCloud MySQL vector database
"""
ALIBABACLOUD_MYSQL_HOST: str = Field(
description="Hostname or IP address of the AlibabaCloud MySQL server (e.g., 'localhost' or 'mysql.aliyun.com')",
default="localhost",
)
ALIBABACLOUD_MYSQL_PORT: PositiveInt = Field(
description="Port number on which the AlibabaCloud MySQL server is listening (default is 3306)",
default=3306,
)
ALIBABACLOUD_MYSQL_USER: str = Field(
description="Username for authenticating with AlibabaCloud MySQL (default is 'root')",
default="root",
)
ALIBABACLOUD_MYSQL_PASSWORD: str = Field(
description="Password for authenticating with AlibabaCloud MySQL (default is an empty string)",
default="",
)
ALIBABACLOUD_MYSQL_DATABASE: str = Field(
description="Name of the AlibabaCloud MySQL database to connect to (default is 'dify')",
default="dify",
)
ALIBABACLOUD_MYSQL_MAX_CONNECTION: PositiveInt = Field(
description="Maximum number of connections in the connection pool",
default=5,
)
ALIBABACLOUD_MYSQL_CHARSET: str = Field(
description="Character set for AlibabaCloud MySQL connection (default is 'utf8mb4')",
default="utf8mb4",
)
ALIBABACLOUD_MYSQL_DISTANCE_FUNCTION: str = Field(
description="Distance function used for vector similarity search in AlibabaCloud MySQL "
"(e.g., 'cosine', 'euclidean')",
default="cosine",
)
ALIBABACLOUD_MYSQL_HNSW_M: PositiveInt = Field(
description="Maximum number of connections per layer for HNSW vector index (default is 6, range: 3-200)",
default=6,
)

View File

@ -1,23 +1,24 @@
from enum import Enum
from enum import StrEnum
from typing import Literal
from pydantic import Field, PositiveInt
from pydantic_settings import BaseSettings
class AuthMethod(StrEnum):
"""
Authentication method for OpenSearch
"""
BASIC = "basic"
AWS_MANAGED_IAM = "aws_managed_iam"
class OpenSearchConfig(BaseSettings):
"""
Configuration settings for OpenSearch
"""
class AuthMethod(Enum):
"""
Authentication method for OpenSearch
"""
BASIC = "basic"
AWS_MANAGED_IAM = "aws_managed_iam"
OPENSEARCH_HOST: str | None = Field(
description="Hostname or IP address of the OpenSearch server (e.g., 'localhost' or 'opensearch.example.com')",
default=None,

View File

@ -1,4 +1,5 @@
from configs import dify_config
from libs.collection_utils import convert_to_lower_and_upper_set
HIDDEN_VALUE = "[__HIDDEN__]"
UNKNOWN_VALUE = "[__UNKNOWN__]"
@ -6,24 +7,39 @@ UUID_NIL = "00000000-0000-0000-0000-000000000000"
DEFAULT_FILE_NUMBER_LIMITS = 3
IMAGE_EXTENSIONS = ["jpg", "jpeg", "png", "webp", "gif", "svg"]
IMAGE_EXTENSIONS.extend([ext.upper() for ext in IMAGE_EXTENSIONS])
IMAGE_EXTENSIONS = convert_to_lower_and_upper_set({"jpg", "jpeg", "png", "webp", "gif", "svg"})
VIDEO_EXTENSIONS = ["mp4", "mov", "mpeg", "webm"]
VIDEO_EXTENSIONS.extend([ext.upper() for ext in VIDEO_EXTENSIONS])
VIDEO_EXTENSIONS = convert_to_lower_and_upper_set({"mp4", "mov", "mpeg", "webm"})
AUDIO_EXTENSIONS = ["mp3", "m4a", "wav", "amr", "mpga"]
AUDIO_EXTENSIONS.extend([ext.upper() for ext in AUDIO_EXTENSIONS])
AUDIO_EXTENSIONS = convert_to_lower_and_upper_set({"mp3", "m4a", "wav", "amr", "mpga"})
_doc_extensions: list[str]
_doc_extensions: set[str]
if dify_config.ETL_TYPE == "Unstructured":
_doc_extensions = ["txt", "markdown", "md", "mdx", "pdf", "html", "htm", "xlsx", "xls", "vtt", "properties"]
_doc_extensions.extend(("doc", "docx", "csv", "eml", "msg", "pptx", "xml", "epub"))
_doc_extensions = {
"txt",
"markdown",
"md",
"mdx",
"pdf",
"html",
"htm",
"xlsx",
"xls",
"vtt",
"properties",
"doc",
"docx",
"csv",
"eml",
"msg",
"pptx",
"xml",
"epub",
}
if dify_config.UNSTRUCTURED_API_URL:
_doc_extensions.append("ppt")
_doc_extensions.add("ppt")
else:
_doc_extensions = [
_doc_extensions = {
"txt",
"markdown",
"md",
@ -37,5 +53,5 @@ else:
"csv",
"vtt",
"properties",
]
DOCUMENT_EXTENSIONS = _doc_extensions + [ext.upper() for ext in _doc_extensions]
}
DOCUMENT_EXTENSIONS: set[str] = convert_to_lower_and_upper_set(_doc_extensions)

View File

@ -1,5 +1,4 @@
import flask_restx
from flask_login import current_user
from flask_restx import Resource, fields, marshal_with
from flask_restx._http import HTTPStatus
from sqlalchemy import select
@ -8,7 +7,8 @@ from werkzeug.exceptions import Forbidden
from extensions.ext_database import db
from libs.helper import TimestampField
from libs.login import login_required
from libs.login import current_user, login_required
from models.account import Account
from models.dataset import Dataset
from models.model import ApiToken, App
@ -57,6 +57,8 @@ class BaseApiKeyListResource(Resource):
def get(self, resource_id):
assert self.resource_id_field is not None, "resource_id_field must be set"
resource_id = str(resource_id)
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
_get_resource(resource_id, current_user.current_tenant_id, self.resource_model)
keys = db.session.scalars(
select(ApiToken).where(
@ -69,8 +71,10 @@ class BaseApiKeyListResource(Resource):
def post(self, resource_id):
assert self.resource_id_field is not None, "resource_id_field must be set"
resource_id = str(resource_id)
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
_get_resource(resource_id, current_user.current_tenant_id, self.resource_model)
if not current_user.is_editor:
if not current_user.has_edit_permission:
raise Forbidden()
current_key_count = (
@ -108,6 +112,8 @@ class BaseApiKeyResource(Resource):
assert self.resource_id_field is not None, "resource_id_field must be set"
resource_id = str(resource_id)
api_key_id = str(api_key_id)
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
_get_resource(resource_id, current_user.current_tenant_id, self.resource_model)
# The role of the current user in the ta table must be admin or owner

View File

@ -19,6 +19,7 @@ from core.ops.ops_trace_manager import OpsTraceManager
from extensions.ext_database import db
from fields.app_fields import app_detail_fields, app_detail_fields_with_site, app_pagination_fields
from libs.login import login_required
from libs.validators import validate_description_length
from models import Account, App
from services.app_dsl_service import AppDslService, ImportMode
from services.app_service import AppService
@ -28,12 +29,6 @@ from services.feature_service import FeatureService
ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "completion"]
def _validate_description_length(description):
if description and len(description) > 400:
raise ValueError("Description cannot exceed 400 characters.")
return description
@console_ns.route("/apps")
class AppListApi(Resource):
@api.doc("list_apps")
@ -138,7 +133,7 @@ class AppListApi(Resource):
"""Create app"""
parser = reqparse.RequestParser()
parser.add_argument("name", type=str, required=True, location="json")
parser.add_argument("description", type=_validate_description_length, location="json")
parser.add_argument("description", type=validate_description_length, location="json")
parser.add_argument("mode", type=str, choices=ALLOW_CREATE_APP_MODES, location="json")
parser.add_argument("icon_type", type=str, location="json")
parser.add_argument("icon", type=str, location="json")
@ -219,7 +214,7 @@ class AppApi(Resource):
parser = reqparse.RequestParser()
parser.add_argument("name", type=str, required=True, nullable=False, location="json")
parser.add_argument("description", type=_validate_description_length, location="json")
parser.add_argument("description", type=validate_description_length, location="json")
parser.add_argument("icon_type", type=str, location="json")
parser.add_argument("icon", type=str, location="json")
parser.add_argument("icon_background", type=str, location="json")
@ -297,7 +292,7 @@ class AppCopyApi(Resource):
parser = reqparse.RequestParser()
parser.add_argument("name", type=str, location="json")
parser.add_argument("description", type=_validate_description_length, location="json")
parser.add_argument("description", type=validate_description_length, location="json")
parser.add_argument("icon_type", type=str, location="json")
parser.add_argument("icon", type=str, location="json")
parser.add_argument("icon_background", type=str, location="json")
@ -309,7 +304,7 @@ class AppCopyApi(Resource):
account = cast(Account, current_user)
result = import_service.import_app(
account=account,
import_mode=ImportMode.YAML_CONTENT.value,
import_mode=ImportMode.YAML_CONTENT,
yaml_content=yaml_content,
name=args.get("name"),
description=args.get("description"),

View File

@ -70,9 +70,9 @@ class AppImportApi(Resource):
EnterpriseService.WebAppAuth.update_app_access_mode(result.app_id, "private")
# Return appropriate status code based on result
status = result.status
if status == ImportStatus.FAILED.value:
if status == ImportStatus.FAILED:
return result.model_dump(mode="json"), 400
elif status == ImportStatus.PENDING.value:
elif status == ImportStatus.PENDING:
return result.model_dump(mode="json"), 202
return result.model_dump(mode="json"), 200
@ -97,7 +97,7 @@ class AppImportConfirmApi(Resource):
session.commit()
# Return appropriate status code based on result
if result.status == ImportStatus.FAILED.value:
if result.status == ImportStatus.FAILED:
return result.model_dump(mode="json"), 400
return result.model_dump(mode="json"), 200

View File

@ -309,7 +309,7 @@ class ChatConversationApi(Resource):
)
if app_model.mode == AppMode.ADVANCED_CHAT:
query = query.where(Conversation.invoke_from != InvokeFrom.DEBUGGER.value)
query = query.where(Conversation.invoke_from != InvokeFrom.DEBUGGER)
match args["sort_by"]:
case "created_at":

View File

@ -14,6 +14,7 @@ from core.tools.tool_manager import ToolManager
from core.tools.utils.configuration import ToolParameterConfigurationManager
from events.app_event import app_model_config_was_updated
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from libs.login import login_required
from models.account import Account
from models.model import AppMode, AppModelConfig
@ -90,7 +91,7 @@ class ModelConfigResource(Resource):
if not isinstance(tool, dict) or len(tool.keys()) <= 3:
continue
agent_tool_entity = AgentToolEntity(**tool)
agent_tool_entity = AgentToolEntity.model_validate(tool)
# get tool
try:
tool_runtime = ToolManager.get_agent_tool_runtime(
@ -124,7 +125,7 @@ class ModelConfigResource(Resource):
# encrypt agent tool parameters if it's secret-input
agent_mode = new_app_model_config.agent_mode_dict
for tool in agent_mode.get("tools") or []:
agent_tool_entity = AgentToolEntity(**tool)
agent_tool_entity = AgentToolEntity.model_validate(tool)
# get tool
key = f"{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}"
@ -172,6 +173,8 @@ class ModelConfigResource(Resource):
db.session.flush()
app_model.app_model_config_id = new_app_model_config.id
app_model.updated_by = current_user.id
app_model.updated_at = naive_utc_now()
db.session.commit()
app_model_config_was_updated.send(app_model, app_model_config=new_app_model_config)

View File

@ -52,7 +52,7 @@ FROM
WHERE
app_id = :app_id
AND invoke_from != :invoke_from"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value}
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
@ -127,7 +127,7 @@ class DailyConversationStatistic(Resource):
sa.func.count(sa.distinct(Message.conversation_id)).label("conversation_count"),
)
.select_from(Message)
.where(Message.app_id == app_model.id, Message.invoke_from != InvokeFrom.DEBUGGER.value)
.where(Message.app_id == app_model.id, Message.invoke_from != InvokeFrom.DEBUGGER)
)
if args["start"]:
@ -190,7 +190,7 @@ FROM
WHERE
app_id = :app_id
AND invoke_from != :invoke_from"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value}
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
@ -263,7 +263,7 @@ FROM
WHERE
app_id = :app_id
AND invoke_from != :invoke_from"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value}
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
@ -345,7 +345,7 @@ FROM
WHERE
c.app_id = :app_id
AND m.invoke_from != :invoke_from"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value}
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
@ -432,7 +432,7 @@ LEFT JOIN
WHERE
m.app_id = :app_id
AND m.invoke_from != :invoke_from"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value}
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
@ -509,7 +509,7 @@ FROM
WHERE
app_id = :app_id
AND invoke_from != :invoke_from"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value}
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
@ -584,7 +584,7 @@ FROM
WHERE
app_id = :app_id
AND invoke_from != :invoke_from"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value}
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc

View File

@ -25,6 +25,7 @@ from factories import file_factory, variable_factory
from fields.workflow_fields import workflow_fields, workflow_pagination_fields
from fields.workflow_run_fields import workflow_run_node_execution_fields
from libs import helper
from libs.datetime_utils import naive_utc_now
from libs.helper import TimestampField, uuid_value
from libs.login import current_user, login_required
from models import App
@ -674,8 +675,12 @@ class PublishedWorkflowApi(Resource):
marked_comment=args.marked_comment or "",
)
app_model.workflow_id = workflow.id
db.session.commit() # NOTE: this is necessary for update app_model.workflow_id
# Update app_model within the same session to ensure atomicity
app_model_in_session = session.get(App, app_model.id)
if app_model_in_session:
app_model_in_session.workflow_id = workflow.id
app_model_in_session.updated_by = current_user.id
app_model_in_session.updated_at = naive_utc_now()
workflow_created_at = TimestampField().format(workflow.created_at)

View File

@ -47,7 +47,7 @@ WHERE
arg_dict = {
"tz": account.timezone,
"app_id": app_model.id,
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN.value,
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN,
}
timezone = pytz.timezone(account.timezone)
@ -115,7 +115,7 @@ WHERE
arg_dict = {
"tz": account.timezone,
"app_id": app_model.id,
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN.value,
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN,
}
timezone = pytz.timezone(account.timezone)
@ -183,7 +183,7 @@ WHERE
arg_dict = {
"tz": account.timezone,
"app_id": app_model.id,
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN.value,
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN,
}
timezone = pytz.timezone(account.timezone)
@ -269,7 +269,7 @@ GROUP BY
arg_dict = {
"tz": account.timezone,
"app_id": app_model.id,
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN.value,
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN,
}
timezone = pytz.timezone(account.timezone)

View File

@ -103,7 +103,7 @@ class ActivateApi(Resource):
account.interface_language = args["interface_language"]
account.timezone = args["timezone"]
account.interface_theme = "light"
account.status = AccountStatus.ACTIVE.value
account.status = AccountStatus.ACTIVE
account.initialized_at = naive_utc_now()
db.session.commit()

View File

@ -130,11 +130,11 @@ class OAuthCallback(Resource):
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message={e.description}")
# Check account status
if account.status == AccountStatus.BANNED.value:
if account.status == AccountStatus.BANNED:
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Account is banned.")
if account.status == AccountStatus.PENDING.value:
account.status = AccountStatus.ACTIVE.value
if account.status == AccountStatus.PENDING:
account.status = AccountStatus.ACTIVE
account.initialized_at = naive_utc_now()
db.session.commit()

View File

@ -1,9 +1,9 @@
from flask import request
from flask_login import current_user
from flask_restx import Resource, reqparse
from libs.helper import extract_remote_ip
from libs.login import login_required
from libs.login import current_user, login_required
from models.account import Account
from services.billing_service import BillingService
from .. import console_ns
@ -17,6 +17,8 @@ class ComplianceApi(Resource):
@account_initialization_required
@only_edition_cloud
def get(self):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
parser = reqparse.RequestParser()
parser.add_argument("doc_name", type=str, required=True, location="args")
args = parser.parse_args()

View File

@ -15,7 +15,7 @@ from core.datasource.entities.datasource_entities import DatasourceProviderType,
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
from core.indexing_runner import IndexingRunner
from core.rag.extractor.entity.datasource_type import DatasourceType
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo
from core.rag.extractor.notion_extractor import NotionExtractor
from extensions.ext_database import db
from fields.data_source_fields import integrate_list_fields, integrate_notion_info_list_fields
@ -256,14 +256,16 @@ class DataSourceNotionApi(Resource):
credential_id = notion_info.get("credential_id")
for page in notion_info["pages"]:
extract_setting = ExtractSetting(
datasource_type=DatasourceType.NOTION.value,
notion_info={
"credential_id": credential_id,
"notion_workspace_id": workspace_id,
"notion_obj_id": page["page_id"],
"notion_page_type": page["type"],
"tenant_id": current_user.current_tenant_id,
},
datasource_type=DatasourceType.NOTION,
notion_info=NotionInfo.model_validate(
{
"credential_id": credential_id,
"notion_workspace_id": workspace_id,
"notion_obj_id": page["page_id"],
"notion_page_type": page["type"],
"tenant_id": current_user.current_tenant_id,
}
),
document_model=args["doc_form"],
)
extract_settings.append(extract_setting)

View File

@ -24,13 +24,14 @@ from core.model_runtime.entities.model_entities import ModelType
from core.provider_manager import ProviderManager
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.extractor.entity.datasource_type import DatasourceType
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo, WebsiteInfo
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from extensions.ext_database import db
from fields.app_fields import related_app_list
from fields.dataset_fields import dataset_detail_fields, dataset_query_detail_fields
from fields.document_fields import document_status_fields
from libs.login import login_required
from libs.validators import validate_description_length
from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile
from models.account import Account
from models.dataset import DatasetPermissionEnum
@ -44,10 +45,77 @@ def _validate_name(name: str) -> str:
return name
def _validate_description_length(description):
if description and len(description) > 400:
raise ValueError("Description cannot exceed 400 characters.")
return description
def _get_retrieval_methods_by_vector_type(vector_type: str | None, is_mock: bool = False) -> dict[str, list[str]]:
"""
Get supported retrieval methods based on vector database type.
Args:
vector_type: Vector database type, can be None
is_mock: Whether this is a Mock API, affects MILVUS handling
Returns:
Dictionary containing supported retrieval methods
Raises:
ValueError: If vector_type is None or unsupported
"""
if vector_type is None:
raise ValueError("Vector store type is not configured.")
# Define vector database types that only support semantic search
semantic_only_types = {
VectorType.RELYT,
VectorType.TIDB_VECTOR,
VectorType.CHROMA,
VectorType.PGVECTO_RS,
VectorType.VIKINGDB,
VectorType.UPSTASH,
}
# Define vector database types that support all retrieval methods
full_search_types = {
VectorType.QDRANT,
VectorType.WEAVIATE,
VectorType.OPENSEARCH,
VectorType.ANALYTICDB,
VectorType.MYSCALE,
VectorType.ORACLE,
VectorType.ELASTICSEARCH,
VectorType.ELASTICSEARCH_JA,
VectorType.PGVECTOR,
VectorType.VASTBASE,
VectorType.TIDB_ON_QDRANT,
VectorType.LINDORM,
VectorType.COUCHBASE,
VectorType.OPENGAUSS,
VectorType.OCEANBASE,
VectorType.TABLESTORE,
VectorType.HUAWEI_CLOUD,
VectorType.TENCENT,
VectorType.MATRIXONE,
VectorType.CLICKZETTA,
VectorType.BAIDU,
VectorType.ALIBABACLOUD_MYSQL,
}
semantic_methods = {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}
full_methods = {
"retrieval_method": [
RetrievalMethod.SEMANTIC_SEARCH.value,
RetrievalMethod.FULL_TEXT_SEARCH.value,
RetrievalMethod.HYBRID_SEARCH.value,
]
}
if vector_type == VectorType.MILVUS:
return semantic_methods if is_mock else full_methods
if vector_type in semantic_only_types:
return semantic_methods
elif vector_type in full_search_types:
return full_methods
else:
raise ValueError(f"Unsupported vector db type {vector_type}.")
@console_ns.route("/datasets")
@ -149,7 +217,7 @@ class DatasetListApi(Resource):
)
parser.add_argument(
"description",
type=_validate_description_length,
type=validate_description_length,
nullable=True,
required=False,
default="",
@ -290,7 +358,7 @@ class DatasetApi(Resource):
help="type is required. Name must be between 1 to 40 characters.",
type=_validate_name,
)
parser.add_argument("description", location="json", store_missing=False, type=_validate_description_length)
parser.add_argument("description", location="json", store_missing=False, type=validate_description_length)
parser.add_argument(
"indexing_technique",
type=str,
@ -505,7 +573,7 @@ class DatasetIndexingEstimateApi(Resource):
if file_details:
for file_detail in file_details:
extract_setting = ExtractSetting(
datasource_type=DatasourceType.FILE.value,
datasource_type=DatasourceType.FILE,
upload_file=file_detail,
document_model=args["doc_form"],
)
@ -517,14 +585,16 @@ class DatasetIndexingEstimateApi(Resource):
credential_id = notion_info.get("credential_id")
for page in notion_info["pages"]:
extract_setting = ExtractSetting(
datasource_type=DatasourceType.NOTION.value,
notion_info={
"credential_id": credential_id,
"notion_workspace_id": workspace_id,
"notion_obj_id": page["page_id"],
"notion_page_type": page["type"],
"tenant_id": current_user.current_tenant_id,
},
datasource_type=DatasourceType.NOTION,
notion_info=NotionInfo.model_validate(
{
"credential_id": credential_id,
"notion_workspace_id": workspace_id,
"notion_obj_id": page["page_id"],
"notion_page_type": page["type"],
"tenant_id": current_user.current_tenant_id,
}
),
document_model=args["doc_form"],
)
extract_settings.append(extract_setting)
@ -532,15 +602,17 @@ class DatasetIndexingEstimateApi(Resource):
website_info_list = args["info_list"]["website_info_list"]
for url in website_info_list["urls"]:
extract_setting = ExtractSetting(
datasource_type=DatasourceType.WEBSITE.value,
website_info={
"provider": website_info_list["provider"],
"job_id": website_info_list["job_id"],
"url": url,
"tenant_id": current_user.current_tenant_id,
"mode": "crawl",
"only_main_content": website_info_list["only_main_content"],
},
datasource_type=DatasourceType.WEBSITE,
website_info=WebsiteInfo.model_validate(
{
"provider": website_info_list["provider"],
"job_id": website_info_list["job_id"],
"url": url,
"tenant_id": current_user.current_tenant_id,
"mode": "crawl",
"only_main_content": website_info_list["only_main_content"],
}
),
document_model=args["doc_form"],
)
extract_settings.append(extract_setting)
@ -778,49 +850,7 @@ class DatasetRetrievalSettingApi(Resource):
@account_initialization_required
def get(self):
vector_type = dify_config.VECTOR_STORE
match vector_type:
case (
VectorType.RELYT
| VectorType.TIDB_VECTOR
| VectorType.CHROMA
| VectorType.PGVECTO_RS
| VectorType.VIKINGDB
| VectorType.UPSTASH
):
return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}
case (
VectorType.QDRANT
| VectorType.WEAVIATE
| VectorType.OPENSEARCH
| VectorType.ANALYTICDB
| VectorType.MYSCALE
| VectorType.ORACLE
| VectorType.ELASTICSEARCH
| VectorType.ELASTICSEARCH_JA
| VectorType.PGVECTOR
| VectorType.VASTBASE
| VectorType.TIDB_ON_QDRANT
| VectorType.LINDORM
| VectorType.COUCHBASE
| VectorType.MILVUS
| VectorType.OPENGAUSS
| VectorType.OCEANBASE
| VectorType.TABLESTORE
| VectorType.HUAWEI_CLOUD
| VectorType.TENCENT
| VectorType.MATRIXONE
| VectorType.CLICKZETTA
| VectorType.BAIDU
):
return {
"retrieval_method": [
RetrievalMethod.SEMANTIC_SEARCH.value,
RetrievalMethod.FULL_TEXT_SEARCH.value,
RetrievalMethod.HYBRID_SEARCH.value,
]
}
case _:
raise ValueError(f"Unsupported vector db type {vector_type}.")
return _get_retrieval_methods_by_vector_type(vector_type, is_mock=False)
@console_ns.route("/datasets/retrieval-setting/<string:vector_type>")
@ -833,48 +863,7 @@ class DatasetRetrievalSettingMockApi(Resource):
@login_required
@account_initialization_required
def get(self, vector_type):
match vector_type:
case (
VectorType.MILVUS
| VectorType.RELYT
| VectorType.TIDB_VECTOR
| VectorType.CHROMA
| VectorType.PGVECTO_RS
| VectorType.VIKINGDB
| VectorType.UPSTASH
):
return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}
case (
VectorType.QDRANT
| VectorType.WEAVIATE
| VectorType.OPENSEARCH
| VectorType.ANALYTICDB
| VectorType.MYSCALE
| VectorType.ORACLE
| VectorType.ELASTICSEARCH
| VectorType.ELASTICSEARCH_JA
| VectorType.COUCHBASE
| VectorType.PGVECTOR
| VectorType.VASTBASE
| VectorType.LINDORM
| VectorType.OPENGAUSS
| VectorType.OCEANBASE
| VectorType.TABLESTORE
| VectorType.TENCENT
| VectorType.HUAWEI_CLOUD
| VectorType.MATRIXONE
| VectorType.CLICKZETTA
| VectorType.BAIDU
):
return {
"retrieval_method": [
RetrievalMethod.SEMANTIC_SEARCH.value,
RetrievalMethod.FULL_TEXT_SEARCH.value,
RetrievalMethod.HYBRID_SEARCH.value,
]
}
case _:
raise ValueError(f"Unsupported vector db type {vector_type}.")
return _get_retrieval_methods_by_vector_type(vector_type, is_mock=True)
@console_ns.route("/datasets/<uuid:dataset_id>/error-docs")

View File

@ -44,7 +44,7 @@ from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.plugin.impl.exc import PluginDaemonClientSideError
from core.rag.extractor.entity.datasource_type import DatasourceType
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo, WebsiteInfo
from extensions.ext_database import db
from fields.document_fields import (
dataset_and_document_fields,
@ -305,7 +305,7 @@ class DatasetDocumentListApi(Resource):
"doc_language", type=str, default="English", required=False, nullable=False, location="json"
)
args = parser.parse_args()
knowledge_config = KnowledgeConfig(**args)
knowledge_config = KnowledgeConfig.model_validate(args)
if not dataset.indexing_technique and not knowledge_config.indexing_technique:
raise ValueError("indexing_technique is required.")
@ -395,7 +395,7 @@ class DatasetInitApi(Resource):
parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
args = parser.parse_args()
knowledge_config = KnowledgeConfig(**args)
knowledge_config = KnowledgeConfig.model_validate(args)
if knowledge_config.indexing_technique == "high_quality":
if knowledge_config.embedding_model is None or knowledge_config.embedding_model_provider is None:
raise ValueError("embedding model and embedding model provider are required for high quality indexing.")
@ -475,7 +475,7 @@ class DocumentIndexingEstimateApi(DocumentResource):
raise NotFound("File not found.")
extract_setting = ExtractSetting(
datasource_type=DatasourceType.FILE.value, upload_file=file, document_model=document.doc_form
datasource_type=DatasourceType.FILE, upload_file=file, document_model=document.doc_form
)
indexing_runner = IndexingRunner()
@ -538,7 +538,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
raise NotFound("File not found.")
extract_setting = ExtractSetting(
datasource_type=DatasourceType.FILE.value, upload_file=file_detail, document_model=document.doc_form
datasource_type=DatasourceType.FILE, upload_file=file_detail, document_model=document.doc_form
)
extract_settings.append(extract_setting)
@ -546,14 +546,16 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
if not data_source_info:
continue
extract_setting = ExtractSetting(
datasource_type=DatasourceType.NOTION.value,
notion_info={
"credential_id": data_source_info["credential_id"],
"notion_workspace_id": data_source_info["notion_workspace_id"],
"notion_obj_id": data_source_info["notion_page_id"],
"notion_page_type": data_source_info["type"],
"tenant_id": current_user.current_tenant_id,
},
datasource_type=DatasourceType.NOTION,
notion_info=NotionInfo.model_validate(
{
"credential_id": data_source_info["credential_id"],
"notion_workspace_id": data_source_info["notion_workspace_id"],
"notion_obj_id": data_source_info["notion_page_id"],
"notion_page_type": data_source_info["type"],
"tenant_id": current_user.current_tenant_id,
}
),
document_model=document.doc_form,
)
extract_settings.append(extract_setting)
@ -561,15 +563,17 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
if not data_source_info:
continue
extract_setting = ExtractSetting(
datasource_type=DatasourceType.WEBSITE.value,
website_info={
"provider": data_source_info["provider"],
"job_id": data_source_info["job_id"],
"url": data_source_info["url"],
"tenant_id": current_user.current_tenant_id,
"mode": data_source_info["mode"],
"only_main_content": data_source_info["only_main_content"],
},
datasource_type=DatasourceType.WEBSITE,
website_info=WebsiteInfo.model_validate(
{
"provider": data_source_info["provider"],
"job_id": data_source_info["job_id"],
"url": data_source_info["url"],
"tenant_id": current_user.current_tenant_id,
"mode": data_source_info["mode"],
"only_main_content": data_source_info["only_main_content"],
}
),
document_model=document.doc_form,
)
extract_settings.append(extract_setting)

View File

@ -309,7 +309,7 @@ class DatasetDocumentSegmentUpdateApi(Resource):
)
args = parser.parse_args()
SegmentService.segment_create_args_validate(args, document)
segment = SegmentService.update_segment(SegmentUpdateArgs(**args), segment, document, dataset)
segment = SegmentService.update_segment(SegmentUpdateArgs.model_validate(args), segment, document, dataset)
return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200
@setup_required
@ -564,7 +564,7 @@ class ChildChunkAddApi(Resource):
args = parser.parse_args()
try:
chunks_data = args["chunks"]
chunks = [ChildChunkUpdateArgs(**chunk) for chunk in chunks_data]
chunks = [ChildChunkUpdateArgs.model_validate(chunk) for chunk in chunks_data]
child_chunks = SegmentService.update_child_chunks(chunks, segment, document, dataset)
except ChildChunkIndexingServiceError as e:
raise ChildChunkIndexingError(str(e))

View File

@ -1,7 +1,5 @@
import logging
from typing import cast
from flask_login import current_user
from flask_restx import marshal, reqparse
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
@ -21,6 +19,7 @@ from core.errors.error import (
)
from core.model_runtime.errors.invoke import InvokeError
from fields.hit_testing_fields import hit_testing_record_fields
from libs.login import current_user
from models.account import Account
from services.dataset_service import DatasetService
from services.hit_testing_service import HitTestingService
@ -31,6 +30,7 @@ logger = logging.getLogger(__name__)
class DatasetsHitTestingBase:
@staticmethod
def get_and_validate_dataset(dataset_id: str):
assert isinstance(current_user, Account)
dataset = DatasetService.get_dataset(dataset_id)
if dataset is None:
raise NotFound("Dataset not found.")
@ -57,11 +57,12 @@ class DatasetsHitTestingBase:
@staticmethod
def perform_hit_testing(dataset, args):
assert isinstance(current_user, Account)
try:
response = HitTestingService.retrieve(
dataset=dataset,
query=args["query"],
account=cast(Account, current_user),
account=current_user,
retrieval_model=args["retrieval_model"],
external_retrieval_model=args["external_retrieval_model"],
limit=10,

View File

@ -28,7 +28,7 @@ class DatasetMetadataCreateApi(Resource):
parser.add_argument("type", type=str, required=True, nullable=False, location="json")
parser.add_argument("name", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
metadata_args = MetadataArgs(**args)
metadata_args = MetadataArgs.model_validate(args)
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
@ -137,7 +137,7 @@ class DocumentMetadataEditApi(Resource):
parser = reqparse.RequestParser()
parser.add_argument("operation_data", type=list, required=True, nullable=False, location="json")
args = parser.parse_args()
metadata_args = MetadataOperationData(**args)
metadata_args = MetadataOperationData.model_validate(args)
MetadataService.update_documents_metadata(dataset, metadata_args)

View File

@ -1,4 +1,3 @@
from fastapi.encoders import jsonable_encoder
from flask import make_response, redirect, request
from flask_login import current_user
from flask_restx import Resource, reqparse
@ -11,6 +10,7 @@ from controllers.console.wraps import (
setup_required,
)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.impl.oauth import OAuthHandler
from libs.helper import StrLen
from libs.login import login_required

View File

@ -88,7 +88,7 @@ class CustomizedPipelineTemplateApi(Resource):
nullable=True,
)
args = parser.parse_args()
pipeline_template_info = PipelineTemplateInfoEntity(**args)
pipeline_template_info = PipelineTemplateInfoEntity.model_validate(args)
RagPipelineService.update_customized_pipeline_template(template_id, pipeline_template_info)
return 200

View File

@ -60,9 +60,9 @@ class RagPipelineImportApi(Resource):
# Return appropriate status code based on result
status = result.status
if status == ImportStatus.FAILED.value:
if status == ImportStatus.FAILED:
return result.model_dump(mode="json"), 400
elif status == ImportStatus.PENDING.value:
elif status == ImportStatus.PENDING:
return result.model_dump(mode="json"), 202
return result.model_dump(mode="json"), 200
@ -87,7 +87,7 @@ class RagPipelineImportConfirmApi(Resource):
session.commit()
# Return appropriate status code based on result
if result.status == ImportStatus.FAILED.value:
if result.status == ImportStatus.FAILED:
return result.model_dump(mode="json"), 400
return result.model_dump(mode="json"), 200

View File

@ -6,7 +6,7 @@ from flask_restx import Resource, inputs, marshal_with, reqparse
from sqlalchemy import and_, select
from werkzeug.exceptions import BadRequest, Forbidden, NotFound
from controllers.console import api
from controllers.console import console_ns
from controllers.console.explore.wraps import InstalledAppResource
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
from extensions.ext_database import db
@ -22,6 +22,7 @@ from services.feature_service import FeatureService
logger = logging.getLogger(__name__)
@console_ns.route("/installed-apps")
class InstalledAppsListApi(Resource):
@login_required
@account_initialization_required
@ -154,6 +155,7 @@ class InstalledAppsListApi(Resource):
return {"message": "App installed successfully"}
@console_ns.route("/installed-apps/<uuid:installed_app_id>")
class InstalledAppApi(InstalledAppResource):
"""
update and delete an installed app
@ -185,7 +187,3 @@ class InstalledAppApi(InstalledAppResource):
db.session.commit()
return {"result": "success", "message": "App info updated successfully"}
api.add_resource(InstalledAppsListApi, "/installed-apps")
api.add_resource(InstalledAppApi, "/installed-apps/<uuid:installed_app_id>")

View File

@ -1,7 +1,7 @@
from flask_restx import marshal_with
from controllers.common import fields
from controllers.console import api
from controllers.console import console_ns
from controllers.console.app.error import AppUnavailableError
from controllers.console.explore.wraps import InstalledAppResource
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
@ -9,6 +9,7 @@ from models.model import AppMode, InstalledApp
from services.app_service import AppService
@console_ns.route("/installed-apps/<uuid:installed_app_id>/parameters", endpoint="installed_app_parameters")
class AppParameterApi(InstalledAppResource):
"""Resource for app variables."""
@ -39,6 +40,7 @@ class AppParameterApi(InstalledAppResource):
return get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
@console_ns.route("/installed-apps/<uuid:installed_app_id>/meta", endpoint="installed_app_meta")
class ExploreAppMetaApi(InstalledAppResource):
def get(self, installed_app: InstalledApp):
"""Get app meta"""
@ -46,9 +48,3 @@ class ExploreAppMetaApi(InstalledAppResource):
if not app_model:
raise ValueError("App not found")
return AppService().get_app_meta(app_model)
api.add_resource(
AppParameterApi, "/installed-apps/<uuid:installed_app_id>/parameters", endpoint="installed_app_parameters"
)
api.add_resource(ExploreAppMetaApi, "/installed-apps/<uuid:installed_app_id>/meta", endpoint="installed_app_meta")

View File

@ -1,7 +1,7 @@
from flask_restx import Resource, fields, marshal_with, reqparse
from constants.languages import languages
from controllers.console import api
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required
from libs.helper import AppIconUrlField
from libs.login import current_user, login_required
@ -35,6 +35,7 @@ recommended_app_list_fields = {
}
@console_ns.route("/explore/apps")
class RecommendedAppListApi(Resource):
@login_required
@account_initialization_required
@ -56,13 +57,10 @@ class RecommendedAppListApi(Resource):
return RecommendedAppService.get_recommended_apps_and_categories(language_prefix)
@console_ns.route("/explore/apps/<uuid:app_id>")
class RecommendedAppApi(Resource):
@login_required
@account_initialization_required
def get(self, app_id):
app_id = str(app_id)
return RecommendedAppService.get_recommend_app_detail(app_id)
api.add_resource(RecommendedAppListApi, "/explore/apps")
api.add_resource(RecommendedAppApi, "/explore/apps/<uuid:app_id>")

View File

@ -2,7 +2,7 @@ from flask_restx import fields, marshal_with, reqparse
from flask_restx.inputs import int_range
from werkzeug.exceptions import NotFound
from controllers.console import api
from controllers.console import console_ns
from controllers.console.explore.error import NotCompletionAppError
from controllers.console.explore.wraps import InstalledAppResource
from fields.conversation_fields import message_file_fields
@ -25,6 +25,7 @@ message_fields = {
}
@console_ns.route("/installed-apps/<uuid:installed_app_id>/saved-messages", endpoint="installed_app_saved_messages")
class SavedMessageListApi(InstalledAppResource):
saved_message_infinite_scroll_pagination_fields = {
"limit": fields.Integer,
@ -66,6 +67,9 @@ class SavedMessageListApi(InstalledAppResource):
return {"result": "success"}
@console_ns.route(
"/installed-apps/<uuid:installed_app_id>/saved-messages/<uuid:message_id>", endpoint="installed_app_saved_message"
)
class SavedMessageApi(InstalledAppResource):
def delete(self, installed_app, message_id):
app_model = installed_app.app
@ -80,15 +84,3 @@ class SavedMessageApi(InstalledAppResource):
SavedMessageService.delete(app_model, current_user, message_id)
return {"result": "success"}, 204
api.add_resource(
SavedMessageListApi,
"/installed-apps/<uuid:installed_app_id>/saved-messages",
endpoint="installed_app_saved_messages",
)
api.add_resource(
SavedMessageApi,
"/installed-apps/<uuid:installed_app_id>/saved-messages/<uuid:message_id>",
endpoint="installed_app_saved_message",
)

View File

@ -2,15 +2,15 @@ from collections.abc import Callable
from functools import wraps
from typing import Concatenate, ParamSpec, TypeVar
from flask_login import current_user
from flask_restx import Resource
from werkzeug.exceptions import NotFound
from controllers.console.explore.error import AppAccessDeniedError
from controllers.console.wraps import account_initialization_required
from extensions.ext_database import db
from libs.login import login_required
from libs.login import current_user, login_required
from models import InstalledApp
from models.account import Account
from services.app_service import AppService
from services.enterprise.enterprise_service import EnterpriseService
from services.feature_service import FeatureService
@ -24,6 +24,8 @@ def installed_app_required(view: Callable[Concatenate[InstalledApp, P], R] | Non
def decorator(view: Callable[Concatenate[InstalledApp, P], R]):
@wraps(view)
def decorated(installed_app_id: str, *args: P.args, **kwargs: P.kwargs):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
installed_app = (
db.session.query(InstalledApp)
.where(
@ -56,6 +58,7 @@ def user_allowed_to_access_app(view: Callable[Concatenate[InstalledApp, P], R] |
def decorated(installed_app: InstalledApp, *args: P.args, **kwargs: P.kwargs):
feature = FeatureService.get_system_features()
if feature.webapp_auth.enabled:
assert isinstance(current_user, Account)
app_id = installed_app.app_id
app_code = AppService.get_app_code_by_id(app_id)
res = EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(

View File

@ -1,11 +1,11 @@
from flask_login import current_user
from flask_restx import Resource, fields, marshal_with, reqparse
from constants import HIDDEN_VALUE
from controllers.console import api, console_ns
from controllers.console.wraps import account_initialization_required, setup_required
from fields.api_based_extension_fields import api_based_extension_fields
from libs.login import login_required
from libs.login import current_user, login_required
from models.account import Account
from models.api_based_extension import APIBasedExtension
from services.api_based_extension_service import APIBasedExtensionService
from services.code_based_extension_service import CodeBasedExtensionService
@ -47,6 +47,8 @@ class APIBasedExtensionAPI(Resource):
@account_initialization_required
@marshal_with(api_based_extension_fields)
def get(self):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
tenant_id = current_user.current_tenant_id
return APIBasedExtensionService.get_all_by_tenant_id(tenant_id)
@ -68,6 +70,8 @@ class APIBasedExtensionAPI(Resource):
@account_initialization_required
@marshal_with(api_based_extension_fields)
def post(self):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
parser = reqparse.RequestParser()
parser.add_argument("name", type=str, required=True, location="json")
parser.add_argument("api_endpoint", type=str, required=True, location="json")
@ -95,6 +99,8 @@ class APIBasedExtensionDetailAPI(Resource):
@account_initialization_required
@marshal_with(api_based_extension_fields)
def get(self, id):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
api_based_extension_id = str(id)
tenant_id = current_user.current_tenant_id
@ -119,6 +125,8 @@ class APIBasedExtensionDetailAPI(Resource):
@account_initialization_required
@marshal_with(api_based_extension_fields)
def post(self, id):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
api_based_extension_id = str(id)
tenant_id = current_user.current_tenant_id
@ -146,6 +154,8 @@ class APIBasedExtensionDetailAPI(Resource):
@login_required
@account_initialization_required
def delete(self, id):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
api_based_extension_id = str(id)
tenant_id = current_user.current_tenant_id

View File

@ -1,7 +1,7 @@
from flask_login import current_user
from flask_restx import Resource, fields
from libs.login import login_required
from libs.login import current_user, login_required
from models.account import Account
from services.feature_service import FeatureService
from . import api, console_ns
@ -23,6 +23,8 @@ class FeatureApi(Resource):
@cloud_utm_record
def get(self):
"""Get feature configuration for current tenant"""
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
return FeatureService.get_features(current_user.current_tenant_id).model_dump()

View File

@ -1,8 +1,6 @@
import urllib.parse
from typing import cast
import httpx
from flask_login import current_user
from flask_restx import Resource, marshal_with, reqparse
import services
@ -16,6 +14,7 @@ from core.file import helpers as file_helpers
from core.helper import ssrf_proxy
from extensions.ext_database import db
from fields.file_fields import file_fields_with_signed_url, remote_file_info_fields
from libs.login import current_user
from models.account import Account
from services.file_service import FileService
@ -65,7 +64,8 @@ class RemoteFileUploadApi(Resource):
content = resp.content if resp.request.method == "GET" else ssrf_proxy.get(url).content
try:
user = cast(Account, current_user)
assert isinstance(current_user, Account)
user = current_user
upload_file = FileService(db.engine).upload_file(
filename=file_info.filename,
content=content,

View File

@ -1,12 +1,12 @@
from flask import request
from flask_login import current_user
from flask_restx import Resource, marshal_with, reqparse
from werkzeug.exceptions import Forbidden
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, setup_required
from fields.tag_fields import dataset_tag_fields
from libs.login import login_required
from libs.login import current_user, login_required
from models.account import Account
from models.model import Tag
from services.tag_service import TagService
@ -24,6 +24,8 @@ class TagListApi(Resource):
@account_initialization_required
@marshal_with(dataset_tag_fields)
def get(self):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
tag_type = request.args.get("type", type=str, default="")
keyword = request.args.get("keyword", default=None, type=str)
tags = TagService.get_tags(tag_type, current_user.current_tenant_id, keyword)
@ -34,8 +36,10 @@ class TagListApi(Resource):
@login_required
@account_initialization_required
def post(self):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
# The role of the current user in the ta table must be admin, owner, or editor
if not (current_user.is_editor or current_user.is_dataset_editor):
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden()
parser = reqparse.RequestParser()
@ -59,9 +63,11 @@ class TagUpdateDeleteApi(Resource):
@login_required
@account_initialization_required
def patch(self, tag_id):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
tag_id = str(tag_id)
# The role of the current user in the ta table must be admin, owner, or editor
if not (current_user.is_editor or current_user.is_dataset_editor):
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden()
parser = reqparse.RequestParser()
@ -81,9 +87,11 @@ class TagUpdateDeleteApi(Resource):
@login_required
@account_initialization_required
def delete(self, tag_id):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
tag_id = str(tag_id)
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
if not current_user.has_edit_permission:
raise Forbidden()
TagService.delete_tag(tag_id)
@ -97,8 +105,10 @@ class TagBindingCreateApi(Resource):
@login_required
@account_initialization_required
def post(self):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
if not (current_user.is_editor or current_user.is_dataset_editor):
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden()
parser = reqparse.RequestParser()
@ -123,8 +133,10 @@ class TagBindingDeleteApi(Resource):
@login_required
@account_initialization_required
def post(self):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
if not (current_user.is_editor or current_user.is_dataset_editor):
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden()
parser = reqparse.RequestParser()

View File

@ -9,7 +9,7 @@ from sqlalchemy.orm import Session
from configs import dify_config
from constants.languages import supported_language
from controllers.console import api
from controllers.console import console_ns
from controllers.console.auth.error import (
EmailAlreadyInUseError,
EmailChangeLimitError,
@ -45,6 +45,7 @@ from services.billing_service import BillingService
from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError
@console_ns.route("/account/init")
class AccountInitApi(Resource):
@setup_required
@login_required
@ -97,6 +98,7 @@ class AccountInitApi(Resource):
return {"result": "success"}
@console_ns.route("/account/profile")
class AccountProfileApi(Resource):
@setup_required
@login_required
@ -109,6 +111,7 @@ class AccountProfileApi(Resource):
return current_user
@console_ns.route("/account/name")
class AccountNameApi(Resource):
@setup_required
@login_required
@ -130,6 +133,7 @@ class AccountNameApi(Resource):
return updated_account
@console_ns.route("/account/avatar")
class AccountAvatarApi(Resource):
@setup_required
@login_required
@ -147,6 +151,7 @@ class AccountAvatarApi(Resource):
return updated_account
@console_ns.route("/account/interface-language")
class AccountInterfaceLanguageApi(Resource):
@setup_required
@login_required
@ -164,6 +169,7 @@ class AccountInterfaceLanguageApi(Resource):
return updated_account
@console_ns.route("/account/interface-theme")
class AccountInterfaceThemeApi(Resource):
@setup_required
@login_required
@ -181,6 +187,7 @@ class AccountInterfaceThemeApi(Resource):
return updated_account
@console_ns.route("/account/timezone")
class AccountTimezoneApi(Resource):
@setup_required
@login_required
@ -202,6 +209,7 @@ class AccountTimezoneApi(Resource):
return updated_account
@console_ns.route("/account/password")
class AccountPasswordApi(Resource):
@setup_required
@login_required
@ -227,6 +235,7 @@ class AccountPasswordApi(Resource):
return {"result": "success"}
@console_ns.route("/account/integrates")
class AccountIntegrateApi(Resource):
integrate_fields = {
"provider": fields.String,
@ -283,6 +292,7 @@ class AccountIntegrateApi(Resource):
return {"data": integrate_data}
@console_ns.route("/account/delete/verify")
class AccountDeleteVerifyApi(Resource):
@setup_required
@login_required
@ -298,6 +308,7 @@ class AccountDeleteVerifyApi(Resource):
return {"result": "success", "data": token}
@console_ns.route("/account/delete")
class AccountDeleteApi(Resource):
@setup_required
@login_required
@ -320,6 +331,7 @@ class AccountDeleteApi(Resource):
return {"result": "success"}
@console_ns.route("/account/delete/feedback")
class AccountDeleteUpdateFeedbackApi(Resource):
@setup_required
def post(self):
@ -333,6 +345,7 @@ class AccountDeleteUpdateFeedbackApi(Resource):
return {"result": "success"}
@console_ns.route("/account/education/verify")
class EducationVerifyApi(Resource):
verify_fields = {
"token": fields.String,
@ -352,6 +365,7 @@ class EducationVerifyApi(Resource):
return BillingService.EducationIdentity.verify(account.id, account.email)
@console_ns.route("/account/education")
class EducationApi(Resource):
status_fields = {
"result": fields.Boolean,
@ -396,6 +410,7 @@ class EducationApi(Resource):
return res
@console_ns.route("/account/education/autocomplete")
class EducationAutoCompleteApi(Resource):
data_fields = {
"data": fields.List(fields.String),
@ -419,6 +434,7 @@ class EducationAutoCompleteApi(Resource):
return BillingService.EducationIdentity.autocomplete(args["keywords"], args["page"], args["limit"])
@console_ns.route("/account/change-email")
class ChangeEmailSendEmailApi(Resource):
@enable_change_email
@setup_required
@ -467,6 +483,7 @@ class ChangeEmailSendEmailApi(Resource):
return {"result": "success", "data": token}
@console_ns.route("/account/change-email/validity")
class ChangeEmailCheckApi(Resource):
@enable_change_email
@setup_required
@ -508,6 +525,7 @@ class ChangeEmailCheckApi(Resource):
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
@console_ns.route("/account/change-email/reset")
class ChangeEmailResetApi(Resource):
@enable_change_email
@setup_required
@ -547,6 +565,7 @@ class ChangeEmailResetApi(Resource):
return updated_account
@console_ns.route("/account/change-email/check-email-unique")
class CheckEmailUnique(Resource):
@setup_required
def post(self):
@ -558,28 +577,3 @@ class CheckEmailUnique(Resource):
if not AccountService.check_email_unique(args["email"]):
raise EmailAlreadyInUseError()
return {"result": "success"}
# Register API resources
api.add_resource(AccountInitApi, "/account/init")
api.add_resource(AccountProfileApi, "/account/profile")
api.add_resource(AccountNameApi, "/account/name")
api.add_resource(AccountAvatarApi, "/account/avatar")
api.add_resource(AccountInterfaceLanguageApi, "/account/interface-language")
api.add_resource(AccountInterfaceThemeApi, "/account/interface-theme")
api.add_resource(AccountTimezoneApi, "/account/timezone")
api.add_resource(AccountPasswordApi, "/account/password")
api.add_resource(AccountIntegrateApi, "/account/integrates")
api.add_resource(AccountDeleteVerifyApi, "/account/delete/verify")
api.add_resource(AccountDeleteApi, "/account/delete")
api.add_resource(AccountDeleteUpdateFeedbackApi, "/account/delete/feedback")
api.add_resource(EducationVerifyApi, "/account/education/verify")
api.add_resource(EducationApi, "/account/education")
api.add_resource(EducationAutoCompleteApi, "/account/education/autocomplete")
# Change email
api.add_resource(ChangeEmailSendEmailApi, "/account/change-email")
api.add_resource(ChangeEmailCheckApi, "/account/change-email/validity")
api.add_resource(ChangeEmailResetApi, "/account/change-email/reset")
api.add_resource(CheckEmailUnique, "/account/change-email/check-email-unique")
# api.add_resource(AccountEmailApi, '/account/email')
# api.add_resource(AccountEmailVerifyApi, '/account/email-verify')

View File

@ -1,10 +1,10 @@
from flask_login import current_user
from flask_restx import Resource, fields
from controllers.console import api, console_ns
from controllers.console.wraps import account_initialization_required, setup_required
from core.model_runtime.utils.encoders import jsonable_encoder
from libs.login import login_required
from libs.login import current_user, login_required
from models.account import Account
from services.agent_service import AgentService
@ -21,7 +21,9 @@ class AgentProviderListApi(Resource):
@login_required
@account_initialization_required
def get(self):
assert isinstance(current_user, Account)
user = current_user
assert user.current_tenant_id is not None
user_id = user.id
tenant_id = user.current_tenant_id
@ -43,7 +45,9 @@ class AgentProviderApi(Resource):
@login_required
@account_initialization_required
def get(self, provider_name: str):
assert isinstance(current_user, Account)
user = current_user
assert user.current_tenant_id is not None
user_id = user.id
tenant_id = user.current_tenant_id
return jsonable_encoder(AgentService.get_agent_provider(user_id, tenant_id, provider_name))

View File

@ -1,4 +1,3 @@
from flask_login import current_user
from flask_restx import Resource, fields, reqparse
from werkzeug.exceptions import Forbidden
@ -6,10 +5,18 @@ from controllers.console import api, console_ns
from controllers.console.wraps import account_initialization_required, setup_required
from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.impl.exc import PluginPermissionDeniedError
from libs.login import login_required
from libs.login import current_user, login_required
from models.account import Account
from services.plugin.endpoint_service import EndpointService
def _current_account_with_tenant() -> tuple[Account, str]:
assert isinstance(current_user, Account)
tenant_id = current_user.current_tenant_id
assert tenant_id is not None
return current_user, tenant_id
@console_ns.route("/workspaces/current/endpoints/create")
class EndpointCreateApi(Resource):
@api.doc("create_endpoint")
@ -34,7 +41,7 @@ class EndpointCreateApi(Resource):
@login_required
@account_initialization_required
def post(self):
user = current_user
user, tenant_id = _current_account_with_tenant()
if not user.is_admin_or_owner:
raise Forbidden()
@ -51,7 +58,7 @@ class EndpointCreateApi(Resource):
try:
return {
"success": EndpointService.create_endpoint(
tenant_id=user.current_tenant_id,
tenant_id=tenant_id,
user_id=user.id,
plugin_unique_identifier=plugin_unique_identifier,
name=name,
@ -80,7 +87,7 @@ class EndpointListApi(Resource):
@login_required
@account_initialization_required
def get(self):
user = current_user
user, tenant_id = _current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("page", type=int, required=True, location="args")
@ -93,7 +100,7 @@ class EndpointListApi(Resource):
return jsonable_encoder(
{
"endpoints": EndpointService.list_endpoints(
tenant_id=user.current_tenant_id,
tenant_id=tenant_id,
user_id=user.id,
page=page,
page_size=page_size,
@ -123,7 +130,7 @@ class EndpointListForSinglePluginApi(Resource):
@login_required
@account_initialization_required
def get(self):
user = current_user
user, tenant_id = _current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("page", type=int, required=True, location="args")
@ -138,7 +145,7 @@ class EndpointListForSinglePluginApi(Resource):
return jsonable_encoder(
{
"endpoints": EndpointService.list_endpoints_for_single_plugin(
tenant_id=user.current_tenant_id,
tenant_id=tenant_id,
user_id=user.id,
plugin_id=plugin_id,
page=page,
@ -165,7 +172,7 @@ class EndpointDeleteApi(Resource):
@login_required
@account_initialization_required
def post(self):
user = current_user
user, tenant_id = _current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("endpoint_id", type=str, required=True)
@ -177,9 +184,7 @@ class EndpointDeleteApi(Resource):
endpoint_id = args["endpoint_id"]
return {
"success": EndpointService.delete_endpoint(
tenant_id=user.current_tenant_id, user_id=user.id, endpoint_id=endpoint_id
)
"success": EndpointService.delete_endpoint(tenant_id=tenant_id, user_id=user.id, endpoint_id=endpoint_id)
}
@ -207,7 +212,7 @@ class EndpointUpdateApi(Resource):
@login_required
@account_initialization_required
def post(self):
user = current_user
user, tenant_id = _current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("endpoint_id", type=str, required=True)
@ -224,7 +229,7 @@ class EndpointUpdateApi(Resource):
return {
"success": EndpointService.update_endpoint(
tenant_id=user.current_tenant_id,
tenant_id=tenant_id,
user_id=user.id,
endpoint_id=endpoint_id,
name=name,
@ -250,7 +255,7 @@ class EndpointEnableApi(Resource):
@login_required
@account_initialization_required
def post(self):
user = current_user
user, tenant_id = _current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("endpoint_id", type=str, required=True)
@ -262,9 +267,7 @@ class EndpointEnableApi(Resource):
raise Forbidden()
return {
"success": EndpointService.enable_endpoint(
tenant_id=user.current_tenant_id, user_id=user.id, endpoint_id=endpoint_id
)
"success": EndpointService.enable_endpoint(tenant_id=tenant_id, user_id=user.id, endpoint_id=endpoint_id)
}
@ -285,7 +288,7 @@ class EndpointDisableApi(Resource):
@login_required
@account_initialization_required
def post(self):
user = current_user
user, tenant_id = _current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("endpoint_id", type=str, required=True)
@ -297,7 +300,5 @@ class EndpointDisableApi(Resource):
raise Forbidden()
return {
"success": EndpointService.disable_endpoint(
tenant_id=user.current_tenant_id, user_id=user.id, endpoint_id=endpoint_id
)
"success": EndpointService.disable_endpoint(tenant_id=tenant_id, user_id=user.id, endpoint_id=endpoint_id)
}

View File

@ -1,7 +1,7 @@
from flask_restx import Resource, reqparse
from werkzeug.exceptions import Forbidden
from controllers.console import api
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, setup_required
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.validate import CredentialsValidateFailedError
@ -10,6 +10,9 @@ from models.account import Account, TenantAccountRole
from services.model_load_balancing_service import ModelLoadBalancingService
@console_ns.route(
"/workspaces/current/model-providers/<path:provider>/models/load-balancing-configs/credentials-validate"
)
class LoadBalancingCredentialsValidateApi(Resource):
@setup_required
@login_required
@ -61,6 +64,9 @@ class LoadBalancingCredentialsValidateApi(Resource):
return response
@console_ns.route(
"/workspaces/current/model-providers/<path:provider>/models/load-balancing-configs/<string:config_id>/credentials-validate"
)
class LoadBalancingConfigCredentialsValidateApi(Resource):
@setup_required
@login_required
@ -111,15 +117,3 @@ class LoadBalancingConfigCredentialsValidateApi(Resource):
response["error"] = error
return response
# Load Balancing Config
api.add_resource(
LoadBalancingCredentialsValidateApi,
"/workspaces/current/model-providers/<path:provider>/models/load-balancing-configs/credentials-validate",
)
api.add_resource(
LoadBalancingConfigCredentialsValidateApi,
"/workspaces/current/model-providers/<path:provider>/models/load-balancing-configs/<string:config_id>/credentials-validate",
)

View File

@ -1,12 +1,11 @@
from urllib import parse
from flask import abort, request
from flask_login import current_user
from flask_restx import Resource, marshal_with, reqparse
import services
from configs import dify_config
from controllers.console import api
from controllers.console import console_ns
from controllers.console.auth.error import (
CannotTransferOwnerToSelfError,
EmailCodeError,
@ -26,13 +25,14 @@ from controllers.console.wraps import (
from extensions.ext_database import db
from fields.member_fields import account_with_role_list_fields
from libs.helper import extract_remote_ip
from libs.login import login_required
from libs.login import current_user, login_required
from models.account import Account, TenantAccountRole
from services.account_service import AccountService, RegisterService, TenantService
from services.errors.account import AccountAlreadyInTenantError
from services.feature_service import FeatureService
@console_ns.route("/workspaces/current/members")
class MemberListApi(Resource):
"""List all members of current tenant."""
@ -49,6 +49,7 @@ class MemberListApi(Resource):
return {"result": "success", "accounts": members}, 200
@console_ns.route("/workspaces/current/members/invite-email")
class MemberInviteEmailApi(Resource):
"""Invite a new member by email."""
@ -111,6 +112,7 @@ class MemberInviteEmailApi(Resource):
}, 201
@console_ns.route("/workspaces/current/members/<uuid:member_id>")
class MemberCancelInviteApi(Resource):
"""Cancel an invitation by member id."""
@ -143,6 +145,7 @@ class MemberCancelInviteApi(Resource):
}, 200
@console_ns.route("/workspaces/current/members/<uuid:member_id>/update-role")
class MemberUpdateRoleApi(Resource):
"""Update member role."""
@ -177,6 +180,7 @@ class MemberUpdateRoleApi(Resource):
return {"result": "success"}
@console_ns.route("/workspaces/current/dataset-operators")
class DatasetOperatorMemberListApi(Resource):
"""List all members of current tenant."""
@ -193,6 +197,7 @@ class DatasetOperatorMemberListApi(Resource):
return {"result": "success", "accounts": members}, 200
@console_ns.route("/workspaces/current/members/send-owner-transfer-confirm-email")
class SendOwnerTransferEmailApi(Resource):
"""Send owner transfer email."""
@ -233,6 +238,7 @@ class SendOwnerTransferEmailApi(Resource):
return {"result": "success", "data": token}
@console_ns.route("/workspaces/current/members/owner-transfer-check")
class OwnerTransferCheckApi(Resource):
@setup_required
@login_required
@ -278,6 +284,7 @@ class OwnerTransferCheckApi(Resource):
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
@console_ns.route("/workspaces/current/members/<uuid:member_id>/owner-transfer")
class OwnerTransfer(Resource):
@setup_required
@login_required
@ -339,14 +346,3 @@ class OwnerTransfer(Resource):
raise ValueError(str(e))
return {"result": "success"}
api.add_resource(MemberListApi, "/workspaces/current/members")
api.add_resource(MemberInviteEmailApi, "/workspaces/current/members/invite-email")
api.add_resource(MemberCancelInviteApi, "/workspaces/current/members/<uuid:member_id>")
api.add_resource(MemberUpdateRoleApi, "/workspaces/current/members/<uuid:member_id>/update-role")
api.add_resource(DatasetOperatorMemberListApi, "/workspaces/current/dataset-operators")
# owner transfer
api.add_resource(SendOwnerTransferEmailApi, "/workspaces/current/members/send-owner-transfer-confirm-email")
api.add_resource(OwnerTransferCheckApi, "/workspaces/current/members/owner-transfer-check")
api.add_resource(OwnerTransfer, "/workspaces/current/members/<uuid:member_id>/owner-transfer")

View File

@ -5,7 +5,7 @@ from flask_login import current_user
from flask_restx import Resource, reqparse
from werkzeug.exceptions import Forbidden
from controllers.console import api
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, setup_required
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.validate import CredentialsValidateFailedError
@ -17,6 +17,7 @@ from services.billing_service import BillingService
from services.model_provider_service import ModelProviderService
@console_ns.route("/workspaces/current/model-providers")
class ModelProviderListApi(Resource):
@setup_required
@login_required
@ -45,6 +46,7 @@ class ModelProviderListApi(Resource):
return jsonable_encoder({"data": provider_list})
@console_ns.route("/workspaces/current/model-providers/<path:provider>/credentials")
class ModelProviderCredentialApi(Resource):
@setup_required
@login_required
@ -151,6 +153,7 @@ class ModelProviderCredentialApi(Resource):
return {"result": "success"}, 204
@console_ns.route("/workspaces/current/model-providers/<path:provider>/credentials/switch")
class ModelProviderCredentialSwitchApi(Resource):
@setup_required
@login_required
@ -175,6 +178,7 @@ class ModelProviderCredentialSwitchApi(Resource):
return {"result": "success"}
@console_ns.route("/workspaces/current/model-providers/<path:provider>/credentials/validate")
class ModelProviderValidateApi(Resource):
@setup_required
@login_required
@ -211,6 +215,7 @@ class ModelProviderValidateApi(Resource):
return response
@console_ns.route("/workspaces/<string:tenant_id>/model-providers/<path:provider>/<string:icon_type>/<string:lang>")
class ModelProviderIconApi(Resource):
"""
Get model provider icon
@ -229,6 +234,7 @@ class ModelProviderIconApi(Resource):
return send_file(io.BytesIO(icon), mimetype=mimetype)
@console_ns.route("/workspaces/current/model-providers/<path:provider>/preferred-provider-type")
class PreferredProviderTypeUpdateApi(Resource):
@setup_required
@login_required
@ -262,6 +268,7 @@ class PreferredProviderTypeUpdateApi(Resource):
return {"result": "success"}
@console_ns.route("/workspaces/current/model-providers/<path:provider>/checkout-url")
class ModelProviderPaymentCheckoutUrlApi(Resource):
@setup_required
@login_required
@ -281,21 +288,3 @@ class ModelProviderPaymentCheckoutUrlApi(Resource):
prefilled_email=current_user.email,
)
return data
api.add_resource(ModelProviderListApi, "/workspaces/current/model-providers")
api.add_resource(ModelProviderCredentialApi, "/workspaces/current/model-providers/<path:provider>/credentials")
api.add_resource(
ModelProviderCredentialSwitchApi, "/workspaces/current/model-providers/<path:provider>/credentials/switch"
)
api.add_resource(ModelProviderValidateApi, "/workspaces/current/model-providers/<path:provider>/credentials/validate")
api.add_resource(
PreferredProviderTypeUpdateApi, "/workspaces/current/model-providers/<path:provider>/preferred-provider-type"
)
api.add_resource(ModelProviderPaymentCheckoutUrlApi, "/workspaces/current/model-providers/<path:provider>/checkout-url")
api.add_resource(
ModelProviderIconApi,
"/workspaces/<string:tenant_id>/model-providers/<path:provider>/<string:icon_type>/<string:lang>",
)

View File

@ -4,7 +4,7 @@ from flask_login import current_user
from flask_restx import Resource, reqparse
from werkzeug.exceptions import Forbidden
from controllers.console import api
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, setup_required
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.validate import CredentialsValidateFailedError
@ -17,6 +17,7 @@ from services.model_provider_service import ModelProviderService
logger = logging.getLogger(__name__)
@console_ns.route("/workspaces/current/default-model")
class DefaultModelApi(Resource):
@setup_required
@login_required
@ -85,6 +86,7 @@ class DefaultModelApi(Resource):
return {"result": "success"}
@console_ns.route("/workspaces/current/model-providers/<path:provider>/models")
class ModelProviderModelApi(Resource):
@setup_required
@login_required
@ -187,6 +189,7 @@ class ModelProviderModelApi(Resource):
return {"result": "success"}, 204
@console_ns.route("/workspaces/current/model-providers/<path:provider>/models/credentials")
class ModelProviderModelCredentialApi(Resource):
@setup_required
@login_required
@ -364,6 +367,7 @@ class ModelProviderModelCredentialApi(Resource):
return {"result": "success"}, 204
@console_ns.route("/workspaces/current/model-providers/<path:provider>/models/credentials/switch")
class ModelProviderModelCredentialSwitchApi(Resource):
@setup_required
@login_required
@ -395,6 +399,9 @@ class ModelProviderModelCredentialSwitchApi(Resource):
return {"result": "success"}
@console_ns.route(
"/workspaces/current/model-providers/<path:provider>/models/enable", endpoint="model-provider-model-enable"
)
class ModelProviderModelEnableApi(Resource):
@setup_required
@login_required
@ -422,6 +429,9 @@ class ModelProviderModelEnableApi(Resource):
return {"result": "success"}
@console_ns.route(
"/workspaces/current/model-providers/<path:provider>/models/disable", endpoint="model-provider-model-disable"
)
class ModelProviderModelDisableApi(Resource):
@setup_required
@login_required
@ -449,6 +459,7 @@ class ModelProviderModelDisableApi(Resource):
return {"result": "success"}
@console_ns.route("/workspaces/current/model-providers/<path:provider>/models/credentials/validate")
class ModelProviderModelValidateApi(Resource):
@setup_required
@login_required
@ -494,6 +505,7 @@ class ModelProviderModelValidateApi(Resource):
return response
@console_ns.route("/workspaces/current/model-providers/<path:provider>/models/parameter-rules")
class ModelProviderModelParameterRuleApi(Resource):
@setup_required
@login_required
@ -513,6 +525,7 @@ class ModelProviderModelParameterRuleApi(Resource):
return jsonable_encoder({"data": parameter_rules})
@console_ns.route("/workspaces/current/models/model-types/<string:model_type>")
class ModelProviderAvailableModelApi(Resource):
@setup_required
@login_required
@ -524,32 +537,3 @@ class ModelProviderAvailableModelApi(Resource):
models = model_provider_service.get_models_by_model_type(tenant_id=tenant_id, model_type=model_type)
return jsonable_encoder({"data": models})
api.add_resource(ModelProviderModelApi, "/workspaces/current/model-providers/<path:provider>/models")
api.add_resource(
ModelProviderModelEnableApi,
"/workspaces/current/model-providers/<path:provider>/models/enable",
endpoint="model-provider-model-enable",
)
api.add_resource(
ModelProviderModelDisableApi,
"/workspaces/current/model-providers/<path:provider>/models/disable",
endpoint="model-provider-model-disable",
)
api.add_resource(
ModelProviderModelCredentialApi, "/workspaces/current/model-providers/<path:provider>/models/credentials"
)
api.add_resource(
ModelProviderModelCredentialSwitchApi,
"/workspaces/current/model-providers/<path:provider>/models/credentials/switch",
)
api.add_resource(
ModelProviderModelValidateApi, "/workspaces/current/model-providers/<path:provider>/models/credentials/validate"
)
api.add_resource(
ModelProviderModelParameterRuleApi, "/workspaces/current/model-providers/<path:provider>/models/parameter-rules"
)
api.add_resource(ModelProviderAvailableModelApi, "/workspaces/current/models/model-types/<string:model_type>")
api.add_resource(DefaultModelApi, "/workspaces/current/default-model")

View File

@ -6,7 +6,7 @@ from flask_restx import Resource, reqparse
from werkzeug.exceptions import Forbidden
from configs import dify_config
from controllers.console import api
from controllers.console import console_ns
from controllers.console.workspace import plugin_permission_required
from controllers.console.wraps import account_initialization_required, setup_required
from core.model_runtime.utils.encoders import jsonable_encoder
@ -19,6 +19,7 @@ from services.plugin.plugin_permission_service import PluginPermissionService
from services.plugin.plugin_service import PluginService
@console_ns.route("/workspaces/current/plugin/debugging-key")
class PluginDebuggingKeyApi(Resource):
@setup_required
@login_required
@ -37,6 +38,7 @@ class PluginDebuggingKeyApi(Resource):
raise ValueError(e)
@console_ns.route("/workspaces/current/plugin/list")
class PluginListApi(Resource):
@setup_required
@login_required
@ -55,6 +57,7 @@ class PluginListApi(Resource):
return jsonable_encoder({"plugins": plugins_with_total.list, "total": plugins_with_total.total})
@console_ns.route("/workspaces/current/plugin/list/latest-versions")
class PluginListLatestVersionsApi(Resource):
@setup_required
@login_required
@ -72,6 +75,7 @@ class PluginListLatestVersionsApi(Resource):
return jsonable_encoder({"versions": versions})
@console_ns.route("/workspaces/current/plugin/list/installations/ids")
class PluginListInstallationsFromIdsApi(Resource):
@setup_required
@login_required
@ -91,6 +95,7 @@ class PluginListInstallationsFromIdsApi(Resource):
return jsonable_encoder({"plugins": plugins})
@console_ns.route("/workspaces/current/plugin/icon")
class PluginIconApi(Resource):
@setup_required
def get(self):
@ -108,6 +113,7 @@ class PluginIconApi(Resource):
return send_file(io.BytesIO(icon_bytes), mimetype=mimetype, max_age=icon_cache_max_age)
@console_ns.route("/workspaces/current/plugin/upload/pkg")
class PluginUploadFromPkgApi(Resource):
@setup_required
@login_required
@ -131,6 +137,7 @@ class PluginUploadFromPkgApi(Resource):
return jsonable_encoder(response)
@console_ns.route("/workspaces/current/plugin/upload/github")
class PluginUploadFromGithubApi(Resource):
@setup_required
@login_required
@ -153,6 +160,7 @@ class PluginUploadFromGithubApi(Resource):
return jsonable_encoder(response)
@console_ns.route("/workspaces/current/plugin/upload/bundle")
class PluginUploadFromBundleApi(Resource):
@setup_required
@login_required
@ -176,6 +184,7 @@ class PluginUploadFromBundleApi(Resource):
return jsonable_encoder(response)
@console_ns.route("/workspaces/current/plugin/install/pkg")
class PluginInstallFromPkgApi(Resource):
@setup_required
@login_required
@ -201,6 +210,7 @@ class PluginInstallFromPkgApi(Resource):
return jsonable_encoder(response)
@console_ns.route("/workspaces/current/plugin/install/github")
class PluginInstallFromGithubApi(Resource):
@setup_required
@login_required
@ -230,6 +240,7 @@ class PluginInstallFromGithubApi(Resource):
return jsonable_encoder(response)
@console_ns.route("/workspaces/current/plugin/install/marketplace")
class PluginInstallFromMarketplaceApi(Resource):
@setup_required
@login_required
@ -255,6 +266,7 @@ class PluginInstallFromMarketplaceApi(Resource):
return jsonable_encoder(response)
@console_ns.route("/workspaces/current/plugin/marketplace/pkg")
class PluginFetchMarketplacePkgApi(Resource):
@setup_required
@login_required
@ -280,6 +292,7 @@ class PluginFetchMarketplacePkgApi(Resource):
raise ValueError(e)
@console_ns.route("/workspaces/current/plugin/fetch-manifest")
class PluginFetchManifestApi(Resource):
@setup_required
@login_required
@ -304,6 +317,7 @@ class PluginFetchManifestApi(Resource):
raise ValueError(e)
@console_ns.route("/workspaces/current/plugin/tasks")
class PluginFetchInstallTasksApi(Resource):
@setup_required
@login_required
@ -325,6 +339,7 @@ class PluginFetchInstallTasksApi(Resource):
raise ValueError(e)
@console_ns.route("/workspaces/current/plugin/tasks/<task_id>")
class PluginFetchInstallTaskApi(Resource):
@setup_required
@login_required
@ -339,6 +354,7 @@ class PluginFetchInstallTaskApi(Resource):
raise ValueError(e)
@console_ns.route("/workspaces/current/plugin/tasks/<task_id>/delete")
class PluginDeleteInstallTaskApi(Resource):
@setup_required
@login_required
@ -353,6 +369,7 @@ class PluginDeleteInstallTaskApi(Resource):
raise ValueError(e)
@console_ns.route("/workspaces/current/plugin/tasks/delete_all")
class PluginDeleteAllInstallTaskItemsApi(Resource):
@setup_required
@login_required
@ -367,6 +384,7 @@ class PluginDeleteAllInstallTaskItemsApi(Resource):
raise ValueError(e)
@console_ns.route("/workspaces/current/plugin/tasks/<task_id>/delete/<path:identifier>")
class PluginDeleteInstallTaskItemApi(Resource):
@setup_required
@login_required
@ -381,6 +399,7 @@ class PluginDeleteInstallTaskItemApi(Resource):
raise ValueError(e)
@console_ns.route("/workspaces/current/plugin/upgrade/marketplace")
class PluginUpgradeFromMarketplaceApi(Resource):
@setup_required
@login_required
@ -404,6 +423,7 @@ class PluginUpgradeFromMarketplaceApi(Resource):
raise ValueError(e)
@console_ns.route("/workspaces/current/plugin/upgrade/github")
class PluginUpgradeFromGithubApi(Resource):
@setup_required
@login_required
@ -435,6 +455,7 @@ class PluginUpgradeFromGithubApi(Resource):
raise ValueError(e)
@console_ns.route("/workspaces/current/plugin/uninstall")
class PluginUninstallApi(Resource):
@setup_required
@login_required
@ -453,6 +474,7 @@ class PluginUninstallApi(Resource):
raise ValueError(e)
@console_ns.route("/workspaces/current/plugin/permission/change")
class PluginChangePermissionApi(Resource):
@setup_required
@login_required
@ -475,6 +497,7 @@ class PluginChangePermissionApi(Resource):
return {"success": PluginPermissionService.change_permission(tenant_id, install_permission, debug_permission)}
@console_ns.route("/workspaces/current/plugin/permission/fetch")
class PluginFetchPermissionApi(Resource):
@setup_required
@login_required
@ -499,6 +522,7 @@ class PluginFetchPermissionApi(Resource):
)
@console_ns.route("/workspaces/current/plugin/parameters/dynamic-options")
class PluginFetchDynamicSelectOptionsApi(Resource):
@setup_required
@login_required
@ -535,6 +559,7 @@ class PluginFetchDynamicSelectOptionsApi(Resource):
return jsonable_encoder({"options": options})
@console_ns.route("/workspaces/current/plugin/preferences/change")
class PluginChangePreferencesApi(Resource):
@setup_required
@login_required
@ -590,6 +615,7 @@ class PluginChangePreferencesApi(Resource):
return jsonable_encoder({"success": True})
@console_ns.route("/workspaces/current/plugin/preferences/fetch")
class PluginFetchPreferencesApi(Resource):
@setup_required
@login_required
@ -628,6 +654,7 @@ class PluginFetchPreferencesApi(Resource):
return jsonable_encoder({"permission": permission_dict, "auto_upgrade": auto_upgrade_dict})
@console_ns.route("/workspaces/current/plugin/preferences/autoupgrade/exclude")
class PluginAutoUpgradeExcludePluginApi(Resource):
@setup_required
@login_required
@ -641,35 +668,3 @@ class PluginAutoUpgradeExcludePluginApi(Resource):
args = req.parse_args()
return jsonable_encoder({"success": PluginAutoUpgradeService.exclude_plugin(tenant_id, args["plugin_id"])})
api.add_resource(PluginDebuggingKeyApi, "/workspaces/current/plugin/debugging-key")
api.add_resource(PluginListApi, "/workspaces/current/plugin/list")
api.add_resource(PluginListLatestVersionsApi, "/workspaces/current/plugin/list/latest-versions")
api.add_resource(PluginListInstallationsFromIdsApi, "/workspaces/current/plugin/list/installations/ids")
api.add_resource(PluginIconApi, "/workspaces/current/plugin/icon")
api.add_resource(PluginUploadFromPkgApi, "/workspaces/current/plugin/upload/pkg")
api.add_resource(PluginUploadFromGithubApi, "/workspaces/current/plugin/upload/github")
api.add_resource(PluginUploadFromBundleApi, "/workspaces/current/plugin/upload/bundle")
api.add_resource(PluginInstallFromPkgApi, "/workspaces/current/plugin/install/pkg")
api.add_resource(PluginInstallFromGithubApi, "/workspaces/current/plugin/install/github")
api.add_resource(PluginUpgradeFromMarketplaceApi, "/workspaces/current/plugin/upgrade/marketplace")
api.add_resource(PluginUpgradeFromGithubApi, "/workspaces/current/plugin/upgrade/github")
api.add_resource(PluginInstallFromMarketplaceApi, "/workspaces/current/plugin/install/marketplace")
api.add_resource(PluginFetchManifestApi, "/workspaces/current/plugin/fetch-manifest")
api.add_resource(PluginFetchInstallTasksApi, "/workspaces/current/plugin/tasks")
api.add_resource(PluginFetchInstallTaskApi, "/workspaces/current/plugin/tasks/<task_id>")
api.add_resource(PluginDeleteInstallTaskApi, "/workspaces/current/plugin/tasks/<task_id>/delete")
api.add_resource(PluginDeleteAllInstallTaskItemsApi, "/workspaces/current/plugin/tasks/delete_all")
api.add_resource(PluginDeleteInstallTaskItemApi, "/workspaces/current/plugin/tasks/<task_id>/delete/<path:identifier>")
api.add_resource(PluginUninstallApi, "/workspaces/current/plugin/uninstall")
api.add_resource(PluginFetchMarketplacePkgApi, "/workspaces/current/plugin/marketplace/pkg")
api.add_resource(PluginChangePermissionApi, "/workspaces/current/plugin/permission/change")
api.add_resource(PluginFetchPermissionApi, "/workspaces/current/plugin/permission/fetch")
api.add_resource(PluginFetchDynamicSelectOptionsApi, "/workspaces/current/plugin/parameters/dynamic-options")
api.add_resource(PluginFetchPreferencesApi, "/workspaces/current/plugin/preferences/fetch")
api.add_resource(PluginChangePreferencesApi, "/workspaces/current/plugin/preferences/change")
api.add_resource(PluginAutoUpgradeExcludePluginApi, "/workspaces/current/plugin/preferences/autoupgrade/exclude")

View File

@ -10,7 +10,7 @@ from flask_restx import (
from werkzeug.exceptions import Forbidden
from configs import dify_config
from controllers.console import api
from controllers.console import console_ns
from controllers.console.wraps import (
account_initialization_required,
enterprise_license_required,
@ -47,6 +47,7 @@ def is_valid_url(url: str) -> bool:
return False
@console_ns.route("/workspaces/current/tool-providers")
class ToolProviderListApi(Resource):
@setup_required
@login_required
@ -71,6 +72,7 @@ class ToolProviderListApi(Resource):
return ToolCommonService.list_tool_providers(user_id, tenant_id, args.get("type", None))
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/tools")
class ToolBuiltinProviderListToolsApi(Resource):
@setup_required
@login_required
@ -88,6 +90,7 @@ class ToolBuiltinProviderListToolsApi(Resource):
)
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/info")
class ToolBuiltinProviderInfoApi(Resource):
@setup_required
@login_required
@ -100,6 +103,7 @@ class ToolBuiltinProviderInfoApi(Resource):
return jsonable_encoder(BuiltinToolManageService.get_builtin_tool_provider_info(tenant_id, provider))
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/delete")
class ToolBuiltinProviderDeleteApi(Resource):
@setup_required
@login_required
@ -121,6 +125,7 @@ class ToolBuiltinProviderDeleteApi(Resource):
)
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/add")
class ToolBuiltinProviderAddApi(Resource):
@setup_required
@login_required
@ -150,6 +155,7 @@ class ToolBuiltinProviderAddApi(Resource):
)
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/update")
class ToolBuiltinProviderUpdateApi(Resource):
@setup_required
@login_required
@ -181,6 +187,7 @@ class ToolBuiltinProviderUpdateApi(Resource):
return result
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/credentials")
class ToolBuiltinProviderGetCredentialsApi(Resource):
@setup_required
@login_required
@ -196,6 +203,7 @@ class ToolBuiltinProviderGetCredentialsApi(Resource):
)
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/icon")
class ToolBuiltinProviderIconApi(Resource):
@setup_required
def get(self, provider):
@ -204,6 +212,7 @@ class ToolBuiltinProviderIconApi(Resource):
return send_file(io.BytesIO(icon_bytes), mimetype=mimetype, max_age=icon_cache_max_age)
@console_ns.route("/workspaces/current/tool-provider/api/add")
class ToolApiProviderAddApi(Resource):
@setup_required
@login_required
@ -243,6 +252,7 @@ class ToolApiProviderAddApi(Resource):
)
@console_ns.route("/workspaces/current/tool-provider/api/remote")
class ToolApiProviderGetRemoteSchemaApi(Resource):
@setup_required
@login_required
@ -266,6 +276,7 @@ class ToolApiProviderGetRemoteSchemaApi(Resource):
)
@console_ns.route("/workspaces/current/tool-provider/api/tools")
class ToolApiProviderListToolsApi(Resource):
@setup_required
@login_required
@ -291,6 +302,7 @@ class ToolApiProviderListToolsApi(Resource):
)
@console_ns.route("/workspaces/current/tool-provider/api/update")
class ToolApiProviderUpdateApi(Resource):
@setup_required
@login_required
@ -332,6 +344,7 @@ class ToolApiProviderUpdateApi(Resource):
)
@console_ns.route("/workspaces/current/tool-provider/api/delete")
class ToolApiProviderDeleteApi(Resource):
@setup_required
@login_required
@ -358,6 +371,7 @@ class ToolApiProviderDeleteApi(Resource):
)
@console_ns.route("/workspaces/current/tool-provider/api/get")
class ToolApiProviderGetApi(Resource):
@setup_required
@login_required
@ -381,6 +395,7 @@ class ToolApiProviderGetApi(Resource):
)
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/credential/schema/<path:credential_type>")
class ToolBuiltinProviderCredentialsSchemaApi(Resource):
@setup_required
@login_required
@ -396,6 +411,7 @@ class ToolBuiltinProviderCredentialsSchemaApi(Resource):
)
@console_ns.route("/workspaces/current/tool-provider/api/schema")
class ToolApiProviderSchemaApi(Resource):
@setup_required
@login_required
@ -412,6 +428,7 @@ class ToolApiProviderSchemaApi(Resource):
)
@console_ns.route("/workspaces/current/tool-provider/api/test/pre")
class ToolApiProviderPreviousTestApi(Resource):
@setup_required
@login_required
@ -439,6 +456,7 @@ class ToolApiProviderPreviousTestApi(Resource):
)
@console_ns.route("/workspaces/current/tool-provider/workflow/create")
class ToolWorkflowProviderCreateApi(Resource):
@setup_required
@login_required
@ -478,6 +496,7 @@ class ToolWorkflowProviderCreateApi(Resource):
)
@console_ns.route("/workspaces/current/tool-provider/workflow/update")
class ToolWorkflowProviderUpdateApi(Resource):
@setup_required
@login_required
@ -520,6 +539,7 @@ class ToolWorkflowProviderUpdateApi(Resource):
)
@console_ns.route("/workspaces/current/tool-provider/workflow/delete")
class ToolWorkflowProviderDeleteApi(Resource):
@setup_required
@login_required
@ -545,6 +565,7 @@ class ToolWorkflowProviderDeleteApi(Resource):
)
@console_ns.route("/workspaces/current/tool-provider/workflow/get")
class ToolWorkflowProviderGetApi(Resource):
@setup_required
@login_required
@ -579,6 +600,7 @@ class ToolWorkflowProviderGetApi(Resource):
return jsonable_encoder(tool)
@console_ns.route("/workspaces/current/tool-provider/workflow/tools")
class ToolWorkflowProviderListToolApi(Resource):
@setup_required
@login_required
@ -603,6 +625,7 @@ class ToolWorkflowProviderListToolApi(Resource):
)
@console_ns.route("/workspaces/current/tools/builtin")
class ToolBuiltinListApi(Resource):
@setup_required
@login_required
@ -624,6 +647,7 @@ class ToolBuiltinListApi(Resource):
)
@console_ns.route("/workspaces/current/tools/api")
class ToolApiListApi(Resource):
@setup_required
@login_required
@ -642,6 +666,7 @@ class ToolApiListApi(Resource):
)
@console_ns.route("/workspaces/current/tools/workflow")
class ToolWorkflowListApi(Resource):
@setup_required
@login_required
@ -663,6 +688,7 @@ class ToolWorkflowListApi(Resource):
)
@console_ns.route("/workspaces/current/tool-labels")
class ToolLabelsApi(Resource):
@setup_required
@login_required
@ -672,6 +698,7 @@ class ToolLabelsApi(Resource):
return jsonable_encoder(ToolLabelsService.list_tool_labels())
@console_ns.route("/oauth/plugin/<path:provider>/tool/authorization-url")
class ToolPluginOAuthApi(Resource):
@setup_required
@login_required
@ -716,6 +743,7 @@ class ToolPluginOAuthApi(Resource):
return response
@console_ns.route("/oauth/plugin/<path:provider>/tool/callback")
class ToolOAuthCallback(Resource):
@setup_required
def get(self, provider):
@ -766,6 +794,7 @@ class ToolOAuthCallback(Resource):
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/default-credential")
class ToolBuiltinProviderSetDefaultApi(Resource):
@setup_required
@login_required
@ -779,6 +808,7 @@ class ToolBuiltinProviderSetDefaultApi(Resource):
)
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/oauth/custom-client")
class ToolOAuthCustomClient(Resource):
@setup_required
@login_required
@ -822,6 +852,7 @@ class ToolOAuthCustomClient(Resource):
)
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/oauth/client-schema")
class ToolBuiltinProviderGetOauthClientSchemaApi(Resource):
@setup_required
@login_required
@ -834,6 +865,7 @@ class ToolBuiltinProviderGetOauthClientSchemaApi(Resource):
)
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/credential/info")
class ToolBuiltinProviderGetCredentialInfoApi(Resource):
@setup_required
@login_required
@ -849,6 +881,7 @@ class ToolBuiltinProviderGetCredentialInfoApi(Resource):
)
@console_ns.route("/workspaces/current/tool-provider/mcp")
class ToolProviderMCPApi(Resource):
@setup_required
@login_required
@ -933,6 +966,7 @@ class ToolProviderMCPApi(Resource):
return {"result": "success"}
@console_ns.route("/workspaces/current/tool-provider/mcp/auth")
class ToolMCPAuthApi(Resource):
@setup_required
@login_required
@ -978,6 +1012,7 @@ class ToolMCPAuthApi(Resource):
raise ValueError(f"Failed to connect to MCP server: {e}") from e
@console_ns.route("/workspaces/current/tool-provider/mcp/tools/<path:provider_id>")
class ToolMCPDetailApi(Resource):
@setup_required
@login_required
@ -988,6 +1023,7 @@ class ToolMCPDetailApi(Resource):
return jsonable_encoder(ToolTransformService.mcp_provider_to_user_provider(provider, for_list=True))
@console_ns.route("/workspaces/current/tools/mcp")
class ToolMCPListAllApi(Resource):
@setup_required
@login_required
@ -1001,6 +1037,7 @@ class ToolMCPListAllApi(Resource):
return [tool.to_dict() for tool in tools]
@console_ns.route("/workspaces/current/tool-provider/mcp/update/<path:provider_id>")
class ToolMCPUpdateApi(Resource):
@setup_required
@login_required
@ -1014,6 +1051,7 @@ class ToolMCPUpdateApi(Resource):
return jsonable_encoder(tools)
@console_ns.route("/mcp/oauth/callback")
class ToolMCPCallbackApi(Resource):
def get(self):
parser = reqparse.RequestParser()
@ -1024,67 +1062,3 @@ class ToolMCPCallbackApi(Resource):
authorization_code = args["code"]
handle_callback(state_key, authorization_code)
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")
# tool provider
api.add_resource(ToolProviderListApi, "/workspaces/current/tool-providers")
# tool oauth
api.add_resource(ToolPluginOAuthApi, "/oauth/plugin/<path:provider>/tool/authorization-url")
api.add_resource(ToolOAuthCallback, "/oauth/plugin/<path:provider>/tool/callback")
api.add_resource(ToolOAuthCustomClient, "/workspaces/current/tool-provider/builtin/<path:provider>/oauth/custom-client")
# builtin tool provider
api.add_resource(ToolBuiltinProviderListToolsApi, "/workspaces/current/tool-provider/builtin/<path:provider>/tools")
api.add_resource(ToolBuiltinProviderInfoApi, "/workspaces/current/tool-provider/builtin/<path:provider>/info")
api.add_resource(ToolBuiltinProviderAddApi, "/workspaces/current/tool-provider/builtin/<path:provider>/add")
api.add_resource(ToolBuiltinProviderDeleteApi, "/workspaces/current/tool-provider/builtin/<path:provider>/delete")
api.add_resource(ToolBuiltinProviderUpdateApi, "/workspaces/current/tool-provider/builtin/<path:provider>/update")
api.add_resource(
ToolBuiltinProviderSetDefaultApi, "/workspaces/current/tool-provider/builtin/<path:provider>/default-credential"
)
api.add_resource(
ToolBuiltinProviderGetCredentialInfoApi, "/workspaces/current/tool-provider/builtin/<path:provider>/credential/info"
)
api.add_resource(
ToolBuiltinProviderGetCredentialsApi, "/workspaces/current/tool-provider/builtin/<path:provider>/credentials"
)
api.add_resource(
ToolBuiltinProviderCredentialsSchemaApi,
"/workspaces/current/tool-provider/builtin/<path:provider>/credential/schema/<path:credential_type>",
)
api.add_resource(
ToolBuiltinProviderGetOauthClientSchemaApi,
"/workspaces/current/tool-provider/builtin/<path:provider>/oauth/client-schema",
)
api.add_resource(ToolBuiltinProviderIconApi, "/workspaces/current/tool-provider/builtin/<path:provider>/icon")
# api tool provider
api.add_resource(ToolApiProviderAddApi, "/workspaces/current/tool-provider/api/add")
api.add_resource(ToolApiProviderGetRemoteSchemaApi, "/workspaces/current/tool-provider/api/remote")
api.add_resource(ToolApiProviderListToolsApi, "/workspaces/current/tool-provider/api/tools")
api.add_resource(ToolApiProviderUpdateApi, "/workspaces/current/tool-provider/api/update")
api.add_resource(ToolApiProviderDeleteApi, "/workspaces/current/tool-provider/api/delete")
api.add_resource(ToolApiProviderGetApi, "/workspaces/current/tool-provider/api/get")
api.add_resource(ToolApiProviderSchemaApi, "/workspaces/current/tool-provider/api/schema")
api.add_resource(ToolApiProviderPreviousTestApi, "/workspaces/current/tool-provider/api/test/pre")
# workflow tool provider
api.add_resource(ToolWorkflowProviderCreateApi, "/workspaces/current/tool-provider/workflow/create")
api.add_resource(ToolWorkflowProviderUpdateApi, "/workspaces/current/tool-provider/workflow/update")
api.add_resource(ToolWorkflowProviderDeleteApi, "/workspaces/current/tool-provider/workflow/delete")
api.add_resource(ToolWorkflowProviderGetApi, "/workspaces/current/tool-provider/workflow/get")
api.add_resource(ToolWorkflowProviderListToolApi, "/workspaces/current/tool-provider/workflow/tools")
# mcp tool provider
api.add_resource(ToolMCPDetailApi, "/workspaces/current/tool-provider/mcp/tools/<path:provider_id>")
api.add_resource(ToolProviderMCPApi, "/workspaces/current/tool-provider/mcp")
api.add_resource(ToolMCPUpdateApi, "/workspaces/current/tool-provider/mcp/update/<path:provider_id>")
api.add_resource(ToolMCPAuthApi, "/workspaces/current/tool-provider/mcp/auth")
api.add_resource(ToolMCPCallbackApi, "/mcp/oauth/callback")
api.add_resource(ToolBuiltinListApi, "/workspaces/current/tools/builtin")
api.add_resource(ToolApiListApi, "/workspaces/current/tools/api")
api.add_resource(ToolMCPListAllApi, "/workspaces/current/tools/mcp")
api.add_resource(ToolWorkflowListApi, "/workspaces/current/tools/workflow")
api.add_resource(ToolLabelsApi, "/workspaces/current/tool-labels")

View File

@ -1,7 +1,6 @@
import logging
from flask import request
from flask_login import current_user
from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse
from sqlalchemy import select
from werkzeug.exceptions import Unauthorized
@ -14,7 +13,7 @@ from controllers.common.errors import (
TooManyFilesError,
UnsupportedFileTypeError,
)
from controllers.console import api
from controllers.console import console_ns
from controllers.console.admin import admin_required
from controllers.console.error import AccountNotLinkTenantError
from controllers.console.wraps import (
@ -24,7 +23,7 @@ from controllers.console.wraps import (
)
from extensions.ext_database import db
from libs.helper import TimestampField
from libs.login import login_required
from libs.login import current_user, login_required
from models.account import Account, Tenant, TenantStatus
from services.account_service import TenantService
from services.feature_service import FeatureService
@ -65,6 +64,7 @@ tenants_fields = {
workspace_fields = {"id": fields.String, "name": fields.String, "status": fields.String, "created_at": TimestampField}
@console_ns.route("/workspaces")
class TenantListApi(Resource):
@setup_required
@login_required
@ -93,6 +93,7 @@ class TenantListApi(Resource):
return {"workspaces": marshal(tenant_dicts, tenants_fields)}, 200
@console_ns.route("/all-workspaces")
class WorkspaceListApi(Resource):
@setup_required
@admin_required
@ -118,6 +119,8 @@ class WorkspaceListApi(Resource):
}, 200
@console_ns.route("/workspaces/current", endpoint="workspaces_current")
@console_ns.route("/info", endpoint="info") # Deprecated
class TenantApi(Resource):
@setup_required
@login_required
@ -143,11 +146,10 @@ class TenantApi(Resource):
else:
raise Unauthorized("workspace is archived")
if not tenant:
raise ValueError("No tenant available")
return WorkspaceService.get_tenant_info(tenant), 200
@console_ns.route("/workspaces/switch")
class SwitchWorkspaceApi(Resource):
@setup_required
@login_required
@ -172,6 +174,7 @@ class SwitchWorkspaceApi(Resource):
return {"result": "success", "new_tenant": marshal(WorkspaceService.get_tenant_info(new_tenant), tenant_fields)}
@console_ns.route("/workspaces/custom-config")
class CustomConfigWorkspaceApi(Resource):
@setup_required
@login_required
@ -202,6 +205,7 @@ class CustomConfigWorkspaceApi(Resource):
return {"result": "success", "tenant": marshal(WorkspaceService.get_tenant_info(tenant), tenant_fields)}
@console_ns.route("/workspaces/custom-config/webapp-logo/upload")
class WebappLogoWorkspaceApi(Resource):
@setup_required
@login_required
@ -242,6 +246,7 @@ class WebappLogoWorkspaceApi(Resource):
return {"id": upload_file.id}, 201
@console_ns.route("/workspaces/info")
class WorkspaceInfoApi(Resource):
@setup_required
@login_required
@ -261,13 +266,3 @@ class WorkspaceInfoApi(Resource):
db.session.commit()
return {"result": "success", "tenant": marshal(WorkspaceService.get_tenant_info(tenant), tenant_fields)}
api.add_resource(TenantListApi, "/workspaces") # GET for getting all tenants
api.add_resource(WorkspaceListApi, "/all-workspaces") # GET for getting all tenants
api.add_resource(TenantApi, "/workspaces/current", endpoint="workspaces_current") # GET for getting current tenant info
api.add_resource(TenantApi, "/info", endpoint="info") # Deprecated
api.add_resource(SwitchWorkspaceApi, "/workspaces/switch") # POST for switching tenant
api.add_resource(CustomConfigWorkspaceApi, "/workspaces/custom-config")
api.add_resource(WebappLogoWorkspaceApi, "/workspaces/custom-config/webapp-logo/upload")
api.add_resource(WorkspaceInfoApi, "/workspaces/info") # POST for changing workspace info

View File

@ -7,13 +7,13 @@ from functools import wraps
from typing import ParamSpec, TypeVar
from flask import abort, request
from flask_login import current_user
from configs import dify_config
from controllers.console.workspace.error import AccountNotInitializedError
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.account import AccountStatus
from libs.login import current_user
from models.account import Account, AccountStatus
from models.dataset import RateLimitLog
from models.model import DifySetup
from services.feature_service import FeatureService, LicenseStatus
@ -25,11 +25,16 @@ P = ParamSpec("P")
R = TypeVar("R")
def _current_account() -> Account:
assert isinstance(current_user, Account)
return current_user
def account_initialization_required(view: Callable[P, R]):
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
# check account initialization
account = current_user
account = _current_account()
if account.status == AccountStatus.UNINITIALIZED:
raise AccountNotInitializedError()
@ -75,7 +80,9 @@ def only_edition_self_hosted(view: Callable[P, R]):
def cloud_edition_billing_enabled(view: Callable[P, R]):
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
features = FeatureService.get_features(current_user.current_tenant_id)
account = _current_account()
assert account.current_tenant_id is not None
features = FeatureService.get_features(account.current_tenant_id)
if not features.billing.enabled:
abort(403, "Billing feature is not enabled.")
return view(*args, **kwargs)
@ -87,7 +94,10 @@ def cloud_edition_billing_resource_check(resource: str):
def interceptor(view: Callable[P, R]):
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
features = FeatureService.get_features(current_user.current_tenant_id)
account = _current_account()
assert account.current_tenant_id is not None
tenant_id = account.current_tenant_id
features = FeatureService.get_features(tenant_id)
if features.billing.enabled:
members = features.members
apps = features.apps
@ -128,7 +138,9 @@ def cloud_edition_billing_knowledge_limit_check(resource: str):
def interceptor(view: Callable[P, R]):
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
features = FeatureService.get_features(current_user.current_tenant_id)
account = _current_account()
assert account.current_tenant_id is not None
features = FeatureService.get_features(account.current_tenant_id)
if features.billing.enabled:
if resource == "add_segment":
if features.billing.subscription.plan == "sandbox":
@ -151,10 +163,13 @@ def cloud_edition_billing_rate_limit_check(resource: str):
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
if resource == "knowledge":
knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(current_user.current_tenant_id)
account = _current_account()
assert account.current_tenant_id is not None
tenant_id = account.current_tenant_id
knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(tenant_id)
if knowledge_rate_limit.enabled:
current_time = int(time.time() * 1000)
key = f"rate_limit_{current_user.current_tenant_id}"
key = f"rate_limit_{tenant_id}"
redis_client.zadd(key, {current_time: current_time})
@ -165,7 +180,7 @@ def cloud_edition_billing_rate_limit_check(resource: str):
if request_count > knowledge_rate_limit.limit:
# add ratelimit record
rate_limit_log = RateLimitLog(
tenant_id=current_user.current_tenant_id,
tenant_id=tenant_id,
subscription_plan=knowledge_rate_limit.subscription_plan,
operation="knowledge",
)
@ -185,14 +200,17 @@ def cloud_utm_record(view: Callable[P, R]):
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
with contextlib.suppress(Exception):
features = FeatureService.get_features(current_user.current_tenant_id)
account = _current_account()
assert account.current_tenant_id is not None
tenant_id = account.current_tenant_id
features = FeatureService.get_features(tenant_id)
if features.billing.enabled:
utm_info = request.cookies.get("utm_info")
if utm_info:
utm_info_dict: dict = json.loads(utm_info)
OperationService.record_utm(current_user.current_tenant_id, utm_info_dict)
OperationService.record_utm(tenant_id, utm_info_dict)
return view(*args, **kwargs)
@ -271,7 +289,9 @@ def enable_change_email(view: Callable[P, R]):
def is_allow_transfer_owner(view: Callable[P, R]):
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
features = FeatureService.get_features(current_user.current_tenant_id)
account = _current_account()
assert account.current_tenant_id is not None
features = FeatureService.get_features(account.current_tenant_id)
if features.is_allow_transfer_workspace:
return view(*args, **kwargs)
@ -284,7 +304,9 @@ def is_allow_transfer_owner(view: Callable[P, R]):
def knowledge_pipeline_publish_enabled(view):
@wraps(view)
def decorated(*args, **kwargs):
features = FeatureService.get_features(current_user.current_tenant_id)
account = _current_account()
assert account.current_tenant_id is not None
features = FeatureService.get_features(account.current_tenant_id)
if features.knowledge_pipeline.publish_enabled:
return view(*args, **kwargs)
abort(403)

View File

@ -25,8 +25,8 @@ def get_user(tenant_id: str, user_id: str | None) -> EndUser:
As a result, it could only be considered as an end user id.
"""
if not user_id:
user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID.value
is_anonymous = user_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID.value
user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID
is_anonymous = user_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID
try:
with Session(db.engine) as session:
user_model = None
@ -85,7 +85,7 @@ def get_user_tenant(view: Callable[P, R] | None = None):
raise ValueError("tenant_id is required")
if not user_id:
user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID.value
user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID
try:
tenant_model = (
@ -128,7 +128,7 @@ def plugin_data(view: Callable[P, R] | None = None, *, payload_type: type[BaseMo
raise ValueError("invalid json")
try:
payload = payload_type(**data)
payload = payload_type.model_validate(data)
except Exception as e:
raise ValueError(f"invalid payload: {str(e)}")

View File

@ -17,6 +17,7 @@ from core.provider_manager import ProviderManager
from fields.dataset_fields import dataset_detail_fields
from fields.tag_fields import build_dataset_tag_fields
from libs.login import current_user
from libs.validators import validate_description_length
from models.account import Account
from models.dataset import Dataset, DatasetPermissionEnum
from models.provider_ids import ModelProviderID
@ -31,12 +32,6 @@ def _validate_name(name):
return name
def _validate_description_length(description):
if description and len(description) > 400:
raise ValueError("Description cannot exceed 400 characters.")
return description
# Define parsers for dataset operations
dataset_create_parser = reqparse.RequestParser()
dataset_create_parser.add_argument(
@ -48,7 +43,7 @@ dataset_create_parser.add_argument(
)
dataset_create_parser.add_argument(
"description",
type=_validate_description_length,
type=validate_description_length,
nullable=True,
required=False,
default="",
@ -101,7 +96,7 @@ dataset_update_parser.add_argument(
type=_validate_name,
)
dataset_update_parser.add_argument(
"description", location="json", store_missing=False, type=_validate_description_length
"description", location="json", store_missing=False, type=validate_description_length
)
dataset_update_parser.add_argument(
"indexing_technique",
@ -285,7 +280,7 @@ class DatasetListApi(DatasetApiResource):
external_knowledge_id=args["external_knowledge_id"],
embedding_model_provider=args["embedding_model_provider"],
embedding_model_name=args["embedding_model"],
retrieval_model=RetrievalModel(**args["retrieval_model"])
retrieval_model=RetrievalModel.model_validate(args["retrieval_model"])
if args["retrieval_model"] is not None
else None,
)

View File

@ -136,7 +136,7 @@ class DocumentAddByTextApi(DatasetApiResource):
"info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}},
}
args["data_source"] = data_source
knowledge_config = KnowledgeConfig(**args)
knowledge_config = KnowledgeConfig.model_validate(args)
# validate args
DocumentService.document_create_args_validate(knowledge_config)
@ -221,7 +221,7 @@ class DocumentUpdateByTextApi(DatasetApiResource):
args["data_source"] = data_source
# validate args
args["original_document_id"] = str(document_id)
knowledge_config = KnowledgeConfig(**args)
knowledge_config = KnowledgeConfig.model_validate(args)
DocumentService.document_create_args_validate(knowledge_config)
try:
@ -328,7 +328,7 @@ class DocumentAddByFileApi(DatasetApiResource):
}
args["data_source"] = data_source
# validate args
knowledge_config = KnowledgeConfig(**args)
knowledge_config = KnowledgeConfig.model_validate(args)
DocumentService.document_create_args_validate(knowledge_config)
dataset_process_rule = dataset.latest_process_rule if "process_rule" not in args else None
@ -426,7 +426,7 @@ class DocumentUpdateByFileApi(DatasetApiResource):
# validate args
args["original_document_id"] = str(document_id)
knowledge_config = KnowledgeConfig(**args)
knowledge_config = KnowledgeConfig.model_validate(args)
DocumentService.document_create_args_validate(knowledge_config)
try:

View File

@ -51,7 +51,7 @@ class DatasetMetadataCreateServiceApi(DatasetApiResource):
def post(self, tenant_id, dataset_id):
"""Create metadata for a dataset."""
args = metadata_create_parser.parse_args()
metadata_args = MetadataArgs(**args)
metadata_args = MetadataArgs.model_validate(args)
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
@ -200,7 +200,7 @@ class DocumentMetadataEditServiceApi(DatasetApiResource):
DatasetService.check_dataset_permission(dataset, current_user)
args = document_metadata_parser.parse_args()
metadata_args = MetadataOperationData(**args)
metadata_args = MetadataOperationData.model_validate(args)
MetadataService.update_documents_metadata(dataset, metadata_args)

View File

@ -98,7 +98,7 @@ class DatasourceNodeRunApi(DatasetApiResource):
parser.add_argument("is_published", type=bool, required=True, location="json")
args: ParseResult = parser.parse_args()
datasource_node_run_api_entity: DatasourceNodeRunApiEntity = DatasourceNodeRunApiEntity(**args)
datasource_node_run_api_entity = DatasourceNodeRunApiEntity.model_validate(args)
assert isinstance(current_user, Account)
rag_pipeline_service: RagPipelineService = RagPipelineService()
pipeline: Pipeline = rag_pipeline_service.get_pipeline(tenant_id=tenant_id, dataset_id=dataset_id)

View File

@ -252,7 +252,7 @@ class DatasetSegmentApi(DatasetApiResource):
args = segment_update_parser.parse_args()
updated_segment = SegmentService.update_segment(
SegmentUpdateArgs(**args["segment"]), segment, document, dataset
SegmentUpdateArgs.model_validate(args["segment"]), segment, document, dataset
)
return {"data": marshal(updated_segment, segment_fields), "doc_form": document.doc_form}, 200

View File

@ -313,7 +313,7 @@ def create_or_update_end_user_for_user_id(app_model: App, user_id: str | None =
Create or update session terminal based on user ID.
"""
if not user_id:
user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID.value
user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID
with Session(db.engine, expire_on_commit=False) as session:
end_user = (
@ -332,7 +332,7 @@ def create_or_update_end_user_for_user_id(app_model: App, user_id: str | None =
tenant_id=app_model.tenant_id,
app_id=app_model.id,
type="service_api",
is_anonymous=user_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID.value,
is_anonymous=user_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID,
session_id=user_id,
)
session.add(end_user)

View File

@ -126,6 +126,8 @@ def exchange_token_for_existing_web_user(app_code: str, enterprise_user_decoded:
end_user_id = enterprise_user_decoded.get("end_user_id")
session_id = enterprise_user_decoded.get("session_id")
user_auth_type = enterprise_user_decoded.get("auth_type")
exchanged_token_expires_unix = enterprise_user_decoded.get("exp")
if not user_auth_type:
raise Unauthorized("Missing auth_type in the token.")
@ -169,8 +171,11 @@ def exchange_token_for_existing_web_user(app_code: str, enterprise_user_decoded:
)
db.session.add(end_user)
db.session.commit()
exp_dt = datetime.now(UTC) + timedelta(minutes=dify_config.ACCESS_TOKEN_EXPIRE_MINUTES)
exp = int(exp_dt.timestamp())
exp = int((datetime.now(UTC) + timedelta(minutes=dify_config.ACCESS_TOKEN_EXPIRE_MINUTES)).timestamp())
if exchanged_token_expires_unix:
exp = int(exchanged_token_expires_unix)
payload = {
"iss": site.id,
"sub": "Web API Passport",

View File

@ -40,7 +40,7 @@ class AgentConfigManager:
"credential_id": tool.get("credential_id", None),
}
agent_tools.append(AgentToolEntity(**agent_tool_properties))
agent_tools.append(AgentToolEntity.model_validate(agent_tool_properties))
if "strategy" in config["agent_mode"] and config["agent_mode"]["strategy"] not in {
"react_router",

View File

@ -1,4 +1,5 @@
import uuid
from typing import Literal, cast
from core.app.app_config.entities import (
DatasetEntity,
@ -74,6 +75,9 @@ class DatasetConfigManager:
return None
query_variable = config.get("dataset_query_variable")
metadata_model_config_dict = dataset_configs.get("metadata_model_config")
metadata_filtering_conditions_dict = dataset_configs.get("metadata_filtering_conditions")
if dataset_configs["retrieval_model"] == "single":
return DatasetEntity(
dataset_ids=dataset_ids,
@ -82,18 +86,23 @@ class DatasetConfigManager:
retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of(
dataset_configs["retrieval_model"]
),
metadata_filtering_mode=dataset_configs.get("metadata_filtering_mode", "disabled"),
metadata_model_config=ModelConfig(**dataset_configs.get("metadata_model_config"))
if dataset_configs.get("metadata_model_config")
metadata_filtering_mode=cast(
Literal["disabled", "automatic", "manual"],
dataset_configs.get("metadata_filtering_mode", "disabled"),
),
metadata_model_config=ModelConfig(**metadata_model_config_dict)
if isinstance(metadata_model_config_dict, dict)
else None,
metadata_filtering_conditions=MetadataFilteringCondition(
**dataset_configs.get("metadata_filtering_conditions", {})
)
if dataset_configs.get("metadata_filtering_conditions")
metadata_filtering_conditions=MetadataFilteringCondition(**metadata_filtering_conditions_dict)
if isinstance(metadata_filtering_conditions_dict, dict)
else None,
),
)
else:
score_threshold_val = dataset_configs.get("score_threshold")
reranking_model_val = dataset_configs.get("reranking_model")
weights_val = dataset_configs.get("weights")
return DatasetEntity(
dataset_ids=dataset_ids,
retrieve_config=DatasetRetrieveConfigEntity(
@ -101,22 +110,23 @@ class DatasetConfigManager:
retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of(
dataset_configs["retrieval_model"]
),
top_k=dataset_configs.get("top_k", 4),
score_threshold=dataset_configs.get("score_threshold")
if dataset_configs.get("score_threshold_enabled", False)
top_k=int(dataset_configs.get("top_k", 4)),
score_threshold=float(score_threshold_val)
if dataset_configs.get("score_threshold_enabled", False) and score_threshold_val is not None
else None,
reranking_model=dataset_configs.get("reranking_model"),
weights=dataset_configs.get("weights"),
reranking_enabled=dataset_configs.get("reranking_enabled", True),
reranking_model=reranking_model_val if isinstance(reranking_model_val, dict) else None,
weights=weights_val if isinstance(weights_val, dict) else None,
reranking_enabled=bool(dataset_configs.get("reranking_enabled", True)),
rerank_mode=dataset_configs.get("reranking_mode", "reranking_model"),
metadata_filtering_mode=dataset_configs.get("metadata_filtering_mode", "disabled"),
metadata_model_config=ModelConfig(**dataset_configs.get("metadata_model_config"))
if dataset_configs.get("metadata_model_config")
metadata_filtering_mode=cast(
Literal["disabled", "automatic", "manual"],
dataset_configs.get("metadata_filtering_mode", "disabled"),
),
metadata_model_config=ModelConfig(**metadata_model_config_dict)
if isinstance(metadata_model_config_dict, dict)
else None,
metadata_filtering_conditions=MetadataFilteringCondition(
**dataset_configs.get("metadata_filtering_conditions", {})
)
if dataset_configs.get("metadata_filtering_conditions")
metadata_filtering_conditions=MetadataFilteringCondition(**metadata_filtering_conditions_dict)
if isinstance(metadata_filtering_conditions_dict, dict)
else None,
),
)
@ -134,18 +144,17 @@ class DatasetConfigManager:
config = cls.extract_dataset_config_for_legacy_compatibility(tenant_id, app_mode, config)
# dataset_configs
if not config.get("dataset_configs"):
config["dataset_configs"] = {"retrieval_model": "single"}
if "dataset_configs" not in config or not config.get("dataset_configs"):
config["dataset_configs"] = {}
config["dataset_configs"]["retrieval_model"] = config["dataset_configs"].get("retrieval_model", "single")
if not isinstance(config["dataset_configs"], dict):
raise ValueError("dataset_configs must be of object type")
if not config["dataset_configs"].get("datasets"):
if "datasets" not in config["dataset_configs"] or not config["dataset_configs"].get("datasets"):
config["dataset_configs"]["datasets"] = {"strategy": "router", "datasets": []}
need_manual_query_datasets = config.get("dataset_configs") and config["dataset_configs"].get(
"datasets", {}
).get("datasets")
need_manual_query_datasets = config.get("dataset_configs", {}).get("datasets", {}).get("datasets")
if need_manual_query_datasets and app_mode == AppMode.COMPLETION:
# Only check when mode is completion
@ -166,8 +175,8 @@ class DatasetConfigManager:
:param config: app model config args
"""
# Extract dataset config for legacy compatibility
if not config.get("agent_mode"):
config["agent_mode"] = {"enabled": False, "tools": []}
if "agent_mode" not in config or not config.get("agent_mode"):
config["agent_mode"] = {}
if not isinstance(config["agent_mode"], dict):
raise ValueError("agent_mode must be of object type")
@ -180,19 +189,22 @@ class DatasetConfigManager:
raise ValueError("enabled in agent_mode must be of boolean type")
# tools
if not config["agent_mode"].get("tools"):
if "tools" not in config["agent_mode"] or not config["agent_mode"].get("tools"):
config["agent_mode"]["tools"] = []
if not isinstance(config["agent_mode"]["tools"], list):
raise ValueError("tools in agent_mode must be a list of objects")
# strategy
if not config["agent_mode"].get("strategy"):
config["agent_mode"]["strategy"] = PlanningStrategy.ROUTER.value
if "strategy" not in config["agent_mode"] or not config["agent_mode"].get("strategy"):
config["agent_mode"]["strategy"] = PlanningStrategy.ROUTER
has_datasets = False
if config["agent_mode"]["strategy"] in {PlanningStrategy.ROUTER.value, PlanningStrategy.REACT_ROUTER.value}:
for tool in config["agent_mode"]["tools"]:
if config.get("agent_mode", {}).get("strategy") in {
PlanningStrategy.ROUTER,
PlanningStrategy.REACT_ROUTER,
}:
for tool in config.get("agent_mode", {}).get("tools", []):
key = list(tool.keys())[0]
if key == "dataset":
# old style, use tool name as key
@ -217,7 +229,7 @@ class DatasetConfigManager:
has_datasets = True
need_manual_query_datasets = has_datasets and config["agent_mode"]["enabled"]
need_manual_query_datasets = has_datasets and config.get("agent_mode", {}).get("enabled")
if need_manual_query_datasets and app_mode == AppMode.COMPLETION:
# Only check when mode is completion

View File

@ -68,9 +68,13 @@ class ModelConfigConverter:
# get model mode
model_mode = model_config.mode
if not model_mode:
model_mode = LLMMode.CHAT.value
model_mode = LLMMode.CHAT
if model_schema and model_schema.model_properties.get(ModelPropertyKey.MODE):
model_mode = LLMMode(model_schema.model_properties[ModelPropertyKey.MODE]).value
try:
model_mode = LLMMode(model_schema.model_properties[ModelPropertyKey.MODE])
except ValueError:
# Fall back to CHAT mode if the stored value is invalid
model_mode = LLMMode.CHAT
if not model_schema:
raise ValueError(f"Model {model_name} not exist.")

View File

@ -100,7 +100,7 @@ class PromptTemplateConfigManager:
if config["model"]["mode"] not in model_mode_vals:
raise ValueError(f"model.mode must be in {model_mode_vals} when prompt_type is advanced")
if app_mode == AppMode.CHAT and config["model"]["mode"] == ModelMode.COMPLETION.value:
if app_mode == AppMode.CHAT and config["model"]["mode"] == ModelMode.COMPLETION:
user_prefix = config["completion_prompt_config"]["conversation_histories_role"]["user_prefix"]
assistant_prefix = config["completion_prompt_config"]["conversation_histories_role"]["assistant_prefix"]
@ -110,7 +110,7 @@ class PromptTemplateConfigManager:
if not assistant_prefix:
config["completion_prompt_config"]["conversation_histories_role"]["assistant_prefix"] = "Assistant"
if config["model"]["mode"] == ModelMode.CHAT.value:
if config["model"]["mode"] == ModelMode.CHAT:
prompt_list = config["chat_prompt_config"]["prompt"]
if len(prompt_list) > 10:

View File

@ -186,7 +186,7 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
raise ValueError("enabled in agent_mode must be of boolean type")
if not agent_mode.get("strategy"):
agent_mode["strategy"] = PlanningStrategy.ROUTER.value
agent_mode["strategy"] = PlanningStrategy.ROUTER
if agent_mode["strategy"] not in [member.value for member in list(PlanningStrategy.__members__.values())]:
raise ValueError("strategy in agent_mode must be in the specified strategy list")

View File

@ -198,9 +198,9 @@ class AgentChatAppRunner(AppRunner):
# start agent runner
if agent_entity.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT:
# check LLM mode
if model_schema.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT.value:
if model_schema.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT:
runner_cls = CotChatAgentRunner
elif model_schema.model_properties.get(ModelPropertyKey.MODE) == LLMMode.COMPLETION.value:
elif model_schema.model_properties.get(ModelPropertyKey.MODE) == LLMMode.COMPLETION:
runner_cls = CotCompletionAgentRunner
else:
raise ValueError(f"Invalid LLM mode: {model_schema.model_properties.get(ModelPropertyKey.MODE)}")

View File

@ -61,9 +61,6 @@ class AppRunner:
if model_context_tokens is None:
return -1
if max_tokens is None:
max_tokens = 0
prompt_tokens = model_instance.get_llm_num_tokens(prompt_messages)
if prompt_tokens + max_tokens > model_context_tokens:

View File

@ -116,7 +116,7 @@ class PipelineRunner(WorkflowBasedAppRunner):
rag_pipeline_variables = []
if workflow.rag_pipeline_variables:
for v in workflow.rag_pipeline_variables:
rag_pipeline_variable = RAGPipelineVariable(**v)
rag_pipeline_variable = RAGPipelineVariable.model_validate(v)
if (
rag_pipeline_variable.belong_to_node_id
in (self.application_generate_entity.start_node_id, "shared")
@ -229,8 +229,8 @@ class PipelineRunner(WorkflowBasedAppRunner):
workflow_id=workflow.id,
graph_config=graph_config,
user_id=self.application_generate_entity.user_id,
user_from=UserFrom.ACCOUNT.value,
invoke_from=InvokeFrom.SERVICE_API.value,
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.SERVICE_API,
call_depth=0,
)

View File

@ -100,8 +100,8 @@ class WorkflowBasedAppRunner:
workflow_id=workflow_id,
graph_config=graph_config,
user_id=user_id,
user_from=UserFrom.ACCOUNT.value,
invoke_from=InvokeFrom.SERVICE_API.value,
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.SERVICE_API,
call_depth=0,
)
@ -244,8 +244,8 @@ class WorkflowBasedAppRunner:
workflow_id=workflow.id,
graph_config=graph_config,
user_id="",
user_from=UserFrom.ACCOUNT.value,
invoke_from=InvokeFrom.SERVICE_API.value,
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.SERVICE_API,
call_depth=0,
)

View File

@ -107,7 +107,6 @@ class MessageCycleManager:
if dify_config.DEBUG:
logger.exception("generate conversation name failed, conversation_id: %s", conversation_id)
db.session.merge(conversation)
db.session.commit()
db.session.close()

View File

@ -1,7 +1,6 @@
from typing import TYPE_CHECKING, Any, Optional
from openai import BaseModel
from pydantic import Field
from pydantic import BaseModel, Field
# Import InvokeFrom locally to avoid circular import
from core.app.entities.app_invoke_entities import InvokeFrom

View File

@ -49,7 +49,7 @@ class DatasourceProviderApiEntity(BaseModel):
for datasource in datasources:
if datasource.get("parameters"):
for parameter in datasource.get("parameters"):
if parameter.get("type") == DatasourceParameter.DatasourceParameterType.SYSTEM_FILES.value:
if parameter.get("type") == DatasourceParameter.DatasourceParameterType.SYSTEM_FILES:
parameter["type"] = "files"
# -------------

View File

@ -1,4 +1,4 @@
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, model_validator
class I18nObject(BaseModel):
@ -11,11 +11,12 @@ class I18nObject(BaseModel):
pt_BR: str | None = Field(default=None)
ja_JP: str | None = Field(default=None)
def __init__(self, **data):
super().__init__(**data)
@model_validator(mode="after")
def _(self):
self.zh_Hans = self.zh_Hans or self.en_US
self.pt_BR = self.pt_BR or self.en_US
self.ja_JP = self.ja_JP or self.en_US
return self
def to_dict(self) -> dict:
return {"zh_Hans": self.zh_Hans, "en_US": self.en_US, "pt_BR": self.pt_BR, "ja_JP": self.ja_JP}

View File

@ -1,5 +1,5 @@
import enum
from enum import Enum
from enum import StrEnum
from typing import Any
from pydantic import BaseModel, Field, ValidationInfo, field_validator
@ -54,16 +54,16 @@ class DatasourceParameter(PluginParameter):
removes TOOLS_SELECTOR from PluginParameterType
"""
STRING = PluginParameterType.STRING.value
NUMBER = PluginParameterType.NUMBER.value
BOOLEAN = PluginParameterType.BOOLEAN.value
SELECT = PluginParameterType.SELECT.value
SECRET_INPUT = PluginParameterType.SECRET_INPUT.value
FILE = PluginParameterType.FILE.value
FILES = PluginParameterType.FILES.value
STRING = PluginParameterType.STRING
NUMBER = PluginParameterType.NUMBER
BOOLEAN = PluginParameterType.BOOLEAN
SELECT = PluginParameterType.SELECT
SECRET_INPUT = PluginParameterType.SECRET_INPUT
FILE = PluginParameterType.FILE
FILES = PluginParameterType.FILES
# deprecated, should not use.
SYSTEM_FILES = PluginParameterType.SYSTEM_FILES.value
SYSTEM_FILES = PluginParameterType.SYSTEM_FILES
def as_normal_type(self):
return as_normal_type(self)
@ -218,7 +218,7 @@ class DatasourceLabel(BaseModel):
icon: str = Field(..., description="The icon of the tool")
class DatasourceInvokeFrom(Enum):
class DatasourceInvokeFrom(StrEnum):
"""
Enum class for datasource invoke
"""

View File

@ -5,7 +5,7 @@ from collections import defaultdict
from collections.abc import Iterator, Sequence
from json import JSONDecodeError
from pydantic import BaseModel, ConfigDict, Field
from pydantic import BaseModel, ConfigDict, Field, model_validator
from sqlalchemy import func, select
from sqlalchemy.orm import Session
@ -73,9 +73,8 @@ class ProviderConfiguration(BaseModel):
# pydantic configs
model_config = ConfigDict(protected_namespaces=())
def __init__(self, **data):
super().__init__(**data)
@model_validator(mode="after")
def _(self):
if self.provider.provider not in original_provider_configurate_methods:
original_provider_configurate_methods[self.provider.provider] = []
for configurate_method in self.provider.configurate_methods:
@ -90,6 +89,7 @@ class ProviderConfiguration(BaseModel):
and ConfigurateMethod.PREDEFINED_MODEL not in self.provider.configurate_methods
):
self.provider.configurate_methods.append(ConfigurateMethod.PREDEFINED_MODEL)
return self
def get_current_credentials(self, model_type: ModelType, model: str) -> dict | None:
"""
@ -207,7 +207,7 @@ class ProviderConfiguration(BaseModel):
"""
stmt = select(Provider).where(
Provider.tenant_id == self.tenant_id,
Provider.provider_type == ProviderType.CUSTOM.value,
Provider.provider_type == ProviderType.CUSTOM,
Provider.provider_name.in_(self._get_provider_names()),
)
@ -458,7 +458,7 @@ class ProviderConfiguration(BaseModel):
provider_record = Provider(
tenant_id=self.tenant_id,
provider_name=self.provider.provider,
provider_type=ProviderType.CUSTOM.value,
provider_type=ProviderType.CUSTOM,
is_valid=True,
credential_id=new_record.id,
)
@ -1414,7 +1414,7 @@ class ProviderConfiguration(BaseModel):
"""
secret_input_form_variables = []
for credential_form_schema in credential_form_schemas:
if credential_form_schema.type.value == FormType.SECRET_INPUT.value:
if credential_form_schema.type == FormType.SECRET_INPUT:
secret_input_form_variables.append(credential_form_schema.variable)
return secret_input_form_variables

View File

@ -1,13 +1,13 @@
from typing import cast
import requests
import httpx
from configs import dify_config
from models.api_based_extension import APIBasedExtensionPoint
class APIBasedExtensionRequestor:
timeout: tuple[int, int] = (5, 60)
timeout: httpx.Timeout = httpx.Timeout(60.0, connect=5.0)
"""timeout for request connect and read"""
def __init__(self, api_endpoint: str, api_key: str):
@ -27,25 +27,23 @@ class APIBasedExtensionRequestor:
url = self.api_endpoint
try:
# proxy support for security
proxies = None
mounts: dict[str, httpx.BaseTransport] | None = None
if dify_config.SSRF_PROXY_HTTP_URL and dify_config.SSRF_PROXY_HTTPS_URL:
proxies = {
"http": dify_config.SSRF_PROXY_HTTP_URL,
"https": dify_config.SSRF_PROXY_HTTPS_URL,
mounts = {
"http://": httpx.HTTPTransport(proxy=dify_config.SSRF_PROXY_HTTP_URL),
"https://": httpx.HTTPTransport(proxy=dify_config.SSRF_PROXY_HTTPS_URL),
}
response = requests.request(
method="POST",
url=url,
json={"point": point.value, "params": params},
headers=headers,
timeout=self.timeout,
proxies=proxies,
)
except requests.Timeout:
with httpx.Client(mounts=mounts, timeout=self.timeout) as client:
response = client.request(
method="POST",
url=url,
json={"point": point.value, "params": params},
headers=headers,
)
except httpx.TimeoutException:
raise ValueError("request timeout")
except requests.ConnectionError:
except httpx.RequestError:
raise ValueError("request connection error")
if response.status_code != 200:

View File

@ -131,7 +131,7 @@ class CodeExecutor:
if (code := response_data.get("code")) != 0:
raise CodeExecutionError(f"Got error code: {code}. Got error msg: {response_data.get('message')}")
response_code = CodeExecutionResponse(**response_data)
response_code = CodeExecutionResponse.model_validate(response_data)
if response_code.data.error:
raise CodeExecutionError(response_code.data.error)

View File

@ -26,7 +26,7 @@ def batch_fetch_plugin_manifests(plugin_ids: list[str]) -> Sequence[MarketplaceP
response = httpx.post(url, json={"plugin_ids": plugin_ids}, headers={"X-Dify-Version": dify_config.project.version})
response.raise_for_status()
return [MarketplacePluginDeclaration(**plugin) for plugin in response.json()["data"]["plugins"]]
return [MarketplacePluginDeclaration.model_validate(plugin) for plugin in response.json()["data"]["plugins"]]
def batch_fetch_plugin_manifests_ignore_deserialization_error(
@ -41,7 +41,7 @@ def batch_fetch_plugin_manifests_ignore_deserialization_error(
result: list[MarketplacePluginDeclaration] = []
for plugin in response.json()["data"]["plugins"]:
try:
result.append(MarketplacePluginDeclaration(**plugin))
result.append(MarketplacePluginDeclaration.model_validate(plugin))
except Exception:
pass

View File

@ -20,7 +20,7 @@ from core.rag.cleaner.clean_processor import CleanProcessor
from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.docstore.dataset_docstore import DatasetDocumentStore
from core.rag.extractor.entity.datasource_type import DatasourceType
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo, WebsiteInfo
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
@ -343,7 +343,7 @@ class IndexingRunner:
if file_detail:
extract_setting = ExtractSetting(
datasource_type=DatasourceType.FILE.value,
datasource_type=DatasourceType.FILE,
upload_file=file_detail,
document_model=dataset_document.doc_form,
)
@ -356,15 +356,17 @@ class IndexingRunner:
):
raise ValueError("no notion import info found")
extract_setting = ExtractSetting(
datasource_type=DatasourceType.NOTION.value,
notion_info={
"credential_id": data_source_info["credential_id"],
"notion_workspace_id": data_source_info["notion_workspace_id"],
"notion_obj_id": data_source_info["notion_page_id"],
"notion_page_type": data_source_info["type"],
"document": dataset_document,
"tenant_id": dataset_document.tenant_id,
},
datasource_type=DatasourceType.NOTION,
notion_info=NotionInfo.model_validate(
{
"credential_id": data_source_info["credential_id"],
"notion_workspace_id": data_source_info["notion_workspace_id"],
"notion_obj_id": data_source_info["notion_page_id"],
"notion_page_type": data_source_info["type"],
"document": dataset_document,
"tenant_id": dataset_document.tenant_id,
}
),
document_model=dataset_document.doc_form,
)
text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"])
@ -377,15 +379,17 @@ class IndexingRunner:
):
raise ValueError("no website import info found")
extract_setting = ExtractSetting(
datasource_type=DatasourceType.WEBSITE.value,
website_info={
"provider": data_source_info["provider"],
"job_id": data_source_info["job_id"],
"tenant_id": dataset_document.tenant_id,
"url": data_source_info["url"],
"mode": data_source_info["mode"],
"only_main_content": data_source_info["only_main_content"],
},
datasource_type=DatasourceType.WEBSITE,
website_info=WebsiteInfo.model_validate(
{
"provider": data_source_info["provider"],
"job_id": data_source_info["job_id"],
"tenant_id": dataset_document.tenant_id,
"url": data_source_info["url"],
"mode": data_source_info["mode"],
"only_main_content": data_source_info["only_main_content"],
}
),
document_model=dataset_document.doc_form,
)
text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"])

View File

@ -224,8 +224,8 @@ def _handle_native_json_schema(
# Set appropriate response format if required by the model
for rule in rules:
if rule.name == "response_format" and ResponseFormat.JSON_SCHEMA.value in rule.options:
model_parameters["response_format"] = ResponseFormat.JSON_SCHEMA.value
if rule.name == "response_format" and ResponseFormat.JSON_SCHEMA in rule.options:
model_parameters["response_format"] = ResponseFormat.JSON_SCHEMA
return model_parameters
@ -239,10 +239,10 @@ def _set_response_format(model_parameters: dict, rules: list):
"""
for rule in rules:
if rule.name == "response_format":
if ResponseFormat.JSON.value in rule.options:
model_parameters["response_format"] = ResponseFormat.JSON.value
elif ResponseFormat.JSON_OBJECT.value in rule.options:
model_parameters["response_format"] = ResponseFormat.JSON_OBJECT.value
if ResponseFormat.JSON in rule.options:
model_parameters["response_format"] = ResponseFormat.JSON
elif ResponseFormat.JSON_OBJECT in rule.options:
model_parameters["response_format"] = ResponseFormat.JSON_OBJECT
def _handle_prompt_based_schema(

View File

@ -294,7 +294,7 @@ class ClientSession(
method="completion/complete",
params=types.CompleteRequestParams(
ref=ref,
argument=types.CompletionArgument(**argument),
argument=types.CompletionArgument.model_validate(argument),
),
)
),

View File

@ -1,4 +1,4 @@
from pydantic import BaseModel
from pydantic import BaseModel, model_validator
class I18nObject(BaseModel):
@ -9,7 +9,8 @@ class I18nObject(BaseModel):
zh_Hans: str | None = None
en_US: str
def __init__(self, **data):
super().__init__(**data)
@model_validator(mode="after")
def _(self):
if not self.zh_Hans:
self.zh_Hans = self.en_US
return self

View File

@ -1,13 +1,13 @@
from collections.abc import Sequence
from enum import Enum, StrEnum, auto
from enum import StrEnum, auto
from pydantic import BaseModel, ConfigDict, Field, field_validator
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
class ConfigurateMethod(Enum):
class ConfigurateMethod(StrEnum):
"""
Enum class for configurate method of provider model.
"""
@ -46,10 +46,11 @@ class FormOption(BaseModel):
value: str
show_on: list[FormShowOnObject] = []
def __init__(self, **data):
super().__init__(**data)
@model_validator(mode="after")
def _(self):
if not self.label:
self.label = I18nObject(en_US=self.value)
return self
class CredentialFormSchema(BaseModel):

View File

@ -269,17 +269,17 @@ class ModelProviderFactory:
}
if model_type == ModelType.LLM:
return LargeLanguageModel(**init_params) # type: ignore
return LargeLanguageModel.model_validate(init_params)
elif model_type == ModelType.TEXT_EMBEDDING:
return TextEmbeddingModel(**init_params) # type: ignore
return TextEmbeddingModel.model_validate(init_params)
elif model_type == ModelType.RERANK:
return RerankModel(**init_params) # type: ignore
return RerankModel.model_validate(init_params)
elif model_type == ModelType.SPEECH2TEXT:
return Speech2TextModel(**init_params) # type: ignore
return Speech2TextModel.model_validate(init_params)
elif model_type == ModelType.MODERATION:
return ModerationModel(**init_params) # type: ignore
return ModerationModel.model_validate(init_params)
elif model_type == ModelType.TTS:
return TTSModel(**init_params) # type: ignore
return TTSModel.model_validate(init_params)
def get_provider_icon(self, provider: str, icon_type: str, lang: str) -> tuple[bytes, str]:
"""

View File

@ -51,7 +51,7 @@ class ApiModeration(Moderation):
params = ModerationInputParams(app_id=self.app_id, inputs=inputs, query=query)
result = self._get_config_by_requestor(APIBasedExtensionPoint.APP_MODERATION_INPUT, params.model_dump())
return ModerationInputsResult(**result)
return ModerationInputsResult.model_validate(result)
return ModerationInputsResult(
flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response
@ -67,7 +67,7 @@ class ApiModeration(Moderation):
params = ModerationOutputParams(app_id=self.app_id, text=text)
result = self._get_config_by_requestor(APIBasedExtensionPoint.APP_MODERATION_OUTPUT, params.model_dump())
return ModerationOutputsResult(**result)
return ModerationOutputsResult.model_validate(result)
return ModerationOutputsResult(
flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response

View File

@ -213,9 +213,9 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
node_metadata.update(json.loads(node_execution.execution_metadata))
# Determine the correct span kind based on node type
span_kind = OpenInferenceSpanKindValues.CHAIN.value
span_kind = OpenInferenceSpanKindValues.CHAIN
if node_execution.node_type == "llm":
span_kind = OpenInferenceSpanKindValues.LLM.value
span_kind = OpenInferenceSpanKindValues.LLM
provider = process_data.get("model_provider")
model = process_data.get("model_name")
if provider:
@ -230,18 +230,18 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
node_metadata["prompt_tokens"] = usage_data.get("prompt_tokens", 0)
node_metadata["completion_tokens"] = usage_data.get("completion_tokens", 0)
elif node_execution.node_type == "dataset_retrieval":
span_kind = OpenInferenceSpanKindValues.RETRIEVER.value
span_kind = OpenInferenceSpanKindValues.RETRIEVER
elif node_execution.node_type == "tool":
span_kind = OpenInferenceSpanKindValues.TOOL.value
span_kind = OpenInferenceSpanKindValues.TOOL
else:
span_kind = OpenInferenceSpanKindValues.CHAIN.value
span_kind = OpenInferenceSpanKindValues.CHAIN
node_span = self.tracer.start_span(
name=node_execution.node_type,
attributes={
SpanAttributes.INPUT_VALUE: node_execution.inputs or "{}",
SpanAttributes.OUTPUT_VALUE: node_execution.outputs or "{}",
SpanAttributes.OPENINFERENCE_SPAN_KIND: span_kind,
SpanAttributes.OPENINFERENCE_SPAN_KIND: span_kind.value,
SpanAttributes.METADATA: json.dumps(node_metadata, ensure_ascii=False),
SpanAttributes.SESSION_ID: trace_info.conversation_id or "",
},

View File

@ -73,7 +73,7 @@ class LangFuseDataTrace(BaseTraceInstance):
if trace_info.message_id:
trace_id = trace_info.trace_id or trace_info.message_id
name = TraceTaskName.MESSAGE_TRACE.value
name = TraceTaskName.MESSAGE_TRACE
trace_data = LangfuseTrace(
id=trace_id,
user_id=user_id,
@ -88,7 +88,7 @@ class LangFuseDataTrace(BaseTraceInstance):
self.add_trace(langfuse_trace_data=trace_data)
workflow_span_data = LangfuseSpan(
id=trace_info.workflow_run_id,
name=TraceTaskName.WORKFLOW_TRACE.value,
name=TraceTaskName.WORKFLOW_TRACE,
input=dict(trace_info.workflow_run_inputs),
output=dict(trace_info.workflow_run_outputs),
trace_id=trace_id,
@ -103,7 +103,7 @@ class LangFuseDataTrace(BaseTraceInstance):
trace_data = LangfuseTrace(
id=trace_id,
user_id=user_id,
name=TraceTaskName.WORKFLOW_TRACE.value,
name=TraceTaskName.WORKFLOW_TRACE,
input=dict(trace_info.workflow_run_inputs),
output=dict(trace_info.workflow_run_outputs),
metadata=metadata,
@ -253,7 +253,7 @@ class LangFuseDataTrace(BaseTraceInstance):
trace_data = LangfuseTrace(
id=trace_id,
user_id=user_id,
name=TraceTaskName.MESSAGE_TRACE.value,
name=TraceTaskName.MESSAGE_TRACE,
input={
"message": trace_info.inputs,
"files": file_list,
@ -303,7 +303,7 @@ class LangFuseDataTrace(BaseTraceInstance):
if trace_info.message_data is None:
return
span_data = LangfuseSpan(
name=TraceTaskName.MODERATION_TRACE.value,
name=TraceTaskName.MODERATION_TRACE,
input=trace_info.inputs,
output={
"action": trace_info.action,
@ -331,7 +331,7 @@ class LangFuseDataTrace(BaseTraceInstance):
)
generation_data = LangfuseGeneration(
name=TraceTaskName.SUGGESTED_QUESTION_TRACE.value,
name=TraceTaskName.SUGGESTED_QUESTION_TRACE,
input=trace_info.inputs,
output=str(trace_info.suggested_question),
trace_id=trace_info.trace_id or trace_info.message_id,
@ -349,7 +349,7 @@ class LangFuseDataTrace(BaseTraceInstance):
if trace_info.message_data is None:
return
dataset_retrieval_span_data = LangfuseSpan(
name=TraceTaskName.DATASET_RETRIEVAL_TRACE.value,
name=TraceTaskName.DATASET_RETRIEVAL_TRACE,
input=trace_info.inputs,
output={"documents": trace_info.documents},
trace_id=trace_info.trace_id or trace_info.message_id,
@ -377,7 +377,7 @@ class LangFuseDataTrace(BaseTraceInstance):
def generate_name_trace(self, trace_info: GenerateNameTraceInfo):
name_generation_trace_data = LangfuseTrace(
name=TraceTaskName.GENERATE_NAME_TRACE.value,
name=TraceTaskName.GENERATE_NAME_TRACE,
input=trace_info.inputs,
output=trace_info.outputs,
user_id=trace_info.tenant_id,
@ -388,7 +388,7 @@ class LangFuseDataTrace(BaseTraceInstance):
self.add_trace(langfuse_trace_data=name_generation_trace_data)
name_generation_span_data = LangfuseSpan(
name=TraceTaskName.GENERATE_NAME_TRACE.value,
name=TraceTaskName.GENERATE_NAME_TRACE,
input=trace_info.inputs,
output=trace_info.outputs,
trace_id=trace_info.conversation_id,

View File

@ -81,7 +81,7 @@ class LangSmithDataTrace(BaseTraceInstance):
if trace_info.message_id:
message_run = LangSmithRunModel(
id=trace_info.message_id,
name=TraceTaskName.MESSAGE_TRACE.value,
name=TraceTaskName.MESSAGE_TRACE,
inputs=dict(trace_info.workflow_run_inputs),
outputs=dict(trace_info.workflow_run_outputs),
run_type=LangSmithRunType.chain,
@ -110,7 +110,7 @@ class LangSmithDataTrace(BaseTraceInstance):
file_list=trace_info.file_list,
total_tokens=trace_info.total_tokens,
id=trace_info.workflow_run_id,
name=TraceTaskName.WORKFLOW_TRACE.value,
name=TraceTaskName.WORKFLOW_TRACE,
inputs=dict(trace_info.workflow_run_inputs),
run_type=LangSmithRunType.tool,
start_time=trace_info.workflow_data.created_at,
@ -271,7 +271,7 @@ class LangSmithDataTrace(BaseTraceInstance):
output_tokens=trace_info.answer_tokens,
total_tokens=trace_info.total_tokens,
id=message_id,
name=TraceTaskName.MESSAGE_TRACE.value,
name=TraceTaskName.MESSAGE_TRACE,
inputs=trace_info.inputs,
run_type=LangSmithRunType.chain,
start_time=trace_info.start_time,
@ -327,7 +327,7 @@ class LangSmithDataTrace(BaseTraceInstance):
if trace_info.message_data is None:
return
langsmith_run = LangSmithRunModel(
name=TraceTaskName.MODERATION_TRACE.value,
name=TraceTaskName.MODERATION_TRACE,
inputs=trace_info.inputs,
outputs={
"action": trace_info.action,
@ -362,7 +362,7 @@ class LangSmithDataTrace(BaseTraceInstance):
if message_data is None:
return
suggested_question_run = LangSmithRunModel(
name=TraceTaskName.SUGGESTED_QUESTION_TRACE.value,
name=TraceTaskName.SUGGESTED_QUESTION_TRACE,
inputs=trace_info.inputs,
outputs=trace_info.suggested_question,
run_type=LangSmithRunType.tool,
@ -391,7 +391,7 @@ class LangSmithDataTrace(BaseTraceInstance):
if trace_info.message_data is None:
return
dataset_retrieval_run = LangSmithRunModel(
name=TraceTaskName.DATASET_RETRIEVAL_TRACE.value,
name=TraceTaskName.DATASET_RETRIEVAL_TRACE,
inputs=trace_info.inputs,
outputs={"documents": trace_info.documents},
run_type=LangSmithRunType.retriever,
@ -447,7 +447,7 @@ class LangSmithDataTrace(BaseTraceInstance):
def generate_name_trace(self, trace_info: GenerateNameTraceInfo):
name_run = LangSmithRunModel(
name=TraceTaskName.GENERATE_NAME_TRACE.value,
name=TraceTaskName.GENERATE_NAME_TRACE,
inputs=trace_info.inputs,
outputs=trace_info.outputs,
run_type=LangSmithRunType.tool,

View File

@ -108,7 +108,7 @@ class OpikDataTrace(BaseTraceInstance):
trace_data = {
"id": opik_trace_id,
"name": TraceTaskName.MESSAGE_TRACE.value,
"name": TraceTaskName.MESSAGE_TRACE,
"start_time": trace_info.start_time,
"end_time": trace_info.end_time,
"metadata": workflow_metadata,
@ -125,7 +125,7 @@ class OpikDataTrace(BaseTraceInstance):
"id": root_span_id,
"parent_span_id": None,
"trace_id": opik_trace_id,
"name": TraceTaskName.WORKFLOW_TRACE.value,
"name": TraceTaskName.WORKFLOW_TRACE,
"input": wrap_dict("input", trace_info.workflow_run_inputs),
"output": wrap_dict("output", trace_info.workflow_run_outputs),
"start_time": trace_info.start_time,
@ -138,7 +138,7 @@ class OpikDataTrace(BaseTraceInstance):
else:
trace_data = {
"id": opik_trace_id,
"name": TraceTaskName.MESSAGE_TRACE.value,
"name": TraceTaskName.MESSAGE_TRACE,
"start_time": trace_info.start_time,
"end_time": trace_info.end_time,
"metadata": workflow_metadata,
@ -290,7 +290,7 @@ class OpikDataTrace(BaseTraceInstance):
trace_data = {
"id": prepare_opik_uuid(trace_info.start_time, dify_trace_id),
"name": TraceTaskName.MESSAGE_TRACE.value,
"name": TraceTaskName.MESSAGE_TRACE,
"start_time": trace_info.start_time,
"end_time": trace_info.end_time,
"metadata": wrap_metadata(metadata),
@ -329,7 +329,7 @@ class OpikDataTrace(BaseTraceInstance):
span_data = {
"trace_id": prepare_opik_uuid(start_time, trace_info.trace_id or trace_info.message_id),
"name": TraceTaskName.MODERATION_TRACE.value,
"name": TraceTaskName.MODERATION_TRACE,
"type": "tool",
"start_time": start_time,
"end_time": trace_info.end_time or trace_info.message_data.updated_at,
@ -355,7 +355,7 @@ class OpikDataTrace(BaseTraceInstance):
span_data = {
"trace_id": prepare_opik_uuid(start_time, trace_info.trace_id or trace_info.message_id),
"name": TraceTaskName.SUGGESTED_QUESTION_TRACE.value,
"name": TraceTaskName.SUGGESTED_QUESTION_TRACE,
"type": "tool",
"start_time": start_time,
"end_time": trace_info.end_time or message_data.updated_at,
@ -375,7 +375,7 @@ class OpikDataTrace(BaseTraceInstance):
span_data = {
"trace_id": prepare_opik_uuid(start_time, trace_info.trace_id or trace_info.message_id),
"name": TraceTaskName.DATASET_RETRIEVAL_TRACE.value,
"name": TraceTaskName.DATASET_RETRIEVAL_TRACE,
"type": "tool",
"start_time": start_time,
"end_time": trace_info.end_time or trace_info.message_data.updated_at,
@ -405,7 +405,7 @@ class OpikDataTrace(BaseTraceInstance):
def generate_name_trace(self, trace_info: GenerateNameTraceInfo):
trace_data = {
"id": prepare_opik_uuid(trace_info.start_time, trace_info.trace_id or trace_info.message_id),
"name": TraceTaskName.GENERATE_NAME_TRACE.value,
"name": TraceTaskName.GENERATE_NAME_TRACE,
"start_time": trace_info.start_time,
"end_time": trace_info.end_time,
"metadata": wrap_metadata(trace_info.metadata),
@ -420,7 +420,7 @@ class OpikDataTrace(BaseTraceInstance):
span_data = {
"trace_id": trace.id,
"name": TraceTaskName.GENERATE_NAME_TRACE.value,
"name": TraceTaskName.GENERATE_NAME_TRACE,
"start_time": trace_info.start_time,
"end_time": trace_info.end_time,
"metadata": wrap_metadata(trace_info.metadata),

View File

@ -104,7 +104,7 @@ class WeaveDataTrace(BaseTraceInstance):
message_run = WeaveTraceModel(
id=trace_info.message_id,
op=str(TraceTaskName.MESSAGE_TRACE.value),
op=str(TraceTaskName.MESSAGE_TRACE),
inputs=dict(trace_info.workflow_run_inputs),
outputs=dict(trace_info.workflow_run_outputs),
total_tokens=trace_info.total_tokens,
@ -126,7 +126,7 @@ class WeaveDataTrace(BaseTraceInstance):
file_list=trace_info.file_list,
total_tokens=trace_info.total_tokens,
id=trace_info.workflow_run_id,
op=str(TraceTaskName.WORKFLOW_TRACE.value),
op=str(TraceTaskName.WORKFLOW_TRACE),
inputs=dict(trace_info.workflow_run_inputs),
outputs=dict(trace_info.workflow_run_outputs),
attributes=workflow_attributes,
@ -253,7 +253,7 @@ class WeaveDataTrace(BaseTraceInstance):
message_run = WeaveTraceModel(
id=trace_id,
op=str(TraceTaskName.MESSAGE_TRACE.value),
op=str(TraceTaskName.MESSAGE_TRACE),
input_tokens=trace_info.message_tokens,
output_tokens=trace_info.answer_tokens,
total_tokens=trace_info.total_tokens,
@ -300,7 +300,7 @@ class WeaveDataTrace(BaseTraceInstance):
moderation_run = WeaveTraceModel(
id=str(uuid.uuid4()),
op=str(TraceTaskName.MODERATION_TRACE.value),
op=str(TraceTaskName.MODERATION_TRACE),
inputs=trace_info.inputs,
outputs={
"action": trace_info.action,
@ -330,7 +330,7 @@ class WeaveDataTrace(BaseTraceInstance):
suggested_question_run = WeaveTraceModel(
id=str(uuid.uuid4()),
op=str(TraceTaskName.SUGGESTED_QUESTION_TRACE.value),
op=str(TraceTaskName.SUGGESTED_QUESTION_TRACE),
inputs=trace_info.inputs,
outputs=trace_info.suggested_question,
attributes=attributes,
@ -355,7 +355,7 @@ class WeaveDataTrace(BaseTraceInstance):
dataset_retrieval_run = WeaveTraceModel(
id=str(uuid.uuid4()),
op=str(TraceTaskName.DATASET_RETRIEVAL_TRACE.value),
op=str(TraceTaskName.DATASET_RETRIEVAL_TRACE),
inputs=trace_info.inputs,
outputs={"documents": trace_info.documents},
attributes=attributes,
@ -397,7 +397,7 @@ class WeaveDataTrace(BaseTraceInstance):
name_run = WeaveTraceModel(
id=str(uuid.uuid4()),
op=str(TraceTaskName.GENERATE_NAME_TRACE.value),
op=str(TraceTaskName.GENERATE_NAME_TRACE),
inputs=trace_info.inputs,
outputs=trace_info.outputs,
attributes=attributes,

View File

@ -52,7 +52,7 @@ class PluginNodeBackwardsInvocation(BaseBackwardsInvocation):
instruction=instruction, # instruct with variables are not supported
)
node_data_dict = node_data.model_dump()
node_data_dict["type"] = NodeType.PARAMETER_EXTRACTOR.value
node_data_dict["type"] = NodeType.PARAMETER_EXTRACTOR
execution = workflow_service.run_free_workflow_node(
node_data_dict,
tenant_id=tenant_id,

View File

@ -83,16 +83,16 @@ class RequestInvokeLLM(BaseRequestInvokeModel):
raise ValueError("prompt_messages must be a list")
for i in range(len(v)):
if v[i]["role"] == PromptMessageRole.USER.value:
v[i] = UserPromptMessage(**v[i])
elif v[i]["role"] == PromptMessageRole.ASSISTANT.value:
v[i] = AssistantPromptMessage(**v[i])
elif v[i]["role"] == PromptMessageRole.SYSTEM.value:
v[i] = SystemPromptMessage(**v[i])
elif v[i]["role"] == PromptMessageRole.TOOL.value:
v[i] = ToolPromptMessage(**v[i])
if v[i]["role"] == PromptMessageRole.USER:
v[i] = UserPromptMessage.model_validate(v[i])
elif v[i]["role"] == PromptMessageRole.ASSISTANT:
v[i] = AssistantPromptMessage.model_validate(v[i])
elif v[i]["role"] == PromptMessageRole.SYSTEM:
v[i] = SystemPromptMessage.model_validate(v[i])
elif v[i]["role"] == PromptMessageRole.TOOL:
v[i] = ToolPromptMessage.model_validate(v[i])
else:
v[i] = PromptMessage(**v[i])
v[i] = PromptMessage.model_validate(v[i])
return v

View File

@ -2,11 +2,10 @@ import inspect
import json
import logging
from collections.abc import Callable, Generator
from typing import TypeVar
from typing import Any, TypeVar
import requests
import httpx
from pydantic import BaseModel
from requests.exceptions import HTTPError
from yarl import URL
from configs import dify_config
@ -47,29 +46,56 @@ class BasePluginClient:
data: bytes | dict | str | None = None,
params: dict | None = None,
files: dict | None = None,
stream: bool = False,
) -> requests.Response:
) -> httpx.Response:
"""
Make a request to the plugin daemon inner API.
"""
url = plugin_daemon_inner_api_baseurl / path
headers = headers or {}
headers["X-Api-Key"] = dify_config.PLUGIN_DAEMON_KEY
headers["Accept-Encoding"] = "gzip, deflate, br"
url, headers, prepared_data, params, files = self._prepare_request(path, headers, data, params, files)
if headers.get("Content-Type") == "application/json" and isinstance(data, dict):
data = json.dumps(data)
request_kwargs: dict[str, Any] = {
"method": method,
"url": url,
"headers": headers,
"params": params,
"files": files,
}
if isinstance(prepared_data, dict):
request_kwargs["data"] = prepared_data
elif prepared_data is not None:
request_kwargs["content"] = prepared_data
try:
response = requests.request(
method=method, url=str(url), headers=headers, data=data, params=params, stream=stream, files=files
)
except requests.ConnectionError:
response = httpx.request(**request_kwargs)
except httpx.RequestError:
logger.exception("Request to Plugin Daemon Service failed")
raise PluginDaemonInnerError(code=-500, message="Request to Plugin Daemon Service failed")
return response
def _prepare_request(
self,
path: str,
headers: dict | None,
data: bytes | dict | str | None,
params: dict | None,
files: dict | None,
) -> tuple[str, dict, bytes | dict | str | None, dict | None, dict | None]:
url = plugin_daemon_inner_api_baseurl / path
prepared_headers = dict(headers or {})
prepared_headers["X-Api-Key"] = dify_config.PLUGIN_DAEMON_KEY
prepared_headers.setdefault("Accept-Encoding", "gzip, deflate, br")
prepared_data: bytes | dict | str | None = (
data if isinstance(data, (bytes, str, dict)) or data is None else None
)
if isinstance(data, dict):
if prepared_headers.get("Content-Type") == "application/json":
prepared_data = json.dumps(data)
else:
prepared_data = data
return str(url), prepared_headers, prepared_data, params, files
def _stream_request(
self,
method: str,
@ -78,23 +104,44 @@ class BasePluginClient:
headers: dict | None = None,
data: bytes | dict | None = None,
files: dict | None = None,
) -> Generator[bytes, None, None]:
) -> Generator[str, None, None]:
"""
Make a stream request to the plugin daemon inner API
"""
response = self._request(method, path, headers, data, params, files, stream=True)
for line in response.iter_lines(chunk_size=1024 * 8):
line = line.decode("utf-8").strip()
if line.startswith("data:"):
line = line[5:].strip()
if line:
yield line
url, headers, prepared_data, params, files = self._prepare_request(path, headers, data, params, files)
stream_kwargs: dict[str, Any] = {
"method": method,
"url": url,
"headers": headers,
"params": params,
"files": files,
}
if isinstance(prepared_data, dict):
stream_kwargs["data"] = prepared_data
elif prepared_data is not None:
stream_kwargs["content"] = prepared_data
try:
with httpx.stream(**stream_kwargs) as response:
for raw_line in response.iter_lines():
if raw_line is None:
continue
line = raw_line.decode("utf-8") if isinstance(raw_line, bytes) else raw_line
line = line.strip()
if line.startswith("data:"):
line = line[5:].strip()
if line:
yield line
except httpx.RequestError:
logger.exception("Stream request to Plugin Daemon Service failed")
raise PluginDaemonInnerError(code=-500, message="Request to Plugin Daemon Service failed")
def _stream_request_with_model(
self,
method: str,
path: str,
type: type[T],
type_: type[T],
headers: dict | None = None,
data: bytes | dict | None = None,
params: dict | None = None,
@ -104,13 +151,13 @@ class BasePluginClient:
Make a stream request to the plugin daemon inner API and yield the response as a model.
"""
for line in self._stream_request(method, path, params, headers, data, files):
yield type(**json.loads(line)) # type: ignore
yield type_(**json.loads(line)) # type: ignore
def _request_with_model(
self,
method: str,
path: str,
type: type[T],
type_: type[T],
headers: dict | None = None,
data: bytes | None = None,
params: dict | None = None,
@ -120,13 +167,13 @@ class BasePluginClient:
Make a request to the plugin daemon inner API and return the response as a model.
"""
response = self._request(method, path, headers, data, params, files)
return type(**response.json()) # type: ignore
return type_(**response.json()) # type: ignore
def _request_with_plugin_daemon_response(
self,
method: str,
path: str,
type: type[T],
type_: type[T],
headers: dict | None = None,
data: bytes | dict | None = None,
params: dict | None = None,
@ -139,23 +186,23 @@ class BasePluginClient:
try:
response = self._request(method, path, headers, data, params, files)
response.raise_for_status()
except HTTPError as e:
msg = f"Failed to request plugin daemon, status: {e.response.status_code}, url: {path}"
logger.exception(msg)
except httpx.HTTPStatusError as e:
logger.exception("Failed to request plugin daemon, status: %s, url: %s", e.response.status_code, path)
raise e
except Exception as e:
msg = f"Failed to request plugin daemon, url: {path}"
logger.exception(msg)
logger.exception("Failed to request plugin daemon, url: %s", path)
raise ValueError(msg) from e
try:
json_response = response.json()
if transformer:
json_response = transformer(json_response)
rep = PluginDaemonBasicResponse[type](**json_response) # type: ignore
# https://stackoverflow.com/questions/59634937/variable-foo-class-is-not-valid-as-type-but-why
rep = PluginDaemonBasicResponse[type_].model_validate(json_response) # type: ignore
except Exception:
msg = (
f"Failed to parse response from plugin daemon to PluginDaemonBasicResponse [{str(type.__name__)}],"
f"Failed to parse response from plugin daemon to PluginDaemonBasicResponse [{str(type_.__name__)}],"
f" url: {path}"
)
logger.exception(msg)
@ -163,7 +210,7 @@ class BasePluginClient:
if rep.code != 0:
try:
error = PluginDaemonError(**json.loads(rep.message))
error = PluginDaemonError.model_validate(json.loads(rep.message))
except Exception:
raise ValueError(f"{rep.message}, code: {rep.code}")
@ -178,7 +225,7 @@ class BasePluginClient:
self,
method: str,
path: str,
type: type[T],
type_: type[T],
headers: dict | None = None,
data: bytes | dict | None = None,
params: dict | None = None,
@ -189,7 +236,7 @@ class BasePluginClient:
"""
for line in self._stream_request(method, path, params, headers, data, files):
try:
rep = PluginDaemonBasicResponse[type].model_validate_json(line) # type: ignore
rep = PluginDaemonBasicResponse[type_].model_validate_json(line) # type: ignore
except (ValueError, TypeError):
# TODO modify this when line_data has code and message
try:
@ -204,11 +251,11 @@ class BasePluginClient:
if rep.code != 0:
if rep.code == -500:
try:
error = PluginDaemonError(**json.loads(rep.message))
error = PluginDaemonError.model_validate(json.loads(rep.message))
except Exception:
raise PluginDaemonInnerError(code=rep.code, message=rep.message)
logger.error("Error in stream reponse for plugin %s", rep.__dict__)
logger.error("Error in stream response for plugin %s", rep.__dict__)
self._handle_plugin_daemon_error(error.error_type, error.message)
raise ValueError(f"plugin daemon: {rep.message}, code: {rep.code}")
if rep.data is None:

View File

@ -46,7 +46,9 @@ class PluginDatasourceManager(BasePluginClient):
params={"page": 1, "page_size": 256},
transformer=transformer,
)
local_file_datasource_provider = PluginDatasourceProviderEntity(**self._get_local_file_datasource_provider())
local_file_datasource_provider = PluginDatasourceProviderEntity.model_validate(
self._get_local_file_datasource_provider()
)
for provider in response:
ToolTransformService.repack_provider(tenant_id=tenant_id, provider=provider)
@ -104,7 +106,7 @@ class PluginDatasourceManager(BasePluginClient):
Fetch datasource provider for the given tenant and plugin.
"""
if provider_id == "langgenius/file/file":
return PluginDatasourceProviderEntity(**self._get_local_file_datasource_provider())
return PluginDatasourceProviderEntity.model_validate(self._get_local_file_datasource_provider())
tool_provider_id = DatasourceProviderID(provider_id)

View File

@ -162,7 +162,7 @@ class PluginModelClient(BasePluginClient):
response = self._request_with_plugin_daemon_response_stream(
method="POST",
path=f"plugin/{tenant_id}/dispatch/llm/invoke",
type=LLMResultChunk,
type_=LLMResultChunk,
data=jsonable_encoder(
{
"user_id": user_id,
@ -208,7 +208,7 @@ class PluginModelClient(BasePluginClient):
response = self._request_with_plugin_daemon_response_stream(
method="POST",
path=f"plugin/{tenant_id}/dispatch/llm/num_tokens",
type=PluginLLMNumTokensResponse,
type_=PluginLLMNumTokensResponse,
data=jsonable_encoder(
{
"user_id": user_id,
@ -250,7 +250,7 @@ class PluginModelClient(BasePluginClient):
response = self._request_with_plugin_daemon_response_stream(
method="POST",
path=f"plugin/{tenant_id}/dispatch/text_embedding/invoke",
type=TextEmbeddingResult,
type_=TextEmbeddingResult,
data=jsonable_encoder(
{
"user_id": user_id,
@ -291,7 +291,7 @@ class PluginModelClient(BasePluginClient):
response = self._request_with_plugin_daemon_response_stream(
method="POST",
path=f"plugin/{tenant_id}/dispatch/text_embedding/num_tokens",
type=PluginTextEmbeddingNumTokensResponse,
type_=PluginTextEmbeddingNumTokensResponse,
data=jsonable_encoder(
{
"user_id": user_id,
@ -334,7 +334,7 @@ class PluginModelClient(BasePluginClient):
response = self._request_with_plugin_daemon_response_stream(
method="POST",
path=f"plugin/{tenant_id}/dispatch/rerank/invoke",
type=RerankResult,
type_=RerankResult,
data=jsonable_encoder(
{
"user_id": user_id,
@ -378,7 +378,7 @@ class PluginModelClient(BasePluginClient):
response = self._request_with_plugin_daemon_response_stream(
method="POST",
path=f"plugin/{tenant_id}/dispatch/tts/invoke",
type=PluginStringResultResponse,
type_=PluginStringResultResponse,
data=jsonable_encoder(
{
"user_id": user_id,
@ -422,7 +422,7 @@ class PluginModelClient(BasePluginClient):
response = self._request_with_plugin_daemon_response_stream(
method="POST",
path=f"plugin/{tenant_id}/dispatch/tts/model/voices",
type=PluginVoicesResponse,
type_=PluginVoicesResponse,
data=jsonable_encoder(
{
"user_id": user_id,
@ -466,7 +466,7 @@ class PluginModelClient(BasePluginClient):
response = self._request_with_plugin_daemon_response_stream(
method="POST",
path=f"plugin/{tenant_id}/dispatch/speech2text/invoke",
type=PluginStringResultResponse,
type_=PluginStringResultResponse,
data=jsonable_encoder(
{
"user_id": user_id,
@ -506,7 +506,7 @@ class PluginModelClient(BasePluginClient):
response = self._request_with_plugin_daemon_response_stream(
method="POST",
path=f"plugin/{tenant_id}/dispatch/moderation/invoke",
type=PluginBasicBooleanResponse,
type_=PluginBasicBooleanResponse,
data=jsonable_encoder(
{
"user_id": user_id,

View File

@ -1,6 +1,6 @@
from collections.abc import Generator
from dataclasses import dataclass, field
from typing import TypeVar, Union, cast
from typing import TypeVar, Union
from core.agent.entities import AgentInvokeMessage
from core.tools.entities.tool_entities import ToolInvokeMessage
@ -87,7 +87,8 @@ def merge_blob_chunks(
),
meta=resp.meta,
)
yield cast(MessageType, merged_message)
assert isinstance(merged_message, (ToolInvokeMessage, AgentInvokeMessage))
yield merged_message # type: ignore
# Clean up the buffer
del files[chunk_id]
else:

View File

@ -610,7 +610,7 @@ class ProviderManager:
provider_quota_to_provider_record_dict = {}
for provider_record in provider_records:
if provider_record.provider_type != ProviderType.SYSTEM.value:
if provider_record.provider_type != ProviderType.SYSTEM:
continue
provider_quota_to_provider_record_dict[ProviderQuotaType.value_of(provider_record.quota_type)] = (
@ -627,8 +627,8 @@ class ProviderManager:
tenant_id=tenant_id,
# TODO: Use provider name with prefix after the data migration.
provider_name=ModelProviderID(provider_name).provider_name,
provider_type=ProviderType.SYSTEM.value,
quota_type=ProviderQuotaType.TRIAL.value,
provider_type=ProviderType.SYSTEM,
quota_type=ProviderQuotaType.TRIAL,
quota_limit=quota.quota_limit, # type: ignore
quota_used=0,
is_valid=True,
@ -641,8 +641,8 @@ class ProviderManager:
stmt = select(Provider).where(
Provider.tenant_id == tenant_id,
Provider.provider_name == ModelProviderID(provider_name).provider_name,
Provider.provider_type == ProviderType.SYSTEM.value,
Provider.quota_type == ProviderQuotaType.TRIAL.value,
Provider.provider_type == ProviderType.SYSTEM,
Provider.quota_type == ProviderQuotaType.TRIAL,
)
existed_provider_record = db.session.scalar(stmt)
if not existed_provider_record:
@ -702,7 +702,7 @@ class ProviderManager:
"""Get custom provider configuration."""
# Find custom provider record (non-system)
custom_provider_record = next(
(record for record in provider_records if record.provider_type != ProviderType.SYSTEM.value), None
(record for record in provider_records if record.provider_type != ProviderType.SYSTEM), None
)
if not custom_provider_record:
@ -905,7 +905,7 @@ class ProviderManager:
# Convert provider_records to dict
quota_type_to_provider_records_dict: dict[ProviderQuotaType, Provider] = {}
for provider_record in provider_records:
if provider_record.provider_type != ProviderType.SYSTEM.value:
if provider_record.provider_type != ProviderType.SYSTEM:
continue
quota_type_to_provider_records_dict[ProviderQuotaType.value_of(provider_record.quota_type)] = (
@ -1046,7 +1046,7 @@ class ProviderManager:
"""
secret_input_form_variables = []
for credential_form_schema in credential_form_schemas:
if credential_form_schema.type.value == FormType.SECRET_INPUT.value:
if credential_form_schema.type == FormType.SECRET_INPUT:
secret_input_form_variables.append(credential_form_schema.variable)
return secret_input_form_variables

View File

@ -46,7 +46,7 @@ class DataPostProcessor:
reranking_model: dict | None = None,
weights: dict | None = None,
) -> BaseRerankRunner | None:
if reranking_mode == RerankMode.WEIGHTED_SCORE.value and weights:
if reranking_mode == RerankMode.WEIGHTED_SCORE and weights:
runner = RerankRunnerFactory.create_rerank_runner(
runner_type=reranking_mode,
tenant_id=tenant_id,
@ -62,7 +62,7 @@ class DataPostProcessor:
),
)
return runner
elif reranking_mode == RerankMode.RERANKING_MODEL.value:
elif reranking_mode == RerankMode.RERANKING_MODEL:
rerank_model_instance = self._get_rerank_model_instance(tenant_id, reranking_model)
if rerank_model_instance is None:
return None

Some files were not shown because too many files have changed in this diff Show More