Merge branch 'main' into feat-agent-mask

This commit is contained in:
GuanMu 2025-11-27 10:43:48 +08:00 committed by GitHub
commit c0916e6eb2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
440 changed files with 26882 additions and 2089 deletions

6
.cursorrules Normal file
View File

@ -0,0 +1,6 @@
# Cursor Rules for Dify Project
## Automated Test Generation
- Use `web/testing/testing.md` as the canonical instruction set for generating frontend automated tests.
- When proposing or saving tests, re-read that document and follow every requirement.

12
.github/copilot-instructions.md vendored Normal file
View File

@ -0,0 +1,12 @@
# Copilot Instructions
GitHub Copilot must follow the unified frontend testing requirements documented in `web/testing/testing.md`.
Key reminders:
- Generate tests using the mandated tech stack, naming, and code style (AAA pattern, `fireEvent`, descriptive test names, cleans up mocks).
- Cover rendering, prop combinations, and edge cases by default; extend coverage for hooks, routing, async flows, and domain-specific components when applicable.
- Target >95% line and branch coverage and 100% function/statement coverage.
- Apply the project's mocking conventions for i18n, toast notifications, and Next.js utilities.
Any suggestions from Copilot that conflict with `web/testing/testing.md` should be revised before acceptance.

View File

@ -0,0 +1,5 @@
# Windsurf Testing Rules
- Use `web/testing/testing.md` as the single source of truth for frontend automated testing.
- Honor every requirement in that document when generating or accepting tests.
- When proposing or saving tests, re-read that document and follow every requirement.

View File

@ -77,6 +77,8 @@ How we prioritize:
For setting up the frontend service, please refer to our comprehensive [guide](https://github.com/langgenius/dify/blob/main/web/README.md) in the `web/README.md` file. This document provides detailed instructions to help you set up the frontend environment properly.
**Testing**: All React components must have comprehensive test coverage. See [web/testing/testing.md](https://github.com/langgenius/dify/blob/main/web/testing/testing.md) for the canonical frontend testing guidelines and follow every requirement described there.
#### Backend
For setting up the backend service, kindly refer to our detailed [instructions](https://github.com/langgenius/dify/blob/main/api/README.md) in the `api/README.md` file. This document contains step-by-step guidance to help you get the backend up and running smoothly.

View File

@ -36,6 +36,12 @@
<img alt="Issues closed" src="https://img.shields.io/github/issues-search?query=repo%3Alanggenius%2Fdify%20is%3Aclosed&label=issues%20closed&labelColor=%20%237d89b0&color=%20%235d6b98"></a>
<a href="https://github.com/langgenius/dify/discussions/" target="_blank">
<img alt="Discussion posts" src="https://img.shields.io/github/discussions/langgenius/dify?labelColor=%20%239b8afb&color=%20%237a5af8"></a>
<a href="https://insights.linuxfoundation.org/project/langgenius-dify" target="_blank">
<img alt="LFX Health Score" src="https://insights.linuxfoundation.org/api/badge/health-score?project=langgenius-dify"></a>
<a href="https://insights.linuxfoundation.org/project/langgenius-dify" target="_blank">
<img alt="LFX Contributors" src="https://insights.linuxfoundation.org/api/badge/contributors?project=langgenius-dify"></a>
<a href="https://insights.linuxfoundation.org/project/langgenius-dify" target="_blank">
<img alt="LFX Active Contributors" src="https://insights.linuxfoundation.org/api/badge/active-contributors?project=langgenius-dify"></a>
</p>
<p align="center">

View File

@ -176,6 +176,7 @@ WEAVIATE_ENDPOINT=http://localhost:8080
WEAVIATE_API_KEY=WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih
WEAVIATE_GRPC_ENABLED=false
WEAVIATE_BATCH_SIZE=100
WEAVIATE_TOKENIZATION=word
# OceanBase Vector configuration
OCEANBASE_VECTOR_HOST=127.0.0.1

View File

@ -16,6 +16,7 @@ layers =
graph
nodes
node_events
runtime
entities
containers =
core.workflow

View File

@ -57,7 +57,7 @@ RUN \
# for gmpy2 \
libgmp-dev libmpfr-dev libmpc-dev \
# For Security
expat libldap-2.5-0 perl libsqlite3-0 zlib1g \
expat libldap-2.5-0=2.5.13+dfsg-5 perl libsqlite3-0=3.40.1-2+deb12u2 zlib1g=1:1.2.13.dfsg-1 \
# install fonts to support the use of tools like pypdfium2
fonts-noto-cjk \
# install a package to improve the accuracy of guessing mime type and file extension
@ -73,7 +73,8 @@ COPY --from=packages ${VIRTUAL_ENV} ${VIRTUAL_ENV}
ENV PATH="${VIRTUAL_ENV}/bin:${PATH}"
# Download nltk data
RUN python -c "import nltk; nltk.download('punkt'); nltk.download('averaged_perceptron_tagger')"
RUN mkdir -p /usr/local/share/nltk_data && NLTK_DATA=/usr/local/share/nltk_data python -c "import nltk; nltk.download('punkt'); nltk.download('averaged_perceptron_tagger'); nltk.download('stopwords')" \
&& chmod -R 755 /usr/local/share/nltk_data
ENV TIKTOKEN_CACHE_DIR=/app/api/.tiktoken_cache
@ -86,7 +87,15 @@ COPY . /app/api/
COPY docker/entrypoint.sh /entrypoint.sh
RUN chmod +x /entrypoint.sh
# Create non-root user and set permissions
RUN groupadd -r -g 1001 dify && \
useradd -r -u 1001 -g 1001 -s /bin/bash dify && \
mkdir -p /home/dify && \
chown -R 1001:1001 /app /home/dify ${TIKTOKEN_CACHE_DIR} /entrypoint.sh
ARG COMMIT_SHA
ENV COMMIT_SHA=${COMMIT_SHA}
ENV NLTK_DATA=/usr/local/share/nltk_data
USER 1001
ENTRYPOINT ["/bin/bash", "/entrypoint.sh"]

View File

@ -31,3 +31,8 @@ class WeaviateConfig(BaseSettings):
description="Number of objects to be processed in a single batch operation (default is 100)",
default=100,
)
WEAVIATE_TOKENIZATION: str | None = Field(
description="Tokenization for Weaviate (default is word)",
default="word",
)

View File

@ -17,7 +17,6 @@ from controllers.console.app.error import (
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import (
ModelCurrentlyNotSupportError,
@ -32,6 +31,7 @@ from libs.login import current_user, login_required
from models import Account
from models.model import AppMode
from services.app_generate_service import AppGenerateService
from services.app_task_service import AppTaskService
from services.errors.llm import InvokeRateLimitError
logger = logging.getLogger(__name__)
@ -121,7 +121,13 @@ class CompletionMessageStopApi(Resource):
def post(self, app_model, task_id):
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, current_user.id)
AppTaskService.stop_task(
task_id=task_id,
invoke_from=InvokeFrom.DEBUGGER,
user_id=current_user.id,
app_mode=AppMode.value_of(app_model.mode),
)
return {"result": "success"}, 200
@ -220,6 +226,12 @@ class ChatMessageStopApi(Resource):
def post(self, app_model, task_id):
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, current_user.id)
AppTaskService.stop_task(
task_id=task_id,
invoke_from=InvokeFrom.DEBUGGER,
user_id=current_user.id,
app_mode=AppMode.value_of(app_model.mode),
)
return {"result": "success"}, 200

View File

@ -369,6 +369,58 @@ class MessageSuggestedQuestionApi(Resource):
return {"data": questions}
# Shared parser for feedback export (used for both documentation and runtime parsing)
feedback_export_parser = (
console_ns.parser()
.add_argument("from_source", type=str, choices=["user", "admin"], location="args", help="Filter by feedback source")
.add_argument("rating", type=str, choices=["like", "dislike"], location="args", help="Filter by rating")
.add_argument("has_comment", type=bool, location="args", help="Only include feedback with comments")
.add_argument("start_date", type=str, location="args", help="Start date (YYYY-MM-DD)")
.add_argument("end_date", type=str, location="args", help="End date (YYYY-MM-DD)")
.add_argument("format", type=str, choices=["csv", "json"], default="csv", location="args", help="Export format")
)
@console_ns.route("/apps/<uuid:app_id>/feedbacks/export")
class MessageFeedbackExportApi(Resource):
@console_ns.doc("export_feedbacks")
@console_ns.doc(description="Export user feedback data for Google Sheets")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(feedback_export_parser)
@console_ns.response(200, "Feedback data exported successfully")
@console_ns.response(400, "Invalid parameters")
@console_ns.response(500, "Internal server error")
@get_app_model
@setup_required
@login_required
@account_initialization_required
def get(self, app_model):
args = feedback_export_parser.parse_args()
# Import the service function
from services.feedback_service import FeedbackService
try:
export_data = FeedbackService.export_feedbacks(
app_id=app_model.id,
from_source=args.get("from_source"),
rating=args.get("rating"),
has_comment=args.get("has_comment"),
start_date=args.get("start_date"),
end_date=args.get("end_date"),
format_type=args.get("format", "csv"),
)
return export_data
except ValueError as e:
logger.exception("Parameter validation error in feedback export")
return {"error": f"Parameter validation error: {str(e)}"}, 400
except Exception as e:
logger.exception("Error exporting feedback data")
raise InternalServerError(str(e))
@console_ns.route("/apps/<uuid:app_id>/messages/<uuid:message_id>")
class MessageApi(Resource):
@console_ns.doc("get_message")

View File

@ -90,14 +90,20 @@ workflow_pagination_model = console_ns.model("WorkflowPagination", workflow_pagi
# Otherwise register it here
from fields.end_user_fields import simple_end_user_fields
simple_end_user_model = None
try:
simple_end_user_model = console_ns.models.get("SimpleEndUser")
except (KeyError, AttributeError):
except AttributeError:
pass
if simple_end_user_model is None:
simple_end_user_model = console_ns.model("SimpleEndUser", simple_end_user_fields)
workflow_run_node_execution_model = None
try:
workflow_run_node_execution_model = console_ns.models.get("WorkflowRunNodeExecution")
except (KeyError, AttributeError):
except AttributeError:
pass
if workflow_run_node_execution_model is None:
workflow_run_node_execution_model = console_ns.model("WorkflowRunNodeExecution", workflow_run_node_execution_fields)

View File

@ -1,14 +1,13 @@
import logging
from flask_restx import Resource, marshal_with, reqparse
from flask import request
from flask_restx import Resource, marshal_with
from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound
from configs import dify_config
from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from extensions.ext_database import db
from fields.workflow_trigger_fields import trigger_fields, triggers_list_fields, webhook_trigger_fields
from libs.login import current_user, login_required
@ -16,12 +15,35 @@ from models.enums import AppTriggerStatus
from models.model import Account, App, AppMode
from models.trigger import AppTrigger, WorkflowWebhookTrigger
from .. import console_ns
from ..app.wraps import get_app_model
from ..wraps import account_initialization_required, edit_permission_required, setup_required
logger = logging.getLogger(__name__)
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class Parser(BaseModel):
node_id: str
class ParserEnable(BaseModel):
trigger_id: str
enable_trigger: bool
console_ns.schema_model(Parser.__name__, Parser.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
console_ns.schema_model(
ParserEnable.__name__, ParserEnable.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
@console_ns.route("/apps/<uuid:app_id>/workflows/triggers/webhook")
class WebhookTriggerApi(Resource):
"""Webhook Trigger API"""
@console_ns.expect(console_ns.models[Parser.__name__], validate=True)
@setup_required
@login_required
@account_initialization_required
@ -29,10 +51,9 @@ class WebhookTriggerApi(Resource):
@marshal_with(webhook_trigger_fields)
def get(self, app_model: App):
"""Get webhook trigger for a node"""
parser = reqparse.RequestParser().add_argument("node_id", type=str, required=True, help="Node ID is required")
args = parser.parse_args()
args = Parser.model_validate(request.args.to_dict(flat=True)) # type: ignore
node_id = str(args["node_id"])
node_id = args.node_id
with Session(db.engine) as session:
# Get webhook trigger for this app and node
@ -51,6 +72,7 @@ class WebhookTriggerApi(Resource):
return webhook_trigger
@console_ns.route("/apps/<uuid:app_id>/triggers")
class AppTriggersApi(Resource):
"""App Triggers list API"""
@ -90,7 +112,9 @@ class AppTriggersApi(Resource):
return {"data": triggers}
@console_ns.route("/apps/<uuid:app_id>/trigger-enable")
class AppTriggerEnableApi(Resource):
@console_ns.expect(console_ns.models[ParserEnable.__name__], validate=True)
@setup_required
@login_required
@account_initialization_required
@ -99,17 +123,11 @@ class AppTriggerEnableApi(Resource):
@marshal_with(trigger_fields)
def post(self, app_model: App):
"""Update app trigger (enable/disable)"""
parser = (
reqparse.RequestParser()
.add_argument("trigger_id", type=str, required=True, nullable=False, location="json")
.add_argument("enable_trigger", type=bool, required=True, nullable=False, location="json")
)
args = parser.parse_args()
args = ParserEnable.model_validate(console_ns.payload)
assert current_user.current_tenant_id is not None
trigger_id = args["trigger_id"]
trigger_id = args.trigger_id
with Session(db.engine) as session:
# Find the trigger using select
trigger = session.execute(
@ -124,7 +142,7 @@ class AppTriggerEnableApi(Resource):
raise NotFound("Trigger not found")
# Update status based on enable_trigger boolean
trigger.status = AppTriggerStatus.ENABLED if args["enable_trigger"] else AppTriggerStatus.DISABLED
trigger.status = AppTriggerStatus.ENABLED if args.enable_trigger else AppTriggerStatus.DISABLED
session.commit()
session.refresh(trigger)
@ -137,8 +155,3 @@ class AppTriggerEnableApi(Resource):
trigger.icon = "" # type: ignore
return trigger
console_ns.add_resource(WebhookTriggerApi, "/apps/<uuid:app_id>/workflows/triggers/webhook")
console_ns.add_resource(AppTriggersApi, "/apps/<uuid:app_id>/triggers")
console_ns.add_resource(AppTriggerEnableApi, "/apps/<uuid:app_id>/trigger-enable")

View File

@ -15,7 +15,6 @@ from controllers.console.app.error import (
from controllers.console.explore.error import NotChatAppError, NotCompletionAppError
from controllers.console.explore.wraps import InstalledAppResource
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import (
ModelCurrentlyNotSupportError,
@ -31,6 +30,7 @@ from libs.login import current_user
from models import Account
from models.model import AppMode
from services.app_generate_service import AppGenerateService
from services.app_task_service import AppTaskService
from services.errors.llm import InvokeRateLimitError
from .. import console_ns
@ -46,7 +46,7 @@ logger = logging.getLogger(__name__)
class CompletionApi(InstalledAppResource):
def post(self, installed_app):
app_model = installed_app.app
if app_model.mode != "completion":
if app_model.mode != AppMode.COMPLETION:
raise NotCompletionAppError()
parser = (
@ -102,12 +102,18 @@ class CompletionApi(InstalledAppResource):
class CompletionStopApi(InstalledAppResource):
def post(self, installed_app, task_id):
app_model = installed_app.app
if app_model.mode != "completion":
if app_model.mode != AppMode.COMPLETION:
raise NotCompletionAppError()
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id)
AppTaskService.stop_task(
task_id=task_id,
invoke_from=InvokeFrom.EXPLORE,
user_id=current_user.id,
app_mode=AppMode.value_of(app_model.mode),
)
return {"result": "success"}, 200
@ -184,6 +190,12 @@ class ChatStopApi(InstalledAppResource):
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id)
AppTaskService.stop_task(
task_id=task_id,
invoke_from=InvokeFrom.EXPLORE,
user_id=current_user.id,
app_mode=app_mode,
)
return {"result": "success"}, 200

View File

@ -1,8 +1,10 @@
from datetime import datetime
from typing import Literal
import pytz
from flask import request
from flask_restx import Resource, fields, marshal_with, reqparse
from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field, field_validator, model_validator
from sqlalchemy import select
from sqlalchemy.orm import Session
@ -42,20 +44,198 @@ from services.account_service import AccountService
from services.billing_service import BillingService
from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
def _init_parser():
parser = reqparse.RequestParser()
if dify_config.EDITION == "CLOUD":
parser.add_argument("invitation_code", type=str, location="json")
parser.add_argument("interface_language", type=supported_language, required=True, location="json").add_argument(
"timezone", type=timezone, required=True, location="json"
)
return parser
class AccountInitPayload(BaseModel):
interface_language: str
timezone: str
invitation_code: str | None = None
@field_validator("interface_language")
@classmethod
def validate_language(cls, value: str) -> str:
return supported_language(value)
@field_validator("timezone")
@classmethod
def validate_timezone(cls, value: str) -> str:
return timezone(value)
class AccountNamePayload(BaseModel):
name: str = Field(min_length=3, max_length=30)
class AccountAvatarPayload(BaseModel):
avatar: str
class AccountInterfaceLanguagePayload(BaseModel):
interface_language: str
@field_validator("interface_language")
@classmethod
def validate_language(cls, value: str) -> str:
return supported_language(value)
class AccountInterfaceThemePayload(BaseModel):
interface_theme: Literal["light", "dark"]
class AccountTimezonePayload(BaseModel):
timezone: str
@field_validator("timezone")
@classmethod
def validate_timezone(cls, value: str) -> str:
return timezone(value)
class AccountPasswordPayload(BaseModel):
password: str | None = None
new_password: str
repeat_new_password: str
@model_validator(mode="after")
def check_passwords_match(self) -> "AccountPasswordPayload":
if self.new_password != self.repeat_new_password:
raise RepeatPasswordNotMatchError()
return self
class AccountDeletePayload(BaseModel):
token: str
code: str
class AccountDeletionFeedbackPayload(BaseModel):
email: str
feedback: str
@field_validator("email")
@classmethod
def validate_email(cls, value: str) -> str:
return email(value)
class EducationActivatePayload(BaseModel):
token: str
institution: str
role: str
class EducationAutocompleteQuery(BaseModel):
keywords: str
page: int = 0
limit: int = 20
class ChangeEmailSendPayload(BaseModel):
email: str
language: str | None = None
phase: str | None = None
token: str | None = None
@field_validator("email")
@classmethod
def validate_email(cls, value: str) -> str:
return email(value)
class ChangeEmailValidityPayload(BaseModel):
email: str
code: str
token: str
@field_validator("email")
@classmethod
def validate_email(cls, value: str) -> str:
return email(value)
class ChangeEmailResetPayload(BaseModel):
new_email: str
token: str
@field_validator("new_email")
@classmethod
def validate_email(cls, value: str) -> str:
return email(value)
class CheckEmailUniquePayload(BaseModel):
email: str
@field_validator("email")
@classmethod
def validate_email(cls, value: str) -> str:
return email(value)
console_ns.schema_model(
AccountInitPayload.__name__, AccountInitPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
console_ns.schema_model(
AccountNamePayload.__name__, AccountNamePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
console_ns.schema_model(
AccountAvatarPayload.__name__, AccountAvatarPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
console_ns.schema_model(
AccountInterfaceLanguagePayload.__name__,
AccountInterfaceLanguagePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
AccountInterfaceThemePayload.__name__,
AccountInterfaceThemePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
AccountTimezonePayload.__name__,
AccountTimezonePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
AccountPasswordPayload.__name__,
AccountPasswordPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
AccountDeletePayload.__name__,
AccountDeletePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
AccountDeletionFeedbackPayload.__name__,
AccountDeletionFeedbackPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
EducationActivatePayload.__name__,
EducationActivatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
EducationAutocompleteQuery.__name__,
EducationAutocompleteQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
ChangeEmailSendPayload.__name__,
ChangeEmailSendPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
ChangeEmailValidityPayload.__name__,
ChangeEmailValidityPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
ChangeEmailResetPayload.__name__,
ChangeEmailResetPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
CheckEmailUniquePayload.__name__,
CheckEmailUniquePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
@console_ns.route("/account/init")
class AccountInitApi(Resource):
@console_ns.expect(_init_parser())
@console_ns.expect(console_ns.models[AccountInitPayload.__name__])
@setup_required
@login_required
def post(self):
@ -64,17 +244,18 @@ class AccountInitApi(Resource):
if account.status == "active":
raise AccountAlreadyInitedError()
args = _init_parser().parse_args()
payload = console_ns.payload or {}
args = AccountInitPayload.model_validate(payload)
if dify_config.EDITION == "CLOUD":
if not args["invitation_code"]:
if not args.invitation_code:
raise ValueError("invitation_code is required")
# check invitation code
invitation_code = (
db.session.query(InvitationCode)
.where(
InvitationCode.code == args["invitation_code"],
InvitationCode.code == args.invitation_code,
InvitationCode.status == "unused",
)
.first()
@ -88,8 +269,8 @@ class AccountInitApi(Resource):
invitation_code.used_by_tenant_id = account.current_tenant_id
invitation_code.used_by_account_id = account.id
account.interface_language = args["interface_language"]
account.timezone = args["timezone"]
account.interface_language = args.interface_language
account.timezone = args.timezone
account.interface_theme = "light"
account.status = "active"
account.initialized_at = naive_utc_now()
@ -110,137 +291,104 @@ class AccountProfileApi(Resource):
return current_user
parser_name = reqparse.RequestParser().add_argument("name", type=str, required=True, location="json")
@console_ns.route("/account/name")
class AccountNameApi(Resource):
@console_ns.expect(parser_name)
@console_ns.expect(console_ns.models[AccountNamePayload.__name__])
@setup_required
@login_required
@account_initialization_required
@marshal_with(account_fields)
def post(self):
current_user, _ = current_account_with_tenant()
args = parser_name.parse_args()
# Validate account name length
if len(args["name"]) < 3 or len(args["name"]) > 30:
raise ValueError("Account name must be between 3 and 30 characters.")
updated_account = AccountService.update_account(current_user, name=args["name"])
payload = console_ns.payload or {}
args = AccountNamePayload.model_validate(payload)
updated_account = AccountService.update_account(current_user, name=args.name)
return updated_account
parser_avatar = reqparse.RequestParser().add_argument("avatar", type=str, required=True, location="json")
@console_ns.route("/account/avatar")
class AccountAvatarApi(Resource):
@console_ns.expect(parser_avatar)
@console_ns.expect(console_ns.models[AccountAvatarPayload.__name__])
@setup_required
@login_required
@account_initialization_required
@marshal_with(account_fields)
def post(self):
current_user, _ = current_account_with_tenant()
args = parser_avatar.parse_args()
payload = console_ns.payload or {}
args = AccountAvatarPayload.model_validate(payload)
updated_account = AccountService.update_account(current_user, avatar=args["avatar"])
updated_account = AccountService.update_account(current_user, avatar=args.avatar)
return updated_account
parser_interface = reqparse.RequestParser().add_argument(
"interface_language", type=supported_language, required=True, location="json"
)
@console_ns.route("/account/interface-language")
class AccountInterfaceLanguageApi(Resource):
@console_ns.expect(parser_interface)
@console_ns.expect(console_ns.models[AccountInterfaceLanguagePayload.__name__])
@setup_required
@login_required
@account_initialization_required
@marshal_with(account_fields)
def post(self):
current_user, _ = current_account_with_tenant()
args = parser_interface.parse_args()
payload = console_ns.payload or {}
args = AccountInterfaceLanguagePayload.model_validate(payload)
updated_account = AccountService.update_account(current_user, interface_language=args["interface_language"])
updated_account = AccountService.update_account(current_user, interface_language=args.interface_language)
return updated_account
parser_theme = reqparse.RequestParser().add_argument(
"interface_theme", type=str, choices=["light", "dark"], required=True, location="json"
)
@console_ns.route("/account/interface-theme")
class AccountInterfaceThemeApi(Resource):
@console_ns.expect(parser_theme)
@console_ns.expect(console_ns.models[AccountInterfaceThemePayload.__name__])
@setup_required
@login_required
@account_initialization_required
@marshal_with(account_fields)
def post(self):
current_user, _ = current_account_with_tenant()
args = parser_theme.parse_args()
payload = console_ns.payload or {}
args = AccountInterfaceThemePayload.model_validate(payload)
updated_account = AccountService.update_account(current_user, interface_theme=args["interface_theme"])
updated_account = AccountService.update_account(current_user, interface_theme=args.interface_theme)
return updated_account
parser_timezone = reqparse.RequestParser().add_argument("timezone", type=str, required=True, location="json")
@console_ns.route("/account/timezone")
class AccountTimezoneApi(Resource):
@console_ns.expect(parser_timezone)
@console_ns.expect(console_ns.models[AccountTimezonePayload.__name__])
@setup_required
@login_required
@account_initialization_required
@marshal_with(account_fields)
def post(self):
current_user, _ = current_account_with_tenant()
args = parser_timezone.parse_args()
payload = console_ns.payload or {}
args = AccountTimezonePayload.model_validate(payload)
# Validate timezone string, e.g. America/New_York, Asia/Shanghai
if args["timezone"] not in pytz.all_timezones:
raise ValueError("Invalid timezone string.")
updated_account = AccountService.update_account(current_user, timezone=args["timezone"])
updated_account = AccountService.update_account(current_user, timezone=args.timezone)
return updated_account
parser_pw = (
reqparse.RequestParser()
.add_argument("password", type=str, required=False, location="json")
.add_argument("new_password", type=str, required=True, location="json")
.add_argument("repeat_new_password", type=str, required=True, location="json")
)
@console_ns.route("/account/password")
class AccountPasswordApi(Resource):
@console_ns.expect(parser_pw)
@console_ns.expect(console_ns.models[AccountPasswordPayload.__name__])
@setup_required
@login_required
@account_initialization_required
@marshal_with(account_fields)
def post(self):
current_user, _ = current_account_with_tenant()
args = parser_pw.parse_args()
if args["new_password"] != args["repeat_new_password"]:
raise RepeatPasswordNotMatchError()
payload = console_ns.payload or {}
args = AccountPasswordPayload.model_validate(payload)
try:
AccountService.update_account_password(current_user, args["password"], args["new_password"])
AccountService.update_account_password(current_user, args.password, args.new_password)
except ServiceCurrentPasswordIncorrectError:
raise CurrentPasswordIncorrectError()
@ -316,25 +464,19 @@ class AccountDeleteVerifyApi(Resource):
return {"result": "success", "data": token}
parser_delete = (
reqparse.RequestParser()
.add_argument("token", type=str, required=True, location="json")
.add_argument("code", type=str, required=True, location="json")
)
@console_ns.route("/account/delete")
class AccountDeleteApi(Resource):
@console_ns.expect(parser_delete)
@console_ns.expect(console_ns.models[AccountDeletePayload.__name__])
@setup_required
@login_required
@account_initialization_required
def post(self):
account, _ = current_account_with_tenant()
args = parser_delete.parse_args()
payload = console_ns.payload or {}
args = AccountDeletePayload.model_validate(payload)
if not AccountService.verify_account_deletion_code(args["token"], args["code"]):
if not AccountService.verify_account_deletion_code(args.token, args.code):
raise InvalidAccountDeletionCodeError()
AccountService.delete_account(account)
@ -342,21 +484,15 @@ class AccountDeleteApi(Resource):
return {"result": "success"}
parser_feedback = (
reqparse.RequestParser()
.add_argument("email", type=str, required=True, location="json")
.add_argument("feedback", type=str, required=True, location="json")
)
@console_ns.route("/account/delete/feedback")
class AccountDeleteUpdateFeedbackApi(Resource):
@console_ns.expect(parser_feedback)
@console_ns.expect(console_ns.models[AccountDeletionFeedbackPayload.__name__])
@setup_required
def post(self):
args = parser_feedback.parse_args()
payload = console_ns.payload or {}
args = AccountDeletionFeedbackPayload.model_validate(payload)
BillingService.update_account_deletion_feedback(args["email"], args["feedback"])
BillingService.update_account_deletion_feedback(args.email, args.feedback)
return {"result": "success"}
@ -379,14 +515,6 @@ class EducationVerifyApi(Resource):
return BillingService.EducationIdentity.verify(account.id, account.email)
parser_edu = (
reqparse.RequestParser()
.add_argument("token", type=str, required=True, location="json")
.add_argument("institution", type=str, required=True, location="json")
.add_argument("role", type=str, required=True, location="json")
)
@console_ns.route("/account/education")
class EducationApi(Resource):
status_fields = {
@ -396,7 +524,7 @@ class EducationApi(Resource):
"allow_refresh": fields.Boolean,
}
@console_ns.expect(parser_edu)
@console_ns.expect(console_ns.models[EducationActivatePayload.__name__])
@setup_required
@login_required
@account_initialization_required
@ -405,9 +533,10 @@ class EducationApi(Resource):
def post(self):
account, _ = current_account_with_tenant()
args = parser_edu.parse_args()
payload = console_ns.payload or {}
args = EducationActivatePayload.model_validate(payload)
return BillingService.EducationIdentity.activate(account, args["token"], args["institution"], args["role"])
return BillingService.EducationIdentity.activate(account, args.token, args.institution, args.role)
@setup_required
@login_required
@ -425,14 +554,6 @@ class EducationApi(Resource):
return res
parser_autocomplete = (
reqparse.RequestParser()
.add_argument("keywords", type=str, required=True, location="args")
.add_argument("page", type=int, required=False, location="args", default=0)
.add_argument("limit", type=int, required=False, location="args", default=20)
)
@console_ns.route("/account/education/autocomplete")
class EducationAutoCompleteApi(Resource):
data_fields = {
@ -441,7 +562,7 @@ class EducationAutoCompleteApi(Resource):
"has_next": fields.Boolean,
}
@console_ns.expect(parser_autocomplete)
@console_ns.expect(console_ns.models[EducationAutocompleteQuery.__name__])
@setup_required
@login_required
@account_initialization_required
@ -449,46 +570,39 @@ class EducationAutoCompleteApi(Resource):
@cloud_edition_billing_enabled
@marshal_with(data_fields)
def get(self):
args = parser_autocomplete.parse_args()
payload = request.args.to_dict(flat=True) # type: ignore
args = EducationAutocompleteQuery.model_validate(payload)
return BillingService.EducationIdentity.autocomplete(args["keywords"], args["page"], args["limit"])
parser_change_email = (
reqparse.RequestParser()
.add_argument("email", type=email, required=True, location="json")
.add_argument("language", type=str, required=False, location="json")
.add_argument("phase", type=str, required=False, location="json")
.add_argument("token", type=str, required=False, location="json")
)
return BillingService.EducationIdentity.autocomplete(args.keywords, args.page, args.limit)
@console_ns.route("/account/change-email")
class ChangeEmailSendEmailApi(Resource):
@console_ns.expect(parser_change_email)
@console_ns.expect(console_ns.models[ChangeEmailSendPayload.__name__])
@enable_change_email
@setup_required
@login_required
@account_initialization_required
def post(self):
current_user, _ = current_account_with_tenant()
args = parser_change_email.parse_args()
payload = console_ns.payload or {}
args = ChangeEmailSendPayload.model_validate(payload)
ip_address = extract_remote_ip(request)
if AccountService.is_email_send_ip_limit(ip_address):
raise EmailSendIpLimitError()
if args["language"] is not None and args["language"] == "zh-Hans":
if args.language is not None and args.language == "zh-Hans":
language = "zh-Hans"
else:
language = "en-US"
account = None
user_email = args["email"]
if args["phase"] is not None and args["phase"] == "new_email":
if args["token"] is None:
user_email = args.email
if args.phase is not None and args.phase == "new_email":
if args.token is None:
raise InvalidTokenError()
reset_data = AccountService.get_change_email_data(args["token"])
reset_data = AccountService.get_change_email_data(args.token)
if reset_data is None:
raise InvalidTokenError()
user_email = reset_data.get("email", "")
@ -497,118 +611,103 @@ class ChangeEmailSendEmailApi(Resource):
raise InvalidEmailError()
else:
with Session(db.engine) as session:
account = session.execute(select(Account).filter_by(email=args["email"])).scalar_one_or_none()
account = session.execute(select(Account).filter_by(email=args.email)).scalar_one_or_none()
if account is None:
raise AccountNotFound()
token = AccountService.send_change_email_email(
account=account, email=args["email"], old_email=user_email, language=language, phase=args["phase"]
account=account, email=args.email, old_email=user_email, language=language, phase=args.phase
)
return {"result": "success", "data": token}
parser_validity = (
reqparse.RequestParser()
.add_argument("email", type=email, required=True, location="json")
.add_argument("code", type=str, required=True, location="json")
.add_argument("token", type=str, required=True, nullable=False, location="json")
)
@console_ns.route("/account/change-email/validity")
class ChangeEmailCheckApi(Resource):
@console_ns.expect(parser_validity)
@console_ns.expect(console_ns.models[ChangeEmailValidityPayload.__name__])
@enable_change_email
@setup_required
@login_required
@account_initialization_required
def post(self):
args = parser_validity.parse_args()
payload = console_ns.payload or {}
args = ChangeEmailValidityPayload.model_validate(payload)
user_email = args["email"]
user_email = args.email
is_change_email_error_rate_limit = AccountService.is_change_email_error_rate_limit(args["email"])
is_change_email_error_rate_limit = AccountService.is_change_email_error_rate_limit(args.email)
if is_change_email_error_rate_limit:
raise EmailChangeLimitError()
token_data = AccountService.get_change_email_data(args["token"])
token_data = AccountService.get_change_email_data(args.token)
if token_data is None:
raise InvalidTokenError()
if user_email != token_data.get("email"):
raise InvalidEmailError()
if args["code"] != token_data.get("code"):
AccountService.add_change_email_error_rate_limit(args["email"])
if args.code != token_data.get("code"):
AccountService.add_change_email_error_rate_limit(args.email)
raise EmailCodeError()
# Verified, revoke the first token
AccountService.revoke_change_email_token(args["token"])
AccountService.revoke_change_email_token(args.token)
# Refresh token data by generating a new token
_, new_token = AccountService.generate_change_email_token(
user_email, code=args["code"], old_email=token_data.get("old_email"), additional_data={}
user_email, code=args.code, old_email=token_data.get("old_email"), additional_data={}
)
AccountService.reset_change_email_error_rate_limit(args["email"])
AccountService.reset_change_email_error_rate_limit(args.email)
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
parser_reset = (
reqparse.RequestParser()
.add_argument("new_email", type=email, required=True, location="json")
.add_argument("token", type=str, required=True, nullable=False, location="json")
)
@console_ns.route("/account/change-email/reset")
class ChangeEmailResetApi(Resource):
@console_ns.expect(parser_reset)
@console_ns.expect(console_ns.models[ChangeEmailResetPayload.__name__])
@enable_change_email
@setup_required
@login_required
@account_initialization_required
@marshal_with(account_fields)
def post(self):
args = parser_reset.parse_args()
payload = console_ns.payload or {}
args = ChangeEmailResetPayload.model_validate(payload)
if AccountService.is_account_in_freeze(args["new_email"]):
if AccountService.is_account_in_freeze(args.new_email):
raise AccountInFreezeError()
if not AccountService.check_email_unique(args["new_email"]):
if not AccountService.check_email_unique(args.new_email):
raise EmailAlreadyInUseError()
reset_data = AccountService.get_change_email_data(args["token"])
reset_data = AccountService.get_change_email_data(args.token)
if not reset_data:
raise InvalidTokenError()
AccountService.revoke_change_email_token(args["token"])
AccountService.revoke_change_email_token(args.token)
old_email = reset_data.get("old_email", "")
current_user, _ = current_account_with_tenant()
if current_user.email != old_email:
raise AccountNotFound()
updated_account = AccountService.update_account_email(current_user, email=args["new_email"])
updated_account = AccountService.update_account_email(current_user, email=args.new_email)
AccountService.send_change_email_completed_notify_email(
email=args["new_email"],
email=args.new_email,
)
return updated_account
parser_check = reqparse.RequestParser().add_argument("email", type=email, required=True, location="json")
@console_ns.route("/account/change-email/check-email-unique")
class CheckEmailUnique(Resource):
@console_ns.expect(parser_check)
@console_ns.expect(console_ns.models[CheckEmailUniquePayload.__name__])
@setup_required
def post(self):
args = parser_check.parse_args()
if AccountService.is_account_in_freeze(args["email"]):
payload = console_ns.payload or {}
args = CheckEmailUniquePayload.model_validate(payload)
if AccountService.is_account_in_freeze(args.email):
raise AccountInFreezeError()
if not AccountService.check_email_unique(args["email"]):
if not AccountService.check_email_unique(args.email):
raise EmailAlreadyInUseError()
return {"result": "success"}

View File

@ -1,7 +1,8 @@
from urllib import parse
from flask import abort, request
from flask_restx import Resource, marshal_with, reqparse
from flask_restx import Resource, marshal_with
from pydantic import BaseModel, Field
import services
from configs import dify_config
@ -31,6 +32,53 @@ from services.account_service import AccountService, RegisterService, TenantServ
from services.errors.account import AccountAlreadyInTenantError
from services.feature_service import FeatureService
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class MemberInvitePayload(BaseModel):
emails: list[str] = Field(default_factory=list)
role: TenantAccountRole
language: str | None = None
class MemberRoleUpdatePayload(BaseModel):
role: str
class OwnerTransferEmailPayload(BaseModel):
language: str | None = None
class OwnerTransferCheckPayload(BaseModel):
code: str
token: str
class OwnerTransferPayload(BaseModel):
token: str
console_ns.schema_model(
MemberInvitePayload.__name__,
MemberInvitePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
MemberRoleUpdatePayload.__name__,
MemberRoleUpdatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
OwnerTransferEmailPayload.__name__,
OwnerTransferEmailPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
OwnerTransferCheckPayload.__name__,
OwnerTransferCheckPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
OwnerTransferPayload.__name__,
OwnerTransferPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
@console_ns.route("/workspaces/current/members")
class MemberListApi(Resource):
@ -48,29 +96,22 @@ class MemberListApi(Resource):
return {"result": "success", "accounts": members}, 200
parser_invite = (
reqparse.RequestParser()
.add_argument("emails", type=list, required=True, location="json")
.add_argument("role", type=str, required=True, default="admin", location="json")
.add_argument("language", type=str, required=False, location="json")
)
@console_ns.route("/workspaces/current/members/invite-email")
class MemberInviteEmailApi(Resource):
"""Invite a new member by email."""
@console_ns.expect(parser_invite)
@console_ns.expect(console_ns.models[MemberInvitePayload.__name__])
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check("members")
def post(self):
args = parser_invite.parse_args()
payload = console_ns.payload or {}
args = MemberInvitePayload.model_validate(payload)
invitee_emails = args["emails"]
invitee_role = args["role"]
interface_language = args["language"]
invitee_emails = args.emails
invitee_role = args.role
interface_language = args.language
if not TenantAccountRole.is_non_owner_role(invitee_role):
return {"code": "invalid-role", "message": "Invalid role"}, 400
current_user, _ = current_account_with_tenant()
@ -146,20 +187,18 @@ class MemberCancelInviteApi(Resource):
}, 200
parser_update = reqparse.RequestParser().add_argument("role", type=str, required=True, location="json")
@console_ns.route("/workspaces/current/members/<uuid:member_id>/update-role")
class MemberUpdateRoleApi(Resource):
"""Update member role."""
@console_ns.expect(parser_update)
@console_ns.expect(console_ns.models[MemberRoleUpdatePayload.__name__])
@setup_required
@login_required
@account_initialization_required
def put(self, member_id):
args = parser_update.parse_args()
new_role = args["role"]
payload = console_ns.payload or {}
args = MemberRoleUpdatePayload.model_validate(payload)
new_role = args.role
if not TenantAccountRole.is_valid_role(new_role):
return {"code": "invalid-role", "message": "Invalid role"}, 400
@ -197,20 +236,18 @@ class DatasetOperatorMemberListApi(Resource):
return {"result": "success", "accounts": members}, 200
parser_send = reqparse.RequestParser().add_argument("language", type=str, required=False, location="json")
@console_ns.route("/workspaces/current/members/send-owner-transfer-confirm-email")
class SendOwnerTransferEmailApi(Resource):
"""Send owner transfer email."""
@console_ns.expect(parser_send)
@console_ns.expect(console_ns.models[OwnerTransferEmailPayload.__name__])
@setup_required
@login_required
@account_initialization_required
@is_allow_transfer_owner
def post(self):
args = parser_send.parse_args()
payload = console_ns.payload or {}
args = OwnerTransferEmailPayload.model_validate(payload)
ip_address = extract_remote_ip(request)
if AccountService.is_email_send_ip_limit(ip_address):
raise EmailSendIpLimitError()
@ -221,7 +258,7 @@ class SendOwnerTransferEmailApi(Resource):
if not TenantService.is_owner(current_user, current_user.current_tenant):
raise NotOwnerError()
if args["language"] is not None and args["language"] == "zh-Hans":
if args.language is not None and args.language == "zh-Hans":
language = "zh-Hans"
else:
language = "en-US"
@ -238,22 +275,16 @@ class SendOwnerTransferEmailApi(Resource):
return {"result": "success", "data": token}
parser_owner = (
reqparse.RequestParser()
.add_argument("code", type=str, required=True, location="json")
.add_argument("token", type=str, required=True, nullable=False, location="json")
)
@console_ns.route("/workspaces/current/members/owner-transfer-check")
class OwnerTransferCheckApi(Resource):
@console_ns.expect(parser_owner)
@console_ns.expect(console_ns.models[OwnerTransferCheckPayload.__name__])
@setup_required
@login_required
@account_initialization_required
@is_allow_transfer_owner
def post(self):
args = parser_owner.parse_args()
payload = console_ns.payload or {}
args = OwnerTransferCheckPayload.model_validate(payload)
# check if the current user is the owner of the workspace
current_user, _ = current_account_with_tenant()
if not current_user.current_tenant:
@ -267,41 +298,37 @@ class OwnerTransferCheckApi(Resource):
if is_owner_transfer_error_rate_limit:
raise OwnerTransferLimitError()
token_data = AccountService.get_owner_transfer_data(args["token"])
token_data = AccountService.get_owner_transfer_data(args.token)
if token_data is None:
raise InvalidTokenError()
if user_email != token_data.get("email"):
raise InvalidEmailError()
if args["code"] != token_data.get("code"):
if args.code != token_data.get("code"):
AccountService.add_owner_transfer_error_rate_limit(user_email)
raise EmailCodeError()
# Verified, revoke the first token
AccountService.revoke_owner_transfer_token(args["token"])
AccountService.revoke_owner_transfer_token(args.token)
# Refresh token data by generating a new token
_, new_token = AccountService.generate_owner_transfer_token(user_email, code=args["code"], additional_data={})
_, new_token = AccountService.generate_owner_transfer_token(user_email, code=args.code, additional_data={})
AccountService.reset_owner_transfer_error_rate_limit(user_email)
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
parser_owner_transfer = reqparse.RequestParser().add_argument(
"token", type=str, required=True, nullable=False, location="json"
)
@console_ns.route("/workspaces/current/members/<uuid:member_id>/owner-transfer")
class OwnerTransfer(Resource):
@console_ns.expect(parser_owner_transfer)
@console_ns.expect(console_ns.models[OwnerTransferPayload.__name__])
@setup_required
@login_required
@account_initialization_required
@is_allow_transfer_owner
def post(self, member_id):
args = parser_owner_transfer.parse_args()
payload = console_ns.payload or {}
args = OwnerTransferPayload.model_validate(payload)
# check if the current user is the owner of the workspace
current_user, _ = current_account_with_tenant()
@ -313,14 +340,14 @@ class OwnerTransfer(Resource):
if current_user.id == str(member_id):
raise CannotTransferOwnerToSelfError()
transfer_token_data = AccountService.get_owner_transfer_data(args["token"])
transfer_token_data = AccountService.get_owner_transfer_data(args.token)
if not transfer_token_data:
raise InvalidTokenError()
if transfer_token_data.get("email") != current_user.email:
raise InvalidEmailError()
AccountService.revoke_owner_transfer_token(args["token"])
AccountService.revoke_owner_transfer_token(args.token)
member = db.session.get(Account, str(member_id))
if not member:

View File

@ -1,31 +1,123 @@
import io
from typing import Any, Literal
from flask import send_file
from flask_restx import Resource, reqparse
from flask import request, send_file
from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.utils.encoders import jsonable_encoder
from libs.helper import StrLen, uuid_value
from libs.helper import uuid_value
from libs.login import current_account_with_tenant, login_required
from services.billing_service import BillingService
from services.model_provider_service import ModelProviderService
parser_model = reqparse.RequestParser().add_argument(
"model_type",
type=str,
required=False,
nullable=True,
choices=[mt.value for mt in ModelType],
location="args",
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class ParserModelList(BaseModel):
model_type: ModelType | None = None
class ParserCredentialId(BaseModel):
credential_id: str | None = None
@field_validator("credential_id")
@classmethod
def validate_optional_credential_id(cls, value: str | None) -> str | None:
if value is None:
return value
return uuid_value(value)
class ParserCredentialCreate(BaseModel):
credentials: dict[str, Any]
name: str | None = Field(default=None, max_length=30)
class ParserCredentialUpdate(BaseModel):
credential_id: str
credentials: dict[str, Any]
name: str | None = Field(default=None, max_length=30)
@field_validator("credential_id")
@classmethod
def validate_update_credential_id(cls, value: str) -> str:
return uuid_value(value)
class ParserCredentialDelete(BaseModel):
credential_id: str
@field_validator("credential_id")
@classmethod
def validate_delete_credential_id(cls, value: str) -> str:
return uuid_value(value)
class ParserCredentialSwitch(BaseModel):
credential_id: str
@field_validator("credential_id")
@classmethod
def validate_switch_credential_id(cls, value: str) -> str:
return uuid_value(value)
class ParserCredentialValidate(BaseModel):
credentials: dict[str, Any]
class ParserPreferredProviderType(BaseModel):
preferred_provider_type: Literal["system", "custom"]
console_ns.schema_model(
ParserModelList.__name__, ParserModelList.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
console_ns.schema_model(
ParserCredentialId.__name__,
ParserCredentialId.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
ParserCredentialCreate.__name__,
ParserCredentialCreate.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
ParserCredentialUpdate.__name__,
ParserCredentialUpdate.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
ParserCredentialDelete.__name__,
ParserCredentialDelete.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
ParserCredentialSwitch.__name__,
ParserCredentialSwitch.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
ParserCredentialValidate.__name__,
ParserCredentialValidate.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
ParserPreferredProviderType.__name__,
ParserPreferredProviderType.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
@console_ns.route("/workspaces/current/model-providers")
class ModelProviderListApi(Resource):
@console_ns.expect(parser_model)
@console_ns.expect(console_ns.models[ParserModelList.__name__])
@setup_required
@login_required
@account_initialization_required
@ -33,38 +125,18 @@ class ModelProviderListApi(Resource):
_, current_tenant_id = current_account_with_tenant()
tenant_id = current_tenant_id
args = parser_model.parse_args()
payload = request.args.to_dict(flat=True) # type: ignore
args = ParserModelList.model_validate(payload)
model_provider_service = ModelProviderService()
provider_list = model_provider_service.get_provider_list(tenant_id=tenant_id, model_type=args.get("model_type"))
provider_list = model_provider_service.get_provider_list(tenant_id=tenant_id, model_type=args.model_type)
return jsonable_encoder({"data": provider_list})
parser_cred = reqparse.RequestParser().add_argument(
"credential_id", type=uuid_value, required=False, nullable=True, location="args"
)
parser_post_cred = (
reqparse.RequestParser()
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
)
parser_put_cred = (
reqparse.RequestParser()
.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
)
parser_delete_cred = reqparse.RequestParser().add_argument(
"credential_id", type=uuid_value, required=True, nullable=False, location="json"
)
@console_ns.route("/workspaces/current/model-providers/<path:provider>/credentials")
class ModelProviderCredentialApi(Resource):
@console_ns.expect(parser_cred)
@console_ns.expect(console_ns.models[ParserCredentialId.__name__])
@setup_required
@login_required
@account_initialization_required
@ -72,23 +144,25 @@ class ModelProviderCredentialApi(Resource):
_, current_tenant_id = current_account_with_tenant()
tenant_id = current_tenant_id
# if credential_id is not provided, return current used credential
args = parser_cred.parse_args()
payload = request.args.to_dict(flat=True) # type: ignore
args = ParserCredentialId.model_validate(payload)
model_provider_service = ModelProviderService()
credentials = model_provider_service.get_provider_credential(
tenant_id=tenant_id, provider=provider, credential_id=args.get("credential_id")
tenant_id=tenant_id, provider=provider, credential_id=args.credential_id
)
return {"credentials": credentials}
@console_ns.expect(parser_post_cred)
@console_ns.expect(console_ns.models[ParserCredentialCreate.__name__])
@setup_required
@login_required
@is_admin_or_owner_required
@account_initialization_required
def post(self, provider: str):
_, current_tenant_id = current_account_with_tenant()
args = parser_post_cred.parse_args()
payload = console_ns.payload or {}
args = ParserCredentialCreate.model_validate(payload)
model_provider_service = ModelProviderService()
@ -96,15 +170,15 @@ class ModelProviderCredentialApi(Resource):
model_provider_service.create_provider_credential(
tenant_id=current_tenant_id,
provider=provider,
credentials=args["credentials"],
credential_name=args["name"],
credentials=args.credentials,
credential_name=args.name,
)
except CredentialsValidateFailedError as ex:
raise ValueError(str(ex))
return {"result": "success"}, 201
@console_ns.expect(parser_put_cred)
@console_ns.expect(console_ns.models[ParserCredentialUpdate.__name__])
@setup_required
@login_required
@is_admin_or_owner_required
@ -112,7 +186,8 @@ class ModelProviderCredentialApi(Resource):
def put(self, provider: str):
_, current_tenant_id = current_account_with_tenant()
args = parser_put_cred.parse_args()
payload = console_ns.payload or {}
args = ParserCredentialUpdate.model_validate(payload)
model_provider_service = ModelProviderService()
@ -120,71 +195,64 @@ class ModelProviderCredentialApi(Resource):
model_provider_service.update_provider_credential(
tenant_id=current_tenant_id,
provider=provider,
credentials=args["credentials"],
credential_id=args["credential_id"],
credential_name=args["name"],
credentials=args.credentials,
credential_id=args.credential_id,
credential_name=args.name,
)
except CredentialsValidateFailedError as ex:
raise ValueError(str(ex))
return {"result": "success"}
@console_ns.expect(parser_delete_cred)
@console_ns.expect(console_ns.models[ParserCredentialDelete.__name__])
@setup_required
@login_required
@is_admin_or_owner_required
@account_initialization_required
def delete(self, provider: str):
_, current_tenant_id = current_account_with_tenant()
args = parser_delete_cred.parse_args()
payload = console_ns.payload or {}
args = ParserCredentialDelete.model_validate(payload)
model_provider_service = ModelProviderService()
model_provider_service.remove_provider_credential(
tenant_id=current_tenant_id, provider=provider, credential_id=args["credential_id"]
tenant_id=current_tenant_id, provider=provider, credential_id=args.credential_id
)
return {"result": "success"}, 204
parser_switch = reqparse.RequestParser().add_argument(
"credential_id", type=str, required=True, nullable=False, location="json"
)
@console_ns.route("/workspaces/current/model-providers/<path:provider>/credentials/switch")
class ModelProviderCredentialSwitchApi(Resource):
@console_ns.expect(parser_switch)
@console_ns.expect(console_ns.models[ParserCredentialSwitch.__name__])
@setup_required
@login_required
@is_admin_or_owner_required
@account_initialization_required
def post(self, provider: str):
_, current_tenant_id = current_account_with_tenant()
args = parser_switch.parse_args()
payload = console_ns.payload or {}
args = ParserCredentialSwitch.model_validate(payload)
service = ModelProviderService()
service.switch_active_provider_credential(
tenant_id=current_tenant_id,
provider=provider,
credential_id=args["credential_id"],
credential_id=args.credential_id,
)
return {"result": "success"}
parser_validate = reqparse.RequestParser().add_argument(
"credentials", type=dict, required=True, nullable=False, location="json"
)
@console_ns.route("/workspaces/current/model-providers/<path:provider>/credentials/validate")
class ModelProviderValidateApi(Resource):
@console_ns.expect(parser_validate)
@console_ns.expect(console_ns.models[ParserCredentialValidate.__name__])
@setup_required
@login_required
@account_initialization_required
def post(self, provider: str):
_, current_tenant_id = current_account_with_tenant()
args = parser_validate.parse_args()
payload = console_ns.payload or {}
args = ParserCredentialValidate.model_validate(payload)
tenant_id = current_tenant_id
@ -195,7 +263,7 @@ class ModelProviderValidateApi(Resource):
try:
model_provider_service.validate_provider_credentials(
tenant_id=tenant_id, provider=provider, credentials=args["credentials"]
tenant_id=tenant_id, provider=provider, credentials=args.credentials
)
except CredentialsValidateFailedError as ex:
result = False
@ -228,19 +296,9 @@ class ModelProviderIconApi(Resource):
return send_file(io.BytesIO(icon), mimetype=mimetype)
parser_preferred = reqparse.RequestParser().add_argument(
"preferred_provider_type",
type=str,
required=True,
nullable=False,
choices=["system", "custom"],
location="json",
)
@console_ns.route("/workspaces/current/model-providers/<path:provider>/preferred-provider-type")
class PreferredProviderTypeUpdateApi(Resource):
@console_ns.expect(parser_preferred)
@console_ns.expect(console_ns.models[ParserPreferredProviderType.__name__])
@setup_required
@login_required
@is_admin_or_owner_required
@ -250,11 +308,12 @@ class PreferredProviderTypeUpdateApi(Resource):
tenant_id = current_tenant_id
args = parser_preferred.parse_args()
payload = console_ns.payload or {}
args = ParserPreferredProviderType.model_validate(payload)
model_provider_service = ModelProviderService()
model_provider_service.switch_preferred_provider(
tenant_id=tenant_id, provider=provider, preferred_provider_type=args["preferred_provider_type"]
tenant_id=tenant_id, provider=provider, preferred_provider_type=args.preferred_provider_type
)
return {"result": "success"}

View File

@ -1,52 +1,172 @@
import logging
from typing import Any
from flask_restx import Resource, reqparse
from flask import request
from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.utils.encoders import jsonable_encoder
from libs.helper import StrLen, uuid_value
from libs.helper import uuid_value
from libs.login import current_account_with_tenant, login_required
from services.model_load_balancing_service import ModelLoadBalancingService
from services.model_provider_service import ModelProviderService
logger = logging.getLogger(__name__)
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
parser_get_default = reqparse.RequestParser().add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="args",
class ParserGetDefault(BaseModel):
model_type: ModelType
class ParserPostDefault(BaseModel):
class Inner(BaseModel):
model_type: ModelType
model: str
provider: str | None = None
model_settings: list[Inner]
console_ns.schema_model(
ParserGetDefault.__name__, ParserGetDefault.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
parser_post_default = reqparse.RequestParser().add_argument(
"model_settings", type=list, required=True, nullable=False, location="json"
console_ns.schema_model(
ParserPostDefault.__name__, ParserPostDefault.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
class ParserDeleteModels(BaseModel):
model: str
model_type: ModelType
console_ns.schema_model(
ParserDeleteModels.__name__, ParserDeleteModels.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
class LoadBalancingPayload(BaseModel):
configs: list[dict[str, Any]] | None = None
enabled: bool | None = None
class ParserPostModels(BaseModel):
model: str
model_type: ModelType
load_balancing: LoadBalancingPayload | None = None
config_from: str | None = None
credential_id: str | None = None
@field_validator("credential_id")
@classmethod
def validate_credential_id(cls, value: str | None) -> str | None:
if value is None:
return value
return uuid_value(value)
class ParserGetCredentials(BaseModel):
model: str
model_type: ModelType
config_from: str | None = None
credential_id: str | None = None
@field_validator("credential_id")
@classmethod
def validate_get_credential_id(cls, value: str | None) -> str | None:
if value is None:
return value
return uuid_value(value)
class ParserCredentialBase(BaseModel):
model: str
model_type: ModelType
class ParserCreateCredential(ParserCredentialBase):
name: str | None = Field(default=None, max_length=30)
credentials: dict[str, Any]
class ParserUpdateCredential(ParserCredentialBase):
credential_id: str
credentials: dict[str, Any]
name: str | None = Field(default=None, max_length=30)
@field_validator("credential_id")
@classmethod
def validate_update_credential_id(cls, value: str) -> str:
return uuid_value(value)
class ParserDeleteCredential(ParserCredentialBase):
credential_id: str
@field_validator("credential_id")
@classmethod
def validate_delete_credential_id(cls, value: str) -> str:
return uuid_value(value)
class ParserParameter(BaseModel):
model: str
console_ns.schema_model(
ParserPostModels.__name__, ParserPostModels.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
console_ns.schema_model(
ParserGetCredentials.__name__,
ParserGetCredentials.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
ParserCreateCredential.__name__,
ParserCreateCredential.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
ParserUpdateCredential.__name__,
ParserUpdateCredential.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
ParserDeleteCredential.__name__,
ParserDeleteCredential.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
ParserParameter.__name__, ParserParameter.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
@console_ns.route("/workspaces/current/default-model")
class DefaultModelApi(Resource):
@console_ns.expect(parser_get_default)
@console_ns.expect(console_ns.models[ParserGetDefault.__name__], validate=True)
@setup_required
@login_required
@account_initialization_required
def get(self):
_, tenant_id = current_account_with_tenant()
args = parser_get_default.parse_args()
args = ParserGetDefault.model_validate(request.args.to_dict(flat=True)) # type: ignore
model_provider_service = ModelProviderService()
default_model_entity = model_provider_service.get_default_model_of_model_type(
tenant_id=tenant_id, model_type=args["model_type"]
tenant_id=tenant_id, model_type=args.model_type
)
return jsonable_encoder({"data": default_model_entity})
@console_ns.expect(parser_post_default)
@console_ns.expect(console_ns.models[ParserPostDefault.__name__])
@setup_required
@login_required
@is_admin_or_owner_required
@ -54,66 +174,31 @@ class DefaultModelApi(Resource):
def post(self):
_, tenant_id = current_account_with_tenant()
args = parser_post_default.parse_args()
args = ParserPostDefault.model_validate(console_ns.payload)
model_provider_service = ModelProviderService()
model_settings = args["model_settings"]
model_settings = args.model_settings
for model_setting in model_settings:
if "model_type" not in model_setting or model_setting["model_type"] not in [mt.value for mt in ModelType]:
raise ValueError("invalid model type")
if "provider" not in model_setting:
if model_setting.provider is None:
continue
if "model" not in model_setting:
raise ValueError("invalid model")
try:
model_provider_service.update_default_model_of_model_type(
tenant_id=tenant_id,
model_type=model_setting["model_type"],
provider=model_setting["provider"],
model=model_setting["model"],
model_type=model_setting.model_type,
provider=model_setting.provider,
model=model_setting.model,
)
except Exception as ex:
logger.exception(
"Failed to update default model, model type: %s, model: %s",
model_setting["model_type"],
model_setting.get("model"),
model_setting.model_type,
model_setting.model,
)
raise ex
return {"result": "success"}
parser_post_models = (
reqparse.RequestParser()
.add_argument("model", type=str, required=True, nullable=False, location="json")
.add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="json",
)
.add_argument("load_balancing", type=dict, required=False, nullable=True, location="json")
.add_argument("config_from", type=str, required=False, nullable=True, location="json")
.add_argument("credential_id", type=uuid_value, required=False, nullable=True, location="json")
)
parser_delete_models = (
reqparse.RequestParser()
.add_argument("model", type=str, required=True, nullable=False, location="json")
.add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="json",
)
)
@console_ns.route("/workspaces/current/model-providers/<path:provider>/models")
class ModelProviderModelApi(Resource):
@setup_required
@ -127,7 +212,7 @@ class ModelProviderModelApi(Resource):
return jsonable_encoder({"data": models})
@console_ns.expect(parser_post_models)
@console_ns.expect(console_ns.models[ParserPostModels.__name__])
@setup_required
@login_required
@is_admin_or_owner_required
@ -135,45 +220,45 @@ class ModelProviderModelApi(Resource):
def post(self, provider: str):
# To save the model's load balance configs
_, tenant_id = current_account_with_tenant()
args = parser_post_models.parse_args()
args = ParserPostModels.model_validate(console_ns.payload)
if args.get("config_from", "") == "custom-model":
if not args.get("credential_id"):
if args.config_from == "custom-model":
if not args.credential_id:
raise ValueError("credential_id is required when configuring a custom-model")
service = ModelProviderService()
service.switch_active_custom_model_credential(
tenant_id=tenant_id,
provider=provider,
model_type=args["model_type"],
model=args["model"],
credential_id=args["credential_id"],
model_type=args.model_type,
model=args.model,
credential_id=args.credential_id,
)
model_load_balancing_service = ModelLoadBalancingService()
if "load_balancing" in args and args["load_balancing"] and "configs" in args["load_balancing"]:
if args.load_balancing and args.load_balancing.configs:
# save load balancing configs
model_load_balancing_service.update_load_balancing_configs(
tenant_id=tenant_id,
provider=provider,
model=args["model"],
model_type=args["model_type"],
configs=args["load_balancing"]["configs"],
config_from=args.get("config_from", ""),
model=args.model,
model_type=args.model_type,
configs=args.load_balancing.configs,
config_from=args.config_from or "",
)
if args.get("load_balancing", {}).get("enabled"):
if args.load_balancing.enabled:
model_load_balancing_service.enable_model_load_balancing(
tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"]
tenant_id=tenant_id, provider=provider, model=args.model, model_type=args.model_type
)
else:
model_load_balancing_service.disable_model_load_balancing(
tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"]
tenant_id=tenant_id, provider=provider, model=args.model, model_type=args.model_type
)
return {"result": "success"}, 200
@console_ns.expect(parser_delete_models)
@console_ns.expect(console_ns.models[ParserDeleteModels.__name__], validate=True)
@setup_required
@login_required
@is_admin_or_owner_required
@ -181,113 +266,53 @@ class ModelProviderModelApi(Resource):
def delete(self, provider: str):
_, tenant_id = current_account_with_tenant()
args = parser_delete_models.parse_args()
args = ParserDeleteModels.model_validate(console_ns.payload)
model_provider_service = ModelProviderService()
model_provider_service.remove_model(
tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"]
tenant_id=tenant_id, provider=provider, model=args.model, model_type=args.model_type
)
return {"result": "success"}, 204
parser_get_credentials = (
reqparse.RequestParser()
.add_argument("model", type=str, required=True, nullable=False, location="args")
.add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="args",
)
.add_argument("config_from", type=str, required=False, nullable=True, location="args")
.add_argument("credential_id", type=uuid_value, required=False, nullable=True, location="args")
)
parser_post_cred = (
reqparse.RequestParser()
.add_argument("model", type=str, required=True, nullable=False, location="json")
.add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="json",
)
.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
)
parser_put_cred = (
reqparse.RequestParser()
.add_argument("model", type=str, required=True, nullable=False, location="json")
.add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="json",
)
.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
)
parser_delete_cred = (
reqparse.RequestParser()
.add_argument("model", type=str, required=True, nullable=False, location="json")
.add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="json",
)
.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
)
@console_ns.route("/workspaces/current/model-providers/<path:provider>/models/credentials")
class ModelProviderModelCredentialApi(Resource):
@console_ns.expect(parser_get_credentials)
@console_ns.expect(console_ns.models[ParserGetCredentials.__name__])
@setup_required
@login_required
@account_initialization_required
def get(self, provider: str):
_, tenant_id = current_account_with_tenant()
args = parser_get_credentials.parse_args()
args = ParserGetCredentials.model_validate(request.args.to_dict(flat=True)) # type: ignore
model_provider_service = ModelProviderService()
current_credential = model_provider_service.get_model_credential(
tenant_id=tenant_id,
provider=provider,
model_type=args["model_type"],
model=args["model"],
credential_id=args.get("credential_id"),
model_type=args.model_type,
model=args.model,
credential_id=args.credential_id,
)
model_load_balancing_service = ModelLoadBalancingService()
is_load_balancing_enabled, load_balancing_configs = model_load_balancing_service.get_load_balancing_configs(
tenant_id=tenant_id,
provider=provider,
model=args["model"],
model_type=args["model_type"],
config_from=args.get("config_from", ""),
model=args.model,
model_type=args.model_type,
config_from=args.config_from or "",
)
if args.get("config_from", "") == "predefined-model":
if args.config_from == "predefined-model":
available_credentials = model_provider_service.provider_manager.get_provider_available_credentials(
tenant_id=tenant_id, provider_name=provider
)
else:
model_type = ModelType.value_of(args["model_type"]).to_origin_model_type()
model_type = args.model_type
available_credentials = model_provider_service.provider_manager.get_provider_model_available_credentials(
tenant_id=tenant_id, provider_name=provider, model_type=model_type, model_name=args["model"]
tenant_id=tenant_id, provider_name=provider, model_type=model_type, model_name=args.model
)
return jsonable_encoder(
@ -304,7 +329,7 @@ class ModelProviderModelCredentialApi(Resource):
}
)
@console_ns.expect(parser_post_cred)
@console_ns.expect(console_ns.models[ParserCreateCredential.__name__])
@setup_required
@login_required
@is_admin_or_owner_required
@ -312,7 +337,7 @@ class ModelProviderModelCredentialApi(Resource):
def post(self, provider: str):
_, tenant_id = current_account_with_tenant()
args = parser_post_cred.parse_args()
args = ParserCreateCredential.model_validate(console_ns.payload)
model_provider_service = ModelProviderService()
@ -320,30 +345,30 @@ class ModelProviderModelCredentialApi(Resource):
model_provider_service.create_model_credential(
tenant_id=tenant_id,
provider=provider,
model=args["model"],
model_type=args["model_type"],
credentials=args["credentials"],
credential_name=args["name"],
model=args.model,
model_type=args.model_type,
credentials=args.credentials,
credential_name=args.name,
)
except CredentialsValidateFailedError as ex:
logger.exception(
"Failed to save model credentials, tenant_id: %s, model: %s, model_type: %s",
tenant_id,
args.get("model"),
args.get("model_type"),
args.model,
args.model_type,
)
raise ValueError(str(ex))
return {"result": "success"}, 201
@console_ns.expect(parser_put_cred)
@console_ns.expect(console_ns.models[ParserUpdateCredential.__name__])
@setup_required
@login_required
@is_admin_or_owner_required
@account_initialization_required
def put(self, provider: str):
_, current_tenant_id = current_account_with_tenant()
args = parser_put_cred.parse_args()
args = ParserUpdateCredential.model_validate(console_ns.payload)
model_provider_service = ModelProviderService()
@ -351,106 +376,87 @@ class ModelProviderModelCredentialApi(Resource):
model_provider_service.update_model_credential(
tenant_id=current_tenant_id,
provider=provider,
model_type=args["model_type"],
model=args["model"],
credentials=args["credentials"],
credential_id=args["credential_id"],
credential_name=args["name"],
model_type=args.model_type,
model=args.model,
credentials=args.credentials,
credential_id=args.credential_id,
credential_name=args.name,
)
except CredentialsValidateFailedError as ex:
raise ValueError(str(ex))
return {"result": "success"}
@console_ns.expect(parser_delete_cred)
@console_ns.expect(console_ns.models[ParserDeleteCredential.__name__])
@setup_required
@login_required
@is_admin_or_owner_required
@account_initialization_required
def delete(self, provider: str):
_, current_tenant_id = current_account_with_tenant()
args = parser_delete_cred.parse_args()
args = ParserDeleteCredential.model_validate(console_ns.payload)
model_provider_service = ModelProviderService()
model_provider_service.remove_model_credential(
tenant_id=current_tenant_id,
provider=provider,
model_type=args["model_type"],
model=args["model"],
credential_id=args["credential_id"],
model_type=args.model_type,
model=args.model,
credential_id=args.credential_id,
)
return {"result": "success"}, 204
parser_switch = (
reqparse.RequestParser()
.add_argument("model", type=str, required=True, nullable=False, location="json")
.add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="json",
)
.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
class ParserSwitch(BaseModel):
model: str
model_type: ModelType
credential_id: str
console_ns.schema_model(
ParserSwitch.__name__, ParserSwitch.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
@console_ns.route("/workspaces/current/model-providers/<path:provider>/models/credentials/switch")
class ModelProviderModelCredentialSwitchApi(Resource):
@console_ns.expect(parser_switch)
@console_ns.expect(console_ns.models[ParserSwitch.__name__])
@setup_required
@login_required
@is_admin_or_owner_required
@account_initialization_required
def post(self, provider: str):
_, current_tenant_id = current_account_with_tenant()
args = parser_switch.parse_args()
args = ParserSwitch.model_validate(console_ns.payload)
service = ModelProviderService()
service.add_model_credential_to_model_list(
tenant_id=current_tenant_id,
provider=provider,
model_type=args["model_type"],
model=args["model"],
credential_id=args["credential_id"],
model_type=args.model_type,
model=args.model,
credential_id=args.credential_id,
)
return {"result": "success"}
parser_model_enable_disable = (
reqparse.RequestParser()
.add_argument("model", type=str, required=True, nullable=False, location="json")
.add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="json",
)
)
@console_ns.route(
"/workspaces/current/model-providers/<path:provider>/models/enable", endpoint="model-provider-model-enable"
)
class ModelProviderModelEnableApi(Resource):
@console_ns.expect(parser_model_enable_disable)
@console_ns.expect(console_ns.models[ParserDeleteModels.__name__])
@setup_required
@login_required
@account_initialization_required
def patch(self, provider: str):
_, tenant_id = current_account_with_tenant()
args = parser_model_enable_disable.parse_args()
args = ParserDeleteModels.model_validate(console_ns.payload)
model_provider_service = ModelProviderService()
model_provider_service.enable_model(
tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"]
tenant_id=tenant_id, provider=provider, model=args.model, model_type=args.model_type
)
return {"result": "success"}
@ -460,48 +466,43 @@ class ModelProviderModelEnableApi(Resource):
"/workspaces/current/model-providers/<path:provider>/models/disable", endpoint="model-provider-model-disable"
)
class ModelProviderModelDisableApi(Resource):
@console_ns.expect(parser_model_enable_disable)
@console_ns.expect(console_ns.models[ParserDeleteModels.__name__])
@setup_required
@login_required
@account_initialization_required
def patch(self, provider: str):
_, tenant_id = current_account_with_tenant()
args = parser_model_enable_disable.parse_args()
args = ParserDeleteModels.model_validate(console_ns.payload)
model_provider_service = ModelProviderService()
model_provider_service.disable_model(
tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"]
tenant_id=tenant_id, provider=provider, model=args.model, model_type=args.model_type
)
return {"result": "success"}
parser_validate = (
reqparse.RequestParser()
.add_argument("model", type=str, required=True, nullable=False, location="json")
.add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="json",
)
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
class ParserValidate(BaseModel):
model: str
model_type: ModelType
credentials: dict
console_ns.schema_model(
ParserValidate.__name__, ParserValidate.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
@console_ns.route("/workspaces/current/model-providers/<path:provider>/models/credentials/validate")
class ModelProviderModelValidateApi(Resource):
@console_ns.expect(parser_validate)
@console_ns.expect(console_ns.models[ParserValidate.__name__])
@setup_required
@login_required
@account_initialization_required
def post(self, provider: str):
_, tenant_id = current_account_with_tenant()
args = parser_validate.parse_args()
args = ParserValidate.model_validate(console_ns.payload)
model_provider_service = ModelProviderService()
@ -512,9 +513,9 @@ class ModelProviderModelValidateApi(Resource):
model_provider_service.validate_model_credentials(
tenant_id=tenant_id,
provider=provider,
model=args["model"],
model_type=args["model_type"],
credentials=args["credentials"],
model=args.model,
model_type=args.model_type,
credentials=args.credentials,
)
except CredentialsValidateFailedError as ex:
result = False
@ -528,24 +529,19 @@ class ModelProviderModelValidateApi(Resource):
return response
parser_parameter = reqparse.RequestParser().add_argument(
"model", type=str, required=True, nullable=False, location="args"
)
@console_ns.route("/workspaces/current/model-providers/<path:provider>/models/parameter-rules")
class ModelProviderModelParameterRuleApi(Resource):
@console_ns.expect(parser_parameter)
@console_ns.expect(console_ns.models[ParserParameter.__name__])
@setup_required
@login_required
@account_initialization_required
def get(self, provider: str):
args = parser_parameter.parse_args()
args = ParserParameter.model_validate(request.args.to_dict(flat=True)) # type: ignore
_, tenant_id = current_account_with_tenant()
model_provider_service = ModelProviderService()
parameter_rules = model_provider_service.get_model_parameter_rules(
tenant_id=tenant_id, provider=provider, model=args["model"]
tenant_id=tenant_id, provider=provider, model=args.model
)
return jsonable_encoder({"data": parameter_rules})

View File

@ -1,7 +1,9 @@
import io
from typing import Literal
from flask import request, send_file
from flask_restx import Resource, reqparse
from flask_restx import Resource
from pydantic import BaseModel, Field
from werkzeug.exceptions import Forbidden
from configs import dify_config
@ -17,6 +19,8 @@ from services.plugin.plugin_parameter_service import PluginParameterService
from services.plugin.plugin_permission_service import PluginPermissionService
from services.plugin.plugin_service import PluginService
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
@console_ns.route("/workspaces/current/plugin/debugging-key")
class PluginDebuggingKeyApi(Resource):
@ -37,88 +41,251 @@ class PluginDebuggingKeyApi(Resource):
raise ValueError(e)
parser_list = (
reqparse.RequestParser()
.add_argument("page", type=int, required=False, location="args", default=1)
.add_argument("page_size", type=int, required=False, location="args", default=256)
class ParserList(BaseModel):
page: int = Field(default=1)
page_size: int = Field(default=256)
console_ns.schema_model(
ParserList.__name__, ParserList.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
@console_ns.route("/workspaces/current/plugin/list")
class PluginListApi(Resource):
@console_ns.expect(parser_list)
@console_ns.expect(console_ns.models[ParserList.__name__])
@setup_required
@login_required
@account_initialization_required
def get(self):
_, tenant_id = current_account_with_tenant()
args = parser_list.parse_args()
args = ParserList.model_validate(request.args.to_dict(flat=True)) # type: ignore
try:
plugins_with_total = PluginService.list_with_total(tenant_id, args["page"], args["page_size"])
plugins_with_total = PluginService.list_with_total(tenant_id, args.page, args.page_size)
except PluginDaemonClientSideError as e:
raise ValueError(e)
return jsonable_encoder({"plugins": plugins_with_total.list, "total": plugins_with_total.total})
parser_latest = reqparse.RequestParser().add_argument("plugin_ids", type=list, required=True, location="json")
class ParserLatest(BaseModel):
plugin_ids: list[str]
console_ns.schema_model(
ParserLatest.__name__, ParserLatest.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
class ParserIcon(BaseModel):
tenant_id: str
filename: str
class ParserAsset(BaseModel):
plugin_unique_identifier: str
file_name: str
class ParserGithubUpload(BaseModel):
repo: str
version: str
package: str
class ParserPluginIdentifiers(BaseModel):
plugin_unique_identifiers: list[str]
class ParserGithubInstall(BaseModel):
plugin_unique_identifier: str
repo: str
version: str
package: str
class ParserPluginIdentifierQuery(BaseModel):
plugin_unique_identifier: str
class ParserTasks(BaseModel):
page: int
page_size: int
class ParserMarketplaceUpgrade(BaseModel):
original_plugin_unique_identifier: str
new_plugin_unique_identifier: str
class ParserGithubUpgrade(BaseModel):
original_plugin_unique_identifier: str
new_plugin_unique_identifier: str
repo: str
version: str
package: str
class ParserUninstall(BaseModel):
plugin_installation_id: str
class ParserPermissionChange(BaseModel):
install_permission: TenantPluginPermission.InstallPermission
debug_permission: TenantPluginPermission.DebugPermission
class ParserDynamicOptions(BaseModel):
plugin_id: str
provider: str
action: str
parameter: str
credential_id: str | None = None
provider_type: Literal["tool", "trigger"]
class PluginPermissionSettingsPayload(BaseModel):
install_permission: TenantPluginPermission.InstallPermission = TenantPluginPermission.InstallPermission.EVERYONE
debug_permission: TenantPluginPermission.DebugPermission = TenantPluginPermission.DebugPermission.EVERYONE
class PluginAutoUpgradeSettingsPayload(BaseModel):
strategy_setting: TenantPluginAutoUpgradeStrategy.StrategySetting = (
TenantPluginAutoUpgradeStrategy.StrategySetting.FIX_ONLY
)
upgrade_time_of_day: int = 0
upgrade_mode: TenantPluginAutoUpgradeStrategy.UpgradeMode = TenantPluginAutoUpgradeStrategy.UpgradeMode.EXCLUDE
exclude_plugins: list[str] = Field(default_factory=list)
include_plugins: list[str] = Field(default_factory=list)
class ParserPreferencesChange(BaseModel):
permission: PluginPermissionSettingsPayload
auto_upgrade: PluginAutoUpgradeSettingsPayload
class ParserExcludePlugin(BaseModel):
plugin_id: str
class ParserReadme(BaseModel):
plugin_unique_identifier: str
language: str = Field(default="en-US")
console_ns.schema_model(
ParserIcon.__name__, ParserIcon.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
console_ns.schema_model(
ParserAsset.__name__, ParserAsset.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
console_ns.schema_model(
ParserGithubUpload.__name__, ParserGithubUpload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
console_ns.schema_model(
ParserPluginIdentifiers.__name__,
ParserPluginIdentifiers.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
ParserGithubInstall.__name__, ParserGithubInstall.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
console_ns.schema_model(
ParserPluginIdentifierQuery.__name__,
ParserPluginIdentifierQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
ParserTasks.__name__, ParserTasks.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
console_ns.schema_model(
ParserMarketplaceUpgrade.__name__,
ParserMarketplaceUpgrade.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
ParserGithubUpgrade.__name__, ParserGithubUpgrade.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
console_ns.schema_model(
ParserUninstall.__name__, ParserUninstall.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
console_ns.schema_model(
ParserPermissionChange.__name__,
ParserPermissionChange.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
ParserDynamicOptions.__name__,
ParserDynamicOptions.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
ParserPreferencesChange.__name__,
ParserPreferencesChange.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
ParserExcludePlugin.__name__,
ParserExcludePlugin.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
ParserReadme.__name__, ParserReadme.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
@console_ns.route("/workspaces/current/plugin/list/latest-versions")
class PluginListLatestVersionsApi(Resource):
@console_ns.expect(parser_latest)
@console_ns.expect(console_ns.models[ParserLatest.__name__])
@setup_required
@login_required
@account_initialization_required
def post(self):
args = parser_latest.parse_args()
args = ParserLatest.model_validate(console_ns.payload)
try:
versions = PluginService.list_latest_versions(args["plugin_ids"])
versions = PluginService.list_latest_versions(args.plugin_ids)
except PluginDaemonClientSideError as e:
raise ValueError(e)
return jsonable_encoder({"versions": versions})
parser_ids = reqparse.RequestParser().add_argument("plugin_ids", type=list, required=True, location="json")
@console_ns.route("/workspaces/current/plugin/list/installations/ids")
class PluginListInstallationsFromIdsApi(Resource):
@console_ns.expect(parser_ids)
@console_ns.expect(console_ns.models[ParserLatest.__name__])
@setup_required
@login_required
@account_initialization_required
def post(self):
_, tenant_id = current_account_with_tenant()
args = parser_ids.parse_args()
args = ParserLatest.model_validate(console_ns.payload)
try:
plugins = PluginService.list_installations_from_ids(tenant_id, args["plugin_ids"])
plugins = PluginService.list_installations_from_ids(tenant_id, args.plugin_ids)
except PluginDaemonClientSideError as e:
raise ValueError(e)
return jsonable_encoder({"plugins": plugins})
parser_icon = (
reqparse.RequestParser()
.add_argument("tenant_id", type=str, required=True, location="args")
.add_argument("filename", type=str, required=True, location="args")
)
@console_ns.route("/workspaces/current/plugin/icon")
class PluginIconApi(Resource):
@console_ns.expect(parser_icon)
@console_ns.expect(console_ns.models[ParserIcon.__name__])
@setup_required
def get(self):
args = parser_icon.parse_args()
args = ParserIcon.model_validate(request.args.to_dict(flat=True)) # type: ignore
try:
icon_bytes, mimetype = PluginService.get_asset(args["tenant_id"], args["filename"])
icon_bytes, mimetype = PluginService.get_asset(args.tenant_id, args.filename)
except PluginDaemonClientSideError as e:
raise ValueError(e)
@ -128,20 +295,16 @@ class PluginIconApi(Resource):
@console_ns.route("/workspaces/current/plugin/asset")
class PluginAssetApi(Resource):
@console_ns.expect(console_ns.models[ParserAsset.__name__])
@setup_required
@login_required
@account_initialization_required
def get(self):
req = (
reqparse.RequestParser()
.add_argument("plugin_unique_identifier", type=str, required=True, location="args")
.add_argument("file_name", type=str, required=True, location="args")
)
args = req.parse_args()
args = ParserAsset.model_validate(request.args.to_dict(flat=True)) # type: ignore
_, tenant_id = current_account_with_tenant()
try:
binary = PluginService.extract_asset(tenant_id, args["plugin_unique_identifier"], args["file_name"])
binary = PluginService.extract_asset(tenant_id, args.plugin_unique_identifier, args.file_name)
return send_file(io.BytesIO(binary), mimetype="application/octet-stream")
except PluginDaemonClientSideError as e:
raise ValueError(e)
@ -171,17 +334,9 @@ class PluginUploadFromPkgApi(Resource):
return jsonable_encoder(response)
parser_github = (
reqparse.RequestParser()
.add_argument("repo", type=str, required=True, location="json")
.add_argument("version", type=str, required=True, location="json")
.add_argument("package", type=str, required=True, location="json")
)
@console_ns.route("/workspaces/current/plugin/upload/github")
class PluginUploadFromGithubApi(Resource):
@console_ns.expect(parser_github)
@console_ns.expect(console_ns.models[ParserGithubUpload.__name__])
@setup_required
@login_required
@account_initialization_required
@ -189,10 +344,10 @@ class PluginUploadFromGithubApi(Resource):
def post(self):
_, tenant_id = current_account_with_tenant()
args = parser_github.parse_args()
args = ParserGithubUpload.model_validate(console_ns.payload)
try:
response = PluginService.upload_pkg_from_github(tenant_id, args["repo"], args["version"], args["package"])
response = PluginService.upload_pkg_from_github(tenant_id, args.repo, args.version, args.package)
except PluginDaemonClientSideError as e:
raise ValueError(e)
@ -223,47 +378,28 @@ class PluginUploadFromBundleApi(Resource):
return jsonable_encoder(response)
parser_pkg = reqparse.RequestParser().add_argument(
"plugin_unique_identifiers", type=list, required=True, location="json"
)
@console_ns.route("/workspaces/current/plugin/install/pkg")
class PluginInstallFromPkgApi(Resource):
@console_ns.expect(parser_pkg)
@console_ns.expect(console_ns.models[ParserPluginIdentifiers.__name__])
@setup_required
@login_required
@account_initialization_required
@plugin_permission_required(install_required=True)
def post(self):
_, tenant_id = current_account_with_tenant()
args = parser_pkg.parse_args()
# check if all plugin_unique_identifiers are valid string
for plugin_unique_identifier in args["plugin_unique_identifiers"]:
if not isinstance(plugin_unique_identifier, str):
raise ValueError("Invalid plugin unique identifier")
args = ParserPluginIdentifiers.model_validate(console_ns.payload)
try:
response = PluginService.install_from_local_pkg(tenant_id, args["plugin_unique_identifiers"])
response = PluginService.install_from_local_pkg(tenant_id, args.plugin_unique_identifiers)
except PluginDaemonClientSideError as e:
raise ValueError(e)
return jsonable_encoder(response)
parser_githubapi = (
reqparse.RequestParser()
.add_argument("repo", type=str, required=True, location="json")
.add_argument("version", type=str, required=True, location="json")
.add_argument("package", type=str, required=True, location="json")
.add_argument("plugin_unique_identifier", type=str, required=True, location="json")
)
@console_ns.route("/workspaces/current/plugin/install/github")
class PluginInstallFromGithubApi(Resource):
@console_ns.expect(parser_githubapi)
@console_ns.expect(console_ns.models[ParserGithubInstall.__name__])
@setup_required
@login_required
@account_initialization_required
@ -271,15 +407,15 @@ class PluginInstallFromGithubApi(Resource):
def post(self):
_, tenant_id = current_account_with_tenant()
args = parser_githubapi.parse_args()
args = ParserGithubInstall.model_validate(console_ns.payload)
try:
response = PluginService.install_from_github(
tenant_id,
args["plugin_unique_identifier"],
args["repo"],
args["version"],
args["package"],
args.plugin_unique_identifier,
args.repo,
args.version,
args.package,
)
except PluginDaemonClientSideError as e:
raise ValueError(e)
@ -287,14 +423,9 @@ class PluginInstallFromGithubApi(Resource):
return jsonable_encoder(response)
parser_marketplace = reqparse.RequestParser().add_argument(
"plugin_unique_identifiers", type=list, required=True, location="json"
)
@console_ns.route("/workspaces/current/plugin/install/marketplace")
class PluginInstallFromMarketplaceApi(Resource):
@console_ns.expect(parser_marketplace)
@console_ns.expect(console_ns.models[ParserPluginIdentifiers.__name__])
@setup_required
@login_required
@account_initialization_required
@ -302,43 +433,33 @@ class PluginInstallFromMarketplaceApi(Resource):
def post(self):
_, tenant_id = current_account_with_tenant()
args = parser_marketplace.parse_args()
# check if all plugin_unique_identifiers are valid string
for plugin_unique_identifier in args["plugin_unique_identifiers"]:
if not isinstance(plugin_unique_identifier, str):
raise ValueError("Invalid plugin unique identifier")
args = ParserPluginIdentifiers.model_validate(console_ns.payload)
try:
response = PluginService.install_from_marketplace_pkg(tenant_id, args["plugin_unique_identifiers"])
response = PluginService.install_from_marketplace_pkg(tenant_id, args.plugin_unique_identifiers)
except PluginDaemonClientSideError as e:
raise ValueError(e)
return jsonable_encoder(response)
parser_pkgapi = reqparse.RequestParser().add_argument(
"plugin_unique_identifier", type=str, required=True, location="args"
)
@console_ns.route("/workspaces/current/plugin/marketplace/pkg")
class PluginFetchMarketplacePkgApi(Resource):
@console_ns.expect(parser_pkgapi)
@console_ns.expect(console_ns.models[ParserPluginIdentifierQuery.__name__])
@setup_required
@login_required
@account_initialization_required
@plugin_permission_required(install_required=True)
def get(self):
_, tenant_id = current_account_with_tenant()
args = parser_pkgapi.parse_args()
args = ParserPluginIdentifierQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
try:
return jsonable_encoder(
{
"manifest": PluginService.fetch_marketplace_pkg(
tenant_id,
args["plugin_unique_identifier"],
args.plugin_unique_identifier,
)
}
)
@ -346,14 +467,9 @@ class PluginFetchMarketplacePkgApi(Resource):
raise ValueError(e)
parser_fetch = reqparse.RequestParser().add_argument(
"plugin_unique_identifier", type=str, required=True, location="args"
)
@console_ns.route("/workspaces/current/plugin/fetch-manifest")
class PluginFetchManifestApi(Resource):
@console_ns.expect(parser_fetch)
@console_ns.expect(console_ns.models[ParserPluginIdentifierQuery.__name__])
@setup_required
@login_required
@account_initialization_required
@ -361,30 +477,19 @@ class PluginFetchManifestApi(Resource):
def get(self):
_, tenant_id = current_account_with_tenant()
args = parser_fetch.parse_args()
args = ParserPluginIdentifierQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
try:
return jsonable_encoder(
{
"manifest": PluginService.fetch_plugin_manifest(
tenant_id, args["plugin_unique_identifier"]
).model_dump()
}
{"manifest": PluginService.fetch_plugin_manifest(tenant_id, args.plugin_unique_identifier).model_dump()}
)
except PluginDaemonClientSideError as e:
raise ValueError(e)
parser_tasks = (
reqparse.RequestParser()
.add_argument("page", type=int, required=True, location="args")
.add_argument("page_size", type=int, required=True, location="args")
)
@console_ns.route("/workspaces/current/plugin/tasks")
class PluginFetchInstallTasksApi(Resource):
@console_ns.expect(parser_tasks)
@console_ns.expect(console_ns.models[ParserTasks.__name__])
@setup_required
@login_required
@account_initialization_required
@ -392,12 +497,10 @@ class PluginFetchInstallTasksApi(Resource):
def get(self):
_, tenant_id = current_account_with_tenant()
args = parser_tasks.parse_args()
args = ParserTasks.model_validate(request.args.to_dict(flat=True)) # type: ignore
try:
return jsonable_encoder(
{"tasks": PluginService.fetch_install_tasks(tenant_id, args["page"], args["page_size"])}
)
return jsonable_encoder({"tasks": PluginService.fetch_install_tasks(tenant_id, args.page, args.page_size)})
except PluginDaemonClientSideError as e:
raise ValueError(e)
@ -462,16 +565,9 @@ class PluginDeleteInstallTaskItemApi(Resource):
raise ValueError(e)
parser_marketplace_api = (
reqparse.RequestParser()
.add_argument("original_plugin_unique_identifier", type=str, required=True, location="json")
.add_argument("new_plugin_unique_identifier", type=str, required=True, location="json")
)
@console_ns.route("/workspaces/current/plugin/upgrade/marketplace")
class PluginUpgradeFromMarketplaceApi(Resource):
@console_ns.expect(parser_marketplace_api)
@console_ns.expect(console_ns.models[ParserMarketplaceUpgrade.__name__])
@setup_required
@login_required
@account_initialization_required
@ -479,31 +575,21 @@ class PluginUpgradeFromMarketplaceApi(Resource):
def post(self):
_, tenant_id = current_account_with_tenant()
args = parser_marketplace_api.parse_args()
args = ParserMarketplaceUpgrade.model_validate(console_ns.payload)
try:
return jsonable_encoder(
PluginService.upgrade_plugin_with_marketplace(
tenant_id, args["original_plugin_unique_identifier"], args["new_plugin_unique_identifier"]
tenant_id, args.original_plugin_unique_identifier, args.new_plugin_unique_identifier
)
)
except PluginDaemonClientSideError as e:
raise ValueError(e)
parser_github_post = (
reqparse.RequestParser()
.add_argument("original_plugin_unique_identifier", type=str, required=True, location="json")
.add_argument("new_plugin_unique_identifier", type=str, required=True, location="json")
.add_argument("repo", type=str, required=True, location="json")
.add_argument("version", type=str, required=True, location="json")
.add_argument("package", type=str, required=True, location="json")
)
@console_ns.route("/workspaces/current/plugin/upgrade/github")
class PluginUpgradeFromGithubApi(Resource):
@console_ns.expect(parser_github_post)
@console_ns.expect(console_ns.models[ParserGithubUpgrade.__name__])
@setup_required
@login_required
@account_initialization_required
@ -511,56 +597,44 @@ class PluginUpgradeFromGithubApi(Resource):
def post(self):
_, tenant_id = current_account_with_tenant()
args = parser_github_post.parse_args()
args = ParserGithubUpgrade.model_validate(console_ns.payload)
try:
return jsonable_encoder(
PluginService.upgrade_plugin_with_github(
tenant_id,
args["original_plugin_unique_identifier"],
args["new_plugin_unique_identifier"],
args["repo"],
args["version"],
args["package"],
args.original_plugin_unique_identifier,
args.new_plugin_unique_identifier,
args.repo,
args.version,
args.package,
)
)
except PluginDaemonClientSideError as e:
raise ValueError(e)
parser_uninstall = reqparse.RequestParser().add_argument(
"plugin_installation_id", type=str, required=True, location="json"
)
@console_ns.route("/workspaces/current/plugin/uninstall")
class PluginUninstallApi(Resource):
@console_ns.expect(parser_uninstall)
@console_ns.expect(console_ns.models[ParserUninstall.__name__])
@setup_required
@login_required
@account_initialization_required
@plugin_permission_required(install_required=True)
def post(self):
args = parser_uninstall.parse_args()
args = ParserUninstall.model_validate(console_ns.payload)
_, tenant_id = current_account_with_tenant()
try:
return {"success": PluginService.uninstall(tenant_id, args["plugin_installation_id"])}
return {"success": PluginService.uninstall(tenant_id, args.plugin_installation_id)}
except PluginDaemonClientSideError as e:
raise ValueError(e)
parser_change_post = (
reqparse.RequestParser()
.add_argument("install_permission", type=str, required=True, location="json")
.add_argument("debug_permission", type=str, required=True, location="json")
)
@console_ns.route("/workspaces/current/plugin/permission/change")
class PluginChangePermissionApi(Resource):
@console_ns.expect(parser_change_post)
@console_ns.expect(console_ns.models[ParserPermissionChange.__name__])
@setup_required
@login_required
@account_initialization_required
@ -570,14 +644,15 @@ class PluginChangePermissionApi(Resource):
if not user.is_admin_or_owner:
raise Forbidden()
args = parser_change_post.parse_args()
install_permission = TenantPluginPermission.InstallPermission(args["install_permission"])
debug_permission = TenantPluginPermission.DebugPermission(args["debug_permission"])
args = ParserPermissionChange.model_validate(console_ns.payload)
tenant_id = current_tenant_id
return {"success": PluginPermissionService.change_permission(tenant_id, install_permission, debug_permission)}
return {
"success": PluginPermissionService.change_permission(
tenant_id, args.install_permission, args.debug_permission
)
}
@console_ns.route("/workspaces/current/plugin/permission/fetch")
@ -605,20 +680,9 @@ class PluginFetchPermissionApi(Resource):
)
parser_dynamic = (
reqparse.RequestParser()
.add_argument("plugin_id", type=str, required=True, location="args")
.add_argument("provider", type=str, required=True, location="args")
.add_argument("action", type=str, required=True, location="args")
.add_argument("parameter", type=str, required=True, location="args")
.add_argument("credential_id", type=str, required=False, location="args")
.add_argument("provider_type", type=str, required=True, location="args")
)
@console_ns.route("/workspaces/current/plugin/parameters/dynamic-options")
class PluginFetchDynamicSelectOptionsApi(Resource):
@console_ns.expect(parser_dynamic)
@console_ns.expect(console_ns.models[ParserDynamicOptions.__name__])
@setup_required
@login_required
@is_admin_or_owner_required
@ -627,18 +691,18 @@ class PluginFetchDynamicSelectOptionsApi(Resource):
current_user, tenant_id = current_account_with_tenant()
user_id = current_user.id
args = parser_dynamic.parse_args()
args = ParserDynamicOptions.model_validate(request.args.to_dict(flat=True)) # type: ignore
try:
options = PluginParameterService.get_dynamic_select_options(
tenant_id=tenant_id,
user_id=user_id,
plugin_id=args["plugin_id"],
provider=args["provider"],
action=args["action"],
parameter=args["parameter"],
credential_id=args["credential_id"],
provider_type=args["provider_type"],
plugin_id=args.plugin_id,
provider=args.provider,
action=args.action,
parameter=args.parameter,
credential_id=args.credential_id,
provider_type=args.provider_type,
)
except PluginDaemonClientSideError as e:
raise ValueError(e)
@ -646,16 +710,9 @@ class PluginFetchDynamicSelectOptionsApi(Resource):
return jsonable_encoder({"options": options})
parser_change = (
reqparse.RequestParser()
.add_argument("permission", type=dict, required=True, location="json")
.add_argument("auto_upgrade", type=dict, required=True, location="json")
)
@console_ns.route("/workspaces/current/plugin/preferences/change")
class PluginChangePreferencesApi(Resource):
@console_ns.expect(parser_change)
@console_ns.expect(console_ns.models[ParserPreferencesChange.__name__])
@setup_required
@login_required
@account_initialization_required
@ -664,22 +721,20 @@ class PluginChangePreferencesApi(Resource):
if not user.is_admin_or_owner:
raise Forbidden()
args = parser_change.parse_args()
args = ParserPreferencesChange.model_validate(console_ns.payload)
permission = args["permission"]
permission = args.permission
install_permission = TenantPluginPermission.InstallPermission(permission.get("install_permission", "everyone"))
debug_permission = TenantPluginPermission.DebugPermission(permission.get("debug_permission", "everyone"))
install_permission = permission.install_permission
debug_permission = permission.debug_permission
auto_upgrade = args["auto_upgrade"]
auto_upgrade = args.auto_upgrade
strategy_setting = TenantPluginAutoUpgradeStrategy.StrategySetting(
auto_upgrade.get("strategy_setting", "fix_only")
)
upgrade_time_of_day = auto_upgrade.get("upgrade_time_of_day", 0)
upgrade_mode = TenantPluginAutoUpgradeStrategy.UpgradeMode(auto_upgrade.get("upgrade_mode", "exclude"))
exclude_plugins = auto_upgrade.get("exclude_plugins", [])
include_plugins = auto_upgrade.get("include_plugins", [])
strategy_setting = auto_upgrade.strategy_setting
upgrade_time_of_day = auto_upgrade.upgrade_time_of_day
upgrade_mode = auto_upgrade.upgrade_mode
exclude_plugins = auto_upgrade.exclude_plugins
include_plugins = auto_upgrade.include_plugins
# set permission
set_permission_result = PluginPermissionService.change_permission(
@ -744,12 +799,9 @@ class PluginFetchPreferencesApi(Resource):
return jsonable_encoder({"permission": permission_dict, "auto_upgrade": auto_upgrade_dict})
parser_exclude = reqparse.RequestParser().add_argument("plugin_id", type=str, required=True, location="json")
@console_ns.route("/workspaces/current/plugin/preferences/autoupgrade/exclude")
class PluginAutoUpgradeExcludePluginApi(Resource):
@console_ns.expect(parser_exclude)
@console_ns.expect(console_ns.models[ParserExcludePlugin.__name__])
@setup_required
@login_required
@account_initialization_required
@ -757,28 +809,20 @@ class PluginAutoUpgradeExcludePluginApi(Resource):
# exclude one single plugin
_, tenant_id = current_account_with_tenant()
args = parser_exclude.parse_args()
args = ParserExcludePlugin.model_validate(console_ns.payload)
return jsonable_encoder({"success": PluginAutoUpgradeService.exclude_plugin(tenant_id, args["plugin_id"])})
return jsonable_encoder({"success": PluginAutoUpgradeService.exclude_plugin(tenant_id, args.plugin_id)})
@console_ns.route("/workspaces/current/plugin/readme")
class PluginReadmeApi(Resource):
@console_ns.expect(console_ns.models[ParserReadme.__name__])
@setup_required
@login_required
@account_initialization_required
def get(self):
_, tenant_id = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("plugin_unique_identifier", type=str, required=True, location="args")
.add_argument("language", type=str, required=False, location="args")
)
args = parser.parse_args()
args = ParserReadme.model_validate(request.args.to_dict(flat=True)) # type: ignore
return jsonable_encoder(
{
"readme": PluginService.fetch_plugin_readme(
tenant_id, args["plugin_unique_identifier"], args.get("language", "en-US")
)
}
{"readme": PluginService.fetch_plugin_readme(tenant_id, args.plugin_unique_identifier, args.language)}
)

View File

@ -6,8 +6,6 @@ from sqlalchemy.orm import Session
from werkzeug.exceptions import BadRequest, Forbidden
from configs import dify_config
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required
from controllers.web.error import NotFoundError
from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.entities.plugin_daemon import CredentialType
@ -23,9 +21,13 @@ from services.trigger.trigger_provider_service import TriggerProviderService
from services.trigger.trigger_subscription_builder_service import TriggerSubscriptionBuilderService
from services.trigger.trigger_subscription_operator_service import TriggerSubscriptionOperatorService
from .. import console_ns
from ..wraps import account_initialization_required, is_admin_or_owner_required, setup_required
logger = logging.getLogger(__name__)
@console_ns.route("/workspaces/current/trigger-provider/<path:provider>/icon")
class TriggerProviderIconApi(Resource):
@setup_required
@login_required
@ -38,6 +40,7 @@ class TriggerProviderIconApi(Resource):
return TriggerManager.get_trigger_plugin_icon(tenant_id=user.current_tenant_id, provider_id=provider)
@console_ns.route("/workspaces/current/triggers")
class TriggerProviderListApi(Resource):
@setup_required
@login_required
@ -50,6 +53,7 @@ class TriggerProviderListApi(Resource):
return jsonable_encoder(TriggerProviderService.list_trigger_providers(user.current_tenant_id))
@console_ns.route("/workspaces/current/trigger-provider/<path:provider>/info")
class TriggerProviderInfoApi(Resource):
@setup_required
@login_required
@ -64,6 +68,7 @@ class TriggerProviderInfoApi(Resource):
)
@console_ns.route("/workspaces/current/trigger-provider/<path:provider>/subscriptions/list")
class TriggerSubscriptionListApi(Resource):
@setup_required
@login_required
@ -87,7 +92,16 @@ class TriggerSubscriptionListApi(Resource):
raise
parser = reqparse.RequestParser().add_argument(
"credential_type", type=str, required=False, nullable=True, location="json"
)
@console_ns.route(
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/create",
)
class TriggerSubscriptionBuilderCreateApi(Resource):
@console_ns.expect(parser)
@setup_required
@login_required
@is_admin_or_owner_required
@ -97,9 +111,6 @@ class TriggerSubscriptionBuilderCreateApi(Resource):
user = current_user
assert user.current_tenant_id is not None
parser = reqparse.RequestParser().add_argument(
"credential_type", type=str, required=False, nullable=True, location="json"
)
args = parser.parse_args()
try:
@ -116,6 +127,9 @@ class TriggerSubscriptionBuilderCreateApi(Resource):
raise
@console_ns.route(
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/<path:subscription_builder_id>",
)
class TriggerSubscriptionBuilderGetApi(Resource):
@setup_required
@login_required
@ -127,7 +141,18 @@ class TriggerSubscriptionBuilderGetApi(Resource):
)
parser_api = (
reqparse.RequestParser()
# The credentials of the subscription builder
.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
)
@console_ns.route(
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/verify/<path:subscription_builder_id>",
)
class TriggerSubscriptionBuilderVerifyApi(Resource):
@console_ns.expect(parser_api)
@setup_required
@login_required
@is_admin_or_owner_required
@ -136,12 +161,8 @@ class TriggerSubscriptionBuilderVerifyApi(Resource):
"""Verify a subscription instance for a trigger provider"""
user = current_user
assert user.current_tenant_id is not None
parser = (
reqparse.RequestParser()
# The credentials of the subscription builder
.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
)
args = parser.parse_args()
args = parser_api.parse_args()
try:
# Use atomic update_and_verify to prevent race conditions
@ -159,7 +180,24 @@ class TriggerSubscriptionBuilderVerifyApi(Resource):
raise ValueError(str(e)) from e
parser_update_api = (
reqparse.RequestParser()
# The name of the subscription builder
.add_argument("name", type=str, required=False, nullable=True, location="json")
# The parameters of the subscription builder
.add_argument("parameters", type=dict, required=False, nullable=True, location="json")
# The properties of the subscription builder
.add_argument("properties", type=dict, required=False, nullable=True, location="json")
# The credentials of the subscription builder
.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
)
@console_ns.route(
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/update/<path:subscription_builder_id>",
)
class TriggerSubscriptionBuilderUpdateApi(Resource):
@console_ns.expect(parser_update_api)
@setup_required
@login_required
@account_initialization_required
@ -169,18 +207,7 @@ class TriggerSubscriptionBuilderUpdateApi(Resource):
assert isinstance(user, Account)
assert user.current_tenant_id is not None
parser = (
reqparse.RequestParser()
# The name of the subscription builder
.add_argument("name", type=str, required=False, nullable=True, location="json")
# The parameters of the subscription builder
.add_argument("parameters", type=dict, required=False, nullable=True, location="json")
# The properties of the subscription builder
.add_argument("properties", type=dict, required=False, nullable=True, location="json")
# The credentials of the subscription builder
.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
)
args = parser.parse_args()
args = parser_update_api.parse_args()
try:
return jsonable_encoder(
TriggerSubscriptionBuilderService.update_trigger_subscription_builder(
@ -200,6 +227,9 @@ class TriggerSubscriptionBuilderUpdateApi(Resource):
raise
@console_ns.route(
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/logs/<path:subscription_builder_id>",
)
class TriggerSubscriptionBuilderLogsApi(Resource):
@setup_required
@login_required
@ -218,7 +248,11 @@ class TriggerSubscriptionBuilderLogsApi(Resource):
raise
@console_ns.route(
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/build/<path:subscription_builder_id>",
)
class TriggerSubscriptionBuilderBuildApi(Resource):
@console_ns.expect(parser_update_api)
@setup_required
@login_required
@is_admin_or_owner_required
@ -227,18 +261,7 @@ class TriggerSubscriptionBuilderBuildApi(Resource):
"""Build a subscription instance for a trigger provider"""
user = current_user
assert user.current_tenant_id is not None
parser = (
reqparse.RequestParser()
# The name of the subscription builder
.add_argument("name", type=str, required=False, nullable=True, location="json")
# The parameters of the subscription builder
.add_argument("parameters", type=dict, required=False, nullable=True, location="json")
# The properties of the subscription builder
.add_argument("properties", type=dict, required=False, nullable=True, location="json")
# The credentials of the subscription builder
.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
)
args = parser.parse_args()
args = parser_update_api.parse_args()
try:
# Use atomic update_and_build to prevent race conditions
TriggerSubscriptionBuilderService.update_and_build_builder(
@ -258,6 +281,9 @@ class TriggerSubscriptionBuilderBuildApi(Resource):
raise ValueError(str(e)) from e
@console_ns.route(
"/workspaces/current/trigger-provider/<path:subscription_id>/subscriptions/delete",
)
class TriggerSubscriptionDeleteApi(Resource):
@setup_required
@login_required
@ -291,6 +317,7 @@ class TriggerSubscriptionDeleteApi(Resource):
raise
@console_ns.route("/workspaces/current/trigger-provider/<path:provider>/subscriptions/oauth/authorize")
class TriggerOAuthAuthorizeApi(Resource):
@setup_required
@login_required
@ -374,6 +401,7 @@ class TriggerOAuthAuthorizeApi(Resource):
raise
@console_ns.route("/oauth/plugin/<path:provider>/trigger/callback")
class TriggerOAuthCallbackApi(Resource):
@setup_required
def get(self, provider):
@ -438,6 +466,14 @@ class TriggerOAuthCallbackApi(Resource):
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")
parser_oauth_client = (
reqparse.RequestParser()
.add_argument("client_params", type=dict, required=False, nullable=True, location="json")
.add_argument("enabled", type=bool, required=False, nullable=True, location="json")
)
@console_ns.route("/workspaces/current/trigger-provider/<path:provider>/oauth/client")
class TriggerOAuthClientManageApi(Resource):
@setup_required
@login_required
@ -484,6 +520,7 @@ class TriggerOAuthClientManageApi(Resource):
logger.exception("Error getting OAuth client", exc_info=e)
raise
@console_ns.expect(parser_oauth_client)
@setup_required
@login_required
@is_admin_or_owner_required
@ -493,12 +530,7 @@ class TriggerOAuthClientManageApi(Resource):
user = current_user
assert user.current_tenant_id is not None
parser = (
reqparse.RequestParser()
.add_argument("client_params", type=dict, required=False, nullable=True, location="json")
.add_argument("enabled", type=bool, required=False, nullable=True, location="json")
)
args = parser.parse_args()
args = parser_oauth_client.parse_args()
try:
provider_id = TriggerProviderID(provider)
@ -536,52 +568,3 @@ class TriggerOAuthClientManageApi(Resource):
except Exception as e:
logger.exception("Error removing OAuth client", exc_info=e)
raise
# Trigger Subscription
console_ns.add_resource(TriggerProviderIconApi, "/workspaces/current/trigger-provider/<path:provider>/icon")
console_ns.add_resource(TriggerProviderListApi, "/workspaces/current/triggers")
console_ns.add_resource(TriggerProviderInfoApi, "/workspaces/current/trigger-provider/<path:provider>/info")
console_ns.add_resource(
TriggerSubscriptionListApi, "/workspaces/current/trigger-provider/<path:provider>/subscriptions/list"
)
console_ns.add_resource(
TriggerSubscriptionDeleteApi,
"/workspaces/current/trigger-provider/<path:subscription_id>/subscriptions/delete",
)
# Trigger Subscription Builder
console_ns.add_resource(
TriggerSubscriptionBuilderCreateApi,
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/create",
)
console_ns.add_resource(
TriggerSubscriptionBuilderGetApi,
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/<path:subscription_builder_id>",
)
console_ns.add_resource(
TriggerSubscriptionBuilderUpdateApi,
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/update/<path:subscription_builder_id>",
)
console_ns.add_resource(
TriggerSubscriptionBuilderVerifyApi,
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/verify/<path:subscription_builder_id>",
)
console_ns.add_resource(
TriggerSubscriptionBuilderBuildApi,
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/build/<path:subscription_builder_id>",
)
console_ns.add_resource(
TriggerSubscriptionBuilderLogsApi,
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/logs/<path:subscription_builder_id>",
)
# OAuth
console_ns.add_resource(
TriggerOAuthAuthorizeApi, "/workspaces/current/trigger-provider/<path:provider>/subscriptions/oauth/authorize"
)
console_ns.add_resource(TriggerOAuthCallbackApi, "/oauth/plugin/<path:provider>/trigger/callback")
console_ns.add_resource(
TriggerOAuthClientManageApi, "/workspaces/current/trigger-provider/<path:provider>/oauth/client"
)

View File

@ -1,7 +1,8 @@
import logging
from flask import request
from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse
from flask_restx import Resource, fields, marshal, marshal_with
from pydantic import BaseModel, Field
from sqlalchemy import select
from werkzeug.exceptions import Unauthorized
@ -32,6 +33,45 @@ from services.file_service import FileService
from services.workspace_service import WorkspaceService
logger = logging.getLogger(__name__)
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class WorkspaceListQuery(BaseModel):
page: int = Field(default=1, ge=1, le=99999)
limit: int = Field(default=20, ge=1, le=100)
class SwitchWorkspacePayload(BaseModel):
tenant_id: str
class WorkspaceCustomConfigPayload(BaseModel):
remove_webapp_brand: bool | None = None
replace_webapp_logo: str | None = None
class WorkspaceInfoPayload(BaseModel):
name: str
console_ns.schema_model(
WorkspaceListQuery.__name__, WorkspaceListQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
console_ns.schema_model(
SwitchWorkspacePayload.__name__,
SwitchWorkspacePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
WorkspaceCustomConfigPayload.__name__,
WorkspaceCustomConfigPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
WorkspaceInfoPayload.__name__,
WorkspaceInfoPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
provider_fields = {
@ -95,18 +135,15 @@ class TenantListApi(Resource):
@console_ns.route("/all-workspaces")
class WorkspaceListApi(Resource):
@console_ns.expect(console_ns.models[WorkspaceListQuery.__name__])
@setup_required
@admin_required
def get(self):
parser = (
reqparse.RequestParser()
.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args")
.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
)
args = parser.parse_args()
payload = request.args.to_dict(flat=True) # type: ignore
args = WorkspaceListQuery.model_validate(payload)
stmt = select(Tenant).order_by(Tenant.created_at.desc())
tenants = db.paginate(select=stmt, page=args["page"], per_page=args["limit"], error_out=False)
tenants = db.paginate(select=stmt, page=args.page, per_page=args.limit, error_out=False)
has_more = False
if tenants.has_next:
@ -115,8 +152,8 @@ class WorkspaceListApi(Resource):
return {
"data": marshal(tenants.items, workspace_fields),
"has_more": has_more,
"limit": args["limit"],
"page": args["page"],
"limit": args.limit,
"page": args.page,
"total": tenants.total,
}, 200
@ -150,26 +187,24 @@ class TenantApi(Resource):
return WorkspaceService.get_tenant_info(tenant), 200
parser_switch = reqparse.RequestParser().add_argument("tenant_id", type=str, required=True, location="json")
@console_ns.route("/workspaces/switch")
class SwitchWorkspaceApi(Resource):
@console_ns.expect(parser_switch)
@console_ns.expect(console_ns.models[SwitchWorkspacePayload.__name__])
@setup_required
@login_required
@account_initialization_required
def post(self):
current_user, _ = current_account_with_tenant()
args = parser_switch.parse_args()
payload = console_ns.payload or {}
args = SwitchWorkspacePayload.model_validate(payload)
# check if tenant_id is valid, 403 if not
try:
TenantService.switch_tenant(current_user, args["tenant_id"])
TenantService.switch_tenant(current_user, args.tenant_id)
except Exception:
raise AccountNotLinkTenantError("Account not link tenant")
new_tenant = db.session.query(Tenant).get(args["tenant_id"]) # Get new tenant
new_tenant = db.session.query(Tenant).get(args.tenant_id) # Get new tenant
if new_tenant is None:
raise ValueError("Tenant not found")
@ -178,24 +213,21 @@ class SwitchWorkspaceApi(Resource):
@console_ns.route("/workspaces/custom-config")
class CustomConfigWorkspaceApi(Resource):
@console_ns.expect(console_ns.models[WorkspaceCustomConfigPayload.__name__])
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check("workspace_custom")
def post(self):
_, current_tenant_id = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("remove_webapp_brand", type=bool, location="json")
.add_argument("replace_webapp_logo", type=str, location="json")
)
args = parser.parse_args()
payload = console_ns.payload or {}
args = WorkspaceCustomConfigPayload.model_validate(payload)
tenant = db.get_or_404(Tenant, current_tenant_id)
custom_config_dict = {
"remove_webapp_brand": args["remove_webapp_brand"],
"replace_webapp_logo": args["replace_webapp_logo"]
if args["replace_webapp_logo"] is not None
"remove_webapp_brand": args.remove_webapp_brand,
"replace_webapp_logo": args.replace_webapp_logo
if args.replace_webapp_logo is not None
else tenant.custom_config_dict.get("replace_webapp_logo"),
}
@ -245,24 +277,22 @@ class WebappLogoWorkspaceApi(Resource):
return {"id": upload_file.id}, 201
parser_info = reqparse.RequestParser().add_argument("name", type=str, required=True, location="json")
@console_ns.route("/workspaces/info")
class WorkspaceInfoApi(Resource):
@console_ns.expect(parser_info)
@console_ns.expect(console_ns.models[WorkspaceInfoPayload.__name__])
@setup_required
@login_required
@account_initialization_required
# Change workspace name
def post(self):
_, current_tenant_id = current_account_with_tenant()
args = parser_info.parse_args()
payload = console_ns.payload or {}
args = WorkspaceInfoPayload.model_validate(payload)
if not current_tenant_id:
raise ValueError("No current tenant")
tenant = db.get_or_404(Tenant, current_tenant_id)
tenant.name = args["name"]
tenant.name = args.name
db.session.commit()
return {"result": "success", "tenant": marshal(WorkspaceService.get_tenant_info(tenant), tenant_fields)}

View File

@ -17,7 +17,6 @@ from controllers.service_api.app.error import (
)
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import (
ModelCurrentlyNotSupportError,
@ -30,6 +29,7 @@ from libs import helper
from libs.helper import uuid_value
from models.model import App, AppMode, EndUser
from services.app_generate_service import AppGenerateService
from services.app_task_service import AppTaskService
from services.errors.app import IsDraftWorkflowError, WorkflowIdFormatError, WorkflowNotFoundError
from services.errors.llm import InvokeRateLimitError
@ -88,7 +88,7 @@ class CompletionApi(Resource):
This endpoint generates a completion based on the provided inputs and query.
Supports both blocking and streaming response modes.
"""
if app_model.mode != "completion":
if app_model.mode != AppMode.COMPLETION:
raise AppUnavailableError()
args = completion_parser.parse_args()
@ -147,10 +147,15 @@ class CompletionStopApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
def post(self, app_model: App, end_user: EndUser, task_id: str):
"""Stop a running completion task."""
if app_model.mode != "completion":
if app_model.mode != AppMode.COMPLETION:
raise AppUnavailableError()
AppQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id)
AppTaskService.stop_task(
task_id=task_id,
invoke_from=InvokeFrom.SERVICE_API,
user_id=end_user.id,
app_mode=AppMode.value_of(app_model.mode),
)
return {"result": "success"}, 200
@ -244,6 +249,11 @@ class ChatStopApi(Resource):
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
AppQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id)
AppTaskService.stop_task(
task_id=task_id,
invoke_from=InvokeFrom.SERVICE_API,
user_id=end_user.id,
app_mode=app_mode,
)
return {"result": "success"}, 200

View File

@ -1,7 +1,7 @@
import logging
import time
from flask import jsonify
from flask import jsonify, request
from werkzeug.exceptions import NotFound, RequestEntityTooLarge
from controllers.trigger import bp
@ -28,8 +28,14 @@ def _prepare_webhook_execution(webhook_id: str, is_debug: bool = False):
webhook_data = WebhookService.extract_and_validate_webhook_data(webhook_trigger, node_config)
return webhook_trigger, workflow, node_config, webhook_data, None
except ValueError as e:
# Fall back to raw extraction for error reporting
webhook_data = WebhookService.extract_webhook_data(webhook_trigger)
# Provide minimal context for error reporting without risking another parse failure
webhook_data = {
"method": request.method,
"headers": dict(request.headers),
"query_params": dict(request.args),
"body": {},
"files": {},
}
return webhook_trigger, workflow, node_config, webhook_data, str(e)

View File

@ -17,7 +17,6 @@ from controllers.web.error import (
)
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
from controllers.web.wraps import WebApiResource
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import (
ModelCurrentlyNotSupportError,
@ -29,6 +28,7 @@ from libs import helper
from libs.helper import uuid_value
from models.model import AppMode
from services.app_generate_service import AppGenerateService
from services.app_task_service import AppTaskService
from services.errors.llm import InvokeRateLimitError
logger = logging.getLogger(__name__)
@ -64,7 +64,7 @@ class CompletionApi(WebApiResource):
}
)
def post(self, app_model, end_user):
if app_model.mode != "completion":
if app_model.mode != AppMode.COMPLETION:
raise NotCompletionAppError()
parser = (
@ -125,10 +125,15 @@ class CompletionStopApi(WebApiResource):
}
)
def post(self, app_model, end_user, task_id):
if app_model.mode != "completion":
if app_model.mode != AppMode.COMPLETION:
raise NotCompletionAppError()
AppQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id)
AppTaskService.stop_task(
task_id=task_id,
invoke_from=InvokeFrom.WEB_APP,
user_id=end_user.id,
app_mode=AppMode.value_of(app_model.mode),
)
return {"result": "success"}, 200
@ -234,6 +239,11 @@ class ChatStopApi(WebApiResource):
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
AppQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id)
AppTaskService.stop_task(
task_id=task_id,
invoke_from=InvokeFrom.WEB_APP,
user_id=end_user.id,
app_mode=app_mode,
)
return {"result": "success"}, 200

View File

@ -62,7 +62,8 @@ from core.app.task_pipeline.message_cycle_manager import MessageCycleManager
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.utils.encoders import jsonable_encoder
from core.ops.ops_trace_manager import TraceQueueManager
from core.ops.entities.trace_entity import TraceTaskName
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
from core.workflow.enums import WorkflowExecutionStatus
from core.workflow.nodes import NodeType
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
@ -72,7 +73,7 @@ from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models import Account, Conversation, EndUser, Message, MessageFile
from models.enums import CreatorUserRole
from models.workflow import Workflow
from models.workflow import Workflow, WorkflowNodeExecutionModel
logger = logging.getLogger(__name__)
@ -580,7 +581,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
with self._database_session() as session:
# Save message
self._save_message(session=session, graph_runtime_state=resolved_state)
self._save_message(session=session, graph_runtime_state=resolved_state, trace_manager=trace_manager)
yield workflow_finish_resp
elif event.stopped_by in (
@ -590,7 +591,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
# When hitting input-moderation or annotation-reply, the workflow will not start
with self._database_session() as session:
# Save message
self._save_message(session=session)
self._save_message(session=session, trace_manager=trace_manager)
yield self._message_end_to_stream_response()
@ -599,6 +600,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
event: QueueAdvancedChatMessageEndEvent,
*,
graph_runtime_state: GraphRuntimeState | None = None,
trace_manager: TraceQueueManager | None = None,
**kwargs,
) -> Generator[StreamResponse, None, None]:
"""Handle advanced chat message end events."""
@ -616,7 +618,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
# Save message
with self._database_session() as session:
self._save_message(session=session, graph_runtime_state=resolved_state)
self._save_message(session=session, graph_runtime_state=resolved_state, trace_manager=trace_manager)
yield self._message_end_to_stream_response()
@ -770,7 +772,13 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
if self._conversation_name_generate_thread:
self._conversation_name_generate_thread.join()
def _save_message(self, *, session: Session, graph_runtime_state: GraphRuntimeState | None = None):
def _save_message(
self,
*,
session: Session,
graph_runtime_state: GraphRuntimeState | None = None,
trace_manager: TraceQueueManager | None = None,
):
message = self._get_message(session=session)
# If there are assistant files, remove markdown image links from answer
@ -809,6 +817,14 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
metadata = self._task_state.metadata.model_dump()
message.message_metadata = json.dumps(jsonable_encoder(metadata))
# Extract model provider and model_id from workflow node executions for tracing
if message.workflow_run_id:
model_info = self._extract_model_info_from_workflow(session, message.workflow_run_id)
if model_info:
message.model_provider = model_info.get("provider")
message.model_id = model_info.get("model")
message_files = [
MessageFile(
message_id=message.id,
@ -826,6 +842,68 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
]
session.add_all(message_files)
# Trigger MESSAGE_TRACE for tracing integrations
if trace_manager:
trace_manager.add_trace_task(
TraceTask(
TraceTaskName.MESSAGE_TRACE, conversation_id=self._conversation_id, message_id=self._message_id
)
)
def _extract_model_info_from_workflow(self, session: Session, workflow_run_id: str) -> dict[str, str] | None:
"""
Extract model provider and model_id from workflow node executions.
Returns dict with 'provider' and 'model' keys, or None if not found.
"""
try:
# Query workflow node executions for LLM or Agent nodes
stmt = (
select(WorkflowNodeExecutionModel)
.where(WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id)
.where(WorkflowNodeExecutionModel.node_type.in_(["llm", "agent"]))
.order_by(WorkflowNodeExecutionModel.created_at.desc())
.limit(1)
)
node_execution = session.scalar(stmt)
if not node_execution:
return None
# Try to extract from execution_metadata for agent nodes
if node_execution.execution_metadata:
try:
metadata = json.loads(node_execution.execution_metadata)
agent_log = metadata.get("agent_log", [])
# Look for the first agent thought with provider info
for log_entry in agent_log:
entry_metadata = log_entry.get("metadata", {})
provider_str = entry_metadata.get("provider")
if provider_str:
# Parse format like "langgenius/deepseek/deepseek"
parts = provider_str.split("/")
if len(parts) >= 3:
return {"provider": parts[1], "model": parts[2]}
elif len(parts) == 2:
return {"provider": parts[0], "model": parts[1]}
except (json.JSONDecodeError, KeyError, AttributeError) as e:
logger.debug("Failed to parse execution_metadata: %s", e)
# Try to extract from process_data for llm nodes
if node_execution.process_data:
try:
process_data = json.loads(node_execution.process_data)
provider = process_data.get("model_provider")
model = process_data.get("model_name")
if provider and model:
return {"provider": provider, "model": model}
except (json.JSONDecodeError, KeyError) as e:
logger.debug("Failed to parse process_data: %s", e)
return None
except Exception as e:
logger.warning("Failed to extract model info from workflow: %s", e)
return None
def _seed_graph_runtime_state_from_queue_manager(self) -> None:
"""Bootstrap the cached runtime state from the queue manager when present."""
candidate = self._base_task_pipeline.queue_manager.graph_runtime_state

View File

@ -155,8 +155,17 @@ class BaseAppGenerator:
f"{variable_entity.variable} in input form must be less than {variable_entity.max_length} files"
)
case VariableEntityType.CHECKBOX:
if not isinstance(value, bool):
raise ValueError(f"{variable_entity.variable} in input form must be a valid boolean value")
if isinstance(value, str):
normalized_value = value.strip().lower()
if normalized_value in {"true", "1", "yes", "on"}:
value = True
elif normalized_value in {"false", "0", "no", "off"}:
value = False
elif isinstance(value, (int, float)):
if value == 1:
value = True
elif value == 0:
value = False
case _:
raise AssertionError("this statement should be unreachable.")

View File

@ -40,6 +40,9 @@ class EasyUITaskState(TaskState):
"""
llm_result: LLMResult
first_token_time: float | None = None
last_token_time: float | None = None
is_streaming_response: bool = False
class WorkflowTaskState(TaskState):

View File

@ -118,6 +118,7 @@ class PauseStatePersistenceLayer(GraphEngineLayer):
workflow_run_id=workflow_run_id,
state_owner_user_id=self._state_owner_user_id,
state=state.dumps(),
pause_reasons=event.reasons,
)
def on_graph_end(self, error: Exception | None) -> None:

View File

@ -332,6 +332,12 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
if not self._task_state.llm_result.prompt_messages:
self._task_state.llm_result.prompt_messages = chunk.prompt_messages
# Track streaming response times
if self._task_state.first_token_time is None:
self._task_state.first_token_time = time.perf_counter()
self._task_state.is_streaming_response = True
self._task_state.last_token_time = time.perf_counter()
# handle output moderation chunk
should_direct_answer = self._handle_output_moderation_chunk(cast(str, delta_text))
if should_direct_answer:
@ -398,6 +404,18 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
message.total_price = usage.total_price
message.currency = usage.currency
self._task_state.llm_result.usage.latency = message.provider_response_latency
# Add streaming metrics to usage if available
if self._task_state.is_streaming_response and self._task_state.first_token_time:
start_time = self.start_at
first_token_time = self._task_state.first_token_time
last_token_time = self._task_state.last_token_time or first_token_time
usage.time_to_first_token = round(first_token_time - start_time, 3)
usage.time_to_generate = round(last_token_time - first_token_time, 3)
# Update metadata with the complete usage info
self._task_state.metadata.usage = usage
message.message_metadata = self._task_state.metadata.model_dump_json()
if trace_manager:

View File

@ -222,6 +222,59 @@ class TencentSpanBuilder:
links=links,
)
@staticmethod
def build_message_llm_span(
trace_info: MessageTraceInfo, trace_id: int, parent_span_id: int, user_id: str
) -> SpanData:
"""Build LLM span for message traces with detailed LLM attributes."""
status = Status(StatusCode.OK)
if trace_info.error:
status = Status(StatusCode.ERROR, trace_info.error)
# Extract model information from `metadata`` or `message_data`
trace_metadata = trace_info.metadata or {}
message_data = trace_info.message_data or {}
model_provider = trace_metadata.get("ls_provider") or (
message_data.get("model_provider", "") if isinstance(message_data, dict) else ""
)
model_name = trace_metadata.get("ls_model_name") or (
message_data.get("model_id", "") if isinstance(message_data, dict) else ""
)
inputs_str = str(trace_info.inputs or "")
outputs_str = str(trace_info.outputs or "")
attributes = {
GEN_AI_SESSION_ID: trace_metadata.get("conversation_id", ""),
GEN_AI_USER_ID: str(user_id),
GEN_AI_SPAN_KIND: GenAISpanKind.GENERATION.value,
GEN_AI_FRAMEWORK: "dify",
GEN_AI_MODEL_NAME: str(model_name),
GEN_AI_PROVIDER: str(model_provider),
GEN_AI_USAGE_INPUT_TOKENS: str(trace_info.message_tokens or 0),
GEN_AI_USAGE_OUTPUT_TOKENS: str(trace_info.answer_tokens or 0),
GEN_AI_USAGE_TOTAL_TOKENS: str(trace_info.total_tokens or 0),
GEN_AI_PROMPT: inputs_str,
GEN_AI_COMPLETION: outputs_str,
INPUT_VALUE: inputs_str,
OUTPUT_VALUE: outputs_str,
}
if trace_info.is_streaming_request:
attributes[GEN_AI_IS_STREAMING_REQUEST] = "true"
return SpanData(
trace_id=trace_id,
parent_span_id=parent_span_id,
span_id=TencentTraceUtils.convert_to_span_id(trace_info.message_id, "llm"),
name="GENERATION",
start_time=TencentSpanBuilder._get_time_nanoseconds(trace_info.start_time),
end_time=TencentSpanBuilder._get_time_nanoseconds(trace_info.end_time),
attributes=attributes,
status=status,
)
@staticmethod
def build_tool_span(trace_info: ToolTraceInfo, trace_id: int, parent_span_id: int) -> SpanData:
"""Build tool span."""

View File

@ -107,9 +107,13 @@ class TencentDataTrace(BaseTraceInstance):
links.append(TencentTraceUtils.create_link(trace_info.trace_id))
message_span = TencentSpanBuilder.build_message_span(trace_info, trace_id, str(user_id), links)
self.trace_client.add_span(message_span)
# Add LLM child span with detailed attributes
parent_span_id = TencentTraceUtils.convert_to_span_id(trace_info.message_id, "message")
llm_span = TencentSpanBuilder.build_message_llm_span(trace_info, trace_id, parent_span_id, str(user_id))
self.trace_client.add_span(llm_span)
self._record_message_llm_metrics(trace_info)
# Record trace duration for entry span

View File

@ -1,20 +1,110 @@
import re
from operator import itemgetter
from typing import cast
class JiebaKeywordTableHandler:
def __init__(self):
from core.rag.datasource.keyword.jieba.stopwords import STOPWORDS
tfidf = self._load_tfidf_extractor()
tfidf.stop_words = STOPWORDS # type: ignore[attr-defined]
self._tfidf = tfidf
def _load_tfidf_extractor(self):
"""
Load jieba TFIDF extractor with fallback strategy.
Loading Flow:
jieba.analyse.default_tfidf
exists?
YES NO
Return default jieba.analyse.TFIDF exists?
TFIDF
YES NO
Try import from
jieba.analyse.tfidf.TFIDF
SUCCESS FAILED
Instantiate TFIDF() Build fallback
& cache to default _SimpleTFIDF
"""
import jieba.analyse # type: ignore
tfidf = getattr(jieba.analyse, "default_tfidf", None)
if tfidf is not None:
return tfidf
tfidf_class = getattr(jieba.analyse, "TFIDF", None)
if tfidf_class is None:
try:
from jieba.analyse.tfidf import TFIDF # type: ignore
tfidf_class = TFIDF
except Exception:
tfidf_class = None
if tfidf_class is not None:
tfidf = tfidf_class()
jieba.analyse.default_tfidf = tfidf # type: ignore[attr-defined]
return tfidf
return self._build_fallback_tfidf()
@staticmethod
def _build_fallback_tfidf():
"""Fallback lightweight TFIDF for environments missing jieba's TFIDF."""
import jieba # type: ignore
from core.rag.datasource.keyword.jieba.stopwords import STOPWORDS
jieba.analyse.default_tfidf.stop_words = STOPWORDS # type: ignore
class _SimpleTFIDF:
def __init__(self):
self.stop_words = STOPWORDS
self._lcut = getattr(jieba, "lcut", None)
def extract_tags(self, sentence: str, top_k: int | None = 20, **kwargs):
# Basic frequency-based keyword extraction as a fallback when TF-IDF is unavailable.
top_k = kwargs.pop("topK", top_k)
cut = getattr(jieba, "cut", None)
if self._lcut:
tokens = self._lcut(sentence)
elif callable(cut):
tokens = list(cut(sentence))
else:
tokens = re.findall(r"\w+", sentence)
words = [w for w in tokens if w and w not in self.stop_words]
freq: dict[str, int] = {}
for w in words:
freq[w] = freq.get(w, 0) + 1
sorted_words = sorted(freq.items(), key=itemgetter(1), reverse=True)
if top_k is not None:
sorted_words = sorted_words[:top_k]
return [item[0] for item in sorted_words]
return _SimpleTFIDF()
def extract_keywords(self, text: str, max_keywords_per_chunk: int | None = 10) -> set[str]:
"""Extract keywords with JIEBA tfidf."""
import jieba.analyse # type: ignore
keywords = jieba.analyse.extract_tags(
keywords = self._tfidf.extract_tags(
sentence=text,
topK=max_keywords_per_chunk,
)

View File

@ -302,8 +302,7 @@ class OracleVector(BaseVector):
nltk.data.find("tokenizers/punkt")
nltk.data.find("corpora/stopwords")
except LookupError:
nltk.download("punkt")
nltk.download("stopwords")
raise LookupError("Unable to find the required NLTK data package: punkt and stopwords")
e_str = re.sub(r"[^\w ]", "", query)
all_tokens = nltk.word_tokenize(e_str)
stop_words = stopwords.words("english")

View File

@ -167,13 +167,18 @@ class WeaviateVector(BaseVector):
try:
if not self._client.collections.exists(self._collection_name):
tokenization = (
wc.Tokenization(dify_config.WEAVIATE_TOKENIZATION)
if dify_config.WEAVIATE_TOKENIZATION
else wc.Tokenization.WORD
)
self._client.collections.create(
name=self._collection_name,
properties=[
wc.Property(
name=Field.TEXT_KEY.value,
data_type=wc.DataType.TEXT,
tokenization=wc.Tokenization.WORD,
tokenization=tokenization,
),
wc.Property(name="document_id", data_type=wc.DataType.TEXT),
wc.Property(name="doc_id", data_type=wc.DataType.TEXT),

View File

@ -141,6 +141,7 @@ class WorkflowToolProviderController(ToolProviderController):
form=parameter.form,
llm_description=parameter.description,
required=variable.required,
default=variable.default,
options=options,
placeholder=I18nObject(en_US="", zh_Hans=""),
)

View File

@ -71,6 +71,11 @@ class TriggerProviderIdentity(BaseModel):
icon_dark: str | None = Field(default=None, description="The dark icon of the trigger provider")
tags: list[str] = Field(default_factory=list, description="The tags of the trigger provider")
@field_validator("tags", mode="before")
@classmethod
def validate_tags(cls, v: list[str] | None) -> list[str]:
return v or []
class EventIdentity(BaseModel):
"""

View File

@ -1,17 +1,11 @@
from ..runtime.graph_runtime_state import GraphRuntimeState
from ..runtime.variable_pool import VariablePool
from .agent import AgentNodeStrategyInit
from .graph_init_params import GraphInitParams
from .workflow_execution import WorkflowExecution
from .workflow_node_execution import WorkflowNodeExecution
from .workflow_pause import WorkflowPauseEntity
__all__ = [
"AgentNodeStrategyInit",
"GraphInitParams",
"GraphRuntimeState",
"VariablePool",
"WorkflowExecution",
"WorkflowNodeExecution",
"WorkflowPauseEntity",
]

View File

@ -1,49 +1,26 @@
from enum import StrEnum, auto
from typing import Annotated, Any, ClassVar, TypeAlias
from typing import Annotated, Literal, TypeAlias
from pydantic import BaseModel, Discriminator, Tag
from pydantic import BaseModel, Field
class _PauseReasonType(StrEnum):
class PauseReasonType(StrEnum):
HUMAN_INPUT_REQUIRED = auto()
SCHEDULED_PAUSE = auto()
class _PauseReasonBase(BaseModel):
TYPE: ClassVar[_PauseReasonType]
class HumanInputRequired(BaseModel):
TYPE: Literal[PauseReasonType.HUMAN_INPUT_REQUIRED] = PauseReasonType.HUMAN_INPUT_REQUIRED
form_id: str
# The identifier of the human input node causing the pause.
node_id: str
class HumanInputRequired(_PauseReasonBase):
TYPE = _PauseReasonType.HUMAN_INPUT_REQUIRED
class SchedulingPause(_PauseReasonBase):
TYPE = _PauseReasonType.SCHEDULED_PAUSE
class SchedulingPause(BaseModel):
TYPE: Literal[PauseReasonType.SCHEDULED_PAUSE] = PauseReasonType.SCHEDULED_PAUSE
message: str
def _get_pause_reason_discriminator(v: Any) -> _PauseReasonType | None:
if isinstance(v, _PauseReasonBase):
return v.TYPE
elif isinstance(v, dict):
reason_type_str = v.get("TYPE")
if reason_type_str is None:
return None
try:
reason_type = _PauseReasonType(reason_type_str)
except ValueError:
return None
return reason_type
else:
# return None if the discriminator value isn't found
return None
PauseReason: TypeAlias = Annotated[
(
Annotated[HumanInputRequired, Tag(_PauseReasonType.HUMAN_INPUT_REQUIRED)]
| Annotated[SchedulingPause, Tag(_PauseReasonType.SCHEDULED_PAUSE)]
),
Discriminator(_get_pause_reason_discriminator),
]
PauseReason: TypeAlias = Annotated[HumanInputRequired | SchedulingPause, Field(discriminator="TYPE")]

View File

@ -42,7 +42,7 @@ class GraphExecutionState(BaseModel):
completed: bool = Field(default=False)
aborted: bool = Field(default=False)
paused: bool = Field(default=False)
pause_reason: PauseReason | None = Field(default=None)
pause_reasons: list[PauseReason] = Field(default_factory=list)
error: GraphExecutionErrorState | None = Field(default=None)
exceptions_count: int = Field(default=0)
node_executions: list[NodeExecutionState] = Field(default_factory=list[NodeExecutionState])
@ -107,7 +107,7 @@ class GraphExecution:
completed: bool = False
aborted: bool = False
paused: bool = False
pause_reason: PauseReason | None = None
pause_reasons: list[PauseReason] = field(default_factory=list)
error: Exception | None = None
node_executions: dict[str, NodeExecution] = field(default_factory=dict[str, NodeExecution])
exceptions_count: int = 0
@ -137,10 +137,8 @@ class GraphExecution:
raise RuntimeError("Cannot pause execution that has completed")
if self.aborted:
raise RuntimeError("Cannot pause execution that has been aborted")
if self.paused:
return
self.paused = True
self.pause_reason = reason
self.pause_reasons.append(reason)
def fail(self, error: Exception) -> None:
"""Mark the graph execution as failed."""
@ -195,7 +193,7 @@ class GraphExecution:
completed=self.completed,
aborted=self.aborted,
paused=self.paused,
pause_reason=self.pause_reason,
pause_reasons=self.pause_reasons,
error=_serialize_error(self.error),
exceptions_count=self.exceptions_count,
node_executions=node_states,
@ -221,7 +219,7 @@ class GraphExecution:
self.completed = state.completed
self.aborted = state.aborted
self.paused = state.paused
self.pause_reason = state.pause_reason
self.pause_reasons = state.pause_reasons
self.error = _deserialize_error(state.error)
self.exceptions_count = state.exceptions_count
self.node_executions = {

View File

@ -110,7 +110,13 @@ class EventManager:
"""
with self._lock.write_lock():
self._events.append(event)
self._notify_layers(event)
# NOTE: `_notify_layers` is intentionally called outside the critical section
# to minimize lock contention and avoid blocking other readers or writers.
#
# The public `notify_layers` method also does not use a write lock,
# so protecting `_notify_layers` with a lock here is unnecessary.
self._notify_layers(event)
def _get_new_events(self, start_index: int) -> list[GraphEngineEvent]:
"""

View File

@ -232,7 +232,7 @@ class GraphEngine:
self._graph_execution.start()
else:
self._graph_execution.paused = False
self._graph_execution.pause_reason = None
self._graph_execution.pause_reasons = []
start_event = GraphRunStartedEvent()
self._event_manager.notify_layers(start_event)
@ -246,11 +246,11 @@ class GraphEngine:
# Handle completion
if self._graph_execution.is_paused:
pause_reason = self._graph_execution.pause_reason
assert pause_reason is not None, "pause_reason should not be None when execution is paused."
pause_reasons = self._graph_execution.pause_reasons
assert pause_reasons, "pause_reasons should not be empty when execution is paused."
# Ensure we have a valid PauseReason for the event
paused_event = GraphRunPausedEvent(
reason=pause_reason,
reasons=pause_reasons,
outputs=self._graph_runtime_state.outputs,
)
self._event_manager.notify_layers(paused_event)

View File

@ -45,8 +45,7 @@ class GraphRunAbortedEvent(BaseGraphEvent):
class GraphRunPausedEvent(BaseGraphEvent):
"""Event emitted when a graph run is paused by user command."""
# reason: str | None = Field(default=None, description="reason for pause")
reason: PauseReason = Field(..., description="reason for pause")
reasons: list[PauseReason] = Field(description="reason for pause", default_factory=list)
outputs: dict[str, object] = Field(
default_factory=dict,
description="Outputs available to the client while the run is paused.",

View File

@ -65,7 +65,8 @@ class HumanInputNode(Node):
return self._pause_generator()
def _pause_generator(self):
yield PauseRequestedEvent(reason=HumanInputRequired())
# TODO(QuantumGhost): yield a real form id.
yield PauseRequestedEvent(reason=HumanInputRequired(form_id="test_form_id", node_id=self.id))
def _is_completion_ready(self) -> bool:
"""Determine whether all required inputs are satisfied."""

View File

@ -229,6 +229,8 @@ def _get_file_extract_string_func(*, key: str) -> Callable[[File], str]:
return lambda x: x.transfer_method
case "url":
return lambda x: x.remote_url or ""
case "related_id":
return lambda x: x.related_id or ""
case _:
raise InvalidKeyError(f"Invalid key: {key}")
@ -299,7 +301,7 @@ def _get_boolean_filter_func(*, condition: FilterOperator, value: bool) -> Calla
def _get_file_filter_func(*, key: str, condition: str, value: str | Sequence[str]) -> Callable[[File], bool]:
extract_func: Callable[[File], Any]
if key in {"name", "extension", "mime_type", "url"} and isinstance(value, str):
if key in {"name", "extension", "mime_type", "url", "related_id"} and isinstance(value, str):
extract_func = _get_file_extract_string_func(key=key)
return lambda x: _get_string_filter_func(condition=condition, value=value)(extract_func(x))
if key in {"type", "transfer_method"}:
@ -358,7 +360,7 @@ def _ge(value: int | float) -> Callable[[int | float], bool]:
def _order_file(*, order: Order, order_by: str = "", array: Sequence[File]):
extract_func: Callable[[File], Any]
if order_by in {"name", "type", "extension", "mime_type", "transfer_method", "url"}:
if order_by in {"name", "type", "extension", "mime_type", "transfer_method", "url", "related_id"}:
extract_func = _get_file_extract_string_func(key=order_by)
return sorted(array, key=lambda x: extract_func(x), reverse=order == Order.DESC)
elif order_by == "size":

View File

@ -329,7 +329,15 @@ class ToolNode(Node):
json.append(message.message.json_object)
elif message.type == ToolInvokeMessage.MessageType.LINK:
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
stream_text = f"Link: {message.message.text}\n"
# Check if this LINK message is a file link
file_obj = (message.meta or {}).get("file")
if isinstance(file_obj, File):
files.append(file_obj)
stream_text = f"File: {message.message.text}\n"
else:
stream_text = f"Link: {message.message.text}\n"
text += stream_text
yield StreamChunkEvent(
selector=[node_id, "text"],

View File

@ -10,6 +10,7 @@ from typing import Any, Protocol
from pydantic.json import pydantic_encoder
from core.model_runtime.entities.llm_entities import LLMUsage
from core.workflow.entities.pause_reason import PauseReason
from core.workflow.runtime.variable_pool import VariablePool
@ -46,7 +47,11 @@ class ReadyQueueProtocol(Protocol):
class GraphExecutionProtocol(Protocol):
"""Structural interface for graph execution aggregate."""
"""Structural interface for graph execution aggregate.
Defines the minimal set of attributes and methods required from a GraphExecution entity
for runtime orchestration and state management.
"""
workflow_id: str
started: bool
@ -54,6 +59,7 @@ class GraphExecutionProtocol(Protocol):
aborted: bool
error: Exception | None
exceptions_count: int
pause_reasons: list[PauseReason]
def start(self) -> None:
"""Transition execution into the running state."""

View File

@ -112,7 +112,7 @@ class Storage:
def exists(self, filename):
return self.storage_runner.exists(filename)
def delete(self, filename):
def delete(self, filename: str):
return self.storage_runner.delete(filename)
def scan(self, path: str, files: bool = True, directories: bool = False) -> list[str]:

View File

@ -0,0 +1,41 @@
"""Add workflow_pauses_reasons table
Revision ID: 7bb281b7a422
Revises: 09cfdda155d1
Create Date: 2025-11-18 18:59:26.999572
"""
from alembic import op
import models as models
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "7bb281b7a422"
down_revision = "09cfdda155d1"
branch_labels = None
depends_on = None
def upgrade():
op.create_table(
"workflow_pause_reasons",
sa.Column("id", models.types.StringUUID(), nullable=False),
sa.Column("created_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
sa.Column("updated_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
sa.Column("pause_id", models.types.StringUUID(), nullable=False),
sa.Column("type_", sa.String(20), nullable=False),
sa.Column("form_id", sa.String(length=36), nullable=False),
sa.Column("node_id", sa.String(length=255), nullable=False),
sa.Column("message", sa.String(length=255), nullable=False),
sa.PrimaryKeyConstraint("id", name=op.f("workflow_pause_reasons_pkey")),
)
with op.batch_alter_table("workflow_pause_reasons", schema=None) as batch_op:
batch_op.create_index(batch_op.f("workflow_pause_reasons_pause_id_idx"), ["pause_id"], unique=False)
def downgrade():
op.drop_table("workflow_pause_reasons")

View File

@ -88,7 +88,9 @@ class Account(UserMixin, TypeBase):
__tablename__ = "accounts"
__table_args__ = (sa.PrimaryKeyConstraint("id", name="account_pkey"), sa.Index("account_email_idx", "email"))
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
id: Mapped[str] = mapped_column(
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
)
name: Mapped[str] = mapped_column(String(255))
email: Mapped[str] = mapped_column(String(255))
password: Mapped[str | None] = mapped_column(String(255), default=None)
@ -235,7 +237,9 @@ class Tenant(TypeBase):
__tablename__ = "tenants"
__table_args__ = (sa.PrimaryKeyConstraint("id", name="tenant_pkey"),)
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
id: Mapped[str] = mapped_column(
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
)
name: Mapped[str] = mapped_column(String(255))
encrypt_public_key: Mapped[str | None] = mapped_column(LongText, default=None)
plan: Mapped[str] = mapped_column(String(255), server_default=sa.text("'basic'"), default="basic")
@ -275,7 +279,9 @@ class TenantAccountJoin(TypeBase):
sa.UniqueConstraint("tenant_id", "account_id", name="unique_tenant_account_join"),
)
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
id: Mapped[str] = mapped_column(
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
)
tenant_id: Mapped[str] = mapped_column(StringUUID)
account_id: Mapped[str] = mapped_column(StringUUID)
current: Mapped[bool] = mapped_column(sa.Boolean, server_default=sa.text("false"), default=False)
@ -297,7 +303,9 @@ class AccountIntegrate(TypeBase):
sa.UniqueConstraint("provider", "open_id", name="unique_provider_open_id"),
)
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
id: Mapped[str] = mapped_column(
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
)
account_id: Mapped[str] = mapped_column(StringUUID)
provider: Mapped[str] = mapped_column(String(16))
open_id: Mapped[str] = mapped_column(String(255))
@ -348,7 +356,9 @@ class TenantPluginPermission(TypeBase):
sa.UniqueConstraint("tenant_id", name="unique_tenant_plugin"),
)
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
id: Mapped[str] = mapped_column(
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
)
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
install_permission: Mapped[InstallPermission] = mapped_column(
String(16), nullable=False, server_default="everyone", default=InstallPermission.EVERYONE
@ -375,7 +385,9 @@ class TenantPluginAutoUpgradeStrategy(TypeBase):
sa.UniqueConstraint("tenant_id", name="unique_tenant_plugin_auto_upgrade_strategy"),
)
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
id: Mapped[str] = mapped_column(
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
)
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
strategy_setting: Mapped[StrategySetting] = mapped_column(
String(16), nullable=False, server_default="fix_only", default=StrategySetting.FIX_ONLY

View File

@ -24,7 +24,9 @@ class APIBasedExtension(TypeBase):
sa.Index("api_based_extension_tenant_idx", "tenant_id"),
)
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
id: Mapped[str] = mapped_column(
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
)
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
name: Mapped[str] = mapped_column(String(255), nullable=False)
api_endpoint: Mapped[str] = mapped_column(String(255), nullable=False)

View File

@ -920,7 +920,12 @@ class AppDatasetJoin(TypeBase):
)
id: Mapped[str] = mapped_column(
StringUUID, primary_key=True, nullable=False, default=lambda: str(uuid4()), init=False
StringUUID,
primary_key=True,
nullable=False,
insert_default=lambda: str(uuid4()),
default_factory=lambda: str(uuid4()),
init=False,
)
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
dataset_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
@ -941,7 +946,12 @@ class DatasetQuery(TypeBase):
)
id: Mapped[str] = mapped_column(
StringUUID, primary_key=True, nullable=False, default=lambda: str(uuid4()), init=False
StringUUID,
primary_key=True,
nullable=False,
insert_default=lambda: str(uuid4()),
default_factory=lambda: str(uuid4()),
init=False,
)
dataset_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
content: Mapped[str] = mapped_column(LongText, nullable=False)
@ -961,7 +971,13 @@ class DatasetKeywordTable(TypeBase):
sa.Index("dataset_keyword_table_dataset_id_idx", "dataset_id"),
)
id: Mapped[str] = mapped_column(StringUUID, primary_key=True, default=lambda: str(uuid4()), init=False)
id: Mapped[str] = mapped_column(
StringUUID,
primary_key=True,
insert_default=lambda: str(uuid4()),
default_factory=lambda: str(uuid4()),
init=False,
)
dataset_id: Mapped[str] = mapped_column(StringUUID, nullable=False, unique=True)
keyword_table: Mapped[str] = mapped_column(LongText, nullable=False)
data_source_type: Mapped[str] = mapped_column(
@ -1012,7 +1028,13 @@ class Embedding(TypeBase):
sa.Index("created_at_idx", "created_at"),
)
id: Mapped[str] = mapped_column(StringUUID, primary_key=True, default=lambda: str(uuid4()), init=False)
id: Mapped[str] = mapped_column(
StringUUID,
primary_key=True,
insert_default=lambda: str(uuid4()),
default_factory=lambda: str(uuid4()),
init=False,
)
model_name: Mapped[str] = mapped_column(
String(255), nullable=False, server_default=sa.text("'text-embedding-ada-002'")
)
@ -1037,7 +1059,13 @@ class DatasetCollectionBinding(TypeBase):
sa.Index("provider_model_name_idx", "provider_name", "model_name"),
)
id: Mapped[str] = mapped_column(StringUUID, primary_key=True, default=lambda: str(uuid4()), init=False)
id: Mapped[str] = mapped_column(
StringUUID,
primary_key=True,
insert_default=lambda: str(uuid4()),
default_factory=lambda: str(uuid4()),
init=False,
)
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
model_name: Mapped[str] = mapped_column(String(255), nullable=False)
type: Mapped[str] = mapped_column(String(40), server_default=sa.text("'dataset'"), nullable=False)
@ -1073,7 +1101,13 @@ class Whitelist(TypeBase):
sa.PrimaryKeyConstraint("id", name="whitelists_pkey"),
sa.Index("whitelists_tenant_idx", "tenant_id"),
)
id: Mapped[str] = mapped_column(StringUUID, primary_key=True, default=lambda: str(uuid4()), init=False)
id: Mapped[str] = mapped_column(
StringUUID,
primary_key=True,
insert_default=lambda: str(uuid4()),
default_factory=lambda: str(uuid4()),
init=False,
)
tenant_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
category: Mapped[str] = mapped_column(String(255), nullable=False)
created_at: Mapped[datetime] = mapped_column(
@ -1090,7 +1124,13 @@ class DatasetPermission(TypeBase):
sa.Index("idx_dataset_permissions_tenant_id", "tenant_id"),
)
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), primary_key=True, init=False)
id: Mapped[str] = mapped_column(
StringUUID,
insert_default=lambda: str(uuid4()),
default_factory=lambda: str(uuid4()),
primary_key=True,
init=False,
)
dataset_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
account_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
@ -1110,7 +1150,13 @@ class ExternalKnowledgeApis(TypeBase):
sa.Index("external_knowledge_apis_name_idx", "name"),
)
id: Mapped[str] = mapped_column(StringUUID, nullable=False, default=lambda: str(uuid4()), init=False)
id: Mapped[str] = mapped_column(
StringUUID,
nullable=False,
insert_default=lambda: str(uuid4()),
default_factory=lambda: str(uuid4()),
init=False,
)
name: Mapped[str] = mapped_column(String(255), nullable=False)
description: Mapped[str] = mapped_column(String(255), nullable=False)
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
@ -1167,7 +1213,13 @@ class ExternalKnowledgeBindings(TypeBase):
sa.Index("external_knowledge_bindings_external_knowledge_api_idx", "external_knowledge_api_id"),
)
id: Mapped[str] = mapped_column(StringUUID, nullable=False, default=lambda: str(uuid4()), init=False)
id: Mapped[str] = mapped_column(
StringUUID,
nullable=False,
insert_default=lambda: str(uuid4()),
default_factory=lambda: str(uuid4()),
init=False,
)
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
external_knowledge_api_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
dataset_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
@ -1191,7 +1243,9 @@ class DatasetAutoDisableLog(TypeBase):
sa.Index("dataset_auto_disable_log_created_atx", "created_at"),
)
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
id: Mapped[str] = mapped_column(
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
)
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
dataset_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
document_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
@ -1209,7 +1263,9 @@ class RateLimitLog(TypeBase):
sa.Index("rate_limit_log_operation_idx", "operation"),
)
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
id: Mapped[str] = mapped_column(
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
)
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
subscription_plan: Mapped[str] = mapped_column(String(255), nullable=False)
operation: Mapped[str] = mapped_column(String(255), nullable=False)
@ -1226,7 +1282,9 @@ class DatasetMetadata(TypeBase):
sa.Index("dataset_metadata_dataset_idx", "dataset_id"),
)
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
id: Mapped[str] = mapped_column(
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
)
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
dataset_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
type: Mapped[str] = mapped_column(String(255), nullable=False)
@ -1255,7 +1313,9 @@ class DatasetMetadataBinding(TypeBase):
sa.Index("dataset_metadata_binding_document_idx", "document_id"),
)
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
id: Mapped[str] = mapped_column(
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
)
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
dataset_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
metadata_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
@ -1270,7 +1330,9 @@ class PipelineBuiltInTemplate(TypeBase):
__tablename__ = "pipeline_built_in_templates"
__table_args__ = (sa.PrimaryKeyConstraint("id", name="pipeline_built_in_template_pkey"),)
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7()), init=False)
id: Mapped[str] = mapped_column(
StringUUID, insert_default=lambda: str(uuidv7()), default_factory=lambda: str(uuidv7()), init=False
)
name: Mapped[str] = mapped_column(sa.String(255), nullable=False)
description: Mapped[str] = mapped_column(LongText, nullable=False)
chunk_structure: Mapped[str] = mapped_column(sa.String(255), nullable=False)
@ -1300,7 +1362,9 @@ class PipelineCustomizedTemplate(TypeBase):
sa.Index("pipeline_customized_template_tenant_idx", "tenant_id"),
)
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7()), init=False)
id: Mapped[str] = mapped_column(
StringUUID, insert_default=lambda: str(uuidv7()), default_factory=lambda: str(uuidv7()), init=False
)
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
name: Mapped[str] = mapped_column(sa.String(255), nullable=False)
description: Mapped[str] = mapped_column(LongText, nullable=False)
@ -1335,7 +1399,9 @@ class Pipeline(TypeBase):
__tablename__ = "pipelines"
__table_args__ = (sa.PrimaryKeyConstraint("id", name="pipeline_pkey"),)
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7()), init=False)
id: Mapped[str] = mapped_column(
StringUUID, insert_default=lambda: str(uuidv7()), default_factory=lambda: str(uuidv7()), init=False
)
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
name: Mapped[str] = mapped_column(sa.String(255), nullable=False)
description: Mapped[str] = mapped_column(LongText, nullable=False, default=sa.text("''"))
@ -1368,7 +1434,9 @@ class DocumentPipelineExecutionLog(TypeBase):
sa.Index("document_pipeline_execution_logs_document_id_idx", "document_id"),
)
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7()), init=False)
id: Mapped[str] = mapped_column(
StringUUID, insert_default=lambda: str(uuidv7()), default_factory=lambda: str(uuidv7()), init=False
)
pipeline_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
document_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
datasource_type: Mapped[str] = mapped_column(sa.String(255), nullable=False)
@ -1385,7 +1453,9 @@ class PipelineRecommendedPlugin(TypeBase):
__tablename__ = "pipeline_recommended_plugins"
__table_args__ = (sa.PrimaryKeyConstraint("id", name="pipeline_recommended_plugin_pkey"),)
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7()), init=False)
id: Mapped[str] = mapped_column(
StringUUID, insert_default=lambda: str(uuidv7()), default_factory=lambda: str(uuidv7()), init=False
)
plugin_id: Mapped[str] = mapped_column(LongText, nullable=False)
provider_name: Mapped[str] = mapped_column(LongText, nullable=False)
position: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0)

View File

@ -572,7 +572,9 @@ class InstalledApp(TypeBase):
sa.UniqueConstraint("tenant_id", "app_id", name="unique_tenant_app"),
)
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
id: Mapped[str] = mapped_column(
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
)
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
app_owner_tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
@ -606,7 +608,9 @@ class OAuthProviderApp(TypeBase):
sa.Index("oauth_provider_app_client_id_idx", "client_id"),
)
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7()), init=False)
id: Mapped[str] = mapped_column(
StringUUID, insert_default=lambda: str(uuidv7()), default_factory=lambda: str(uuidv7()), init=False
)
app_icon: Mapped[str] = mapped_column(String(255), nullable=False)
client_id: Mapped[str] = mapped_column(String(255), nullable=False)
client_secret: Mapped[str] = mapped_column(String(255), nullable=False)
@ -1251,9 +1255,13 @@ class Message(Base):
"id": self.id,
"app_id": self.app_id,
"conversation_id": self.conversation_id,
"model_provider": self.model_provider,
"model_id": self.model_id,
"inputs": self.inputs,
"query": self.query,
"message_tokens": self.message_tokens,
"answer_tokens": self.answer_tokens,
"provider_response_latency": self.provider_response_latency,
"total_price": self.total_price,
"message": self.message,
"answer": self.answer,
@ -1275,8 +1283,12 @@ class Message(Base):
id=data["id"],
app_id=data["app_id"],
conversation_id=data["conversation_id"],
model_provider=data.get("model_provider"),
model_id=data["model_id"],
inputs=data["inputs"],
message_tokens=data.get("message_tokens", 0),
answer_tokens=data.get("answer_tokens", 0),
provider_response_latency=data.get("provider_response_latency", 0.0),
total_price=data["total_price"],
query=data["query"],
message=data["message"],
@ -1303,7 +1315,9 @@ class MessageFeedback(TypeBase):
sa.Index("message_feedback_conversation_idx", "conversation_id", "from_source", "rating"),
)
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
id: Mapped[str] = mapped_column(
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
)
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
conversation_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
message_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
@ -1352,7 +1366,9 @@ class MessageFile(TypeBase):
sa.Index("message_file_created_by_idx", "created_by"),
)
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
id: Mapped[str] = mapped_column(
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
)
message_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
type: Mapped[str] = mapped_column(String(255), nullable=False)
transfer_method: Mapped[FileTransferMethod] = mapped_column(String(255), nullable=False)
@ -1444,7 +1460,9 @@ class AppAnnotationSetting(TypeBase):
sa.Index("app_annotation_settings_app_idx", "app_id"),
)
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
id: Mapped[str] = mapped_column(
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
)
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
score_threshold: Mapped[float] = mapped_column(Float, nullable=False, server_default=sa.text("0"))
collection_binding_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
@ -1480,7 +1498,9 @@ class OperationLog(TypeBase):
sa.Index("operation_log_account_action_idx", "tenant_id", "account_id", "action"),
)
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
id: Mapped[str] = mapped_column(
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
)
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
account_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
action: Mapped[str] = mapped_column(String(255), nullable=False)
@ -1546,7 +1566,9 @@ class AppMCPServer(TypeBase):
sa.UniqueConstraint("tenant_id", "app_id", name="unique_app_mcp_server_tenant_app_id"),
sa.UniqueConstraint("server_code", name="unique_app_mcp_server_server_code"),
)
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
id: Mapped[str] = mapped_column(
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
)
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
name: Mapped[str] = mapped_column(String(255), nullable=False)
@ -1756,7 +1778,9 @@ class ApiRequest(TypeBase):
sa.Index("api_request_token_idx", "tenant_id", "api_token_id"),
)
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
id: Mapped[str] = mapped_column(
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
)
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
api_token_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
path: Mapped[str] = mapped_column(String(255), nullable=False)
@ -1775,7 +1799,9 @@ class MessageChain(TypeBase):
sa.Index("message_chain_message_id_idx", "message_id"),
)
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
id: Mapped[str] = mapped_column(
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
)
message_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
type: Mapped[str] = mapped_column(String(255), nullable=False)
input: Mapped[str | None] = mapped_column(LongText, nullable=True)
@ -1906,7 +1932,9 @@ class DatasetRetrieverResource(TypeBase):
sa.Index("dataset_retriever_resource_message_id_idx", "message_id"),
)
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
id: Mapped[str] = mapped_column(
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
)
message_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
position: Mapped[int] = mapped_column(sa.Integer, nullable=False)
dataset_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
@ -1938,7 +1966,9 @@ class Tag(TypeBase):
TAG_TYPE_LIST = ["knowledge", "app"]
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
id: Mapped[str] = mapped_column(
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
)
tenant_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
type: Mapped[str] = mapped_column(String(16), nullable=False)
name: Mapped[str] = mapped_column(String(255), nullable=False)
@ -1956,7 +1986,9 @@ class TagBinding(TypeBase):
sa.Index("tag_bind_tag_id_idx", "tag_id"),
)
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
id: Mapped[str] = mapped_column(
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
)
tenant_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
tag_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
target_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
@ -1973,7 +2005,9 @@ class TraceAppConfig(TypeBase):
sa.Index("trace_app_config_app_id_idx", "app_id"),
)
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
id: Mapped[str] = mapped_column(
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
)
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
tracing_provider: Mapped[str | None] = mapped_column(String(255), nullable=True)
tracing_config: Mapped[dict | None] = mapped_column(sa.JSON, nullable=True)

View File

@ -17,7 +17,9 @@ class DatasourceOauthParamConfig(TypeBase):
sa.UniqueConstraint("plugin_id", "provider", name="datasource_oauth_config_datasource_id_provider_idx"),
)
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7()), init=False)
id: Mapped[str] = mapped_column(
StringUUID, insert_default=lambda: str(uuidv7()), default_factory=lambda: str(uuidv7()), init=False
)
plugin_id: Mapped[str] = mapped_column(sa.String(255), nullable=False)
provider: Mapped[str] = mapped_column(sa.String(255), nullable=False)
system_credentials: Mapped[dict] = mapped_column(AdjustedJSON, nullable=False)
@ -30,7 +32,9 @@ class DatasourceProvider(TypeBase):
sa.UniqueConstraint("tenant_id", "plugin_id", "provider", "name", name="datasource_provider_unique_name"),
sa.Index("datasource_provider_auth_type_provider_idx", "tenant_id", "plugin_id", "provider"),
)
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7()), init=False)
id: Mapped[str] = mapped_column(
StringUUID, insert_default=lambda: str(uuidv7()), default_factory=lambda: str(uuidv7()), init=False
)
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
name: Mapped[str] = mapped_column(sa.String(255), nullable=False)
provider: Mapped[str] = mapped_column(sa.String(128), nullable=False)
@ -60,7 +64,9 @@ class DatasourceOauthTenantParamConfig(TypeBase):
sa.UniqueConstraint("tenant_id", "plugin_id", "provider", name="datasource_oauth_tenant_config_unique"),
)
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7()), init=False)
id: Mapped[str] = mapped_column(
StringUUID, insert_default=lambda: str(uuidv7()), default_factory=lambda: str(uuidv7()), init=False
)
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
provider: Mapped[str] = mapped_column(sa.String(255), nullable=False)
plugin_id: Mapped[str] = mapped_column(sa.String(255), nullable=False)

View File

@ -58,7 +58,13 @@ class Provider(TypeBase):
),
)
id: Mapped[str] = mapped_column(StringUUID, primary_key=True, default=lambda: str(uuidv7()), init=False)
id: Mapped[str] = mapped_column(
StringUUID,
primary_key=True,
insert_default=lambda: str(uuidv7()),
default_factory=lambda: str(uuidv7()),
init=False,
)
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
provider_type: Mapped[str] = mapped_column(
@ -132,7 +138,9 @@ class ProviderModel(TypeBase):
),
)
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
id: Mapped[str] = mapped_column(
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
)
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
model_name: Mapped[str] = mapped_column(String(255), nullable=False)
@ -173,7 +181,9 @@ class TenantDefaultModel(TypeBase):
sa.Index("tenant_default_model_tenant_id_provider_type_idx", "tenant_id", "provider_name", "model_type"),
)
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
id: Mapped[str] = mapped_column(
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
)
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
model_name: Mapped[str] = mapped_column(String(255), nullable=False)
@ -193,7 +203,9 @@ class TenantPreferredModelProvider(TypeBase):
sa.Index("tenant_preferred_model_provider_tenant_provider_idx", "tenant_id", "provider_name"),
)
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
id: Mapped[str] = mapped_column(
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
)
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
preferred_provider_type: Mapped[str] = mapped_column(String(40), nullable=False)
@ -212,7 +224,9 @@ class ProviderOrder(TypeBase):
sa.Index("provider_order_tenant_provider_idx", "tenant_id", "provider_name"),
)
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
id: Mapped[str] = mapped_column(
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
)
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
account_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
@ -245,7 +259,9 @@ class ProviderModelSetting(TypeBase):
sa.Index("provider_model_setting_tenant_provider_model_idx", "tenant_id", "provider_name", "model_type"),
)
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
id: Mapped[str] = mapped_column(
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
)
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
model_name: Mapped[str] = mapped_column(String(255), nullable=False)
@ -273,7 +289,9 @@ class LoadBalancingModelConfig(TypeBase):
sa.Index("load_balancing_model_config_tenant_provider_model_idx", "tenant_id", "provider_name", "model_type"),
)
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
id: Mapped[str] = mapped_column(
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
)
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
model_name: Mapped[str] = mapped_column(String(255), nullable=False)
@ -302,7 +320,9 @@ class ProviderCredential(TypeBase):
sa.Index("provider_credential_tenant_provider_idx", "tenant_id", "provider_name"),
)
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7()), init=False)
id: Mapped[str] = mapped_column(
StringUUID, insert_default=lambda: str(uuidv7()), default_factory=lambda: str(uuidv7()), init=False
)
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
credential_name: Mapped[str] = mapped_column(String(255), nullable=False)
@ -332,7 +352,9 @@ class ProviderModelCredential(TypeBase):
),
)
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7()), init=False)
id: Mapped[str] = mapped_column(
StringUUID, insert_default=lambda: str(uuidv7()), default_factory=lambda: str(uuidv7()), init=False
)
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
model_name: Mapped[str] = mapped_column(String(255), nullable=False)

View File

@ -18,7 +18,9 @@ class DataSourceOauthBinding(TypeBase):
adjusted_json_index("source_info_idx", "source_info"),
)
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
id: Mapped[str] = mapped_column(
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
)
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
access_token: Mapped[str] = mapped_column(String(255), nullable=False)
provider: Mapped[str] = mapped_column(String(255), nullable=False)
@ -44,7 +46,9 @@ class DataSourceApiKeyAuthBinding(TypeBase):
sa.Index("data_source_api_key_auth_binding_provider_idx", "provider"),
)
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
id: Mapped[str] = mapped_column(
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
)
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
category: Mapped[str] = mapped_column(String(255), nullable=False)
provider: Mapped[str] = mapped_column(String(255), nullable=False)

View File

@ -24,7 +24,8 @@ class CeleryTask(TypeBase):
result: Mapped[bytes | None] = mapped_column(BinaryData, nullable=True, default=None)
date_done: Mapped[datetime | None] = mapped_column(
DateTime,
default=naive_utc_now,
insert_default=naive_utc_now,
default=None,
onupdate=naive_utc_now,
nullable=True,
)
@ -47,4 +48,6 @@ class CeleryTaskSet(TypeBase):
)
taskset_id: Mapped[str] = mapped_column(String(155), unique=True)
result: Mapped[bytes | None] = mapped_column(BinaryData, nullable=True, default=None)
date_done: Mapped[datetime | None] = mapped_column(DateTime, default=naive_utc_now, nullable=True)
date_done: Mapped[datetime | None] = mapped_column(
DateTime, insert_default=naive_utc_now, default=None, nullable=True
)

View File

@ -30,7 +30,9 @@ class ToolOAuthSystemClient(TypeBase):
sa.UniqueConstraint("plugin_id", "provider", name="tool_oauth_system_client_plugin_id_provider_idx"),
)
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
id: Mapped[str] = mapped_column(
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
)
plugin_id: Mapped[str] = mapped_column(String(512), nullable=False)
provider: Mapped[str] = mapped_column(String(255), nullable=False)
# oauth params of the tool provider
@ -45,7 +47,9 @@ class ToolOAuthTenantClient(TypeBase):
sa.UniqueConstraint("tenant_id", "plugin_id", "provider", name="unique_tool_oauth_tenant_client"),
)
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
id: Mapped[str] = mapped_column(
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
)
# tenant id
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
plugin_id: Mapped[str] = mapped_column(String(255), nullable=False)
@ -71,7 +75,9 @@ class BuiltinToolProvider(TypeBase):
)
# id of the tool provider
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
id: Mapped[str] = mapped_column(
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
)
name: Mapped[str] = mapped_column(
String(256),
nullable=False,
@ -120,7 +126,9 @@ class ApiToolProvider(TypeBase):
sa.UniqueConstraint("name", "tenant_id", name="unique_api_tool_provider"),
)
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
id: Mapped[str] = mapped_column(
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
)
# name of the api provider
name: Mapped[str] = mapped_column(
String(255),
@ -192,7 +200,9 @@ class ToolLabelBinding(TypeBase):
sa.UniqueConstraint("tool_id", "label_name", name="unique_tool_label_bind"),
)
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
id: Mapped[str] = mapped_column(
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
)
# tool id
tool_id: Mapped[str] = mapped_column(String(64), nullable=False)
# tool type
@ -213,7 +223,9 @@ class WorkflowToolProvider(TypeBase):
sa.UniqueConstraint("tenant_id", "app_id", name="unique_workflow_tool_provider_app_id"),
)
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
id: Mapped[str] = mapped_column(
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
)
# name of the workflow provider
name: Mapped[str] = mapped_column(String(255), nullable=False)
# label of the workflow provider
@ -279,7 +291,9 @@ class MCPToolProvider(TypeBase):
sa.UniqueConstraint("tenant_id", "server_identifier", name="unique_mcp_provider_server_identifier"),
)
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
id: Mapped[str] = mapped_column(
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
)
# name of the mcp provider
name: Mapped[str] = mapped_column(String(40), nullable=False)
# server identifier of the mcp provider
@ -360,7 +374,9 @@ class ToolModelInvoke(TypeBase):
__tablename__ = "tool_model_invokes"
__table_args__ = (sa.PrimaryKeyConstraint("id", name="tool_model_invoke_pkey"),)
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
id: Mapped[str] = mapped_column(
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
)
# who invoke this tool
user_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
# tenant id
@ -413,7 +429,9 @@ class ToolConversationVariables(TypeBase):
sa.Index("conversation_id_idx", "conversation_id"),
)
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
id: Mapped[str] = mapped_column(
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
)
# conversation user id
user_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
# tenant id
@ -450,7 +468,9 @@ class ToolFile(TypeBase):
sa.Index("tool_file_conversation_id_idx", "conversation_id"),
)
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
id: Mapped[str] = mapped_column(
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
)
# conversation user id
user_id: Mapped[str] = mapped_column(StringUUID)
# tenant id
@ -481,7 +501,9 @@ class DeprecatedPublishedAppTool(TypeBase):
sa.UniqueConstraint("app_id", "user_id", name="unique_published_app_tool"),
)
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
id: Mapped[str] = mapped_column(
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
)
# id of the app
app_id: Mapped[str] = mapped_column(StringUUID, ForeignKey("apps.id"), nullable=False)

View File

@ -41,7 +41,9 @@ class TriggerSubscription(TypeBase):
UniqueConstraint("tenant_id", "provider_id", "name", name="unique_trigger_provider"),
)
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
id: Mapped[str] = mapped_column(
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
)
name: Mapped[str] = mapped_column(String(255), nullable=False, comment="Subscription instance name")
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
user_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
@ -111,7 +113,9 @@ class TriggerOAuthSystemClient(TypeBase):
sa.UniqueConstraint("plugin_id", "provider", name="trigger_oauth_system_client_plugin_id_provider_idx"),
)
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
id: Mapped[str] = mapped_column(
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
)
plugin_id: Mapped[str] = mapped_column(String(255), nullable=False)
provider: Mapped[str] = mapped_column(String(255), nullable=False)
# oauth params of the trigger provider
@ -136,7 +140,9 @@ class TriggerOAuthTenantClient(TypeBase):
sa.UniqueConstraint("tenant_id", "plugin_id", "provider", name="unique_trigger_oauth_tenant_client"),
)
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
id: Mapped[str] = mapped_column(
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
)
# tenant id
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
plugin_id: Mapped[str] = mapped_column(String(255), nullable=False)
@ -202,7 +208,9 @@ class WorkflowTriggerLog(TypeBase):
sa.Index("workflow_trigger_log_workflow_id_idx", "workflow_id"),
)
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7()), init=False)
id: Mapped[str] = mapped_column(
StringUUID, insert_default=lambda: str(uuidv7()), default_factory=lambda: str(uuidv7()), init=False
)
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
workflow_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
@ -294,7 +302,9 @@ class WorkflowWebhookTrigger(TypeBase):
sa.UniqueConstraint("webhook_id", name="uniq_webhook_id"),
)
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7()), init=False)
id: Mapped[str] = mapped_column(
StringUUID, insert_default=lambda: str(uuidv7()), default_factory=lambda: str(uuidv7()), init=False
)
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
node_id: Mapped[str] = mapped_column(String(64), nullable=False)
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
@ -351,7 +361,9 @@ class WorkflowPluginTrigger(TypeBase):
sa.UniqueConstraint("app_id", "node_id", name="uniq_app_node_subscription"),
)
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
id: Mapped[str] = mapped_column(
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
)
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
node_id: Mapped[str] = mapped_column(String(64), nullable=False)
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
@ -395,7 +407,9 @@ class AppTrigger(TypeBase):
sa.Index("app_trigger_tenant_app_idx", "tenant_id", "app_id"),
)
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7()), init=False)
id: Mapped[str] = mapped_column(
StringUUID, insert_default=lambda: str(uuidv7()), default_factory=lambda: str(uuidv7()), init=False
)
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
node_id: Mapped[str | None] = mapped_column(String(64), nullable=False)
@ -443,7 +457,13 @@ class WorkflowSchedulePlan(TypeBase):
sa.Index("workflow_schedule_plan_next_idx", "next_run_at"),
)
id: Mapped[str] = mapped_column(StringUUID, primary_key=True, default=lambda: str(uuidv7()), init=False)
id: Mapped[str] = mapped_column(
StringUUID,
primary_key=True,
insert_default=lambda: str(uuidv7()),
default_factory=lambda: str(uuidv7()),
init=False,
)
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
node_id: Mapped[str] = mapped_column(String(64), nullable=False)
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)

View File

@ -18,7 +18,9 @@ class SavedMessage(TypeBase):
sa.Index("saved_message_message_idx", "app_id", "message_id", "created_by_role", "created_by"),
)
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
id: Mapped[str] = mapped_column(
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
)
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
message_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
created_by_role: Mapped[str] = mapped_column(String(255), nullable=False, server_default=sa.text("'end_user'"))
@ -42,7 +44,9 @@ class PinnedConversation(TypeBase):
sa.Index("pinned_conversation_conversation_idx", "app_id", "conversation_id", "created_by_role", "created_by"),
)
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
id: Mapped[str] = mapped_column(
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
)
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
conversation_id: Mapped[str] = mapped_column(StringUUID)
created_by_role: Mapped[str] = mapped_column(

View File

@ -29,6 +29,7 @@ from core.workflow.constants import (
CONVERSATION_VARIABLE_NODE_ID,
SYSTEM_VARIABLE_NODE_ID,
)
from core.workflow.entities.pause_reason import HumanInputRequired, PauseReason, PauseReasonType, SchedulingPause
from core.workflow.enums import NodeType
from extensions.ext_storage import Storage
from factories.variable_factory import TypeMismatchError, build_segment_with_type
@ -1102,7 +1103,9 @@ class WorkflowAppLog(TypeBase):
sa.Index("workflow_app_log_workflow_run_id_idx", "workflow_run_id"),
)
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
id: Mapped[str] = mapped_column(
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
)
tenant_id: Mapped[str] = mapped_column(StringUUID)
app_id: Mapped[str] = mapped_column(StringUUID)
workflow_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
@ -1728,3 +1731,68 @@ class WorkflowPause(DefaultFieldsMixin, Base):
primaryjoin="WorkflowPause.workflow_run_id == WorkflowRun.id",
back_populates="pause",
)
class WorkflowPauseReason(DefaultFieldsMixin, Base):
__tablename__ = "workflow_pause_reasons"
# `pause_id` represents the identifier of the pause,
# correspond to the `id` field of `WorkflowPause`.
pause_id: Mapped[str] = mapped_column(StringUUID, nullable=False, index=True)
type_: Mapped[PauseReasonType] = mapped_column(EnumText(PauseReasonType), nullable=False)
# form_id is not empty if and if only type_ == PauseReasonType.HUMAN_INPUT_REQUIRED
#
form_id: Mapped[str] = mapped_column(
String(36),
nullable=False,
default="",
)
# message records the text description of this pause reason. For example,
# "The workflow has been paused due to scheduling."
#
# Empty message means that this pause reason is not speified.
message: Mapped[str] = mapped_column(
String(255),
nullable=False,
default="",
)
# `node_id` is the identifier of node causing the pasue, correspond to
# `Node.id`. Empty `node_id` means that this pause reason is not caused by any specific node
# (E.G. time slicing pauses.)
node_id: Mapped[str] = mapped_column(
String(255),
nullable=False,
default="",
)
# Relationship to WorkflowPause
pause: Mapped[WorkflowPause] = orm.relationship(
foreign_keys=[pause_id],
# require explicit preloading.
lazy="raise",
uselist=False,
primaryjoin="WorkflowPauseReason.pause_id == WorkflowPause.id",
)
@classmethod
def from_entity(cls, pause_reason: PauseReason) -> "WorkflowPauseReason":
if isinstance(pause_reason, HumanInputRequired):
return cls(
type_=PauseReasonType.HUMAN_INPUT_REQUIRED, form_id=pause_reason.form_id, node_id=pause_reason.node_id
)
elif isinstance(pause_reason, SchedulingPause):
return cls(type_=PauseReasonType.SCHEDULED_PAUSE, message=pause_reason.message, node_id="")
else:
raise AssertionError(f"Unknown pause reason type: {pause_reason}")
def to_entity(self) -> PauseReason:
if self.type_ == PauseReasonType.HUMAN_INPUT_REQUIRED:
return HumanInputRequired(form_id=self.form_id, node_id=self.node_id)
elif self.type_ == PauseReasonType.SCHEDULED_PAUSE:
return SchedulingPause(message=self.message)
else:
raise AssertionError(f"Unknown pause reason type: {self.type_}")

View File

@ -1,6 +1,6 @@
[project]
name = "dify-api"
version = "1.10.0"
version = "1.10.1"
requires-python = ">=3.11,<3.13"
dependencies = [

View File

@ -38,11 +38,12 @@ from collections.abc import Sequence
from datetime import datetime
from typing import Protocol
from core.workflow.entities.workflow_pause import WorkflowPauseEntity
from core.workflow.entities.pause_reason import PauseReason
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from models.enums import WorkflowRunTriggeredFrom
from models.workflow import WorkflowRun
from repositories.entities.workflow_pause import WorkflowPauseEntity
from repositories.types import (
AverageInteractionStats,
DailyRunsStats,
@ -257,6 +258,7 @@ class APIWorkflowRunRepository(WorkflowExecutionRepository, Protocol):
workflow_run_id: str,
state_owner_user_id: str,
state: str,
pause_reasons: Sequence[PauseReason],
) -> WorkflowPauseEntity:
"""
Create a new workflow pause state.

View File

@ -7,8 +7,11 @@ and don't contain implementation details like tenant_id, app_id, etc.
"""
from abc import ABC, abstractmethod
from collections.abc import Sequence
from datetime import datetime
from core.workflow.entities.pause_reason import PauseReason
class WorkflowPauseEntity(ABC):
"""
@ -59,3 +62,15 @@ class WorkflowPauseEntity(ABC):
the pause is not resumed yet.
"""
pass
@abstractmethod
def get_pause_reasons(self) -> Sequence[PauseReason]:
"""
Retrieve detailed reasons for this pause.
Returns a sequence of `PauseReason` objects describing the specific nodes and
reasons for which the workflow execution was paused.
This information is related to, but distinct from, the `PauseReason` type
defined in `api/core/workflow/entities/pause_reason.py`.
"""
...

View File

@ -31,7 +31,7 @@ from sqlalchemy import and_, delete, func, null, or_, select
from sqlalchemy.engine import CursorResult
from sqlalchemy.orm import Session, selectinload, sessionmaker
from core.workflow.entities.workflow_pause import WorkflowPauseEntity
from core.workflow.entities.pause_reason import HumanInputRequired, PauseReason, SchedulingPause
from core.workflow.enums import WorkflowExecutionStatus
from extensions.ext_storage import storage
from libs.datetime_utils import naive_utc_now
@ -41,8 +41,9 @@ from libs.time_parser import get_time_threshold
from libs.uuid_utils import uuidv7
from models.enums import WorkflowRunTriggeredFrom
from models.workflow import WorkflowPause as WorkflowPauseModel
from models.workflow import WorkflowRun
from models.workflow import WorkflowPauseReason, WorkflowRun
from repositories.api_workflow_run_repository import APIWorkflowRunRepository
from repositories.entities.workflow_pause import WorkflowPauseEntity
from repositories.types import (
AverageInteractionStats,
DailyRunsStats,
@ -318,6 +319,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
workflow_run_id: str,
state_owner_user_id: str,
state: str,
pause_reasons: Sequence[PauseReason],
) -> WorkflowPauseEntity:
"""
Create a new workflow pause state.
@ -371,6 +373,25 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
pause_model.workflow_run_id = workflow_run.id
pause_model.state_object_key = state_obj_key
pause_model.created_at = naive_utc_now()
pause_reason_models = []
for reason in pause_reasons:
if isinstance(reason, HumanInputRequired):
# TODO(QuantumGhost): record node_id for `WorkflowPauseReason`
pause_reason_model = WorkflowPauseReason(
pause_id=pause_model.id,
type_=reason.TYPE,
form_id=reason.form_id,
)
elif isinstance(reason, SchedulingPause):
pause_reason_model = WorkflowPauseReason(
pause_id=pause_model.id,
type_=reason.TYPE,
message=reason.message,
)
else:
raise AssertionError(f"unkown reason type: {type(reason)}")
pause_reason_models.append(pause_reason_model)
# Update workflow run status
workflow_run.status = WorkflowExecutionStatus.PAUSED
@ -378,10 +399,16 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
# Save everything in a transaction
session.add(pause_model)
session.add(workflow_run)
session.add_all(pause_reason_models)
logger.info("Created workflow pause %s for workflow run %s", pause_model.id, workflow_run_id)
return _PrivateWorkflowPauseEntity.from_models(pause_model)
return _PrivateWorkflowPauseEntity(pause_model=pause_model, reason_models=pause_reason_models)
def _get_reasons_by_pause_id(self, session: Session, pause_id: str):
reason_stmt = select(WorkflowPauseReason).where(WorkflowPauseReason.pause_id == pause_id)
pause_reason_models = session.scalars(reason_stmt).all()
return pause_reason_models
def get_workflow_pause(
self,
@ -413,8 +440,16 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
pause_model = workflow_run.pause
if pause_model is None:
return None
pause_reason_models = self._get_reasons_by_pause_id(session, pause_model.id)
return _PrivateWorkflowPauseEntity.from_models(pause_model)
human_input_form: list[Any] = []
# TODO(QuantumGhost): query human_input_forms model and rebuild PauseReason
return _PrivateWorkflowPauseEntity(
pause_model=pause_model,
reason_models=pause_reason_models,
human_input_form=human_input_form,
)
def resume_workflow_pause(
self,
@ -466,6 +501,8 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
if pause_model.resumed_at is not None:
raise _WorkflowRunError(f"Cannot resume an already resumed pause, pause_id={pause_model.id}")
pause_reasons = self._get_reasons_by_pause_id(session, pause_model.id)
# Mark as resumed
pause_model.resumed_at = naive_utc_now()
workflow_run.pause_id = None # type: ignore
@ -476,7 +513,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
logger.info("Resumed workflow pause %s for workflow run %s", pause_model.id, workflow_run_id)
return _PrivateWorkflowPauseEntity.from_models(pause_model)
return _PrivateWorkflowPauseEntity(pause_model=pause_model, reason_models=pause_reasons)
def delete_workflow_pause(
self,
@ -815,26 +852,13 @@ class _PrivateWorkflowPauseEntity(WorkflowPauseEntity):
self,
*,
pause_model: WorkflowPauseModel,
reason_models: Sequence[WorkflowPauseReason],
human_input_form: Sequence = (),
) -> None:
self._pause_model = pause_model
self._reason_models = reason_models
self._cached_state: bytes | None = None
@classmethod
def from_models(cls, workflow_pause_model) -> "_PrivateWorkflowPauseEntity":
"""
Create a _PrivateWorkflowPauseEntity from database models.
Args:
workflow_pause_model: The WorkflowPause database model
upload_file_model: The UploadFile database model
Returns:
_PrivateWorkflowPauseEntity: The constructed entity
Raises:
ValueError: If required model attributes are missing
"""
return cls(pause_model=workflow_pause_model)
self._human_input_form = human_input_form
@property
def id(self) -> str:
@ -867,3 +891,6 @@ class _PrivateWorkflowPauseEntity(WorkflowPauseEntity):
@property
def resumed_at(self) -> datetime | None:
return self._pause_model.resumed_at
def get_pause_reasons(self) -> Sequence[PauseReason]:
return [reason.to_entity() for reason in self._reason_models]

View File

@ -1352,7 +1352,7 @@ class RegisterService:
@classmethod
def invite_new_member(
cls, tenant: Tenant, email: str, language: str, role: str = "normal", inviter: Account | None = None
cls, tenant: Tenant, email: str, language: str | None, role: str = "normal", inviter: Account | None = None
) -> str:
if not inviter:
raise ValueError("Inviter is required")

View File

@ -550,7 +550,7 @@ class AppDslService:
"app": {
"name": app_model.name,
"mode": app_model.mode,
"icon": "🤖" if app_model.icon_type == "image" else app_model.icon,
"icon": app_model.icon if app_model.icon_type == "image" else "🤖",
"icon_background": "#FFEAD5" if app_model.icon_type == "image" else app_model.icon_background,
"description": app_model.description,
"use_icon_as_answer_icon": app_model.use_icon_as_answer_icon,

View File

@ -0,0 +1,45 @@
"""Service for managing application task operations.
This service provides centralized logic for task control operations
like stopping tasks, handling both legacy Redis flag mechanism and
new GraphEngine command channel mechanism.
"""
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.graph_engine.manager import GraphEngineManager
from models.model import AppMode
class AppTaskService:
"""Service for managing application task operations."""
@staticmethod
def stop_task(
task_id: str,
invoke_from: InvokeFrom,
user_id: str,
app_mode: AppMode,
) -> None:
"""Stop a running task.
This method handles stopping tasks using both mechanisms:
1. Legacy Redis flag mechanism (for backward compatibility)
2. New GraphEngine command channel (for workflow-based apps)
Args:
task_id: The task ID to stop
invoke_from: The source of the invoke (e.g., DEBUGGER, WEB_APP, SERVICE_API)
user_id: The user ID requesting the stop
app_mode: The application mode (CHAT, AGENT_CHAT, ADVANCED_CHAT, WORKFLOW, etc.)
Returns:
None
"""
# Legacy mechanism: Set stop flag in Redis
AppQueueManager.set_stop_flag(task_id, invoke_from, user_id)
# New mechanism: Send stop command via GraphEngine for workflow-based apps
# This ensures proper workflow status recording in the persistence layer
if app_mode in (AppMode.ADVANCED_CHAT, AppMode.WORKFLOW):
GraphEngineManager.send_stop_command(task_id)

View File

@ -1375,6 +1375,11 @@ class DocumentService:
document.name = name
db.session.add(document)
if document.data_source_info_dict:
db.session.query(UploadFile).where(
UploadFile.id == document.data_source_info_dict["upload_file_id"]
).update({UploadFile.name: name})
db.session.commit()
return document

View File

@ -0,0 +1,185 @@
import csv
import io
import json
from datetime import datetime
from flask import Response
from sqlalchemy import or_
from extensions.ext_database import db
from models.model import Account, App, Conversation, Message, MessageFeedback
class FeedbackService:
@staticmethod
def export_feedbacks(
app_id: str,
from_source: str | None = None,
rating: str | None = None,
has_comment: bool | None = None,
start_date: str | None = None,
end_date: str | None = None,
format_type: str = "csv",
):
"""
Export feedback data with message details for analysis
Args:
app_id: Application ID
from_source: Filter by feedback source ('user' or 'admin')
rating: Filter by rating ('like' or 'dislike')
has_comment: Only include feedback with comments
start_date: Start date filter (YYYY-MM-DD)
end_date: End date filter (YYYY-MM-DD)
format_type: Export format ('csv' or 'json')
"""
# Validate format early to avoid hitting DB when unnecessary
fmt = (format_type or "csv").lower()
if fmt not in {"csv", "json"}:
raise ValueError(f"Unsupported format: {format_type}")
# Build base query
query = (
db.session.query(MessageFeedback, Message, Conversation, App, Account)
.join(Message, MessageFeedback.message_id == Message.id)
.join(Conversation, MessageFeedback.conversation_id == Conversation.id)
.join(App, MessageFeedback.app_id == App.id)
.outerjoin(Account, MessageFeedback.from_account_id == Account.id)
.where(MessageFeedback.app_id == app_id)
)
# Apply filters
if from_source:
query = query.filter(MessageFeedback.from_source == from_source)
if rating:
query = query.filter(MessageFeedback.rating == rating)
if has_comment is not None:
if has_comment:
query = query.filter(MessageFeedback.content.isnot(None), MessageFeedback.content != "")
else:
query = query.filter(or_(MessageFeedback.content.is_(None), MessageFeedback.content == ""))
if start_date:
try:
start_dt = datetime.strptime(start_date, "%Y-%m-%d")
query = query.filter(MessageFeedback.created_at >= start_dt)
except ValueError:
raise ValueError(f"Invalid start_date format: {start_date}. Use YYYY-MM-DD")
if end_date:
try:
end_dt = datetime.strptime(end_date, "%Y-%m-%d")
query = query.filter(MessageFeedback.created_at <= end_dt)
except ValueError:
raise ValueError(f"Invalid end_date format: {end_date}. Use YYYY-MM-DD")
# Order by creation date (newest first)
query = query.order_by(MessageFeedback.created_at.desc())
# Execute query
results = query.all()
# Prepare data for export
export_data = []
for feedback, message, conversation, app, account in results:
# Get the user query from the message
user_query = message.query or message.inputs.get("query", "") if message.inputs else ""
# Format the feedback data
feedback_record = {
"feedback_id": str(feedback.id),
"app_name": app.name,
"app_id": str(app.id),
"conversation_id": str(conversation.id),
"conversation_name": conversation.name or "",
"message_id": str(message.id),
"user_query": user_query,
"ai_response": message.answer[:500] + "..."
if len(message.answer) > 500
else message.answer, # Truncate long responses
"feedback_rating": "👍" if feedback.rating == "like" else "👎",
"feedback_rating_raw": feedback.rating,
"feedback_comment": feedback.content or "",
"feedback_source": feedback.from_source,
"feedback_date": feedback.created_at.strftime("%Y-%m-%d %H:%M:%S"),
"message_date": message.created_at.strftime("%Y-%m-%d %H:%M:%S"),
"from_account_name": account.name if account else "",
"from_end_user_id": str(feedback.from_end_user_id) if feedback.from_end_user_id else "",
"has_comment": "Yes" if feedback.content and feedback.content.strip() else "No",
}
export_data.append(feedback_record)
# Export based on format
if fmt == "csv":
return FeedbackService._export_csv(export_data, app_id)
else: # fmt == "json"
return FeedbackService._export_json(export_data, app_id)
@staticmethod
def _export_csv(data, app_id):
"""Export data as CSV"""
if not data:
pass # allow empty CSV with headers only
# Create CSV in memory
output = io.StringIO()
# Define headers
headers = [
"feedback_id",
"app_name",
"app_id",
"conversation_id",
"conversation_name",
"message_id",
"user_query",
"ai_response",
"feedback_rating",
"feedback_rating_raw",
"feedback_comment",
"feedback_source",
"feedback_date",
"message_date",
"from_account_name",
"from_end_user_id",
"has_comment",
]
writer = csv.DictWriter(output, fieldnames=headers)
writer.writeheader()
writer.writerows(data)
# Create response without requiring app context
response = Response(output.getvalue(), mimetype="text/csv; charset=utf-8-sig")
response.headers["Content-Disposition"] = (
f"attachment; filename=dify_feedback_export_{app_id}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv"
)
return response
@staticmethod
def _export_json(data, app_id):
"""Export data as JSON"""
response_data = {
"export_info": {
"app_id": app_id,
"export_date": datetime.now().isoformat(),
"total_records": len(data),
"data_source": "dify_feedback_export",
},
"feedback_data": data,
}
# Create response without requiring app context
response = Response(
json.dumps(response_data, ensure_ascii=False, indent=2),
mimetype="application/json; charset=utf-8",
)
response.headers["Content-Disposition"] = (
f"attachment; filename=dify_feedback_export_{app_id}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
)
return response

View File

@ -3,8 +3,8 @@ import os
import uuid
from typing import Literal, Union
from sqlalchemy import Engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy import Engine, select
from sqlalchemy.orm import Session, sessionmaker
from werkzeug.exceptions import NotFound
from configs import dify_config
@ -29,7 +29,7 @@ PREVIEW_WORDS_LIMIT = 3000
class FileService:
_session_maker: sessionmaker
_session_maker: sessionmaker[Session]
def __init__(self, session_factory: sessionmaker | Engine | None = None):
if isinstance(session_factory, Engine):
@ -236,11 +236,10 @@ class FileService:
return content.decode("utf-8")
def delete_file(self, file_id: str):
with self._session_maker(expire_on_commit=False) as session:
upload_file: UploadFile | None = session.query(UploadFile).where(UploadFile.id == file_id).first()
with self._session_maker() as session, session.begin():
upload_file = session.scalar(select(UploadFile).where(UploadFile.id == file_id))
if not upload_file:
return
storage.delete(upload_file.key)
session.delete(upload_file)
session.commit()
if not upload_file:
return
storage.delete(upload_file.key)
session.delete(upload_file)

View File

@ -5,6 +5,7 @@ import secrets
from collections.abc import Mapping
from typing import Any
import orjson
from flask import request
from pydantic import BaseModel
from sqlalchemy import select
@ -169,7 +170,7 @@ class WebhookService:
- method: HTTP method
- headers: Request headers
- query_params: Query parameters as strings
- body: Request body (varies by content type)
- body: Request body (varies by content type; JSON parsing errors raise ValueError)
- files: Uploaded files (if any)
"""
cls._validate_content_length()
@ -255,14 +256,21 @@ class WebhookService:
Returns:
tuple: (body_data, files_data) where:
- body_data: Parsed JSON content or empty dict if parsing fails
- body_data: Parsed JSON content
- files_data: Empty dict (JSON requests don't contain files)
Raises:
ValueError: If JSON parsing fails
"""
raw_body = request.get_data(cache=True)
if not raw_body or raw_body.strip() == b"":
return {}, {}
try:
body = request.get_json() or {}
except Exception:
logger.warning("Failed to parse JSON body")
body = {}
body = orjson.loads(raw_body)
except orjson.JSONDecodeError as exc:
logger.warning("Failed to parse JSON body: %s", exc)
raise ValueError(f"Invalid JSON body: {exc}") from exc
return body, {}
@classmethod

View File

@ -15,7 +15,7 @@ from core.file import File
from core.repositories import DifyCoreRepositoryFactory
from core.variables import Variable
from core.variables.variables import VariableUnion
from core.workflow.entities import VariablePool, WorkflowNodeExecution
from core.workflow.entities import WorkflowNodeExecution
from core.workflow.enums import ErrorStrategy, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from core.workflow.errors import WorkflowNodeRunFailedError
from core.workflow.graph_events import GraphNodeEventBase, NodeRunFailedEvent, NodeRunSucceededEvent
@ -24,6 +24,7 @@ from core.workflow.nodes import NodeType
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
from core.workflow.nodes.start.entities import StartNodeData
from core.workflow.runtime import VariablePool
from core.workflow.system_variable import SystemVariable
from core.workflow.workflow_entry import WorkflowEntry
from enums.cloud_plan import CloudPlan

View File

@ -62,6 +62,7 @@ WEAVIATE_ENDPOINT=http://localhost:8080
WEAVIATE_API_KEY=WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih
WEAVIATE_GRPC_ENABLED=false
WEAVIATE_BATCH_SIZE=100
WEAVIATE_TOKENIZATION=word
# Upload configuration

View File

@ -0,0 +1,106 @@
"""Basic integration tests for Feedback API endpoints."""
import uuid
from flask.testing import FlaskClient
class TestFeedbackApiBasic:
"""Basic tests for feedback API endpoints."""
def test_feedback_export_endpoint_exists(self, test_client: FlaskClient, auth_header):
"""Test that feedback export endpoint exists and handles basic requests."""
app_id = str(uuid.uuid4())
# Test endpoint exists (even if it fails, it should return 500 or 403, not 404)
response = test_client.get(
f"/console/api/apps/{app_id}/feedbacks/export", headers=auth_header, query_string={"format": "csv"}
)
# Should not return 404 (endpoint exists)
assert response.status_code != 404
# Should return authentication or permission error
assert response.status_code in [401, 403, 500] # 500 if app doesn't exist, 403 if no permission
def test_feedback_summary_endpoint_exists(self, test_client: FlaskClient, auth_header):
"""Test that feedback summary endpoint exists and handles basic requests."""
app_id = str(uuid.uuid4())
# Test endpoint exists
response = test_client.get(f"/console/api/apps/{app_id}/feedbacks/summary", headers=auth_header)
# Should not return 404 (endpoint exists)
assert response.status_code != 404
# Should return authentication or permission error
assert response.status_code in [401, 403, 500]
def test_feedback_export_invalid_format(self, test_client: FlaskClient, auth_header):
"""Test feedback export endpoint with invalid format parameter."""
app_id = str(uuid.uuid4())
# Test with invalid format
response = test_client.get(
f"/console/api/apps/{app_id}/feedbacks/export",
headers=auth_header,
query_string={"format": "invalid_format"},
)
# Should not return 404
assert response.status_code != 404
def test_feedback_export_with_filters(self, test_client: FlaskClient, auth_header):
"""Test feedback export endpoint with various filter parameters."""
app_id = str(uuid.uuid4())
# Test with various filter combinations
filter_params = [
{"from_source": "user"},
{"rating": "like"},
{"has_comment": True},
{"start_date": "2024-01-01"},
{"end_date": "2024-12-31"},
{"format": "json"},
{
"from_source": "admin",
"rating": "dislike",
"has_comment": True,
"start_date": "2024-01-01",
"end_date": "2024-12-31",
"format": "csv",
},
]
for params in filter_params:
response = test_client.get(
f"/console/api/apps/{app_id}/feedbacks/export", headers=auth_header, query_string=params
)
# Should not return 404
assert response.status_code != 404
def test_feedback_export_invalid_dates(self, test_client: FlaskClient, auth_header):
"""Test feedback export endpoint with invalid date formats."""
app_id = str(uuid.uuid4())
# Test with invalid date formats
invalid_dates = [
{"start_date": "invalid-date"},
{"end_date": "not-a-date"},
{"start_date": "2024-13-01"}, # Invalid month
{"end_date": "2024-12-32"}, # Invalid day
]
for params in invalid_dates:
response = test_client.get(
f"/console/api/apps/{app_id}/feedbacks/export", headers=auth_header, query_string=params
)
# Should not return 404
assert response.status_code != 404

View File

@ -0,0 +1,334 @@
"""Integration tests for Feedback Export API endpoints."""
import json
import uuid
from datetime import datetime
from types import SimpleNamespace
from unittest import mock
import pytest
from flask.testing import FlaskClient
from controllers.console.app import message as message_api
from controllers.console.app import wraps
from libs.datetime_utils import naive_utc_now
from models import App, Tenant
from models.account import Account, TenantAccountJoin, TenantAccountRole
from models.model import AppMode, MessageFeedback
from services.feedback_service import FeedbackService
class TestFeedbackExportApi:
"""Test feedback export API endpoints."""
@pytest.fixture
def mock_app_model(self):
"""Create a mock App model for testing."""
app = App()
app.id = str(uuid.uuid4())
app.mode = AppMode.CHAT
app.tenant_id = str(uuid.uuid4())
app.status = "normal"
app.name = "Test App"
return app
@pytest.fixture
def mock_account(self, monkeypatch: pytest.MonkeyPatch):
"""Create a mock Account for testing."""
account = Account(
name="Test User",
email="test@example.com",
)
account.last_active_at = naive_utc_now()
account.created_at = naive_utc_now()
account.updated_at = naive_utc_now()
account.id = str(uuid.uuid4())
# Create mock tenant
tenant = Tenant(name="Test Tenant")
tenant.id = str(uuid.uuid4())
mock_session_instance = mock.Mock()
mock_tenant_join = TenantAccountJoin(role=TenantAccountRole.OWNER)
monkeypatch.setattr(mock_session_instance, "scalar", mock.Mock(return_value=mock_tenant_join))
mock_scalars_result = mock.Mock()
mock_scalars_result.one.return_value = tenant
monkeypatch.setattr(mock_session_instance, "scalars", mock.Mock(return_value=mock_scalars_result))
mock_session_context = mock.Mock()
mock_session_context.__enter__.return_value = mock_session_instance
monkeypatch.setattr("models.account.Session", lambda _, expire_on_commit: mock_session_context)
account.current_tenant = tenant
return account
@pytest.fixture
def sample_feedback_data(self):
"""Create sample feedback data for testing."""
app_id = str(uuid.uuid4())
conversation_id = str(uuid.uuid4())
message_id = str(uuid.uuid4())
# Mock feedback data
user_feedback = MessageFeedback(
id=str(uuid.uuid4()),
app_id=app_id,
conversation_id=conversation_id,
message_id=message_id,
rating="like",
from_source="user",
content=None,
from_end_user_id=str(uuid.uuid4()),
from_account_id=None,
created_at=naive_utc_now(),
)
admin_feedback = MessageFeedback(
id=str(uuid.uuid4()),
app_id=app_id,
conversation_id=conversation_id,
message_id=message_id,
rating="dislike",
from_source="admin",
content="The response was not helpful",
from_end_user_id=None,
from_account_id=str(uuid.uuid4()),
created_at=naive_utc_now(),
)
# Mock message and conversation
mock_message = SimpleNamespace(
id=message_id,
conversation_id=conversation_id,
query="What is the weather today?",
answer="It's sunny and 25 degrees outside.",
inputs={"query": "What is the weather today?"},
created_at=naive_utc_now(),
)
mock_conversation = SimpleNamespace(id=conversation_id, name="Weather Conversation", app_id=app_id)
mock_app = SimpleNamespace(id=app_id, name="Weather App")
return {
"user_feedback": user_feedback,
"admin_feedback": admin_feedback,
"message": mock_message,
"conversation": mock_conversation,
"app": mock_app,
}
@pytest.mark.parametrize(
("role", "status"),
[
(TenantAccountRole.OWNER, 200),
(TenantAccountRole.ADMIN, 200),
(TenantAccountRole.EDITOR, 200),
(TenantAccountRole.NORMAL, 403),
(TenantAccountRole.DATASET_OPERATOR, 403),
],
)
def test_feedback_export_permissions(
self,
test_client: FlaskClient,
auth_header,
monkeypatch,
mock_app_model,
mock_account,
role: TenantAccountRole,
status: int,
):
"""Test feedback export endpoint permissions."""
# Setup mocks
mock_load_app_model = mock.Mock(return_value=mock_app_model)
monkeypatch.setattr(wraps, "_load_app_model", mock_load_app_model)
mock_export_feedbacks = mock.Mock(return_value="mock csv response")
monkeypatch.setattr(FeedbackService, "export_feedbacks", mock_export_feedbacks)
monkeypatch.setattr(message_api, "current_user", mock_account)
# Set user role
mock_account.role = role
response = test_client.get(
f"/console/api/apps/{mock_app_model.id}/feedbacks/export",
headers=auth_header,
query_string={"format": "csv"},
)
assert response.status_code == status
if status == 200:
mock_export_feedbacks.assert_called_once()
def test_feedback_export_csv_format(
self, test_client: FlaskClient, auth_header, monkeypatch, mock_app_model, mock_account, sample_feedback_data
):
"""Test feedback export in CSV format."""
# Setup mocks
mock_load_app_model = mock.Mock(return_value=mock_app_model)
monkeypatch.setattr(wraps, "_load_app_model", mock_load_app_model)
# Create mock CSV response
mock_csv_content = (
"feedback_id,app_name,conversation_id,user_query,ai_response,feedback_rating,feedback_comment\n"
)
mock_csv_content += f"{sample_feedback_data['user_feedback'].id},{sample_feedback_data['app'].name},"
mock_csv_content += f"{sample_feedback_data['conversation'].id},{sample_feedback_data['message'].query},"
mock_csv_content += f"{sample_feedback_data['message'].answer},👍,\n"
mock_response = mock.Mock()
mock_response.headers = {"Content-Type": "text/csv; charset=utf-8-sig"}
mock_response.data = mock_csv_content.encode("utf-8")
mock_export_feedbacks = mock.Mock(return_value=mock_response)
monkeypatch.setattr(FeedbackService, "export_feedbacks", mock_export_feedbacks)
monkeypatch.setattr(message_api, "current_user", mock_account)
response = test_client.get(
f"/console/api/apps/{mock_app_model.id}/feedbacks/export",
headers=auth_header,
query_string={"format": "csv", "from_source": "user"},
)
assert response.status_code == 200
assert "text/csv" in response.content_type
def test_feedback_export_json_format(
self, test_client: FlaskClient, auth_header, monkeypatch, mock_app_model, mock_account, sample_feedback_data
):
"""Test feedback export in JSON format."""
# Setup mocks
mock_load_app_model = mock.Mock(return_value=mock_app_model)
monkeypatch.setattr(wraps, "_load_app_model", mock_load_app_model)
mock_json_response = {
"export_info": {
"app_id": mock_app_model.id,
"export_date": datetime.now().isoformat(),
"total_records": 2,
"data_source": "dify_feedback_export",
},
"feedback_data": [
{
"feedback_id": sample_feedback_data["user_feedback"].id,
"feedback_rating": "👍",
"feedback_rating_raw": "like",
"feedback_comment": "",
}
],
}
mock_response = mock.Mock()
mock_response.headers = {"Content-Type": "application/json; charset=utf-8"}
mock_response.data = json.dumps(mock_json_response).encode("utf-8")
mock_export_feedbacks = mock.Mock(return_value=mock_response)
monkeypatch.setattr(FeedbackService, "export_feedbacks", mock_export_feedbacks)
monkeypatch.setattr(message_api, "current_user", mock_account)
response = test_client.get(
f"/console/api/apps/{mock_app_model.id}/feedbacks/export",
headers=auth_header,
query_string={"format": "json"},
)
assert response.status_code == 200
assert "application/json" in response.content_type
def test_feedback_export_with_filters(
self, test_client: FlaskClient, auth_header, monkeypatch, mock_app_model, mock_account
):
"""Test feedback export with various filters."""
# Setup mocks
mock_load_app_model = mock.Mock(return_value=mock_app_model)
monkeypatch.setattr(wraps, "_load_app_model", mock_load_app_model)
mock_export_feedbacks = mock.Mock(return_value="mock filtered response")
monkeypatch.setattr(FeedbackService, "export_feedbacks", mock_export_feedbacks)
monkeypatch.setattr(message_api, "current_user", mock_account)
# Test with multiple filters
response = test_client.get(
f"/console/api/apps/{mock_app_model.id}/feedbacks/export",
headers=auth_header,
query_string={
"from_source": "user",
"rating": "dislike",
"has_comment": True,
"start_date": "2024-01-01",
"end_date": "2024-12-31",
"format": "csv",
},
)
assert response.status_code == 200
# Verify service was called with correct parameters
mock_export_feedbacks.assert_called_once_with(
app_id=mock_app_model.id,
from_source="user",
rating="dislike",
has_comment=True,
start_date="2024-01-01",
end_date="2024-12-31",
format_type="csv",
)
def test_feedback_export_invalid_date_format(
self, test_client: FlaskClient, auth_header, monkeypatch, mock_app_model, mock_account
):
"""Test feedback export with invalid date format."""
# Setup mocks
mock_load_app_model = mock.Mock(return_value=mock_app_model)
monkeypatch.setattr(wraps, "_load_app_model", mock_load_app_model)
# Mock the service to raise ValueError for invalid date
mock_export_feedbacks = mock.Mock(side_effect=ValueError("Invalid date format"))
monkeypatch.setattr(FeedbackService, "export_feedbacks", mock_export_feedbacks)
monkeypatch.setattr(message_api, "current_user", mock_account)
response = test_client.get(
f"/console/api/apps/{mock_app_model.id}/feedbacks/export",
headers=auth_header,
query_string={"start_date": "invalid-date", "format": "csv"},
)
assert response.status_code == 400
response_json = response.get_json()
assert "Parameter validation error" in response_json["error"]
def test_feedback_export_server_error(
self, test_client: FlaskClient, auth_header, monkeypatch, mock_app_model, mock_account
):
"""Test feedback export with server error."""
# Setup mocks
mock_load_app_model = mock.Mock(return_value=mock_app_model)
monkeypatch.setattr(wraps, "_load_app_model", mock_load_app_model)
# Mock the service to raise an exception
mock_export_feedbacks = mock.Mock(side_effect=Exception("Database connection failed"))
monkeypatch.setattr(FeedbackService, "export_feedbacks", mock_export_feedbacks)
monkeypatch.setattr(message_api, "current_user", mock_account)
response = test_client.get(
f"/console/api/apps/{mock_app_model.id}/feedbacks/export",
headers=auth_header,
query_string={"format": "csv"},
)
assert response.status_code == 500

View File

@ -319,7 +319,7 @@ class TestPauseStatePersistenceLayerTestContainers:
# Create pause event
event = GraphRunPausedEvent(
reason=SchedulingPause(message="test pause"),
reasons=[SchedulingPause(message="test pause")],
outputs={"intermediate": "result"},
)
@ -381,7 +381,7 @@ class TestPauseStatePersistenceLayerTestContainers:
command_channel = _TestCommandChannelImpl()
layer.initialize(graph_runtime_state, command_channel)
event = GraphRunPausedEvent(reason=SchedulingPause(message="test pause"))
event = GraphRunPausedEvent(reasons=[SchedulingPause(message="test pause")])
# Act - Save pause state
layer.on_event(event)
@ -390,6 +390,7 @@ class TestPauseStatePersistenceLayerTestContainers:
pause_entity = self.workflow_run_service._workflow_run_repo.get_workflow_pause(self.test_workflow_run_id)
assert pause_entity is not None
assert pause_entity.workflow_execution_id == self.test_workflow_run_id
assert pause_entity.get_pause_reasons() == event.reasons
state_bytes = pause_entity.get_state()
resumption_context = WorkflowResumptionContext.loads(state_bytes.decode())
@ -414,7 +415,7 @@ class TestPauseStatePersistenceLayerTestContainers:
command_channel = _TestCommandChannelImpl()
layer.initialize(graph_runtime_state, command_channel)
event = GraphRunPausedEvent(reason=SchedulingPause(message="test pause"))
event = GraphRunPausedEvent(reasons=[SchedulingPause(message="test pause")])
# Act
layer.on_event(event)
@ -448,7 +449,7 @@ class TestPauseStatePersistenceLayerTestContainers:
command_channel = _TestCommandChannelImpl()
layer.initialize(graph_runtime_state, command_channel)
event = GraphRunPausedEvent(reason=SchedulingPause(message="test pause"))
event = GraphRunPausedEvent(reasons=[SchedulingPause(message="test pause")])
# Act
layer.on_event(event)
@ -514,7 +515,7 @@ class TestPauseStatePersistenceLayerTestContainers:
command_channel = _TestCommandChannelImpl()
layer.initialize(graph_runtime_state, command_channel)
event = GraphRunPausedEvent(reason=SchedulingPause(message="test pause"))
event = GraphRunPausedEvent(reasons=[SchedulingPause(message="test pause")])
# Act
layer.on_event(event)
@ -570,7 +571,7 @@ class TestPauseStatePersistenceLayerTestContainers:
layer = self._create_pause_state_persistence_layer()
# Don't initialize - graph_runtime_state should not be set
event = GraphRunPausedEvent(reason=SchedulingPause(message="test pause"))
event = GraphRunPausedEvent(reasons=[SchedulingPause(message="test pause")])
# Act & Assert - Should raise AttributeError
with pytest.raises(AttributeError):

View File

@ -295,9 +295,13 @@ class TestAPIBasedExtensionService:
original_name = created_extension.name
original_endpoint = created_extension.api_endpoint
# Update the extension
# Update the extension with guaranteed different values
new_name = fake.company()
# Ensure new endpoint is different from original
new_endpoint = f"https://{fake.domain_name()}/api"
# If by chance they're the same, generate a new one
while new_endpoint == original_endpoint:
new_endpoint = f"https://{fake.domain_name()}/api"
new_api_key = fake.password(length=20)
created_extension.name = new_name

View File

@ -0,0 +1,386 @@
"""Unit tests for FeedbackService."""
import json
from datetime import datetime
from types import SimpleNamespace
from unittest import mock
import pytest
from extensions.ext_database import db
from models.model import App, Conversation, Message
from services.feedback_service import FeedbackService
class TestFeedbackService:
"""Test FeedbackService methods."""
@pytest.fixture
def mock_db_session(self, monkeypatch):
"""Mock database session."""
mock_session = mock.Mock()
monkeypatch.setattr(db, "session", mock_session)
return mock_session
@pytest.fixture
def sample_data(self):
"""Create sample data for testing."""
app_id = "test-app-id"
# Create mock models
app = App(id=app_id, name="Test App")
conversation = Conversation(id="test-conversation-id", app_id=app_id, name="Test Conversation")
message = Message(
id="test-message-id",
conversation_id="test-conversation-id",
query="What is AI?",
answer="AI is artificial intelligence.",
inputs={"query": "What is AI?"},
created_at=datetime(2024, 1, 1, 10, 0, 0),
)
# Use SimpleNamespace to avoid ORM model constructor issues
user_feedback = SimpleNamespace(
id="user-feedback-id",
app_id=app_id,
conversation_id="test-conversation-id",
message_id="test-message-id",
rating="like",
from_source="user",
content="Great answer!",
from_end_user_id="user-123",
from_account_id=None,
from_account=None, # Mock account object
created_at=datetime(2024, 1, 1, 10, 5, 0),
)
admin_feedback = SimpleNamespace(
id="admin-feedback-id",
app_id=app_id,
conversation_id="test-conversation-id",
message_id="test-message-id",
rating="dislike",
from_source="admin",
content="Could be more detailed",
from_end_user_id=None,
from_account_id="admin-456",
from_account=SimpleNamespace(name="Admin User"), # Mock account object
created_at=datetime(2024, 1, 1, 10, 10, 0),
)
return {
"app": app,
"conversation": conversation,
"message": message,
"user_feedback": user_feedback,
"admin_feedback": admin_feedback,
}
def test_export_feedbacks_csv_format(self, mock_db_session, sample_data):
"""Test exporting feedback data in CSV format."""
# Setup mock query result
mock_query = mock.Mock()
mock_query.join.return_value = mock_query
mock_query.outerjoin.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.filter.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.all.return_value = [
(
sample_data["user_feedback"],
sample_data["message"],
sample_data["conversation"],
sample_data["app"],
sample_data["user_feedback"].from_account,
)
]
mock_db_session.query.return_value = mock_query
# Test CSV export
result = FeedbackService.export_feedbacks(app_id=sample_data["app"].id, format_type="csv")
# Verify response structure
assert hasattr(result, "headers")
assert "text/csv" in result.headers["Content-Type"]
assert "attachment" in result.headers["Content-Disposition"]
# Check CSV content
csv_content = result.get_data(as_text=True)
# Verify essential headers exist (order may include additional columns)
assert "feedback_id" in csv_content
assert "app_name" in csv_content
assert "conversation_id" in csv_content
assert sample_data["app"].name in csv_content
assert sample_data["message"].query in csv_content
def test_export_feedbacks_json_format(self, mock_db_session, sample_data):
"""Test exporting feedback data in JSON format."""
# Setup mock query result
mock_query = mock.Mock()
mock_query.join.return_value = mock_query
mock_query.outerjoin.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.filter.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.all.return_value = [
(
sample_data["admin_feedback"],
sample_data["message"],
sample_data["conversation"],
sample_data["app"],
sample_data["admin_feedback"].from_account,
)
]
mock_db_session.query.return_value = mock_query
# Test JSON export
result = FeedbackService.export_feedbacks(app_id=sample_data["app"].id, format_type="json")
# Verify response structure
assert hasattr(result, "headers")
assert "application/json" in result.headers["Content-Type"]
assert "attachment" in result.headers["Content-Disposition"]
# Check JSON content
json_content = json.loads(result.get_data(as_text=True))
assert "export_info" in json_content
assert "feedback_data" in json_content
assert json_content["export_info"]["app_id"] == sample_data["app"].id
assert json_content["export_info"]["total_records"] == 1
def test_export_feedbacks_with_filters(self, mock_db_session, sample_data):
"""Test exporting feedback with various filters."""
# Setup mock query result
mock_query = mock.Mock()
mock_query.join.return_value = mock_query
mock_query.outerjoin.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.filter.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.all.return_value = [
(
sample_data["admin_feedback"],
sample_data["message"],
sample_data["conversation"],
sample_data["app"],
sample_data["admin_feedback"].from_account,
)
]
mock_db_session.query.return_value = mock_query
# Test with filters
result = FeedbackService.export_feedbacks(
app_id=sample_data["app"].id,
from_source="admin",
rating="dislike",
has_comment=True,
start_date="2024-01-01",
end_date="2024-12-31",
format_type="csv",
)
# Verify filters were applied
assert mock_query.filter.called
filter_calls = mock_query.filter.call_args_list
# At least three filter invocations are expected (source, rating, comment)
assert len(filter_calls) >= 3
def test_export_feedbacks_no_data(self, mock_db_session, sample_data):
"""Test exporting feedback when no data exists."""
# Setup mock query result with no data
mock_query = mock.Mock()
mock_query.join.return_value = mock_query
mock_query.outerjoin.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.filter.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.all.return_value = []
mock_db_session.query.return_value = mock_query
result = FeedbackService.export_feedbacks(app_id=sample_data["app"].id, format_type="csv")
# Should return an empty CSV with headers only
assert hasattr(result, "headers")
assert "text/csv" in result.headers["Content-Type"]
csv_content = result.get_data(as_text=True)
# Headers should exist (order can include additional columns)
assert "feedback_id" in csv_content
assert "app_name" in csv_content
assert "conversation_id" in csv_content
# No data rows expected
assert len([line for line in csv_content.strip().splitlines() if line.strip()]) == 1
def test_export_feedbacks_invalid_date_format(self, mock_db_session, sample_data):
"""Test exporting feedback with invalid date format."""
# Test with invalid start_date
with pytest.raises(ValueError, match="Invalid start_date format"):
FeedbackService.export_feedbacks(app_id=sample_data["app"].id, start_date="invalid-date-format")
# Test with invalid end_date
with pytest.raises(ValueError, match="Invalid end_date format"):
FeedbackService.export_feedbacks(app_id=sample_data["app"].id, end_date="invalid-date-format")
def test_export_feedbacks_invalid_format(self, mock_db_session, sample_data):
"""Test exporting feedback with unsupported format."""
with pytest.raises(ValueError, match="Unsupported format"):
FeedbackService.export_feedbacks(
app_id=sample_data["app"].id,
format_type="xml", # Unsupported format
)
def test_export_feedbacks_long_response_truncation(self, mock_db_session, sample_data):
"""Test that long AI responses are truncated in export."""
# Create message with long response
long_message = Message(
id="long-message-id",
conversation_id="test-conversation-id",
query="What is AI?",
answer="A" * 600, # 600 character response
inputs={"query": "What is AI?"},
created_at=datetime(2024, 1, 1, 10, 0, 0),
)
# Setup mock query result
mock_query = mock.Mock()
mock_query.join.return_value = mock_query
mock_query.outerjoin.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.filter.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.all.return_value = [
(
sample_data["user_feedback"],
long_message,
sample_data["conversation"],
sample_data["app"],
sample_data["user_feedback"].from_account,
)
]
mock_db_session.query.return_value = mock_query
# Test export
result = FeedbackService.export_feedbacks(app_id=sample_data["app"].id, format_type="json")
# Check JSON content
json_content = json.loads(result.get_data(as_text=True))
exported_answer = json_content["feedback_data"][0]["ai_response"]
# Should be truncated with ellipsis
assert len(exported_answer) <= 503 # 500 + "..."
assert exported_answer.endswith("...")
assert len(exported_answer) > 500 # Should be close to limit
def test_export_feedbacks_unicode_content(self, mock_db_session, sample_data):
"""Test exporting feedback with unicode content (Chinese characters)."""
# Create feedback with Chinese content (use SimpleNamespace to avoid ORM constructor constraints)
chinese_feedback = SimpleNamespace(
id="chinese-feedback-id",
app_id=sample_data["app"].id,
conversation_id="test-conversation-id",
message_id="test-message-id",
rating="dislike",
from_source="user",
content="回答不够详细,需要更多信息",
from_end_user_id="user-123",
from_account_id=None,
created_at=datetime(2024, 1, 1, 10, 5, 0),
)
# Create Chinese message
chinese_message = Message(
id="chinese-message-id",
conversation_id="test-conversation-id",
query="什么是人工智能?",
answer="人工智能是模拟人类智能的技术。",
inputs={"query": "什么是人工智能?"},
created_at=datetime(2024, 1, 1, 10, 0, 0),
)
# Setup mock query result
mock_query = mock.Mock()
mock_query.join.return_value = mock_query
mock_query.outerjoin.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.filter.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.all.return_value = [
(
chinese_feedback,
chinese_message,
sample_data["conversation"],
sample_data["app"],
None, # No account for user feedback
)
]
mock_db_session.query.return_value = mock_query
# Test export
result = FeedbackService.export_feedbacks(app_id=sample_data["app"].id, format_type="csv")
# Check that unicode content is preserved
csv_content = result.get_data(as_text=True)
assert "什么是人工智能?" in csv_content
assert "回答不够详细,需要更多信息" in csv_content
assert "人工智能是模拟人类智能的技术" in csv_content
def test_export_feedbacks_emoji_ratings(self, mock_db_session, sample_data):
"""Test that rating emojis are properly formatted in export."""
# Setup mock query result with both like and dislike feedback
mock_query = mock.Mock()
mock_query.join.return_value = mock_query
mock_query.outerjoin.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.filter.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.all.return_value = [
(
sample_data["user_feedback"],
sample_data["message"],
sample_data["conversation"],
sample_data["app"],
sample_data["user_feedback"].from_account,
),
(
sample_data["admin_feedback"],
sample_data["message"],
sample_data["conversation"],
sample_data["app"],
sample_data["admin_feedback"].from_account,
),
]
mock_db_session.query.return_value = mock_query
# Test export
result = FeedbackService.export_feedbacks(app_id=sample_data["app"].id, format_type="json")
# Check JSON content for emoji ratings
json_content = json.loads(result.get_data(as_text=True))
feedback_data = json_content["feedback_data"]
# Should have both feedback records
assert len(feedback_data) == 2
# Check that emojis are properly set
like_feedback = next(f for f in feedback_data if f["feedback_rating_raw"] == "like")
dislike_feedback = next(f for f in feedback_data if f["feedback_rating_raw"] == "dislike")
assert like_feedback["feedback_rating"] == "👍"
assert dislike_feedback["feedback_rating"] == "👎"

View File

@ -334,12 +334,14 @@ class TestWorkflowPauseIntegration:
workflow_run_id=workflow_run.id,
state_owner_user_id=self.test_user_id,
state=test_state,
pause_reasons=[],
)
# Assert - Pause state created
assert pause_entity is not None
assert pause_entity.id is not None
assert pause_entity.workflow_execution_id == workflow_run.id
assert list(pause_entity.get_pause_reasons()) == []
# Convert both to strings for comparison
retrieved_state = pause_entity.get_state()
if isinstance(retrieved_state, bytes):
@ -366,6 +368,7 @@ class TestWorkflowPauseIntegration:
if isinstance(retrieved_state, bytes):
retrieved_state = retrieved_state.decode()
assert retrieved_state == test_state
assert list(retrieved_entity.get_pause_reasons()) == []
# Act - Resume workflow
resumed_entity = repository.resume_workflow_pause(
@ -402,6 +405,7 @@ class TestWorkflowPauseIntegration:
workflow_run_id=workflow_run.id,
state_owner_user_id=self.test_user_id,
state=test_state,
pause_reasons=[],
)
assert pause_entity is not None
@ -432,6 +436,7 @@ class TestWorkflowPauseIntegration:
workflow_run_id=workflow_run.id,
state_owner_user_id=self.test_user_id,
state=test_state,
pause_reasons=[],
)
@pytest.mark.parametrize("test_case", resume_workflow_success_cases(), ids=lambda tc: tc.name)
@ -449,6 +454,7 @@ class TestWorkflowPauseIntegration:
workflow_run_id=workflow_run.id,
state_owner_user_id=self.test_user_id,
state=test_state,
pause_reasons=[],
)
self.session.refresh(workflow_run)
@ -480,6 +486,7 @@ class TestWorkflowPauseIntegration:
workflow_run_id=workflow_run.id,
state_owner_user_id=self.test_user_id,
state=test_state,
pause_reasons=[],
)
self.session.refresh(workflow_run)
@ -503,6 +510,7 @@ class TestWorkflowPauseIntegration:
workflow_run_id=workflow_run.id,
state_owner_user_id=self.test_user_id,
state=test_state,
pause_reasons=[],
)
pause_model = self.session.get(WorkflowPauseModel, pause_entity.id)
pause_model.resumed_at = naive_utc_now()
@ -530,6 +538,7 @@ class TestWorkflowPauseIntegration:
workflow_run_id=nonexistent_id,
state_owner_user_id=self.test_user_id,
state=test_state,
pause_reasons=[],
)
def test_resume_nonexistent_workflow_run(self):
@ -543,6 +552,7 @@ class TestWorkflowPauseIntegration:
workflow_run_id=workflow_run.id,
state_owner_user_id=self.test_user_id,
state=test_state,
pause_reasons=[],
)
nonexistent_id = str(uuid.uuid4())
@ -570,6 +580,7 @@ class TestWorkflowPauseIntegration:
workflow_run_id=workflow_run.id,
state_owner_user_id=self.test_user_id,
state=test_state,
pause_reasons=[],
)
# Manually adjust timestamps for testing
@ -648,6 +659,7 @@ class TestWorkflowPauseIntegration:
workflow_run_id=workflow_run.id,
state_owner_user_id=self.test_user_id,
state=test_state,
pause_reasons=[],
)
pause_entities.append(pause_entity)
@ -750,6 +762,7 @@ class TestWorkflowPauseIntegration:
workflow_run_id=workflow_run1.id,
state_owner_user_id=self.test_user_id,
state=test_state,
pause_reasons=[],
)
# Try to access pause from tenant 2 using tenant 1's repository
@ -762,6 +775,7 @@ class TestWorkflowPauseIntegration:
workflow_run_id=workflow_run2.id,
state_owner_user_id=account2.id,
state=test_state,
pause_reasons=[],
)
# Assert - Both pauses should exist and be separate
@ -782,6 +796,7 @@ class TestWorkflowPauseIntegration:
workflow_run_id=workflow_run.id,
state_owner_user_id=self.test_user_id,
state=test_state,
pause_reasons=[],
)
# Verify pause is properly scoped
@ -802,6 +817,7 @@ class TestWorkflowPauseIntegration:
workflow_run_id=workflow_run.id,
state_owner_user_id=self.test_user_id,
state=test_state,
pause_reasons=[],
)
# Assert - Verify file was uploaded to storage
@ -828,9 +844,7 @@ class TestWorkflowPauseIntegration:
repository = self._get_workflow_run_repository()
pause_entity = repository.create_workflow_pause(
workflow_run_id=workflow_run.id,
state_owner_user_id=self.test_user_id,
state=test_state,
workflow_run_id=workflow_run.id, state_owner_user_id=self.test_user_id, state=test_state, pause_reasons=[]
)
# Get file info before deletion
@ -868,6 +882,7 @@ class TestWorkflowPauseIntegration:
workflow_run_id=workflow_run.id,
state_owner_user_id=self.test_user_id,
state=large_state_json,
pause_reasons=[],
)
# Assert
@ -902,9 +917,7 @@ class TestWorkflowPauseIntegration:
# Pause
pause_entity = repository.create_workflow_pause(
workflow_run_id=workflow_run.id,
state_owner_user_id=self.test_user_id,
state=state,
workflow_run_id=workflow_run.id, state_owner_user_id=self.test_user_id, state=state, pause_reasons=[]
)
assert pause_entity is not None

View File

@ -31,7 +31,7 @@ class TestDataFactory:
@staticmethod
def create_graph_run_paused_event(outputs: dict[str, object] | None = None) -> GraphRunPausedEvent:
return GraphRunPausedEvent(reason=SchedulingPause(message="test pause"), outputs=outputs or {})
return GraphRunPausedEvent(reasons=[SchedulingPause(message="test pause")], outputs=outputs or {})
@staticmethod
def create_graph_run_started_event() -> GraphRunStartedEvent:
@ -255,15 +255,17 @@ class TestPauseStatePersistenceLayer:
layer.on_event(event)
mock_factory.assert_called_once_with(session_factory)
mock_repo.create_workflow_pause.assert_called_once_with(
workflow_run_id="run-123",
state_owner_user_id="owner-123",
state=mock_repo.create_workflow_pause.call_args.kwargs["state"],
)
serialized_state = mock_repo.create_workflow_pause.call_args.kwargs["state"]
assert mock_repo.create_workflow_pause.call_count == 1
call_kwargs = mock_repo.create_workflow_pause.call_args.kwargs
assert call_kwargs["workflow_run_id"] == "run-123"
assert call_kwargs["state_owner_user_id"] == "owner-123"
serialized_state = call_kwargs["state"]
resumption_context = WorkflowResumptionContext.loads(serialized_state)
assert resumption_context.serialized_graph_runtime_state == expected_state
assert resumption_context.get_generate_entity().model_dump() == generate_entity.model_dump()
pause_reasons = call_kwargs["pause_reasons"]
assert isinstance(pause_reasons, list)
def test_on_event_ignores_non_paused_events(self, monkeypatch: pytest.MonkeyPatch):
session_factory = Mock(name="session_factory")

View File

@ -19,38 +19,18 @@ class TestPrivateWorkflowPauseEntity:
mock_pause_model.resumed_at = None
# Create entity
entity = _PrivateWorkflowPauseEntity(
pause_model=mock_pause_model,
)
entity = _PrivateWorkflowPauseEntity(pause_model=mock_pause_model, reason_models=[], human_input_form=[])
# Verify initialization
assert entity._pause_model is mock_pause_model
assert entity._cached_state is None
def test_from_models_classmethod(self):
"""Test from_models class method."""
# Create mock models
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
mock_pause_model.id = "pause-123"
mock_pause_model.workflow_run_id = "execution-456"
# Create entity using from_models
entity = _PrivateWorkflowPauseEntity.from_models(
workflow_pause_model=mock_pause_model,
)
# Verify entity creation
assert isinstance(entity, _PrivateWorkflowPauseEntity)
assert entity._pause_model is mock_pause_model
def test_id_property(self):
"""Test id property returns pause model ID."""
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
mock_pause_model.id = "pause-123"
entity = _PrivateWorkflowPauseEntity(
pause_model=mock_pause_model,
)
entity = _PrivateWorkflowPauseEntity(pause_model=mock_pause_model, reason_models=[], human_input_form=[])
assert entity.id == "pause-123"
@ -59,9 +39,7 @@ class TestPrivateWorkflowPauseEntity:
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
mock_pause_model.workflow_run_id = "execution-456"
entity = _PrivateWorkflowPauseEntity(
pause_model=mock_pause_model,
)
entity = _PrivateWorkflowPauseEntity(pause_model=mock_pause_model, reason_models=[], human_input_form=[])
assert entity.workflow_execution_id == "execution-456"
@ -72,9 +50,7 @@ class TestPrivateWorkflowPauseEntity:
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
mock_pause_model.resumed_at = resumed_at
entity = _PrivateWorkflowPauseEntity(
pause_model=mock_pause_model,
)
entity = _PrivateWorkflowPauseEntity(pause_model=mock_pause_model, reason_models=[], human_input_form=[])
assert entity.resumed_at == resumed_at
@ -83,9 +59,7 @@ class TestPrivateWorkflowPauseEntity:
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
mock_pause_model.resumed_at = None
entity = _PrivateWorkflowPauseEntity(
pause_model=mock_pause_model,
)
entity = _PrivateWorkflowPauseEntity(pause_model=mock_pause_model, reason_models=[], human_input_form=[])
assert entity.resumed_at is None
@ -98,9 +72,7 @@ class TestPrivateWorkflowPauseEntity:
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
mock_pause_model.state_object_key = "test-state-key"
entity = _PrivateWorkflowPauseEntity(
pause_model=mock_pause_model,
)
entity = _PrivateWorkflowPauseEntity(pause_model=mock_pause_model, reason_models=[], human_input_form=[])
# First call should load from storage
result = entity.get_state()
@ -118,9 +90,7 @@ class TestPrivateWorkflowPauseEntity:
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
mock_pause_model.state_object_key = "test-state-key"
entity = _PrivateWorkflowPauseEntity(
pause_model=mock_pause_model,
)
entity = _PrivateWorkflowPauseEntity(pause_model=mock_pause_model, reason_models=[], human_input_form=[])
# First call
result1 = entity.get_state()
@ -139,9 +109,7 @@ class TestPrivateWorkflowPauseEntity:
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
entity = _PrivateWorkflowPauseEntity(
pause_model=mock_pause_model,
)
entity = _PrivateWorkflowPauseEntity(pause_model=mock_pause_model, reason_models=[], human_input_form=[])
# Pre-cache data
entity._cached_state = state_data
@ -162,9 +130,7 @@ class TestPrivateWorkflowPauseEntity:
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
entity = _PrivateWorkflowPauseEntity(
pause_model=mock_pause_model,
)
entity = _PrivateWorkflowPauseEntity(pause_model=mock_pause_model, reason_models=[], human_input_form=[])
result = entity.get_state()

View File

@ -8,12 +8,13 @@ from typing import Any
import pytest
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
from core.workflow.entities import GraphInitParams
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType
from core.workflow.graph import Graph
from core.workflow.graph.validation import GraphValidationError
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
from models.enums import UserFrom

View File

@ -178,8 +178,7 @@ def test_pause_command():
assert any(isinstance(e, GraphRunStartedEvent) for e in events)
pause_events = [e for e in events if isinstance(e, GraphRunPausedEvent)]
assert len(pause_events) == 1
assert pause_events[0].reason == SchedulingPause(message="User requested pause")
assert pause_events[0].reasons == [SchedulingPause(message="User requested pause")]
graph_execution = engine.graph_runtime_state.graph_execution
assert graph_execution.paused
assert graph_execution.pause_reason == SchedulingPause(message="User requested pause")
assert graph_execution.pause_reasons == [SchedulingPause(message="User requested pause")]

View File

@ -0,0 +1,488 @@
from core.helper.code_executor.code_executor import CodeLanguage
from core.variables.types import SegmentType
from core.workflow.nodes.code.code_node import CodeNode
from core.workflow.nodes.code.entities import CodeNodeData
from core.workflow.nodes.code.exc import (
CodeNodeError,
DepthLimitError,
OutputValidationError,
)
class TestCodeNodeExceptions:
"""Test suite for code node exceptions."""
def test_code_node_error_is_value_error(self):
"""Test CodeNodeError inherits from ValueError."""
error = CodeNodeError("test error")
assert isinstance(error, ValueError)
assert str(error) == "test error"
def test_output_validation_error_is_code_node_error(self):
"""Test OutputValidationError inherits from CodeNodeError."""
error = OutputValidationError("validation failed")
assert isinstance(error, CodeNodeError)
assert isinstance(error, ValueError)
assert str(error) == "validation failed"
def test_depth_limit_error_is_code_node_error(self):
"""Test DepthLimitError inherits from CodeNodeError."""
error = DepthLimitError("depth exceeded")
assert isinstance(error, CodeNodeError)
assert isinstance(error, ValueError)
assert str(error) == "depth exceeded"
def test_code_node_error_with_empty_message(self):
"""Test CodeNodeError with empty message."""
error = CodeNodeError("")
assert str(error) == ""
def test_output_validation_error_with_field_info(self):
"""Test OutputValidationError with field information."""
error = OutputValidationError("Output 'result' is not a valid type")
assert "result" in str(error)
assert "not a valid type" in str(error)
def test_depth_limit_error_with_limit_info(self):
"""Test DepthLimitError with limit information."""
error = DepthLimitError("Depth limit 5 reached, object too deep")
assert "5" in str(error)
assert "too deep" in str(error)
class TestCodeNodeClassMethods:
"""Test suite for CodeNode class methods."""
def test_code_node_version(self):
"""Test CodeNode version method."""
version = CodeNode.version()
assert version == "1"
def test_get_default_config_python3(self):
"""Test get_default_config for Python3."""
config = CodeNode.get_default_config(filters={"code_language": CodeLanguage.PYTHON3})
assert config is not None
assert isinstance(config, dict)
def test_get_default_config_javascript(self):
"""Test get_default_config for JavaScript."""
config = CodeNode.get_default_config(filters={"code_language": CodeLanguage.JAVASCRIPT})
assert config is not None
assert isinstance(config, dict)
def test_get_default_config_no_filters(self):
"""Test get_default_config with no filters defaults to Python3."""
config = CodeNode.get_default_config()
assert config is not None
assert isinstance(config, dict)
def test_get_default_config_empty_filters(self):
"""Test get_default_config with empty filters."""
config = CodeNode.get_default_config(filters={})
assert config is not None
class TestCodeNodeCheckMethods:
"""Test suite for CodeNode check methods."""
def test_check_string_none_value(self):
"""Test _check_string with None value."""
node = CodeNode.__new__(CodeNode)
result = node._check_string(None, "test_var")
assert result is None
def test_check_string_removes_null_bytes(self):
"""Test _check_string removes null bytes."""
node = CodeNode.__new__(CodeNode)
result = node._check_string("hello\x00world", "test_var")
assert result == "helloworld"
assert "\x00" not in result
def test_check_string_valid_string(self):
"""Test _check_string with valid string."""
node = CodeNode.__new__(CodeNode)
result = node._check_string("valid string", "test_var")
assert result == "valid string"
def test_check_string_empty_string(self):
"""Test _check_string with empty string."""
node = CodeNode.__new__(CodeNode)
result = node._check_string("", "test_var")
assert result == ""
def test_check_string_with_unicode(self):
"""Test _check_string with unicode characters."""
node = CodeNode.__new__(CodeNode)
result = node._check_string("你好世界🌍", "test_var")
assert result == "你好世界🌍"
def test_check_boolean_none_value(self):
"""Test _check_boolean with None value."""
node = CodeNode.__new__(CodeNode)
result = node._check_boolean(None, "test_var")
assert result is None
def test_check_boolean_true_value(self):
"""Test _check_boolean with True value."""
node = CodeNode.__new__(CodeNode)
result = node._check_boolean(True, "test_var")
assert result is True
def test_check_boolean_false_value(self):
"""Test _check_boolean with False value."""
node = CodeNode.__new__(CodeNode)
result = node._check_boolean(False, "test_var")
assert result is False
def test_check_number_none_value(self):
"""Test _check_number with None value."""
node = CodeNode.__new__(CodeNode)
result = node._check_number(None, "test_var")
assert result is None
def test_check_number_integer_value(self):
"""Test _check_number with integer value."""
node = CodeNode.__new__(CodeNode)
result = node._check_number(42, "test_var")
assert result == 42
def test_check_number_float_value(self):
"""Test _check_number with float value."""
node = CodeNode.__new__(CodeNode)
result = node._check_number(3.14, "test_var")
assert result == 3.14
def test_check_number_zero(self):
"""Test _check_number with zero."""
node = CodeNode.__new__(CodeNode)
result = node._check_number(0, "test_var")
assert result == 0
def test_check_number_negative(self):
"""Test _check_number with negative number."""
node = CodeNode.__new__(CodeNode)
result = node._check_number(-100, "test_var")
assert result == -100
def test_check_number_negative_float(self):
"""Test _check_number with negative float."""
node = CodeNode.__new__(CodeNode)
result = node._check_number(-3.14159, "test_var")
assert result == -3.14159
class TestCodeNodeConvertBooleanToInt:
"""Test suite for _convert_boolean_to_int static method."""
def test_convert_none_returns_none(self):
"""Test converting None returns None."""
result = CodeNode._convert_boolean_to_int(None)
assert result is None
def test_convert_true_returns_one(self):
"""Test converting True returns 1."""
result = CodeNode._convert_boolean_to_int(True)
assert result == 1
assert isinstance(result, int)
def test_convert_false_returns_zero(self):
"""Test converting False returns 0."""
result = CodeNode._convert_boolean_to_int(False)
assert result == 0
assert isinstance(result, int)
def test_convert_integer_returns_same(self):
"""Test converting integer returns same value."""
result = CodeNode._convert_boolean_to_int(42)
assert result == 42
def test_convert_float_returns_same(self):
"""Test converting float returns same value."""
result = CodeNode._convert_boolean_to_int(3.14)
assert result == 3.14
def test_convert_zero_returns_zero(self):
"""Test converting zero returns zero."""
result = CodeNode._convert_boolean_to_int(0)
assert result == 0
def test_convert_negative_returns_same(self):
"""Test converting negative number returns same value."""
result = CodeNode._convert_boolean_to_int(-100)
assert result == -100
class TestCodeNodeExtractVariableSelector:
"""Test suite for _extract_variable_selector_to_variable_mapping."""
def test_extract_empty_variables(self):
"""Test extraction with no variables."""
node_data = {
"title": "Test",
"variables": [],
"code_language": "python3",
"code": "def main(): return {}",
"outputs": {},
}
result = CodeNode._extract_variable_selector_to_variable_mapping(
graph_config={},
node_id="node_1",
node_data=node_data,
)
assert result == {}
def test_extract_single_variable(self):
"""Test extraction with single variable."""
node_data = {
"title": "Test",
"variables": [
{"variable": "input_text", "value_selector": ["start", "text"]},
],
"code_language": "python3",
"code": "def main(): return {}",
"outputs": {},
}
result = CodeNode._extract_variable_selector_to_variable_mapping(
graph_config={},
node_id="node_1",
node_data=node_data,
)
assert "node_1.input_text" in result
assert result["node_1.input_text"] == ["start", "text"]
def test_extract_multiple_variables(self):
"""Test extraction with multiple variables."""
node_data = {
"title": "Test",
"variables": [
{"variable": "var1", "value_selector": ["node_a", "output1"]},
{"variable": "var2", "value_selector": ["node_b", "output2"]},
{"variable": "var3", "value_selector": ["node_c", "output3"]},
],
"code_language": "python3",
"code": "def main(): return {}",
"outputs": {},
}
result = CodeNode._extract_variable_selector_to_variable_mapping(
graph_config={},
node_id="code_node",
node_data=node_data,
)
assert len(result) == 3
assert "code_node.var1" in result
assert "code_node.var2" in result
assert "code_node.var3" in result
def test_extract_with_nested_selector(self):
"""Test extraction with nested value selector."""
node_data = {
"title": "Test",
"variables": [
{"variable": "deep_var", "value_selector": ["node", "obj", "nested", "value"]},
],
"code_language": "python3",
"code": "def main(): return {}",
"outputs": {},
}
result = CodeNode._extract_variable_selector_to_variable_mapping(
graph_config={},
node_id="node_x",
node_data=node_data,
)
assert result["node_x.deep_var"] == ["node", "obj", "nested", "value"]
class TestCodeNodeDataValidation:
"""Test suite for CodeNodeData validation scenarios."""
def test_valid_python3_code_node_data(self):
"""Test valid Python3 CodeNodeData."""
data = CodeNodeData(
title="Python Code",
variables=[],
code_language=CodeLanguage.PYTHON3,
code="def main(): return {'result': 1}",
outputs={"result": CodeNodeData.Output(type=SegmentType.NUMBER)},
)
assert data.code_language == CodeLanguage.PYTHON3
def test_valid_javascript_code_node_data(self):
"""Test valid JavaScript CodeNodeData."""
data = CodeNodeData(
title="JS Code",
variables=[],
code_language=CodeLanguage.JAVASCRIPT,
code="function main() { return { result: 1 }; }",
outputs={"result": CodeNodeData.Output(type=SegmentType.NUMBER)},
)
assert data.code_language == CodeLanguage.JAVASCRIPT
def test_code_node_data_with_all_output_types(self):
"""Test CodeNodeData with all valid output types."""
data = CodeNodeData(
title="All Types",
variables=[],
code_language=CodeLanguage.PYTHON3,
code="def main(): return {}",
outputs={
"str_out": CodeNodeData.Output(type=SegmentType.STRING),
"num_out": CodeNodeData.Output(type=SegmentType.NUMBER),
"bool_out": CodeNodeData.Output(type=SegmentType.BOOLEAN),
"obj_out": CodeNodeData.Output(type=SegmentType.OBJECT),
"arr_str": CodeNodeData.Output(type=SegmentType.ARRAY_STRING),
"arr_num": CodeNodeData.Output(type=SegmentType.ARRAY_NUMBER),
"arr_bool": CodeNodeData.Output(type=SegmentType.ARRAY_BOOLEAN),
"arr_obj": CodeNodeData.Output(type=SegmentType.ARRAY_OBJECT),
},
)
assert len(data.outputs) == 8
def test_code_node_data_complex_nested_output(self):
"""Test CodeNodeData with complex nested output structure."""
data = CodeNodeData(
title="Complex Output",
variables=[],
code_language=CodeLanguage.PYTHON3,
code="def main(): return {}",
outputs={
"response": CodeNodeData.Output(
type=SegmentType.OBJECT,
children={
"data": CodeNodeData.Output(
type=SegmentType.OBJECT,
children={
"items": CodeNodeData.Output(type=SegmentType.ARRAY_STRING),
"count": CodeNodeData.Output(type=SegmentType.NUMBER),
},
),
"status": CodeNodeData.Output(type=SegmentType.STRING),
"success": CodeNodeData.Output(type=SegmentType.BOOLEAN),
},
),
},
)
assert data.outputs["response"].type == SegmentType.OBJECT
assert data.outputs["response"].children is not None
assert "data" in data.outputs["response"].children
assert data.outputs["response"].children["data"].children is not None
class TestCodeNodeInitialization:
"""Test suite for CodeNode initialization methods."""
def test_init_node_data_python3(self):
"""Test init_node_data with Python3 configuration."""
node = CodeNode.__new__(CodeNode)
data = {
"title": "Test Node",
"variables": [],
"code_language": "python3",
"code": "def main(): return {'x': 1}",
"outputs": {"x": {"type": "number"}},
}
node.init_node_data(data)
assert node._node_data.title == "Test Node"
assert node._node_data.code_language == CodeLanguage.PYTHON3
def test_init_node_data_javascript(self):
"""Test init_node_data with JavaScript configuration."""
node = CodeNode.__new__(CodeNode)
data = {
"title": "JS Node",
"variables": [],
"code_language": "javascript",
"code": "function main() { return { x: 1 }; }",
"outputs": {"x": {"type": "number"}},
}
node.init_node_data(data)
assert node._node_data.code_language == CodeLanguage.JAVASCRIPT
def test_get_title(self):
"""Test _get_title method."""
node = CodeNode.__new__(CodeNode)
node._node_data = CodeNodeData(
title="My Code Node",
variables=[],
code_language=CodeLanguage.PYTHON3,
code="",
outputs={},
)
assert node._get_title() == "My Code Node"
def test_get_description_none(self):
"""Test _get_description returns None when not set."""
node = CodeNode.__new__(CodeNode)
node._node_data = CodeNodeData(
title="Test",
variables=[],
code_language=CodeLanguage.PYTHON3,
code="",
outputs={},
)
assert node._get_description() is None
def test_get_base_node_data(self):
"""Test get_base_node_data returns node data."""
node = CodeNode.__new__(CodeNode)
node._node_data = CodeNodeData(
title="Base Test",
variables=[],
code_language=CodeLanguage.PYTHON3,
code="",
outputs={},
)
result = node.get_base_node_data()
assert result == node._node_data
assert result.title == "Base Test"

View File

@ -0,0 +1,353 @@
import pytest
from pydantic import ValidationError
from core.helper.code_executor.code_executor import CodeLanguage
from core.variables.types import SegmentType
from core.workflow.nodes.code.entities import CodeNodeData
class TestCodeNodeDataOutput:
"""Test suite for CodeNodeData.Output model."""
def test_output_with_string_type(self):
"""Test Output with STRING type."""
output = CodeNodeData.Output(type=SegmentType.STRING)
assert output.type == SegmentType.STRING
assert output.children is None
def test_output_with_number_type(self):
"""Test Output with NUMBER type."""
output = CodeNodeData.Output(type=SegmentType.NUMBER)
assert output.type == SegmentType.NUMBER
assert output.children is None
def test_output_with_boolean_type(self):
"""Test Output with BOOLEAN type."""
output = CodeNodeData.Output(type=SegmentType.BOOLEAN)
assert output.type == SegmentType.BOOLEAN
def test_output_with_object_type(self):
"""Test Output with OBJECT type."""
output = CodeNodeData.Output(type=SegmentType.OBJECT)
assert output.type == SegmentType.OBJECT
def test_output_with_array_string_type(self):
"""Test Output with ARRAY_STRING type."""
output = CodeNodeData.Output(type=SegmentType.ARRAY_STRING)
assert output.type == SegmentType.ARRAY_STRING
def test_output_with_array_number_type(self):
"""Test Output with ARRAY_NUMBER type."""
output = CodeNodeData.Output(type=SegmentType.ARRAY_NUMBER)
assert output.type == SegmentType.ARRAY_NUMBER
def test_output_with_array_object_type(self):
"""Test Output with ARRAY_OBJECT type."""
output = CodeNodeData.Output(type=SegmentType.ARRAY_OBJECT)
assert output.type == SegmentType.ARRAY_OBJECT
def test_output_with_array_boolean_type(self):
"""Test Output with ARRAY_BOOLEAN type."""
output = CodeNodeData.Output(type=SegmentType.ARRAY_BOOLEAN)
assert output.type == SegmentType.ARRAY_BOOLEAN
def test_output_with_nested_children(self):
"""Test Output with nested children for OBJECT type."""
child_output = CodeNodeData.Output(type=SegmentType.STRING)
parent_output = CodeNodeData.Output(
type=SegmentType.OBJECT,
children={"name": child_output},
)
assert parent_output.type == SegmentType.OBJECT
assert parent_output.children is not None
assert "name" in parent_output.children
assert parent_output.children["name"].type == SegmentType.STRING
def test_output_with_deeply_nested_children(self):
"""Test Output with deeply nested children."""
inner_child = CodeNodeData.Output(type=SegmentType.NUMBER)
middle_child = CodeNodeData.Output(
type=SegmentType.OBJECT,
children={"value": inner_child},
)
outer_output = CodeNodeData.Output(
type=SegmentType.OBJECT,
children={"nested": middle_child},
)
assert outer_output.children is not None
assert outer_output.children["nested"].children is not None
assert outer_output.children["nested"].children["value"].type == SegmentType.NUMBER
def test_output_with_multiple_children(self):
"""Test Output with multiple children."""
output = CodeNodeData.Output(
type=SegmentType.OBJECT,
children={
"name": CodeNodeData.Output(type=SegmentType.STRING),
"age": CodeNodeData.Output(type=SegmentType.NUMBER),
"active": CodeNodeData.Output(type=SegmentType.BOOLEAN),
},
)
assert output.children is not None
assert len(output.children) == 3
assert output.children["name"].type == SegmentType.STRING
assert output.children["age"].type == SegmentType.NUMBER
assert output.children["active"].type == SegmentType.BOOLEAN
def test_output_rejects_invalid_type(self):
"""Test Output rejects invalid segment types."""
with pytest.raises(ValidationError):
CodeNodeData.Output(type=SegmentType.FILE)
def test_output_rejects_array_file_type(self):
"""Test Output rejects ARRAY_FILE type."""
with pytest.raises(ValidationError):
CodeNodeData.Output(type=SegmentType.ARRAY_FILE)
class TestCodeNodeDataDependency:
"""Test suite for CodeNodeData.Dependency model."""
def test_dependency_basic(self):
"""Test Dependency with name and version."""
dependency = CodeNodeData.Dependency(name="numpy", version="1.24.0")
assert dependency.name == "numpy"
assert dependency.version == "1.24.0"
def test_dependency_with_complex_version(self):
"""Test Dependency with complex version string."""
dependency = CodeNodeData.Dependency(name="pandas", version=">=2.0.0,<3.0.0")
assert dependency.name == "pandas"
assert dependency.version == ">=2.0.0,<3.0.0"
def test_dependency_with_empty_version(self):
"""Test Dependency with empty version."""
dependency = CodeNodeData.Dependency(name="requests", version="")
assert dependency.name == "requests"
assert dependency.version == ""
class TestCodeNodeData:
"""Test suite for CodeNodeData model."""
def test_code_node_data_python3(self):
"""Test CodeNodeData with Python3 language."""
data = CodeNodeData(
title="Test Code Node",
variables=[],
code_language=CodeLanguage.PYTHON3,
code="def main(): return {'result': 42}",
outputs={"result": CodeNodeData.Output(type=SegmentType.NUMBER)},
)
assert data.title == "Test Code Node"
assert data.code_language == CodeLanguage.PYTHON3
assert data.code == "def main(): return {'result': 42}"
assert "result" in data.outputs
assert data.dependencies is None
def test_code_node_data_javascript(self):
"""Test CodeNodeData with JavaScript language."""
data = CodeNodeData(
title="JS Code Node",
variables=[],
code_language=CodeLanguage.JAVASCRIPT,
code="function main() { return { result: 'hello' }; }",
outputs={"result": CodeNodeData.Output(type=SegmentType.STRING)},
)
assert data.code_language == CodeLanguage.JAVASCRIPT
assert "result" in data.outputs
assert data.outputs["result"].type == SegmentType.STRING
def test_code_node_data_with_dependencies(self):
"""Test CodeNodeData with dependencies."""
data = CodeNodeData(
title="Code with Deps",
variables=[],
code_language=CodeLanguage.PYTHON3,
code="import numpy as np\ndef main(): return {'sum': 10}",
outputs={"sum": CodeNodeData.Output(type=SegmentType.NUMBER)},
dependencies=[
CodeNodeData.Dependency(name="numpy", version="1.24.0"),
CodeNodeData.Dependency(name="pandas", version="2.0.0"),
],
)
assert data.dependencies is not None
assert len(data.dependencies) == 2
assert data.dependencies[0].name == "numpy"
assert data.dependencies[1].name == "pandas"
def test_code_node_data_with_multiple_outputs(self):
"""Test CodeNodeData with multiple outputs."""
data = CodeNodeData(
title="Multi Output",
variables=[],
code_language=CodeLanguage.PYTHON3,
code="def main(): return {'name': 'test', 'count': 5, 'items': ['a', 'b']}",
outputs={
"name": CodeNodeData.Output(type=SegmentType.STRING),
"count": CodeNodeData.Output(type=SegmentType.NUMBER),
"items": CodeNodeData.Output(type=SegmentType.ARRAY_STRING),
},
)
assert len(data.outputs) == 3
assert data.outputs["name"].type == SegmentType.STRING
assert data.outputs["count"].type == SegmentType.NUMBER
assert data.outputs["items"].type == SegmentType.ARRAY_STRING
def test_code_node_data_with_object_output(self):
"""Test CodeNodeData with nested object output."""
data = CodeNodeData(
title="Object Output",
variables=[],
code_language=CodeLanguage.PYTHON3,
code="def main(): return {'user': {'name': 'John', 'age': 30}}",
outputs={
"user": CodeNodeData.Output(
type=SegmentType.OBJECT,
children={
"name": CodeNodeData.Output(type=SegmentType.STRING),
"age": CodeNodeData.Output(type=SegmentType.NUMBER),
},
),
},
)
assert data.outputs["user"].type == SegmentType.OBJECT
assert data.outputs["user"].children is not None
assert len(data.outputs["user"].children) == 2
def test_code_node_data_with_array_object_output(self):
"""Test CodeNodeData with array of objects output."""
data = CodeNodeData(
title="Array Object Output",
variables=[],
code_language=CodeLanguage.PYTHON3,
code="def main(): return {'users': [{'name': 'A'}, {'name': 'B'}]}",
outputs={
"users": CodeNodeData.Output(
type=SegmentType.ARRAY_OBJECT,
children={
"name": CodeNodeData.Output(type=SegmentType.STRING),
},
),
},
)
assert data.outputs["users"].type == SegmentType.ARRAY_OBJECT
assert data.outputs["users"].children is not None
def test_code_node_data_empty_code(self):
"""Test CodeNodeData with empty code."""
data = CodeNodeData(
title="Empty Code",
variables=[],
code_language=CodeLanguage.PYTHON3,
code="",
outputs={},
)
assert data.code == ""
assert len(data.outputs) == 0
def test_code_node_data_multiline_code(self):
"""Test CodeNodeData with multiline code."""
multiline_code = """
def main():
result = 0
for i in range(10):
result += i
return {'sum': result}
"""
data = CodeNodeData(
title="Multiline Code",
variables=[],
code_language=CodeLanguage.PYTHON3,
code=multiline_code,
outputs={"sum": CodeNodeData.Output(type=SegmentType.NUMBER)},
)
assert "for i in range(10)" in data.code
assert "result += i" in data.code
def test_code_node_data_with_special_characters_in_code(self):
"""Test CodeNodeData with special characters in code."""
code_with_special = "def main(): return {'msg': 'Hello\\nWorld\\t!'}"
data = CodeNodeData(
title="Special Chars",
variables=[],
code_language=CodeLanguage.PYTHON3,
code=code_with_special,
outputs={"msg": CodeNodeData.Output(type=SegmentType.STRING)},
)
assert "\\n" in data.code
assert "\\t" in data.code
def test_code_node_data_with_unicode_in_code(self):
"""Test CodeNodeData with unicode characters in code."""
unicode_code = "def main(): return {'greeting': '你好世界'}"
data = CodeNodeData(
title="Unicode Code",
variables=[],
code_language=CodeLanguage.PYTHON3,
code=unicode_code,
outputs={"greeting": CodeNodeData.Output(type=SegmentType.STRING)},
)
assert "你好世界" in data.code
def test_code_node_data_empty_dependencies_list(self):
"""Test CodeNodeData with empty dependencies list."""
data = CodeNodeData(
title="No Deps",
variables=[],
code_language=CodeLanguage.PYTHON3,
code="def main(): return {}",
outputs={},
dependencies=[],
)
assert data.dependencies is not None
assert len(data.dependencies) == 0
def test_code_node_data_with_boolean_array_output(self):
"""Test CodeNodeData with boolean array output."""
data = CodeNodeData(
title="Boolean Array",
variables=[],
code_language=CodeLanguage.PYTHON3,
code="def main(): return {'flags': [True, False, True]}",
outputs={"flags": CodeNodeData.Output(type=SegmentType.ARRAY_BOOLEAN)},
)
assert data.outputs["flags"].type == SegmentType.ARRAY_BOOLEAN
def test_code_node_data_with_number_array_output(self):
"""Test CodeNodeData with number array output."""
data = CodeNodeData(
title="Number Array",
variables=[],
code_language=CodeLanguage.PYTHON3,
code="def main(): return {'values': [1, 2, 3, 4, 5]}",
outputs={"values": CodeNodeData.Output(type=SegmentType.ARRAY_NUMBER)},
)
assert data.outputs["values"].type == SegmentType.ARRAY_NUMBER

View File

@ -0,0 +1,339 @@
from core.workflow.nodes.iteration.entities import (
ErrorHandleMode,
IterationNodeData,
IterationStartNodeData,
IterationState,
)
class TestErrorHandleMode:
"""Test suite for ErrorHandleMode enum."""
def test_terminated_value(self):
"""Test TERMINATED enum value."""
assert ErrorHandleMode.TERMINATED == "terminated"
assert ErrorHandleMode.TERMINATED.value == "terminated"
def test_continue_on_error_value(self):
"""Test CONTINUE_ON_ERROR enum value."""
assert ErrorHandleMode.CONTINUE_ON_ERROR == "continue-on-error"
assert ErrorHandleMode.CONTINUE_ON_ERROR.value == "continue-on-error"
def test_remove_abnormal_output_value(self):
"""Test REMOVE_ABNORMAL_OUTPUT enum value."""
assert ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT == "remove-abnormal-output"
assert ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT.value == "remove-abnormal-output"
def test_error_handle_mode_is_str_enum(self):
"""Test ErrorHandleMode is a string enum."""
assert isinstance(ErrorHandleMode.TERMINATED, str)
assert isinstance(ErrorHandleMode.CONTINUE_ON_ERROR, str)
assert isinstance(ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT, str)
def test_error_handle_mode_comparison(self):
"""Test ErrorHandleMode can be compared with strings."""
assert ErrorHandleMode.TERMINATED == "terminated"
assert ErrorHandleMode.CONTINUE_ON_ERROR == "continue-on-error"
def test_all_error_handle_modes(self):
"""Test all ErrorHandleMode values are accessible."""
modes = list(ErrorHandleMode)
assert len(modes) == 3
assert ErrorHandleMode.TERMINATED in modes
assert ErrorHandleMode.CONTINUE_ON_ERROR in modes
assert ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT in modes
class TestIterationNodeData:
"""Test suite for IterationNodeData model."""
def test_iteration_node_data_basic(self):
"""Test IterationNodeData with basic configuration."""
data = IterationNodeData(
title="Test Iteration",
iterator_selector=["node1", "output"],
output_selector=["iteration", "result"],
)
assert data.title == "Test Iteration"
assert data.iterator_selector == ["node1", "output"]
assert data.output_selector == ["iteration", "result"]
def test_iteration_node_data_default_values(self):
"""Test IterationNodeData default values."""
data = IterationNodeData(
title="Default Test",
iterator_selector=["start", "items"],
output_selector=["iter", "out"],
)
assert data.parent_loop_id is None
assert data.is_parallel is False
assert data.parallel_nums == 10
assert data.error_handle_mode == ErrorHandleMode.TERMINATED
assert data.flatten_output is True
def test_iteration_node_data_parallel_mode(self):
"""Test IterationNodeData with parallel mode enabled."""
data = IterationNodeData(
title="Parallel Iteration",
iterator_selector=["node", "list"],
output_selector=["iter", "output"],
is_parallel=True,
parallel_nums=5,
)
assert data.is_parallel is True
assert data.parallel_nums == 5
def test_iteration_node_data_custom_parallel_nums(self):
"""Test IterationNodeData with custom parallel numbers."""
data = IterationNodeData(
title="Custom Parallel",
iterator_selector=["a", "b"],
output_selector=["c", "d"],
parallel_nums=20,
)
assert data.parallel_nums == 20
def test_iteration_node_data_continue_on_error(self):
"""Test IterationNodeData with continue on error mode."""
data = IterationNodeData(
title="Continue Error",
iterator_selector=["x", "y"],
output_selector=["z", "w"],
error_handle_mode=ErrorHandleMode.CONTINUE_ON_ERROR,
)
assert data.error_handle_mode == ErrorHandleMode.CONTINUE_ON_ERROR
def test_iteration_node_data_remove_abnormal_output(self):
"""Test IterationNodeData with remove abnormal output mode."""
data = IterationNodeData(
title="Remove Abnormal",
iterator_selector=["input", "array"],
output_selector=["output", "result"],
error_handle_mode=ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT,
)
assert data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT
def test_iteration_node_data_flatten_output_disabled(self):
"""Test IterationNodeData with flatten output disabled."""
data = IterationNodeData(
title="No Flatten",
iterator_selector=["a"],
output_selector=["b"],
flatten_output=False,
)
assert data.flatten_output is False
def test_iteration_node_data_with_parent_loop_id(self):
"""Test IterationNodeData with parent loop ID."""
data = IterationNodeData(
title="Nested Loop",
iterator_selector=["parent", "items"],
output_selector=["child", "output"],
parent_loop_id="parent_loop_123",
)
assert data.parent_loop_id == "parent_loop_123"
def test_iteration_node_data_complex_selectors(self):
"""Test IterationNodeData with complex selectors."""
data = IterationNodeData(
title="Complex Selectors",
iterator_selector=["node1", "output", "data", "items"],
output_selector=["iteration", "result", "value"],
)
assert len(data.iterator_selector) == 4
assert len(data.output_selector) == 3
def test_iteration_node_data_all_options(self):
"""Test IterationNodeData with all options configured."""
data = IterationNodeData(
title="Full Config",
iterator_selector=["start", "list"],
output_selector=["end", "result"],
parent_loop_id="outer_loop",
is_parallel=True,
parallel_nums=15,
error_handle_mode=ErrorHandleMode.CONTINUE_ON_ERROR,
flatten_output=False,
)
assert data.title == "Full Config"
assert data.parent_loop_id == "outer_loop"
assert data.is_parallel is True
assert data.parallel_nums == 15
assert data.error_handle_mode == ErrorHandleMode.CONTINUE_ON_ERROR
assert data.flatten_output is False
class TestIterationStartNodeData:
"""Test suite for IterationStartNodeData model."""
def test_iteration_start_node_data_basic(self):
"""Test IterationStartNodeData basic creation."""
data = IterationStartNodeData(title="Iteration Start")
assert data.title == "Iteration Start"
def test_iteration_start_node_data_with_description(self):
"""Test IterationStartNodeData with description."""
data = IterationStartNodeData(
title="Start Node",
desc="This is the start of iteration",
)
assert data.title == "Start Node"
assert data.desc == "This is the start of iteration"
class TestIterationState:
"""Test suite for IterationState model."""
def test_iteration_state_default_values(self):
"""Test IterationState default values."""
state = IterationState()
assert state.outputs == []
assert state.current_output is None
def test_iteration_state_with_outputs(self):
"""Test IterationState with outputs."""
state = IterationState(outputs=["result1", "result2", "result3"])
assert len(state.outputs) == 3
assert state.outputs[0] == "result1"
assert state.outputs[2] == "result3"
def test_iteration_state_with_current_output(self):
"""Test IterationState with current output."""
state = IterationState(current_output="current_value")
assert state.current_output == "current_value"
def test_iteration_state_get_last_output_with_outputs(self):
"""Test get_last_output with outputs present."""
state = IterationState(outputs=["first", "second", "last"])
result = state.get_last_output()
assert result == "last"
def test_iteration_state_get_last_output_empty(self):
"""Test get_last_output with empty outputs."""
state = IterationState(outputs=[])
result = state.get_last_output()
assert result is None
def test_iteration_state_get_last_output_single(self):
"""Test get_last_output with single output."""
state = IterationState(outputs=["only_one"])
result = state.get_last_output()
assert result == "only_one"
def test_iteration_state_get_current_output(self):
"""Test get_current_output method."""
state = IterationState(current_output={"key": "value"})
result = state.get_current_output()
assert result == {"key": "value"}
def test_iteration_state_get_current_output_none(self):
"""Test get_current_output when None."""
state = IterationState()
result = state.get_current_output()
assert result is None
def test_iteration_state_with_complex_outputs(self):
"""Test IterationState with complex output types."""
state = IterationState(
outputs=[
{"id": 1, "name": "first"},
{"id": 2, "name": "second"},
[1, 2, 3],
"string_output",
]
)
assert len(state.outputs) == 4
assert state.outputs[0] == {"id": 1, "name": "first"}
assert state.outputs[2] == [1, 2, 3]
def test_iteration_state_with_none_outputs(self):
"""Test IterationState with None values in outputs."""
state = IterationState(outputs=["value1", None, "value3"])
assert len(state.outputs) == 3
assert state.outputs[1] is None
def test_iteration_state_get_last_output_with_none(self):
"""Test get_last_output when last output is None."""
state = IterationState(outputs=["first", None])
result = state.get_last_output()
assert result is None
def test_iteration_state_metadata_class(self):
"""Test IterationState.MetaData class."""
metadata = IterationState.MetaData(iterator_length=10)
assert metadata.iterator_length == 10
def test_iteration_state_metadata_different_lengths(self):
"""Test IterationState.MetaData with different lengths."""
metadata1 = IterationState.MetaData(iterator_length=0)
metadata2 = IterationState.MetaData(iterator_length=100)
metadata3 = IterationState.MetaData(iterator_length=1000000)
assert metadata1.iterator_length == 0
assert metadata2.iterator_length == 100
assert metadata3.iterator_length == 1000000
def test_iteration_state_outputs_modification(self):
"""Test modifying IterationState outputs."""
state = IterationState(outputs=[])
state.outputs.append("new_output")
state.outputs.append("another_output")
assert len(state.outputs) == 2
assert state.get_last_output() == "another_output"
def test_iteration_state_current_output_update(self):
"""Test updating current_output."""
state = IterationState()
state.current_output = "first_value"
assert state.get_current_output() == "first_value"
state.current_output = "updated_value"
assert state.get_current_output() == "updated_value"
def test_iteration_state_with_numeric_outputs(self):
"""Test IterationState with numeric outputs."""
state = IterationState(outputs=[1, 2, 3, 4, 5])
assert state.get_last_output() == 5
assert len(state.outputs) == 5
def test_iteration_state_with_boolean_outputs(self):
"""Test IterationState with boolean outputs."""
state = IterationState(outputs=[True, False, True])
assert state.get_last_output() is True
assert state.outputs[1] is False

View File

@ -0,0 +1,390 @@
from core.workflow.enums import NodeType
from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData
from core.workflow.nodes.iteration.exc import (
InvalidIteratorValueError,
IterationGraphNotFoundError,
IterationIndexNotFoundError,
IterationNodeError,
IteratorVariableNotFoundError,
StartNodeIdNotFoundError,
)
from core.workflow.nodes.iteration.iteration_node import IterationNode
class TestIterationNodeExceptions:
"""Test suite for iteration node exceptions."""
def test_iteration_node_error_is_value_error(self):
"""Test IterationNodeError inherits from ValueError."""
error = IterationNodeError("test error")
assert isinstance(error, ValueError)
assert str(error) == "test error"
def test_iterator_variable_not_found_error(self):
"""Test IteratorVariableNotFoundError."""
error = IteratorVariableNotFoundError("Iterator variable not found")
assert isinstance(error, IterationNodeError)
assert isinstance(error, ValueError)
assert "Iterator variable not found" in str(error)
def test_invalid_iterator_value_error(self):
"""Test InvalidIteratorValueError."""
error = InvalidIteratorValueError("Invalid iterator value")
assert isinstance(error, IterationNodeError)
assert "Invalid iterator value" in str(error)
def test_start_node_id_not_found_error(self):
"""Test StartNodeIdNotFoundError."""
error = StartNodeIdNotFoundError("Start node ID not found")
assert isinstance(error, IterationNodeError)
assert "Start node ID not found" in str(error)
def test_iteration_graph_not_found_error(self):
"""Test IterationGraphNotFoundError."""
error = IterationGraphNotFoundError("Iteration graph not found")
assert isinstance(error, IterationNodeError)
assert "Iteration graph not found" in str(error)
def test_iteration_index_not_found_error(self):
"""Test IterationIndexNotFoundError."""
error = IterationIndexNotFoundError("Iteration index not found")
assert isinstance(error, IterationNodeError)
assert "Iteration index not found" in str(error)
def test_exception_with_empty_message(self):
"""Test exception with empty message."""
error = IterationNodeError("")
assert str(error) == ""
def test_exception_with_detailed_message(self):
"""Test exception with detailed message."""
error = IteratorVariableNotFoundError("Variable 'items' not found in node 'start_node'")
assert "items" in str(error)
assert "start_node" in str(error)
def test_all_exceptions_inherit_from_base(self):
"""Test all exceptions inherit from IterationNodeError."""
exceptions = [
IteratorVariableNotFoundError("test"),
InvalidIteratorValueError("test"),
StartNodeIdNotFoundError("test"),
IterationGraphNotFoundError("test"),
IterationIndexNotFoundError("test"),
]
for exc in exceptions:
assert isinstance(exc, IterationNodeError)
assert isinstance(exc, ValueError)
class TestIterationNodeClassAttributes:
"""Test suite for IterationNode class attributes."""
def test_node_type(self):
"""Test IterationNode node_type attribute."""
assert IterationNode.node_type == NodeType.ITERATION
def test_version(self):
"""Test IterationNode version method."""
version = IterationNode.version()
assert version == "1"
class TestIterationNodeDefaultConfig:
"""Test suite for IterationNode get_default_config."""
def test_get_default_config_returns_dict(self):
"""Test get_default_config returns a dictionary."""
config = IterationNode.get_default_config()
assert isinstance(config, dict)
def test_get_default_config_type(self):
"""Test get_default_config includes type."""
config = IterationNode.get_default_config()
assert config.get("type") == "iteration"
def test_get_default_config_has_config_section(self):
"""Test get_default_config has config section."""
config = IterationNode.get_default_config()
assert "config" in config
assert isinstance(config["config"], dict)
def test_get_default_config_is_parallel_default(self):
"""Test get_default_config is_parallel default value."""
config = IterationNode.get_default_config()
assert config["config"]["is_parallel"] is False
def test_get_default_config_parallel_nums_default(self):
"""Test get_default_config parallel_nums default value."""
config = IterationNode.get_default_config()
assert config["config"]["parallel_nums"] == 10
def test_get_default_config_error_handle_mode_default(self):
"""Test get_default_config error_handle_mode default value."""
config = IterationNode.get_default_config()
assert config["config"]["error_handle_mode"] == ErrorHandleMode.TERMINATED
def test_get_default_config_flatten_output_default(self):
"""Test get_default_config flatten_output default value."""
config = IterationNode.get_default_config()
assert config["config"]["flatten_output"] is True
def test_get_default_config_with_none_filters(self):
"""Test get_default_config with None filters."""
config = IterationNode.get_default_config(filters=None)
assert config is not None
assert "type" in config
def test_get_default_config_with_empty_filters(self):
"""Test get_default_config with empty filters."""
config = IterationNode.get_default_config(filters={})
assert config is not None
class TestIterationNodeInitialization:
"""Test suite for IterationNode initialization."""
def test_init_node_data_basic(self):
"""Test init_node_data with basic configuration."""
node = IterationNode.__new__(IterationNode)
data = {
"title": "Test Iteration",
"iterator_selector": ["start", "items"],
"output_selector": ["iteration", "result"],
}
node.init_node_data(data)
assert node._node_data.title == "Test Iteration"
assert node._node_data.iterator_selector == ["start", "items"]
def test_init_node_data_with_parallel(self):
"""Test init_node_data with parallel configuration."""
node = IterationNode.__new__(IterationNode)
data = {
"title": "Parallel Iteration",
"iterator_selector": ["node", "list"],
"output_selector": ["out", "result"],
"is_parallel": True,
"parallel_nums": 5,
}
node.init_node_data(data)
assert node._node_data.is_parallel is True
assert node._node_data.parallel_nums == 5
def test_init_node_data_with_error_handle_mode(self):
"""Test init_node_data with error handle mode."""
node = IterationNode.__new__(IterationNode)
data = {
"title": "Error Handle Test",
"iterator_selector": ["a", "b"],
"output_selector": ["c", "d"],
"error_handle_mode": "continue-on-error",
}
node.init_node_data(data)
assert node._node_data.error_handle_mode == ErrorHandleMode.CONTINUE_ON_ERROR
def test_get_title(self):
"""Test _get_title method."""
node = IterationNode.__new__(IterationNode)
node._node_data = IterationNodeData(
title="My Iteration",
iterator_selector=["x"],
output_selector=["y"],
)
assert node._get_title() == "My Iteration"
def test_get_description_none(self):
"""Test _get_description returns None when not set."""
node = IterationNode.__new__(IterationNode)
node._node_data = IterationNodeData(
title="Test",
iterator_selector=["a"],
output_selector=["b"],
)
assert node._get_description() is None
def test_get_description_with_value(self):
"""Test _get_description with value."""
node = IterationNode.__new__(IterationNode)
node._node_data = IterationNodeData(
title="Test",
desc="This is a description",
iterator_selector=["a"],
output_selector=["b"],
)
assert node._get_description() == "This is a description"
def test_get_base_node_data(self):
"""Test get_base_node_data returns node data."""
node = IterationNode.__new__(IterationNode)
node._node_data = IterationNodeData(
title="Base Test",
iterator_selector=["x"],
output_selector=["y"],
)
result = node.get_base_node_data()
assert result == node._node_data
class TestIterationNodeDataValidation:
"""Test suite for IterationNodeData validation scenarios."""
def test_valid_iteration_node_data(self):
"""Test valid IterationNodeData creation."""
data = IterationNodeData(
title="Valid Iteration",
iterator_selector=["start", "items"],
output_selector=["end", "result"],
)
assert data.title == "Valid Iteration"
def test_iteration_node_data_with_all_error_modes(self):
"""Test IterationNodeData with all error handle modes."""
modes = [
ErrorHandleMode.TERMINATED,
ErrorHandleMode.CONTINUE_ON_ERROR,
ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT,
]
for mode in modes:
data = IterationNodeData(
title=f"Test {mode}",
iterator_selector=["a"],
output_selector=["b"],
error_handle_mode=mode,
)
assert data.error_handle_mode == mode
def test_iteration_node_data_parallel_configuration(self):
"""Test IterationNodeData parallel configuration combinations."""
configs = [
(False, 10),
(True, 1),
(True, 5),
(True, 20),
(True, 100),
]
for is_parallel, parallel_nums in configs:
data = IterationNodeData(
title="Parallel Test",
iterator_selector=["x"],
output_selector=["y"],
is_parallel=is_parallel,
parallel_nums=parallel_nums,
)
assert data.is_parallel == is_parallel
assert data.parallel_nums == parallel_nums
def test_iteration_node_data_flatten_output_options(self):
"""Test IterationNodeData flatten_output options."""
data_flatten = IterationNodeData(
title="Flatten True",
iterator_selector=["a"],
output_selector=["b"],
flatten_output=True,
)
data_no_flatten = IterationNodeData(
title="Flatten False",
iterator_selector=["a"],
output_selector=["b"],
flatten_output=False,
)
assert data_flatten.flatten_output is True
assert data_no_flatten.flatten_output is False
def test_iteration_node_data_complex_selectors(self):
"""Test IterationNodeData with complex selectors."""
data = IterationNodeData(
title="Complex",
iterator_selector=["node1", "output", "data", "items", "list"],
output_selector=["iteration", "result", "value", "final"],
)
assert len(data.iterator_selector) == 5
assert len(data.output_selector) == 4
def test_iteration_node_data_single_element_selectors(self):
"""Test IterationNodeData with single element selectors."""
data = IterationNodeData(
title="Single",
iterator_selector=["items"],
output_selector=["result"],
)
assert len(data.iterator_selector) == 1
assert len(data.output_selector) == 1
class TestIterationNodeErrorStrategies:
"""Test suite for IterationNode error strategies."""
def test_get_error_strategy_default(self):
"""Test _get_error_strategy with default value."""
node = IterationNode.__new__(IterationNode)
node._node_data = IterationNodeData(
title="Test",
iterator_selector=["a"],
output_selector=["b"],
)
result = node._get_error_strategy()
assert result is None or result == node._node_data.error_strategy
def test_get_retry_config(self):
"""Test _get_retry_config method."""
node = IterationNode.__new__(IterationNode)
node._node_data = IterationNodeData(
title="Test",
iterator_selector=["a"],
output_selector=["b"],
)
result = node._get_retry_config()
assert result is not None
def test_get_default_value_dict(self):
"""Test _get_default_value_dict method."""
node = IterationNode.__new__(IterationNode)
node._node_data = IterationNodeData(
title="Test",
iterator_selector=["a"],
output_selector=["b"],
)
result = node._get_default_value_dict()
assert isinstance(result, dict)

View File

@ -0,0 +1,544 @@
from unittest.mock import MagicMock
import pytest
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.variables import ArrayNumberSegment, ArrayStringSegment
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
from core.workflow.nodes.list_operator.node import ListOperatorNode
from models.workflow import WorkflowType
class TestListOperatorNode:
"""Comprehensive tests for ListOperatorNode."""
@pytest.fixture
def mock_graph_runtime_state(self):
"""Create mock GraphRuntimeState."""
mock_state = MagicMock(spec=GraphRuntimeState)
mock_variable_pool = MagicMock()
mock_state.variable_pool = mock_variable_pool
return mock_state
@pytest.fixture
def mock_graph(self):
"""Create mock Graph."""
return MagicMock(spec=Graph)
@pytest.fixture
def graph_init_params(self):
"""Create GraphInitParams fixture."""
return GraphInitParams(
tenant_id="test",
app_id="test",
workflow_type=WorkflowType.WORKFLOW,
workflow_id="test",
graph_config={},
user_id="test",
user_from="test",
invoke_from="test",
call_depth=0,
)
@pytest.fixture
def list_operator_node_factory(self, graph_init_params, mock_graph, mock_graph_runtime_state):
"""Factory fixture for creating ListOperatorNode instances."""
def _create_node(config, mock_variable):
mock_graph_runtime_state.variable_pool.get.return_value = mock_variable
return ListOperatorNode(
id="test",
config=config,
graph_init_params=graph_init_params,
graph=mock_graph,
graph_runtime_state=mock_graph_runtime_state,
)
return _create_node
def test_node_initialization(self, mock_graph, mock_graph_runtime_state, graph_init_params):
"""Test node initializes correctly."""
config = {
"title": "List Operator",
"variable": ["sys", "list"],
"filter_by": {"enabled": False},
"order_by": {"enabled": False},
"limit": {"enabled": False},
}
node = ListOperatorNode(
id="test",
config=config,
graph_init_params=graph_init_params,
graph=mock_graph,
graph_runtime_state=mock_graph_runtime_state,
)
assert node.node_type == NodeType.LIST_OPERATOR
assert node._node_data.title == "List Operator"
def test_version(self):
"""Test version returns correct value."""
assert ListOperatorNode.version() == "1"
def test_run_with_string_array(self, list_operator_node_factory):
"""Test with string array."""
config = {
"title": "Test",
"variable": ["sys", "items"],
"filter_by": {"enabled": False},
"order_by": {"enabled": False},
"limit": {"enabled": False},
}
mock_var = ArrayStringSegment(value=["apple", "banana", "cherry"])
node = list_operator_node_factory(config, mock_var)
result = node._run()
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert result.outputs["result"].value == ["apple", "banana", "cherry"]
def test_run_with_empty_array(self, mock_graph, mock_graph_runtime_state, graph_init_params):
"""Test with empty array."""
config = {
"title": "Test",
"variable": ["sys", "items"],
"filter_by": {"enabled": False},
"order_by": {"enabled": False},
"limit": {"enabled": False},
}
mock_var = ArrayStringSegment(value=[])
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
node = ListOperatorNode(
id="test",
config=config,
graph_init_params=graph_init_params,
graph=mock_graph,
graph_runtime_state=mock_graph_runtime_state,
)
result = node._run()
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert result.outputs["result"].value == []
assert result.outputs["first_record"] is None
assert result.outputs["last_record"] is None
def test_run_with_filter_contains(self, mock_graph, mock_graph_runtime_state, graph_init_params):
"""Test filter with contains condition."""
config = {
"title": "Test",
"variable": ["sys", "items"],
"filter_by": {
"enabled": True,
"condition": "contains",
"value": "app",
},
"order_by": {"enabled": False},
"limit": {"enabled": False},
}
mock_var = ArrayStringSegment(value=["apple", "banana", "pineapple", "cherry"])
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
node = ListOperatorNode(
id="test",
config=config,
graph_init_params=graph_init_params,
graph=mock_graph,
graph_runtime_state=mock_graph_runtime_state,
)
result = node._run()
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert result.outputs["result"].value == ["apple", "pineapple"]
def test_run_with_filter_not_contains(self, mock_graph, mock_graph_runtime_state, graph_init_params):
"""Test filter with not contains condition."""
config = {
"title": "Test",
"variable": ["sys", "items"],
"filter_by": {
"enabled": True,
"condition": "not contains",
"value": "app",
},
"order_by": {"enabled": False},
"limit": {"enabled": False},
}
mock_var = ArrayStringSegment(value=["apple", "banana", "pineapple", "cherry"])
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
node = ListOperatorNode(
id="test",
config=config,
graph_init_params=graph_init_params,
graph=mock_graph,
graph_runtime_state=mock_graph_runtime_state,
)
result = node._run()
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert result.outputs["result"].value == ["banana", "cherry"]
def test_run_with_number_filter_greater_than(self, mock_graph, mock_graph_runtime_state, graph_init_params):
"""Test filter with greater than condition on numbers."""
config = {
"title": "Test",
"variable": ["sys", "numbers"],
"filter_by": {
"enabled": True,
"condition": ">",
"value": "5",
},
"order_by": {"enabled": False},
"limit": {"enabled": False},
}
mock_var = ArrayNumberSegment(value=[1, 3, 5, 7, 9, 11])
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
node = ListOperatorNode(
id="test",
config=config,
graph_init_params=graph_init_params,
graph=mock_graph,
graph_runtime_state=mock_graph_runtime_state,
)
result = node._run()
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert result.outputs["result"].value == [7, 9, 11]
def test_run_with_order_ascending(self, mock_graph, mock_graph_runtime_state, graph_init_params):
"""Test ordering in ascending order."""
config = {
"title": "Test",
"variable": ["sys", "items"],
"filter_by": {"enabled": False},
"order_by": {
"enabled": True,
"value": "asc",
},
"limit": {"enabled": False},
}
mock_var = ArrayStringSegment(value=["cherry", "apple", "banana"])
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
node = ListOperatorNode(
id="test",
config=config,
graph_init_params=graph_init_params,
graph=mock_graph,
graph_runtime_state=mock_graph_runtime_state,
)
result = node._run()
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert result.outputs["result"].value == ["apple", "banana", "cherry"]
def test_run_with_order_descending(self, mock_graph, mock_graph_runtime_state, graph_init_params):
"""Test ordering in descending order."""
config = {
"title": "Test",
"variable": ["sys", "items"],
"filter_by": {"enabled": False},
"order_by": {
"enabled": True,
"value": "desc",
},
"limit": {"enabled": False},
}
mock_var = ArrayStringSegment(value=["cherry", "apple", "banana"])
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
node = ListOperatorNode(
id="test",
config=config,
graph_init_params=graph_init_params,
graph=mock_graph,
graph_runtime_state=mock_graph_runtime_state,
)
result = node._run()
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert result.outputs["result"].value == ["cherry", "banana", "apple"]
def test_run_with_limit(self, mock_graph, mock_graph_runtime_state, graph_init_params):
"""Test with limit enabled."""
config = {
"title": "Test",
"variable": ["sys", "items"],
"filter_by": {"enabled": False},
"order_by": {"enabled": False},
"limit": {
"enabled": True,
"size": 2,
},
}
mock_var = ArrayStringSegment(value=["apple", "banana", "cherry", "date"])
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
node = ListOperatorNode(
id="test",
config=config,
graph_init_params=graph_init_params,
graph=mock_graph,
graph_runtime_state=mock_graph_runtime_state,
)
result = node._run()
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert result.outputs["result"].value == ["apple", "banana"]
def test_run_with_filter_order_and_limit(self, mock_graph, mock_graph_runtime_state, graph_init_params):
"""Test with filter, order, and limit combined."""
config = {
"title": "Test",
"variable": ["sys", "numbers"],
"filter_by": {
"enabled": True,
"condition": ">",
"value": "3",
},
"order_by": {
"enabled": True,
"value": "desc",
},
"limit": {
"enabled": True,
"size": 3,
},
}
mock_var = ArrayNumberSegment(value=[1, 2, 3, 4, 5, 6, 7, 8, 9])
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
node = ListOperatorNode(
id="test",
config=config,
graph_init_params=graph_init_params,
graph=mock_graph,
graph_runtime_state=mock_graph_runtime_state,
)
result = node._run()
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert result.outputs["result"].value == [9, 8, 7]
def test_run_with_variable_not_found(self, mock_graph, mock_graph_runtime_state, graph_init_params):
"""Test when variable is not found."""
config = {
"title": "Test",
"variable": ["sys", "missing"],
"filter_by": {"enabled": False},
"order_by": {"enabled": False},
"limit": {"enabled": False},
}
mock_graph_runtime_state.variable_pool.get.return_value = None
node = ListOperatorNode(
id="test",
config=config,
graph_init_params=graph_init_params,
graph=mock_graph,
graph_runtime_state=mock_graph_runtime_state,
)
result = node._run()
assert result.status == WorkflowNodeExecutionStatus.FAILED
assert "Variable not found" in result.error
def test_run_with_first_and_last_record(self, mock_graph, mock_graph_runtime_state, graph_init_params):
"""Test first_record and last_record outputs."""
config = {
"title": "Test",
"variable": ["sys", "items"],
"filter_by": {"enabled": False},
"order_by": {"enabled": False},
"limit": {"enabled": False},
}
mock_var = ArrayStringSegment(value=["first", "middle", "last"])
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
node = ListOperatorNode(
id="test",
config=config,
graph_init_params=graph_init_params,
graph=mock_graph,
graph_runtime_state=mock_graph_runtime_state,
)
result = node._run()
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert result.outputs["first_record"] == "first"
assert result.outputs["last_record"] == "last"
def test_run_with_filter_startswith(self, mock_graph, mock_graph_runtime_state, graph_init_params):
"""Test filter with startswith condition."""
config = {
"title": "Test",
"variable": ["sys", "items"],
"filter_by": {
"enabled": True,
"condition": "start with",
"value": "app",
},
"order_by": {"enabled": False},
"limit": {"enabled": False},
}
mock_var = ArrayStringSegment(value=["apple", "application", "banana", "apricot"])
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
node = ListOperatorNode(
id="test",
config=config,
graph_init_params=graph_init_params,
graph=mock_graph,
graph_runtime_state=mock_graph_runtime_state,
)
result = node._run()
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert result.outputs["result"].value == ["apple", "application"]
def test_run_with_filter_endswith(self, mock_graph, mock_graph_runtime_state, graph_init_params):
"""Test filter with endswith condition."""
config = {
"title": "Test",
"variable": ["sys", "items"],
"filter_by": {
"enabled": True,
"condition": "end with",
"value": "le",
},
"order_by": {"enabled": False},
"limit": {"enabled": False},
}
mock_var = ArrayStringSegment(value=["apple", "banana", "pineapple", "table"])
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
node = ListOperatorNode(
id="test",
config=config,
graph_init_params=graph_init_params,
graph=mock_graph,
graph_runtime_state=mock_graph_runtime_state,
)
result = node._run()
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert result.outputs["result"].value == ["apple", "pineapple", "table"]
def test_run_with_number_filter_equals(self, mock_graph, mock_graph_runtime_state, graph_init_params):
"""Test number filter with equals condition."""
config = {
"title": "Test",
"variable": ["sys", "numbers"],
"filter_by": {
"enabled": True,
"condition": "=",
"value": "5",
},
"order_by": {"enabled": False},
"limit": {"enabled": False},
}
mock_var = ArrayNumberSegment(value=[1, 3, 5, 5, 7, 9])
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
node = ListOperatorNode(
id="test",
config=config,
graph_init_params=graph_init_params,
graph=mock_graph,
graph_runtime_state=mock_graph_runtime_state,
)
result = node._run()
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert result.outputs["result"].value == [5, 5]
def test_run_with_number_filter_not_equals(self, mock_graph, mock_graph_runtime_state, graph_init_params):
"""Test number filter with not equals condition."""
config = {
"title": "Test",
"variable": ["sys", "numbers"],
"filter_by": {
"enabled": True,
"condition": "",
"value": "5",
},
"order_by": {"enabled": False},
"limit": {"enabled": False},
}
mock_var = ArrayNumberSegment(value=[1, 3, 5, 7, 9])
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
node = ListOperatorNode(
id="test",
config=config,
graph_init_params=graph_init_params,
graph=mock_graph,
graph_runtime_state=mock_graph_runtime_state,
)
result = node._run()
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert result.outputs["result"].value == [1, 3, 7, 9]
def test_run_with_number_order_ascending(self, mock_graph, mock_graph_runtime_state, graph_init_params):
"""Test number ordering in ascending order."""
config = {
"title": "Test",
"variable": ["sys", "numbers"],
"filter_by": {"enabled": False},
"order_by": {
"enabled": True,
"value": "asc",
},
"limit": {"enabled": False},
}
mock_var = ArrayNumberSegment(value=[9, 3, 7, 1, 5])
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
node = ListOperatorNode(
id="test",
config=config,
graph_init_params=graph_init_params,
graph=mock_graph,
graph_runtime_state=mock_graph_runtime_state,
)
result = node._run()
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert result.outputs["result"].value == [1, 3, 5, 7, 9]

View File

@ -0,0 +1,225 @@
import pytest
from pydantic import ValidationError
from core.workflow.enums import ErrorStrategy
from core.workflow.nodes.template_transform.entities import TemplateTransformNodeData
class TestTemplateTransformNodeData:
"""Test suite for TemplateTransformNodeData entity."""
def test_valid_template_transform_node_data(self):
"""Test creating valid TemplateTransformNodeData."""
data = {
"title": "Template Transform",
"desc": "Transform data using Jinja2 template",
"variables": [
{"variable": "name", "value_selector": ["sys", "user_name"]},
{"variable": "age", "value_selector": ["sys", "user_age"]},
],
"template": "Hello {{ name }}, you are {{ age }} years old!",
}
node_data = TemplateTransformNodeData.model_validate(data)
assert node_data.title == "Template Transform"
assert node_data.desc == "Transform data using Jinja2 template"
assert len(node_data.variables) == 2
assert node_data.variables[0].variable == "name"
assert node_data.variables[0].value_selector == ["sys", "user_name"]
assert node_data.variables[1].variable == "age"
assert node_data.variables[1].value_selector == ["sys", "user_age"]
assert node_data.template == "Hello {{ name }}, you are {{ age }} years old!"
def test_template_transform_node_data_with_empty_variables(self):
"""Test TemplateTransformNodeData with no variables."""
data = {
"title": "Static Template",
"variables": [],
"template": "This is a static template with no variables.",
}
node_data = TemplateTransformNodeData.model_validate(data)
assert node_data.title == "Static Template"
assert len(node_data.variables) == 0
assert node_data.template == "This is a static template with no variables."
def test_template_transform_node_data_with_complex_template(self):
"""Test TemplateTransformNodeData with complex Jinja2 template."""
data = {
"title": "Complex Template",
"variables": [
{"variable": "items", "value_selector": ["sys", "item_list"]},
{"variable": "total", "value_selector": ["sys", "total_count"]},
],
"template": (
"{% for item in items %}{{ item }}{% if not loop.last %}, {% endif %}{% endfor %}. Total: {{ total }}"
),
}
node_data = TemplateTransformNodeData.model_validate(data)
assert node_data.title == "Complex Template"
assert len(node_data.variables) == 2
assert "{% for item in items %}" in node_data.template
assert "{{ total }}" in node_data.template
def test_template_transform_node_data_with_error_strategy(self):
"""Test TemplateTransformNodeData with error handling strategy."""
data = {
"title": "Template with Error Handling",
"variables": [{"variable": "value", "value_selector": ["sys", "input"]}],
"template": "{{ value }}",
"error_strategy": "fail-branch",
}
node_data = TemplateTransformNodeData.model_validate(data)
assert node_data.error_strategy == ErrorStrategy.FAIL_BRANCH
def test_template_transform_node_data_with_retry_config(self):
"""Test TemplateTransformNodeData with retry configuration."""
data = {
"title": "Template with Retry",
"variables": [{"variable": "data", "value_selector": ["sys", "data"]}],
"template": "{{ data }}",
"retry_config": {"enabled": True, "max_retries": 3, "retry_interval": 1000},
}
node_data = TemplateTransformNodeData.model_validate(data)
assert node_data.retry_config.enabled is True
assert node_data.retry_config.max_retries == 3
assert node_data.retry_config.retry_interval == 1000
def test_template_transform_node_data_missing_required_fields(self):
"""Test that missing required fields raises ValidationError."""
data = {
"title": "Incomplete Template",
# Missing 'variables' and 'template'
}
with pytest.raises(ValidationError) as exc_info:
TemplateTransformNodeData.model_validate(data)
errors = exc_info.value.errors()
assert len(errors) >= 2
error_fields = {error["loc"][0] for error in errors}
assert "variables" in error_fields
assert "template" in error_fields
def test_template_transform_node_data_invalid_variable_selector(self):
"""Test that invalid variable selector format raises ValidationError."""
data = {
"title": "Invalid Variable",
"variables": [
{"variable": "name", "value_selector": "invalid_format"} # Should be list
],
"template": "{{ name }}",
}
with pytest.raises(ValidationError):
TemplateTransformNodeData.model_validate(data)
def test_template_transform_node_data_with_default_value_dict(self):
"""Test TemplateTransformNodeData with default value dictionary."""
data = {
"title": "Template with Defaults",
"variables": [
{"variable": "name", "value_selector": ["sys", "user_name"]},
{"variable": "greeting", "value_selector": ["sys", "greeting"]},
],
"template": "{{ greeting }} {{ name }}!",
"default_value_dict": {"greeting": "Hello", "name": "Guest"},
}
node_data = TemplateTransformNodeData.model_validate(data)
assert node_data.default_value_dict == {"greeting": "Hello", "name": "Guest"}
def test_template_transform_node_data_with_nested_selectors(self):
"""Test TemplateTransformNodeData with nested variable selectors."""
data = {
"title": "Nested Selectors",
"variables": [
{"variable": "user_info", "value_selector": ["sys", "user", "profile", "name"]},
{"variable": "settings", "value_selector": ["sys", "config", "app", "theme"]},
],
"template": "User: {{ user_info }}, Theme: {{ settings }}",
}
node_data = TemplateTransformNodeData.model_validate(data)
assert len(node_data.variables) == 2
assert node_data.variables[0].value_selector == ["sys", "user", "profile", "name"]
assert node_data.variables[1].value_selector == ["sys", "config", "app", "theme"]
def test_template_transform_node_data_with_multiline_template(self):
"""Test TemplateTransformNodeData with multiline template."""
data = {
"title": "Multiline Template",
"variables": [
{"variable": "title", "value_selector": ["sys", "title"]},
{"variable": "content", "value_selector": ["sys", "content"]},
],
"template": """
# {{ title }}
{{ content }}
---
Generated by Template Transform Node
""",
}
node_data = TemplateTransformNodeData.model_validate(data)
assert "# {{ title }}" in node_data.template
assert "{{ content }}" in node_data.template
assert "Generated by Template Transform Node" in node_data.template
def test_template_transform_node_data_serialization(self):
"""Test that TemplateTransformNodeData can be serialized and deserialized."""
original_data = {
"title": "Serialization Test",
"desc": "Test serialization",
"variables": [{"variable": "test", "value_selector": ["sys", "test"]}],
"template": "{{ test }}",
}
node_data = TemplateTransformNodeData.model_validate(original_data)
serialized = node_data.model_dump()
deserialized = TemplateTransformNodeData.model_validate(serialized)
assert deserialized.title == node_data.title
assert deserialized.desc == node_data.desc
assert len(deserialized.variables) == len(node_data.variables)
assert deserialized.template == node_data.template
def test_template_transform_node_data_with_special_characters(self):
"""Test TemplateTransformNodeData with special characters in template."""
data = {
"title": "Special Characters",
"variables": [{"variable": "text", "value_selector": ["sys", "input"]}],
"template": "Special: {{ text }} | Symbols: @#$%^&*() | Unicode: 你好 🎉",
}
node_data = TemplateTransformNodeData.model_validate(data)
assert "@#$%^&*()" in node_data.template
assert "你好" in node_data.template
assert "🎉" in node_data.template
def test_template_transform_node_data_empty_template(self):
"""Test TemplateTransformNodeData with empty template string."""
data = {
"title": "Empty Template",
"variables": [],
"template": "",
}
node_data = TemplateTransformNodeData.model_validate(data)
assert node_data.template == ""
assert len(node_data.variables) == 0

View File

@ -0,0 +1,414 @@
from unittest.mock import MagicMock, patch
import pytest
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.helper.code_executor.code_executor import CodeExecutionError
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode
from models.workflow import WorkflowType
class TestTemplateTransformNode:
"""Comprehensive test suite for TemplateTransformNode."""
@pytest.fixture
def mock_graph_runtime_state(self):
"""Create a mock GraphRuntimeState with variable pool."""
mock_state = MagicMock(spec=GraphRuntimeState)
mock_variable_pool = MagicMock()
mock_state.variable_pool = mock_variable_pool
return mock_state
@pytest.fixture
def mock_graph(self):
"""Create a mock Graph."""
return MagicMock(spec=Graph)
@pytest.fixture
def graph_init_params(self):
"""Create a mock GraphInitParams."""
return GraphInitParams(
tenant_id="test_tenant",
app_id="test_app",
workflow_type=WorkflowType.WORKFLOW,
workflow_id="test_workflow",
graph_config={},
user_id="test_user",
user_from="test",
invoke_from="test",
call_depth=0,
)
@pytest.fixture
def basic_node_data(self):
"""Create basic node data for testing."""
return {
"title": "Template Transform",
"desc": "Transform data using template",
"variables": [
{"variable": "name", "value_selector": ["sys", "user_name"]},
{"variable": "age", "value_selector": ["sys", "user_age"]},
],
"template": "Hello {{ name }}, you are {{ age }} years old!",
}
def test_node_initialization(self, basic_node_data, mock_graph, mock_graph_runtime_state, graph_init_params):
"""Test that TemplateTransformNode initializes correctly."""
node = TemplateTransformNode(
id="test_node",
config=basic_node_data,
graph_init_params=graph_init_params,
graph=mock_graph,
graph_runtime_state=mock_graph_runtime_state,
)
assert node.node_type == NodeType.TEMPLATE_TRANSFORM
assert node._node_data.title == "Template Transform"
assert len(node._node_data.variables) == 2
assert node._node_data.template == "Hello {{ name }}, you are {{ age }} years old!"
def test_get_title(self, basic_node_data, mock_graph, mock_graph_runtime_state, graph_init_params):
"""Test _get_title method."""
node = TemplateTransformNode(
id="test_node",
config=basic_node_data,
graph_init_params=graph_init_params,
graph=mock_graph,
graph_runtime_state=mock_graph_runtime_state,
)
assert node._get_title() == "Template Transform"
def test_get_description(self, basic_node_data, mock_graph, mock_graph_runtime_state, graph_init_params):
"""Test _get_description method."""
node = TemplateTransformNode(
id="test_node",
config=basic_node_data,
graph_init_params=graph_init_params,
graph=mock_graph,
graph_runtime_state=mock_graph_runtime_state,
)
assert node._get_description() == "Transform data using template"
def test_get_error_strategy(self, mock_graph, mock_graph_runtime_state, graph_init_params):
"""Test _get_error_strategy method."""
node_data = {
"title": "Test",
"variables": [],
"template": "test",
"error_strategy": "fail-branch",
}
node = TemplateTransformNode(
id="test_node",
config=node_data,
graph_init_params=graph_init_params,
graph=mock_graph,
graph_runtime_state=mock_graph_runtime_state,
)
assert node._get_error_strategy() == ErrorStrategy.FAIL_BRANCH
def test_get_default_config(self):
"""Test get_default_config class method."""
config = TemplateTransformNode.get_default_config()
assert config["type"] == "template-transform"
assert "config" in config
assert "variables" in config["config"]
assert "template" in config["config"]
assert config["config"]["template"] == "{{ arg1 }}"
def test_version(self):
"""Test version class method."""
assert TemplateTransformNode.version() == "1"
@patch("core.workflow.nodes.template_transform.template_transform_node.CodeExecutor.execute_workflow_code_template")
def test_run_simple_template(
self, mock_execute, basic_node_data, mock_graph, mock_graph_runtime_state, graph_init_params
):
"""Test _run with simple template transformation."""
# Setup mock variable pool
mock_name_value = MagicMock()
mock_name_value.to_object.return_value = "Alice"
mock_age_value = MagicMock()
mock_age_value.to_object.return_value = 30
variable_map = {
("sys", "user_name"): mock_name_value,
("sys", "user_age"): mock_age_value,
}
mock_graph_runtime_state.variable_pool.get.side_effect = lambda selector: variable_map.get(tuple(selector))
# Setup mock executor
mock_execute.return_value = {"result": "Hello Alice, you are 30 years old!"}
node = TemplateTransformNode(
id="test_node",
config=basic_node_data,
graph_init_params=graph_init_params,
graph=mock_graph,
graph_runtime_state=mock_graph_runtime_state,
)
result = node._run()
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert result.outputs["output"] == "Hello Alice, you are 30 years old!"
assert result.inputs["name"] == "Alice"
assert result.inputs["age"] == 30
@patch("core.workflow.nodes.template_transform.template_transform_node.CodeExecutor.execute_workflow_code_template")
def test_run_with_none_values(self, mock_execute, mock_graph, mock_graph_runtime_state, graph_init_params):
"""Test _run with None variable values."""
node_data = {
"title": "Test",
"variables": [{"variable": "value", "value_selector": ["sys", "missing"]}],
"template": "Value: {{ value }}",
}
mock_graph_runtime_state.variable_pool.get.return_value = None
mock_execute.return_value = {"result": "Value: "}
node = TemplateTransformNode(
id="test_node",
config=node_data,
graph_init_params=graph_init_params,
graph=mock_graph,
graph_runtime_state=mock_graph_runtime_state,
)
result = node._run()
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert result.inputs["value"] is None
@patch("core.workflow.nodes.template_transform.template_transform_node.CodeExecutor.execute_workflow_code_template")
def test_run_with_code_execution_error(
self, mock_execute, basic_node_data, mock_graph, mock_graph_runtime_state, graph_init_params
):
"""Test _run when code execution fails."""
mock_graph_runtime_state.variable_pool.get.return_value = MagicMock()
mock_execute.side_effect = CodeExecutionError("Template syntax error")
node = TemplateTransformNode(
id="test_node",
config=basic_node_data,
graph_init_params=graph_init_params,
graph=mock_graph,
graph_runtime_state=mock_graph_runtime_state,
)
result = node._run()
assert result.status == WorkflowNodeExecutionStatus.FAILED
assert "Template syntax error" in result.error
@patch("core.workflow.nodes.template_transform.template_transform_node.CodeExecutor.execute_workflow_code_template")
@patch("core.workflow.nodes.template_transform.template_transform_node.MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH", 10)
def test_run_output_length_exceeds_limit(
self, mock_execute, basic_node_data, mock_graph, mock_graph_runtime_state, graph_init_params
):
"""Test _run when output exceeds maximum length."""
mock_graph_runtime_state.variable_pool.get.return_value = MagicMock()
mock_execute.return_value = {"result": "This is a very long output that exceeds the limit"}
node = TemplateTransformNode(
id="test_node",
config=basic_node_data,
graph_init_params=graph_init_params,
graph=mock_graph,
graph_runtime_state=mock_graph_runtime_state,
)
result = node._run()
assert result.status == WorkflowNodeExecutionStatus.FAILED
assert "Output length exceeds" in result.error
@patch("core.workflow.nodes.template_transform.template_transform_node.CodeExecutor.execute_workflow_code_template")
def test_run_with_complex_jinja2_template(
self, mock_execute, mock_graph, mock_graph_runtime_state, graph_init_params
):
"""Test _run with complex Jinja2 template including loops and conditions."""
node_data = {
"title": "Complex Template",
"variables": [
{"variable": "items", "value_selector": ["sys", "items"]},
{"variable": "show_total", "value_selector": ["sys", "show_total"]},
],
"template": (
"{% for item in items %}{{ item }}{% if not loop.last %}, {% endif %}{% endfor %}"
"{% if show_total %} (Total: {{ items|length }}){% endif %}"
),
}
mock_items = MagicMock()
mock_items.to_object.return_value = ["apple", "banana", "orange"]
mock_show_total = MagicMock()
mock_show_total.to_object.return_value = True
variable_map = {
("sys", "items"): mock_items,
("sys", "show_total"): mock_show_total,
}
mock_graph_runtime_state.variable_pool.get.side_effect = lambda selector: variable_map.get(tuple(selector))
mock_execute.return_value = {"result": "apple, banana, orange (Total: 3)"}
node = TemplateTransformNode(
id="test_node",
config=node_data,
graph_init_params=graph_init_params,
graph=mock_graph,
graph_runtime_state=mock_graph_runtime_state,
)
result = node._run()
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert result.outputs["output"] == "apple, banana, orange (Total: 3)"
def test_extract_variable_selector_to_variable_mapping(self):
"""Test _extract_variable_selector_to_variable_mapping class method."""
node_data = {
"title": "Test",
"variables": [
{"variable": "var1", "value_selector": ["sys", "input1"]},
{"variable": "var2", "value_selector": ["sys", "input2"]},
],
"template": "{{ var1 }} {{ var2 }}",
}
mapping = TemplateTransformNode._extract_variable_selector_to_variable_mapping(
graph_config={}, node_id="node_123", node_data=node_data
)
assert "node_123.var1" in mapping
assert "node_123.var2" in mapping
assert mapping["node_123.var1"] == ["sys", "input1"]
assert mapping["node_123.var2"] == ["sys", "input2"]
@patch("core.workflow.nodes.template_transform.template_transform_node.CodeExecutor.execute_workflow_code_template")
def test_run_with_empty_variables(self, mock_execute, mock_graph, mock_graph_runtime_state, graph_init_params):
"""Test _run with no variables (static template)."""
node_data = {
"title": "Static Template",
"variables": [],
"template": "This is a static message.",
}
mock_execute.return_value = {"result": "This is a static message."}
node = TemplateTransformNode(
id="test_node",
config=node_data,
graph_init_params=graph_init_params,
graph=mock_graph,
graph_runtime_state=mock_graph_runtime_state,
)
result = node._run()
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert result.outputs["output"] == "This is a static message."
assert result.inputs == {}
@patch("core.workflow.nodes.template_transform.template_transform_node.CodeExecutor.execute_workflow_code_template")
def test_run_with_numeric_values(self, mock_execute, mock_graph, mock_graph_runtime_state, graph_init_params):
"""Test _run with numeric variable values."""
node_data = {
"title": "Numeric Template",
"variables": [
{"variable": "price", "value_selector": ["sys", "price"]},
{"variable": "quantity", "value_selector": ["sys", "quantity"]},
],
"template": "Total: ${{ price * quantity }}",
}
mock_price = MagicMock()
mock_price.to_object.return_value = 10.5
mock_quantity = MagicMock()
mock_quantity.to_object.return_value = 3
variable_map = {
("sys", "price"): mock_price,
("sys", "quantity"): mock_quantity,
}
mock_graph_runtime_state.variable_pool.get.side_effect = lambda selector: variable_map.get(tuple(selector))
mock_execute.return_value = {"result": "Total: $31.5"}
node = TemplateTransformNode(
id="test_node",
config=node_data,
graph_init_params=graph_init_params,
graph=mock_graph,
graph_runtime_state=mock_graph_runtime_state,
)
result = node._run()
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert result.outputs["output"] == "Total: $31.5"
@patch("core.workflow.nodes.template_transform.template_transform_node.CodeExecutor.execute_workflow_code_template")
def test_run_with_dict_values(self, mock_execute, mock_graph, mock_graph_runtime_state, graph_init_params):
"""Test _run with dictionary variable values."""
node_data = {
"title": "Dict Template",
"variables": [{"variable": "user", "value_selector": ["sys", "user_data"]}],
"template": "Name: {{ user.name }}, Email: {{ user.email }}",
}
mock_user = MagicMock()
mock_user.to_object.return_value = {"name": "John Doe", "email": "john@example.com"}
mock_graph_runtime_state.variable_pool.get.return_value = mock_user
mock_execute.return_value = {"result": "Name: John Doe, Email: john@example.com"}
node = TemplateTransformNode(
id="test_node",
config=node_data,
graph_init_params=graph_init_params,
graph=mock_graph,
graph_runtime_state=mock_graph_runtime_state,
)
result = node._run()
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert "John Doe" in result.outputs["output"]
assert "john@example.com" in result.outputs["output"]
@patch("core.workflow.nodes.template_transform.template_transform_node.CodeExecutor.execute_workflow_code_template")
def test_run_with_list_values(self, mock_execute, mock_graph, mock_graph_runtime_state, graph_init_params):
"""Test _run with list variable values."""
node_data = {
"title": "List Template",
"variables": [{"variable": "tags", "value_selector": ["sys", "tags"]}],
"template": "Tags: {% for tag in tags %}#{{ tag }} {% endfor %}",
}
mock_tags = MagicMock()
mock_tags.to_object.return_value = ["python", "ai", "workflow"]
mock_graph_runtime_state.variable_pool.get.return_value = mock_tags
mock_execute.return_value = {"result": "Tags: #python #ai #workflow "}
node = TemplateTransformNode(
id="test_node",
config=node_data,
graph_init_params=graph_init_params,
graph=mock_graph,
graph_runtime_state=mock_graph_runtime_state,
)
result = node._run()
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert "#python" in result.outputs["output"]
assert "#ai" in result.outputs["output"]
assert "#workflow" in result.outputs["output"]

View File

@ -0,0 +1,160 @@
import sys
import types
from collections.abc import Generator
from typing import TYPE_CHECKING, Any
from unittest.mock import MagicMock, patch
import pytest
from core.file import File, FileTransferMethod, FileType
from core.model_runtime.entities.llm_entities import LLMUsage
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.utils.message_transformer import ToolFileMessageTransformer
from core.variables.segments import ArrayFileSegment
from core.workflow.entities import GraphInitParams
from core.workflow.node_events import StreamChunkEvent, StreamCompletedEvent
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
if TYPE_CHECKING: # pragma: no cover - imported for type checking only
from core.workflow.nodes.tool.tool_node import ToolNode
@pytest.fixture
def tool_node(monkeypatch) -> "ToolNode":
module_name = "core.ops.ops_trace_manager"
if module_name not in sys.modules:
ops_stub = types.ModuleType(module_name)
ops_stub.TraceQueueManager = object # pragma: no cover - stub attribute
ops_stub.TraceTask = object # pragma: no cover - stub attribute
monkeypatch.setitem(sys.modules, module_name, ops_stub)
from core.workflow.nodes.tool.tool_node import ToolNode
graph_config: dict[str, Any] = {
"nodes": [
{
"id": "tool-node",
"data": {
"type": "tool",
"title": "Tool",
"desc": "",
"provider_id": "provider",
"provider_type": "builtin",
"provider_name": "provider",
"tool_name": "tool",
"tool_label": "tool",
"tool_configurations": {},
"tool_parameters": {},
},
}
],
"edges": [],
}
init_params = GraphInitParams(
tenant_id="tenant-id",
app_id="app-id",
workflow_id="workflow-id",
graph_config=graph_config,
user_id="user-id",
user_from="account",
invoke_from="debugger",
call_depth=0,
)
variable_pool = VariablePool(system_variables=SystemVariable(user_id="user-id"))
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=0.0)
config = graph_config["nodes"][0]
node = ToolNode(
id="node-instance",
config=config,
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
)
node.init_node_data(config["data"])
return node
def _collect_events(generator: Generator) -> tuple[list[Any], LLMUsage]:
events: list[Any] = []
try:
while True:
events.append(next(generator))
except StopIteration as stop:
return events, stop.value
def _run_transform(tool_node: "ToolNode", message: ToolInvokeMessage) -> tuple[list[Any], LLMUsage]:
def _identity_transform(messages, *_args, **_kwargs):
return messages
tool_runtime = MagicMock()
with patch.object(ToolFileMessageTransformer, "transform_tool_invoke_messages", side_effect=_identity_transform):
generator = tool_node._transform_message(
messages=iter([message]),
tool_info={"provider_type": "builtin", "provider_id": "provider"},
parameters_for_log={},
user_id="user-id",
tenant_id="tenant-id",
node_id=tool_node._node_id,
tool_runtime=tool_runtime,
)
return _collect_events(generator)
def test_link_messages_with_file_populate_files_output(tool_node: "ToolNode"):
file_obj = File(
tenant_id="tenant-id",
type=FileType.DOCUMENT,
transfer_method=FileTransferMethod.TOOL_FILE,
related_id="file-id",
filename="demo.pdf",
extension=".pdf",
mime_type="application/pdf",
size=123,
storage_key="file-key",
)
message = ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.LINK,
message=ToolInvokeMessage.TextMessage(text="/files/tools/file-id.pdf"),
meta={"file": file_obj},
)
events, usage = _run_transform(tool_node, message)
assert isinstance(usage, LLMUsage)
chunk_events = [event for event in events if isinstance(event, StreamChunkEvent)]
assert chunk_events
assert chunk_events[0].chunk == "File: /files/tools/file-id.pdf\n"
completed_events = [event for event in events if isinstance(event, StreamCompletedEvent)]
assert len(completed_events) == 1
outputs = completed_events[0].node_run_result.outputs
assert outputs["text"] == "File: /files/tools/file-id.pdf\n"
files_segment = outputs["files"]
assert isinstance(files_segment, ArrayFileSegment)
assert files_segment.value == [file_obj]
def test_plain_link_messages_remain_links(tool_node: "ToolNode"):
message = ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.LINK,
message=ToolInvokeMessage.TextMessage(text="https://dify.ai"),
meta=None,
)
events, _ = _run_transform(tool_node, message)
chunk_events = [event for event in events if isinstance(event, StreamChunkEvent)]
assert chunk_events
assert chunk_events[0].chunk == "Link: https://dify.ai\n"
completed_events = [event for event in events if isinstance(event, StreamCompletedEvent)]
assert len(completed_events) == 1
files_segment = completed_events[0].node_run_result.outputs["files"]
assert isinstance(files_segment, ArrayFileSegment)
assert files_segment.value == []

View File

@ -0,0 +1,966 @@
"""
Comprehensive unit tests for Tool models.
This test suite covers:
- ToolProvider model validation (BuiltinToolProvider, ApiToolProvider)
- BuiltinToolProvider relationships and credential management
- ApiToolProvider credential storage and encryption
- Tool OAuth client models
- ToolLabelBinding relationships
"""
import json
from uuid import uuid4
from core.tools.entities.tool_entities import ApiProviderSchemaType
from models.tools import (
ApiToolProvider,
BuiltinToolProvider,
ToolLabelBinding,
ToolOAuthSystemClient,
ToolOAuthTenantClient,
)
class TestBuiltinToolProviderValidation:
"""Test suite for BuiltinToolProvider model validation and operations."""
def test_builtin_tool_provider_creation_with_required_fields(self):
"""Test creating a builtin tool provider with all required fields."""
# Arrange
tenant_id = str(uuid4())
user_id = str(uuid4())
provider_name = "google"
credentials = {"api_key": "test_key_123"}
# Act
builtin_provider = BuiltinToolProvider(
tenant_id=tenant_id,
user_id=user_id,
provider=provider_name,
encrypted_credentials=json.dumps(credentials),
name="Google API Key 1",
)
# Assert
assert builtin_provider.tenant_id == tenant_id
assert builtin_provider.user_id == user_id
assert builtin_provider.provider == provider_name
assert builtin_provider.name == "Google API Key 1"
assert builtin_provider.encrypted_credentials == json.dumps(credentials)
def test_builtin_tool_provider_credentials_property(self):
"""Test credentials property parses JSON correctly."""
# Arrange
credentials_data = {
"api_key": "sk-test123",
"auth_type": "api_key",
"endpoint": "https://api.example.com",
}
builtin_provider = BuiltinToolProvider(
tenant_id=str(uuid4()),
user_id=str(uuid4()),
provider="custom_provider",
name="Custom Provider Key",
encrypted_credentials=json.dumps(credentials_data),
)
# Act
result = builtin_provider.credentials
# Assert
assert result == credentials_data
assert result["api_key"] == "sk-test123"
assert result["auth_type"] == "api_key"
def test_builtin_tool_provider_credentials_empty_when_none(self):
"""Test credentials property returns empty dict when encrypted_credentials is None."""
# Arrange
builtin_provider = BuiltinToolProvider(
tenant_id=str(uuid4()),
user_id=str(uuid4()),
provider="test_provider",
name="Test Provider",
encrypted_credentials=None,
)
# Act
result = builtin_provider.credentials
# Assert
assert result == {}
def test_builtin_tool_provider_credentials_empty_when_empty_string(self):
"""Test credentials property returns empty dict when encrypted_credentials is empty."""
# Arrange
builtin_provider = BuiltinToolProvider(
tenant_id=str(uuid4()),
user_id=str(uuid4()),
provider="test_provider",
name="Test Provider",
encrypted_credentials="",
)
# Act
result = builtin_provider.credentials
# Assert
assert result == {}
def test_builtin_tool_provider_default_values(self):
"""Test builtin tool provider default values."""
# Arrange & Act
builtin_provider = BuiltinToolProvider(
tenant_id=str(uuid4()),
user_id=str(uuid4()),
provider="test_provider",
name="Test Provider",
)
# Assert
assert builtin_provider.is_default is False
assert builtin_provider.credential_type == "api-key"
assert builtin_provider.expires_at == -1
def test_builtin_tool_provider_with_oauth_credential_type(self):
"""Test builtin tool provider with OAuth credential type."""
# Arrange
credentials = {
"access_token": "oauth_token_123",
"refresh_token": "refresh_token_456",
"token_type": "Bearer",
}
# Act
builtin_provider = BuiltinToolProvider(
tenant_id=str(uuid4()),
user_id=str(uuid4()),
provider="google",
name="Google OAuth",
encrypted_credentials=json.dumps(credentials),
credential_type="oauth2",
expires_at=1735689600,
)
# Assert
assert builtin_provider.credential_type == "oauth2"
assert builtin_provider.expires_at == 1735689600
assert builtin_provider.credentials["access_token"] == "oauth_token_123"
def test_builtin_tool_provider_is_default_flag(self):
"""Test is_default flag for builtin tool provider."""
# Arrange
provider1 = BuiltinToolProvider(
tenant_id=str(uuid4()),
user_id=str(uuid4()),
provider="google",
name="Google Key 1",
is_default=True,
)
provider2 = BuiltinToolProvider(
tenant_id=str(uuid4()),
user_id=str(uuid4()),
provider="google",
name="Google Key 2",
is_default=False,
)
# Assert
assert provider1.is_default is True
assert provider2.is_default is False
def test_builtin_tool_provider_unique_constraint_fields(self):
"""Test unique constraint fields (tenant_id, provider, name)."""
# Arrange
tenant_id = str(uuid4())
provider_name = "google"
credential_name = "My Google Key"
# Act
builtin_provider = BuiltinToolProvider(
tenant_id=tenant_id,
user_id=str(uuid4()),
provider=provider_name,
name=credential_name,
)
# Assert - these fields form unique constraint
assert builtin_provider.tenant_id == tenant_id
assert builtin_provider.provider == provider_name
assert builtin_provider.name == credential_name
def test_builtin_tool_provider_multiple_credentials_same_provider(self):
"""Test multiple credential sets for the same provider."""
# Arrange
tenant_id = str(uuid4())
user_id = str(uuid4())
provider = "openai"
# Act - create multiple credentials for same provider
provider1 = BuiltinToolProvider(
tenant_id=tenant_id,
user_id=user_id,
provider=provider,
name="OpenAI Key 1",
encrypted_credentials=json.dumps({"api_key": "key1"}),
)
provider2 = BuiltinToolProvider(
tenant_id=tenant_id,
user_id=user_id,
provider=provider,
name="OpenAI Key 2",
encrypted_credentials=json.dumps({"api_key": "key2"}),
)
# Assert - different names allow multiple credentials
assert provider1.provider == provider2.provider
assert provider1.name != provider2.name
assert provider1.credentials != provider2.credentials
class TestApiToolProviderValidation:
"""Test suite for ApiToolProvider model validation and operations."""
def test_api_tool_provider_creation_with_required_fields(self):
"""Test creating an API tool provider with all required fields."""
# Arrange
tenant_id = str(uuid4())
user_id = str(uuid4())
provider_name = "Custom API"
schema = '{"openapi": "3.0.0", "info": {"title": "Test API"}}'
tools = [{"name": "test_tool", "description": "A test tool"}]
credentials = {"auth_type": "api_key", "api_key_value": "test123"}
# Act
api_provider = ApiToolProvider(
tenant_id=tenant_id,
user_id=user_id,
name=provider_name,
icon='{"type": "emoji", "value": "🔧"}',
schema=schema,
schema_type_str="openapi",
description="Custom API for testing",
tools_str=json.dumps(tools),
credentials_str=json.dumps(credentials),
)
# Assert
assert api_provider.tenant_id == tenant_id
assert api_provider.user_id == user_id
assert api_provider.name == provider_name
assert api_provider.schema == schema
assert api_provider.schema_type_str == "openapi"
assert api_provider.description == "Custom API for testing"
def test_api_tool_provider_schema_type_property(self):
"""Test schema_type property converts string to enum."""
# Arrange
api_provider = ApiToolProvider(
tenant_id=str(uuid4()),
user_id=str(uuid4()),
name="Test API",
icon="{}",
schema="{}",
schema_type_str="openapi",
description="Test",
tools_str="[]",
credentials_str="{}",
)
# Act
result = api_provider.schema_type
# Assert
assert result == ApiProviderSchemaType.OPENAPI
def test_api_tool_provider_tools_property(self):
"""Test tools property parses JSON and returns ApiToolBundle list."""
# Arrange
tools_data = [
{
"author": "test",
"server_url": "https://api.weather.com",
"method": "get",
"summary": "Get weather information",
"operation_id": "getWeather",
"parameters": [],
"openapi": {
"operation_id": "getWeather",
"parameters": [],
"method": "get",
"path": "/weather",
"server_url": "https://api.weather.com",
},
},
{
"author": "test",
"server_url": "https://api.location.com",
"method": "get",
"summary": "Get location data",
"operation_id": "getLocation",
"parameters": [],
"openapi": {
"operation_id": "getLocation",
"parameters": [],
"method": "get",
"path": "/location",
"server_url": "https://api.location.com",
},
},
]
api_provider = ApiToolProvider(
tenant_id=str(uuid4()),
user_id=str(uuid4()),
name="Weather API",
icon="{}",
schema="{}",
schema_type_str="openapi",
description="Weather API",
tools_str=json.dumps(tools_data),
credentials_str="{}",
)
# Act
result = api_provider.tools
# Assert
assert len(result) == 2
assert result[0].operation_id == "getWeather"
assert result[1].operation_id == "getLocation"
def test_api_tool_provider_credentials_property(self):
"""Test credentials property parses JSON correctly."""
# Arrange
credentials_data = {
"auth_type": "api_key_header",
"api_key_header": "Authorization",
"api_key_value": "Bearer test_token",
"api_key_header_prefix": "bearer",
}
api_provider = ApiToolProvider(
tenant_id=str(uuid4()),
user_id=str(uuid4()),
name="Secure API",
icon="{}",
schema="{}",
schema_type_str="openapi",
description="Secure API",
tools_str="[]",
credentials_str=json.dumps(credentials_data),
)
# Act
result = api_provider.credentials
# Assert
assert result["auth_type"] == "api_key_header"
assert result["api_key_header"] == "Authorization"
assert result["api_key_value"] == "Bearer test_token"
def test_api_tool_provider_with_privacy_policy(self):
"""Test API tool provider with privacy policy."""
# Arrange
privacy_policy_url = "https://example.com/privacy"
# Act
api_provider = ApiToolProvider(
tenant_id=str(uuid4()),
user_id=str(uuid4()),
name="Privacy API",
icon="{}",
schema="{}",
schema_type_str="openapi",
description="API with privacy policy",
tools_str="[]",
credentials_str="{}",
privacy_policy=privacy_policy_url,
)
# Assert
assert api_provider.privacy_policy == privacy_policy_url
def test_api_tool_provider_with_custom_disclaimer(self):
"""Test API tool provider with custom disclaimer."""
# Arrange
disclaimer = "This API is provided as-is without warranty."
# Act
api_provider = ApiToolProvider(
tenant_id=str(uuid4()),
user_id=str(uuid4()),
name="Disclaimer API",
icon="{}",
schema="{}",
schema_type_str="openapi",
description="API with disclaimer",
tools_str="[]",
credentials_str="{}",
custom_disclaimer=disclaimer,
)
# Assert
assert api_provider.custom_disclaimer == disclaimer
def test_api_tool_provider_default_custom_disclaimer(self):
"""Test API tool provider default custom_disclaimer is empty string."""
# Arrange & Act
api_provider = ApiToolProvider(
tenant_id=str(uuid4()),
user_id=str(uuid4()),
name="Default API",
icon="{}",
schema="{}",
schema_type_str="openapi",
description="API",
tools_str="[]",
credentials_str="{}",
)
# Assert
assert api_provider.custom_disclaimer == ""
def test_api_tool_provider_unique_constraint_fields(self):
"""Test unique constraint fields (name, tenant_id)."""
# Arrange
tenant_id = str(uuid4())
provider_name = "Unique API"
# Act
api_provider = ApiToolProvider(
tenant_id=tenant_id,
user_id=str(uuid4()),
name=provider_name,
icon="{}",
schema="{}",
schema_type_str="openapi",
description="Unique API",
tools_str="[]",
credentials_str="{}",
)
# Assert - these fields form unique constraint
assert api_provider.tenant_id == tenant_id
assert api_provider.name == provider_name
def test_api_tool_provider_with_no_auth(self):
"""Test API tool provider with no authentication."""
# Arrange
credentials = {"auth_type": "none"}
# Act
api_provider = ApiToolProvider(
tenant_id=str(uuid4()),
user_id=str(uuid4()),
name="Public API",
icon="{}",
schema="{}",
schema_type_str="openapi",
description="Public API with no auth",
tools_str="[]",
credentials_str=json.dumps(credentials),
)
# Assert
assert api_provider.credentials["auth_type"] == "none"
def test_api_tool_provider_with_api_key_query_auth(self):
"""Test API tool provider with API key in query parameter."""
# Arrange
credentials = {
"auth_type": "api_key_query",
"api_key_query_param": "apikey",
"api_key_value": "my_secret_key",
}
# Act
api_provider = ApiToolProvider(
tenant_id=str(uuid4()),
user_id=str(uuid4()),
name="Query Auth API",
icon="{}",
schema="{}",
schema_type_str="openapi",
description="API with query auth",
tools_str="[]",
credentials_str=json.dumps(credentials),
)
# Assert
assert api_provider.credentials["auth_type"] == "api_key_query"
assert api_provider.credentials["api_key_query_param"] == "apikey"
class TestToolOAuthModels:
"""Test suite for OAuth client models (system and tenant level)."""
def test_oauth_system_client_creation(self):
"""Test creating a system-level OAuth client."""
# Arrange
plugin_id = "builtin.google"
provider = "google"
oauth_params = json.dumps(
{"client_id": "system_client_id", "client_secret": "system_secret", "scope": "email profile"}
)
# Act
oauth_client = ToolOAuthSystemClient(
plugin_id=plugin_id,
provider=provider,
encrypted_oauth_params=oauth_params,
)
# Assert
assert oauth_client.plugin_id == plugin_id
assert oauth_client.provider == provider
assert oauth_client.encrypted_oauth_params == oauth_params
def test_oauth_system_client_unique_constraint(self):
"""Test unique constraint on plugin_id and provider."""
# Arrange
plugin_id = "builtin.github"
provider = "github"
# Act
oauth_client = ToolOAuthSystemClient(
plugin_id=plugin_id,
provider=provider,
encrypted_oauth_params="{}",
)
# Assert - these fields form unique constraint
assert oauth_client.plugin_id == plugin_id
assert oauth_client.provider == provider
def test_oauth_tenant_client_creation(self):
"""Test creating a tenant-level OAuth client."""
# Arrange
tenant_id = str(uuid4())
plugin_id = "builtin.google"
provider = "google"
# Act
oauth_client = ToolOAuthTenantClient(
tenant_id=tenant_id,
plugin_id=plugin_id,
provider=provider,
)
# Set encrypted_oauth_params after creation (it has init=False)
oauth_params = json.dumps({"client_id": "tenant_client_id", "client_secret": "tenant_secret"})
oauth_client.encrypted_oauth_params = oauth_params
# Assert
assert oauth_client.tenant_id == tenant_id
assert oauth_client.plugin_id == plugin_id
assert oauth_client.provider == provider
def test_oauth_tenant_client_enabled_default(self):
"""Test OAuth tenant client enabled flag has init=False and uses server default."""
# Arrange & Act
oauth_client = ToolOAuthTenantClient(
tenant_id=str(uuid4()),
plugin_id="builtin.slack",
provider="slack",
)
# Assert - enabled has init=False, so it won't be set until saved to DB
# We can manually set it to test the field exists
oauth_client.enabled = True
assert oauth_client.enabled is True
def test_oauth_tenant_client_oauth_params_property(self):
"""Test oauth_params property parses JSON correctly."""
# Arrange
params_data = {
"client_id": "test_client_123",
"client_secret": "secret_456",
"redirect_uri": "https://app.example.com/callback",
}
oauth_client = ToolOAuthTenantClient(
tenant_id=str(uuid4()),
plugin_id="builtin.dropbox",
provider="dropbox",
)
# Set encrypted_oauth_params after creation (it has init=False)
oauth_client.encrypted_oauth_params = json.dumps(params_data)
# Act
result = oauth_client.oauth_params
# Assert
assert result == params_data
assert result["client_id"] == "test_client_123"
assert result["redirect_uri"] == "https://app.example.com/callback"
def test_oauth_tenant_client_oauth_params_empty_when_none(self):
"""Test oauth_params returns empty dict when encrypted_oauth_params is None."""
# Arrange
oauth_client = ToolOAuthTenantClient(
tenant_id=str(uuid4()),
plugin_id="builtin.test",
provider="test",
)
# encrypted_oauth_params has init=False, set it to None
oauth_client.encrypted_oauth_params = None
# Act
result = oauth_client.oauth_params
# Assert
assert result == {}
def test_oauth_tenant_client_disabled_state(self):
"""Test OAuth tenant client can be disabled."""
# Arrange
oauth_client = ToolOAuthTenantClient(
tenant_id=str(uuid4()),
plugin_id="builtin.microsoft",
provider="microsoft",
)
# Act
oauth_client.enabled = False
# Assert
assert oauth_client.enabled is False
class TestToolLabelBinding:
"""Test suite for ToolLabelBinding model."""
def test_tool_label_binding_creation(self):
"""Test creating a tool label binding."""
# Arrange
tool_id = "google.search"
tool_type = "builtin"
label_name = "search"
# Act
label_binding = ToolLabelBinding(
tool_id=tool_id,
tool_type=tool_type,
label_name=label_name,
)
# Assert
assert label_binding.tool_id == tool_id
assert label_binding.tool_type == tool_type
assert label_binding.label_name == label_name
def test_tool_label_binding_unique_constraint(self):
"""Test unique constraint on tool_id and label_name."""
# Arrange
tool_id = "openai.text_generation"
label_name = "text"
# Act
label_binding = ToolLabelBinding(
tool_id=tool_id,
tool_type="builtin",
label_name=label_name,
)
# Assert - these fields form unique constraint
assert label_binding.tool_id == tool_id
assert label_binding.label_name == label_name
def test_tool_label_binding_multiple_labels_same_tool(self):
"""Test multiple labels can be bound to the same tool."""
# Arrange
tool_id = "google.search"
tool_type = "builtin"
# Act
binding1 = ToolLabelBinding(
tool_id=tool_id,
tool_type=tool_type,
label_name="search",
)
binding2 = ToolLabelBinding(
tool_id=tool_id,
tool_type=tool_type,
label_name="productivity",
)
# Assert
assert binding1.tool_id == binding2.tool_id
assert binding1.label_name != binding2.label_name
def test_tool_label_binding_different_tool_types(self):
"""Test label bindings for different tool types."""
# Arrange
tool_types = ["builtin", "api", "workflow"]
# Act & Assert
for tool_type in tool_types:
binding = ToolLabelBinding(
tool_id=f"test_tool_{tool_type}",
tool_type=tool_type,
label_name="test",
)
assert binding.tool_type == tool_type
class TestCredentialStorage:
"""Test suite for credential storage and encryption patterns."""
def test_builtin_provider_credential_storage_format(self):
"""Test builtin provider stores credentials as JSON string."""
# Arrange
credentials = {
"api_key": "sk-test123",
"endpoint": "https://api.example.com",
"timeout": 30,
}
# Act
provider = BuiltinToolProvider(
tenant_id=str(uuid4()),
user_id=str(uuid4()),
provider="test",
name="Test Provider",
encrypted_credentials=json.dumps(credentials),
)
# Assert
assert isinstance(provider.encrypted_credentials, str)
assert provider.credentials == credentials
def test_api_provider_credential_storage_format(self):
"""Test API provider stores credentials as JSON string."""
# Arrange
credentials = {
"auth_type": "api_key_header",
"api_key_header": "X-API-Key",
"api_key_value": "secret_key_789",
}
# Act
provider = ApiToolProvider(
tenant_id=str(uuid4()),
user_id=str(uuid4()),
name="Test API",
icon="{}",
schema="{}",
schema_type_str="openapi",
description="Test",
tools_str="[]",
credentials_str=json.dumps(credentials),
)
# Assert
assert isinstance(provider.credentials_str, str)
assert provider.credentials == credentials
def test_builtin_provider_complex_credential_structure(self):
"""Test builtin provider with complex nested credential structure."""
# Arrange
credentials = {
"auth_type": "oauth2",
"oauth_config": {
"access_token": "token123",
"refresh_token": "refresh456",
"expires_in": 3600,
"token_type": "Bearer",
},
"additional_headers": {"X-Custom-Header": "value"},
}
# Act
provider = BuiltinToolProvider(
tenant_id=str(uuid4()),
user_id=str(uuid4()),
provider="oauth_provider",
name="OAuth Provider",
encrypted_credentials=json.dumps(credentials),
)
# Assert
assert provider.credentials["oauth_config"]["access_token"] == "token123"
assert provider.credentials["additional_headers"]["X-Custom-Header"] == "value"
def test_api_provider_credential_update_pattern(self):
"""Test pattern for updating API provider credentials."""
# Arrange
original_credentials = {"auth_type": "api_key_header", "api_key_value": "old_key"}
provider = ApiToolProvider(
tenant_id=str(uuid4()),
user_id=str(uuid4()),
name="Update Test",
icon="{}",
schema="{}",
schema_type_str="openapi",
description="Test",
tools_str="[]",
credentials_str=json.dumps(original_credentials),
)
# Act - simulate credential update
new_credentials = {"auth_type": "api_key_header", "api_key_value": "new_key"}
provider.credentials_str = json.dumps(new_credentials)
# Assert
assert provider.credentials["api_key_value"] == "new_key"
def test_builtin_provider_credential_expiration(self):
"""Test builtin provider credential expiration tracking."""
# Arrange
future_timestamp = 1735689600 # Future date
past_timestamp = 1609459200 # Past date
# Act
active_provider = BuiltinToolProvider(
tenant_id=str(uuid4()),
user_id=str(uuid4()),
provider="active",
name="Active Provider",
expires_at=future_timestamp,
)
expired_provider = BuiltinToolProvider(
tenant_id=str(uuid4()),
user_id=str(uuid4()),
provider="expired",
name="Expired Provider",
expires_at=past_timestamp,
)
never_expires_provider = BuiltinToolProvider(
tenant_id=str(uuid4()),
user_id=str(uuid4()),
provider="permanent",
name="Permanent Provider",
expires_at=-1,
)
# Assert
assert active_provider.expires_at == future_timestamp
assert expired_provider.expires_at == past_timestamp
assert never_expires_provider.expires_at == -1
def test_oauth_client_credential_storage(self):
"""Test OAuth client credential storage pattern."""
# Arrange
oauth_credentials = {
"client_id": "oauth_client_123",
"client_secret": "oauth_secret_456",
"authorization_url": "https://oauth.example.com/authorize",
"token_url": "https://oauth.example.com/token",
"scope": "read write",
}
# Act
system_client = ToolOAuthSystemClient(
plugin_id="builtin.oauth_test",
provider="oauth_test",
encrypted_oauth_params=json.dumps(oauth_credentials),
)
tenant_client = ToolOAuthTenantClient(
tenant_id=str(uuid4()),
plugin_id="builtin.oauth_test",
provider="oauth_test",
)
# Set encrypted_oauth_params after creation (it has init=False)
tenant_client.encrypted_oauth_params = json.dumps(oauth_credentials)
# Assert
assert system_client.encrypted_oauth_params == json.dumps(oauth_credentials)
assert tenant_client.oauth_params == oauth_credentials
class TestToolProviderRelationships:
"""Test suite for tool provider relationships and associations."""
def test_builtin_provider_tenant_relationship(self):
"""Test builtin provider belongs to a tenant."""
# Arrange
tenant_id = str(uuid4())
# Act
provider = BuiltinToolProvider(
tenant_id=tenant_id,
user_id=str(uuid4()),
provider="test",
name="Test Provider",
)
# Assert
assert provider.tenant_id == tenant_id
def test_api_provider_user_relationship(self):
"""Test API provider belongs to a user."""
# Arrange
user_id = str(uuid4())
# Act
provider = ApiToolProvider(
tenant_id=str(uuid4()),
user_id=user_id,
name="User API",
icon="{}",
schema="{}",
schema_type_str="openapi",
description="Test",
tools_str="[]",
credentials_str="{}",
)
# Assert
assert provider.user_id == user_id
def test_multiple_providers_same_tenant(self):
"""Test multiple providers can belong to the same tenant."""
# Arrange
tenant_id = str(uuid4())
user_id = str(uuid4())
# Act
builtin1 = BuiltinToolProvider(
tenant_id=tenant_id,
user_id=user_id,
provider="google",
name="Google Key 1",
)
builtin2 = BuiltinToolProvider(
tenant_id=tenant_id,
user_id=user_id,
provider="openai",
name="OpenAI Key 1",
)
api1 = ApiToolProvider(
tenant_id=tenant_id,
user_id=user_id,
name="Custom API 1",
icon="{}",
schema="{}",
schema_type_str="openapi",
description="Test",
tools_str="[]",
credentials_str="{}",
)
# Assert
assert builtin1.tenant_id == tenant_id
assert builtin2.tenant_id == tenant_id
assert api1.tenant_id == tenant_id
def test_tool_label_bindings_for_provider_tools(self):
"""Test tool label bindings can be associated with provider tools."""
# Arrange
provider_name = "google"
tool_id = f"{provider_name}.search"
# Act
binding1 = ToolLabelBinding(
tool_id=tool_id,
tool_type="builtin",
label_name="search",
)
binding2 = ToolLabelBinding(
tool_id=tool_id,
tool_type="builtin",
label_name="web",
)
# Assert
assert binding1.tool_id == tool_id
assert binding2.tool_id == tool_id
assert binding1.label_name != binding2.label_name

View File

@ -6,10 +6,10 @@ from unittest.mock import Mock, patch
import pytest
from sqlalchemy.orm import Session, sessionmaker
from core.workflow.entities.workflow_pause import WorkflowPauseEntity
from core.workflow.enums import WorkflowExecutionStatus
from models.workflow import WorkflowPause as WorkflowPauseModel
from models.workflow import WorkflowRun
from repositories.entities.workflow_pause import WorkflowPauseEntity
from repositories.sqlalchemy_api_workflow_run_repository import (
DifyAPISQLAlchemyWorkflowRunRepository,
_PrivateWorkflowPauseEntity,
@ -129,12 +129,14 @@ class TestCreateWorkflowPause(TestDifyAPISQLAlchemyWorkflowRunRepository):
workflow_run_id=workflow_run_id,
state_owner_user_id=state_owner_user_id,
state=state,
pause_reasons=[],
)
# Assert
assert isinstance(result, _PrivateWorkflowPauseEntity)
assert result.id == "pause-123"
assert result.workflow_execution_id == workflow_run_id
assert result.get_pause_reasons() == []
# Verify database interactions
mock_session.get.assert_called_once_with(WorkflowRun, workflow_run_id)
@ -156,6 +158,7 @@ class TestCreateWorkflowPause(TestDifyAPISQLAlchemyWorkflowRunRepository):
workflow_run_id="workflow-run-123",
state_owner_user_id="user-123",
state='{"test": "state"}',
pause_reasons=[],
)
mock_session.get.assert_called_once_with(WorkflowRun, "workflow-run-123")
@ -174,6 +177,7 @@ class TestCreateWorkflowPause(TestDifyAPISQLAlchemyWorkflowRunRepository):
workflow_run_id="workflow-run-123",
state_owner_user_id="user-123",
state='{"test": "state"}',
pause_reasons=[],
)
@ -316,19 +320,10 @@ class TestDeleteWorkflowPause(TestDifyAPISQLAlchemyWorkflowRunRepository):
class TestPrivateWorkflowPauseEntity(TestDifyAPISQLAlchemyWorkflowRunRepository):
"""Test _PrivateWorkflowPauseEntity class."""
def test_from_models(self, sample_workflow_pause: Mock):
"""Test creating _PrivateWorkflowPauseEntity from models."""
# Act
entity = _PrivateWorkflowPauseEntity.from_models(sample_workflow_pause)
# Assert
assert isinstance(entity, _PrivateWorkflowPauseEntity)
assert entity._pause_model == sample_workflow_pause
def test_properties(self, sample_workflow_pause: Mock):
"""Test entity properties."""
# Arrange
entity = _PrivateWorkflowPauseEntity.from_models(sample_workflow_pause)
entity = _PrivateWorkflowPauseEntity(pause_model=sample_workflow_pause, reason_models=[], human_input_form=[])
# Act & Assert
assert entity.id == sample_workflow_pause.id
@ -338,7 +333,7 @@ class TestPrivateWorkflowPauseEntity(TestDifyAPISQLAlchemyWorkflowRunRepository)
def test_get_state(self, sample_workflow_pause: Mock):
"""Test getting state from storage."""
# Arrange
entity = _PrivateWorkflowPauseEntity.from_models(sample_workflow_pause)
entity = _PrivateWorkflowPauseEntity(pause_model=sample_workflow_pause, reason_models=[], human_input_form=[])
expected_state = b'{"test": "state"}'
with patch("repositories.sqlalchemy_api_workflow_run_repository.storage") as mock_storage:
@ -354,7 +349,7 @@ class TestPrivateWorkflowPauseEntity(TestDifyAPISQLAlchemyWorkflowRunRepository)
def test_get_state_caching(self, sample_workflow_pause: Mock):
"""Test state caching in get_state method."""
# Arrange
entity = _PrivateWorkflowPauseEntity.from_models(sample_workflow_pause)
entity = _PrivateWorkflowPauseEntity(pause_model=sample_workflow_pause, reason_models=[], human_input_form=[])
expected_state = b'{"test": "state"}'
with patch("repositories.sqlalchemy_api_workflow_run_repository.storage") as mock_storage:

View File

@ -0,0 +1,932 @@
"""
Comprehensive unit tests for DatasetCollectionBindingService.
This module contains extensive unit tests for the DatasetCollectionBindingService class,
which handles dataset collection binding operations for vector database collections.
The DatasetCollectionBindingService provides methods for:
- Retrieving or creating dataset collection bindings by provider, model, and type
- Retrieving specific collection bindings by ID and type
- Managing collection bindings for different collection types (dataset, etc.)
Collection bindings are used to map embedding models (provider + model name) to
specific vector database collections, allowing datasets to share collections when
they use the same embedding model configuration.
This test suite ensures:
- Correct retrieval of existing bindings
- Proper creation of new bindings when they don't exist
- Accurate filtering by provider, model, and collection type
- Proper error handling for missing bindings
- Database transaction handling (add, commit)
- Collection name generation using Dataset.gen_collection_name_by_id
================================================================================
ARCHITECTURE OVERVIEW
================================================================================
The DatasetCollectionBindingService is a critical component in the Dify platform's
vector database management system. It serves as an abstraction layer between the
application logic and the underlying vector database collections.
Key Concepts:
1. Collection Binding: A mapping between an embedding model configuration
(provider + model name) and a vector database collection name. This allows
multiple datasets to share the same collection when they use identical
embedding models, improving resource efficiency.
2. Collection Type: Different types of collections can exist (e.g., "dataset",
"custom_type"). This allows for separation of collections based on their
intended use case or data structure.
3. Provider and Model: The combination of provider_name (e.g., "openai",
"cohere", "huggingface") and model_name (e.g., "text-embedding-ada-002")
uniquely identifies an embedding model configuration.
4. Collection Name Generation: When a new binding is created, a unique collection
name is generated using Dataset.gen_collection_name_by_id() with a UUID.
This ensures each binding has a unique collection identifier.
================================================================================
TESTING STRATEGY
================================================================================
This test suite follows a comprehensive testing strategy that covers:
1. Happy Path Scenarios:
- Successful retrieval of existing bindings
- Successful creation of new bindings
- Proper handling of default parameters
2. Edge Cases:
- Different collection types
- Various provider/model combinations
- Default vs explicit parameter usage
3. Error Handling:
- Missing bindings (for get_by_id_and_type)
- Database query failures
- Invalid parameter combinations
4. Database Interaction:
- Query construction and execution
- Transaction management (add, commit)
- Query chaining (where, order_by, first)
5. Mocking Strategy:
- Database session mocking
- Query builder chain mocking
- UUID generation mocking
- Collection name generation mocking
================================================================================
"""
"""
Import statements for the test module.
This section imports all necessary dependencies for testing the
DatasetCollectionBindingService, including:
- unittest.mock for creating mock objects
- pytest for test framework functionality
- uuid for UUID generation (used in collection name generation)
- Models and services from the application codebase
"""
from unittest.mock import Mock, patch
import pytest
from models.dataset import Dataset, DatasetCollectionBinding
from services.dataset_service import DatasetCollectionBindingService
# ============================================================================
# Test Data Factory
# ============================================================================
# The Test Data Factory pattern is used here to centralize the creation of
# test objects and mock instances. This approach provides several benefits:
#
# 1. Consistency: All test objects are created using the same factory methods,
# ensuring consistent structure across all tests.
#
# 2. Maintainability: If the structure of DatasetCollectionBinding or Dataset
# changes, we only need to update the factory methods rather than every
# individual test.
#
# 3. Reusability: Factory methods can be reused across multiple test classes,
# reducing code duplication.
#
# 4. Readability: Tests become more readable when they use descriptive factory
# method calls instead of complex object construction logic.
#
# ============================================================================
class DatasetCollectionBindingTestDataFactory:
"""
Factory class for creating test data and mock objects for dataset collection binding tests.
This factory provides static methods to create mock objects for:
- DatasetCollectionBinding instances
- Database query results
- Collection name generation results
The factory methods help maintain consistency across tests and reduce
code duplication when setting up test scenarios.
"""
@staticmethod
def create_collection_binding_mock(
binding_id: str = "binding-123",
provider_name: str = "openai",
model_name: str = "text-embedding-ada-002",
collection_name: str = "collection-abc",
collection_type: str = "dataset",
created_at=None,
**kwargs,
) -> Mock:
"""
Create a mock DatasetCollectionBinding with specified attributes.
Args:
binding_id: Unique identifier for the binding
provider_name: Name of the embedding model provider (e.g., "openai", "cohere")
model_name: Name of the embedding model (e.g., "text-embedding-ada-002")
collection_name: Name of the vector database collection
collection_type: Type of collection (default: "dataset")
created_at: Optional datetime for creation timestamp
**kwargs: Additional attributes to set on the mock
Returns:
Mock object configured as a DatasetCollectionBinding instance
"""
binding = Mock(spec=DatasetCollectionBinding)
binding.id = binding_id
binding.provider_name = provider_name
binding.model_name = model_name
binding.collection_name = collection_name
binding.type = collection_type
binding.created_at = created_at
for key, value in kwargs.items():
setattr(binding, key, value)
return binding
@staticmethod
def create_dataset_mock(
dataset_id: str = "dataset-123",
**kwargs,
) -> Mock:
"""
Create a mock Dataset for testing collection name generation.
Args:
dataset_id: Unique identifier for the dataset
**kwargs: Additional attributes to set on the mock
Returns:
Mock object configured as a Dataset instance
"""
dataset = Mock(spec=Dataset)
dataset.id = dataset_id
for key, value in kwargs.items():
setattr(dataset, key, value)
return dataset
# ============================================================================
# Tests for get_dataset_collection_binding
# ============================================================================
class TestDatasetCollectionBindingServiceGetBinding:
"""
Comprehensive unit tests for DatasetCollectionBindingService.get_dataset_collection_binding method.
This test class covers the main collection binding retrieval/creation functionality,
including various provider/model combinations, collection types, and edge cases.
The get_dataset_collection_binding method:
1. Queries for existing binding by provider_name, model_name, and collection_type
2. Orders results by created_at (ascending) and takes the first match
3. If no binding exists, creates a new one with:
- The provided provider_name and model_name
- A generated collection_name using Dataset.gen_collection_name_by_id
- The provided collection_type
4. Adds the new binding to the database session and commits
5. Returns the binding (either existing or newly created)
Test scenarios include:
- Retrieving existing bindings
- Creating new bindings when none exist
- Different collection types
- Database transaction handling
- Collection name generation
"""
@pytest.fixture
def mock_db_session(self):
"""
Mock database session for testing database operations.
Provides a mocked database session that can be used to verify:
- Query construction and execution
- Add operations for new bindings
- Commit operations for transaction completion
The mock is configured to return a query builder that supports
chaining operations like .where(), .order_by(), and .first().
"""
with patch("services.dataset_service.db.session") as mock_db:
yield mock_db
def test_get_dataset_collection_binding_existing_binding_success(self, mock_db_session):
"""
Test successful retrieval of an existing collection binding.
Verifies that when a binding already exists in the database for the given
provider, model, and collection type, the method returns the existing binding
without creating a new one.
This test ensures:
- The query is constructed correctly with all three filters
- Results are ordered by created_at
- The first matching binding is returned
- No new binding is created (db.session.add is not called)
- No commit is performed (db.session.commit is not called)
"""
# Arrange
provider_name = "openai"
model_name = "text-embedding-ada-002"
collection_type = "dataset"
existing_binding = DatasetCollectionBindingTestDataFactory.create_collection_binding_mock(
binding_id="binding-123",
provider_name=provider_name,
model_name=model_name,
collection_type=collection_type,
)
# Mock the query chain: query().where().order_by().first()
mock_query = Mock()
mock_where = Mock()
mock_order_by = Mock()
mock_query.where.return_value = mock_where
mock_where.order_by.return_value = mock_order_by
mock_order_by.first.return_value = existing_binding
mock_db_session.query.return_value = mock_query
# Act
result = DatasetCollectionBindingService.get_dataset_collection_binding(
provider_name=provider_name, model_name=model_name, collection_type=collection_type
)
# Assert
assert result == existing_binding
assert result.id == "binding-123"
assert result.provider_name == provider_name
assert result.model_name == model_name
assert result.type == collection_type
# Verify query was constructed correctly
# The query should be constructed with DatasetCollectionBinding as the model
mock_db_session.query.assert_called_once_with(DatasetCollectionBinding)
# Verify the where clause was applied to filter by provider, model, and type
mock_query.where.assert_called_once()
# Verify the results were ordered by created_at (ascending)
# This ensures we get the oldest binding if multiple exist
mock_where.order_by.assert_called_once()
# Verify no new binding was created
# Since an existing binding was found, we should not create a new one
mock_db_session.add.assert_not_called()
# Verify no commit was performed
# Since no new binding was created, no database transaction is needed
mock_db_session.commit.assert_not_called()
def test_get_dataset_collection_binding_create_new_binding_success(self, mock_db_session):
"""
Test successful creation of a new collection binding when none exists.
Verifies that when no binding exists in the database for the given
provider, model, and collection type, the method creates a new binding
with a generated collection name and commits it to the database.
This test ensures:
- The query returns None (no existing binding)
- A new DatasetCollectionBinding is created with correct attributes
- Dataset.gen_collection_name_by_id is called to generate collection name
- The new binding is added to the database session
- The transaction is committed
- The newly created binding is returned
"""
# Arrange
provider_name = "cohere"
model_name = "embed-english-v3.0"
collection_type = "dataset"
generated_collection_name = "collection-generated-xyz"
# Mock the query chain to return None (no existing binding)
mock_query = Mock()
mock_where = Mock()
mock_order_by = Mock()
mock_query.where.return_value = mock_where
mock_where.order_by.return_value = mock_order_by
mock_order_by.first.return_value = None # No existing binding
mock_db_session.query.return_value = mock_query
# Mock Dataset.gen_collection_name_by_id to return a generated name
with patch("services.dataset_service.Dataset.gen_collection_name_by_id") as mock_gen_name:
mock_gen_name.return_value = generated_collection_name
# Mock uuid.uuid4 for the collection name generation
mock_uuid = "test-uuid-123"
with patch("services.dataset_service.uuid.uuid4", return_value=mock_uuid):
# Act
result = DatasetCollectionBindingService.get_dataset_collection_binding(
provider_name=provider_name, model_name=model_name, collection_type=collection_type
)
# Assert
assert result is not None
assert result.provider_name == provider_name
assert result.model_name == model_name
assert result.type == collection_type
assert result.collection_name == generated_collection_name
# Verify Dataset.gen_collection_name_by_id was called with the generated UUID
# This method generates a unique collection name based on the UUID
# The UUID is converted to string before passing to the method
mock_gen_name.assert_called_once_with(str(mock_uuid))
# Verify new binding was added to the database session
# The add method should be called exactly once with the new binding instance
mock_db_session.add.assert_called_once()
# Extract the binding that was added to verify its properties
added_binding = mock_db_session.add.call_args[0][0]
# Verify the added binding is an instance of DatasetCollectionBinding
# This ensures we're creating the correct type of object
assert isinstance(added_binding, DatasetCollectionBinding)
# Verify all the binding properties are set correctly
# These should match the input parameters to the method
assert added_binding.provider_name == provider_name
assert added_binding.model_name == model_name
assert added_binding.type == collection_type
# Verify the collection name was set from the generated name
# This ensures the binding has a valid collection identifier
assert added_binding.collection_name == generated_collection_name
# Verify the transaction was committed
# This ensures the new binding is persisted to the database
mock_db_session.commit.assert_called_once()
def test_get_dataset_collection_binding_different_collection_type(self, mock_db_session):
"""
Test retrieval with a different collection type (not "dataset").
Verifies that the method correctly filters by collection_type, allowing
different types of collections to coexist with the same provider/model
combination.
This test ensures:
- Collection type is properly used as a filter in the query
- Different collection types can have separate bindings
- The correct binding is returned based on type
"""
# Arrange
provider_name = "openai"
model_name = "text-embedding-ada-002"
collection_type = "custom_type"
existing_binding = DatasetCollectionBindingTestDataFactory.create_collection_binding_mock(
binding_id="binding-456",
provider_name=provider_name,
model_name=model_name,
collection_type=collection_type,
)
# Mock the query chain
mock_query = Mock()
mock_where = Mock()
mock_order_by = Mock()
mock_query.where.return_value = mock_where
mock_where.order_by.return_value = mock_order_by
mock_order_by.first.return_value = existing_binding
mock_db_session.query.return_value = mock_query
# Act
result = DatasetCollectionBindingService.get_dataset_collection_binding(
provider_name=provider_name, model_name=model_name, collection_type=collection_type
)
# Assert
assert result == existing_binding
assert result.type == collection_type
# Verify query was constructed with the correct type filter
mock_db_session.query.assert_called_once_with(DatasetCollectionBinding)
mock_query.where.assert_called_once()
def test_get_dataset_collection_binding_default_collection_type(self, mock_db_session):
"""
Test retrieval with default collection type ("dataset").
Verifies that when collection_type is not provided, it defaults to "dataset"
as specified in the method signature.
This test ensures:
- The default value "dataset" is used when type is not specified
- The query correctly filters by the default type
"""
# Arrange
provider_name = "openai"
model_name = "text-embedding-ada-002"
# collection_type defaults to "dataset" in method signature
existing_binding = DatasetCollectionBindingTestDataFactory.create_collection_binding_mock(
binding_id="binding-789",
provider_name=provider_name,
model_name=model_name,
collection_type="dataset", # Default type
)
# Mock the query chain
mock_query = Mock()
mock_where = Mock()
mock_order_by = Mock()
mock_query.where.return_value = mock_where
mock_where.order_by.return_value = mock_order_by
mock_order_by.first.return_value = existing_binding
mock_db_session.query.return_value = mock_query
# Act - call without specifying collection_type (uses default)
result = DatasetCollectionBindingService.get_dataset_collection_binding(
provider_name=provider_name, model_name=model_name
)
# Assert
assert result == existing_binding
assert result.type == "dataset"
# Verify query was constructed correctly
mock_db_session.query.assert_called_once_with(DatasetCollectionBinding)
def test_get_dataset_collection_binding_different_provider_model_combination(self, mock_db_session):
"""
Test retrieval with different provider/model combinations.
Verifies that bindings are correctly filtered by both provider_name and
model_name, ensuring that different model combinations have separate bindings.
This test ensures:
- Provider and model are both used as filters
- Different combinations result in different bindings
- The correct binding is returned for each combination
"""
# Arrange
provider_name = "huggingface"
model_name = "sentence-transformers/all-MiniLM-L6-v2"
collection_type = "dataset"
existing_binding = DatasetCollectionBindingTestDataFactory.create_collection_binding_mock(
binding_id="binding-hf-123",
provider_name=provider_name,
model_name=model_name,
collection_type=collection_type,
)
# Mock the query chain
mock_query = Mock()
mock_where = Mock()
mock_order_by = Mock()
mock_query.where.return_value = mock_where
mock_where.order_by.return_value = mock_order_by
mock_order_by.first.return_value = existing_binding
mock_db_session.query.return_value = mock_query
# Act
result = DatasetCollectionBindingService.get_dataset_collection_binding(
provider_name=provider_name, model_name=model_name, collection_type=collection_type
)
# Assert
assert result == existing_binding
assert result.provider_name == provider_name
assert result.model_name == model_name
# Verify query filters were applied correctly
# The query should filter by both provider_name and model_name
# This ensures different model combinations have separate bindings
mock_db_session.query.assert_called_once_with(DatasetCollectionBinding)
# Verify the where clause was applied with all three filters:
# - provider_name filter
# - model_name filter
# - collection_type filter
mock_query.where.assert_called_once()
# ============================================================================
# Tests for get_dataset_collection_binding_by_id_and_type
# ============================================================================
# This section contains tests for the get_dataset_collection_binding_by_id_and_type
# method, which retrieves a specific collection binding by its ID and type.
#
# Key differences from get_dataset_collection_binding:
# 1. This method queries by ID and type, not by provider/model/type
# 2. This method does NOT create a new binding if one doesn't exist
# 3. This method raises ValueError if the binding is not found
# 4. This method is typically used when you already know the binding ID
#
# Use cases:
# - Retrieving a binding that was previously created
# - Validating that a binding exists before using it
# - Accessing binding metadata when you have the ID
#
# ============================================================================
class TestDatasetCollectionBindingServiceGetBindingByIdAndType:
"""
Comprehensive unit tests for DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type method.
This test class covers collection binding retrieval by ID and type,
including success scenarios and error handling for missing bindings.
The get_dataset_collection_binding_by_id_and_type method:
1. Queries for a binding by collection_binding_id and collection_type
2. Orders results by created_at (ascending) and takes the first match
3. If no binding exists, raises ValueError("Dataset collection binding not found")
4. Returns the found binding
Unlike get_dataset_collection_binding, this method does NOT create a new
binding if one doesn't exist - it only retrieves existing bindings.
Test scenarios include:
- Successful retrieval of existing bindings
- Error handling for missing bindings
- Different collection types
- Default collection type behavior
"""
@pytest.fixture
def mock_db_session(self):
"""
Mock database session for testing database operations.
Provides a mocked database session that can be used to verify:
- Query construction with ID and type filters
- Ordering by created_at
- First result retrieval
The mock is configured to return a query builder that supports
chaining operations like .where(), .order_by(), and .first().
"""
with patch("services.dataset_service.db.session") as mock_db:
yield mock_db
def test_get_dataset_collection_binding_by_id_and_type_success(self, mock_db_session):
"""
Test successful retrieval of a collection binding by ID and type.
Verifies that when a binding exists in the database with the given
ID and collection type, the method returns the binding.
This test ensures:
- The query is constructed correctly with ID and type filters
- Results are ordered by created_at
- The first matching binding is returned
- No error is raised
"""
# Arrange
collection_binding_id = "binding-123"
collection_type = "dataset"
existing_binding = DatasetCollectionBindingTestDataFactory.create_collection_binding_mock(
binding_id=collection_binding_id,
provider_name="openai",
model_name="text-embedding-ada-002",
collection_type=collection_type,
)
# Mock the query chain: query().where().order_by().first()
mock_query = Mock()
mock_where = Mock()
mock_order_by = Mock()
mock_query.where.return_value = mock_where
mock_where.order_by.return_value = mock_order_by
mock_order_by.first.return_value = existing_binding
mock_db_session.query.return_value = mock_query
# Act
result = DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(
collection_binding_id=collection_binding_id, collection_type=collection_type
)
# Assert
assert result == existing_binding
assert result.id == collection_binding_id
assert result.type == collection_type
# Verify query was constructed correctly
mock_db_session.query.assert_called_once_with(DatasetCollectionBinding)
mock_query.where.assert_called_once()
mock_where.order_by.assert_called_once()
def test_get_dataset_collection_binding_by_id_and_type_not_found_error(self, mock_db_session):
"""
Test error handling when binding is not found.
Verifies that when no binding exists in the database with the given
ID and collection type, the method raises a ValueError with the
message "Dataset collection binding not found".
This test ensures:
- The query returns None (no existing binding)
- ValueError is raised with the correct message
- No binding is returned
"""
# Arrange
collection_binding_id = "non-existent-binding"
collection_type = "dataset"
# Mock the query chain to return None (no existing binding)
mock_query = Mock()
mock_where = Mock()
mock_order_by = Mock()
mock_query.where.return_value = mock_where
mock_where.order_by.return_value = mock_order_by
mock_order_by.first.return_value = None # No existing binding
mock_db_session.query.return_value = mock_query
# Act & Assert
with pytest.raises(ValueError, match="Dataset collection binding not found"):
DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(
collection_binding_id=collection_binding_id, collection_type=collection_type
)
# Verify query was attempted
mock_db_session.query.assert_called_once_with(DatasetCollectionBinding)
mock_query.where.assert_called_once()
def test_get_dataset_collection_binding_by_id_and_type_different_collection_type(self, mock_db_session):
"""
Test retrieval with a different collection type.
Verifies that the method correctly filters by collection_type, ensuring
that bindings with the same ID but different types are treated as
separate entities.
This test ensures:
- Collection type is properly used as a filter in the query
- Different collection types can have separate bindings with same ID
- The correct binding is returned based on type
"""
# Arrange
collection_binding_id = "binding-456"
collection_type = "custom_type"
existing_binding = DatasetCollectionBindingTestDataFactory.create_collection_binding_mock(
binding_id=collection_binding_id,
provider_name="cohere",
model_name="embed-english-v3.0",
collection_type=collection_type,
)
# Mock the query chain
mock_query = Mock()
mock_where = Mock()
mock_order_by = Mock()
mock_query.where.return_value = mock_where
mock_where.order_by.return_value = mock_order_by
mock_order_by.first.return_value = existing_binding
mock_db_session.query.return_value = mock_query
# Act
result = DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(
collection_binding_id=collection_binding_id, collection_type=collection_type
)
# Assert
assert result == existing_binding
assert result.id == collection_binding_id
assert result.type == collection_type
# Verify query was constructed with the correct type filter
mock_db_session.query.assert_called_once_with(DatasetCollectionBinding)
mock_query.where.assert_called_once()
def test_get_dataset_collection_binding_by_id_and_type_default_collection_type(self, mock_db_session):
"""
Test retrieval with default collection type ("dataset").
Verifies that when collection_type is not provided, it defaults to "dataset"
as specified in the method signature.
This test ensures:
- The default value "dataset" is used when type is not specified
- The query correctly filters by the default type
- The correct binding is returned
"""
# Arrange
collection_binding_id = "binding-789"
# collection_type defaults to "dataset" in method signature
existing_binding = DatasetCollectionBindingTestDataFactory.create_collection_binding_mock(
binding_id=collection_binding_id,
provider_name="openai",
model_name="text-embedding-ada-002",
collection_type="dataset", # Default type
)
# Mock the query chain
mock_query = Mock()
mock_where = Mock()
mock_order_by = Mock()
mock_query.where.return_value = mock_where
mock_where.order_by.return_value = mock_order_by
mock_order_by.first.return_value = existing_binding
mock_db_session.query.return_value = mock_query
# Act - call without specifying collection_type (uses default)
result = DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(
collection_binding_id=collection_binding_id
)
# Assert
assert result == existing_binding
assert result.id == collection_binding_id
assert result.type == "dataset"
# Verify query was constructed correctly
mock_db_session.query.assert_called_once_with(DatasetCollectionBinding)
mock_query.where.assert_called_once()
def test_get_dataset_collection_binding_by_id_and_type_wrong_type_error(self, mock_db_session):
"""
Test error handling when binding exists but with wrong collection type.
Verifies that when a binding exists with the given ID but a different
collection type, the method raises a ValueError because the binding
doesn't match both the ID and type criteria.
This test ensures:
- The query correctly filters by both ID and type
- Bindings with matching ID but different type are not returned
- ValueError is raised when no matching binding is found
"""
# Arrange
collection_binding_id = "binding-123"
collection_type = "dataset"
# Mock the query chain to return None (binding exists but with different type)
mock_query = Mock()
mock_where = Mock()
mock_order_by = Mock()
mock_query.where.return_value = mock_where
mock_where.order_by.return_value = mock_order_by
mock_order_by.first.return_value = None # No matching binding
mock_db_session.query.return_value = mock_query
# Act & Assert
with pytest.raises(ValueError, match="Dataset collection binding not found"):
DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(
collection_binding_id=collection_binding_id, collection_type=collection_type
)
# Verify query was attempted with both ID and type filters
# The query should filter by both collection_binding_id and collection_type
# This ensures we only get bindings that match both criteria
mock_db_session.query.assert_called_once_with(DatasetCollectionBinding)
# Verify the where clause was applied with both filters:
# - collection_binding_id filter (exact match)
# - collection_type filter (exact match)
mock_query.where.assert_called_once()
# Note: The order_by and first() calls are also part of the query chain,
# but we don't need to verify them separately since they're part of the
# standard query pattern used by both methods in this service.
# ============================================================================
# Additional Test Scenarios and Edge Cases
# ============================================================================
# The following section could contain additional test scenarios if needed:
#
# Potential additional tests:
# 1. Test with multiple existing bindings (verify ordering by created_at)
# 2. Test with very long provider/model names (boundary testing)
# 3. Test with special characters in provider/model names
# 4. Test concurrent binding creation (thread safety)
# 5. Test database rollback scenarios
# 6. Test with None values for optional parameters
# 7. Test with empty strings for required parameters
# 8. Test collection name generation uniqueness
# 9. Test with different UUID formats
# 10. Test query performance with large datasets
#
# These scenarios are not currently implemented but could be added if needed
# based on real-world usage patterns or discovered edge cases.
#
# ============================================================================
# ============================================================================
# Integration Notes and Best Practices
# ============================================================================
#
# When using DatasetCollectionBindingService in production code, consider:
#
# 1. Error Handling:
# - Always handle ValueError exceptions when calling
# get_dataset_collection_binding_by_id_and_type
# - Check return values from get_dataset_collection_binding to ensure
# bindings were created successfully
#
# 2. Performance Considerations:
# - The service queries the database on every call, so consider caching
# bindings if they're accessed frequently
# - Collection bindings are typically long-lived, so caching is safe
#
# 3. Transaction Management:
# - New bindings are automatically committed to the database
# - If you need to rollback, ensure you're within a transaction context
#
# 4. Collection Type Usage:
# - Use "dataset" for standard dataset collections
# - Use custom types only when you need to separate collections by purpose
# - Be consistent with collection type naming across your application
#
# 5. Provider and Model Naming:
# - Use consistent provider names (e.g., "openai", not "OpenAI" or "OPENAI")
# - Use exact model names as provided by the model provider
# - These names are case-sensitive and must match exactly
#
# ============================================================================
# ============================================================================
# Database Schema Reference
# ============================================================================
#
# The DatasetCollectionBinding model has the following structure:
#
# - id: StringUUID (primary key, auto-generated)
# - provider_name: String(255) (required, e.g., "openai", "cohere")
# - model_name: String(255) (required, e.g., "text-embedding-ada-002")
# - type: String(40) (required, default: "dataset")
# - collection_name: String(64) (required, unique collection identifier)
# - created_at: DateTime (auto-generated timestamp)
#
# Indexes:
# - Primary key on id
# - Composite index on (provider_name, model_name) for efficient lookups
#
# Relationships:
# - One binding can be referenced by multiple datasets
# - Datasets reference bindings via collection_binding_id
#
# ============================================================================
# ============================================================================
# Mocking Strategy Documentation
# ============================================================================
#
# This test suite uses extensive mocking to isolate the unit under test.
# Here's how the mocking strategy works:
#
# 1. Database Session Mocking:
# - db.session is patched to prevent actual database access
# - Query chains are mocked to return predictable results
# - Add and commit operations are tracked for verification
#
# 2. Query Chain Mocking:
# - query() returns a mock query object
# - where() returns a mock where object
# - order_by() returns a mock order_by object
# - first() returns the final result (binding or None)
#
# 3. UUID Generation Mocking:
# - uuid.uuid4() is mocked to return predictable UUIDs
# - This ensures collection names are generated consistently in tests
#
# 4. Collection Name Generation Mocking:
# - Dataset.gen_collection_name_by_id() is mocked
# - This allows us to verify the method is called correctly
# - We can control the generated collection name for testing
#
# Benefits of this approach:
# - Tests run quickly (no database I/O)
# - Tests are deterministic (no random UUIDs)
# - Tests are isolated (no side effects)
# - Tests are maintainable (clear mock setup)
#
# ============================================================================

View File

@ -0,0 +1,920 @@
"""
Extensive unit tests for ``ExternalDatasetService``.
This module focuses on the *external dataset service* surface area, which is responsible
for integrating with **external knowledge APIs** and wiring them into Dify datasets.
The goal of this test suite is twofold:
- Provide **highconfidence regression coverage** for all public helpers on
``ExternalDatasetService``.
- Serve as **executable documentation** for how external API integration is expected
to behave in different scenarios (happy paths, validation failures, and error codes).
The file intentionally contains **rich comments and generous spacing** in order to make
each scenario easy to scan during reviews.
"""
from __future__ import annotations
from types import SimpleNamespace
from typing import Any, cast
from unittest.mock import MagicMock, Mock, patch
import httpx
import pytest
from constants import HIDDEN_VALUE
from models.dataset import Dataset, ExternalKnowledgeApis, ExternalKnowledgeBindings
from services.entities.external_knowledge_entities.external_knowledge_entities import (
Authorization,
AuthorizationConfig,
ExternalKnowledgeApiSetting,
)
from services.errors.dataset import DatasetNameDuplicateError
from services.external_knowledge_service import ExternalDatasetService
class ExternalDatasetTestDataFactory:
"""
Factory helpers for building *lightweight* mocks for external knowledge tests.
These helpers are intentionally small and explicit:
- They avoid pulling in unnecessary fixtures.
- They reflect the minimal contract that the service under test cares about.
"""
@staticmethod
def create_external_api(
api_id: str = "api-123",
tenant_id: str = "tenant-1",
name: str = "Test API",
description: str = "Description",
settings: dict | None = None,
) -> ExternalKnowledgeApis:
"""
Create a concrete ``ExternalKnowledgeApis`` instance with minimal fields.
Using the real SQLAlchemy model (instead of a pure Mock) makes it easier to
exercise ``settings_dict`` and other convenience properties if needed.
"""
instance = ExternalKnowledgeApis(
tenant_id=tenant_id,
name=name,
description=description,
settings=None if settings is None else cast(str, pytest.approx), # type: ignore[assignment]
)
# Overwrite generated id for determinism in assertions.
instance.id = api_id
return instance
@staticmethod
def create_dataset(
dataset_id: str = "ds-1",
tenant_id: str = "tenant-1",
name: str = "External Dataset",
provider: str = "external",
) -> Dataset:
"""
Build a small ``Dataset`` instance representing an external dataset.
"""
dataset = Dataset(
tenant_id=tenant_id,
name=name,
description="",
provider=provider,
created_by="user-1",
)
dataset.id = dataset_id
return dataset
@staticmethod
def create_external_binding(
tenant_id: str = "tenant-1",
dataset_id: str = "ds-1",
api_id: str = "api-1",
external_knowledge_id: str = "knowledge-1",
) -> ExternalKnowledgeBindings:
"""
Small helper for a binding between dataset and external knowledge API.
"""
binding = ExternalKnowledgeBindings(
tenant_id=tenant_id,
dataset_id=dataset_id,
external_knowledge_api_id=api_id,
external_knowledge_id=external_knowledge_id,
created_by="user-1",
)
return binding
# ---------------------------------------------------------------------------
# get_external_knowledge_apis
# ---------------------------------------------------------------------------
class TestExternalDatasetServiceGetExternalKnowledgeApis:
"""
Tests for ``ExternalDatasetService.get_external_knowledge_apis``.
These tests focus on:
- Basic pagination wiring via ``db.paginate``.
- Optional search keyword behaviour.
"""
@pytest.fixture
def mock_db_paginate(self):
"""
Patch ``db.paginate`` so we do not touch the real database layer.
"""
with (
patch("services.external_knowledge_service.db.paginate") as mock_paginate,
patch("services.external_knowledge_service.select"),
):
yield mock_paginate
def test_get_external_knowledge_apis_basic_pagination(self, mock_db_paginate: MagicMock):
"""
It should return ``items`` and ``total`` coming from the paginate object.
"""
# Arrange
tenant_id = "tenant-1"
page = 1
per_page = 20
mock_items = [Mock(spec=ExternalKnowledgeApis), Mock(spec=ExternalKnowledgeApis)]
mock_pagination = SimpleNamespace(items=mock_items, total=42)
mock_db_paginate.return_value = mock_pagination
# Act
items, total = ExternalDatasetService.get_external_knowledge_apis(page, per_page, tenant_id)
# Assert
assert items is mock_items
assert total == 42
mock_db_paginate.assert_called_once()
call_kwargs = mock_db_paginate.call_args.kwargs
assert call_kwargs["page"] == page
assert call_kwargs["per_page"] == per_page
assert call_kwargs["max_per_page"] == 100
assert call_kwargs["error_out"] is False
def test_get_external_knowledge_apis_with_search_keyword(self, mock_db_paginate: MagicMock):
"""
When a search keyword is provided, the query should be adjusted
(we simply assert that paginate is still called and does not explode).
"""
# Arrange
tenant_id = "tenant-1"
page = 2
per_page = 10
search = "foo"
mock_pagination = SimpleNamespace(items=[], total=0)
mock_db_paginate.return_value = mock_pagination
# Act
items, total = ExternalDatasetService.get_external_knowledge_apis(page, per_page, tenant_id, search=search)
# Assert
assert items == []
assert total == 0
mock_db_paginate.assert_called_once()
# ---------------------------------------------------------------------------
# validate_api_list
# ---------------------------------------------------------------------------
class TestExternalDatasetServiceValidateApiList:
"""
Lightweight validation tests for ``validate_api_list``.
"""
def test_validate_api_list_success(self):
"""
A minimal valid configuration (endpoint + api_key) should pass.
"""
config = {"endpoint": "https://example.com", "api_key": "secret"}
# Act & Assert no exception expected
ExternalDatasetService.validate_api_list(config)
@pytest.mark.parametrize(
("config", "expected_message"),
[
({}, "api list is empty"),
({"api_key": "k"}, "endpoint is required"),
({"endpoint": "https://example.com"}, "api_key is required"),
],
)
def test_validate_api_list_failures(self, config: dict, expected_message: str):
"""
Invalid configs should raise ``ValueError`` with a clear message.
"""
with pytest.raises(ValueError, match=expected_message):
ExternalDatasetService.validate_api_list(config)
# ---------------------------------------------------------------------------
# create_external_knowledge_api & get/update/delete
# ---------------------------------------------------------------------------
class TestExternalDatasetServiceCrudExternalKnowledgeApi:
"""
CRUD tests for external knowledge API templates.
"""
@pytest.fixture
def mock_db_session(self):
"""
Patch ``db.session`` for all CRUD tests in this class.
"""
with patch("services.external_knowledge_service.db.session") as mock_session:
yield mock_session
def test_create_external_knowledge_api_success(self, mock_db_session: MagicMock):
"""
``create_external_knowledge_api`` should persist a new record
when settings are present and valid.
"""
tenant_id = "tenant-1"
user_id = "user-1"
args = {
"name": "API",
"description": "desc",
"settings": {"endpoint": "https://api.example.com", "api_key": "secret"},
}
# We do not want to actually call the remote endpoint here, so we patch the validator.
with patch.object(ExternalDatasetService, "check_endpoint_and_api_key") as mock_check:
result = ExternalDatasetService.create_external_knowledge_api(tenant_id, user_id, args)
assert isinstance(result, ExternalKnowledgeApis)
mock_check.assert_called_once_with(args["settings"])
mock_db_session.add.assert_called_once()
mock_db_session.commit.assert_called_once()
def test_create_external_knowledge_api_missing_settings_raises(self, mock_db_session: MagicMock):
"""
Missing ``settings`` should result in a ``ValueError``.
"""
tenant_id = "tenant-1"
user_id = "user-1"
args = {"name": "API", "description": "desc"}
with pytest.raises(ValueError, match="settings is required"):
ExternalDatasetService.create_external_knowledge_api(tenant_id, user_id, args)
mock_db_session.add.assert_not_called()
mock_db_session.commit.assert_not_called()
def test_get_external_knowledge_api_found(self, mock_db_session: MagicMock):
"""
``get_external_knowledge_api`` should return the first matching record.
"""
api = Mock(spec=ExternalKnowledgeApis)
mock_db_session.query.return_value.filter_by.return_value.first.return_value = api
result = ExternalDatasetService.get_external_knowledge_api("api-id")
assert result is api
def test_get_external_knowledge_api_not_found_raises(self, mock_db_session: MagicMock):
"""
When the record is absent, a ``ValueError`` is raised.
"""
mock_db_session.query.return_value.filter_by.return_value.first.return_value = None
with pytest.raises(ValueError, match="api template not found"):
ExternalDatasetService.get_external_knowledge_api("missing-id")
def test_update_external_knowledge_api_success_with_hidden_api_key(self, mock_db_session: MagicMock):
"""
Updating an API should keep the existing API key when the special hidden
value placeholder is sent from the UI.
"""
tenant_id = "tenant-1"
user_id = "user-1"
api_id = "api-1"
existing_api = Mock(spec=ExternalKnowledgeApis)
existing_api.settings_dict = {"api_key": "stored-key"}
existing_api.settings = '{"api_key":"stored-key"}'
mock_db_session.query.return_value.filter_by.return_value.first.return_value = existing_api
args = {
"name": "New Name",
"description": "New Desc",
"settings": {"endpoint": "https://api.example.com", "api_key": HIDDEN_VALUE},
}
result = ExternalDatasetService.update_external_knowledge_api(tenant_id, user_id, api_id, args)
assert result is existing_api
# The placeholder should be replaced with stored key.
assert args["settings"]["api_key"] == "stored-key"
mock_db_session.commit.assert_called_once()
def test_update_external_knowledge_api_not_found_raises(self, mock_db_session: MagicMock):
"""
Updating a nonexistent API template should raise ``ValueError``.
"""
mock_db_session.query.return_value.filter_by.return_value.first.return_value = None
with pytest.raises(ValueError, match="api template not found"):
ExternalDatasetService.update_external_knowledge_api(
tenant_id="tenant-1",
user_id="user-1",
external_knowledge_api_id="missing-id",
args={"name": "n", "description": "d", "settings": {}},
)
def test_delete_external_knowledge_api_success(self, mock_db_session: MagicMock):
"""
``delete_external_knowledge_api`` should delete and commit when found.
"""
api = Mock(spec=ExternalKnowledgeApis)
mock_db_session.query.return_value.filter_by.return_value.first.return_value = api
ExternalDatasetService.delete_external_knowledge_api("tenant-1", "api-1")
mock_db_session.delete.assert_called_once_with(api)
mock_db_session.commit.assert_called_once()
def test_delete_external_knowledge_api_not_found_raises(self, mock_db_session: MagicMock):
"""
Deletion of a missing template should raise ``ValueError``.
"""
mock_db_session.query.return_value.filter_by.return_value.first.return_value = None
with pytest.raises(ValueError, match="api template not found"):
ExternalDatasetService.delete_external_knowledge_api("tenant-1", "missing")
# ---------------------------------------------------------------------------
# external_knowledge_api_use_check & binding lookups
# ---------------------------------------------------------------------------
class TestExternalDatasetServiceUsageAndBindings:
"""
Tests for usage checks and dataset binding retrieval.
"""
@pytest.fixture
def mock_db_session(self):
with patch("services.external_knowledge_service.db.session") as mock_session:
yield mock_session
def test_external_knowledge_api_use_check_in_use(self, mock_db_session: MagicMock):
"""
When there are bindings, ``external_knowledge_api_use_check`` returns True and count.
"""
mock_db_session.query.return_value.filter_by.return_value.count.return_value = 3
in_use, count = ExternalDatasetService.external_knowledge_api_use_check("api-1")
assert in_use is True
assert count == 3
def test_external_knowledge_api_use_check_not_in_use(self, mock_db_session: MagicMock):
"""
Zero bindings should return ``(False, 0)``.
"""
mock_db_session.query.return_value.filter_by.return_value.count.return_value = 0
in_use, count = ExternalDatasetService.external_knowledge_api_use_check("api-1")
assert in_use is False
assert count == 0
def test_get_external_knowledge_binding_with_dataset_id_found(self, mock_db_session: MagicMock):
"""
Binding lookup should return the first record when present.
"""
binding = Mock(spec=ExternalKnowledgeBindings)
mock_db_session.query.return_value.filter_by.return_value.first.return_value = binding
result = ExternalDatasetService.get_external_knowledge_binding_with_dataset_id("tenant-1", "ds-1")
assert result is binding
def test_get_external_knowledge_binding_with_dataset_id_not_found_raises(self, mock_db_session: MagicMock):
"""
Missing binding should result in a ``ValueError``.
"""
mock_db_session.query.return_value.filter_by.return_value.first.return_value = None
with pytest.raises(ValueError, match="external knowledge binding not found"):
ExternalDatasetService.get_external_knowledge_binding_with_dataset_id("tenant-1", "ds-1")
# ---------------------------------------------------------------------------
# document_create_args_validate
# ---------------------------------------------------------------------------
class TestExternalDatasetServiceDocumentCreateArgsValidate:
"""
Tests for ``document_create_args_validate``.
"""
@pytest.fixture
def mock_db_session(self):
with patch("services.external_knowledge_service.db.session") as mock_session:
yield mock_session
def test_document_create_args_validate_success(self, mock_db_session: MagicMock):
"""
All required custom parameters present validation should pass.
"""
external_api = Mock(spec=ExternalKnowledgeApis)
external_api.settings = json_settings = (
'[{"document_process_setting":[{"name":"foo","required":true},{"name":"bar","required":false}]}]'
)
# Raw string; the service itself calls json.loads on it
mock_db_session.query.return_value.filter_by.return_value.first.return_value = external_api
process_parameter = {"foo": "value", "bar": "optional"}
# Act & Assert no exception
ExternalDatasetService.document_create_args_validate("tenant-1", "api-1", process_parameter)
assert json_settings in external_api.settings # simple sanity check on our test data
def test_document_create_args_validate_missing_template_raises(self, mock_db_session: MagicMock):
"""
When the referenced API template is missing, a ``ValueError`` is raised.
"""
mock_db_session.query.return_value.filter_by.return_value.first.return_value = None
with pytest.raises(ValueError, match="api template not found"):
ExternalDatasetService.document_create_args_validate("tenant-1", "missing", {})
def test_document_create_args_validate_missing_required_parameter_raises(self, mock_db_session: MagicMock):
"""
Required document process parameters must be supplied.
"""
external_api = Mock(spec=ExternalKnowledgeApis)
external_api.settings = (
'[{"document_process_setting":[{"name":"foo","required":true},{"name":"bar","required":false}]}]'
)
mock_db_session.query.return_value.filter_by.return_value.first.return_value = external_api
process_parameter = {"bar": "present"} # missing "foo"
with pytest.raises(ValueError, match="foo is required"):
ExternalDatasetService.document_create_args_validate("tenant-1", "api-1", process_parameter)
# ---------------------------------------------------------------------------
# process_external_api
# ---------------------------------------------------------------------------
class TestExternalDatasetServiceProcessExternalApi:
"""
Tests focused on the HTTP request assembly and method mapping behaviour.
"""
def test_process_external_api_valid_method_post(self):
"""
For a supported HTTP verb we should delegate to the correct ``ssrf_proxy`` function.
"""
settings = ExternalKnowledgeApiSetting(
url="https://example.com/path",
request_method="POST",
headers={"X-Test": "1"},
params={"foo": "bar"},
)
fake_response = httpx.Response(200)
with patch("services.external_knowledge_service.ssrf_proxy.post") as mock_post:
mock_post.return_value = fake_response
result = ExternalDatasetService.process_external_api(settings, files=None)
assert result is fake_response
mock_post.assert_called_once()
kwargs = mock_post.call_args.kwargs
assert kwargs["url"] == settings.url
assert kwargs["headers"] == settings.headers
assert kwargs["follow_redirects"] is True
assert "data" in kwargs
def test_process_external_api_invalid_method_raises(self):
"""
An unsupported HTTP verb should raise ``InvalidHttpMethodError``.
"""
settings = ExternalKnowledgeApiSetting(
url="https://example.com",
request_method="INVALID",
headers=None,
params={},
)
from core.workflow.nodes.http_request.exc import InvalidHttpMethodError
with pytest.raises(InvalidHttpMethodError):
ExternalDatasetService.process_external_api(settings, files=None)
# ---------------------------------------------------------------------------
# assembling_headers
# ---------------------------------------------------------------------------
class TestExternalDatasetServiceAssemblingHeaders:
"""
Tests for header assembly based on different authentication flavours.
"""
def test_assembling_headers_bearer_token(self):
"""
For bearer auth we expect ``Authorization: Bearer <key>`` by default.
"""
auth = Authorization(
type="api-key",
config=AuthorizationConfig(type="bearer", api_key="secret", header=None),
)
headers = ExternalDatasetService.assembling_headers(auth)
assert headers["Authorization"] == "Bearer secret"
def test_assembling_headers_basic_token_with_custom_header(self):
"""
For basic auth we honour the configured header name.
"""
auth = Authorization(
type="api-key",
config=AuthorizationConfig(type="basic", api_key="abc123", header="X-Auth"),
)
headers = ExternalDatasetService.assembling_headers(auth, headers={"Existing": "1"})
assert headers["Existing"] == "1"
assert headers["X-Auth"] == "Basic abc123"
def test_assembling_headers_custom_type(self):
"""
Custom auth type should inject the raw API key.
"""
auth = Authorization(
type="api-key",
config=AuthorizationConfig(type="custom", api_key="raw-key", header="X-API-KEY"),
)
headers = ExternalDatasetService.assembling_headers(auth, headers=None)
assert headers["X-API-KEY"] == "raw-key"
def test_assembling_headers_missing_config_raises(self):
"""
Missing config object should be rejected.
"""
auth = Authorization(type="api-key", config=None)
with pytest.raises(ValueError, match="authorization config is required"):
ExternalDatasetService.assembling_headers(auth)
def test_assembling_headers_missing_api_key_raises(self):
"""
``api_key`` is required when type is ``api-key``.
"""
auth = Authorization(
type="api-key",
config=AuthorizationConfig(type="bearer", api_key=None, header="Authorization"),
)
with pytest.raises(ValueError, match="api_key is required"):
ExternalDatasetService.assembling_headers(auth)
def test_assembling_headers_no_auth_type_leaves_headers_unchanged(self):
"""
For ``no-auth`` we should not modify the headers mapping.
"""
auth = Authorization(type="no-auth", config=None)
base_headers = {"X": "1"}
result = ExternalDatasetService.assembling_headers(auth, headers=base_headers)
# A copy is returned, original is not mutated.
assert result == base_headers
assert result is not base_headers
# ---------------------------------------------------------------------------
# get_external_knowledge_api_settings
# ---------------------------------------------------------------------------
class TestExternalDatasetServiceGetExternalKnowledgeApiSettings:
"""
Simple shape test for ``get_external_knowledge_api_settings``.
"""
def test_get_external_knowledge_api_settings(self):
settings_dict: dict[str, Any] = {
"url": "https://example.com/retrieval",
"request_method": "post",
"headers": {"Content-Type": "application/json"},
"params": {"foo": "bar"},
}
result = ExternalDatasetService.get_external_knowledge_api_settings(settings_dict)
assert isinstance(result, ExternalKnowledgeApiSetting)
assert result.url == settings_dict["url"]
assert result.request_method == settings_dict["request_method"]
assert result.headers == settings_dict["headers"]
assert result.params == settings_dict["params"]
# ---------------------------------------------------------------------------
# create_external_dataset
# ---------------------------------------------------------------------------
class TestExternalDatasetServiceCreateExternalDataset:
"""
Tests around creating the external dataset and its binding row.
"""
@pytest.fixture
def mock_db_session(self):
with patch("services.external_knowledge_service.db.session") as mock_session:
yield mock_session
def test_create_external_dataset_success(self, mock_db_session: MagicMock):
"""
A brand new dataset name with valid external knowledge references
should create both the dataset and its binding.
"""
tenant_id = "tenant-1"
user_id = "user-1"
args = {
"name": "My Dataset",
"description": "desc",
"external_knowledge_api_id": "api-1",
"external_knowledge_id": "knowledge-1",
"external_retrieval_model": {"top_k": 3},
}
# No existing dataset with same name.
mock_db_session.query.return_value.filter_by.return_value.first.side_effect = [
None, # duplicatename check
Mock(spec=ExternalKnowledgeApis), # external knowledge api
]
dataset = ExternalDatasetService.create_external_dataset(tenant_id, user_id, args)
assert isinstance(dataset, Dataset)
assert dataset.provider == "external"
assert dataset.retrieval_model == args["external_retrieval_model"]
assert mock_db_session.add.call_count >= 2 # dataset + binding
mock_db_session.flush.assert_called_once()
mock_db_session.commit.assert_called_once()
def test_create_external_dataset_duplicate_name_raises(self, mock_db_session: MagicMock):
"""
When a dataset with the same name already exists,
``DatasetNameDuplicateError`` is raised.
"""
existing_dataset = Mock(spec=Dataset)
mock_db_session.query.return_value.filter_by.return_value.first.return_value = existing_dataset
args = {
"name": "Existing",
"external_knowledge_api_id": "api-1",
"external_knowledge_id": "knowledge-1",
}
with pytest.raises(DatasetNameDuplicateError):
ExternalDatasetService.create_external_dataset("tenant-1", "user-1", args)
mock_db_session.add.assert_not_called()
mock_db_session.commit.assert_not_called()
def test_create_external_dataset_missing_api_template_raises(self, mock_db_session: MagicMock):
"""
If the referenced external knowledge API does not exist, a ``ValueError`` is raised.
"""
# First call: duplicate name check not found.
mock_db_session.query.return_value.filter_by.return_value.first.side_effect = [
None,
None, # external knowledge api lookup
]
args = {
"name": "Dataset",
"external_knowledge_api_id": "missing",
"external_knowledge_id": "knowledge-1",
}
with pytest.raises(ValueError, match="api template not found"):
ExternalDatasetService.create_external_dataset("tenant-1", "user-1", args)
def test_create_external_dataset_missing_required_ids_raise(self, mock_db_session: MagicMock):
"""
``external_knowledge_id`` and ``external_knowledge_api_id`` are mandatory.
"""
# duplicate name check
mock_db_session.query.return_value.filter_by.return_value.first.side_effect = [
None,
Mock(spec=ExternalKnowledgeApis),
]
args_missing_knowledge_id = {
"name": "Dataset",
"external_knowledge_api_id": "api-1",
"external_knowledge_id": None,
}
with pytest.raises(ValueError, match="external_knowledge_id is required"):
ExternalDatasetService.create_external_dataset("tenant-1", "user-1", args_missing_knowledge_id)
args_missing_api_id = {
"name": "Dataset",
"external_knowledge_api_id": None,
"external_knowledge_id": "k-1",
}
with pytest.raises(ValueError, match="external_knowledge_api_id is required"):
ExternalDatasetService.create_external_dataset("tenant-1", "user-1", args_missing_api_id)
# ---------------------------------------------------------------------------
# fetch_external_knowledge_retrieval
# ---------------------------------------------------------------------------
class TestExternalDatasetServiceFetchExternalKnowledgeRetrieval:
"""
Tests for ``fetch_external_knowledge_retrieval`` which orchestrates
external retrieval requests and normalises the response payload.
"""
@pytest.fixture
def mock_db_session(self):
with patch("services.external_knowledge_service.db.session") as mock_session:
yield mock_session
def test_fetch_external_knowledge_retrieval_success(self, mock_db_session: MagicMock):
"""
With a valid binding and API template, records from the external
service should be returned when the HTTP response is 200.
"""
tenant_id = "tenant-1"
dataset_id = "ds-1"
query = "test query"
external_retrieval_parameters = {"top_k": 3, "score_threshold_enabled": True, "score_threshold": 0.5}
binding = ExternalDatasetTestDataFactory.create_external_binding(
tenant_id=tenant_id,
dataset_id=dataset_id,
api_id="api-1",
external_knowledge_id="knowledge-1",
)
api = Mock(spec=ExternalKnowledgeApis)
api.settings = '{"endpoint":"https://example.com","api_key":"secret"}'
# First query: binding; second query: api.
mock_db_session.query.return_value.filter_by.return_value.first.side_effect = [
binding,
api,
]
fake_records = [{"content": "doc", "score": 0.9}]
fake_response = Mock(spec=httpx.Response)
fake_response.status_code = 200
fake_response.json.return_value = {"records": fake_records}
metadata_condition = SimpleNamespace(model_dump=lambda: {"field": "value"})
with patch.object(ExternalDatasetService, "process_external_api", return_value=fake_response) as mock_process:
result = ExternalDatasetService.fetch_external_knowledge_retrieval(
tenant_id=tenant_id,
dataset_id=dataset_id,
query=query,
external_retrieval_parameters=external_retrieval_parameters,
metadata_condition=metadata_condition,
)
assert result == fake_records
mock_process.assert_called_once()
setting_arg = mock_process.call_args.args[0]
assert isinstance(setting_arg, ExternalKnowledgeApiSetting)
assert setting_arg.url.endswith("/retrieval")
def test_fetch_external_knowledge_retrieval_binding_not_found_raises(self, mock_db_session: MagicMock):
"""
Missing binding should raise ``ValueError``.
"""
mock_db_session.query.return_value.filter_by.return_value.first.return_value = None
with pytest.raises(ValueError, match="external knowledge binding not found"):
ExternalDatasetService.fetch_external_knowledge_retrieval(
tenant_id="tenant-1",
dataset_id="missing",
query="q",
external_retrieval_parameters={},
metadata_condition=None,
)
def test_fetch_external_knowledge_retrieval_missing_api_template_raises(self, mock_db_session: MagicMock):
"""
When the API template is missing or has no settings, a ``ValueError`` is raised.
"""
binding = ExternalDatasetTestDataFactory.create_external_binding()
mock_db_session.query.return_value.filter_by.return_value.first.side_effect = [
binding,
None,
]
with pytest.raises(ValueError, match="external api template not found"):
ExternalDatasetService.fetch_external_knowledge_retrieval(
tenant_id="tenant-1",
dataset_id="ds-1",
query="q",
external_retrieval_parameters={},
metadata_condition=None,
)
def test_fetch_external_knowledge_retrieval_non_200_status_returns_empty_list(self, mock_db_session: MagicMock):
"""
Non200 responses should be treated as an empty result set.
"""
binding = ExternalDatasetTestDataFactory.create_external_binding()
api = Mock(spec=ExternalKnowledgeApis)
api.settings = '{"endpoint":"https://example.com","api_key":"secret"}'
mock_db_session.query.return_value.filter_by.return_value.first.side_effect = [
binding,
api,
]
fake_response = Mock(spec=httpx.Response)
fake_response.status_code = 500
fake_response.json.return_value = {}
with patch.object(ExternalDatasetService, "process_external_api", return_value=fake_response):
result = ExternalDatasetService.fetch_external_knowledge_retrieval(
tenant_id="tenant-1",
dataset_id="ds-1",
query="q",
external_retrieval_parameters={},
metadata_condition=None,
)
assert result == []

View File

@ -0,0 +1,802 @@
"""
Unit tests for HitTestingService.
This module contains comprehensive unit tests for the HitTestingService class,
which handles retrieval testing operations for datasets, including internal
dataset retrieval and external knowledge base retrieval.
"""
from unittest.mock import MagicMock, Mock, patch
import pytest
from core.rag.models.document import Document
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from models import Account
from models.dataset import Dataset
from services.hit_testing_service import HitTestingService
class HitTestingTestDataFactory:
"""
Factory class for creating test data and mock objects for hit testing service tests.
This factory provides static methods to create mock objects for datasets, users,
documents, and retrieval records used in HitTestingService unit tests.
"""
@staticmethod
def create_dataset_mock(
dataset_id: str = "dataset-123",
tenant_id: str = "tenant-123",
provider: str = "vendor",
retrieval_model: dict | None = None,
**kwargs,
) -> Mock:
"""
Create a mock dataset with specified attributes.
Args:
dataset_id: Unique identifier for the dataset
tenant_id: Tenant identifier
provider: Dataset provider (vendor, external, etc.)
retrieval_model: Optional retrieval model configuration
**kwargs: Additional attributes to set on the mock
Returns:
Mock object configured as a Dataset instance
"""
dataset = Mock(spec=Dataset)
dataset.id = dataset_id
dataset.tenant_id = tenant_id
dataset.provider = provider
dataset.retrieval_model = retrieval_model
for key, value in kwargs.items():
setattr(dataset, key, value)
return dataset
@staticmethod
def create_user_mock(
user_id: str = "user-789",
tenant_id: str = "tenant-123",
**kwargs,
) -> Mock:
"""
Create a mock user (Account) with specified attributes.
Args:
user_id: Unique identifier for the user
tenant_id: Tenant identifier
**kwargs: Additional attributes to set on the mock
Returns:
Mock object configured as an Account instance
"""
user = Mock(spec=Account)
user.id = user_id
user.current_tenant_id = tenant_id
user.name = "Test User"
for key, value in kwargs.items():
setattr(user, key, value)
return user
@staticmethod
def create_document_mock(
content: str = "Test document content",
metadata: dict | None = None,
**kwargs,
) -> Mock:
"""
Create a mock Document from core.rag.models.document.
Args:
content: Document content/text
metadata: Optional metadata dictionary
**kwargs: Additional attributes to set on the mock
Returns:
Mock object configured as a Document instance
"""
document = Mock(spec=Document)
document.page_content = content
document.metadata = metadata or {}
for key, value in kwargs.items():
setattr(document, key, value)
return document
@staticmethod
def create_retrieval_record_mock(
content: str = "Test content",
score: float = 0.95,
**kwargs,
) -> Mock:
"""
Create a mock retrieval record.
Args:
content: Record content
score: Retrieval score
**kwargs: Additional fields for the record
Returns:
Mock object with model_dump method returning record data
"""
record = Mock()
record.model_dump.return_value = {
"content": content,
"score": score,
**kwargs,
}
return record
class TestHitTestingServiceRetrieve:
"""
Tests for HitTestingService.retrieve method (hit_testing).
This test class covers the main retrieval testing functionality, including
various retrieval model configurations, metadata filtering, and query logging.
"""
@pytest.fixture
def mock_db_session(self):
"""
Mock database session.
Provides a mocked database session for testing database operations
like adding and committing DatasetQuery records.
"""
with patch("services.hit_testing_service.db.session") as mock_db:
yield mock_db
def test_retrieve_success_with_default_retrieval_model(self, mock_db_session):
"""
Test successful retrieval with default retrieval model.
Verifies that the retrieve method works correctly when no custom
retrieval model is provided, using the default retrieval configuration.
"""
# Arrange
dataset = HitTestingTestDataFactory.create_dataset_mock(retrieval_model=None)
account = HitTestingTestDataFactory.create_user_mock()
query = "test query"
retrieval_model = None
external_retrieval_model = {}
documents = [
HitTestingTestDataFactory.create_document_mock(content="Doc 1"),
HitTestingTestDataFactory.create_document_mock(content="Doc 2"),
]
mock_records = [
HitTestingTestDataFactory.create_retrieval_record_mock(content="Doc 1"),
HitTestingTestDataFactory.create_retrieval_record_mock(content="Doc 2"),
]
with (
patch("services.hit_testing_service.RetrievalService.retrieve") as mock_retrieve,
patch("services.hit_testing_service.RetrievalService.format_retrieval_documents") as mock_format,
patch("services.hit_testing_service.time.perf_counter") as mock_perf_counter,
):
mock_perf_counter.side_effect = [0.0, 0.1] # start, end
mock_retrieve.return_value = documents
mock_format.return_value = mock_records
# Act
result = HitTestingService.retrieve(dataset, query, account, retrieval_model, external_retrieval_model)
# Assert
assert result["query"]["content"] == query
assert len(result["records"]) == 2
mock_retrieve.assert_called_once()
mock_db_session.add.assert_called_once()
mock_db_session.commit.assert_called_once()
def test_retrieve_success_with_custom_retrieval_model(self, mock_db_session):
"""
Test successful retrieval with custom retrieval model.
Verifies that custom retrieval model parameters (search method, reranking,
score threshold, etc.) are properly passed to RetrievalService.
"""
# Arrange
dataset = HitTestingTestDataFactory.create_dataset_mock()
account = HitTestingTestDataFactory.create_user_mock()
query = "test query"
retrieval_model = {
"search_method": RetrievalMethod.KEYWORD_SEARCH,
"reranking_enable": True,
"reranking_model": {"reranking_provider_name": "cohere", "reranking_model_name": "rerank-1"},
"top_k": 5,
"score_threshold_enabled": True,
"score_threshold": 0.7,
"weights": {"vector_setting": 0.5, "keyword_setting": 0.5},
}
external_retrieval_model = {}
documents = [HitTestingTestDataFactory.create_document_mock()]
mock_records = [HitTestingTestDataFactory.create_retrieval_record_mock()]
with (
patch("services.hit_testing_service.RetrievalService.retrieve") as mock_retrieve,
patch("services.hit_testing_service.RetrievalService.format_retrieval_documents") as mock_format,
patch("services.hit_testing_service.time.perf_counter") as mock_perf_counter,
):
mock_perf_counter.side_effect = [0.0, 0.1]
mock_retrieve.return_value = documents
mock_format.return_value = mock_records
# Act
result = HitTestingService.retrieve(dataset, query, account, retrieval_model, external_retrieval_model)
# Assert
assert result["query"]["content"] == query
mock_retrieve.assert_called_once()
call_kwargs = mock_retrieve.call_args[1]
assert call_kwargs["retrieval_method"] == RetrievalMethod.KEYWORD_SEARCH
assert call_kwargs["top_k"] == 5
assert call_kwargs["score_threshold"] == 0.7
assert call_kwargs["reranking_model"] == retrieval_model["reranking_model"]
def test_retrieve_with_metadata_filtering(self, mock_db_session):
"""
Test retrieval with metadata filtering conditions.
Verifies that metadata filtering conditions are properly processed
and document ID filters are applied to the retrieval query.
"""
# Arrange
dataset = HitTestingTestDataFactory.create_dataset_mock()
account = HitTestingTestDataFactory.create_user_mock()
query = "test query"
retrieval_model = {
"metadata_filtering_conditions": {
"conditions": [
{"field": "category", "operator": "is", "value": "test"},
],
},
}
external_retrieval_model = {}
mock_dataset_retrieval = MagicMock()
mock_dataset_retrieval.get_metadata_filter_condition.return_value = (
{dataset.id: ["doc-1", "doc-2"]},
None,
)
documents = [HitTestingTestDataFactory.create_document_mock()]
mock_records = [HitTestingTestDataFactory.create_retrieval_record_mock()]
with (
patch("services.hit_testing_service.RetrievalService.retrieve") as mock_retrieve,
patch("services.hit_testing_service.RetrievalService.format_retrieval_documents") as mock_format,
patch("services.hit_testing_service.DatasetRetrieval") as mock_dataset_retrieval_class,
patch("services.hit_testing_service.time.perf_counter") as mock_perf_counter,
):
mock_perf_counter.side_effect = [0.0, 0.1]
mock_dataset_retrieval_class.return_value = mock_dataset_retrieval
mock_retrieve.return_value = documents
mock_format.return_value = mock_records
# Act
result = HitTestingService.retrieve(dataset, query, account, retrieval_model, external_retrieval_model)
# Assert
assert result["query"]["content"] == query
mock_dataset_retrieval.get_metadata_filter_condition.assert_called_once()
call_kwargs = mock_retrieve.call_args[1]
assert call_kwargs["document_ids_filter"] == ["doc-1", "doc-2"]
def test_retrieve_with_metadata_filtering_no_documents(self, mock_db_session):
"""
Test retrieval with metadata filtering that returns no documents.
Verifies that when metadata filtering results in no matching documents,
an empty result is returned without calling RetrievalService.
"""
# Arrange
dataset = HitTestingTestDataFactory.create_dataset_mock()
account = HitTestingTestDataFactory.create_user_mock()
query = "test query"
retrieval_model = {
"metadata_filtering_conditions": {
"conditions": [
{"field": "category", "operator": "is", "value": "test"},
],
},
}
external_retrieval_model = {}
mock_dataset_retrieval = MagicMock()
mock_dataset_retrieval.get_metadata_filter_condition.return_value = ({}, True)
with (
patch("services.hit_testing_service.DatasetRetrieval") as mock_dataset_retrieval_class,
patch("services.hit_testing_service.RetrievalService.format_retrieval_documents") as mock_format,
):
mock_dataset_retrieval_class.return_value = mock_dataset_retrieval
mock_format.return_value = []
# Act
result = HitTestingService.retrieve(dataset, query, account, retrieval_model, external_retrieval_model)
# Assert
assert result["query"]["content"] == query
assert result["records"] == []
def test_retrieve_with_dataset_retrieval_model(self, mock_db_session):
"""
Test retrieval using dataset's retrieval model when not provided.
Verifies that when no retrieval model is provided, the dataset's
retrieval model is used as a fallback.
"""
# Arrange
dataset_retrieval_model = {
"search_method": RetrievalMethod.HYBRID_SEARCH,
"top_k": 3,
}
dataset = HitTestingTestDataFactory.create_dataset_mock(retrieval_model=dataset_retrieval_model)
account = HitTestingTestDataFactory.create_user_mock()
query = "test query"
retrieval_model = None
external_retrieval_model = {}
documents = [HitTestingTestDataFactory.create_document_mock()]
mock_records = [HitTestingTestDataFactory.create_retrieval_record_mock()]
with (
patch("services.hit_testing_service.RetrievalService.retrieve") as mock_retrieve,
patch("services.hit_testing_service.RetrievalService.format_retrieval_documents") as mock_format,
patch("services.hit_testing_service.time.perf_counter") as mock_perf_counter,
):
mock_perf_counter.side_effect = [0.0, 0.1]
mock_retrieve.return_value = documents
mock_format.return_value = mock_records
# Act
result = HitTestingService.retrieve(dataset, query, account, retrieval_model, external_retrieval_model)
# Assert
assert result["query"]["content"] == query
call_kwargs = mock_retrieve.call_args[1]
assert call_kwargs["retrieval_method"] == RetrievalMethod.HYBRID_SEARCH
assert call_kwargs["top_k"] == 3
class TestHitTestingServiceExternalRetrieve:
"""
Tests for HitTestingService.external_retrieve method.
This test class covers external knowledge base retrieval functionality,
including query escaping, response formatting, and provider validation.
"""
@pytest.fixture
def mock_db_session(self):
"""
Mock database session.
Provides a mocked database session for testing database operations
like adding and committing DatasetQuery records.
"""
with patch("services.hit_testing_service.db.session") as mock_db:
yield mock_db
def test_external_retrieve_success(self, mock_db_session):
"""
Test successful external retrieval.
Verifies that external knowledge base retrieval works correctly,
including query escaping, document formatting, and query logging.
"""
# Arrange
dataset = HitTestingTestDataFactory.create_dataset_mock(provider="external")
account = HitTestingTestDataFactory.create_user_mock()
query = 'test query with "quotes"'
external_retrieval_model = {"top_k": 5, "score_threshold": 0.8}
metadata_filtering_conditions = {}
external_documents = [
{"content": "External doc 1", "title": "Title 1", "score": 0.95, "metadata": {"key": "value"}},
{"content": "External doc 2", "title": "Title 2", "score": 0.85, "metadata": {}},
]
with (
patch("services.hit_testing_service.RetrievalService.external_retrieve") as mock_external_retrieve,
patch("services.hit_testing_service.time.perf_counter") as mock_perf_counter,
):
mock_perf_counter.side_effect = [0.0, 0.1]
mock_external_retrieve.return_value = external_documents
# Act
result = HitTestingService.external_retrieve(
dataset, query, account, external_retrieval_model, metadata_filtering_conditions
)
# Assert
assert result["query"]["content"] == query
assert len(result["records"]) == 2
assert result["records"][0]["content"] == "External doc 1"
assert result["records"][0]["title"] == "Title 1"
assert result["records"][0]["score"] == 0.95
mock_external_retrieve.assert_called_once()
# Verify query was escaped
assert mock_external_retrieve.call_args[1]["query"] == 'test query with \\"quotes\\"'
mock_db_session.add.assert_called_once()
mock_db_session.commit.assert_called_once()
def test_external_retrieve_non_external_provider(self, mock_db_session):
"""
Test external retrieval with non-external provider (should return empty).
Verifies that when the dataset provider is not "external", the method
returns an empty result without performing retrieval or database operations.
"""
# Arrange
dataset = HitTestingTestDataFactory.create_dataset_mock(provider="vendor")
account = HitTestingTestDataFactory.create_user_mock()
query = "test query"
external_retrieval_model = {}
metadata_filtering_conditions = {}
# Act
result = HitTestingService.external_retrieve(
dataset, query, account, external_retrieval_model, metadata_filtering_conditions
)
# Assert
assert result["query"]["content"] == query
assert result["records"] == []
mock_db_session.add.assert_not_called()
def test_external_retrieve_with_metadata_filtering(self, mock_db_session):
"""
Test external retrieval with metadata filtering conditions.
Verifies that metadata filtering conditions are properly passed
to the external retrieval service.
"""
# Arrange
dataset = HitTestingTestDataFactory.create_dataset_mock(provider="external")
account = HitTestingTestDataFactory.create_user_mock()
query = "test query"
external_retrieval_model = {"top_k": 3}
metadata_filtering_conditions = {"category": "test"}
external_documents = [{"content": "Doc 1", "title": "Title", "score": 0.9, "metadata": {}}]
with (
patch("services.hit_testing_service.RetrievalService.external_retrieve") as mock_external_retrieve,
patch("services.hit_testing_service.time.perf_counter") as mock_perf_counter,
):
mock_perf_counter.side_effect = [0.0, 0.1]
mock_external_retrieve.return_value = external_documents
# Act
result = HitTestingService.external_retrieve(
dataset, query, account, external_retrieval_model, metadata_filtering_conditions
)
# Assert
assert result["query"]["content"] == query
assert len(result["records"]) == 1
call_kwargs = mock_external_retrieve.call_args[1]
assert call_kwargs["metadata_filtering_conditions"] == metadata_filtering_conditions
def test_external_retrieve_empty_documents(self, mock_db_session):
"""
Test external retrieval with empty document list.
Verifies that when external retrieval returns no documents,
an empty result is properly formatted and returned.
"""
# Arrange
dataset = HitTestingTestDataFactory.create_dataset_mock(provider="external")
account = HitTestingTestDataFactory.create_user_mock()
query = "test query"
external_retrieval_model = {}
metadata_filtering_conditions = {}
with (
patch("services.hit_testing_service.RetrievalService.external_retrieve") as mock_external_retrieve,
patch("services.hit_testing_service.time.perf_counter") as mock_perf_counter,
):
mock_perf_counter.side_effect = [0.0, 0.1]
mock_external_retrieve.return_value = []
# Act
result = HitTestingService.external_retrieve(
dataset, query, account, external_retrieval_model, metadata_filtering_conditions
)
# Assert
assert result["query"]["content"] == query
assert result["records"] == []
class TestHitTestingServiceCompactRetrieveResponse:
"""
Tests for HitTestingService.compact_retrieve_response method.
This test class covers response formatting for internal dataset retrieval,
ensuring documents are properly formatted into retrieval records.
"""
def test_compact_retrieve_response_success(self):
"""
Test successful response formatting.
Verifies that documents are properly formatted into retrieval records
with correct structure and data.
"""
# Arrange
query = "test query"
documents = [
HitTestingTestDataFactory.create_document_mock(content="Doc 1"),
HitTestingTestDataFactory.create_document_mock(content="Doc 2"),
]
mock_records = [
HitTestingTestDataFactory.create_retrieval_record_mock(content="Doc 1", score=0.95),
HitTestingTestDataFactory.create_retrieval_record_mock(content="Doc 2", score=0.85),
]
with patch("services.hit_testing_service.RetrievalService.format_retrieval_documents") as mock_format:
mock_format.return_value = mock_records
# Act
result = HitTestingService.compact_retrieve_response(query, documents)
# Assert
assert result["query"]["content"] == query
assert len(result["records"]) == 2
assert result["records"][0]["content"] == "Doc 1"
assert result["records"][0]["score"] == 0.95
mock_format.assert_called_once_with(documents)
def test_compact_retrieve_response_empty_documents(self):
"""
Test response formatting with empty document list.
Verifies that an empty document list results in an empty records array
while maintaining the correct response structure.
"""
# Arrange
query = "test query"
documents = []
with patch("services.hit_testing_service.RetrievalService.format_retrieval_documents") as mock_format:
mock_format.return_value = []
# Act
result = HitTestingService.compact_retrieve_response(query, documents)
# Assert
assert result["query"]["content"] == query
assert result["records"] == []
class TestHitTestingServiceCompactExternalRetrieveResponse:
"""
Tests for HitTestingService.compact_external_retrieve_response method.
This test class covers response formatting for external knowledge base
retrieval, ensuring proper field extraction and provider validation.
"""
def test_compact_external_retrieve_response_external_provider(self):
"""
Test external response formatting for external provider.
Verifies that external documents are properly formatted with all
required fields (content, title, score, metadata).
"""
# Arrange
dataset = HitTestingTestDataFactory.create_dataset_mock(provider="external")
query = "test query"
documents = [
{"content": "Doc 1", "title": "Title 1", "score": 0.95, "metadata": {"key": "value"}},
{"content": "Doc 2", "title": "Title 2", "score": 0.85, "metadata": {}},
]
# Act
result = HitTestingService.compact_external_retrieve_response(dataset, query, documents)
# Assert
assert result["query"]["content"] == query
assert len(result["records"]) == 2
assert result["records"][0]["content"] == "Doc 1"
assert result["records"][0]["title"] == "Title 1"
assert result["records"][0]["score"] == 0.95
assert result["records"][0]["metadata"] == {"key": "value"}
def test_compact_external_retrieve_response_non_external_provider(self):
"""
Test external response formatting for non-external provider.
Verifies that non-external providers return an empty records array
regardless of input documents.
"""
# Arrange
dataset = HitTestingTestDataFactory.create_dataset_mock(provider="vendor")
query = "test query"
documents = [{"content": "Doc 1"}]
# Act
result = HitTestingService.compact_external_retrieve_response(dataset, query, documents)
# Assert
assert result["query"]["content"] == query
assert result["records"] == []
def test_compact_external_retrieve_response_missing_fields(self):
"""
Test external response formatting with missing optional fields.
Verifies that missing optional fields (title, score, metadata) are
handled gracefully by setting them to None.
"""
# Arrange
dataset = HitTestingTestDataFactory.create_dataset_mock(provider="external")
query = "test query"
documents = [
{"content": "Doc 1"}, # Missing title, score, metadata
{"content": "Doc 2", "title": "Title 2"}, # Missing score, metadata
]
# Act
result = HitTestingService.compact_external_retrieve_response(dataset, query, documents)
# Assert
assert result["query"]["content"] == query
assert len(result["records"]) == 2
assert result["records"][0]["content"] == "Doc 1"
assert result["records"][0]["title"] is None
assert result["records"][0]["score"] is None
assert result["records"][0]["metadata"] is None
class TestHitTestingServiceHitTestingArgsCheck:
"""
Tests for HitTestingService.hit_testing_args_check method.
This test class covers query argument validation, ensuring queries
meet the required criteria (non-empty, max 250 characters).
"""
def test_hit_testing_args_check_success(self):
"""
Test successful argument validation.
Verifies that valid queries pass validation without raising errors.
"""
# Arrange
args = {"query": "valid query"}
# Act & Assert (should not raise)
HitTestingService.hit_testing_args_check(args)
def test_hit_testing_args_check_empty_query(self):
"""
Test validation fails with empty query.
Verifies that empty queries raise a ValueError with appropriate message.
"""
# Arrange
args = {"query": ""}
# Act & Assert
with pytest.raises(ValueError, match="Query is required and cannot exceed 250 characters"):
HitTestingService.hit_testing_args_check(args)
def test_hit_testing_args_check_none_query(self):
"""
Test validation fails with None query.
Verifies that None queries raise a ValueError with appropriate message.
"""
# Arrange
args = {"query": None}
# Act & Assert
with pytest.raises(ValueError, match="Query is required and cannot exceed 250 characters"):
HitTestingService.hit_testing_args_check(args)
def test_hit_testing_args_check_too_long_query(self):
"""
Test validation fails with query exceeding 250 characters.
Verifies that queries longer than 250 characters raise a ValueError.
"""
# Arrange
args = {"query": "a" * 251}
# Act & Assert
with pytest.raises(ValueError, match="Query is required and cannot exceed 250 characters"):
HitTestingService.hit_testing_args_check(args)
def test_hit_testing_args_check_exactly_250_characters(self):
"""
Test validation succeeds with exactly 250 characters.
Verifies that queries with exactly 250 characters (the maximum)
pass validation successfully.
"""
# Arrange
args = {"query": "a" * 250}
# Act & Assert (should not raise)
HitTestingService.hit_testing_args_check(args)
class TestHitTestingServiceEscapeQueryForSearch:
"""
Tests for HitTestingService.escape_query_for_search method.
This test class covers query escaping functionality for external search,
ensuring special characters are properly escaped.
"""
def test_escape_query_for_search_with_quotes(self):
"""
Test escaping quotes in query.
Verifies that double quotes in queries are properly escaped with
backslashes for external search compatibility.
"""
# Arrange
query = 'test query with "quotes"'
# Act
result = HitTestingService.escape_query_for_search(query)
# Assert
assert result == 'test query with \\"quotes\\"'
def test_escape_query_for_search_without_quotes(self):
"""
Test query without quotes (no change).
Verifies that queries without quotes remain unchanged after escaping.
"""
# Arrange
query = "test query without quotes"
# Act
result = HitTestingService.escape_query_for_search(query)
# Assert
assert result == query
def test_escape_query_for_search_multiple_quotes(self):
"""
Test escaping multiple quotes in query.
Verifies that all occurrences of double quotes in a query are
properly escaped, not just the first one.
"""
# Arrange
query = 'test "query" with "multiple" quotes'
# Act
result = HitTestingService.escape_query_for_search(query)
# Assert
assert result == 'test \\"query\\" with \\"multiple\\" quotes'
def test_escape_query_for_search_empty_string(self):
"""
Test escaping empty string.
Verifies that empty strings are handled correctly and remain empty
after the escaping operation.
"""
# Arrange
query = ""
# Act
result = HitTestingService.escape_query_for_search(query)
# Assert
assert result == ""

Some files were not shown because too many files have changed in this diff Show More