mirror of https://github.com/langgenius/dify.git
Merge branch 'main' into feat/hitl-frontend
This commit is contained in:
commit
ccef15aafa
|
|
@ -9,6 +9,14 @@
|
|||
# Backend (default owner, more specific rules below will override)
|
||||
api/ @QuantumGhost
|
||||
|
||||
# Backend - MCP
|
||||
api/core/mcp/ @Nov1c444
|
||||
api/core/entities/mcp_provider.py @Nov1c444
|
||||
api/services/tools/mcp_tools_manage_service.py @Nov1c444
|
||||
api/controllers/mcp/ @Nov1c444
|
||||
api/controllers/console/app/mcp_server.py @Nov1c444
|
||||
api/tests/**/*mcp* @Nov1c444
|
||||
|
||||
# Backend - Workflow - Engine (Core graph execution engine)
|
||||
api/core/workflow/graph_engine/ @laipz8200 @QuantumGhost
|
||||
api/core/workflow/runtime/ @laipz8200 @QuantumGhost
|
||||
|
|
|
|||
|
|
@ -1,8 +1,6 @@
|
|||
name: "✨ Refactor"
|
||||
description: Refactor existing code for improved readability and maintainability.
|
||||
title: "[Chore/Refactor] "
|
||||
labels:
|
||||
- refactor
|
||||
name: "✨ Refactor or Chore"
|
||||
description: Refactor existing code or perform maintenance chores to improve readability and reliability.
|
||||
title: "[Refactor/Chore] "
|
||||
body:
|
||||
- type: checkboxes
|
||||
attributes:
|
||||
|
|
@ -11,7 +9,7 @@ body:
|
|||
options:
|
||||
- label: I have read the [Contributing Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md) and [Language Policy](https://github.com/langgenius/dify/issues/1542).
|
||||
required: true
|
||||
- label: This is only for refactoring, if you would like to ask a question, please head to [Discussions](https://github.com/langgenius/dify/discussions/categories/general).
|
||||
- label: This is only for refactors or chores; if you would like to ask a question, please head to [Discussions](https://github.com/langgenius/dify/discussions/categories/general).
|
||||
required: true
|
||||
- label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones.
|
||||
required: true
|
||||
|
|
@ -25,14 +23,14 @@ body:
|
|||
id: description
|
||||
attributes:
|
||||
label: Description
|
||||
placeholder: "Describe the refactor you are proposing."
|
||||
placeholder: "Describe the refactor or chore you are proposing."
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
id: motivation
|
||||
attributes:
|
||||
label: Motivation
|
||||
placeholder: "Explain why this refactor is necessary."
|
||||
placeholder: "Explain why this refactor or chore is necessary."
|
||||
validations:
|
||||
required: false
|
||||
- type: textarea
|
||||
|
|
|
|||
|
|
@ -1,13 +0,0 @@
|
|||
name: "👾 Tracker"
|
||||
description: For inner usages, please do not use this template.
|
||||
title: "[Tracker] "
|
||||
labels:
|
||||
- tracker
|
||||
body:
|
||||
- type: textarea
|
||||
id: content
|
||||
attributes:
|
||||
label: Blockers
|
||||
placeholder: "- [ ] ..."
|
||||
validations:
|
||||
required: true
|
||||
|
|
@ -0,0 +1,21 @@
|
|||
name: Semantic Pull Request
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
types:
|
||||
- opened
|
||||
- edited
|
||||
- reopened
|
||||
- synchronize
|
||||
|
||||
jobs:
|
||||
lint:
|
||||
name: Validate PR title
|
||||
permissions:
|
||||
pull-requests: read
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Check title
|
||||
uses: amannn/action-semantic-pull-request@v6.1.1
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
|
@ -106,7 +106,7 @@ jobs:
|
|||
- name: Web type check
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
working-directory: ./web
|
||||
run: pnpm run type-check
|
||||
run: pnpm run type-check:tsgo
|
||||
|
||||
docker-compose-template:
|
||||
name: Docker Compose Template
|
||||
|
|
|
|||
|
|
@ -24,8 +24,8 @@ The codebase is split into:
|
|||
|
||||
```bash
|
||||
cd web
|
||||
pnpm lint
|
||||
pnpm lint:fix
|
||||
pnpm type-check:tsgo
|
||||
pnpm test
|
||||
```
|
||||
|
||||
|
|
@ -39,7 +39,7 @@ pnpm test
|
|||
## Language Style
|
||||
|
||||
- **Python**: Keep type hints on functions and attributes, and implement relevant special methods (e.g., `__repr__`, `__str__`).
|
||||
- **TypeScript**: Use the strict config, lean on ESLint + Prettier workflows, and avoid `any` types.
|
||||
- **TypeScript**: Use the strict config, rely on ESLint (`pnpm lint:fix` preferred) plus `pnpm type-check:tsgo`, and avoid `any` types.
|
||||
|
||||
## General Practices
|
||||
|
||||
|
|
|
|||
13
README.md
13
README.md
|
|
@ -139,6 +139,19 @@ Star Dify on GitHub and be instantly notified of new releases.
|
|||
|
||||
If you need to customize the configuration, please refer to the comments in our [.env.example](docker/.env.example) file and update the corresponding values in your `.env` file. Additionally, you might need to make adjustments to the `docker-compose.yaml` file itself, such as changing image versions, port mappings, or volume mounts, based on your specific deployment environment and requirements. After making any changes, please re-run `docker-compose up -d`. You can find the full list of available environment variables [here](https://docs.dify.ai/getting-started/install-self-hosted/environments).
|
||||
|
||||
#### Customizing Suggested Questions
|
||||
|
||||
You can now customize the "Suggested Questions After Answer" feature to better fit your use case. For example, to generate longer, more technical questions:
|
||||
|
||||
```bash
|
||||
# In your .env file
|
||||
SUGGESTED_QUESTIONS_PROMPT='Please help me predict the five most likely technical follow-up questions a developer would ask. Focus on implementation details, best practices, and architecture considerations. Keep each question between 40-60 characters. Output must be JSON array: ["question1","question2","question3","question4","question5"]'
|
||||
SUGGESTED_QUESTIONS_MAX_TOKENS=512
|
||||
SUGGESTED_QUESTIONS_TEMPERATURE=0.3
|
||||
```
|
||||
|
||||
See the [Suggested Questions Configuration Guide](docs/suggested-questions-configuration.md) for detailed examples and usage instructions.
|
||||
|
||||
### Metrics Monitoring with Grafana
|
||||
|
||||
Import the dashboard to Grafana, using Dify's PostgreSQL database as data source, to monitor metrics in granularity of apps, tenants, messages, and more.
|
||||
|
|
|
|||
|
|
@ -633,8 +633,30 @@ SWAGGER_UI_PATH=/swagger-ui.html
|
|||
# Set to false to export dataset IDs as plain text for easier cross-environment import
|
||||
DSL_EXPORT_ENCRYPT_DATASET_ID=true
|
||||
|
||||
# Suggested Questions After Answer Configuration
|
||||
# These environment variables allow customization of the suggested questions feature
|
||||
#
|
||||
# Custom prompt for generating suggested questions (optional)
|
||||
# If not set, uses the default prompt that generates 3 questions under 20 characters each
|
||||
# Example: "Please help me predict the five most likely technical follow-up questions a developer would ask. Focus on implementation details, best practices, and architecture considerations. Keep each question between 40-60 characters. Output must be JSON array: [\"question1\",\"question2\",\"question3\",\"question4\",\"question5\"]"
|
||||
# SUGGESTED_QUESTIONS_PROMPT=
|
||||
|
||||
# Maximum number of tokens for suggested questions generation (default: 256)
|
||||
# Adjust this value for longer questions or more questions
|
||||
# SUGGESTED_QUESTIONS_MAX_TOKENS=256
|
||||
|
||||
# Temperature for suggested questions generation (default: 0.0)
|
||||
# Higher values (0.5-1.0) produce more creative questions, lower values (0.0-0.3) produce more focused questions
|
||||
# SUGGESTED_QUESTIONS_TEMPERATURE=0
|
||||
|
||||
# Tenant isolated task queue configuration
|
||||
TENANT_ISOLATED_TASK_CONCURRENCY=1
|
||||
|
||||
# Maximum number of segments for dataset segments API (0 for unlimited)
|
||||
DATASET_MAX_SEGMENTS_PER_REQUEST=0
|
||||
|
||||
# Multimodal knowledgebase limit
|
||||
SINGLE_CHUNK_ATTACHMENT_LIMIT=10
|
||||
ATTACHMENT_IMAGE_FILE_SIZE_LIMIT=2
|
||||
ATTACHMENT_IMAGE_DOWNLOAD_TIMEOUT=60
|
||||
IMAGE_FILE_BATCH_LIMIT=10
|
||||
|
|
|
|||
|
|
@ -36,17 +36,20 @@ select = [
|
|||
"UP", # pyupgrade rules
|
||||
"W191", # tab-indentation
|
||||
"W605", # invalid-escape-sequence
|
||||
"G001", # don't use str format to logging messages
|
||||
"G003", # don't use + in logging messages
|
||||
"G004", # don't use f-strings to format logging messages
|
||||
"UP042", # use StrEnum,
|
||||
"S110", # disallow the try-except-pass pattern.
|
||||
|
||||
# security related linting rules
|
||||
# RCE proctection (sort of)
|
||||
"S102", # exec-builtin, disallow use of `exec`
|
||||
"S307", # suspicious-eval-usage, disallow use of `eval` and `ast.literal_eval`
|
||||
"S301", # suspicious-pickle-usage, disallow use of `pickle` and its wrappers.
|
||||
"S302", # suspicious-marshal-usage, disallow use of `marshal` module
|
||||
"S311", # suspicious-non-cryptographic-random-usage
|
||||
"G001", # don't use str format to logging messages
|
||||
"G003", # don't use + in logging messages
|
||||
"G004", # don't use f-strings to format logging messages
|
||||
"UP042", # use StrEnum
|
||||
"S311", # suspicious-non-cryptographic-random-usage,
|
||||
|
||||
]
|
||||
|
||||
ignore = [
|
||||
|
|
@ -91,18 +94,16 @@ ignore = [
|
|||
"configs/*" = [
|
||||
"N802", # invalid-function-name
|
||||
]
|
||||
"core/model_runtime/callbacks/base_callback.py" = [
|
||||
"T201",
|
||||
]
|
||||
"core/workflow/callbacks/workflow_logging_callback.py" = [
|
||||
"T201",
|
||||
]
|
||||
"core/model_runtime/callbacks/base_callback.py" = ["T201"]
|
||||
"core/workflow/callbacks/workflow_logging_callback.py" = ["T201"]
|
||||
"libs/gmpy2_pkcs10aep_cipher.py" = [
|
||||
"N803", # invalid-argument-name
|
||||
]
|
||||
"tests/*" = [
|
||||
"F811", # redefined-while-unused
|
||||
"T201", # allow print in tests
|
||||
"T201", # allow print in tests,
|
||||
"S110", # allow ignoring exceptions in tests code (currently)
|
||||
|
||||
]
|
||||
|
||||
[lint.pyflakes]
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
import logging
|
||||
import time
|
||||
|
||||
from opentelemetry.trace import get_current_span
|
||||
|
||||
from configs import dify_config
|
||||
from contexts.wrapper import RecyclableContextVar
|
||||
from dify_app import DifyApp
|
||||
|
|
@ -26,8 +28,25 @@ def create_flask_app_with_configs() -> DifyApp:
|
|||
# add an unique identifier to each request
|
||||
RecyclableContextVar.increment_thread_recycles()
|
||||
|
||||
# add after request hook for injecting X-Trace-Id header from OpenTelemetry span context
|
||||
@dify_app.after_request
|
||||
def add_trace_id_header(response):
|
||||
try:
|
||||
span = get_current_span()
|
||||
ctx = span.get_span_context() if span else None
|
||||
if ctx and ctx.is_valid:
|
||||
trace_id_hex = format(ctx.trace_id, "032x")
|
||||
# Avoid duplicates if some middleware added it
|
||||
if "X-Trace-Id" not in response.headers:
|
||||
response.headers["X-Trace-Id"] = trace_id_hex
|
||||
except Exception:
|
||||
# Never break the response due to tracing header injection
|
||||
logger.warning("Failed to add trace ID to response header", exc_info=True)
|
||||
return response
|
||||
|
||||
# Capture the decorator's return value to avoid pyright reportUnusedFunction
|
||||
_ = before_request
|
||||
_ = add_trace_id_header
|
||||
|
||||
return dify_app
|
||||
|
||||
|
|
|
|||
|
|
@ -1139,6 +1139,7 @@ def remove_orphaned_files_on_storage(force: bool):
|
|||
click.echo(click.style(f"Found {len(all_files_in_tables)} files in tables.", fg="white"))
|
||||
except Exception as e:
|
||||
click.echo(click.style(f"Error fetching keys: {str(e)}", fg="red"))
|
||||
return
|
||||
|
||||
all_files_on_storage = []
|
||||
for storage_path in storage_paths:
|
||||
|
|
|
|||
|
|
@ -360,6 +360,26 @@ class FileUploadConfig(BaseSettings):
|
|||
default=10,
|
||||
)
|
||||
|
||||
IMAGE_FILE_BATCH_LIMIT: PositiveInt = Field(
|
||||
description="Maximum number of files allowed in a image batch upload operation",
|
||||
default=10,
|
||||
)
|
||||
|
||||
SINGLE_CHUNK_ATTACHMENT_LIMIT: PositiveInt = Field(
|
||||
description="Maximum number of files allowed in a single chunk attachment",
|
||||
default=10,
|
||||
)
|
||||
|
||||
ATTACHMENT_IMAGE_FILE_SIZE_LIMIT: NonNegativeInt = Field(
|
||||
description="Maximum allowed image file size for attachments in megabytes",
|
||||
default=2,
|
||||
)
|
||||
|
||||
ATTACHMENT_IMAGE_DOWNLOAD_TIMEOUT: NonNegativeInt = Field(
|
||||
description="Timeout for downloading image attachments in seconds",
|
||||
default=60,
|
||||
)
|
||||
|
||||
inner_UPLOAD_FILE_EXTENSION_BLACKLIST: str = Field(
|
||||
description=(
|
||||
"Comma-separated list of file extensions that are blocked from upload. "
|
||||
|
|
@ -553,7 +573,10 @@ class LoggingConfig(BaseSettings):
|
|||
|
||||
LOG_FORMAT: str = Field(
|
||||
description="Format string for log messages",
|
||||
default="%(asctime)s.%(msecs)03d %(levelname)s [%(threadName)s] [%(filename)s:%(lineno)d] - %(message)s",
|
||||
default=(
|
||||
"%(asctime)s.%(msecs)03d %(levelname)s [%(threadName)s] "
|
||||
"[%(filename)s:%(lineno)d] %(trace_id)s - %(message)s"
|
||||
),
|
||||
)
|
||||
|
||||
LOG_DATEFORMAT: str | None = Field(
|
||||
|
|
|
|||
|
|
@ -0,0 +1,26 @@
|
|||
"""Helpers for registering Pydantic models with Flask-RESTX namespaces."""
|
||||
|
||||
from flask_restx import Namespace
|
||||
from pydantic import BaseModel
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
def register_schema_model(namespace: Namespace, model: type[BaseModel]) -> None:
|
||||
"""Register a single BaseModel with a namespace for Swagger documentation."""
|
||||
|
||||
namespace.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
|
||||
def register_schema_models(namespace: Namespace, *models: type[BaseModel]) -> None:
|
||||
"""Register multiple BaseModels with a namespace."""
|
||||
|
||||
for model in models:
|
||||
register_schema_model(namespace, model)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"DEFAULT_REF_TEMPLATE_SWAGGER_2_0",
|
||||
"register_schema_model",
|
||||
"register_schema_models",
|
||||
]
|
||||
|
|
@ -3,7 +3,8 @@ from functools import wraps
|
|||
from typing import ParamSpec, TypeVar
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields, reqparse
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import NotFound, Unauthorized
|
||||
|
|
@ -18,6 +19,30 @@ from extensions.ext_database import db
|
|||
from libs.token import extract_access_token
|
||||
from models.model import App, InstalledApp, RecommendedApp
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class InsertExploreAppPayload(BaseModel):
|
||||
app_id: str = Field(...)
|
||||
desc: str | None = None
|
||||
copyright: str | None = None
|
||||
privacy_policy: str | None = None
|
||||
custom_disclaimer: str | None = None
|
||||
language: str = Field(...)
|
||||
category: str = Field(...)
|
||||
position: int = Field(...)
|
||||
|
||||
@field_validator("language")
|
||||
@classmethod
|
||||
def validate_language(cls, value: str) -> str:
|
||||
return supported_language(value)
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
InsertExploreAppPayload.__name__,
|
||||
InsertExploreAppPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
|
||||
def admin_required(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
|
|
@ -40,59 +65,34 @@ def admin_required(view: Callable[P, R]):
|
|||
class InsertExploreAppListApi(Resource):
|
||||
@console_ns.doc("insert_explore_app")
|
||||
@console_ns.doc(description="Insert or update an app in the explore list")
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"InsertExploreAppRequest",
|
||||
{
|
||||
"app_id": fields.String(required=True, description="Application ID"),
|
||||
"desc": fields.String(description="App description"),
|
||||
"copyright": fields.String(description="Copyright information"),
|
||||
"privacy_policy": fields.String(description="Privacy policy"),
|
||||
"custom_disclaimer": fields.String(description="Custom disclaimer"),
|
||||
"language": fields.String(required=True, description="Language code"),
|
||||
"category": fields.String(required=True, description="App category"),
|
||||
"position": fields.Integer(required=True, description="Display position"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[InsertExploreAppPayload.__name__])
|
||||
@console_ns.response(200, "App updated successfully")
|
||||
@console_ns.response(201, "App inserted successfully")
|
||||
@console_ns.response(404, "App not found")
|
||||
@only_edition_cloud
|
||||
@admin_required
|
||||
def post(self):
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("app_id", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("desc", type=str, location="json")
|
||||
.add_argument("copyright", type=str, location="json")
|
||||
.add_argument("privacy_policy", type=str, location="json")
|
||||
.add_argument("custom_disclaimer", type=str, location="json")
|
||||
.add_argument("language", type=supported_language, required=True, nullable=False, location="json")
|
||||
.add_argument("category", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("position", type=int, required=True, nullable=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
payload = InsertExploreAppPayload.model_validate(console_ns.payload)
|
||||
|
||||
app = db.session.execute(select(App).where(App.id == args["app_id"])).scalar_one_or_none()
|
||||
app = db.session.execute(select(App).where(App.id == payload.app_id)).scalar_one_or_none()
|
||||
if not app:
|
||||
raise NotFound(f"App '{args['app_id']}' is not found")
|
||||
raise NotFound(f"App '{payload.app_id}' is not found")
|
||||
|
||||
site = app.site
|
||||
if not site:
|
||||
desc = args["desc"] or ""
|
||||
copy_right = args["copyright"] or ""
|
||||
privacy_policy = args["privacy_policy"] or ""
|
||||
custom_disclaimer = args["custom_disclaimer"] or ""
|
||||
desc = payload.desc or ""
|
||||
copy_right = payload.copyright or ""
|
||||
privacy_policy = payload.privacy_policy or ""
|
||||
custom_disclaimer = payload.custom_disclaimer or ""
|
||||
else:
|
||||
desc = site.description or args["desc"] or ""
|
||||
copy_right = site.copyright or args["copyright"] or ""
|
||||
privacy_policy = site.privacy_policy or args["privacy_policy"] or ""
|
||||
custom_disclaimer = site.custom_disclaimer or args["custom_disclaimer"] or ""
|
||||
desc = site.description or payload.desc or ""
|
||||
copy_right = site.copyright or payload.copyright or ""
|
||||
privacy_policy = site.privacy_policy or payload.privacy_policy or ""
|
||||
custom_disclaimer = site.custom_disclaimer or payload.custom_disclaimer or ""
|
||||
|
||||
with Session(db.engine) as session:
|
||||
recommended_app = session.execute(
|
||||
select(RecommendedApp).where(RecommendedApp.app_id == args["app_id"])
|
||||
select(RecommendedApp).where(RecommendedApp.app_id == payload.app_id)
|
||||
).scalar_one_or_none()
|
||||
|
||||
if not recommended_app:
|
||||
|
|
@ -102,9 +102,9 @@ class InsertExploreAppListApi(Resource):
|
|||
copyright=copy_right,
|
||||
privacy_policy=privacy_policy,
|
||||
custom_disclaimer=custom_disclaimer,
|
||||
language=args["language"],
|
||||
category=args["category"],
|
||||
position=args["position"],
|
||||
language=payload.language,
|
||||
category=payload.category,
|
||||
position=payload.position,
|
||||
)
|
||||
|
||||
db.session.add(recommended_app)
|
||||
|
|
@ -118,9 +118,9 @@ class InsertExploreAppListApi(Resource):
|
|||
recommended_app.copyright = copy_right
|
||||
recommended_app.privacy_policy = privacy_policy
|
||||
recommended_app.custom_disclaimer = custom_disclaimer
|
||||
recommended_app.language = args["language"]
|
||||
recommended_app.category = args["category"]
|
||||
recommended_app.position = args["position"]
|
||||
recommended_app.language = payload.language
|
||||
recommended_app.category = payload.category
|
||||
recommended_app.position = payload.position
|
||||
|
||||
app.is_public = True
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,6 @@
|
|||
from flask_restx import Resource, fields, reqparse
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
|
|
@ -8,10 +10,21 @@ from libs.login import login_required
|
|||
from models.model import AppMode
|
||||
from services.agent_service import AgentService
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("message_id", type=uuid_value, required=True, location="args", help="Message UUID")
|
||||
.add_argument("conversation_id", type=uuid_value, required=True, location="args", help="Conversation UUID")
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class AgentLogQuery(BaseModel):
|
||||
message_id: str = Field(..., description="Message UUID")
|
||||
conversation_id: str = Field(..., description="Conversation UUID")
|
||||
|
||||
@field_validator("message_id", "conversation_id")
|
||||
@classmethod
|
||||
def validate_uuid(cls, value: str) -> str:
|
||||
return uuid_value(value)
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
AgentLogQuery.__name__, AgentLogQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -20,7 +33,7 @@ class AgentLogApi(Resource):
|
|||
@console_ns.doc("get_agent_logs")
|
||||
@console_ns.doc(description="Get agent execution logs for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(parser)
|
||||
@console_ns.expect(console_ns.models[AgentLogQuery.__name__])
|
||||
@console_ns.response(
|
||||
200, "Agent logs retrieved successfully", fields.List(fields.Raw(description="Agent log entries"))
|
||||
)
|
||||
|
|
@ -31,6 +44,6 @@ class AgentLogApi(Resource):
|
|||
@get_app_model(mode=[AppMode.AGENT_CHAT])
|
||||
def get(self, app_model):
|
||||
"""Get agent logs"""
|
||||
args = parser.parse_args()
|
||||
args = AgentLogQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
return AgentService.get_agent_logs(app_model, args["conversation_id"], args["message_id"])
|
||||
return AgentService.get_agent_logs(app_model, args.conversation_id, args.message_id)
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
from typing import Literal
|
||||
from typing import Any, Literal
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields, marshal, marshal_with, reqparse
|
||||
from flask_restx import Resource, fields, marshal, marshal_with
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from controllers.common.errors import NoFileUploadedError, TooManyFilesError
|
||||
from controllers.console import console_ns
|
||||
|
|
@ -21,22 +22,79 @@ from libs.helper import uuid_value
|
|||
from libs.login import login_required
|
||||
from services.annotation_service import AppAnnotationService
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class AnnotationReplyPayload(BaseModel):
|
||||
score_threshold: float = Field(..., description="Score threshold for annotation matching")
|
||||
embedding_provider_name: str = Field(..., description="Embedding provider name")
|
||||
embedding_model_name: str = Field(..., description="Embedding model name")
|
||||
|
||||
|
||||
class AnnotationSettingUpdatePayload(BaseModel):
|
||||
score_threshold: float = Field(..., description="Score threshold")
|
||||
|
||||
|
||||
class AnnotationListQuery(BaseModel):
|
||||
page: int = Field(default=1, ge=1, description="Page number")
|
||||
limit: int = Field(default=20, ge=1, description="Page size")
|
||||
keyword: str = Field(default="", description="Search keyword")
|
||||
|
||||
|
||||
class CreateAnnotationPayload(BaseModel):
|
||||
message_id: str | None = Field(default=None, description="Message ID")
|
||||
question: str | None = Field(default=None, description="Question text")
|
||||
answer: str | None = Field(default=None, description="Answer text")
|
||||
content: str | None = Field(default=None, description="Content text")
|
||||
annotation_reply: dict[str, Any] | None = Field(default=None, description="Annotation reply data")
|
||||
|
||||
@field_validator("message_id")
|
||||
@classmethod
|
||||
def validate_message_id(cls, value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return value
|
||||
return uuid_value(value)
|
||||
|
||||
|
||||
class UpdateAnnotationPayload(BaseModel):
|
||||
question: str | None = None
|
||||
answer: str | None = None
|
||||
content: str | None = None
|
||||
annotation_reply: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class AnnotationReplyStatusQuery(BaseModel):
|
||||
action: Literal["enable", "disable"]
|
||||
|
||||
|
||||
class AnnotationFilePayload(BaseModel):
|
||||
message_id: str = Field(..., description="Message ID")
|
||||
|
||||
@field_validator("message_id")
|
||||
@classmethod
|
||||
def validate_message_id(cls, value: str) -> str:
|
||||
return uuid_value(value)
|
||||
|
||||
|
||||
def reg(model: type[BaseModel]) -> None:
|
||||
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
|
||||
reg(AnnotationReplyPayload)
|
||||
reg(AnnotationSettingUpdatePayload)
|
||||
reg(AnnotationListQuery)
|
||||
reg(CreateAnnotationPayload)
|
||||
reg(UpdateAnnotationPayload)
|
||||
reg(AnnotationReplyStatusQuery)
|
||||
reg(AnnotationFilePayload)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/annotation-reply/<string:action>")
|
||||
class AnnotationReplyActionApi(Resource):
|
||||
@console_ns.doc("annotation_reply_action")
|
||||
@console_ns.doc(description="Enable or disable annotation reply for an app")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "action": "Action to perform (enable/disable)"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"AnnotationReplyActionRequest",
|
||||
{
|
||||
"score_threshold": fields.Float(required=True, description="Score threshold for annotation matching"),
|
||||
"embedding_provider_name": fields.String(required=True, description="Embedding provider name"),
|
||||
"embedding_model_name": fields.String(required=True, description="Embedding model name"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[AnnotationReplyPayload.__name__])
|
||||
@console_ns.response(200, "Action completed successfully")
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
|
|
@ -46,15 +104,9 @@ class AnnotationReplyActionApi(Resource):
|
|||
@edit_permission_required
|
||||
def post(self, app_id, action: Literal["enable", "disable"]):
|
||||
app_id = str(app_id)
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("score_threshold", required=True, type=float, location="json")
|
||||
.add_argument("embedding_provider_name", required=True, type=str, location="json")
|
||||
.add_argument("embedding_model_name", required=True, type=str, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = AnnotationReplyPayload.model_validate(console_ns.payload)
|
||||
if action == "enable":
|
||||
result = AppAnnotationService.enable_app_annotation(args, app_id)
|
||||
result = AppAnnotationService.enable_app_annotation(args.model_dump(), app_id)
|
||||
elif action == "disable":
|
||||
result = AppAnnotationService.disable_app_annotation(app_id)
|
||||
return result, 200
|
||||
|
|
@ -82,16 +134,7 @@ class AppAnnotationSettingUpdateApi(Resource):
|
|||
@console_ns.doc("update_annotation_setting")
|
||||
@console_ns.doc(description="Update annotation settings for an app")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "annotation_setting_id": "Annotation setting ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"AnnotationSettingUpdateRequest",
|
||||
{
|
||||
"score_threshold": fields.Float(required=True, description="Score threshold"),
|
||||
"embedding_provider_name": fields.String(required=True, description="Embedding provider"),
|
||||
"embedding_model_name": fields.String(required=True, description="Embedding model"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[AnnotationSettingUpdatePayload.__name__])
|
||||
@console_ns.response(200, "Settings updated successfully")
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
|
|
@ -102,10 +145,9 @@ class AppAnnotationSettingUpdateApi(Resource):
|
|||
app_id = str(app_id)
|
||||
annotation_setting_id = str(annotation_setting_id)
|
||||
|
||||
parser = reqparse.RequestParser().add_argument("score_threshold", required=True, type=float, location="json")
|
||||
args = parser.parse_args()
|
||||
args = AnnotationSettingUpdatePayload.model_validate(console_ns.payload)
|
||||
|
||||
result = AppAnnotationService.update_app_annotation_setting(app_id, annotation_setting_id, args)
|
||||
result = AppAnnotationService.update_app_annotation_setting(app_id, annotation_setting_id, args.model_dump())
|
||||
return result, 200
|
||||
|
||||
|
||||
|
|
@ -142,12 +184,7 @@ class AnnotationApi(Resource):
|
|||
@console_ns.doc("list_annotations")
|
||||
@console_ns.doc(description="Get annotations for an app with pagination")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.parser()
|
||||
.add_argument("page", type=int, location="args", default=1, help="Page number")
|
||||
.add_argument("limit", type=int, location="args", default=20, help="Page size")
|
||||
.add_argument("keyword", type=str, location="args", default="", help="Search keyword")
|
||||
)
|
||||
@console_ns.expect(console_ns.models[AnnotationListQuery.__name__])
|
||||
@console_ns.response(200, "Annotations retrieved successfully")
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
|
|
@ -155,9 +192,10 @@ class AnnotationApi(Resource):
|
|||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def get(self, app_id):
|
||||
page = request.args.get("page", default=1, type=int)
|
||||
limit = request.args.get("limit", default=20, type=int)
|
||||
keyword = request.args.get("keyword", default="", type=str)
|
||||
args = AnnotationListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
page = args.page
|
||||
limit = args.limit
|
||||
keyword = args.keyword
|
||||
|
||||
app_id = str(app_id)
|
||||
annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(app_id, page, limit, keyword)
|
||||
|
|
@ -173,18 +211,7 @@ class AnnotationApi(Resource):
|
|||
@console_ns.doc("create_annotation")
|
||||
@console_ns.doc(description="Create a new annotation for an app")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"CreateAnnotationRequest",
|
||||
{
|
||||
"message_id": fields.String(description="Message ID (optional)"),
|
||||
"question": fields.String(description="Question text (required when message_id not provided)"),
|
||||
"answer": fields.String(description="Answer text (use 'answer' or 'content')"),
|
||||
"content": fields.String(description="Content text (use 'answer' or 'content')"),
|
||||
"annotation_reply": fields.Raw(description="Annotation reply data"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[CreateAnnotationPayload.__name__])
|
||||
@console_ns.response(201, "Annotation created successfully", build_annotation_model(console_ns))
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
|
|
@ -195,16 +222,9 @@ class AnnotationApi(Resource):
|
|||
@edit_permission_required
|
||||
def post(self, app_id):
|
||||
app_id = str(app_id)
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("message_id", required=False, type=uuid_value, location="json")
|
||||
.add_argument("question", required=False, type=str, location="json")
|
||||
.add_argument("answer", required=False, type=str, location="json")
|
||||
.add_argument("content", required=False, type=str, location="json")
|
||||
.add_argument("annotation_reply", required=False, type=dict, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
annotation = AppAnnotationService.up_insert_app_annotation_from_message(args, app_id)
|
||||
args = CreateAnnotationPayload.model_validate(console_ns.payload)
|
||||
data = args.model_dump(exclude_none=True)
|
||||
annotation = AppAnnotationService.up_insert_app_annotation_from_message(data, app_id)
|
||||
return annotation
|
||||
|
||||
@setup_required
|
||||
|
|
@ -256,13 +276,6 @@ class AnnotationExportApi(Resource):
|
|||
return response, 200
|
||||
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("question", required=True, type=str, location="json")
|
||||
.add_argument("answer", required=True, type=str, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/annotations/<uuid:annotation_id>")
|
||||
class AnnotationUpdateDeleteApi(Resource):
|
||||
@console_ns.doc("update_delete_annotation")
|
||||
|
|
@ -271,7 +284,7 @@ class AnnotationUpdateDeleteApi(Resource):
|
|||
@console_ns.response(200, "Annotation updated successfully", build_annotation_model(console_ns))
|
||||
@console_ns.response(204, "Annotation deleted successfully")
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@console_ns.expect(parser)
|
||||
@console_ns.expect(console_ns.models[UpdateAnnotationPayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -281,8 +294,10 @@ class AnnotationUpdateDeleteApi(Resource):
|
|||
def post(self, app_id, annotation_id):
|
||||
app_id = str(app_id)
|
||||
annotation_id = str(annotation_id)
|
||||
args = parser.parse_args()
|
||||
annotation = AppAnnotationService.update_app_annotation_directly(args, app_id, annotation_id)
|
||||
args = UpdateAnnotationPayload.model_validate(console_ns.payload)
|
||||
annotation = AppAnnotationService.update_app_annotation_directly(
|
||||
args.model_dump(exclude_none=True), app_id, annotation_id
|
||||
)
|
||||
return annotation
|
||||
|
||||
@setup_required
|
||||
|
|
|
|||
|
|
@ -31,7 +31,6 @@ from fields.app_fields import (
|
|||
from fields.workflow_fields import workflow_partial_fields as _workflow_partial_fields_dict
|
||||
from libs.helper import AppIconUrlField, TimestampField
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.validators import validate_description_length
|
||||
from models import App, Workflow
|
||||
from services.app_dsl_service import AppDslService, ImportMode
|
||||
from services.app_service import AppService
|
||||
|
|
@ -76,51 +75,30 @@ class AppListQuery(BaseModel):
|
|||
|
||||
class CreateAppPayload(BaseModel):
|
||||
name: str = Field(..., min_length=1, description="App name")
|
||||
description: str | None = Field(default=None, description="App description (max 400 chars)")
|
||||
description: str | None = Field(default=None, description="App description (max 400 chars)", max_length=400)
|
||||
mode: Literal["chat", "agent-chat", "advanced-chat", "workflow", "completion"] = Field(..., description="App mode")
|
||||
icon_type: str | None = Field(default=None, description="Icon type")
|
||||
icon: str | None = Field(default=None, description="Icon")
|
||||
icon_background: str | None = Field(default=None, description="Icon background color")
|
||||
|
||||
@field_validator("description")
|
||||
@classmethod
|
||||
def validate_description(cls, value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return value
|
||||
return validate_description_length(value)
|
||||
|
||||
|
||||
class UpdateAppPayload(BaseModel):
|
||||
name: str = Field(..., min_length=1, description="App name")
|
||||
description: str | None = Field(default=None, description="App description (max 400 chars)")
|
||||
description: str | None = Field(default=None, description="App description (max 400 chars)", max_length=400)
|
||||
icon_type: str | None = Field(default=None, description="Icon type")
|
||||
icon: str | None = Field(default=None, description="Icon")
|
||||
icon_background: str | None = Field(default=None, description="Icon background color")
|
||||
use_icon_as_answer_icon: bool | None = Field(default=None, description="Use icon as answer icon")
|
||||
max_active_requests: int | None = Field(default=None, description="Maximum active requests")
|
||||
|
||||
@field_validator("description")
|
||||
@classmethod
|
||||
def validate_description(cls, value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return value
|
||||
return validate_description_length(value)
|
||||
|
||||
|
||||
class CopyAppPayload(BaseModel):
|
||||
name: str | None = Field(default=None, description="Name for the copied app")
|
||||
description: str | None = Field(default=None, description="Description for the copied app")
|
||||
description: str | None = Field(default=None, description="Description for the copied app", max_length=400)
|
||||
icon_type: str | None = Field(default=None, description="Icon type")
|
||||
icon: str | None = Field(default=None, description="Icon")
|
||||
icon_background: str | None = Field(default=None, description="Icon background color")
|
||||
|
||||
@field_validator("description")
|
||||
@classmethod
|
||||
def validate_description(cls, value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return value
|
||||
return validate_description_length(value)
|
||||
|
||||
|
||||
class AppExportQuery(BaseModel):
|
||||
include_secret: bool = Field(default=False, description="Include secrets in export")
|
||||
|
|
@ -146,7 +124,14 @@ class AppApiStatusPayload(BaseModel):
|
|||
|
||||
class AppTracePayload(BaseModel):
|
||||
enabled: bool = Field(..., description="Enable or disable tracing")
|
||||
tracing_provider: str = Field(..., description="Tracing provider")
|
||||
tracing_provider: str | None = Field(default=None, description="Tracing provider")
|
||||
|
||||
@field_validator("tracing_provider")
|
||||
@classmethod
|
||||
def validate_tracing_provider(cls, value: str | None, info) -> str | None:
|
||||
if info.data.get("enabled") and not value:
|
||||
raise ValueError("tracing_provider is required when enabled is True")
|
||||
return value
|
||||
|
||||
|
||||
def reg(cls: type[BaseModel]):
|
||||
|
|
@ -324,10 +309,13 @@ class AppListApi(Resource):
|
|||
NodeType.TRIGGER_PLUGIN,
|
||||
}
|
||||
for workflow in draft_workflows:
|
||||
for _, node_data in workflow.walk_nodes():
|
||||
if node_data.get("type") in trigger_node_types:
|
||||
draft_trigger_app_ids.add(str(workflow.app_id))
|
||||
break
|
||||
try:
|
||||
for _, node_data in workflow.walk_nodes():
|
||||
if node_data.get("type") in trigger_node_types:
|
||||
draft_trigger_app_ids.add(str(workflow.app_id))
|
||||
break
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
for app in app_pagination.items:
|
||||
app.has_draft_trigger = str(app.id) in draft_trigger_app_ids
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
from flask_restx import Resource, fields, marshal_with, reqparse
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
|
|
@ -35,23 +36,29 @@ app_import_check_dependencies_model = console_ns.model(
|
|||
"AppImportCheckDependencies", app_import_check_dependencies_fields_copy
|
||||
)
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("mode", type=str, required=True, location="json")
|
||||
.add_argument("yaml_content", type=str, location="json")
|
||||
.add_argument("yaml_url", type=str, location="json")
|
||||
.add_argument("name", type=str, location="json")
|
||||
.add_argument("description", type=str, location="json")
|
||||
.add_argument("icon_type", type=str, location="json")
|
||||
.add_argument("icon", type=str, location="json")
|
||||
.add_argument("icon_background", type=str, location="json")
|
||||
.add_argument("app_id", type=str, location="json")
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class AppImportPayload(BaseModel):
|
||||
mode: str = Field(..., description="Import mode")
|
||||
yaml_content: str | None = None
|
||||
yaml_url: str | None = None
|
||||
name: str | None = None
|
||||
description: str | None = None
|
||||
icon_type: str | None = None
|
||||
icon: str | None = None
|
||||
icon_background: str | None = None
|
||||
app_id: str | None = None
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
AppImportPayload.__name__, AppImportPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/imports")
|
||||
class AppImportApi(Resource):
|
||||
@console_ns.expect(parser)
|
||||
@console_ns.expect(console_ns.models[AppImportPayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -61,7 +68,7 @@ class AppImportApi(Resource):
|
|||
def post(self):
|
||||
# Check user role first
|
||||
current_user, _ = current_account_with_tenant()
|
||||
args = parser.parse_args()
|
||||
args = AppImportPayload.model_validate(console_ns.payload)
|
||||
|
||||
# Create service with session
|
||||
with Session(db.engine) as session:
|
||||
|
|
@ -70,15 +77,15 @@ class AppImportApi(Resource):
|
|||
account = current_user
|
||||
result = import_service.import_app(
|
||||
account=account,
|
||||
import_mode=args["mode"],
|
||||
yaml_content=args.get("yaml_content"),
|
||||
yaml_url=args.get("yaml_url"),
|
||||
name=args.get("name"),
|
||||
description=args.get("description"),
|
||||
icon_type=args.get("icon_type"),
|
||||
icon=args.get("icon"),
|
||||
icon_background=args.get("icon_background"),
|
||||
app_id=args.get("app_id"),
|
||||
import_mode=args.mode,
|
||||
yaml_content=args.yaml_content,
|
||||
yaml_url=args.yaml_url,
|
||||
name=args.name,
|
||||
description=args.description,
|
||||
icon_type=args.icon_type,
|
||||
icon=args.icon,
|
||||
icon_background=args.icon_background,
|
||||
app_id=args.app_id,
|
||||
)
|
||||
session.commit()
|
||||
if result.app_id and FeatureService.get_system_features().webapp_auth.enabled:
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
import logging
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields, reqparse
|
||||
from flask_restx import Resource, fields
|
||||
from pydantic import BaseModel, Field
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
import services
|
||||
|
|
@ -32,6 +33,27 @@ from services.errors.audio import (
|
|||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class TextToSpeechPayload(BaseModel):
|
||||
message_id: str | None = Field(default=None, description="Message ID")
|
||||
text: str = Field(..., description="Text to convert")
|
||||
voice: str | None = Field(default=None, description="Voice name")
|
||||
streaming: bool | None = Field(default=None, description="Whether to stream audio")
|
||||
|
||||
|
||||
class TextToSpeechVoiceQuery(BaseModel):
|
||||
language: str = Field(..., description="Language code")
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
TextToSpeechPayload.__name__, TextToSpeechPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
console_ns.schema_model(
|
||||
TextToSpeechVoiceQuery.__name__,
|
||||
TextToSpeechVoiceQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/audio-to-text")
|
||||
|
|
@ -92,17 +114,7 @@ class ChatMessageTextApi(Resource):
|
|||
@console_ns.doc("chat_message_text_to_speech")
|
||||
@console_ns.doc(description="Convert text to speech for chat messages")
|
||||
@console_ns.doc(params={"app_id": "App ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"TextToSpeechRequest",
|
||||
{
|
||||
"message_id": fields.String(description="Message ID"),
|
||||
"text": fields.String(required=True, description="Text to convert to speech"),
|
||||
"voice": fields.String(description="Voice to use for TTS"),
|
||||
"streaming": fields.Boolean(description="Whether to stream the audio"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[TextToSpeechPayload.__name__])
|
||||
@console_ns.response(200, "Text to speech conversion successful")
|
||||
@console_ns.response(400, "Bad request - Invalid parameters")
|
||||
@get_app_model
|
||||
|
|
@ -111,21 +123,14 @@ class ChatMessageTextApi(Resource):
|
|||
@account_initialization_required
|
||||
def post(self, app_model: App):
|
||||
try:
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("message_id", type=str, location="json")
|
||||
.add_argument("text", type=str, location="json")
|
||||
.add_argument("voice", type=str, location="json")
|
||||
.add_argument("streaming", type=bool, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
message_id = args.get("message_id", None)
|
||||
text = args.get("text", None)
|
||||
voice = args.get("voice", None)
|
||||
payload = TextToSpeechPayload.model_validate(console_ns.payload)
|
||||
|
||||
response = AudioService.transcript_tts(
|
||||
app_model=app_model, text=text, voice=voice, message_id=message_id, is_draft=True
|
||||
app_model=app_model,
|
||||
text=payload.text,
|
||||
voice=payload.voice,
|
||||
message_id=payload.message_id,
|
||||
is_draft=True,
|
||||
)
|
||||
return response
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
|
|
@ -159,9 +164,7 @@ class TextModesApi(Resource):
|
|||
@console_ns.doc("get_text_to_speech_voices")
|
||||
@console_ns.doc(description="Get available TTS voices for a specific language")
|
||||
@console_ns.doc(params={"app_id": "App ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.parser().add_argument("language", type=str, required=True, location="args", help="Language code")
|
||||
)
|
||||
@console_ns.expect(console_ns.models[TextToSpeechVoiceQuery.__name__])
|
||||
@console_ns.response(
|
||||
200, "TTS voices retrieved successfully", fields.List(fields.Raw(description="Available voices"))
|
||||
)
|
||||
|
|
@ -172,12 +175,11 @@ class TextModesApi(Resource):
|
|||
@account_initialization_required
|
||||
def get(self, app_model):
|
||||
try:
|
||||
parser = reqparse.RequestParser().add_argument("language", type=str, required=True, location="args")
|
||||
args = parser.parse_args()
|
||||
args = TextToSpeechVoiceQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
response = AudioService.transcript_tts_voices(
|
||||
tenant_id=app_model.tenant_id,
|
||||
language=args["language"],
|
||||
language=args.language,
|
||||
)
|
||||
|
||||
return response
|
||||
|
|
|
|||
|
|
@ -49,7 +49,6 @@ class CompletionConversationQuery(BaseConversationQuery):
|
|||
|
||||
|
||||
class ChatConversationQuery(BaseConversationQuery):
|
||||
message_count_gte: int | None = Field(default=None, ge=1, description="Minimum message count")
|
||||
sort_by: Literal["created_at", "-created_at", "updated_at", "-updated_at"] = Field(
|
||||
default="-updated_at", description="Sort field and direction"
|
||||
)
|
||||
|
|
@ -509,14 +508,6 @@ class ChatConversationApi(Resource):
|
|||
.having(func.count(MessageAnnotation.id) == 0)
|
||||
)
|
||||
|
||||
if args.message_count_gte and args.message_count_gte >= 1:
|
||||
query = (
|
||||
query.options(joinedload(Conversation.messages)) # type: ignore
|
||||
.join(Message, Message.conversation_id == Conversation.id)
|
||||
.group_by(Conversation.id)
|
||||
.having(func.count(Message.id) >= args.message_count_gte)
|
||||
)
|
||||
|
||||
if app_model.mode == AppMode.ADVANCED_CHAT:
|
||||
query = query.where(Conversation.invoke_from != InvokeFrom.DEBUGGER)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
import json
|
||||
from enum import StrEnum
|
||||
|
||||
from flask_restx import Resource, fields, marshal_with, reqparse
|
||||
from flask_restx import Resource, marshal_with
|
||||
from pydantic import BaseModel, Field
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.console import console_ns
|
||||
|
|
@ -12,6 +13,8 @@ from fields.app_fields import app_server_fields
|
|||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.model import AppMCPServer
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
# Register model for flask_restx to avoid dict type issues in Swagger
|
||||
app_server_model = console_ns.model("AppServer", app_server_fields)
|
||||
|
||||
|
|
@ -21,6 +24,22 @@ class AppMCPServerStatus(StrEnum):
|
|||
INACTIVE = "inactive"
|
||||
|
||||
|
||||
class MCPServerCreatePayload(BaseModel):
|
||||
description: str | None = Field(default=None, description="Server description")
|
||||
parameters: dict = Field(..., description="Server parameters configuration")
|
||||
|
||||
|
||||
class MCPServerUpdatePayload(BaseModel):
|
||||
id: str = Field(..., description="Server ID")
|
||||
description: str | None = Field(default=None, description="Server description")
|
||||
parameters: dict = Field(..., description="Server parameters configuration")
|
||||
status: str | None = Field(default=None, description="Server status")
|
||||
|
||||
|
||||
for model in (MCPServerCreatePayload, MCPServerUpdatePayload):
|
||||
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/server")
|
||||
class AppMCPServerController(Resource):
|
||||
@console_ns.doc("get_app_mcp_server")
|
||||
|
|
@ -39,15 +58,7 @@ class AppMCPServerController(Resource):
|
|||
@console_ns.doc("create_app_mcp_server")
|
||||
@console_ns.doc(description="Create MCP server configuration for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"MCPServerCreateRequest",
|
||||
{
|
||||
"description": fields.String(description="Server description"),
|
||||
"parameters": fields.Raw(required=True, description="Server parameters configuration"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[MCPServerCreatePayload.__name__])
|
||||
@console_ns.response(201, "MCP server configuration created successfully", app_server_model)
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@account_initialization_required
|
||||
|
|
@ -58,21 +69,16 @@ class AppMCPServerController(Resource):
|
|||
@edit_permission_required
|
||||
def post(self, app_model):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("description", type=str, required=False, location="json")
|
||||
.add_argument("parameters", type=dict, required=True, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
payload = MCPServerCreatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
description = args.get("description")
|
||||
description = payload.description
|
||||
if not description:
|
||||
description = app_model.description or ""
|
||||
|
||||
server = AppMCPServer(
|
||||
name=app_model.name,
|
||||
description=description,
|
||||
parameters=json.dumps(args["parameters"], ensure_ascii=False),
|
||||
parameters=json.dumps(payload.parameters, ensure_ascii=False),
|
||||
status=AppMCPServerStatus.ACTIVE,
|
||||
app_id=app_model.id,
|
||||
tenant_id=current_tenant_id,
|
||||
|
|
@ -85,17 +91,7 @@ class AppMCPServerController(Resource):
|
|||
@console_ns.doc("update_app_mcp_server")
|
||||
@console_ns.doc(description="Update MCP server configuration for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"MCPServerUpdateRequest",
|
||||
{
|
||||
"id": fields.String(required=True, description="Server ID"),
|
||||
"description": fields.String(description="Server description"),
|
||||
"parameters": fields.Raw(required=True, description="Server parameters configuration"),
|
||||
"status": fields.String(description="Server status"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[MCPServerUpdatePayload.__name__])
|
||||
@console_ns.response(200, "MCP server configuration updated successfully", app_server_model)
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@console_ns.response(404, "Server not found")
|
||||
|
|
@ -106,19 +102,12 @@ class AppMCPServerController(Resource):
|
|||
@marshal_with(app_server_model)
|
||||
@edit_permission_required
|
||||
def put(self, app_model):
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("id", type=str, required=True, location="json")
|
||||
.add_argument("description", type=str, required=False, location="json")
|
||||
.add_argument("parameters", type=dict, required=True, location="json")
|
||||
.add_argument("status", type=str, required=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
server = db.session.query(AppMCPServer).where(AppMCPServer.id == args["id"]).first()
|
||||
payload = MCPServerUpdatePayload.model_validate(console_ns.payload or {})
|
||||
server = db.session.query(AppMCPServer).where(AppMCPServer.id == payload.id).first()
|
||||
if not server:
|
||||
raise NotFound()
|
||||
|
||||
description = args.get("description")
|
||||
description = payload.description
|
||||
if description is None:
|
||||
pass
|
||||
elif not description:
|
||||
|
|
@ -126,11 +115,11 @@ class AppMCPServerController(Resource):
|
|||
else:
|
||||
server.description = description
|
||||
|
||||
server.parameters = json.dumps(args["parameters"], ensure_ascii=False)
|
||||
if args["status"]:
|
||||
if args["status"] not in [status.value for status in AppMCPServerStatus]:
|
||||
server.parameters = json.dumps(payload.parameters, ensure_ascii=False)
|
||||
if payload.status:
|
||||
if payload.status not in [status.value for status in AppMCPServerStatus]:
|
||||
raise ValueError("Invalid status")
|
||||
server.status = args["status"]
|
||||
server.status = payload.status
|
||||
db.session.commit()
|
||||
return server
|
||||
|
||||
|
|
|
|||
|
|
@ -61,6 +61,7 @@ class ChatMessagesQuery(BaseModel):
|
|||
class MessageFeedbackPayload(BaseModel):
|
||||
message_id: str = Field(..., description="Message ID")
|
||||
rating: Literal["like", "dislike"] | None = Field(default=None, description="Feedback rating")
|
||||
content: str | None = Field(default=None, description="Feedback content")
|
||||
|
||||
@field_validator("message_id")
|
||||
@classmethod
|
||||
|
|
@ -324,6 +325,7 @@ class MessageFeedbackApi(Resource):
|
|||
db.session.delete(feedback)
|
||||
elif args.rating and feedback:
|
||||
feedback.rating = args.rating
|
||||
feedback.content = args.content
|
||||
elif not args.rating and not feedback:
|
||||
raise ValueError("rating cannot be None when feedback not exists")
|
||||
else:
|
||||
|
|
@ -335,6 +337,7 @@ class MessageFeedbackApi(Resource):
|
|||
conversation_id=message.conversation_id,
|
||||
message_id=message.id,
|
||||
rating=rating_value,
|
||||
content=args.content,
|
||||
from_source="admin",
|
||||
from_account_id=current_user.id,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,4 +1,8 @@
|
|||
from flask_restx import Resource, fields, reqparse
|
||||
from typing import Any
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields
|
||||
from pydantic import BaseModel, Field
|
||||
from werkzeug.exceptions import BadRequest
|
||||
|
||||
from controllers.console import console_ns
|
||||
|
|
@ -7,6 +11,26 @@ from controllers.console.wraps import account_initialization_required, setup_req
|
|||
from libs.login import login_required
|
||||
from services.ops_service import OpsService
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class TraceProviderQuery(BaseModel):
|
||||
tracing_provider: str = Field(..., description="Tracing provider name")
|
||||
|
||||
|
||||
class TraceConfigPayload(BaseModel):
|
||||
tracing_provider: str = Field(..., description="Tracing provider name")
|
||||
tracing_config: dict[str, Any] = Field(..., description="Tracing configuration data")
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
TraceProviderQuery.__name__,
|
||||
TraceProviderQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
console_ns.schema_model(
|
||||
TraceConfigPayload.__name__, TraceConfigPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/trace-config")
|
||||
class TraceAppConfigApi(Resource):
|
||||
|
|
@ -17,11 +41,7 @@ class TraceAppConfigApi(Resource):
|
|||
@console_ns.doc("get_trace_app_config")
|
||||
@console_ns.doc(description="Get tracing configuration for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.parser().add_argument(
|
||||
"tracing_provider", type=str, required=True, location="args", help="Tracing provider name"
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[TraceProviderQuery.__name__])
|
||||
@console_ns.response(
|
||||
200, "Tracing configuration retrieved successfully", fields.Raw(description="Tracing configuration data")
|
||||
)
|
||||
|
|
@ -30,11 +50,10 @@ class TraceAppConfigApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_id):
|
||||
parser = reqparse.RequestParser().add_argument("tracing_provider", type=str, required=True, location="args")
|
||||
args = parser.parse_args()
|
||||
args = TraceProviderQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
try:
|
||||
trace_config = OpsService.get_tracing_app_config(app_id=app_id, tracing_provider=args["tracing_provider"])
|
||||
trace_config = OpsService.get_tracing_app_config(app_id=app_id, tracing_provider=args.tracing_provider)
|
||||
if not trace_config:
|
||||
return {"has_not_configured": True}
|
||||
return trace_config
|
||||
|
|
@ -44,15 +63,7 @@ class TraceAppConfigApi(Resource):
|
|||
@console_ns.doc("create_trace_app_config")
|
||||
@console_ns.doc(description="Create a new tracing configuration for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"TraceConfigCreateRequest",
|
||||
{
|
||||
"tracing_provider": fields.String(required=True, description="Tracing provider name"),
|
||||
"tracing_config": fields.Raw(required=True, description="Tracing configuration data"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[TraceConfigPayload.__name__])
|
||||
@console_ns.response(
|
||||
201, "Tracing configuration created successfully", fields.Raw(description="Created configuration data")
|
||||
)
|
||||
|
|
@ -62,16 +73,11 @@ class TraceAppConfigApi(Resource):
|
|||
@account_initialization_required
|
||||
def post(self, app_id):
|
||||
"""Create a new trace app configuration"""
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("tracing_provider", type=str, required=True, location="json")
|
||||
.add_argument("tracing_config", type=dict, required=True, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = TraceConfigPayload.model_validate(console_ns.payload)
|
||||
|
||||
try:
|
||||
result = OpsService.create_tracing_app_config(
|
||||
app_id=app_id, tracing_provider=args["tracing_provider"], tracing_config=args["tracing_config"]
|
||||
app_id=app_id, tracing_provider=args.tracing_provider, tracing_config=args.tracing_config
|
||||
)
|
||||
if not result:
|
||||
raise TracingConfigIsExist()
|
||||
|
|
@ -84,15 +90,7 @@ class TraceAppConfigApi(Resource):
|
|||
@console_ns.doc("update_trace_app_config")
|
||||
@console_ns.doc(description="Update an existing tracing configuration for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"TraceConfigUpdateRequest",
|
||||
{
|
||||
"tracing_provider": fields.String(required=True, description="Tracing provider name"),
|
||||
"tracing_config": fields.Raw(required=True, description="Updated tracing configuration data"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[TraceConfigPayload.__name__])
|
||||
@console_ns.response(200, "Tracing configuration updated successfully", fields.Raw(description="Success response"))
|
||||
@console_ns.response(400, "Invalid request parameters or configuration not found")
|
||||
@setup_required
|
||||
|
|
@ -100,16 +98,11 @@ class TraceAppConfigApi(Resource):
|
|||
@account_initialization_required
|
||||
def patch(self, app_id):
|
||||
"""Update an existing trace app configuration"""
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("tracing_provider", type=str, required=True, location="json")
|
||||
.add_argument("tracing_config", type=dict, required=True, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = TraceConfigPayload.model_validate(console_ns.payload)
|
||||
|
||||
try:
|
||||
result = OpsService.update_tracing_app_config(
|
||||
app_id=app_id, tracing_provider=args["tracing_provider"], tracing_config=args["tracing_config"]
|
||||
app_id=app_id, tracing_provider=args.tracing_provider, tracing_config=args.tracing_config
|
||||
)
|
||||
if not result:
|
||||
raise TracingConfigNotExist()
|
||||
|
|
@ -120,11 +113,7 @@ class TraceAppConfigApi(Resource):
|
|||
@console_ns.doc("delete_trace_app_config")
|
||||
@console_ns.doc(description="Delete an existing tracing configuration for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.parser().add_argument(
|
||||
"tracing_provider", type=str, required=True, location="args", help="Tracing provider name"
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[TraceProviderQuery.__name__])
|
||||
@console_ns.response(204, "Tracing configuration deleted successfully")
|
||||
@console_ns.response(400, "Invalid request parameters or configuration not found")
|
||||
@setup_required
|
||||
|
|
@ -132,11 +121,10 @@ class TraceAppConfigApi(Resource):
|
|||
@account_initialization_required
|
||||
def delete(self, app_id):
|
||||
"""Delete an existing trace app configuration"""
|
||||
parser = reqparse.RequestParser().add_argument("tracing_provider", type=str, required=True, location="args")
|
||||
args = parser.parse_args()
|
||||
args = TraceProviderQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
try:
|
||||
result = OpsService.delete_tracing_app_config(app_id=app_id, tracing_provider=args["tracing_provider"])
|
||||
result = OpsService.delete_tracing_app_config(app_id=app_id, tracing_provider=args.tracing_provider)
|
||||
if not result:
|
||||
raise TracingConfigNotExist()
|
||||
return {"result": "success"}, 204
|
||||
|
|
|
|||
|
|
@ -1,4 +1,7 @@
|
|||
from flask_restx import Resource, fields, marshal_with, reqparse
|
||||
from typing import Literal
|
||||
|
||||
from flask_restx import Resource, marshal_with
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from constants.languages import supported_language
|
||||
|
|
@ -16,69 +19,50 @@ from libs.datetime_utils import naive_utc_now
|
|||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import Site
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class AppSiteUpdatePayload(BaseModel):
|
||||
title: str | None = Field(default=None)
|
||||
icon_type: str | None = Field(default=None)
|
||||
icon: str | None = Field(default=None)
|
||||
icon_background: str | None = Field(default=None)
|
||||
description: str | None = Field(default=None)
|
||||
default_language: str | None = Field(default=None)
|
||||
chat_color_theme: str | None = Field(default=None)
|
||||
chat_color_theme_inverted: bool | None = Field(default=None)
|
||||
customize_domain: str | None = Field(default=None)
|
||||
copyright: str | None = Field(default=None)
|
||||
privacy_policy: str | None = Field(default=None)
|
||||
custom_disclaimer: str | None = Field(default=None)
|
||||
customize_token_strategy: Literal["must", "allow", "not_allow"] | None = Field(default=None)
|
||||
prompt_public: bool | None = Field(default=None)
|
||||
show_workflow_steps: bool | None = Field(default=None)
|
||||
use_icon_as_answer_icon: bool | None = Field(default=None)
|
||||
|
||||
@field_validator("default_language")
|
||||
@classmethod
|
||||
def validate_language(cls, value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return value
|
||||
return supported_language(value)
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
AppSiteUpdatePayload.__name__,
|
||||
AppSiteUpdatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
# Register model for flask_restx to avoid dict type issues in Swagger
|
||||
app_site_model = console_ns.model("AppSite", app_site_fields)
|
||||
|
||||
|
||||
def parse_app_site_args():
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("title", type=str, required=False, location="json")
|
||||
.add_argument("icon_type", type=str, required=False, location="json")
|
||||
.add_argument("icon", type=str, required=False, location="json")
|
||||
.add_argument("icon_background", type=str, required=False, location="json")
|
||||
.add_argument("description", type=str, required=False, location="json")
|
||||
.add_argument("default_language", type=supported_language, required=False, location="json")
|
||||
.add_argument("chat_color_theme", type=str, required=False, location="json")
|
||||
.add_argument("chat_color_theme_inverted", type=bool, required=False, location="json")
|
||||
.add_argument("customize_domain", type=str, required=False, location="json")
|
||||
.add_argument("copyright", type=str, required=False, location="json")
|
||||
.add_argument("privacy_policy", type=str, required=False, location="json")
|
||||
.add_argument("custom_disclaimer", type=str, required=False, location="json")
|
||||
.add_argument(
|
||||
"customize_token_strategy",
|
||||
type=str,
|
||||
choices=["must", "allow", "not_allow"],
|
||||
required=False,
|
||||
location="json",
|
||||
)
|
||||
.add_argument("prompt_public", type=bool, required=False, location="json")
|
||||
.add_argument("show_workflow_steps", type=bool, required=False, location="json")
|
||||
.add_argument("use_icon_as_answer_icon", type=bool, required=False, location="json")
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/site")
|
||||
class AppSite(Resource):
|
||||
@console_ns.doc("update_app_site")
|
||||
@console_ns.doc(description="Update application site configuration")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"AppSiteRequest",
|
||||
{
|
||||
"title": fields.String(description="Site title"),
|
||||
"icon_type": fields.String(description="Icon type"),
|
||||
"icon": fields.String(description="Icon"),
|
||||
"icon_background": fields.String(description="Icon background color"),
|
||||
"description": fields.String(description="Site description"),
|
||||
"default_language": fields.String(description="Default language"),
|
||||
"chat_color_theme": fields.String(description="Chat color theme"),
|
||||
"chat_color_theme_inverted": fields.Boolean(description="Inverted chat color theme"),
|
||||
"customize_domain": fields.String(description="Custom domain"),
|
||||
"copyright": fields.String(description="Copyright text"),
|
||||
"privacy_policy": fields.String(description="Privacy policy"),
|
||||
"custom_disclaimer": fields.String(description="Custom disclaimer"),
|
||||
"customize_token_strategy": fields.String(
|
||||
enum=["must", "allow", "not_allow"], description="Token strategy"
|
||||
),
|
||||
"prompt_public": fields.Boolean(description="Make prompt public"),
|
||||
"show_workflow_steps": fields.Boolean(description="Show workflow steps"),
|
||||
"use_icon_as_answer_icon": fields.Boolean(description="Use icon as answer icon"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[AppSiteUpdatePayload.__name__])
|
||||
@console_ns.response(200, "Site configuration updated successfully", app_site_model)
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@console_ns.response(404, "App not found")
|
||||
|
|
@ -89,7 +73,7 @@ class AppSite(Resource):
|
|||
@get_app_model
|
||||
@marshal_with(app_site_model)
|
||||
def post(self, app_model):
|
||||
args = parse_app_site_args()
|
||||
args = AppSiteUpdatePayload.model_validate(console_ns.payload or {})
|
||||
current_user, _ = current_account_with_tenant()
|
||||
site = db.session.query(Site).where(Site.app_id == app_model.id).first()
|
||||
if not site:
|
||||
|
|
@ -113,7 +97,7 @@ class AppSite(Resource):
|
|||
"show_workflow_steps",
|
||||
"use_icon_as_answer_icon",
|
||||
]:
|
||||
value = args.get(attr_name)
|
||||
value = getattr(args, attr_name)
|
||||
if value is not None:
|
||||
setattr(site, attr_name, value)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,10 +1,11 @@
|
|||
import logging
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import NoReturn, ParamSpec, TypeVar
|
||||
from typing import Any, NoReturn, ParamSpec, TypeVar
|
||||
|
||||
from flask import Response
|
||||
from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse
|
||||
from flask import Response, request
|
||||
from flask_restx import Resource, fields, marshal, marshal_with
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from controllers.console import console_ns
|
||||
|
|
@ -29,6 +30,27 @@ from services.workflow_draft_variable_service import WorkflowDraftVariableList,
|
|||
from services.workflow_service import WorkflowService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class WorkflowDraftVariableListQuery(BaseModel):
|
||||
page: int = Field(default=1, ge=1, le=100_000, description="Page number")
|
||||
limit: int = Field(default=20, ge=1, le=100, description="Items per page")
|
||||
|
||||
|
||||
class WorkflowDraftVariableUpdatePayload(BaseModel):
|
||||
name: str | None = Field(default=None, description="Variable name")
|
||||
value: Any | None = Field(default=None, description="Variable value")
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
WorkflowDraftVariableListQuery.__name__,
|
||||
WorkflowDraftVariableListQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
console_ns.schema_model(
|
||||
WorkflowDraftVariableUpdatePayload.__name__,
|
||||
WorkflowDraftVariableUpdatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
|
||||
def _convert_values_to_json_serializable_object(value: Segment):
|
||||
|
|
@ -57,22 +79,6 @@ def _serialize_var_value(variable: WorkflowDraftVariable):
|
|||
return _convert_values_to_json_serializable_object(value)
|
||||
|
||||
|
||||
def _create_pagination_parser():
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument(
|
||||
"page",
|
||||
type=inputs.int_range(1, 100_000),
|
||||
required=False,
|
||||
default=1,
|
||||
location="args",
|
||||
help="the page of data requested",
|
||||
)
|
||||
.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def _serialize_variable_type(workflow_draft_var: WorkflowDraftVariable) -> str:
|
||||
value_type = workflow_draft_var.value_type
|
||||
return value_type.exposed_type().value
|
||||
|
|
@ -201,7 +207,7 @@ def _api_prerequisite(f: Callable[P, R]):
|
|||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/variables")
|
||||
class WorkflowVariableCollectionApi(Resource):
|
||||
@console_ns.expect(_create_pagination_parser())
|
||||
@console_ns.expect(console_ns.models[WorkflowDraftVariableListQuery.__name__])
|
||||
@console_ns.doc("get_workflow_variables")
|
||||
@console_ns.doc(description="Get draft workflow variables")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
|
|
@ -215,8 +221,7 @@ class WorkflowVariableCollectionApi(Resource):
|
|||
"""
|
||||
Get draft workflow
|
||||
"""
|
||||
parser = _create_pagination_parser()
|
||||
args = parser.parse_args()
|
||||
args = WorkflowDraftVariableListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
# fetch draft workflow by app_model
|
||||
workflow_service = WorkflowService()
|
||||
|
|
@ -323,15 +328,7 @@ class VariableApi(Resource):
|
|||
|
||||
@console_ns.doc("update_variable")
|
||||
@console_ns.doc(description="Update a workflow variable")
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"UpdateVariableRequest",
|
||||
{
|
||||
"name": fields.String(description="Variable name"),
|
||||
"value": fields.Raw(description="Variable value"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[WorkflowDraftVariableUpdatePayload.__name__])
|
||||
@console_ns.response(200, "Variable updated successfully", workflow_draft_variable_model)
|
||||
@console_ns.response(404, "Variable not found")
|
||||
@_api_prerequisite
|
||||
|
|
@ -358,16 +355,10 @@ class VariableApi(Resource):
|
|||
# "upload_file_id": "1602650a-4fe4-423c-85a2-af76c083e3c4"
|
||||
# }
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument(self._PATCH_NAME_FIELD, type=str, required=False, nullable=True, location="json")
|
||||
.add_argument(self._PATCH_VALUE_FIELD, type=lambda x: x, required=False, nullable=True, location="json")
|
||||
)
|
||||
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=db.session(),
|
||||
)
|
||||
args = parser.parse_args(strict=True)
|
||||
args_model = WorkflowDraftVariableUpdatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
variable = draft_var_srv.get_variable(variable_id=variable_id)
|
||||
if variable is None:
|
||||
|
|
@ -375,8 +366,8 @@ class VariableApi(Resource):
|
|||
if variable.app_id != app_model.id:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
|
||||
new_name = args.get(self._PATCH_NAME_FIELD, None)
|
||||
raw_value = args.get(self._PATCH_VALUE_FIELD, None)
|
||||
new_name = args_model.name
|
||||
raw_value = args_model.value
|
||||
if new_name is None and raw_value is None:
|
||||
return variable
|
||||
|
||||
|
|
|
|||
|
|
@ -114,7 +114,7 @@ class AppTriggersApi(Resource):
|
|||
|
||||
@console_ns.route("/apps/<uuid:app_id>/trigger-enable")
|
||||
class AppTriggerEnableApi(Resource):
|
||||
@console_ns.expect(console_ns.models[ParserEnable.__name__], validate=True)
|
||||
@console_ns.expect(console_ns.models[ParserEnable.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
|
|||
|
|
@ -1,28 +1,53 @@
|
|||
from flask import request
|
||||
from flask_restx import Resource, fields, reqparse
|
||||
from flask_restx import Resource, fields
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from constants.languages import supported_language
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.error import AlreadyActivateError
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.helper import StrLen, email, extract_remote_ip, timezone
|
||||
from libs.helper import EmailStr, extract_remote_ip, timezone
|
||||
from models import AccountStatus
|
||||
from services.account_service import AccountService, RegisterService
|
||||
|
||||
active_check_parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("workspace_id", type=str, required=False, nullable=True, location="args", help="Workspace ID")
|
||||
.add_argument("email", type=email, required=False, nullable=True, location="args", help="Email address")
|
||||
.add_argument("token", type=str, required=True, nullable=False, location="args", help="Activation token")
|
||||
)
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class ActivateCheckQuery(BaseModel):
|
||||
workspace_id: str | None = Field(default=None)
|
||||
email: EmailStr | None = Field(default=None)
|
||||
token: str
|
||||
|
||||
|
||||
class ActivatePayload(BaseModel):
|
||||
workspace_id: str | None = Field(default=None)
|
||||
email: EmailStr | None = Field(default=None)
|
||||
token: str
|
||||
name: str = Field(..., max_length=30)
|
||||
interface_language: str = Field(...)
|
||||
timezone: str = Field(...)
|
||||
|
||||
@field_validator("interface_language")
|
||||
@classmethod
|
||||
def validate_lang(cls, value: str) -> str:
|
||||
return supported_language(value)
|
||||
|
||||
@field_validator("timezone")
|
||||
@classmethod
|
||||
def validate_tz(cls, value: str) -> str:
|
||||
return timezone(value)
|
||||
|
||||
|
||||
for model in (ActivateCheckQuery, ActivatePayload):
|
||||
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
|
||||
@console_ns.route("/activate/check")
|
||||
class ActivateCheckApi(Resource):
|
||||
@console_ns.doc("check_activation_token")
|
||||
@console_ns.doc(description="Check if activation token is valid")
|
||||
@console_ns.expect(active_check_parser)
|
||||
@console_ns.expect(console_ns.models[ActivateCheckQuery.__name__])
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Success",
|
||||
|
|
@ -35,11 +60,11 @@ class ActivateCheckApi(Resource):
|
|||
),
|
||||
)
|
||||
def get(self):
|
||||
args = active_check_parser.parse_args()
|
||||
args = ActivateCheckQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
workspaceId = args["workspace_id"]
|
||||
reg_email = args["email"]
|
||||
token = args["token"]
|
||||
workspaceId = args.workspace_id
|
||||
reg_email = args.email
|
||||
token = args.token
|
||||
|
||||
invitation = RegisterService.get_invitation_if_token_valid(workspaceId, reg_email, token)
|
||||
if invitation:
|
||||
|
|
@ -56,22 +81,11 @@ class ActivateCheckApi(Resource):
|
|||
return {"is_valid": False}
|
||||
|
||||
|
||||
active_parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("workspace_id", type=str, required=False, nullable=True, location="json")
|
||||
.add_argument("email", type=email, required=False, nullable=True, location="json")
|
||||
.add_argument("token", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json")
|
||||
.add_argument("interface_language", type=supported_language, required=True, nullable=False, location="json")
|
||||
.add_argument("timezone", type=timezone, required=True, nullable=False, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/activate")
|
||||
class ActivateApi(Resource):
|
||||
@console_ns.doc("activate_account")
|
||||
@console_ns.doc(description="Activate account with invitation token")
|
||||
@console_ns.expect(active_parser)
|
||||
@console_ns.expect(console_ns.models[ActivatePayload.__name__])
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Account activated successfully",
|
||||
|
|
@ -85,19 +99,19 @@ class ActivateApi(Resource):
|
|||
)
|
||||
@console_ns.response(400, "Already activated or invalid token")
|
||||
def post(self):
|
||||
args = active_parser.parse_args()
|
||||
args = ActivatePayload.model_validate(console_ns.payload)
|
||||
|
||||
invitation = RegisterService.get_invitation_if_token_valid(args["workspace_id"], args["email"], args["token"])
|
||||
invitation = RegisterService.get_invitation_if_token_valid(args.workspace_id, args.email, args.token)
|
||||
if invitation is None:
|
||||
raise AlreadyActivateError()
|
||||
|
||||
RegisterService.revoke_token(args["workspace_id"], args["email"], args["token"])
|
||||
RegisterService.revoke_token(args.workspace_id, args.email, args.token)
|
||||
|
||||
account = invitation["account"]
|
||||
account.name = args["name"]
|
||||
account.name = args.name
|
||||
|
||||
account.interface_language = args["interface_language"]
|
||||
account.timezone = args["timezone"]
|
||||
account.interface_language = args.interface_language
|
||||
account.timezone = args.timezone
|
||||
account.interface_theme = "light"
|
||||
account.status = AccountStatus.ACTIVE
|
||||
account.initialized_at = naive_utc_now()
|
||||
|
|
|
|||
|
|
@ -1,12 +1,26 @@
|
|||
from flask_restx import Resource, reqparse
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.auth.error import ApiKeyAuthFailedError
|
||||
from controllers.console.wraps import is_admin_or_owner_required
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from services.auth.api_key_auth_service import ApiKeyAuthService
|
||||
|
||||
from ..wraps import account_initialization_required, setup_required
|
||||
from .. import console_ns
|
||||
from ..auth.error import ApiKeyAuthFailedError
|
||||
from ..wraps import account_initialization_required, is_admin_or_owner_required, setup_required
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class ApiKeyAuthBindingPayload(BaseModel):
|
||||
category: str = Field(...)
|
||||
provider: str = Field(...)
|
||||
credentials: dict = Field(...)
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
ApiKeyAuthBindingPayload.__name__,
|
||||
ApiKeyAuthBindingPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/api-key-auth/data-source")
|
||||
|
|
@ -40,19 +54,15 @@ class ApiKeyAuthDataSourceBinding(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
@is_admin_or_owner_required
|
||||
@console_ns.expect(console_ns.models[ApiKeyAuthBindingPayload.__name__])
|
||||
def post(self):
|
||||
# The role of the current user in the table must be admin or owner
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("category", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("provider", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
ApiKeyAuthService.validate_api_key_auth_args(args)
|
||||
payload = ApiKeyAuthBindingPayload.model_validate(console_ns.payload)
|
||||
data = payload.model_dump()
|
||||
ApiKeyAuthService.validate_api_key_auth_args(data)
|
||||
try:
|
||||
ApiKeyAuthService.create_provider_auth(current_tenant_id, args)
|
||||
ApiKeyAuthService.create_provider_auth(current_tenant_id, data)
|
||||
except Exception as e:
|
||||
raise ApiKeyAuthFailedError(str(e))
|
||||
return {"result": "success"}, 200
|
||||
|
|
|
|||
|
|
@ -5,12 +5,11 @@ from flask import current_app, redirect, request
|
|||
from flask_restx import Resource, fields
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import is_admin_or_owner_required
|
||||
from libs.login import login_required
|
||||
from libs.oauth_data_source import NotionOAuth
|
||||
|
||||
from ..wraps import account_initialization_required, setup_required
|
||||
from .. import console_ns
|
||||
from ..wraps import account_initialization_required, is_admin_or_owner_required, setup_required
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
from flask import request
|
||||
from flask_restx import Resource, reqparse
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
|
|
@ -14,16 +15,45 @@ from controllers.console.auth.error import (
|
|||
InvalidTokenError,
|
||||
PasswordMismatchError,
|
||||
)
|
||||
from controllers.console.error import AccountInFreezeError, EmailSendIpLimitError
|
||||
from controllers.console.wraps import email_password_login_enabled, email_register_enabled, setup_required
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import email, extract_remote_ip
|
||||
from libs.helper import EmailStr, extract_remote_ip
|
||||
from libs.password import valid_password
|
||||
from models import Account
|
||||
from services.account_service import AccountService
|
||||
from services.billing_service import BillingService
|
||||
from services.errors.account import AccountNotFoundError, AccountRegisterError
|
||||
|
||||
from ..error import AccountInFreezeError, EmailSendIpLimitError
|
||||
from ..wraps import email_password_login_enabled, email_register_enabled, setup_required
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class EmailRegisterSendPayload(BaseModel):
|
||||
email: EmailStr = Field(..., description="Email address")
|
||||
language: str | None = Field(default=None, description="Language code")
|
||||
|
||||
|
||||
class EmailRegisterValidityPayload(BaseModel):
|
||||
email: EmailStr = Field(...)
|
||||
code: str = Field(...)
|
||||
token: str = Field(...)
|
||||
|
||||
|
||||
class EmailRegisterResetPayload(BaseModel):
|
||||
token: str = Field(...)
|
||||
new_password: str = Field(...)
|
||||
password_confirm: str = Field(...)
|
||||
|
||||
@field_validator("new_password", "password_confirm")
|
||||
@classmethod
|
||||
def validate_password(cls, value: str) -> str:
|
||||
return valid_password(value)
|
||||
|
||||
|
||||
for model in (EmailRegisterSendPayload, EmailRegisterValidityPayload, EmailRegisterResetPayload):
|
||||
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
|
||||
@console_ns.route("/email-register/send-email")
|
||||
class EmailRegisterSendEmailApi(Resource):
|
||||
|
|
@ -31,27 +61,22 @@ class EmailRegisterSendEmailApi(Resource):
|
|||
@email_password_login_enabled
|
||||
@email_register_enabled
|
||||
def post(self):
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("email", type=email, required=True, location="json")
|
||||
.add_argument("language", type=str, required=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = EmailRegisterSendPayload.model_validate(console_ns.payload)
|
||||
|
||||
ip_address = extract_remote_ip(request)
|
||||
if AccountService.is_email_send_ip_limit(ip_address):
|
||||
raise EmailSendIpLimitError()
|
||||
language = "en-US"
|
||||
if args["language"] in languages:
|
||||
language = args["language"]
|
||||
if args.language in languages:
|
||||
language = args.language
|
||||
|
||||
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(args["email"]):
|
||||
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(args.email):
|
||||
raise AccountInFreezeError()
|
||||
|
||||
with Session(db.engine) as session:
|
||||
account = session.execute(select(Account).filter_by(email=args["email"])).scalar_one_or_none()
|
||||
account = session.execute(select(Account).filter_by(email=args.email)).scalar_one_or_none()
|
||||
token = None
|
||||
token = AccountService.send_email_register_email(email=args["email"], account=account, language=language)
|
||||
token = AccountService.send_email_register_email(email=args.email, account=account, language=language)
|
||||
return {"result": "success", "data": token}
|
||||
|
||||
|
||||
|
|
@ -61,40 +86,34 @@ class EmailRegisterCheckApi(Resource):
|
|||
@email_password_login_enabled
|
||||
@email_register_enabled
|
||||
def post(self):
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("email", type=str, required=True, location="json")
|
||||
.add_argument("code", type=str, required=True, location="json")
|
||||
.add_argument("token", type=str, required=True, nullable=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = EmailRegisterValidityPayload.model_validate(console_ns.payload)
|
||||
|
||||
user_email = args["email"]
|
||||
user_email = args.email
|
||||
|
||||
is_email_register_error_rate_limit = AccountService.is_email_register_error_rate_limit(args["email"])
|
||||
is_email_register_error_rate_limit = AccountService.is_email_register_error_rate_limit(args.email)
|
||||
if is_email_register_error_rate_limit:
|
||||
raise EmailRegisterLimitError()
|
||||
|
||||
token_data = AccountService.get_email_register_data(args["token"])
|
||||
token_data = AccountService.get_email_register_data(args.token)
|
||||
if token_data is None:
|
||||
raise InvalidTokenError()
|
||||
|
||||
if user_email != token_data.get("email"):
|
||||
raise InvalidEmailError()
|
||||
|
||||
if args["code"] != token_data.get("code"):
|
||||
AccountService.add_email_register_error_rate_limit(args["email"])
|
||||
if args.code != token_data.get("code"):
|
||||
AccountService.add_email_register_error_rate_limit(args.email)
|
||||
raise EmailCodeError()
|
||||
|
||||
# Verified, revoke the first token
|
||||
AccountService.revoke_email_register_token(args["token"])
|
||||
AccountService.revoke_email_register_token(args.token)
|
||||
|
||||
# Refresh token data by generating a new token
|
||||
_, new_token = AccountService.generate_email_register_token(
|
||||
user_email, code=args["code"], additional_data={"phase": "register"}
|
||||
user_email, code=args.code, additional_data={"phase": "register"}
|
||||
)
|
||||
|
||||
AccountService.reset_email_register_error_rate_limit(args["email"])
|
||||
AccountService.reset_email_register_error_rate_limit(args.email)
|
||||
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
|
||||
|
||||
|
||||
|
|
@ -104,20 +123,14 @@ class EmailRegisterResetApi(Resource):
|
|||
@email_password_login_enabled
|
||||
@email_register_enabled
|
||||
def post(self):
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("token", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("new_password", type=valid_password, required=True, nullable=False, location="json")
|
||||
.add_argument("password_confirm", type=valid_password, required=True, nullable=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = EmailRegisterResetPayload.model_validate(console_ns.payload)
|
||||
|
||||
# Validate passwords match
|
||||
if args["new_password"] != args["password_confirm"]:
|
||||
if args.new_password != args.password_confirm:
|
||||
raise PasswordMismatchError()
|
||||
|
||||
# Validate token and get register data
|
||||
register_data = AccountService.get_email_register_data(args["token"])
|
||||
register_data = AccountService.get_email_register_data(args.token)
|
||||
if not register_data:
|
||||
raise InvalidTokenError()
|
||||
# Must use token in reset phase
|
||||
|
|
@ -125,7 +138,7 @@ class EmailRegisterResetApi(Resource):
|
|||
raise InvalidTokenError()
|
||||
|
||||
# Revoke token to prevent reuse
|
||||
AccountService.revoke_email_register_token(args["token"])
|
||||
AccountService.revoke_email_register_token(args.token)
|
||||
|
||||
email = register_data.get("email", "")
|
||||
|
||||
|
|
@ -135,7 +148,7 @@ class EmailRegisterResetApi(Resource):
|
|||
if account:
|
||||
raise EmailAlreadyInUseError()
|
||||
else:
|
||||
account = self._create_new_account(email, args["password_confirm"])
|
||||
account = self._create_new_account(email, args.password_confirm)
|
||||
if not account:
|
||||
raise AccountNotFoundError()
|
||||
token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request))
|
||||
|
|
|
|||
|
|
@ -2,7 +2,8 @@ import base64
|
|||
import secrets
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields, reqparse
|
||||
from flask_restx import Resource, fields
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
|
|
@ -18,26 +19,46 @@ from controllers.console.error import AccountNotFound, EmailSendIpLimitError
|
|||
from controllers.console.wraps import email_password_login_enabled, setup_required
|
||||
from events.tenant_event import tenant_was_created
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import email, extract_remote_ip
|
||||
from libs.helper import EmailStr, extract_remote_ip
|
||||
from libs.password import hash_password, valid_password
|
||||
from models import Account
|
||||
from services.account_service import AccountService, TenantService
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class ForgotPasswordSendPayload(BaseModel):
|
||||
email: EmailStr = Field(...)
|
||||
language: str | None = Field(default=None)
|
||||
|
||||
|
||||
class ForgotPasswordCheckPayload(BaseModel):
|
||||
email: EmailStr = Field(...)
|
||||
code: str = Field(...)
|
||||
token: str = Field(...)
|
||||
|
||||
|
||||
class ForgotPasswordResetPayload(BaseModel):
|
||||
token: str = Field(...)
|
||||
new_password: str = Field(...)
|
||||
password_confirm: str = Field(...)
|
||||
|
||||
@field_validator("new_password", "password_confirm")
|
||||
@classmethod
|
||||
def validate_password(cls, value: str) -> str:
|
||||
return valid_password(value)
|
||||
|
||||
|
||||
for model in (ForgotPasswordSendPayload, ForgotPasswordCheckPayload, ForgotPasswordResetPayload):
|
||||
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
|
||||
@console_ns.route("/forgot-password")
|
||||
class ForgotPasswordSendEmailApi(Resource):
|
||||
@console_ns.doc("send_forgot_password_email")
|
||||
@console_ns.doc(description="Send password reset email")
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"ForgotPasswordEmailRequest",
|
||||
{
|
||||
"email": fields.String(required=True, description="Email address"),
|
||||
"language": fields.String(description="Language for email (zh-Hans/en-US)"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[ForgotPasswordSendPayload.__name__])
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Email sent successfully",
|
||||
|
|
@ -54,28 +75,23 @@ class ForgotPasswordSendEmailApi(Resource):
|
|||
@setup_required
|
||||
@email_password_login_enabled
|
||||
def post(self):
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("email", type=email, required=True, location="json")
|
||||
.add_argument("language", type=str, required=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = ForgotPasswordSendPayload.model_validate(console_ns.payload)
|
||||
|
||||
ip_address = extract_remote_ip(request)
|
||||
if AccountService.is_email_send_ip_limit(ip_address):
|
||||
raise EmailSendIpLimitError()
|
||||
|
||||
if args["language"] is not None and args["language"] == "zh-Hans":
|
||||
if args.language is not None and args.language == "zh-Hans":
|
||||
language = "zh-Hans"
|
||||
else:
|
||||
language = "en-US"
|
||||
|
||||
with Session(db.engine) as session:
|
||||
account = session.execute(select(Account).filter_by(email=args["email"])).scalar_one_or_none()
|
||||
account = session.execute(select(Account).filter_by(email=args.email)).scalar_one_or_none()
|
||||
|
||||
token = AccountService.send_reset_password_email(
|
||||
account=account,
|
||||
email=args["email"],
|
||||
email=args.email,
|
||||
language=language,
|
||||
is_allow_register=FeatureService.get_system_features().is_allow_register,
|
||||
)
|
||||
|
|
@ -87,16 +103,7 @@ class ForgotPasswordSendEmailApi(Resource):
|
|||
class ForgotPasswordCheckApi(Resource):
|
||||
@console_ns.doc("check_forgot_password_code")
|
||||
@console_ns.doc(description="Verify password reset code")
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"ForgotPasswordCheckRequest",
|
||||
{
|
||||
"email": fields.String(required=True, description="Email address"),
|
||||
"code": fields.String(required=True, description="Verification code"),
|
||||
"token": fields.String(required=True, description="Reset token"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[ForgotPasswordCheckPayload.__name__])
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Code verified successfully",
|
||||
|
|
@ -113,40 +120,34 @@ class ForgotPasswordCheckApi(Resource):
|
|||
@setup_required
|
||||
@email_password_login_enabled
|
||||
def post(self):
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("email", type=str, required=True, location="json")
|
||||
.add_argument("code", type=str, required=True, location="json")
|
||||
.add_argument("token", type=str, required=True, nullable=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = ForgotPasswordCheckPayload.model_validate(console_ns.payload)
|
||||
|
||||
user_email = args["email"]
|
||||
user_email = args.email
|
||||
|
||||
is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(args["email"])
|
||||
is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(args.email)
|
||||
if is_forgot_password_error_rate_limit:
|
||||
raise EmailPasswordResetLimitError()
|
||||
|
||||
token_data = AccountService.get_reset_password_data(args["token"])
|
||||
token_data = AccountService.get_reset_password_data(args.token)
|
||||
if token_data is None:
|
||||
raise InvalidTokenError()
|
||||
|
||||
if user_email != token_data.get("email"):
|
||||
raise InvalidEmailError()
|
||||
|
||||
if args["code"] != token_data.get("code"):
|
||||
AccountService.add_forgot_password_error_rate_limit(args["email"])
|
||||
if args.code != token_data.get("code"):
|
||||
AccountService.add_forgot_password_error_rate_limit(args.email)
|
||||
raise EmailCodeError()
|
||||
|
||||
# Verified, revoke the first token
|
||||
AccountService.revoke_reset_password_token(args["token"])
|
||||
AccountService.revoke_reset_password_token(args.token)
|
||||
|
||||
# Refresh token data by generating a new token
|
||||
_, new_token = AccountService.generate_reset_password_token(
|
||||
user_email, code=args["code"], additional_data={"phase": "reset"}
|
||||
user_email, code=args.code, additional_data={"phase": "reset"}
|
||||
)
|
||||
|
||||
AccountService.reset_forgot_password_error_rate_limit(args["email"])
|
||||
AccountService.reset_forgot_password_error_rate_limit(args.email)
|
||||
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
|
||||
|
||||
|
||||
|
|
@ -154,16 +155,7 @@ class ForgotPasswordCheckApi(Resource):
|
|||
class ForgotPasswordResetApi(Resource):
|
||||
@console_ns.doc("reset_password")
|
||||
@console_ns.doc(description="Reset password with verification token")
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"ForgotPasswordResetRequest",
|
||||
{
|
||||
"token": fields.String(required=True, description="Verification token"),
|
||||
"new_password": fields.String(required=True, description="New password"),
|
||||
"password_confirm": fields.String(required=True, description="Password confirmation"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[ForgotPasswordResetPayload.__name__])
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Password reset successfully",
|
||||
|
|
@ -173,20 +165,14 @@ class ForgotPasswordResetApi(Resource):
|
|||
@setup_required
|
||||
@email_password_login_enabled
|
||||
def post(self):
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("token", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("new_password", type=valid_password, required=True, nullable=False, location="json")
|
||||
.add_argument("password_confirm", type=valid_password, required=True, nullable=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = ForgotPasswordResetPayload.model_validate(console_ns.payload)
|
||||
|
||||
# Validate passwords match
|
||||
if args["new_password"] != args["password_confirm"]:
|
||||
if args.new_password != args.password_confirm:
|
||||
raise PasswordMismatchError()
|
||||
|
||||
# Validate token and get reset data
|
||||
reset_data = AccountService.get_reset_password_data(args["token"])
|
||||
reset_data = AccountService.get_reset_password_data(args.token)
|
||||
if not reset_data:
|
||||
raise InvalidTokenError()
|
||||
# Must use token in reset phase
|
||||
|
|
@ -194,11 +180,11 @@ class ForgotPasswordResetApi(Resource):
|
|||
raise InvalidTokenError()
|
||||
|
||||
# Revoke token to prevent reuse
|
||||
AccountService.revoke_reset_password_token(args["token"])
|
||||
AccountService.revoke_reset_password_token(args.token)
|
||||
|
||||
# Generate secure salt and hash password
|
||||
salt = secrets.token_bytes(16)
|
||||
password_hashed = hash_password(args["new_password"], salt)
|
||||
password_hashed = hash_password(args.new_password, salt)
|
||||
|
||||
email = reset_data.get("email", "")
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
import flask_login
|
||||
from flask import make_response, request
|
||||
from flask_restx import Resource, reqparse
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
import services
|
||||
from configs import dify_config
|
||||
|
|
@ -23,7 +24,7 @@ from controllers.console.error import (
|
|||
)
|
||||
from controllers.console.wraps import email_password_login_enabled, setup_required
|
||||
from events.tenant_event import tenant_was_created
|
||||
from libs.helper import email, extract_remote_ip
|
||||
from libs.helper import EmailStr, extract_remote_ip
|
||||
from libs.login import current_account_with_tenant
|
||||
from libs.token import (
|
||||
clear_access_token_from_cookie,
|
||||
|
|
@ -40,6 +41,36 @@ from services.errors.account import AccountRegisterError
|
|||
from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkspacesLimitExceededError
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class LoginPayload(BaseModel):
|
||||
email: EmailStr = Field(..., description="Email address")
|
||||
password: str = Field(..., description="Password")
|
||||
remember_me: bool = Field(default=False, description="Remember me flag")
|
||||
invite_token: str | None = Field(default=None, description="Invitation token")
|
||||
|
||||
|
||||
class EmailPayload(BaseModel):
|
||||
email: EmailStr = Field(...)
|
||||
language: str | None = Field(default=None)
|
||||
|
||||
|
||||
class EmailCodeLoginPayload(BaseModel):
|
||||
email: EmailStr = Field(...)
|
||||
code: str = Field(...)
|
||||
token: str = Field(...)
|
||||
language: str | None = Field(default=None)
|
||||
|
||||
|
||||
def reg(cls: type[BaseModel]):
|
||||
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
|
||||
reg(LoginPayload)
|
||||
reg(EmailPayload)
|
||||
reg(EmailCodeLoginPayload)
|
||||
|
||||
|
||||
@console_ns.route("/login")
|
||||
class LoginApi(Resource):
|
||||
|
|
@ -47,41 +78,36 @@ class LoginApi(Resource):
|
|||
|
||||
@setup_required
|
||||
@email_password_login_enabled
|
||||
@console_ns.expect(console_ns.models[LoginPayload.__name__])
|
||||
def post(self):
|
||||
"""Authenticate user and login."""
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("email", type=email, required=True, location="json")
|
||||
.add_argument("password", type=str, required=True, location="json")
|
||||
.add_argument("remember_me", type=bool, required=False, default=False, location="json")
|
||||
.add_argument("invite_token", type=str, required=False, default=None, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = LoginPayload.model_validate(console_ns.payload)
|
||||
|
||||
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(args["email"]):
|
||||
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(args.email):
|
||||
raise AccountInFreezeError()
|
||||
|
||||
is_login_error_rate_limit = AccountService.is_login_error_rate_limit(args["email"])
|
||||
is_login_error_rate_limit = AccountService.is_login_error_rate_limit(args.email)
|
||||
if is_login_error_rate_limit:
|
||||
raise EmailPasswordLoginLimitError()
|
||||
|
||||
invitation = args["invite_token"]
|
||||
# TODO: why invitation is re-assigned with different type?
|
||||
invitation = args.invite_token # type: ignore
|
||||
if invitation:
|
||||
invitation = RegisterService.get_invitation_if_token_valid(None, args["email"], invitation)
|
||||
invitation = RegisterService.get_invitation_if_token_valid(None, args.email, invitation) # type: ignore
|
||||
|
||||
try:
|
||||
if invitation:
|
||||
data = invitation.get("data", {})
|
||||
data = invitation.get("data", {}) # type: ignore
|
||||
invitee_email = data.get("email") if data else None
|
||||
if invitee_email != args["email"]:
|
||||
if invitee_email != args.email:
|
||||
raise InvalidEmailError()
|
||||
account = AccountService.authenticate(args["email"], args["password"], args["invite_token"])
|
||||
account = AccountService.authenticate(args.email, args.password, args.invite_token)
|
||||
else:
|
||||
account = AccountService.authenticate(args["email"], args["password"])
|
||||
account = AccountService.authenticate(args.email, args.password)
|
||||
except services.errors.account.AccountLoginError:
|
||||
raise AccountBannedError()
|
||||
except services.errors.account.AccountPasswordError:
|
||||
AccountService.add_login_error_rate_limit(args["email"])
|
||||
AccountService.add_login_error_rate_limit(args.email)
|
||||
raise AuthenticationFailedError()
|
||||
# SELF_HOSTED only have one workspace
|
||||
tenants = TenantService.get_join_tenants(account)
|
||||
|
|
@ -97,7 +123,7 @@ class LoginApi(Resource):
|
|||
}
|
||||
|
||||
token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request))
|
||||
AccountService.reset_login_error_rate_limit(args["email"])
|
||||
AccountService.reset_login_error_rate_limit(args.email)
|
||||
|
||||
# Create response with cookies instead of returning tokens in body
|
||||
response = make_response({"result": "success"})
|
||||
|
|
@ -134,25 +160,21 @@ class LogoutApi(Resource):
|
|||
class ResetPasswordSendEmailApi(Resource):
|
||||
@setup_required
|
||||
@email_password_login_enabled
|
||||
@console_ns.expect(console_ns.models[EmailPayload.__name__])
|
||||
def post(self):
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("email", type=email, required=True, location="json")
|
||||
.add_argument("language", type=str, required=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = EmailPayload.model_validate(console_ns.payload)
|
||||
|
||||
if args["language"] is not None and args["language"] == "zh-Hans":
|
||||
if args.language is not None and args.language == "zh-Hans":
|
||||
language = "zh-Hans"
|
||||
else:
|
||||
language = "en-US"
|
||||
try:
|
||||
account = AccountService.get_user_through_email(args["email"])
|
||||
account = AccountService.get_user_through_email(args.email)
|
||||
except AccountRegisterError:
|
||||
raise AccountInFreezeError()
|
||||
|
||||
token = AccountService.send_reset_password_email(
|
||||
email=args["email"],
|
||||
email=args.email,
|
||||
account=account,
|
||||
language=language,
|
||||
is_allow_register=FeatureService.get_system_features().is_allow_register,
|
||||
|
|
@ -164,30 +186,26 @@ class ResetPasswordSendEmailApi(Resource):
|
|||
@console_ns.route("/email-code-login")
|
||||
class EmailCodeLoginSendEmailApi(Resource):
|
||||
@setup_required
|
||||
@console_ns.expect(console_ns.models[EmailPayload.__name__])
|
||||
def post(self):
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("email", type=email, required=True, location="json")
|
||||
.add_argument("language", type=str, required=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = EmailPayload.model_validate(console_ns.payload)
|
||||
|
||||
ip_address = extract_remote_ip(request)
|
||||
if AccountService.is_email_send_ip_limit(ip_address):
|
||||
raise EmailSendIpLimitError()
|
||||
|
||||
if args["language"] is not None and args["language"] == "zh-Hans":
|
||||
if args.language is not None and args.language == "zh-Hans":
|
||||
language = "zh-Hans"
|
||||
else:
|
||||
language = "en-US"
|
||||
try:
|
||||
account = AccountService.get_user_through_email(args["email"])
|
||||
account = AccountService.get_user_through_email(args.email)
|
||||
except AccountRegisterError:
|
||||
raise AccountInFreezeError()
|
||||
|
||||
if account is None:
|
||||
if FeatureService.get_system_features().is_allow_register:
|
||||
token = AccountService.send_email_code_login_email(email=args["email"], language=language)
|
||||
token = AccountService.send_email_code_login_email(email=args.email, language=language)
|
||||
else:
|
||||
raise AccountNotFound()
|
||||
else:
|
||||
|
|
@ -199,30 +217,24 @@ class EmailCodeLoginSendEmailApi(Resource):
|
|||
@console_ns.route("/email-code-login/validity")
|
||||
class EmailCodeLoginApi(Resource):
|
||||
@setup_required
|
||||
@console_ns.expect(console_ns.models[EmailCodeLoginPayload.__name__])
|
||||
def post(self):
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("email", type=str, required=True, location="json")
|
||||
.add_argument("code", type=str, required=True, location="json")
|
||||
.add_argument("token", type=str, required=True, location="json")
|
||||
.add_argument("language", type=str, required=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = EmailCodeLoginPayload.model_validate(console_ns.payload)
|
||||
|
||||
user_email = args["email"]
|
||||
language = args["language"]
|
||||
user_email = args.email
|
||||
language = args.language
|
||||
|
||||
token_data = AccountService.get_email_code_login_data(args["token"])
|
||||
token_data = AccountService.get_email_code_login_data(args.token)
|
||||
if token_data is None:
|
||||
raise InvalidTokenError()
|
||||
|
||||
if token_data["email"] != args["email"]:
|
||||
if token_data["email"] != args.email:
|
||||
raise InvalidEmailError()
|
||||
|
||||
if token_data["code"] != args["code"]:
|
||||
if token_data["code"] != args.code:
|
||||
raise EmailCodeError()
|
||||
|
||||
AccountService.revoke_email_code_login_token(args["token"])
|
||||
AccountService.revoke_email_code_login_token(args.token)
|
||||
try:
|
||||
account = AccountService.get_user_through_email(user_email)
|
||||
except AccountRegisterError:
|
||||
|
|
@ -255,7 +267,7 @@ class EmailCodeLoginApi(Resource):
|
|||
except WorkspacesLimitExceededError:
|
||||
raise WorkspacesLimitExceeded()
|
||||
token_pair = AccountService.login(account, ip_address=extract_remote_ip(request))
|
||||
AccountService.reset_login_error_rate_limit(args["email"])
|
||||
AccountService.reset_login_error_rate_limit(args.email)
|
||||
|
||||
# Create response with cookies instead of returning tokens in body
|
||||
response = make_response({"result": "success"})
|
||||
|
|
|
|||
|
|
@ -3,7 +3,8 @@ from functools import wraps
|
|||
from typing import Concatenate, ParamSpec, TypeVar
|
||||
|
||||
from flask import jsonify, request
|
||||
from flask_restx import Resource, reqparse
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel
|
||||
from werkzeug.exceptions import BadRequest, NotFound
|
||||
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
|
|
@ -20,15 +21,34 @@ R = TypeVar("R")
|
|||
T = TypeVar("T")
|
||||
|
||||
|
||||
class OAuthClientPayload(BaseModel):
|
||||
client_id: str
|
||||
|
||||
|
||||
class OAuthProviderRequest(BaseModel):
|
||||
client_id: str
|
||||
redirect_uri: str
|
||||
|
||||
|
||||
class OAuthTokenRequest(BaseModel):
|
||||
client_id: str
|
||||
grant_type: str
|
||||
code: str | None = None
|
||||
client_secret: str | None = None
|
||||
redirect_uri: str | None = None
|
||||
refresh_token: str | None = None
|
||||
|
||||
|
||||
def oauth_server_client_id_required(view: Callable[Concatenate[T, OAuthProviderApp, P], R]):
|
||||
@wraps(view)
|
||||
def decorated(self: T, *args: P.args, **kwargs: P.kwargs):
|
||||
parser = reqparse.RequestParser().add_argument("client_id", type=str, required=True, location="json")
|
||||
parsed_args = parser.parse_args()
|
||||
client_id = parsed_args.get("client_id")
|
||||
if not client_id:
|
||||
json_data = request.get_json()
|
||||
if json_data is None:
|
||||
raise BadRequest("client_id is required")
|
||||
|
||||
payload = OAuthClientPayload.model_validate(json_data)
|
||||
client_id = payload.client_id
|
||||
|
||||
oauth_provider_app = OAuthServerService.get_oauth_provider_app(client_id)
|
||||
if not oauth_provider_app:
|
||||
raise NotFound("client_id is invalid")
|
||||
|
|
@ -89,9 +109,8 @@ class OAuthServerAppApi(Resource):
|
|||
@setup_required
|
||||
@oauth_server_client_id_required
|
||||
def post(self, oauth_provider_app: OAuthProviderApp):
|
||||
parser = reqparse.RequestParser().add_argument("redirect_uri", type=str, required=True, location="json")
|
||||
parsed_args = parser.parse_args()
|
||||
redirect_uri = parsed_args.get("redirect_uri")
|
||||
payload = OAuthProviderRequest.model_validate(request.get_json())
|
||||
redirect_uri = payload.redirect_uri
|
||||
|
||||
# check if redirect_uri is valid
|
||||
if redirect_uri not in oauth_provider_app.redirect_uris:
|
||||
|
|
@ -130,33 +149,25 @@ class OAuthServerUserTokenApi(Resource):
|
|||
@setup_required
|
||||
@oauth_server_client_id_required
|
||||
def post(self, oauth_provider_app: OAuthProviderApp):
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("grant_type", type=str, required=True, location="json")
|
||||
.add_argument("code", type=str, required=False, location="json")
|
||||
.add_argument("client_secret", type=str, required=False, location="json")
|
||||
.add_argument("redirect_uri", type=str, required=False, location="json")
|
||||
.add_argument("refresh_token", type=str, required=False, location="json")
|
||||
)
|
||||
parsed_args = parser.parse_args()
|
||||
payload = OAuthTokenRequest.model_validate(request.get_json())
|
||||
|
||||
try:
|
||||
grant_type = OAuthGrantType(parsed_args["grant_type"])
|
||||
grant_type = OAuthGrantType(payload.grant_type)
|
||||
except ValueError:
|
||||
raise BadRequest("invalid grant_type")
|
||||
|
||||
if grant_type == OAuthGrantType.AUTHORIZATION_CODE:
|
||||
if not parsed_args["code"]:
|
||||
if not payload.code:
|
||||
raise BadRequest("code is required")
|
||||
|
||||
if parsed_args["client_secret"] != oauth_provider_app.client_secret:
|
||||
if payload.client_secret != oauth_provider_app.client_secret:
|
||||
raise BadRequest("client_secret is invalid")
|
||||
|
||||
if parsed_args["redirect_uri"] not in oauth_provider_app.redirect_uris:
|
||||
if payload.redirect_uri not in oauth_provider_app.redirect_uris:
|
||||
raise BadRequest("redirect_uri is invalid")
|
||||
|
||||
access_token, refresh_token = OAuthServerService.sign_oauth_access_token(
|
||||
grant_type, code=parsed_args["code"], client_id=oauth_provider_app.client_id
|
||||
grant_type, code=payload.code, client_id=oauth_provider_app.client_id
|
||||
)
|
||||
return jsonable_encoder(
|
||||
{
|
||||
|
|
@ -167,11 +178,11 @@ class OAuthServerUserTokenApi(Resource):
|
|||
}
|
||||
)
|
||||
elif grant_type == OAuthGrantType.REFRESH_TOKEN:
|
||||
if not parsed_args["refresh_token"]:
|
||||
if not payload.refresh_token:
|
||||
raise BadRequest("refresh_token is required")
|
||||
|
||||
access_token, refresh_token = OAuthServerService.sign_oauth_access_token(
|
||||
grant_type, refresh_token=parsed_args["refresh_token"], client_id=oauth_provider_app.client_id
|
||||
grant_type, refresh_token=payload.refresh_token, client_id=oauth_provider_app.client_id
|
||||
)
|
||||
return jsonable_encoder(
|
||||
{
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
import base64
|
||||
|
||||
from flask_restx import Resource, fields, reqparse
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from werkzeug.exceptions import BadRequest
|
||||
|
||||
from controllers.console import console_ns
|
||||
|
|
@ -9,6 +11,35 @@ from enums.cloud_plan import CloudPlan
|
|||
from libs.login import current_account_with_tenant, login_required
|
||||
from services.billing_service import BillingService
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class SubscriptionQuery(BaseModel):
|
||||
plan: str = Field(..., description="Subscription plan")
|
||||
interval: str = Field(..., description="Billing interval")
|
||||
|
||||
@field_validator("plan")
|
||||
@classmethod
|
||||
def validate_plan(cls, value: str) -> str:
|
||||
if value not in [CloudPlan.PROFESSIONAL, CloudPlan.TEAM]:
|
||||
raise ValueError("Invalid plan")
|
||||
return value
|
||||
|
||||
@field_validator("interval")
|
||||
@classmethod
|
||||
def validate_interval(cls, value: str) -> str:
|
||||
if value not in {"month", "year"}:
|
||||
raise ValueError("Invalid interval")
|
||||
return value
|
||||
|
||||
|
||||
class PartnerTenantsPayload(BaseModel):
|
||||
click_id: str = Field(..., description="Click Id from partner referral link")
|
||||
|
||||
|
||||
for model in (SubscriptionQuery, PartnerTenantsPayload):
|
||||
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
|
||||
@console_ns.route("/billing/subscription")
|
||||
class Subscription(Resource):
|
||||
|
|
@ -18,20 +49,9 @@ class Subscription(Resource):
|
|||
@only_edition_cloud
|
||||
def get(self):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument(
|
||||
"plan",
|
||||
type=str,
|
||||
required=True,
|
||||
location="args",
|
||||
choices=[CloudPlan.PROFESSIONAL, CloudPlan.TEAM],
|
||||
)
|
||||
.add_argument("interval", type=str, required=True, location="args", choices=["month", "year"])
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = SubscriptionQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
BillingService.is_tenant_owner_or_admin(current_user)
|
||||
return BillingService.get_subscription(args["plan"], args["interval"], current_user.email, current_tenant_id)
|
||||
return BillingService.get_subscription(args.plan, args.interval, current_user.email, current_tenant_id)
|
||||
|
||||
|
||||
@console_ns.route("/billing/invoices")
|
||||
|
|
@ -65,11 +85,10 @@ class PartnerTenants(Resource):
|
|||
@only_edition_cloud
|
||||
def put(self, partner_key: str):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
parser = reqparse.RequestParser().add_argument("click_id", required=True, type=str, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
click_id = args["click_id"]
|
||||
args = PartnerTenantsPayload.model_validate(console_ns.payload or {})
|
||||
click_id = args.click_id
|
||||
decoded_partner_key = base64.b64decode(partner_key).decode("utf-8")
|
||||
except Exception:
|
||||
raise BadRequest("Invalid partner_key")
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
from flask import request
|
||||
from flask_restx import Resource, reqparse
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from libs.helper import extract_remote_ip
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
|
|
@ -9,16 +10,28 @@ from .. import console_ns
|
|||
from ..wraps import account_initialization_required, only_edition_cloud, setup_required
|
||||
|
||||
|
||||
class ComplianceDownloadQuery(BaseModel):
|
||||
doc_name: str = Field(..., description="Compliance document name")
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
ComplianceDownloadQuery.__name__,
|
||||
ComplianceDownloadQuery.model_json_schema(ref_template="#/definitions/{model}"),
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/compliance/download")
|
||||
class ComplianceApi(Resource):
|
||||
@console_ns.expect(console_ns.models[ComplianceDownloadQuery.__name__])
|
||||
@console_ns.doc("download_compliance_document")
|
||||
@console_ns.doc(description="Get compliance document download link")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@only_edition_cloud
|
||||
def get(self):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
parser = reqparse.RequestParser().add_argument("doc_name", type=str, required=True, location="args")
|
||||
args = parser.parse_args()
|
||||
args = ComplianceDownloadQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
ip_address = extract_remote_ip(request)
|
||||
device_info = request.headers.get("User-Agent", "Unknown device")
|
||||
|
|
|
|||
|
|
@ -1,15 +1,15 @@
|
|||
import json
|
||||
from collections.abc import Generator
|
||||
from typing import cast
|
||||
from typing import Any, cast
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, marshal_with, reqparse
|
||||
from flask_restx import Resource, marshal_with
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.common.schema import register_schema_model
|
||||
from core.datasource.entities.datasource_entities import DatasourceProviderType, OnlineDocumentPagesMessage
|
||||
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
|
||||
from core.indexing_runner import IndexingRunner
|
||||
|
|
@ -25,6 +25,19 @@ from services.dataset_service import DatasetService, DocumentService
|
|||
from services.datasource_provider_service import DatasourceProviderService
|
||||
from tasks.document_indexing_sync_task import document_indexing_sync_task
|
||||
|
||||
from .. import console_ns
|
||||
from ..wraps import account_initialization_required, setup_required
|
||||
|
||||
|
||||
class NotionEstimatePayload(BaseModel):
|
||||
notion_info_list: list[dict[str, Any]]
|
||||
process_rule: dict[str, Any]
|
||||
doc_form: str = Field(default="text_model")
|
||||
doc_language: str = Field(default="English")
|
||||
|
||||
|
||||
register_schema_model(console_ns, NotionEstimatePayload)
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
"/data-source/integrates",
|
||||
|
|
@ -243,20 +256,15 @@ class DataSourceNotionApi(Resource):
|
|||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.expect(console_ns.models[NotionEstimatePayload.__name__])
|
||||
def post(self):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("notion_info_list", type=list, required=True, nullable=True, location="json")
|
||||
.add_argument("process_rule", type=dict, required=True, nullable=True, location="json")
|
||||
.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
|
||||
.add_argument("doc_language", type=str, default="English", required=False, nullable=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
payload = NotionEstimatePayload.model_validate(console_ns.payload or {})
|
||||
args = payload.model_dump()
|
||||
# validate args
|
||||
DocumentService.estimate_args_validate(args)
|
||||
notion_info_list = args["notion_info_list"]
|
||||
notion_info_list = payload.notion_info_list
|
||||
extract_settings = []
|
||||
for notion_info in notion_info_list:
|
||||
workspace_id = notion_info["workspace_id"]
|
||||
|
|
|
|||
|
|
@ -1,12 +1,14 @@
|
|||
from typing import Any, cast
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields, marshal, marshal_with, reqparse
|
||||
from flask_restx import Resource, fields, marshal, marshal_with
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy import select
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
import services
|
||||
from configs import dify_config
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.apikey import (
|
||||
api_key_item_model,
|
||||
|
|
@ -48,7 +50,6 @@ from fields.dataset_fields import (
|
|||
)
|
||||
from fields.document_fields import document_status_fields
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.validators import validate_description_length
|
||||
from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile
|
||||
from models.dataset import DatasetPermissionEnum
|
||||
from models.provider_ids import ModelProviderID
|
||||
|
|
@ -107,10 +108,75 @@ related_app_list_copy["data"] = fields.List(fields.Nested(app_detail_kernel_mode
|
|||
related_app_list_model = _get_or_create_model("RelatedAppList", related_app_list_copy)
|
||||
|
||||
|
||||
def _validate_name(name: str) -> str:
|
||||
if not name or len(name) < 1 or len(name) > 40:
|
||||
raise ValueError("Name must be between 1 to 40 characters.")
|
||||
return name
|
||||
def _validate_indexing_technique(value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return value
|
||||
if value not in Dataset.INDEXING_TECHNIQUE_LIST:
|
||||
raise ValueError("Invalid indexing technique.")
|
||||
return value
|
||||
|
||||
|
||||
class DatasetCreatePayload(BaseModel):
|
||||
name: str = Field(..., min_length=1, max_length=40)
|
||||
description: str = Field("", max_length=400)
|
||||
indexing_technique: str | None = None
|
||||
permission: DatasetPermissionEnum | None = DatasetPermissionEnum.ONLY_ME
|
||||
provider: str = "vendor"
|
||||
external_knowledge_api_id: str | None = None
|
||||
external_knowledge_id: str | None = None
|
||||
|
||||
@field_validator("indexing_technique")
|
||||
@classmethod
|
||||
def validate_indexing(cls, value: str | None) -> str | None:
|
||||
return _validate_indexing_technique(value)
|
||||
|
||||
@field_validator("provider")
|
||||
@classmethod
|
||||
def validate_provider(cls, value: str) -> str:
|
||||
if value not in Dataset.PROVIDER_LIST:
|
||||
raise ValueError("Invalid provider.")
|
||||
return value
|
||||
|
||||
|
||||
class DatasetUpdatePayload(BaseModel):
|
||||
name: str | None = Field(None, min_length=1, max_length=40)
|
||||
description: str | None = Field(None, max_length=400)
|
||||
permission: DatasetPermissionEnum | None = None
|
||||
indexing_technique: str | None = None
|
||||
embedding_model: str | None = None
|
||||
embedding_model_provider: str | None = None
|
||||
retrieval_model: dict[str, Any] | None = None
|
||||
partial_member_list: list[str] | None = None
|
||||
external_retrieval_model: dict[str, Any] | None = None
|
||||
external_knowledge_id: str | None = None
|
||||
external_knowledge_api_id: str | None = None
|
||||
icon_info: dict[str, Any] | None = None
|
||||
is_multimodal: bool | None = False
|
||||
|
||||
@field_validator("indexing_technique")
|
||||
@classmethod
|
||||
def validate_indexing(cls, value: str | None) -> str | None:
|
||||
return _validate_indexing_technique(value)
|
||||
|
||||
|
||||
class IndexingEstimatePayload(BaseModel):
|
||||
info_list: dict[str, Any]
|
||||
process_rule: dict[str, Any]
|
||||
indexing_technique: str
|
||||
doc_form: str = "text_model"
|
||||
dataset_id: str | None = None
|
||||
doc_language: str = "English"
|
||||
|
||||
@field_validator("indexing_technique")
|
||||
@classmethod
|
||||
def validate_indexing(cls, value: str) -> str:
|
||||
result = _validate_indexing_technique(value)
|
||||
if result is None:
|
||||
raise ValueError("indexing_technique is required.")
|
||||
return result
|
||||
|
||||
|
||||
register_schema_models(console_ns, DatasetCreatePayload, DatasetUpdatePayload, IndexingEstimatePayload)
|
||||
|
||||
|
||||
def _get_retrieval_methods_by_vector_type(vector_type: str | None, is_mock: bool = False) -> dict[str, list[str]]:
|
||||
|
|
@ -255,20 +321,7 @@ class DatasetListApi(Resource):
|
|||
|
||||
@console_ns.doc("create_dataset")
|
||||
@console_ns.doc(description="Create a new dataset")
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"CreateDatasetRequest",
|
||||
{
|
||||
"name": fields.String(required=True, description="Dataset name (1-40 characters)"),
|
||||
"description": fields.String(description="Dataset description (max 400 characters)"),
|
||||
"indexing_technique": fields.String(description="Indexing technique"),
|
||||
"permission": fields.String(description="Dataset permission"),
|
||||
"provider": fields.String(description="Provider"),
|
||||
"external_knowledge_api_id": fields.String(description="External knowledge API ID"),
|
||||
"external_knowledge_id": fields.String(description="External knowledge ID"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[DatasetCreatePayload.__name__])
|
||||
@console_ns.response(201, "Dataset created successfully")
|
||||
@console_ns.response(400, "Invalid request parameters")
|
||||
@setup_required
|
||||
|
|
@ -276,52 +329,7 @@ class DatasetListApi(Resource):
|
|||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def post(self):
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument(
|
||||
"name",
|
||||
nullable=False,
|
||||
required=True,
|
||||
help="type is required. Name must be between 1 to 40 characters.",
|
||||
type=_validate_name,
|
||||
)
|
||||
.add_argument(
|
||||
"description",
|
||||
type=validate_description_length,
|
||||
nullable=True,
|
||||
required=False,
|
||||
default="",
|
||||
)
|
||||
.add_argument(
|
||||
"indexing_technique",
|
||||
type=str,
|
||||
location="json",
|
||||
choices=Dataset.INDEXING_TECHNIQUE_LIST,
|
||||
nullable=True,
|
||||
help="Invalid indexing technique.",
|
||||
)
|
||||
.add_argument(
|
||||
"external_knowledge_api_id",
|
||||
type=str,
|
||||
nullable=True,
|
||||
required=False,
|
||||
)
|
||||
.add_argument(
|
||||
"provider",
|
||||
type=str,
|
||||
nullable=True,
|
||||
choices=Dataset.PROVIDER_LIST,
|
||||
required=False,
|
||||
default="vendor",
|
||||
)
|
||||
.add_argument(
|
||||
"external_knowledge_id",
|
||||
type=str,
|
||||
nullable=True,
|
||||
required=False,
|
||||
)
|
||||
)
|
||||
args = parser.parse_args()
|
||||
payload = DatasetCreatePayload.model_validate(console_ns.payload or {})
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
|
||||
|
|
@ -331,14 +339,14 @@ class DatasetListApi(Resource):
|
|||
try:
|
||||
dataset = DatasetService.create_empty_dataset(
|
||||
tenant_id=current_tenant_id,
|
||||
name=args["name"],
|
||||
description=args["description"],
|
||||
indexing_technique=args["indexing_technique"],
|
||||
name=payload.name,
|
||||
description=payload.description,
|
||||
indexing_technique=payload.indexing_technique,
|
||||
account=current_user,
|
||||
permission=DatasetPermissionEnum.ONLY_ME,
|
||||
provider=args["provider"],
|
||||
external_knowledge_api_id=args["external_knowledge_api_id"],
|
||||
external_knowledge_id=args["external_knowledge_id"],
|
||||
permission=payload.permission or DatasetPermissionEnum.ONLY_ME,
|
||||
provider=payload.provider,
|
||||
external_knowledge_api_id=payload.external_knowledge_api_id,
|
||||
external_knowledge_id=payload.external_knowledge_id,
|
||||
)
|
||||
except services.errors.dataset.DatasetNameDuplicateError:
|
||||
raise DatasetNameDuplicateError()
|
||||
|
|
@ -399,18 +407,7 @@ class DatasetApi(Resource):
|
|||
|
||||
@console_ns.doc("update_dataset")
|
||||
@console_ns.doc(description="Update dataset details")
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"UpdateDatasetRequest",
|
||||
{
|
||||
"name": fields.String(description="Dataset name"),
|
||||
"description": fields.String(description="Dataset description"),
|
||||
"permission": fields.String(description="Dataset permission"),
|
||||
"indexing_technique": fields.String(description="Indexing technique"),
|
||||
"external_retrieval_model": fields.Raw(description="External retrieval model settings"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[DatasetUpdatePayload.__name__])
|
||||
@console_ns.response(200, "Dataset updated successfully", dataset_detail_model)
|
||||
@console_ns.response(404, "Dataset not found")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
|
|
@ -424,93 +421,25 @@ class DatasetApi(Resource):
|
|||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument(
|
||||
"name",
|
||||
nullable=False,
|
||||
help="type is required. Name must be between 1 to 40 characters.",
|
||||
type=_validate_name,
|
||||
)
|
||||
.add_argument("description", location="json", store_missing=False, type=validate_description_length)
|
||||
.add_argument(
|
||||
"indexing_technique",
|
||||
type=str,
|
||||
location="json",
|
||||
choices=Dataset.INDEXING_TECHNIQUE_LIST,
|
||||
nullable=True,
|
||||
help="Invalid indexing technique.",
|
||||
)
|
||||
.add_argument(
|
||||
"permission",
|
||||
type=str,
|
||||
location="json",
|
||||
choices=(
|
||||
DatasetPermissionEnum.ONLY_ME,
|
||||
DatasetPermissionEnum.ALL_TEAM,
|
||||
DatasetPermissionEnum.PARTIAL_TEAM,
|
||||
),
|
||||
help="Invalid permission.",
|
||||
)
|
||||
.add_argument("embedding_model", type=str, location="json", help="Invalid embedding model.")
|
||||
.add_argument(
|
||||
"embedding_model_provider", type=str, location="json", help="Invalid embedding model provider."
|
||||
)
|
||||
.add_argument("retrieval_model", type=dict, location="json", help="Invalid retrieval model.")
|
||||
.add_argument("partial_member_list", type=list, location="json", help="Invalid parent user list.")
|
||||
.add_argument(
|
||||
"external_retrieval_model",
|
||||
type=dict,
|
||||
required=False,
|
||||
nullable=True,
|
||||
location="json",
|
||||
help="Invalid external retrieval model.",
|
||||
)
|
||||
.add_argument(
|
||||
"external_knowledge_id",
|
||||
type=str,
|
||||
required=False,
|
||||
nullable=True,
|
||||
location="json",
|
||||
help="Invalid external knowledge id.",
|
||||
)
|
||||
.add_argument(
|
||||
"external_knowledge_api_id",
|
||||
type=str,
|
||||
required=False,
|
||||
nullable=True,
|
||||
location="json",
|
||||
help="Invalid external knowledge api id.",
|
||||
)
|
||||
.add_argument(
|
||||
"icon_info",
|
||||
type=dict,
|
||||
required=False,
|
||||
nullable=True,
|
||||
location="json",
|
||||
help="Invalid icon info.",
|
||||
)
|
||||
)
|
||||
args = parser.parse_args()
|
||||
data = request.get_json()
|
||||
payload = DatasetUpdatePayload.model_validate(console_ns.payload or {})
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
# check embedding model setting
|
||||
if (
|
||||
data.get("indexing_technique") == "high_quality"
|
||||
and data.get("embedding_model_provider") is not None
|
||||
and data.get("embedding_model") is not None
|
||||
payload.indexing_technique == "high_quality"
|
||||
and payload.embedding_model_provider is not None
|
||||
and payload.embedding_model is not None
|
||||
):
|
||||
DatasetService.check_embedding_model_setting(
|
||||
dataset.tenant_id, data.get("embedding_model_provider"), data.get("embedding_model")
|
||||
is_multimodal = DatasetService.check_is_multimodal_model(
|
||||
dataset.tenant_id, payload.embedding_model_provider, payload.embedding_model
|
||||
)
|
||||
|
||||
payload.is_multimodal = is_multimodal
|
||||
payload_data = payload.model_dump(exclude_unset=True)
|
||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||
DatasetPermissionService.check_permission(
|
||||
current_user, dataset, data.get("permission"), data.get("partial_member_list")
|
||||
current_user, dataset, payload.permission, payload.partial_member_list
|
||||
)
|
||||
|
||||
dataset = DatasetService.update_dataset(dataset_id_str, args, current_user)
|
||||
dataset = DatasetService.update_dataset(dataset_id_str, payload_data, current_user)
|
||||
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
|
@ -518,15 +447,10 @@ class DatasetApi(Resource):
|
|||
result_data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields))
|
||||
tenant_id = current_tenant_id
|
||||
|
||||
if data.get("partial_member_list") and data.get("permission") == "partial_members":
|
||||
DatasetPermissionService.update_partial_member_list(
|
||||
tenant_id, dataset_id_str, data.get("partial_member_list")
|
||||
)
|
||||
if payload.partial_member_list is not None and payload.permission == DatasetPermissionEnum.PARTIAL_TEAM:
|
||||
DatasetPermissionService.update_partial_member_list(tenant_id, dataset_id_str, payload.partial_member_list)
|
||||
# clear partial member list when permission is only_me or all_team_members
|
||||
elif (
|
||||
data.get("permission") == DatasetPermissionEnum.ONLY_ME
|
||||
or data.get("permission") == DatasetPermissionEnum.ALL_TEAM
|
||||
):
|
||||
elif payload.permission in {DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM}:
|
||||
DatasetPermissionService.clear_partial_member_list(dataset_id_str)
|
||||
|
||||
partial_member_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
|
||||
|
|
@ -615,24 +539,10 @@ class DatasetIndexingEstimateApi(Resource):
|
|||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.expect(console_ns.models[IndexingEstimatePayload.__name__])
|
||||
def post(self):
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("info_list", type=dict, required=True, nullable=True, location="json")
|
||||
.add_argument("process_rule", type=dict, required=True, nullable=True, location="json")
|
||||
.add_argument(
|
||||
"indexing_technique",
|
||||
type=str,
|
||||
required=True,
|
||||
choices=Dataset.INDEXING_TECHNIQUE_LIST,
|
||||
nullable=True,
|
||||
location="json",
|
||||
)
|
||||
.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
|
||||
.add_argument("dataset_id", type=str, required=False, nullable=False, location="json")
|
||||
.add_argument("doc_language", type=str, default="English", required=False, nullable=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
payload = IndexingEstimatePayload.model_validate(console_ns.payload or {})
|
||||
args = payload.model_dump()
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
# validate args
|
||||
DocumentService.estimate_args_validate(args)
|
||||
|
|
|
|||
|
|
@ -6,31 +6,14 @@ from typing import Literal, cast
|
|||
|
||||
import sqlalchemy as sa
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields, marshal, marshal_with, reqparse
|
||||
from flask_restx import Resource, fields, marshal, marshal_with
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import asc, desc, select
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
import services
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import (
|
||||
ProviderModelCurrentlyNotSupportError,
|
||||
ProviderNotInitializeError,
|
||||
ProviderQuotaExceededError,
|
||||
)
|
||||
from controllers.console.datasets.error import (
|
||||
ArchivedDocumentImmutableError,
|
||||
DocumentAlreadyFinishedError,
|
||||
DocumentIndexingError,
|
||||
IndexingEstimateError,
|
||||
InvalidActionError,
|
||||
InvalidMetadataError,
|
||||
)
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
cloud_edition_billing_rate_limit_check,
|
||||
cloud_edition_billing_resource_check,
|
||||
setup_required,
|
||||
)
|
||||
from core.errors.error import (
|
||||
LLMBadRequestError,
|
||||
ModelCurrentlyNotSupportError,
|
||||
|
|
@ -55,10 +38,30 @@ from fields.document_fields import (
|
|||
)
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import Dataset, DatasetProcessRule, Document, DocumentSegment, UploadFile
|
||||
from models import DatasetProcessRule, Document, DocumentSegment, UploadFile
|
||||
from models.dataset import DocumentPipelineExecutionLog
|
||||
from services.dataset_service import DatasetService, DocumentService
|
||||
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig
|
||||
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig, ProcessRule, RetrievalModel
|
||||
|
||||
from ..app.error import (
|
||||
ProviderModelCurrentlyNotSupportError,
|
||||
ProviderNotInitializeError,
|
||||
ProviderQuotaExceededError,
|
||||
)
|
||||
from ..datasets.error import (
|
||||
ArchivedDocumentImmutableError,
|
||||
DocumentAlreadyFinishedError,
|
||||
DocumentIndexingError,
|
||||
IndexingEstimateError,
|
||||
InvalidActionError,
|
||||
InvalidMetadataError,
|
||||
)
|
||||
from ..wraps import (
|
||||
account_initialization_required,
|
||||
cloud_edition_billing_rate_limit_check,
|
||||
cloud_edition_billing_resource_check,
|
||||
setup_required,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -93,6 +96,24 @@ dataset_and_document_fields_copy["documents"] = fields.List(fields.Nested(docume
|
|||
dataset_and_document_model = _get_or_create_model("DatasetAndDocument", dataset_and_document_fields_copy)
|
||||
|
||||
|
||||
class DocumentRetryPayload(BaseModel):
|
||||
document_ids: list[str]
|
||||
|
||||
|
||||
class DocumentRenamePayload(BaseModel):
|
||||
name: str
|
||||
|
||||
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
KnowledgeConfig,
|
||||
ProcessRule,
|
||||
RetrievalModel,
|
||||
DocumentRetryPayload,
|
||||
DocumentRenamePayload,
|
||||
)
|
||||
|
||||
|
||||
class DocumentResource(Resource):
|
||||
def get_document(self, dataset_id: str, document_id: str) -> Document:
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
|
@ -201,8 +222,9 @@ class DatasetDocumentListApi(Resource):
|
|||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id: str):
|
||||
def get(self, dataset_id):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
dataset_id = str(dataset_id)
|
||||
page = request.args.get("page", default=1, type=int)
|
||||
limit = request.args.get("limit", default=20, type=int)
|
||||
search = request.args.get("keyword", default=None, type=str)
|
||||
|
|
@ -310,6 +332,7 @@ class DatasetDocumentListApi(Resource):
|
|||
@marshal_with(dataset_and_document_model)
|
||||
@cloud_edition_billing_resource_check("vector_space")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.expect(console_ns.models[KnowledgeConfig.__name__])
|
||||
def post(self, dataset_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
dataset_id = str(dataset_id)
|
||||
|
|
@ -328,23 +351,7 @@ class DatasetDocumentListApi(Resource):
|
|||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument(
|
||||
"indexing_technique", type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=False, location="json"
|
||||
)
|
||||
.add_argument("data_source", type=dict, required=False, location="json")
|
||||
.add_argument("process_rule", type=dict, required=False, location="json")
|
||||
.add_argument("duplicate", type=bool, default=True, nullable=False, location="json")
|
||||
.add_argument("original_document_id", type=str, required=False, location="json")
|
||||
.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
|
||||
.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json")
|
||||
.add_argument("embedding_model", type=str, required=False, nullable=True, location="json")
|
||||
.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
|
||||
.add_argument("doc_language", type=str, default="English", required=False, nullable=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
knowledge_config = KnowledgeConfig.model_validate(args)
|
||||
knowledge_config = KnowledgeConfig.model_validate(console_ns.payload or {})
|
||||
|
||||
if not dataset.indexing_technique and not knowledge_config.indexing_technique:
|
||||
raise ValueError("indexing_technique is required.")
|
||||
|
|
@ -390,17 +397,7 @@ class DatasetDocumentListApi(Resource):
|
|||
class DatasetInitApi(Resource):
|
||||
@console_ns.doc("init_dataset")
|
||||
@console_ns.doc(description="Initialize dataset with documents")
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"DatasetInitRequest",
|
||||
{
|
||||
"upload_file_id": fields.String(required=True, description="Upload file ID"),
|
||||
"indexing_technique": fields.String(description="Indexing technique"),
|
||||
"process_rule": fields.Raw(description="Processing rules"),
|
||||
"data_source": fields.Raw(description="Data source configuration"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[KnowledgeConfig.__name__])
|
||||
@console_ns.response(201, "Dataset initialized successfully", dataset_and_document_model)
|
||||
@console_ns.response(400, "Invalid request parameters")
|
||||
@setup_required
|
||||
|
|
@ -415,27 +412,7 @@ class DatasetInitApi(Resource):
|
|||
if not current_user.is_dataset_editor:
|
||||
raise Forbidden()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument(
|
||||
"indexing_technique",
|
||||
type=str,
|
||||
choices=Dataset.INDEXING_TECHNIQUE_LIST,
|
||||
required=True,
|
||||
nullable=False,
|
||||
location="json",
|
||||
)
|
||||
.add_argument("data_source", type=dict, required=True, nullable=True, location="json")
|
||||
.add_argument("process_rule", type=dict, required=True, nullable=True, location="json")
|
||||
.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
|
||||
.add_argument("doc_language", type=str, default="English", required=False, nullable=False, location="json")
|
||||
.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json")
|
||||
.add_argument("embedding_model", type=str, required=False, nullable=True, location="json")
|
||||
.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
knowledge_config = KnowledgeConfig.model_validate(args)
|
||||
knowledge_config = KnowledgeConfig.model_validate(console_ns.payload or {})
|
||||
if knowledge_config.indexing_technique == "high_quality":
|
||||
if knowledge_config.embedding_model is None or knowledge_config.embedding_model_provider is None:
|
||||
raise ValueError("embedding model and embedding model provider are required for high quality indexing.")
|
||||
|
|
@ -443,10 +420,14 @@ class DatasetInitApi(Resource):
|
|||
model_manager = ModelManager()
|
||||
model_manager.get_model_instance(
|
||||
tenant_id=current_tenant_id,
|
||||
provider=args["embedding_model_provider"],
|
||||
provider=knowledge_config.embedding_model_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=args["embedding_model"],
|
||||
model=knowledge_config.embedding_model,
|
||||
)
|
||||
is_multimodal = DatasetService.check_is_multimodal_model(
|
||||
current_tenant_id, knowledge_config.embedding_model_provider, knowledge_config.embedding_model
|
||||
)
|
||||
knowledge_config.is_multimodal = is_multimodal
|
||||
except InvokeAuthorizationError:
|
||||
raise ProviderNotInitializeError(
|
||||
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
|
||||
|
|
@ -1076,19 +1057,16 @@ class DocumentRetryApi(DocumentResource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.expect(console_ns.models[DocumentRetryPayload.__name__])
|
||||
def post(self, dataset_id):
|
||||
"""retry document."""
|
||||
|
||||
parser = reqparse.RequestParser().add_argument(
|
||||
"document_ids", type=list, required=True, nullable=False, location="json"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
payload = DocumentRetryPayload.model_validate(console_ns.payload or {})
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
retry_documents = []
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
for document_id in args["document_ids"]:
|
||||
for document_id in payload.document_ids:
|
||||
try:
|
||||
document_id = str(document_id)
|
||||
|
||||
|
|
@ -1121,6 +1099,7 @@ class DocumentRenameApi(DocumentResource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(document_fields)
|
||||
@console_ns.expect(console_ns.models[DocumentRenamePayload.__name__])
|
||||
def post(self, dataset_id, document_id):
|
||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
|
@ -1130,11 +1109,10 @@ class DocumentRenameApi(DocumentResource):
|
|||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
DatasetService.check_dataset_operator_permission(current_user, dataset)
|
||||
parser = reqparse.RequestParser().add_argument("name", type=str, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
payload = DocumentRenamePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
try:
|
||||
document = DocumentService.rename_document(dataset_id, document_id, args["name"])
|
||||
document = DocumentService.rename_document(dataset_id, document_id, payload.name)
|
||||
except services.errors.document.DocumentIndexingError:
|
||||
raise DocumentIndexingError("Cannot delete document during indexing.")
|
||||
|
||||
|
|
|
|||
|
|
@ -1,11 +1,13 @@
|
|||
import uuid
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, marshal, reqparse
|
||||
from flask_restx import Resource, marshal
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
import services
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import ProviderNotInitializeError
|
||||
from controllers.console.datasets.error import (
|
||||
|
|
@ -36,6 +38,58 @@ from services.errors.chunk import ChildChunkIndexingError as ChildChunkIndexingS
|
|||
from tasks.batch_create_segment_to_index_task import batch_create_segment_to_index_task
|
||||
|
||||
|
||||
class SegmentListQuery(BaseModel):
|
||||
limit: int = Field(default=20, ge=1, le=100)
|
||||
status: list[str] = Field(default_factory=list)
|
||||
hit_count_gte: int | None = None
|
||||
enabled: str = Field(default="all")
|
||||
keyword: str | None = None
|
||||
page: int = Field(default=1, ge=1)
|
||||
|
||||
|
||||
class SegmentCreatePayload(BaseModel):
|
||||
content: str
|
||||
answer: str | None = None
|
||||
keywords: list[str] | None = None
|
||||
attachment_ids: list[str] | None = None
|
||||
|
||||
|
||||
class SegmentUpdatePayload(BaseModel):
|
||||
content: str
|
||||
answer: str | None = None
|
||||
keywords: list[str] | None = None
|
||||
regenerate_child_chunks: bool = False
|
||||
attachment_ids: list[str] | None = None
|
||||
|
||||
|
||||
class BatchImportPayload(BaseModel):
|
||||
upload_file_id: str
|
||||
|
||||
|
||||
class ChildChunkCreatePayload(BaseModel):
|
||||
content: str
|
||||
|
||||
|
||||
class ChildChunkUpdatePayload(BaseModel):
|
||||
content: str
|
||||
|
||||
|
||||
class ChildChunkBatchUpdatePayload(BaseModel):
|
||||
chunks: list[ChildChunkUpdateArgs]
|
||||
|
||||
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
SegmentListQuery,
|
||||
SegmentCreatePayload,
|
||||
SegmentUpdatePayload,
|
||||
BatchImportPayload,
|
||||
ChildChunkCreatePayload,
|
||||
ChildChunkUpdatePayload,
|
||||
ChildChunkBatchUpdatePayload,
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments")
|
||||
class DatasetDocumentSegmentListApi(Resource):
|
||||
@setup_required
|
||||
|
|
@ -60,23 +114,18 @@ class DatasetDocumentSegmentListApi(Resource):
|
|||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("limit", type=int, default=20, location="args")
|
||||
.add_argument("status", type=str, action="append", default=[], location="args")
|
||||
.add_argument("hit_count_gte", type=int, default=None, location="args")
|
||||
.add_argument("enabled", type=str, default="all", location="args")
|
||||
.add_argument("keyword", type=str, default=None, location="args")
|
||||
.add_argument("page", type=int, default=1, location="args")
|
||||
args = SegmentListQuery.model_validate(
|
||||
{
|
||||
**request.args.to_dict(),
|
||||
"status": request.args.getlist("status"),
|
||||
}
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
page = args["page"]
|
||||
limit = min(args["limit"], 100)
|
||||
status_list = args["status"]
|
||||
hit_count_gte = args["hit_count_gte"]
|
||||
keyword = args["keyword"]
|
||||
page = args.page
|
||||
limit = min(args.limit, 100)
|
||||
status_list = args.status
|
||||
hit_count_gte = args.hit_count_gte
|
||||
keyword = args.keyword
|
||||
|
||||
query = (
|
||||
select(DocumentSegment)
|
||||
|
|
@ -96,10 +145,10 @@ class DatasetDocumentSegmentListApi(Resource):
|
|||
if keyword:
|
||||
query = query.where(DocumentSegment.content.ilike(f"%{keyword}%"))
|
||||
|
||||
if args["enabled"].lower() != "all":
|
||||
if args["enabled"].lower() == "true":
|
||||
if args.enabled.lower() != "all":
|
||||
if args.enabled.lower() == "true":
|
||||
query = query.where(DocumentSegment.enabled == True)
|
||||
elif args["enabled"].lower() == "false":
|
||||
elif args.enabled.lower() == "false":
|
||||
query = query.where(DocumentSegment.enabled == False)
|
||||
|
||||
segments = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
|
||||
|
|
@ -210,6 +259,7 @@ class DatasetDocumentSegmentAddApi(Resource):
|
|||
@cloud_edition_billing_resource_check("vector_space")
|
||||
@cloud_edition_billing_knowledge_limit_check("add_segment")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.expect(console_ns.models[SegmentCreatePayload.__name__])
|
||||
def post(self, dataset_id, document_id):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
|
|
@ -246,15 +296,10 @@ class DatasetDocumentSegmentAddApi(Resource):
|
|||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
# validate args
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("content", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("answer", type=str, required=False, nullable=True, location="json")
|
||||
.add_argument("keywords", type=list, required=False, nullable=True, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
SegmentService.segment_create_args_validate(args, document)
|
||||
segment = SegmentService.create_segment(args, document, dataset)
|
||||
payload = SegmentCreatePayload.model_validate(console_ns.payload or {})
|
||||
payload_dict = payload.model_dump(exclude_none=True)
|
||||
SegmentService.segment_create_args_validate(payload_dict, document)
|
||||
segment = SegmentService.create_segment(payload_dict, document, dataset)
|
||||
return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200
|
||||
|
||||
|
||||
|
|
@ -265,6 +310,7 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
|||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("vector_space")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.expect(console_ns.models[SegmentUpdatePayload.__name__])
|
||||
def patch(self, dataset_id, document_id, segment_id):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
|
|
@ -313,18 +359,12 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
|||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
# validate args
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("content", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("answer", type=str, required=False, nullable=True, location="json")
|
||||
.add_argument("keywords", type=list, required=False, nullable=True, location="json")
|
||||
.add_argument(
|
||||
"regenerate_child_chunks", type=bool, required=False, nullable=True, default=False, location="json"
|
||||
)
|
||||
payload = SegmentUpdatePayload.model_validate(console_ns.payload or {})
|
||||
payload_dict = payload.model_dump(exclude_none=True)
|
||||
SegmentService.segment_create_args_validate(payload_dict, document)
|
||||
segment = SegmentService.update_segment(
|
||||
SegmentUpdateArgs.model_validate(payload.model_dump(exclude_none=True)), segment, document, dataset
|
||||
)
|
||||
args = parser.parse_args()
|
||||
SegmentService.segment_create_args_validate(args, document)
|
||||
segment = SegmentService.update_segment(SegmentUpdateArgs.model_validate(args), segment, document, dataset)
|
||||
return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200
|
||||
|
||||
@setup_required
|
||||
|
|
@ -377,6 +417,7 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
|
|||
@cloud_edition_billing_resource_check("vector_space")
|
||||
@cloud_edition_billing_knowledge_limit_check("add_segment")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.expect(console_ns.models[BatchImportPayload.__name__])
|
||||
def post(self, dataset_id, document_id):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
|
|
@ -391,11 +432,8 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
|
|||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
|
||||
parser = reqparse.RequestParser().add_argument(
|
||||
"upload_file_id", type=str, required=True, nullable=False, location="json"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
upload_file_id = args["upload_file_id"]
|
||||
payload = BatchImportPayload.model_validate(console_ns.payload or {})
|
||||
upload_file_id = payload.upload_file_id
|
||||
|
||||
upload_file = db.session.query(UploadFile).where(UploadFile.id == upload_file_id).first()
|
||||
if not upload_file:
|
||||
|
|
@ -446,6 +484,7 @@ class ChildChunkAddApi(Resource):
|
|||
@cloud_edition_billing_resource_check("vector_space")
|
||||
@cloud_edition_billing_knowledge_limit_check("add_segment")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.expect(console_ns.models[ChildChunkCreatePayload.__name__])
|
||||
def post(self, dataset_id, document_id, segment_id):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
|
|
@ -491,13 +530,9 @@ class ChildChunkAddApi(Resource):
|
|||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
# validate args
|
||||
parser = reqparse.RequestParser().add_argument(
|
||||
"content", type=str, required=True, nullable=False, location="json"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
try:
|
||||
content = args["content"]
|
||||
child_chunk = SegmentService.create_child_chunk(content, segment, document, dataset)
|
||||
payload = ChildChunkCreatePayload.model_validate(console_ns.payload or {})
|
||||
child_chunk = SegmentService.create_child_chunk(payload.content, segment, document, dataset)
|
||||
except ChildChunkIndexingServiceError as e:
|
||||
raise ChildChunkIndexingError(str(e))
|
||||
return {"data": marshal(child_chunk, child_chunk_fields)}, 200
|
||||
|
|
@ -529,18 +564,17 @@ class ChildChunkAddApi(Resource):
|
|||
)
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("limit", type=int, default=20, location="args")
|
||||
.add_argument("keyword", type=str, default=None, location="args")
|
||||
.add_argument("page", type=int, default=1, location="args")
|
||||
args = SegmentListQuery.model_validate(
|
||||
{
|
||||
"limit": request.args.get("limit", default=20, type=int),
|
||||
"keyword": request.args.get("keyword"),
|
||||
"page": request.args.get("page", default=1, type=int),
|
||||
}
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
page = args["page"]
|
||||
limit = min(args["limit"], 100)
|
||||
keyword = args["keyword"]
|
||||
page = args.page
|
||||
limit = min(args.limit, 100)
|
||||
keyword = args.keyword
|
||||
|
||||
child_chunks = SegmentService.get_child_chunks(segment_id, document_id, dataset_id, page, limit, keyword)
|
||||
return {
|
||||
|
|
@ -588,14 +622,9 @@ class ChildChunkAddApi(Resource):
|
|||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
# validate args
|
||||
parser = reqparse.RequestParser().add_argument(
|
||||
"chunks", type=list, required=True, nullable=False, location="json"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
payload = ChildChunkBatchUpdatePayload.model_validate(console_ns.payload or {})
|
||||
try:
|
||||
chunks_data = args["chunks"]
|
||||
chunks = [ChildChunkUpdateArgs.model_validate(chunk) for chunk in chunks_data]
|
||||
child_chunks = SegmentService.update_child_chunks(chunks, segment, document, dataset)
|
||||
child_chunks = SegmentService.update_child_chunks(payload.chunks, segment, document, dataset)
|
||||
except ChildChunkIndexingServiceError as e:
|
||||
raise ChildChunkIndexingError(str(e))
|
||||
return {"data": marshal(child_chunks, child_chunk_fields)}, 200
|
||||
|
|
@ -665,6 +694,7 @@ class ChildChunkUpdateApi(Resource):
|
|||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("vector_space")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.expect(console_ns.models[ChildChunkUpdatePayload.__name__])
|
||||
def patch(self, dataset_id, document_id, segment_id, child_chunk_id):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
|
|
@ -711,13 +741,9 @@ class ChildChunkUpdateApi(Resource):
|
|||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
# validate args
|
||||
parser = reqparse.RequestParser().add_argument(
|
||||
"content", type=str, required=True, nullable=False, location="json"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
try:
|
||||
content = args["content"]
|
||||
child_chunk = SegmentService.update_child_chunk(content, child_chunk, segment, document, dataset)
|
||||
payload = ChildChunkUpdatePayload.model_validate(console_ns.payload or {})
|
||||
child_chunk = SegmentService.update_child_chunk(payload.content, child_chunk, segment, document, dataset)
|
||||
except ChildChunkIndexingServiceError as e:
|
||||
raise ChildChunkIndexingError(str(e))
|
||||
return {"data": marshal(child_chunk, child_chunk_fields)}, 200
|
||||
|
|
|
|||
|
|
@ -1,8 +1,10 @@
|
|||
from flask import request
|
||||
from flask_restx import Resource, fields, marshal, reqparse
|
||||
from flask_restx import Resource, fields, marshal
|
||||
from pydantic import BaseModel, Field
|
||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.datasets.error import DatasetNameDuplicateError
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
|
|
@ -71,10 +73,38 @@ except KeyError:
|
|||
dataset_detail_model = _build_dataset_detail_model()
|
||||
|
||||
|
||||
def _validate_name(name: str) -> str:
|
||||
if not name or len(name) < 1 or len(name) > 100:
|
||||
raise ValueError("Name must be between 1 to 100 characters.")
|
||||
return name
|
||||
class ExternalKnowledgeApiPayload(BaseModel):
|
||||
name: str = Field(..., min_length=1, max_length=40)
|
||||
settings: dict[str, object]
|
||||
|
||||
|
||||
class ExternalDatasetCreatePayload(BaseModel):
|
||||
external_knowledge_api_id: str
|
||||
external_knowledge_id: str
|
||||
name: str = Field(..., min_length=1, max_length=40)
|
||||
description: str | None = Field(None, max_length=400)
|
||||
external_retrieval_model: dict[str, object] | None = None
|
||||
|
||||
|
||||
class ExternalHitTestingPayload(BaseModel):
|
||||
query: str
|
||||
external_retrieval_model: dict[str, object] | None = None
|
||||
metadata_filtering_conditions: dict[str, object] | None = None
|
||||
|
||||
|
||||
class BedrockRetrievalPayload(BaseModel):
|
||||
retrieval_setting: dict[str, object]
|
||||
query: str
|
||||
knowledge_id: str
|
||||
|
||||
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
ExternalKnowledgeApiPayload,
|
||||
ExternalDatasetCreatePayload,
|
||||
ExternalHitTestingPayload,
|
||||
BedrockRetrievalPayload,
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/datasets/external-knowledge-api")
|
||||
|
|
@ -113,28 +143,12 @@ class ExternalApiTemplateListApi(Resource):
|
|||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.expect(console_ns.models[ExternalKnowledgeApiPayload.__name__])
|
||||
def post(self):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument(
|
||||
"name",
|
||||
nullable=False,
|
||||
required=True,
|
||||
help="Name is required. Name must be between 1 to 100 characters.",
|
||||
type=_validate_name,
|
||||
)
|
||||
.add_argument(
|
||||
"settings",
|
||||
type=dict,
|
||||
location="json",
|
||||
nullable=False,
|
||||
required=True,
|
||||
)
|
||||
)
|
||||
args = parser.parse_args()
|
||||
payload = ExternalKnowledgeApiPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
ExternalDatasetService.validate_api_list(args["settings"])
|
||||
ExternalDatasetService.validate_api_list(payload.settings)
|
||||
|
||||
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
|
||||
if not current_user.is_dataset_editor:
|
||||
|
|
@ -142,7 +156,7 @@ class ExternalApiTemplateListApi(Resource):
|
|||
|
||||
try:
|
||||
external_knowledge_api = ExternalDatasetService.create_external_knowledge_api(
|
||||
tenant_id=current_tenant_id, user_id=current_user.id, args=args
|
||||
tenant_id=current_tenant_id, user_id=current_user.id, args=payload.model_dump()
|
||||
)
|
||||
except services.errors.dataset.DatasetNameDuplicateError:
|
||||
raise DatasetNameDuplicateError()
|
||||
|
|
@ -171,35 +185,19 @@ class ExternalApiTemplateApi(Resource):
|
|||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.expect(console_ns.models[ExternalKnowledgeApiPayload.__name__])
|
||||
def patch(self, external_knowledge_api_id):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
external_knowledge_api_id = str(external_knowledge_api_id)
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument(
|
||||
"name",
|
||||
nullable=False,
|
||||
required=True,
|
||||
help="type is required. Name must be between 1 to 100 characters.",
|
||||
type=_validate_name,
|
||||
)
|
||||
.add_argument(
|
||||
"settings",
|
||||
type=dict,
|
||||
location="json",
|
||||
nullable=False,
|
||||
required=True,
|
||||
)
|
||||
)
|
||||
args = parser.parse_args()
|
||||
ExternalDatasetService.validate_api_list(args["settings"])
|
||||
payload = ExternalKnowledgeApiPayload.model_validate(console_ns.payload or {})
|
||||
ExternalDatasetService.validate_api_list(payload.settings)
|
||||
|
||||
external_knowledge_api = ExternalDatasetService.update_external_knowledge_api(
|
||||
tenant_id=current_tenant_id,
|
||||
user_id=current_user.id,
|
||||
external_knowledge_api_id=external_knowledge_api_id,
|
||||
args=args,
|
||||
args=payload.model_dump(),
|
||||
)
|
||||
|
||||
return external_knowledge_api.to_dict(), 200
|
||||
|
|
@ -240,17 +238,7 @@ class ExternalApiUseCheckApi(Resource):
|
|||
class ExternalDatasetCreateApi(Resource):
|
||||
@console_ns.doc("create_external_dataset")
|
||||
@console_ns.doc(description="Create external knowledge dataset")
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"CreateExternalDatasetRequest",
|
||||
{
|
||||
"external_knowledge_api_id": fields.String(required=True, description="External knowledge API ID"),
|
||||
"external_knowledge_id": fields.String(required=True, description="External knowledge ID"),
|
||||
"name": fields.String(required=True, description="Dataset name"),
|
||||
"description": fields.String(description="Dataset description"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[ExternalDatasetCreatePayload.__name__])
|
||||
@console_ns.response(201, "External dataset created successfully", dataset_detail_model)
|
||||
@console_ns.response(400, "Invalid parameters")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
|
|
@ -261,22 +249,8 @@ class ExternalDatasetCreateApi(Resource):
|
|||
def post(self):
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("external_knowledge_api_id", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("external_knowledge_id", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument(
|
||||
"name",
|
||||
nullable=False,
|
||||
required=True,
|
||||
help="name is required. Name must be between 1 to 100 characters.",
|
||||
type=_validate_name,
|
||||
)
|
||||
.add_argument("description", type=str, required=False, nullable=True, location="json")
|
||||
.add_argument("external_retrieval_model", type=dict, required=False, location="json")
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
payload = ExternalDatasetCreatePayload.model_validate(console_ns.payload or {})
|
||||
args = payload.model_dump(exclude_none=True)
|
||||
|
||||
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
|
||||
if not current_user.is_dataset_editor:
|
||||
|
|
@ -299,16 +273,7 @@ class ExternalKnowledgeHitTestingApi(Resource):
|
|||
@console_ns.doc("test_external_knowledge_retrieval")
|
||||
@console_ns.doc(description="Test external knowledge retrieval for dataset")
|
||||
@console_ns.doc(params={"dataset_id": "Dataset ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"ExternalHitTestingRequest",
|
||||
{
|
||||
"query": fields.String(required=True, description="Query text for testing"),
|
||||
"retrieval_model": fields.Raw(description="Retrieval model configuration"),
|
||||
"external_retrieval_model": fields.Raw(description="External retrieval model configuration"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[ExternalHitTestingPayload.__name__])
|
||||
@console_ns.response(200, "External hit testing completed successfully")
|
||||
@console_ns.response(404, "Dataset not found")
|
||||
@console_ns.response(400, "Invalid parameters")
|
||||
|
|
@ -327,23 +292,16 @@ class ExternalKnowledgeHitTestingApi(Resource):
|
|||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("query", type=str, location="json")
|
||||
.add_argument("external_retrieval_model", type=dict, required=False, location="json")
|
||||
.add_argument("metadata_filtering_conditions", type=dict, required=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
HitTestingService.hit_testing_args_check(args)
|
||||
payload = ExternalHitTestingPayload.model_validate(console_ns.payload or {})
|
||||
HitTestingService.hit_testing_args_check(payload.model_dump())
|
||||
|
||||
try:
|
||||
response = HitTestingService.external_retrieve(
|
||||
dataset=dataset,
|
||||
query=args["query"],
|
||||
query=payload.query,
|
||||
account=current_user,
|
||||
external_retrieval_model=args["external_retrieval_model"],
|
||||
metadata_filtering_conditions=args["metadata_filtering_conditions"],
|
||||
external_retrieval_model=payload.external_retrieval_model,
|
||||
metadata_filtering_conditions=payload.metadata_filtering_conditions,
|
||||
)
|
||||
|
||||
return response
|
||||
|
|
@ -356,33 +314,13 @@ class BedrockRetrievalApi(Resource):
|
|||
# this api is only for internal testing
|
||||
@console_ns.doc("bedrock_retrieval_test")
|
||||
@console_ns.doc(description="Bedrock retrieval test (internal use only)")
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"BedrockRetrievalTestRequest",
|
||||
{
|
||||
"retrieval_setting": fields.Raw(required=True, description="Retrieval settings"),
|
||||
"query": fields.String(required=True, description="Query text"),
|
||||
"knowledge_id": fields.String(required=True, description="Knowledge ID"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[BedrockRetrievalPayload.__name__])
|
||||
@console_ns.response(200, "Bedrock retrieval test completed")
|
||||
def post(self):
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("retrieval_setting", nullable=False, required=True, type=dict, location="json")
|
||||
.add_argument(
|
||||
"query",
|
||||
nullable=False,
|
||||
required=True,
|
||||
type=str,
|
||||
)
|
||||
.add_argument("knowledge_id", nullable=False, required=True, type=str)
|
||||
)
|
||||
args = parser.parse_args()
|
||||
payload = BedrockRetrievalPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
# Call the knowledge retrieval service
|
||||
result = ExternalDatasetTestService.knowledge_retrieval(
|
||||
args["retrieval_setting"], args["query"], args["knowledge_id"]
|
||||
payload.retrieval_setting, payload.query, payload.knowledge_id
|
||||
)
|
||||
return result, 200
|
||||
|
|
|
|||
|
|
@ -1,13 +1,17 @@
|
|||
from flask_restx import Resource, fields
|
||||
from flask_restx import Resource
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase
|
||||
from controllers.console.wraps import (
|
||||
from controllers.common.schema import register_schema_model
|
||||
from libs.login import login_required
|
||||
|
||||
from .. import console_ns
|
||||
from ..datasets.hit_testing_base import DatasetsHitTestingBase, HitTestingPayload
|
||||
from ..wraps import (
|
||||
account_initialization_required,
|
||||
cloud_edition_billing_rate_limit_check,
|
||||
setup_required,
|
||||
)
|
||||
from libs.login import login_required
|
||||
|
||||
register_schema_model(console_ns, HitTestingPayload)
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/hit-testing")
|
||||
|
|
@ -15,17 +19,7 @@ class HitTestingApi(Resource, DatasetsHitTestingBase):
|
|||
@console_ns.doc("test_dataset_retrieval")
|
||||
@console_ns.doc(description="Test dataset knowledge retrieval")
|
||||
@console_ns.doc(params={"dataset_id": "Dataset ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"HitTestingRequest",
|
||||
{
|
||||
"query": fields.String(required=True, description="Query text for testing"),
|
||||
"retrieval_model": fields.Raw(description="Retrieval model configuration"),
|
||||
"top_k": fields.Integer(description="Number of top results to return"),
|
||||
"score_threshold": fields.Float(description="Score threshold for filtering results"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[HitTestingPayload.__name__])
|
||||
@console_ns.response(200, "Hit testing completed successfully")
|
||||
@console_ns.response(404, "Dataset not found")
|
||||
@console_ns.response(400, "Invalid parameters")
|
||||
|
|
@ -37,7 +31,8 @@ class HitTestingApi(Resource, DatasetsHitTestingBase):
|
|||
dataset_id_str = str(dataset_id)
|
||||
|
||||
dataset = self.get_and_validate_dataset(dataset_id_str)
|
||||
args = self.parse_args()
|
||||
payload = HitTestingPayload.model_validate(console_ns.payload or {})
|
||||
args = payload.model_dump(exclude_none=True)
|
||||
self.hit_testing_args_check(args)
|
||||
|
||||
return self.perform_hit_testing(dataset, args)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
import logging
|
||||
from typing import Any
|
||||
|
||||
from flask_restx import marshal, reqparse
|
||||
from pydantic import BaseModel, Field
|
||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
|
|
@ -27,6 +29,13 @@ from services.hit_testing_service import HitTestingService
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HitTestingPayload(BaseModel):
|
||||
query: str = Field(max_length=250)
|
||||
retrieval_model: dict[str, Any] | None = None
|
||||
external_retrieval_model: dict[str, Any] | None = None
|
||||
attachment_ids: list[str] | None = None
|
||||
|
||||
|
||||
class DatasetsHitTestingBase:
|
||||
@staticmethod
|
||||
def get_and_validate_dataset(dataset_id: str):
|
||||
|
|
@ -43,14 +52,15 @@ class DatasetsHitTestingBase:
|
|||
return dataset
|
||||
|
||||
@staticmethod
|
||||
def hit_testing_args_check(args):
|
||||
def hit_testing_args_check(args: dict[str, Any]):
|
||||
HitTestingService.hit_testing_args_check(args)
|
||||
|
||||
@staticmethod
|
||||
def parse_args():
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("query", type=str, location="json")
|
||||
.add_argument("query", type=str, required=False, location="json")
|
||||
.add_argument("attachment_ids", type=list, required=False, location="json")
|
||||
.add_argument("retrieval_model", type=dict, required=False, location="json")
|
||||
.add_argument("external_retrieval_model", type=dict, required=False, location="json")
|
||||
)
|
||||
|
|
@ -62,10 +72,11 @@ class DatasetsHitTestingBase:
|
|||
try:
|
||||
response = HitTestingService.retrieve(
|
||||
dataset=dataset,
|
||||
query=args["query"],
|
||||
query=args.get("query"),
|
||||
account=current_user,
|
||||
retrieval_model=args["retrieval_model"],
|
||||
external_retrieval_model=args["external_retrieval_model"],
|
||||
retrieval_model=args.get("retrieval_model"),
|
||||
external_retrieval_model=args.get("external_retrieval_model"),
|
||||
attachment_ids=args.get("attachment_ids"),
|
||||
limit=10,
|
||||
)
|
||||
return {"query": response["query"], "records": marshal(response["records"], hit_testing_record_fields)}
|
||||
|
|
|
|||
|
|
@ -1,8 +1,10 @@
|
|||
from typing import Literal
|
||||
|
||||
from flask_restx import Resource, marshal_with, reqparse
|
||||
from flask_restx import Resource, marshal_with
|
||||
from pydantic import BaseModel
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.common.schema import register_schema_model, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required
|
||||
from fields.dataset_fields import dataset_metadata_fields
|
||||
|
|
@ -15,6 +17,14 @@ from services.entities.knowledge_entities.knowledge_entities import (
|
|||
from services.metadata_service import MetadataService
|
||||
|
||||
|
||||
class MetadataUpdatePayload(BaseModel):
|
||||
name: str
|
||||
|
||||
|
||||
register_schema_models(console_ns, MetadataArgs, MetadataOperationData)
|
||||
register_schema_model(console_ns, MetadataUpdatePayload)
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/metadata")
|
||||
class DatasetMetadataCreateApi(Resource):
|
||||
@setup_required
|
||||
|
|
@ -22,15 +32,10 @@ class DatasetMetadataCreateApi(Resource):
|
|||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
@marshal_with(dataset_metadata_fields)
|
||||
@console_ns.expect(console_ns.models[MetadataArgs.__name__])
|
||||
def post(self, dataset_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("type", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("name", type=str, required=True, nullable=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
metadata_args = MetadataArgs.model_validate(args)
|
||||
metadata_args = MetadataArgs.model_validate(console_ns.payload or {})
|
||||
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
|
|
@ -60,11 +65,11 @@ class DatasetMetadataApi(Resource):
|
|||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
@marshal_with(dataset_metadata_fields)
|
||||
@console_ns.expect(console_ns.models[MetadataUpdatePayload.__name__])
|
||||
def patch(self, dataset_id, metadata_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
parser = reqparse.RequestParser().add_argument("name", type=str, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
name = args["name"]
|
||||
payload = MetadataUpdatePayload.model_validate(console_ns.payload or {})
|
||||
name = payload.name
|
||||
|
||||
dataset_id_str = str(dataset_id)
|
||||
metadata_id_str = str(metadata_id)
|
||||
|
|
@ -131,6 +136,7 @@ class DocumentMetadataEditApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
@console_ns.expect(console_ns.models[MetadataOperationData.__name__])
|
||||
def post(self, dataset_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
|
|
@ -139,11 +145,7 @@ class DocumentMetadataEditApi(Resource):
|
|||
raise NotFound("Dataset not found.")
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
|
||||
parser = reqparse.RequestParser().add_argument(
|
||||
"operation_data", type=list, required=True, nullable=False, location="json"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
metadata_args = MetadataOperationData.model_validate(args)
|
||||
metadata_args = MetadataOperationData.model_validate(console_ns.payload or {})
|
||||
|
||||
MetadataService.update_documents_metadata(dataset, metadata_args)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,20 +1,63 @@
|
|||
from typing import Any
|
||||
|
||||
from flask import make_response, redirect, request
|
||||
from flask_restx import Resource, reqparse
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.plugin.impl.oauth import OAuthHandler
|
||||
from libs.helper import StrLen
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.provider_ids import DatasourceProviderID
|
||||
from services.datasource_provider_service import DatasourceProviderService
|
||||
from services.plugin.oauth_service import OAuthProxyService
|
||||
|
||||
|
||||
class DatasourceCredentialPayload(BaseModel):
|
||||
name: str | None = Field(default=None, max_length=100)
|
||||
credentials: dict[str, Any]
|
||||
|
||||
|
||||
class DatasourceCredentialDeletePayload(BaseModel):
|
||||
credential_id: str
|
||||
|
||||
|
||||
class DatasourceCredentialUpdatePayload(BaseModel):
|
||||
credential_id: str
|
||||
name: str | None = Field(default=None, max_length=100)
|
||||
credentials: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class DatasourceCustomClientPayload(BaseModel):
|
||||
client_params: dict[str, Any] | None = None
|
||||
enable_oauth_custom_client: bool | None = None
|
||||
|
||||
|
||||
class DatasourceDefaultPayload(BaseModel):
|
||||
id: str
|
||||
|
||||
|
||||
class DatasourceUpdateNamePayload(BaseModel):
|
||||
credential_id: str
|
||||
name: str = Field(max_length=100)
|
||||
|
||||
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
DatasourceCredentialPayload,
|
||||
DatasourceCredentialDeletePayload,
|
||||
DatasourceCredentialUpdatePayload,
|
||||
DatasourceCustomClientPayload,
|
||||
DatasourceDefaultPayload,
|
||||
DatasourceUpdateNamePayload,
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/oauth/plugin/<path:provider_id>/datasource/get-authorization-url")
|
||||
class DatasourcePluginOAuthAuthorizationUrl(Resource):
|
||||
@setup_required
|
||||
|
|
@ -121,16 +164,9 @@ class DatasourceOAuthCallback(Resource):
|
|||
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")
|
||||
|
||||
|
||||
parser_datasource = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("name", type=StrLen(max_length=100), required=False, nullable=True, location="json", default=None)
|
||||
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/auth/plugin/datasource/<path:provider_id>")
|
||||
class DatasourceAuth(Resource):
|
||||
@console_ns.expect(parser_datasource)
|
||||
@console_ns.expect(console_ns.models[DatasourceCredentialPayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -138,7 +174,7 @@ class DatasourceAuth(Resource):
|
|||
def post(self, provider_id: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
args = parser_datasource.parse_args()
|
||||
payload = DatasourceCredentialPayload.model_validate(console_ns.payload or {})
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
|
||||
|
|
@ -146,8 +182,8 @@ class DatasourceAuth(Resource):
|
|||
datasource_provider_service.add_datasource_api_key_provider(
|
||||
tenant_id=current_tenant_id,
|
||||
provider_id=datasource_provider_id,
|
||||
credentials=args["credentials"],
|
||||
name=args["name"],
|
||||
credentials=payload.credentials,
|
||||
name=payload.name,
|
||||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
raise ValueError(str(ex))
|
||||
|
|
@ -169,14 +205,9 @@ class DatasourceAuth(Resource):
|
|||
return {"result": datasources}, 200
|
||||
|
||||
|
||||
parser_datasource_delete = reqparse.RequestParser().add_argument(
|
||||
"credential_id", type=str, required=True, nullable=False, location="json"
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/delete")
|
||||
class DatasourceAuthDeleteApi(Resource):
|
||||
@console_ns.expect(parser_datasource_delete)
|
||||
@console_ns.expect(console_ns.models[DatasourceCredentialDeletePayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -188,28 +219,20 @@ class DatasourceAuthDeleteApi(Resource):
|
|||
plugin_id = datasource_provider_id.plugin_id
|
||||
provider_name = datasource_provider_id.provider_name
|
||||
|
||||
args = parser_datasource_delete.parse_args()
|
||||
payload = DatasourceCredentialDeletePayload.model_validate(console_ns.payload or {})
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
datasource_provider_service.remove_datasource_credentials(
|
||||
tenant_id=current_tenant_id,
|
||||
auth_id=args["credential_id"],
|
||||
auth_id=payload.credential_id,
|
||||
provider=provider_name,
|
||||
plugin_id=plugin_id,
|
||||
)
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
parser_datasource_update = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
|
||||
.add_argument("name", type=StrLen(max_length=100), required=False, nullable=True, location="json")
|
||||
.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/update")
|
||||
class DatasourceAuthUpdateApi(Resource):
|
||||
@console_ns.expect(parser_datasource_update)
|
||||
@console_ns.expect(console_ns.models[DatasourceCredentialUpdatePayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -218,16 +241,16 @@ class DatasourceAuthUpdateApi(Resource):
|
|||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
args = parser_datasource_update.parse_args()
|
||||
payload = DatasourceCredentialUpdatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
datasource_provider_service.update_datasource_credentials(
|
||||
tenant_id=current_tenant_id,
|
||||
auth_id=args["credential_id"],
|
||||
auth_id=payload.credential_id,
|
||||
provider=datasource_provider_id.provider_name,
|
||||
plugin_id=datasource_provider_id.plugin_id,
|
||||
credentials=args.get("credentials", {}),
|
||||
name=args.get("name", None),
|
||||
credentials=payload.credentials or {},
|
||||
name=payload.name,
|
||||
)
|
||||
return {"result": "success"}, 201
|
||||
|
||||
|
|
@ -258,16 +281,9 @@ class DatasourceHardCodeAuthListApi(Resource):
|
|||
return {"result": jsonable_encoder(datasources)}, 200
|
||||
|
||||
|
||||
parser_datasource_custom = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("client_params", type=dict, required=False, nullable=True, location="json")
|
||||
.add_argument("enable_oauth_custom_client", type=bool, required=False, nullable=True, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/custom-client")
|
||||
class DatasourceAuthOauthCustomClient(Resource):
|
||||
@console_ns.expect(parser_datasource_custom)
|
||||
@console_ns.expect(console_ns.models[DatasourceCustomClientPayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -275,14 +291,14 @@ class DatasourceAuthOauthCustomClient(Resource):
|
|||
def post(self, provider_id: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
args = parser_datasource_custom.parse_args()
|
||||
payload = DatasourceCustomClientPayload.model_validate(console_ns.payload or {})
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
datasource_provider_service.setup_oauth_custom_client_params(
|
||||
tenant_id=current_tenant_id,
|
||||
datasource_provider_id=datasource_provider_id,
|
||||
client_params=args.get("client_params", {}),
|
||||
enabled=args.get("enable_oauth_custom_client", False),
|
||||
client_params=payload.client_params or {},
|
||||
enabled=payload.enable_oauth_custom_client or False,
|
||||
)
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
|
@ -301,12 +317,9 @@ class DatasourceAuthOauthCustomClient(Resource):
|
|||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
parser_default = reqparse.RequestParser().add_argument("id", type=str, required=True, nullable=False, location="json")
|
||||
|
||||
|
||||
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/default")
|
||||
class DatasourceAuthDefaultApi(Resource):
|
||||
@console_ns.expect(parser_default)
|
||||
@console_ns.expect(console_ns.models[DatasourceDefaultPayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -314,27 +327,20 @@ class DatasourceAuthDefaultApi(Resource):
|
|||
def post(self, provider_id: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
args = parser_default.parse_args()
|
||||
payload = DatasourceDefaultPayload.model_validate(console_ns.payload or {})
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
datasource_provider_service.set_default_datasource_provider(
|
||||
tenant_id=current_tenant_id,
|
||||
datasource_provider_id=datasource_provider_id,
|
||||
credential_id=args["id"],
|
||||
credential_id=payload.id,
|
||||
)
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
parser_update_name = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("name", type=StrLen(max_length=100), required=True, nullable=False, location="json")
|
||||
.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/update-name")
|
||||
class DatasourceUpdateProviderNameApi(Resource):
|
||||
@console_ns.expect(parser_update_name)
|
||||
@console_ns.expect(console_ns.models[DatasourceUpdateNamePayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -342,13 +348,13 @@ class DatasourceUpdateProviderNameApi(Resource):
|
|||
def post(self, provider_id: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
args = parser_update_name.parse_args()
|
||||
payload = DatasourceUpdateNamePayload.model_validate(console_ns.payload or {})
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
datasource_provider_service.update_datasource_provider_name(
|
||||
tenant_id=current_tenant_id,
|
||||
datasource_provider_id=datasource_provider_id,
|
||||
name=args["name"],
|
||||
credential_id=args["credential_id"],
|
||||
name=payload.name,
|
||||
credential_id=payload.credential_id,
|
||||
)
|
||||
return {"result": "success"}, 200
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ console_ns.schema_model(Parser.__name__, Parser.model_json_schema(ref_template=D
|
|||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/datasource/nodes/<string:node_id>/preview")
|
||||
class DataSourceContentPreviewApi(Resource):
|
||||
@console_ns.expect(console_ns.models[Parser.__name__], validate=True)
|
||||
@console_ns.expect(console_ns.models[Parser.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
|
|||
|
|
@ -1,9 +1,11 @@
|
|||
import logging
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, reqparse
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
|
|
@ -20,18 +22,6 @@ from services.rag_pipeline.rag_pipeline import RagPipelineService
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _validate_name(name: str) -> str:
|
||||
if not name or len(name) < 1 or len(name) > 40:
|
||||
raise ValueError("Name must be between 1 to 40 characters.")
|
||||
return name
|
||||
|
||||
|
||||
def _validate_description_length(description: str) -> str:
|
||||
if len(description) > 400:
|
||||
raise ValueError("Description cannot exceed 400 characters.")
|
||||
return description
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipeline/templates")
|
||||
class PipelineTemplateListApi(Resource):
|
||||
@setup_required
|
||||
|
|
@ -59,6 +49,15 @@ class PipelineTemplateDetailApi(Resource):
|
|||
return pipeline_template, 200
|
||||
|
||||
|
||||
class Payload(BaseModel):
|
||||
name: str = Field(..., min_length=1, max_length=40)
|
||||
description: str = Field(default="", max_length=400)
|
||||
icon_info: dict[str, object] | None = None
|
||||
|
||||
|
||||
register_schema_models(console_ns, Payload)
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipeline/customized/templates/<string:template_id>")
|
||||
class CustomizedPipelineTemplateApi(Resource):
|
||||
@setup_required
|
||||
|
|
@ -66,31 +65,8 @@ class CustomizedPipelineTemplateApi(Resource):
|
|||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
def patch(self, template_id: str):
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument(
|
||||
"name",
|
||||
nullable=False,
|
||||
required=True,
|
||||
help="Name must be between 1 to 40 characters.",
|
||||
type=_validate_name,
|
||||
)
|
||||
.add_argument(
|
||||
"description",
|
||||
type=_validate_description_length,
|
||||
nullable=True,
|
||||
required=False,
|
||||
default="",
|
||||
)
|
||||
.add_argument(
|
||||
"icon_info",
|
||||
type=dict,
|
||||
location="json",
|
||||
nullable=True,
|
||||
)
|
||||
)
|
||||
args = parser.parse_args()
|
||||
pipeline_template_info = PipelineTemplateInfoEntity.model_validate(args)
|
||||
payload = Payload.model_validate(console_ns.payload or {})
|
||||
pipeline_template_info = PipelineTemplateInfoEntity.model_validate(payload.model_dump())
|
||||
RagPipelineService.update_customized_pipeline_template(template_id, pipeline_template_info)
|
||||
return 200
|
||||
|
||||
|
|
@ -119,36 +95,14 @@ class CustomizedPipelineTemplateApi(Resource):
|
|||
|
||||
@console_ns.route("/rag/pipelines/<string:pipeline_id>/customized/publish")
|
||||
class PublishCustomizedPipelineTemplateApi(Resource):
|
||||
@console_ns.expect(console_ns.models[Payload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
@knowledge_pipeline_publish_enabled
|
||||
def post(self, pipeline_id: str):
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument(
|
||||
"name",
|
||||
nullable=False,
|
||||
required=True,
|
||||
help="Name must be between 1 to 40 characters.",
|
||||
type=_validate_name,
|
||||
)
|
||||
.add_argument(
|
||||
"description",
|
||||
type=_validate_description_length,
|
||||
nullable=True,
|
||||
required=False,
|
||||
default="",
|
||||
)
|
||||
.add_argument(
|
||||
"icon_info",
|
||||
type=dict,
|
||||
location="json",
|
||||
nullable=True,
|
||||
)
|
||||
)
|
||||
args = parser.parse_args()
|
||||
payload = Payload.model_validate(console_ns.payload or {})
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
rag_pipeline_service.publish_customized_pipeline_template(pipeline_id, args)
|
||||
rag_pipeline_service.publish_customized_pipeline_template(pipeline_id, payload.model_dump())
|
||||
return {"result": "success"}
|
||||
|
|
|
|||
|
|
@ -1,8 +1,10 @@
|
|||
from flask_restx import Resource, marshal, reqparse
|
||||
from flask_restx import Resource, marshal
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
import services
|
||||
from controllers.common.schema import register_schema_model
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.datasets.error import DatasetNameDuplicateError
|
||||
from controllers.console.wraps import (
|
||||
|
|
@ -19,22 +21,22 @@ from services.entities.knowledge_entities.rag_pipeline_entities import IconInfo,
|
|||
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
|
||||
|
||||
|
||||
class RagPipelineDatasetImportPayload(BaseModel):
|
||||
yaml_content: str
|
||||
|
||||
|
||||
register_schema_model(console_ns, RagPipelineDatasetImportPayload)
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipeline/dataset")
|
||||
class CreateRagPipelineDatasetApi(Resource):
|
||||
@console_ns.expect(console_ns.models[RagPipelineDatasetImportPayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser().add_argument(
|
||||
"yaml_content",
|
||||
type=str,
|
||||
nullable=False,
|
||||
required=True,
|
||||
help="yaml_content is required.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
payload = RagPipelineDatasetImportPayload.model_validate(console_ns.payload or {})
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
|
||||
if not current_user.is_dataset_editor:
|
||||
|
|
@ -49,7 +51,7 @@ class CreateRagPipelineDatasetApi(Resource):
|
|||
),
|
||||
permission=DatasetPermissionEnum.ONLY_ME,
|
||||
partial_member_list=None,
|
||||
yaml_content=args["yaml_content"],
|
||||
yaml_content=payload.yaml_content,
|
||||
)
|
||||
try:
|
||||
with Session(db.engine) as session:
|
||||
|
|
|
|||
|
|
@ -1,11 +1,13 @@
|
|||
import logging
|
||||
from typing import NoReturn
|
||||
from typing import Any, NoReturn
|
||||
|
||||
from flask import Response
|
||||
from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse
|
||||
from flask import Response, request
|
||||
from flask_restx import Resource, fields, marshal, marshal_with
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import (
|
||||
DraftWorkflowNotExist,
|
||||
|
|
@ -33,19 +35,21 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
def _create_pagination_parser():
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument(
|
||||
"page",
|
||||
type=inputs.int_range(1, 100_000),
|
||||
required=False,
|
||||
default=1,
|
||||
location="args",
|
||||
help="the page of data requested",
|
||||
)
|
||||
.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
|
||||
)
|
||||
return parser
|
||||
class PaginationQuery(BaseModel):
|
||||
page: int = Field(default=1, ge=1, le=100_000)
|
||||
limit: int = Field(default=20, ge=1, le=100)
|
||||
|
||||
register_schema_models(console_ns, PaginationQuery)
|
||||
|
||||
return PaginationQuery
|
||||
|
||||
|
||||
class WorkflowDraftVariablePatchPayload(BaseModel):
|
||||
name: str | None = None
|
||||
value: Any | None = None
|
||||
|
||||
|
||||
register_schema_models(console_ns, WorkflowDraftVariablePatchPayload)
|
||||
|
||||
|
||||
def _get_items(var_list: WorkflowDraftVariableList) -> list[WorkflowDraftVariable]:
|
||||
|
|
@ -93,8 +97,8 @@ class RagPipelineVariableCollectionApi(Resource):
|
|||
"""
|
||||
Get draft workflow
|
||||
"""
|
||||
parser = _create_pagination_parser()
|
||||
args = parser.parse_args()
|
||||
pagination = _create_pagination_parser()
|
||||
query = pagination.model_validate(request.args.to_dict())
|
||||
|
||||
# fetch draft workflow by app_model
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
|
|
@ -109,8 +113,8 @@ class RagPipelineVariableCollectionApi(Resource):
|
|||
)
|
||||
workflow_vars = draft_var_srv.list_variables_without_values(
|
||||
app_id=pipeline.id,
|
||||
page=args.page,
|
||||
limit=args.limit,
|
||||
page=query.page,
|
||||
limit=query.limit,
|
||||
)
|
||||
|
||||
return workflow_vars
|
||||
|
|
@ -186,6 +190,7 @@ class RagPipelineVariableApi(Resource):
|
|||
|
||||
@_api_prerequisite
|
||||
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_FIELDS)
|
||||
@console_ns.expect(console_ns.models[WorkflowDraftVariablePatchPayload.__name__])
|
||||
def patch(self, pipeline: Pipeline, variable_id: str):
|
||||
# Request payload for file types:
|
||||
#
|
||||
|
|
@ -208,16 +213,11 @@ class RagPipelineVariableApi(Resource):
|
|||
# "upload_file_id": "1602650a-4fe4-423c-85a2-af76c083e3c4"
|
||||
# }
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument(self._PATCH_NAME_FIELD, type=str, required=False, nullable=True, location="json")
|
||||
.add_argument(self._PATCH_VALUE_FIELD, type=lambda x: x, required=False, nullable=True, location="json")
|
||||
)
|
||||
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=db.session(),
|
||||
)
|
||||
args = parser.parse_args(strict=True)
|
||||
payload = WorkflowDraftVariablePatchPayload.model_validate(console_ns.payload or {})
|
||||
args = payload.model_dump(exclude_none=True)
|
||||
|
||||
variable = draft_var_srv.get_variable(variable_id=variable_id)
|
||||
if variable is None:
|
||||
|
|
|
|||
|
|
@ -1,6 +1,9 @@
|
|||
from flask_restx import Resource, marshal_with, reqparse # type: ignore
|
||||
from flask import request
|
||||
from flask_restx import Resource, marshal_with # type: ignore
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.datasets.wraps import get_rag_pipeline
|
||||
from controllers.console.wraps import (
|
||||
|
|
@ -16,6 +19,25 @@ from services.app_dsl_service import ImportStatus
|
|||
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
|
||||
|
||||
|
||||
class RagPipelineImportPayload(BaseModel):
|
||||
mode: str
|
||||
yaml_content: str | None = None
|
||||
yaml_url: str | None = None
|
||||
name: str | None = None
|
||||
description: str | None = None
|
||||
icon_type: str | None = None
|
||||
icon: str | None = None
|
||||
icon_background: str | None = None
|
||||
pipeline_id: str | None = None
|
||||
|
||||
|
||||
class IncludeSecretQuery(BaseModel):
|
||||
include_secret: str = Field(default="false")
|
||||
|
||||
|
||||
register_schema_models(console_ns, RagPipelineImportPayload, IncludeSecretQuery)
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/imports")
|
||||
class RagPipelineImportApi(Resource):
|
||||
@setup_required
|
||||
|
|
@ -23,23 +45,11 @@ class RagPipelineImportApi(Resource):
|
|||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@marshal_with(pipeline_import_fields)
|
||||
@console_ns.expect(console_ns.models[RagPipelineImportPayload.__name__])
|
||||
def post(self):
|
||||
# Check user role first
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("mode", type=str, required=True, location="json")
|
||||
.add_argument("yaml_content", type=str, location="json")
|
||||
.add_argument("yaml_url", type=str, location="json")
|
||||
.add_argument("name", type=str, location="json")
|
||||
.add_argument("description", type=str, location="json")
|
||||
.add_argument("icon_type", type=str, location="json")
|
||||
.add_argument("icon", type=str, location="json")
|
||||
.add_argument("icon_background", type=str, location="json")
|
||||
.add_argument("pipeline_id", type=str, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
payload = RagPipelineImportPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
# Create service with session
|
||||
with Session(db.engine) as session:
|
||||
|
|
@ -48,11 +58,11 @@ class RagPipelineImportApi(Resource):
|
|||
account = current_user
|
||||
result = import_service.import_rag_pipeline(
|
||||
account=account,
|
||||
import_mode=args["mode"],
|
||||
yaml_content=args.get("yaml_content"),
|
||||
yaml_url=args.get("yaml_url"),
|
||||
pipeline_id=args.get("pipeline_id"),
|
||||
dataset_name=args.get("name"),
|
||||
import_mode=payload.mode,
|
||||
yaml_content=payload.yaml_content,
|
||||
yaml_url=payload.yaml_url,
|
||||
pipeline_id=payload.pipeline_id,
|
||||
dataset_name=payload.name,
|
||||
)
|
||||
session.commit()
|
||||
|
||||
|
|
@ -114,13 +124,12 @@ class RagPipelineExportApi(Resource):
|
|||
@edit_permission_required
|
||||
def get(self, pipeline: Pipeline):
|
||||
# Add include_secret params
|
||||
parser = reqparse.RequestParser().add_argument("include_secret", type=str, default="false", location="args")
|
||||
args = parser.parse_args()
|
||||
query = IncludeSecretQuery.model_validate(request.args.to_dict())
|
||||
|
||||
with Session(db.engine) as session:
|
||||
export_service = RagPipelineDslService(session)
|
||||
result = export_service.export_rag_pipeline_dsl(
|
||||
pipeline=pipeline, include_secret=args["include_secret"] == "true"
|
||||
pipeline=pipeline, include_secret=query.include_secret == "true"
|
||||
)
|
||||
|
||||
return {"data": result}, 200
|
||||
|
|
|
|||
|
|
@ -1,14 +1,16 @@
|
|||
import json
|
||||
import logging
|
||||
from typing import cast
|
||||
from typing import Any, Literal, cast
|
||||
from uuid import UUID
|
||||
|
||||
from flask import abort, request
|
||||
from flask_restx import Resource, inputs, marshal_with, reqparse # type: ignore # type: ignore
|
||||
from flask_restx.inputs import int_range # type: ignore
|
||||
from flask_restx import Resource, marshal_with # type: ignore
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import (
|
||||
ConversationCompletedError,
|
||||
|
|
@ -36,7 +38,7 @@ from fields.workflow_run_fields import (
|
|||
workflow_run_pagination_fields,
|
||||
)
|
||||
from libs import helper
|
||||
from libs.helper import TimestampField, uuid_value
|
||||
from libs.helper import TimestampField
|
||||
from libs.login import current_account_with_tenant, current_user, login_required
|
||||
from models import Account
|
||||
from models.dataset import Pipeline
|
||||
|
|
@ -51,6 +53,91 @@ from services.rag_pipeline.rag_pipeline_transform_service import RagPipelineTran
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DraftWorkflowSyncPayload(BaseModel):
|
||||
graph: dict[str, Any]
|
||||
hash: str | None = None
|
||||
environment_variables: list[dict[str, Any]] | None = None
|
||||
conversation_variables: list[dict[str, Any]] | None = None
|
||||
rag_pipeline_variables: list[dict[str, Any]] | None = None
|
||||
features: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class NodeRunPayload(BaseModel):
|
||||
inputs: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class NodeRunRequiredPayload(BaseModel):
|
||||
inputs: dict[str, Any]
|
||||
|
||||
|
||||
class DatasourceNodeRunPayload(BaseModel):
|
||||
inputs: dict[str, Any]
|
||||
datasource_type: str
|
||||
credential_id: str | None = None
|
||||
|
||||
|
||||
class DraftWorkflowRunPayload(BaseModel):
|
||||
inputs: dict[str, Any]
|
||||
datasource_type: str
|
||||
datasource_info_list: list[dict[str, Any]]
|
||||
start_node_id: str
|
||||
|
||||
|
||||
class PublishedWorkflowRunPayload(DraftWorkflowRunPayload):
|
||||
is_preview: bool = False
|
||||
response_mode: Literal["streaming", "blocking"] = "streaming"
|
||||
original_document_id: str | None = None
|
||||
|
||||
|
||||
class DefaultBlockConfigQuery(BaseModel):
|
||||
q: str | None = None
|
||||
|
||||
|
||||
class WorkflowListQuery(BaseModel):
|
||||
page: int = Field(default=1, ge=1, le=99999)
|
||||
limit: int = Field(default=10, ge=1, le=100)
|
||||
user_id: str | None = None
|
||||
named_only: bool = False
|
||||
|
||||
|
||||
class WorkflowUpdatePayload(BaseModel):
|
||||
marked_name: str | None = Field(default=None, max_length=20)
|
||||
marked_comment: str | None = Field(default=None, max_length=100)
|
||||
|
||||
|
||||
class NodeIdQuery(BaseModel):
|
||||
node_id: str
|
||||
|
||||
|
||||
class WorkflowRunQuery(BaseModel):
|
||||
last_id: UUID | None = None
|
||||
limit: int = Field(default=20, ge=1, le=100)
|
||||
|
||||
|
||||
class DatasourceVariablesPayload(BaseModel):
|
||||
datasource_type: str
|
||||
datasource_info: dict[str, Any]
|
||||
start_node_id: str
|
||||
start_node_title: str
|
||||
|
||||
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
DraftWorkflowSyncPayload,
|
||||
NodeRunPayload,
|
||||
NodeRunRequiredPayload,
|
||||
DatasourceNodeRunPayload,
|
||||
DraftWorkflowRunPayload,
|
||||
PublishedWorkflowRunPayload,
|
||||
DefaultBlockConfigQuery,
|
||||
WorkflowListQuery,
|
||||
WorkflowUpdatePayload,
|
||||
NodeIdQuery,
|
||||
WorkflowRunQuery,
|
||||
DatasourceVariablesPayload,
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft")
|
||||
class DraftRagPipelineApi(Resource):
|
||||
@setup_required
|
||||
|
|
@ -88,15 +175,7 @@ class DraftRagPipelineApi(Resource):
|
|||
content_type = request.headers.get("Content-Type", "")
|
||||
|
||||
if "application/json" in content_type:
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("graph", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("hash", type=str, required=False, location="json")
|
||||
.add_argument("environment_variables", type=list, required=False, location="json")
|
||||
.add_argument("conversation_variables", type=list, required=False, location="json")
|
||||
.add_argument("rag_pipeline_variables", type=list, required=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
payload_dict = console_ns.payload or {}
|
||||
elif "text/plain" in content_type:
|
||||
try:
|
||||
data = json.loads(request.data.decode("utf-8"))
|
||||
|
|
@ -106,7 +185,7 @@ class DraftRagPipelineApi(Resource):
|
|||
if not isinstance(data.get("graph"), dict):
|
||||
raise ValueError("graph is not a dict")
|
||||
|
||||
args = {
|
||||
payload_dict = {
|
||||
"graph": data.get("graph"),
|
||||
"features": data.get("features"),
|
||||
"hash": data.get("hash"),
|
||||
|
|
@ -119,24 +198,26 @@ class DraftRagPipelineApi(Resource):
|
|||
else:
|
||||
abort(415)
|
||||
|
||||
payload = DraftWorkflowSyncPayload.model_validate(payload_dict)
|
||||
|
||||
try:
|
||||
environment_variables_list = args.get("environment_variables") or []
|
||||
environment_variables_list = payload.environment_variables or []
|
||||
environment_variables = [
|
||||
variable_factory.build_environment_variable_from_mapping(obj) for obj in environment_variables_list
|
||||
]
|
||||
conversation_variables_list = args.get("conversation_variables") or []
|
||||
conversation_variables_list = payload.conversation_variables or []
|
||||
conversation_variables = [
|
||||
variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list
|
||||
]
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
workflow = rag_pipeline_service.sync_draft_workflow(
|
||||
pipeline=pipeline,
|
||||
graph=args["graph"],
|
||||
unique_hash=args.get("hash"),
|
||||
graph=payload.graph,
|
||||
unique_hash=payload.hash,
|
||||
account=current_user,
|
||||
environment_variables=environment_variables,
|
||||
conversation_variables=conversation_variables,
|
||||
rag_pipeline_variables=args.get("rag_pipeline_variables") or [],
|
||||
rag_pipeline_variables=payload.rag_pipeline_variables or [],
|
||||
)
|
||||
except WorkflowHashNotEqualError:
|
||||
raise DraftWorkflowNotSync()
|
||||
|
|
@ -148,12 +229,9 @@ class DraftRagPipelineApi(Resource):
|
|||
}
|
||||
|
||||
|
||||
parser_run = reqparse.RequestParser().add_argument("inputs", type=dict, location="json")
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/iteration/nodes/<string:node_id>/run")
|
||||
class RagPipelineDraftRunIterationNodeApi(Resource):
|
||||
@console_ns.expect(parser_run)
|
||||
@console_ns.expect(console_ns.models[NodeRunPayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -166,7 +244,8 @@ class RagPipelineDraftRunIterationNodeApi(Resource):
|
|||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
args = parser_run.parse_args()
|
||||
payload = NodeRunPayload.model_validate(console_ns.payload or {})
|
||||
args = payload.model_dump(exclude_none=True)
|
||||
|
||||
try:
|
||||
response = PipelineGenerateService.generate_single_iteration(
|
||||
|
|
@ -187,7 +266,7 @@ class RagPipelineDraftRunIterationNodeApi(Resource):
|
|||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/loop/nodes/<string:node_id>/run")
|
||||
class RagPipelineDraftRunLoopNodeApi(Resource):
|
||||
@console_ns.expect(parser_run)
|
||||
@console_ns.expect(console_ns.models[NodeRunPayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -200,7 +279,8 @@ class RagPipelineDraftRunLoopNodeApi(Resource):
|
|||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
args = parser_run.parse_args()
|
||||
payload = NodeRunPayload.model_validate(console_ns.payload or {})
|
||||
args = payload.model_dump(exclude_none=True)
|
||||
|
||||
try:
|
||||
response = PipelineGenerateService.generate_single_loop(
|
||||
|
|
@ -219,18 +299,9 @@ class RagPipelineDraftRunLoopNodeApi(Resource):
|
|||
raise InternalServerError()
|
||||
|
||||
|
||||
parser_draft_run = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("datasource_type", type=str, required=True, location="json")
|
||||
.add_argument("datasource_info_list", type=list, required=True, location="json")
|
||||
.add_argument("start_node_id", type=str, required=True, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/run")
|
||||
class DraftRagPipelineRunApi(Resource):
|
||||
@console_ns.expect(parser_draft_run)
|
||||
@console_ns.expect(console_ns.models[DraftWorkflowRunPayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -243,7 +314,8 @@ class DraftRagPipelineRunApi(Resource):
|
|||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
args = parser_draft_run.parse_args()
|
||||
payload = DraftWorkflowRunPayload.model_validate(console_ns.payload or {})
|
||||
args = payload.model_dump()
|
||||
|
||||
try:
|
||||
response = PipelineGenerateService.generate(
|
||||
|
|
@ -259,21 +331,9 @@ class DraftRagPipelineRunApi(Resource):
|
|||
raise InvokeRateLimitHttpError(ex.description)
|
||||
|
||||
|
||||
parser_published_run = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("datasource_type", type=str, required=True, location="json")
|
||||
.add_argument("datasource_info_list", type=list, required=True, location="json")
|
||||
.add_argument("start_node_id", type=str, required=True, location="json")
|
||||
.add_argument("is_preview", type=bool, required=True, location="json", default=False)
|
||||
.add_argument("response_mode", type=str, required=True, location="json", default="streaming")
|
||||
.add_argument("original_document_id", type=str, required=False, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/run")
|
||||
class PublishedRagPipelineRunApi(Resource):
|
||||
@console_ns.expect(parser_published_run)
|
||||
@console_ns.expect(console_ns.models[PublishedWorkflowRunPayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -286,16 +346,16 @@ class PublishedRagPipelineRunApi(Resource):
|
|||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
args = parser_published_run.parse_args()
|
||||
|
||||
streaming = args["response_mode"] == "streaming"
|
||||
payload = PublishedWorkflowRunPayload.model_validate(console_ns.payload or {})
|
||||
args = payload.model_dump(exclude_none=True)
|
||||
streaming = payload.response_mode == "streaming"
|
||||
|
||||
try:
|
||||
response = PipelineGenerateService.generate(
|
||||
pipeline=pipeline,
|
||||
user=current_user,
|
||||
args=args,
|
||||
invoke_from=InvokeFrom.DEBUGGER if args.get("is_preview") else InvokeFrom.PUBLISHED,
|
||||
invoke_from=InvokeFrom.DEBUGGER if payload.is_preview else InvokeFrom.PUBLISHED,
|
||||
streaming=streaming,
|
||||
)
|
||||
|
||||
|
|
@ -387,17 +447,9 @@ class PublishedRagPipelineRunApi(Resource):
|
|||
#
|
||||
# return result
|
||||
#
|
||||
parser_rag_run = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("datasource_type", type=str, required=True, location="json")
|
||||
.add_argument("credential_id", type=str, required=False, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/datasource/nodes/<string:node_id>/run")
|
||||
class RagPipelinePublishedDatasourceNodeRunApi(Resource):
|
||||
@console_ns.expect(parser_rag_run)
|
||||
@console_ns.expect(console_ns.models[DatasourceNodeRunPayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -410,14 +462,7 @@ class RagPipelinePublishedDatasourceNodeRunApi(Resource):
|
|||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
args = parser_rag_run.parse_args()
|
||||
|
||||
inputs = args.get("inputs")
|
||||
if inputs is None:
|
||||
raise ValueError("missing inputs")
|
||||
datasource_type = args.get("datasource_type")
|
||||
if datasource_type is None:
|
||||
raise ValueError("missing datasource_type")
|
||||
payload = DatasourceNodeRunPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
return helper.compact_generate_response(
|
||||
|
|
@ -425,11 +470,11 @@ class RagPipelinePublishedDatasourceNodeRunApi(Resource):
|
|||
rag_pipeline_service.run_datasource_workflow_node(
|
||||
pipeline=pipeline,
|
||||
node_id=node_id,
|
||||
user_inputs=inputs,
|
||||
user_inputs=payload.inputs,
|
||||
account=current_user,
|
||||
datasource_type=datasource_type,
|
||||
datasource_type=payload.datasource_type,
|
||||
is_published=False,
|
||||
credential_id=args.get("credential_id"),
|
||||
credential_id=payload.credential_id,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
|
@ -437,7 +482,7 @@ class RagPipelinePublishedDatasourceNodeRunApi(Resource):
|
|||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/datasource/nodes/<string:node_id>/run")
|
||||
class RagPipelineDraftDatasourceNodeRunApi(Resource):
|
||||
@console_ns.expect(parser_rag_run)
|
||||
@console_ns.expect(console_ns.models[DatasourceNodeRunPayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@edit_permission_required
|
||||
|
|
@ -450,14 +495,7 @@ class RagPipelineDraftDatasourceNodeRunApi(Resource):
|
|||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
args = parser_rag_run.parse_args()
|
||||
|
||||
inputs = args.get("inputs")
|
||||
if inputs is None:
|
||||
raise ValueError("missing inputs")
|
||||
datasource_type = args.get("datasource_type")
|
||||
if datasource_type is None:
|
||||
raise ValueError("missing datasource_type")
|
||||
payload = DatasourceNodeRunPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
return helper.compact_generate_response(
|
||||
|
|
@ -465,24 +503,19 @@ class RagPipelineDraftDatasourceNodeRunApi(Resource):
|
|||
rag_pipeline_service.run_datasource_workflow_node(
|
||||
pipeline=pipeline,
|
||||
node_id=node_id,
|
||||
user_inputs=inputs,
|
||||
user_inputs=payload.inputs,
|
||||
account=current_user,
|
||||
datasource_type=datasource_type,
|
||||
datasource_type=payload.datasource_type,
|
||||
is_published=False,
|
||||
credential_id=args.get("credential_id"),
|
||||
credential_id=payload.credential_id,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
parser_run_api = reqparse.RequestParser().add_argument(
|
||||
"inputs", type=dict, required=True, nullable=False, location="json"
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/nodes/<string:node_id>/run")
|
||||
class RagPipelineDraftNodeRunApi(Resource):
|
||||
@console_ns.expect(parser_run_api)
|
||||
@console_ns.expect(console_ns.models[NodeRunRequiredPayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@edit_permission_required
|
||||
|
|
@ -496,11 +529,8 @@ class RagPipelineDraftNodeRunApi(Resource):
|
|||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
args = parser_run_api.parse_args()
|
||||
|
||||
inputs = args.get("inputs")
|
||||
if inputs == None:
|
||||
raise ValueError("missing inputs")
|
||||
payload = NodeRunRequiredPayload.model_validate(console_ns.payload or {})
|
||||
inputs = payload.inputs
|
||||
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
workflow_node_execution = rag_pipeline_service.run_draft_workflow_node(
|
||||
|
|
@ -602,12 +632,8 @@ class DefaultRagPipelineBlockConfigsApi(Resource):
|
|||
return rag_pipeline_service.get_default_block_configs()
|
||||
|
||||
|
||||
parser_default = reqparse.RequestParser().add_argument("q", type=str, location="args")
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/default-workflow-block-configs/<string:block_type>")
|
||||
class DefaultRagPipelineBlockConfigApi(Resource):
|
||||
@console_ns.expect(parser_default)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -617,14 +643,12 @@ class DefaultRagPipelineBlockConfigApi(Resource):
|
|||
"""
|
||||
Get default block config
|
||||
"""
|
||||
args = parser_default.parse_args()
|
||||
|
||||
q = args.get("q")
|
||||
query = DefaultBlockConfigQuery.model_validate(request.args.to_dict())
|
||||
|
||||
filters = None
|
||||
if q:
|
||||
if query.q:
|
||||
try:
|
||||
filters = json.loads(args.get("q", ""))
|
||||
filters = json.loads(query.q)
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError("Invalid filters")
|
||||
|
||||
|
|
@ -633,18 +657,8 @@ class DefaultRagPipelineBlockConfigApi(Resource):
|
|||
return rag_pipeline_service.get_default_block_config(node_type=block_type, filters=filters)
|
||||
|
||||
|
||||
parser_wf = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args")
|
||||
.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=10, location="args")
|
||||
.add_argument("user_id", type=str, required=False, location="args")
|
||||
.add_argument("named_only", type=inputs.boolean, required=False, default=False, location="args")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows")
|
||||
class PublishedAllRagPipelineApi(Resource):
|
||||
@console_ns.expect(parser_wf)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -657,16 +671,16 @@ class PublishedAllRagPipelineApi(Resource):
|
|||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
args = parser_wf.parse_args()
|
||||
page = args["page"]
|
||||
limit = args["limit"]
|
||||
user_id = args.get("user_id")
|
||||
named_only = args.get("named_only", False)
|
||||
query = WorkflowListQuery.model_validate(request.args.to_dict())
|
||||
|
||||
page = query.page
|
||||
limit = query.limit
|
||||
user_id = query.user_id
|
||||
named_only = query.named_only
|
||||
|
||||
if user_id:
|
||||
if user_id != current_user.id:
|
||||
raise Forbidden()
|
||||
user_id = cast(str, user_id)
|
||||
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
with Session(db.engine) as session:
|
||||
|
|
@ -687,16 +701,8 @@ class PublishedAllRagPipelineApi(Resource):
|
|||
}
|
||||
|
||||
|
||||
parser_wf_id = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("marked_name", type=str, required=False, location="json")
|
||||
.add_argument("marked_comment", type=str, required=False, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/<string:workflow_id>")
|
||||
class RagPipelineByIdApi(Resource):
|
||||
@console_ns.expect(parser_wf_id)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -710,20 +716,8 @@ class RagPipelineByIdApi(Resource):
|
|||
# Check permission
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
args = parser_wf_id.parse_args()
|
||||
|
||||
# Validate name and comment length
|
||||
if args.marked_name and len(args.marked_name) > 20:
|
||||
raise ValueError("Marked name cannot exceed 20 characters")
|
||||
if args.marked_comment and len(args.marked_comment) > 100:
|
||||
raise ValueError("Marked comment cannot exceed 100 characters")
|
||||
|
||||
# Prepare update data
|
||||
update_data = {}
|
||||
if args.get("marked_name") is not None:
|
||||
update_data["marked_name"] = args["marked_name"]
|
||||
if args.get("marked_comment") is not None:
|
||||
update_data["marked_comment"] = args["marked_comment"]
|
||||
payload = WorkflowUpdatePayload.model_validate(console_ns.payload or {})
|
||||
update_data = payload.model_dump(exclude_unset=True)
|
||||
|
||||
if not update_data:
|
||||
return {"message": "No valid fields to update"}, 400
|
||||
|
|
@ -749,12 +743,8 @@ class RagPipelineByIdApi(Resource):
|
|||
return workflow
|
||||
|
||||
|
||||
parser_parameters = reqparse.RequestParser().add_argument("node_id", type=str, required=True, location="args")
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/processing/parameters")
|
||||
class PublishedRagPipelineSecondStepApi(Resource):
|
||||
@console_ns.expect(parser_parameters)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -764,10 +754,8 @@ class PublishedRagPipelineSecondStepApi(Resource):
|
|||
"""
|
||||
Get second step parameters of rag pipeline
|
||||
"""
|
||||
args = parser_parameters.parse_args()
|
||||
node_id = args.get("node_id")
|
||||
if not node_id:
|
||||
raise ValueError("Node ID is required")
|
||||
query = NodeIdQuery.model_validate(request.args.to_dict())
|
||||
node_id = query.node_id
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
variables = rag_pipeline_service.get_second_step_parameters(pipeline=pipeline, node_id=node_id, is_draft=False)
|
||||
return {
|
||||
|
|
@ -777,7 +765,6 @@ class PublishedRagPipelineSecondStepApi(Resource):
|
|||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/pre-processing/parameters")
|
||||
class PublishedRagPipelineFirstStepApi(Resource):
|
||||
@console_ns.expect(parser_parameters)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -787,10 +774,8 @@ class PublishedRagPipelineFirstStepApi(Resource):
|
|||
"""
|
||||
Get first step parameters of rag pipeline
|
||||
"""
|
||||
args = parser_parameters.parse_args()
|
||||
node_id = args.get("node_id")
|
||||
if not node_id:
|
||||
raise ValueError("Node ID is required")
|
||||
query = NodeIdQuery.model_validate(request.args.to_dict())
|
||||
node_id = query.node_id
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
variables = rag_pipeline_service.get_first_step_parameters(pipeline=pipeline, node_id=node_id, is_draft=False)
|
||||
return {
|
||||
|
|
@ -800,7 +785,6 @@ class PublishedRagPipelineFirstStepApi(Resource):
|
|||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/pre-processing/parameters")
|
||||
class DraftRagPipelineFirstStepApi(Resource):
|
||||
@console_ns.expect(parser_parameters)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -810,10 +794,8 @@ class DraftRagPipelineFirstStepApi(Resource):
|
|||
"""
|
||||
Get first step parameters of rag pipeline
|
||||
"""
|
||||
args = parser_parameters.parse_args()
|
||||
node_id = args.get("node_id")
|
||||
if not node_id:
|
||||
raise ValueError("Node ID is required")
|
||||
query = NodeIdQuery.model_validate(request.args.to_dict())
|
||||
node_id = query.node_id
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
variables = rag_pipeline_service.get_first_step_parameters(pipeline=pipeline, node_id=node_id, is_draft=True)
|
||||
return {
|
||||
|
|
@ -823,7 +805,6 @@ class DraftRagPipelineFirstStepApi(Resource):
|
|||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/processing/parameters")
|
||||
class DraftRagPipelineSecondStepApi(Resource):
|
||||
@console_ns.expect(parser_parameters)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -833,10 +814,8 @@ class DraftRagPipelineSecondStepApi(Resource):
|
|||
"""
|
||||
Get second step parameters of rag pipeline
|
||||
"""
|
||||
args = parser_parameters.parse_args()
|
||||
node_id = args.get("node_id")
|
||||
if not node_id:
|
||||
raise ValueError("Node ID is required")
|
||||
query = NodeIdQuery.model_validate(request.args.to_dict())
|
||||
node_id = query.node_id
|
||||
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
variables = rag_pipeline_service.get_second_step_parameters(pipeline=pipeline, node_id=node_id, is_draft=True)
|
||||
|
|
@ -845,16 +824,8 @@ class DraftRagPipelineSecondStepApi(Resource):
|
|||
}
|
||||
|
||||
|
||||
parser_wf_run = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("last_id", type=uuid_value, location="args")
|
||||
.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflow-runs")
|
||||
class RagPipelineWorkflowRunListApi(Resource):
|
||||
@console_ns.expect(parser_wf_run)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -864,7 +835,16 @@ class RagPipelineWorkflowRunListApi(Resource):
|
|||
"""
|
||||
Get workflow run list
|
||||
"""
|
||||
args = parser_wf_run.parse_args()
|
||||
query = WorkflowRunQuery.model_validate(
|
||||
{
|
||||
"last_id": request.args.get("last_id"),
|
||||
"limit": request.args.get("limit", type=int, default=20),
|
||||
}
|
||||
)
|
||||
args = {
|
||||
"last_id": str(query.last_id) if query.last_id else None,
|
||||
"limit": query.limit,
|
||||
}
|
||||
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
result = rag_pipeline_service.get_rag_pipeline_paginate_workflow_runs(pipeline=pipeline, args=args)
|
||||
|
|
@ -964,18 +944,9 @@ class RagPipelineTransformApi(Resource):
|
|||
return result
|
||||
|
||||
|
||||
parser_var = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("datasource_type", type=str, required=True, location="json")
|
||||
.add_argument("datasource_info", type=dict, required=True, location="json")
|
||||
.add_argument("start_node_id", type=str, required=True, location="json")
|
||||
.add_argument("start_node_title", type=str, required=True, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/datasource/variables-inspect")
|
||||
class RagPipelineDatasourceVariableApi(Resource):
|
||||
@console_ns.expect(parser_var)
|
||||
@console_ns.expect(console_ns.models[DatasourceVariablesPayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -987,7 +958,7 @@ class RagPipelineDatasourceVariableApi(Resource):
|
|||
Set datasource variables
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
args = parser_var.parse_args()
|
||||
args = DatasourceVariablesPayload.model_validate(console_ns.payload or {}).model_dump()
|
||||
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
workflow_node_execution = rag_pipeline_service.set_datasource_variables(
|
||||
|
|
|
|||
|
|
@ -1,5 +1,10 @@
|
|||
from flask_restx import Resource, fields, reqparse
|
||||
from typing import Literal
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.datasets.error import WebsiteCrawlError
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
|
|
@ -7,48 +12,35 @@ from libs.login import login_required
|
|||
from services.website_service import WebsiteCrawlApiRequest, WebsiteCrawlStatusApiRequest, WebsiteService
|
||||
|
||||
|
||||
class WebsiteCrawlPayload(BaseModel):
|
||||
provider: Literal["firecrawl", "watercrawl", "jinareader"]
|
||||
url: str
|
||||
options: dict[str, object]
|
||||
|
||||
|
||||
class WebsiteCrawlStatusQuery(BaseModel):
|
||||
provider: Literal["firecrawl", "watercrawl", "jinareader"]
|
||||
|
||||
|
||||
register_schema_models(console_ns, WebsiteCrawlPayload, WebsiteCrawlStatusQuery)
|
||||
|
||||
|
||||
@console_ns.route("/website/crawl")
|
||||
class WebsiteCrawlApi(Resource):
|
||||
@console_ns.doc("crawl_website")
|
||||
@console_ns.doc(description="Crawl website content")
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"WebsiteCrawlRequest",
|
||||
{
|
||||
"provider": fields.String(
|
||||
required=True,
|
||||
description="Crawl provider (firecrawl/watercrawl/jinareader)",
|
||||
enum=["firecrawl", "watercrawl", "jinareader"],
|
||||
),
|
||||
"url": fields.String(required=True, description="URL to crawl"),
|
||||
"options": fields.Raw(required=True, description="Crawl options"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[WebsiteCrawlPayload.__name__])
|
||||
@console_ns.response(200, "Website crawl initiated successfully")
|
||||
@console_ns.response(400, "Invalid crawl parameters")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument(
|
||||
"provider",
|
||||
type=str,
|
||||
choices=["firecrawl", "watercrawl", "jinareader"],
|
||||
required=True,
|
||||
nullable=True,
|
||||
location="json",
|
||||
)
|
||||
.add_argument("url", type=str, required=True, nullable=True, location="json")
|
||||
.add_argument("options", type=dict, required=True, nullable=True, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
payload = WebsiteCrawlPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
# Create typed request and validate
|
||||
try:
|
||||
api_request = WebsiteCrawlApiRequest.from_args(args)
|
||||
api_request = WebsiteCrawlApiRequest.from_args(payload.model_dump())
|
||||
except ValueError as e:
|
||||
raise WebsiteCrawlError(str(e))
|
||||
|
||||
|
|
@ -65,6 +57,7 @@ class WebsiteCrawlStatusApi(Resource):
|
|||
@console_ns.doc("get_crawl_status")
|
||||
@console_ns.doc(description="Get website crawl status")
|
||||
@console_ns.doc(params={"job_id": "Crawl job ID", "provider": "Crawl provider (firecrawl/watercrawl/jinareader)"})
|
||||
@console_ns.expect(console_ns.models[WebsiteCrawlStatusQuery.__name__])
|
||||
@console_ns.response(200, "Crawl status retrieved successfully")
|
||||
@console_ns.response(404, "Crawl job not found")
|
||||
@console_ns.response(400, "Invalid provider")
|
||||
|
|
@ -72,14 +65,11 @@ class WebsiteCrawlStatusApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, job_id: str):
|
||||
parser = reqparse.RequestParser().add_argument(
|
||||
"provider", type=str, choices=["firecrawl", "watercrawl", "jinareader"], required=True, location="args"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = WebsiteCrawlStatusQuery.model_validate(request.args.to_dict())
|
||||
|
||||
# Create typed request and validate
|
||||
try:
|
||||
api_request = WebsiteCrawlStatusApiRequest.from_args(args, job_id)
|
||||
api_request = WebsiteCrawlStatusApiRequest.from_args(args.model_dump(), job_id)
|
||||
except ValueError as e:
|
||||
raise WebsiteCrawlError(str(e))
|
||||
|
||||
|
|
|
|||
|
|
@ -1,9 +1,11 @@
|
|||
import logging
|
||||
|
||||
from flask import request
|
||||
from pydantic import BaseModel, Field
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
import services
|
||||
from controllers.common.schema import register_schema_model
|
||||
from controllers.console.app.error import (
|
||||
AppUnavailableError,
|
||||
AudioTooLargeError,
|
||||
|
|
@ -31,6 +33,16 @@ from .. import console_ns
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TextToAudioPayload(BaseModel):
|
||||
message_id: str | None = None
|
||||
voice: str | None = None
|
||||
text: str | None = None
|
||||
streaming: bool | None = Field(default=None, description="Enable streaming response")
|
||||
|
||||
|
||||
register_schema_model(console_ns, TextToAudioPayload)
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
"/installed-apps/<uuid:installed_app_id>/audio-to-text",
|
||||
endpoint="installed_app_audio",
|
||||
|
|
@ -76,23 +88,15 @@ class ChatAudioApi(InstalledAppResource):
|
|||
endpoint="installed_app_text",
|
||||
)
|
||||
class ChatTextApi(InstalledAppResource):
|
||||
@console_ns.expect(console_ns.models[TextToAudioPayload.__name__])
|
||||
def post(self, installed_app):
|
||||
from flask_restx import reqparse
|
||||
|
||||
app_model = installed_app.app
|
||||
try:
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("message_id", type=str, required=False, location="json")
|
||||
.add_argument("voice", type=str, location="json")
|
||||
.add_argument("text", type=str, location="json")
|
||||
.add_argument("streaming", type=bool, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
payload = TextToAudioPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
message_id = args.get("message_id", None)
|
||||
text = args.get("text", None)
|
||||
voice = args.get("voice", None)
|
||||
message_id = payload.message_id
|
||||
text = payload.text
|
||||
voice = payload.voice
|
||||
|
||||
response = AudioService.transcript_tts(app_model=app_model, text=text, voice=voice, message_id=message_id)
|
||||
return response
|
||||
|
|
|
|||
|
|
@ -1,9 +1,12 @@
|
|||
import logging
|
||||
from typing import Any, Literal
|
||||
from uuid import UUID
|
||||
|
||||
from flask_restx import reqparse
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console.app.error import (
|
||||
AppUnavailableError,
|
||||
CompletionRequestError,
|
||||
|
|
@ -25,7 +28,6 @@ from core.model_runtime.errors.invoke import InvokeError
|
|||
from extensions.ext_database import db
|
||||
from libs import helper
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.helper import uuid_value
|
||||
from libs.login import current_user
|
||||
from models import Account
|
||||
from models.model import AppMode
|
||||
|
|
@ -38,28 +40,56 @@ from .. import console_ns
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CompletionMessagePayload(BaseModel):
|
||||
inputs: dict[str, Any]
|
||||
query: str = ""
|
||||
files: list[dict[str, Any]] | None = None
|
||||
response_mode: Literal["blocking", "streaming"] | None = None
|
||||
retriever_from: str = Field(default="explore_app")
|
||||
|
||||
|
||||
class ChatMessagePayload(BaseModel):
|
||||
inputs: dict[str, Any]
|
||||
query: str
|
||||
files: list[dict[str, Any]] | None = None
|
||||
conversation_id: str | None = None
|
||||
parent_message_id: str | None = None
|
||||
retriever_from: str = Field(default="explore_app")
|
||||
|
||||
@field_validator("conversation_id", "parent_message_id", mode="before")
|
||||
@classmethod
|
||||
def normalize_uuid(cls, value: str | UUID | None) -> str | None:
|
||||
"""
|
||||
Accept blank IDs and validate UUID format when provided.
|
||||
"""
|
||||
if not value:
|
||||
return None
|
||||
|
||||
try:
|
||||
return helper.uuid_value(value)
|
||||
except ValueError as exc:
|
||||
raise ValueError("must be a valid UUID") from exc
|
||||
|
||||
|
||||
register_schema_models(console_ns, CompletionMessagePayload, ChatMessagePayload)
|
||||
|
||||
|
||||
# define completion api for user
|
||||
@console_ns.route(
|
||||
"/installed-apps/<uuid:installed_app_id>/completion-messages",
|
||||
endpoint="installed_app_completion",
|
||||
)
|
||||
class CompletionApi(InstalledAppResource):
|
||||
@console_ns.expect(console_ns.models[CompletionMessagePayload.__name__])
|
||||
def post(self, installed_app):
|
||||
app_model = installed_app.app
|
||||
if app_model.mode != AppMode.COMPLETION:
|
||||
raise NotCompletionAppError()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("inputs", type=dict, required=True, location="json")
|
||||
.add_argument("query", type=str, location="json", default="")
|
||||
.add_argument("files", type=list, required=False, location="json")
|
||||
.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
|
||||
.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
payload = CompletionMessagePayload.model_validate(console_ns.payload or {})
|
||||
args = payload.model_dump(exclude_none=True)
|
||||
|
||||
streaming = args["response_mode"] == "streaming"
|
||||
streaming = payload.response_mode == "streaming"
|
||||
args["auto_generate_name"] = False
|
||||
|
||||
installed_app.last_used_at = naive_utc_now()
|
||||
|
|
@ -123,22 +153,15 @@ class CompletionStopApi(InstalledAppResource):
|
|||
endpoint="installed_app_chat_completion",
|
||||
)
|
||||
class ChatApi(InstalledAppResource):
|
||||
@console_ns.expect(console_ns.models[ChatMessagePayload.__name__])
|
||||
def post(self, installed_app):
|
||||
app_model = installed_app.app
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("inputs", type=dict, required=True, location="json")
|
||||
.add_argument("query", type=str, required=True, location="json")
|
||||
.add_argument("files", type=list, required=False, location="json")
|
||||
.add_argument("conversation_id", type=uuid_value, location="json")
|
||||
.add_argument("parent_message_id", type=uuid_value, required=False, location="json")
|
||||
.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
payload = ChatMessagePayload.model_validate(console_ns.payload or {})
|
||||
args = payload.model_dump(exclude_none=True)
|
||||
|
||||
args["auto_generate_name"] = False
|
||||
|
||||
|
|
|
|||
|
|
@ -1,14 +1,18 @@
|
|||
from flask_restx import marshal_with, reqparse
|
||||
from flask_restx.inputs import int_range
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from flask import request
|
||||
from flask_restx import marshal_with
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console.explore.error import NotChatAppError
|
||||
from controllers.console.explore.wraps import InstalledAppResource
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from extensions.ext_database import db
|
||||
from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields
|
||||
from libs.helper import uuid_value
|
||||
from libs.login import current_user
|
||||
from models import Account
|
||||
from models.model import AppMode
|
||||
|
|
@ -19,29 +23,51 @@ from services.web_conversation_service import WebConversationService
|
|||
from .. import console_ns
|
||||
|
||||
|
||||
class ConversationListQuery(BaseModel):
|
||||
last_id: UUID | None = None
|
||||
limit: int = Field(default=20, ge=1, le=100)
|
||||
pinned: bool | None = None
|
||||
|
||||
|
||||
class ConversationRenamePayload(BaseModel):
|
||||
name: str | None = None
|
||||
auto_generate: bool = False
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_name_requirement(self):
|
||||
if not self.auto_generate:
|
||||
if self.name is None or not self.name.strip():
|
||||
raise ValueError("name is required when auto_generate is false")
|
||||
return self
|
||||
|
||||
|
||||
register_schema_models(console_ns, ConversationListQuery, ConversationRenamePayload)
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
"/installed-apps/<uuid:installed_app_id>/conversations",
|
||||
endpoint="installed_app_conversations",
|
||||
)
|
||||
class ConversationListApi(InstalledAppResource):
|
||||
@marshal_with(conversation_infinite_scroll_pagination_fields)
|
||||
@console_ns.expect(console_ns.models[ConversationListQuery.__name__])
|
||||
def get(self, installed_app):
|
||||
app_model = installed_app.app
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("last_id", type=uuid_value, location="args")
|
||||
.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
|
||||
.add_argument("pinned", type=str, choices=["true", "false", None], location="args")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
pinned = None
|
||||
if "pinned" in args and args["pinned"] is not None:
|
||||
pinned = args["pinned"] == "true"
|
||||
raw_args: dict[str, Any] = {
|
||||
"last_id": request.args.get("last_id"),
|
||||
"limit": request.args.get("limit", default=20, type=int),
|
||||
"pinned": request.args.get("pinned"),
|
||||
}
|
||||
if raw_args["last_id"] is None:
|
||||
raw_args["last_id"] = None
|
||||
pinned_value = raw_args["pinned"]
|
||||
if isinstance(pinned_value, str):
|
||||
raw_args["pinned"] = pinned_value == "true"
|
||||
args = ConversationListQuery.model_validate(raw_args)
|
||||
|
||||
try:
|
||||
if not isinstance(current_user, Account):
|
||||
|
|
@ -51,10 +77,10 @@ class ConversationListApi(InstalledAppResource):
|
|||
session=session,
|
||||
app_model=app_model,
|
||||
user=current_user,
|
||||
last_id=args["last_id"],
|
||||
limit=args["limit"],
|
||||
last_id=str(args.last_id) if args.last_id else None,
|
||||
limit=args.limit,
|
||||
invoke_from=InvokeFrom.EXPLORE,
|
||||
pinned=pinned,
|
||||
pinned=args.pinned,
|
||||
)
|
||||
except LastConversationNotExistsError:
|
||||
raise NotFound("Last Conversation Not Exists.")
|
||||
|
|
@ -88,6 +114,7 @@ class ConversationApi(InstalledAppResource):
|
|||
)
|
||||
class ConversationRenameApi(InstalledAppResource):
|
||||
@marshal_with(simple_conversation_fields)
|
||||
@console_ns.expect(console_ns.models[ConversationRenamePayload.__name__])
|
||||
def post(self, installed_app, c_id):
|
||||
app_model = installed_app.app
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
|
|
@ -96,18 +123,13 @@ class ConversationRenameApi(InstalledAppResource):
|
|||
|
||||
conversation_id = str(c_id)
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("name", type=str, required=False, location="json")
|
||||
.add_argument("auto_generate", type=bool, required=False, default=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
payload = ConversationRenamePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
try:
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
return ConversationService.rename(
|
||||
app_model, conversation_id, current_user, args["name"], args["auto_generate"]
|
||||
app_model, conversation_id, current_user, payload.name, payload.auto_generate
|
||||
)
|
||||
except ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
|
|
|||
|
|
@ -1,9 +1,13 @@
|
|||
import logging
|
||||
from typing import Literal
|
||||
from uuid import UUID
|
||||
|
||||
from flask_restx import marshal_with, reqparse
|
||||
from flask_restx.inputs import int_range
|
||||
from flask import request
|
||||
from flask_restx import marshal_with
|
||||
from pydantic import BaseModel, Field
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console.app.error import (
|
||||
AppMoreLikeThisDisabledError,
|
||||
CompletionRequestError,
|
||||
|
|
@ -22,7 +26,6 @@ from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotIni
|
|||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from fields.message_fields import message_infinite_scroll_pagination_fields
|
||||
from libs import helper
|
||||
from libs.helper import uuid_value
|
||||
from libs.login import current_account_with_tenant
|
||||
from models.model import AppMode
|
||||
from services.app_generate_service import AppGenerateService
|
||||
|
|
@ -40,12 +43,31 @@ from .. import console_ns
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MessageListQuery(BaseModel):
|
||||
conversation_id: UUID
|
||||
first_id: UUID | None = None
|
||||
limit: int = Field(default=20, ge=1, le=100)
|
||||
|
||||
|
||||
class MessageFeedbackPayload(BaseModel):
|
||||
rating: Literal["like", "dislike"] | None = None
|
||||
content: str | None = None
|
||||
|
||||
|
||||
class MoreLikeThisQuery(BaseModel):
|
||||
response_mode: Literal["blocking", "streaming"]
|
||||
|
||||
|
||||
register_schema_models(console_ns, MessageListQuery, MessageFeedbackPayload, MoreLikeThisQuery)
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
"/installed-apps/<uuid:installed_app_id>/messages",
|
||||
endpoint="installed_app_messages",
|
||||
)
|
||||
class MessageListApi(InstalledAppResource):
|
||||
@marshal_with(message_infinite_scroll_pagination_fields)
|
||||
@console_ns.expect(console_ns.models[MessageListQuery.__name__])
|
||||
def get(self, installed_app):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
app_model = installed_app.app
|
||||
|
|
@ -53,18 +75,15 @@ class MessageListApi(InstalledAppResource):
|
|||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("conversation_id", required=True, type=uuid_value, location="args")
|
||||
.add_argument("first_id", type=uuid_value, location="args")
|
||||
.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = MessageListQuery.model_validate(request.args.to_dict())
|
||||
|
||||
try:
|
||||
return MessageService.pagination_by_first_id(
|
||||
app_model, current_user, args["conversation_id"], args["first_id"], args["limit"]
|
||||
app_model,
|
||||
current_user,
|
||||
str(args.conversation_id),
|
||||
str(args.first_id) if args.first_id else None,
|
||||
args.limit,
|
||||
)
|
||||
except ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
|
@ -77,26 +96,22 @@ class MessageListApi(InstalledAppResource):
|
|||
endpoint="installed_app_message_feedback",
|
||||
)
|
||||
class MessageFeedbackApi(InstalledAppResource):
|
||||
@console_ns.expect(console_ns.models[MessageFeedbackPayload.__name__])
|
||||
def post(self, installed_app, message_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
app_model = installed_app.app
|
||||
|
||||
message_id = str(message_id)
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("rating", type=str, choices=["like", "dislike", None], location="json")
|
||||
.add_argument("content", type=str, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
payload = MessageFeedbackPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
try:
|
||||
MessageService.create_feedback(
|
||||
app_model=app_model,
|
||||
message_id=message_id,
|
||||
user=current_user,
|
||||
rating=args.get("rating"),
|
||||
content=args.get("content"),
|
||||
rating=payload.rating,
|
||||
content=payload.content,
|
||||
)
|
||||
except MessageNotExistsError:
|
||||
raise NotFound("Message Not Exists.")
|
||||
|
|
@ -109,6 +124,7 @@ class MessageFeedbackApi(InstalledAppResource):
|
|||
endpoint="installed_app_more_like_this",
|
||||
)
|
||||
class MessageMoreLikeThisApi(InstalledAppResource):
|
||||
@console_ns.expect(console_ns.models[MoreLikeThisQuery.__name__])
|
||||
def get(self, installed_app, message_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
app_model = installed_app.app
|
||||
|
|
@ -117,12 +133,9 @@ class MessageMoreLikeThisApi(InstalledAppResource):
|
|||
|
||||
message_id = str(message_id)
|
||||
|
||||
parser = reqparse.RequestParser().add_argument(
|
||||
"response_mode", type=str, required=True, choices=["blocking", "streaming"], location="args"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = MoreLikeThisQuery.model_validate(request.args.to_dict())
|
||||
|
||||
streaming = args["response_mode"] == "streaming"
|
||||
streaming = args.response_mode == "streaming"
|
||||
|
||||
try:
|
||||
response = AppGenerateService.generate_more_like_this(
|
||||
|
|
|
|||
|
|
@ -1,4 +1,6 @@
|
|||
from flask_restx import Resource, fields, marshal_with, reqparse
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from constants.languages import languages
|
||||
from controllers.console import console_ns
|
||||
|
|
@ -35,20 +37,26 @@ recommended_app_list_fields = {
|
|||
}
|
||||
|
||||
|
||||
parser_apps = reqparse.RequestParser().add_argument("language", type=str, location="args")
|
||||
class RecommendedAppsQuery(BaseModel):
|
||||
language: str | None = Field(default=None)
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
RecommendedAppsQuery.__name__,
|
||||
RecommendedAppsQuery.model_json_schema(ref_template="#/definitions/{model}"),
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/explore/apps")
|
||||
class RecommendedAppListApi(Resource):
|
||||
@console_ns.expect(parser_apps)
|
||||
@console_ns.expect(console_ns.models[RecommendedAppsQuery.__name__])
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(recommended_app_list_fields)
|
||||
def get(self):
|
||||
# language args
|
||||
args = parser_apps.parse_args()
|
||||
|
||||
language = args.get("language")
|
||||
args = RecommendedAppsQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
language = args.language
|
||||
if language and language in languages:
|
||||
language_prefix = language
|
||||
elif current_user and current_user.interface_language:
|
||||
|
|
|
|||
|
|
@ -1,16 +1,33 @@
|
|||
from flask_restx import fields, marshal_with, reqparse
|
||||
from flask_restx.inputs import int_range
|
||||
from uuid import UUID
|
||||
|
||||
from flask import request
|
||||
from flask_restx import fields, marshal_with
|
||||
from pydantic import BaseModel, Field
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.explore.error import NotCompletionAppError
|
||||
from controllers.console.explore.wraps import InstalledAppResource
|
||||
from fields.conversation_fields import message_file_fields
|
||||
from libs.helper import TimestampField, uuid_value
|
||||
from libs.helper import TimestampField
|
||||
from libs.login import current_account_with_tenant
|
||||
from services.errors.message import MessageNotExistsError
|
||||
from services.saved_message_service import SavedMessageService
|
||||
|
||||
|
||||
class SavedMessageListQuery(BaseModel):
|
||||
last_id: UUID | None = None
|
||||
limit: int = Field(default=20, ge=1, le=100)
|
||||
|
||||
|
||||
class SavedMessageCreatePayload(BaseModel):
|
||||
message_id: UUID
|
||||
|
||||
|
||||
register_schema_models(console_ns, SavedMessageListQuery, SavedMessageCreatePayload)
|
||||
|
||||
|
||||
feedback_fields = {"rating": fields.String}
|
||||
|
||||
message_fields = {
|
||||
|
|
@ -33,32 +50,33 @@ class SavedMessageListApi(InstalledAppResource):
|
|||
}
|
||||
|
||||
@marshal_with(saved_message_infinite_scroll_pagination_fields)
|
||||
@console_ns.expect(console_ns.models[SavedMessageListQuery.__name__])
|
||||
def get(self, installed_app):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
app_model = installed_app.app
|
||||
if app_model.mode != "completion":
|
||||
raise NotCompletionAppError()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("last_id", type=uuid_value, location="args")
|
||||
.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
|
||||
args = SavedMessageListQuery.model_validate(request.args.to_dict())
|
||||
|
||||
return SavedMessageService.pagination_by_last_id(
|
||||
app_model,
|
||||
current_user,
|
||||
str(args.last_id) if args.last_id else None,
|
||||
args.limit,
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
return SavedMessageService.pagination_by_last_id(app_model, current_user, args["last_id"], args["limit"])
|
||||
|
||||
@console_ns.expect(console_ns.models[SavedMessageCreatePayload.__name__])
|
||||
def post(self, installed_app):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
app_model = installed_app.app
|
||||
if app_model.mode != "completion":
|
||||
raise NotCompletionAppError()
|
||||
|
||||
parser = reqparse.RequestParser().add_argument("message_id", type=uuid_value, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
payload = SavedMessageCreatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
try:
|
||||
SavedMessageService.save(app_model, current_user, args["message_id"])
|
||||
SavedMessageService.save(app_model, current_user, str(payload.message_id))
|
||||
except MessageNotExistsError:
|
||||
raise NotFound("Message Not Exists.")
|
||||
|
||||
|
|
|
|||
|
|
@ -1,8 +1,10 @@
|
|||
import logging
|
||||
from typing import Any
|
||||
|
||||
from flask_restx import reqparse
|
||||
from pydantic import BaseModel
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
from controllers.common.schema import register_schema_model
|
||||
from controllers.console.app.error import (
|
||||
CompletionRequestError,
|
||||
ProviderModelCurrentlyNotSupportError,
|
||||
|
|
@ -32,8 +34,17 @@ from .. import console_ns
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorkflowRunPayload(BaseModel):
|
||||
inputs: dict[str, Any]
|
||||
files: list[dict[str, Any]] | None = None
|
||||
|
||||
|
||||
register_schema_model(console_ns, WorkflowRunPayload)
|
||||
|
||||
|
||||
@console_ns.route("/installed-apps/<uuid:installed_app_id>/workflows/run")
|
||||
class InstalledAppWorkflowRunApi(InstalledAppResource):
|
||||
@console_ns.expect(console_ns.models[WorkflowRunPayload.__name__])
|
||||
def post(self, installed_app: InstalledApp):
|
||||
"""
|
||||
Run workflow
|
||||
|
|
@ -46,12 +57,8 @@ class InstalledAppWorkflowRunApi(InstalledAppResource):
|
|||
if app_mode != AppMode.WORKFLOW:
|
||||
raise NotWorkflowAppError()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("files", type=list, required=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
payload = WorkflowRunPayload.model_validate(console_ns.payload or {})
|
||||
args = payload.model_dump(exclude_none=True)
|
||||
try:
|
||||
response = AppGenerateService.generate(
|
||||
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=True
|
||||
|
|
|
|||
|
|
@ -45,6 +45,9 @@ class FileApi(Resource):
|
|||
"video_file_size_limit": dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT,
|
||||
"audio_file_size_limit": dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT,
|
||||
"workflow_file_upload_limit": dify_config.WORKFLOW_FILE_UPLOAD_LIMIT,
|
||||
"image_file_batch_limit": dify_config.IMAGE_FILE_BATCH_LIMIT,
|
||||
"single_chunk_attachment_limit": dify_config.SINGLE_CHUNK_ATTACHMENT_LIMIT,
|
||||
"attachment_image_file_size_limit": dify_config.ATTACHMENT_IMAGE_FILE_SIZE_LIMIT,
|
||||
}, 200
|
||||
|
||||
@setup_required
|
||||
|
|
|
|||
|
|
@ -1,13 +1,13 @@
|
|||
import os
|
||||
|
||||
from flask import session
|
||||
from flask_restx import Resource, fields, reqparse
|
||||
from flask_restx import Resource, fields
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import StrLen
|
||||
from models.model import DifySetup
|
||||
from services.account_service import TenantService
|
||||
|
||||
|
|
@ -15,6 +15,18 @@ from . import console_ns
|
|||
from .error import AlreadySetupError, InitValidateFailedError
|
||||
from .wraps import only_edition_self_hosted
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class InitValidatePayload(BaseModel):
|
||||
password: str = Field(..., max_length=30)
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
InitValidatePayload.__name__,
|
||||
InitValidatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/init")
|
||||
class InitValidateAPI(Resource):
|
||||
|
|
@ -37,12 +49,7 @@ class InitValidateAPI(Resource):
|
|||
|
||||
@console_ns.doc("validate_init_password")
|
||||
@console_ns.doc(description="Validate initialization password for self-hosted edition")
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"InitValidateRequest",
|
||||
{"password": fields.String(required=True, description="Initialization password", max_length=30)},
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[InitValidatePayload.__name__])
|
||||
@console_ns.response(
|
||||
201,
|
||||
"Success",
|
||||
|
|
@ -57,8 +64,8 @@ class InitValidateAPI(Resource):
|
|||
if tenant_count > 0:
|
||||
raise AlreadySetupError()
|
||||
|
||||
parser = reqparse.RequestParser().add_argument("password", type=StrLen(30), required=True, location="json")
|
||||
input_password = parser.parse_args()["password"]
|
||||
payload = InitValidatePayload.model_validate(console_ns.payload)
|
||||
input_password = payload.password
|
||||
|
||||
if input_password != os.environ.get("INIT_PASSWORD"):
|
||||
session["is_init_validated"] = False
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
import urllib.parse
|
||||
|
||||
import httpx
|
||||
from flask_restx import Resource, marshal_with, reqparse
|
||||
from flask_restx import Resource, marshal_with
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
import services
|
||||
from controllers.common import helpers
|
||||
|
|
@ -36,17 +37,23 @@ class RemoteFileInfoApi(Resource):
|
|||
}
|
||||
|
||||
|
||||
parser_upload = reqparse.RequestParser().add_argument("url", type=str, required=True, help="URL is required")
|
||||
class RemoteFileUploadPayload(BaseModel):
|
||||
url: str = Field(..., description="URL to fetch")
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
RemoteFileUploadPayload.__name__,
|
||||
RemoteFileUploadPayload.model_json_schema(ref_template="#/definitions/{model}"),
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/remote-files/upload")
|
||||
class RemoteFileUploadApi(Resource):
|
||||
@console_ns.expect(parser_upload)
|
||||
@console_ns.expect(console_ns.models[RemoteFileUploadPayload.__name__])
|
||||
@marshal_with(file_fields_with_signed_url)
|
||||
def post(self):
|
||||
args = parser_upload.parse_args()
|
||||
|
||||
url = args["url"]
|
||||
args = RemoteFileUploadPayload.model_validate(console_ns.payload)
|
||||
url = args.url
|
||||
|
||||
try:
|
||||
resp = ssrf_proxy.head(url=url)
|
||||
|
|
|
|||
|
|
@ -1,8 +1,9 @@
|
|||
from flask import request
|
||||
from flask_restx import Resource, fields, reqparse
|
||||
from flask_restx import Resource, fields
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from configs import dify_config
|
||||
from libs.helper import StrLen, email, extract_remote_ip
|
||||
from libs.helper import EmailStr, extract_remote_ip
|
||||
from libs.password import valid_password
|
||||
from models.model import DifySetup, db
|
||||
from services.account_service import RegisterService, TenantService
|
||||
|
|
@ -12,6 +13,26 @@ from .error import AlreadySetupError, NotInitValidateError
|
|||
from .init_validate import get_init_validate_status
|
||||
from .wraps import only_edition_self_hosted
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class SetupRequestPayload(BaseModel):
|
||||
email: EmailStr = Field(..., description="Admin email address")
|
||||
name: str = Field(..., max_length=30, description="Admin name (max 30 characters)")
|
||||
password: str = Field(..., description="Admin password")
|
||||
language: str | None = Field(default=None, description="Admin language")
|
||||
|
||||
@field_validator("password")
|
||||
@classmethod
|
||||
def validate_password(cls, value: str) -> str:
|
||||
return valid_password(value)
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
SetupRequestPayload.__name__,
|
||||
SetupRequestPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/setup")
|
||||
class SetupApi(Resource):
|
||||
|
|
@ -42,17 +63,7 @@ class SetupApi(Resource):
|
|||
|
||||
@console_ns.doc("setup_system")
|
||||
@console_ns.doc(description="Initialize system setup with admin account")
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"SetupRequest",
|
||||
{
|
||||
"email": fields.String(required=True, description="Admin email address"),
|
||||
"name": fields.String(required=True, description="Admin name (max 30 characters)"),
|
||||
"password": fields.String(required=True, description="Admin password"),
|
||||
"language": fields.String(required=False, description="Admin language"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[SetupRequestPayload.__name__])
|
||||
@console_ns.response(
|
||||
201, "Success", console_ns.model("SetupResponse", {"result": fields.String(description="Setup result")})
|
||||
)
|
||||
|
|
@ -72,22 +83,15 @@ class SetupApi(Resource):
|
|||
if not get_init_validate_status():
|
||||
raise NotInitValidateError()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("email", type=email, required=True, location="json")
|
||||
.add_argument("name", type=StrLen(30), required=True, location="json")
|
||||
.add_argument("password", type=valid_password, required=True, location="json")
|
||||
.add_argument("language", type=str, required=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = SetupRequestPayload.model_validate(console_ns.payload)
|
||||
|
||||
# setup
|
||||
RegisterService.setup(
|
||||
email=args["email"],
|
||||
name=args["name"],
|
||||
password=args["password"],
|
||||
email=args.email,
|
||||
name=args.name,
|
||||
password=args.password,
|
||||
ip_address=extract_remote_ip(request),
|
||||
language=args["language"],
|
||||
language=args.language,
|
||||
)
|
||||
|
||||
return {"result": "success"}, 201
|
||||
|
|
|
|||
|
|
@ -2,8 +2,10 @@ import json
|
|||
import logging
|
||||
|
||||
import httpx
|
||||
from flask_restx import Resource, fields, reqparse
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields
|
||||
from packaging import version
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from configs import dify_config
|
||||
|
||||
|
|
@ -11,8 +13,14 @@ from . import console_ns
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
parser = reqparse.RequestParser().add_argument(
|
||||
"current_version", type=str, required=True, location="args", help="Current application version"
|
||||
|
||||
class VersionQuery(BaseModel):
|
||||
current_version: str = Field(..., description="Current application version")
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
VersionQuery.__name__,
|
||||
VersionQuery.model_json_schema(ref_template="#/definitions/{model}"),
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -20,7 +28,7 @@ parser = reqparse.RequestParser().add_argument(
|
|||
class VersionApi(Resource):
|
||||
@console_ns.doc("check_version_update")
|
||||
@console_ns.doc(description="Check for application version updates")
|
||||
@console_ns.expect(parser)
|
||||
@console_ns.expect(console_ns.models[VersionQuery.__name__])
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Success",
|
||||
|
|
@ -37,7 +45,7 @@ class VersionApi(Resource):
|
|||
)
|
||||
def get(self):
|
||||
"""Check for application version updates"""
|
||||
args = parser.parse_args()
|
||||
args = VersionQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
check_update_url = dify_config.CHECK_UPDATE_URL
|
||||
|
||||
result = {
|
||||
|
|
@ -57,16 +65,16 @@ class VersionApi(Resource):
|
|||
try:
|
||||
response = httpx.get(
|
||||
check_update_url,
|
||||
params={"current_version": args["current_version"]},
|
||||
params={"current_version": args.current_version},
|
||||
timeout=httpx.Timeout(timeout=10.0, connect=3.0),
|
||||
)
|
||||
except Exception as error:
|
||||
logger.warning("Check update version error: %s.", str(error))
|
||||
result["version"] = args["current_version"]
|
||||
result["version"] = args.current_version
|
||||
return result
|
||||
|
||||
content = json.loads(response.content)
|
||||
if _has_new_version(latest_version=content["version"], current_version=f"{args['current_version']}"):
|
||||
if _has_new_version(latest_version=content["version"], current_version=f"{args.current_version}"):
|
||||
result["version"] = content["version"]
|
||||
result["release_date"] = content["releaseDate"]
|
||||
result["release_notes"] = content["releaseNotes"]
|
||||
|
|
|
|||
|
|
@ -37,7 +37,7 @@ from controllers.console.wraps import (
|
|||
from extensions.ext_database import db
|
||||
from fields.member_fields import account_fields
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.helper import TimestampField, email, extract_remote_ip, timezone
|
||||
from libs.helper import EmailStr, TimestampField, extract_remote_ip, timezone
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import Account, AccountIntegrate, InvitationCode
|
||||
from services.account_service import AccountService
|
||||
|
|
@ -111,14 +111,9 @@ class AccountDeletePayload(BaseModel):
|
|||
|
||||
|
||||
class AccountDeletionFeedbackPayload(BaseModel):
|
||||
email: str
|
||||
email: EmailStr
|
||||
feedback: str
|
||||
|
||||
@field_validator("email")
|
||||
@classmethod
|
||||
def validate_email(cls, value: str) -> str:
|
||||
return email(value)
|
||||
|
||||
|
||||
class EducationActivatePayload(BaseModel):
|
||||
token: str
|
||||
|
|
@ -133,45 +128,25 @@ class EducationAutocompleteQuery(BaseModel):
|
|||
|
||||
|
||||
class ChangeEmailSendPayload(BaseModel):
|
||||
email: str
|
||||
email: EmailStr
|
||||
language: str | None = None
|
||||
phase: str | None = None
|
||||
token: str | None = None
|
||||
|
||||
@field_validator("email")
|
||||
@classmethod
|
||||
def validate_email(cls, value: str) -> str:
|
||||
return email(value)
|
||||
|
||||
|
||||
class ChangeEmailValidityPayload(BaseModel):
|
||||
email: str
|
||||
email: EmailStr
|
||||
code: str
|
||||
token: str
|
||||
|
||||
@field_validator("email")
|
||||
@classmethod
|
||||
def validate_email(cls, value: str) -> str:
|
||||
return email(value)
|
||||
|
||||
|
||||
class ChangeEmailResetPayload(BaseModel):
|
||||
new_email: str
|
||||
new_email: EmailStr
|
||||
token: str
|
||||
|
||||
@field_validator("new_email")
|
||||
@classmethod
|
||||
def validate_email(cls, value: str) -> str:
|
||||
return email(value)
|
||||
|
||||
|
||||
class CheckEmailUniquePayload(BaseModel):
|
||||
email: str
|
||||
|
||||
@field_validator("email")
|
||||
@classmethod
|
||||
def validate_email(cls, value: str) -> str:
|
||||
return email(value)
|
||||
email: EmailStr
|
||||
|
||||
|
||||
def reg(cls: type[BaseModel]):
|
||||
|
|
|
|||
|
|
@ -230,7 +230,7 @@ class ModelProviderModelApi(Resource):
|
|||
|
||||
return {"result": "success"}, 200
|
||||
|
||||
@console_ns.expect(console_ns.models[ParserDeleteModels.__name__], validate=True)
|
||||
@console_ns.expect(console_ns.models[ParserDeleteModels.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
|
|
@ -282,9 +282,10 @@ class ModelProviderModelCredentialApi(Resource):
|
|||
tenant_id=tenant_id, provider_name=provider
|
||||
)
|
||||
else:
|
||||
model_type = args.model_type
|
||||
# Normalize model_type to the origin value stored in DB (e.g., "text-generation" for LLM)
|
||||
normalized_model_type = args.model_type.to_origin_model_type()
|
||||
available_credentials = model_provider_service.provider_manager.get_provider_model_available_credentials(
|
||||
tenant_id=tenant_id, provider_name=provider, model_type=model_type, model_name=args.model
|
||||
tenant_id=tenant_id, provider_name=provider, model_type=normalized_model_type, model_name=args.model
|
||||
)
|
||||
|
||||
return jsonable_encoder(
|
||||
|
|
|
|||
|
|
@ -22,7 +22,12 @@ from services.trigger.trigger_subscription_builder_service import TriggerSubscri
|
|||
from services.trigger.trigger_subscription_operator_service import TriggerSubscriptionOperatorService
|
||||
|
||||
from .. import console_ns
|
||||
from ..wraps import account_initialization_required, is_admin_or_owner_required, setup_required
|
||||
from ..wraps import (
|
||||
account_initialization_required,
|
||||
edit_permission_required,
|
||||
is_admin_or_owner_required,
|
||||
setup_required,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -72,7 +77,7 @@ class TriggerProviderInfoApi(Resource):
|
|||
class TriggerSubscriptionListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@edit_permission_required
|
||||
@account_initialization_required
|
||||
def get(self, provider):
|
||||
"""List all trigger subscriptions for the current tenant's provider"""
|
||||
|
|
@ -104,7 +109,7 @@ class TriggerSubscriptionBuilderCreateApi(Resource):
|
|||
@console_ns.expect(parser)
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@edit_permission_required
|
||||
@account_initialization_required
|
||||
def post(self, provider):
|
||||
"""Add a new subscription instance for a trigger provider"""
|
||||
|
|
@ -133,6 +138,7 @@ class TriggerSubscriptionBuilderCreateApi(Resource):
|
|||
class TriggerSubscriptionBuilderGetApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@edit_permission_required
|
||||
@account_initialization_required
|
||||
def get(self, provider, subscription_builder_id):
|
||||
"""Get a subscription instance for a trigger provider"""
|
||||
|
|
@ -155,7 +161,7 @@ class TriggerSubscriptionBuilderVerifyApi(Resource):
|
|||
@console_ns.expect(parser_api)
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@edit_permission_required
|
||||
@account_initialization_required
|
||||
def post(self, provider, subscription_builder_id):
|
||||
"""Verify a subscription instance for a trigger provider"""
|
||||
|
|
@ -200,6 +206,7 @@ class TriggerSubscriptionBuilderUpdateApi(Resource):
|
|||
@console_ns.expect(parser_update_api)
|
||||
@setup_required
|
||||
@login_required
|
||||
@edit_permission_required
|
||||
@account_initialization_required
|
||||
def post(self, provider, subscription_builder_id):
|
||||
"""Update a subscription instance for a trigger provider"""
|
||||
|
|
@ -233,6 +240,7 @@ class TriggerSubscriptionBuilderUpdateApi(Resource):
|
|||
class TriggerSubscriptionBuilderLogsApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@edit_permission_required
|
||||
@account_initialization_required
|
||||
def get(self, provider, subscription_builder_id):
|
||||
"""Get the request logs for a subscription instance for a trigger provider"""
|
||||
|
|
@ -255,7 +263,7 @@ class TriggerSubscriptionBuilderBuildApi(Resource):
|
|||
@console_ns.expect(parser_update_api)
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@edit_permission_required
|
||||
@account_initialization_required
|
||||
def post(self, provider, subscription_builder_id):
|
||||
"""Build a subscription instance for a trigger provider"""
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
from urllib.parse import quote
|
||||
|
||||
from flask import Response, request
|
||||
from flask_restx import Resource, reqparse
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
import services
|
||||
|
|
@ -11,6 +12,26 @@ from extensions.ext_database import db
|
|||
from services.account_service import TenantService
|
||||
from services.file_service import FileService
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class FileSignatureQuery(BaseModel):
|
||||
timestamp: str = Field(..., description="Unix timestamp used in the signature")
|
||||
nonce: str = Field(..., description="Random string for signature")
|
||||
sign: str = Field(..., description="HMAC signature")
|
||||
|
||||
|
||||
class FilePreviewQuery(FileSignatureQuery):
|
||||
as_attachment: bool = Field(default=False, description="Whether to download as attachment")
|
||||
|
||||
|
||||
files_ns.schema_model(
|
||||
FileSignatureQuery.__name__, FileSignatureQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
files_ns.schema_model(
|
||||
FilePreviewQuery.__name__, FilePreviewQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
|
||||
|
||||
@files_ns.route("/<uuid:file_id>/image-preview")
|
||||
class ImagePreviewApi(Resource):
|
||||
|
|
@ -36,12 +57,10 @@ class ImagePreviewApi(Resource):
|
|||
def get(self, file_id):
|
||||
file_id = str(file_id)
|
||||
|
||||
timestamp = request.args.get("timestamp")
|
||||
nonce = request.args.get("nonce")
|
||||
sign = request.args.get("sign")
|
||||
|
||||
if not timestamp or not nonce or not sign:
|
||||
return {"content": "Invalid request."}, 400
|
||||
args = FileSignatureQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
timestamp = args.timestamp
|
||||
nonce = args.nonce
|
||||
sign = args.sign
|
||||
|
||||
try:
|
||||
generator, mimetype = FileService(db.engine).get_image_preview(
|
||||
|
|
@ -80,25 +99,14 @@ class FilePreviewApi(Resource):
|
|||
def get(self, file_id):
|
||||
file_id = str(file_id)
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("timestamp", type=str, required=True, location="args")
|
||||
.add_argument("nonce", type=str, required=True, location="args")
|
||||
.add_argument("sign", type=str, required=True, location="args")
|
||||
.add_argument("as_attachment", type=bool, required=False, default=False, location="args")
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if not args["timestamp"] or not args["nonce"] or not args["sign"]:
|
||||
return {"content": "Invalid request."}, 400
|
||||
args = FilePreviewQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
try:
|
||||
generator, upload_file = FileService(db.engine).get_file_generator_by_file_id(
|
||||
file_id=file_id,
|
||||
timestamp=args["timestamp"],
|
||||
nonce=args["nonce"],
|
||||
sign=args["sign"],
|
||||
timestamp=args.timestamp,
|
||||
nonce=args.nonce,
|
||||
sign=args.sign,
|
||||
)
|
||||
except services.errors.file.UnsupportedFileTypeError:
|
||||
raise UnsupportedFileTypeError()
|
||||
|
|
@ -125,7 +133,7 @@ class FilePreviewApi(Resource):
|
|||
response.headers["Accept-Ranges"] = "bytes"
|
||||
if upload_file.size > 0:
|
||||
response.headers["Content-Length"] = str(upload_file.size)
|
||||
if args["as_attachment"]:
|
||||
if args.as_attachment:
|
||||
encoded_filename = quote(upload_file.name)
|
||||
response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{encoded_filename}"
|
||||
response.headers["Content-Type"] = "application/octet-stream"
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
from urllib.parse import quote
|
||||
|
||||
from flask import Response
|
||||
from flask_restx import Resource, reqparse
|
||||
from flask import Response, request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
from controllers.common.errors import UnsupportedFileTypeError
|
||||
|
|
@ -10,6 +11,20 @@ from core.tools.signature import verify_tool_file_signature
|
|||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from extensions.ext_database import db as global_db
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class ToolFileQuery(BaseModel):
|
||||
timestamp: str = Field(..., description="Unix timestamp")
|
||||
nonce: str = Field(..., description="Random nonce")
|
||||
sign: str = Field(..., description="HMAC signature")
|
||||
as_attachment: bool = Field(default=False, description="Download as attachment")
|
||||
|
||||
|
||||
files_ns.schema_model(
|
||||
ToolFileQuery.__name__, ToolFileQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
|
||||
|
||||
@files_ns.route("/tools/<uuid:file_id>.<string:extension>")
|
||||
class ToolFileApi(Resource):
|
||||
|
|
@ -36,18 +51,8 @@ class ToolFileApi(Resource):
|
|||
def get(self, file_id, extension):
|
||||
file_id = str(file_id)
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("timestamp", type=str, required=True, location="args")
|
||||
.add_argument("nonce", type=str, required=True, location="args")
|
||||
.add_argument("sign", type=str, required=True, location="args")
|
||||
.add_argument("as_attachment", type=bool, required=False, default=False, location="args")
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
if not verify_tool_file_signature(
|
||||
file_id=file_id, timestamp=args["timestamp"], nonce=args["nonce"], sign=args["sign"]
|
||||
):
|
||||
args = ToolFileQuery.model_validate(request.args.to_dict())
|
||||
if not verify_tool_file_signature(file_id=file_id, timestamp=args.timestamp, nonce=args.nonce, sign=args.sign):
|
||||
raise Forbidden("Invalid request.")
|
||||
|
||||
try:
|
||||
|
|
@ -69,7 +74,7 @@ class ToolFileApi(Resource):
|
|||
)
|
||||
if tool_file.size > 0:
|
||||
response.headers["Content-Length"] = str(tool_file.size)
|
||||
if args["as_attachment"]:
|
||||
if args.as_attachment:
|
||||
encoded_filename = quote(tool_file.name)
|
||||
response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{encoded_filename}"
|
||||
|
||||
|
|
|
|||
|
|
@ -1,40 +1,45 @@
|
|||
from mimetypes import guess_extension
|
||||
|
||||
from flask_restx import Resource, reqparse
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from flask_restx.api import HTTPStatus
|
||||
from pydantic import BaseModel, Field
|
||||
from werkzeug.datastructures import FileStorage
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
import services
|
||||
from controllers.common.errors import (
|
||||
FileTooLargeError,
|
||||
UnsupportedFileTypeError,
|
||||
)
|
||||
from controllers.console.wraps import setup_required
|
||||
from controllers.files import files_ns
|
||||
from controllers.inner_api.plugin.wraps import get_user
|
||||
from core.file.helpers import verify_plugin_file_signature
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from fields.file_fields import build_file_model
|
||||
|
||||
# Define parser for both documentation and validation
|
||||
upload_parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("file", location="files", type=FileStorage, required=True, help="File to upload")
|
||||
.add_argument(
|
||||
"timestamp", type=str, required=True, location="args", help="Unix timestamp for signature verification"
|
||||
)
|
||||
.add_argument("nonce", type=str, required=True, location="args", help="Random string for signature verification")
|
||||
.add_argument("sign", type=str, required=True, location="args", help="HMAC signature for request validation")
|
||||
.add_argument("tenant_id", type=str, required=True, location="args", help="Tenant identifier")
|
||||
.add_argument("user_id", type=str, required=False, location="args", help="User identifier")
|
||||
from ..common.errors import (
|
||||
FileTooLargeError,
|
||||
UnsupportedFileTypeError,
|
||||
)
|
||||
from ..console.wraps import setup_required
|
||||
from ..files import files_ns
|
||||
from ..inner_api.plugin.wraps import get_user
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class PluginUploadQuery(BaseModel):
|
||||
timestamp: str = Field(..., description="Unix timestamp for signature verification")
|
||||
nonce: str = Field(..., description="Random nonce for signature verification")
|
||||
sign: str = Field(..., description="HMAC signature")
|
||||
tenant_id: str = Field(..., description="Tenant identifier")
|
||||
user_id: str | None = Field(default=None, description="User identifier")
|
||||
|
||||
|
||||
files_ns.schema_model(
|
||||
PluginUploadQuery.__name__, PluginUploadQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
|
||||
|
||||
@files_ns.route("/upload/for-plugin")
|
||||
class PluginUploadFileApi(Resource):
|
||||
@setup_required
|
||||
@files_ns.expect(upload_parser)
|
||||
@files_ns.expect(files_ns.models[PluginUploadQuery.__name__])
|
||||
@files_ns.doc("upload_plugin_file")
|
||||
@files_ns.doc(description="Upload a file for plugin usage with signature verification")
|
||||
@files_ns.doc(
|
||||
|
|
@ -62,15 +67,17 @@ class PluginUploadFileApi(Resource):
|
|||
FileTooLargeError: File exceeds size limit
|
||||
UnsupportedFileTypeError: File type not supported
|
||||
"""
|
||||
# Parse and validate all arguments
|
||||
args = upload_parser.parse_args()
|
||||
args = PluginUploadQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
file: FileStorage = args["file"]
|
||||
timestamp: str = args["timestamp"]
|
||||
nonce: str = args["nonce"]
|
||||
sign: str = args["sign"]
|
||||
tenant_id: str = args["tenant_id"]
|
||||
user_id: str | None = args.get("user_id")
|
||||
file: FileStorage | None = request.files.get("file")
|
||||
if file is None:
|
||||
raise Forbidden("File is required.")
|
||||
|
||||
timestamp = args.timestamp
|
||||
nonce = args.nonce
|
||||
sign = args.sign
|
||||
tenant_id = args.tenant_id
|
||||
user_id = args.user_id
|
||||
user = get_user(tenant_id, user_id)
|
||||
|
||||
filename: str | None = file.filename
|
||||
|
|
|
|||
|
|
@ -1,29 +1,38 @@
|
|||
from flask_restx import Resource, reqparse
|
||||
from typing import Any
|
||||
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from controllers.common.schema import register_schema_model
|
||||
from controllers.console.wraps import setup_required
|
||||
from controllers.inner_api import inner_api_ns
|
||||
from controllers.inner_api.wraps import billing_inner_api_only, enterprise_inner_api_only
|
||||
from tasks.mail_inner_task import send_inner_email_task
|
||||
|
||||
_mail_parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("to", type=str, action="append", required=True)
|
||||
.add_argument("subject", type=str, required=True)
|
||||
.add_argument("body", type=str, required=True)
|
||||
.add_argument("substitutions", type=dict, required=False)
|
||||
)
|
||||
|
||||
class InnerMailPayload(BaseModel):
|
||||
to: list[str] = Field(description="Recipient email addresses", min_length=1)
|
||||
subject: str
|
||||
body: str
|
||||
substitutions: dict[str, Any] | None = None
|
||||
|
||||
|
||||
register_schema_model(inner_api_ns, InnerMailPayload)
|
||||
|
||||
|
||||
class BaseMail(Resource):
|
||||
"""Shared logic for sending an inner email."""
|
||||
|
||||
@inner_api_ns.doc("send_inner_mail")
|
||||
@inner_api_ns.doc(description="Send internal email")
|
||||
@inner_api_ns.expect(inner_api_ns.models[InnerMailPayload.__name__])
|
||||
def post(self):
|
||||
args = _mail_parser.parse_args()
|
||||
send_inner_email_task.delay( # type: ignore
|
||||
to=args["to"],
|
||||
subject=args["subject"],
|
||||
body=args["body"],
|
||||
substitutions=args["substitutions"],
|
||||
args = InnerMailPayload.model_validate(inner_api_ns.payload or {})
|
||||
send_inner_email_task.delay(
|
||||
to=args.to,
|
||||
subject=args.subject,
|
||||
body=args.body,
|
||||
substitutions=args.substitutions, # type: ignore
|
||||
)
|
||||
return {"message": "success"}, 200
|
||||
|
||||
|
|
@ -34,7 +43,7 @@ class EnterpriseMail(BaseMail):
|
|||
|
||||
@inner_api_ns.doc("send_enterprise_mail")
|
||||
@inner_api_ns.doc(description="Send internal email for enterprise features")
|
||||
@inner_api_ns.expect(_mail_parser)
|
||||
@inner_api_ns.expect(inner_api_ns.models[InnerMailPayload.__name__])
|
||||
@inner_api_ns.doc(
|
||||
responses={200: "Email sent successfully", 401: "Unauthorized - invalid API key", 404: "Service not available"}
|
||||
)
|
||||
|
|
@ -56,7 +65,7 @@ class BillingMail(BaseMail):
|
|||
|
||||
@inner_api_ns.doc("send_billing_mail")
|
||||
@inner_api_ns.doc(description="Send internal email for billing notifications")
|
||||
@inner_api_ns.expect(_mail_parser)
|
||||
@inner_api_ns.expect(inner_api_ns.models[InnerMailPayload.__name__])
|
||||
@inner_api_ns.doc(
|
||||
responses={200: "Email sent successfully", 401: "Unauthorized - invalid API key", 404: "Service not available"}
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,10 +1,9 @@
|
|||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import ParamSpec, TypeVar, cast
|
||||
from typing import ParamSpec, TypeVar
|
||||
|
||||
from flask import current_app, request
|
||||
from flask_login import user_logged_in
|
||||
from flask_restx import reqparse
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
|
|
@ -17,6 +16,11 @@ P = ParamSpec("P")
|
|||
R = TypeVar("R")
|
||||
|
||||
|
||||
class TenantUserPayload(BaseModel):
|
||||
tenant_id: str
|
||||
user_id: str
|
||||
|
||||
|
||||
def get_user(tenant_id: str, user_id: str | None) -> EndUser:
|
||||
"""
|
||||
Get current user
|
||||
|
|
@ -67,58 +71,45 @@ def get_user(tenant_id: str, user_id: str | None) -> EndUser:
|
|||
return user_model
|
||||
|
||||
|
||||
def get_user_tenant(view: Callable[P, R] | None = None):
|
||||
def decorator(view_func: Callable[P, R]):
|
||||
@wraps(view_func)
|
||||
def decorated_view(*args: P.args, **kwargs: P.kwargs):
|
||||
# fetch json body
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("tenant_id", type=str, required=True, location="json")
|
||||
.add_argument("user_id", type=str, required=True, location="json")
|
||||
)
|
||||
def get_user_tenant(view_func: Callable[P, R]):
|
||||
@wraps(view_func)
|
||||
def decorated_view(*args: P.args, **kwargs: P.kwargs):
|
||||
payload = TenantUserPayload.model_validate(request.get_json(silent=True) or {})
|
||||
|
||||
p = parser.parse_args()
|
||||
user_id = payload.user_id
|
||||
tenant_id = payload.tenant_id
|
||||
|
||||
user_id = cast(str, p.get("user_id"))
|
||||
tenant_id = cast(str, p.get("tenant_id"))
|
||||
if not tenant_id:
|
||||
raise ValueError("tenant_id is required")
|
||||
|
||||
if not tenant_id:
|
||||
raise ValueError("tenant_id is required")
|
||||
if not user_id:
|
||||
user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID
|
||||
|
||||
if not user_id:
|
||||
user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID
|
||||
|
||||
try:
|
||||
tenant_model = (
|
||||
db.session.query(Tenant)
|
||||
.where(
|
||||
Tenant.id == tenant_id,
|
||||
)
|
||||
.first()
|
||||
try:
|
||||
tenant_model = (
|
||||
db.session.query(Tenant)
|
||||
.where(
|
||||
Tenant.id == tenant_id,
|
||||
)
|
||||
except Exception:
|
||||
raise ValueError("tenant not found")
|
||||
.first()
|
||||
)
|
||||
except Exception:
|
||||
raise ValueError("tenant not found")
|
||||
|
||||
if not tenant_model:
|
||||
raise ValueError("tenant not found")
|
||||
if not tenant_model:
|
||||
raise ValueError("tenant not found")
|
||||
|
||||
kwargs["tenant_model"] = tenant_model
|
||||
kwargs["tenant_model"] = tenant_model
|
||||
|
||||
user = get_user(tenant_id, user_id)
|
||||
kwargs["user_model"] = user
|
||||
user = get_user(tenant_id, user_id)
|
||||
kwargs["user_model"] = user
|
||||
|
||||
current_app.login_manager._update_request_context_with_user(user) # type: ignore
|
||||
user_logged_in.send(current_app._get_current_object(), user=current_user) # type: ignore
|
||||
current_app.login_manager._update_request_context_with_user(user) # type: ignore
|
||||
user_logged_in.send(current_app._get_current_object(), user=current_user) # type: ignore
|
||||
|
||||
return view_func(*args, **kwargs)
|
||||
return view_func(*args, **kwargs)
|
||||
|
||||
return decorated_view
|
||||
|
||||
if view is None:
|
||||
return decorator
|
||||
else:
|
||||
return decorator(view)
|
||||
return decorated_view
|
||||
|
||||
|
||||
def plugin_data(view: Callable[P, R] | None = None, *, payload_type: type[BaseModel]):
|
||||
|
|
|
|||
|
|
@ -1,7 +1,9 @@
|
|||
import json
|
||||
|
||||
from flask_restx import Resource, reqparse
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console.wraps import setup_required
|
||||
from controllers.inner_api import inner_api_ns
|
||||
from controllers.inner_api.wraps import enterprise_inner_api_only
|
||||
|
|
@ -11,12 +13,25 @@ from models import Account
|
|||
from services.account_service import TenantService
|
||||
|
||||
|
||||
class WorkspaceCreatePayload(BaseModel):
|
||||
name: str
|
||||
owner_email: str
|
||||
|
||||
|
||||
class WorkspaceOwnerlessPayload(BaseModel):
|
||||
name: str
|
||||
|
||||
|
||||
register_schema_models(inner_api_ns, WorkspaceCreatePayload, WorkspaceOwnerlessPayload)
|
||||
|
||||
|
||||
@inner_api_ns.route("/enterprise/workspace")
|
||||
class EnterpriseWorkspace(Resource):
|
||||
@setup_required
|
||||
@enterprise_inner_api_only
|
||||
@inner_api_ns.doc("create_enterprise_workspace")
|
||||
@inner_api_ns.doc(description="Create a new enterprise workspace with owner assignment")
|
||||
@inner_api_ns.expect(inner_api_ns.models[WorkspaceCreatePayload.__name__])
|
||||
@inner_api_ns.doc(
|
||||
responses={
|
||||
200: "Workspace created successfully",
|
||||
|
|
@ -25,18 +40,13 @@ class EnterpriseWorkspace(Resource):
|
|||
}
|
||||
)
|
||||
def post(self):
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("name", type=str, required=True, location="json")
|
||||
.add_argument("owner_email", type=str, required=True, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = WorkspaceCreatePayload.model_validate(inner_api_ns.payload or {})
|
||||
|
||||
account = db.session.query(Account).filter_by(email=args["owner_email"]).first()
|
||||
account = db.session.query(Account).filter_by(email=args.owner_email).first()
|
||||
if account is None:
|
||||
return {"message": "owner account not found."}, 404
|
||||
|
||||
tenant = TenantService.create_tenant(args["name"], is_from_dashboard=True)
|
||||
tenant = TenantService.create_tenant(args.name, is_from_dashboard=True)
|
||||
TenantService.create_tenant_member(tenant, account, role="owner")
|
||||
|
||||
tenant_was_created.send(tenant)
|
||||
|
|
@ -62,6 +72,7 @@ class EnterpriseWorkspaceNoOwnerEmail(Resource):
|
|||
@enterprise_inner_api_only
|
||||
@inner_api_ns.doc("create_enterprise_workspace_ownerless")
|
||||
@inner_api_ns.doc(description="Create a new enterprise workspace without initial owner assignment")
|
||||
@inner_api_ns.expect(inner_api_ns.models[WorkspaceOwnerlessPayload.__name__])
|
||||
@inner_api_ns.doc(
|
||||
responses={
|
||||
200: "Workspace created successfully",
|
||||
|
|
@ -70,10 +81,9 @@ class EnterpriseWorkspaceNoOwnerEmail(Resource):
|
|||
}
|
||||
)
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser().add_argument("name", type=str, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
args = WorkspaceOwnerlessPayload.model_validate(inner_api_ns.payload or {})
|
||||
|
||||
tenant = TenantService.create_tenant(args["name"], is_from_dashboard=True)
|
||||
tenant = TenantService.create_tenant(args.name, is_from_dashboard=True)
|
||||
|
||||
tenant_was_created.send(tenant)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,10 +1,11 @@
|
|||
from typing import Union
|
||||
from typing import Any, Union
|
||||
|
||||
from flask import Response
|
||||
from flask_restx import Resource, reqparse
|
||||
from pydantic import ValidationError
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, ValidationError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from controllers.common.schema import register_schema_model
|
||||
from controllers.console.app.mcp_server import AppMCPServerStatus
|
||||
from controllers.mcp import mcp_ns
|
||||
from core.app.app_config.entities import VariableEntity
|
||||
|
|
@ -24,27 +25,19 @@ class MCPRequestError(Exception):
|
|||
super().__init__(message)
|
||||
|
||||
|
||||
def int_or_str(value):
|
||||
"""Validate that a value is either an integer or string."""
|
||||
if isinstance(value, (int, str)):
|
||||
return value
|
||||
else:
|
||||
return None
|
||||
class MCPRequestPayload(BaseModel):
|
||||
jsonrpc: str = Field(description="JSON-RPC version (should be '2.0')")
|
||||
method: str = Field(description="The method to invoke")
|
||||
params: dict[str, Any] | None = Field(default=None, description="Parameters for the method")
|
||||
id: int | str | None = Field(default=None, description="Request ID for tracking responses")
|
||||
|
||||
|
||||
# Define parser for both documentation and validation
|
||||
mcp_request_parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("jsonrpc", type=str, required=True, location="json", help="JSON-RPC version (should be '2.0')")
|
||||
.add_argument("method", type=str, required=True, location="json", help="The method to invoke")
|
||||
.add_argument("params", type=dict, required=False, location="json", help="Parameters for the method")
|
||||
.add_argument("id", type=int_or_str, required=False, location="json", help="Request ID for tracking responses")
|
||||
)
|
||||
register_schema_model(mcp_ns, MCPRequestPayload)
|
||||
|
||||
|
||||
@mcp_ns.route("/server/<string:server_code>/mcp")
|
||||
class MCPAppApi(Resource):
|
||||
@mcp_ns.expect(mcp_request_parser)
|
||||
@mcp_ns.expect(mcp_ns.models[MCPRequestPayload.__name__])
|
||||
@mcp_ns.doc("handle_mcp_request")
|
||||
@mcp_ns.doc(description="Handle Model Context Protocol (MCP) requests for a specific server")
|
||||
@mcp_ns.doc(params={"server_code": "Unique identifier for the MCP server"})
|
||||
|
|
@ -70,9 +63,9 @@ class MCPAppApi(Resource):
|
|||
Raises:
|
||||
ValidationError: Invalid request format or parameters
|
||||
"""
|
||||
args = mcp_request_parser.parse_args()
|
||||
request_id: Union[int, str] | None = args.get("id")
|
||||
mcp_request = self._parse_mcp_request(args)
|
||||
args = MCPRequestPayload.model_validate(mcp_ns.payload or {})
|
||||
request_id: Union[int, str] | None = args.id
|
||||
mcp_request = self._parse_mcp_request(args.model_dump(exclude_none=True))
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
# Get MCP server and app
|
||||
|
|
|
|||
|
|
@ -1,9 +1,11 @@
|
|||
from typing import Literal
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Api, Namespace, Resource, fields, reqparse
|
||||
from flask_restx import Api, Namespace, Resource, fields
|
||||
from flask_restx.api import HTTPStatus
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console.wraps import edit_permission_required
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.wraps import validate_app_token
|
||||
|
|
@ -12,26 +14,24 @@ from fields.annotation_fields import annotation_fields, build_annotation_model
|
|||
from models.model import App
|
||||
from services.annotation_service import AppAnnotationService
|
||||
|
||||
# Define parsers for annotation API
|
||||
annotation_create_parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("question", required=True, type=str, location="json", help="Annotation question")
|
||||
.add_argument("answer", required=True, type=str, location="json", help="Annotation answer")
|
||||
)
|
||||
|
||||
annotation_reply_action_parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument(
|
||||
"score_threshold", required=True, type=float, location="json", help="Score threshold for annotation matching"
|
||||
)
|
||||
.add_argument("embedding_provider_name", required=True, type=str, location="json", help="Embedding provider name")
|
||||
.add_argument("embedding_model_name", required=True, type=str, location="json", help="Embedding model name")
|
||||
)
|
||||
class AnnotationCreatePayload(BaseModel):
|
||||
question: str = Field(description="Annotation question")
|
||||
answer: str = Field(description="Annotation answer")
|
||||
|
||||
|
||||
class AnnotationReplyActionPayload(BaseModel):
|
||||
score_threshold: float = Field(description="Score threshold for annotation matching")
|
||||
embedding_provider_name: str = Field(description="Embedding provider name")
|
||||
embedding_model_name: str = Field(description="Embedding model name")
|
||||
|
||||
|
||||
register_schema_models(service_api_ns, AnnotationCreatePayload, AnnotationReplyActionPayload)
|
||||
|
||||
|
||||
@service_api_ns.route("/apps/annotation-reply/<string:action>")
|
||||
class AnnotationReplyActionApi(Resource):
|
||||
@service_api_ns.expect(annotation_reply_action_parser)
|
||||
@service_api_ns.expect(service_api_ns.models[AnnotationReplyActionPayload.__name__])
|
||||
@service_api_ns.doc("annotation_reply_action")
|
||||
@service_api_ns.doc(description="Enable or disable annotation reply feature")
|
||||
@service_api_ns.doc(params={"action": "Action to perform: 'enable' or 'disable'"})
|
||||
|
|
@ -44,7 +44,7 @@ class AnnotationReplyActionApi(Resource):
|
|||
@validate_app_token
|
||||
def post(self, app_model: App, action: Literal["enable", "disable"]):
|
||||
"""Enable or disable annotation reply feature."""
|
||||
args = annotation_reply_action_parser.parse_args()
|
||||
args = AnnotationReplyActionPayload.model_validate(service_api_ns.payload or {}).model_dump()
|
||||
if action == "enable":
|
||||
result = AppAnnotationService.enable_app_annotation(args, app_model.id)
|
||||
elif action == "disable":
|
||||
|
|
@ -126,7 +126,7 @@ class AnnotationListApi(Resource):
|
|||
"page": page,
|
||||
}
|
||||
|
||||
@service_api_ns.expect(annotation_create_parser)
|
||||
@service_api_ns.expect(service_api_ns.models[AnnotationCreatePayload.__name__])
|
||||
@service_api_ns.doc("create_annotation")
|
||||
@service_api_ns.doc(description="Create a new annotation")
|
||||
@service_api_ns.doc(
|
||||
|
|
@ -139,14 +139,14 @@ class AnnotationListApi(Resource):
|
|||
@service_api_ns.marshal_with(build_annotation_model(service_api_ns), code=HTTPStatus.CREATED)
|
||||
def post(self, app_model: App):
|
||||
"""Create a new annotation."""
|
||||
args = annotation_create_parser.parse_args()
|
||||
args = AnnotationCreatePayload.model_validate(service_api_ns.payload or {}).model_dump()
|
||||
annotation = AppAnnotationService.insert_app_annotation_directly(args, app_model.id)
|
||||
return annotation, 201
|
||||
|
||||
|
||||
@service_api_ns.route("/apps/annotations/<uuid:annotation_id>")
|
||||
class AnnotationUpdateDeleteApi(Resource):
|
||||
@service_api_ns.expect(annotation_create_parser)
|
||||
@service_api_ns.expect(service_api_ns.models[AnnotationCreatePayload.__name__])
|
||||
@service_api_ns.doc("update_annotation")
|
||||
@service_api_ns.doc(description="Update an existing annotation")
|
||||
@service_api_ns.doc(params={"annotation_id": "Annotation ID"})
|
||||
|
|
@ -163,7 +163,7 @@ class AnnotationUpdateDeleteApi(Resource):
|
|||
@service_api_ns.marshal_with(build_annotation_model(service_api_ns))
|
||||
def put(self, app_model: App, annotation_id: str):
|
||||
"""Update an existing annotation."""
|
||||
args = annotation_create_parser.parse_args()
|
||||
args = AnnotationCreatePayload.model_validate(service_api_ns.payload or {}).model_dump()
|
||||
annotation = AppAnnotationService.update_app_annotation_directly(args, app_model.id, annotation_id)
|
||||
return annotation
|
||||
|
||||
|
|
|
|||
|
|
@ -1,10 +1,12 @@
|
|||
import logging
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, reqparse
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
import services
|
||||
from controllers.common.schema import register_schema_model
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.app.error import (
|
||||
AppUnavailableError,
|
||||
|
|
@ -84,19 +86,19 @@ class AudioApi(Resource):
|
|||
raise InternalServerError()
|
||||
|
||||
|
||||
# Define parser for text-to-audio API
|
||||
text_to_audio_parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("message_id", type=str, required=False, location="json", help="Message ID")
|
||||
.add_argument("voice", type=str, location="json", help="Voice to use for TTS")
|
||||
.add_argument("text", type=str, location="json", help="Text to convert to audio")
|
||||
.add_argument("streaming", type=bool, location="json", help="Enable streaming response")
|
||||
)
|
||||
class TextToAudioPayload(BaseModel):
|
||||
message_id: str | None = Field(default=None, description="Message ID")
|
||||
voice: str | None = Field(default=None, description="Voice to use for TTS")
|
||||
text: str | None = Field(default=None, description="Text to convert to audio")
|
||||
streaming: bool | None = Field(default=None, description="Enable streaming response")
|
||||
|
||||
|
||||
register_schema_model(service_api_ns, TextToAudioPayload)
|
||||
|
||||
|
||||
@service_api_ns.route("/text-to-audio")
|
||||
class TextApi(Resource):
|
||||
@service_api_ns.expect(text_to_audio_parser)
|
||||
@service_api_ns.expect(service_api_ns.models[TextToAudioPayload.__name__])
|
||||
@service_api_ns.doc("text_to_audio")
|
||||
@service_api_ns.doc(description="Convert text to audio using text-to-speech")
|
||||
@service_api_ns.doc(
|
||||
|
|
@ -114,11 +116,11 @@ class TextApi(Resource):
|
|||
Converts the provided text to audio using the specified voice.
|
||||
"""
|
||||
try:
|
||||
args = text_to_audio_parser.parse_args()
|
||||
payload = TextToAudioPayload.model_validate(service_api_ns.payload or {})
|
||||
|
||||
message_id = args.get("message_id", None)
|
||||
text = args.get("text", None)
|
||||
voice = args.get("voice", None)
|
||||
message_id = payload.message_id
|
||||
text = payload.text
|
||||
voice = payload.voice
|
||||
response = AudioService.transcript_tts(
|
||||
app_model=app_model, text=text, voice=voice, end_user=end_user.external_user_id, message_id=message_id
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,10 +1,14 @@
|
|||
import logging
|
||||
from typing import Any, Literal
|
||||
from uuid import UUID
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, reqparse
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.app.error import (
|
||||
AppUnavailableError,
|
||||
|
|
@ -26,7 +30,6 @@ from core.errors.error import (
|
|||
from core.helper.trace_id_helper import get_external_trace_id
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from libs import helper
|
||||
from libs.helper import uuid_value
|
||||
from models.model import App, AppMode, EndUser
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.app_task_service import AppTaskService
|
||||
|
|
@ -36,40 +39,43 @@ from services.errors.llm import InvokeRateLimitError
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Define parser for completion API
|
||||
completion_parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("inputs", type=dict, required=True, location="json", help="Input parameters for completion")
|
||||
.add_argument("query", type=str, location="json", default="", help="The query string")
|
||||
.add_argument("files", type=list, required=False, location="json", help="List of file attachments")
|
||||
.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json", help="Response mode")
|
||||
.add_argument("retriever_from", type=str, required=False, default="dev", location="json", help="Retriever source")
|
||||
)
|
||||
class CompletionRequestPayload(BaseModel):
|
||||
inputs: dict[str, Any]
|
||||
query: str = Field(default="")
|
||||
files: list[dict[str, Any]] | None = None
|
||||
response_mode: Literal["blocking", "streaming"] | None = None
|
||||
retriever_from: str = Field(default="dev")
|
||||
|
||||
# Define parser for chat API
|
||||
chat_parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("inputs", type=dict, required=True, location="json", help="Input parameters for chat")
|
||||
.add_argument("query", type=str, required=True, location="json", help="The chat query")
|
||||
.add_argument("files", type=list, required=False, location="json", help="List of file attachments")
|
||||
.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json", help="Response mode")
|
||||
.add_argument("conversation_id", type=uuid_value, location="json", help="Existing conversation ID")
|
||||
.add_argument("retriever_from", type=str, required=False, default="dev", location="json", help="Retriever source")
|
||||
.add_argument(
|
||||
"auto_generate_name",
|
||||
type=bool,
|
||||
required=False,
|
||||
default=True,
|
||||
location="json",
|
||||
help="Auto generate conversation name",
|
||||
)
|
||||
.add_argument("workflow_id", type=str, required=False, location="json", help="Workflow ID for advanced chat")
|
||||
)
|
||||
|
||||
class ChatRequestPayload(BaseModel):
|
||||
inputs: dict[str, Any]
|
||||
query: str
|
||||
files: list[dict[str, Any]] | None = None
|
||||
response_mode: Literal["blocking", "streaming"] | None = None
|
||||
conversation_id: str | None = Field(default=None, description="Conversation UUID")
|
||||
retriever_from: str = Field(default="dev")
|
||||
auto_generate_name: bool = Field(default=True, description="Auto generate conversation name")
|
||||
workflow_id: str | None = Field(default=None, description="Workflow ID for advanced chat")
|
||||
|
||||
@field_validator("conversation_id", mode="before")
|
||||
@classmethod
|
||||
def normalize_conversation_id(cls, value: str | UUID | None) -> str | None:
|
||||
"""Allow missing or blank conversation IDs; enforce UUID format when provided."""
|
||||
if not value:
|
||||
return None
|
||||
|
||||
try:
|
||||
return helper.uuid_value(value)
|
||||
except ValueError as exc:
|
||||
raise ValueError("conversation_id must be a valid UUID") from exc
|
||||
|
||||
|
||||
register_schema_models(service_api_ns, CompletionRequestPayload, ChatRequestPayload)
|
||||
|
||||
|
||||
@service_api_ns.route("/completion-messages")
|
||||
class CompletionApi(Resource):
|
||||
@service_api_ns.expect(completion_parser)
|
||||
@service_api_ns.expect(service_api_ns.models[CompletionRequestPayload.__name__])
|
||||
@service_api_ns.doc("create_completion")
|
||||
@service_api_ns.doc(description="Create a completion for the given prompt")
|
||||
@service_api_ns.doc(
|
||||
|
|
@ -91,12 +97,13 @@ class CompletionApi(Resource):
|
|||
if app_model.mode != AppMode.COMPLETION:
|
||||
raise AppUnavailableError()
|
||||
|
||||
args = completion_parser.parse_args()
|
||||
payload = CompletionRequestPayload.model_validate(service_api_ns.payload or {})
|
||||
external_trace_id = get_external_trace_id(request)
|
||||
args = payload.model_dump(exclude_none=True)
|
||||
if external_trace_id:
|
||||
args["external_trace_id"] = external_trace_id
|
||||
|
||||
streaming = args["response_mode"] == "streaming"
|
||||
streaming = payload.response_mode == "streaming"
|
||||
|
||||
args["auto_generate_name"] = False
|
||||
|
||||
|
|
@ -162,7 +169,7 @@ class CompletionStopApi(Resource):
|
|||
|
||||
@service_api_ns.route("/chat-messages")
|
||||
class ChatApi(Resource):
|
||||
@service_api_ns.expect(chat_parser)
|
||||
@service_api_ns.expect(service_api_ns.models[ChatRequestPayload.__name__])
|
||||
@service_api_ns.doc("create_chat_message")
|
||||
@service_api_ns.doc(description="Send a message in a chat conversation")
|
||||
@service_api_ns.doc(
|
||||
|
|
@ -186,13 +193,14 @@ class ChatApi(Resource):
|
|||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
args = chat_parser.parse_args()
|
||||
payload = ChatRequestPayload.model_validate(service_api_ns.payload or {})
|
||||
|
||||
external_trace_id = get_external_trace_id(request)
|
||||
args = payload.model_dump(exclude_none=True)
|
||||
if external_trace_id:
|
||||
args["external_trace_id"] = external_trace_id
|
||||
|
||||
streaming = args["response_mode"] == "streaming"
|
||||
streaming = payload.response_mode == "streaming"
|
||||
|
||||
try:
|
||||
response = AppGenerateService.generate(
|
||||
|
|
|
|||
|
|
@ -1,10 +1,15 @@
|
|||
from flask_restx import Resource, reqparse
|
||||
from typing import Any, Literal
|
||||
from uuid import UUID
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from flask_restx._http import HTTPStatus
|
||||
from flask_restx.inputs import int_range
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import BadRequest, NotFound
|
||||
|
||||
import services
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.app.error import NotChatAppError
|
||||
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
||||
|
|
@ -19,74 +24,51 @@ from fields.conversation_variable_fields import (
|
|||
build_conversation_variable_infinite_scroll_pagination_model,
|
||||
build_conversation_variable_model,
|
||||
)
|
||||
from libs.helper import uuid_value
|
||||
from models.model import App, AppMode, EndUser
|
||||
from services.conversation_service import ConversationService
|
||||
|
||||
# Define parsers for conversation APIs
|
||||
conversation_list_parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("last_id", type=uuid_value, location="args", help="Last conversation ID for pagination")
|
||||
.add_argument(
|
||||
"limit",
|
||||
type=int_range(1, 100),
|
||||
required=False,
|
||||
default=20,
|
||||
location="args",
|
||||
help="Number of conversations to return",
|
||||
)
|
||||
.add_argument(
|
||||
"sort_by",
|
||||
type=str,
|
||||
choices=["created_at", "-created_at", "updated_at", "-updated_at"],
|
||||
required=False,
|
||||
default="-updated_at",
|
||||
location="args",
|
||||
help="Sort order for conversations",
|
||||
)
|
||||
)
|
||||
|
||||
conversation_rename_parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("name", type=str, required=False, location="json", help="New conversation name")
|
||||
.add_argument(
|
||||
"auto_generate",
|
||||
type=bool,
|
||||
required=False,
|
||||
default=False,
|
||||
location="json",
|
||||
help="Auto-generate conversation name",
|
||||
class ConversationListQuery(BaseModel):
|
||||
last_id: UUID | None = Field(default=None, description="Last conversation ID for pagination")
|
||||
limit: int = Field(default=20, ge=1, le=100, description="Number of conversations to return")
|
||||
sort_by: Literal["created_at", "-created_at", "updated_at", "-updated_at"] = Field(
|
||||
default="-updated_at", description="Sort order for conversations"
|
||||
)
|
||||
)
|
||||
|
||||
conversation_variables_parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("last_id", type=uuid_value, location="args", help="Last variable ID for pagination")
|
||||
.add_argument(
|
||||
"limit",
|
||||
type=int_range(1, 100),
|
||||
required=False,
|
||||
default=20,
|
||||
location="args",
|
||||
help="Number of variables to return",
|
||||
)
|
||||
)
|
||||
|
||||
conversation_variable_update_parser = reqparse.RequestParser().add_argument(
|
||||
# using lambda is for passing the already-typed value without modification
|
||||
# if no lambda, it will be converted to string
|
||||
# the string cannot be converted using json.loads
|
||||
"value",
|
||||
required=True,
|
||||
location="json",
|
||||
type=lambda x: x,
|
||||
help="New value for the conversation variable",
|
||||
class ConversationRenamePayload(BaseModel):
|
||||
name: str | None = Field(default=None, description="New conversation name (required if auto_generate is false)")
|
||||
auto_generate: bool = Field(default=False, description="Auto-generate conversation name")
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_name_requirement(self):
|
||||
if not self.auto_generate:
|
||||
if self.name is None or not self.name.strip():
|
||||
raise ValueError("name is required when auto_generate is false")
|
||||
return self
|
||||
|
||||
|
||||
class ConversationVariablesQuery(BaseModel):
|
||||
last_id: UUID | None = Field(default=None, description="Last variable ID for pagination")
|
||||
limit: int = Field(default=20, ge=1, le=100, description="Number of variables to return")
|
||||
|
||||
|
||||
class ConversationVariableUpdatePayload(BaseModel):
|
||||
value: Any
|
||||
|
||||
|
||||
register_schema_models(
|
||||
service_api_ns,
|
||||
ConversationListQuery,
|
||||
ConversationRenamePayload,
|
||||
ConversationVariablesQuery,
|
||||
ConversationVariableUpdatePayload,
|
||||
)
|
||||
|
||||
|
||||
@service_api_ns.route("/conversations")
|
||||
class ConversationApi(Resource):
|
||||
@service_api_ns.expect(conversation_list_parser)
|
||||
@service_api_ns.expect(service_api_ns.models[ConversationListQuery.__name__])
|
||||
@service_api_ns.doc("list_conversations")
|
||||
@service_api_ns.doc(description="List all conversations for the current user")
|
||||
@service_api_ns.doc(
|
||||
|
|
@ -107,7 +89,8 @@ class ConversationApi(Resource):
|
|||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
args = conversation_list_parser.parse_args()
|
||||
query_args = ConversationListQuery.model_validate(request.args.to_dict())
|
||||
last_id = str(query_args.last_id) if query_args.last_id else None
|
||||
|
||||
try:
|
||||
with Session(db.engine) as session:
|
||||
|
|
@ -115,10 +98,10 @@ class ConversationApi(Resource):
|
|||
session=session,
|
||||
app_model=app_model,
|
||||
user=end_user,
|
||||
last_id=args["last_id"],
|
||||
limit=args["limit"],
|
||||
last_id=last_id,
|
||||
limit=query_args.limit,
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
sort_by=args["sort_by"],
|
||||
sort_by=query_args.sort_by,
|
||||
)
|
||||
except services.errors.conversation.LastConversationNotExistsError:
|
||||
raise NotFound("Last Conversation Not Exists.")
|
||||
|
|
@ -155,7 +138,7 @@ class ConversationDetailApi(Resource):
|
|||
|
||||
@service_api_ns.route("/conversations/<uuid:c_id>/name")
|
||||
class ConversationRenameApi(Resource):
|
||||
@service_api_ns.expect(conversation_rename_parser)
|
||||
@service_api_ns.expect(service_api_ns.models[ConversationRenamePayload.__name__])
|
||||
@service_api_ns.doc("rename_conversation")
|
||||
@service_api_ns.doc(description="Rename a conversation or auto-generate a name")
|
||||
@service_api_ns.doc(params={"c_id": "Conversation ID"})
|
||||
|
|
@ -176,17 +159,17 @@ class ConversationRenameApi(Resource):
|
|||
|
||||
conversation_id = str(c_id)
|
||||
|
||||
args = conversation_rename_parser.parse_args()
|
||||
payload = ConversationRenamePayload.model_validate(service_api_ns.payload or {})
|
||||
|
||||
try:
|
||||
return ConversationService.rename(app_model, conversation_id, end_user, args["name"], args["auto_generate"])
|
||||
return ConversationService.rename(app_model, conversation_id, end_user, payload.name, payload.auto_generate)
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
||||
|
||||
@service_api_ns.route("/conversations/<uuid:c_id>/variables")
|
||||
class ConversationVariablesApi(Resource):
|
||||
@service_api_ns.expect(conversation_variables_parser)
|
||||
@service_api_ns.expect(service_api_ns.models[ConversationVariablesQuery.__name__])
|
||||
@service_api_ns.doc("list_conversation_variables")
|
||||
@service_api_ns.doc(description="List all variables for a conversation")
|
||||
@service_api_ns.doc(params={"c_id": "Conversation ID"})
|
||||
|
|
@ -211,11 +194,12 @@ class ConversationVariablesApi(Resource):
|
|||
|
||||
conversation_id = str(c_id)
|
||||
|
||||
args = conversation_variables_parser.parse_args()
|
||||
query_args = ConversationVariablesQuery.model_validate(request.args.to_dict())
|
||||
last_id = str(query_args.last_id) if query_args.last_id else None
|
||||
|
||||
try:
|
||||
return ConversationService.get_conversational_variable(
|
||||
app_model, conversation_id, end_user, args["limit"], args["last_id"]
|
||||
app_model, conversation_id, end_user, query_args.limit, last_id
|
||||
)
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
|
@ -223,7 +207,7 @@ class ConversationVariablesApi(Resource):
|
|||
|
||||
@service_api_ns.route("/conversations/<uuid:c_id>/variables/<uuid:variable_id>")
|
||||
class ConversationVariableDetailApi(Resource):
|
||||
@service_api_ns.expect(conversation_variable_update_parser)
|
||||
@service_api_ns.expect(service_api_ns.models[ConversationVariableUpdatePayload.__name__])
|
||||
@service_api_ns.doc("update_conversation_variable")
|
||||
@service_api_ns.doc(description="Update a conversation variable's value")
|
||||
@service_api_ns.doc(params={"c_id": "Conversation ID", "variable_id": "Variable ID"})
|
||||
|
|
@ -250,11 +234,11 @@ class ConversationVariableDetailApi(Resource):
|
|||
conversation_id = str(c_id)
|
||||
variable_id = str(variable_id)
|
||||
|
||||
args = conversation_variable_update_parser.parse_args()
|
||||
payload = ConversationVariableUpdatePayload.model_validate(service_api_ns.payload or {})
|
||||
|
||||
try:
|
||||
return ConversationService.update_conversation_variable(
|
||||
app_model, conversation_id, variable_id, end_user, args["value"]
|
||||
app_model, conversation_id, variable_id, end_user, payload.value
|
||||
)
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
|
|
|||
|
|
@ -1,9 +1,11 @@
|
|||
import logging
|
||||
from urllib.parse import quote
|
||||
|
||||
from flask import Response
|
||||
from flask_restx import Resource, reqparse
|
||||
from flask import Response, request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from controllers.common.schema import register_schema_model
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.app.error import (
|
||||
FileAccessDeniedError,
|
||||
|
|
@ -17,10 +19,11 @@ from models.model import App, EndUser, Message, MessageFile, UploadFile
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Define parser for file preview API
|
||||
file_preview_parser = reqparse.RequestParser().add_argument(
|
||||
"as_attachment", type=bool, required=False, default=False, location="args", help="Download as attachment"
|
||||
)
|
||||
class FilePreviewQuery(BaseModel):
|
||||
as_attachment: bool = Field(default=False, description="Download as attachment")
|
||||
|
||||
|
||||
register_schema_model(service_api_ns, FilePreviewQuery)
|
||||
|
||||
|
||||
@service_api_ns.route("/files/<uuid:file_id>/preview")
|
||||
|
|
@ -32,7 +35,7 @@ class FilePreviewApi(Resource):
|
|||
Files can only be accessed if they belong to messages within the requesting app's context.
|
||||
"""
|
||||
|
||||
@service_api_ns.expect(file_preview_parser)
|
||||
@service_api_ns.expect(service_api_ns.models[FilePreviewQuery.__name__])
|
||||
@service_api_ns.doc("preview_file")
|
||||
@service_api_ns.doc(description="Preview or download a file uploaded via Service API")
|
||||
@service_api_ns.doc(params={"file_id": "UUID of the file to preview"})
|
||||
|
|
@ -55,7 +58,7 @@ class FilePreviewApi(Resource):
|
|||
file_id = str(file_id)
|
||||
|
||||
# Parse query parameters
|
||||
args = file_preview_parser.parse_args()
|
||||
args = FilePreviewQuery.model_validate(request.args.to_dict())
|
||||
|
||||
# Validate file ownership and get file objects
|
||||
_, upload_file = self._validate_file_ownership(file_id, app_model.id)
|
||||
|
|
@ -67,7 +70,7 @@ class FilePreviewApi(Resource):
|
|||
raise FileNotFoundError(f"Failed to load file content: {str(e)}")
|
||||
|
||||
# Build response with appropriate headers
|
||||
response = self._build_file_response(generator, upload_file, args["as_attachment"])
|
||||
response = self._build_file_response(generator, upload_file, args.as_attachment)
|
||||
|
||||
return response
|
||||
|
||||
|
|
|
|||
|
|
@ -1,11 +1,15 @@
|
|||
import json
|
||||
import logging
|
||||
from typing import Literal
|
||||
from uuid import UUID
|
||||
|
||||
from flask_restx import Api, Namespace, Resource, fields, reqparse
|
||||
from flask_restx.inputs import int_range
|
||||
from flask import request
|
||||
from flask_restx import Namespace, Resource, fields
|
||||
from pydantic import BaseModel, Field
|
||||
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.app.error import NotChatAppError
|
||||
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
||||
|
|
@ -13,7 +17,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom
|
|||
from fields.conversation_fields import build_message_file_model
|
||||
from fields.message_fields import build_agent_thought_model, build_feedback_model
|
||||
from fields.raws import FilesContainedField
|
||||
from libs.helper import TimestampField, uuid_value
|
||||
from libs.helper import TimestampField
|
||||
from models.model import App, AppMode, EndUser
|
||||
from services.errors.message import (
|
||||
FirstMessageNotExistsError,
|
||||
|
|
@ -25,42 +29,26 @@ from services.message_service import MessageService
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Define parsers for message APIs
|
||||
message_list_parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("conversation_id", required=True, type=uuid_value, location="args", help="Conversation ID")
|
||||
.add_argument("first_id", type=uuid_value, location="args", help="First message ID for pagination")
|
||||
.add_argument(
|
||||
"limit",
|
||||
type=int_range(1, 100),
|
||||
required=False,
|
||||
default=20,
|
||||
location="args",
|
||||
help="Number of messages to return",
|
||||
)
|
||||
)
|
||||
|
||||
message_feedback_parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("rating", type=str, choices=["like", "dislike", None], location="json", help="Feedback rating")
|
||||
.add_argument("content", type=str, location="json", help="Feedback content")
|
||||
)
|
||||
|
||||
feedback_list_parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("page", type=int, default=1, location="args", help="Page number")
|
||||
.add_argument(
|
||||
"limit",
|
||||
type=int_range(1, 101),
|
||||
required=False,
|
||||
default=20,
|
||||
location="args",
|
||||
help="Number of feedbacks per page",
|
||||
)
|
||||
)
|
||||
class MessageListQuery(BaseModel):
|
||||
conversation_id: UUID
|
||||
first_id: UUID | None = None
|
||||
limit: int = Field(default=20, ge=1, le=100, description="Number of messages to return")
|
||||
|
||||
|
||||
def build_message_model(api_or_ns: Api | Namespace):
|
||||
class MessageFeedbackPayload(BaseModel):
|
||||
rating: Literal["like", "dislike"] | None = Field(default=None, description="Feedback rating")
|
||||
content: str | None = Field(default=None, description="Feedback content")
|
||||
|
||||
|
||||
class FeedbackListQuery(BaseModel):
|
||||
page: int = Field(default=1, ge=1, description="Page number")
|
||||
limit: int = Field(default=20, ge=1, le=101, description="Number of feedbacks per page")
|
||||
|
||||
|
||||
register_schema_models(service_api_ns, MessageListQuery, MessageFeedbackPayload, FeedbackListQuery)
|
||||
|
||||
|
||||
def build_message_model(api_or_ns: Namespace):
|
||||
"""Build the message model for the API or Namespace."""
|
||||
# First build the nested models
|
||||
feedback_model = build_feedback_model(api_or_ns)
|
||||
|
|
@ -90,7 +78,7 @@ def build_message_model(api_or_ns: Api | Namespace):
|
|||
return api_or_ns.model("Message", message_fields)
|
||||
|
||||
|
||||
def build_message_infinite_scroll_pagination_model(api_or_ns: Api | Namespace):
|
||||
def build_message_infinite_scroll_pagination_model(api_or_ns: Namespace):
|
||||
"""Build the message infinite scroll pagination model for the API or Namespace."""
|
||||
# Build the nested message model first
|
||||
message_model = build_message_model(api_or_ns)
|
||||
|
|
@ -105,7 +93,7 @@ def build_message_infinite_scroll_pagination_model(api_or_ns: Api | Namespace):
|
|||
|
||||
@service_api_ns.route("/messages")
|
||||
class MessageListApi(Resource):
|
||||
@service_api_ns.expect(message_list_parser)
|
||||
@service_api_ns.expect(service_api_ns.models[MessageListQuery.__name__])
|
||||
@service_api_ns.doc("list_messages")
|
||||
@service_api_ns.doc(description="List messages in a conversation")
|
||||
@service_api_ns.doc(
|
||||
|
|
@ -126,11 +114,13 @@ class MessageListApi(Resource):
|
|||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
args = message_list_parser.parse_args()
|
||||
query_args = MessageListQuery.model_validate(request.args.to_dict())
|
||||
conversation_id = str(query_args.conversation_id)
|
||||
first_id = str(query_args.first_id) if query_args.first_id else None
|
||||
|
||||
try:
|
||||
return MessageService.pagination_by_first_id(
|
||||
app_model, end_user, args["conversation_id"], args["first_id"], args["limit"]
|
||||
app_model, end_user, conversation_id, first_id, query_args.limit
|
||||
)
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
|
@ -140,7 +130,7 @@ class MessageListApi(Resource):
|
|||
|
||||
@service_api_ns.route("/messages/<uuid:message_id>/feedbacks")
|
||||
class MessageFeedbackApi(Resource):
|
||||
@service_api_ns.expect(message_feedback_parser)
|
||||
@service_api_ns.expect(service_api_ns.models[MessageFeedbackPayload.__name__])
|
||||
@service_api_ns.doc("create_message_feedback")
|
||||
@service_api_ns.doc(description="Submit feedback for a message")
|
||||
@service_api_ns.doc(params={"message_id": "Message ID"})
|
||||
|
|
@ -159,15 +149,15 @@ class MessageFeedbackApi(Resource):
|
|||
"""
|
||||
message_id = str(message_id)
|
||||
|
||||
args = message_feedback_parser.parse_args()
|
||||
payload = MessageFeedbackPayload.model_validate(service_api_ns.payload or {})
|
||||
|
||||
try:
|
||||
MessageService.create_feedback(
|
||||
app_model=app_model,
|
||||
message_id=message_id,
|
||||
user=end_user,
|
||||
rating=args.get("rating"),
|
||||
content=args.get("content"),
|
||||
rating=payload.rating,
|
||||
content=payload.content,
|
||||
)
|
||||
except MessageNotExistsError:
|
||||
raise NotFound("Message Not Exists.")
|
||||
|
|
@ -177,7 +167,7 @@ class MessageFeedbackApi(Resource):
|
|||
|
||||
@service_api_ns.route("/app/feedbacks")
|
||||
class AppGetFeedbacksApi(Resource):
|
||||
@service_api_ns.expect(feedback_list_parser)
|
||||
@service_api_ns.expect(service_api_ns.models[FeedbackListQuery.__name__])
|
||||
@service_api_ns.doc("get_app_feedbacks")
|
||||
@service_api_ns.doc(description="Get all feedbacks for the application")
|
||||
@service_api_ns.doc(
|
||||
|
|
@ -192,8 +182,8 @@ class AppGetFeedbacksApi(Resource):
|
|||
|
||||
Returns paginated list of all feedback submitted for messages in this app.
|
||||
"""
|
||||
args = feedback_list_parser.parse_args()
|
||||
feedbacks = MessageService.get_all_messages_feedbacks(app_model, page=args["page"], limit=args["limit"])
|
||||
query_args = FeedbackListQuery.model_validate(request.args.to_dict())
|
||||
feedbacks = MessageService.get_all_messages_feedbacks(app_model, page=query_args.page, limit=query_args.limit)
|
||||
return {"data": feedbacks}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,12 +1,14 @@
|
|||
import logging
|
||||
from typing import Any, Literal
|
||||
|
||||
from dateutil.parser import isoparse
|
||||
from flask import request
|
||||
from flask_restx import Api, Namespace, Resource, fields, reqparse
|
||||
from flask_restx.inputs import int_range
|
||||
from flask_restx import Api, Namespace, Resource, fields
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.app.error import (
|
||||
CompletionRequestError,
|
||||
|
|
@ -41,37 +43,25 @@ from services.workflow_app_service import WorkflowAppService
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Define parsers for workflow APIs
|
||||
workflow_run_parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("files", type=list, required=False, location="json")
|
||||
.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
|
||||
)
|
||||
|
||||
workflow_log_parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("keyword", type=str, location="args")
|
||||
.add_argument("status", type=str, choices=["succeeded", "failed", "stopped"], location="args")
|
||||
.add_argument("created_at__before", type=str, location="args")
|
||||
.add_argument("created_at__after", type=str, location="args")
|
||||
.add_argument(
|
||||
"created_by_end_user_session_id",
|
||||
type=str,
|
||||
location="args",
|
||||
required=False,
|
||||
default=None,
|
||||
)
|
||||
.add_argument(
|
||||
"created_by_account",
|
||||
type=str,
|
||||
location="args",
|
||||
required=False,
|
||||
default=None,
|
||||
)
|
||||
.add_argument("page", type=int_range(1, 99999), default=1, location="args")
|
||||
.add_argument("limit", type=int_range(1, 100), default=20, location="args")
|
||||
)
|
||||
class WorkflowRunPayload(BaseModel):
|
||||
inputs: dict[str, Any]
|
||||
files: list[dict[str, Any]] | None = None
|
||||
response_mode: Literal["blocking", "streaming"] | None = None
|
||||
|
||||
|
||||
class WorkflowLogQuery(BaseModel):
|
||||
keyword: str | None = None
|
||||
status: Literal["succeeded", "failed", "stopped"] | None = None
|
||||
created_at__before: str | None = None
|
||||
created_at__after: str | None = None
|
||||
created_by_end_user_session_id: str | None = None
|
||||
created_by_account: str | None = None
|
||||
page: int = Field(default=1, ge=1, le=99999)
|
||||
limit: int = Field(default=20, ge=1, le=100)
|
||||
|
||||
|
||||
register_schema_models(service_api_ns, WorkflowRunPayload, WorkflowLogQuery)
|
||||
|
||||
workflow_run_fields = {
|
||||
"id": fields.String,
|
||||
|
|
@ -130,7 +120,7 @@ class WorkflowRunDetailApi(Resource):
|
|||
|
||||
@service_api_ns.route("/workflows/run")
|
||||
class WorkflowRunApi(Resource):
|
||||
@service_api_ns.expect(workflow_run_parser)
|
||||
@service_api_ns.expect(service_api_ns.models[WorkflowRunPayload.__name__])
|
||||
@service_api_ns.doc("run_workflow")
|
||||
@service_api_ns.doc(description="Execute a workflow")
|
||||
@service_api_ns.doc(
|
||||
|
|
@ -154,11 +144,12 @@ class WorkflowRunApi(Resource):
|
|||
if app_mode != AppMode.WORKFLOW:
|
||||
raise NotWorkflowAppError()
|
||||
|
||||
args = workflow_run_parser.parse_args()
|
||||
payload = WorkflowRunPayload.model_validate(service_api_ns.payload or {})
|
||||
args = payload.model_dump(exclude_none=True)
|
||||
external_trace_id = get_external_trace_id(request)
|
||||
if external_trace_id:
|
||||
args["external_trace_id"] = external_trace_id
|
||||
streaming = args.get("response_mode") == "streaming"
|
||||
streaming = payload.response_mode == "streaming"
|
||||
|
||||
try:
|
||||
response = AppGenerateService.generate(
|
||||
|
|
@ -185,7 +176,7 @@ class WorkflowRunApi(Resource):
|
|||
|
||||
@service_api_ns.route("/workflows/<string:workflow_id>/run")
|
||||
class WorkflowRunByIdApi(Resource):
|
||||
@service_api_ns.expect(workflow_run_parser)
|
||||
@service_api_ns.expect(service_api_ns.models[WorkflowRunPayload.__name__])
|
||||
@service_api_ns.doc("run_workflow_by_id")
|
||||
@service_api_ns.doc(description="Execute a specific workflow by ID")
|
||||
@service_api_ns.doc(params={"workflow_id": "Workflow ID to execute"})
|
||||
|
|
@ -209,7 +200,8 @@ class WorkflowRunByIdApi(Resource):
|
|||
if app_mode != AppMode.WORKFLOW:
|
||||
raise NotWorkflowAppError()
|
||||
|
||||
args = workflow_run_parser.parse_args()
|
||||
payload = WorkflowRunPayload.model_validate(service_api_ns.payload or {})
|
||||
args = payload.model_dump(exclude_none=True)
|
||||
|
||||
# Add workflow_id to args for AppGenerateService
|
||||
args["workflow_id"] = workflow_id
|
||||
|
|
@ -217,7 +209,7 @@ class WorkflowRunByIdApi(Resource):
|
|||
external_trace_id = get_external_trace_id(request)
|
||||
if external_trace_id:
|
||||
args["external_trace_id"] = external_trace_id
|
||||
streaming = args.get("response_mode") == "streaming"
|
||||
streaming = payload.response_mode == "streaming"
|
||||
|
||||
try:
|
||||
response = AppGenerateService.generate(
|
||||
|
|
@ -279,7 +271,7 @@ class WorkflowTaskStopApi(Resource):
|
|||
|
||||
@service_api_ns.route("/workflows/logs")
|
||||
class WorkflowAppLogApi(Resource):
|
||||
@service_api_ns.expect(workflow_log_parser)
|
||||
@service_api_ns.expect(service_api_ns.models[WorkflowLogQuery.__name__])
|
||||
@service_api_ns.doc("get_workflow_logs")
|
||||
@service_api_ns.doc(description="Get workflow execution logs")
|
||||
@service_api_ns.doc(
|
||||
|
|
@ -295,14 +287,11 @@ class WorkflowAppLogApi(Resource):
|
|||
|
||||
Returns paginated workflow execution logs with filtering options.
|
||||
"""
|
||||
args = workflow_log_parser.parse_args()
|
||||
args = WorkflowLogQuery.model_validate(request.args.to_dict())
|
||||
|
||||
args.status = WorkflowExecutionStatus(args.status) if args.status else None
|
||||
if args.created_at__before:
|
||||
args.created_at__before = isoparse(args.created_at__before)
|
||||
|
||||
if args.created_at__after:
|
||||
args.created_at__after = isoparse(args.created_at__after)
|
||||
status = WorkflowExecutionStatus(args.status) if args.status else None
|
||||
created_at_before = isoparse(args.created_at__before) if args.created_at__before else None
|
||||
created_at_after = isoparse(args.created_at__after) if args.created_at__after else None
|
||||
|
||||
# get paginate workflow app logs
|
||||
workflow_app_service = WorkflowAppService()
|
||||
|
|
@ -311,9 +300,9 @@ class WorkflowAppLogApi(Resource):
|
|||
session=session,
|
||||
app_model=app_model,
|
||||
keyword=args.keyword,
|
||||
status=args.status,
|
||||
created_at_before=args.created_at__before,
|
||||
created_at_after=args.created_at__after,
|
||||
status=status,
|
||||
created_at_before=created_at_before,
|
||||
created_at_after=created_at_after,
|
||||
page=args.page,
|
||||
limit=args.limit,
|
||||
created_by_end_user_session_id=args.created_by_end_user_session_id,
|
||||
|
|
|
|||
|
|
@ -1,10 +1,12 @@
|
|||
from typing import Any, Literal, cast
|
||||
|
||||
from flask import request
|
||||
from flask_restx import marshal, reqparse
|
||||
from flask_restx import marshal
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
import services
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console.wraps import edit_permission_required
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.dataset.error import DatasetInUseError, DatasetNameDuplicateError, InvalidActionError
|
||||
|
|
@ -18,173 +20,83 @@ from core.provider_manager import ProviderManager
|
|||
from fields.dataset_fields import dataset_detail_fields
|
||||
from fields.tag_fields import build_dataset_tag_fields
|
||||
from libs.login import current_user
|
||||
from libs.validators import validate_description_length
|
||||
from models.account import Account
|
||||
from models.dataset import Dataset, DatasetPermissionEnum
|
||||
from models.dataset import DatasetPermissionEnum
|
||||
from models.provider_ids import ModelProviderID
|
||||
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
|
||||
from services.entities.knowledge_entities.knowledge_entities import RetrievalModel
|
||||
from services.tag_service import TagService
|
||||
|
||||
|
||||
def _validate_name(name):
|
||||
if not name or len(name) < 1 or len(name) > 40:
|
||||
raise ValueError("Name must be between 1 to 40 characters.")
|
||||
return name
|
||||
class DatasetCreatePayload(BaseModel):
|
||||
name: str = Field(..., min_length=1, max_length=40)
|
||||
description: str = Field(default="", description="Dataset description (max 400 chars)", max_length=400)
|
||||
indexing_technique: Literal["high_quality", "economy"] | None = None
|
||||
permission: DatasetPermissionEnum | None = DatasetPermissionEnum.ONLY_ME
|
||||
external_knowledge_api_id: str | None = None
|
||||
provider: str = "vendor"
|
||||
external_knowledge_id: str | None = None
|
||||
retrieval_model: RetrievalModel | None = None
|
||||
embedding_model: str | None = None
|
||||
embedding_model_provider: str | None = None
|
||||
|
||||
|
||||
# Define parsers for dataset operations
|
||||
dataset_create_parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument(
|
||||
"name",
|
||||
nullable=False,
|
||||
required=True,
|
||||
help="type is required. Name must be between 1 to 40 characters.",
|
||||
type=_validate_name,
|
||||
)
|
||||
.add_argument(
|
||||
"description",
|
||||
type=validate_description_length,
|
||||
nullable=True,
|
||||
required=False,
|
||||
default="",
|
||||
)
|
||||
.add_argument(
|
||||
"indexing_technique",
|
||||
type=str,
|
||||
location="json",
|
||||
choices=Dataset.INDEXING_TECHNIQUE_LIST,
|
||||
help="Invalid indexing technique.",
|
||||
)
|
||||
.add_argument(
|
||||
"permission",
|
||||
type=str,
|
||||
location="json",
|
||||
choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM),
|
||||
help="Invalid permission.",
|
||||
required=False,
|
||||
nullable=False,
|
||||
)
|
||||
.add_argument(
|
||||
"external_knowledge_api_id",
|
||||
type=str,
|
||||
nullable=True,
|
||||
required=False,
|
||||
default="_validate_name",
|
||||
)
|
||||
.add_argument(
|
||||
"provider",
|
||||
type=str,
|
||||
nullable=True,
|
||||
required=False,
|
||||
default="vendor",
|
||||
)
|
||||
.add_argument(
|
||||
"external_knowledge_id",
|
||||
type=str,
|
||||
nullable=True,
|
||||
required=False,
|
||||
)
|
||||
.add_argument("retrieval_model", type=dict, required=False, nullable=True, location="json")
|
||||
.add_argument("embedding_model", type=str, required=False, nullable=True, location="json")
|
||||
.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
|
||||
)
|
||||
class DatasetUpdatePayload(BaseModel):
|
||||
name: str | None = Field(default=None, min_length=1, max_length=40)
|
||||
description: str | None = Field(default=None, description="Dataset description (max 400 chars)", max_length=400)
|
||||
indexing_technique: Literal["high_quality", "economy"] | None = None
|
||||
permission: DatasetPermissionEnum | None = None
|
||||
embedding_model: str | None = None
|
||||
embedding_model_provider: str | None = None
|
||||
retrieval_model: RetrievalModel | None = None
|
||||
partial_member_list: list[str] | None = None
|
||||
external_retrieval_model: dict[str, Any] | None = None
|
||||
external_knowledge_id: str | None = None
|
||||
external_knowledge_api_id: str | None = None
|
||||
|
||||
dataset_update_parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument(
|
||||
"name",
|
||||
nullable=False,
|
||||
help="type is required. Name must be between 1 to 40 characters.",
|
||||
type=_validate_name,
|
||||
)
|
||||
.add_argument("description", location="json", store_missing=False, type=validate_description_length)
|
||||
.add_argument(
|
||||
"indexing_technique",
|
||||
type=str,
|
||||
location="json",
|
||||
choices=Dataset.INDEXING_TECHNIQUE_LIST,
|
||||
nullable=True,
|
||||
help="Invalid indexing technique.",
|
||||
)
|
||||
.add_argument(
|
||||
"permission",
|
||||
type=str,
|
||||
location="json",
|
||||
choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM),
|
||||
help="Invalid permission.",
|
||||
)
|
||||
.add_argument("embedding_model", type=str, location="json", help="Invalid embedding model.")
|
||||
.add_argument("embedding_model_provider", type=str, location="json", help="Invalid embedding model provider.")
|
||||
.add_argument("retrieval_model", type=dict, location="json", help="Invalid retrieval model.")
|
||||
.add_argument("partial_member_list", type=list, location="json", help="Invalid parent user list.")
|
||||
.add_argument(
|
||||
"external_retrieval_model",
|
||||
type=dict,
|
||||
required=False,
|
||||
nullable=True,
|
||||
location="json",
|
||||
help="Invalid external retrieval model.",
|
||||
)
|
||||
.add_argument(
|
||||
"external_knowledge_id",
|
||||
type=str,
|
||||
required=False,
|
||||
nullable=True,
|
||||
location="json",
|
||||
help="Invalid external knowledge id.",
|
||||
)
|
||||
.add_argument(
|
||||
"external_knowledge_api_id",
|
||||
type=str,
|
||||
required=False,
|
||||
nullable=True,
|
||||
location="json",
|
||||
help="Invalid external knowledge api id.",
|
||||
)
|
||||
)
|
||||
|
||||
tag_create_parser = reqparse.RequestParser().add_argument(
|
||||
"name",
|
||||
nullable=False,
|
||||
required=True,
|
||||
help="Name must be between 1 to 50 characters.",
|
||||
type=lambda x: x
|
||||
if x and 1 <= len(x) <= 50
|
||||
else (_ for _ in ()).throw(ValueError("Name must be between 1 to 50 characters.")),
|
||||
)
|
||||
class TagNamePayload(BaseModel):
|
||||
name: str = Field(..., min_length=1, max_length=50)
|
||||
|
||||
tag_update_parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument(
|
||||
"name",
|
||||
nullable=False,
|
||||
required=True,
|
||||
help="Name must be between 1 to 50 characters.",
|
||||
type=lambda x: x
|
||||
if x and 1 <= len(x) <= 50
|
||||
else (_ for _ in ()).throw(ValueError("Name must be between 1 to 50 characters.")),
|
||||
)
|
||||
.add_argument("tag_id", nullable=False, required=True, help="Id of a tag.", type=str)
|
||||
)
|
||||
|
||||
tag_delete_parser = reqparse.RequestParser().add_argument(
|
||||
"tag_id", nullable=False, required=True, help="Id of a tag.", type=str
|
||||
)
|
||||
class TagCreatePayload(TagNamePayload):
|
||||
pass
|
||||
|
||||
tag_binding_parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("tag_ids", type=list, nullable=False, required=True, location="json", help="Tag IDs is required.")
|
||||
.add_argument(
|
||||
"target_id", type=str, nullable=False, required=True, location="json", help="Target Dataset ID is required."
|
||||
)
|
||||
)
|
||||
|
||||
tag_unbinding_parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("tag_id", type=str, nullable=False, required=True, help="Tag ID is required.")
|
||||
.add_argument("target_id", type=str, nullable=False, required=True, help="Target ID is required.")
|
||||
class TagUpdatePayload(TagNamePayload):
|
||||
tag_id: str
|
||||
|
||||
|
||||
class TagDeletePayload(BaseModel):
|
||||
tag_id: str
|
||||
|
||||
|
||||
class TagBindingPayload(BaseModel):
|
||||
tag_ids: list[str]
|
||||
target_id: str
|
||||
|
||||
@field_validator("tag_ids")
|
||||
@classmethod
|
||||
def validate_tag_ids(cls, value: list[str]) -> list[str]:
|
||||
if not value:
|
||||
raise ValueError("Tag IDs is required.")
|
||||
return value
|
||||
|
||||
|
||||
class TagUnbindingPayload(BaseModel):
|
||||
tag_id: str
|
||||
target_id: str
|
||||
|
||||
|
||||
register_schema_models(
|
||||
service_api_ns,
|
||||
DatasetCreatePayload,
|
||||
DatasetUpdatePayload,
|
||||
TagCreatePayload,
|
||||
TagUpdatePayload,
|
||||
TagDeletePayload,
|
||||
TagBindingPayload,
|
||||
TagUnbindingPayload,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -239,7 +151,7 @@ class DatasetListApi(DatasetApiResource):
|
|||
response = {"data": data, "has_more": len(datasets) == limit, "limit": limit, "total": total, "page": page}
|
||||
return response, 200
|
||||
|
||||
@service_api_ns.expect(dataset_create_parser)
|
||||
@service_api_ns.expect(service_api_ns.models[DatasetCreatePayload.__name__])
|
||||
@service_api_ns.doc("create_dataset")
|
||||
@service_api_ns.doc(description="Create a new dataset")
|
||||
@service_api_ns.doc(
|
||||
|
|
@ -252,42 +164,41 @@ class DatasetListApi(DatasetApiResource):
|
|||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def post(self, tenant_id):
|
||||
"""Resource for creating datasets."""
|
||||
args = dataset_create_parser.parse_args()
|
||||
payload = DatasetCreatePayload.model_validate(service_api_ns.payload or {})
|
||||
|
||||
embedding_model_provider = args.get("embedding_model_provider")
|
||||
embedding_model = args.get("embedding_model")
|
||||
embedding_model_provider = payload.embedding_model_provider
|
||||
embedding_model = payload.embedding_model
|
||||
if embedding_model_provider and embedding_model:
|
||||
DatasetService.check_embedding_model_setting(tenant_id, embedding_model_provider, embedding_model)
|
||||
|
||||
retrieval_model = args.get("retrieval_model")
|
||||
retrieval_model = payload.retrieval_model
|
||||
if (
|
||||
retrieval_model
|
||||
and retrieval_model.get("reranking_model")
|
||||
and retrieval_model.get("reranking_model").get("reranking_provider_name")
|
||||
and retrieval_model.reranking_model
|
||||
and retrieval_model.reranking_model.reranking_provider_name
|
||||
and retrieval_model.reranking_model.reranking_model_name
|
||||
):
|
||||
DatasetService.check_reranking_model_setting(
|
||||
tenant_id,
|
||||
retrieval_model.get("reranking_model").get("reranking_provider_name"),
|
||||
retrieval_model.get("reranking_model").get("reranking_model_name"),
|
||||
retrieval_model.reranking_model.reranking_provider_name,
|
||||
retrieval_model.reranking_model.reranking_model_name,
|
||||
)
|
||||
|
||||
try:
|
||||
assert isinstance(current_user, Account)
|
||||
dataset = DatasetService.create_empty_dataset(
|
||||
tenant_id=tenant_id,
|
||||
name=args["name"],
|
||||
description=args["description"],
|
||||
indexing_technique=args["indexing_technique"],
|
||||
name=payload.name,
|
||||
description=payload.description,
|
||||
indexing_technique=payload.indexing_technique,
|
||||
account=current_user,
|
||||
permission=args["permission"],
|
||||
provider=args["provider"],
|
||||
external_knowledge_api_id=args["external_knowledge_api_id"],
|
||||
external_knowledge_id=args["external_knowledge_id"],
|
||||
embedding_model_provider=args["embedding_model_provider"],
|
||||
embedding_model_name=args["embedding_model"],
|
||||
retrieval_model=RetrievalModel.model_validate(args["retrieval_model"])
|
||||
if args["retrieval_model"] is not None
|
||||
else None,
|
||||
permission=str(payload.permission) if payload.permission else None,
|
||||
provider=payload.provider,
|
||||
external_knowledge_api_id=payload.external_knowledge_api_id,
|
||||
external_knowledge_id=payload.external_knowledge_id,
|
||||
embedding_model_provider=payload.embedding_model_provider,
|
||||
embedding_model_name=payload.embedding_model,
|
||||
retrieval_model=payload.retrieval_model,
|
||||
)
|
||||
except services.errors.dataset.DatasetNameDuplicateError:
|
||||
raise DatasetNameDuplicateError()
|
||||
|
|
@ -353,7 +264,7 @@ class DatasetApi(DatasetApiResource):
|
|||
|
||||
return data, 200
|
||||
|
||||
@service_api_ns.expect(dataset_update_parser)
|
||||
@service_api_ns.expect(service_api_ns.models[DatasetUpdatePayload.__name__])
|
||||
@service_api_ns.doc("update_dataset")
|
||||
@service_api_ns.doc(description="Update an existing dataset")
|
||||
@service_api_ns.doc(params={"dataset_id": "Dataset ID"})
|
||||
|
|
@ -372,36 +283,45 @@ class DatasetApi(DatasetApiResource):
|
|||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
args = dataset_update_parser.parse_args()
|
||||
data = request.get_json()
|
||||
payload_dict = service_api_ns.payload or {}
|
||||
payload = DatasetUpdatePayload.model_validate(payload_dict)
|
||||
update_data = payload.model_dump(exclude_unset=True)
|
||||
if payload.permission is not None:
|
||||
update_data["permission"] = str(payload.permission)
|
||||
if payload.retrieval_model is not None:
|
||||
update_data["retrieval_model"] = payload.retrieval_model.model_dump()
|
||||
|
||||
# check embedding model setting
|
||||
embedding_model_provider = data.get("embedding_model_provider")
|
||||
embedding_model = data.get("embedding_model")
|
||||
if data.get("indexing_technique") == "high_quality" or embedding_model_provider:
|
||||
embedding_model_provider = payload.embedding_model_provider
|
||||
embedding_model = payload.embedding_model
|
||||
if payload.indexing_technique == "high_quality" or embedding_model_provider:
|
||||
if embedding_model_provider and embedding_model:
|
||||
DatasetService.check_embedding_model_setting(
|
||||
dataset.tenant_id, embedding_model_provider, embedding_model
|
||||
)
|
||||
|
||||
retrieval_model = data.get("retrieval_model")
|
||||
retrieval_model = payload.retrieval_model
|
||||
if (
|
||||
retrieval_model
|
||||
and retrieval_model.get("reranking_model")
|
||||
and retrieval_model.get("reranking_model").get("reranking_provider_name")
|
||||
and retrieval_model.reranking_model
|
||||
and retrieval_model.reranking_model.reranking_provider_name
|
||||
and retrieval_model.reranking_model.reranking_model_name
|
||||
):
|
||||
DatasetService.check_reranking_model_setting(
|
||||
dataset.tenant_id,
|
||||
retrieval_model.get("reranking_model").get("reranking_provider_name"),
|
||||
retrieval_model.get("reranking_model").get("reranking_model_name"),
|
||||
retrieval_model.reranking_model.reranking_provider_name,
|
||||
retrieval_model.reranking_model.reranking_model_name,
|
||||
)
|
||||
|
||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||
DatasetPermissionService.check_permission(
|
||||
current_user, dataset, data.get("permission"), data.get("partial_member_list")
|
||||
current_user,
|
||||
dataset,
|
||||
str(payload.permission) if payload.permission else None,
|
||||
payload.partial_member_list,
|
||||
)
|
||||
|
||||
dataset = DatasetService.update_dataset(dataset_id_str, args, current_user)
|
||||
dataset = DatasetService.update_dataset(dataset_id_str, update_data, current_user)
|
||||
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
|
@ -410,15 +330,10 @@ class DatasetApi(DatasetApiResource):
|
|||
assert isinstance(current_user, Account)
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
if data.get("partial_member_list") and data.get("permission") == "partial_members":
|
||||
DatasetPermissionService.update_partial_member_list(
|
||||
tenant_id, dataset_id_str, data.get("partial_member_list")
|
||||
)
|
||||
if payload.partial_member_list and payload.permission == DatasetPermissionEnum.PARTIAL_TEAM:
|
||||
DatasetPermissionService.update_partial_member_list(tenant_id, dataset_id_str, payload.partial_member_list)
|
||||
# clear partial member list when permission is only_me or all_team_members
|
||||
elif (
|
||||
data.get("permission") == DatasetPermissionEnum.ONLY_ME
|
||||
or data.get("permission") == DatasetPermissionEnum.ALL_TEAM
|
||||
):
|
||||
elif payload.permission in {DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM}:
|
||||
DatasetPermissionService.clear_partial_member_list(dataset_id_str)
|
||||
|
||||
partial_member_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
|
||||
|
|
@ -556,7 +471,7 @@ class DatasetTagsApi(DatasetApiResource):
|
|||
|
||||
return tags, 200
|
||||
|
||||
@service_api_ns.expect(tag_create_parser)
|
||||
@service_api_ns.expect(service_api_ns.models[TagCreatePayload.__name__])
|
||||
@service_api_ns.doc("create_dataset_tag")
|
||||
@service_api_ns.doc(description="Add a knowledge type tag")
|
||||
@service_api_ns.doc(
|
||||
|
|
@ -574,14 +489,13 @@ class DatasetTagsApi(DatasetApiResource):
|
|||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
args = tag_create_parser.parse_args()
|
||||
args["type"] = "knowledge"
|
||||
tag = TagService.save_tags(args)
|
||||
payload = TagCreatePayload.model_validate(service_api_ns.payload or {})
|
||||
tag = TagService.save_tags({"name": payload.name, "type": "knowledge"})
|
||||
|
||||
response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0}
|
||||
return response, 200
|
||||
|
||||
@service_api_ns.expect(tag_update_parser)
|
||||
@service_api_ns.expect(service_api_ns.models[TagUpdatePayload.__name__])
|
||||
@service_api_ns.doc("update_dataset_tag")
|
||||
@service_api_ns.doc(description="Update a knowledge type tag")
|
||||
@service_api_ns.doc(
|
||||
|
|
@ -598,10 +512,10 @@ class DatasetTagsApi(DatasetApiResource):
|
|||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
args = tag_update_parser.parse_args()
|
||||
args["type"] = "knowledge"
|
||||
tag_id = args["tag_id"]
|
||||
tag = TagService.update_tags(args, tag_id)
|
||||
payload = TagUpdatePayload.model_validate(service_api_ns.payload or {})
|
||||
params = {"name": payload.name, "type": "knowledge"}
|
||||
tag_id = payload.tag_id
|
||||
tag = TagService.update_tags(params, tag_id)
|
||||
|
||||
binding_count = TagService.get_tag_binding_count(tag_id)
|
||||
|
||||
|
|
@ -609,7 +523,7 @@ class DatasetTagsApi(DatasetApiResource):
|
|||
|
||||
return response, 200
|
||||
|
||||
@service_api_ns.expect(tag_delete_parser)
|
||||
@service_api_ns.expect(service_api_ns.models[TagDeletePayload.__name__])
|
||||
@service_api_ns.doc("delete_dataset_tag")
|
||||
@service_api_ns.doc(description="Delete a knowledge type tag")
|
||||
@service_api_ns.doc(
|
||||
|
|
@ -623,15 +537,15 @@ class DatasetTagsApi(DatasetApiResource):
|
|||
@edit_permission_required
|
||||
def delete(self, _, dataset_id):
|
||||
"""Delete a knowledge type tag."""
|
||||
args = tag_delete_parser.parse_args()
|
||||
TagService.delete_tag(args["tag_id"])
|
||||
payload = TagDeletePayload.model_validate(service_api_ns.payload or {})
|
||||
TagService.delete_tag(payload.tag_id)
|
||||
|
||||
return 204
|
||||
|
||||
|
||||
@service_api_ns.route("/datasets/tags/binding")
|
||||
class DatasetTagBindingApi(DatasetApiResource):
|
||||
@service_api_ns.expect(tag_binding_parser)
|
||||
@service_api_ns.expect(service_api_ns.models[TagBindingPayload.__name__])
|
||||
@service_api_ns.doc("bind_dataset_tags")
|
||||
@service_api_ns.doc(description="Bind tags to a dataset")
|
||||
@service_api_ns.doc(
|
||||
|
|
@ -648,16 +562,15 @@ class DatasetTagBindingApi(DatasetApiResource):
|
|||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
args = tag_binding_parser.parse_args()
|
||||
args["type"] = "knowledge"
|
||||
TagService.save_tag_binding(args)
|
||||
payload = TagBindingPayload.model_validate(service_api_ns.payload or {})
|
||||
TagService.save_tag_binding({"tag_ids": payload.tag_ids, "target_id": payload.target_id, "type": "knowledge"})
|
||||
|
||||
return 204
|
||||
|
||||
|
||||
@service_api_ns.route("/datasets/tags/unbinding")
|
||||
class DatasetTagUnbindingApi(DatasetApiResource):
|
||||
@service_api_ns.expect(tag_unbinding_parser)
|
||||
@service_api_ns.expect(service_api_ns.models[TagUnbindingPayload.__name__])
|
||||
@service_api_ns.doc("unbind_dataset_tag")
|
||||
@service_api_ns.doc(description="Unbind a tag from a dataset")
|
||||
@service_api_ns.doc(
|
||||
|
|
@ -674,9 +587,8 @@ class DatasetTagUnbindingApi(DatasetApiResource):
|
|||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
args = tag_unbinding_parser.parse_args()
|
||||
args["type"] = "knowledge"
|
||||
TagService.delete_tag_binding(args)
|
||||
payload = TagUnbindingPayload.model_validate(service_api_ns.payload or {})
|
||||
TagService.delete_tag_binding({"tag_id": payload.tag_id, "target_id": payload.target_id, "type": "knowledge"})
|
||||
|
||||
return 204
|
||||
|
||||
|
|
|
|||
|
|
@ -3,8 +3,8 @@ from typing import Self
|
|||
from uuid import UUID
|
||||
|
||||
from flask import request
|
||||
from flask_restx import marshal, reqparse
|
||||
from pydantic import BaseModel, model_validator
|
||||
from flask_restx import marshal
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from sqlalchemy import desc, select
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
|
|
@ -37,22 +37,19 @@ from services.dataset_service import DatasetService, DocumentService
|
|||
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig, ProcessRule, RetrievalModel
|
||||
from services.file_service import FileService
|
||||
|
||||
# Define parsers for document operations
|
||||
document_text_create_parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("name", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("text", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("process_rule", type=dict, required=False, nullable=True, location="json")
|
||||
.add_argument("original_document_id", type=str, required=False, location="json")
|
||||
.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
|
||||
.add_argument("doc_language", type=str, default="English", required=False, nullable=False, location="json")
|
||||
.add_argument(
|
||||
"indexing_technique", type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=False, location="json"
|
||||
)
|
||||
.add_argument("retrieval_model", type=dict, required=False, nullable=True, location="json")
|
||||
.add_argument("embedding_model", type=str, required=False, nullable=True, location="json")
|
||||
.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
|
||||
)
|
||||
|
||||
class DocumentTextCreatePayload(BaseModel):
|
||||
name: str
|
||||
text: str
|
||||
process_rule: ProcessRule | None = None
|
||||
original_document_id: str | None = None
|
||||
doc_form: str = Field(default="text_model")
|
||||
doc_language: str = Field(default="English")
|
||||
indexing_technique: str | None = None
|
||||
retrieval_model: RetrievalModel | None = None
|
||||
embedding_model: str | None = None
|
||||
embedding_model_provider: str | None = None
|
||||
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
|
@ -72,7 +69,7 @@ class DocumentTextUpdate(BaseModel):
|
|||
return self
|
||||
|
||||
|
||||
for m in [ProcessRule, RetrievalModel, DocumentTextUpdate]:
|
||||
for m in [ProcessRule, RetrievalModel, DocumentTextCreatePayload, DocumentTextUpdate]:
|
||||
service_api_ns.schema_model(m.__name__, m.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) # type: ignore
|
||||
|
||||
|
||||
|
|
@ -83,7 +80,7 @@ for m in [ProcessRule, RetrievalModel, DocumentTextUpdate]:
|
|||
class DocumentAddByTextApi(DatasetApiResource):
|
||||
"""Resource for documents."""
|
||||
|
||||
@service_api_ns.expect(document_text_create_parser)
|
||||
@service_api_ns.expect(service_api_ns.models[DocumentTextCreatePayload.__name__])
|
||||
@service_api_ns.doc("create_document_by_text")
|
||||
@service_api_ns.doc(description="Create a new document by providing text content")
|
||||
@service_api_ns.doc(params={"dataset_id": "Dataset ID"})
|
||||
|
|
@ -99,7 +96,8 @@ class DocumentAddByTextApi(DatasetApiResource):
|
|||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def post(self, tenant_id, dataset_id):
|
||||
"""Create document by text."""
|
||||
args = document_text_create_parser.parse_args()
|
||||
payload = DocumentTextCreatePayload.model_validate(service_api_ns.payload or {})
|
||||
args = payload.model_dump(exclude_none=True)
|
||||
|
||||
dataset_id = str(dataset_id)
|
||||
tenant_id = str(tenant_id)
|
||||
|
|
@ -111,33 +109,29 @@ class DocumentAddByTextApi(DatasetApiResource):
|
|||
if not dataset.indexing_technique and not args["indexing_technique"]:
|
||||
raise ValueError("indexing_technique is required.")
|
||||
|
||||
text = args.get("text")
|
||||
name = args.get("name")
|
||||
if text is None or name is None:
|
||||
raise ValueError("Both 'text' and 'name' must be non-null values.")
|
||||
|
||||
embedding_model_provider = args.get("embedding_model_provider")
|
||||
embedding_model = args.get("embedding_model")
|
||||
embedding_model_provider = payload.embedding_model_provider
|
||||
embedding_model = payload.embedding_model
|
||||
if embedding_model_provider and embedding_model:
|
||||
DatasetService.check_embedding_model_setting(tenant_id, embedding_model_provider, embedding_model)
|
||||
|
||||
retrieval_model = args.get("retrieval_model")
|
||||
retrieval_model = payload.retrieval_model
|
||||
if (
|
||||
retrieval_model
|
||||
and retrieval_model.get("reranking_model")
|
||||
and retrieval_model.get("reranking_model").get("reranking_provider_name")
|
||||
and retrieval_model.reranking_model
|
||||
and retrieval_model.reranking_model.reranking_provider_name
|
||||
and retrieval_model.reranking_model.reranking_model_name
|
||||
):
|
||||
DatasetService.check_reranking_model_setting(
|
||||
tenant_id,
|
||||
retrieval_model.get("reranking_model").get("reranking_provider_name"),
|
||||
retrieval_model.get("reranking_model").get("reranking_model_name"),
|
||||
retrieval_model.reranking_model.reranking_provider_name,
|
||||
retrieval_model.reranking_model.reranking_model_name,
|
||||
)
|
||||
|
||||
if not current_user:
|
||||
raise ValueError("current_user is required")
|
||||
|
||||
upload_file = FileService(db.engine).upload_text(
|
||||
text=str(text), text_name=str(name), user_id=current_user.id, tenant_id=tenant_id
|
||||
text=payload.text, text_name=payload.name, user_id=current_user.id, tenant_id=tenant_id
|
||||
)
|
||||
data_source = {
|
||||
"type": "upload_file",
|
||||
|
|
@ -174,7 +168,7 @@ class DocumentAddByTextApi(DatasetApiResource):
|
|||
class DocumentUpdateByTextApi(DatasetApiResource):
|
||||
"""Resource for update documents."""
|
||||
|
||||
@service_api_ns.expect(service_api_ns.models[DocumentTextUpdate.__name__], validate=True)
|
||||
@service_api_ns.expect(service_api_ns.models[DocumentTextUpdate.__name__])
|
||||
@service_api_ns.doc("update_document_by_text")
|
||||
@service_api_ns.doc(description="Update an existing document by providing text content")
|
||||
@service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
|
||||
|
|
@ -189,22 +183,23 @@ class DocumentUpdateByTextApi(DatasetApiResource):
|
|||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def post(self, tenant_id: str, dataset_id: UUID, document_id: UUID):
|
||||
"""Update document by text."""
|
||||
args = DocumentTextUpdate.model_validate(service_api_ns.payload).model_dump(exclude_unset=True)
|
||||
payload = DocumentTextUpdate.model_validate(service_api_ns.payload or {})
|
||||
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == str(dataset_id)).first()
|
||||
|
||||
args = payload.model_dump(exclude_none=True)
|
||||
if not dataset:
|
||||
raise ValueError("Dataset does not exist.")
|
||||
|
||||
retrieval_model = args.get("retrieval_model")
|
||||
retrieval_model = payload.retrieval_model
|
||||
if (
|
||||
retrieval_model
|
||||
and retrieval_model.get("reranking_model")
|
||||
and retrieval_model.get("reranking_model").get("reranking_provider_name")
|
||||
and retrieval_model.reranking_model
|
||||
and retrieval_model.reranking_model.reranking_provider_name
|
||||
and retrieval_model.reranking_model.reranking_model_name
|
||||
):
|
||||
DatasetService.check_reranking_model_setting(
|
||||
tenant_id,
|
||||
retrieval_model.get("reranking_model").get("reranking_provider_name"),
|
||||
retrieval_model.get("reranking_model").get("reranking_model_name"),
|
||||
retrieval_model.reranking_model.reranking_provider_name,
|
||||
retrieval_model.reranking_model.reranking_model_name,
|
||||
)
|
||||
|
||||
# indexing_technique is already set in dataset since this is an update
|
||||
|
|
|
|||
|
|
@ -1,9 +1,11 @@
|
|||
from typing import Literal
|
||||
|
||||
from flask_login import current_user
|
||||
from flask_restx import marshal, reqparse
|
||||
from flask_restx import marshal
|
||||
from pydantic import BaseModel
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.common.schema import register_schema_model, register_schema_models
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_rate_limit_check
|
||||
from fields.dataset_fields import dataset_metadata_fields
|
||||
|
|
@ -14,25 +16,18 @@ from services.entities.knowledge_entities.knowledge_entities import (
|
|||
)
|
||||
from services.metadata_service import MetadataService
|
||||
|
||||
# Define parsers for metadata APIs
|
||||
metadata_create_parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("type", type=str, required=True, nullable=False, location="json", help="Metadata type")
|
||||
.add_argument("name", type=str, required=True, nullable=False, location="json", help="Metadata name")
|
||||
)
|
||||
|
||||
metadata_update_parser = reqparse.RequestParser().add_argument(
|
||||
"name", type=str, required=True, nullable=False, location="json", help="New metadata name"
|
||||
)
|
||||
class MetadataUpdatePayload(BaseModel):
|
||||
name: str
|
||||
|
||||
document_metadata_parser = reqparse.RequestParser().add_argument(
|
||||
"operation_data", type=list, required=True, nullable=False, location="json", help="Metadata operation data"
|
||||
)
|
||||
|
||||
register_schema_model(service_api_ns, MetadataUpdatePayload)
|
||||
register_schema_models(service_api_ns, MetadataArgs, MetadataOperationData)
|
||||
|
||||
|
||||
@service_api_ns.route("/datasets/<uuid:dataset_id>/metadata")
|
||||
class DatasetMetadataCreateServiceApi(DatasetApiResource):
|
||||
@service_api_ns.expect(metadata_create_parser)
|
||||
@service_api_ns.expect(service_api_ns.models[MetadataArgs.__name__])
|
||||
@service_api_ns.doc("create_dataset_metadata")
|
||||
@service_api_ns.doc(description="Create metadata for a dataset")
|
||||
@service_api_ns.doc(params={"dataset_id": "Dataset ID"})
|
||||
|
|
@ -46,8 +41,7 @@ class DatasetMetadataCreateServiceApi(DatasetApiResource):
|
|||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def post(self, tenant_id, dataset_id):
|
||||
"""Create metadata for a dataset."""
|
||||
args = metadata_create_parser.parse_args()
|
||||
metadata_args = MetadataArgs.model_validate(args)
|
||||
metadata_args = MetadataArgs.model_validate(service_api_ns.payload or {})
|
||||
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
|
|
@ -79,7 +73,7 @@ class DatasetMetadataCreateServiceApi(DatasetApiResource):
|
|||
|
||||
@service_api_ns.route("/datasets/<uuid:dataset_id>/metadata/<uuid:metadata_id>")
|
||||
class DatasetMetadataServiceApi(DatasetApiResource):
|
||||
@service_api_ns.expect(metadata_update_parser)
|
||||
@service_api_ns.expect(service_api_ns.models[MetadataUpdatePayload.__name__])
|
||||
@service_api_ns.doc("update_dataset_metadata")
|
||||
@service_api_ns.doc(description="Update metadata name")
|
||||
@service_api_ns.doc(params={"dataset_id": "Dataset ID", "metadata_id": "Metadata ID"})
|
||||
|
|
@ -93,7 +87,7 @@ class DatasetMetadataServiceApi(DatasetApiResource):
|
|||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def patch(self, tenant_id, dataset_id, metadata_id):
|
||||
"""Update metadata name."""
|
||||
args = metadata_update_parser.parse_args()
|
||||
payload = MetadataUpdatePayload.model_validate(service_api_ns.payload or {})
|
||||
|
||||
dataset_id_str = str(dataset_id)
|
||||
metadata_id_str = str(metadata_id)
|
||||
|
|
@ -102,7 +96,7 @@ class DatasetMetadataServiceApi(DatasetApiResource):
|
|||
raise NotFound("Dataset not found.")
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
|
||||
metadata = MetadataService.update_metadata_name(dataset_id_str, metadata_id_str, args["name"])
|
||||
metadata = MetadataService.update_metadata_name(dataset_id_str, metadata_id_str, payload.name)
|
||||
return marshal(metadata, dataset_metadata_fields), 200
|
||||
|
||||
@service_api_ns.doc("delete_dataset_metadata")
|
||||
|
|
@ -175,7 +169,7 @@ class DatasetMetadataBuiltInFieldActionServiceApi(DatasetApiResource):
|
|||
|
||||
@service_api_ns.route("/datasets/<uuid:dataset_id>/documents/metadata")
|
||||
class DocumentMetadataEditServiceApi(DatasetApiResource):
|
||||
@service_api_ns.expect(document_metadata_parser)
|
||||
@service_api_ns.expect(service_api_ns.models[MetadataOperationData.__name__])
|
||||
@service_api_ns.doc("update_documents_metadata")
|
||||
@service_api_ns.doc(description="Update metadata for multiple documents")
|
||||
@service_api_ns.doc(params={"dataset_id": "Dataset ID"})
|
||||
|
|
@ -195,8 +189,7 @@ class DocumentMetadataEditServiceApi(DatasetApiResource):
|
|||
raise NotFound("Dataset not found.")
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
|
||||
args = document_metadata_parser.parse_args()
|
||||
metadata_args = MetadataOperationData.model_validate(args)
|
||||
metadata_args = MetadataOperationData.model_validate(service_api_ns.payload or {})
|
||||
|
||||
MetadataService.update_documents_metadata(dataset, metadata_args)
|
||||
|
||||
|
|
|
|||
|
|
@ -4,12 +4,12 @@ from collections.abc import Generator
|
|||
from typing import Any
|
||||
|
||||
from flask import request
|
||||
from flask_restx import reqparse
|
||||
from flask_restx.reqparse import ParseResult, RequestParser
|
||||
from pydantic import BaseModel
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
import services
|
||||
from controllers.common.errors import FilenameNotExistsError, NoFileUploadedError, TooManyFilesError
|
||||
from controllers.common.schema import register_schema_model
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.dataset.error import PipelineRunError
|
||||
from controllers.service_api.wraps import DatasetApiResource
|
||||
|
|
@ -22,11 +22,25 @@ from models.dataset import Pipeline
|
|||
from models.engine import db
|
||||
from services.errors.file import FileTooLargeError, UnsupportedFileTypeError
|
||||
from services.file_service import FileService
|
||||
from services.rag_pipeline.entity.pipeline_service_api_entities import DatasourceNodeRunApiEntity
|
||||
from services.rag_pipeline.entity.pipeline_service_api_entities import (
|
||||
DatasourceNodeRunApiEntity,
|
||||
PipelineRunApiEntity,
|
||||
)
|
||||
from services.rag_pipeline.pipeline_generate_service import PipelineGenerateService
|
||||
from services.rag_pipeline.rag_pipeline import RagPipelineService
|
||||
|
||||
|
||||
class DatasourceNodeRunPayload(BaseModel):
|
||||
inputs: dict[str, Any]
|
||||
datasource_type: str
|
||||
credential_id: str | None = None
|
||||
is_published: bool
|
||||
|
||||
|
||||
register_schema_model(service_api_ns, DatasourceNodeRunPayload)
|
||||
register_schema_model(service_api_ns, PipelineRunApiEntity)
|
||||
|
||||
|
||||
@service_api_ns.route(f"/datasets/{uuid:dataset_id}/pipeline/datasource-plugins")
|
||||
class DatasourcePluginsApi(DatasetApiResource):
|
||||
"""Resource for datasource plugins."""
|
||||
|
|
@ -88,22 +102,20 @@ class DatasourceNodeRunApi(DatasetApiResource):
|
|||
401: "Unauthorized - invalid API token",
|
||||
}
|
||||
)
|
||||
@service_api_ns.expect(service_api_ns.models[DatasourceNodeRunPayload.__name__])
|
||||
def post(self, tenant_id: str, dataset_id: str, node_id: str):
|
||||
"""Resource for getting datasource plugins."""
|
||||
# Get query parameter to determine published or draft
|
||||
parser: RequestParser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("datasource_type", type=str, required=True, location="json")
|
||||
.add_argument("credential_id", type=str, required=False, location="json")
|
||||
.add_argument("is_published", type=bool, required=True, location="json")
|
||||
)
|
||||
args: ParseResult = parser.parse_args()
|
||||
|
||||
datasource_node_run_api_entity = DatasourceNodeRunApiEntity.model_validate(args)
|
||||
payload = DatasourceNodeRunPayload.model_validate(service_api_ns.payload or {})
|
||||
assert isinstance(current_user, Account)
|
||||
rag_pipeline_service: RagPipelineService = RagPipelineService()
|
||||
pipeline: Pipeline = rag_pipeline_service.get_pipeline(tenant_id=tenant_id, dataset_id=dataset_id)
|
||||
datasource_node_run_api_entity = DatasourceNodeRunApiEntity.model_validate(
|
||||
{
|
||||
**payload.model_dump(exclude_none=True),
|
||||
"pipeline_id": str(pipeline.id),
|
||||
"node_id": node_id,
|
||||
}
|
||||
)
|
||||
return helper.compact_generate_response(
|
||||
PipelineGenerator.convert_to_event_stream(
|
||||
rag_pipeline_service.run_datasource_workflow_node(
|
||||
|
|
@ -147,25 +159,10 @@ class PipelineRunApi(DatasetApiResource):
|
|||
401: "Unauthorized - invalid API token",
|
||||
}
|
||||
)
|
||||
@service_api_ns.expect(service_api_ns.models[PipelineRunApiEntity.__name__])
|
||||
def post(self, tenant_id: str, dataset_id: str):
|
||||
"""Resource for running a rag pipeline."""
|
||||
parser: RequestParser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("datasource_type", type=str, required=True, location="json")
|
||||
.add_argument("datasource_info_list", type=list, required=True, location="json")
|
||||
.add_argument("start_node_id", type=str, required=True, location="json")
|
||||
.add_argument("is_published", type=bool, required=True, default=True, location="json")
|
||||
.add_argument(
|
||||
"response_mode",
|
||||
type=str,
|
||||
required=True,
|
||||
choices=["streaming", "blocking"],
|
||||
default="blocking",
|
||||
location="json",
|
||||
)
|
||||
)
|
||||
args: ParseResult = parser.parse_args()
|
||||
payload = PipelineRunApiEntity.model_validate(service_api_ns.payload or {})
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
|
|
@ -176,9 +173,9 @@ class PipelineRunApi(DatasetApiResource):
|
|||
response: dict[Any, Any] | Generator[str, Any, None] = PipelineGenerateService.generate(
|
||||
pipeline=pipeline,
|
||||
user=current_user,
|
||||
args=args,
|
||||
invoke_from=InvokeFrom.PUBLISHED if args.get("is_published") else InvokeFrom.DEBUGGER,
|
||||
streaming=args.get("response_mode") == "streaming",
|
||||
args=payload.model_dump(),
|
||||
invoke_from=InvokeFrom.PUBLISHED if payload.is_published else InvokeFrom.DEBUGGER,
|
||||
streaming=payload.response_mode == "streaming",
|
||||
)
|
||||
|
||||
return helper.compact_generate_response(response)
|
||||
|
|
|
|||
|
|
@ -1,8 +1,12 @@
|
|||
from typing import Any
|
||||
|
||||
from flask import request
|
||||
from flask_restx import marshal, reqparse
|
||||
from flask_restx import marshal
|
||||
from pydantic import BaseModel, Field
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.app.error import ProviderNotInitializeError
|
||||
from controllers.service_api.wraps import (
|
||||
|
|
@ -24,34 +28,42 @@ from services.errors.chunk import ChildChunkDeleteIndexError, ChildChunkIndexing
|
|||
from services.errors.chunk import ChildChunkDeleteIndexError as ChildChunkDeleteIndexServiceError
|
||||
from services.errors.chunk import ChildChunkIndexingError as ChildChunkIndexingServiceError
|
||||
|
||||
# Define parsers for segment operations
|
||||
segment_create_parser = reqparse.RequestParser().add_argument(
|
||||
"segments", type=list, required=False, nullable=True, location="json"
|
||||
)
|
||||
|
||||
segment_list_parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("status", type=str, action="append", default=[], location="args")
|
||||
.add_argument("keyword", type=str, default=None, location="args")
|
||||
)
|
||||
class SegmentCreatePayload(BaseModel):
|
||||
segments: list[dict[str, Any]] | None = None
|
||||
|
||||
segment_update_parser = reqparse.RequestParser().add_argument(
|
||||
"segment", type=dict, required=False, nullable=True, location="json"
|
||||
)
|
||||
|
||||
child_chunk_create_parser = reqparse.RequestParser().add_argument(
|
||||
"content", type=str, required=True, nullable=False, location="json"
|
||||
)
|
||||
class SegmentListQuery(BaseModel):
|
||||
status: list[str] = Field(default_factory=list)
|
||||
keyword: str | None = None
|
||||
|
||||
child_chunk_list_parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("limit", type=int, default=20, location="args")
|
||||
.add_argument("keyword", type=str, default=None, location="args")
|
||||
.add_argument("page", type=int, default=1, location="args")
|
||||
)
|
||||
|
||||
child_chunk_update_parser = reqparse.RequestParser().add_argument(
|
||||
"content", type=str, required=True, nullable=False, location="json"
|
||||
class SegmentUpdatePayload(BaseModel):
|
||||
segment: SegmentUpdateArgs
|
||||
|
||||
|
||||
class ChildChunkCreatePayload(BaseModel):
|
||||
content: str
|
||||
|
||||
|
||||
class ChildChunkListQuery(BaseModel):
|
||||
limit: int = Field(default=20, ge=1)
|
||||
keyword: str | None = None
|
||||
page: int = Field(default=1, ge=1)
|
||||
|
||||
|
||||
class ChildChunkUpdatePayload(BaseModel):
|
||||
content: str
|
||||
|
||||
|
||||
register_schema_models(
|
||||
service_api_ns,
|
||||
SegmentCreatePayload,
|
||||
SegmentListQuery,
|
||||
SegmentUpdatePayload,
|
||||
ChildChunkCreatePayload,
|
||||
ChildChunkListQuery,
|
||||
ChildChunkUpdatePayload,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -59,7 +71,7 @@ child_chunk_update_parser = reqparse.RequestParser().add_argument(
|
|||
class SegmentApi(DatasetApiResource):
|
||||
"""Resource for segments."""
|
||||
|
||||
@service_api_ns.expect(segment_create_parser)
|
||||
@service_api_ns.expect(service_api_ns.models[SegmentCreatePayload.__name__])
|
||||
@service_api_ns.doc("create_segments")
|
||||
@service_api_ns.doc(description="Create segments in a document")
|
||||
@service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
|
||||
|
|
@ -106,20 +118,20 @@ class SegmentApi(DatasetApiResource):
|
|||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
# validate args
|
||||
args = segment_create_parser.parse_args()
|
||||
if args["segments"] is not None:
|
||||
payload = SegmentCreatePayload.model_validate(service_api_ns.payload or {})
|
||||
if payload.segments is not None:
|
||||
segments_limit = dify_config.DATASET_MAX_SEGMENTS_PER_REQUEST
|
||||
if segments_limit > 0 and len(args["segments"]) > segments_limit:
|
||||
if segments_limit > 0 and len(payload.segments) > segments_limit:
|
||||
raise ValueError(f"Exceeded maximum segments limit of {segments_limit}.")
|
||||
|
||||
for args_item in args["segments"]:
|
||||
for args_item in payload.segments:
|
||||
SegmentService.segment_create_args_validate(args_item, document)
|
||||
segments = SegmentService.multi_create_segment(args["segments"], document, dataset)
|
||||
segments = SegmentService.multi_create_segment(payload.segments, document, dataset)
|
||||
return {"data": marshal(segments, segment_fields), "doc_form": document.doc_form}, 200
|
||||
else:
|
||||
return {"error": "Segments is required"}, 400
|
||||
|
||||
@service_api_ns.expect(segment_list_parser)
|
||||
@service_api_ns.expect(service_api_ns.models[SegmentListQuery.__name__])
|
||||
@service_api_ns.doc("list_segments")
|
||||
@service_api_ns.doc(description="List segments in a document")
|
||||
@service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
|
||||
|
|
@ -160,13 +172,18 @@ class SegmentApi(DatasetApiResource):
|
|||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
|
||||
args = segment_list_parser.parse_args()
|
||||
args = SegmentListQuery.model_validate(
|
||||
{
|
||||
"status": request.args.getlist("status"),
|
||||
"keyword": request.args.get("keyword"),
|
||||
}
|
||||
)
|
||||
|
||||
segments, total = SegmentService.get_segments(
|
||||
document_id=document_id,
|
||||
tenant_id=current_tenant_id,
|
||||
status_list=args["status"],
|
||||
keyword=args["keyword"],
|
||||
status_list=args.status,
|
||||
keyword=args.keyword,
|
||||
page=page,
|
||||
limit=limit,
|
||||
)
|
||||
|
|
@ -217,7 +234,7 @@ class DatasetSegmentApi(DatasetApiResource):
|
|||
SegmentService.delete_segment(segment, document, dataset)
|
||||
return 204
|
||||
|
||||
@service_api_ns.expect(segment_update_parser)
|
||||
@service_api_ns.expect(service_api_ns.models[SegmentUpdatePayload.__name__])
|
||||
@service_api_ns.doc("update_segment")
|
||||
@service_api_ns.doc(description="Update a specific segment")
|
||||
@service_api_ns.doc(
|
||||
|
|
@ -265,12 +282,9 @@ class DatasetSegmentApi(DatasetApiResource):
|
|||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
|
||||
# validate args
|
||||
args = segment_update_parser.parse_args()
|
||||
payload = SegmentUpdatePayload.model_validate(service_api_ns.payload or {})
|
||||
|
||||
updated_segment = SegmentService.update_segment(
|
||||
SegmentUpdateArgs.model_validate(args["segment"]), segment, document, dataset
|
||||
)
|
||||
updated_segment = SegmentService.update_segment(payload.segment, segment, document, dataset)
|
||||
return {"data": marshal(updated_segment, segment_fields), "doc_form": document.doc_form}, 200
|
||||
|
||||
@service_api_ns.doc("get_segment")
|
||||
|
|
@ -308,7 +322,7 @@ class DatasetSegmentApi(DatasetApiResource):
|
|||
class ChildChunkApi(DatasetApiResource):
|
||||
"""Resource for child chunks."""
|
||||
|
||||
@service_api_ns.expect(child_chunk_create_parser)
|
||||
@service_api_ns.expect(service_api_ns.models[ChildChunkCreatePayload.__name__])
|
||||
@service_api_ns.doc("create_child_chunk")
|
||||
@service_api_ns.doc(description="Create a new child chunk for a segment")
|
||||
@service_api_ns.doc(
|
||||
|
|
@ -360,16 +374,16 @@ class ChildChunkApi(DatasetApiResource):
|
|||
raise ProviderNotInitializeError(ex.description)
|
||||
|
||||
# validate args
|
||||
args = child_chunk_create_parser.parse_args()
|
||||
payload = ChildChunkCreatePayload.model_validate(service_api_ns.payload or {})
|
||||
|
||||
try:
|
||||
child_chunk = SegmentService.create_child_chunk(args["content"], segment, document, dataset)
|
||||
child_chunk = SegmentService.create_child_chunk(payload.content, segment, document, dataset)
|
||||
except ChildChunkIndexingServiceError as e:
|
||||
raise ChildChunkIndexingError(str(e))
|
||||
|
||||
return {"data": marshal(child_chunk, child_chunk_fields)}, 200
|
||||
|
||||
@service_api_ns.expect(child_chunk_list_parser)
|
||||
@service_api_ns.expect(service_api_ns.models[ChildChunkListQuery.__name__])
|
||||
@service_api_ns.doc("list_child_chunks")
|
||||
@service_api_ns.doc(description="List child chunks for a segment")
|
||||
@service_api_ns.doc(
|
||||
|
|
@ -400,11 +414,17 @@ class ChildChunkApi(DatasetApiResource):
|
|||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
|
||||
args = child_chunk_list_parser.parse_args()
|
||||
args = ChildChunkListQuery.model_validate(
|
||||
{
|
||||
"limit": request.args.get("limit", default=20, type=int),
|
||||
"keyword": request.args.get("keyword"),
|
||||
"page": request.args.get("page", default=1, type=int),
|
||||
}
|
||||
)
|
||||
|
||||
page = args["page"]
|
||||
limit = min(args["limit"], 100)
|
||||
keyword = args["keyword"]
|
||||
page = args.page
|
||||
limit = min(args.limit, 100)
|
||||
keyword = args.keyword
|
||||
|
||||
child_chunks = SegmentService.get_child_chunks(segment_id, document_id, dataset_id, page, limit, keyword)
|
||||
|
||||
|
|
@ -480,7 +500,7 @@ class DatasetChildChunkApi(DatasetApiResource):
|
|||
|
||||
return 204
|
||||
|
||||
@service_api_ns.expect(child_chunk_update_parser)
|
||||
@service_api_ns.expect(service_api_ns.models[ChildChunkUpdatePayload.__name__])
|
||||
@service_api_ns.doc("update_child_chunk")
|
||||
@service_api_ns.doc(description="Update a specific child chunk")
|
||||
@service_api_ns.doc(
|
||||
|
|
@ -533,10 +553,10 @@ class DatasetChildChunkApi(DatasetApiResource):
|
|||
raise NotFound("Child chunk not found.")
|
||||
|
||||
# validate args
|
||||
args = child_chunk_update_parser.parse_args()
|
||||
payload = ChildChunkUpdatePayload.model_validate(service_api_ns.payload or {})
|
||||
|
||||
try:
|
||||
child_chunk = SegmentService.update_child_chunk(args["content"], child_chunk, segment, document, dataset)
|
||||
child_chunk = SegmentService.update_child_chunk(payload.content, child_chunk, segment, document, dataset)
|
||||
except ChildChunkIndexingServiceError as e:
|
||||
raise ChildChunkIndexingError(str(e))
|
||||
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import logging
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from datetime import timedelta
|
||||
|
|
@ -28,6 +29,8 @@ P = ParamSpec("P")
|
|||
R = TypeVar("R")
|
||||
T = TypeVar("T")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WhereisUserArg(StrEnum):
|
||||
"""
|
||||
|
|
@ -238,8 +241,8 @@ def validate_dataset_token(view: Callable[Concatenate[T, P], R] | None = None):
|
|||
# Basic check: UUIDs are 36 chars with hyphens
|
||||
if len(str_id) == 36 and str_id.count("-") == 4:
|
||||
dataset_id = str_id
|
||||
except:
|
||||
pass
|
||||
except Exception:
|
||||
logger.exception("Failed to parse dataset_id from class method args")
|
||||
elif len(args) > 0:
|
||||
# Not a class method, check if args[0] looks like a UUID
|
||||
potential_id = args[0]
|
||||
|
|
@ -247,8 +250,8 @@ def validate_dataset_token(view: Callable[Concatenate[T, P], R] | None = None):
|
|||
str_id = str(potential_id)
|
||||
if len(str_id) == 36 and str_id.count("-") == 4:
|
||||
dataset_id = str_id
|
||||
except:
|
||||
pass
|
||||
except Exception:
|
||||
logger.exception("Failed to parse dataset_id from positional args")
|
||||
|
||||
# Validate dataset if dataset_id is provided
|
||||
if dataset_id:
|
||||
|
|
@ -316,18 +319,16 @@ def validate_and_get_api_token(scope: str | None = None):
|
|||
ApiToken.type == scope,
|
||||
)
|
||||
.values(last_used_at=current_time)
|
||||
.returning(ApiToken)
|
||||
)
|
||||
stmt = select(ApiToken).where(ApiToken.token == auth_token, ApiToken.type == scope)
|
||||
result = session.execute(update_stmt)
|
||||
api_token = result.scalar_one_or_none()
|
||||
api_token = session.scalar(stmt)
|
||||
|
||||
if hasattr(result, "rowcount") and result.rowcount > 0:
|
||||
session.commit()
|
||||
|
||||
if not api_token:
|
||||
stmt = select(ApiToken).where(ApiToken.token == auth_token, ApiToken.type == scope)
|
||||
api_token = session.scalar(stmt)
|
||||
if not api_token:
|
||||
raise Unauthorized("Access token is invalid")
|
||||
else:
|
||||
session.commit()
|
||||
raise Unauthorized("Access token is invalid")
|
||||
|
||||
return api_token
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import json
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import Any
|
||||
|
|
@ -23,6 +24,8 @@ from core.tools.entities.tool_entities import ToolInvokeMeta
|
|||
from core.tools.tool_engine import ToolEngine
|
||||
from models.model import Message
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
_is_first_iteration = True
|
||||
|
|
@ -400,8 +403,8 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
|||
action_input=json.loads(message.tool_calls[0].function.arguments),
|
||||
)
|
||||
current_scratchpad.action_str = json.dumps(current_scratchpad.action.to_dict())
|
||||
except:
|
||||
pass
|
||||
except Exception:
|
||||
logger.exception("Failed to parse tool call from assistant message")
|
||||
elif isinstance(message, ToolPromptMessage):
|
||||
if current_scratchpad:
|
||||
assert isinstance(message.content, str)
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ from collections.abc import Sequence
|
|||
from enum import StrEnum, auto
|
||||
from typing import Any, Literal
|
||||
|
||||
from jsonschema import Draft7Validator, SchemaError
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from core.file import FileTransferMethod, FileType, FileUploadConfig
|
||||
|
|
@ -98,6 +99,7 @@ class VariableEntityType(StrEnum):
|
|||
FILE = "file"
|
||||
FILE_LIST = "file-list"
|
||||
CHECKBOX = "checkbox"
|
||||
JSON_OBJECT = "json_object"
|
||||
|
||||
|
||||
class VariableEntity(BaseModel):
|
||||
|
|
@ -118,6 +120,7 @@ class VariableEntity(BaseModel):
|
|||
allowed_file_types: Sequence[FileType] | None = Field(default_factory=list)
|
||||
allowed_file_extensions: Sequence[str] | None = Field(default_factory=list)
|
||||
allowed_file_upload_methods: Sequence[FileTransferMethod] | None = Field(default_factory=list)
|
||||
json_schema: dict[str, Any] | None = Field(default=None)
|
||||
|
||||
@field_validator("description", mode="before")
|
||||
@classmethod
|
||||
|
|
@ -129,6 +132,17 @@ class VariableEntity(BaseModel):
|
|||
def convert_none_options(cls, v: Any) -> Sequence[str]:
|
||||
return v or []
|
||||
|
||||
@field_validator("json_schema")
|
||||
@classmethod
|
||||
def validate_json_schema(cls, schema: dict[str, Any] | None) -> dict[str, Any] | None:
|
||||
if schema is None:
|
||||
return None
|
||||
try:
|
||||
Draft7Validator.check_schema(schema)
|
||||
except SchemaError as e:
|
||||
raise ValueError(f"Invalid JSON schema: {e.message}")
|
||||
return schema
|
||||
|
||||
|
||||
class RagPipelineVariableEntity(VariableEntity):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -35,6 +35,7 @@ from core.workflow.variable_loader import VariableLoader
|
|||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from extensions.otel import WorkflowAppRunnerHandler, trace_span
|
||||
from models import Workflow
|
||||
from models.enums import UserFrom
|
||||
from models.model import App, Conversation, Message, MessageAnnotation
|
||||
|
|
@ -80,6 +81,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
|||
self._workflow_execution_repository = workflow_execution_repository
|
||||
self._workflow_node_execution_repository = workflow_node_execution_repository
|
||||
|
||||
@trace_span(WorkflowAppRunnerHandler)
|
||||
def run(self):
|
||||
app_config = self.application_generate_entity.app_config
|
||||
app_config = cast(AdvancedChatAppConfig, app_config)
|
||||
|
|
|
|||
|
|
@ -62,8 +62,7 @@ from core.app.task_pipeline.message_cycle_manager import MessageCycleManager
|
|||
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.ops.entities.trace_entity import TraceTaskName
|
||||
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.workflow.enums import WorkflowExecutionStatus
|
||||
from core.workflow.nodes import NodeType
|
||||
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
|
||||
|
|
@ -73,7 +72,7 @@ from extensions.ext_database import db
|
|||
from libs.datetime_utils import naive_utc_now
|
||||
from models import Account, Conversation, EndUser, Message, MessageFile
|
||||
from models.enums import CreatorUserRole
|
||||
from models.workflow import Workflow, WorkflowNodeExecutionModel
|
||||
from models.workflow import Workflow
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -581,7 +580,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
|||
|
||||
with self._database_session() as session:
|
||||
# Save message
|
||||
self._save_message(session=session, graph_runtime_state=resolved_state, trace_manager=trace_manager)
|
||||
self._save_message(session=session, graph_runtime_state=resolved_state)
|
||||
|
||||
yield workflow_finish_resp
|
||||
elif event.stopped_by in (
|
||||
|
|
@ -591,7 +590,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
|||
# When hitting input-moderation or annotation-reply, the workflow will not start
|
||||
with self._database_session() as session:
|
||||
# Save message
|
||||
self._save_message(session=session, trace_manager=trace_manager)
|
||||
self._save_message(session=session)
|
||||
|
||||
yield self._message_end_to_stream_response()
|
||||
|
||||
|
|
@ -600,7 +599,6 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
|||
event: QueueAdvancedChatMessageEndEvent,
|
||||
*,
|
||||
graph_runtime_state: GraphRuntimeState | None = None,
|
||||
trace_manager: TraceQueueManager | None = None,
|
||||
**kwargs,
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle advanced chat message end events."""
|
||||
|
|
@ -618,7 +616,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
|||
|
||||
# Save message
|
||||
with self._database_session() as session:
|
||||
self._save_message(session=session, graph_runtime_state=resolved_state, trace_manager=trace_manager)
|
||||
self._save_message(session=session, graph_runtime_state=resolved_state)
|
||||
|
||||
yield self._message_end_to_stream_response()
|
||||
|
||||
|
|
@ -770,15 +768,9 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
|||
tts_publisher.publish(None)
|
||||
|
||||
if self._conversation_name_generate_thread:
|
||||
self._conversation_name_generate_thread.join()
|
||||
logger.debug("Conversation name generation running as daemon thread")
|
||||
|
||||
def _save_message(
|
||||
self,
|
||||
*,
|
||||
session: Session,
|
||||
graph_runtime_state: GraphRuntimeState | None = None,
|
||||
trace_manager: TraceQueueManager | None = None,
|
||||
):
|
||||
def _save_message(self, *, session: Session, graph_runtime_state: GraphRuntimeState | None = None):
|
||||
message = self._get_message(session=session)
|
||||
|
||||
# If there are assistant files, remove markdown image links from answer
|
||||
|
|
@ -817,14 +809,6 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
|||
|
||||
metadata = self._task_state.metadata.model_dump()
|
||||
message.message_metadata = json.dumps(jsonable_encoder(metadata))
|
||||
|
||||
# Extract model provider and model_id from workflow node executions for tracing
|
||||
if message.workflow_run_id:
|
||||
model_info = self._extract_model_info_from_workflow(session, message.workflow_run_id)
|
||||
if model_info:
|
||||
message.model_provider = model_info.get("provider")
|
||||
message.model_id = model_info.get("model")
|
||||
|
||||
message_files = [
|
||||
MessageFile(
|
||||
message_id=message.id,
|
||||
|
|
@ -842,68 +826,6 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
|||
]
|
||||
session.add_all(message_files)
|
||||
|
||||
# Trigger MESSAGE_TRACE for tracing integrations
|
||||
if trace_manager:
|
||||
trace_manager.add_trace_task(
|
||||
TraceTask(
|
||||
TraceTaskName.MESSAGE_TRACE, conversation_id=self._conversation_id, message_id=self._message_id
|
||||
)
|
||||
)
|
||||
|
||||
def _extract_model_info_from_workflow(self, session: Session, workflow_run_id: str) -> dict[str, str] | None:
|
||||
"""
|
||||
Extract model provider and model_id from workflow node executions.
|
||||
Returns dict with 'provider' and 'model' keys, or None if not found.
|
||||
"""
|
||||
try:
|
||||
# Query workflow node executions for LLM or Agent nodes
|
||||
stmt = (
|
||||
select(WorkflowNodeExecutionModel)
|
||||
.where(WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id)
|
||||
.where(WorkflowNodeExecutionModel.node_type.in_(["llm", "agent"]))
|
||||
.order_by(WorkflowNodeExecutionModel.created_at.desc())
|
||||
.limit(1)
|
||||
)
|
||||
node_execution = session.scalar(stmt)
|
||||
|
||||
if not node_execution:
|
||||
return None
|
||||
|
||||
# Try to extract from execution_metadata for agent nodes
|
||||
if node_execution.execution_metadata:
|
||||
try:
|
||||
metadata = json.loads(node_execution.execution_metadata)
|
||||
agent_log = metadata.get("agent_log", [])
|
||||
# Look for the first agent thought with provider info
|
||||
for log_entry in agent_log:
|
||||
entry_metadata = log_entry.get("metadata", {})
|
||||
provider_str = entry_metadata.get("provider")
|
||||
if provider_str:
|
||||
# Parse format like "langgenius/deepseek/deepseek"
|
||||
parts = provider_str.split("/")
|
||||
if len(parts) >= 3:
|
||||
return {"provider": parts[1], "model": parts[2]}
|
||||
elif len(parts) == 2:
|
||||
return {"provider": parts[0], "model": parts[1]}
|
||||
except (json.JSONDecodeError, KeyError, AttributeError) as e:
|
||||
logger.debug("Failed to parse execution_metadata: %s", e)
|
||||
|
||||
# Try to extract from process_data for llm nodes
|
||||
if node_execution.process_data:
|
||||
try:
|
||||
process_data = json.loads(node_execution.process_data)
|
||||
provider = process_data.get("model_provider")
|
||||
model = process_data.get("model_name")
|
||||
if provider and model:
|
||||
return {"provider": provider, "model": model}
|
||||
except (json.JSONDecodeError, KeyError) as e:
|
||||
logger.debug("Failed to parse process_data: %s", e)
|
||||
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.warning("Failed to extract model info from workflow: %s", e)
|
||||
return None
|
||||
|
||||
def _seed_graph_runtime_state_from_queue_manager(self) -> None:
|
||||
"""Bootstrap the cached runtime state from the queue manager when present."""
|
||||
candidate = self._base_task_pipeline.queue_manager.graph_runtime_state
|
||||
|
|
|
|||
|
|
@ -99,6 +99,15 @@ class BaseAppGenerator:
|
|||
if value is None:
|
||||
return None
|
||||
|
||||
# Treat empty placeholders for optional file inputs as unset
|
||||
if (
|
||||
variable_entity.type in {VariableEntityType.FILE, VariableEntityType.FILE_LIST}
|
||||
and not variable_entity.required
|
||||
):
|
||||
# Treat empty string (frontend default) or empty list as unset
|
||||
if not value and isinstance(value, (str, list)):
|
||||
return None
|
||||
|
||||
if variable_entity.type in {
|
||||
VariableEntityType.TEXT_INPUT,
|
||||
VariableEntityType.SELECT,
|
||||
|
|
|
|||
|
|
@ -83,6 +83,7 @@ class AppRunner:
|
|||
context: str | None = None,
|
||||
memory: TokenBufferMemory | None = None,
|
||||
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
|
||||
context_files: list["File"] | None = None,
|
||||
) -> tuple[list[PromptMessage], list[str] | None]:
|
||||
"""
|
||||
Organize prompt messages
|
||||
|
|
@ -111,6 +112,7 @@ class AppRunner:
|
|||
memory=memory,
|
||||
model_config=model_config,
|
||||
image_detail_config=image_detail_config,
|
||||
context_files=context_files,
|
||||
)
|
||||
else:
|
||||
memory_config = MemoryConfig(window=MemoryConfig.WindowConfig(enabled=False))
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ from core.app.entities.app_invoke_entities import (
|
|||
)
|
||||
from core.app.entities.queue_entities import QueueAnnotationReplyEvent
|
||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||
from core.file import File
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
|
||||
|
|
@ -146,6 +147,7 @@ class ChatAppRunner(AppRunner):
|
|||
|
||||
# get context from datasets
|
||||
context = None
|
||||
context_files: list[File] = []
|
||||
if app_config.dataset and app_config.dataset.dataset_ids:
|
||||
hit_callback = DatasetIndexToolCallbackHandler(
|
||||
queue_manager,
|
||||
|
|
@ -156,7 +158,7 @@ class ChatAppRunner(AppRunner):
|
|||
)
|
||||
|
||||
dataset_retrieval = DatasetRetrieval(application_generate_entity)
|
||||
context = dataset_retrieval.retrieve(
|
||||
context, retrieved_files = dataset_retrieval.retrieve(
|
||||
app_id=app_record.id,
|
||||
user_id=application_generate_entity.user_id,
|
||||
tenant_id=app_record.tenant_id,
|
||||
|
|
@ -171,7 +173,11 @@ class ChatAppRunner(AppRunner):
|
|||
memory=memory,
|
||||
message_id=message.id,
|
||||
inputs=inputs,
|
||||
vision_enabled=application_generate_entity.app_config.app_model_config_dict.get("file_upload", {}).get(
|
||||
"enabled", False
|
||||
),
|
||||
)
|
||||
context_files = retrieved_files or []
|
||||
|
||||
# reorganize all inputs and template to prompt messages
|
||||
# Include: prompt template, inputs, query(optional), files(optional)
|
||||
|
|
@ -186,6 +192,7 @@ class ChatAppRunner(AppRunner):
|
|||
context=context,
|
||||
memory=memory,
|
||||
image_detail_config=image_detail_config,
|
||||
context_files=context_files,
|
||||
)
|
||||
|
||||
# check hosting moderation
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import logging
|
||||
import time
|
||||
from collections.abc import Mapping, Sequence
|
||||
from dataclasses import dataclass
|
||||
|
|
@ -55,6 +56,7 @@ from models import Account, EndUser
|
|||
from services.variable_truncator import BaseTruncator, DummyVariableTruncator, VariableTruncator
|
||||
|
||||
NodeExecutionId = NewType("NodeExecutionId", str)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
|
|
@ -289,26 +291,30 @@ class WorkflowResponseConverter:
|
|||
),
|
||||
)
|
||||
|
||||
if event.node_type == NodeType.TOOL:
|
||||
response.data.extras["icon"] = ToolManager.get_tool_icon(
|
||||
tenant_id=self._application_generate_entity.app_config.tenant_id,
|
||||
provider_type=ToolProviderType(event.provider_type),
|
||||
provider_id=event.provider_id,
|
||||
)
|
||||
elif event.node_type == NodeType.DATASOURCE:
|
||||
manager = PluginDatasourceManager()
|
||||
provider_entity = manager.fetch_datasource_provider(
|
||||
self._application_generate_entity.app_config.tenant_id,
|
||||
event.provider_id,
|
||||
)
|
||||
response.data.extras["icon"] = provider_entity.declaration.identity.generate_datasource_icon_url(
|
||||
self._application_generate_entity.app_config.tenant_id
|
||||
)
|
||||
elif event.node_type == NodeType.TRIGGER_PLUGIN:
|
||||
response.data.extras["icon"] = TriggerManager.get_trigger_plugin_icon(
|
||||
self._application_generate_entity.app_config.tenant_id,
|
||||
event.provider_id,
|
||||
)
|
||||
try:
|
||||
if event.node_type == NodeType.TOOL:
|
||||
response.data.extras["icon"] = ToolManager.get_tool_icon(
|
||||
tenant_id=self._application_generate_entity.app_config.tenant_id,
|
||||
provider_type=ToolProviderType(event.provider_type),
|
||||
provider_id=event.provider_id,
|
||||
)
|
||||
elif event.node_type == NodeType.DATASOURCE:
|
||||
manager = PluginDatasourceManager()
|
||||
provider_entity = manager.fetch_datasource_provider(
|
||||
self._application_generate_entity.app_config.tenant_id,
|
||||
event.provider_id,
|
||||
)
|
||||
response.data.extras["icon"] = provider_entity.declaration.identity.generate_datasource_icon_url(
|
||||
self._application_generate_entity.app_config.tenant_id
|
||||
)
|
||||
elif event.node_type == NodeType.TRIGGER_PLUGIN:
|
||||
response.data.extras["icon"] = TriggerManager.get_trigger_plugin_icon(
|
||||
self._application_generate_entity.app_config.tenant_id,
|
||||
event.provider_id,
|
||||
)
|
||||
except Exception:
|
||||
# metadata fetch may fail, for example, the plugin daemon is down or plugin is uninstalled.
|
||||
logger.warning("failed to fetch icon for %s", event.provider_id)
|
||||
|
||||
return response
|
||||
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ from core.app.entities.app_invoke_entities import (
|
|||
CompletionAppGenerateEntity,
|
||||
)
|
||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||
from core.file import File
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
|
||||
from core.moderation.base import ModerationError
|
||||
|
|
@ -102,6 +103,7 @@ class CompletionAppRunner(AppRunner):
|
|||
|
||||
# get context from datasets
|
||||
context = None
|
||||
context_files: list[File] = []
|
||||
if app_config.dataset and app_config.dataset.dataset_ids:
|
||||
hit_callback = DatasetIndexToolCallbackHandler(
|
||||
queue_manager,
|
||||
|
|
@ -116,7 +118,7 @@ class CompletionAppRunner(AppRunner):
|
|||
query = inputs.get(dataset_config.retrieve_config.query_variable, "")
|
||||
|
||||
dataset_retrieval = DatasetRetrieval(application_generate_entity)
|
||||
context = dataset_retrieval.retrieve(
|
||||
context, retrieved_files = dataset_retrieval.retrieve(
|
||||
app_id=app_record.id,
|
||||
user_id=application_generate_entity.user_id,
|
||||
tenant_id=app_record.tenant_id,
|
||||
|
|
@ -130,7 +132,11 @@ class CompletionAppRunner(AppRunner):
|
|||
hit_callback=hit_callback,
|
||||
message_id=message.id,
|
||||
inputs=inputs,
|
||||
vision_enabled=application_generate_entity.app_config.app_model_config_dict.get("file_upload", {}).get(
|
||||
"enabled", False
|
||||
),
|
||||
)
|
||||
context_files = retrieved_files or []
|
||||
|
||||
# reorganize all inputs and template to prompt messages
|
||||
# Include: prompt template, inputs, query(optional), files(optional)
|
||||
|
|
@ -144,6 +150,7 @@ class CompletionAppRunner(AppRunner):
|
|||
query=query,
|
||||
context=context,
|
||||
image_detail_config=image_detail_config,
|
||||
context_files=context_files,
|
||||
)
|
||||
|
||||
# check hosting moderation
|
||||
|
|
|
|||
|
|
@ -156,79 +156,86 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
|||
query = application_generate_entity.query or "New conversation"
|
||||
conversation_name = (query[:20] + "…") if len(query) > 20 else query
|
||||
|
||||
if not conversation:
|
||||
conversation = Conversation(
|
||||
try:
|
||||
if not conversation:
|
||||
conversation = Conversation(
|
||||
app_id=app_config.app_id,
|
||||
app_model_config_id=app_model_config_id,
|
||||
model_provider=model_provider,
|
||||
model_id=model_id,
|
||||
override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
|
||||
mode=app_config.app_mode.value,
|
||||
name=conversation_name,
|
||||
inputs=application_generate_entity.inputs,
|
||||
introduction=introduction,
|
||||
system_instruction="",
|
||||
system_instruction_tokens=0,
|
||||
status="normal",
|
||||
invoke_from=application_generate_entity.invoke_from.value,
|
||||
from_source=from_source,
|
||||
from_end_user_id=end_user_id,
|
||||
from_account_id=account_id,
|
||||
)
|
||||
|
||||
db.session.add(conversation)
|
||||
db.session.flush()
|
||||
db.session.refresh(conversation)
|
||||
else:
|
||||
conversation.updated_at = naive_utc_now()
|
||||
|
||||
message = Message(
|
||||
app_id=app_config.app_id,
|
||||
app_model_config_id=app_model_config_id,
|
||||
model_provider=model_provider,
|
||||
model_id=model_id,
|
||||
override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
|
||||
mode=app_config.app_mode.value,
|
||||
name=conversation_name,
|
||||
conversation_id=conversation.id,
|
||||
inputs=application_generate_entity.inputs,
|
||||
introduction=introduction,
|
||||
system_instruction="",
|
||||
system_instruction_tokens=0,
|
||||
status="normal",
|
||||
query=application_generate_entity.query,
|
||||
message="",
|
||||
message_tokens=0,
|
||||
message_unit_price=0,
|
||||
message_price_unit=0,
|
||||
answer="",
|
||||
answer_tokens=0,
|
||||
answer_unit_price=0,
|
||||
answer_price_unit=0,
|
||||
parent_message_id=getattr(application_generate_entity, "parent_message_id", None),
|
||||
provider_response_latency=0,
|
||||
total_price=0,
|
||||
currency="USD",
|
||||
invoke_from=application_generate_entity.invoke_from.value,
|
||||
from_source=from_source,
|
||||
from_end_user_id=end_user_id,
|
||||
from_account_id=account_id,
|
||||
app_mode=app_config.app_mode,
|
||||
)
|
||||
|
||||
db.session.add(conversation)
|
||||
db.session.add(message)
|
||||
db.session.flush()
|
||||
db.session.refresh(message)
|
||||
|
||||
message_files = []
|
||||
for file in application_generate_entity.files:
|
||||
message_file = MessageFile(
|
||||
message_id=message.id,
|
||||
type=file.type,
|
||||
transfer_method=file.transfer_method,
|
||||
belongs_to="user",
|
||||
url=file.remote_url,
|
||||
upload_file_id=file.related_id,
|
||||
created_by_role=(CreatorUserRole.ACCOUNT if account_id else CreatorUserRole.END_USER),
|
||||
created_by=account_id or end_user_id or "",
|
||||
)
|
||||
message_files.append(message_file)
|
||||
|
||||
if message_files:
|
||||
db.session.add_all(message_files)
|
||||
|
||||
db.session.commit()
|
||||
db.session.refresh(conversation)
|
||||
else:
|
||||
conversation.updated_at = naive_utc_now()
|
||||
db.session.commit()
|
||||
|
||||
message = Message(
|
||||
app_id=app_config.app_id,
|
||||
model_provider=model_provider,
|
||||
model_id=model_id,
|
||||
override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
|
||||
conversation_id=conversation.id,
|
||||
inputs=application_generate_entity.inputs,
|
||||
query=application_generate_entity.query,
|
||||
message="",
|
||||
message_tokens=0,
|
||||
message_unit_price=0,
|
||||
message_price_unit=0,
|
||||
answer="",
|
||||
answer_tokens=0,
|
||||
answer_unit_price=0,
|
||||
answer_price_unit=0,
|
||||
parent_message_id=getattr(application_generate_entity, "parent_message_id", None),
|
||||
provider_response_latency=0,
|
||||
total_price=0,
|
||||
currency="USD",
|
||||
invoke_from=application_generate_entity.invoke_from.value,
|
||||
from_source=from_source,
|
||||
from_end_user_id=end_user_id,
|
||||
from_account_id=account_id,
|
||||
app_mode=app_config.app_mode,
|
||||
)
|
||||
|
||||
db.session.add(message)
|
||||
db.session.commit()
|
||||
db.session.refresh(message)
|
||||
|
||||
for file in application_generate_entity.files:
|
||||
message_file = MessageFile(
|
||||
message_id=message.id,
|
||||
type=file.type,
|
||||
transfer_method=file.transfer_method,
|
||||
belongs_to="user",
|
||||
url=file.remote_url,
|
||||
upload_file_id=file.related_id,
|
||||
created_by_role=(CreatorUserRole.ACCOUNT if account_id else CreatorUserRole.END_USER),
|
||||
created_by=account_id or end_user_id or "",
|
||||
)
|
||||
db.session.add(message_file)
|
||||
db.session.commit()
|
||||
|
||||
return conversation, message
|
||||
return conversation, message
|
||||
except Exception:
|
||||
db.session.rollback()
|
||||
raise
|
||||
|
||||
def _get_conversation_introduction(self, application_generate_entity: AppGenerateEntity) -> str:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ from core.workflow.system_variable import SystemVariable
|
|||
from core.workflow.variable_loader import VariableLoader
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from extensions.ext_redis import redis_client
|
||||
from extensions.otel import WorkflowAppRunnerHandler, trace_span
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.enums import UserFrom
|
||||
from models.workflow import Workflow
|
||||
|
|
@ -56,6 +57,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
|||
self._workflow_execution_repository = workflow_execution_repository
|
||||
self._workflow_node_execution_repository = workflow_node_execution_repository
|
||||
|
||||
@trace_span(WorkflowAppRunnerHandler)
|
||||
def run(self):
|
||||
"""
|
||||
Run application
|
||||
|
|
|
|||
|
|
@ -40,9 +40,6 @@ class EasyUITaskState(TaskState):
|
|||
"""
|
||||
|
||||
llm_result: LLMResult
|
||||
first_token_time: float | None = None
|
||||
last_token_time: float | None = None
|
||||
is_streaming_response: bool = False
|
||||
|
||||
|
||||
class WorkflowTaskState(TaskState):
|
||||
|
|
|
|||
|
|
@ -332,12 +332,6 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
|||
if not self._task_state.llm_result.prompt_messages:
|
||||
self._task_state.llm_result.prompt_messages = chunk.prompt_messages
|
||||
|
||||
# Track streaming response times
|
||||
if self._task_state.first_token_time is None:
|
||||
self._task_state.first_token_time = time.perf_counter()
|
||||
self._task_state.is_streaming_response = True
|
||||
self._task_state.last_token_time = time.perf_counter()
|
||||
|
||||
# handle output moderation chunk
|
||||
should_direct_answer = self._handle_output_moderation_chunk(cast(str, delta_text))
|
||||
if should_direct_answer:
|
||||
|
|
@ -366,7 +360,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
|||
if publisher:
|
||||
publisher.publish(None)
|
||||
if self._conversation_name_generate_thread:
|
||||
self._conversation_name_generate_thread.join()
|
||||
logger.debug("Conversation name generation running as daemon thread")
|
||||
|
||||
def _save_message(self, *, session: Session, trace_manager: TraceQueueManager | None = None):
|
||||
"""
|
||||
|
|
@ -404,18 +398,6 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
|||
message.total_price = usage.total_price
|
||||
message.currency = usage.currency
|
||||
self._task_state.llm_result.usage.latency = message.provider_response_latency
|
||||
|
||||
# Add streaming metrics to usage if available
|
||||
if self._task_state.is_streaming_response and self._task_state.first_token_time:
|
||||
start_time = self.start_at
|
||||
first_token_time = self._task_state.first_token_time
|
||||
last_token_time = self._task_state.last_token_time or first_token_time
|
||||
usage.time_to_first_token = round(first_token_time - start_time, 3)
|
||||
usage.time_to_generate = round(last_token_time - first_token_time, 3)
|
||||
|
||||
# Update metadata with the complete usage info
|
||||
self._task_state.metadata.usage = usage
|
||||
|
||||
message.message_metadata = self._task_state.metadata.model_dump_json()
|
||||
|
||||
if trace_manager:
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue