Merge branch 'main' into feat/hitl-frontend

This commit is contained in:
twwu 2025-12-12 17:59:08 +08:00
commit ccef15aafa
742 changed files with 26364 additions and 17398 deletions

8
.github/CODEOWNERS vendored
View File

@ -9,6 +9,14 @@
# Backend (default owner, more specific rules below will override) # Backend (default owner, more specific rules below will override)
api/ @QuantumGhost api/ @QuantumGhost
# Backend - MCP
api/core/mcp/ @Nov1c444
api/core/entities/mcp_provider.py @Nov1c444
api/services/tools/mcp_tools_manage_service.py @Nov1c444
api/controllers/mcp/ @Nov1c444
api/controllers/console/app/mcp_server.py @Nov1c444
api/tests/**/*mcp* @Nov1c444
# Backend - Workflow - Engine (Core graph execution engine) # Backend - Workflow - Engine (Core graph execution engine)
api/core/workflow/graph_engine/ @laipz8200 @QuantumGhost api/core/workflow/graph_engine/ @laipz8200 @QuantumGhost
api/core/workflow/runtime/ @laipz8200 @QuantumGhost api/core/workflow/runtime/ @laipz8200 @QuantumGhost

View File

@ -1,8 +1,6 @@
name: "✨ Refactor" name: "✨ Refactor or Chore"
description: Refactor existing code for improved readability and maintainability. description: Refactor existing code or perform maintenance chores to improve readability and reliability.
title: "[Chore/Refactor] " title: "[Refactor/Chore] "
labels:
- refactor
body: body:
- type: checkboxes - type: checkboxes
attributes: attributes:
@ -11,7 +9,7 @@ body:
options: options:
- label: I have read the [Contributing Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md) and [Language Policy](https://github.com/langgenius/dify/issues/1542). - label: I have read the [Contributing Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md) and [Language Policy](https://github.com/langgenius/dify/issues/1542).
required: true required: true
- label: This is only for refactoring, if you would like to ask a question, please head to [Discussions](https://github.com/langgenius/dify/discussions/categories/general). - label: This is only for refactors or chores; if you would like to ask a question, please head to [Discussions](https://github.com/langgenius/dify/discussions/categories/general).
required: true required: true
- label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones. - label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones.
required: true required: true
@ -25,14 +23,14 @@ body:
id: description id: description
attributes: attributes:
label: Description label: Description
placeholder: "Describe the refactor you are proposing." placeholder: "Describe the refactor or chore you are proposing."
validations: validations:
required: true required: true
- type: textarea - type: textarea
id: motivation id: motivation
attributes: attributes:
label: Motivation label: Motivation
placeholder: "Explain why this refactor is necessary." placeholder: "Explain why this refactor or chore is necessary."
validations: validations:
required: false required: false
- type: textarea - type: textarea

View File

@ -1,13 +0,0 @@
name: "👾 Tracker"
description: For inner usages, please do not use this template.
title: "[Tracker] "
labels:
- tracker
body:
- type: textarea
id: content
attributes:
label: Blockers
placeholder: "- [ ] ..."
validations:
required: true

View File

@ -0,0 +1,21 @@
name: Semantic Pull Request
on:
pull_request:
types:
- opened
- edited
- reopened
- synchronize
jobs:
lint:
name: Validate PR title
permissions:
pull-requests: read
runs-on: ubuntu-latest
steps:
- name: Check title
uses: amannn/action-semantic-pull-request@v6.1.1
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}

View File

@ -106,7 +106,7 @@ jobs:
- name: Web type check - name: Web type check
if: steps.changed-files.outputs.any_changed == 'true' if: steps.changed-files.outputs.any_changed == 'true'
working-directory: ./web working-directory: ./web
run: pnpm run type-check run: pnpm run type-check:tsgo
docker-compose-template: docker-compose-template:
name: Docker Compose Template name: Docker Compose Template

1
.nvmrc Normal file
View File

@ -0,0 +1 @@
22.11.0

View File

@ -24,8 +24,8 @@ The codebase is split into:
```bash ```bash
cd web cd web
pnpm lint
pnpm lint:fix pnpm lint:fix
pnpm type-check:tsgo
pnpm test pnpm test
``` ```
@ -39,7 +39,7 @@ pnpm test
## Language Style ## Language Style
- **Python**: Keep type hints on functions and attributes, and implement relevant special methods (e.g., `__repr__`, `__str__`). - **Python**: Keep type hints on functions and attributes, and implement relevant special methods (e.g., `__repr__`, `__str__`).
- **TypeScript**: Use the strict config, lean on ESLint + Prettier workflows, and avoid `any` types. - **TypeScript**: Use the strict config, rely on ESLint (`pnpm lint:fix` preferred) plus `pnpm type-check:tsgo`, and avoid `any` types.
## General Practices ## General Practices

View File

@ -139,6 +139,19 @@ Star Dify on GitHub and be instantly notified of new releases.
If you need to customize the configuration, please refer to the comments in our [.env.example](docker/.env.example) file and update the corresponding values in your `.env` file. Additionally, you might need to make adjustments to the `docker-compose.yaml` file itself, such as changing image versions, port mappings, or volume mounts, based on your specific deployment environment and requirements. After making any changes, please re-run `docker-compose up -d`. You can find the full list of available environment variables [here](https://docs.dify.ai/getting-started/install-self-hosted/environments). If you need to customize the configuration, please refer to the comments in our [.env.example](docker/.env.example) file and update the corresponding values in your `.env` file. Additionally, you might need to make adjustments to the `docker-compose.yaml` file itself, such as changing image versions, port mappings, or volume mounts, based on your specific deployment environment and requirements. After making any changes, please re-run `docker-compose up -d`. You can find the full list of available environment variables [here](https://docs.dify.ai/getting-started/install-self-hosted/environments).
#### Customizing Suggested Questions
You can now customize the "Suggested Questions After Answer" feature to better fit your use case. For example, to generate longer, more technical questions:
```bash
# In your .env file
SUGGESTED_QUESTIONS_PROMPT='Please help me predict the five most likely technical follow-up questions a developer would ask. Focus on implementation details, best practices, and architecture considerations. Keep each question between 40-60 characters. Output must be JSON array: ["question1","question2","question3","question4","question5"]'
SUGGESTED_QUESTIONS_MAX_TOKENS=512
SUGGESTED_QUESTIONS_TEMPERATURE=0.3
```
See the [Suggested Questions Configuration Guide](docs/suggested-questions-configuration.md) for detailed examples and usage instructions.
### Metrics Monitoring with Grafana ### Metrics Monitoring with Grafana
Import the dashboard to Grafana, using Dify's PostgreSQL database as data source, to monitor metrics in granularity of apps, tenants, messages, and more. Import the dashboard to Grafana, using Dify's PostgreSQL database as data source, to monitor metrics in granularity of apps, tenants, messages, and more.

View File

@ -633,8 +633,30 @@ SWAGGER_UI_PATH=/swagger-ui.html
# Set to false to export dataset IDs as plain text for easier cross-environment import # Set to false to export dataset IDs as plain text for easier cross-environment import
DSL_EXPORT_ENCRYPT_DATASET_ID=true DSL_EXPORT_ENCRYPT_DATASET_ID=true
# Suggested Questions After Answer Configuration
# These environment variables allow customization of the suggested questions feature
#
# Custom prompt for generating suggested questions (optional)
# If not set, uses the default prompt that generates 3 questions under 20 characters each
# Example: "Please help me predict the five most likely technical follow-up questions a developer would ask. Focus on implementation details, best practices, and architecture considerations. Keep each question between 40-60 characters. Output must be JSON array: [\"question1\",\"question2\",\"question3\",\"question4\",\"question5\"]"
# SUGGESTED_QUESTIONS_PROMPT=
# Maximum number of tokens for suggested questions generation (default: 256)
# Adjust this value for longer questions or more questions
# SUGGESTED_QUESTIONS_MAX_TOKENS=256
# Temperature for suggested questions generation (default: 0.0)
# Higher values (0.5-1.0) produce more creative questions, lower values (0.0-0.3) produce more focused questions
# SUGGESTED_QUESTIONS_TEMPERATURE=0
# Tenant isolated task queue configuration # Tenant isolated task queue configuration
TENANT_ISOLATED_TASK_CONCURRENCY=1 TENANT_ISOLATED_TASK_CONCURRENCY=1
# Maximum number of segments for dataset segments API (0 for unlimited) # Maximum number of segments for dataset segments API (0 for unlimited)
DATASET_MAX_SEGMENTS_PER_REQUEST=0 DATASET_MAX_SEGMENTS_PER_REQUEST=0
# Multimodal knowledgebase limit
SINGLE_CHUNK_ATTACHMENT_LIMIT=10
ATTACHMENT_IMAGE_FILE_SIZE_LIMIT=2
ATTACHMENT_IMAGE_DOWNLOAD_TIMEOUT=60
IMAGE_FILE_BATCH_LIMIT=10

View File

@ -36,17 +36,20 @@ select = [
"UP", # pyupgrade rules "UP", # pyupgrade rules
"W191", # tab-indentation "W191", # tab-indentation
"W605", # invalid-escape-sequence "W605", # invalid-escape-sequence
"G001", # don't use str format to logging messages
"G003", # don't use + in logging messages
"G004", # don't use f-strings to format logging messages
"UP042", # use StrEnum,
"S110", # disallow the try-except-pass pattern.
# security related linting rules # security related linting rules
# RCE proctection (sort of) # RCE proctection (sort of)
"S102", # exec-builtin, disallow use of `exec` "S102", # exec-builtin, disallow use of `exec`
"S307", # suspicious-eval-usage, disallow use of `eval` and `ast.literal_eval` "S307", # suspicious-eval-usage, disallow use of `eval` and `ast.literal_eval`
"S301", # suspicious-pickle-usage, disallow use of `pickle` and its wrappers. "S301", # suspicious-pickle-usage, disallow use of `pickle` and its wrappers.
"S302", # suspicious-marshal-usage, disallow use of `marshal` module "S302", # suspicious-marshal-usage, disallow use of `marshal` module
"S311", # suspicious-non-cryptographic-random-usage "S311", # suspicious-non-cryptographic-random-usage,
"G001", # don't use str format to logging messages
"G003", # don't use + in logging messages
"G004", # don't use f-strings to format logging messages
"UP042", # use StrEnum
] ]
ignore = [ ignore = [
@ -91,18 +94,16 @@ ignore = [
"configs/*" = [ "configs/*" = [
"N802", # invalid-function-name "N802", # invalid-function-name
] ]
"core/model_runtime/callbacks/base_callback.py" = [ "core/model_runtime/callbacks/base_callback.py" = ["T201"]
"T201", "core/workflow/callbacks/workflow_logging_callback.py" = ["T201"]
]
"core/workflow/callbacks/workflow_logging_callback.py" = [
"T201",
]
"libs/gmpy2_pkcs10aep_cipher.py" = [ "libs/gmpy2_pkcs10aep_cipher.py" = [
"N803", # invalid-argument-name "N803", # invalid-argument-name
] ]
"tests/*" = [ "tests/*" = [
"F811", # redefined-while-unused "F811", # redefined-while-unused
"T201", # allow print in tests "T201", # allow print in tests,
"S110", # allow ignoring exceptions in tests code (currently)
] ]
[lint.pyflakes] [lint.pyflakes]

View File

@ -1,6 +1,8 @@
import logging import logging
import time import time
from opentelemetry.trace import get_current_span
from configs import dify_config from configs import dify_config
from contexts.wrapper import RecyclableContextVar from contexts.wrapper import RecyclableContextVar
from dify_app import DifyApp from dify_app import DifyApp
@ -26,8 +28,25 @@ def create_flask_app_with_configs() -> DifyApp:
# add an unique identifier to each request # add an unique identifier to each request
RecyclableContextVar.increment_thread_recycles() RecyclableContextVar.increment_thread_recycles()
# add after request hook for injecting X-Trace-Id header from OpenTelemetry span context
@dify_app.after_request
def add_trace_id_header(response):
try:
span = get_current_span()
ctx = span.get_span_context() if span else None
if ctx and ctx.is_valid:
trace_id_hex = format(ctx.trace_id, "032x")
# Avoid duplicates if some middleware added it
if "X-Trace-Id" not in response.headers:
response.headers["X-Trace-Id"] = trace_id_hex
except Exception:
# Never break the response due to tracing header injection
logger.warning("Failed to add trace ID to response header", exc_info=True)
return response
# Capture the decorator's return value to avoid pyright reportUnusedFunction # Capture the decorator's return value to avoid pyright reportUnusedFunction
_ = before_request _ = before_request
_ = add_trace_id_header
return dify_app return dify_app

View File

@ -1139,6 +1139,7 @@ def remove_orphaned_files_on_storage(force: bool):
click.echo(click.style(f"Found {len(all_files_in_tables)} files in tables.", fg="white")) click.echo(click.style(f"Found {len(all_files_in_tables)} files in tables.", fg="white"))
except Exception as e: except Exception as e:
click.echo(click.style(f"Error fetching keys: {str(e)}", fg="red")) click.echo(click.style(f"Error fetching keys: {str(e)}", fg="red"))
return
all_files_on_storage = [] all_files_on_storage = []
for storage_path in storage_paths: for storage_path in storage_paths:

View File

@ -360,6 +360,26 @@ class FileUploadConfig(BaseSettings):
default=10, default=10,
) )
IMAGE_FILE_BATCH_LIMIT: PositiveInt = Field(
description="Maximum number of files allowed in a image batch upload operation",
default=10,
)
SINGLE_CHUNK_ATTACHMENT_LIMIT: PositiveInt = Field(
description="Maximum number of files allowed in a single chunk attachment",
default=10,
)
ATTACHMENT_IMAGE_FILE_SIZE_LIMIT: NonNegativeInt = Field(
description="Maximum allowed image file size for attachments in megabytes",
default=2,
)
ATTACHMENT_IMAGE_DOWNLOAD_TIMEOUT: NonNegativeInt = Field(
description="Timeout for downloading image attachments in seconds",
default=60,
)
inner_UPLOAD_FILE_EXTENSION_BLACKLIST: str = Field( inner_UPLOAD_FILE_EXTENSION_BLACKLIST: str = Field(
description=( description=(
"Comma-separated list of file extensions that are blocked from upload. " "Comma-separated list of file extensions that are blocked from upload. "
@ -553,7 +573,10 @@ class LoggingConfig(BaseSettings):
LOG_FORMAT: str = Field( LOG_FORMAT: str = Field(
description="Format string for log messages", description="Format string for log messages",
default="%(asctime)s.%(msecs)03d %(levelname)s [%(threadName)s] [%(filename)s:%(lineno)d] - %(message)s", default=(
"%(asctime)s.%(msecs)03d %(levelname)s [%(threadName)s] "
"[%(filename)s:%(lineno)d] %(trace_id)s - %(message)s"
),
) )
LOG_DATEFORMAT: str | None = Field( LOG_DATEFORMAT: str | None = Field(

View File

@ -0,0 +1,26 @@
"""Helpers for registering Pydantic models with Flask-RESTX namespaces."""
from flask_restx import Namespace
from pydantic import BaseModel
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
def register_schema_model(namespace: Namespace, model: type[BaseModel]) -> None:
"""Register a single BaseModel with a namespace for Swagger documentation."""
namespace.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
def register_schema_models(namespace: Namespace, *models: type[BaseModel]) -> None:
"""Register multiple BaseModels with a namespace."""
for model in models:
register_schema_model(namespace, model)
__all__ = [
"DEFAULT_REF_TEMPLATE_SWAGGER_2_0",
"register_schema_model",
"register_schema_models",
]

View File

@ -3,7 +3,8 @@ from functools import wraps
from typing import ParamSpec, TypeVar from typing import ParamSpec, TypeVar
from flask import request from flask import request
from flask_restx import Resource, fields, reqparse from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound, Unauthorized from werkzeug.exceptions import NotFound, Unauthorized
@ -18,6 +19,30 @@ from extensions.ext_database import db
from libs.token import extract_access_token from libs.token import extract_access_token
from models.model import App, InstalledApp, RecommendedApp from models.model import App, InstalledApp, RecommendedApp
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class InsertExploreAppPayload(BaseModel):
app_id: str = Field(...)
desc: str | None = None
copyright: str | None = None
privacy_policy: str | None = None
custom_disclaimer: str | None = None
language: str = Field(...)
category: str = Field(...)
position: int = Field(...)
@field_validator("language")
@classmethod
def validate_language(cls, value: str) -> str:
return supported_language(value)
console_ns.schema_model(
InsertExploreAppPayload.__name__,
InsertExploreAppPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
def admin_required(view: Callable[P, R]): def admin_required(view: Callable[P, R]):
@wraps(view) @wraps(view)
@ -40,59 +65,34 @@ def admin_required(view: Callable[P, R]):
class InsertExploreAppListApi(Resource): class InsertExploreAppListApi(Resource):
@console_ns.doc("insert_explore_app") @console_ns.doc("insert_explore_app")
@console_ns.doc(description="Insert or update an app in the explore list") @console_ns.doc(description="Insert or update an app in the explore list")
@console_ns.expect( @console_ns.expect(console_ns.models[InsertExploreAppPayload.__name__])
console_ns.model(
"InsertExploreAppRequest",
{
"app_id": fields.String(required=True, description="Application ID"),
"desc": fields.String(description="App description"),
"copyright": fields.String(description="Copyright information"),
"privacy_policy": fields.String(description="Privacy policy"),
"custom_disclaimer": fields.String(description="Custom disclaimer"),
"language": fields.String(required=True, description="Language code"),
"category": fields.String(required=True, description="App category"),
"position": fields.Integer(required=True, description="Display position"),
},
)
)
@console_ns.response(200, "App updated successfully") @console_ns.response(200, "App updated successfully")
@console_ns.response(201, "App inserted successfully") @console_ns.response(201, "App inserted successfully")
@console_ns.response(404, "App not found") @console_ns.response(404, "App not found")
@only_edition_cloud @only_edition_cloud
@admin_required @admin_required
def post(self): def post(self):
parser = ( payload = InsertExploreAppPayload.model_validate(console_ns.payload)
reqparse.RequestParser()
.add_argument("app_id", type=str, required=True, nullable=False, location="json")
.add_argument("desc", type=str, location="json")
.add_argument("copyright", type=str, location="json")
.add_argument("privacy_policy", type=str, location="json")
.add_argument("custom_disclaimer", type=str, location="json")
.add_argument("language", type=supported_language, required=True, nullable=False, location="json")
.add_argument("category", type=str, required=True, nullable=False, location="json")
.add_argument("position", type=int, required=True, nullable=False, location="json")
)
args = parser.parse_args()
app = db.session.execute(select(App).where(App.id == args["app_id"])).scalar_one_or_none() app = db.session.execute(select(App).where(App.id == payload.app_id)).scalar_one_or_none()
if not app: if not app:
raise NotFound(f"App '{args['app_id']}' is not found") raise NotFound(f"App '{payload.app_id}' is not found")
site = app.site site = app.site
if not site: if not site:
desc = args["desc"] or "" desc = payload.desc or ""
copy_right = args["copyright"] or "" copy_right = payload.copyright or ""
privacy_policy = args["privacy_policy"] or "" privacy_policy = payload.privacy_policy or ""
custom_disclaimer = args["custom_disclaimer"] or "" custom_disclaimer = payload.custom_disclaimer or ""
else: else:
desc = site.description or args["desc"] or "" desc = site.description or payload.desc or ""
copy_right = site.copyright or args["copyright"] or "" copy_right = site.copyright or payload.copyright or ""
privacy_policy = site.privacy_policy or args["privacy_policy"] or "" privacy_policy = site.privacy_policy or payload.privacy_policy or ""
custom_disclaimer = site.custom_disclaimer or args["custom_disclaimer"] or "" custom_disclaimer = site.custom_disclaimer or payload.custom_disclaimer or ""
with Session(db.engine) as session: with Session(db.engine) as session:
recommended_app = session.execute( recommended_app = session.execute(
select(RecommendedApp).where(RecommendedApp.app_id == args["app_id"]) select(RecommendedApp).where(RecommendedApp.app_id == payload.app_id)
).scalar_one_or_none() ).scalar_one_or_none()
if not recommended_app: if not recommended_app:
@ -102,9 +102,9 @@ class InsertExploreAppListApi(Resource):
copyright=copy_right, copyright=copy_right,
privacy_policy=privacy_policy, privacy_policy=privacy_policy,
custom_disclaimer=custom_disclaimer, custom_disclaimer=custom_disclaimer,
language=args["language"], language=payload.language,
category=args["category"], category=payload.category,
position=args["position"], position=payload.position,
) )
db.session.add(recommended_app) db.session.add(recommended_app)
@ -118,9 +118,9 @@ class InsertExploreAppListApi(Resource):
recommended_app.copyright = copy_right recommended_app.copyright = copy_right
recommended_app.privacy_policy = privacy_policy recommended_app.privacy_policy = privacy_policy
recommended_app.custom_disclaimer = custom_disclaimer recommended_app.custom_disclaimer = custom_disclaimer
recommended_app.language = args["language"] recommended_app.language = payload.language
recommended_app.category = args["category"] recommended_app.category = payload.category
recommended_app.position = args["position"] recommended_app.position = payload.position
app.is_public = True app.is_public = True

View File

@ -1,4 +1,6 @@
from flask_restx import Resource, fields, reqparse from flask import request
from flask_restx import Resource, fields
from pydantic import BaseModel, Field, field_validator
from controllers.console import console_ns from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model from controllers.console.app.wraps import get_app_model
@ -8,10 +10,21 @@ from libs.login import login_required
from models.model import AppMode from models.model import AppMode
from services.agent_service import AgentService from services.agent_service import AgentService
parser = ( DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
reqparse.RequestParser()
.add_argument("message_id", type=uuid_value, required=True, location="args", help="Message UUID")
.add_argument("conversation_id", type=uuid_value, required=True, location="args", help="Conversation UUID") class AgentLogQuery(BaseModel):
message_id: str = Field(..., description="Message UUID")
conversation_id: str = Field(..., description="Conversation UUID")
@field_validator("message_id", "conversation_id")
@classmethod
def validate_uuid(cls, value: str) -> str:
return uuid_value(value)
console_ns.schema_model(
AgentLogQuery.__name__, AgentLogQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
) )
@ -20,7 +33,7 @@ class AgentLogApi(Resource):
@console_ns.doc("get_agent_logs") @console_ns.doc("get_agent_logs")
@console_ns.doc(description="Get agent execution logs for an application") @console_ns.doc(description="Get agent execution logs for an application")
@console_ns.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(parser) @console_ns.expect(console_ns.models[AgentLogQuery.__name__])
@console_ns.response( @console_ns.response(
200, "Agent logs retrieved successfully", fields.List(fields.Raw(description="Agent log entries")) 200, "Agent logs retrieved successfully", fields.List(fields.Raw(description="Agent log entries"))
) )
@ -31,6 +44,6 @@ class AgentLogApi(Resource):
@get_app_model(mode=[AppMode.AGENT_CHAT]) @get_app_model(mode=[AppMode.AGENT_CHAT])
def get(self, app_model): def get(self, app_model):
"""Get agent logs""" """Get agent logs"""
args = parser.parse_args() args = AgentLogQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
return AgentService.get_agent_logs(app_model, args["conversation_id"], args["message_id"]) return AgentService.get_agent_logs(app_model, args.conversation_id, args.message_id)

View File

@ -1,7 +1,8 @@
from typing import Literal from typing import Any, Literal
from flask import request from flask import request
from flask_restx import Resource, fields, marshal, marshal_with, reqparse from flask_restx import Resource, fields, marshal, marshal_with
from pydantic import BaseModel, Field, field_validator
from controllers.common.errors import NoFileUploadedError, TooManyFilesError from controllers.common.errors import NoFileUploadedError, TooManyFilesError
from controllers.console import console_ns from controllers.console import console_ns
@ -21,22 +22,79 @@ from libs.helper import uuid_value
from libs.login import login_required from libs.login import login_required
from services.annotation_service import AppAnnotationService from services.annotation_service import AppAnnotationService
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class AnnotationReplyPayload(BaseModel):
score_threshold: float = Field(..., description="Score threshold for annotation matching")
embedding_provider_name: str = Field(..., description="Embedding provider name")
embedding_model_name: str = Field(..., description="Embedding model name")
class AnnotationSettingUpdatePayload(BaseModel):
score_threshold: float = Field(..., description="Score threshold")
class AnnotationListQuery(BaseModel):
page: int = Field(default=1, ge=1, description="Page number")
limit: int = Field(default=20, ge=1, description="Page size")
keyword: str = Field(default="", description="Search keyword")
class CreateAnnotationPayload(BaseModel):
message_id: str | None = Field(default=None, description="Message ID")
question: str | None = Field(default=None, description="Question text")
answer: str | None = Field(default=None, description="Answer text")
content: str | None = Field(default=None, description="Content text")
annotation_reply: dict[str, Any] | None = Field(default=None, description="Annotation reply data")
@field_validator("message_id")
@classmethod
def validate_message_id(cls, value: str | None) -> str | None:
if value is None:
return value
return uuid_value(value)
class UpdateAnnotationPayload(BaseModel):
question: str | None = None
answer: str | None = None
content: str | None = None
annotation_reply: dict[str, Any] | None = None
class AnnotationReplyStatusQuery(BaseModel):
action: Literal["enable", "disable"]
class AnnotationFilePayload(BaseModel):
message_id: str = Field(..., description="Message ID")
@field_validator("message_id")
@classmethod
def validate_message_id(cls, value: str) -> str:
return uuid_value(value)
def reg(model: type[BaseModel]) -> None:
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
reg(AnnotationReplyPayload)
reg(AnnotationSettingUpdatePayload)
reg(AnnotationListQuery)
reg(CreateAnnotationPayload)
reg(UpdateAnnotationPayload)
reg(AnnotationReplyStatusQuery)
reg(AnnotationFilePayload)
@console_ns.route("/apps/<uuid:app_id>/annotation-reply/<string:action>") @console_ns.route("/apps/<uuid:app_id>/annotation-reply/<string:action>")
class AnnotationReplyActionApi(Resource): class AnnotationReplyActionApi(Resource):
@console_ns.doc("annotation_reply_action") @console_ns.doc("annotation_reply_action")
@console_ns.doc(description="Enable or disable annotation reply for an app") @console_ns.doc(description="Enable or disable annotation reply for an app")
@console_ns.doc(params={"app_id": "Application ID", "action": "Action to perform (enable/disable)"}) @console_ns.doc(params={"app_id": "Application ID", "action": "Action to perform (enable/disable)"})
@console_ns.expect( @console_ns.expect(console_ns.models[AnnotationReplyPayload.__name__])
console_ns.model(
"AnnotationReplyActionRequest",
{
"score_threshold": fields.Float(required=True, description="Score threshold for annotation matching"),
"embedding_provider_name": fields.String(required=True, description="Embedding provider name"),
"embedding_model_name": fields.String(required=True, description="Embedding model name"),
},
)
)
@console_ns.response(200, "Action completed successfully") @console_ns.response(200, "Action completed successfully")
@console_ns.response(403, "Insufficient permissions") @console_ns.response(403, "Insufficient permissions")
@setup_required @setup_required
@ -46,15 +104,9 @@ class AnnotationReplyActionApi(Resource):
@edit_permission_required @edit_permission_required
def post(self, app_id, action: Literal["enable", "disable"]): def post(self, app_id, action: Literal["enable", "disable"]):
app_id = str(app_id) app_id = str(app_id)
parser = ( args = AnnotationReplyPayload.model_validate(console_ns.payload)
reqparse.RequestParser()
.add_argument("score_threshold", required=True, type=float, location="json")
.add_argument("embedding_provider_name", required=True, type=str, location="json")
.add_argument("embedding_model_name", required=True, type=str, location="json")
)
args = parser.parse_args()
if action == "enable": if action == "enable":
result = AppAnnotationService.enable_app_annotation(args, app_id) result = AppAnnotationService.enable_app_annotation(args.model_dump(), app_id)
elif action == "disable": elif action == "disable":
result = AppAnnotationService.disable_app_annotation(app_id) result = AppAnnotationService.disable_app_annotation(app_id)
return result, 200 return result, 200
@ -82,16 +134,7 @@ class AppAnnotationSettingUpdateApi(Resource):
@console_ns.doc("update_annotation_setting") @console_ns.doc("update_annotation_setting")
@console_ns.doc(description="Update annotation settings for an app") @console_ns.doc(description="Update annotation settings for an app")
@console_ns.doc(params={"app_id": "Application ID", "annotation_setting_id": "Annotation setting ID"}) @console_ns.doc(params={"app_id": "Application ID", "annotation_setting_id": "Annotation setting ID"})
@console_ns.expect( @console_ns.expect(console_ns.models[AnnotationSettingUpdatePayload.__name__])
console_ns.model(
"AnnotationSettingUpdateRequest",
{
"score_threshold": fields.Float(required=True, description="Score threshold"),
"embedding_provider_name": fields.String(required=True, description="Embedding provider"),
"embedding_model_name": fields.String(required=True, description="Embedding model"),
},
)
)
@console_ns.response(200, "Settings updated successfully") @console_ns.response(200, "Settings updated successfully")
@console_ns.response(403, "Insufficient permissions") @console_ns.response(403, "Insufficient permissions")
@setup_required @setup_required
@ -102,10 +145,9 @@ class AppAnnotationSettingUpdateApi(Resource):
app_id = str(app_id) app_id = str(app_id)
annotation_setting_id = str(annotation_setting_id) annotation_setting_id = str(annotation_setting_id)
parser = reqparse.RequestParser().add_argument("score_threshold", required=True, type=float, location="json") args = AnnotationSettingUpdatePayload.model_validate(console_ns.payload)
args = parser.parse_args()
result = AppAnnotationService.update_app_annotation_setting(app_id, annotation_setting_id, args) result = AppAnnotationService.update_app_annotation_setting(app_id, annotation_setting_id, args.model_dump())
return result, 200 return result, 200
@ -142,12 +184,7 @@ class AnnotationApi(Resource):
@console_ns.doc("list_annotations") @console_ns.doc("list_annotations")
@console_ns.doc(description="Get annotations for an app with pagination") @console_ns.doc(description="Get annotations for an app with pagination")
@console_ns.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect( @console_ns.expect(console_ns.models[AnnotationListQuery.__name__])
console_ns.parser()
.add_argument("page", type=int, location="args", default=1, help="Page number")
.add_argument("limit", type=int, location="args", default=20, help="Page size")
.add_argument("keyword", type=str, location="args", default="", help="Search keyword")
)
@console_ns.response(200, "Annotations retrieved successfully") @console_ns.response(200, "Annotations retrieved successfully")
@console_ns.response(403, "Insufficient permissions") @console_ns.response(403, "Insufficient permissions")
@setup_required @setup_required
@ -155,9 +192,10 @@ class AnnotationApi(Resource):
@account_initialization_required @account_initialization_required
@edit_permission_required @edit_permission_required
def get(self, app_id): def get(self, app_id):
page = request.args.get("page", default=1, type=int) args = AnnotationListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
limit = request.args.get("limit", default=20, type=int) page = args.page
keyword = request.args.get("keyword", default="", type=str) limit = args.limit
keyword = args.keyword
app_id = str(app_id) app_id = str(app_id)
annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(app_id, page, limit, keyword) annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(app_id, page, limit, keyword)
@ -173,18 +211,7 @@ class AnnotationApi(Resource):
@console_ns.doc("create_annotation") @console_ns.doc("create_annotation")
@console_ns.doc(description="Create a new annotation for an app") @console_ns.doc(description="Create a new annotation for an app")
@console_ns.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect( @console_ns.expect(console_ns.models[CreateAnnotationPayload.__name__])
console_ns.model(
"CreateAnnotationRequest",
{
"message_id": fields.String(description="Message ID (optional)"),
"question": fields.String(description="Question text (required when message_id not provided)"),
"answer": fields.String(description="Answer text (use 'answer' or 'content')"),
"content": fields.String(description="Content text (use 'answer' or 'content')"),
"annotation_reply": fields.Raw(description="Annotation reply data"),
},
)
)
@console_ns.response(201, "Annotation created successfully", build_annotation_model(console_ns)) @console_ns.response(201, "Annotation created successfully", build_annotation_model(console_ns))
@console_ns.response(403, "Insufficient permissions") @console_ns.response(403, "Insufficient permissions")
@setup_required @setup_required
@ -195,16 +222,9 @@ class AnnotationApi(Resource):
@edit_permission_required @edit_permission_required
def post(self, app_id): def post(self, app_id):
app_id = str(app_id) app_id = str(app_id)
parser = ( args = CreateAnnotationPayload.model_validate(console_ns.payload)
reqparse.RequestParser() data = args.model_dump(exclude_none=True)
.add_argument("message_id", required=False, type=uuid_value, location="json") annotation = AppAnnotationService.up_insert_app_annotation_from_message(data, app_id)
.add_argument("question", required=False, type=str, location="json")
.add_argument("answer", required=False, type=str, location="json")
.add_argument("content", required=False, type=str, location="json")
.add_argument("annotation_reply", required=False, type=dict, location="json")
)
args = parser.parse_args()
annotation = AppAnnotationService.up_insert_app_annotation_from_message(args, app_id)
return annotation return annotation
@setup_required @setup_required
@ -256,13 +276,6 @@ class AnnotationExportApi(Resource):
return response, 200 return response, 200
parser = (
reqparse.RequestParser()
.add_argument("question", required=True, type=str, location="json")
.add_argument("answer", required=True, type=str, location="json")
)
@console_ns.route("/apps/<uuid:app_id>/annotations/<uuid:annotation_id>") @console_ns.route("/apps/<uuid:app_id>/annotations/<uuid:annotation_id>")
class AnnotationUpdateDeleteApi(Resource): class AnnotationUpdateDeleteApi(Resource):
@console_ns.doc("update_delete_annotation") @console_ns.doc("update_delete_annotation")
@ -271,7 +284,7 @@ class AnnotationUpdateDeleteApi(Resource):
@console_ns.response(200, "Annotation updated successfully", build_annotation_model(console_ns)) @console_ns.response(200, "Annotation updated successfully", build_annotation_model(console_ns))
@console_ns.response(204, "Annotation deleted successfully") @console_ns.response(204, "Annotation deleted successfully")
@console_ns.response(403, "Insufficient permissions") @console_ns.response(403, "Insufficient permissions")
@console_ns.expect(parser) @console_ns.expect(console_ns.models[UpdateAnnotationPayload.__name__])
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -281,8 +294,10 @@ class AnnotationUpdateDeleteApi(Resource):
def post(self, app_id, annotation_id): def post(self, app_id, annotation_id):
app_id = str(app_id) app_id = str(app_id)
annotation_id = str(annotation_id) annotation_id = str(annotation_id)
args = parser.parse_args() args = UpdateAnnotationPayload.model_validate(console_ns.payload)
annotation = AppAnnotationService.update_app_annotation_directly(args, app_id, annotation_id) annotation = AppAnnotationService.update_app_annotation_directly(
args.model_dump(exclude_none=True), app_id, annotation_id
)
return annotation return annotation
@setup_required @setup_required

View File

@ -31,7 +31,6 @@ from fields.app_fields import (
from fields.workflow_fields import workflow_partial_fields as _workflow_partial_fields_dict from fields.workflow_fields import workflow_partial_fields as _workflow_partial_fields_dict
from libs.helper import AppIconUrlField, TimestampField from libs.helper import AppIconUrlField, TimestampField
from libs.login import current_account_with_tenant, login_required from libs.login import current_account_with_tenant, login_required
from libs.validators import validate_description_length
from models import App, Workflow from models import App, Workflow
from services.app_dsl_service import AppDslService, ImportMode from services.app_dsl_service import AppDslService, ImportMode
from services.app_service import AppService from services.app_service import AppService
@ -76,51 +75,30 @@ class AppListQuery(BaseModel):
class CreateAppPayload(BaseModel): class CreateAppPayload(BaseModel):
name: str = Field(..., min_length=1, description="App name") name: str = Field(..., min_length=1, description="App name")
description: str | None = Field(default=None, description="App description (max 400 chars)") description: str | None = Field(default=None, description="App description (max 400 chars)", max_length=400)
mode: Literal["chat", "agent-chat", "advanced-chat", "workflow", "completion"] = Field(..., description="App mode") mode: Literal["chat", "agent-chat", "advanced-chat", "workflow", "completion"] = Field(..., description="App mode")
icon_type: str | None = Field(default=None, description="Icon type") icon_type: str | None = Field(default=None, description="Icon type")
icon: str | None = Field(default=None, description="Icon") icon: str | None = Field(default=None, description="Icon")
icon_background: str | None = Field(default=None, description="Icon background color") icon_background: str | None = Field(default=None, description="Icon background color")
@field_validator("description")
@classmethod
def validate_description(cls, value: str | None) -> str | None:
if value is None:
return value
return validate_description_length(value)
class UpdateAppPayload(BaseModel): class UpdateAppPayload(BaseModel):
name: str = Field(..., min_length=1, description="App name") name: str = Field(..., min_length=1, description="App name")
description: str | None = Field(default=None, description="App description (max 400 chars)") description: str | None = Field(default=None, description="App description (max 400 chars)", max_length=400)
icon_type: str | None = Field(default=None, description="Icon type") icon_type: str | None = Field(default=None, description="Icon type")
icon: str | None = Field(default=None, description="Icon") icon: str | None = Field(default=None, description="Icon")
icon_background: str | None = Field(default=None, description="Icon background color") icon_background: str | None = Field(default=None, description="Icon background color")
use_icon_as_answer_icon: bool | None = Field(default=None, description="Use icon as answer icon") use_icon_as_answer_icon: bool | None = Field(default=None, description="Use icon as answer icon")
max_active_requests: int | None = Field(default=None, description="Maximum active requests") max_active_requests: int | None = Field(default=None, description="Maximum active requests")
@field_validator("description")
@classmethod
def validate_description(cls, value: str | None) -> str | None:
if value is None:
return value
return validate_description_length(value)
class CopyAppPayload(BaseModel): class CopyAppPayload(BaseModel):
name: str | None = Field(default=None, description="Name for the copied app") name: str | None = Field(default=None, description="Name for the copied app")
description: str | None = Field(default=None, description="Description for the copied app") description: str | None = Field(default=None, description="Description for the copied app", max_length=400)
icon_type: str | None = Field(default=None, description="Icon type") icon_type: str | None = Field(default=None, description="Icon type")
icon: str | None = Field(default=None, description="Icon") icon: str | None = Field(default=None, description="Icon")
icon_background: str | None = Field(default=None, description="Icon background color") icon_background: str | None = Field(default=None, description="Icon background color")
@field_validator("description")
@classmethod
def validate_description(cls, value: str | None) -> str | None:
if value is None:
return value
return validate_description_length(value)
class AppExportQuery(BaseModel): class AppExportQuery(BaseModel):
include_secret: bool = Field(default=False, description="Include secrets in export") include_secret: bool = Field(default=False, description="Include secrets in export")
@ -146,7 +124,14 @@ class AppApiStatusPayload(BaseModel):
class AppTracePayload(BaseModel): class AppTracePayload(BaseModel):
enabled: bool = Field(..., description="Enable or disable tracing") enabled: bool = Field(..., description="Enable or disable tracing")
tracing_provider: str = Field(..., description="Tracing provider") tracing_provider: str | None = Field(default=None, description="Tracing provider")
@field_validator("tracing_provider")
@classmethod
def validate_tracing_provider(cls, value: str | None, info) -> str | None:
if info.data.get("enabled") and not value:
raise ValueError("tracing_provider is required when enabled is True")
return value
def reg(cls: type[BaseModel]): def reg(cls: type[BaseModel]):
@ -324,10 +309,13 @@ class AppListApi(Resource):
NodeType.TRIGGER_PLUGIN, NodeType.TRIGGER_PLUGIN,
} }
for workflow in draft_workflows: for workflow in draft_workflows:
for _, node_data in workflow.walk_nodes(): try:
if node_data.get("type") in trigger_node_types: for _, node_data in workflow.walk_nodes():
draft_trigger_app_ids.add(str(workflow.app_id)) if node_data.get("type") in trigger_node_types:
break draft_trigger_app_ids.add(str(workflow.app_id))
break
except Exception:
continue
for app in app_pagination.items: for app in app_pagination.items:
app.has_draft_trigger = str(app.id) in draft_trigger_app_ids app.has_draft_trigger = str(app.id) in draft_trigger_app_ids

View File

@ -1,4 +1,5 @@
from flask_restx import Resource, fields, marshal_with, reqparse from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from controllers.console.app.wraps import get_app_model from controllers.console.app.wraps import get_app_model
@ -35,23 +36,29 @@ app_import_check_dependencies_model = console_ns.model(
"AppImportCheckDependencies", app_import_check_dependencies_fields_copy "AppImportCheckDependencies", app_import_check_dependencies_fields_copy
) )
parser = ( DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
reqparse.RequestParser()
.add_argument("mode", type=str, required=True, location="json")
.add_argument("yaml_content", type=str, location="json") class AppImportPayload(BaseModel):
.add_argument("yaml_url", type=str, location="json") mode: str = Field(..., description="Import mode")
.add_argument("name", type=str, location="json") yaml_content: str | None = None
.add_argument("description", type=str, location="json") yaml_url: str | None = None
.add_argument("icon_type", type=str, location="json") name: str | None = None
.add_argument("icon", type=str, location="json") description: str | None = None
.add_argument("icon_background", type=str, location="json") icon_type: str | None = None
.add_argument("app_id", type=str, location="json") icon: str | None = None
icon_background: str | None = None
app_id: str | None = None
console_ns.schema_model(
AppImportPayload.__name__, AppImportPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
) )
@console_ns.route("/apps/imports") @console_ns.route("/apps/imports")
class AppImportApi(Resource): class AppImportApi(Resource):
@console_ns.expect(parser) @console_ns.expect(console_ns.models[AppImportPayload.__name__])
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -61,7 +68,7 @@ class AppImportApi(Resource):
def post(self): def post(self):
# Check user role first # Check user role first
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
args = parser.parse_args() args = AppImportPayload.model_validate(console_ns.payload)
# Create service with session # Create service with session
with Session(db.engine) as session: with Session(db.engine) as session:
@ -70,15 +77,15 @@ class AppImportApi(Resource):
account = current_user account = current_user
result = import_service.import_app( result = import_service.import_app(
account=account, account=account,
import_mode=args["mode"], import_mode=args.mode,
yaml_content=args.get("yaml_content"), yaml_content=args.yaml_content,
yaml_url=args.get("yaml_url"), yaml_url=args.yaml_url,
name=args.get("name"), name=args.name,
description=args.get("description"), description=args.description,
icon_type=args.get("icon_type"), icon_type=args.icon_type,
icon=args.get("icon"), icon=args.icon,
icon_background=args.get("icon_background"), icon_background=args.icon_background,
app_id=args.get("app_id"), app_id=args.app_id,
) )
session.commit() session.commit()
if result.app_id and FeatureService.get_system_features().webapp_auth.enabled: if result.app_id and FeatureService.get_system_features().webapp_auth.enabled:

View File

@ -1,7 +1,8 @@
import logging import logging
from flask import request from flask import request
from flask_restx import Resource, fields, reqparse from flask_restx import Resource, fields
from pydantic import BaseModel, Field
from werkzeug.exceptions import InternalServerError from werkzeug.exceptions import InternalServerError
import services import services
@ -32,6 +33,27 @@ from services.errors.audio import (
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class TextToSpeechPayload(BaseModel):
message_id: str | None = Field(default=None, description="Message ID")
text: str = Field(..., description="Text to convert")
voice: str | None = Field(default=None, description="Voice name")
streaming: bool | None = Field(default=None, description="Whether to stream audio")
class TextToSpeechVoiceQuery(BaseModel):
language: str = Field(..., description="Language code")
console_ns.schema_model(
TextToSpeechPayload.__name__, TextToSpeechPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
console_ns.schema_model(
TextToSpeechVoiceQuery.__name__,
TextToSpeechVoiceQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
@console_ns.route("/apps/<uuid:app_id>/audio-to-text") @console_ns.route("/apps/<uuid:app_id>/audio-to-text")
@ -92,17 +114,7 @@ class ChatMessageTextApi(Resource):
@console_ns.doc("chat_message_text_to_speech") @console_ns.doc("chat_message_text_to_speech")
@console_ns.doc(description="Convert text to speech for chat messages") @console_ns.doc(description="Convert text to speech for chat messages")
@console_ns.doc(params={"app_id": "App ID"}) @console_ns.doc(params={"app_id": "App ID"})
@console_ns.expect( @console_ns.expect(console_ns.models[TextToSpeechPayload.__name__])
console_ns.model(
"TextToSpeechRequest",
{
"message_id": fields.String(description="Message ID"),
"text": fields.String(required=True, description="Text to convert to speech"),
"voice": fields.String(description="Voice to use for TTS"),
"streaming": fields.Boolean(description="Whether to stream the audio"),
},
)
)
@console_ns.response(200, "Text to speech conversion successful") @console_ns.response(200, "Text to speech conversion successful")
@console_ns.response(400, "Bad request - Invalid parameters") @console_ns.response(400, "Bad request - Invalid parameters")
@get_app_model @get_app_model
@ -111,21 +123,14 @@ class ChatMessageTextApi(Resource):
@account_initialization_required @account_initialization_required
def post(self, app_model: App): def post(self, app_model: App):
try: try:
parser = ( payload = TextToSpeechPayload.model_validate(console_ns.payload)
reqparse.RequestParser()
.add_argument("message_id", type=str, location="json")
.add_argument("text", type=str, location="json")
.add_argument("voice", type=str, location="json")
.add_argument("streaming", type=bool, location="json")
)
args = parser.parse_args()
message_id = args.get("message_id", None)
text = args.get("text", None)
voice = args.get("voice", None)
response = AudioService.transcript_tts( response = AudioService.transcript_tts(
app_model=app_model, text=text, voice=voice, message_id=message_id, is_draft=True app_model=app_model,
text=payload.text,
voice=payload.voice,
message_id=payload.message_id,
is_draft=True,
) )
return response return response
except services.errors.app_model_config.AppModelConfigBrokenError: except services.errors.app_model_config.AppModelConfigBrokenError:
@ -159,9 +164,7 @@ class TextModesApi(Resource):
@console_ns.doc("get_text_to_speech_voices") @console_ns.doc("get_text_to_speech_voices")
@console_ns.doc(description="Get available TTS voices for a specific language") @console_ns.doc(description="Get available TTS voices for a specific language")
@console_ns.doc(params={"app_id": "App ID"}) @console_ns.doc(params={"app_id": "App ID"})
@console_ns.expect( @console_ns.expect(console_ns.models[TextToSpeechVoiceQuery.__name__])
console_ns.parser().add_argument("language", type=str, required=True, location="args", help="Language code")
)
@console_ns.response( @console_ns.response(
200, "TTS voices retrieved successfully", fields.List(fields.Raw(description="Available voices")) 200, "TTS voices retrieved successfully", fields.List(fields.Raw(description="Available voices"))
) )
@ -172,12 +175,11 @@ class TextModesApi(Resource):
@account_initialization_required @account_initialization_required
def get(self, app_model): def get(self, app_model):
try: try:
parser = reqparse.RequestParser().add_argument("language", type=str, required=True, location="args") args = TextToSpeechVoiceQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
args = parser.parse_args()
response = AudioService.transcript_tts_voices( response = AudioService.transcript_tts_voices(
tenant_id=app_model.tenant_id, tenant_id=app_model.tenant_id,
language=args["language"], language=args.language,
) )
return response return response

View File

@ -49,7 +49,6 @@ class CompletionConversationQuery(BaseConversationQuery):
class ChatConversationQuery(BaseConversationQuery): class ChatConversationQuery(BaseConversationQuery):
message_count_gte: int | None = Field(default=None, ge=1, description="Minimum message count")
sort_by: Literal["created_at", "-created_at", "updated_at", "-updated_at"] = Field( sort_by: Literal["created_at", "-created_at", "updated_at", "-updated_at"] = Field(
default="-updated_at", description="Sort field and direction" default="-updated_at", description="Sort field and direction"
) )
@ -509,14 +508,6 @@ class ChatConversationApi(Resource):
.having(func.count(MessageAnnotation.id) == 0) .having(func.count(MessageAnnotation.id) == 0)
) )
if args.message_count_gte and args.message_count_gte >= 1:
query = (
query.options(joinedload(Conversation.messages)) # type: ignore
.join(Message, Message.conversation_id == Conversation.id)
.group_by(Conversation.id)
.having(func.count(Message.id) >= args.message_count_gte)
)
if app_model.mode == AppMode.ADVANCED_CHAT: if app_model.mode == AppMode.ADVANCED_CHAT:
query = query.where(Conversation.invoke_from != InvokeFrom.DEBUGGER) query = query.where(Conversation.invoke_from != InvokeFrom.DEBUGGER)

View File

@ -1,7 +1,8 @@
import json import json
from enum import StrEnum from enum import StrEnum
from flask_restx import Resource, fields, marshal_with, reqparse from flask_restx import Resource, marshal_with
from pydantic import BaseModel, Field
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
from controllers.console import console_ns from controllers.console import console_ns
@ -12,6 +13,8 @@ from fields.app_fields import app_server_fields
from libs.login import current_account_with_tenant, login_required from libs.login import current_account_with_tenant, login_required
from models.model import AppMCPServer from models.model import AppMCPServer
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
# Register model for flask_restx to avoid dict type issues in Swagger # Register model for flask_restx to avoid dict type issues in Swagger
app_server_model = console_ns.model("AppServer", app_server_fields) app_server_model = console_ns.model("AppServer", app_server_fields)
@ -21,6 +24,22 @@ class AppMCPServerStatus(StrEnum):
INACTIVE = "inactive" INACTIVE = "inactive"
class MCPServerCreatePayload(BaseModel):
description: str | None = Field(default=None, description="Server description")
parameters: dict = Field(..., description="Server parameters configuration")
class MCPServerUpdatePayload(BaseModel):
id: str = Field(..., description="Server ID")
description: str | None = Field(default=None, description="Server description")
parameters: dict = Field(..., description="Server parameters configuration")
status: str | None = Field(default=None, description="Server status")
for model in (MCPServerCreatePayload, MCPServerUpdatePayload):
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
@console_ns.route("/apps/<uuid:app_id>/server") @console_ns.route("/apps/<uuid:app_id>/server")
class AppMCPServerController(Resource): class AppMCPServerController(Resource):
@console_ns.doc("get_app_mcp_server") @console_ns.doc("get_app_mcp_server")
@ -39,15 +58,7 @@ class AppMCPServerController(Resource):
@console_ns.doc("create_app_mcp_server") @console_ns.doc("create_app_mcp_server")
@console_ns.doc(description="Create MCP server configuration for an application") @console_ns.doc(description="Create MCP server configuration for an application")
@console_ns.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect( @console_ns.expect(console_ns.models[MCPServerCreatePayload.__name__])
console_ns.model(
"MCPServerCreateRequest",
{
"description": fields.String(description="Server description"),
"parameters": fields.Raw(required=True, description="Server parameters configuration"),
},
)
)
@console_ns.response(201, "MCP server configuration created successfully", app_server_model) @console_ns.response(201, "MCP server configuration created successfully", app_server_model)
@console_ns.response(403, "Insufficient permissions") @console_ns.response(403, "Insufficient permissions")
@account_initialization_required @account_initialization_required
@ -58,21 +69,16 @@ class AppMCPServerController(Resource):
@edit_permission_required @edit_permission_required
def post(self, app_model): def post(self, app_model):
_, current_tenant_id = current_account_with_tenant() _, current_tenant_id = current_account_with_tenant()
parser = ( payload = MCPServerCreatePayload.model_validate(console_ns.payload or {})
reqparse.RequestParser()
.add_argument("description", type=str, required=False, location="json")
.add_argument("parameters", type=dict, required=True, location="json")
)
args = parser.parse_args()
description = args.get("description") description = payload.description
if not description: if not description:
description = app_model.description or "" description = app_model.description or ""
server = AppMCPServer( server = AppMCPServer(
name=app_model.name, name=app_model.name,
description=description, description=description,
parameters=json.dumps(args["parameters"], ensure_ascii=False), parameters=json.dumps(payload.parameters, ensure_ascii=False),
status=AppMCPServerStatus.ACTIVE, status=AppMCPServerStatus.ACTIVE,
app_id=app_model.id, app_id=app_model.id,
tenant_id=current_tenant_id, tenant_id=current_tenant_id,
@ -85,17 +91,7 @@ class AppMCPServerController(Resource):
@console_ns.doc("update_app_mcp_server") @console_ns.doc("update_app_mcp_server")
@console_ns.doc(description="Update MCP server configuration for an application") @console_ns.doc(description="Update MCP server configuration for an application")
@console_ns.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect( @console_ns.expect(console_ns.models[MCPServerUpdatePayload.__name__])
console_ns.model(
"MCPServerUpdateRequest",
{
"id": fields.String(required=True, description="Server ID"),
"description": fields.String(description="Server description"),
"parameters": fields.Raw(required=True, description="Server parameters configuration"),
"status": fields.String(description="Server status"),
},
)
)
@console_ns.response(200, "MCP server configuration updated successfully", app_server_model) @console_ns.response(200, "MCP server configuration updated successfully", app_server_model)
@console_ns.response(403, "Insufficient permissions") @console_ns.response(403, "Insufficient permissions")
@console_ns.response(404, "Server not found") @console_ns.response(404, "Server not found")
@ -106,19 +102,12 @@ class AppMCPServerController(Resource):
@marshal_with(app_server_model) @marshal_with(app_server_model)
@edit_permission_required @edit_permission_required
def put(self, app_model): def put(self, app_model):
parser = ( payload = MCPServerUpdatePayload.model_validate(console_ns.payload or {})
reqparse.RequestParser() server = db.session.query(AppMCPServer).where(AppMCPServer.id == payload.id).first()
.add_argument("id", type=str, required=True, location="json")
.add_argument("description", type=str, required=False, location="json")
.add_argument("parameters", type=dict, required=True, location="json")
.add_argument("status", type=str, required=False, location="json")
)
args = parser.parse_args()
server = db.session.query(AppMCPServer).where(AppMCPServer.id == args["id"]).first()
if not server: if not server:
raise NotFound() raise NotFound()
description = args.get("description") description = payload.description
if description is None: if description is None:
pass pass
elif not description: elif not description:
@ -126,11 +115,11 @@ class AppMCPServerController(Resource):
else: else:
server.description = description server.description = description
server.parameters = json.dumps(args["parameters"], ensure_ascii=False) server.parameters = json.dumps(payload.parameters, ensure_ascii=False)
if args["status"]: if payload.status:
if args["status"] not in [status.value for status in AppMCPServerStatus]: if payload.status not in [status.value for status in AppMCPServerStatus]:
raise ValueError("Invalid status") raise ValueError("Invalid status")
server.status = args["status"] server.status = payload.status
db.session.commit() db.session.commit()
return server return server

View File

@ -61,6 +61,7 @@ class ChatMessagesQuery(BaseModel):
class MessageFeedbackPayload(BaseModel): class MessageFeedbackPayload(BaseModel):
message_id: str = Field(..., description="Message ID") message_id: str = Field(..., description="Message ID")
rating: Literal["like", "dislike"] | None = Field(default=None, description="Feedback rating") rating: Literal["like", "dislike"] | None = Field(default=None, description="Feedback rating")
content: str | None = Field(default=None, description="Feedback content")
@field_validator("message_id") @field_validator("message_id")
@classmethod @classmethod
@ -324,6 +325,7 @@ class MessageFeedbackApi(Resource):
db.session.delete(feedback) db.session.delete(feedback)
elif args.rating and feedback: elif args.rating and feedback:
feedback.rating = args.rating feedback.rating = args.rating
feedback.content = args.content
elif not args.rating and not feedback: elif not args.rating and not feedback:
raise ValueError("rating cannot be None when feedback not exists") raise ValueError("rating cannot be None when feedback not exists")
else: else:
@ -335,6 +337,7 @@ class MessageFeedbackApi(Resource):
conversation_id=message.conversation_id, conversation_id=message.conversation_id,
message_id=message.id, message_id=message.id,
rating=rating_value, rating=rating_value,
content=args.content,
from_source="admin", from_source="admin",
from_account_id=current_user.id, from_account_id=current_user.id,
) )

View File

@ -1,4 +1,8 @@
from flask_restx import Resource, fields, reqparse from typing import Any
from flask import request
from flask_restx import Resource, fields
from pydantic import BaseModel, Field
from werkzeug.exceptions import BadRequest from werkzeug.exceptions import BadRequest
from controllers.console import console_ns from controllers.console import console_ns
@ -7,6 +11,26 @@ from controllers.console.wraps import account_initialization_required, setup_req
from libs.login import login_required from libs.login import login_required
from services.ops_service import OpsService from services.ops_service import OpsService
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class TraceProviderQuery(BaseModel):
tracing_provider: str = Field(..., description="Tracing provider name")
class TraceConfigPayload(BaseModel):
tracing_provider: str = Field(..., description="Tracing provider name")
tracing_config: dict[str, Any] = Field(..., description="Tracing configuration data")
console_ns.schema_model(
TraceProviderQuery.__name__,
TraceProviderQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
TraceConfigPayload.__name__, TraceConfigPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
@console_ns.route("/apps/<uuid:app_id>/trace-config") @console_ns.route("/apps/<uuid:app_id>/trace-config")
class TraceAppConfigApi(Resource): class TraceAppConfigApi(Resource):
@ -17,11 +41,7 @@ class TraceAppConfigApi(Resource):
@console_ns.doc("get_trace_app_config") @console_ns.doc("get_trace_app_config")
@console_ns.doc(description="Get tracing configuration for an application") @console_ns.doc(description="Get tracing configuration for an application")
@console_ns.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect( @console_ns.expect(console_ns.models[TraceProviderQuery.__name__])
console_ns.parser().add_argument(
"tracing_provider", type=str, required=True, location="args", help="Tracing provider name"
)
)
@console_ns.response( @console_ns.response(
200, "Tracing configuration retrieved successfully", fields.Raw(description="Tracing configuration data") 200, "Tracing configuration retrieved successfully", fields.Raw(description="Tracing configuration data")
) )
@ -30,11 +50,10 @@ class TraceAppConfigApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, app_id): def get(self, app_id):
parser = reqparse.RequestParser().add_argument("tracing_provider", type=str, required=True, location="args") args = TraceProviderQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
args = parser.parse_args()
try: try:
trace_config = OpsService.get_tracing_app_config(app_id=app_id, tracing_provider=args["tracing_provider"]) trace_config = OpsService.get_tracing_app_config(app_id=app_id, tracing_provider=args.tracing_provider)
if not trace_config: if not trace_config:
return {"has_not_configured": True} return {"has_not_configured": True}
return trace_config return trace_config
@ -44,15 +63,7 @@ class TraceAppConfigApi(Resource):
@console_ns.doc("create_trace_app_config") @console_ns.doc("create_trace_app_config")
@console_ns.doc(description="Create a new tracing configuration for an application") @console_ns.doc(description="Create a new tracing configuration for an application")
@console_ns.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect( @console_ns.expect(console_ns.models[TraceConfigPayload.__name__])
console_ns.model(
"TraceConfigCreateRequest",
{
"tracing_provider": fields.String(required=True, description="Tracing provider name"),
"tracing_config": fields.Raw(required=True, description="Tracing configuration data"),
},
)
)
@console_ns.response( @console_ns.response(
201, "Tracing configuration created successfully", fields.Raw(description="Created configuration data") 201, "Tracing configuration created successfully", fields.Raw(description="Created configuration data")
) )
@ -62,16 +73,11 @@ class TraceAppConfigApi(Resource):
@account_initialization_required @account_initialization_required
def post(self, app_id): def post(self, app_id):
"""Create a new trace app configuration""" """Create a new trace app configuration"""
parser = ( args = TraceConfigPayload.model_validate(console_ns.payload)
reqparse.RequestParser()
.add_argument("tracing_provider", type=str, required=True, location="json")
.add_argument("tracing_config", type=dict, required=True, location="json")
)
args = parser.parse_args()
try: try:
result = OpsService.create_tracing_app_config( result = OpsService.create_tracing_app_config(
app_id=app_id, tracing_provider=args["tracing_provider"], tracing_config=args["tracing_config"] app_id=app_id, tracing_provider=args.tracing_provider, tracing_config=args.tracing_config
) )
if not result: if not result:
raise TracingConfigIsExist() raise TracingConfigIsExist()
@ -84,15 +90,7 @@ class TraceAppConfigApi(Resource):
@console_ns.doc("update_trace_app_config") @console_ns.doc("update_trace_app_config")
@console_ns.doc(description="Update an existing tracing configuration for an application") @console_ns.doc(description="Update an existing tracing configuration for an application")
@console_ns.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect( @console_ns.expect(console_ns.models[TraceConfigPayload.__name__])
console_ns.model(
"TraceConfigUpdateRequest",
{
"tracing_provider": fields.String(required=True, description="Tracing provider name"),
"tracing_config": fields.Raw(required=True, description="Updated tracing configuration data"),
},
)
)
@console_ns.response(200, "Tracing configuration updated successfully", fields.Raw(description="Success response")) @console_ns.response(200, "Tracing configuration updated successfully", fields.Raw(description="Success response"))
@console_ns.response(400, "Invalid request parameters or configuration not found") @console_ns.response(400, "Invalid request parameters or configuration not found")
@setup_required @setup_required
@ -100,16 +98,11 @@ class TraceAppConfigApi(Resource):
@account_initialization_required @account_initialization_required
def patch(self, app_id): def patch(self, app_id):
"""Update an existing trace app configuration""" """Update an existing trace app configuration"""
parser = ( args = TraceConfigPayload.model_validate(console_ns.payload)
reqparse.RequestParser()
.add_argument("tracing_provider", type=str, required=True, location="json")
.add_argument("tracing_config", type=dict, required=True, location="json")
)
args = parser.parse_args()
try: try:
result = OpsService.update_tracing_app_config( result = OpsService.update_tracing_app_config(
app_id=app_id, tracing_provider=args["tracing_provider"], tracing_config=args["tracing_config"] app_id=app_id, tracing_provider=args.tracing_provider, tracing_config=args.tracing_config
) )
if not result: if not result:
raise TracingConfigNotExist() raise TracingConfigNotExist()
@ -120,11 +113,7 @@ class TraceAppConfigApi(Resource):
@console_ns.doc("delete_trace_app_config") @console_ns.doc("delete_trace_app_config")
@console_ns.doc(description="Delete an existing tracing configuration for an application") @console_ns.doc(description="Delete an existing tracing configuration for an application")
@console_ns.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect( @console_ns.expect(console_ns.models[TraceProviderQuery.__name__])
console_ns.parser().add_argument(
"tracing_provider", type=str, required=True, location="args", help="Tracing provider name"
)
)
@console_ns.response(204, "Tracing configuration deleted successfully") @console_ns.response(204, "Tracing configuration deleted successfully")
@console_ns.response(400, "Invalid request parameters or configuration not found") @console_ns.response(400, "Invalid request parameters or configuration not found")
@setup_required @setup_required
@ -132,11 +121,10 @@ class TraceAppConfigApi(Resource):
@account_initialization_required @account_initialization_required
def delete(self, app_id): def delete(self, app_id):
"""Delete an existing trace app configuration""" """Delete an existing trace app configuration"""
parser = reqparse.RequestParser().add_argument("tracing_provider", type=str, required=True, location="args") args = TraceProviderQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
args = parser.parse_args()
try: try:
result = OpsService.delete_tracing_app_config(app_id=app_id, tracing_provider=args["tracing_provider"]) result = OpsService.delete_tracing_app_config(app_id=app_id, tracing_provider=args.tracing_provider)
if not result: if not result:
raise TracingConfigNotExist() raise TracingConfigNotExist()
return {"result": "success"}, 204 return {"result": "success"}, 204

View File

@ -1,4 +1,7 @@
from flask_restx import Resource, fields, marshal_with, reqparse from typing import Literal
from flask_restx import Resource, marshal_with
from pydantic import BaseModel, Field, field_validator
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
from constants.languages import supported_language from constants.languages import supported_language
@ -16,69 +19,50 @@ from libs.datetime_utils import naive_utc_now
from libs.login import current_account_with_tenant, login_required from libs.login import current_account_with_tenant, login_required
from models import Site from models import Site
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class AppSiteUpdatePayload(BaseModel):
title: str | None = Field(default=None)
icon_type: str | None = Field(default=None)
icon: str | None = Field(default=None)
icon_background: str | None = Field(default=None)
description: str | None = Field(default=None)
default_language: str | None = Field(default=None)
chat_color_theme: str | None = Field(default=None)
chat_color_theme_inverted: bool | None = Field(default=None)
customize_domain: str | None = Field(default=None)
copyright: str | None = Field(default=None)
privacy_policy: str | None = Field(default=None)
custom_disclaimer: str | None = Field(default=None)
customize_token_strategy: Literal["must", "allow", "not_allow"] | None = Field(default=None)
prompt_public: bool | None = Field(default=None)
show_workflow_steps: bool | None = Field(default=None)
use_icon_as_answer_icon: bool | None = Field(default=None)
@field_validator("default_language")
@classmethod
def validate_language(cls, value: str | None) -> str | None:
if value is None:
return value
return supported_language(value)
console_ns.schema_model(
AppSiteUpdatePayload.__name__,
AppSiteUpdatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
# Register model for flask_restx to avoid dict type issues in Swagger # Register model for flask_restx to avoid dict type issues in Swagger
app_site_model = console_ns.model("AppSite", app_site_fields) app_site_model = console_ns.model("AppSite", app_site_fields)
def parse_app_site_args():
parser = (
reqparse.RequestParser()
.add_argument("title", type=str, required=False, location="json")
.add_argument("icon_type", type=str, required=False, location="json")
.add_argument("icon", type=str, required=False, location="json")
.add_argument("icon_background", type=str, required=False, location="json")
.add_argument("description", type=str, required=False, location="json")
.add_argument("default_language", type=supported_language, required=False, location="json")
.add_argument("chat_color_theme", type=str, required=False, location="json")
.add_argument("chat_color_theme_inverted", type=bool, required=False, location="json")
.add_argument("customize_domain", type=str, required=False, location="json")
.add_argument("copyright", type=str, required=False, location="json")
.add_argument("privacy_policy", type=str, required=False, location="json")
.add_argument("custom_disclaimer", type=str, required=False, location="json")
.add_argument(
"customize_token_strategy",
type=str,
choices=["must", "allow", "not_allow"],
required=False,
location="json",
)
.add_argument("prompt_public", type=bool, required=False, location="json")
.add_argument("show_workflow_steps", type=bool, required=False, location="json")
.add_argument("use_icon_as_answer_icon", type=bool, required=False, location="json")
)
return parser.parse_args()
@console_ns.route("/apps/<uuid:app_id>/site") @console_ns.route("/apps/<uuid:app_id>/site")
class AppSite(Resource): class AppSite(Resource):
@console_ns.doc("update_app_site") @console_ns.doc("update_app_site")
@console_ns.doc(description="Update application site configuration") @console_ns.doc(description="Update application site configuration")
@console_ns.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect( @console_ns.expect(console_ns.models[AppSiteUpdatePayload.__name__])
console_ns.model(
"AppSiteRequest",
{
"title": fields.String(description="Site title"),
"icon_type": fields.String(description="Icon type"),
"icon": fields.String(description="Icon"),
"icon_background": fields.String(description="Icon background color"),
"description": fields.String(description="Site description"),
"default_language": fields.String(description="Default language"),
"chat_color_theme": fields.String(description="Chat color theme"),
"chat_color_theme_inverted": fields.Boolean(description="Inverted chat color theme"),
"customize_domain": fields.String(description="Custom domain"),
"copyright": fields.String(description="Copyright text"),
"privacy_policy": fields.String(description="Privacy policy"),
"custom_disclaimer": fields.String(description="Custom disclaimer"),
"customize_token_strategy": fields.String(
enum=["must", "allow", "not_allow"], description="Token strategy"
),
"prompt_public": fields.Boolean(description="Make prompt public"),
"show_workflow_steps": fields.Boolean(description="Show workflow steps"),
"use_icon_as_answer_icon": fields.Boolean(description="Use icon as answer icon"),
},
)
)
@console_ns.response(200, "Site configuration updated successfully", app_site_model) @console_ns.response(200, "Site configuration updated successfully", app_site_model)
@console_ns.response(403, "Insufficient permissions") @console_ns.response(403, "Insufficient permissions")
@console_ns.response(404, "App not found") @console_ns.response(404, "App not found")
@ -89,7 +73,7 @@ class AppSite(Resource):
@get_app_model @get_app_model
@marshal_with(app_site_model) @marshal_with(app_site_model)
def post(self, app_model): def post(self, app_model):
args = parse_app_site_args() args = AppSiteUpdatePayload.model_validate(console_ns.payload or {})
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
site = db.session.query(Site).where(Site.app_id == app_model.id).first() site = db.session.query(Site).where(Site.app_id == app_model.id).first()
if not site: if not site:
@ -113,7 +97,7 @@ class AppSite(Resource):
"show_workflow_steps", "show_workflow_steps",
"use_icon_as_answer_icon", "use_icon_as_answer_icon",
]: ]:
value = args.get(attr_name) value = getattr(args, attr_name)
if value is not None: if value is not None:
setattr(site, attr_name, value) setattr(site, attr_name, value)

View File

@ -1,10 +1,11 @@
import logging import logging
from collections.abc import Callable from collections.abc import Callable
from functools import wraps from functools import wraps
from typing import NoReturn, ParamSpec, TypeVar from typing import Any, NoReturn, ParamSpec, TypeVar
from flask import Response from flask import Response, request
from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse from flask_restx import Resource, fields, marshal, marshal_with
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from controllers.console import console_ns from controllers.console import console_ns
@ -29,6 +30,27 @@ from services.workflow_draft_variable_service import WorkflowDraftVariableList,
from services.workflow_service import WorkflowService from services.workflow_service import WorkflowService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class WorkflowDraftVariableListQuery(BaseModel):
page: int = Field(default=1, ge=1, le=100_000, description="Page number")
limit: int = Field(default=20, ge=1, le=100, description="Items per page")
class WorkflowDraftVariableUpdatePayload(BaseModel):
name: str | None = Field(default=None, description="Variable name")
value: Any | None = Field(default=None, description="Variable value")
console_ns.schema_model(
WorkflowDraftVariableListQuery.__name__,
WorkflowDraftVariableListQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
WorkflowDraftVariableUpdatePayload.__name__,
WorkflowDraftVariableUpdatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
def _convert_values_to_json_serializable_object(value: Segment): def _convert_values_to_json_serializable_object(value: Segment):
@ -57,22 +79,6 @@ def _serialize_var_value(variable: WorkflowDraftVariable):
return _convert_values_to_json_serializable_object(value) return _convert_values_to_json_serializable_object(value)
def _create_pagination_parser():
parser = (
reqparse.RequestParser()
.add_argument(
"page",
type=inputs.int_range(1, 100_000),
required=False,
default=1,
location="args",
help="the page of data requested",
)
.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
)
return parser
def _serialize_variable_type(workflow_draft_var: WorkflowDraftVariable) -> str: def _serialize_variable_type(workflow_draft_var: WorkflowDraftVariable) -> str:
value_type = workflow_draft_var.value_type value_type = workflow_draft_var.value_type
return value_type.exposed_type().value return value_type.exposed_type().value
@ -201,7 +207,7 @@ def _api_prerequisite(f: Callable[P, R]):
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/variables") @console_ns.route("/apps/<uuid:app_id>/workflows/draft/variables")
class WorkflowVariableCollectionApi(Resource): class WorkflowVariableCollectionApi(Resource):
@console_ns.expect(_create_pagination_parser()) @console_ns.expect(console_ns.models[WorkflowDraftVariableListQuery.__name__])
@console_ns.doc("get_workflow_variables") @console_ns.doc("get_workflow_variables")
@console_ns.doc(description="Get draft workflow variables") @console_ns.doc(description="Get draft workflow variables")
@console_ns.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@ -215,8 +221,7 @@ class WorkflowVariableCollectionApi(Resource):
""" """
Get draft workflow Get draft workflow
""" """
parser = _create_pagination_parser() args = WorkflowDraftVariableListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
args = parser.parse_args()
# fetch draft workflow by app_model # fetch draft workflow by app_model
workflow_service = WorkflowService() workflow_service = WorkflowService()
@ -323,15 +328,7 @@ class VariableApi(Resource):
@console_ns.doc("update_variable") @console_ns.doc("update_variable")
@console_ns.doc(description="Update a workflow variable") @console_ns.doc(description="Update a workflow variable")
@console_ns.expect( @console_ns.expect(console_ns.models[WorkflowDraftVariableUpdatePayload.__name__])
console_ns.model(
"UpdateVariableRequest",
{
"name": fields.String(description="Variable name"),
"value": fields.Raw(description="Variable value"),
},
)
)
@console_ns.response(200, "Variable updated successfully", workflow_draft_variable_model) @console_ns.response(200, "Variable updated successfully", workflow_draft_variable_model)
@console_ns.response(404, "Variable not found") @console_ns.response(404, "Variable not found")
@_api_prerequisite @_api_prerequisite
@ -358,16 +355,10 @@ class VariableApi(Resource):
# "upload_file_id": "1602650a-4fe4-423c-85a2-af76c083e3c4" # "upload_file_id": "1602650a-4fe4-423c-85a2-af76c083e3c4"
# } # }
parser = (
reqparse.RequestParser()
.add_argument(self._PATCH_NAME_FIELD, type=str, required=False, nullable=True, location="json")
.add_argument(self._PATCH_VALUE_FIELD, type=lambda x: x, required=False, nullable=True, location="json")
)
draft_var_srv = WorkflowDraftVariableService( draft_var_srv = WorkflowDraftVariableService(
session=db.session(), session=db.session(),
) )
args = parser.parse_args(strict=True) args_model = WorkflowDraftVariableUpdatePayload.model_validate(console_ns.payload or {})
variable = draft_var_srv.get_variable(variable_id=variable_id) variable = draft_var_srv.get_variable(variable_id=variable_id)
if variable is None: if variable is None:
@ -375,8 +366,8 @@ class VariableApi(Resource):
if variable.app_id != app_model.id: if variable.app_id != app_model.id:
raise NotFoundError(description=f"variable not found, id={variable_id}") raise NotFoundError(description=f"variable not found, id={variable_id}")
new_name = args.get(self._PATCH_NAME_FIELD, None) new_name = args_model.name
raw_value = args.get(self._PATCH_VALUE_FIELD, None) raw_value = args_model.value
if new_name is None and raw_value is None: if new_name is None and raw_value is None:
return variable return variable

View File

@ -114,7 +114,7 @@ class AppTriggersApi(Resource):
@console_ns.route("/apps/<uuid:app_id>/trigger-enable") @console_ns.route("/apps/<uuid:app_id>/trigger-enable")
class AppTriggerEnableApi(Resource): class AppTriggerEnableApi(Resource):
@console_ns.expect(console_ns.models[ParserEnable.__name__], validate=True) @console_ns.expect(console_ns.models[ParserEnable.__name__])
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required

View File

@ -1,28 +1,53 @@
from flask import request from flask import request
from flask_restx import Resource, fields, reqparse from flask_restx import Resource, fields
from pydantic import BaseModel, Field, field_validator
from constants.languages import supported_language from constants.languages import supported_language
from controllers.console import console_ns from controllers.console import console_ns
from controllers.console.error import AlreadyActivateError from controllers.console.error import AlreadyActivateError
from extensions.ext_database import db from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now from libs.datetime_utils import naive_utc_now
from libs.helper import StrLen, email, extract_remote_ip, timezone from libs.helper import EmailStr, extract_remote_ip, timezone
from models import AccountStatus from models import AccountStatus
from services.account_service import AccountService, RegisterService from services.account_service import AccountService, RegisterService
active_check_parser = ( DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
reqparse.RequestParser()
.add_argument("workspace_id", type=str, required=False, nullable=True, location="args", help="Workspace ID")
.add_argument("email", type=email, required=False, nullable=True, location="args", help="Email address") class ActivateCheckQuery(BaseModel):
.add_argument("token", type=str, required=True, nullable=False, location="args", help="Activation token") workspace_id: str | None = Field(default=None)
) email: EmailStr | None = Field(default=None)
token: str
class ActivatePayload(BaseModel):
workspace_id: str | None = Field(default=None)
email: EmailStr | None = Field(default=None)
token: str
name: str = Field(..., max_length=30)
interface_language: str = Field(...)
timezone: str = Field(...)
@field_validator("interface_language")
@classmethod
def validate_lang(cls, value: str) -> str:
return supported_language(value)
@field_validator("timezone")
@classmethod
def validate_tz(cls, value: str) -> str:
return timezone(value)
for model in (ActivateCheckQuery, ActivatePayload):
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
@console_ns.route("/activate/check") @console_ns.route("/activate/check")
class ActivateCheckApi(Resource): class ActivateCheckApi(Resource):
@console_ns.doc("check_activation_token") @console_ns.doc("check_activation_token")
@console_ns.doc(description="Check if activation token is valid") @console_ns.doc(description="Check if activation token is valid")
@console_ns.expect(active_check_parser) @console_ns.expect(console_ns.models[ActivateCheckQuery.__name__])
@console_ns.response( @console_ns.response(
200, 200,
"Success", "Success",
@ -35,11 +60,11 @@ class ActivateCheckApi(Resource):
), ),
) )
def get(self): def get(self):
args = active_check_parser.parse_args() args = ActivateCheckQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
workspaceId = args["workspace_id"] workspaceId = args.workspace_id
reg_email = args["email"] reg_email = args.email
token = args["token"] token = args.token
invitation = RegisterService.get_invitation_if_token_valid(workspaceId, reg_email, token) invitation = RegisterService.get_invitation_if_token_valid(workspaceId, reg_email, token)
if invitation: if invitation:
@ -56,22 +81,11 @@ class ActivateCheckApi(Resource):
return {"is_valid": False} return {"is_valid": False}
active_parser = (
reqparse.RequestParser()
.add_argument("workspace_id", type=str, required=False, nullable=True, location="json")
.add_argument("email", type=email, required=False, nullable=True, location="json")
.add_argument("token", type=str, required=True, nullable=False, location="json")
.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json")
.add_argument("interface_language", type=supported_language, required=True, nullable=False, location="json")
.add_argument("timezone", type=timezone, required=True, nullable=False, location="json")
)
@console_ns.route("/activate") @console_ns.route("/activate")
class ActivateApi(Resource): class ActivateApi(Resource):
@console_ns.doc("activate_account") @console_ns.doc("activate_account")
@console_ns.doc(description="Activate account with invitation token") @console_ns.doc(description="Activate account with invitation token")
@console_ns.expect(active_parser) @console_ns.expect(console_ns.models[ActivatePayload.__name__])
@console_ns.response( @console_ns.response(
200, 200,
"Account activated successfully", "Account activated successfully",
@ -85,19 +99,19 @@ class ActivateApi(Resource):
) )
@console_ns.response(400, "Already activated or invalid token") @console_ns.response(400, "Already activated or invalid token")
def post(self): def post(self):
args = active_parser.parse_args() args = ActivatePayload.model_validate(console_ns.payload)
invitation = RegisterService.get_invitation_if_token_valid(args["workspace_id"], args["email"], args["token"]) invitation = RegisterService.get_invitation_if_token_valid(args.workspace_id, args.email, args.token)
if invitation is None: if invitation is None:
raise AlreadyActivateError() raise AlreadyActivateError()
RegisterService.revoke_token(args["workspace_id"], args["email"], args["token"]) RegisterService.revoke_token(args.workspace_id, args.email, args.token)
account = invitation["account"] account = invitation["account"]
account.name = args["name"] account.name = args.name
account.interface_language = args["interface_language"] account.interface_language = args.interface_language
account.timezone = args["timezone"] account.timezone = args.timezone
account.interface_theme = "light" account.interface_theme = "light"
account.status = AccountStatus.ACTIVE account.status = AccountStatus.ACTIVE
account.initialized_at = naive_utc_now() account.initialized_at = naive_utc_now()

View File

@ -1,12 +1,26 @@
from flask_restx import Resource, reqparse from flask_restx import Resource
from pydantic import BaseModel, Field
from controllers.console import console_ns
from controllers.console.auth.error import ApiKeyAuthFailedError
from controllers.console.wraps import is_admin_or_owner_required
from libs.login import current_account_with_tenant, login_required from libs.login import current_account_with_tenant, login_required
from services.auth.api_key_auth_service import ApiKeyAuthService from services.auth.api_key_auth_service import ApiKeyAuthService
from ..wraps import account_initialization_required, setup_required from .. import console_ns
from ..auth.error import ApiKeyAuthFailedError
from ..wraps import account_initialization_required, is_admin_or_owner_required, setup_required
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class ApiKeyAuthBindingPayload(BaseModel):
category: str = Field(...)
provider: str = Field(...)
credentials: dict = Field(...)
console_ns.schema_model(
ApiKeyAuthBindingPayload.__name__,
ApiKeyAuthBindingPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
@console_ns.route("/api-key-auth/data-source") @console_ns.route("/api-key-auth/data-source")
@ -40,19 +54,15 @@ class ApiKeyAuthDataSourceBinding(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@is_admin_or_owner_required @is_admin_or_owner_required
@console_ns.expect(console_ns.models[ApiKeyAuthBindingPayload.__name__])
def post(self): def post(self):
# The role of the current user in the table must be admin or owner # The role of the current user in the table must be admin or owner
_, current_tenant_id = current_account_with_tenant() _, current_tenant_id = current_account_with_tenant()
parser = ( payload = ApiKeyAuthBindingPayload.model_validate(console_ns.payload)
reqparse.RequestParser() data = payload.model_dump()
.add_argument("category", type=str, required=True, nullable=False, location="json") ApiKeyAuthService.validate_api_key_auth_args(data)
.add_argument("provider", type=str, required=True, nullable=False, location="json")
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
)
args = parser.parse_args()
ApiKeyAuthService.validate_api_key_auth_args(args)
try: try:
ApiKeyAuthService.create_provider_auth(current_tenant_id, args) ApiKeyAuthService.create_provider_auth(current_tenant_id, data)
except Exception as e: except Exception as e:
raise ApiKeyAuthFailedError(str(e)) raise ApiKeyAuthFailedError(str(e))
return {"result": "success"}, 200 return {"result": "success"}, 200

View File

@ -5,12 +5,11 @@ from flask import current_app, redirect, request
from flask_restx import Resource, fields from flask_restx import Resource, fields
from configs import dify_config from configs import dify_config
from controllers.console import console_ns
from controllers.console.wraps import is_admin_or_owner_required
from libs.login import login_required from libs.login import login_required
from libs.oauth_data_source import NotionOAuth from libs.oauth_data_source import NotionOAuth
from ..wraps import account_initialization_required, setup_required from .. import console_ns
from ..wraps import account_initialization_required, is_admin_or_owner_required, setup_required
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@ -1,5 +1,6 @@
from flask import request from flask import request
from flask_restx import Resource, reqparse from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@ -14,16 +15,45 @@ from controllers.console.auth.error import (
InvalidTokenError, InvalidTokenError,
PasswordMismatchError, PasswordMismatchError,
) )
from controllers.console.error import AccountInFreezeError, EmailSendIpLimitError
from controllers.console.wraps import email_password_login_enabled, email_register_enabled, setup_required
from extensions.ext_database import db from extensions.ext_database import db
from libs.helper import email, extract_remote_ip from libs.helper import EmailStr, extract_remote_ip
from libs.password import valid_password from libs.password import valid_password
from models import Account from models import Account
from services.account_service import AccountService from services.account_service import AccountService
from services.billing_service import BillingService from services.billing_service import BillingService
from services.errors.account import AccountNotFoundError, AccountRegisterError from services.errors.account import AccountNotFoundError, AccountRegisterError
from ..error import AccountInFreezeError, EmailSendIpLimitError
from ..wraps import email_password_login_enabled, email_register_enabled, setup_required
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class EmailRegisterSendPayload(BaseModel):
email: EmailStr = Field(..., description="Email address")
language: str | None = Field(default=None, description="Language code")
class EmailRegisterValidityPayload(BaseModel):
email: EmailStr = Field(...)
code: str = Field(...)
token: str = Field(...)
class EmailRegisterResetPayload(BaseModel):
token: str = Field(...)
new_password: str = Field(...)
password_confirm: str = Field(...)
@field_validator("new_password", "password_confirm")
@classmethod
def validate_password(cls, value: str) -> str:
return valid_password(value)
for model in (EmailRegisterSendPayload, EmailRegisterValidityPayload, EmailRegisterResetPayload):
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
@console_ns.route("/email-register/send-email") @console_ns.route("/email-register/send-email")
class EmailRegisterSendEmailApi(Resource): class EmailRegisterSendEmailApi(Resource):
@ -31,27 +61,22 @@ class EmailRegisterSendEmailApi(Resource):
@email_password_login_enabled @email_password_login_enabled
@email_register_enabled @email_register_enabled
def post(self): def post(self):
parser = ( args = EmailRegisterSendPayload.model_validate(console_ns.payload)
reqparse.RequestParser()
.add_argument("email", type=email, required=True, location="json")
.add_argument("language", type=str, required=False, location="json")
)
args = parser.parse_args()
ip_address = extract_remote_ip(request) ip_address = extract_remote_ip(request)
if AccountService.is_email_send_ip_limit(ip_address): if AccountService.is_email_send_ip_limit(ip_address):
raise EmailSendIpLimitError() raise EmailSendIpLimitError()
language = "en-US" language = "en-US"
if args["language"] in languages: if args.language in languages:
language = args["language"] language = args.language
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(args["email"]): if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(args.email):
raise AccountInFreezeError() raise AccountInFreezeError()
with Session(db.engine) as session: with Session(db.engine) as session:
account = session.execute(select(Account).filter_by(email=args["email"])).scalar_one_or_none() account = session.execute(select(Account).filter_by(email=args.email)).scalar_one_or_none()
token = None token = None
token = AccountService.send_email_register_email(email=args["email"], account=account, language=language) token = AccountService.send_email_register_email(email=args.email, account=account, language=language)
return {"result": "success", "data": token} return {"result": "success", "data": token}
@ -61,40 +86,34 @@ class EmailRegisterCheckApi(Resource):
@email_password_login_enabled @email_password_login_enabled
@email_register_enabled @email_register_enabled
def post(self): def post(self):
parser = ( args = EmailRegisterValidityPayload.model_validate(console_ns.payload)
reqparse.RequestParser()
.add_argument("email", type=str, required=True, location="json")
.add_argument("code", type=str, required=True, location="json")
.add_argument("token", type=str, required=True, nullable=False, location="json")
)
args = parser.parse_args()
user_email = args["email"] user_email = args.email
is_email_register_error_rate_limit = AccountService.is_email_register_error_rate_limit(args["email"]) is_email_register_error_rate_limit = AccountService.is_email_register_error_rate_limit(args.email)
if is_email_register_error_rate_limit: if is_email_register_error_rate_limit:
raise EmailRegisterLimitError() raise EmailRegisterLimitError()
token_data = AccountService.get_email_register_data(args["token"]) token_data = AccountService.get_email_register_data(args.token)
if token_data is None: if token_data is None:
raise InvalidTokenError() raise InvalidTokenError()
if user_email != token_data.get("email"): if user_email != token_data.get("email"):
raise InvalidEmailError() raise InvalidEmailError()
if args["code"] != token_data.get("code"): if args.code != token_data.get("code"):
AccountService.add_email_register_error_rate_limit(args["email"]) AccountService.add_email_register_error_rate_limit(args.email)
raise EmailCodeError() raise EmailCodeError()
# Verified, revoke the first token # Verified, revoke the first token
AccountService.revoke_email_register_token(args["token"]) AccountService.revoke_email_register_token(args.token)
# Refresh token data by generating a new token # Refresh token data by generating a new token
_, new_token = AccountService.generate_email_register_token( _, new_token = AccountService.generate_email_register_token(
user_email, code=args["code"], additional_data={"phase": "register"} user_email, code=args.code, additional_data={"phase": "register"}
) )
AccountService.reset_email_register_error_rate_limit(args["email"]) AccountService.reset_email_register_error_rate_limit(args.email)
return {"is_valid": True, "email": token_data.get("email"), "token": new_token} return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
@ -104,20 +123,14 @@ class EmailRegisterResetApi(Resource):
@email_password_login_enabled @email_password_login_enabled
@email_register_enabled @email_register_enabled
def post(self): def post(self):
parser = ( args = EmailRegisterResetPayload.model_validate(console_ns.payload)
reqparse.RequestParser()
.add_argument("token", type=str, required=True, nullable=False, location="json")
.add_argument("new_password", type=valid_password, required=True, nullable=False, location="json")
.add_argument("password_confirm", type=valid_password, required=True, nullable=False, location="json")
)
args = parser.parse_args()
# Validate passwords match # Validate passwords match
if args["new_password"] != args["password_confirm"]: if args.new_password != args.password_confirm:
raise PasswordMismatchError() raise PasswordMismatchError()
# Validate token and get register data # Validate token and get register data
register_data = AccountService.get_email_register_data(args["token"]) register_data = AccountService.get_email_register_data(args.token)
if not register_data: if not register_data:
raise InvalidTokenError() raise InvalidTokenError()
# Must use token in reset phase # Must use token in reset phase
@ -125,7 +138,7 @@ class EmailRegisterResetApi(Resource):
raise InvalidTokenError() raise InvalidTokenError()
# Revoke token to prevent reuse # Revoke token to prevent reuse
AccountService.revoke_email_register_token(args["token"]) AccountService.revoke_email_register_token(args.token)
email = register_data.get("email", "") email = register_data.get("email", "")
@ -135,7 +148,7 @@ class EmailRegisterResetApi(Resource):
if account: if account:
raise EmailAlreadyInUseError() raise EmailAlreadyInUseError()
else: else:
account = self._create_new_account(email, args["password_confirm"]) account = self._create_new_account(email, args.password_confirm)
if not account: if not account:
raise AccountNotFoundError() raise AccountNotFoundError()
token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request)) token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request))

View File

@ -2,7 +2,8 @@ import base64
import secrets import secrets
from flask import request from flask import request
from flask_restx import Resource, fields, reqparse from flask_restx import Resource, fields
from pydantic import BaseModel, Field, field_validator
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@ -18,26 +19,46 @@ from controllers.console.error import AccountNotFound, EmailSendIpLimitError
from controllers.console.wraps import email_password_login_enabled, setup_required from controllers.console.wraps import email_password_login_enabled, setup_required
from events.tenant_event import tenant_was_created from events.tenant_event import tenant_was_created
from extensions.ext_database import db from extensions.ext_database import db
from libs.helper import email, extract_remote_ip from libs.helper import EmailStr, extract_remote_ip
from libs.password import hash_password, valid_password from libs.password import hash_password, valid_password
from models import Account from models import Account
from services.account_service import AccountService, TenantService from services.account_service import AccountService, TenantService
from services.feature_service import FeatureService from services.feature_service import FeatureService
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class ForgotPasswordSendPayload(BaseModel):
email: EmailStr = Field(...)
language: str | None = Field(default=None)
class ForgotPasswordCheckPayload(BaseModel):
email: EmailStr = Field(...)
code: str = Field(...)
token: str = Field(...)
class ForgotPasswordResetPayload(BaseModel):
token: str = Field(...)
new_password: str = Field(...)
password_confirm: str = Field(...)
@field_validator("new_password", "password_confirm")
@classmethod
def validate_password(cls, value: str) -> str:
return valid_password(value)
for model in (ForgotPasswordSendPayload, ForgotPasswordCheckPayload, ForgotPasswordResetPayload):
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
@console_ns.route("/forgot-password") @console_ns.route("/forgot-password")
class ForgotPasswordSendEmailApi(Resource): class ForgotPasswordSendEmailApi(Resource):
@console_ns.doc("send_forgot_password_email") @console_ns.doc("send_forgot_password_email")
@console_ns.doc(description="Send password reset email") @console_ns.doc(description="Send password reset email")
@console_ns.expect( @console_ns.expect(console_ns.models[ForgotPasswordSendPayload.__name__])
console_ns.model(
"ForgotPasswordEmailRequest",
{
"email": fields.String(required=True, description="Email address"),
"language": fields.String(description="Language for email (zh-Hans/en-US)"),
},
)
)
@console_ns.response( @console_ns.response(
200, 200,
"Email sent successfully", "Email sent successfully",
@ -54,28 +75,23 @@ class ForgotPasswordSendEmailApi(Resource):
@setup_required @setup_required
@email_password_login_enabled @email_password_login_enabled
def post(self): def post(self):
parser = ( args = ForgotPasswordSendPayload.model_validate(console_ns.payload)
reqparse.RequestParser()
.add_argument("email", type=email, required=True, location="json")
.add_argument("language", type=str, required=False, location="json")
)
args = parser.parse_args()
ip_address = extract_remote_ip(request) ip_address = extract_remote_ip(request)
if AccountService.is_email_send_ip_limit(ip_address): if AccountService.is_email_send_ip_limit(ip_address):
raise EmailSendIpLimitError() raise EmailSendIpLimitError()
if args["language"] is not None and args["language"] == "zh-Hans": if args.language is not None and args.language == "zh-Hans":
language = "zh-Hans" language = "zh-Hans"
else: else:
language = "en-US" language = "en-US"
with Session(db.engine) as session: with Session(db.engine) as session:
account = session.execute(select(Account).filter_by(email=args["email"])).scalar_one_or_none() account = session.execute(select(Account).filter_by(email=args.email)).scalar_one_or_none()
token = AccountService.send_reset_password_email( token = AccountService.send_reset_password_email(
account=account, account=account,
email=args["email"], email=args.email,
language=language, language=language,
is_allow_register=FeatureService.get_system_features().is_allow_register, is_allow_register=FeatureService.get_system_features().is_allow_register,
) )
@ -87,16 +103,7 @@ class ForgotPasswordSendEmailApi(Resource):
class ForgotPasswordCheckApi(Resource): class ForgotPasswordCheckApi(Resource):
@console_ns.doc("check_forgot_password_code") @console_ns.doc("check_forgot_password_code")
@console_ns.doc(description="Verify password reset code") @console_ns.doc(description="Verify password reset code")
@console_ns.expect( @console_ns.expect(console_ns.models[ForgotPasswordCheckPayload.__name__])
console_ns.model(
"ForgotPasswordCheckRequest",
{
"email": fields.String(required=True, description="Email address"),
"code": fields.String(required=True, description="Verification code"),
"token": fields.String(required=True, description="Reset token"),
},
)
)
@console_ns.response( @console_ns.response(
200, 200,
"Code verified successfully", "Code verified successfully",
@ -113,40 +120,34 @@ class ForgotPasswordCheckApi(Resource):
@setup_required @setup_required
@email_password_login_enabled @email_password_login_enabled
def post(self): def post(self):
parser = ( args = ForgotPasswordCheckPayload.model_validate(console_ns.payload)
reqparse.RequestParser()
.add_argument("email", type=str, required=True, location="json")
.add_argument("code", type=str, required=True, location="json")
.add_argument("token", type=str, required=True, nullable=False, location="json")
)
args = parser.parse_args()
user_email = args["email"] user_email = args.email
is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(args["email"]) is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(args.email)
if is_forgot_password_error_rate_limit: if is_forgot_password_error_rate_limit:
raise EmailPasswordResetLimitError() raise EmailPasswordResetLimitError()
token_data = AccountService.get_reset_password_data(args["token"]) token_data = AccountService.get_reset_password_data(args.token)
if token_data is None: if token_data is None:
raise InvalidTokenError() raise InvalidTokenError()
if user_email != token_data.get("email"): if user_email != token_data.get("email"):
raise InvalidEmailError() raise InvalidEmailError()
if args["code"] != token_data.get("code"): if args.code != token_data.get("code"):
AccountService.add_forgot_password_error_rate_limit(args["email"]) AccountService.add_forgot_password_error_rate_limit(args.email)
raise EmailCodeError() raise EmailCodeError()
# Verified, revoke the first token # Verified, revoke the first token
AccountService.revoke_reset_password_token(args["token"]) AccountService.revoke_reset_password_token(args.token)
# Refresh token data by generating a new token # Refresh token data by generating a new token
_, new_token = AccountService.generate_reset_password_token( _, new_token = AccountService.generate_reset_password_token(
user_email, code=args["code"], additional_data={"phase": "reset"} user_email, code=args.code, additional_data={"phase": "reset"}
) )
AccountService.reset_forgot_password_error_rate_limit(args["email"]) AccountService.reset_forgot_password_error_rate_limit(args.email)
return {"is_valid": True, "email": token_data.get("email"), "token": new_token} return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
@ -154,16 +155,7 @@ class ForgotPasswordCheckApi(Resource):
class ForgotPasswordResetApi(Resource): class ForgotPasswordResetApi(Resource):
@console_ns.doc("reset_password") @console_ns.doc("reset_password")
@console_ns.doc(description="Reset password with verification token") @console_ns.doc(description="Reset password with verification token")
@console_ns.expect( @console_ns.expect(console_ns.models[ForgotPasswordResetPayload.__name__])
console_ns.model(
"ForgotPasswordResetRequest",
{
"token": fields.String(required=True, description="Verification token"),
"new_password": fields.String(required=True, description="New password"),
"password_confirm": fields.String(required=True, description="Password confirmation"),
},
)
)
@console_ns.response( @console_ns.response(
200, 200,
"Password reset successfully", "Password reset successfully",
@ -173,20 +165,14 @@ class ForgotPasswordResetApi(Resource):
@setup_required @setup_required
@email_password_login_enabled @email_password_login_enabled
def post(self): def post(self):
parser = ( args = ForgotPasswordResetPayload.model_validate(console_ns.payload)
reqparse.RequestParser()
.add_argument("token", type=str, required=True, nullable=False, location="json")
.add_argument("new_password", type=valid_password, required=True, nullable=False, location="json")
.add_argument("password_confirm", type=valid_password, required=True, nullable=False, location="json")
)
args = parser.parse_args()
# Validate passwords match # Validate passwords match
if args["new_password"] != args["password_confirm"]: if args.new_password != args.password_confirm:
raise PasswordMismatchError() raise PasswordMismatchError()
# Validate token and get reset data # Validate token and get reset data
reset_data = AccountService.get_reset_password_data(args["token"]) reset_data = AccountService.get_reset_password_data(args.token)
if not reset_data: if not reset_data:
raise InvalidTokenError() raise InvalidTokenError()
# Must use token in reset phase # Must use token in reset phase
@ -194,11 +180,11 @@ class ForgotPasswordResetApi(Resource):
raise InvalidTokenError() raise InvalidTokenError()
# Revoke token to prevent reuse # Revoke token to prevent reuse
AccountService.revoke_reset_password_token(args["token"]) AccountService.revoke_reset_password_token(args.token)
# Generate secure salt and hash password # Generate secure salt and hash password
salt = secrets.token_bytes(16) salt = secrets.token_bytes(16)
password_hashed = hash_password(args["new_password"], salt) password_hashed = hash_password(args.new_password, salt)
email = reset_data.get("email", "") email = reset_data.get("email", "")

View File

@ -1,6 +1,7 @@
import flask_login import flask_login
from flask import make_response, request from flask import make_response, request
from flask_restx import Resource, reqparse from flask_restx import Resource
from pydantic import BaseModel, Field
import services import services
from configs import dify_config from configs import dify_config
@ -23,7 +24,7 @@ from controllers.console.error import (
) )
from controllers.console.wraps import email_password_login_enabled, setup_required from controllers.console.wraps import email_password_login_enabled, setup_required
from events.tenant_event import tenant_was_created from events.tenant_event import tenant_was_created
from libs.helper import email, extract_remote_ip from libs.helper import EmailStr, extract_remote_ip
from libs.login import current_account_with_tenant from libs.login import current_account_with_tenant
from libs.token import ( from libs.token import (
clear_access_token_from_cookie, clear_access_token_from_cookie,
@ -40,6 +41,36 @@ from services.errors.account import AccountRegisterError
from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkspacesLimitExceededError from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkspacesLimitExceededError
from services.feature_service import FeatureService from services.feature_service import FeatureService
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class LoginPayload(BaseModel):
email: EmailStr = Field(..., description="Email address")
password: str = Field(..., description="Password")
remember_me: bool = Field(default=False, description="Remember me flag")
invite_token: str | None = Field(default=None, description="Invitation token")
class EmailPayload(BaseModel):
email: EmailStr = Field(...)
language: str | None = Field(default=None)
class EmailCodeLoginPayload(BaseModel):
email: EmailStr = Field(...)
code: str = Field(...)
token: str = Field(...)
language: str | None = Field(default=None)
def reg(cls: type[BaseModel]):
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
reg(LoginPayload)
reg(EmailPayload)
reg(EmailCodeLoginPayload)
@console_ns.route("/login") @console_ns.route("/login")
class LoginApi(Resource): class LoginApi(Resource):
@ -47,41 +78,36 @@ class LoginApi(Resource):
@setup_required @setup_required
@email_password_login_enabled @email_password_login_enabled
@console_ns.expect(console_ns.models[LoginPayload.__name__])
def post(self): def post(self):
"""Authenticate user and login.""" """Authenticate user and login."""
parser = ( args = LoginPayload.model_validate(console_ns.payload)
reqparse.RequestParser()
.add_argument("email", type=email, required=True, location="json")
.add_argument("password", type=str, required=True, location="json")
.add_argument("remember_me", type=bool, required=False, default=False, location="json")
.add_argument("invite_token", type=str, required=False, default=None, location="json")
)
args = parser.parse_args()
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(args["email"]): if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(args.email):
raise AccountInFreezeError() raise AccountInFreezeError()
is_login_error_rate_limit = AccountService.is_login_error_rate_limit(args["email"]) is_login_error_rate_limit = AccountService.is_login_error_rate_limit(args.email)
if is_login_error_rate_limit: if is_login_error_rate_limit:
raise EmailPasswordLoginLimitError() raise EmailPasswordLoginLimitError()
invitation = args["invite_token"] # TODO: why invitation is re-assigned with different type?
invitation = args.invite_token # type: ignore
if invitation: if invitation:
invitation = RegisterService.get_invitation_if_token_valid(None, args["email"], invitation) invitation = RegisterService.get_invitation_if_token_valid(None, args.email, invitation) # type: ignore
try: try:
if invitation: if invitation:
data = invitation.get("data", {}) data = invitation.get("data", {}) # type: ignore
invitee_email = data.get("email") if data else None invitee_email = data.get("email") if data else None
if invitee_email != args["email"]: if invitee_email != args.email:
raise InvalidEmailError() raise InvalidEmailError()
account = AccountService.authenticate(args["email"], args["password"], args["invite_token"]) account = AccountService.authenticate(args.email, args.password, args.invite_token)
else: else:
account = AccountService.authenticate(args["email"], args["password"]) account = AccountService.authenticate(args.email, args.password)
except services.errors.account.AccountLoginError: except services.errors.account.AccountLoginError:
raise AccountBannedError() raise AccountBannedError()
except services.errors.account.AccountPasswordError: except services.errors.account.AccountPasswordError:
AccountService.add_login_error_rate_limit(args["email"]) AccountService.add_login_error_rate_limit(args.email)
raise AuthenticationFailedError() raise AuthenticationFailedError()
# SELF_HOSTED only have one workspace # SELF_HOSTED only have one workspace
tenants = TenantService.get_join_tenants(account) tenants = TenantService.get_join_tenants(account)
@ -97,7 +123,7 @@ class LoginApi(Resource):
} }
token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request)) token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request))
AccountService.reset_login_error_rate_limit(args["email"]) AccountService.reset_login_error_rate_limit(args.email)
# Create response with cookies instead of returning tokens in body # Create response with cookies instead of returning tokens in body
response = make_response({"result": "success"}) response = make_response({"result": "success"})
@ -134,25 +160,21 @@ class LogoutApi(Resource):
class ResetPasswordSendEmailApi(Resource): class ResetPasswordSendEmailApi(Resource):
@setup_required @setup_required
@email_password_login_enabled @email_password_login_enabled
@console_ns.expect(console_ns.models[EmailPayload.__name__])
def post(self): def post(self):
parser = ( args = EmailPayload.model_validate(console_ns.payload)
reqparse.RequestParser()
.add_argument("email", type=email, required=True, location="json")
.add_argument("language", type=str, required=False, location="json")
)
args = parser.parse_args()
if args["language"] is not None and args["language"] == "zh-Hans": if args.language is not None and args.language == "zh-Hans":
language = "zh-Hans" language = "zh-Hans"
else: else:
language = "en-US" language = "en-US"
try: try:
account = AccountService.get_user_through_email(args["email"]) account = AccountService.get_user_through_email(args.email)
except AccountRegisterError: except AccountRegisterError:
raise AccountInFreezeError() raise AccountInFreezeError()
token = AccountService.send_reset_password_email( token = AccountService.send_reset_password_email(
email=args["email"], email=args.email,
account=account, account=account,
language=language, language=language,
is_allow_register=FeatureService.get_system_features().is_allow_register, is_allow_register=FeatureService.get_system_features().is_allow_register,
@ -164,30 +186,26 @@ class ResetPasswordSendEmailApi(Resource):
@console_ns.route("/email-code-login") @console_ns.route("/email-code-login")
class EmailCodeLoginSendEmailApi(Resource): class EmailCodeLoginSendEmailApi(Resource):
@setup_required @setup_required
@console_ns.expect(console_ns.models[EmailPayload.__name__])
def post(self): def post(self):
parser = ( args = EmailPayload.model_validate(console_ns.payload)
reqparse.RequestParser()
.add_argument("email", type=email, required=True, location="json")
.add_argument("language", type=str, required=False, location="json")
)
args = parser.parse_args()
ip_address = extract_remote_ip(request) ip_address = extract_remote_ip(request)
if AccountService.is_email_send_ip_limit(ip_address): if AccountService.is_email_send_ip_limit(ip_address):
raise EmailSendIpLimitError() raise EmailSendIpLimitError()
if args["language"] is not None and args["language"] == "zh-Hans": if args.language is not None and args.language == "zh-Hans":
language = "zh-Hans" language = "zh-Hans"
else: else:
language = "en-US" language = "en-US"
try: try:
account = AccountService.get_user_through_email(args["email"]) account = AccountService.get_user_through_email(args.email)
except AccountRegisterError: except AccountRegisterError:
raise AccountInFreezeError() raise AccountInFreezeError()
if account is None: if account is None:
if FeatureService.get_system_features().is_allow_register: if FeatureService.get_system_features().is_allow_register:
token = AccountService.send_email_code_login_email(email=args["email"], language=language) token = AccountService.send_email_code_login_email(email=args.email, language=language)
else: else:
raise AccountNotFound() raise AccountNotFound()
else: else:
@ -199,30 +217,24 @@ class EmailCodeLoginSendEmailApi(Resource):
@console_ns.route("/email-code-login/validity") @console_ns.route("/email-code-login/validity")
class EmailCodeLoginApi(Resource): class EmailCodeLoginApi(Resource):
@setup_required @setup_required
@console_ns.expect(console_ns.models[EmailCodeLoginPayload.__name__])
def post(self): def post(self):
parser = ( args = EmailCodeLoginPayload.model_validate(console_ns.payload)
reqparse.RequestParser()
.add_argument("email", type=str, required=True, location="json")
.add_argument("code", type=str, required=True, location="json")
.add_argument("token", type=str, required=True, location="json")
.add_argument("language", type=str, required=False, location="json")
)
args = parser.parse_args()
user_email = args["email"] user_email = args.email
language = args["language"] language = args.language
token_data = AccountService.get_email_code_login_data(args["token"]) token_data = AccountService.get_email_code_login_data(args.token)
if token_data is None: if token_data is None:
raise InvalidTokenError() raise InvalidTokenError()
if token_data["email"] != args["email"]: if token_data["email"] != args.email:
raise InvalidEmailError() raise InvalidEmailError()
if token_data["code"] != args["code"]: if token_data["code"] != args.code:
raise EmailCodeError() raise EmailCodeError()
AccountService.revoke_email_code_login_token(args["token"]) AccountService.revoke_email_code_login_token(args.token)
try: try:
account = AccountService.get_user_through_email(user_email) account = AccountService.get_user_through_email(user_email)
except AccountRegisterError: except AccountRegisterError:
@ -255,7 +267,7 @@ class EmailCodeLoginApi(Resource):
except WorkspacesLimitExceededError: except WorkspacesLimitExceededError:
raise WorkspacesLimitExceeded() raise WorkspacesLimitExceeded()
token_pair = AccountService.login(account, ip_address=extract_remote_ip(request)) token_pair = AccountService.login(account, ip_address=extract_remote_ip(request))
AccountService.reset_login_error_rate_limit(args["email"]) AccountService.reset_login_error_rate_limit(args.email)
# Create response with cookies instead of returning tokens in body # Create response with cookies instead of returning tokens in body
response = make_response({"result": "success"}) response = make_response({"result": "success"})

View File

@ -3,7 +3,8 @@ from functools import wraps
from typing import Concatenate, ParamSpec, TypeVar from typing import Concatenate, ParamSpec, TypeVar
from flask import jsonify, request from flask import jsonify, request
from flask_restx import Resource, reqparse from flask_restx import Resource
from pydantic import BaseModel
from werkzeug.exceptions import BadRequest, NotFound from werkzeug.exceptions import BadRequest, NotFound
from controllers.console.wraps import account_initialization_required, setup_required from controllers.console.wraps import account_initialization_required, setup_required
@ -20,15 +21,34 @@ R = TypeVar("R")
T = TypeVar("T") T = TypeVar("T")
class OAuthClientPayload(BaseModel):
client_id: str
class OAuthProviderRequest(BaseModel):
client_id: str
redirect_uri: str
class OAuthTokenRequest(BaseModel):
client_id: str
grant_type: str
code: str | None = None
client_secret: str | None = None
redirect_uri: str | None = None
refresh_token: str | None = None
def oauth_server_client_id_required(view: Callable[Concatenate[T, OAuthProviderApp, P], R]): def oauth_server_client_id_required(view: Callable[Concatenate[T, OAuthProviderApp, P], R]):
@wraps(view) @wraps(view)
def decorated(self: T, *args: P.args, **kwargs: P.kwargs): def decorated(self: T, *args: P.args, **kwargs: P.kwargs):
parser = reqparse.RequestParser().add_argument("client_id", type=str, required=True, location="json") json_data = request.get_json()
parsed_args = parser.parse_args() if json_data is None:
client_id = parsed_args.get("client_id")
if not client_id:
raise BadRequest("client_id is required") raise BadRequest("client_id is required")
payload = OAuthClientPayload.model_validate(json_data)
client_id = payload.client_id
oauth_provider_app = OAuthServerService.get_oauth_provider_app(client_id) oauth_provider_app = OAuthServerService.get_oauth_provider_app(client_id)
if not oauth_provider_app: if not oauth_provider_app:
raise NotFound("client_id is invalid") raise NotFound("client_id is invalid")
@ -89,9 +109,8 @@ class OAuthServerAppApi(Resource):
@setup_required @setup_required
@oauth_server_client_id_required @oauth_server_client_id_required
def post(self, oauth_provider_app: OAuthProviderApp): def post(self, oauth_provider_app: OAuthProviderApp):
parser = reqparse.RequestParser().add_argument("redirect_uri", type=str, required=True, location="json") payload = OAuthProviderRequest.model_validate(request.get_json())
parsed_args = parser.parse_args() redirect_uri = payload.redirect_uri
redirect_uri = parsed_args.get("redirect_uri")
# check if redirect_uri is valid # check if redirect_uri is valid
if redirect_uri not in oauth_provider_app.redirect_uris: if redirect_uri not in oauth_provider_app.redirect_uris:
@ -130,33 +149,25 @@ class OAuthServerUserTokenApi(Resource):
@setup_required @setup_required
@oauth_server_client_id_required @oauth_server_client_id_required
def post(self, oauth_provider_app: OAuthProviderApp): def post(self, oauth_provider_app: OAuthProviderApp):
parser = ( payload = OAuthTokenRequest.model_validate(request.get_json())
reqparse.RequestParser()
.add_argument("grant_type", type=str, required=True, location="json")
.add_argument("code", type=str, required=False, location="json")
.add_argument("client_secret", type=str, required=False, location="json")
.add_argument("redirect_uri", type=str, required=False, location="json")
.add_argument("refresh_token", type=str, required=False, location="json")
)
parsed_args = parser.parse_args()
try: try:
grant_type = OAuthGrantType(parsed_args["grant_type"]) grant_type = OAuthGrantType(payload.grant_type)
except ValueError: except ValueError:
raise BadRequest("invalid grant_type") raise BadRequest("invalid grant_type")
if grant_type == OAuthGrantType.AUTHORIZATION_CODE: if grant_type == OAuthGrantType.AUTHORIZATION_CODE:
if not parsed_args["code"]: if not payload.code:
raise BadRequest("code is required") raise BadRequest("code is required")
if parsed_args["client_secret"] != oauth_provider_app.client_secret: if payload.client_secret != oauth_provider_app.client_secret:
raise BadRequest("client_secret is invalid") raise BadRequest("client_secret is invalid")
if parsed_args["redirect_uri"] not in oauth_provider_app.redirect_uris: if payload.redirect_uri not in oauth_provider_app.redirect_uris:
raise BadRequest("redirect_uri is invalid") raise BadRequest("redirect_uri is invalid")
access_token, refresh_token = OAuthServerService.sign_oauth_access_token( access_token, refresh_token = OAuthServerService.sign_oauth_access_token(
grant_type, code=parsed_args["code"], client_id=oauth_provider_app.client_id grant_type, code=payload.code, client_id=oauth_provider_app.client_id
) )
return jsonable_encoder( return jsonable_encoder(
{ {
@ -167,11 +178,11 @@ class OAuthServerUserTokenApi(Resource):
} }
) )
elif grant_type == OAuthGrantType.REFRESH_TOKEN: elif grant_type == OAuthGrantType.REFRESH_TOKEN:
if not parsed_args["refresh_token"]: if not payload.refresh_token:
raise BadRequest("refresh_token is required") raise BadRequest("refresh_token is required")
access_token, refresh_token = OAuthServerService.sign_oauth_access_token( access_token, refresh_token = OAuthServerService.sign_oauth_access_token(
grant_type, refresh_token=parsed_args["refresh_token"], client_id=oauth_provider_app.client_id grant_type, refresh_token=payload.refresh_token, client_id=oauth_provider_app.client_id
) )
return jsonable_encoder( return jsonable_encoder(
{ {

View File

@ -1,6 +1,8 @@
import base64 import base64
from flask_restx import Resource, fields, reqparse from flask import request
from flask_restx import Resource, fields
from pydantic import BaseModel, Field, field_validator
from werkzeug.exceptions import BadRequest from werkzeug.exceptions import BadRequest
from controllers.console import console_ns from controllers.console import console_ns
@ -9,6 +11,35 @@ from enums.cloud_plan import CloudPlan
from libs.login import current_account_with_tenant, login_required from libs.login import current_account_with_tenant, login_required
from services.billing_service import BillingService from services.billing_service import BillingService
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class SubscriptionQuery(BaseModel):
plan: str = Field(..., description="Subscription plan")
interval: str = Field(..., description="Billing interval")
@field_validator("plan")
@classmethod
def validate_plan(cls, value: str) -> str:
if value not in [CloudPlan.PROFESSIONAL, CloudPlan.TEAM]:
raise ValueError("Invalid plan")
return value
@field_validator("interval")
@classmethod
def validate_interval(cls, value: str) -> str:
if value not in {"month", "year"}:
raise ValueError("Invalid interval")
return value
class PartnerTenantsPayload(BaseModel):
click_id: str = Field(..., description="Click Id from partner referral link")
for model in (SubscriptionQuery, PartnerTenantsPayload):
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
@console_ns.route("/billing/subscription") @console_ns.route("/billing/subscription")
class Subscription(Resource): class Subscription(Resource):
@ -18,20 +49,9 @@ class Subscription(Resource):
@only_edition_cloud @only_edition_cloud
def get(self): def get(self):
current_user, current_tenant_id = current_account_with_tenant() current_user, current_tenant_id = current_account_with_tenant()
parser = ( args = SubscriptionQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
reqparse.RequestParser()
.add_argument(
"plan",
type=str,
required=True,
location="args",
choices=[CloudPlan.PROFESSIONAL, CloudPlan.TEAM],
)
.add_argument("interval", type=str, required=True, location="args", choices=["month", "year"])
)
args = parser.parse_args()
BillingService.is_tenant_owner_or_admin(current_user) BillingService.is_tenant_owner_or_admin(current_user)
return BillingService.get_subscription(args["plan"], args["interval"], current_user.email, current_tenant_id) return BillingService.get_subscription(args.plan, args.interval, current_user.email, current_tenant_id)
@console_ns.route("/billing/invoices") @console_ns.route("/billing/invoices")
@ -65,11 +85,10 @@ class PartnerTenants(Resource):
@only_edition_cloud @only_edition_cloud
def put(self, partner_key: str): def put(self, partner_key: str):
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser().add_argument("click_id", required=True, type=str, location="json")
args = parser.parse_args()
try: try:
click_id = args["click_id"] args = PartnerTenantsPayload.model_validate(console_ns.payload or {})
click_id = args.click_id
decoded_partner_key = base64.b64decode(partner_key).decode("utf-8") decoded_partner_key = base64.b64decode(partner_key).decode("utf-8")
except Exception: except Exception:
raise BadRequest("Invalid partner_key") raise BadRequest("Invalid partner_key")

View File

@ -1,5 +1,6 @@
from flask import request from flask import request
from flask_restx import Resource, reqparse from flask_restx import Resource
from pydantic import BaseModel, Field
from libs.helper import extract_remote_ip from libs.helper import extract_remote_ip
from libs.login import current_account_with_tenant, login_required from libs.login import current_account_with_tenant, login_required
@ -9,16 +10,28 @@ from .. import console_ns
from ..wraps import account_initialization_required, only_edition_cloud, setup_required from ..wraps import account_initialization_required, only_edition_cloud, setup_required
class ComplianceDownloadQuery(BaseModel):
doc_name: str = Field(..., description="Compliance document name")
console_ns.schema_model(
ComplianceDownloadQuery.__name__,
ComplianceDownloadQuery.model_json_schema(ref_template="#/definitions/{model}"),
)
@console_ns.route("/compliance/download") @console_ns.route("/compliance/download")
class ComplianceApi(Resource): class ComplianceApi(Resource):
@console_ns.expect(console_ns.models[ComplianceDownloadQuery.__name__])
@console_ns.doc("download_compliance_document")
@console_ns.doc(description="Get compliance document download link")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@only_edition_cloud @only_edition_cloud
def get(self): def get(self):
current_user, current_tenant_id = current_account_with_tenant() current_user, current_tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser().add_argument("doc_name", type=str, required=True, location="args") args = ComplianceDownloadQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
args = parser.parse_args()
ip_address = extract_remote_ip(request) ip_address = extract_remote_ip(request)
device_info = request.headers.get("User-Agent", "Unknown device") device_info = request.headers.get("User-Agent", "Unknown device")

View File

@ -1,15 +1,15 @@
import json import json
from collections.abc import Generator from collections.abc import Generator
from typing import cast from typing import Any, cast
from flask import request from flask import request
from flask_restx import Resource, marshal_with, reqparse from flask_restx import Resource, marshal_with
from pydantic import BaseModel, Field
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
from controllers.console import console_ns from controllers.common.schema import register_schema_model
from controllers.console.wraps import account_initialization_required, setup_required
from core.datasource.entities.datasource_entities import DatasourceProviderType, OnlineDocumentPagesMessage from core.datasource.entities.datasource_entities import DatasourceProviderType, OnlineDocumentPagesMessage
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
from core.indexing_runner import IndexingRunner from core.indexing_runner import IndexingRunner
@ -25,6 +25,19 @@ from services.dataset_service import DatasetService, DocumentService
from services.datasource_provider_service import DatasourceProviderService from services.datasource_provider_service import DatasourceProviderService
from tasks.document_indexing_sync_task import document_indexing_sync_task from tasks.document_indexing_sync_task import document_indexing_sync_task
from .. import console_ns
from ..wraps import account_initialization_required, setup_required
class NotionEstimatePayload(BaseModel):
notion_info_list: list[dict[str, Any]]
process_rule: dict[str, Any]
doc_form: str = Field(default="text_model")
doc_language: str = Field(default="English")
register_schema_model(console_ns, NotionEstimatePayload)
@console_ns.route( @console_ns.route(
"/data-source/integrates", "/data-source/integrates",
@ -243,20 +256,15 @@ class DataSourceNotionApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@console_ns.expect(console_ns.models[NotionEstimatePayload.__name__])
def post(self): def post(self):
_, current_tenant_id = current_account_with_tenant() _, current_tenant_id = current_account_with_tenant()
parser = ( payload = NotionEstimatePayload.model_validate(console_ns.payload or {})
reqparse.RequestParser() args = payload.model_dump()
.add_argument("notion_info_list", type=list, required=True, nullable=True, location="json")
.add_argument("process_rule", type=dict, required=True, nullable=True, location="json")
.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
.add_argument("doc_language", type=str, default="English", required=False, nullable=False, location="json")
)
args = parser.parse_args()
# validate args # validate args
DocumentService.estimate_args_validate(args) DocumentService.estimate_args_validate(args)
notion_info_list = args["notion_info_list"] notion_info_list = payload.notion_info_list
extract_settings = [] extract_settings = []
for notion_info in notion_info_list: for notion_info in notion_info_list:
workspace_id = notion_info["workspace_id"] workspace_id = notion_info["workspace_id"]

View File

@ -1,12 +1,14 @@
from typing import Any, cast from typing import Any, cast
from flask import request from flask import request
from flask_restx import Resource, fields, marshal, marshal_with, reqparse from flask_restx import Resource, fields, marshal, marshal_with
from pydantic import BaseModel, Field, field_validator
from sqlalchemy import select from sqlalchemy import select
from werkzeug.exceptions import Forbidden, NotFound from werkzeug.exceptions import Forbidden, NotFound
import services import services
from configs import dify_config from configs import dify_config
from controllers.common.schema import register_schema_models
from controllers.console import console_ns from controllers.console import console_ns
from controllers.console.apikey import ( from controllers.console.apikey import (
api_key_item_model, api_key_item_model,
@ -48,7 +50,6 @@ from fields.dataset_fields import (
) )
from fields.document_fields import document_status_fields from fields.document_fields import document_status_fields
from libs.login import current_account_with_tenant, login_required from libs.login import current_account_with_tenant, login_required
from libs.validators import validate_description_length
from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile
from models.dataset import DatasetPermissionEnum from models.dataset import DatasetPermissionEnum
from models.provider_ids import ModelProviderID from models.provider_ids import ModelProviderID
@ -107,10 +108,75 @@ related_app_list_copy["data"] = fields.List(fields.Nested(app_detail_kernel_mode
related_app_list_model = _get_or_create_model("RelatedAppList", related_app_list_copy) related_app_list_model = _get_or_create_model("RelatedAppList", related_app_list_copy)
def _validate_name(name: str) -> str: def _validate_indexing_technique(value: str | None) -> str | None:
if not name or len(name) < 1 or len(name) > 40: if value is None:
raise ValueError("Name must be between 1 to 40 characters.") return value
return name if value not in Dataset.INDEXING_TECHNIQUE_LIST:
raise ValueError("Invalid indexing technique.")
return value
class DatasetCreatePayload(BaseModel):
name: str = Field(..., min_length=1, max_length=40)
description: str = Field("", max_length=400)
indexing_technique: str | None = None
permission: DatasetPermissionEnum | None = DatasetPermissionEnum.ONLY_ME
provider: str = "vendor"
external_knowledge_api_id: str | None = None
external_knowledge_id: str | None = None
@field_validator("indexing_technique")
@classmethod
def validate_indexing(cls, value: str | None) -> str | None:
return _validate_indexing_technique(value)
@field_validator("provider")
@classmethod
def validate_provider(cls, value: str) -> str:
if value not in Dataset.PROVIDER_LIST:
raise ValueError("Invalid provider.")
return value
class DatasetUpdatePayload(BaseModel):
name: str | None = Field(None, min_length=1, max_length=40)
description: str | None = Field(None, max_length=400)
permission: DatasetPermissionEnum | None = None
indexing_technique: str | None = None
embedding_model: str | None = None
embedding_model_provider: str | None = None
retrieval_model: dict[str, Any] | None = None
partial_member_list: list[str] | None = None
external_retrieval_model: dict[str, Any] | None = None
external_knowledge_id: str | None = None
external_knowledge_api_id: str | None = None
icon_info: dict[str, Any] | None = None
is_multimodal: bool | None = False
@field_validator("indexing_technique")
@classmethod
def validate_indexing(cls, value: str | None) -> str | None:
return _validate_indexing_technique(value)
class IndexingEstimatePayload(BaseModel):
info_list: dict[str, Any]
process_rule: dict[str, Any]
indexing_technique: str
doc_form: str = "text_model"
dataset_id: str | None = None
doc_language: str = "English"
@field_validator("indexing_technique")
@classmethod
def validate_indexing(cls, value: str) -> str:
result = _validate_indexing_technique(value)
if result is None:
raise ValueError("indexing_technique is required.")
return result
register_schema_models(console_ns, DatasetCreatePayload, DatasetUpdatePayload, IndexingEstimatePayload)
def _get_retrieval_methods_by_vector_type(vector_type: str | None, is_mock: bool = False) -> dict[str, list[str]]: def _get_retrieval_methods_by_vector_type(vector_type: str | None, is_mock: bool = False) -> dict[str, list[str]]:
@ -255,20 +321,7 @@ class DatasetListApi(Resource):
@console_ns.doc("create_dataset") @console_ns.doc("create_dataset")
@console_ns.doc(description="Create a new dataset") @console_ns.doc(description="Create a new dataset")
@console_ns.expect( @console_ns.expect(console_ns.models[DatasetCreatePayload.__name__])
console_ns.model(
"CreateDatasetRequest",
{
"name": fields.String(required=True, description="Dataset name (1-40 characters)"),
"description": fields.String(description="Dataset description (max 400 characters)"),
"indexing_technique": fields.String(description="Indexing technique"),
"permission": fields.String(description="Dataset permission"),
"provider": fields.String(description="Provider"),
"external_knowledge_api_id": fields.String(description="External knowledge API ID"),
"external_knowledge_id": fields.String(description="External knowledge ID"),
},
)
)
@console_ns.response(201, "Dataset created successfully") @console_ns.response(201, "Dataset created successfully")
@console_ns.response(400, "Invalid request parameters") @console_ns.response(400, "Invalid request parameters")
@setup_required @setup_required
@ -276,52 +329,7 @@ class DatasetListApi(Resource):
@account_initialization_required @account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge") @cloud_edition_billing_rate_limit_check("knowledge")
def post(self): def post(self):
parser = ( payload = DatasetCreatePayload.model_validate(console_ns.payload or {})
reqparse.RequestParser()
.add_argument(
"name",
nullable=False,
required=True,
help="type is required. Name must be between 1 to 40 characters.",
type=_validate_name,
)
.add_argument(
"description",
type=validate_description_length,
nullable=True,
required=False,
default="",
)
.add_argument(
"indexing_technique",
type=str,
location="json",
choices=Dataset.INDEXING_TECHNIQUE_LIST,
nullable=True,
help="Invalid indexing technique.",
)
.add_argument(
"external_knowledge_api_id",
type=str,
nullable=True,
required=False,
)
.add_argument(
"provider",
type=str,
nullable=True,
choices=Dataset.PROVIDER_LIST,
required=False,
default="vendor",
)
.add_argument(
"external_knowledge_id",
type=str,
nullable=True,
required=False,
)
)
args = parser.parse_args()
current_user, current_tenant_id = current_account_with_tenant() current_user, current_tenant_id = current_account_with_tenant()
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator # The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
@ -331,14 +339,14 @@ class DatasetListApi(Resource):
try: try:
dataset = DatasetService.create_empty_dataset( dataset = DatasetService.create_empty_dataset(
tenant_id=current_tenant_id, tenant_id=current_tenant_id,
name=args["name"], name=payload.name,
description=args["description"], description=payload.description,
indexing_technique=args["indexing_technique"], indexing_technique=payload.indexing_technique,
account=current_user, account=current_user,
permission=DatasetPermissionEnum.ONLY_ME, permission=payload.permission or DatasetPermissionEnum.ONLY_ME,
provider=args["provider"], provider=payload.provider,
external_knowledge_api_id=args["external_knowledge_api_id"], external_knowledge_api_id=payload.external_knowledge_api_id,
external_knowledge_id=args["external_knowledge_id"], external_knowledge_id=payload.external_knowledge_id,
) )
except services.errors.dataset.DatasetNameDuplicateError: except services.errors.dataset.DatasetNameDuplicateError:
raise DatasetNameDuplicateError() raise DatasetNameDuplicateError()
@ -399,18 +407,7 @@ class DatasetApi(Resource):
@console_ns.doc("update_dataset") @console_ns.doc("update_dataset")
@console_ns.doc(description="Update dataset details") @console_ns.doc(description="Update dataset details")
@console_ns.expect( @console_ns.expect(console_ns.models[DatasetUpdatePayload.__name__])
console_ns.model(
"UpdateDatasetRequest",
{
"name": fields.String(description="Dataset name"),
"description": fields.String(description="Dataset description"),
"permission": fields.String(description="Dataset permission"),
"indexing_technique": fields.String(description="Indexing technique"),
"external_retrieval_model": fields.Raw(description="External retrieval model settings"),
},
)
)
@console_ns.response(200, "Dataset updated successfully", dataset_detail_model) @console_ns.response(200, "Dataset updated successfully", dataset_detail_model)
@console_ns.response(404, "Dataset not found") @console_ns.response(404, "Dataset not found")
@console_ns.response(403, "Permission denied") @console_ns.response(403, "Permission denied")
@ -424,93 +421,25 @@ class DatasetApi(Resource):
if dataset is None: if dataset is None:
raise NotFound("Dataset not found.") raise NotFound("Dataset not found.")
parser = ( payload = DatasetUpdatePayload.model_validate(console_ns.payload or {})
reqparse.RequestParser()
.add_argument(
"name",
nullable=False,
help="type is required. Name must be between 1 to 40 characters.",
type=_validate_name,
)
.add_argument("description", location="json", store_missing=False, type=validate_description_length)
.add_argument(
"indexing_technique",
type=str,
location="json",
choices=Dataset.INDEXING_TECHNIQUE_LIST,
nullable=True,
help="Invalid indexing technique.",
)
.add_argument(
"permission",
type=str,
location="json",
choices=(
DatasetPermissionEnum.ONLY_ME,
DatasetPermissionEnum.ALL_TEAM,
DatasetPermissionEnum.PARTIAL_TEAM,
),
help="Invalid permission.",
)
.add_argument("embedding_model", type=str, location="json", help="Invalid embedding model.")
.add_argument(
"embedding_model_provider", type=str, location="json", help="Invalid embedding model provider."
)
.add_argument("retrieval_model", type=dict, location="json", help="Invalid retrieval model.")
.add_argument("partial_member_list", type=list, location="json", help="Invalid parent user list.")
.add_argument(
"external_retrieval_model",
type=dict,
required=False,
nullable=True,
location="json",
help="Invalid external retrieval model.",
)
.add_argument(
"external_knowledge_id",
type=str,
required=False,
nullable=True,
location="json",
help="Invalid external knowledge id.",
)
.add_argument(
"external_knowledge_api_id",
type=str,
required=False,
nullable=True,
location="json",
help="Invalid external knowledge api id.",
)
.add_argument(
"icon_info",
type=dict,
required=False,
nullable=True,
location="json",
help="Invalid icon info.",
)
)
args = parser.parse_args()
data = request.get_json()
current_user, current_tenant_id = current_account_with_tenant() current_user, current_tenant_id = current_account_with_tenant()
# check embedding model setting # check embedding model setting
if ( if (
data.get("indexing_technique") == "high_quality" payload.indexing_technique == "high_quality"
and data.get("embedding_model_provider") is not None and payload.embedding_model_provider is not None
and data.get("embedding_model") is not None and payload.embedding_model is not None
): ):
DatasetService.check_embedding_model_setting( is_multimodal = DatasetService.check_is_multimodal_model(
dataset.tenant_id, data.get("embedding_model_provider"), data.get("embedding_model") dataset.tenant_id, payload.embedding_model_provider, payload.embedding_model
) )
payload.is_multimodal = is_multimodal
payload_data = payload.model_dump(exclude_unset=True)
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
DatasetPermissionService.check_permission( DatasetPermissionService.check_permission(
current_user, dataset, data.get("permission"), data.get("partial_member_list") current_user, dataset, payload.permission, payload.partial_member_list
) )
dataset = DatasetService.update_dataset(dataset_id_str, args, current_user) dataset = DatasetService.update_dataset(dataset_id_str, payload_data, current_user)
if dataset is None: if dataset is None:
raise NotFound("Dataset not found.") raise NotFound("Dataset not found.")
@ -518,15 +447,10 @@ class DatasetApi(Resource):
result_data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields)) result_data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields))
tenant_id = current_tenant_id tenant_id = current_tenant_id
if data.get("partial_member_list") and data.get("permission") == "partial_members": if payload.partial_member_list is not None and payload.permission == DatasetPermissionEnum.PARTIAL_TEAM:
DatasetPermissionService.update_partial_member_list( DatasetPermissionService.update_partial_member_list(tenant_id, dataset_id_str, payload.partial_member_list)
tenant_id, dataset_id_str, data.get("partial_member_list")
)
# clear partial member list when permission is only_me or all_team_members # clear partial member list when permission is only_me or all_team_members
elif ( elif payload.permission in {DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM}:
data.get("permission") == DatasetPermissionEnum.ONLY_ME
or data.get("permission") == DatasetPermissionEnum.ALL_TEAM
):
DatasetPermissionService.clear_partial_member_list(dataset_id_str) DatasetPermissionService.clear_partial_member_list(dataset_id_str)
partial_member_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str) partial_member_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
@ -615,24 +539,10 @@ class DatasetIndexingEstimateApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@console_ns.expect(console_ns.models[IndexingEstimatePayload.__name__])
def post(self): def post(self):
parser = ( payload = IndexingEstimatePayload.model_validate(console_ns.payload or {})
reqparse.RequestParser() args = payload.model_dump()
.add_argument("info_list", type=dict, required=True, nullable=True, location="json")
.add_argument("process_rule", type=dict, required=True, nullable=True, location="json")
.add_argument(
"indexing_technique",
type=str,
required=True,
choices=Dataset.INDEXING_TECHNIQUE_LIST,
nullable=True,
location="json",
)
.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
.add_argument("dataset_id", type=str, required=False, nullable=False, location="json")
.add_argument("doc_language", type=str, default="English", required=False, nullable=False, location="json")
)
args = parser.parse_args()
_, current_tenant_id = current_account_with_tenant() _, current_tenant_id = current_account_with_tenant()
# validate args # validate args
DocumentService.estimate_args_validate(args) DocumentService.estimate_args_validate(args)

View File

@ -6,31 +6,14 @@ from typing import Literal, cast
import sqlalchemy as sa import sqlalchemy as sa
from flask import request from flask import request
from flask_restx import Resource, fields, marshal, marshal_with, reqparse from flask_restx import Resource, fields, marshal, marshal_with
from pydantic import BaseModel
from sqlalchemy import asc, desc, select from sqlalchemy import asc, desc, select
from werkzeug.exceptions import Forbidden, NotFound from werkzeug.exceptions import Forbidden, NotFound
import services import services
from controllers.common.schema import register_schema_models
from controllers.console import console_ns from controllers.console import console_ns
from controllers.console.app.error import (
ProviderModelCurrentlyNotSupportError,
ProviderNotInitializeError,
ProviderQuotaExceededError,
)
from controllers.console.datasets.error import (
ArchivedDocumentImmutableError,
DocumentAlreadyFinishedError,
DocumentIndexingError,
IndexingEstimateError,
InvalidActionError,
InvalidMetadataError,
)
from controllers.console.wraps import (
account_initialization_required,
cloud_edition_billing_rate_limit_check,
cloud_edition_billing_resource_check,
setup_required,
)
from core.errors.error import ( from core.errors.error import (
LLMBadRequestError, LLMBadRequestError,
ModelCurrentlyNotSupportError, ModelCurrentlyNotSupportError,
@ -55,10 +38,30 @@ from fields.document_fields import (
) )
from libs.datetime_utils import naive_utc_now from libs.datetime_utils import naive_utc_now
from libs.login import current_account_with_tenant, login_required from libs.login import current_account_with_tenant, login_required
from models import Dataset, DatasetProcessRule, Document, DocumentSegment, UploadFile from models import DatasetProcessRule, Document, DocumentSegment, UploadFile
from models.dataset import DocumentPipelineExecutionLog from models.dataset import DocumentPipelineExecutionLog
from services.dataset_service import DatasetService, DocumentService from services.dataset_service import DatasetService, DocumentService
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig, ProcessRule, RetrievalModel
from ..app.error import (
ProviderModelCurrentlyNotSupportError,
ProviderNotInitializeError,
ProviderQuotaExceededError,
)
from ..datasets.error import (
ArchivedDocumentImmutableError,
DocumentAlreadyFinishedError,
DocumentIndexingError,
IndexingEstimateError,
InvalidActionError,
InvalidMetadataError,
)
from ..wraps import (
account_initialization_required,
cloud_edition_billing_rate_limit_check,
cloud_edition_billing_resource_check,
setup_required,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -93,6 +96,24 @@ dataset_and_document_fields_copy["documents"] = fields.List(fields.Nested(docume
dataset_and_document_model = _get_or_create_model("DatasetAndDocument", dataset_and_document_fields_copy) dataset_and_document_model = _get_or_create_model("DatasetAndDocument", dataset_and_document_fields_copy)
class DocumentRetryPayload(BaseModel):
document_ids: list[str]
class DocumentRenamePayload(BaseModel):
name: str
register_schema_models(
console_ns,
KnowledgeConfig,
ProcessRule,
RetrievalModel,
DocumentRetryPayload,
DocumentRenamePayload,
)
class DocumentResource(Resource): class DocumentResource(Resource):
def get_document(self, dataset_id: str, document_id: str) -> Document: def get_document(self, dataset_id: str, document_id: str) -> Document:
current_user, current_tenant_id = current_account_with_tenant() current_user, current_tenant_id = current_account_with_tenant()
@ -201,8 +222,9 @@ class DatasetDocumentListApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, dataset_id: str): def get(self, dataset_id):
current_user, current_tenant_id = current_account_with_tenant() current_user, current_tenant_id = current_account_with_tenant()
dataset_id = str(dataset_id)
page = request.args.get("page", default=1, type=int) page = request.args.get("page", default=1, type=int)
limit = request.args.get("limit", default=20, type=int) limit = request.args.get("limit", default=20, type=int)
search = request.args.get("keyword", default=None, type=str) search = request.args.get("keyword", default=None, type=str)
@ -310,6 +332,7 @@ class DatasetDocumentListApi(Resource):
@marshal_with(dataset_and_document_model) @marshal_with(dataset_and_document_model)
@cloud_edition_billing_resource_check("vector_space") @cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge") @cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.expect(console_ns.models[KnowledgeConfig.__name__])
def post(self, dataset_id): def post(self, dataset_id):
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
@ -328,23 +351,7 @@ class DatasetDocumentListApi(Resource):
except services.errors.account.NoPermissionError as e: except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e)) raise Forbidden(str(e))
parser = ( knowledge_config = KnowledgeConfig.model_validate(console_ns.payload or {})
reqparse.RequestParser()
.add_argument(
"indexing_technique", type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=False, location="json"
)
.add_argument("data_source", type=dict, required=False, location="json")
.add_argument("process_rule", type=dict, required=False, location="json")
.add_argument("duplicate", type=bool, default=True, nullable=False, location="json")
.add_argument("original_document_id", type=str, required=False, location="json")
.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json")
.add_argument("embedding_model", type=str, required=False, nullable=True, location="json")
.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
.add_argument("doc_language", type=str, default="English", required=False, nullable=False, location="json")
)
args = parser.parse_args()
knowledge_config = KnowledgeConfig.model_validate(args)
if not dataset.indexing_technique and not knowledge_config.indexing_technique: if not dataset.indexing_technique and not knowledge_config.indexing_technique:
raise ValueError("indexing_technique is required.") raise ValueError("indexing_technique is required.")
@ -390,17 +397,7 @@ class DatasetDocumentListApi(Resource):
class DatasetInitApi(Resource): class DatasetInitApi(Resource):
@console_ns.doc("init_dataset") @console_ns.doc("init_dataset")
@console_ns.doc(description="Initialize dataset with documents") @console_ns.doc(description="Initialize dataset with documents")
@console_ns.expect( @console_ns.expect(console_ns.models[KnowledgeConfig.__name__])
console_ns.model(
"DatasetInitRequest",
{
"upload_file_id": fields.String(required=True, description="Upload file ID"),
"indexing_technique": fields.String(description="Indexing technique"),
"process_rule": fields.Raw(description="Processing rules"),
"data_source": fields.Raw(description="Data source configuration"),
},
)
)
@console_ns.response(201, "Dataset initialized successfully", dataset_and_document_model) @console_ns.response(201, "Dataset initialized successfully", dataset_and_document_model)
@console_ns.response(400, "Invalid request parameters") @console_ns.response(400, "Invalid request parameters")
@setup_required @setup_required
@ -415,27 +412,7 @@ class DatasetInitApi(Resource):
if not current_user.is_dataset_editor: if not current_user.is_dataset_editor:
raise Forbidden() raise Forbidden()
parser = ( knowledge_config = KnowledgeConfig.model_validate(console_ns.payload or {})
reqparse.RequestParser()
.add_argument(
"indexing_technique",
type=str,
choices=Dataset.INDEXING_TECHNIQUE_LIST,
required=True,
nullable=False,
location="json",
)
.add_argument("data_source", type=dict, required=True, nullable=True, location="json")
.add_argument("process_rule", type=dict, required=True, nullable=True, location="json")
.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
.add_argument("doc_language", type=str, default="English", required=False, nullable=False, location="json")
.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json")
.add_argument("embedding_model", type=str, required=False, nullable=True, location="json")
.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
)
args = parser.parse_args()
knowledge_config = KnowledgeConfig.model_validate(args)
if knowledge_config.indexing_technique == "high_quality": if knowledge_config.indexing_technique == "high_quality":
if knowledge_config.embedding_model is None or knowledge_config.embedding_model_provider is None: if knowledge_config.embedding_model is None or knowledge_config.embedding_model_provider is None:
raise ValueError("embedding model and embedding model provider are required for high quality indexing.") raise ValueError("embedding model and embedding model provider are required for high quality indexing.")
@ -443,10 +420,14 @@ class DatasetInitApi(Resource):
model_manager = ModelManager() model_manager = ModelManager()
model_manager.get_model_instance( model_manager.get_model_instance(
tenant_id=current_tenant_id, tenant_id=current_tenant_id,
provider=args["embedding_model_provider"], provider=knowledge_config.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING, model_type=ModelType.TEXT_EMBEDDING,
model=args["embedding_model"], model=knowledge_config.embedding_model,
) )
is_multimodal = DatasetService.check_is_multimodal_model(
current_tenant_id, knowledge_config.embedding_model_provider, knowledge_config.embedding_model
)
knowledge_config.is_multimodal = is_multimodal
except InvokeAuthorizationError: except InvokeAuthorizationError:
raise ProviderNotInitializeError( raise ProviderNotInitializeError(
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider." "No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
@ -1076,19 +1057,16 @@ class DocumentRetryApi(DocumentResource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge") @cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.expect(console_ns.models[DocumentRetryPayload.__name__])
def post(self, dataset_id): def post(self, dataset_id):
"""retry document.""" """retry document."""
payload = DocumentRetryPayload.model_validate(console_ns.payload or {})
parser = reqparse.RequestParser().add_argument(
"document_ids", type=list, required=True, nullable=False, location="json"
)
args = parser.parse_args()
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id) dataset = DatasetService.get_dataset(dataset_id)
retry_documents = [] retry_documents = []
if not dataset: if not dataset:
raise NotFound("Dataset not found.") raise NotFound("Dataset not found.")
for document_id in args["document_ids"]: for document_id in payload.document_ids:
try: try:
document_id = str(document_id) document_id = str(document_id)
@ -1121,6 +1099,7 @@ class DocumentRenameApi(DocumentResource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@marshal_with(document_fields) @marshal_with(document_fields)
@console_ns.expect(console_ns.models[DocumentRenamePayload.__name__])
def post(self, dataset_id, document_id): def post(self, dataset_id, document_id):
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
@ -1130,11 +1109,10 @@ class DocumentRenameApi(DocumentResource):
if not dataset: if not dataset:
raise NotFound("Dataset not found.") raise NotFound("Dataset not found.")
DatasetService.check_dataset_operator_permission(current_user, dataset) DatasetService.check_dataset_operator_permission(current_user, dataset)
parser = reqparse.RequestParser().add_argument("name", type=str, required=True, nullable=False, location="json") payload = DocumentRenamePayload.model_validate(console_ns.payload or {})
args = parser.parse_args()
try: try:
document = DocumentService.rename_document(dataset_id, document_id, args["name"]) document = DocumentService.rename_document(dataset_id, document_id, payload.name)
except services.errors.document.DocumentIndexingError: except services.errors.document.DocumentIndexingError:
raise DocumentIndexingError("Cannot delete document during indexing.") raise DocumentIndexingError("Cannot delete document during indexing.")

View File

@ -1,11 +1,13 @@
import uuid import uuid
from flask import request from flask import request
from flask_restx import Resource, marshal, reqparse from flask_restx import Resource, marshal
from pydantic import BaseModel, Field
from sqlalchemy import select from sqlalchemy import select
from werkzeug.exceptions import Forbidden, NotFound from werkzeug.exceptions import Forbidden, NotFound
import services import services
from controllers.common.schema import register_schema_models
from controllers.console import console_ns from controllers.console import console_ns
from controllers.console.app.error import ProviderNotInitializeError from controllers.console.app.error import ProviderNotInitializeError
from controllers.console.datasets.error import ( from controllers.console.datasets.error import (
@ -36,6 +38,58 @@ from services.errors.chunk import ChildChunkIndexingError as ChildChunkIndexingS
from tasks.batch_create_segment_to_index_task import batch_create_segment_to_index_task from tasks.batch_create_segment_to_index_task import batch_create_segment_to_index_task
class SegmentListQuery(BaseModel):
limit: int = Field(default=20, ge=1, le=100)
status: list[str] = Field(default_factory=list)
hit_count_gte: int | None = None
enabled: str = Field(default="all")
keyword: str | None = None
page: int = Field(default=1, ge=1)
class SegmentCreatePayload(BaseModel):
content: str
answer: str | None = None
keywords: list[str] | None = None
attachment_ids: list[str] | None = None
class SegmentUpdatePayload(BaseModel):
content: str
answer: str | None = None
keywords: list[str] | None = None
regenerate_child_chunks: bool = False
attachment_ids: list[str] | None = None
class BatchImportPayload(BaseModel):
upload_file_id: str
class ChildChunkCreatePayload(BaseModel):
content: str
class ChildChunkUpdatePayload(BaseModel):
content: str
class ChildChunkBatchUpdatePayload(BaseModel):
chunks: list[ChildChunkUpdateArgs]
register_schema_models(
console_ns,
SegmentListQuery,
SegmentCreatePayload,
SegmentUpdatePayload,
BatchImportPayload,
ChildChunkCreatePayload,
ChildChunkUpdatePayload,
ChildChunkBatchUpdatePayload,
)
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments") @console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments")
class DatasetDocumentSegmentListApi(Resource): class DatasetDocumentSegmentListApi(Resource):
@setup_required @setup_required
@ -60,23 +114,18 @@ class DatasetDocumentSegmentListApi(Resource):
if not document: if not document:
raise NotFound("Document not found.") raise NotFound("Document not found.")
parser = ( args = SegmentListQuery.model_validate(
reqparse.RequestParser() {
.add_argument("limit", type=int, default=20, location="args") **request.args.to_dict(),
.add_argument("status", type=str, action="append", default=[], location="args") "status": request.args.getlist("status"),
.add_argument("hit_count_gte", type=int, default=None, location="args") }
.add_argument("enabled", type=str, default="all", location="args")
.add_argument("keyword", type=str, default=None, location="args")
.add_argument("page", type=int, default=1, location="args")
) )
args = parser.parse_args() page = args.page
limit = min(args.limit, 100)
page = args["page"] status_list = args.status
limit = min(args["limit"], 100) hit_count_gte = args.hit_count_gte
status_list = args["status"] keyword = args.keyword
hit_count_gte = args["hit_count_gte"]
keyword = args["keyword"]
query = ( query = (
select(DocumentSegment) select(DocumentSegment)
@ -96,10 +145,10 @@ class DatasetDocumentSegmentListApi(Resource):
if keyword: if keyword:
query = query.where(DocumentSegment.content.ilike(f"%{keyword}%")) query = query.where(DocumentSegment.content.ilike(f"%{keyword}%"))
if args["enabled"].lower() != "all": if args.enabled.lower() != "all":
if args["enabled"].lower() == "true": if args.enabled.lower() == "true":
query = query.where(DocumentSegment.enabled == True) query = query.where(DocumentSegment.enabled == True)
elif args["enabled"].lower() == "false": elif args.enabled.lower() == "false":
query = query.where(DocumentSegment.enabled == False) query = query.where(DocumentSegment.enabled == False)
segments = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False) segments = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
@ -210,6 +259,7 @@ class DatasetDocumentSegmentAddApi(Resource):
@cloud_edition_billing_resource_check("vector_space") @cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_knowledge_limit_check("add_segment") @cloud_edition_billing_knowledge_limit_check("add_segment")
@cloud_edition_billing_rate_limit_check("knowledge") @cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.expect(console_ns.models[SegmentCreatePayload.__name__])
def post(self, dataset_id, document_id): def post(self, dataset_id, document_id):
current_user, current_tenant_id = current_account_with_tenant() current_user, current_tenant_id = current_account_with_tenant()
@ -246,15 +296,10 @@ class DatasetDocumentSegmentAddApi(Resource):
except services.errors.account.NoPermissionError as e: except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e)) raise Forbidden(str(e))
# validate args # validate args
parser = ( payload = SegmentCreatePayload.model_validate(console_ns.payload or {})
reqparse.RequestParser() payload_dict = payload.model_dump(exclude_none=True)
.add_argument("content", type=str, required=True, nullable=False, location="json") SegmentService.segment_create_args_validate(payload_dict, document)
.add_argument("answer", type=str, required=False, nullable=True, location="json") segment = SegmentService.create_segment(payload_dict, document, dataset)
.add_argument("keywords", type=list, required=False, nullable=True, location="json")
)
args = parser.parse_args()
SegmentService.segment_create_args_validate(args, document)
segment = SegmentService.create_segment(args, document, dataset)
return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200 return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200
@ -265,6 +310,7 @@ class DatasetDocumentSegmentUpdateApi(Resource):
@account_initialization_required @account_initialization_required
@cloud_edition_billing_resource_check("vector_space") @cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge") @cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.expect(console_ns.models[SegmentUpdatePayload.__name__])
def patch(self, dataset_id, document_id, segment_id): def patch(self, dataset_id, document_id, segment_id):
current_user, current_tenant_id = current_account_with_tenant() current_user, current_tenant_id = current_account_with_tenant()
@ -313,18 +359,12 @@ class DatasetDocumentSegmentUpdateApi(Resource):
except services.errors.account.NoPermissionError as e: except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e)) raise Forbidden(str(e))
# validate args # validate args
parser = ( payload = SegmentUpdatePayload.model_validate(console_ns.payload or {})
reqparse.RequestParser() payload_dict = payload.model_dump(exclude_none=True)
.add_argument("content", type=str, required=True, nullable=False, location="json") SegmentService.segment_create_args_validate(payload_dict, document)
.add_argument("answer", type=str, required=False, nullable=True, location="json") segment = SegmentService.update_segment(
.add_argument("keywords", type=list, required=False, nullable=True, location="json") SegmentUpdateArgs.model_validate(payload.model_dump(exclude_none=True)), segment, document, dataset
.add_argument(
"regenerate_child_chunks", type=bool, required=False, nullable=True, default=False, location="json"
)
) )
args = parser.parse_args()
SegmentService.segment_create_args_validate(args, document)
segment = SegmentService.update_segment(SegmentUpdateArgs.model_validate(args), segment, document, dataset)
return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200 return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200
@setup_required @setup_required
@ -377,6 +417,7 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
@cloud_edition_billing_resource_check("vector_space") @cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_knowledge_limit_check("add_segment") @cloud_edition_billing_knowledge_limit_check("add_segment")
@cloud_edition_billing_rate_limit_check("knowledge") @cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.expect(console_ns.models[BatchImportPayload.__name__])
def post(self, dataset_id, document_id): def post(self, dataset_id, document_id):
current_user, current_tenant_id = current_account_with_tenant() current_user, current_tenant_id = current_account_with_tenant()
@ -391,11 +432,8 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
if not document: if not document:
raise NotFound("Document not found.") raise NotFound("Document not found.")
parser = reqparse.RequestParser().add_argument( payload = BatchImportPayload.model_validate(console_ns.payload or {})
"upload_file_id", type=str, required=True, nullable=False, location="json" upload_file_id = payload.upload_file_id
)
args = parser.parse_args()
upload_file_id = args["upload_file_id"]
upload_file = db.session.query(UploadFile).where(UploadFile.id == upload_file_id).first() upload_file = db.session.query(UploadFile).where(UploadFile.id == upload_file_id).first()
if not upload_file: if not upload_file:
@ -446,6 +484,7 @@ class ChildChunkAddApi(Resource):
@cloud_edition_billing_resource_check("vector_space") @cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_knowledge_limit_check("add_segment") @cloud_edition_billing_knowledge_limit_check("add_segment")
@cloud_edition_billing_rate_limit_check("knowledge") @cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.expect(console_ns.models[ChildChunkCreatePayload.__name__])
def post(self, dataset_id, document_id, segment_id): def post(self, dataset_id, document_id, segment_id):
current_user, current_tenant_id = current_account_with_tenant() current_user, current_tenant_id = current_account_with_tenant()
@ -491,13 +530,9 @@ class ChildChunkAddApi(Resource):
except services.errors.account.NoPermissionError as e: except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e)) raise Forbidden(str(e))
# validate args # validate args
parser = reqparse.RequestParser().add_argument(
"content", type=str, required=True, nullable=False, location="json"
)
args = parser.parse_args()
try: try:
content = args["content"] payload = ChildChunkCreatePayload.model_validate(console_ns.payload or {})
child_chunk = SegmentService.create_child_chunk(content, segment, document, dataset) child_chunk = SegmentService.create_child_chunk(payload.content, segment, document, dataset)
except ChildChunkIndexingServiceError as e: except ChildChunkIndexingServiceError as e:
raise ChildChunkIndexingError(str(e)) raise ChildChunkIndexingError(str(e))
return {"data": marshal(child_chunk, child_chunk_fields)}, 200 return {"data": marshal(child_chunk, child_chunk_fields)}, 200
@ -529,18 +564,17 @@ class ChildChunkAddApi(Resource):
) )
if not segment: if not segment:
raise NotFound("Segment not found.") raise NotFound("Segment not found.")
parser = ( args = SegmentListQuery.model_validate(
reqparse.RequestParser() {
.add_argument("limit", type=int, default=20, location="args") "limit": request.args.get("limit", default=20, type=int),
.add_argument("keyword", type=str, default=None, location="args") "keyword": request.args.get("keyword"),
.add_argument("page", type=int, default=1, location="args") "page": request.args.get("page", default=1, type=int),
}
) )
args = parser.parse_args() page = args.page
limit = min(args.limit, 100)
page = args["page"] keyword = args.keyword
limit = min(args["limit"], 100)
keyword = args["keyword"]
child_chunks = SegmentService.get_child_chunks(segment_id, document_id, dataset_id, page, limit, keyword) child_chunks = SegmentService.get_child_chunks(segment_id, document_id, dataset_id, page, limit, keyword)
return { return {
@ -588,14 +622,9 @@ class ChildChunkAddApi(Resource):
except services.errors.account.NoPermissionError as e: except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e)) raise Forbidden(str(e))
# validate args # validate args
parser = reqparse.RequestParser().add_argument( payload = ChildChunkBatchUpdatePayload.model_validate(console_ns.payload or {})
"chunks", type=list, required=True, nullable=False, location="json"
)
args = parser.parse_args()
try: try:
chunks_data = args["chunks"] child_chunks = SegmentService.update_child_chunks(payload.chunks, segment, document, dataset)
chunks = [ChildChunkUpdateArgs.model_validate(chunk) for chunk in chunks_data]
child_chunks = SegmentService.update_child_chunks(chunks, segment, document, dataset)
except ChildChunkIndexingServiceError as e: except ChildChunkIndexingServiceError as e:
raise ChildChunkIndexingError(str(e)) raise ChildChunkIndexingError(str(e))
return {"data": marshal(child_chunks, child_chunk_fields)}, 200 return {"data": marshal(child_chunks, child_chunk_fields)}, 200
@ -665,6 +694,7 @@ class ChildChunkUpdateApi(Resource):
@account_initialization_required @account_initialization_required
@cloud_edition_billing_resource_check("vector_space") @cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge") @cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.expect(console_ns.models[ChildChunkUpdatePayload.__name__])
def patch(self, dataset_id, document_id, segment_id, child_chunk_id): def patch(self, dataset_id, document_id, segment_id, child_chunk_id):
current_user, current_tenant_id = current_account_with_tenant() current_user, current_tenant_id = current_account_with_tenant()
@ -711,13 +741,9 @@ class ChildChunkUpdateApi(Resource):
except services.errors.account.NoPermissionError as e: except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e)) raise Forbidden(str(e))
# validate args # validate args
parser = reqparse.RequestParser().add_argument(
"content", type=str, required=True, nullable=False, location="json"
)
args = parser.parse_args()
try: try:
content = args["content"] payload = ChildChunkUpdatePayload.model_validate(console_ns.payload or {})
child_chunk = SegmentService.update_child_chunk(content, child_chunk, segment, document, dataset) child_chunk = SegmentService.update_child_chunk(payload.content, child_chunk, segment, document, dataset)
except ChildChunkIndexingServiceError as e: except ChildChunkIndexingServiceError as e:
raise ChildChunkIndexingError(str(e)) raise ChildChunkIndexingError(str(e))
return {"data": marshal(child_chunk, child_chunk_fields)}, 200 return {"data": marshal(child_chunk, child_chunk_fields)}, 200

View File

@ -1,8 +1,10 @@
from flask import request from flask import request
from flask_restx import Resource, fields, marshal, reqparse from flask_restx import Resource, fields, marshal
from pydantic import BaseModel, Field
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
import services import services
from controllers.common.schema import register_schema_models
from controllers.console import console_ns from controllers.console import console_ns
from controllers.console.datasets.error import DatasetNameDuplicateError from controllers.console.datasets.error import DatasetNameDuplicateError
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
@ -71,10 +73,38 @@ except KeyError:
dataset_detail_model = _build_dataset_detail_model() dataset_detail_model = _build_dataset_detail_model()
def _validate_name(name: str) -> str: class ExternalKnowledgeApiPayload(BaseModel):
if not name or len(name) < 1 or len(name) > 100: name: str = Field(..., min_length=1, max_length=40)
raise ValueError("Name must be between 1 to 100 characters.") settings: dict[str, object]
return name
class ExternalDatasetCreatePayload(BaseModel):
external_knowledge_api_id: str
external_knowledge_id: str
name: str = Field(..., min_length=1, max_length=40)
description: str | None = Field(None, max_length=400)
external_retrieval_model: dict[str, object] | None = None
class ExternalHitTestingPayload(BaseModel):
query: str
external_retrieval_model: dict[str, object] | None = None
metadata_filtering_conditions: dict[str, object] | None = None
class BedrockRetrievalPayload(BaseModel):
retrieval_setting: dict[str, object]
query: str
knowledge_id: str
register_schema_models(
console_ns,
ExternalKnowledgeApiPayload,
ExternalDatasetCreatePayload,
ExternalHitTestingPayload,
BedrockRetrievalPayload,
)
@console_ns.route("/datasets/external-knowledge-api") @console_ns.route("/datasets/external-knowledge-api")
@ -113,28 +143,12 @@ class ExternalApiTemplateListApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@console_ns.expect(console_ns.models[ExternalKnowledgeApiPayload.__name__])
def post(self): def post(self):
current_user, current_tenant_id = current_account_with_tenant() current_user, current_tenant_id = current_account_with_tenant()
parser = ( payload = ExternalKnowledgeApiPayload.model_validate(console_ns.payload or {})
reqparse.RequestParser()
.add_argument(
"name",
nullable=False,
required=True,
help="Name is required. Name must be between 1 to 100 characters.",
type=_validate_name,
)
.add_argument(
"settings",
type=dict,
location="json",
nullable=False,
required=True,
)
)
args = parser.parse_args()
ExternalDatasetService.validate_api_list(args["settings"]) ExternalDatasetService.validate_api_list(payload.settings)
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator # The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
if not current_user.is_dataset_editor: if not current_user.is_dataset_editor:
@ -142,7 +156,7 @@ class ExternalApiTemplateListApi(Resource):
try: try:
external_knowledge_api = ExternalDatasetService.create_external_knowledge_api( external_knowledge_api = ExternalDatasetService.create_external_knowledge_api(
tenant_id=current_tenant_id, user_id=current_user.id, args=args tenant_id=current_tenant_id, user_id=current_user.id, args=payload.model_dump()
) )
except services.errors.dataset.DatasetNameDuplicateError: except services.errors.dataset.DatasetNameDuplicateError:
raise DatasetNameDuplicateError() raise DatasetNameDuplicateError()
@ -171,35 +185,19 @@ class ExternalApiTemplateApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@console_ns.expect(console_ns.models[ExternalKnowledgeApiPayload.__name__])
def patch(self, external_knowledge_api_id): def patch(self, external_knowledge_api_id):
current_user, current_tenant_id = current_account_with_tenant() current_user, current_tenant_id = current_account_with_tenant()
external_knowledge_api_id = str(external_knowledge_api_id) external_knowledge_api_id = str(external_knowledge_api_id)
parser = ( payload = ExternalKnowledgeApiPayload.model_validate(console_ns.payload or {})
reqparse.RequestParser() ExternalDatasetService.validate_api_list(payload.settings)
.add_argument(
"name",
nullable=False,
required=True,
help="type is required. Name must be between 1 to 100 characters.",
type=_validate_name,
)
.add_argument(
"settings",
type=dict,
location="json",
nullable=False,
required=True,
)
)
args = parser.parse_args()
ExternalDatasetService.validate_api_list(args["settings"])
external_knowledge_api = ExternalDatasetService.update_external_knowledge_api( external_knowledge_api = ExternalDatasetService.update_external_knowledge_api(
tenant_id=current_tenant_id, tenant_id=current_tenant_id,
user_id=current_user.id, user_id=current_user.id,
external_knowledge_api_id=external_knowledge_api_id, external_knowledge_api_id=external_knowledge_api_id,
args=args, args=payload.model_dump(),
) )
return external_knowledge_api.to_dict(), 200 return external_knowledge_api.to_dict(), 200
@ -240,17 +238,7 @@ class ExternalApiUseCheckApi(Resource):
class ExternalDatasetCreateApi(Resource): class ExternalDatasetCreateApi(Resource):
@console_ns.doc("create_external_dataset") @console_ns.doc("create_external_dataset")
@console_ns.doc(description="Create external knowledge dataset") @console_ns.doc(description="Create external knowledge dataset")
@console_ns.expect( @console_ns.expect(console_ns.models[ExternalDatasetCreatePayload.__name__])
console_ns.model(
"CreateExternalDatasetRequest",
{
"external_knowledge_api_id": fields.String(required=True, description="External knowledge API ID"),
"external_knowledge_id": fields.String(required=True, description="External knowledge ID"),
"name": fields.String(required=True, description="Dataset name"),
"description": fields.String(description="Dataset description"),
},
)
)
@console_ns.response(201, "External dataset created successfully", dataset_detail_model) @console_ns.response(201, "External dataset created successfully", dataset_detail_model)
@console_ns.response(400, "Invalid parameters") @console_ns.response(400, "Invalid parameters")
@console_ns.response(403, "Permission denied") @console_ns.response(403, "Permission denied")
@ -261,22 +249,8 @@ class ExternalDatasetCreateApi(Resource):
def post(self): def post(self):
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
current_user, current_tenant_id = current_account_with_tenant() current_user, current_tenant_id = current_account_with_tenant()
parser = ( payload = ExternalDatasetCreatePayload.model_validate(console_ns.payload or {})
reqparse.RequestParser() args = payload.model_dump(exclude_none=True)
.add_argument("external_knowledge_api_id", type=str, required=True, nullable=False, location="json")
.add_argument("external_knowledge_id", type=str, required=True, nullable=False, location="json")
.add_argument(
"name",
nullable=False,
required=True,
help="name is required. Name must be between 1 to 100 characters.",
type=_validate_name,
)
.add_argument("description", type=str, required=False, nullable=True, location="json")
.add_argument("external_retrieval_model", type=dict, required=False, location="json")
)
args = parser.parse_args()
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator # The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
if not current_user.is_dataset_editor: if not current_user.is_dataset_editor:
@ -299,16 +273,7 @@ class ExternalKnowledgeHitTestingApi(Resource):
@console_ns.doc("test_external_knowledge_retrieval") @console_ns.doc("test_external_knowledge_retrieval")
@console_ns.doc(description="Test external knowledge retrieval for dataset") @console_ns.doc(description="Test external knowledge retrieval for dataset")
@console_ns.doc(params={"dataset_id": "Dataset ID"}) @console_ns.doc(params={"dataset_id": "Dataset ID"})
@console_ns.expect( @console_ns.expect(console_ns.models[ExternalHitTestingPayload.__name__])
console_ns.model(
"ExternalHitTestingRequest",
{
"query": fields.String(required=True, description="Query text for testing"),
"retrieval_model": fields.Raw(description="Retrieval model configuration"),
"external_retrieval_model": fields.Raw(description="External retrieval model configuration"),
},
)
)
@console_ns.response(200, "External hit testing completed successfully") @console_ns.response(200, "External hit testing completed successfully")
@console_ns.response(404, "Dataset not found") @console_ns.response(404, "Dataset not found")
@console_ns.response(400, "Invalid parameters") @console_ns.response(400, "Invalid parameters")
@ -327,23 +292,16 @@ class ExternalKnowledgeHitTestingApi(Resource):
except services.errors.account.NoPermissionError as e: except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e)) raise Forbidden(str(e))
parser = ( payload = ExternalHitTestingPayload.model_validate(console_ns.payload or {})
reqparse.RequestParser() HitTestingService.hit_testing_args_check(payload.model_dump())
.add_argument("query", type=str, location="json")
.add_argument("external_retrieval_model", type=dict, required=False, location="json")
.add_argument("metadata_filtering_conditions", type=dict, required=False, location="json")
)
args = parser.parse_args()
HitTestingService.hit_testing_args_check(args)
try: try:
response = HitTestingService.external_retrieve( response = HitTestingService.external_retrieve(
dataset=dataset, dataset=dataset,
query=args["query"], query=payload.query,
account=current_user, account=current_user,
external_retrieval_model=args["external_retrieval_model"], external_retrieval_model=payload.external_retrieval_model,
metadata_filtering_conditions=args["metadata_filtering_conditions"], metadata_filtering_conditions=payload.metadata_filtering_conditions,
) )
return response return response
@ -356,33 +314,13 @@ class BedrockRetrievalApi(Resource):
# this api is only for internal testing # this api is only for internal testing
@console_ns.doc("bedrock_retrieval_test") @console_ns.doc("bedrock_retrieval_test")
@console_ns.doc(description="Bedrock retrieval test (internal use only)") @console_ns.doc(description="Bedrock retrieval test (internal use only)")
@console_ns.expect( @console_ns.expect(console_ns.models[BedrockRetrievalPayload.__name__])
console_ns.model(
"BedrockRetrievalTestRequest",
{
"retrieval_setting": fields.Raw(required=True, description="Retrieval settings"),
"query": fields.String(required=True, description="Query text"),
"knowledge_id": fields.String(required=True, description="Knowledge ID"),
},
)
)
@console_ns.response(200, "Bedrock retrieval test completed") @console_ns.response(200, "Bedrock retrieval test completed")
def post(self): def post(self):
parser = ( payload = BedrockRetrievalPayload.model_validate(console_ns.payload or {})
reqparse.RequestParser()
.add_argument("retrieval_setting", nullable=False, required=True, type=dict, location="json")
.add_argument(
"query",
nullable=False,
required=True,
type=str,
)
.add_argument("knowledge_id", nullable=False, required=True, type=str)
)
args = parser.parse_args()
# Call the knowledge retrieval service # Call the knowledge retrieval service
result = ExternalDatasetTestService.knowledge_retrieval( result = ExternalDatasetTestService.knowledge_retrieval(
args["retrieval_setting"], args["query"], args["knowledge_id"] payload.retrieval_setting, payload.query, payload.knowledge_id
) )
return result, 200 return result, 200

View File

@ -1,13 +1,17 @@
from flask_restx import Resource, fields from flask_restx import Resource
from controllers.console import console_ns from controllers.common.schema import register_schema_model
from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase from libs.login import login_required
from controllers.console.wraps import (
from .. import console_ns
from ..datasets.hit_testing_base import DatasetsHitTestingBase, HitTestingPayload
from ..wraps import (
account_initialization_required, account_initialization_required,
cloud_edition_billing_rate_limit_check, cloud_edition_billing_rate_limit_check,
setup_required, setup_required,
) )
from libs.login import login_required
register_schema_model(console_ns, HitTestingPayload)
@console_ns.route("/datasets/<uuid:dataset_id>/hit-testing") @console_ns.route("/datasets/<uuid:dataset_id>/hit-testing")
@ -15,17 +19,7 @@ class HitTestingApi(Resource, DatasetsHitTestingBase):
@console_ns.doc("test_dataset_retrieval") @console_ns.doc("test_dataset_retrieval")
@console_ns.doc(description="Test dataset knowledge retrieval") @console_ns.doc(description="Test dataset knowledge retrieval")
@console_ns.doc(params={"dataset_id": "Dataset ID"}) @console_ns.doc(params={"dataset_id": "Dataset ID"})
@console_ns.expect( @console_ns.expect(console_ns.models[HitTestingPayload.__name__])
console_ns.model(
"HitTestingRequest",
{
"query": fields.String(required=True, description="Query text for testing"),
"retrieval_model": fields.Raw(description="Retrieval model configuration"),
"top_k": fields.Integer(description="Number of top results to return"),
"score_threshold": fields.Float(description="Score threshold for filtering results"),
},
)
)
@console_ns.response(200, "Hit testing completed successfully") @console_ns.response(200, "Hit testing completed successfully")
@console_ns.response(404, "Dataset not found") @console_ns.response(404, "Dataset not found")
@console_ns.response(400, "Invalid parameters") @console_ns.response(400, "Invalid parameters")
@ -37,7 +31,8 @@ class HitTestingApi(Resource, DatasetsHitTestingBase):
dataset_id_str = str(dataset_id) dataset_id_str = str(dataset_id)
dataset = self.get_and_validate_dataset(dataset_id_str) dataset = self.get_and_validate_dataset(dataset_id_str)
args = self.parse_args() payload = HitTestingPayload.model_validate(console_ns.payload or {})
args = payload.model_dump(exclude_none=True)
self.hit_testing_args_check(args) self.hit_testing_args_check(args)
return self.perform_hit_testing(dataset, args) return self.perform_hit_testing(dataset, args)

View File

@ -1,6 +1,8 @@
import logging import logging
from typing import Any
from flask_restx import marshal, reqparse from flask_restx import marshal, reqparse
from pydantic import BaseModel, Field
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
import services import services
@ -27,6 +29,13 @@ from services.hit_testing_service import HitTestingService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class HitTestingPayload(BaseModel):
query: str = Field(max_length=250)
retrieval_model: dict[str, Any] | None = None
external_retrieval_model: dict[str, Any] | None = None
attachment_ids: list[str] | None = None
class DatasetsHitTestingBase: class DatasetsHitTestingBase:
@staticmethod @staticmethod
def get_and_validate_dataset(dataset_id: str): def get_and_validate_dataset(dataset_id: str):
@ -43,14 +52,15 @@ class DatasetsHitTestingBase:
return dataset return dataset
@staticmethod @staticmethod
def hit_testing_args_check(args): def hit_testing_args_check(args: dict[str, Any]):
HitTestingService.hit_testing_args_check(args) HitTestingService.hit_testing_args_check(args)
@staticmethod @staticmethod
def parse_args(): def parse_args():
parser = ( parser = (
reqparse.RequestParser() reqparse.RequestParser()
.add_argument("query", type=str, location="json") .add_argument("query", type=str, required=False, location="json")
.add_argument("attachment_ids", type=list, required=False, location="json")
.add_argument("retrieval_model", type=dict, required=False, location="json") .add_argument("retrieval_model", type=dict, required=False, location="json")
.add_argument("external_retrieval_model", type=dict, required=False, location="json") .add_argument("external_retrieval_model", type=dict, required=False, location="json")
) )
@ -62,10 +72,11 @@ class DatasetsHitTestingBase:
try: try:
response = HitTestingService.retrieve( response = HitTestingService.retrieve(
dataset=dataset, dataset=dataset,
query=args["query"], query=args.get("query"),
account=current_user, account=current_user,
retrieval_model=args["retrieval_model"], retrieval_model=args.get("retrieval_model"),
external_retrieval_model=args["external_retrieval_model"], external_retrieval_model=args.get("external_retrieval_model"),
attachment_ids=args.get("attachment_ids"),
limit=10, limit=10,
) )
return {"query": response["query"], "records": marshal(response["records"], hit_testing_record_fields)} return {"query": response["query"], "records": marshal(response["records"], hit_testing_record_fields)}

View File

@ -1,8 +1,10 @@
from typing import Literal from typing import Literal
from flask_restx import Resource, marshal_with, reqparse from flask_restx import Resource, marshal_with
from pydantic import BaseModel
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
from controllers.common.schema import register_schema_model, register_schema_models
from controllers.console import console_ns from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required
from fields.dataset_fields import dataset_metadata_fields from fields.dataset_fields import dataset_metadata_fields
@ -15,6 +17,14 @@ from services.entities.knowledge_entities.knowledge_entities import (
from services.metadata_service import MetadataService from services.metadata_service import MetadataService
class MetadataUpdatePayload(BaseModel):
name: str
register_schema_models(console_ns, MetadataArgs, MetadataOperationData)
register_schema_model(console_ns, MetadataUpdatePayload)
@console_ns.route("/datasets/<uuid:dataset_id>/metadata") @console_ns.route("/datasets/<uuid:dataset_id>/metadata")
class DatasetMetadataCreateApi(Resource): class DatasetMetadataCreateApi(Resource):
@setup_required @setup_required
@ -22,15 +32,10 @@ class DatasetMetadataCreateApi(Resource):
@account_initialization_required @account_initialization_required
@enterprise_license_required @enterprise_license_required
@marshal_with(dataset_metadata_fields) @marshal_with(dataset_metadata_fields)
@console_ns.expect(console_ns.models[MetadataArgs.__name__])
def post(self, dataset_id): def post(self, dataset_id):
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
parser = ( metadata_args = MetadataArgs.model_validate(console_ns.payload or {})
reqparse.RequestParser()
.add_argument("type", type=str, required=True, nullable=False, location="json")
.add_argument("name", type=str, required=True, nullable=False, location="json")
)
args = parser.parse_args()
metadata_args = MetadataArgs.model_validate(args)
dataset_id_str = str(dataset_id) dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str) dataset = DatasetService.get_dataset(dataset_id_str)
@ -60,11 +65,11 @@ class DatasetMetadataApi(Resource):
@account_initialization_required @account_initialization_required
@enterprise_license_required @enterprise_license_required
@marshal_with(dataset_metadata_fields) @marshal_with(dataset_metadata_fields)
@console_ns.expect(console_ns.models[MetadataUpdatePayload.__name__])
def patch(self, dataset_id, metadata_id): def patch(self, dataset_id, metadata_id):
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser().add_argument("name", type=str, required=True, nullable=False, location="json") payload = MetadataUpdatePayload.model_validate(console_ns.payload or {})
args = parser.parse_args() name = payload.name
name = args["name"]
dataset_id_str = str(dataset_id) dataset_id_str = str(dataset_id)
metadata_id_str = str(metadata_id) metadata_id_str = str(metadata_id)
@ -131,6 +136,7 @@ class DocumentMetadataEditApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@enterprise_license_required @enterprise_license_required
@console_ns.expect(console_ns.models[MetadataOperationData.__name__])
def post(self, dataset_id): def post(self, dataset_id):
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
dataset_id_str = str(dataset_id) dataset_id_str = str(dataset_id)
@ -139,11 +145,7 @@ class DocumentMetadataEditApi(Resource):
raise NotFound("Dataset not found.") raise NotFound("Dataset not found.")
DatasetService.check_dataset_permission(dataset, current_user) DatasetService.check_dataset_permission(dataset, current_user)
parser = reqparse.RequestParser().add_argument( metadata_args = MetadataOperationData.model_validate(console_ns.payload or {})
"operation_data", type=list, required=True, nullable=False, location="json"
)
args = parser.parse_args()
metadata_args = MetadataOperationData.model_validate(args)
MetadataService.update_documents_metadata(dataset, metadata_args) MetadataService.update_documents_metadata(dataset, metadata_args)

View File

@ -1,20 +1,63 @@
from typing import Any
from flask import make_response, redirect, request from flask import make_response, redirect, request
from flask_restx import Resource, reqparse from flask_restx import Resource
from pydantic import BaseModel, Field
from werkzeug.exceptions import Forbidden, NotFound from werkzeug.exceptions import Forbidden, NotFound
from configs import dify_config from configs import dify_config
from controllers.common.schema import register_schema_models
from controllers.console import console_ns from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.impl.oauth import OAuthHandler from core.plugin.impl.oauth import OAuthHandler
from libs.helper import StrLen
from libs.login import current_account_with_tenant, login_required from libs.login import current_account_with_tenant, login_required
from models.provider_ids import DatasourceProviderID from models.provider_ids import DatasourceProviderID
from services.datasource_provider_service import DatasourceProviderService from services.datasource_provider_service import DatasourceProviderService
from services.plugin.oauth_service import OAuthProxyService from services.plugin.oauth_service import OAuthProxyService
class DatasourceCredentialPayload(BaseModel):
name: str | None = Field(default=None, max_length=100)
credentials: dict[str, Any]
class DatasourceCredentialDeletePayload(BaseModel):
credential_id: str
class DatasourceCredentialUpdatePayload(BaseModel):
credential_id: str
name: str | None = Field(default=None, max_length=100)
credentials: dict[str, Any] | None = None
class DatasourceCustomClientPayload(BaseModel):
client_params: dict[str, Any] | None = None
enable_oauth_custom_client: bool | None = None
class DatasourceDefaultPayload(BaseModel):
id: str
class DatasourceUpdateNamePayload(BaseModel):
credential_id: str
name: str = Field(max_length=100)
register_schema_models(
console_ns,
DatasourceCredentialPayload,
DatasourceCredentialDeletePayload,
DatasourceCredentialUpdatePayload,
DatasourceCustomClientPayload,
DatasourceDefaultPayload,
DatasourceUpdateNamePayload,
)
@console_ns.route("/oauth/plugin/<path:provider_id>/datasource/get-authorization-url") @console_ns.route("/oauth/plugin/<path:provider_id>/datasource/get-authorization-url")
class DatasourcePluginOAuthAuthorizationUrl(Resource): class DatasourcePluginOAuthAuthorizationUrl(Resource):
@setup_required @setup_required
@ -121,16 +164,9 @@ class DatasourceOAuthCallback(Resource):
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback") return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")
parser_datasource = (
reqparse.RequestParser()
.add_argument("name", type=StrLen(max_length=100), required=False, nullable=True, location="json", default=None)
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
)
@console_ns.route("/auth/plugin/datasource/<path:provider_id>") @console_ns.route("/auth/plugin/datasource/<path:provider_id>")
class DatasourceAuth(Resource): class DatasourceAuth(Resource):
@console_ns.expect(parser_datasource) @console_ns.expect(console_ns.models[DatasourceCredentialPayload.__name__])
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -138,7 +174,7 @@ class DatasourceAuth(Resource):
def post(self, provider_id: str): def post(self, provider_id: str):
_, current_tenant_id = current_account_with_tenant() _, current_tenant_id = current_account_with_tenant()
args = parser_datasource.parse_args() payload = DatasourceCredentialPayload.model_validate(console_ns.payload or {})
datasource_provider_id = DatasourceProviderID(provider_id) datasource_provider_id = DatasourceProviderID(provider_id)
datasource_provider_service = DatasourceProviderService() datasource_provider_service = DatasourceProviderService()
@ -146,8 +182,8 @@ class DatasourceAuth(Resource):
datasource_provider_service.add_datasource_api_key_provider( datasource_provider_service.add_datasource_api_key_provider(
tenant_id=current_tenant_id, tenant_id=current_tenant_id,
provider_id=datasource_provider_id, provider_id=datasource_provider_id,
credentials=args["credentials"], credentials=payload.credentials,
name=args["name"], name=payload.name,
) )
except CredentialsValidateFailedError as ex: except CredentialsValidateFailedError as ex:
raise ValueError(str(ex)) raise ValueError(str(ex))
@ -169,14 +205,9 @@ class DatasourceAuth(Resource):
return {"result": datasources}, 200 return {"result": datasources}, 200
parser_datasource_delete = reqparse.RequestParser().add_argument(
"credential_id", type=str, required=True, nullable=False, location="json"
)
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/delete") @console_ns.route("/auth/plugin/datasource/<path:provider_id>/delete")
class DatasourceAuthDeleteApi(Resource): class DatasourceAuthDeleteApi(Resource):
@console_ns.expect(parser_datasource_delete) @console_ns.expect(console_ns.models[DatasourceCredentialDeletePayload.__name__])
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -188,28 +219,20 @@ class DatasourceAuthDeleteApi(Resource):
plugin_id = datasource_provider_id.plugin_id plugin_id = datasource_provider_id.plugin_id
provider_name = datasource_provider_id.provider_name provider_name = datasource_provider_id.provider_name
args = parser_datasource_delete.parse_args() payload = DatasourceCredentialDeletePayload.model_validate(console_ns.payload or {})
datasource_provider_service = DatasourceProviderService() datasource_provider_service = DatasourceProviderService()
datasource_provider_service.remove_datasource_credentials( datasource_provider_service.remove_datasource_credentials(
tenant_id=current_tenant_id, tenant_id=current_tenant_id,
auth_id=args["credential_id"], auth_id=payload.credential_id,
provider=provider_name, provider=provider_name,
plugin_id=plugin_id, plugin_id=plugin_id,
) )
return {"result": "success"}, 200 return {"result": "success"}, 200
parser_datasource_update = (
reqparse.RequestParser()
.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
.add_argument("name", type=StrLen(max_length=100), required=False, nullable=True, location="json")
.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
)
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/update") @console_ns.route("/auth/plugin/datasource/<path:provider_id>/update")
class DatasourceAuthUpdateApi(Resource): class DatasourceAuthUpdateApi(Resource):
@console_ns.expect(parser_datasource_update) @console_ns.expect(console_ns.models[DatasourceCredentialUpdatePayload.__name__])
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -218,16 +241,16 @@ class DatasourceAuthUpdateApi(Resource):
_, current_tenant_id = current_account_with_tenant() _, current_tenant_id = current_account_with_tenant()
datasource_provider_id = DatasourceProviderID(provider_id) datasource_provider_id = DatasourceProviderID(provider_id)
args = parser_datasource_update.parse_args() payload = DatasourceCredentialUpdatePayload.model_validate(console_ns.payload or {})
datasource_provider_service = DatasourceProviderService() datasource_provider_service = DatasourceProviderService()
datasource_provider_service.update_datasource_credentials( datasource_provider_service.update_datasource_credentials(
tenant_id=current_tenant_id, tenant_id=current_tenant_id,
auth_id=args["credential_id"], auth_id=payload.credential_id,
provider=datasource_provider_id.provider_name, provider=datasource_provider_id.provider_name,
plugin_id=datasource_provider_id.plugin_id, plugin_id=datasource_provider_id.plugin_id,
credentials=args.get("credentials", {}), credentials=payload.credentials or {},
name=args.get("name", None), name=payload.name,
) )
return {"result": "success"}, 201 return {"result": "success"}, 201
@ -258,16 +281,9 @@ class DatasourceHardCodeAuthListApi(Resource):
return {"result": jsonable_encoder(datasources)}, 200 return {"result": jsonable_encoder(datasources)}, 200
parser_datasource_custom = (
reqparse.RequestParser()
.add_argument("client_params", type=dict, required=False, nullable=True, location="json")
.add_argument("enable_oauth_custom_client", type=bool, required=False, nullable=True, location="json")
)
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/custom-client") @console_ns.route("/auth/plugin/datasource/<path:provider_id>/custom-client")
class DatasourceAuthOauthCustomClient(Resource): class DatasourceAuthOauthCustomClient(Resource):
@console_ns.expect(parser_datasource_custom) @console_ns.expect(console_ns.models[DatasourceCustomClientPayload.__name__])
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -275,14 +291,14 @@ class DatasourceAuthOauthCustomClient(Resource):
def post(self, provider_id: str): def post(self, provider_id: str):
_, current_tenant_id = current_account_with_tenant() _, current_tenant_id = current_account_with_tenant()
args = parser_datasource_custom.parse_args() payload = DatasourceCustomClientPayload.model_validate(console_ns.payload or {})
datasource_provider_id = DatasourceProviderID(provider_id) datasource_provider_id = DatasourceProviderID(provider_id)
datasource_provider_service = DatasourceProviderService() datasource_provider_service = DatasourceProviderService()
datasource_provider_service.setup_oauth_custom_client_params( datasource_provider_service.setup_oauth_custom_client_params(
tenant_id=current_tenant_id, tenant_id=current_tenant_id,
datasource_provider_id=datasource_provider_id, datasource_provider_id=datasource_provider_id,
client_params=args.get("client_params", {}), client_params=payload.client_params or {},
enabled=args.get("enable_oauth_custom_client", False), enabled=payload.enable_oauth_custom_client or False,
) )
return {"result": "success"}, 200 return {"result": "success"}, 200
@ -301,12 +317,9 @@ class DatasourceAuthOauthCustomClient(Resource):
return {"result": "success"}, 200 return {"result": "success"}, 200
parser_default = reqparse.RequestParser().add_argument("id", type=str, required=True, nullable=False, location="json")
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/default") @console_ns.route("/auth/plugin/datasource/<path:provider_id>/default")
class DatasourceAuthDefaultApi(Resource): class DatasourceAuthDefaultApi(Resource):
@console_ns.expect(parser_default) @console_ns.expect(console_ns.models[DatasourceDefaultPayload.__name__])
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -314,27 +327,20 @@ class DatasourceAuthDefaultApi(Resource):
def post(self, provider_id: str): def post(self, provider_id: str):
_, current_tenant_id = current_account_with_tenant() _, current_tenant_id = current_account_with_tenant()
args = parser_default.parse_args() payload = DatasourceDefaultPayload.model_validate(console_ns.payload or {})
datasource_provider_id = DatasourceProviderID(provider_id) datasource_provider_id = DatasourceProviderID(provider_id)
datasource_provider_service = DatasourceProviderService() datasource_provider_service = DatasourceProviderService()
datasource_provider_service.set_default_datasource_provider( datasource_provider_service.set_default_datasource_provider(
tenant_id=current_tenant_id, tenant_id=current_tenant_id,
datasource_provider_id=datasource_provider_id, datasource_provider_id=datasource_provider_id,
credential_id=args["id"], credential_id=payload.id,
) )
return {"result": "success"}, 200 return {"result": "success"}, 200
parser_update_name = (
reqparse.RequestParser()
.add_argument("name", type=StrLen(max_length=100), required=True, nullable=False, location="json")
.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
)
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/update-name") @console_ns.route("/auth/plugin/datasource/<path:provider_id>/update-name")
class DatasourceUpdateProviderNameApi(Resource): class DatasourceUpdateProviderNameApi(Resource):
@console_ns.expect(parser_update_name) @console_ns.expect(console_ns.models[DatasourceUpdateNamePayload.__name__])
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -342,13 +348,13 @@ class DatasourceUpdateProviderNameApi(Resource):
def post(self, provider_id: str): def post(self, provider_id: str):
_, current_tenant_id = current_account_with_tenant() _, current_tenant_id = current_account_with_tenant()
args = parser_update_name.parse_args() payload = DatasourceUpdateNamePayload.model_validate(console_ns.payload or {})
datasource_provider_id = DatasourceProviderID(provider_id) datasource_provider_id = DatasourceProviderID(provider_id)
datasource_provider_service = DatasourceProviderService() datasource_provider_service = DatasourceProviderService()
datasource_provider_service.update_datasource_provider_name( datasource_provider_service.update_datasource_provider_name(
tenant_id=current_tenant_id, tenant_id=current_tenant_id,
datasource_provider_id=datasource_provider_id, datasource_provider_id=datasource_provider_id,
name=args["name"], name=payload.name,
credential_id=args["credential_id"], credential_id=payload.credential_id,
) )
return {"result": "success"}, 200 return {"result": "success"}, 200

View File

@ -26,7 +26,7 @@ console_ns.schema_model(Parser.__name__, Parser.model_json_schema(ref_template=D
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/datasource/nodes/<string:node_id>/preview") @console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/datasource/nodes/<string:node_id>/preview")
class DataSourceContentPreviewApi(Resource): class DataSourceContentPreviewApi(Resource):
@console_ns.expect(console_ns.models[Parser.__name__], validate=True) @console_ns.expect(console_ns.models[Parser.__name__])
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required

View File

@ -1,9 +1,11 @@
import logging import logging
from flask import request from flask import request
from flask_restx import Resource, reqparse from flask_restx import Resource
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from controllers.common.schema import register_schema_models
from controllers.console import console_ns from controllers.console import console_ns
from controllers.console.wraps import ( from controllers.console.wraps import (
account_initialization_required, account_initialization_required,
@ -20,18 +22,6 @@ from services.rag_pipeline.rag_pipeline import RagPipelineService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def _validate_name(name: str) -> str:
if not name or len(name) < 1 or len(name) > 40:
raise ValueError("Name must be between 1 to 40 characters.")
return name
def _validate_description_length(description: str) -> str:
if len(description) > 400:
raise ValueError("Description cannot exceed 400 characters.")
return description
@console_ns.route("/rag/pipeline/templates") @console_ns.route("/rag/pipeline/templates")
class PipelineTemplateListApi(Resource): class PipelineTemplateListApi(Resource):
@setup_required @setup_required
@ -59,6 +49,15 @@ class PipelineTemplateDetailApi(Resource):
return pipeline_template, 200 return pipeline_template, 200
class Payload(BaseModel):
name: str = Field(..., min_length=1, max_length=40)
description: str = Field(default="", max_length=400)
icon_info: dict[str, object] | None = None
register_schema_models(console_ns, Payload)
@console_ns.route("/rag/pipeline/customized/templates/<string:template_id>") @console_ns.route("/rag/pipeline/customized/templates/<string:template_id>")
class CustomizedPipelineTemplateApi(Resource): class CustomizedPipelineTemplateApi(Resource):
@setup_required @setup_required
@ -66,31 +65,8 @@ class CustomizedPipelineTemplateApi(Resource):
@account_initialization_required @account_initialization_required
@enterprise_license_required @enterprise_license_required
def patch(self, template_id: str): def patch(self, template_id: str):
parser = ( payload = Payload.model_validate(console_ns.payload or {})
reqparse.RequestParser() pipeline_template_info = PipelineTemplateInfoEntity.model_validate(payload.model_dump())
.add_argument(
"name",
nullable=False,
required=True,
help="Name must be between 1 to 40 characters.",
type=_validate_name,
)
.add_argument(
"description",
type=_validate_description_length,
nullable=True,
required=False,
default="",
)
.add_argument(
"icon_info",
type=dict,
location="json",
nullable=True,
)
)
args = parser.parse_args()
pipeline_template_info = PipelineTemplateInfoEntity.model_validate(args)
RagPipelineService.update_customized_pipeline_template(template_id, pipeline_template_info) RagPipelineService.update_customized_pipeline_template(template_id, pipeline_template_info)
return 200 return 200
@ -119,36 +95,14 @@ class CustomizedPipelineTemplateApi(Resource):
@console_ns.route("/rag/pipelines/<string:pipeline_id>/customized/publish") @console_ns.route("/rag/pipelines/<string:pipeline_id>/customized/publish")
class PublishCustomizedPipelineTemplateApi(Resource): class PublishCustomizedPipelineTemplateApi(Resource):
@console_ns.expect(console_ns.models[Payload.__name__])
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@enterprise_license_required @enterprise_license_required
@knowledge_pipeline_publish_enabled @knowledge_pipeline_publish_enabled
def post(self, pipeline_id: str): def post(self, pipeline_id: str):
parser = ( payload = Payload.model_validate(console_ns.payload or {})
reqparse.RequestParser()
.add_argument(
"name",
nullable=False,
required=True,
help="Name must be between 1 to 40 characters.",
type=_validate_name,
)
.add_argument(
"description",
type=_validate_description_length,
nullable=True,
required=False,
default="",
)
.add_argument(
"icon_info",
type=dict,
location="json",
nullable=True,
)
)
args = parser.parse_args()
rag_pipeline_service = RagPipelineService() rag_pipeline_service = RagPipelineService()
rag_pipeline_service.publish_customized_pipeline_template(pipeline_id, args) rag_pipeline_service.publish_customized_pipeline_template(pipeline_id, payload.model_dump())
return {"result": "success"} return {"result": "success"}

View File

@ -1,8 +1,10 @@
from flask_restx import Resource, marshal, reqparse from flask_restx import Resource, marshal
from pydantic import BaseModel
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden
import services import services
from controllers.common.schema import register_schema_model
from controllers.console import console_ns from controllers.console import console_ns
from controllers.console.datasets.error import DatasetNameDuplicateError from controllers.console.datasets.error import DatasetNameDuplicateError
from controllers.console.wraps import ( from controllers.console.wraps import (
@ -19,22 +21,22 @@ from services.entities.knowledge_entities.rag_pipeline_entities import IconInfo,
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
class RagPipelineDatasetImportPayload(BaseModel):
yaml_content: str
register_schema_model(console_ns, RagPipelineDatasetImportPayload)
@console_ns.route("/rag/pipeline/dataset") @console_ns.route("/rag/pipeline/dataset")
class CreateRagPipelineDatasetApi(Resource): class CreateRagPipelineDatasetApi(Resource):
@console_ns.expect(console_ns.models[RagPipelineDatasetImportPayload.__name__])
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge") @cloud_edition_billing_rate_limit_check("knowledge")
def post(self): def post(self):
parser = reqparse.RequestParser().add_argument( payload = RagPipelineDatasetImportPayload.model_validate(console_ns.payload or {})
"yaml_content",
type=str,
nullable=False,
required=True,
help="yaml_content is required.",
)
args = parser.parse_args()
current_user, current_tenant_id = current_account_with_tenant() current_user, current_tenant_id = current_account_with_tenant()
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator # The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
if not current_user.is_dataset_editor: if not current_user.is_dataset_editor:
@ -49,7 +51,7 @@ class CreateRagPipelineDatasetApi(Resource):
), ),
permission=DatasetPermissionEnum.ONLY_ME, permission=DatasetPermissionEnum.ONLY_ME,
partial_member_list=None, partial_member_list=None,
yaml_content=args["yaml_content"], yaml_content=payload.yaml_content,
) )
try: try:
with Session(db.engine) as session: with Session(db.engine) as session:

View File

@ -1,11 +1,13 @@
import logging import logging
from typing import NoReturn from typing import Any, NoReturn
from flask import Response from flask import Response, request
from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse from flask_restx import Resource, fields, marshal, marshal_with
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden
from controllers.common.schema import register_schema_models
from controllers.console import console_ns from controllers.console import console_ns
from controllers.console.app.error import ( from controllers.console.app.error import (
DraftWorkflowNotExist, DraftWorkflowNotExist,
@ -33,19 +35,21 @@ logger = logging.getLogger(__name__)
def _create_pagination_parser(): def _create_pagination_parser():
parser = ( class PaginationQuery(BaseModel):
reqparse.RequestParser() page: int = Field(default=1, ge=1, le=100_000)
.add_argument( limit: int = Field(default=20, ge=1, le=100)
"page",
type=inputs.int_range(1, 100_000), register_schema_models(console_ns, PaginationQuery)
required=False,
default=1, return PaginationQuery
location="args",
help="the page of data requested",
) class WorkflowDraftVariablePatchPayload(BaseModel):
.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args") name: str | None = None
) value: Any | None = None
return parser
register_schema_models(console_ns, WorkflowDraftVariablePatchPayload)
def _get_items(var_list: WorkflowDraftVariableList) -> list[WorkflowDraftVariable]: def _get_items(var_list: WorkflowDraftVariableList) -> list[WorkflowDraftVariable]:
@ -93,8 +97,8 @@ class RagPipelineVariableCollectionApi(Resource):
""" """
Get draft workflow Get draft workflow
""" """
parser = _create_pagination_parser() pagination = _create_pagination_parser()
args = parser.parse_args() query = pagination.model_validate(request.args.to_dict())
# fetch draft workflow by app_model # fetch draft workflow by app_model
rag_pipeline_service = RagPipelineService() rag_pipeline_service = RagPipelineService()
@ -109,8 +113,8 @@ class RagPipelineVariableCollectionApi(Resource):
) )
workflow_vars = draft_var_srv.list_variables_without_values( workflow_vars = draft_var_srv.list_variables_without_values(
app_id=pipeline.id, app_id=pipeline.id,
page=args.page, page=query.page,
limit=args.limit, limit=query.limit,
) )
return workflow_vars return workflow_vars
@ -186,6 +190,7 @@ class RagPipelineVariableApi(Resource):
@_api_prerequisite @_api_prerequisite
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_FIELDS) @marshal_with(_WORKFLOW_DRAFT_VARIABLE_FIELDS)
@console_ns.expect(console_ns.models[WorkflowDraftVariablePatchPayload.__name__])
def patch(self, pipeline: Pipeline, variable_id: str): def patch(self, pipeline: Pipeline, variable_id: str):
# Request payload for file types: # Request payload for file types:
# #
@ -208,16 +213,11 @@ class RagPipelineVariableApi(Resource):
# "upload_file_id": "1602650a-4fe4-423c-85a2-af76c083e3c4" # "upload_file_id": "1602650a-4fe4-423c-85a2-af76c083e3c4"
# } # }
parser = (
reqparse.RequestParser()
.add_argument(self._PATCH_NAME_FIELD, type=str, required=False, nullable=True, location="json")
.add_argument(self._PATCH_VALUE_FIELD, type=lambda x: x, required=False, nullable=True, location="json")
)
draft_var_srv = WorkflowDraftVariableService( draft_var_srv = WorkflowDraftVariableService(
session=db.session(), session=db.session(),
) )
args = parser.parse_args(strict=True) payload = WorkflowDraftVariablePatchPayload.model_validate(console_ns.payload or {})
args = payload.model_dump(exclude_none=True)
variable = draft_var_srv.get_variable(variable_id=variable_id) variable = draft_var_srv.get_variable(variable_id=variable_id)
if variable is None: if variable is None:

View File

@ -1,6 +1,9 @@
from flask_restx import Resource, marshal_with, reqparse # type: ignore from flask import request
from flask_restx import Resource, marshal_with # type: ignore
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from controllers.common.schema import register_schema_models
from controllers.console import console_ns from controllers.console import console_ns
from controllers.console.datasets.wraps import get_rag_pipeline from controllers.console.datasets.wraps import get_rag_pipeline
from controllers.console.wraps import ( from controllers.console.wraps import (
@ -16,6 +19,25 @@ from services.app_dsl_service import ImportStatus
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
class RagPipelineImportPayload(BaseModel):
mode: str
yaml_content: str | None = None
yaml_url: str | None = None
name: str | None = None
description: str | None = None
icon_type: str | None = None
icon: str | None = None
icon_background: str | None = None
pipeline_id: str | None = None
class IncludeSecretQuery(BaseModel):
include_secret: str = Field(default="false")
register_schema_models(console_ns, RagPipelineImportPayload, IncludeSecretQuery)
@console_ns.route("/rag/pipelines/imports") @console_ns.route("/rag/pipelines/imports")
class RagPipelineImportApi(Resource): class RagPipelineImportApi(Resource):
@setup_required @setup_required
@ -23,23 +45,11 @@ class RagPipelineImportApi(Resource):
@account_initialization_required @account_initialization_required
@edit_permission_required @edit_permission_required
@marshal_with(pipeline_import_fields) @marshal_with(pipeline_import_fields)
@console_ns.expect(console_ns.models[RagPipelineImportPayload.__name__])
def post(self): def post(self):
# Check user role first # Check user role first
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
payload = RagPipelineImportPayload.model_validate(console_ns.payload or {})
parser = (
reqparse.RequestParser()
.add_argument("mode", type=str, required=True, location="json")
.add_argument("yaml_content", type=str, location="json")
.add_argument("yaml_url", type=str, location="json")
.add_argument("name", type=str, location="json")
.add_argument("description", type=str, location="json")
.add_argument("icon_type", type=str, location="json")
.add_argument("icon", type=str, location="json")
.add_argument("icon_background", type=str, location="json")
.add_argument("pipeline_id", type=str, location="json")
)
args = parser.parse_args()
# Create service with session # Create service with session
with Session(db.engine) as session: with Session(db.engine) as session:
@ -48,11 +58,11 @@ class RagPipelineImportApi(Resource):
account = current_user account = current_user
result = import_service.import_rag_pipeline( result = import_service.import_rag_pipeline(
account=account, account=account,
import_mode=args["mode"], import_mode=payload.mode,
yaml_content=args.get("yaml_content"), yaml_content=payload.yaml_content,
yaml_url=args.get("yaml_url"), yaml_url=payload.yaml_url,
pipeline_id=args.get("pipeline_id"), pipeline_id=payload.pipeline_id,
dataset_name=args.get("name"), dataset_name=payload.name,
) )
session.commit() session.commit()
@ -114,13 +124,12 @@ class RagPipelineExportApi(Resource):
@edit_permission_required @edit_permission_required
def get(self, pipeline: Pipeline): def get(self, pipeline: Pipeline):
# Add include_secret params # Add include_secret params
parser = reqparse.RequestParser().add_argument("include_secret", type=str, default="false", location="args") query = IncludeSecretQuery.model_validate(request.args.to_dict())
args = parser.parse_args()
with Session(db.engine) as session: with Session(db.engine) as session:
export_service = RagPipelineDslService(session) export_service = RagPipelineDslService(session)
result = export_service.export_rag_pipeline_dsl( result = export_service.export_rag_pipeline_dsl(
pipeline=pipeline, include_secret=args["include_secret"] == "true" pipeline=pipeline, include_secret=query.include_secret == "true"
) )
return {"data": result}, 200 return {"data": result}, 200

View File

@ -1,14 +1,16 @@
import json import json
import logging import logging
from typing import cast from typing import Any, Literal, cast
from uuid import UUID
from flask import abort, request from flask import abort, request
from flask_restx import Resource, inputs, marshal_with, reqparse # type: ignore # type: ignore from flask_restx import Resource, marshal_with # type: ignore
from flask_restx.inputs import int_range # type: ignore from pydantic import BaseModel, Field
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
import services import services
from controllers.common.schema import register_schema_models
from controllers.console import console_ns from controllers.console import console_ns
from controllers.console.app.error import ( from controllers.console.app.error import (
ConversationCompletedError, ConversationCompletedError,
@ -36,7 +38,7 @@ from fields.workflow_run_fields import (
workflow_run_pagination_fields, workflow_run_pagination_fields,
) )
from libs import helper from libs import helper
from libs.helper import TimestampField, uuid_value from libs.helper import TimestampField
from libs.login import current_account_with_tenant, current_user, login_required from libs.login import current_account_with_tenant, current_user, login_required
from models import Account from models import Account
from models.dataset import Pipeline from models.dataset import Pipeline
@ -51,6 +53,91 @@ from services.rag_pipeline.rag_pipeline_transform_service import RagPipelineTran
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class DraftWorkflowSyncPayload(BaseModel):
graph: dict[str, Any]
hash: str | None = None
environment_variables: list[dict[str, Any]] | None = None
conversation_variables: list[dict[str, Any]] | None = None
rag_pipeline_variables: list[dict[str, Any]] | None = None
features: dict[str, Any] | None = None
class NodeRunPayload(BaseModel):
inputs: dict[str, Any] | None = None
class NodeRunRequiredPayload(BaseModel):
inputs: dict[str, Any]
class DatasourceNodeRunPayload(BaseModel):
inputs: dict[str, Any]
datasource_type: str
credential_id: str | None = None
class DraftWorkflowRunPayload(BaseModel):
inputs: dict[str, Any]
datasource_type: str
datasource_info_list: list[dict[str, Any]]
start_node_id: str
class PublishedWorkflowRunPayload(DraftWorkflowRunPayload):
is_preview: bool = False
response_mode: Literal["streaming", "blocking"] = "streaming"
original_document_id: str | None = None
class DefaultBlockConfigQuery(BaseModel):
q: str | None = None
class WorkflowListQuery(BaseModel):
page: int = Field(default=1, ge=1, le=99999)
limit: int = Field(default=10, ge=1, le=100)
user_id: str | None = None
named_only: bool = False
class WorkflowUpdatePayload(BaseModel):
marked_name: str | None = Field(default=None, max_length=20)
marked_comment: str | None = Field(default=None, max_length=100)
class NodeIdQuery(BaseModel):
node_id: str
class WorkflowRunQuery(BaseModel):
last_id: UUID | None = None
limit: int = Field(default=20, ge=1, le=100)
class DatasourceVariablesPayload(BaseModel):
datasource_type: str
datasource_info: dict[str, Any]
start_node_id: str
start_node_title: str
register_schema_models(
console_ns,
DraftWorkflowSyncPayload,
NodeRunPayload,
NodeRunRequiredPayload,
DatasourceNodeRunPayload,
DraftWorkflowRunPayload,
PublishedWorkflowRunPayload,
DefaultBlockConfigQuery,
WorkflowListQuery,
WorkflowUpdatePayload,
NodeIdQuery,
WorkflowRunQuery,
DatasourceVariablesPayload,
)
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft") @console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft")
class DraftRagPipelineApi(Resource): class DraftRagPipelineApi(Resource):
@setup_required @setup_required
@ -88,15 +175,7 @@ class DraftRagPipelineApi(Resource):
content_type = request.headers.get("Content-Type", "") content_type = request.headers.get("Content-Type", "")
if "application/json" in content_type: if "application/json" in content_type:
parser = ( payload_dict = console_ns.payload or {}
reqparse.RequestParser()
.add_argument("graph", type=dict, required=True, nullable=False, location="json")
.add_argument("hash", type=str, required=False, location="json")
.add_argument("environment_variables", type=list, required=False, location="json")
.add_argument("conversation_variables", type=list, required=False, location="json")
.add_argument("rag_pipeline_variables", type=list, required=False, location="json")
)
args = parser.parse_args()
elif "text/plain" in content_type: elif "text/plain" in content_type:
try: try:
data = json.loads(request.data.decode("utf-8")) data = json.loads(request.data.decode("utf-8"))
@ -106,7 +185,7 @@ class DraftRagPipelineApi(Resource):
if not isinstance(data.get("graph"), dict): if not isinstance(data.get("graph"), dict):
raise ValueError("graph is not a dict") raise ValueError("graph is not a dict")
args = { payload_dict = {
"graph": data.get("graph"), "graph": data.get("graph"),
"features": data.get("features"), "features": data.get("features"),
"hash": data.get("hash"), "hash": data.get("hash"),
@ -119,24 +198,26 @@ class DraftRagPipelineApi(Resource):
else: else:
abort(415) abort(415)
payload = DraftWorkflowSyncPayload.model_validate(payload_dict)
try: try:
environment_variables_list = args.get("environment_variables") or [] environment_variables_list = payload.environment_variables or []
environment_variables = [ environment_variables = [
variable_factory.build_environment_variable_from_mapping(obj) for obj in environment_variables_list variable_factory.build_environment_variable_from_mapping(obj) for obj in environment_variables_list
] ]
conversation_variables_list = args.get("conversation_variables") or [] conversation_variables_list = payload.conversation_variables or []
conversation_variables = [ conversation_variables = [
variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list
] ]
rag_pipeline_service = RagPipelineService() rag_pipeline_service = RagPipelineService()
workflow = rag_pipeline_service.sync_draft_workflow( workflow = rag_pipeline_service.sync_draft_workflow(
pipeline=pipeline, pipeline=pipeline,
graph=args["graph"], graph=payload.graph,
unique_hash=args.get("hash"), unique_hash=payload.hash,
account=current_user, account=current_user,
environment_variables=environment_variables, environment_variables=environment_variables,
conversation_variables=conversation_variables, conversation_variables=conversation_variables,
rag_pipeline_variables=args.get("rag_pipeline_variables") or [], rag_pipeline_variables=payload.rag_pipeline_variables or [],
) )
except WorkflowHashNotEqualError: except WorkflowHashNotEqualError:
raise DraftWorkflowNotSync() raise DraftWorkflowNotSync()
@ -148,12 +229,9 @@ class DraftRagPipelineApi(Resource):
} }
parser_run = reqparse.RequestParser().add_argument("inputs", type=dict, location="json")
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/iteration/nodes/<string:node_id>/run") @console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/iteration/nodes/<string:node_id>/run")
class RagPipelineDraftRunIterationNodeApi(Resource): class RagPipelineDraftRunIterationNodeApi(Resource):
@console_ns.expect(parser_run) @console_ns.expect(console_ns.models[NodeRunPayload.__name__])
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -166,7 +244,8 @@ class RagPipelineDraftRunIterationNodeApi(Resource):
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
args = parser_run.parse_args() payload = NodeRunPayload.model_validate(console_ns.payload or {})
args = payload.model_dump(exclude_none=True)
try: try:
response = PipelineGenerateService.generate_single_iteration( response = PipelineGenerateService.generate_single_iteration(
@ -187,7 +266,7 @@ class RagPipelineDraftRunIterationNodeApi(Resource):
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/loop/nodes/<string:node_id>/run") @console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/loop/nodes/<string:node_id>/run")
class RagPipelineDraftRunLoopNodeApi(Resource): class RagPipelineDraftRunLoopNodeApi(Resource):
@console_ns.expect(parser_run) @console_ns.expect(console_ns.models[NodeRunPayload.__name__])
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -200,7 +279,8 @@ class RagPipelineDraftRunLoopNodeApi(Resource):
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
args = parser_run.parse_args() payload = NodeRunPayload.model_validate(console_ns.payload or {})
args = payload.model_dump(exclude_none=True)
try: try:
response = PipelineGenerateService.generate_single_loop( response = PipelineGenerateService.generate_single_loop(
@ -219,18 +299,9 @@ class RagPipelineDraftRunLoopNodeApi(Resource):
raise InternalServerError() raise InternalServerError()
parser_draft_run = (
reqparse.RequestParser()
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
.add_argument("datasource_type", type=str, required=True, location="json")
.add_argument("datasource_info_list", type=list, required=True, location="json")
.add_argument("start_node_id", type=str, required=True, location="json")
)
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/run") @console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/run")
class DraftRagPipelineRunApi(Resource): class DraftRagPipelineRunApi(Resource):
@console_ns.expect(parser_draft_run) @console_ns.expect(console_ns.models[DraftWorkflowRunPayload.__name__])
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -243,7 +314,8 @@ class DraftRagPipelineRunApi(Resource):
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
args = parser_draft_run.parse_args() payload = DraftWorkflowRunPayload.model_validate(console_ns.payload or {})
args = payload.model_dump()
try: try:
response = PipelineGenerateService.generate( response = PipelineGenerateService.generate(
@ -259,21 +331,9 @@ class DraftRagPipelineRunApi(Resource):
raise InvokeRateLimitHttpError(ex.description) raise InvokeRateLimitHttpError(ex.description)
parser_published_run = (
reqparse.RequestParser()
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
.add_argument("datasource_type", type=str, required=True, location="json")
.add_argument("datasource_info_list", type=list, required=True, location="json")
.add_argument("start_node_id", type=str, required=True, location="json")
.add_argument("is_preview", type=bool, required=True, location="json", default=False)
.add_argument("response_mode", type=str, required=True, location="json", default="streaming")
.add_argument("original_document_id", type=str, required=False, location="json")
)
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/run") @console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/run")
class PublishedRagPipelineRunApi(Resource): class PublishedRagPipelineRunApi(Resource):
@console_ns.expect(parser_published_run) @console_ns.expect(console_ns.models[PublishedWorkflowRunPayload.__name__])
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -286,16 +346,16 @@ class PublishedRagPipelineRunApi(Resource):
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
args = parser_published_run.parse_args() payload = PublishedWorkflowRunPayload.model_validate(console_ns.payload or {})
args = payload.model_dump(exclude_none=True)
streaming = args["response_mode"] == "streaming" streaming = payload.response_mode == "streaming"
try: try:
response = PipelineGenerateService.generate( response = PipelineGenerateService.generate(
pipeline=pipeline, pipeline=pipeline,
user=current_user, user=current_user,
args=args, args=args,
invoke_from=InvokeFrom.DEBUGGER if args.get("is_preview") else InvokeFrom.PUBLISHED, invoke_from=InvokeFrom.DEBUGGER if payload.is_preview else InvokeFrom.PUBLISHED,
streaming=streaming, streaming=streaming,
) )
@ -387,17 +447,9 @@ class PublishedRagPipelineRunApi(Resource):
# #
# return result # return result
# #
parser_rag_run = (
reqparse.RequestParser()
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
.add_argument("datasource_type", type=str, required=True, location="json")
.add_argument("credential_id", type=str, required=False, location="json")
)
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/datasource/nodes/<string:node_id>/run") @console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/datasource/nodes/<string:node_id>/run")
class RagPipelinePublishedDatasourceNodeRunApi(Resource): class RagPipelinePublishedDatasourceNodeRunApi(Resource):
@console_ns.expect(parser_rag_run) @console_ns.expect(console_ns.models[DatasourceNodeRunPayload.__name__])
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -410,14 +462,7 @@ class RagPipelinePublishedDatasourceNodeRunApi(Resource):
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
args = parser_rag_run.parse_args() payload = DatasourceNodeRunPayload.model_validate(console_ns.payload or {})
inputs = args.get("inputs")
if inputs is None:
raise ValueError("missing inputs")
datasource_type = args.get("datasource_type")
if datasource_type is None:
raise ValueError("missing datasource_type")
rag_pipeline_service = RagPipelineService() rag_pipeline_service = RagPipelineService()
return helper.compact_generate_response( return helper.compact_generate_response(
@ -425,11 +470,11 @@ class RagPipelinePublishedDatasourceNodeRunApi(Resource):
rag_pipeline_service.run_datasource_workflow_node( rag_pipeline_service.run_datasource_workflow_node(
pipeline=pipeline, pipeline=pipeline,
node_id=node_id, node_id=node_id,
user_inputs=inputs, user_inputs=payload.inputs,
account=current_user, account=current_user,
datasource_type=datasource_type, datasource_type=payload.datasource_type,
is_published=False, is_published=False,
credential_id=args.get("credential_id"), credential_id=payload.credential_id,
) )
) )
) )
@ -437,7 +482,7 @@ class RagPipelinePublishedDatasourceNodeRunApi(Resource):
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/datasource/nodes/<string:node_id>/run") @console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/datasource/nodes/<string:node_id>/run")
class RagPipelineDraftDatasourceNodeRunApi(Resource): class RagPipelineDraftDatasourceNodeRunApi(Resource):
@console_ns.expect(parser_rag_run) @console_ns.expect(console_ns.models[DatasourceNodeRunPayload.__name__])
@setup_required @setup_required
@login_required @login_required
@edit_permission_required @edit_permission_required
@ -450,14 +495,7 @@ class RagPipelineDraftDatasourceNodeRunApi(Resource):
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
args = parser_rag_run.parse_args() payload = DatasourceNodeRunPayload.model_validate(console_ns.payload or {})
inputs = args.get("inputs")
if inputs is None:
raise ValueError("missing inputs")
datasource_type = args.get("datasource_type")
if datasource_type is None:
raise ValueError("missing datasource_type")
rag_pipeline_service = RagPipelineService() rag_pipeline_service = RagPipelineService()
return helper.compact_generate_response( return helper.compact_generate_response(
@ -465,24 +503,19 @@ class RagPipelineDraftDatasourceNodeRunApi(Resource):
rag_pipeline_service.run_datasource_workflow_node( rag_pipeline_service.run_datasource_workflow_node(
pipeline=pipeline, pipeline=pipeline,
node_id=node_id, node_id=node_id,
user_inputs=inputs, user_inputs=payload.inputs,
account=current_user, account=current_user,
datasource_type=datasource_type, datasource_type=payload.datasource_type,
is_published=False, is_published=False,
credential_id=args.get("credential_id"), credential_id=payload.credential_id,
) )
) )
) )
parser_run_api = reqparse.RequestParser().add_argument(
"inputs", type=dict, required=True, nullable=False, location="json"
)
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/nodes/<string:node_id>/run") @console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/nodes/<string:node_id>/run")
class RagPipelineDraftNodeRunApi(Resource): class RagPipelineDraftNodeRunApi(Resource):
@console_ns.expect(parser_run_api) @console_ns.expect(console_ns.models[NodeRunRequiredPayload.__name__])
@setup_required @setup_required
@login_required @login_required
@edit_permission_required @edit_permission_required
@ -496,11 +529,8 @@ class RagPipelineDraftNodeRunApi(Resource):
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
args = parser_run_api.parse_args() payload = NodeRunRequiredPayload.model_validate(console_ns.payload or {})
inputs = payload.inputs
inputs = args.get("inputs")
if inputs == None:
raise ValueError("missing inputs")
rag_pipeline_service = RagPipelineService() rag_pipeline_service = RagPipelineService()
workflow_node_execution = rag_pipeline_service.run_draft_workflow_node( workflow_node_execution = rag_pipeline_service.run_draft_workflow_node(
@ -602,12 +632,8 @@ class DefaultRagPipelineBlockConfigsApi(Resource):
return rag_pipeline_service.get_default_block_configs() return rag_pipeline_service.get_default_block_configs()
parser_default = reqparse.RequestParser().add_argument("q", type=str, location="args")
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/default-workflow-block-configs/<string:block_type>") @console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/default-workflow-block-configs/<string:block_type>")
class DefaultRagPipelineBlockConfigApi(Resource): class DefaultRagPipelineBlockConfigApi(Resource):
@console_ns.expect(parser_default)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -617,14 +643,12 @@ class DefaultRagPipelineBlockConfigApi(Resource):
""" """
Get default block config Get default block config
""" """
args = parser_default.parse_args() query = DefaultBlockConfigQuery.model_validate(request.args.to_dict())
q = args.get("q")
filters = None filters = None
if q: if query.q:
try: try:
filters = json.loads(args.get("q", "")) filters = json.loads(query.q)
except json.JSONDecodeError: except json.JSONDecodeError:
raise ValueError("Invalid filters") raise ValueError("Invalid filters")
@ -633,18 +657,8 @@ class DefaultRagPipelineBlockConfigApi(Resource):
return rag_pipeline_service.get_default_block_config(node_type=block_type, filters=filters) return rag_pipeline_service.get_default_block_config(node_type=block_type, filters=filters)
parser_wf = (
reqparse.RequestParser()
.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args")
.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=10, location="args")
.add_argument("user_id", type=str, required=False, location="args")
.add_argument("named_only", type=inputs.boolean, required=False, default=False, location="args")
)
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows") @console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows")
class PublishedAllRagPipelineApi(Resource): class PublishedAllRagPipelineApi(Resource):
@console_ns.expect(parser_wf)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -657,16 +671,16 @@ class PublishedAllRagPipelineApi(Resource):
""" """
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
args = parser_wf.parse_args() query = WorkflowListQuery.model_validate(request.args.to_dict())
page = args["page"]
limit = args["limit"] page = query.page
user_id = args.get("user_id") limit = query.limit
named_only = args.get("named_only", False) user_id = query.user_id
named_only = query.named_only
if user_id: if user_id:
if user_id != current_user.id: if user_id != current_user.id:
raise Forbidden() raise Forbidden()
user_id = cast(str, user_id)
rag_pipeline_service = RagPipelineService() rag_pipeline_service = RagPipelineService()
with Session(db.engine) as session: with Session(db.engine) as session:
@ -687,16 +701,8 @@ class PublishedAllRagPipelineApi(Resource):
} }
parser_wf_id = (
reqparse.RequestParser()
.add_argument("marked_name", type=str, required=False, location="json")
.add_argument("marked_comment", type=str, required=False, location="json")
)
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/<string:workflow_id>") @console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/<string:workflow_id>")
class RagPipelineByIdApi(Resource): class RagPipelineByIdApi(Resource):
@console_ns.expect(parser_wf_id)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -710,20 +716,8 @@ class RagPipelineByIdApi(Resource):
# Check permission # Check permission
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
args = parser_wf_id.parse_args() payload = WorkflowUpdatePayload.model_validate(console_ns.payload or {})
update_data = payload.model_dump(exclude_unset=True)
# Validate name and comment length
if args.marked_name and len(args.marked_name) > 20:
raise ValueError("Marked name cannot exceed 20 characters")
if args.marked_comment and len(args.marked_comment) > 100:
raise ValueError("Marked comment cannot exceed 100 characters")
# Prepare update data
update_data = {}
if args.get("marked_name") is not None:
update_data["marked_name"] = args["marked_name"]
if args.get("marked_comment") is not None:
update_data["marked_comment"] = args["marked_comment"]
if not update_data: if not update_data:
return {"message": "No valid fields to update"}, 400 return {"message": "No valid fields to update"}, 400
@ -749,12 +743,8 @@ class RagPipelineByIdApi(Resource):
return workflow return workflow
parser_parameters = reqparse.RequestParser().add_argument("node_id", type=str, required=True, location="args")
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/processing/parameters") @console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/processing/parameters")
class PublishedRagPipelineSecondStepApi(Resource): class PublishedRagPipelineSecondStepApi(Resource):
@console_ns.expect(parser_parameters)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -764,10 +754,8 @@ class PublishedRagPipelineSecondStepApi(Resource):
""" """
Get second step parameters of rag pipeline Get second step parameters of rag pipeline
""" """
args = parser_parameters.parse_args() query = NodeIdQuery.model_validate(request.args.to_dict())
node_id = args.get("node_id") node_id = query.node_id
if not node_id:
raise ValueError("Node ID is required")
rag_pipeline_service = RagPipelineService() rag_pipeline_service = RagPipelineService()
variables = rag_pipeline_service.get_second_step_parameters(pipeline=pipeline, node_id=node_id, is_draft=False) variables = rag_pipeline_service.get_second_step_parameters(pipeline=pipeline, node_id=node_id, is_draft=False)
return { return {
@ -777,7 +765,6 @@ class PublishedRagPipelineSecondStepApi(Resource):
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/pre-processing/parameters") @console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/pre-processing/parameters")
class PublishedRagPipelineFirstStepApi(Resource): class PublishedRagPipelineFirstStepApi(Resource):
@console_ns.expect(parser_parameters)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -787,10 +774,8 @@ class PublishedRagPipelineFirstStepApi(Resource):
""" """
Get first step parameters of rag pipeline Get first step parameters of rag pipeline
""" """
args = parser_parameters.parse_args() query = NodeIdQuery.model_validate(request.args.to_dict())
node_id = args.get("node_id") node_id = query.node_id
if not node_id:
raise ValueError("Node ID is required")
rag_pipeline_service = RagPipelineService() rag_pipeline_service = RagPipelineService()
variables = rag_pipeline_service.get_first_step_parameters(pipeline=pipeline, node_id=node_id, is_draft=False) variables = rag_pipeline_service.get_first_step_parameters(pipeline=pipeline, node_id=node_id, is_draft=False)
return { return {
@ -800,7 +785,6 @@ class PublishedRagPipelineFirstStepApi(Resource):
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/pre-processing/parameters") @console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/pre-processing/parameters")
class DraftRagPipelineFirstStepApi(Resource): class DraftRagPipelineFirstStepApi(Resource):
@console_ns.expect(parser_parameters)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -810,10 +794,8 @@ class DraftRagPipelineFirstStepApi(Resource):
""" """
Get first step parameters of rag pipeline Get first step parameters of rag pipeline
""" """
args = parser_parameters.parse_args() query = NodeIdQuery.model_validate(request.args.to_dict())
node_id = args.get("node_id") node_id = query.node_id
if not node_id:
raise ValueError("Node ID is required")
rag_pipeline_service = RagPipelineService() rag_pipeline_service = RagPipelineService()
variables = rag_pipeline_service.get_first_step_parameters(pipeline=pipeline, node_id=node_id, is_draft=True) variables = rag_pipeline_service.get_first_step_parameters(pipeline=pipeline, node_id=node_id, is_draft=True)
return { return {
@ -823,7 +805,6 @@ class DraftRagPipelineFirstStepApi(Resource):
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/processing/parameters") @console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/processing/parameters")
class DraftRagPipelineSecondStepApi(Resource): class DraftRagPipelineSecondStepApi(Resource):
@console_ns.expect(parser_parameters)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -833,10 +814,8 @@ class DraftRagPipelineSecondStepApi(Resource):
""" """
Get second step parameters of rag pipeline Get second step parameters of rag pipeline
""" """
args = parser_parameters.parse_args() query = NodeIdQuery.model_validate(request.args.to_dict())
node_id = args.get("node_id") node_id = query.node_id
if not node_id:
raise ValueError("Node ID is required")
rag_pipeline_service = RagPipelineService() rag_pipeline_service = RagPipelineService()
variables = rag_pipeline_service.get_second_step_parameters(pipeline=pipeline, node_id=node_id, is_draft=True) variables = rag_pipeline_service.get_second_step_parameters(pipeline=pipeline, node_id=node_id, is_draft=True)
@ -845,16 +824,8 @@ class DraftRagPipelineSecondStepApi(Resource):
} }
parser_wf_run = (
reqparse.RequestParser()
.add_argument("last_id", type=uuid_value, location="args")
.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
)
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflow-runs") @console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflow-runs")
class RagPipelineWorkflowRunListApi(Resource): class RagPipelineWorkflowRunListApi(Resource):
@console_ns.expect(parser_wf_run)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -864,7 +835,16 @@ class RagPipelineWorkflowRunListApi(Resource):
""" """
Get workflow run list Get workflow run list
""" """
args = parser_wf_run.parse_args() query = WorkflowRunQuery.model_validate(
{
"last_id": request.args.get("last_id"),
"limit": request.args.get("limit", type=int, default=20),
}
)
args = {
"last_id": str(query.last_id) if query.last_id else None,
"limit": query.limit,
}
rag_pipeline_service = RagPipelineService() rag_pipeline_service = RagPipelineService()
result = rag_pipeline_service.get_rag_pipeline_paginate_workflow_runs(pipeline=pipeline, args=args) result = rag_pipeline_service.get_rag_pipeline_paginate_workflow_runs(pipeline=pipeline, args=args)
@ -964,18 +944,9 @@ class RagPipelineTransformApi(Resource):
return result return result
parser_var = (
reqparse.RequestParser()
.add_argument("datasource_type", type=str, required=True, location="json")
.add_argument("datasource_info", type=dict, required=True, location="json")
.add_argument("start_node_id", type=str, required=True, location="json")
.add_argument("start_node_title", type=str, required=True, location="json")
)
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/datasource/variables-inspect") @console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/datasource/variables-inspect")
class RagPipelineDatasourceVariableApi(Resource): class RagPipelineDatasourceVariableApi(Resource):
@console_ns.expect(parser_var) @console_ns.expect(console_ns.models[DatasourceVariablesPayload.__name__])
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -987,7 +958,7 @@ class RagPipelineDatasourceVariableApi(Resource):
Set datasource variables Set datasource variables
""" """
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
args = parser_var.parse_args() args = DatasourceVariablesPayload.model_validate(console_ns.payload or {}).model_dump()
rag_pipeline_service = RagPipelineService() rag_pipeline_service = RagPipelineService()
workflow_node_execution = rag_pipeline_service.set_datasource_variables( workflow_node_execution = rag_pipeline_service.set_datasource_variables(

View File

@ -1,5 +1,10 @@
from flask_restx import Resource, fields, reqparse from typing import Literal
from flask import request
from flask_restx import Resource
from pydantic import BaseModel
from controllers.common.schema import register_schema_models
from controllers.console import console_ns from controllers.console import console_ns
from controllers.console.datasets.error import WebsiteCrawlError from controllers.console.datasets.error import WebsiteCrawlError
from controllers.console.wraps import account_initialization_required, setup_required from controllers.console.wraps import account_initialization_required, setup_required
@ -7,48 +12,35 @@ from libs.login import login_required
from services.website_service import WebsiteCrawlApiRequest, WebsiteCrawlStatusApiRequest, WebsiteService from services.website_service import WebsiteCrawlApiRequest, WebsiteCrawlStatusApiRequest, WebsiteService
class WebsiteCrawlPayload(BaseModel):
provider: Literal["firecrawl", "watercrawl", "jinareader"]
url: str
options: dict[str, object]
class WebsiteCrawlStatusQuery(BaseModel):
provider: Literal["firecrawl", "watercrawl", "jinareader"]
register_schema_models(console_ns, WebsiteCrawlPayload, WebsiteCrawlStatusQuery)
@console_ns.route("/website/crawl") @console_ns.route("/website/crawl")
class WebsiteCrawlApi(Resource): class WebsiteCrawlApi(Resource):
@console_ns.doc("crawl_website") @console_ns.doc("crawl_website")
@console_ns.doc(description="Crawl website content") @console_ns.doc(description="Crawl website content")
@console_ns.expect( @console_ns.expect(console_ns.models[WebsiteCrawlPayload.__name__])
console_ns.model(
"WebsiteCrawlRequest",
{
"provider": fields.String(
required=True,
description="Crawl provider (firecrawl/watercrawl/jinareader)",
enum=["firecrawl", "watercrawl", "jinareader"],
),
"url": fields.String(required=True, description="URL to crawl"),
"options": fields.Raw(required=True, description="Crawl options"),
},
)
)
@console_ns.response(200, "Website crawl initiated successfully") @console_ns.response(200, "Website crawl initiated successfully")
@console_ns.response(400, "Invalid crawl parameters") @console_ns.response(400, "Invalid crawl parameters")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
parser = ( payload = WebsiteCrawlPayload.model_validate(console_ns.payload or {})
reqparse.RequestParser()
.add_argument(
"provider",
type=str,
choices=["firecrawl", "watercrawl", "jinareader"],
required=True,
nullable=True,
location="json",
)
.add_argument("url", type=str, required=True, nullable=True, location="json")
.add_argument("options", type=dict, required=True, nullable=True, location="json")
)
args = parser.parse_args()
# Create typed request and validate # Create typed request and validate
try: try:
api_request = WebsiteCrawlApiRequest.from_args(args) api_request = WebsiteCrawlApiRequest.from_args(payload.model_dump())
except ValueError as e: except ValueError as e:
raise WebsiteCrawlError(str(e)) raise WebsiteCrawlError(str(e))
@ -65,6 +57,7 @@ class WebsiteCrawlStatusApi(Resource):
@console_ns.doc("get_crawl_status") @console_ns.doc("get_crawl_status")
@console_ns.doc(description="Get website crawl status") @console_ns.doc(description="Get website crawl status")
@console_ns.doc(params={"job_id": "Crawl job ID", "provider": "Crawl provider (firecrawl/watercrawl/jinareader)"}) @console_ns.doc(params={"job_id": "Crawl job ID", "provider": "Crawl provider (firecrawl/watercrawl/jinareader)"})
@console_ns.expect(console_ns.models[WebsiteCrawlStatusQuery.__name__])
@console_ns.response(200, "Crawl status retrieved successfully") @console_ns.response(200, "Crawl status retrieved successfully")
@console_ns.response(404, "Crawl job not found") @console_ns.response(404, "Crawl job not found")
@console_ns.response(400, "Invalid provider") @console_ns.response(400, "Invalid provider")
@ -72,14 +65,11 @@ class WebsiteCrawlStatusApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, job_id: str): def get(self, job_id: str):
parser = reqparse.RequestParser().add_argument( args = WebsiteCrawlStatusQuery.model_validate(request.args.to_dict())
"provider", type=str, choices=["firecrawl", "watercrawl", "jinareader"], required=True, location="args"
)
args = parser.parse_args()
# Create typed request and validate # Create typed request and validate
try: try:
api_request = WebsiteCrawlStatusApiRequest.from_args(args, job_id) api_request = WebsiteCrawlStatusApiRequest.from_args(args.model_dump(), job_id)
except ValueError as e: except ValueError as e:
raise WebsiteCrawlError(str(e)) raise WebsiteCrawlError(str(e))

View File

@ -1,9 +1,11 @@
import logging import logging
from flask import request from flask import request
from pydantic import BaseModel, Field
from werkzeug.exceptions import InternalServerError from werkzeug.exceptions import InternalServerError
import services import services
from controllers.common.schema import register_schema_model
from controllers.console.app.error import ( from controllers.console.app.error import (
AppUnavailableError, AppUnavailableError,
AudioTooLargeError, AudioTooLargeError,
@ -31,6 +33,16 @@ from .. import console_ns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class TextToAudioPayload(BaseModel):
message_id: str | None = None
voice: str | None = None
text: str | None = None
streaming: bool | None = Field(default=None, description="Enable streaming response")
register_schema_model(console_ns, TextToAudioPayload)
@console_ns.route( @console_ns.route(
"/installed-apps/<uuid:installed_app_id>/audio-to-text", "/installed-apps/<uuid:installed_app_id>/audio-to-text",
endpoint="installed_app_audio", endpoint="installed_app_audio",
@ -76,23 +88,15 @@ class ChatAudioApi(InstalledAppResource):
endpoint="installed_app_text", endpoint="installed_app_text",
) )
class ChatTextApi(InstalledAppResource): class ChatTextApi(InstalledAppResource):
@console_ns.expect(console_ns.models[TextToAudioPayload.__name__])
def post(self, installed_app): def post(self, installed_app):
from flask_restx import reqparse
app_model = installed_app.app app_model = installed_app.app
try: try:
parser = ( payload = TextToAudioPayload.model_validate(console_ns.payload or {})
reqparse.RequestParser()
.add_argument("message_id", type=str, required=False, location="json")
.add_argument("voice", type=str, location="json")
.add_argument("text", type=str, location="json")
.add_argument("streaming", type=bool, location="json")
)
args = parser.parse_args()
message_id = args.get("message_id", None) message_id = payload.message_id
text = args.get("text", None) text = payload.text
voice = args.get("voice", None) voice = payload.voice
response = AudioService.transcript_tts(app_model=app_model, text=text, voice=voice, message_id=message_id) response = AudioService.transcript_tts(app_model=app_model, text=text, voice=voice, message_id=message_id)
return response return response

View File

@ -1,9 +1,12 @@
import logging import logging
from typing import Any, Literal
from uuid import UUID
from flask_restx import reqparse from pydantic import BaseModel, Field, field_validator
from werkzeug.exceptions import InternalServerError, NotFound from werkzeug.exceptions import InternalServerError, NotFound
import services import services
from controllers.common.schema import register_schema_models
from controllers.console.app.error import ( from controllers.console.app.error import (
AppUnavailableError, AppUnavailableError,
CompletionRequestError, CompletionRequestError,
@ -25,7 +28,6 @@ from core.model_runtime.errors.invoke import InvokeError
from extensions.ext_database import db from extensions.ext_database import db
from libs import helper from libs import helper
from libs.datetime_utils import naive_utc_now from libs.datetime_utils import naive_utc_now
from libs.helper import uuid_value
from libs.login import current_user from libs.login import current_user
from models import Account from models import Account
from models.model import AppMode from models.model import AppMode
@ -38,28 +40,56 @@ from .. import console_ns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class CompletionMessagePayload(BaseModel):
inputs: dict[str, Any]
query: str = ""
files: list[dict[str, Any]] | None = None
response_mode: Literal["blocking", "streaming"] | None = None
retriever_from: str = Field(default="explore_app")
class ChatMessagePayload(BaseModel):
inputs: dict[str, Any]
query: str
files: list[dict[str, Any]] | None = None
conversation_id: str | None = None
parent_message_id: str | None = None
retriever_from: str = Field(default="explore_app")
@field_validator("conversation_id", "parent_message_id", mode="before")
@classmethod
def normalize_uuid(cls, value: str | UUID | None) -> str | None:
"""
Accept blank IDs and validate UUID format when provided.
"""
if not value:
return None
try:
return helper.uuid_value(value)
except ValueError as exc:
raise ValueError("must be a valid UUID") from exc
register_schema_models(console_ns, CompletionMessagePayload, ChatMessagePayload)
# define completion api for user # define completion api for user
@console_ns.route( @console_ns.route(
"/installed-apps/<uuid:installed_app_id>/completion-messages", "/installed-apps/<uuid:installed_app_id>/completion-messages",
endpoint="installed_app_completion", endpoint="installed_app_completion",
) )
class CompletionApi(InstalledAppResource): class CompletionApi(InstalledAppResource):
@console_ns.expect(console_ns.models[CompletionMessagePayload.__name__])
def post(self, installed_app): def post(self, installed_app):
app_model = installed_app.app app_model = installed_app.app
if app_model.mode != AppMode.COMPLETION: if app_model.mode != AppMode.COMPLETION:
raise NotCompletionAppError() raise NotCompletionAppError()
parser = ( payload = CompletionMessagePayload.model_validate(console_ns.payload or {})
reqparse.RequestParser() args = payload.model_dump(exclude_none=True)
.add_argument("inputs", type=dict, required=True, location="json")
.add_argument("query", type=str, location="json", default="")
.add_argument("files", type=list, required=False, location="json")
.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json")
)
args = parser.parse_args()
streaming = args["response_mode"] == "streaming" streaming = payload.response_mode == "streaming"
args["auto_generate_name"] = False args["auto_generate_name"] = False
installed_app.last_used_at = naive_utc_now() installed_app.last_used_at = naive_utc_now()
@ -123,22 +153,15 @@ class CompletionStopApi(InstalledAppResource):
endpoint="installed_app_chat_completion", endpoint="installed_app_chat_completion",
) )
class ChatApi(InstalledAppResource): class ChatApi(InstalledAppResource):
@console_ns.expect(console_ns.models[ChatMessagePayload.__name__])
def post(self, installed_app): def post(self, installed_app):
app_model = installed_app.app app_model = installed_app.app
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError() raise NotChatAppError()
parser = ( payload = ChatMessagePayload.model_validate(console_ns.payload or {})
reqparse.RequestParser() args = payload.model_dump(exclude_none=True)
.add_argument("inputs", type=dict, required=True, location="json")
.add_argument("query", type=str, required=True, location="json")
.add_argument("files", type=list, required=False, location="json")
.add_argument("conversation_id", type=uuid_value, location="json")
.add_argument("parent_message_id", type=uuid_value, required=False, location="json")
.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json")
)
args = parser.parse_args()
args["auto_generate_name"] = False args["auto_generate_name"] = False

View File

@ -1,14 +1,18 @@
from flask_restx import marshal_with, reqparse from typing import Any
from flask_restx.inputs import int_range from uuid import UUID
from flask import request
from flask_restx import marshal_with
from pydantic import BaseModel, Field, model_validator
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
from controllers.common.schema import register_schema_models
from controllers.console.explore.error import NotChatAppError from controllers.console.explore.error import NotChatAppError
from controllers.console.explore.wraps import InstalledAppResource from controllers.console.explore.wraps import InstalledAppResource
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db from extensions.ext_database import db
from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields
from libs.helper import uuid_value
from libs.login import current_user from libs.login import current_user
from models import Account from models import Account
from models.model import AppMode from models.model import AppMode
@ -19,29 +23,51 @@ from services.web_conversation_service import WebConversationService
from .. import console_ns from .. import console_ns
class ConversationListQuery(BaseModel):
last_id: UUID | None = None
limit: int = Field(default=20, ge=1, le=100)
pinned: bool | None = None
class ConversationRenamePayload(BaseModel):
name: str | None = None
auto_generate: bool = False
@model_validator(mode="after")
def validate_name_requirement(self):
if not self.auto_generate:
if self.name is None or not self.name.strip():
raise ValueError("name is required when auto_generate is false")
return self
register_schema_models(console_ns, ConversationListQuery, ConversationRenamePayload)
@console_ns.route( @console_ns.route(
"/installed-apps/<uuid:installed_app_id>/conversations", "/installed-apps/<uuid:installed_app_id>/conversations",
endpoint="installed_app_conversations", endpoint="installed_app_conversations",
) )
class ConversationListApi(InstalledAppResource): class ConversationListApi(InstalledAppResource):
@marshal_with(conversation_infinite_scroll_pagination_fields) @marshal_with(conversation_infinite_scroll_pagination_fields)
@console_ns.expect(console_ns.models[ConversationListQuery.__name__])
def get(self, installed_app): def get(self, installed_app):
app_model = installed_app.app app_model = installed_app.app
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError() raise NotChatAppError()
parser = ( raw_args: dict[str, Any] = {
reqparse.RequestParser() "last_id": request.args.get("last_id"),
.add_argument("last_id", type=uuid_value, location="args") "limit": request.args.get("limit", default=20, type=int),
.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") "pinned": request.args.get("pinned"),
.add_argument("pinned", type=str, choices=["true", "false", None], location="args") }
) if raw_args["last_id"] is None:
args = parser.parse_args() raw_args["last_id"] = None
pinned_value = raw_args["pinned"]
pinned = None if isinstance(pinned_value, str):
if "pinned" in args and args["pinned"] is not None: raw_args["pinned"] = pinned_value == "true"
pinned = args["pinned"] == "true" args = ConversationListQuery.model_validate(raw_args)
try: try:
if not isinstance(current_user, Account): if not isinstance(current_user, Account):
@ -51,10 +77,10 @@ class ConversationListApi(InstalledAppResource):
session=session, session=session,
app_model=app_model, app_model=app_model,
user=current_user, user=current_user,
last_id=args["last_id"], last_id=str(args.last_id) if args.last_id else None,
limit=args["limit"], limit=args.limit,
invoke_from=InvokeFrom.EXPLORE, invoke_from=InvokeFrom.EXPLORE,
pinned=pinned, pinned=args.pinned,
) )
except LastConversationNotExistsError: except LastConversationNotExistsError:
raise NotFound("Last Conversation Not Exists.") raise NotFound("Last Conversation Not Exists.")
@ -88,6 +114,7 @@ class ConversationApi(InstalledAppResource):
) )
class ConversationRenameApi(InstalledAppResource): class ConversationRenameApi(InstalledAppResource):
@marshal_with(simple_conversation_fields) @marshal_with(simple_conversation_fields)
@console_ns.expect(console_ns.models[ConversationRenamePayload.__name__])
def post(self, installed_app, c_id): def post(self, installed_app, c_id):
app_model = installed_app.app app_model = installed_app.app
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
@ -96,18 +123,13 @@ class ConversationRenameApi(InstalledAppResource):
conversation_id = str(c_id) conversation_id = str(c_id)
parser = ( payload = ConversationRenamePayload.model_validate(console_ns.payload or {})
reqparse.RequestParser()
.add_argument("name", type=str, required=False, location="json")
.add_argument("auto_generate", type=bool, required=False, default=False, location="json")
)
args = parser.parse_args()
try: try:
if not isinstance(current_user, Account): if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance") raise ValueError("current_user must be an Account instance")
return ConversationService.rename( return ConversationService.rename(
app_model, conversation_id, current_user, args["name"], args["auto_generate"] app_model, conversation_id, current_user, payload.name, payload.auto_generate
) )
except ConversationNotExistsError: except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.") raise NotFound("Conversation Not Exists.")

View File

@ -1,9 +1,13 @@
import logging import logging
from typing import Literal
from uuid import UUID
from flask_restx import marshal_with, reqparse from flask import request
from flask_restx.inputs import int_range from flask_restx import marshal_with
from pydantic import BaseModel, Field
from werkzeug.exceptions import InternalServerError, NotFound from werkzeug.exceptions import InternalServerError, NotFound
from controllers.common.schema import register_schema_models
from controllers.console.app.error import ( from controllers.console.app.error import (
AppMoreLikeThisDisabledError, AppMoreLikeThisDisabledError,
CompletionRequestError, CompletionRequestError,
@ -22,7 +26,6 @@ from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotIni
from core.model_runtime.errors.invoke import InvokeError from core.model_runtime.errors.invoke import InvokeError
from fields.message_fields import message_infinite_scroll_pagination_fields from fields.message_fields import message_infinite_scroll_pagination_fields
from libs import helper from libs import helper
from libs.helper import uuid_value
from libs.login import current_account_with_tenant from libs.login import current_account_with_tenant
from models.model import AppMode from models.model import AppMode
from services.app_generate_service import AppGenerateService from services.app_generate_service import AppGenerateService
@ -40,12 +43,31 @@ from .. import console_ns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class MessageListQuery(BaseModel):
conversation_id: UUID
first_id: UUID | None = None
limit: int = Field(default=20, ge=1, le=100)
class MessageFeedbackPayload(BaseModel):
rating: Literal["like", "dislike"] | None = None
content: str | None = None
class MoreLikeThisQuery(BaseModel):
response_mode: Literal["blocking", "streaming"]
register_schema_models(console_ns, MessageListQuery, MessageFeedbackPayload, MoreLikeThisQuery)
@console_ns.route( @console_ns.route(
"/installed-apps/<uuid:installed_app_id>/messages", "/installed-apps/<uuid:installed_app_id>/messages",
endpoint="installed_app_messages", endpoint="installed_app_messages",
) )
class MessageListApi(InstalledAppResource): class MessageListApi(InstalledAppResource):
@marshal_with(message_infinite_scroll_pagination_fields) @marshal_with(message_infinite_scroll_pagination_fields)
@console_ns.expect(console_ns.models[MessageListQuery.__name__])
def get(self, installed_app): def get(self, installed_app):
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
app_model = installed_app.app app_model = installed_app.app
@ -53,18 +75,15 @@ class MessageListApi(InstalledAppResource):
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError() raise NotChatAppError()
args = MessageListQuery.model_validate(request.args.to_dict())
parser = (
reqparse.RequestParser()
.add_argument("conversation_id", required=True, type=uuid_value, location="args")
.add_argument("first_id", type=uuid_value, location="args")
.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
)
args = parser.parse_args()
try: try:
return MessageService.pagination_by_first_id( return MessageService.pagination_by_first_id(
app_model, current_user, args["conversation_id"], args["first_id"], args["limit"] app_model,
current_user,
str(args.conversation_id),
str(args.first_id) if args.first_id else None,
args.limit,
) )
except ConversationNotExistsError: except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.") raise NotFound("Conversation Not Exists.")
@ -77,26 +96,22 @@ class MessageListApi(InstalledAppResource):
endpoint="installed_app_message_feedback", endpoint="installed_app_message_feedback",
) )
class MessageFeedbackApi(InstalledAppResource): class MessageFeedbackApi(InstalledAppResource):
@console_ns.expect(console_ns.models[MessageFeedbackPayload.__name__])
def post(self, installed_app, message_id): def post(self, installed_app, message_id):
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
app_model = installed_app.app app_model = installed_app.app
message_id = str(message_id) message_id = str(message_id)
parser = ( payload = MessageFeedbackPayload.model_validate(console_ns.payload or {})
reqparse.RequestParser()
.add_argument("rating", type=str, choices=["like", "dislike", None], location="json")
.add_argument("content", type=str, location="json")
)
args = parser.parse_args()
try: try:
MessageService.create_feedback( MessageService.create_feedback(
app_model=app_model, app_model=app_model,
message_id=message_id, message_id=message_id,
user=current_user, user=current_user,
rating=args.get("rating"), rating=payload.rating,
content=args.get("content"), content=payload.content,
) )
except MessageNotExistsError: except MessageNotExistsError:
raise NotFound("Message Not Exists.") raise NotFound("Message Not Exists.")
@ -109,6 +124,7 @@ class MessageFeedbackApi(InstalledAppResource):
endpoint="installed_app_more_like_this", endpoint="installed_app_more_like_this",
) )
class MessageMoreLikeThisApi(InstalledAppResource): class MessageMoreLikeThisApi(InstalledAppResource):
@console_ns.expect(console_ns.models[MoreLikeThisQuery.__name__])
def get(self, installed_app, message_id): def get(self, installed_app, message_id):
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
app_model = installed_app.app app_model = installed_app.app
@ -117,12 +133,9 @@ class MessageMoreLikeThisApi(InstalledAppResource):
message_id = str(message_id) message_id = str(message_id)
parser = reqparse.RequestParser().add_argument( args = MoreLikeThisQuery.model_validate(request.args.to_dict())
"response_mode", type=str, required=True, choices=["blocking", "streaming"], location="args"
)
args = parser.parse_args()
streaming = args["response_mode"] == "streaming" streaming = args.response_mode == "streaming"
try: try:
response = AppGenerateService.generate_more_like_this( response = AppGenerateService.generate_more_like_this(

View File

@ -1,4 +1,6 @@
from flask_restx import Resource, fields, marshal_with, reqparse from flask import request
from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field
from constants.languages import languages from constants.languages import languages
from controllers.console import console_ns from controllers.console import console_ns
@ -35,20 +37,26 @@ recommended_app_list_fields = {
} }
parser_apps = reqparse.RequestParser().add_argument("language", type=str, location="args") class RecommendedAppsQuery(BaseModel):
language: str | None = Field(default=None)
console_ns.schema_model(
RecommendedAppsQuery.__name__,
RecommendedAppsQuery.model_json_schema(ref_template="#/definitions/{model}"),
)
@console_ns.route("/explore/apps") @console_ns.route("/explore/apps")
class RecommendedAppListApi(Resource): class RecommendedAppListApi(Resource):
@console_ns.expect(parser_apps) @console_ns.expect(console_ns.models[RecommendedAppsQuery.__name__])
@login_required @login_required
@account_initialization_required @account_initialization_required
@marshal_with(recommended_app_list_fields) @marshal_with(recommended_app_list_fields)
def get(self): def get(self):
# language args # language args
args = parser_apps.parse_args() args = RecommendedAppsQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
language = args.language
language = args.get("language")
if language and language in languages: if language and language in languages:
language_prefix = language language_prefix = language
elif current_user and current_user.interface_language: elif current_user and current_user.interface_language:

View File

@ -1,16 +1,33 @@
from flask_restx import fields, marshal_with, reqparse from uuid import UUID
from flask_restx.inputs import int_range
from flask import request
from flask_restx import fields, marshal_with
from pydantic import BaseModel, Field
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
from controllers.common.schema import register_schema_models
from controllers.console import console_ns from controllers.console import console_ns
from controllers.console.explore.error import NotCompletionAppError from controllers.console.explore.error import NotCompletionAppError
from controllers.console.explore.wraps import InstalledAppResource from controllers.console.explore.wraps import InstalledAppResource
from fields.conversation_fields import message_file_fields from fields.conversation_fields import message_file_fields
from libs.helper import TimestampField, uuid_value from libs.helper import TimestampField
from libs.login import current_account_with_tenant from libs.login import current_account_with_tenant
from services.errors.message import MessageNotExistsError from services.errors.message import MessageNotExistsError
from services.saved_message_service import SavedMessageService from services.saved_message_service import SavedMessageService
class SavedMessageListQuery(BaseModel):
last_id: UUID | None = None
limit: int = Field(default=20, ge=1, le=100)
class SavedMessageCreatePayload(BaseModel):
message_id: UUID
register_schema_models(console_ns, SavedMessageListQuery, SavedMessageCreatePayload)
feedback_fields = {"rating": fields.String} feedback_fields = {"rating": fields.String}
message_fields = { message_fields = {
@ -33,32 +50,33 @@ class SavedMessageListApi(InstalledAppResource):
} }
@marshal_with(saved_message_infinite_scroll_pagination_fields) @marshal_with(saved_message_infinite_scroll_pagination_fields)
@console_ns.expect(console_ns.models[SavedMessageListQuery.__name__])
def get(self, installed_app): def get(self, installed_app):
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
app_model = installed_app.app app_model = installed_app.app
if app_model.mode != "completion": if app_model.mode != "completion":
raise NotCompletionAppError() raise NotCompletionAppError()
parser = ( args = SavedMessageListQuery.model_validate(request.args.to_dict())
reqparse.RequestParser()
.add_argument("last_id", type=uuid_value, location="args") return SavedMessageService.pagination_by_last_id(
.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") app_model,
current_user,
str(args.last_id) if args.last_id else None,
args.limit,
) )
args = parser.parse_args()
return SavedMessageService.pagination_by_last_id(app_model, current_user, args["last_id"], args["limit"])
@console_ns.expect(console_ns.models[SavedMessageCreatePayload.__name__])
def post(self, installed_app): def post(self, installed_app):
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
app_model = installed_app.app app_model = installed_app.app
if app_model.mode != "completion": if app_model.mode != "completion":
raise NotCompletionAppError() raise NotCompletionAppError()
parser = reqparse.RequestParser().add_argument("message_id", type=uuid_value, required=True, location="json") payload = SavedMessageCreatePayload.model_validate(console_ns.payload or {})
args = parser.parse_args()
try: try:
SavedMessageService.save(app_model, current_user, args["message_id"]) SavedMessageService.save(app_model, current_user, str(payload.message_id))
except MessageNotExistsError: except MessageNotExistsError:
raise NotFound("Message Not Exists.") raise NotFound("Message Not Exists.")

View File

@ -1,8 +1,10 @@
import logging import logging
from typing import Any
from flask_restx import reqparse from pydantic import BaseModel
from werkzeug.exceptions import InternalServerError from werkzeug.exceptions import InternalServerError
from controllers.common.schema import register_schema_model
from controllers.console.app.error import ( from controllers.console.app.error import (
CompletionRequestError, CompletionRequestError,
ProviderModelCurrentlyNotSupportError, ProviderModelCurrentlyNotSupportError,
@ -32,8 +34,17 @@ from .. import console_ns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class WorkflowRunPayload(BaseModel):
inputs: dict[str, Any]
files: list[dict[str, Any]] | None = None
register_schema_model(console_ns, WorkflowRunPayload)
@console_ns.route("/installed-apps/<uuid:installed_app_id>/workflows/run") @console_ns.route("/installed-apps/<uuid:installed_app_id>/workflows/run")
class InstalledAppWorkflowRunApi(InstalledAppResource): class InstalledAppWorkflowRunApi(InstalledAppResource):
@console_ns.expect(console_ns.models[WorkflowRunPayload.__name__])
def post(self, installed_app: InstalledApp): def post(self, installed_app: InstalledApp):
""" """
Run workflow Run workflow
@ -46,12 +57,8 @@ class InstalledAppWorkflowRunApi(InstalledAppResource):
if app_mode != AppMode.WORKFLOW: if app_mode != AppMode.WORKFLOW:
raise NotWorkflowAppError() raise NotWorkflowAppError()
parser = ( payload = WorkflowRunPayload.model_validate(console_ns.payload or {})
reqparse.RequestParser() args = payload.model_dump(exclude_none=True)
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
.add_argument("files", type=list, required=False, location="json")
)
args = parser.parse_args()
try: try:
response = AppGenerateService.generate( response = AppGenerateService.generate(
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=True app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=True

View File

@ -45,6 +45,9 @@ class FileApi(Resource):
"video_file_size_limit": dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT, "video_file_size_limit": dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT,
"audio_file_size_limit": dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT, "audio_file_size_limit": dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT,
"workflow_file_upload_limit": dify_config.WORKFLOW_FILE_UPLOAD_LIMIT, "workflow_file_upload_limit": dify_config.WORKFLOW_FILE_UPLOAD_LIMIT,
"image_file_batch_limit": dify_config.IMAGE_FILE_BATCH_LIMIT,
"single_chunk_attachment_limit": dify_config.SINGLE_CHUNK_ATTACHMENT_LIMIT,
"attachment_image_file_size_limit": dify_config.ATTACHMENT_IMAGE_FILE_SIZE_LIMIT,
}, 200 }, 200
@setup_required @setup_required

View File

@ -1,13 +1,13 @@
import os import os
from flask import session from flask import session
from flask_restx import Resource, fields, reqparse from flask_restx import Resource, fields
from pydantic import BaseModel, Field
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from configs import dify_config from configs import dify_config
from extensions.ext_database import db from extensions.ext_database import db
from libs.helper import StrLen
from models.model import DifySetup from models.model import DifySetup
from services.account_service import TenantService from services.account_service import TenantService
@ -15,6 +15,18 @@ from . import console_ns
from .error import AlreadySetupError, InitValidateFailedError from .error import AlreadySetupError, InitValidateFailedError
from .wraps import only_edition_self_hosted from .wraps import only_edition_self_hosted
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class InitValidatePayload(BaseModel):
password: str = Field(..., max_length=30)
console_ns.schema_model(
InitValidatePayload.__name__,
InitValidatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
@console_ns.route("/init") @console_ns.route("/init")
class InitValidateAPI(Resource): class InitValidateAPI(Resource):
@ -37,12 +49,7 @@ class InitValidateAPI(Resource):
@console_ns.doc("validate_init_password") @console_ns.doc("validate_init_password")
@console_ns.doc(description="Validate initialization password for self-hosted edition") @console_ns.doc(description="Validate initialization password for self-hosted edition")
@console_ns.expect( @console_ns.expect(console_ns.models[InitValidatePayload.__name__])
console_ns.model(
"InitValidateRequest",
{"password": fields.String(required=True, description="Initialization password", max_length=30)},
)
)
@console_ns.response( @console_ns.response(
201, 201,
"Success", "Success",
@ -57,8 +64,8 @@ class InitValidateAPI(Resource):
if tenant_count > 0: if tenant_count > 0:
raise AlreadySetupError() raise AlreadySetupError()
parser = reqparse.RequestParser().add_argument("password", type=StrLen(30), required=True, location="json") payload = InitValidatePayload.model_validate(console_ns.payload)
input_password = parser.parse_args()["password"] input_password = payload.password
if input_password != os.environ.get("INIT_PASSWORD"): if input_password != os.environ.get("INIT_PASSWORD"):
session["is_init_validated"] = False session["is_init_validated"] = False

View File

@ -1,7 +1,8 @@
import urllib.parse import urllib.parse
import httpx import httpx
from flask_restx import Resource, marshal_with, reqparse from flask_restx import Resource, marshal_with
from pydantic import BaseModel, Field
import services import services
from controllers.common import helpers from controllers.common import helpers
@ -36,17 +37,23 @@ class RemoteFileInfoApi(Resource):
} }
parser_upload = reqparse.RequestParser().add_argument("url", type=str, required=True, help="URL is required") class RemoteFileUploadPayload(BaseModel):
url: str = Field(..., description="URL to fetch")
console_ns.schema_model(
RemoteFileUploadPayload.__name__,
RemoteFileUploadPayload.model_json_schema(ref_template="#/definitions/{model}"),
)
@console_ns.route("/remote-files/upload") @console_ns.route("/remote-files/upload")
class RemoteFileUploadApi(Resource): class RemoteFileUploadApi(Resource):
@console_ns.expect(parser_upload) @console_ns.expect(console_ns.models[RemoteFileUploadPayload.__name__])
@marshal_with(file_fields_with_signed_url) @marshal_with(file_fields_with_signed_url)
def post(self): def post(self):
args = parser_upload.parse_args() args = RemoteFileUploadPayload.model_validate(console_ns.payload)
url = args.url
url = args["url"]
try: try:
resp = ssrf_proxy.head(url=url) resp = ssrf_proxy.head(url=url)

View File

@ -1,8 +1,9 @@
from flask import request from flask import request
from flask_restx import Resource, fields, reqparse from flask_restx import Resource, fields
from pydantic import BaseModel, Field, field_validator
from configs import dify_config from configs import dify_config
from libs.helper import StrLen, email, extract_remote_ip from libs.helper import EmailStr, extract_remote_ip
from libs.password import valid_password from libs.password import valid_password
from models.model import DifySetup, db from models.model import DifySetup, db
from services.account_service import RegisterService, TenantService from services.account_service import RegisterService, TenantService
@ -12,6 +13,26 @@ from .error import AlreadySetupError, NotInitValidateError
from .init_validate import get_init_validate_status from .init_validate import get_init_validate_status
from .wraps import only_edition_self_hosted from .wraps import only_edition_self_hosted
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class SetupRequestPayload(BaseModel):
email: EmailStr = Field(..., description="Admin email address")
name: str = Field(..., max_length=30, description="Admin name (max 30 characters)")
password: str = Field(..., description="Admin password")
language: str | None = Field(default=None, description="Admin language")
@field_validator("password")
@classmethod
def validate_password(cls, value: str) -> str:
return valid_password(value)
console_ns.schema_model(
SetupRequestPayload.__name__,
SetupRequestPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
@console_ns.route("/setup") @console_ns.route("/setup")
class SetupApi(Resource): class SetupApi(Resource):
@ -42,17 +63,7 @@ class SetupApi(Resource):
@console_ns.doc("setup_system") @console_ns.doc("setup_system")
@console_ns.doc(description="Initialize system setup with admin account") @console_ns.doc(description="Initialize system setup with admin account")
@console_ns.expect( @console_ns.expect(console_ns.models[SetupRequestPayload.__name__])
console_ns.model(
"SetupRequest",
{
"email": fields.String(required=True, description="Admin email address"),
"name": fields.String(required=True, description="Admin name (max 30 characters)"),
"password": fields.String(required=True, description="Admin password"),
"language": fields.String(required=False, description="Admin language"),
},
)
)
@console_ns.response( @console_ns.response(
201, "Success", console_ns.model("SetupResponse", {"result": fields.String(description="Setup result")}) 201, "Success", console_ns.model("SetupResponse", {"result": fields.String(description="Setup result")})
) )
@ -72,22 +83,15 @@ class SetupApi(Resource):
if not get_init_validate_status(): if not get_init_validate_status():
raise NotInitValidateError() raise NotInitValidateError()
parser = ( args = SetupRequestPayload.model_validate(console_ns.payload)
reqparse.RequestParser()
.add_argument("email", type=email, required=True, location="json")
.add_argument("name", type=StrLen(30), required=True, location="json")
.add_argument("password", type=valid_password, required=True, location="json")
.add_argument("language", type=str, required=False, location="json")
)
args = parser.parse_args()
# setup # setup
RegisterService.setup( RegisterService.setup(
email=args["email"], email=args.email,
name=args["name"], name=args.name,
password=args["password"], password=args.password,
ip_address=extract_remote_ip(request), ip_address=extract_remote_ip(request),
language=args["language"], language=args.language,
) )
return {"result": "success"}, 201 return {"result": "success"}, 201

View File

@ -2,8 +2,10 @@ import json
import logging import logging
import httpx import httpx
from flask_restx import Resource, fields, reqparse from flask import request
from flask_restx import Resource, fields
from packaging import version from packaging import version
from pydantic import BaseModel, Field
from configs import dify_config from configs import dify_config
@ -11,8 +13,14 @@ from . import console_ns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
parser = reqparse.RequestParser().add_argument(
"current_version", type=str, required=True, location="args", help="Current application version" class VersionQuery(BaseModel):
current_version: str = Field(..., description="Current application version")
console_ns.schema_model(
VersionQuery.__name__,
VersionQuery.model_json_schema(ref_template="#/definitions/{model}"),
) )
@ -20,7 +28,7 @@ parser = reqparse.RequestParser().add_argument(
class VersionApi(Resource): class VersionApi(Resource):
@console_ns.doc("check_version_update") @console_ns.doc("check_version_update")
@console_ns.doc(description="Check for application version updates") @console_ns.doc(description="Check for application version updates")
@console_ns.expect(parser) @console_ns.expect(console_ns.models[VersionQuery.__name__])
@console_ns.response( @console_ns.response(
200, 200,
"Success", "Success",
@ -37,7 +45,7 @@ class VersionApi(Resource):
) )
def get(self): def get(self):
"""Check for application version updates""" """Check for application version updates"""
args = parser.parse_args() args = VersionQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
check_update_url = dify_config.CHECK_UPDATE_URL check_update_url = dify_config.CHECK_UPDATE_URL
result = { result = {
@ -57,16 +65,16 @@ class VersionApi(Resource):
try: try:
response = httpx.get( response = httpx.get(
check_update_url, check_update_url,
params={"current_version": args["current_version"]}, params={"current_version": args.current_version},
timeout=httpx.Timeout(timeout=10.0, connect=3.0), timeout=httpx.Timeout(timeout=10.0, connect=3.0),
) )
except Exception as error: except Exception as error:
logger.warning("Check update version error: %s.", str(error)) logger.warning("Check update version error: %s.", str(error))
result["version"] = args["current_version"] result["version"] = args.current_version
return result return result
content = json.loads(response.content) content = json.loads(response.content)
if _has_new_version(latest_version=content["version"], current_version=f"{args['current_version']}"): if _has_new_version(latest_version=content["version"], current_version=f"{args.current_version}"):
result["version"] = content["version"] result["version"] = content["version"]
result["release_date"] = content["releaseDate"] result["release_date"] = content["releaseDate"]
result["release_notes"] = content["releaseNotes"] result["release_notes"] = content["releaseNotes"]

View File

@ -37,7 +37,7 @@ from controllers.console.wraps import (
from extensions.ext_database import db from extensions.ext_database import db
from fields.member_fields import account_fields from fields.member_fields import account_fields
from libs.datetime_utils import naive_utc_now from libs.datetime_utils import naive_utc_now
from libs.helper import TimestampField, email, extract_remote_ip, timezone from libs.helper import EmailStr, TimestampField, extract_remote_ip, timezone
from libs.login import current_account_with_tenant, login_required from libs.login import current_account_with_tenant, login_required
from models import Account, AccountIntegrate, InvitationCode from models import Account, AccountIntegrate, InvitationCode
from services.account_service import AccountService from services.account_service import AccountService
@ -111,14 +111,9 @@ class AccountDeletePayload(BaseModel):
class AccountDeletionFeedbackPayload(BaseModel): class AccountDeletionFeedbackPayload(BaseModel):
email: str email: EmailStr
feedback: str feedback: str
@field_validator("email")
@classmethod
def validate_email(cls, value: str) -> str:
return email(value)
class EducationActivatePayload(BaseModel): class EducationActivatePayload(BaseModel):
token: str token: str
@ -133,45 +128,25 @@ class EducationAutocompleteQuery(BaseModel):
class ChangeEmailSendPayload(BaseModel): class ChangeEmailSendPayload(BaseModel):
email: str email: EmailStr
language: str | None = None language: str | None = None
phase: str | None = None phase: str | None = None
token: str | None = None token: str | None = None
@field_validator("email")
@classmethod
def validate_email(cls, value: str) -> str:
return email(value)
class ChangeEmailValidityPayload(BaseModel): class ChangeEmailValidityPayload(BaseModel):
email: str email: EmailStr
code: str code: str
token: str token: str
@field_validator("email")
@classmethod
def validate_email(cls, value: str) -> str:
return email(value)
class ChangeEmailResetPayload(BaseModel): class ChangeEmailResetPayload(BaseModel):
new_email: str new_email: EmailStr
token: str token: str
@field_validator("new_email")
@classmethod
def validate_email(cls, value: str) -> str:
return email(value)
class CheckEmailUniquePayload(BaseModel): class CheckEmailUniquePayload(BaseModel):
email: str email: EmailStr
@field_validator("email")
@classmethod
def validate_email(cls, value: str) -> str:
return email(value)
def reg(cls: type[BaseModel]): def reg(cls: type[BaseModel]):

View File

@ -230,7 +230,7 @@ class ModelProviderModelApi(Resource):
return {"result": "success"}, 200 return {"result": "success"}, 200
@console_ns.expect(console_ns.models[ParserDeleteModels.__name__], validate=True) @console_ns.expect(console_ns.models[ParserDeleteModels.__name__])
@setup_required @setup_required
@login_required @login_required
@is_admin_or_owner_required @is_admin_or_owner_required
@ -282,9 +282,10 @@ class ModelProviderModelCredentialApi(Resource):
tenant_id=tenant_id, provider_name=provider tenant_id=tenant_id, provider_name=provider
) )
else: else:
model_type = args.model_type # Normalize model_type to the origin value stored in DB (e.g., "text-generation" for LLM)
normalized_model_type = args.model_type.to_origin_model_type()
available_credentials = model_provider_service.provider_manager.get_provider_model_available_credentials( available_credentials = model_provider_service.provider_manager.get_provider_model_available_credentials(
tenant_id=tenant_id, provider_name=provider, model_type=model_type, model_name=args.model tenant_id=tenant_id, provider_name=provider, model_type=normalized_model_type, model_name=args.model
) )
return jsonable_encoder( return jsonable_encoder(

View File

@ -22,7 +22,12 @@ from services.trigger.trigger_subscription_builder_service import TriggerSubscri
from services.trigger.trigger_subscription_operator_service import TriggerSubscriptionOperatorService from services.trigger.trigger_subscription_operator_service import TriggerSubscriptionOperatorService
from .. import console_ns from .. import console_ns
from ..wraps import account_initialization_required, is_admin_or_owner_required, setup_required from ..wraps import (
account_initialization_required,
edit_permission_required,
is_admin_or_owner_required,
setup_required,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -72,7 +77,7 @@ class TriggerProviderInfoApi(Resource):
class TriggerSubscriptionListApi(Resource): class TriggerSubscriptionListApi(Resource):
@setup_required @setup_required
@login_required @login_required
@is_admin_or_owner_required @edit_permission_required
@account_initialization_required @account_initialization_required
def get(self, provider): def get(self, provider):
"""List all trigger subscriptions for the current tenant's provider""" """List all trigger subscriptions for the current tenant's provider"""
@ -104,7 +109,7 @@ class TriggerSubscriptionBuilderCreateApi(Resource):
@console_ns.expect(parser) @console_ns.expect(parser)
@setup_required @setup_required
@login_required @login_required
@is_admin_or_owner_required @edit_permission_required
@account_initialization_required @account_initialization_required
def post(self, provider): def post(self, provider):
"""Add a new subscription instance for a trigger provider""" """Add a new subscription instance for a trigger provider"""
@ -133,6 +138,7 @@ class TriggerSubscriptionBuilderCreateApi(Resource):
class TriggerSubscriptionBuilderGetApi(Resource): class TriggerSubscriptionBuilderGetApi(Resource):
@setup_required @setup_required
@login_required @login_required
@edit_permission_required
@account_initialization_required @account_initialization_required
def get(self, provider, subscription_builder_id): def get(self, provider, subscription_builder_id):
"""Get a subscription instance for a trigger provider""" """Get a subscription instance for a trigger provider"""
@ -155,7 +161,7 @@ class TriggerSubscriptionBuilderVerifyApi(Resource):
@console_ns.expect(parser_api) @console_ns.expect(parser_api)
@setup_required @setup_required
@login_required @login_required
@is_admin_or_owner_required @edit_permission_required
@account_initialization_required @account_initialization_required
def post(self, provider, subscription_builder_id): def post(self, provider, subscription_builder_id):
"""Verify a subscription instance for a trigger provider""" """Verify a subscription instance for a trigger provider"""
@ -200,6 +206,7 @@ class TriggerSubscriptionBuilderUpdateApi(Resource):
@console_ns.expect(parser_update_api) @console_ns.expect(parser_update_api)
@setup_required @setup_required
@login_required @login_required
@edit_permission_required
@account_initialization_required @account_initialization_required
def post(self, provider, subscription_builder_id): def post(self, provider, subscription_builder_id):
"""Update a subscription instance for a trigger provider""" """Update a subscription instance for a trigger provider"""
@ -233,6 +240,7 @@ class TriggerSubscriptionBuilderUpdateApi(Resource):
class TriggerSubscriptionBuilderLogsApi(Resource): class TriggerSubscriptionBuilderLogsApi(Resource):
@setup_required @setup_required
@login_required @login_required
@edit_permission_required
@account_initialization_required @account_initialization_required
def get(self, provider, subscription_builder_id): def get(self, provider, subscription_builder_id):
"""Get the request logs for a subscription instance for a trigger provider""" """Get the request logs for a subscription instance for a trigger provider"""
@ -255,7 +263,7 @@ class TriggerSubscriptionBuilderBuildApi(Resource):
@console_ns.expect(parser_update_api) @console_ns.expect(parser_update_api)
@setup_required @setup_required
@login_required @login_required
@is_admin_or_owner_required @edit_permission_required
@account_initialization_required @account_initialization_required
def post(self, provider, subscription_builder_id): def post(self, provider, subscription_builder_id):
"""Build a subscription instance for a trigger provider""" """Build a subscription instance for a trigger provider"""

View File

@ -1,7 +1,8 @@
from urllib.parse import quote from urllib.parse import quote
from flask import Response, request from flask import Response, request
from flask_restx import Resource, reqparse from flask_restx import Resource
from pydantic import BaseModel, Field
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
import services import services
@ -11,6 +12,26 @@ from extensions.ext_database import db
from services.account_service import TenantService from services.account_service import TenantService
from services.file_service import FileService from services.file_service import FileService
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class FileSignatureQuery(BaseModel):
timestamp: str = Field(..., description="Unix timestamp used in the signature")
nonce: str = Field(..., description="Random string for signature")
sign: str = Field(..., description="HMAC signature")
class FilePreviewQuery(FileSignatureQuery):
as_attachment: bool = Field(default=False, description="Whether to download as attachment")
files_ns.schema_model(
FileSignatureQuery.__name__, FileSignatureQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
files_ns.schema_model(
FilePreviewQuery.__name__, FilePreviewQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
@files_ns.route("/<uuid:file_id>/image-preview") @files_ns.route("/<uuid:file_id>/image-preview")
class ImagePreviewApi(Resource): class ImagePreviewApi(Resource):
@ -36,12 +57,10 @@ class ImagePreviewApi(Resource):
def get(self, file_id): def get(self, file_id):
file_id = str(file_id) file_id = str(file_id)
timestamp = request.args.get("timestamp") args = FileSignatureQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
nonce = request.args.get("nonce") timestamp = args.timestamp
sign = request.args.get("sign") nonce = args.nonce
sign = args.sign
if not timestamp or not nonce or not sign:
return {"content": "Invalid request."}, 400
try: try:
generator, mimetype = FileService(db.engine).get_image_preview( generator, mimetype = FileService(db.engine).get_image_preview(
@ -80,25 +99,14 @@ class FilePreviewApi(Resource):
def get(self, file_id): def get(self, file_id):
file_id = str(file_id) file_id = str(file_id)
parser = ( args = FilePreviewQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
reqparse.RequestParser()
.add_argument("timestamp", type=str, required=True, location="args")
.add_argument("nonce", type=str, required=True, location="args")
.add_argument("sign", type=str, required=True, location="args")
.add_argument("as_attachment", type=bool, required=False, default=False, location="args")
)
args = parser.parse_args()
if not args["timestamp"] or not args["nonce"] or not args["sign"]:
return {"content": "Invalid request."}, 400
try: try:
generator, upload_file = FileService(db.engine).get_file_generator_by_file_id( generator, upload_file = FileService(db.engine).get_file_generator_by_file_id(
file_id=file_id, file_id=file_id,
timestamp=args["timestamp"], timestamp=args.timestamp,
nonce=args["nonce"], nonce=args.nonce,
sign=args["sign"], sign=args.sign,
) )
except services.errors.file.UnsupportedFileTypeError: except services.errors.file.UnsupportedFileTypeError:
raise UnsupportedFileTypeError() raise UnsupportedFileTypeError()
@ -125,7 +133,7 @@ class FilePreviewApi(Resource):
response.headers["Accept-Ranges"] = "bytes" response.headers["Accept-Ranges"] = "bytes"
if upload_file.size > 0: if upload_file.size > 0:
response.headers["Content-Length"] = str(upload_file.size) response.headers["Content-Length"] = str(upload_file.size)
if args["as_attachment"]: if args.as_attachment:
encoded_filename = quote(upload_file.name) encoded_filename = quote(upload_file.name)
response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{encoded_filename}" response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{encoded_filename}"
response.headers["Content-Type"] = "application/octet-stream" response.headers["Content-Type"] = "application/octet-stream"

View File

@ -1,7 +1,8 @@
from urllib.parse import quote from urllib.parse import quote
from flask import Response from flask import Response, request
from flask_restx import Resource, reqparse from flask_restx import Resource
from pydantic import BaseModel, Field
from werkzeug.exceptions import Forbidden, NotFound from werkzeug.exceptions import Forbidden, NotFound
from controllers.common.errors import UnsupportedFileTypeError from controllers.common.errors import UnsupportedFileTypeError
@ -10,6 +11,20 @@ from core.tools.signature import verify_tool_file_signature
from core.tools.tool_file_manager import ToolFileManager from core.tools.tool_file_manager import ToolFileManager
from extensions.ext_database import db as global_db from extensions.ext_database import db as global_db
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class ToolFileQuery(BaseModel):
timestamp: str = Field(..., description="Unix timestamp")
nonce: str = Field(..., description="Random nonce")
sign: str = Field(..., description="HMAC signature")
as_attachment: bool = Field(default=False, description="Download as attachment")
files_ns.schema_model(
ToolFileQuery.__name__, ToolFileQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
@files_ns.route("/tools/<uuid:file_id>.<string:extension>") @files_ns.route("/tools/<uuid:file_id>.<string:extension>")
class ToolFileApi(Resource): class ToolFileApi(Resource):
@ -36,18 +51,8 @@ class ToolFileApi(Resource):
def get(self, file_id, extension): def get(self, file_id, extension):
file_id = str(file_id) file_id = str(file_id)
parser = ( args = ToolFileQuery.model_validate(request.args.to_dict())
reqparse.RequestParser() if not verify_tool_file_signature(file_id=file_id, timestamp=args.timestamp, nonce=args.nonce, sign=args.sign):
.add_argument("timestamp", type=str, required=True, location="args")
.add_argument("nonce", type=str, required=True, location="args")
.add_argument("sign", type=str, required=True, location="args")
.add_argument("as_attachment", type=bool, required=False, default=False, location="args")
)
args = parser.parse_args()
if not verify_tool_file_signature(
file_id=file_id, timestamp=args["timestamp"], nonce=args["nonce"], sign=args["sign"]
):
raise Forbidden("Invalid request.") raise Forbidden("Invalid request.")
try: try:
@ -69,7 +74,7 @@ class ToolFileApi(Resource):
) )
if tool_file.size > 0: if tool_file.size > 0:
response.headers["Content-Length"] = str(tool_file.size) response.headers["Content-Length"] = str(tool_file.size)
if args["as_attachment"]: if args.as_attachment:
encoded_filename = quote(tool_file.name) encoded_filename = quote(tool_file.name)
response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{encoded_filename}" response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{encoded_filename}"

View File

@ -1,40 +1,45 @@
from mimetypes import guess_extension from mimetypes import guess_extension
from flask_restx import Resource, reqparse from flask import request
from flask_restx import Resource
from flask_restx.api import HTTPStatus from flask_restx.api import HTTPStatus
from pydantic import BaseModel, Field
from werkzeug.datastructures import FileStorage from werkzeug.datastructures import FileStorage
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden
import services import services
from controllers.common.errors import (
FileTooLargeError,
UnsupportedFileTypeError,
)
from controllers.console.wraps import setup_required
from controllers.files import files_ns
from controllers.inner_api.plugin.wraps import get_user
from core.file.helpers import verify_plugin_file_signature from core.file.helpers import verify_plugin_file_signature
from core.tools.tool_file_manager import ToolFileManager from core.tools.tool_file_manager import ToolFileManager
from fields.file_fields import build_file_model from fields.file_fields import build_file_model
# Define parser for both documentation and validation from ..common.errors import (
upload_parser = ( FileTooLargeError,
reqparse.RequestParser() UnsupportedFileTypeError,
.add_argument("file", location="files", type=FileStorage, required=True, help="File to upload") )
.add_argument( from ..console.wraps import setup_required
"timestamp", type=str, required=True, location="args", help="Unix timestamp for signature verification" from ..files import files_ns
) from ..inner_api.plugin.wraps import get_user
.add_argument("nonce", type=str, required=True, location="args", help="Random string for signature verification")
.add_argument("sign", type=str, required=True, location="args", help="HMAC signature for request validation") DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
.add_argument("tenant_id", type=str, required=True, location="args", help="Tenant identifier")
.add_argument("user_id", type=str, required=False, location="args", help="User identifier")
class PluginUploadQuery(BaseModel):
timestamp: str = Field(..., description="Unix timestamp for signature verification")
nonce: str = Field(..., description="Random nonce for signature verification")
sign: str = Field(..., description="HMAC signature")
tenant_id: str = Field(..., description="Tenant identifier")
user_id: str | None = Field(default=None, description="User identifier")
files_ns.schema_model(
PluginUploadQuery.__name__, PluginUploadQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
) )
@files_ns.route("/upload/for-plugin") @files_ns.route("/upload/for-plugin")
class PluginUploadFileApi(Resource): class PluginUploadFileApi(Resource):
@setup_required @setup_required
@files_ns.expect(upload_parser) @files_ns.expect(files_ns.models[PluginUploadQuery.__name__])
@files_ns.doc("upload_plugin_file") @files_ns.doc("upload_plugin_file")
@files_ns.doc(description="Upload a file for plugin usage with signature verification") @files_ns.doc(description="Upload a file for plugin usage with signature verification")
@files_ns.doc( @files_ns.doc(
@ -62,15 +67,17 @@ class PluginUploadFileApi(Resource):
FileTooLargeError: File exceeds size limit FileTooLargeError: File exceeds size limit
UnsupportedFileTypeError: File type not supported UnsupportedFileTypeError: File type not supported
""" """
# Parse and validate all arguments args = PluginUploadQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
args = upload_parser.parse_args()
file: FileStorage = args["file"] file: FileStorage | None = request.files.get("file")
timestamp: str = args["timestamp"] if file is None:
nonce: str = args["nonce"] raise Forbidden("File is required.")
sign: str = args["sign"]
tenant_id: str = args["tenant_id"] timestamp = args.timestamp
user_id: str | None = args.get("user_id") nonce = args.nonce
sign = args.sign
tenant_id = args.tenant_id
user_id = args.user_id
user = get_user(tenant_id, user_id) user = get_user(tenant_id, user_id)
filename: str | None = file.filename filename: str | None = file.filename

View File

@ -1,29 +1,38 @@
from flask_restx import Resource, reqparse from typing import Any
from flask_restx import Resource
from pydantic import BaseModel, Field
from controllers.common.schema import register_schema_model
from controllers.console.wraps import setup_required from controllers.console.wraps import setup_required
from controllers.inner_api import inner_api_ns from controllers.inner_api import inner_api_ns
from controllers.inner_api.wraps import billing_inner_api_only, enterprise_inner_api_only from controllers.inner_api.wraps import billing_inner_api_only, enterprise_inner_api_only
from tasks.mail_inner_task import send_inner_email_task from tasks.mail_inner_task import send_inner_email_task
_mail_parser = (
reqparse.RequestParser() class InnerMailPayload(BaseModel):
.add_argument("to", type=str, action="append", required=True) to: list[str] = Field(description="Recipient email addresses", min_length=1)
.add_argument("subject", type=str, required=True) subject: str
.add_argument("body", type=str, required=True) body: str
.add_argument("substitutions", type=dict, required=False) substitutions: dict[str, Any] | None = None
)
register_schema_model(inner_api_ns, InnerMailPayload)
class BaseMail(Resource): class BaseMail(Resource):
"""Shared logic for sending an inner email.""" """Shared logic for sending an inner email."""
@inner_api_ns.doc("send_inner_mail")
@inner_api_ns.doc(description="Send internal email")
@inner_api_ns.expect(inner_api_ns.models[InnerMailPayload.__name__])
def post(self): def post(self):
args = _mail_parser.parse_args() args = InnerMailPayload.model_validate(inner_api_ns.payload or {})
send_inner_email_task.delay( # type: ignore send_inner_email_task.delay(
to=args["to"], to=args.to,
subject=args["subject"], subject=args.subject,
body=args["body"], body=args.body,
substitutions=args["substitutions"], substitutions=args.substitutions, # type: ignore
) )
return {"message": "success"}, 200 return {"message": "success"}, 200
@ -34,7 +43,7 @@ class EnterpriseMail(BaseMail):
@inner_api_ns.doc("send_enterprise_mail") @inner_api_ns.doc("send_enterprise_mail")
@inner_api_ns.doc(description="Send internal email for enterprise features") @inner_api_ns.doc(description="Send internal email for enterprise features")
@inner_api_ns.expect(_mail_parser) @inner_api_ns.expect(inner_api_ns.models[InnerMailPayload.__name__])
@inner_api_ns.doc( @inner_api_ns.doc(
responses={200: "Email sent successfully", 401: "Unauthorized - invalid API key", 404: "Service not available"} responses={200: "Email sent successfully", 401: "Unauthorized - invalid API key", 404: "Service not available"}
) )
@ -56,7 +65,7 @@ class BillingMail(BaseMail):
@inner_api_ns.doc("send_billing_mail") @inner_api_ns.doc("send_billing_mail")
@inner_api_ns.doc(description="Send internal email for billing notifications") @inner_api_ns.doc(description="Send internal email for billing notifications")
@inner_api_ns.expect(_mail_parser) @inner_api_ns.expect(inner_api_ns.models[InnerMailPayload.__name__])
@inner_api_ns.doc( @inner_api_ns.doc(
responses={200: "Email sent successfully", 401: "Unauthorized - invalid API key", 404: "Service not available"} responses={200: "Email sent successfully", 401: "Unauthorized - invalid API key", 404: "Service not available"}
) )

View File

@ -1,10 +1,9 @@
from collections.abc import Callable from collections.abc import Callable
from functools import wraps from functools import wraps
from typing import ParamSpec, TypeVar, cast from typing import ParamSpec, TypeVar
from flask import current_app, request from flask import current_app, request
from flask_login import user_logged_in from flask_login import user_logged_in
from flask_restx import reqparse
from pydantic import BaseModel from pydantic import BaseModel
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@ -17,6 +16,11 @@ P = ParamSpec("P")
R = TypeVar("R") R = TypeVar("R")
class TenantUserPayload(BaseModel):
tenant_id: str
user_id: str
def get_user(tenant_id: str, user_id: str | None) -> EndUser: def get_user(tenant_id: str, user_id: str | None) -> EndUser:
""" """
Get current user Get current user
@ -67,58 +71,45 @@ def get_user(tenant_id: str, user_id: str | None) -> EndUser:
return user_model return user_model
def get_user_tenant(view: Callable[P, R] | None = None): def get_user_tenant(view_func: Callable[P, R]):
def decorator(view_func: Callable[P, R]): @wraps(view_func)
@wraps(view_func) def decorated_view(*args: P.args, **kwargs: P.kwargs):
def decorated_view(*args: P.args, **kwargs: P.kwargs): payload = TenantUserPayload.model_validate(request.get_json(silent=True) or {})
# fetch json body
parser = (
reqparse.RequestParser()
.add_argument("tenant_id", type=str, required=True, location="json")
.add_argument("user_id", type=str, required=True, location="json")
)
p = parser.parse_args() user_id = payload.user_id
tenant_id = payload.tenant_id
user_id = cast(str, p.get("user_id")) if not tenant_id:
tenant_id = cast(str, p.get("tenant_id")) raise ValueError("tenant_id is required")
if not tenant_id: if not user_id:
raise ValueError("tenant_id is required") user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID
if not user_id: try:
user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID tenant_model = (
db.session.query(Tenant)
try: .where(
tenant_model = ( Tenant.id == tenant_id,
db.session.query(Tenant)
.where(
Tenant.id == tenant_id,
)
.first()
) )
except Exception: .first()
raise ValueError("tenant not found") )
except Exception:
raise ValueError("tenant not found")
if not tenant_model: if not tenant_model:
raise ValueError("tenant not found") raise ValueError("tenant not found")
kwargs["tenant_model"] = tenant_model kwargs["tenant_model"] = tenant_model
user = get_user(tenant_id, user_id) user = get_user(tenant_id, user_id)
kwargs["user_model"] = user kwargs["user_model"] = user
current_app.login_manager._update_request_context_with_user(user) # type: ignore current_app.login_manager._update_request_context_with_user(user) # type: ignore
user_logged_in.send(current_app._get_current_object(), user=current_user) # type: ignore user_logged_in.send(current_app._get_current_object(), user=current_user) # type: ignore
return view_func(*args, **kwargs) return view_func(*args, **kwargs)
return decorated_view return decorated_view
if view is None:
return decorator
else:
return decorator(view)
def plugin_data(view: Callable[P, R] | None = None, *, payload_type: type[BaseModel]): def plugin_data(view: Callable[P, R] | None = None, *, payload_type: type[BaseModel]):

View File

@ -1,7 +1,9 @@
import json import json
from flask_restx import Resource, reqparse from flask_restx import Resource
from pydantic import BaseModel
from controllers.common.schema import register_schema_models
from controllers.console.wraps import setup_required from controllers.console.wraps import setup_required
from controllers.inner_api import inner_api_ns from controllers.inner_api import inner_api_ns
from controllers.inner_api.wraps import enterprise_inner_api_only from controllers.inner_api.wraps import enterprise_inner_api_only
@ -11,12 +13,25 @@ from models import Account
from services.account_service import TenantService from services.account_service import TenantService
class WorkspaceCreatePayload(BaseModel):
name: str
owner_email: str
class WorkspaceOwnerlessPayload(BaseModel):
name: str
register_schema_models(inner_api_ns, WorkspaceCreatePayload, WorkspaceOwnerlessPayload)
@inner_api_ns.route("/enterprise/workspace") @inner_api_ns.route("/enterprise/workspace")
class EnterpriseWorkspace(Resource): class EnterpriseWorkspace(Resource):
@setup_required @setup_required
@enterprise_inner_api_only @enterprise_inner_api_only
@inner_api_ns.doc("create_enterprise_workspace") @inner_api_ns.doc("create_enterprise_workspace")
@inner_api_ns.doc(description="Create a new enterprise workspace with owner assignment") @inner_api_ns.doc(description="Create a new enterprise workspace with owner assignment")
@inner_api_ns.expect(inner_api_ns.models[WorkspaceCreatePayload.__name__])
@inner_api_ns.doc( @inner_api_ns.doc(
responses={ responses={
200: "Workspace created successfully", 200: "Workspace created successfully",
@ -25,18 +40,13 @@ class EnterpriseWorkspace(Resource):
} }
) )
def post(self): def post(self):
parser = ( args = WorkspaceCreatePayload.model_validate(inner_api_ns.payload or {})
reqparse.RequestParser()
.add_argument("name", type=str, required=True, location="json")
.add_argument("owner_email", type=str, required=True, location="json")
)
args = parser.parse_args()
account = db.session.query(Account).filter_by(email=args["owner_email"]).first() account = db.session.query(Account).filter_by(email=args.owner_email).first()
if account is None: if account is None:
return {"message": "owner account not found."}, 404 return {"message": "owner account not found."}, 404
tenant = TenantService.create_tenant(args["name"], is_from_dashboard=True) tenant = TenantService.create_tenant(args.name, is_from_dashboard=True)
TenantService.create_tenant_member(tenant, account, role="owner") TenantService.create_tenant_member(tenant, account, role="owner")
tenant_was_created.send(tenant) tenant_was_created.send(tenant)
@ -62,6 +72,7 @@ class EnterpriseWorkspaceNoOwnerEmail(Resource):
@enterprise_inner_api_only @enterprise_inner_api_only
@inner_api_ns.doc("create_enterprise_workspace_ownerless") @inner_api_ns.doc("create_enterprise_workspace_ownerless")
@inner_api_ns.doc(description="Create a new enterprise workspace without initial owner assignment") @inner_api_ns.doc(description="Create a new enterprise workspace without initial owner assignment")
@inner_api_ns.expect(inner_api_ns.models[WorkspaceOwnerlessPayload.__name__])
@inner_api_ns.doc( @inner_api_ns.doc(
responses={ responses={
200: "Workspace created successfully", 200: "Workspace created successfully",
@ -70,10 +81,9 @@ class EnterpriseWorkspaceNoOwnerEmail(Resource):
} }
) )
def post(self): def post(self):
parser = reqparse.RequestParser().add_argument("name", type=str, required=True, location="json") args = WorkspaceOwnerlessPayload.model_validate(inner_api_ns.payload or {})
args = parser.parse_args()
tenant = TenantService.create_tenant(args["name"], is_from_dashboard=True) tenant = TenantService.create_tenant(args.name, is_from_dashboard=True)
tenant_was_created.send(tenant) tenant_was_created.send(tenant)

View File

@ -1,10 +1,11 @@
from typing import Union from typing import Any, Union
from flask import Response from flask import Response
from flask_restx import Resource, reqparse from flask_restx import Resource
from pydantic import ValidationError from pydantic import BaseModel, Field, ValidationError
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from controllers.common.schema import register_schema_model
from controllers.console.app.mcp_server import AppMCPServerStatus from controllers.console.app.mcp_server import AppMCPServerStatus
from controllers.mcp import mcp_ns from controllers.mcp import mcp_ns
from core.app.app_config.entities import VariableEntity from core.app.app_config.entities import VariableEntity
@ -24,27 +25,19 @@ class MCPRequestError(Exception):
super().__init__(message) super().__init__(message)
def int_or_str(value): class MCPRequestPayload(BaseModel):
"""Validate that a value is either an integer or string.""" jsonrpc: str = Field(description="JSON-RPC version (should be '2.0')")
if isinstance(value, (int, str)): method: str = Field(description="The method to invoke")
return value params: dict[str, Any] | None = Field(default=None, description="Parameters for the method")
else: id: int | str | None = Field(default=None, description="Request ID for tracking responses")
return None
# Define parser for both documentation and validation register_schema_model(mcp_ns, MCPRequestPayload)
mcp_request_parser = (
reqparse.RequestParser()
.add_argument("jsonrpc", type=str, required=True, location="json", help="JSON-RPC version (should be '2.0')")
.add_argument("method", type=str, required=True, location="json", help="The method to invoke")
.add_argument("params", type=dict, required=False, location="json", help="Parameters for the method")
.add_argument("id", type=int_or_str, required=False, location="json", help="Request ID for tracking responses")
)
@mcp_ns.route("/server/<string:server_code>/mcp") @mcp_ns.route("/server/<string:server_code>/mcp")
class MCPAppApi(Resource): class MCPAppApi(Resource):
@mcp_ns.expect(mcp_request_parser) @mcp_ns.expect(mcp_ns.models[MCPRequestPayload.__name__])
@mcp_ns.doc("handle_mcp_request") @mcp_ns.doc("handle_mcp_request")
@mcp_ns.doc(description="Handle Model Context Protocol (MCP) requests for a specific server") @mcp_ns.doc(description="Handle Model Context Protocol (MCP) requests for a specific server")
@mcp_ns.doc(params={"server_code": "Unique identifier for the MCP server"}) @mcp_ns.doc(params={"server_code": "Unique identifier for the MCP server"})
@ -70,9 +63,9 @@ class MCPAppApi(Resource):
Raises: Raises:
ValidationError: Invalid request format or parameters ValidationError: Invalid request format or parameters
""" """
args = mcp_request_parser.parse_args() args = MCPRequestPayload.model_validate(mcp_ns.payload or {})
request_id: Union[int, str] | None = args.get("id") request_id: Union[int, str] | None = args.id
mcp_request = self._parse_mcp_request(args) mcp_request = self._parse_mcp_request(args.model_dump(exclude_none=True))
with Session(db.engine, expire_on_commit=False) as session: with Session(db.engine, expire_on_commit=False) as session:
# Get MCP server and app # Get MCP server and app

View File

@ -1,9 +1,11 @@
from typing import Literal from typing import Literal
from flask import request from flask import request
from flask_restx import Api, Namespace, Resource, fields, reqparse from flask_restx import Api, Namespace, Resource, fields
from flask_restx.api import HTTPStatus from flask_restx.api import HTTPStatus
from pydantic import BaseModel, Field
from controllers.common.schema import register_schema_models
from controllers.console.wraps import edit_permission_required from controllers.console.wraps import edit_permission_required
from controllers.service_api import service_api_ns from controllers.service_api import service_api_ns
from controllers.service_api.wraps import validate_app_token from controllers.service_api.wraps import validate_app_token
@ -12,26 +14,24 @@ from fields.annotation_fields import annotation_fields, build_annotation_model
from models.model import App from models.model import App
from services.annotation_service import AppAnnotationService from services.annotation_service import AppAnnotationService
# Define parsers for annotation API
annotation_create_parser = (
reqparse.RequestParser()
.add_argument("question", required=True, type=str, location="json", help="Annotation question")
.add_argument("answer", required=True, type=str, location="json", help="Annotation answer")
)
annotation_reply_action_parser = ( class AnnotationCreatePayload(BaseModel):
reqparse.RequestParser() question: str = Field(description="Annotation question")
.add_argument( answer: str = Field(description="Annotation answer")
"score_threshold", required=True, type=float, location="json", help="Score threshold for annotation matching"
)
.add_argument("embedding_provider_name", required=True, type=str, location="json", help="Embedding provider name") class AnnotationReplyActionPayload(BaseModel):
.add_argument("embedding_model_name", required=True, type=str, location="json", help="Embedding model name") score_threshold: float = Field(description="Score threshold for annotation matching")
) embedding_provider_name: str = Field(description="Embedding provider name")
embedding_model_name: str = Field(description="Embedding model name")
register_schema_models(service_api_ns, AnnotationCreatePayload, AnnotationReplyActionPayload)
@service_api_ns.route("/apps/annotation-reply/<string:action>") @service_api_ns.route("/apps/annotation-reply/<string:action>")
class AnnotationReplyActionApi(Resource): class AnnotationReplyActionApi(Resource):
@service_api_ns.expect(annotation_reply_action_parser) @service_api_ns.expect(service_api_ns.models[AnnotationReplyActionPayload.__name__])
@service_api_ns.doc("annotation_reply_action") @service_api_ns.doc("annotation_reply_action")
@service_api_ns.doc(description="Enable or disable annotation reply feature") @service_api_ns.doc(description="Enable or disable annotation reply feature")
@service_api_ns.doc(params={"action": "Action to perform: 'enable' or 'disable'"}) @service_api_ns.doc(params={"action": "Action to perform: 'enable' or 'disable'"})
@ -44,7 +44,7 @@ class AnnotationReplyActionApi(Resource):
@validate_app_token @validate_app_token
def post(self, app_model: App, action: Literal["enable", "disable"]): def post(self, app_model: App, action: Literal["enable", "disable"]):
"""Enable or disable annotation reply feature.""" """Enable or disable annotation reply feature."""
args = annotation_reply_action_parser.parse_args() args = AnnotationReplyActionPayload.model_validate(service_api_ns.payload or {}).model_dump()
if action == "enable": if action == "enable":
result = AppAnnotationService.enable_app_annotation(args, app_model.id) result = AppAnnotationService.enable_app_annotation(args, app_model.id)
elif action == "disable": elif action == "disable":
@ -126,7 +126,7 @@ class AnnotationListApi(Resource):
"page": page, "page": page,
} }
@service_api_ns.expect(annotation_create_parser) @service_api_ns.expect(service_api_ns.models[AnnotationCreatePayload.__name__])
@service_api_ns.doc("create_annotation") @service_api_ns.doc("create_annotation")
@service_api_ns.doc(description="Create a new annotation") @service_api_ns.doc(description="Create a new annotation")
@service_api_ns.doc( @service_api_ns.doc(
@ -139,14 +139,14 @@ class AnnotationListApi(Resource):
@service_api_ns.marshal_with(build_annotation_model(service_api_ns), code=HTTPStatus.CREATED) @service_api_ns.marshal_with(build_annotation_model(service_api_ns), code=HTTPStatus.CREATED)
def post(self, app_model: App): def post(self, app_model: App):
"""Create a new annotation.""" """Create a new annotation."""
args = annotation_create_parser.parse_args() args = AnnotationCreatePayload.model_validate(service_api_ns.payload or {}).model_dump()
annotation = AppAnnotationService.insert_app_annotation_directly(args, app_model.id) annotation = AppAnnotationService.insert_app_annotation_directly(args, app_model.id)
return annotation, 201 return annotation, 201
@service_api_ns.route("/apps/annotations/<uuid:annotation_id>") @service_api_ns.route("/apps/annotations/<uuid:annotation_id>")
class AnnotationUpdateDeleteApi(Resource): class AnnotationUpdateDeleteApi(Resource):
@service_api_ns.expect(annotation_create_parser) @service_api_ns.expect(service_api_ns.models[AnnotationCreatePayload.__name__])
@service_api_ns.doc("update_annotation") @service_api_ns.doc("update_annotation")
@service_api_ns.doc(description="Update an existing annotation") @service_api_ns.doc(description="Update an existing annotation")
@service_api_ns.doc(params={"annotation_id": "Annotation ID"}) @service_api_ns.doc(params={"annotation_id": "Annotation ID"})
@ -163,7 +163,7 @@ class AnnotationUpdateDeleteApi(Resource):
@service_api_ns.marshal_with(build_annotation_model(service_api_ns)) @service_api_ns.marshal_with(build_annotation_model(service_api_ns))
def put(self, app_model: App, annotation_id: str): def put(self, app_model: App, annotation_id: str):
"""Update an existing annotation.""" """Update an existing annotation."""
args = annotation_create_parser.parse_args() args = AnnotationCreatePayload.model_validate(service_api_ns.payload or {}).model_dump()
annotation = AppAnnotationService.update_app_annotation_directly(args, app_model.id, annotation_id) annotation = AppAnnotationService.update_app_annotation_directly(args, app_model.id, annotation_id)
return annotation return annotation

View File

@ -1,10 +1,12 @@
import logging import logging
from flask import request from flask import request
from flask_restx import Resource, reqparse from flask_restx import Resource
from pydantic import BaseModel, Field
from werkzeug.exceptions import InternalServerError from werkzeug.exceptions import InternalServerError
import services import services
from controllers.common.schema import register_schema_model
from controllers.service_api import service_api_ns from controllers.service_api import service_api_ns
from controllers.service_api.app.error import ( from controllers.service_api.app.error import (
AppUnavailableError, AppUnavailableError,
@ -84,19 +86,19 @@ class AudioApi(Resource):
raise InternalServerError() raise InternalServerError()
# Define parser for text-to-audio API class TextToAudioPayload(BaseModel):
text_to_audio_parser = ( message_id: str | None = Field(default=None, description="Message ID")
reqparse.RequestParser() voice: str | None = Field(default=None, description="Voice to use for TTS")
.add_argument("message_id", type=str, required=False, location="json", help="Message ID") text: str | None = Field(default=None, description="Text to convert to audio")
.add_argument("voice", type=str, location="json", help="Voice to use for TTS") streaming: bool | None = Field(default=None, description="Enable streaming response")
.add_argument("text", type=str, location="json", help="Text to convert to audio")
.add_argument("streaming", type=bool, location="json", help="Enable streaming response")
) register_schema_model(service_api_ns, TextToAudioPayload)
@service_api_ns.route("/text-to-audio") @service_api_ns.route("/text-to-audio")
class TextApi(Resource): class TextApi(Resource):
@service_api_ns.expect(text_to_audio_parser) @service_api_ns.expect(service_api_ns.models[TextToAudioPayload.__name__])
@service_api_ns.doc("text_to_audio") @service_api_ns.doc("text_to_audio")
@service_api_ns.doc(description="Convert text to audio using text-to-speech") @service_api_ns.doc(description="Convert text to audio using text-to-speech")
@service_api_ns.doc( @service_api_ns.doc(
@ -114,11 +116,11 @@ class TextApi(Resource):
Converts the provided text to audio using the specified voice. Converts the provided text to audio using the specified voice.
""" """
try: try:
args = text_to_audio_parser.parse_args() payload = TextToAudioPayload.model_validate(service_api_ns.payload or {})
message_id = args.get("message_id", None) message_id = payload.message_id
text = args.get("text", None) text = payload.text
voice = args.get("voice", None) voice = payload.voice
response = AudioService.transcript_tts( response = AudioService.transcript_tts(
app_model=app_model, text=text, voice=voice, end_user=end_user.external_user_id, message_id=message_id app_model=app_model, text=text, voice=voice, end_user=end_user.external_user_id, message_id=message_id
) )

View File

@ -1,10 +1,14 @@
import logging import logging
from typing import Any, Literal
from uuid import UUID
from flask import request from flask import request
from flask_restx import Resource, reqparse from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
import services import services
from controllers.common.schema import register_schema_models
from controllers.service_api import service_api_ns from controllers.service_api import service_api_ns
from controllers.service_api.app.error import ( from controllers.service_api.app.error import (
AppUnavailableError, AppUnavailableError,
@ -26,7 +30,6 @@ from core.errors.error import (
from core.helper.trace_id_helper import get_external_trace_id from core.helper.trace_id_helper import get_external_trace_id
from core.model_runtime.errors.invoke import InvokeError from core.model_runtime.errors.invoke import InvokeError
from libs import helper from libs import helper
from libs.helper import uuid_value
from models.model import App, AppMode, EndUser from models.model import App, AppMode, EndUser
from services.app_generate_service import AppGenerateService from services.app_generate_service import AppGenerateService
from services.app_task_service import AppTaskService from services.app_task_service import AppTaskService
@ -36,40 +39,43 @@ from services.errors.llm import InvokeRateLimitError
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Define parser for completion API class CompletionRequestPayload(BaseModel):
completion_parser = ( inputs: dict[str, Any]
reqparse.RequestParser() query: str = Field(default="")
.add_argument("inputs", type=dict, required=True, location="json", help="Input parameters for completion") files: list[dict[str, Any]] | None = None
.add_argument("query", type=str, location="json", default="", help="The query string") response_mode: Literal["blocking", "streaming"] | None = None
.add_argument("files", type=list, required=False, location="json", help="List of file attachments") retriever_from: str = Field(default="dev")
.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json", help="Response mode")
.add_argument("retriever_from", type=str, required=False, default="dev", location="json", help="Retriever source")
)
# Define parser for chat API
chat_parser = ( class ChatRequestPayload(BaseModel):
reqparse.RequestParser() inputs: dict[str, Any]
.add_argument("inputs", type=dict, required=True, location="json", help="Input parameters for chat") query: str
.add_argument("query", type=str, required=True, location="json", help="The chat query") files: list[dict[str, Any]] | None = None
.add_argument("files", type=list, required=False, location="json", help="List of file attachments") response_mode: Literal["blocking", "streaming"] | None = None
.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json", help="Response mode") conversation_id: str | None = Field(default=None, description="Conversation UUID")
.add_argument("conversation_id", type=uuid_value, location="json", help="Existing conversation ID") retriever_from: str = Field(default="dev")
.add_argument("retriever_from", type=str, required=False, default="dev", location="json", help="Retriever source") auto_generate_name: bool = Field(default=True, description="Auto generate conversation name")
.add_argument( workflow_id: str | None = Field(default=None, description="Workflow ID for advanced chat")
"auto_generate_name",
type=bool, @field_validator("conversation_id", mode="before")
required=False, @classmethod
default=True, def normalize_conversation_id(cls, value: str | UUID | None) -> str | None:
location="json", """Allow missing or blank conversation IDs; enforce UUID format when provided."""
help="Auto generate conversation name", if not value:
) return None
.add_argument("workflow_id", type=str, required=False, location="json", help="Workflow ID for advanced chat")
) try:
return helper.uuid_value(value)
except ValueError as exc:
raise ValueError("conversation_id must be a valid UUID") from exc
register_schema_models(service_api_ns, CompletionRequestPayload, ChatRequestPayload)
@service_api_ns.route("/completion-messages") @service_api_ns.route("/completion-messages")
class CompletionApi(Resource): class CompletionApi(Resource):
@service_api_ns.expect(completion_parser) @service_api_ns.expect(service_api_ns.models[CompletionRequestPayload.__name__])
@service_api_ns.doc("create_completion") @service_api_ns.doc("create_completion")
@service_api_ns.doc(description="Create a completion for the given prompt") @service_api_ns.doc(description="Create a completion for the given prompt")
@service_api_ns.doc( @service_api_ns.doc(
@ -91,12 +97,13 @@ class CompletionApi(Resource):
if app_model.mode != AppMode.COMPLETION: if app_model.mode != AppMode.COMPLETION:
raise AppUnavailableError() raise AppUnavailableError()
args = completion_parser.parse_args() payload = CompletionRequestPayload.model_validate(service_api_ns.payload or {})
external_trace_id = get_external_trace_id(request) external_trace_id = get_external_trace_id(request)
args = payload.model_dump(exclude_none=True)
if external_trace_id: if external_trace_id:
args["external_trace_id"] = external_trace_id args["external_trace_id"] = external_trace_id
streaming = args["response_mode"] == "streaming" streaming = payload.response_mode == "streaming"
args["auto_generate_name"] = False args["auto_generate_name"] = False
@ -162,7 +169,7 @@ class CompletionStopApi(Resource):
@service_api_ns.route("/chat-messages") @service_api_ns.route("/chat-messages")
class ChatApi(Resource): class ChatApi(Resource):
@service_api_ns.expect(chat_parser) @service_api_ns.expect(service_api_ns.models[ChatRequestPayload.__name__])
@service_api_ns.doc("create_chat_message") @service_api_ns.doc("create_chat_message")
@service_api_ns.doc(description="Send a message in a chat conversation") @service_api_ns.doc(description="Send a message in a chat conversation")
@service_api_ns.doc( @service_api_ns.doc(
@ -186,13 +193,14 @@ class ChatApi(Resource):
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError() raise NotChatAppError()
args = chat_parser.parse_args() payload = ChatRequestPayload.model_validate(service_api_ns.payload or {})
external_trace_id = get_external_trace_id(request) external_trace_id = get_external_trace_id(request)
args = payload.model_dump(exclude_none=True)
if external_trace_id: if external_trace_id:
args["external_trace_id"] = external_trace_id args["external_trace_id"] = external_trace_id
streaming = args["response_mode"] == "streaming" streaming = payload.response_mode == "streaming"
try: try:
response = AppGenerateService.generate( response = AppGenerateService.generate(

View File

@ -1,10 +1,15 @@
from flask_restx import Resource, reqparse from typing import Any, Literal
from uuid import UUID
from flask import request
from flask_restx import Resource
from flask_restx._http import HTTPStatus from flask_restx._http import HTTPStatus
from flask_restx.inputs import int_range from pydantic import BaseModel, Field, model_validator
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from werkzeug.exceptions import BadRequest, NotFound from werkzeug.exceptions import BadRequest, NotFound
import services import services
from controllers.common.schema import register_schema_models
from controllers.service_api import service_api_ns from controllers.service_api import service_api_ns
from controllers.service_api.app.error import NotChatAppError from controllers.service_api.app.error import NotChatAppError
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
@ -19,74 +24,51 @@ from fields.conversation_variable_fields import (
build_conversation_variable_infinite_scroll_pagination_model, build_conversation_variable_infinite_scroll_pagination_model,
build_conversation_variable_model, build_conversation_variable_model,
) )
from libs.helper import uuid_value
from models.model import App, AppMode, EndUser from models.model import App, AppMode, EndUser
from services.conversation_service import ConversationService from services.conversation_service import ConversationService
# Define parsers for conversation APIs
conversation_list_parser = (
reqparse.RequestParser()
.add_argument("last_id", type=uuid_value, location="args", help="Last conversation ID for pagination")
.add_argument(
"limit",
type=int_range(1, 100),
required=False,
default=20,
location="args",
help="Number of conversations to return",
)
.add_argument(
"sort_by",
type=str,
choices=["created_at", "-created_at", "updated_at", "-updated_at"],
required=False,
default="-updated_at",
location="args",
help="Sort order for conversations",
)
)
conversation_rename_parser = ( class ConversationListQuery(BaseModel):
reqparse.RequestParser() last_id: UUID | None = Field(default=None, description="Last conversation ID for pagination")
.add_argument("name", type=str, required=False, location="json", help="New conversation name") limit: int = Field(default=20, ge=1, le=100, description="Number of conversations to return")
.add_argument( sort_by: Literal["created_at", "-created_at", "updated_at", "-updated_at"] = Field(
"auto_generate", default="-updated_at", description="Sort order for conversations"
type=bool,
required=False,
default=False,
location="json",
help="Auto-generate conversation name",
) )
)
conversation_variables_parser = (
reqparse.RequestParser()
.add_argument("last_id", type=uuid_value, location="args", help="Last variable ID for pagination")
.add_argument(
"limit",
type=int_range(1, 100),
required=False,
default=20,
location="args",
help="Number of variables to return",
)
)
conversation_variable_update_parser = reqparse.RequestParser().add_argument( class ConversationRenamePayload(BaseModel):
# using lambda is for passing the already-typed value without modification name: str | None = Field(default=None, description="New conversation name (required if auto_generate is false)")
# if no lambda, it will be converted to string auto_generate: bool = Field(default=False, description="Auto-generate conversation name")
# the string cannot be converted using json.loads
"value", @model_validator(mode="after")
required=True, def validate_name_requirement(self):
location="json", if not self.auto_generate:
type=lambda x: x, if self.name is None or not self.name.strip():
help="New value for the conversation variable", raise ValueError("name is required when auto_generate is false")
return self
class ConversationVariablesQuery(BaseModel):
last_id: UUID | None = Field(default=None, description="Last variable ID for pagination")
limit: int = Field(default=20, ge=1, le=100, description="Number of variables to return")
class ConversationVariableUpdatePayload(BaseModel):
value: Any
register_schema_models(
service_api_ns,
ConversationListQuery,
ConversationRenamePayload,
ConversationVariablesQuery,
ConversationVariableUpdatePayload,
) )
@service_api_ns.route("/conversations") @service_api_ns.route("/conversations")
class ConversationApi(Resource): class ConversationApi(Resource):
@service_api_ns.expect(conversation_list_parser) @service_api_ns.expect(service_api_ns.models[ConversationListQuery.__name__])
@service_api_ns.doc("list_conversations") @service_api_ns.doc("list_conversations")
@service_api_ns.doc(description="List all conversations for the current user") @service_api_ns.doc(description="List all conversations for the current user")
@service_api_ns.doc( @service_api_ns.doc(
@ -107,7 +89,8 @@ class ConversationApi(Resource):
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError() raise NotChatAppError()
args = conversation_list_parser.parse_args() query_args = ConversationListQuery.model_validate(request.args.to_dict())
last_id = str(query_args.last_id) if query_args.last_id else None
try: try:
with Session(db.engine) as session: with Session(db.engine) as session:
@ -115,10 +98,10 @@ class ConversationApi(Resource):
session=session, session=session,
app_model=app_model, app_model=app_model,
user=end_user, user=end_user,
last_id=args["last_id"], last_id=last_id,
limit=args["limit"], limit=query_args.limit,
invoke_from=InvokeFrom.SERVICE_API, invoke_from=InvokeFrom.SERVICE_API,
sort_by=args["sort_by"], sort_by=query_args.sort_by,
) )
except services.errors.conversation.LastConversationNotExistsError: except services.errors.conversation.LastConversationNotExistsError:
raise NotFound("Last Conversation Not Exists.") raise NotFound("Last Conversation Not Exists.")
@ -155,7 +138,7 @@ class ConversationDetailApi(Resource):
@service_api_ns.route("/conversations/<uuid:c_id>/name") @service_api_ns.route("/conversations/<uuid:c_id>/name")
class ConversationRenameApi(Resource): class ConversationRenameApi(Resource):
@service_api_ns.expect(conversation_rename_parser) @service_api_ns.expect(service_api_ns.models[ConversationRenamePayload.__name__])
@service_api_ns.doc("rename_conversation") @service_api_ns.doc("rename_conversation")
@service_api_ns.doc(description="Rename a conversation or auto-generate a name") @service_api_ns.doc(description="Rename a conversation or auto-generate a name")
@service_api_ns.doc(params={"c_id": "Conversation ID"}) @service_api_ns.doc(params={"c_id": "Conversation ID"})
@ -176,17 +159,17 @@ class ConversationRenameApi(Resource):
conversation_id = str(c_id) conversation_id = str(c_id)
args = conversation_rename_parser.parse_args() payload = ConversationRenamePayload.model_validate(service_api_ns.payload or {})
try: try:
return ConversationService.rename(app_model, conversation_id, end_user, args["name"], args["auto_generate"]) return ConversationService.rename(app_model, conversation_id, end_user, payload.name, payload.auto_generate)
except services.errors.conversation.ConversationNotExistsError: except services.errors.conversation.ConversationNotExistsError:
raise NotFound("Conversation Not Exists.") raise NotFound("Conversation Not Exists.")
@service_api_ns.route("/conversations/<uuid:c_id>/variables") @service_api_ns.route("/conversations/<uuid:c_id>/variables")
class ConversationVariablesApi(Resource): class ConversationVariablesApi(Resource):
@service_api_ns.expect(conversation_variables_parser) @service_api_ns.expect(service_api_ns.models[ConversationVariablesQuery.__name__])
@service_api_ns.doc("list_conversation_variables") @service_api_ns.doc("list_conversation_variables")
@service_api_ns.doc(description="List all variables for a conversation") @service_api_ns.doc(description="List all variables for a conversation")
@service_api_ns.doc(params={"c_id": "Conversation ID"}) @service_api_ns.doc(params={"c_id": "Conversation ID"})
@ -211,11 +194,12 @@ class ConversationVariablesApi(Resource):
conversation_id = str(c_id) conversation_id = str(c_id)
args = conversation_variables_parser.parse_args() query_args = ConversationVariablesQuery.model_validate(request.args.to_dict())
last_id = str(query_args.last_id) if query_args.last_id else None
try: try:
return ConversationService.get_conversational_variable( return ConversationService.get_conversational_variable(
app_model, conversation_id, end_user, args["limit"], args["last_id"] app_model, conversation_id, end_user, query_args.limit, last_id
) )
except services.errors.conversation.ConversationNotExistsError: except services.errors.conversation.ConversationNotExistsError:
raise NotFound("Conversation Not Exists.") raise NotFound("Conversation Not Exists.")
@ -223,7 +207,7 @@ class ConversationVariablesApi(Resource):
@service_api_ns.route("/conversations/<uuid:c_id>/variables/<uuid:variable_id>") @service_api_ns.route("/conversations/<uuid:c_id>/variables/<uuid:variable_id>")
class ConversationVariableDetailApi(Resource): class ConversationVariableDetailApi(Resource):
@service_api_ns.expect(conversation_variable_update_parser) @service_api_ns.expect(service_api_ns.models[ConversationVariableUpdatePayload.__name__])
@service_api_ns.doc("update_conversation_variable") @service_api_ns.doc("update_conversation_variable")
@service_api_ns.doc(description="Update a conversation variable's value") @service_api_ns.doc(description="Update a conversation variable's value")
@service_api_ns.doc(params={"c_id": "Conversation ID", "variable_id": "Variable ID"}) @service_api_ns.doc(params={"c_id": "Conversation ID", "variable_id": "Variable ID"})
@ -250,11 +234,11 @@ class ConversationVariableDetailApi(Resource):
conversation_id = str(c_id) conversation_id = str(c_id)
variable_id = str(variable_id) variable_id = str(variable_id)
args = conversation_variable_update_parser.parse_args() payload = ConversationVariableUpdatePayload.model_validate(service_api_ns.payload or {})
try: try:
return ConversationService.update_conversation_variable( return ConversationService.update_conversation_variable(
app_model, conversation_id, variable_id, end_user, args["value"] app_model, conversation_id, variable_id, end_user, payload.value
) )
except services.errors.conversation.ConversationNotExistsError: except services.errors.conversation.ConversationNotExistsError:
raise NotFound("Conversation Not Exists.") raise NotFound("Conversation Not Exists.")

View File

@ -1,9 +1,11 @@
import logging import logging
from urllib.parse import quote from urllib.parse import quote
from flask import Response from flask import Response, request
from flask_restx import Resource, reqparse from flask_restx import Resource
from pydantic import BaseModel, Field
from controllers.common.schema import register_schema_model
from controllers.service_api import service_api_ns from controllers.service_api import service_api_ns
from controllers.service_api.app.error import ( from controllers.service_api.app.error import (
FileAccessDeniedError, FileAccessDeniedError,
@ -17,10 +19,11 @@ from models.model import App, EndUser, Message, MessageFile, UploadFile
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Define parser for file preview API class FilePreviewQuery(BaseModel):
file_preview_parser = reqparse.RequestParser().add_argument( as_attachment: bool = Field(default=False, description="Download as attachment")
"as_attachment", type=bool, required=False, default=False, location="args", help="Download as attachment"
)
register_schema_model(service_api_ns, FilePreviewQuery)
@service_api_ns.route("/files/<uuid:file_id>/preview") @service_api_ns.route("/files/<uuid:file_id>/preview")
@ -32,7 +35,7 @@ class FilePreviewApi(Resource):
Files can only be accessed if they belong to messages within the requesting app's context. Files can only be accessed if they belong to messages within the requesting app's context.
""" """
@service_api_ns.expect(file_preview_parser) @service_api_ns.expect(service_api_ns.models[FilePreviewQuery.__name__])
@service_api_ns.doc("preview_file") @service_api_ns.doc("preview_file")
@service_api_ns.doc(description="Preview or download a file uploaded via Service API") @service_api_ns.doc(description="Preview or download a file uploaded via Service API")
@service_api_ns.doc(params={"file_id": "UUID of the file to preview"}) @service_api_ns.doc(params={"file_id": "UUID of the file to preview"})
@ -55,7 +58,7 @@ class FilePreviewApi(Resource):
file_id = str(file_id) file_id = str(file_id)
# Parse query parameters # Parse query parameters
args = file_preview_parser.parse_args() args = FilePreviewQuery.model_validate(request.args.to_dict())
# Validate file ownership and get file objects # Validate file ownership and get file objects
_, upload_file = self._validate_file_ownership(file_id, app_model.id) _, upload_file = self._validate_file_ownership(file_id, app_model.id)
@ -67,7 +70,7 @@ class FilePreviewApi(Resource):
raise FileNotFoundError(f"Failed to load file content: {str(e)}") raise FileNotFoundError(f"Failed to load file content: {str(e)}")
# Build response with appropriate headers # Build response with appropriate headers
response = self._build_file_response(generator, upload_file, args["as_attachment"]) response = self._build_file_response(generator, upload_file, args.as_attachment)
return response return response

View File

@ -1,11 +1,15 @@
import json import json
import logging import logging
from typing import Literal
from uuid import UUID
from flask_restx import Api, Namespace, Resource, fields, reqparse from flask import request
from flask_restx.inputs import int_range from flask_restx import Namespace, Resource, fields
from pydantic import BaseModel, Field
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
import services import services
from controllers.common.schema import register_schema_models
from controllers.service_api import service_api_ns from controllers.service_api import service_api_ns
from controllers.service_api.app.error import NotChatAppError from controllers.service_api.app.error import NotChatAppError
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
@ -13,7 +17,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom
from fields.conversation_fields import build_message_file_model from fields.conversation_fields import build_message_file_model
from fields.message_fields import build_agent_thought_model, build_feedback_model from fields.message_fields import build_agent_thought_model, build_feedback_model
from fields.raws import FilesContainedField from fields.raws import FilesContainedField
from libs.helper import TimestampField, uuid_value from libs.helper import TimestampField
from models.model import App, AppMode, EndUser from models.model import App, AppMode, EndUser
from services.errors.message import ( from services.errors.message import (
FirstMessageNotExistsError, FirstMessageNotExistsError,
@ -25,42 +29,26 @@ from services.message_service import MessageService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Define parsers for message APIs class MessageListQuery(BaseModel):
message_list_parser = ( conversation_id: UUID
reqparse.RequestParser() first_id: UUID | None = None
.add_argument("conversation_id", required=True, type=uuid_value, location="args", help="Conversation ID") limit: int = Field(default=20, ge=1, le=100, description="Number of messages to return")
.add_argument("first_id", type=uuid_value, location="args", help="First message ID for pagination")
.add_argument(
"limit",
type=int_range(1, 100),
required=False,
default=20,
location="args",
help="Number of messages to return",
)
)
message_feedback_parser = (
reqparse.RequestParser()
.add_argument("rating", type=str, choices=["like", "dislike", None], location="json", help="Feedback rating")
.add_argument("content", type=str, location="json", help="Feedback content")
)
feedback_list_parser = (
reqparse.RequestParser()
.add_argument("page", type=int, default=1, location="args", help="Page number")
.add_argument(
"limit",
type=int_range(1, 101),
required=False,
default=20,
location="args",
help="Number of feedbacks per page",
)
)
def build_message_model(api_or_ns: Api | Namespace): class MessageFeedbackPayload(BaseModel):
rating: Literal["like", "dislike"] | None = Field(default=None, description="Feedback rating")
content: str | None = Field(default=None, description="Feedback content")
class FeedbackListQuery(BaseModel):
page: int = Field(default=1, ge=1, description="Page number")
limit: int = Field(default=20, ge=1, le=101, description="Number of feedbacks per page")
register_schema_models(service_api_ns, MessageListQuery, MessageFeedbackPayload, FeedbackListQuery)
def build_message_model(api_or_ns: Namespace):
"""Build the message model for the API or Namespace.""" """Build the message model for the API or Namespace."""
# First build the nested models # First build the nested models
feedback_model = build_feedback_model(api_or_ns) feedback_model = build_feedback_model(api_or_ns)
@ -90,7 +78,7 @@ def build_message_model(api_or_ns: Api | Namespace):
return api_or_ns.model("Message", message_fields) return api_or_ns.model("Message", message_fields)
def build_message_infinite_scroll_pagination_model(api_or_ns: Api | Namespace): def build_message_infinite_scroll_pagination_model(api_or_ns: Namespace):
"""Build the message infinite scroll pagination model for the API or Namespace.""" """Build the message infinite scroll pagination model for the API or Namespace."""
# Build the nested message model first # Build the nested message model first
message_model = build_message_model(api_or_ns) message_model = build_message_model(api_or_ns)
@ -105,7 +93,7 @@ def build_message_infinite_scroll_pagination_model(api_or_ns: Api | Namespace):
@service_api_ns.route("/messages") @service_api_ns.route("/messages")
class MessageListApi(Resource): class MessageListApi(Resource):
@service_api_ns.expect(message_list_parser) @service_api_ns.expect(service_api_ns.models[MessageListQuery.__name__])
@service_api_ns.doc("list_messages") @service_api_ns.doc("list_messages")
@service_api_ns.doc(description="List messages in a conversation") @service_api_ns.doc(description="List messages in a conversation")
@service_api_ns.doc( @service_api_ns.doc(
@ -126,11 +114,13 @@ class MessageListApi(Resource):
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError() raise NotChatAppError()
args = message_list_parser.parse_args() query_args = MessageListQuery.model_validate(request.args.to_dict())
conversation_id = str(query_args.conversation_id)
first_id = str(query_args.first_id) if query_args.first_id else None
try: try:
return MessageService.pagination_by_first_id( return MessageService.pagination_by_first_id(
app_model, end_user, args["conversation_id"], args["first_id"], args["limit"] app_model, end_user, conversation_id, first_id, query_args.limit
) )
except services.errors.conversation.ConversationNotExistsError: except services.errors.conversation.ConversationNotExistsError:
raise NotFound("Conversation Not Exists.") raise NotFound("Conversation Not Exists.")
@ -140,7 +130,7 @@ class MessageListApi(Resource):
@service_api_ns.route("/messages/<uuid:message_id>/feedbacks") @service_api_ns.route("/messages/<uuid:message_id>/feedbacks")
class MessageFeedbackApi(Resource): class MessageFeedbackApi(Resource):
@service_api_ns.expect(message_feedback_parser) @service_api_ns.expect(service_api_ns.models[MessageFeedbackPayload.__name__])
@service_api_ns.doc("create_message_feedback") @service_api_ns.doc("create_message_feedback")
@service_api_ns.doc(description="Submit feedback for a message") @service_api_ns.doc(description="Submit feedback for a message")
@service_api_ns.doc(params={"message_id": "Message ID"}) @service_api_ns.doc(params={"message_id": "Message ID"})
@ -159,15 +149,15 @@ class MessageFeedbackApi(Resource):
""" """
message_id = str(message_id) message_id = str(message_id)
args = message_feedback_parser.parse_args() payload = MessageFeedbackPayload.model_validate(service_api_ns.payload or {})
try: try:
MessageService.create_feedback( MessageService.create_feedback(
app_model=app_model, app_model=app_model,
message_id=message_id, message_id=message_id,
user=end_user, user=end_user,
rating=args.get("rating"), rating=payload.rating,
content=args.get("content"), content=payload.content,
) )
except MessageNotExistsError: except MessageNotExistsError:
raise NotFound("Message Not Exists.") raise NotFound("Message Not Exists.")
@ -177,7 +167,7 @@ class MessageFeedbackApi(Resource):
@service_api_ns.route("/app/feedbacks") @service_api_ns.route("/app/feedbacks")
class AppGetFeedbacksApi(Resource): class AppGetFeedbacksApi(Resource):
@service_api_ns.expect(feedback_list_parser) @service_api_ns.expect(service_api_ns.models[FeedbackListQuery.__name__])
@service_api_ns.doc("get_app_feedbacks") @service_api_ns.doc("get_app_feedbacks")
@service_api_ns.doc(description="Get all feedbacks for the application") @service_api_ns.doc(description="Get all feedbacks for the application")
@service_api_ns.doc( @service_api_ns.doc(
@ -192,8 +182,8 @@ class AppGetFeedbacksApi(Resource):
Returns paginated list of all feedback submitted for messages in this app. Returns paginated list of all feedback submitted for messages in this app.
""" """
args = feedback_list_parser.parse_args() query_args = FeedbackListQuery.model_validate(request.args.to_dict())
feedbacks = MessageService.get_all_messages_feedbacks(app_model, page=args["page"], limit=args["limit"]) feedbacks = MessageService.get_all_messages_feedbacks(app_model, page=query_args.page, limit=query_args.limit)
return {"data": feedbacks} return {"data": feedbacks}

View File

@ -1,12 +1,14 @@
import logging import logging
from typing import Any, Literal
from dateutil.parser import isoparse from dateutil.parser import isoparse
from flask import request from flask import request
from flask_restx import Api, Namespace, Resource, fields, reqparse from flask_restx import Api, Namespace, Resource, fields
from flask_restx.inputs import int_range from pydantic import BaseModel, Field
from sqlalchemy.orm import Session, sessionmaker from sqlalchemy.orm import Session, sessionmaker
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
from controllers.common.schema import register_schema_models
from controllers.service_api import service_api_ns from controllers.service_api import service_api_ns
from controllers.service_api.app.error import ( from controllers.service_api.app.error import (
CompletionRequestError, CompletionRequestError,
@ -41,37 +43,25 @@ from services.workflow_app_service import WorkflowAppService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Define parsers for workflow APIs
workflow_run_parser = (
reqparse.RequestParser()
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
.add_argument("files", type=list, required=False, location="json")
.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
)
workflow_log_parser = ( class WorkflowRunPayload(BaseModel):
reqparse.RequestParser() inputs: dict[str, Any]
.add_argument("keyword", type=str, location="args") files: list[dict[str, Any]] | None = None
.add_argument("status", type=str, choices=["succeeded", "failed", "stopped"], location="args") response_mode: Literal["blocking", "streaming"] | None = None
.add_argument("created_at__before", type=str, location="args")
.add_argument("created_at__after", type=str, location="args")
.add_argument( class WorkflowLogQuery(BaseModel):
"created_by_end_user_session_id", keyword: str | None = None
type=str, status: Literal["succeeded", "failed", "stopped"] | None = None
location="args", created_at__before: str | None = None
required=False, created_at__after: str | None = None
default=None, created_by_end_user_session_id: str | None = None
) created_by_account: str | None = None
.add_argument( page: int = Field(default=1, ge=1, le=99999)
"created_by_account", limit: int = Field(default=20, ge=1, le=100)
type=str,
location="args",
required=False, register_schema_models(service_api_ns, WorkflowRunPayload, WorkflowLogQuery)
default=None,
)
.add_argument("page", type=int_range(1, 99999), default=1, location="args")
.add_argument("limit", type=int_range(1, 100), default=20, location="args")
)
workflow_run_fields = { workflow_run_fields = {
"id": fields.String, "id": fields.String,
@ -130,7 +120,7 @@ class WorkflowRunDetailApi(Resource):
@service_api_ns.route("/workflows/run") @service_api_ns.route("/workflows/run")
class WorkflowRunApi(Resource): class WorkflowRunApi(Resource):
@service_api_ns.expect(workflow_run_parser) @service_api_ns.expect(service_api_ns.models[WorkflowRunPayload.__name__])
@service_api_ns.doc("run_workflow") @service_api_ns.doc("run_workflow")
@service_api_ns.doc(description="Execute a workflow") @service_api_ns.doc(description="Execute a workflow")
@service_api_ns.doc( @service_api_ns.doc(
@ -154,11 +144,12 @@ class WorkflowRunApi(Resource):
if app_mode != AppMode.WORKFLOW: if app_mode != AppMode.WORKFLOW:
raise NotWorkflowAppError() raise NotWorkflowAppError()
args = workflow_run_parser.parse_args() payload = WorkflowRunPayload.model_validate(service_api_ns.payload or {})
args = payload.model_dump(exclude_none=True)
external_trace_id = get_external_trace_id(request) external_trace_id = get_external_trace_id(request)
if external_trace_id: if external_trace_id:
args["external_trace_id"] = external_trace_id args["external_trace_id"] = external_trace_id
streaming = args.get("response_mode") == "streaming" streaming = payload.response_mode == "streaming"
try: try:
response = AppGenerateService.generate( response = AppGenerateService.generate(
@ -185,7 +176,7 @@ class WorkflowRunApi(Resource):
@service_api_ns.route("/workflows/<string:workflow_id>/run") @service_api_ns.route("/workflows/<string:workflow_id>/run")
class WorkflowRunByIdApi(Resource): class WorkflowRunByIdApi(Resource):
@service_api_ns.expect(workflow_run_parser) @service_api_ns.expect(service_api_ns.models[WorkflowRunPayload.__name__])
@service_api_ns.doc("run_workflow_by_id") @service_api_ns.doc("run_workflow_by_id")
@service_api_ns.doc(description="Execute a specific workflow by ID") @service_api_ns.doc(description="Execute a specific workflow by ID")
@service_api_ns.doc(params={"workflow_id": "Workflow ID to execute"}) @service_api_ns.doc(params={"workflow_id": "Workflow ID to execute"})
@ -209,7 +200,8 @@ class WorkflowRunByIdApi(Resource):
if app_mode != AppMode.WORKFLOW: if app_mode != AppMode.WORKFLOW:
raise NotWorkflowAppError() raise NotWorkflowAppError()
args = workflow_run_parser.parse_args() payload = WorkflowRunPayload.model_validate(service_api_ns.payload or {})
args = payload.model_dump(exclude_none=True)
# Add workflow_id to args for AppGenerateService # Add workflow_id to args for AppGenerateService
args["workflow_id"] = workflow_id args["workflow_id"] = workflow_id
@ -217,7 +209,7 @@ class WorkflowRunByIdApi(Resource):
external_trace_id = get_external_trace_id(request) external_trace_id = get_external_trace_id(request)
if external_trace_id: if external_trace_id:
args["external_trace_id"] = external_trace_id args["external_trace_id"] = external_trace_id
streaming = args.get("response_mode") == "streaming" streaming = payload.response_mode == "streaming"
try: try:
response = AppGenerateService.generate( response = AppGenerateService.generate(
@ -279,7 +271,7 @@ class WorkflowTaskStopApi(Resource):
@service_api_ns.route("/workflows/logs") @service_api_ns.route("/workflows/logs")
class WorkflowAppLogApi(Resource): class WorkflowAppLogApi(Resource):
@service_api_ns.expect(workflow_log_parser) @service_api_ns.expect(service_api_ns.models[WorkflowLogQuery.__name__])
@service_api_ns.doc("get_workflow_logs") @service_api_ns.doc("get_workflow_logs")
@service_api_ns.doc(description="Get workflow execution logs") @service_api_ns.doc(description="Get workflow execution logs")
@service_api_ns.doc( @service_api_ns.doc(
@ -295,14 +287,11 @@ class WorkflowAppLogApi(Resource):
Returns paginated workflow execution logs with filtering options. Returns paginated workflow execution logs with filtering options.
""" """
args = workflow_log_parser.parse_args() args = WorkflowLogQuery.model_validate(request.args.to_dict())
args.status = WorkflowExecutionStatus(args.status) if args.status else None status = WorkflowExecutionStatus(args.status) if args.status else None
if args.created_at__before: created_at_before = isoparse(args.created_at__before) if args.created_at__before else None
args.created_at__before = isoparse(args.created_at__before) created_at_after = isoparse(args.created_at__after) if args.created_at__after else None
if args.created_at__after:
args.created_at__after = isoparse(args.created_at__after)
# get paginate workflow app logs # get paginate workflow app logs
workflow_app_service = WorkflowAppService() workflow_app_service = WorkflowAppService()
@ -311,9 +300,9 @@ class WorkflowAppLogApi(Resource):
session=session, session=session,
app_model=app_model, app_model=app_model,
keyword=args.keyword, keyword=args.keyword,
status=args.status, status=status,
created_at_before=args.created_at__before, created_at_before=created_at_before,
created_at_after=args.created_at__after, created_at_after=created_at_after,
page=args.page, page=args.page,
limit=args.limit, limit=args.limit,
created_by_end_user_session_id=args.created_by_end_user_session_id, created_by_end_user_session_id=args.created_by_end_user_session_id,

View File

@ -1,10 +1,12 @@
from typing import Any, Literal, cast from typing import Any, Literal, cast
from flask import request from flask import request
from flask_restx import marshal, reqparse from flask_restx import marshal
from pydantic import BaseModel, Field, field_validator
from werkzeug.exceptions import Forbidden, NotFound from werkzeug.exceptions import Forbidden, NotFound
import services import services
from controllers.common.schema import register_schema_models
from controllers.console.wraps import edit_permission_required from controllers.console.wraps import edit_permission_required
from controllers.service_api import service_api_ns from controllers.service_api import service_api_ns
from controllers.service_api.dataset.error import DatasetInUseError, DatasetNameDuplicateError, InvalidActionError from controllers.service_api.dataset.error import DatasetInUseError, DatasetNameDuplicateError, InvalidActionError
@ -18,173 +20,83 @@ from core.provider_manager import ProviderManager
from fields.dataset_fields import dataset_detail_fields from fields.dataset_fields import dataset_detail_fields
from fields.tag_fields import build_dataset_tag_fields from fields.tag_fields import build_dataset_tag_fields
from libs.login import current_user from libs.login import current_user
from libs.validators import validate_description_length
from models.account import Account from models.account import Account
from models.dataset import Dataset, DatasetPermissionEnum from models.dataset import DatasetPermissionEnum
from models.provider_ids import ModelProviderID from models.provider_ids import ModelProviderID
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
from services.entities.knowledge_entities.knowledge_entities import RetrievalModel from services.entities.knowledge_entities.knowledge_entities import RetrievalModel
from services.tag_service import TagService from services.tag_service import TagService
def _validate_name(name): class DatasetCreatePayload(BaseModel):
if not name or len(name) < 1 or len(name) > 40: name: str = Field(..., min_length=1, max_length=40)
raise ValueError("Name must be between 1 to 40 characters.") description: str = Field(default="", description="Dataset description (max 400 chars)", max_length=400)
return name indexing_technique: Literal["high_quality", "economy"] | None = None
permission: DatasetPermissionEnum | None = DatasetPermissionEnum.ONLY_ME
external_knowledge_api_id: str | None = None
provider: str = "vendor"
external_knowledge_id: str | None = None
retrieval_model: RetrievalModel | None = None
embedding_model: str | None = None
embedding_model_provider: str | None = None
# Define parsers for dataset operations class DatasetUpdatePayload(BaseModel):
dataset_create_parser = ( name: str | None = Field(default=None, min_length=1, max_length=40)
reqparse.RequestParser() description: str | None = Field(default=None, description="Dataset description (max 400 chars)", max_length=400)
.add_argument( indexing_technique: Literal["high_quality", "economy"] | None = None
"name", permission: DatasetPermissionEnum | None = None
nullable=False, embedding_model: str | None = None
required=True, embedding_model_provider: str | None = None
help="type is required. Name must be between 1 to 40 characters.", retrieval_model: RetrievalModel | None = None
type=_validate_name, partial_member_list: list[str] | None = None
) external_retrieval_model: dict[str, Any] | None = None
.add_argument( external_knowledge_id: str | None = None
"description", external_knowledge_api_id: str | None = None
type=validate_description_length,
nullable=True,
required=False,
default="",
)
.add_argument(
"indexing_technique",
type=str,
location="json",
choices=Dataset.INDEXING_TECHNIQUE_LIST,
help="Invalid indexing technique.",
)
.add_argument(
"permission",
type=str,
location="json",
choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM),
help="Invalid permission.",
required=False,
nullable=False,
)
.add_argument(
"external_knowledge_api_id",
type=str,
nullable=True,
required=False,
default="_validate_name",
)
.add_argument(
"provider",
type=str,
nullable=True,
required=False,
default="vendor",
)
.add_argument(
"external_knowledge_id",
type=str,
nullable=True,
required=False,
)
.add_argument("retrieval_model", type=dict, required=False, nullable=True, location="json")
.add_argument("embedding_model", type=str, required=False, nullable=True, location="json")
.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
)
dataset_update_parser = (
reqparse.RequestParser()
.add_argument(
"name",
nullable=False,
help="type is required. Name must be between 1 to 40 characters.",
type=_validate_name,
)
.add_argument("description", location="json", store_missing=False, type=validate_description_length)
.add_argument(
"indexing_technique",
type=str,
location="json",
choices=Dataset.INDEXING_TECHNIQUE_LIST,
nullable=True,
help="Invalid indexing technique.",
)
.add_argument(
"permission",
type=str,
location="json",
choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM),
help="Invalid permission.",
)
.add_argument("embedding_model", type=str, location="json", help="Invalid embedding model.")
.add_argument("embedding_model_provider", type=str, location="json", help="Invalid embedding model provider.")
.add_argument("retrieval_model", type=dict, location="json", help="Invalid retrieval model.")
.add_argument("partial_member_list", type=list, location="json", help="Invalid parent user list.")
.add_argument(
"external_retrieval_model",
type=dict,
required=False,
nullable=True,
location="json",
help="Invalid external retrieval model.",
)
.add_argument(
"external_knowledge_id",
type=str,
required=False,
nullable=True,
location="json",
help="Invalid external knowledge id.",
)
.add_argument(
"external_knowledge_api_id",
type=str,
required=False,
nullable=True,
location="json",
help="Invalid external knowledge api id.",
)
)
tag_create_parser = reqparse.RequestParser().add_argument( class TagNamePayload(BaseModel):
"name", name: str = Field(..., min_length=1, max_length=50)
nullable=False,
required=True,
help="Name must be between 1 to 50 characters.",
type=lambda x: x
if x and 1 <= len(x) <= 50
else (_ for _ in ()).throw(ValueError("Name must be between 1 to 50 characters.")),
)
tag_update_parser = (
reqparse.RequestParser()
.add_argument(
"name",
nullable=False,
required=True,
help="Name must be between 1 to 50 characters.",
type=lambda x: x
if x and 1 <= len(x) <= 50
else (_ for _ in ()).throw(ValueError("Name must be between 1 to 50 characters.")),
)
.add_argument("tag_id", nullable=False, required=True, help="Id of a tag.", type=str)
)
tag_delete_parser = reqparse.RequestParser().add_argument( class TagCreatePayload(TagNamePayload):
"tag_id", nullable=False, required=True, help="Id of a tag.", type=str pass
)
tag_binding_parser = (
reqparse.RequestParser()
.add_argument("tag_ids", type=list, nullable=False, required=True, location="json", help="Tag IDs is required.")
.add_argument(
"target_id", type=str, nullable=False, required=True, location="json", help="Target Dataset ID is required."
)
)
tag_unbinding_parser = ( class TagUpdatePayload(TagNamePayload):
reqparse.RequestParser() tag_id: str
.add_argument("tag_id", type=str, nullable=False, required=True, help="Tag ID is required.")
.add_argument("target_id", type=str, nullable=False, required=True, help="Target ID is required.")
class TagDeletePayload(BaseModel):
tag_id: str
class TagBindingPayload(BaseModel):
tag_ids: list[str]
target_id: str
@field_validator("tag_ids")
@classmethod
def validate_tag_ids(cls, value: list[str]) -> list[str]:
if not value:
raise ValueError("Tag IDs is required.")
return value
class TagUnbindingPayload(BaseModel):
tag_id: str
target_id: str
register_schema_models(
service_api_ns,
DatasetCreatePayload,
DatasetUpdatePayload,
TagCreatePayload,
TagUpdatePayload,
TagDeletePayload,
TagBindingPayload,
TagUnbindingPayload,
) )
@ -239,7 +151,7 @@ class DatasetListApi(DatasetApiResource):
response = {"data": data, "has_more": len(datasets) == limit, "limit": limit, "total": total, "page": page} response = {"data": data, "has_more": len(datasets) == limit, "limit": limit, "total": total, "page": page}
return response, 200 return response, 200
@service_api_ns.expect(dataset_create_parser) @service_api_ns.expect(service_api_ns.models[DatasetCreatePayload.__name__])
@service_api_ns.doc("create_dataset") @service_api_ns.doc("create_dataset")
@service_api_ns.doc(description="Create a new dataset") @service_api_ns.doc(description="Create a new dataset")
@service_api_ns.doc( @service_api_ns.doc(
@ -252,42 +164,41 @@ class DatasetListApi(DatasetApiResource):
@cloud_edition_billing_rate_limit_check("knowledge", "dataset") @cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id): def post(self, tenant_id):
"""Resource for creating datasets.""" """Resource for creating datasets."""
args = dataset_create_parser.parse_args() payload = DatasetCreatePayload.model_validate(service_api_ns.payload or {})
embedding_model_provider = args.get("embedding_model_provider") embedding_model_provider = payload.embedding_model_provider
embedding_model = args.get("embedding_model") embedding_model = payload.embedding_model
if embedding_model_provider and embedding_model: if embedding_model_provider and embedding_model:
DatasetService.check_embedding_model_setting(tenant_id, embedding_model_provider, embedding_model) DatasetService.check_embedding_model_setting(tenant_id, embedding_model_provider, embedding_model)
retrieval_model = args.get("retrieval_model") retrieval_model = payload.retrieval_model
if ( if (
retrieval_model retrieval_model
and retrieval_model.get("reranking_model") and retrieval_model.reranking_model
and retrieval_model.get("reranking_model").get("reranking_provider_name") and retrieval_model.reranking_model.reranking_provider_name
and retrieval_model.reranking_model.reranking_model_name
): ):
DatasetService.check_reranking_model_setting( DatasetService.check_reranking_model_setting(
tenant_id, tenant_id,
retrieval_model.get("reranking_model").get("reranking_provider_name"), retrieval_model.reranking_model.reranking_provider_name,
retrieval_model.get("reranking_model").get("reranking_model_name"), retrieval_model.reranking_model.reranking_model_name,
) )
try: try:
assert isinstance(current_user, Account) assert isinstance(current_user, Account)
dataset = DatasetService.create_empty_dataset( dataset = DatasetService.create_empty_dataset(
tenant_id=tenant_id, tenant_id=tenant_id,
name=args["name"], name=payload.name,
description=args["description"], description=payload.description,
indexing_technique=args["indexing_technique"], indexing_technique=payload.indexing_technique,
account=current_user, account=current_user,
permission=args["permission"], permission=str(payload.permission) if payload.permission else None,
provider=args["provider"], provider=payload.provider,
external_knowledge_api_id=args["external_knowledge_api_id"], external_knowledge_api_id=payload.external_knowledge_api_id,
external_knowledge_id=args["external_knowledge_id"], external_knowledge_id=payload.external_knowledge_id,
embedding_model_provider=args["embedding_model_provider"], embedding_model_provider=payload.embedding_model_provider,
embedding_model_name=args["embedding_model"], embedding_model_name=payload.embedding_model,
retrieval_model=RetrievalModel.model_validate(args["retrieval_model"]) retrieval_model=payload.retrieval_model,
if args["retrieval_model"] is not None
else None,
) )
except services.errors.dataset.DatasetNameDuplicateError: except services.errors.dataset.DatasetNameDuplicateError:
raise DatasetNameDuplicateError() raise DatasetNameDuplicateError()
@ -353,7 +264,7 @@ class DatasetApi(DatasetApiResource):
return data, 200 return data, 200
@service_api_ns.expect(dataset_update_parser) @service_api_ns.expect(service_api_ns.models[DatasetUpdatePayload.__name__])
@service_api_ns.doc("update_dataset") @service_api_ns.doc("update_dataset")
@service_api_ns.doc(description="Update an existing dataset") @service_api_ns.doc(description="Update an existing dataset")
@service_api_ns.doc(params={"dataset_id": "Dataset ID"}) @service_api_ns.doc(params={"dataset_id": "Dataset ID"})
@ -372,36 +283,45 @@ class DatasetApi(DatasetApiResource):
if dataset is None: if dataset is None:
raise NotFound("Dataset not found.") raise NotFound("Dataset not found.")
args = dataset_update_parser.parse_args() payload_dict = service_api_ns.payload or {}
data = request.get_json() payload = DatasetUpdatePayload.model_validate(payload_dict)
update_data = payload.model_dump(exclude_unset=True)
if payload.permission is not None:
update_data["permission"] = str(payload.permission)
if payload.retrieval_model is not None:
update_data["retrieval_model"] = payload.retrieval_model.model_dump()
# check embedding model setting # check embedding model setting
embedding_model_provider = data.get("embedding_model_provider") embedding_model_provider = payload.embedding_model_provider
embedding_model = data.get("embedding_model") embedding_model = payload.embedding_model
if data.get("indexing_technique") == "high_quality" or embedding_model_provider: if payload.indexing_technique == "high_quality" or embedding_model_provider:
if embedding_model_provider and embedding_model: if embedding_model_provider and embedding_model:
DatasetService.check_embedding_model_setting( DatasetService.check_embedding_model_setting(
dataset.tenant_id, embedding_model_provider, embedding_model dataset.tenant_id, embedding_model_provider, embedding_model
) )
retrieval_model = data.get("retrieval_model") retrieval_model = payload.retrieval_model
if ( if (
retrieval_model retrieval_model
and retrieval_model.get("reranking_model") and retrieval_model.reranking_model
and retrieval_model.get("reranking_model").get("reranking_provider_name") and retrieval_model.reranking_model.reranking_provider_name
and retrieval_model.reranking_model.reranking_model_name
): ):
DatasetService.check_reranking_model_setting( DatasetService.check_reranking_model_setting(
dataset.tenant_id, dataset.tenant_id,
retrieval_model.get("reranking_model").get("reranking_provider_name"), retrieval_model.reranking_model.reranking_provider_name,
retrieval_model.get("reranking_model").get("reranking_model_name"), retrieval_model.reranking_model.reranking_model_name,
) )
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
DatasetPermissionService.check_permission( DatasetPermissionService.check_permission(
current_user, dataset, data.get("permission"), data.get("partial_member_list") current_user,
dataset,
str(payload.permission) if payload.permission else None,
payload.partial_member_list,
) )
dataset = DatasetService.update_dataset(dataset_id_str, args, current_user) dataset = DatasetService.update_dataset(dataset_id_str, update_data, current_user)
if dataset is None: if dataset is None:
raise NotFound("Dataset not found.") raise NotFound("Dataset not found.")
@ -410,15 +330,10 @@ class DatasetApi(DatasetApiResource):
assert isinstance(current_user, Account) assert isinstance(current_user, Account)
tenant_id = current_user.current_tenant_id tenant_id = current_user.current_tenant_id
if data.get("partial_member_list") and data.get("permission") == "partial_members": if payload.partial_member_list and payload.permission == DatasetPermissionEnum.PARTIAL_TEAM:
DatasetPermissionService.update_partial_member_list( DatasetPermissionService.update_partial_member_list(tenant_id, dataset_id_str, payload.partial_member_list)
tenant_id, dataset_id_str, data.get("partial_member_list")
)
# clear partial member list when permission is only_me or all_team_members # clear partial member list when permission is only_me or all_team_members
elif ( elif payload.permission in {DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM}:
data.get("permission") == DatasetPermissionEnum.ONLY_ME
or data.get("permission") == DatasetPermissionEnum.ALL_TEAM
):
DatasetPermissionService.clear_partial_member_list(dataset_id_str) DatasetPermissionService.clear_partial_member_list(dataset_id_str)
partial_member_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str) partial_member_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
@ -556,7 +471,7 @@ class DatasetTagsApi(DatasetApiResource):
return tags, 200 return tags, 200
@service_api_ns.expect(tag_create_parser) @service_api_ns.expect(service_api_ns.models[TagCreatePayload.__name__])
@service_api_ns.doc("create_dataset_tag") @service_api_ns.doc("create_dataset_tag")
@service_api_ns.doc(description="Add a knowledge type tag") @service_api_ns.doc(description="Add a knowledge type tag")
@service_api_ns.doc( @service_api_ns.doc(
@ -574,14 +489,13 @@ class DatasetTagsApi(DatasetApiResource):
if not (current_user.has_edit_permission or current_user.is_dataset_editor): if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden() raise Forbidden()
args = tag_create_parser.parse_args() payload = TagCreatePayload.model_validate(service_api_ns.payload or {})
args["type"] = "knowledge" tag = TagService.save_tags({"name": payload.name, "type": "knowledge"})
tag = TagService.save_tags(args)
response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0} response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0}
return response, 200 return response, 200
@service_api_ns.expect(tag_update_parser) @service_api_ns.expect(service_api_ns.models[TagUpdatePayload.__name__])
@service_api_ns.doc("update_dataset_tag") @service_api_ns.doc("update_dataset_tag")
@service_api_ns.doc(description="Update a knowledge type tag") @service_api_ns.doc(description="Update a knowledge type tag")
@service_api_ns.doc( @service_api_ns.doc(
@ -598,10 +512,10 @@ class DatasetTagsApi(DatasetApiResource):
if not (current_user.has_edit_permission or current_user.is_dataset_editor): if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden() raise Forbidden()
args = tag_update_parser.parse_args() payload = TagUpdatePayload.model_validate(service_api_ns.payload or {})
args["type"] = "knowledge" params = {"name": payload.name, "type": "knowledge"}
tag_id = args["tag_id"] tag_id = payload.tag_id
tag = TagService.update_tags(args, tag_id) tag = TagService.update_tags(params, tag_id)
binding_count = TagService.get_tag_binding_count(tag_id) binding_count = TagService.get_tag_binding_count(tag_id)
@ -609,7 +523,7 @@ class DatasetTagsApi(DatasetApiResource):
return response, 200 return response, 200
@service_api_ns.expect(tag_delete_parser) @service_api_ns.expect(service_api_ns.models[TagDeletePayload.__name__])
@service_api_ns.doc("delete_dataset_tag") @service_api_ns.doc("delete_dataset_tag")
@service_api_ns.doc(description="Delete a knowledge type tag") @service_api_ns.doc(description="Delete a knowledge type tag")
@service_api_ns.doc( @service_api_ns.doc(
@ -623,15 +537,15 @@ class DatasetTagsApi(DatasetApiResource):
@edit_permission_required @edit_permission_required
def delete(self, _, dataset_id): def delete(self, _, dataset_id):
"""Delete a knowledge type tag.""" """Delete a knowledge type tag."""
args = tag_delete_parser.parse_args() payload = TagDeletePayload.model_validate(service_api_ns.payload or {})
TagService.delete_tag(args["tag_id"]) TagService.delete_tag(payload.tag_id)
return 204 return 204
@service_api_ns.route("/datasets/tags/binding") @service_api_ns.route("/datasets/tags/binding")
class DatasetTagBindingApi(DatasetApiResource): class DatasetTagBindingApi(DatasetApiResource):
@service_api_ns.expect(tag_binding_parser) @service_api_ns.expect(service_api_ns.models[TagBindingPayload.__name__])
@service_api_ns.doc("bind_dataset_tags") @service_api_ns.doc("bind_dataset_tags")
@service_api_ns.doc(description="Bind tags to a dataset") @service_api_ns.doc(description="Bind tags to a dataset")
@service_api_ns.doc( @service_api_ns.doc(
@ -648,16 +562,15 @@ class DatasetTagBindingApi(DatasetApiResource):
if not (current_user.has_edit_permission or current_user.is_dataset_editor): if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden() raise Forbidden()
args = tag_binding_parser.parse_args() payload = TagBindingPayload.model_validate(service_api_ns.payload or {})
args["type"] = "knowledge" TagService.save_tag_binding({"tag_ids": payload.tag_ids, "target_id": payload.target_id, "type": "knowledge"})
TagService.save_tag_binding(args)
return 204 return 204
@service_api_ns.route("/datasets/tags/unbinding") @service_api_ns.route("/datasets/tags/unbinding")
class DatasetTagUnbindingApi(DatasetApiResource): class DatasetTagUnbindingApi(DatasetApiResource):
@service_api_ns.expect(tag_unbinding_parser) @service_api_ns.expect(service_api_ns.models[TagUnbindingPayload.__name__])
@service_api_ns.doc("unbind_dataset_tag") @service_api_ns.doc("unbind_dataset_tag")
@service_api_ns.doc(description="Unbind a tag from a dataset") @service_api_ns.doc(description="Unbind a tag from a dataset")
@service_api_ns.doc( @service_api_ns.doc(
@ -674,9 +587,8 @@ class DatasetTagUnbindingApi(DatasetApiResource):
if not (current_user.has_edit_permission or current_user.is_dataset_editor): if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden() raise Forbidden()
args = tag_unbinding_parser.parse_args() payload = TagUnbindingPayload.model_validate(service_api_ns.payload or {})
args["type"] = "knowledge" TagService.delete_tag_binding({"tag_id": payload.tag_id, "target_id": payload.target_id, "type": "knowledge"})
TagService.delete_tag_binding(args)
return 204 return 204

View File

@ -3,8 +3,8 @@ from typing import Self
from uuid import UUID from uuid import UUID
from flask import request from flask import request
from flask_restx import marshal, reqparse from flask_restx import marshal
from pydantic import BaseModel, model_validator from pydantic import BaseModel, Field, model_validator
from sqlalchemy import desc, select from sqlalchemy import desc, select
from werkzeug.exceptions import Forbidden, NotFound from werkzeug.exceptions import Forbidden, NotFound
@ -37,22 +37,19 @@ from services.dataset_service import DatasetService, DocumentService
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig, ProcessRule, RetrievalModel from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig, ProcessRule, RetrievalModel
from services.file_service import FileService from services.file_service import FileService
# Define parsers for document operations
document_text_create_parser = ( class DocumentTextCreatePayload(BaseModel):
reqparse.RequestParser() name: str
.add_argument("name", type=str, required=True, nullable=False, location="json") text: str
.add_argument("text", type=str, required=True, nullable=False, location="json") process_rule: ProcessRule | None = None
.add_argument("process_rule", type=dict, required=False, nullable=True, location="json") original_document_id: str | None = None
.add_argument("original_document_id", type=str, required=False, location="json") doc_form: str = Field(default="text_model")
.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json") doc_language: str = Field(default="English")
.add_argument("doc_language", type=str, default="English", required=False, nullable=False, location="json") indexing_technique: str | None = None
.add_argument( retrieval_model: RetrievalModel | None = None
"indexing_technique", type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=False, location="json" embedding_model: str | None = None
) embedding_model_provider: str | None = None
.add_argument("retrieval_model", type=dict, required=False, nullable=True, location="json")
.add_argument("embedding_model", type=str, required=False, nullable=True, location="json")
.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
)
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
@ -72,7 +69,7 @@ class DocumentTextUpdate(BaseModel):
return self return self
for m in [ProcessRule, RetrievalModel, DocumentTextUpdate]: for m in [ProcessRule, RetrievalModel, DocumentTextCreatePayload, DocumentTextUpdate]:
service_api_ns.schema_model(m.__name__, m.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) # type: ignore service_api_ns.schema_model(m.__name__, m.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) # type: ignore
@ -83,7 +80,7 @@ for m in [ProcessRule, RetrievalModel, DocumentTextUpdate]:
class DocumentAddByTextApi(DatasetApiResource): class DocumentAddByTextApi(DatasetApiResource):
"""Resource for documents.""" """Resource for documents."""
@service_api_ns.expect(document_text_create_parser) @service_api_ns.expect(service_api_ns.models[DocumentTextCreatePayload.__name__])
@service_api_ns.doc("create_document_by_text") @service_api_ns.doc("create_document_by_text")
@service_api_ns.doc(description="Create a new document by providing text content") @service_api_ns.doc(description="Create a new document by providing text content")
@service_api_ns.doc(params={"dataset_id": "Dataset ID"}) @service_api_ns.doc(params={"dataset_id": "Dataset ID"})
@ -99,7 +96,8 @@ class DocumentAddByTextApi(DatasetApiResource):
@cloud_edition_billing_rate_limit_check("knowledge", "dataset") @cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id, dataset_id): def post(self, tenant_id, dataset_id):
"""Create document by text.""" """Create document by text."""
args = document_text_create_parser.parse_args() payload = DocumentTextCreatePayload.model_validate(service_api_ns.payload or {})
args = payload.model_dump(exclude_none=True)
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
tenant_id = str(tenant_id) tenant_id = str(tenant_id)
@ -111,33 +109,29 @@ class DocumentAddByTextApi(DatasetApiResource):
if not dataset.indexing_technique and not args["indexing_technique"]: if not dataset.indexing_technique and not args["indexing_technique"]:
raise ValueError("indexing_technique is required.") raise ValueError("indexing_technique is required.")
text = args.get("text") embedding_model_provider = payload.embedding_model_provider
name = args.get("name") embedding_model = payload.embedding_model
if text is None or name is None:
raise ValueError("Both 'text' and 'name' must be non-null values.")
embedding_model_provider = args.get("embedding_model_provider")
embedding_model = args.get("embedding_model")
if embedding_model_provider and embedding_model: if embedding_model_provider and embedding_model:
DatasetService.check_embedding_model_setting(tenant_id, embedding_model_provider, embedding_model) DatasetService.check_embedding_model_setting(tenant_id, embedding_model_provider, embedding_model)
retrieval_model = args.get("retrieval_model") retrieval_model = payload.retrieval_model
if ( if (
retrieval_model retrieval_model
and retrieval_model.get("reranking_model") and retrieval_model.reranking_model
and retrieval_model.get("reranking_model").get("reranking_provider_name") and retrieval_model.reranking_model.reranking_provider_name
and retrieval_model.reranking_model.reranking_model_name
): ):
DatasetService.check_reranking_model_setting( DatasetService.check_reranking_model_setting(
tenant_id, tenant_id,
retrieval_model.get("reranking_model").get("reranking_provider_name"), retrieval_model.reranking_model.reranking_provider_name,
retrieval_model.get("reranking_model").get("reranking_model_name"), retrieval_model.reranking_model.reranking_model_name,
) )
if not current_user: if not current_user:
raise ValueError("current_user is required") raise ValueError("current_user is required")
upload_file = FileService(db.engine).upload_text( upload_file = FileService(db.engine).upload_text(
text=str(text), text_name=str(name), user_id=current_user.id, tenant_id=tenant_id text=payload.text, text_name=payload.name, user_id=current_user.id, tenant_id=tenant_id
) )
data_source = { data_source = {
"type": "upload_file", "type": "upload_file",
@ -174,7 +168,7 @@ class DocumentAddByTextApi(DatasetApiResource):
class DocumentUpdateByTextApi(DatasetApiResource): class DocumentUpdateByTextApi(DatasetApiResource):
"""Resource for update documents.""" """Resource for update documents."""
@service_api_ns.expect(service_api_ns.models[DocumentTextUpdate.__name__], validate=True) @service_api_ns.expect(service_api_ns.models[DocumentTextUpdate.__name__])
@service_api_ns.doc("update_document_by_text") @service_api_ns.doc("update_document_by_text")
@service_api_ns.doc(description="Update an existing document by providing text content") @service_api_ns.doc(description="Update an existing document by providing text content")
@service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"}) @service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
@ -189,22 +183,23 @@ class DocumentUpdateByTextApi(DatasetApiResource):
@cloud_edition_billing_rate_limit_check("knowledge", "dataset") @cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id: str, dataset_id: UUID, document_id: UUID): def post(self, tenant_id: str, dataset_id: UUID, document_id: UUID):
"""Update document by text.""" """Update document by text."""
args = DocumentTextUpdate.model_validate(service_api_ns.payload).model_dump(exclude_unset=True) payload = DocumentTextUpdate.model_validate(service_api_ns.payload or {})
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == str(dataset_id)).first() dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == str(dataset_id)).first()
args = payload.model_dump(exclude_none=True)
if not dataset: if not dataset:
raise ValueError("Dataset does not exist.") raise ValueError("Dataset does not exist.")
retrieval_model = args.get("retrieval_model") retrieval_model = payload.retrieval_model
if ( if (
retrieval_model retrieval_model
and retrieval_model.get("reranking_model") and retrieval_model.reranking_model
and retrieval_model.get("reranking_model").get("reranking_provider_name") and retrieval_model.reranking_model.reranking_provider_name
and retrieval_model.reranking_model.reranking_model_name
): ):
DatasetService.check_reranking_model_setting( DatasetService.check_reranking_model_setting(
tenant_id, tenant_id,
retrieval_model.get("reranking_model").get("reranking_provider_name"), retrieval_model.reranking_model.reranking_provider_name,
retrieval_model.get("reranking_model").get("reranking_model_name"), retrieval_model.reranking_model.reranking_model_name,
) )
# indexing_technique is already set in dataset since this is an update # indexing_technique is already set in dataset since this is an update

View File

@ -1,9 +1,11 @@
from typing import Literal from typing import Literal
from flask_login import current_user from flask_login import current_user
from flask_restx import marshal, reqparse from flask_restx import marshal
from pydantic import BaseModel
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
from controllers.common.schema import register_schema_model, register_schema_models
from controllers.service_api import service_api_ns from controllers.service_api import service_api_ns
from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_rate_limit_check from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_rate_limit_check
from fields.dataset_fields import dataset_metadata_fields from fields.dataset_fields import dataset_metadata_fields
@ -14,25 +16,18 @@ from services.entities.knowledge_entities.knowledge_entities import (
) )
from services.metadata_service import MetadataService from services.metadata_service import MetadataService
# Define parsers for metadata APIs
metadata_create_parser = (
reqparse.RequestParser()
.add_argument("type", type=str, required=True, nullable=False, location="json", help="Metadata type")
.add_argument("name", type=str, required=True, nullable=False, location="json", help="Metadata name")
)
metadata_update_parser = reqparse.RequestParser().add_argument( class MetadataUpdatePayload(BaseModel):
"name", type=str, required=True, nullable=False, location="json", help="New metadata name" name: str
)
document_metadata_parser = reqparse.RequestParser().add_argument(
"operation_data", type=list, required=True, nullable=False, location="json", help="Metadata operation data" register_schema_model(service_api_ns, MetadataUpdatePayload)
) register_schema_models(service_api_ns, MetadataArgs, MetadataOperationData)
@service_api_ns.route("/datasets/<uuid:dataset_id>/metadata") @service_api_ns.route("/datasets/<uuid:dataset_id>/metadata")
class DatasetMetadataCreateServiceApi(DatasetApiResource): class DatasetMetadataCreateServiceApi(DatasetApiResource):
@service_api_ns.expect(metadata_create_parser) @service_api_ns.expect(service_api_ns.models[MetadataArgs.__name__])
@service_api_ns.doc("create_dataset_metadata") @service_api_ns.doc("create_dataset_metadata")
@service_api_ns.doc(description="Create metadata for a dataset") @service_api_ns.doc(description="Create metadata for a dataset")
@service_api_ns.doc(params={"dataset_id": "Dataset ID"}) @service_api_ns.doc(params={"dataset_id": "Dataset ID"})
@ -46,8 +41,7 @@ class DatasetMetadataCreateServiceApi(DatasetApiResource):
@cloud_edition_billing_rate_limit_check("knowledge", "dataset") @cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id, dataset_id): def post(self, tenant_id, dataset_id):
"""Create metadata for a dataset.""" """Create metadata for a dataset."""
args = metadata_create_parser.parse_args() metadata_args = MetadataArgs.model_validate(service_api_ns.payload or {})
metadata_args = MetadataArgs.model_validate(args)
dataset_id_str = str(dataset_id) dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str) dataset = DatasetService.get_dataset(dataset_id_str)
@ -79,7 +73,7 @@ class DatasetMetadataCreateServiceApi(DatasetApiResource):
@service_api_ns.route("/datasets/<uuid:dataset_id>/metadata/<uuid:metadata_id>") @service_api_ns.route("/datasets/<uuid:dataset_id>/metadata/<uuid:metadata_id>")
class DatasetMetadataServiceApi(DatasetApiResource): class DatasetMetadataServiceApi(DatasetApiResource):
@service_api_ns.expect(metadata_update_parser) @service_api_ns.expect(service_api_ns.models[MetadataUpdatePayload.__name__])
@service_api_ns.doc("update_dataset_metadata") @service_api_ns.doc("update_dataset_metadata")
@service_api_ns.doc(description="Update metadata name") @service_api_ns.doc(description="Update metadata name")
@service_api_ns.doc(params={"dataset_id": "Dataset ID", "metadata_id": "Metadata ID"}) @service_api_ns.doc(params={"dataset_id": "Dataset ID", "metadata_id": "Metadata ID"})
@ -93,7 +87,7 @@ class DatasetMetadataServiceApi(DatasetApiResource):
@cloud_edition_billing_rate_limit_check("knowledge", "dataset") @cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def patch(self, tenant_id, dataset_id, metadata_id): def patch(self, tenant_id, dataset_id, metadata_id):
"""Update metadata name.""" """Update metadata name."""
args = metadata_update_parser.parse_args() payload = MetadataUpdatePayload.model_validate(service_api_ns.payload or {})
dataset_id_str = str(dataset_id) dataset_id_str = str(dataset_id)
metadata_id_str = str(metadata_id) metadata_id_str = str(metadata_id)
@ -102,7 +96,7 @@ class DatasetMetadataServiceApi(DatasetApiResource):
raise NotFound("Dataset not found.") raise NotFound("Dataset not found.")
DatasetService.check_dataset_permission(dataset, current_user) DatasetService.check_dataset_permission(dataset, current_user)
metadata = MetadataService.update_metadata_name(dataset_id_str, metadata_id_str, args["name"]) metadata = MetadataService.update_metadata_name(dataset_id_str, metadata_id_str, payload.name)
return marshal(metadata, dataset_metadata_fields), 200 return marshal(metadata, dataset_metadata_fields), 200
@service_api_ns.doc("delete_dataset_metadata") @service_api_ns.doc("delete_dataset_metadata")
@ -175,7 +169,7 @@ class DatasetMetadataBuiltInFieldActionServiceApi(DatasetApiResource):
@service_api_ns.route("/datasets/<uuid:dataset_id>/documents/metadata") @service_api_ns.route("/datasets/<uuid:dataset_id>/documents/metadata")
class DocumentMetadataEditServiceApi(DatasetApiResource): class DocumentMetadataEditServiceApi(DatasetApiResource):
@service_api_ns.expect(document_metadata_parser) @service_api_ns.expect(service_api_ns.models[MetadataOperationData.__name__])
@service_api_ns.doc("update_documents_metadata") @service_api_ns.doc("update_documents_metadata")
@service_api_ns.doc(description="Update metadata for multiple documents") @service_api_ns.doc(description="Update metadata for multiple documents")
@service_api_ns.doc(params={"dataset_id": "Dataset ID"}) @service_api_ns.doc(params={"dataset_id": "Dataset ID"})
@ -195,8 +189,7 @@ class DocumentMetadataEditServiceApi(DatasetApiResource):
raise NotFound("Dataset not found.") raise NotFound("Dataset not found.")
DatasetService.check_dataset_permission(dataset, current_user) DatasetService.check_dataset_permission(dataset, current_user)
args = document_metadata_parser.parse_args() metadata_args = MetadataOperationData.model_validate(service_api_ns.payload or {})
metadata_args = MetadataOperationData.model_validate(args)
MetadataService.update_documents_metadata(dataset, metadata_args) MetadataService.update_documents_metadata(dataset, metadata_args)

View File

@ -4,12 +4,12 @@ from collections.abc import Generator
from typing import Any from typing import Any
from flask import request from flask import request
from flask_restx import reqparse from pydantic import BaseModel
from flask_restx.reqparse import ParseResult, RequestParser
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden
import services import services
from controllers.common.errors import FilenameNotExistsError, NoFileUploadedError, TooManyFilesError from controllers.common.errors import FilenameNotExistsError, NoFileUploadedError, TooManyFilesError
from controllers.common.schema import register_schema_model
from controllers.service_api import service_api_ns from controllers.service_api import service_api_ns
from controllers.service_api.dataset.error import PipelineRunError from controllers.service_api.dataset.error import PipelineRunError
from controllers.service_api.wraps import DatasetApiResource from controllers.service_api.wraps import DatasetApiResource
@ -22,11 +22,25 @@ from models.dataset import Pipeline
from models.engine import db from models.engine import db
from services.errors.file import FileTooLargeError, UnsupportedFileTypeError from services.errors.file import FileTooLargeError, UnsupportedFileTypeError
from services.file_service import FileService from services.file_service import FileService
from services.rag_pipeline.entity.pipeline_service_api_entities import DatasourceNodeRunApiEntity from services.rag_pipeline.entity.pipeline_service_api_entities import (
DatasourceNodeRunApiEntity,
PipelineRunApiEntity,
)
from services.rag_pipeline.pipeline_generate_service import PipelineGenerateService from services.rag_pipeline.pipeline_generate_service import PipelineGenerateService
from services.rag_pipeline.rag_pipeline import RagPipelineService from services.rag_pipeline.rag_pipeline import RagPipelineService
class DatasourceNodeRunPayload(BaseModel):
inputs: dict[str, Any]
datasource_type: str
credential_id: str | None = None
is_published: bool
register_schema_model(service_api_ns, DatasourceNodeRunPayload)
register_schema_model(service_api_ns, PipelineRunApiEntity)
@service_api_ns.route(f"/datasets/{uuid:dataset_id}/pipeline/datasource-plugins") @service_api_ns.route(f"/datasets/{uuid:dataset_id}/pipeline/datasource-plugins")
class DatasourcePluginsApi(DatasetApiResource): class DatasourcePluginsApi(DatasetApiResource):
"""Resource for datasource plugins.""" """Resource for datasource plugins."""
@ -88,22 +102,20 @@ class DatasourceNodeRunApi(DatasetApiResource):
401: "Unauthorized - invalid API token", 401: "Unauthorized - invalid API token",
} }
) )
@service_api_ns.expect(service_api_ns.models[DatasourceNodeRunPayload.__name__])
def post(self, tenant_id: str, dataset_id: str, node_id: str): def post(self, tenant_id: str, dataset_id: str, node_id: str):
"""Resource for getting datasource plugins.""" """Resource for getting datasource plugins."""
# Get query parameter to determine published or draft payload = DatasourceNodeRunPayload.model_validate(service_api_ns.payload or {})
parser: RequestParser = (
reqparse.RequestParser()
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
.add_argument("datasource_type", type=str, required=True, location="json")
.add_argument("credential_id", type=str, required=False, location="json")
.add_argument("is_published", type=bool, required=True, location="json")
)
args: ParseResult = parser.parse_args()
datasource_node_run_api_entity = DatasourceNodeRunApiEntity.model_validate(args)
assert isinstance(current_user, Account) assert isinstance(current_user, Account)
rag_pipeline_service: RagPipelineService = RagPipelineService() rag_pipeline_service: RagPipelineService = RagPipelineService()
pipeline: Pipeline = rag_pipeline_service.get_pipeline(tenant_id=tenant_id, dataset_id=dataset_id) pipeline: Pipeline = rag_pipeline_service.get_pipeline(tenant_id=tenant_id, dataset_id=dataset_id)
datasource_node_run_api_entity = DatasourceNodeRunApiEntity.model_validate(
{
**payload.model_dump(exclude_none=True),
"pipeline_id": str(pipeline.id),
"node_id": node_id,
}
)
return helper.compact_generate_response( return helper.compact_generate_response(
PipelineGenerator.convert_to_event_stream( PipelineGenerator.convert_to_event_stream(
rag_pipeline_service.run_datasource_workflow_node( rag_pipeline_service.run_datasource_workflow_node(
@ -147,25 +159,10 @@ class PipelineRunApi(DatasetApiResource):
401: "Unauthorized - invalid API token", 401: "Unauthorized - invalid API token",
} }
) )
@service_api_ns.expect(service_api_ns.models[PipelineRunApiEntity.__name__])
def post(self, tenant_id: str, dataset_id: str): def post(self, tenant_id: str, dataset_id: str):
"""Resource for running a rag pipeline.""" """Resource for running a rag pipeline."""
parser: RequestParser = ( payload = PipelineRunApiEntity.model_validate(service_api_ns.payload or {})
reqparse.RequestParser()
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
.add_argument("datasource_type", type=str, required=True, location="json")
.add_argument("datasource_info_list", type=list, required=True, location="json")
.add_argument("start_node_id", type=str, required=True, location="json")
.add_argument("is_published", type=bool, required=True, default=True, location="json")
.add_argument(
"response_mode",
type=str,
required=True,
choices=["streaming", "blocking"],
default="blocking",
location="json",
)
)
args: ParseResult = parser.parse_args()
if not isinstance(current_user, Account): if not isinstance(current_user, Account):
raise Forbidden() raise Forbidden()
@ -176,9 +173,9 @@ class PipelineRunApi(DatasetApiResource):
response: dict[Any, Any] | Generator[str, Any, None] = PipelineGenerateService.generate( response: dict[Any, Any] | Generator[str, Any, None] = PipelineGenerateService.generate(
pipeline=pipeline, pipeline=pipeline,
user=current_user, user=current_user,
args=args, args=payload.model_dump(),
invoke_from=InvokeFrom.PUBLISHED if args.get("is_published") else InvokeFrom.DEBUGGER, invoke_from=InvokeFrom.PUBLISHED if payload.is_published else InvokeFrom.DEBUGGER,
streaming=args.get("response_mode") == "streaming", streaming=payload.response_mode == "streaming",
) )
return helper.compact_generate_response(response) return helper.compact_generate_response(response)

View File

@ -1,8 +1,12 @@
from typing import Any
from flask import request from flask import request
from flask_restx import marshal, reqparse from flask_restx import marshal
from pydantic import BaseModel, Field
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
from configs import dify_config from configs import dify_config
from controllers.common.schema import register_schema_models
from controllers.service_api import service_api_ns from controllers.service_api import service_api_ns
from controllers.service_api.app.error import ProviderNotInitializeError from controllers.service_api.app.error import ProviderNotInitializeError
from controllers.service_api.wraps import ( from controllers.service_api.wraps import (
@ -24,34 +28,42 @@ from services.errors.chunk import ChildChunkDeleteIndexError, ChildChunkIndexing
from services.errors.chunk import ChildChunkDeleteIndexError as ChildChunkDeleteIndexServiceError from services.errors.chunk import ChildChunkDeleteIndexError as ChildChunkDeleteIndexServiceError
from services.errors.chunk import ChildChunkIndexingError as ChildChunkIndexingServiceError from services.errors.chunk import ChildChunkIndexingError as ChildChunkIndexingServiceError
# Define parsers for segment operations
segment_create_parser = reqparse.RequestParser().add_argument(
"segments", type=list, required=False, nullable=True, location="json"
)
segment_list_parser = ( class SegmentCreatePayload(BaseModel):
reqparse.RequestParser() segments: list[dict[str, Any]] | None = None
.add_argument("status", type=str, action="append", default=[], location="args")
.add_argument("keyword", type=str, default=None, location="args")
)
segment_update_parser = reqparse.RequestParser().add_argument(
"segment", type=dict, required=False, nullable=True, location="json"
)
child_chunk_create_parser = reqparse.RequestParser().add_argument( class SegmentListQuery(BaseModel):
"content", type=str, required=True, nullable=False, location="json" status: list[str] = Field(default_factory=list)
) keyword: str | None = None
child_chunk_list_parser = (
reqparse.RequestParser()
.add_argument("limit", type=int, default=20, location="args")
.add_argument("keyword", type=str, default=None, location="args")
.add_argument("page", type=int, default=1, location="args")
)
child_chunk_update_parser = reqparse.RequestParser().add_argument( class SegmentUpdatePayload(BaseModel):
"content", type=str, required=True, nullable=False, location="json" segment: SegmentUpdateArgs
class ChildChunkCreatePayload(BaseModel):
content: str
class ChildChunkListQuery(BaseModel):
limit: int = Field(default=20, ge=1)
keyword: str | None = None
page: int = Field(default=1, ge=1)
class ChildChunkUpdatePayload(BaseModel):
content: str
register_schema_models(
service_api_ns,
SegmentCreatePayload,
SegmentListQuery,
SegmentUpdatePayload,
ChildChunkCreatePayload,
ChildChunkListQuery,
ChildChunkUpdatePayload,
) )
@ -59,7 +71,7 @@ child_chunk_update_parser = reqparse.RequestParser().add_argument(
class SegmentApi(DatasetApiResource): class SegmentApi(DatasetApiResource):
"""Resource for segments.""" """Resource for segments."""
@service_api_ns.expect(segment_create_parser) @service_api_ns.expect(service_api_ns.models[SegmentCreatePayload.__name__])
@service_api_ns.doc("create_segments") @service_api_ns.doc("create_segments")
@service_api_ns.doc(description="Create segments in a document") @service_api_ns.doc(description="Create segments in a document")
@service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"}) @service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
@ -106,20 +118,20 @@ class SegmentApi(DatasetApiResource):
except ProviderTokenNotInitError as ex: except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description) raise ProviderNotInitializeError(ex.description)
# validate args # validate args
args = segment_create_parser.parse_args() payload = SegmentCreatePayload.model_validate(service_api_ns.payload or {})
if args["segments"] is not None: if payload.segments is not None:
segments_limit = dify_config.DATASET_MAX_SEGMENTS_PER_REQUEST segments_limit = dify_config.DATASET_MAX_SEGMENTS_PER_REQUEST
if segments_limit > 0 and len(args["segments"]) > segments_limit: if segments_limit > 0 and len(payload.segments) > segments_limit:
raise ValueError(f"Exceeded maximum segments limit of {segments_limit}.") raise ValueError(f"Exceeded maximum segments limit of {segments_limit}.")
for args_item in args["segments"]: for args_item in payload.segments:
SegmentService.segment_create_args_validate(args_item, document) SegmentService.segment_create_args_validate(args_item, document)
segments = SegmentService.multi_create_segment(args["segments"], document, dataset) segments = SegmentService.multi_create_segment(payload.segments, document, dataset)
return {"data": marshal(segments, segment_fields), "doc_form": document.doc_form}, 200 return {"data": marshal(segments, segment_fields), "doc_form": document.doc_form}, 200
else: else:
return {"error": "Segments is required"}, 400 return {"error": "Segments is required"}, 400
@service_api_ns.expect(segment_list_parser) @service_api_ns.expect(service_api_ns.models[SegmentListQuery.__name__])
@service_api_ns.doc("list_segments") @service_api_ns.doc("list_segments")
@service_api_ns.doc(description="List segments in a document") @service_api_ns.doc(description="List segments in a document")
@service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"}) @service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
@ -160,13 +172,18 @@ class SegmentApi(DatasetApiResource):
except ProviderTokenNotInitError as ex: except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description) raise ProviderNotInitializeError(ex.description)
args = segment_list_parser.parse_args() args = SegmentListQuery.model_validate(
{
"status": request.args.getlist("status"),
"keyword": request.args.get("keyword"),
}
)
segments, total = SegmentService.get_segments( segments, total = SegmentService.get_segments(
document_id=document_id, document_id=document_id,
tenant_id=current_tenant_id, tenant_id=current_tenant_id,
status_list=args["status"], status_list=args.status,
keyword=args["keyword"], keyword=args.keyword,
page=page, page=page,
limit=limit, limit=limit,
) )
@ -217,7 +234,7 @@ class DatasetSegmentApi(DatasetApiResource):
SegmentService.delete_segment(segment, document, dataset) SegmentService.delete_segment(segment, document, dataset)
return 204 return 204
@service_api_ns.expect(segment_update_parser) @service_api_ns.expect(service_api_ns.models[SegmentUpdatePayload.__name__])
@service_api_ns.doc("update_segment") @service_api_ns.doc("update_segment")
@service_api_ns.doc(description="Update a specific segment") @service_api_ns.doc(description="Update a specific segment")
@service_api_ns.doc( @service_api_ns.doc(
@ -265,12 +282,9 @@ class DatasetSegmentApi(DatasetApiResource):
if not segment: if not segment:
raise NotFound("Segment not found.") raise NotFound("Segment not found.")
# validate args payload = SegmentUpdatePayload.model_validate(service_api_ns.payload or {})
args = segment_update_parser.parse_args()
updated_segment = SegmentService.update_segment( updated_segment = SegmentService.update_segment(payload.segment, segment, document, dataset)
SegmentUpdateArgs.model_validate(args["segment"]), segment, document, dataset
)
return {"data": marshal(updated_segment, segment_fields), "doc_form": document.doc_form}, 200 return {"data": marshal(updated_segment, segment_fields), "doc_form": document.doc_form}, 200
@service_api_ns.doc("get_segment") @service_api_ns.doc("get_segment")
@ -308,7 +322,7 @@ class DatasetSegmentApi(DatasetApiResource):
class ChildChunkApi(DatasetApiResource): class ChildChunkApi(DatasetApiResource):
"""Resource for child chunks.""" """Resource for child chunks."""
@service_api_ns.expect(child_chunk_create_parser) @service_api_ns.expect(service_api_ns.models[ChildChunkCreatePayload.__name__])
@service_api_ns.doc("create_child_chunk") @service_api_ns.doc("create_child_chunk")
@service_api_ns.doc(description="Create a new child chunk for a segment") @service_api_ns.doc(description="Create a new child chunk for a segment")
@service_api_ns.doc( @service_api_ns.doc(
@ -360,16 +374,16 @@ class ChildChunkApi(DatasetApiResource):
raise ProviderNotInitializeError(ex.description) raise ProviderNotInitializeError(ex.description)
# validate args # validate args
args = child_chunk_create_parser.parse_args() payload = ChildChunkCreatePayload.model_validate(service_api_ns.payload or {})
try: try:
child_chunk = SegmentService.create_child_chunk(args["content"], segment, document, dataset) child_chunk = SegmentService.create_child_chunk(payload.content, segment, document, dataset)
except ChildChunkIndexingServiceError as e: except ChildChunkIndexingServiceError as e:
raise ChildChunkIndexingError(str(e)) raise ChildChunkIndexingError(str(e))
return {"data": marshal(child_chunk, child_chunk_fields)}, 200 return {"data": marshal(child_chunk, child_chunk_fields)}, 200
@service_api_ns.expect(child_chunk_list_parser) @service_api_ns.expect(service_api_ns.models[ChildChunkListQuery.__name__])
@service_api_ns.doc("list_child_chunks") @service_api_ns.doc("list_child_chunks")
@service_api_ns.doc(description="List child chunks for a segment") @service_api_ns.doc(description="List child chunks for a segment")
@service_api_ns.doc( @service_api_ns.doc(
@ -400,11 +414,17 @@ class ChildChunkApi(DatasetApiResource):
if not segment: if not segment:
raise NotFound("Segment not found.") raise NotFound("Segment not found.")
args = child_chunk_list_parser.parse_args() args = ChildChunkListQuery.model_validate(
{
"limit": request.args.get("limit", default=20, type=int),
"keyword": request.args.get("keyword"),
"page": request.args.get("page", default=1, type=int),
}
)
page = args["page"] page = args.page
limit = min(args["limit"], 100) limit = min(args.limit, 100)
keyword = args["keyword"] keyword = args.keyword
child_chunks = SegmentService.get_child_chunks(segment_id, document_id, dataset_id, page, limit, keyword) child_chunks = SegmentService.get_child_chunks(segment_id, document_id, dataset_id, page, limit, keyword)
@ -480,7 +500,7 @@ class DatasetChildChunkApi(DatasetApiResource):
return 204 return 204
@service_api_ns.expect(child_chunk_update_parser) @service_api_ns.expect(service_api_ns.models[ChildChunkUpdatePayload.__name__])
@service_api_ns.doc("update_child_chunk") @service_api_ns.doc("update_child_chunk")
@service_api_ns.doc(description="Update a specific child chunk") @service_api_ns.doc(description="Update a specific child chunk")
@service_api_ns.doc( @service_api_ns.doc(
@ -533,10 +553,10 @@ class DatasetChildChunkApi(DatasetApiResource):
raise NotFound("Child chunk not found.") raise NotFound("Child chunk not found.")
# validate args # validate args
args = child_chunk_update_parser.parse_args() payload = ChildChunkUpdatePayload.model_validate(service_api_ns.payload or {})
try: try:
child_chunk = SegmentService.update_child_chunk(args["content"], child_chunk, segment, document, dataset) child_chunk = SegmentService.update_child_chunk(payload.content, child_chunk, segment, document, dataset)
except ChildChunkIndexingServiceError as e: except ChildChunkIndexingServiceError as e:
raise ChildChunkIndexingError(str(e)) raise ChildChunkIndexingError(str(e))

View File

@ -1,3 +1,4 @@
import logging
import time import time
from collections.abc import Callable from collections.abc import Callable
from datetime import timedelta from datetime import timedelta
@ -28,6 +29,8 @@ P = ParamSpec("P")
R = TypeVar("R") R = TypeVar("R")
T = TypeVar("T") T = TypeVar("T")
logger = logging.getLogger(__name__)
class WhereisUserArg(StrEnum): class WhereisUserArg(StrEnum):
""" """
@ -238,8 +241,8 @@ def validate_dataset_token(view: Callable[Concatenate[T, P], R] | None = None):
# Basic check: UUIDs are 36 chars with hyphens # Basic check: UUIDs are 36 chars with hyphens
if len(str_id) == 36 and str_id.count("-") == 4: if len(str_id) == 36 and str_id.count("-") == 4:
dataset_id = str_id dataset_id = str_id
except: except Exception:
pass logger.exception("Failed to parse dataset_id from class method args")
elif len(args) > 0: elif len(args) > 0:
# Not a class method, check if args[0] looks like a UUID # Not a class method, check if args[0] looks like a UUID
potential_id = args[0] potential_id = args[0]
@ -247,8 +250,8 @@ def validate_dataset_token(view: Callable[Concatenate[T, P], R] | None = None):
str_id = str(potential_id) str_id = str(potential_id)
if len(str_id) == 36 and str_id.count("-") == 4: if len(str_id) == 36 and str_id.count("-") == 4:
dataset_id = str_id dataset_id = str_id
except: except Exception:
pass logger.exception("Failed to parse dataset_id from positional args")
# Validate dataset if dataset_id is provided # Validate dataset if dataset_id is provided
if dataset_id: if dataset_id:
@ -316,18 +319,16 @@ def validate_and_get_api_token(scope: str | None = None):
ApiToken.type == scope, ApiToken.type == scope,
) )
.values(last_used_at=current_time) .values(last_used_at=current_time)
.returning(ApiToken)
) )
stmt = select(ApiToken).where(ApiToken.token == auth_token, ApiToken.type == scope)
result = session.execute(update_stmt) result = session.execute(update_stmt)
api_token = result.scalar_one_or_none() api_token = session.scalar(stmt)
if hasattr(result, "rowcount") and result.rowcount > 0:
session.commit()
if not api_token: if not api_token:
stmt = select(ApiToken).where(ApiToken.token == auth_token, ApiToken.type == scope) raise Unauthorized("Access token is invalid")
api_token = session.scalar(stmt)
if not api_token:
raise Unauthorized("Access token is invalid")
else:
session.commit()
return api_token return api_token

View File

@ -1,4 +1,5 @@
import json import json
import logging
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Generator, Mapping, Sequence from collections.abc import Generator, Mapping, Sequence
from typing import Any from typing import Any
@ -23,6 +24,8 @@ from core.tools.entities.tool_entities import ToolInvokeMeta
from core.tools.tool_engine import ToolEngine from core.tools.tool_engine import ToolEngine
from models.model import Message from models.model import Message
logger = logging.getLogger(__name__)
class CotAgentRunner(BaseAgentRunner, ABC): class CotAgentRunner(BaseAgentRunner, ABC):
_is_first_iteration = True _is_first_iteration = True
@ -400,8 +403,8 @@ class CotAgentRunner(BaseAgentRunner, ABC):
action_input=json.loads(message.tool_calls[0].function.arguments), action_input=json.loads(message.tool_calls[0].function.arguments),
) )
current_scratchpad.action_str = json.dumps(current_scratchpad.action.to_dict()) current_scratchpad.action_str = json.dumps(current_scratchpad.action.to_dict())
except: except Exception:
pass logger.exception("Failed to parse tool call from assistant message")
elif isinstance(message, ToolPromptMessage): elif isinstance(message, ToolPromptMessage):
if current_scratchpad: if current_scratchpad:
assert isinstance(message.content, str) assert isinstance(message.content, str)

View File

@ -2,6 +2,7 @@ from collections.abc import Sequence
from enum import StrEnum, auto from enum import StrEnum, auto
from typing import Any, Literal from typing import Any, Literal
from jsonschema import Draft7Validator, SchemaError
from pydantic import BaseModel, Field, field_validator from pydantic import BaseModel, Field, field_validator
from core.file import FileTransferMethod, FileType, FileUploadConfig from core.file import FileTransferMethod, FileType, FileUploadConfig
@ -98,6 +99,7 @@ class VariableEntityType(StrEnum):
FILE = "file" FILE = "file"
FILE_LIST = "file-list" FILE_LIST = "file-list"
CHECKBOX = "checkbox" CHECKBOX = "checkbox"
JSON_OBJECT = "json_object"
class VariableEntity(BaseModel): class VariableEntity(BaseModel):
@ -118,6 +120,7 @@ class VariableEntity(BaseModel):
allowed_file_types: Sequence[FileType] | None = Field(default_factory=list) allowed_file_types: Sequence[FileType] | None = Field(default_factory=list)
allowed_file_extensions: Sequence[str] | None = Field(default_factory=list) allowed_file_extensions: Sequence[str] | None = Field(default_factory=list)
allowed_file_upload_methods: Sequence[FileTransferMethod] | None = Field(default_factory=list) allowed_file_upload_methods: Sequence[FileTransferMethod] | None = Field(default_factory=list)
json_schema: dict[str, Any] | None = Field(default=None)
@field_validator("description", mode="before") @field_validator("description", mode="before")
@classmethod @classmethod
@ -129,6 +132,17 @@ class VariableEntity(BaseModel):
def convert_none_options(cls, v: Any) -> Sequence[str]: def convert_none_options(cls, v: Any) -> Sequence[str]:
return v or [] return v or []
@field_validator("json_schema")
@classmethod
def validate_json_schema(cls, schema: dict[str, Any] | None) -> dict[str, Any] | None:
if schema is None:
return None
try:
Draft7Validator.check_schema(schema)
except SchemaError as e:
raise ValueError(f"Invalid JSON schema: {e.message}")
return schema
class RagPipelineVariableEntity(VariableEntity): class RagPipelineVariableEntity(VariableEntity):
""" """

View File

@ -35,6 +35,7 @@ from core.workflow.variable_loader import VariableLoader
from core.workflow.workflow_entry import WorkflowEntry from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db from extensions.ext_database import db
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from extensions.otel import WorkflowAppRunnerHandler, trace_span
from models import Workflow from models import Workflow
from models.enums import UserFrom from models.enums import UserFrom
from models.model import App, Conversation, Message, MessageAnnotation from models.model import App, Conversation, Message, MessageAnnotation
@ -80,6 +81,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
self._workflow_execution_repository = workflow_execution_repository self._workflow_execution_repository = workflow_execution_repository
self._workflow_node_execution_repository = workflow_node_execution_repository self._workflow_node_execution_repository = workflow_node_execution_repository
@trace_span(WorkflowAppRunnerHandler)
def run(self): def run(self):
app_config = self.application_generate_entity.app_config app_config = self.application_generate_entity.app_config
app_config = cast(AdvancedChatAppConfig, app_config) app_config = cast(AdvancedChatAppConfig, app_config)

View File

@ -62,8 +62,7 @@ from core.app.task_pipeline.message_cycle_manager import MessageCycleManager
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
from core.model_runtime.entities.llm_entities import LLMUsage from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
from core.ops.entities.trace_entity import TraceTaskName from core.ops.ops_trace_manager import TraceQueueManager
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
from core.workflow.enums import WorkflowExecutionStatus from core.workflow.enums import WorkflowExecutionStatus
from core.workflow.nodes import NodeType from core.workflow.nodes import NodeType
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
@ -73,7 +72,7 @@ from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now from libs.datetime_utils import naive_utc_now
from models import Account, Conversation, EndUser, Message, MessageFile from models import Account, Conversation, EndUser, Message, MessageFile
from models.enums import CreatorUserRole from models.enums import CreatorUserRole
from models.workflow import Workflow, WorkflowNodeExecutionModel from models.workflow import Workflow
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -581,7 +580,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
with self._database_session() as session: with self._database_session() as session:
# Save message # Save message
self._save_message(session=session, graph_runtime_state=resolved_state, trace_manager=trace_manager) self._save_message(session=session, graph_runtime_state=resolved_state)
yield workflow_finish_resp yield workflow_finish_resp
elif event.stopped_by in ( elif event.stopped_by in (
@ -591,7 +590,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
# When hitting input-moderation or annotation-reply, the workflow will not start # When hitting input-moderation or annotation-reply, the workflow will not start
with self._database_session() as session: with self._database_session() as session:
# Save message # Save message
self._save_message(session=session, trace_manager=trace_manager) self._save_message(session=session)
yield self._message_end_to_stream_response() yield self._message_end_to_stream_response()
@ -600,7 +599,6 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
event: QueueAdvancedChatMessageEndEvent, event: QueueAdvancedChatMessageEndEvent,
*, *,
graph_runtime_state: GraphRuntimeState | None = None, graph_runtime_state: GraphRuntimeState | None = None,
trace_manager: TraceQueueManager | None = None,
**kwargs, **kwargs,
) -> Generator[StreamResponse, None, None]: ) -> Generator[StreamResponse, None, None]:
"""Handle advanced chat message end events.""" """Handle advanced chat message end events."""
@ -618,7 +616,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
# Save message # Save message
with self._database_session() as session: with self._database_session() as session:
self._save_message(session=session, graph_runtime_state=resolved_state, trace_manager=trace_manager) self._save_message(session=session, graph_runtime_state=resolved_state)
yield self._message_end_to_stream_response() yield self._message_end_to_stream_response()
@ -770,15 +768,9 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
tts_publisher.publish(None) tts_publisher.publish(None)
if self._conversation_name_generate_thread: if self._conversation_name_generate_thread:
self._conversation_name_generate_thread.join() logger.debug("Conversation name generation running as daemon thread")
def _save_message( def _save_message(self, *, session: Session, graph_runtime_state: GraphRuntimeState | None = None):
self,
*,
session: Session,
graph_runtime_state: GraphRuntimeState | None = None,
trace_manager: TraceQueueManager | None = None,
):
message = self._get_message(session=session) message = self._get_message(session=session)
# If there are assistant files, remove markdown image links from answer # If there are assistant files, remove markdown image links from answer
@ -817,14 +809,6 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
metadata = self._task_state.metadata.model_dump() metadata = self._task_state.metadata.model_dump()
message.message_metadata = json.dumps(jsonable_encoder(metadata)) message.message_metadata = json.dumps(jsonable_encoder(metadata))
# Extract model provider and model_id from workflow node executions for tracing
if message.workflow_run_id:
model_info = self._extract_model_info_from_workflow(session, message.workflow_run_id)
if model_info:
message.model_provider = model_info.get("provider")
message.model_id = model_info.get("model")
message_files = [ message_files = [
MessageFile( MessageFile(
message_id=message.id, message_id=message.id,
@ -842,68 +826,6 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
] ]
session.add_all(message_files) session.add_all(message_files)
# Trigger MESSAGE_TRACE for tracing integrations
if trace_manager:
trace_manager.add_trace_task(
TraceTask(
TraceTaskName.MESSAGE_TRACE, conversation_id=self._conversation_id, message_id=self._message_id
)
)
def _extract_model_info_from_workflow(self, session: Session, workflow_run_id: str) -> dict[str, str] | None:
"""
Extract model provider and model_id from workflow node executions.
Returns dict with 'provider' and 'model' keys, or None if not found.
"""
try:
# Query workflow node executions for LLM or Agent nodes
stmt = (
select(WorkflowNodeExecutionModel)
.where(WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id)
.where(WorkflowNodeExecutionModel.node_type.in_(["llm", "agent"]))
.order_by(WorkflowNodeExecutionModel.created_at.desc())
.limit(1)
)
node_execution = session.scalar(stmt)
if not node_execution:
return None
# Try to extract from execution_metadata for agent nodes
if node_execution.execution_metadata:
try:
metadata = json.loads(node_execution.execution_metadata)
agent_log = metadata.get("agent_log", [])
# Look for the first agent thought with provider info
for log_entry in agent_log:
entry_metadata = log_entry.get("metadata", {})
provider_str = entry_metadata.get("provider")
if provider_str:
# Parse format like "langgenius/deepseek/deepseek"
parts = provider_str.split("/")
if len(parts) >= 3:
return {"provider": parts[1], "model": parts[2]}
elif len(parts) == 2:
return {"provider": parts[0], "model": parts[1]}
except (json.JSONDecodeError, KeyError, AttributeError) as e:
logger.debug("Failed to parse execution_metadata: %s", e)
# Try to extract from process_data for llm nodes
if node_execution.process_data:
try:
process_data = json.loads(node_execution.process_data)
provider = process_data.get("model_provider")
model = process_data.get("model_name")
if provider and model:
return {"provider": provider, "model": model}
except (json.JSONDecodeError, KeyError) as e:
logger.debug("Failed to parse process_data: %s", e)
return None
except Exception as e:
logger.warning("Failed to extract model info from workflow: %s", e)
return None
def _seed_graph_runtime_state_from_queue_manager(self) -> None: def _seed_graph_runtime_state_from_queue_manager(self) -> None:
"""Bootstrap the cached runtime state from the queue manager when present.""" """Bootstrap the cached runtime state from the queue manager when present."""
candidate = self._base_task_pipeline.queue_manager.graph_runtime_state candidate = self._base_task_pipeline.queue_manager.graph_runtime_state

View File

@ -99,6 +99,15 @@ class BaseAppGenerator:
if value is None: if value is None:
return None return None
# Treat empty placeholders for optional file inputs as unset
if (
variable_entity.type in {VariableEntityType.FILE, VariableEntityType.FILE_LIST}
and not variable_entity.required
):
# Treat empty string (frontend default) or empty list as unset
if not value and isinstance(value, (str, list)):
return None
if variable_entity.type in { if variable_entity.type in {
VariableEntityType.TEXT_INPUT, VariableEntityType.TEXT_INPUT,
VariableEntityType.SELECT, VariableEntityType.SELECT,

View File

@ -83,6 +83,7 @@ class AppRunner:
context: str | None = None, context: str | None = None,
memory: TokenBufferMemory | None = None, memory: TokenBufferMemory | None = None,
image_detail_config: ImagePromptMessageContent.DETAIL | None = None, image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
context_files: list["File"] | None = None,
) -> tuple[list[PromptMessage], list[str] | None]: ) -> tuple[list[PromptMessage], list[str] | None]:
""" """
Organize prompt messages Organize prompt messages
@ -111,6 +112,7 @@ class AppRunner:
memory=memory, memory=memory,
model_config=model_config, model_config=model_config,
image_detail_config=image_detail_config, image_detail_config=image_detail_config,
context_files=context_files,
) )
else: else:
memory_config = MemoryConfig(window=MemoryConfig.WindowConfig(enabled=False)) memory_config = MemoryConfig(window=MemoryConfig.WindowConfig(enabled=False))

View File

@ -11,6 +11,7 @@ from core.app.entities.app_invoke_entities import (
) )
from core.app.entities.queue_entities import QueueAnnotationReplyEvent from core.app.entities.queue_entities import QueueAnnotationReplyEvent
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.file import File
from core.memory.token_buffer_memory import TokenBufferMemory from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance from core.model_manager import ModelInstance
from core.model_runtime.entities.message_entities import ImagePromptMessageContent from core.model_runtime.entities.message_entities import ImagePromptMessageContent
@ -146,6 +147,7 @@ class ChatAppRunner(AppRunner):
# get context from datasets # get context from datasets
context = None context = None
context_files: list[File] = []
if app_config.dataset and app_config.dataset.dataset_ids: if app_config.dataset and app_config.dataset.dataset_ids:
hit_callback = DatasetIndexToolCallbackHandler( hit_callback = DatasetIndexToolCallbackHandler(
queue_manager, queue_manager,
@ -156,7 +158,7 @@ class ChatAppRunner(AppRunner):
) )
dataset_retrieval = DatasetRetrieval(application_generate_entity) dataset_retrieval = DatasetRetrieval(application_generate_entity)
context = dataset_retrieval.retrieve( context, retrieved_files = dataset_retrieval.retrieve(
app_id=app_record.id, app_id=app_record.id,
user_id=application_generate_entity.user_id, user_id=application_generate_entity.user_id,
tenant_id=app_record.tenant_id, tenant_id=app_record.tenant_id,
@ -171,7 +173,11 @@ class ChatAppRunner(AppRunner):
memory=memory, memory=memory,
message_id=message.id, message_id=message.id,
inputs=inputs, inputs=inputs,
vision_enabled=application_generate_entity.app_config.app_model_config_dict.get("file_upload", {}).get(
"enabled", False
),
) )
context_files = retrieved_files or []
# reorganize all inputs and template to prompt messages # reorganize all inputs and template to prompt messages
# Include: prompt template, inputs, query(optional), files(optional) # Include: prompt template, inputs, query(optional), files(optional)
@ -186,6 +192,7 @@ class ChatAppRunner(AppRunner):
context=context, context=context,
memory=memory, memory=memory,
image_detail_config=image_detail_config, image_detail_config=image_detail_config,
context_files=context_files,
) )
# check hosting moderation # check hosting moderation

View File

@ -1,3 +1,4 @@
import logging
import time import time
from collections.abc import Mapping, Sequence from collections.abc import Mapping, Sequence
from dataclasses import dataclass from dataclasses import dataclass
@ -55,6 +56,7 @@ from models import Account, EndUser
from services.variable_truncator import BaseTruncator, DummyVariableTruncator, VariableTruncator from services.variable_truncator import BaseTruncator, DummyVariableTruncator, VariableTruncator
NodeExecutionId = NewType("NodeExecutionId", str) NodeExecutionId = NewType("NodeExecutionId", str)
logger = logging.getLogger(__name__)
@dataclass(slots=True) @dataclass(slots=True)
@ -289,26 +291,30 @@ class WorkflowResponseConverter:
), ),
) )
if event.node_type == NodeType.TOOL: try:
response.data.extras["icon"] = ToolManager.get_tool_icon( if event.node_type == NodeType.TOOL:
tenant_id=self._application_generate_entity.app_config.tenant_id, response.data.extras["icon"] = ToolManager.get_tool_icon(
provider_type=ToolProviderType(event.provider_type), tenant_id=self._application_generate_entity.app_config.tenant_id,
provider_id=event.provider_id, provider_type=ToolProviderType(event.provider_type),
) provider_id=event.provider_id,
elif event.node_type == NodeType.DATASOURCE: )
manager = PluginDatasourceManager() elif event.node_type == NodeType.DATASOURCE:
provider_entity = manager.fetch_datasource_provider( manager = PluginDatasourceManager()
self._application_generate_entity.app_config.tenant_id, provider_entity = manager.fetch_datasource_provider(
event.provider_id, self._application_generate_entity.app_config.tenant_id,
) event.provider_id,
response.data.extras["icon"] = provider_entity.declaration.identity.generate_datasource_icon_url( )
self._application_generate_entity.app_config.tenant_id response.data.extras["icon"] = provider_entity.declaration.identity.generate_datasource_icon_url(
) self._application_generate_entity.app_config.tenant_id
elif event.node_type == NodeType.TRIGGER_PLUGIN: )
response.data.extras["icon"] = TriggerManager.get_trigger_plugin_icon( elif event.node_type == NodeType.TRIGGER_PLUGIN:
self._application_generate_entity.app_config.tenant_id, response.data.extras["icon"] = TriggerManager.get_trigger_plugin_icon(
event.provider_id, self._application_generate_entity.app_config.tenant_id,
) event.provider_id,
)
except Exception:
# metadata fetch may fail, for example, the plugin daemon is down or plugin is uninstalled.
logger.warning("failed to fetch icon for %s", event.provider_id)
return response return response

View File

@ -10,6 +10,7 @@ from core.app.entities.app_invoke_entities import (
CompletionAppGenerateEntity, CompletionAppGenerateEntity,
) )
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.file import File
from core.model_manager import ModelInstance from core.model_manager import ModelInstance
from core.model_runtime.entities.message_entities import ImagePromptMessageContent from core.model_runtime.entities.message_entities import ImagePromptMessageContent
from core.moderation.base import ModerationError from core.moderation.base import ModerationError
@ -102,6 +103,7 @@ class CompletionAppRunner(AppRunner):
# get context from datasets # get context from datasets
context = None context = None
context_files: list[File] = []
if app_config.dataset and app_config.dataset.dataset_ids: if app_config.dataset and app_config.dataset.dataset_ids:
hit_callback = DatasetIndexToolCallbackHandler( hit_callback = DatasetIndexToolCallbackHandler(
queue_manager, queue_manager,
@ -116,7 +118,7 @@ class CompletionAppRunner(AppRunner):
query = inputs.get(dataset_config.retrieve_config.query_variable, "") query = inputs.get(dataset_config.retrieve_config.query_variable, "")
dataset_retrieval = DatasetRetrieval(application_generate_entity) dataset_retrieval = DatasetRetrieval(application_generate_entity)
context = dataset_retrieval.retrieve( context, retrieved_files = dataset_retrieval.retrieve(
app_id=app_record.id, app_id=app_record.id,
user_id=application_generate_entity.user_id, user_id=application_generate_entity.user_id,
tenant_id=app_record.tenant_id, tenant_id=app_record.tenant_id,
@ -130,7 +132,11 @@ class CompletionAppRunner(AppRunner):
hit_callback=hit_callback, hit_callback=hit_callback,
message_id=message.id, message_id=message.id,
inputs=inputs, inputs=inputs,
vision_enabled=application_generate_entity.app_config.app_model_config_dict.get("file_upload", {}).get(
"enabled", False
),
) )
context_files = retrieved_files or []
# reorganize all inputs and template to prompt messages # reorganize all inputs and template to prompt messages
# Include: prompt template, inputs, query(optional), files(optional) # Include: prompt template, inputs, query(optional), files(optional)
@ -144,6 +150,7 @@ class CompletionAppRunner(AppRunner):
query=query, query=query,
context=context, context=context,
image_detail_config=image_detail_config, image_detail_config=image_detail_config,
context_files=context_files,
) )
# check hosting moderation # check hosting moderation

View File

@ -156,79 +156,86 @@ class MessageBasedAppGenerator(BaseAppGenerator):
query = application_generate_entity.query or "New conversation" query = application_generate_entity.query or "New conversation"
conversation_name = (query[:20] + "") if len(query) > 20 else query conversation_name = (query[:20] + "") if len(query) > 20 else query
if not conversation: try:
conversation = Conversation( if not conversation:
conversation = Conversation(
app_id=app_config.app_id,
app_model_config_id=app_model_config_id,
model_provider=model_provider,
model_id=model_id,
override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
mode=app_config.app_mode.value,
name=conversation_name,
inputs=application_generate_entity.inputs,
introduction=introduction,
system_instruction="",
system_instruction_tokens=0,
status="normal",
invoke_from=application_generate_entity.invoke_from.value,
from_source=from_source,
from_end_user_id=end_user_id,
from_account_id=account_id,
)
db.session.add(conversation)
db.session.flush()
db.session.refresh(conversation)
else:
conversation.updated_at = naive_utc_now()
message = Message(
app_id=app_config.app_id, app_id=app_config.app_id,
app_model_config_id=app_model_config_id,
model_provider=model_provider, model_provider=model_provider,
model_id=model_id, model_id=model_id,
override_model_configs=json.dumps(override_model_configs) if override_model_configs else None, override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
mode=app_config.app_mode.value, conversation_id=conversation.id,
name=conversation_name,
inputs=application_generate_entity.inputs, inputs=application_generate_entity.inputs,
introduction=introduction, query=application_generate_entity.query,
system_instruction="", message="",
system_instruction_tokens=0, message_tokens=0,
status="normal", message_unit_price=0,
message_price_unit=0,
answer="",
answer_tokens=0,
answer_unit_price=0,
answer_price_unit=0,
parent_message_id=getattr(application_generate_entity, "parent_message_id", None),
provider_response_latency=0,
total_price=0,
currency="USD",
invoke_from=application_generate_entity.invoke_from.value, invoke_from=application_generate_entity.invoke_from.value,
from_source=from_source, from_source=from_source,
from_end_user_id=end_user_id, from_end_user_id=end_user_id,
from_account_id=account_id, from_account_id=account_id,
app_mode=app_config.app_mode,
) )
db.session.add(conversation) db.session.add(message)
db.session.flush()
db.session.refresh(message)
message_files = []
for file in application_generate_entity.files:
message_file = MessageFile(
message_id=message.id,
type=file.type,
transfer_method=file.transfer_method,
belongs_to="user",
url=file.remote_url,
upload_file_id=file.related_id,
created_by_role=(CreatorUserRole.ACCOUNT if account_id else CreatorUserRole.END_USER),
created_by=account_id or end_user_id or "",
)
message_files.append(message_file)
if message_files:
db.session.add_all(message_files)
db.session.commit() db.session.commit()
db.session.refresh(conversation) return conversation, message
else: except Exception:
conversation.updated_at = naive_utc_now() db.session.rollback()
db.session.commit() raise
message = Message(
app_id=app_config.app_id,
model_provider=model_provider,
model_id=model_id,
override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
conversation_id=conversation.id,
inputs=application_generate_entity.inputs,
query=application_generate_entity.query,
message="",
message_tokens=0,
message_unit_price=0,
message_price_unit=0,
answer="",
answer_tokens=0,
answer_unit_price=0,
answer_price_unit=0,
parent_message_id=getattr(application_generate_entity, "parent_message_id", None),
provider_response_latency=0,
total_price=0,
currency="USD",
invoke_from=application_generate_entity.invoke_from.value,
from_source=from_source,
from_end_user_id=end_user_id,
from_account_id=account_id,
app_mode=app_config.app_mode,
)
db.session.add(message)
db.session.commit()
db.session.refresh(message)
for file in application_generate_entity.files:
message_file = MessageFile(
message_id=message.id,
type=file.type,
transfer_method=file.transfer_method,
belongs_to="user",
url=file.remote_url,
upload_file_id=file.related_id,
created_by_role=(CreatorUserRole.ACCOUNT if account_id else CreatorUserRole.END_USER),
created_by=account_id or end_user_id or "",
)
db.session.add(message_file)
db.session.commit()
return conversation, message
def _get_conversation_introduction(self, application_generate_entity: AppGenerateEntity) -> str: def _get_conversation_introduction(self, application_generate_entity: AppGenerateEntity) -> str:
""" """

View File

@ -18,6 +18,7 @@ from core.workflow.system_variable import SystemVariable
from core.workflow.variable_loader import VariableLoader from core.workflow.variable_loader import VariableLoader
from core.workflow.workflow_entry import WorkflowEntry from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from extensions.otel import WorkflowAppRunnerHandler, trace_span
from libs.datetime_utils import naive_utc_now from libs.datetime_utils import naive_utc_now
from models.enums import UserFrom from models.enums import UserFrom
from models.workflow import Workflow from models.workflow import Workflow
@ -56,6 +57,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
self._workflow_execution_repository = workflow_execution_repository self._workflow_execution_repository = workflow_execution_repository
self._workflow_node_execution_repository = workflow_node_execution_repository self._workflow_node_execution_repository = workflow_node_execution_repository
@trace_span(WorkflowAppRunnerHandler)
def run(self): def run(self):
""" """
Run application Run application

View File

@ -40,9 +40,6 @@ class EasyUITaskState(TaskState):
""" """
llm_result: LLMResult llm_result: LLMResult
first_token_time: float | None = None
last_token_time: float | None = None
is_streaming_response: bool = False
class WorkflowTaskState(TaskState): class WorkflowTaskState(TaskState):

View File

@ -332,12 +332,6 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
if not self._task_state.llm_result.prompt_messages: if not self._task_state.llm_result.prompt_messages:
self._task_state.llm_result.prompt_messages = chunk.prompt_messages self._task_state.llm_result.prompt_messages = chunk.prompt_messages
# Track streaming response times
if self._task_state.first_token_time is None:
self._task_state.first_token_time = time.perf_counter()
self._task_state.is_streaming_response = True
self._task_state.last_token_time = time.perf_counter()
# handle output moderation chunk # handle output moderation chunk
should_direct_answer = self._handle_output_moderation_chunk(cast(str, delta_text)) should_direct_answer = self._handle_output_moderation_chunk(cast(str, delta_text))
if should_direct_answer: if should_direct_answer:
@ -366,7 +360,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
if publisher: if publisher:
publisher.publish(None) publisher.publish(None)
if self._conversation_name_generate_thread: if self._conversation_name_generate_thread:
self._conversation_name_generate_thread.join() logger.debug("Conversation name generation running as daemon thread")
def _save_message(self, *, session: Session, trace_manager: TraceQueueManager | None = None): def _save_message(self, *, session: Session, trace_manager: TraceQueueManager | None = None):
""" """
@ -404,18 +398,6 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
message.total_price = usage.total_price message.total_price = usage.total_price
message.currency = usage.currency message.currency = usage.currency
self._task_state.llm_result.usage.latency = message.provider_response_latency self._task_state.llm_result.usage.latency = message.provider_response_latency
# Add streaming metrics to usage if available
if self._task_state.is_streaming_response and self._task_state.first_token_time:
start_time = self.start_at
first_token_time = self._task_state.first_token_time
last_token_time = self._task_state.last_token_time or first_token_time
usage.time_to_first_token = round(first_token_time - start_time, 3)
usage.time_to_generate = round(last_token_time - first_token_time, 3)
# Update metadata with the complete usage info
self._task_state.metadata.usage = usage
message.message_metadata = self._task_state.metadata.model_dump_json() message.message_metadata = self._task_state.metadata.model_dump_json()
if trace_manager: if trace_manager:

Some files were not shown because too many files have changed in this diff Show More