mirror of https://github.com/langgenius/dify.git
Merge branch 'main' into feat-agent-mask
This commit is contained in:
commit
c0916e6eb2
|
|
@ -0,0 +1,6 @@
|
|||
# Cursor Rules for Dify Project
|
||||
|
||||
## Automated Test Generation
|
||||
|
||||
- Use `web/testing/testing.md` as the canonical instruction set for generating frontend automated tests.
|
||||
- When proposing or saving tests, re-read that document and follow every requirement.
|
||||
|
|
@ -0,0 +1,12 @@
|
|||
# Copilot Instructions
|
||||
|
||||
GitHub Copilot must follow the unified frontend testing requirements documented in `web/testing/testing.md`.
|
||||
|
||||
Key reminders:
|
||||
|
||||
- Generate tests using the mandated tech stack, naming, and code style (AAA pattern, `fireEvent`, descriptive test names, cleans up mocks).
|
||||
- Cover rendering, prop combinations, and edge cases by default; extend coverage for hooks, routing, async flows, and domain-specific components when applicable.
|
||||
- Target >95% line and branch coverage and 100% function/statement coverage.
|
||||
- Apply the project's mocking conventions for i18n, toast notifications, and Next.js utilities.
|
||||
|
||||
Any suggestions from Copilot that conflict with `web/testing/testing.md` should be revised before acceptance.
|
||||
|
|
@ -0,0 +1,5 @@
|
|||
# Windsurf Testing Rules
|
||||
|
||||
- Use `web/testing/testing.md` as the single source of truth for frontend automated testing.
|
||||
- Honor every requirement in that document when generating or accepting tests.
|
||||
- When proposing or saving tests, re-read that document and follow every requirement.
|
||||
|
|
@ -77,6 +77,8 @@ How we prioritize:
|
|||
|
||||
For setting up the frontend service, please refer to our comprehensive [guide](https://github.com/langgenius/dify/blob/main/web/README.md) in the `web/README.md` file. This document provides detailed instructions to help you set up the frontend environment properly.
|
||||
|
||||
**Testing**: All React components must have comprehensive test coverage. See [web/testing/testing.md](https://github.com/langgenius/dify/blob/main/web/testing/testing.md) for the canonical frontend testing guidelines and follow every requirement described there.
|
||||
|
||||
#### Backend
|
||||
|
||||
For setting up the backend service, kindly refer to our detailed [instructions](https://github.com/langgenius/dify/blob/main/api/README.md) in the `api/README.md` file. This document contains step-by-step guidance to help you get the backend up and running smoothly.
|
||||
|
|
|
|||
|
|
@ -36,6 +36,12 @@
|
|||
<img alt="Issues closed" src="https://img.shields.io/github/issues-search?query=repo%3Alanggenius%2Fdify%20is%3Aclosed&label=issues%20closed&labelColor=%20%237d89b0&color=%20%235d6b98"></a>
|
||||
<a href="https://github.com/langgenius/dify/discussions/" target="_blank">
|
||||
<img alt="Discussion posts" src="https://img.shields.io/github/discussions/langgenius/dify?labelColor=%20%239b8afb&color=%20%237a5af8"></a>
|
||||
<a href="https://insights.linuxfoundation.org/project/langgenius-dify" target="_blank">
|
||||
<img alt="LFX Health Score" src="https://insights.linuxfoundation.org/api/badge/health-score?project=langgenius-dify"></a>
|
||||
<a href="https://insights.linuxfoundation.org/project/langgenius-dify" target="_blank">
|
||||
<img alt="LFX Contributors" src="https://insights.linuxfoundation.org/api/badge/contributors?project=langgenius-dify"></a>
|
||||
<a href="https://insights.linuxfoundation.org/project/langgenius-dify" target="_blank">
|
||||
<img alt="LFX Active Contributors" src="https://insights.linuxfoundation.org/api/badge/active-contributors?project=langgenius-dify"></a>
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
|
|
|
|||
|
|
@ -176,6 +176,7 @@ WEAVIATE_ENDPOINT=http://localhost:8080
|
|||
WEAVIATE_API_KEY=WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih
|
||||
WEAVIATE_GRPC_ENABLED=false
|
||||
WEAVIATE_BATCH_SIZE=100
|
||||
WEAVIATE_TOKENIZATION=word
|
||||
|
||||
# OceanBase Vector configuration
|
||||
OCEANBASE_VECTOR_HOST=127.0.0.1
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ layers =
|
|||
graph
|
||||
nodes
|
||||
node_events
|
||||
runtime
|
||||
entities
|
||||
containers =
|
||||
core.workflow
|
||||
|
|
|
|||
|
|
@ -57,7 +57,7 @@ RUN \
|
|||
# for gmpy2 \
|
||||
libgmp-dev libmpfr-dev libmpc-dev \
|
||||
# For Security
|
||||
expat libldap-2.5-0 perl libsqlite3-0 zlib1g \
|
||||
expat libldap-2.5-0=2.5.13+dfsg-5 perl libsqlite3-0=3.40.1-2+deb12u2 zlib1g=1:1.2.13.dfsg-1 \
|
||||
# install fonts to support the use of tools like pypdfium2
|
||||
fonts-noto-cjk \
|
||||
# install a package to improve the accuracy of guessing mime type and file extension
|
||||
|
|
@ -73,7 +73,8 @@ COPY --from=packages ${VIRTUAL_ENV} ${VIRTUAL_ENV}
|
|||
ENV PATH="${VIRTUAL_ENV}/bin:${PATH}"
|
||||
|
||||
# Download nltk data
|
||||
RUN python -c "import nltk; nltk.download('punkt'); nltk.download('averaged_perceptron_tagger')"
|
||||
RUN mkdir -p /usr/local/share/nltk_data && NLTK_DATA=/usr/local/share/nltk_data python -c "import nltk; nltk.download('punkt'); nltk.download('averaged_perceptron_tagger'); nltk.download('stopwords')" \
|
||||
&& chmod -R 755 /usr/local/share/nltk_data
|
||||
|
||||
ENV TIKTOKEN_CACHE_DIR=/app/api/.tiktoken_cache
|
||||
|
||||
|
|
@ -86,7 +87,15 @@ COPY . /app/api/
|
|||
COPY docker/entrypoint.sh /entrypoint.sh
|
||||
RUN chmod +x /entrypoint.sh
|
||||
|
||||
# Create non-root user and set permissions
|
||||
RUN groupadd -r -g 1001 dify && \
|
||||
useradd -r -u 1001 -g 1001 -s /bin/bash dify && \
|
||||
mkdir -p /home/dify && \
|
||||
chown -R 1001:1001 /app /home/dify ${TIKTOKEN_CACHE_DIR} /entrypoint.sh
|
||||
|
||||
ARG COMMIT_SHA
|
||||
ENV COMMIT_SHA=${COMMIT_SHA}
|
||||
ENV NLTK_DATA=/usr/local/share/nltk_data
|
||||
USER 1001
|
||||
|
||||
ENTRYPOINT ["/bin/bash", "/entrypoint.sh"]
|
||||
|
|
|
|||
|
|
@ -31,3 +31,8 @@ class WeaviateConfig(BaseSettings):
|
|||
description="Number of objects to be processed in a single batch operation (default is 100)",
|
||||
default=100,
|
||||
)
|
||||
|
||||
WEAVIATE_TOKENIZATION: str | None = Field(
|
||||
description="Tokenization for Weaviate (default is word)",
|
||||
default="word",
|
||||
)
|
||||
|
|
|
|||
|
|
@ -17,7 +17,6 @@ from controllers.console.app.error import (
|
|||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.errors.error import (
|
||||
ModelCurrentlyNotSupportError,
|
||||
|
|
@ -32,6 +31,7 @@ from libs.login import current_user, login_required
|
|||
from models import Account
|
||||
from models.model import AppMode
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.app_task_service import AppTaskService
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -121,7 +121,13 @@ class CompletionMessageStopApi(Resource):
|
|||
def post(self, app_model, task_id):
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, current_user.id)
|
||||
|
||||
AppTaskService.stop_task(
|
||||
task_id=task_id,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
user_id=current_user.id,
|
||||
app_mode=AppMode.value_of(app_model.mode),
|
||||
)
|
||||
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
|
@ -220,6 +226,12 @@ class ChatMessageStopApi(Resource):
|
|||
def post(self, app_model, task_id):
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, current_user.id)
|
||||
|
||||
AppTaskService.stop_task(
|
||||
task_id=task_id,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
user_id=current_user.id,
|
||||
app_mode=AppMode.value_of(app_model.mode),
|
||||
)
|
||||
|
||||
return {"result": "success"}, 200
|
||||
|
|
|
|||
|
|
@ -369,6 +369,58 @@ class MessageSuggestedQuestionApi(Resource):
|
|||
return {"data": questions}
|
||||
|
||||
|
||||
# Shared parser for feedback export (used for both documentation and runtime parsing)
|
||||
feedback_export_parser = (
|
||||
console_ns.parser()
|
||||
.add_argument("from_source", type=str, choices=["user", "admin"], location="args", help="Filter by feedback source")
|
||||
.add_argument("rating", type=str, choices=["like", "dislike"], location="args", help="Filter by rating")
|
||||
.add_argument("has_comment", type=bool, location="args", help="Only include feedback with comments")
|
||||
.add_argument("start_date", type=str, location="args", help="Start date (YYYY-MM-DD)")
|
||||
.add_argument("end_date", type=str, location="args", help="End date (YYYY-MM-DD)")
|
||||
.add_argument("format", type=str, choices=["csv", "json"], default="csv", location="args", help="Export format")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/feedbacks/export")
|
||||
class MessageFeedbackExportApi(Resource):
|
||||
@console_ns.doc("export_feedbacks")
|
||||
@console_ns.doc(description="Export user feedback data for Google Sheets")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(feedback_export_parser)
|
||||
@console_ns.response(200, "Feedback data exported successfully")
|
||||
@console_ns.response(400, "Invalid parameters")
|
||||
@console_ns.response(500, "Internal server error")
|
||||
@get_app_model
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_model):
|
||||
args = feedback_export_parser.parse_args()
|
||||
|
||||
# Import the service function
|
||||
from services.feedback_service import FeedbackService
|
||||
|
||||
try:
|
||||
export_data = FeedbackService.export_feedbacks(
|
||||
app_id=app_model.id,
|
||||
from_source=args.get("from_source"),
|
||||
rating=args.get("rating"),
|
||||
has_comment=args.get("has_comment"),
|
||||
start_date=args.get("start_date"),
|
||||
end_date=args.get("end_date"),
|
||||
format_type=args.get("format", "csv"),
|
||||
)
|
||||
|
||||
return export_data
|
||||
|
||||
except ValueError as e:
|
||||
logger.exception("Parameter validation error in feedback export")
|
||||
return {"error": f"Parameter validation error: {str(e)}"}, 400
|
||||
except Exception as e:
|
||||
logger.exception("Error exporting feedback data")
|
||||
raise InternalServerError(str(e))
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/messages/<uuid:message_id>")
|
||||
class MessageApi(Resource):
|
||||
@console_ns.doc("get_message")
|
||||
|
|
|
|||
|
|
@ -90,14 +90,20 @@ workflow_pagination_model = console_ns.model("WorkflowPagination", workflow_pagi
|
|||
# Otherwise register it here
|
||||
from fields.end_user_fields import simple_end_user_fields
|
||||
|
||||
simple_end_user_model = None
|
||||
try:
|
||||
simple_end_user_model = console_ns.models.get("SimpleEndUser")
|
||||
except (KeyError, AttributeError):
|
||||
except AttributeError:
|
||||
pass
|
||||
if simple_end_user_model is None:
|
||||
simple_end_user_model = console_ns.model("SimpleEndUser", simple_end_user_fields)
|
||||
|
||||
workflow_run_node_execution_model = None
|
||||
try:
|
||||
workflow_run_node_execution_model = console_ns.models.get("WorkflowRunNodeExecution")
|
||||
except (KeyError, AttributeError):
|
||||
except AttributeError:
|
||||
pass
|
||||
if workflow_run_node_execution_model is None:
|
||||
workflow_run_node_execution_model = console_ns.model("WorkflowRunNodeExecution", workflow_run_node_execution_fields)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,14 +1,13 @@
|
|||
import logging
|
||||
|
||||
from flask_restx import Resource, marshal_with, reqparse
|
||||
from flask import request
|
||||
from flask_restx import Resource, marshal_with
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from extensions.ext_database import db
|
||||
from fields.workflow_trigger_fields import trigger_fields, triggers_list_fields, webhook_trigger_fields
|
||||
from libs.login import current_user, login_required
|
||||
|
|
@ -16,12 +15,35 @@ from models.enums import AppTriggerStatus
|
|||
from models.model import Account, App, AppMode
|
||||
from models.trigger import AppTrigger, WorkflowWebhookTrigger
|
||||
|
||||
from .. import console_ns
|
||||
from ..app.wraps import get_app_model
|
||||
from ..wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class Parser(BaseModel):
|
||||
node_id: str
|
||||
|
||||
|
||||
class ParserEnable(BaseModel):
|
||||
trigger_id: str
|
||||
enable_trigger: bool
|
||||
|
||||
|
||||
console_ns.schema_model(Parser.__name__, Parser.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserEnable.__name__, ParserEnable.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/triggers/webhook")
|
||||
class WebhookTriggerApi(Resource):
|
||||
"""Webhook Trigger API"""
|
||||
|
||||
@console_ns.expect(console_ns.models[Parser.__name__], validate=True)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -29,10 +51,9 @@ class WebhookTriggerApi(Resource):
|
|||
@marshal_with(webhook_trigger_fields)
|
||||
def get(self, app_model: App):
|
||||
"""Get webhook trigger for a node"""
|
||||
parser = reqparse.RequestParser().add_argument("node_id", type=str, required=True, help="Node ID is required")
|
||||
args = parser.parse_args()
|
||||
args = Parser.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
node_id = str(args["node_id"])
|
||||
node_id = args.node_id
|
||||
|
||||
with Session(db.engine) as session:
|
||||
# Get webhook trigger for this app and node
|
||||
|
|
@ -51,6 +72,7 @@ class WebhookTriggerApi(Resource):
|
|||
return webhook_trigger
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/triggers")
|
||||
class AppTriggersApi(Resource):
|
||||
"""App Triggers list API"""
|
||||
|
||||
|
|
@ -90,7 +112,9 @@ class AppTriggersApi(Resource):
|
|||
return {"data": triggers}
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/trigger-enable")
|
||||
class AppTriggerEnableApi(Resource):
|
||||
@console_ns.expect(console_ns.models[ParserEnable.__name__], validate=True)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -99,17 +123,11 @@ class AppTriggerEnableApi(Resource):
|
|||
@marshal_with(trigger_fields)
|
||||
def post(self, app_model: App):
|
||||
"""Update app trigger (enable/disable)"""
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("trigger_id", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("enable_trigger", type=bool, required=True, nullable=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = ParserEnable.model_validate(console_ns.payload)
|
||||
|
||||
assert current_user.current_tenant_id is not None
|
||||
|
||||
trigger_id = args["trigger_id"]
|
||||
|
||||
trigger_id = args.trigger_id
|
||||
with Session(db.engine) as session:
|
||||
# Find the trigger using select
|
||||
trigger = session.execute(
|
||||
|
|
@ -124,7 +142,7 @@ class AppTriggerEnableApi(Resource):
|
|||
raise NotFound("Trigger not found")
|
||||
|
||||
# Update status based on enable_trigger boolean
|
||||
trigger.status = AppTriggerStatus.ENABLED if args["enable_trigger"] else AppTriggerStatus.DISABLED
|
||||
trigger.status = AppTriggerStatus.ENABLED if args.enable_trigger else AppTriggerStatus.DISABLED
|
||||
|
||||
session.commit()
|
||||
session.refresh(trigger)
|
||||
|
|
@ -137,8 +155,3 @@ class AppTriggerEnableApi(Resource):
|
|||
trigger.icon = "" # type: ignore
|
||||
|
||||
return trigger
|
||||
|
||||
|
||||
console_ns.add_resource(WebhookTriggerApi, "/apps/<uuid:app_id>/workflows/triggers/webhook")
|
||||
console_ns.add_resource(AppTriggersApi, "/apps/<uuid:app_id>/triggers")
|
||||
console_ns.add_resource(AppTriggerEnableApi, "/apps/<uuid:app_id>/trigger-enable")
|
||||
|
|
|
|||
|
|
@ -15,7 +15,6 @@ from controllers.console.app.error import (
|
|||
from controllers.console.explore.error import NotChatAppError, NotCompletionAppError
|
||||
from controllers.console.explore.wraps import InstalledAppResource
|
||||
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.errors.error import (
|
||||
ModelCurrentlyNotSupportError,
|
||||
|
|
@ -31,6 +30,7 @@ from libs.login import current_user
|
|||
from models import Account
|
||||
from models.model import AppMode
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.app_task_service import AppTaskService
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
|
||||
from .. import console_ns
|
||||
|
|
@ -46,7 +46,7 @@ logger = logging.getLogger(__name__)
|
|||
class CompletionApi(InstalledAppResource):
|
||||
def post(self, installed_app):
|
||||
app_model = installed_app.app
|
||||
if app_model.mode != "completion":
|
||||
if app_model.mode != AppMode.COMPLETION:
|
||||
raise NotCompletionAppError()
|
||||
|
||||
parser = (
|
||||
|
|
@ -102,12 +102,18 @@ class CompletionApi(InstalledAppResource):
|
|||
class CompletionStopApi(InstalledAppResource):
|
||||
def post(self, installed_app, task_id):
|
||||
app_model = installed_app.app
|
||||
if app_model.mode != "completion":
|
||||
if app_model.mode != AppMode.COMPLETION:
|
||||
raise NotCompletionAppError()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id)
|
||||
|
||||
AppTaskService.stop_task(
|
||||
task_id=task_id,
|
||||
invoke_from=InvokeFrom.EXPLORE,
|
||||
user_id=current_user.id,
|
||||
app_mode=AppMode.value_of(app_model.mode),
|
||||
)
|
||||
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
|
@ -184,6 +190,12 @@ class ChatStopApi(InstalledAppResource):
|
|||
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id)
|
||||
|
||||
AppTaskService.stop_task(
|
||||
task_id=task_id,
|
||||
invoke_from=InvokeFrom.EXPLORE,
|
||||
user_id=current_user.id,
|
||||
app_mode=app_mode,
|
||||
)
|
||||
|
||||
return {"result": "success"}, 200
|
||||
|
|
|
|||
|
|
@ -1,8 +1,10 @@
|
|||
from datetime import datetime
|
||||
from typing import Literal
|
||||
|
||||
import pytz
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields, marshal_with, reqparse
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
|
|
@ -42,20 +44,198 @@ from services.account_service import AccountService
|
|||
from services.billing_service import BillingService
|
||||
from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
def _init_parser():
|
||||
parser = reqparse.RequestParser()
|
||||
if dify_config.EDITION == "CLOUD":
|
||||
parser.add_argument("invitation_code", type=str, location="json")
|
||||
parser.add_argument("interface_language", type=supported_language, required=True, location="json").add_argument(
|
||||
"timezone", type=timezone, required=True, location="json"
|
||||
)
|
||||
return parser
|
||||
|
||||
class AccountInitPayload(BaseModel):
|
||||
interface_language: str
|
||||
timezone: str
|
||||
invitation_code: str | None = None
|
||||
|
||||
@field_validator("interface_language")
|
||||
@classmethod
|
||||
def validate_language(cls, value: str) -> str:
|
||||
return supported_language(value)
|
||||
|
||||
@field_validator("timezone")
|
||||
@classmethod
|
||||
def validate_timezone(cls, value: str) -> str:
|
||||
return timezone(value)
|
||||
|
||||
|
||||
class AccountNamePayload(BaseModel):
|
||||
name: str = Field(min_length=3, max_length=30)
|
||||
|
||||
|
||||
class AccountAvatarPayload(BaseModel):
|
||||
avatar: str
|
||||
|
||||
|
||||
class AccountInterfaceLanguagePayload(BaseModel):
|
||||
interface_language: str
|
||||
|
||||
@field_validator("interface_language")
|
||||
@classmethod
|
||||
def validate_language(cls, value: str) -> str:
|
||||
return supported_language(value)
|
||||
|
||||
|
||||
class AccountInterfaceThemePayload(BaseModel):
|
||||
interface_theme: Literal["light", "dark"]
|
||||
|
||||
|
||||
class AccountTimezonePayload(BaseModel):
|
||||
timezone: str
|
||||
|
||||
@field_validator("timezone")
|
||||
@classmethod
|
||||
def validate_timezone(cls, value: str) -> str:
|
||||
return timezone(value)
|
||||
|
||||
|
||||
class AccountPasswordPayload(BaseModel):
|
||||
password: str | None = None
|
||||
new_password: str
|
||||
repeat_new_password: str
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_passwords_match(self) -> "AccountPasswordPayload":
|
||||
if self.new_password != self.repeat_new_password:
|
||||
raise RepeatPasswordNotMatchError()
|
||||
return self
|
||||
|
||||
|
||||
class AccountDeletePayload(BaseModel):
|
||||
token: str
|
||||
code: str
|
||||
|
||||
|
||||
class AccountDeletionFeedbackPayload(BaseModel):
|
||||
email: str
|
||||
feedback: str
|
||||
|
||||
@field_validator("email")
|
||||
@classmethod
|
||||
def validate_email(cls, value: str) -> str:
|
||||
return email(value)
|
||||
|
||||
|
||||
class EducationActivatePayload(BaseModel):
|
||||
token: str
|
||||
institution: str
|
||||
role: str
|
||||
|
||||
|
||||
class EducationAutocompleteQuery(BaseModel):
|
||||
keywords: str
|
||||
page: int = 0
|
||||
limit: int = 20
|
||||
|
||||
|
||||
class ChangeEmailSendPayload(BaseModel):
|
||||
email: str
|
||||
language: str | None = None
|
||||
phase: str | None = None
|
||||
token: str | None = None
|
||||
|
||||
@field_validator("email")
|
||||
@classmethod
|
||||
def validate_email(cls, value: str) -> str:
|
||||
return email(value)
|
||||
|
||||
|
||||
class ChangeEmailValidityPayload(BaseModel):
|
||||
email: str
|
||||
code: str
|
||||
token: str
|
||||
|
||||
@field_validator("email")
|
||||
@classmethod
|
||||
def validate_email(cls, value: str) -> str:
|
||||
return email(value)
|
||||
|
||||
|
||||
class ChangeEmailResetPayload(BaseModel):
|
||||
new_email: str
|
||||
token: str
|
||||
|
||||
@field_validator("new_email")
|
||||
@classmethod
|
||||
def validate_email(cls, value: str) -> str:
|
||||
return email(value)
|
||||
|
||||
|
||||
class CheckEmailUniquePayload(BaseModel):
|
||||
email: str
|
||||
|
||||
@field_validator("email")
|
||||
@classmethod
|
||||
def validate_email(cls, value: str) -> str:
|
||||
return email(value)
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
AccountInitPayload.__name__, AccountInitPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
console_ns.schema_model(
|
||||
AccountNamePayload.__name__, AccountNamePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
console_ns.schema_model(
|
||||
AccountAvatarPayload.__name__, AccountAvatarPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
console_ns.schema_model(
|
||||
AccountInterfaceLanguagePayload.__name__,
|
||||
AccountInterfaceLanguagePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
console_ns.schema_model(
|
||||
AccountInterfaceThemePayload.__name__,
|
||||
AccountInterfaceThemePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
console_ns.schema_model(
|
||||
AccountTimezonePayload.__name__,
|
||||
AccountTimezonePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
console_ns.schema_model(
|
||||
AccountPasswordPayload.__name__,
|
||||
AccountPasswordPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
console_ns.schema_model(
|
||||
AccountDeletePayload.__name__,
|
||||
AccountDeletePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
console_ns.schema_model(
|
||||
AccountDeletionFeedbackPayload.__name__,
|
||||
AccountDeletionFeedbackPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
console_ns.schema_model(
|
||||
EducationActivatePayload.__name__,
|
||||
EducationActivatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
console_ns.schema_model(
|
||||
EducationAutocompleteQuery.__name__,
|
||||
EducationAutocompleteQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
console_ns.schema_model(
|
||||
ChangeEmailSendPayload.__name__,
|
||||
ChangeEmailSendPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
console_ns.schema_model(
|
||||
ChangeEmailValidityPayload.__name__,
|
||||
ChangeEmailValidityPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
console_ns.schema_model(
|
||||
ChangeEmailResetPayload.__name__,
|
||||
ChangeEmailResetPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
console_ns.schema_model(
|
||||
CheckEmailUniquePayload.__name__,
|
||||
CheckEmailUniquePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/account/init")
|
||||
class AccountInitApi(Resource):
|
||||
@console_ns.expect(_init_parser())
|
||||
@console_ns.expect(console_ns.models[AccountInitPayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
def post(self):
|
||||
|
|
@ -64,17 +244,18 @@ class AccountInitApi(Resource):
|
|||
if account.status == "active":
|
||||
raise AccountAlreadyInitedError()
|
||||
|
||||
args = _init_parser().parse_args()
|
||||
payload = console_ns.payload or {}
|
||||
args = AccountInitPayload.model_validate(payload)
|
||||
|
||||
if dify_config.EDITION == "CLOUD":
|
||||
if not args["invitation_code"]:
|
||||
if not args.invitation_code:
|
||||
raise ValueError("invitation_code is required")
|
||||
|
||||
# check invitation code
|
||||
invitation_code = (
|
||||
db.session.query(InvitationCode)
|
||||
.where(
|
||||
InvitationCode.code == args["invitation_code"],
|
||||
InvitationCode.code == args.invitation_code,
|
||||
InvitationCode.status == "unused",
|
||||
)
|
||||
.first()
|
||||
|
|
@ -88,8 +269,8 @@ class AccountInitApi(Resource):
|
|||
invitation_code.used_by_tenant_id = account.current_tenant_id
|
||||
invitation_code.used_by_account_id = account.id
|
||||
|
||||
account.interface_language = args["interface_language"]
|
||||
account.timezone = args["timezone"]
|
||||
account.interface_language = args.interface_language
|
||||
account.timezone = args.timezone
|
||||
account.interface_theme = "light"
|
||||
account.status = "active"
|
||||
account.initialized_at = naive_utc_now()
|
||||
|
|
@ -110,137 +291,104 @@ class AccountProfileApi(Resource):
|
|||
return current_user
|
||||
|
||||
|
||||
parser_name = reqparse.RequestParser().add_argument("name", type=str, required=True, location="json")
|
||||
|
||||
|
||||
@console_ns.route("/account/name")
|
||||
class AccountNameApi(Resource):
|
||||
@console_ns.expect(parser_name)
|
||||
@console_ns.expect(console_ns.models[AccountNamePayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(account_fields)
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
args = parser_name.parse_args()
|
||||
|
||||
# Validate account name length
|
||||
if len(args["name"]) < 3 or len(args["name"]) > 30:
|
||||
raise ValueError("Account name must be between 3 and 30 characters.")
|
||||
|
||||
updated_account = AccountService.update_account(current_user, name=args["name"])
|
||||
payload = console_ns.payload or {}
|
||||
args = AccountNamePayload.model_validate(payload)
|
||||
updated_account = AccountService.update_account(current_user, name=args.name)
|
||||
|
||||
return updated_account
|
||||
|
||||
|
||||
parser_avatar = reqparse.RequestParser().add_argument("avatar", type=str, required=True, location="json")
|
||||
|
||||
|
||||
@console_ns.route("/account/avatar")
|
||||
class AccountAvatarApi(Resource):
|
||||
@console_ns.expect(parser_avatar)
|
||||
@console_ns.expect(console_ns.models[AccountAvatarPayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(account_fields)
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
args = parser_avatar.parse_args()
|
||||
payload = console_ns.payload or {}
|
||||
args = AccountAvatarPayload.model_validate(payload)
|
||||
|
||||
updated_account = AccountService.update_account(current_user, avatar=args["avatar"])
|
||||
updated_account = AccountService.update_account(current_user, avatar=args.avatar)
|
||||
|
||||
return updated_account
|
||||
|
||||
|
||||
parser_interface = reqparse.RequestParser().add_argument(
|
||||
"interface_language", type=supported_language, required=True, location="json"
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/account/interface-language")
|
||||
class AccountInterfaceLanguageApi(Resource):
|
||||
@console_ns.expect(parser_interface)
|
||||
@console_ns.expect(console_ns.models[AccountInterfaceLanguagePayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(account_fields)
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
args = parser_interface.parse_args()
|
||||
payload = console_ns.payload or {}
|
||||
args = AccountInterfaceLanguagePayload.model_validate(payload)
|
||||
|
||||
updated_account = AccountService.update_account(current_user, interface_language=args["interface_language"])
|
||||
updated_account = AccountService.update_account(current_user, interface_language=args.interface_language)
|
||||
|
||||
return updated_account
|
||||
|
||||
|
||||
parser_theme = reqparse.RequestParser().add_argument(
|
||||
"interface_theme", type=str, choices=["light", "dark"], required=True, location="json"
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/account/interface-theme")
|
||||
class AccountInterfaceThemeApi(Resource):
|
||||
@console_ns.expect(parser_theme)
|
||||
@console_ns.expect(console_ns.models[AccountInterfaceThemePayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(account_fields)
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
args = parser_theme.parse_args()
|
||||
payload = console_ns.payload or {}
|
||||
args = AccountInterfaceThemePayload.model_validate(payload)
|
||||
|
||||
updated_account = AccountService.update_account(current_user, interface_theme=args["interface_theme"])
|
||||
updated_account = AccountService.update_account(current_user, interface_theme=args.interface_theme)
|
||||
|
||||
return updated_account
|
||||
|
||||
|
||||
parser_timezone = reqparse.RequestParser().add_argument("timezone", type=str, required=True, location="json")
|
||||
|
||||
|
||||
@console_ns.route("/account/timezone")
|
||||
class AccountTimezoneApi(Resource):
|
||||
@console_ns.expect(parser_timezone)
|
||||
@console_ns.expect(console_ns.models[AccountTimezonePayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(account_fields)
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
args = parser_timezone.parse_args()
|
||||
payload = console_ns.payload or {}
|
||||
args = AccountTimezonePayload.model_validate(payload)
|
||||
|
||||
# Validate timezone string, e.g. America/New_York, Asia/Shanghai
|
||||
if args["timezone"] not in pytz.all_timezones:
|
||||
raise ValueError("Invalid timezone string.")
|
||||
|
||||
updated_account = AccountService.update_account(current_user, timezone=args["timezone"])
|
||||
updated_account = AccountService.update_account(current_user, timezone=args.timezone)
|
||||
|
||||
return updated_account
|
||||
|
||||
|
||||
parser_pw = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("password", type=str, required=False, location="json")
|
||||
.add_argument("new_password", type=str, required=True, location="json")
|
||||
.add_argument("repeat_new_password", type=str, required=True, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/account/password")
|
||||
class AccountPasswordApi(Resource):
|
||||
@console_ns.expect(parser_pw)
|
||||
@console_ns.expect(console_ns.models[AccountPasswordPayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(account_fields)
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
args = parser_pw.parse_args()
|
||||
|
||||
if args["new_password"] != args["repeat_new_password"]:
|
||||
raise RepeatPasswordNotMatchError()
|
||||
payload = console_ns.payload or {}
|
||||
args = AccountPasswordPayload.model_validate(payload)
|
||||
|
||||
try:
|
||||
AccountService.update_account_password(current_user, args["password"], args["new_password"])
|
||||
AccountService.update_account_password(current_user, args.password, args.new_password)
|
||||
except ServiceCurrentPasswordIncorrectError:
|
||||
raise CurrentPasswordIncorrectError()
|
||||
|
||||
|
|
@ -316,25 +464,19 @@ class AccountDeleteVerifyApi(Resource):
|
|||
return {"result": "success", "data": token}
|
||||
|
||||
|
||||
parser_delete = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("token", type=str, required=True, location="json")
|
||||
.add_argument("code", type=str, required=True, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/account/delete")
|
||||
class AccountDeleteApi(Resource):
|
||||
@console_ns.expect(parser_delete)
|
||||
@console_ns.expect(console_ns.models[AccountDeletePayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
args = parser_delete.parse_args()
|
||||
payload = console_ns.payload or {}
|
||||
args = AccountDeletePayload.model_validate(payload)
|
||||
|
||||
if not AccountService.verify_account_deletion_code(args["token"], args["code"]):
|
||||
if not AccountService.verify_account_deletion_code(args.token, args.code):
|
||||
raise InvalidAccountDeletionCodeError()
|
||||
|
||||
AccountService.delete_account(account)
|
||||
|
|
@ -342,21 +484,15 @@ class AccountDeleteApi(Resource):
|
|||
return {"result": "success"}
|
||||
|
||||
|
||||
parser_feedback = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("email", type=str, required=True, location="json")
|
||||
.add_argument("feedback", type=str, required=True, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/account/delete/feedback")
|
||||
class AccountDeleteUpdateFeedbackApi(Resource):
|
||||
@console_ns.expect(parser_feedback)
|
||||
@console_ns.expect(console_ns.models[AccountDeletionFeedbackPayload.__name__])
|
||||
@setup_required
|
||||
def post(self):
|
||||
args = parser_feedback.parse_args()
|
||||
payload = console_ns.payload or {}
|
||||
args = AccountDeletionFeedbackPayload.model_validate(payload)
|
||||
|
||||
BillingService.update_account_deletion_feedback(args["email"], args["feedback"])
|
||||
BillingService.update_account_deletion_feedback(args.email, args.feedback)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
|
|
@ -379,14 +515,6 @@ class EducationVerifyApi(Resource):
|
|||
return BillingService.EducationIdentity.verify(account.id, account.email)
|
||||
|
||||
|
||||
parser_edu = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("token", type=str, required=True, location="json")
|
||||
.add_argument("institution", type=str, required=True, location="json")
|
||||
.add_argument("role", type=str, required=True, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/account/education")
|
||||
class EducationApi(Resource):
|
||||
status_fields = {
|
||||
|
|
@ -396,7 +524,7 @@ class EducationApi(Resource):
|
|||
"allow_refresh": fields.Boolean,
|
||||
}
|
||||
|
||||
@console_ns.expect(parser_edu)
|
||||
@console_ns.expect(console_ns.models[EducationActivatePayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -405,9 +533,10 @@ class EducationApi(Resource):
|
|||
def post(self):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
args = parser_edu.parse_args()
|
||||
payload = console_ns.payload or {}
|
||||
args = EducationActivatePayload.model_validate(payload)
|
||||
|
||||
return BillingService.EducationIdentity.activate(account, args["token"], args["institution"], args["role"])
|
||||
return BillingService.EducationIdentity.activate(account, args.token, args.institution, args.role)
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -425,14 +554,6 @@ class EducationApi(Resource):
|
|||
return res
|
||||
|
||||
|
||||
parser_autocomplete = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("keywords", type=str, required=True, location="args")
|
||||
.add_argument("page", type=int, required=False, location="args", default=0)
|
||||
.add_argument("limit", type=int, required=False, location="args", default=20)
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/account/education/autocomplete")
|
||||
class EducationAutoCompleteApi(Resource):
|
||||
data_fields = {
|
||||
|
|
@ -441,7 +562,7 @@ class EducationAutoCompleteApi(Resource):
|
|||
"has_next": fields.Boolean,
|
||||
}
|
||||
|
||||
@console_ns.expect(parser_autocomplete)
|
||||
@console_ns.expect(console_ns.models[EducationAutocompleteQuery.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -449,46 +570,39 @@ class EducationAutoCompleteApi(Resource):
|
|||
@cloud_edition_billing_enabled
|
||||
@marshal_with(data_fields)
|
||||
def get(self):
|
||||
args = parser_autocomplete.parse_args()
|
||||
payload = request.args.to_dict(flat=True) # type: ignore
|
||||
args = EducationAutocompleteQuery.model_validate(payload)
|
||||
|
||||
return BillingService.EducationIdentity.autocomplete(args["keywords"], args["page"], args["limit"])
|
||||
|
||||
|
||||
parser_change_email = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("email", type=email, required=True, location="json")
|
||||
.add_argument("language", type=str, required=False, location="json")
|
||||
.add_argument("phase", type=str, required=False, location="json")
|
||||
.add_argument("token", type=str, required=False, location="json")
|
||||
)
|
||||
return BillingService.EducationIdentity.autocomplete(args.keywords, args.page, args.limit)
|
||||
|
||||
|
||||
@console_ns.route("/account/change-email")
|
||||
class ChangeEmailSendEmailApi(Resource):
|
||||
@console_ns.expect(parser_change_email)
|
||||
@console_ns.expect(console_ns.models[ChangeEmailSendPayload.__name__])
|
||||
@enable_change_email
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
args = parser_change_email.parse_args()
|
||||
payload = console_ns.payload or {}
|
||||
args = ChangeEmailSendPayload.model_validate(payload)
|
||||
|
||||
ip_address = extract_remote_ip(request)
|
||||
if AccountService.is_email_send_ip_limit(ip_address):
|
||||
raise EmailSendIpLimitError()
|
||||
|
||||
if args["language"] is not None and args["language"] == "zh-Hans":
|
||||
if args.language is not None and args.language == "zh-Hans":
|
||||
language = "zh-Hans"
|
||||
else:
|
||||
language = "en-US"
|
||||
account = None
|
||||
user_email = args["email"]
|
||||
if args["phase"] is not None and args["phase"] == "new_email":
|
||||
if args["token"] is None:
|
||||
user_email = args.email
|
||||
if args.phase is not None and args.phase == "new_email":
|
||||
if args.token is None:
|
||||
raise InvalidTokenError()
|
||||
|
||||
reset_data = AccountService.get_change_email_data(args["token"])
|
||||
reset_data = AccountService.get_change_email_data(args.token)
|
||||
if reset_data is None:
|
||||
raise InvalidTokenError()
|
||||
user_email = reset_data.get("email", "")
|
||||
|
|
@ -497,118 +611,103 @@ class ChangeEmailSendEmailApi(Resource):
|
|||
raise InvalidEmailError()
|
||||
else:
|
||||
with Session(db.engine) as session:
|
||||
account = session.execute(select(Account).filter_by(email=args["email"])).scalar_one_or_none()
|
||||
account = session.execute(select(Account).filter_by(email=args.email)).scalar_one_or_none()
|
||||
if account is None:
|
||||
raise AccountNotFound()
|
||||
|
||||
token = AccountService.send_change_email_email(
|
||||
account=account, email=args["email"], old_email=user_email, language=language, phase=args["phase"]
|
||||
account=account, email=args.email, old_email=user_email, language=language, phase=args.phase
|
||||
)
|
||||
return {"result": "success", "data": token}
|
||||
|
||||
|
||||
parser_validity = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("email", type=email, required=True, location="json")
|
||||
.add_argument("code", type=str, required=True, location="json")
|
||||
.add_argument("token", type=str, required=True, nullable=False, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/account/change-email/validity")
|
||||
class ChangeEmailCheckApi(Resource):
|
||||
@console_ns.expect(parser_validity)
|
||||
@console_ns.expect(console_ns.models[ChangeEmailValidityPayload.__name__])
|
||||
@enable_change_email
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
args = parser_validity.parse_args()
|
||||
payload = console_ns.payload or {}
|
||||
args = ChangeEmailValidityPayload.model_validate(payload)
|
||||
|
||||
user_email = args["email"]
|
||||
user_email = args.email
|
||||
|
||||
is_change_email_error_rate_limit = AccountService.is_change_email_error_rate_limit(args["email"])
|
||||
is_change_email_error_rate_limit = AccountService.is_change_email_error_rate_limit(args.email)
|
||||
if is_change_email_error_rate_limit:
|
||||
raise EmailChangeLimitError()
|
||||
|
||||
token_data = AccountService.get_change_email_data(args["token"])
|
||||
token_data = AccountService.get_change_email_data(args.token)
|
||||
if token_data is None:
|
||||
raise InvalidTokenError()
|
||||
|
||||
if user_email != token_data.get("email"):
|
||||
raise InvalidEmailError()
|
||||
|
||||
if args["code"] != token_data.get("code"):
|
||||
AccountService.add_change_email_error_rate_limit(args["email"])
|
||||
if args.code != token_data.get("code"):
|
||||
AccountService.add_change_email_error_rate_limit(args.email)
|
||||
raise EmailCodeError()
|
||||
|
||||
# Verified, revoke the first token
|
||||
AccountService.revoke_change_email_token(args["token"])
|
||||
AccountService.revoke_change_email_token(args.token)
|
||||
|
||||
# Refresh token data by generating a new token
|
||||
_, new_token = AccountService.generate_change_email_token(
|
||||
user_email, code=args["code"], old_email=token_data.get("old_email"), additional_data={}
|
||||
user_email, code=args.code, old_email=token_data.get("old_email"), additional_data={}
|
||||
)
|
||||
|
||||
AccountService.reset_change_email_error_rate_limit(args["email"])
|
||||
AccountService.reset_change_email_error_rate_limit(args.email)
|
||||
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
|
||||
|
||||
|
||||
parser_reset = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("new_email", type=email, required=True, location="json")
|
||||
.add_argument("token", type=str, required=True, nullable=False, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/account/change-email/reset")
|
||||
class ChangeEmailResetApi(Resource):
|
||||
@console_ns.expect(parser_reset)
|
||||
@console_ns.expect(console_ns.models[ChangeEmailResetPayload.__name__])
|
||||
@enable_change_email
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(account_fields)
|
||||
def post(self):
|
||||
args = parser_reset.parse_args()
|
||||
payload = console_ns.payload or {}
|
||||
args = ChangeEmailResetPayload.model_validate(payload)
|
||||
|
||||
if AccountService.is_account_in_freeze(args["new_email"]):
|
||||
if AccountService.is_account_in_freeze(args.new_email):
|
||||
raise AccountInFreezeError()
|
||||
|
||||
if not AccountService.check_email_unique(args["new_email"]):
|
||||
if not AccountService.check_email_unique(args.new_email):
|
||||
raise EmailAlreadyInUseError()
|
||||
|
||||
reset_data = AccountService.get_change_email_data(args["token"])
|
||||
reset_data = AccountService.get_change_email_data(args.token)
|
||||
if not reset_data:
|
||||
raise InvalidTokenError()
|
||||
|
||||
AccountService.revoke_change_email_token(args["token"])
|
||||
AccountService.revoke_change_email_token(args.token)
|
||||
|
||||
old_email = reset_data.get("old_email", "")
|
||||
current_user, _ = current_account_with_tenant()
|
||||
if current_user.email != old_email:
|
||||
raise AccountNotFound()
|
||||
|
||||
updated_account = AccountService.update_account_email(current_user, email=args["new_email"])
|
||||
updated_account = AccountService.update_account_email(current_user, email=args.new_email)
|
||||
|
||||
AccountService.send_change_email_completed_notify_email(
|
||||
email=args["new_email"],
|
||||
email=args.new_email,
|
||||
)
|
||||
|
||||
return updated_account
|
||||
|
||||
|
||||
parser_check = reqparse.RequestParser().add_argument("email", type=email, required=True, location="json")
|
||||
|
||||
|
||||
@console_ns.route("/account/change-email/check-email-unique")
|
||||
class CheckEmailUnique(Resource):
|
||||
@console_ns.expect(parser_check)
|
||||
@console_ns.expect(console_ns.models[CheckEmailUniquePayload.__name__])
|
||||
@setup_required
|
||||
def post(self):
|
||||
args = parser_check.parse_args()
|
||||
if AccountService.is_account_in_freeze(args["email"]):
|
||||
payload = console_ns.payload or {}
|
||||
args = CheckEmailUniquePayload.model_validate(payload)
|
||||
if AccountService.is_account_in_freeze(args.email):
|
||||
raise AccountInFreezeError()
|
||||
if not AccountService.check_email_unique(args["email"]):
|
||||
if not AccountService.check_email_unique(args.email):
|
||||
raise EmailAlreadyInUseError()
|
||||
return {"result": "success"}
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
from urllib import parse
|
||||
|
||||
from flask import abort, request
|
||||
from flask_restx import Resource, marshal_with, reqparse
|
||||
from flask_restx import Resource, marshal_with
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
import services
|
||||
from configs import dify_config
|
||||
|
|
@ -31,6 +32,53 @@ from services.account_service import AccountService, RegisterService, TenantServ
|
|||
from services.errors.account import AccountAlreadyInTenantError
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class MemberInvitePayload(BaseModel):
|
||||
emails: list[str] = Field(default_factory=list)
|
||||
role: TenantAccountRole
|
||||
language: str | None = None
|
||||
|
||||
|
||||
class MemberRoleUpdatePayload(BaseModel):
|
||||
role: str
|
||||
|
||||
|
||||
class OwnerTransferEmailPayload(BaseModel):
|
||||
language: str | None = None
|
||||
|
||||
|
||||
class OwnerTransferCheckPayload(BaseModel):
|
||||
code: str
|
||||
token: str
|
||||
|
||||
|
||||
class OwnerTransferPayload(BaseModel):
|
||||
token: str
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
MemberInvitePayload.__name__,
|
||||
MemberInvitePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
console_ns.schema_model(
|
||||
MemberRoleUpdatePayload.__name__,
|
||||
MemberRoleUpdatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
console_ns.schema_model(
|
||||
OwnerTransferEmailPayload.__name__,
|
||||
OwnerTransferEmailPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
console_ns.schema_model(
|
||||
OwnerTransferCheckPayload.__name__,
|
||||
OwnerTransferCheckPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
console_ns.schema_model(
|
||||
OwnerTransferPayload.__name__,
|
||||
OwnerTransferPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/members")
|
||||
class MemberListApi(Resource):
|
||||
|
|
@ -48,29 +96,22 @@ class MemberListApi(Resource):
|
|||
return {"result": "success", "accounts": members}, 200
|
||||
|
||||
|
||||
parser_invite = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("emails", type=list, required=True, location="json")
|
||||
.add_argument("role", type=str, required=True, default="admin", location="json")
|
||||
.add_argument("language", type=str, required=False, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/members/invite-email")
|
||||
class MemberInviteEmailApi(Resource):
|
||||
"""Invite a new member by email."""
|
||||
|
||||
@console_ns.expect(parser_invite)
|
||||
@console_ns.expect(console_ns.models[MemberInvitePayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("members")
|
||||
def post(self):
|
||||
args = parser_invite.parse_args()
|
||||
payload = console_ns.payload or {}
|
||||
args = MemberInvitePayload.model_validate(payload)
|
||||
|
||||
invitee_emails = args["emails"]
|
||||
invitee_role = args["role"]
|
||||
interface_language = args["language"]
|
||||
invitee_emails = args.emails
|
||||
invitee_role = args.role
|
||||
interface_language = args.language
|
||||
if not TenantAccountRole.is_non_owner_role(invitee_role):
|
||||
return {"code": "invalid-role", "message": "Invalid role"}, 400
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
|
@ -146,20 +187,18 @@ class MemberCancelInviteApi(Resource):
|
|||
}, 200
|
||||
|
||||
|
||||
parser_update = reqparse.RequestParser().add_argument("role", type=str, required=True, location="json")
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/members/<uuid:member_id>/update-role")
|
||||
class MemberUpdateRoleApi(Resource):
|
||||
"""Update member role."""
|
||||
|
||||
@console_ns.expect(parser_update)
|
||||
@console_ns.expect(console_ns.models[MemberRoleUpdatePayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def put(self, member_id):
|
||||
args = parser_update.parse_args()
|
||||
new_role = args["role"]
|
||||
payload = console_ns.payload or {}
|
||||
args = MemberRoleUpdatePayload.model_validate(payload)
|
||||
new_role = args.role
|
||||
|
||||
if not TenantAccountRole.is_valid_role(new_role):
|
||||
return {"code": "invalid-role", "message": "Invalid role"}, 400
|
||||
|
|
@ -197,20 +236,18 @@ class DatasetOperatorMemberListApi(Resource):
|
|||
return {"result": "success", "accounts": members}, 200
|
||||
|
||||
|
||||
parser_send = reqparse.RequestParser().add_argument("language", type=str, required=False, location="json")
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/members/send-owner-transfer-confirm-email")
|
||||
class SendOwnerTransferEmailApi(Resource):
|
||||
"""Send owner transfer email."""
|
||||
|
||||
@console_ns.expect(parser_send)
|
||||
@console_ns.expect(console_ns.models[OwnerTransferEmailPayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@is_allow_transfer_owner
|
||||
def post(self):
|
||||
args = parser_send.parse_args()
|
||||
payload = console_ns.payload or {}
|
||||
args = OwnerTransferEmailPayload.model_validate(payload)
|
||||
ip_address = extract_remote_ip(request)
|
||||
if AccountService.is_email_send_ip_limit(ip_address):
|
||||
raise EmailSendIpLimitError()
|
||||
|
|
@ -221,7 +258,7 @@ class SendOwnerTransferEmailApi(Resource):
|
|||
if not TenantService.is_owner(current_user, current_user.current_tenant):
|
||||
raise NotOwnerError()
|
||||
|
||||
if args["language"] is not None and args["language"] == "zh-Hans":
|
||||
if args.language is not None and args.language == "zh-Hans":
|
||||
language = "zh-Hans"
|
||||
else:
|
||||
language = "en-US"
|
||||
|
|
@ -238,22 +275,16 @@ class SendOwnerTransferEmailApi(Resource):
|
|||
return {"result": "success", "data": token}
|
||||
|
||||
|
||||
parser_owner = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("code", type=str, required=True, location="json")
|
||||
.add_argument("token", type=str, required=True, nullable=False, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/members/owner-transfer-check")
|
||||
class OwnerTransferCheckApi(Resource):
|
||||
@console_ns.expect(parser_owner)
|
||||
@console_ns.expect(console_ns.models[OwnerTransferCheckPayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@is_allow_transfer_owner
|
||||
def post(self):
|
||||
args = parser_owner.parse_args()
|
||||
payload = console_ns.payload or {}
|
||||
args = OwnerTransferCheckPayload.model_validate(payload)
|
||||
# check if the current user is the owner of the workspace
|
||||
current_user, _ = current_account_with_tenant()
|
||||
if not current_user.current_tenant:
|
||||
|
|
@ -267,41 +298,37 @@ class OwnerTransferCheckApi(Resource):
|
|||
if is_owner_transfer_error_rate_limit:
|
||||
raise OwnerTransferLimitError()
|
||||
|
||||
token_data = AccountService.get_owner_transfer_data(args["token"])
|
||||
token_data = AccountService.get_owner_transfer_data(args.token)
|
||||
if token_data is None:
|
||||
raise InvalidTokenError()
|
||||
|
||||
if user_email != token_data.get("email"):
|
||||
raise InvalidEmailError()
|
||||
|
||||
if args["code"] != token_data.get("code"):
|
||||
if args.code != token_data.get("code"):
|
||||
AccountService.add_owner_transfer_error_rate_limit(user_email)
|
||||
raise EmailCodeError()
|
||||
|
||||
# Verified, revoke the first token
|
||||
AccountService.revoke_owner_transfer_token(args["token"])
|
||||
AccountService.revoke_owner_transfer_token(args.token)
|
||||
|
||||
# Refresh token data by generating a new token
|
||||
_, new_token = AccountService.generate_owner_transfer_token(user_email, code=args["code"], additional_data={})
|
||||
_, new_token = AccountService.generate_owner_transfer_token(user_email, code=args.code, additional_data={})
|
||||
|
||||
AccountService.reset_owner_transfer_error_rate_limit(user_email)
|
||||
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
|
||||
|
||||
|
||||
parser_owner_transfer = reqparse.RequestParser().add_argument(
|
||||
"token", type=str, required=True, nullable=False, location="json"
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/members/<uuid:member_id>/owner-transfer")
|
||||
class OwnerTransfer(Resource):
|
||||
@console_ns.expect(parser_owner_transfer)
|
||||
@console_ns.expect(console_ns.models[OwnerTransferPayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@is_allow_transfer_owner
|
||||
def post(self, member_id):
|
||||
args = parser_owner_transfer.parse_args()
|
||||
payload = console_ns.payload or {}
|
||||
args = OwnerTransferPayload.model_validate(payload)
|
||||
|
||||
# check if the current user is the owner of the workspace
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
|
@ -313,14 +340,14 @@ class OwnerTransfer(Resource):
|
|||
if current_user.id == str(member_id):
|
||||
raise CannotTransferOwnerToSelfError()
|
||||
|
||||
transfer_token_data = AccountService.get_owner_transfer_data(args["token"])
|
||||
transfer_token_data = AccountService.get_owner_transfer_data(args.token)
|
||||
if not transfer_token_data:
|
||||
raise InvalidTokenError()
|
||||
|
||||
if transfer_token_data.get("email") != current_user.email:
|
||||
raise InvalidEmailError()
|
||||
|
||||
AccountService.revoke_owner_transfer_token(args["token"])
|
||||
AccountService.revoke_owner_transfer_token(args.token)
|
||||
|
||||
member = db.session.get(Account, str(member_id))
|
||||
if not member:
|
||||
|
|
|
|||
|
|
@ -1,31 +1,123 @@
|
|||
import io
|
||||
from typing import Any, Literal
|
||||
|
||||
from flask import send_file
|
||||
from flask_restx import Resource, reqparse
|
||||
from flask import request, send_file
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from libs.helper import StrLen, uuid_value
|
||||
from libs.helper import uuid_value
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from services.billing_service import BillingService
|
||||
from services.model_provider_service import ModelProviderService
|
||||
|
||||
parser_model = reqparse.RequestParser().add_argument(
|
||||
"model_type",
|
||||
type=str,
|
||||
required=False,
|
||||
nullable=True,
|
||||
choices=[mt.value for mt in ModelType],
|
||||
location="args",
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class ParserModelList(BaseModel):
|
||||
model_type: ModelType | None = None
|
||||
|
||||
|
||||
class ParserCredentialId(BaseModel):
|
||||
credential_id: str | None = None
|
||||
|
||||
@field_validator("credential_id")
|
||||
@classmethod
|
||||
def validate_optional_credential_id(cls, value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return value
|
||||
return uuid_value(value)
|
||||
|
||||
|
||||
class ParserCredentialCreate(BaseModel):
|
||||
credentials: dict[str, Any]
|
||||
name: str | None = Field(default=None, max_length=30)
|
||||
|
||||
|
||||
class ParserCredentialUpdate(BaseModel):
|
||||
credential_id: str
|
||||
credentials: dict[str, Any]
|
||||
name: str | None = Field(default=None, max_length=30)
|
||||
|
||||
@field_validator("credential_id")
|
||||
@classmethod
|
||||
def validate_update_credential_id(cls, value: str) -> str:
|
||||
return uuid_value(value)
|
||||
|
||||
|
||||
class ParserCredentialDelete(BaseModel):
|
||||
credential_id: str
|
||||
|
||||
@field_validator("credential_id")
|
||||
@classmethod
|
||||
def validate_delete_credential_id(cls, value: str) -> str:
|
||||
return uuid_value(value)
|
||||
|
||||
|
||||
class ParserCredentialSwitch(BaseModel):
|
||||
credential_id: str
|
||||
|
||||
@field_validator("credential_id")
|
||||
@classmethod
|
||||
def validate_switch_credential_id(cls, value: str) -> str:
|
||||
return uuid_value(value)
|
||||
|
||||
|
||||
class ParserCredentialValidate(BaseModel):
|
||||
credentials: dict[str, Any]
|
||||
|
||||
|
||||
class ParserPreferredProviderType(BaseModel):
|
||||
preferred_provider_type: Literal["system", "custom"]
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserModelList.__name__, ParserModelList.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserCredentialId.__name__,
|
||||
ParserCredentialId.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserCredentialCreate.__name__,
|
||||
ParserCredentialCreate.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserCredentialUpdate.__name__,
|
||||
ParserCredentialUpdate.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserCredentialDelete.__name__,
|
||||
ParserCredentialDelete.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserCredentialSwitch.__name__,
|
||||
ParserCredentialSwitch.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserCredentialValidate.__name__,
|
||||
ParserCredentialValidate.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserPreferredProviderType.__name__,
|
||||
ParserPreferredProviderType.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/model-providers")
|
||||
class ModelProviderListApi(Resource):
|
||||
@console_ns.expect(parser_model)
|
||||
@console_ns.expect(console_ns.models[ParserModelList.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -33,38 +125,18 @@ class ModelProviderListApi(Resource):
|
|||
_, current_tenant_id = current_account_with_tenant()
|
||||
tenant_id = current_tenant_id
|
||||
|
||||
args = parser_model.parse_args()
|
||||
payload = request.args.to_dict(flat=True) # type: ignore
|
||||
args = ParserModelList.model_validate(payload)
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
provider_list = model_provider_service.get_provider_list(tenant_id=tenant_id, model_type=args.get("model_type"))
|
||||
provider_list = model_provider_service.get_provider_list(tenant_id=tenant_id, model_type=args.model_type)
|
||||
|
||||
return jsonable_encoder({"data": provider_list})
|
||||
|
||||
|
||||
parser_cred = reqparse.RequestParser().add_argument(
|
||||
"credential_id", type=uuid_value, required=False, nullable=True, location="args"
|
||||
)
|
||||
parser_post_cred = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
|
||||
)
|
||||
|
||||
parser_put_cred = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
|
||||
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
|
||||
)
|
||||
|
||||
parser_delete_cred = reqparse.RequestParser().add_argument(
|
||||
"credential_id", type=uuid_value, required=True, nullable=False, location="json"
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/model-providers/<path:provider>/credentials")
|
||||
class ModelProviderCredentialApi(Resource):
|
||||
@console_ns.expect(parser_cred)
|
||||
@console_ns.expect(console_ns.models[ParserCredentialId.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -72,23 +144,25 @@ class ModelProviderCredentialApi(Resource):
|
|||
_, current_tenant_id = current_account_with_tenant()
|
||||
tenant_id = current_tenant_id
|
||||
# if credential_id is not provided, return current used credential
|
||||
args = parser_cred.parse_args()
|
||||
payload = request.args.to_dict(flat=True) # type: ignore
|
||||
args = ParserCredentialId.model_validate(payload)
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
credentials = model_provider_service.get_provider_credential(
|
||||
tenant_id=tenant_id, provider=provider, credential_id=args.get("credential_id")
|
||||
tenant_id=tenant_id, provider=provider, credential_id=args.credential_id
|
||||
)
|
||||
|
||||
return {"credentials": credentials}
|
||||
|
||||
@console_ns.expect(parser_post_cred)
|
||||
@console_ns.expect(console_ns.models[ParserCredentialCreate.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
args = parser_post_cred.parse_args()
|
||||
payload = console_ns.payload or {}
|
||||
args = ParserCredentialCreate.model_validate(payload)
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
|
||||
|
|
@ -96,15 +170,15 @@ class ModelProviderCredentialApi(Resource):
|
|||
model_provider_service.create_provider_credential(
|
||||
tenant_id=current_tenant_id,
|
||||
provider=provider,
|
||||
credentials=args["credentials"],
|
||||
credential_name=args["name"],
|
||||
credentials=args.credentials,
|
||||
credential_name=args.name,
|
||||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
raise ValueError(str(ex))
|
||||
|
||||
return {"result": "success"}, 201
|
||||
|
||||
@console_ns.expect(parser_put_cred)
|
||||
@console_ns.expect(console_ns.models[ParserCredentialUpdate.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
|
|
@ -112,7 +186,8 @@ class ModelProviderCredentialApi(Resource):
|
|||
def put(self, provider: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
args = parser_put_cred.parse_args()
|
||||
payload = console_ns.payload or {}
|
||||
args = ParserCredentialUpdate.model_validate(payload)
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
|
||||
|
|
@ -120,71 +195,64 @@ class ModelProviderCredentialApi(Resource):
|
|||
model_provider_service.update_provider_credential(
|
||||
tenant_id=current_tenant_id,
|
||||
provider=provider,
|
||||
credentials=args["credentials"],
|
||||
credential_id=args["credential_id"],
|
||||
credential_name=args["name"],
|
||||
credentials=args.credentials,
|
||||
credential_id=args.credential_id,
|
||||
credential_name=args.name,
|
||||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
raise ValueError(str(ex))
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
@console_ns.expect(parser_delete_cred)
|
||||
@console_ns.expect(console_ns.models[ParserCredentialDelete.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def delete(self, provider: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
args = parser_delete_cred.parse_args()
|
||||
payload = console_ns.payload or {}
|
||||
args = ParserCredentialDelete.model_validate(payload)
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
model_provider_service.remove_provider_credential(
|
||||
tenant_id=current_tenant_id, provider=provider, credential_id=args["credential_id"]
|
||||
tenant_id=current_tenant_id, provider=provider, credential_id=args.credential_id
|
||||
)
|
||||
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
parser_switch = reqparse.RequestParser().add_argument(
|
||||
"credential_id", type=str, required=True, nullable=False, location="json"
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/model-providers/<path:provider>/credentials/switch")
|
||||
class ModelProviderCredentialSwitchApi(Resource):
|
||||
@console_ns.expect(parser_switch)
|
||||
@console_ns.expect(console_ns.models[ParserCredentialSwitch.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
args = parser_switch.parse_args()
|
||||
payload = console_ns.payload or {}
|
||||
args = ParserCredentialSwitch.model_validate(payload)
|
||||
|
||||
service = ModelProviderService()
|
||||
service.switch_active_provider_credential(
|
||||
tenant_id=current_tenant_id,
|
||||
provider=provider,
|
||||
credential_id=args["credential_id"],
|
||||
credential_id=args.credential_id,
|
||||
)
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
parser_validate = reqparse.RequestParser().add_argument(
|
||||
"credentials", type=dict, required=True, nullable=False, location="json"
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/model-providers/<path:provider>/credentials/validate")
|
||||
class ModelProviderValidateApi(Resource):
|
||||
@console_ns.expect(parser_validate)
|
||||
@console_ns.expect(console_ns.models[ParserCredentialValidate.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
args = parser_validate.parse_args()
|
||||
payload = console_ns.payload or {}
|
||||
args = ParserCredentialValidate.model_validate(payload)
|
||||
|
||||
tenant_id = current_tenant_id
|
||||
|
||||
|
|
@ -195,7 +263,7 @@ class ModelProviderValidateApi(Resource):
|
|||
|
||||
try:
|
||||
model_provider_service.validate_provider_credentials(
|
||||
tenant_id=tenant_id, provider=provider, credentials=args["credentials"]
|
||||
tenant_id=tenant_id, provider=provider, credentials=args.credentials
|
||||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
result = False
|
||||
|
|
@ -228,19 +296,9 @@ class ModelProviderIconApi(Resource):
|
|||
return send_file(io.BytesIO(icon), mimetype=mimetype)
|
||||
|
||||
|
||||
parser_preferred = reqparse.RequestParser().add_argument(
|
||||
"preferred_provider_type",
|
||||
type=str,
|
||||
required=True,
|
||||
nullable=False,
|
||||
choices=["system", "custom"],
|
||||
location="json",
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/model-providers/<path:provider>/preferred-provider-type")
|
||||
class PreferredProviderTypeUpdateApi(Resource):
|
||||
@console_ns.expect(parser_preferred)
|
||||
@console_ns.expect(console_ns.models[ParserPreferredProviderType.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
|
|
@ -250,11 +308,12 @@ class PreferredProviderTypeUpdateApi(Resource):
|
|||
|
||||
tenant_id = current_tenant_id
|
||||
|
||||
args = parser_preferred.parse_args()
|
||||
payload = console_ns.payload or {}
|
||||
args = ParserPreferredProviderType.model_validate(payload)
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
model_provider_service.switch_preferred_provider(
|
||||
tenant_id=tenant_id, provider=provider, preferred_provider_type=args["preferred_provider_type"]
|
||||
tenant_id=tenant_id, provider=provider, preferred_provider_type=args.preferred_provider_type
|
||||
)
|
||||
|
||||
return {"result": "success"}
|
||||
|
|
|
|||
|
|
@ -1,52 +1,172 @@
|
|||
import logging
|
||||
from typing import Any
|
||||
|
||||
from flask_restx import Resource, reqparse
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from libs.helper import StrLen, uuid_value
|
||||
from libs.helper import uuid_value
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from services.model_load_balancing_service import ModelLoadBalancingService
|
||||
from services.model_provider_service import ModelProviderService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
parser_get_default = reqparse.RequestParser().add_argument(
|
||||
"model_type",
|
||||
type=str,
|
||||
required=True,
|
||||
nullable=False,
|
||||
choices=[mt.value for mt in ModelType],
|
||||
location="args",
|
||||
class ParserGetDefault(BaseModel):
|
||||
model_type: ModelType
|
||||
|
||||
|
||||
class ParserPostDefault(BaseModel):
|
||||
class Inner(BaseModel):
|
||||
model_type: ModelType
|
||||
model: str
|
||||
provider: str | None = None
|
||||
|
||||
model_settings: list[Inner]
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserGetDefault.__name__, ParserGetDefault.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
parser_post_default = reqparse.RequestParser().add_argument(
|
||||
"model_settings", type=list, required=True, nullable=False, location="json"
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserPostDefault.__name__, ParserPostDefault.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
|
||||
|
||||
class ParserDeleteModels(BaseModel):
|
||||
model: str
|
||||
model_type: ModelType
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserDeleteModels.__name__, ParserDeleteModels.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
|
||||
|
||||
class LoadBalancingPayload(BaseModel):
|
||||
configs: list[dict[str, Any]] | None = None
|
||||
enabled: bool | None = None
|
||||
|
||||
|
||||
class ParserPostModels(BaseModel):
|
||||
model: str
|
||||
model_type: ModelType
|
||||
load_balancing: LoadBalancingPayload | None = None
|
||||
config_from: str | None = None
|
||||
credential_id: str | None = None
|
||||
|
||||
@field_validator("credential_id")
|
||||
@classmethod
|
||||
def validate_credential_id(cls, value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return value
|
||||
return uuid_value(value)
|
||||
|
||||
|
||||
class ParserGetCredentials(BaseModel):
|
||||
model: str
|
||||
model_type: ModelType
|
||||
config_from: str | None = None
|
||||
credential_id: str | None = None
|
||||
|
||||
@field_validator("credential_id")
|
||||
@classmethod
|
||||
def validate_get_credential_id(cls, value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return value
|
||||
return uuid_value(value)
|
||||
|
||||
|
||||
class ParserCredentialBase(BaseModel):
|
||||
model: str
|
||||
model_type: ModelType
|
||||
|
||||
|
||||
class ParserCreateCredential(ParserCredentialBase):
|
||||
name: str | None = Field(default=None, max_length=30)
|
||||
credentials: dict[str, Any]
|
||||
|
||||
|
||||
class ParserUpdateCredential(ParserCredentialBase):
|
||||
credential_id: str
|
||||
credentials: dict[str, Any]
|
||||
name: str | None = Field(default=None, max_length=30)
|
||||
|
||||
@field_validator("credential_id")
|
||||
@classmethod
|
||||
def validate_update_credential_id(cls, value: str) -> str:
|
||||
return uuid_value(value)
|
||||
|
||||
|
||||
class ParserDeleteCredential(ParserCredentialBase):
|
||||
credential_id: str
|
||||
|
||||
@field_validator("credential_id")
|
||||
@classmethod
|
||||
def validate_delete_credential_id(cls, value: str) -> str:
|
||||
return uuid_value(value)
|
||||
|
||||
|
||||
class ParserParameter(BaseModel):
|
||||
model: str
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserPostModels.__name__, ParserPostModels.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserGetCredentials.__name__,
|
||||
ParserGetCredentials.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserCreateCredential.__name__,
|
||||
ParserCreateCredential.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserUpdateCredential.__name__,
|
||||
ParserUpdateCredential.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserDeleteCredential.__name__,
|
||||
ParserDeleteCredential.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserParameter.__name__, ParserParameter.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/default-model")
|
||||
class DefaultModelApi(Resource):
|
||||
@console_ns.expect(parser_get_default)
|
||||
@console_ns.expect(console_ns.models[ParserGetDefault.__name__], validate=True)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
args = parser_get_default.parse_args()
|
||||
args = ParserGetDefault.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
default_model_entity = model_provider_service.get_default_model_of_model_type(
|
||||
tenant_id=tenant_id, model_type=args["model_type"]
|
||||
tenant_id=tenant_id, model_type=args.model_type
|
||||
)
|
||||
|
||||
return jsonable_encoder({"data": default_model_entity})
|
||||
|
||||
@console_ns.expect(parser_post_default)
|
||||
@console_ns.expect(console_ns.models[ParserPostDefault.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
|
|
@ -54,66 +174,31 @@ class DefaultModelApi(Resource):
|
|||
def post(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
args = parser_post_default.parse_args()
|
||||
args = ParserPostDefault.model_validate(console_ns.payload)
|
||||
model_provider_service = ModelProviderService()
|
||||
model_settings = args["model_settings"]
|
||||
model_settings = args.model_settings
|
||||
for model_setting in model_settings:
|
||||
if "model_type" not in model_setting or model_setting["model_type"] not in [mt.value for mt in ModelType]:
|
||||
raise ValueError("invalid model type")
|
||||
|
||||
if "provider" not in model_setting:
|
||||
if model_setting.provider is None:
|
||||
continue
|
||||
|
||||
if "model" not in model_setting:
|
||||
raise ValueError("invalid model")
|
||||
|
||||
try:
|
||||
model_provider_service.update_default_model_of_model_type(
|
||||
tenant_id=tenant_id,
|
||||
model_type=model_setting["model_type"],
|
||||
provider=model_setting["provider"],
|
||||
model=model_setting["model"],
|
||||
model_type=model_setting.model_type,
|
||||
provider=model_setting.provider,
|
||||
model=model_setting.model,
|
||||
)
|
||||
except Exception as ex:
|
||||
logger.exception(
|
||||
"Failed to update default model, model type: %s, model: %s",
|
||||
model_setting["model_type"],
|
||||
model_setting.get("model"),
|
||||
model_setting.model_type,
|
||||
model_setting.model,
|
||||
)
|
||||
raise ex
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
parser_post_models = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("model", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument(
|
||||
"model_type",
|
||||
type=str,
|
||||
required=True,
|
||||
nullable=False,
|
||||
choices=[mt.value for mt in ModelType],
|
||||
location="json",
|
||||
)
|
||||
.add_argument("load_balancing", type=dict, required=False, nullable=True, location="json")
|
||||
.add_argument("config_from", type=str, required=False, nullable=True, location="json")
|
||||
.add_argument("credential_id", type=uuid_value, required=False, nullable=True, location="json")
|
||||
)
|
||||
parser_delete_models = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("model", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument(
|
||||
"model_type",
|
||||
type=str,
|
||||
required=True,
|
||||
nullable=False,
|
||||
choices=[mt.value for mt in ModelType],
|
||||
location="json",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/model-providers/<path:provider>/models")
|
||||
class ModelProviderModelApi(Resource):
|
||||
@setup_required
|
||||
|
|
@ -127,7 +212,7 @@ class ModelProviderModelApi(Resource):
|
|||
|
||||
return jsonable_encoder({"data": models})
|
||||
|
||||
@console_ns.expect(parser_post_models)
|
||||
@console_ns.expect(console_ns.models[ParserPostModels.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
|
|
@ -135,45 +220,45 @@ class ModelProviderModelApi(Resource):
|
|||
def post(self, provider: str):
|
||||
# To save the model's load balance configs
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
args = parser_post_models.parse_args()
|
||||
args = ParserPostModels.model_validate(console_ns.payload)
|
||||
|
||||
if args.get("config_from", "") == "custom-model":
|
||||
if not args.get("credential_id"):
|
||||
if args.config_from == "custom-model":
|
||||
if not args.credential_id:
|
||||
raise ValueError("credential_id is required when configuring a custom-model")
|
||||
service = ModelProviderService()
|
||||
service.switch_active_custom_model_credential(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
model_type=args["model_type"],
|
||||
model=args["model"],
|
||||
credential_id=args["credential_id"],
|
||||
model_type=args.model_type,
|
||||
model=args.model,
|
||||
credential_id=args.credential_id,
|
||||
)
|
||||
|
||||
model_load_balancing_service = ModelLoadBalancingService()
|
||||
|
||||
if "load_balancing" in args and args["load_balancing"] and "configs" in args["load_balancing"]:
|
||||
if args.load_balancing and args.load_balancing.configs:
|
||||
# save load balancing configs
|
||||
model_load_balancing_service.update_load_balancing_configs(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
model=args["model"],
|
||||
model_type=args["model_type"],
|
||||
configs=args["load_balancing"]["configs"],
|
||||
config_from=args.get("config_from", ""),
|
||||
model=args.model,
|
||||
model_type=args.model_type,
|
||||
configs=args.load_balancing.configs,
|
||||
config_from=args.config_from or "",
|
||||
)
|
||||
|
||||
if args.get("load_balancing", {}).get("enabled"):
|
||||
if args.load_balancing.enabled:
|
||||
model_load_balancing_service.enable_model_load_balancing(
|
||||
tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"]
|
||||
tenant_id=tenant_id, provider=provider, model=args.model, model_type=args.model_type
|
||||
)
|
||||
else:
|
||||
model_load_balancing_service.disable_model_load_balancing(
|
||||
tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"]
|
||||
tenant_id=tenant_id, provider=provider, model=args.model, model_type=args.model_type
|
||||
)
|
||||
|
||||
return {"result": "success"}, 200
|
||||
|
||||
@console_ns.expect(parser_delete_models)
|
||||
@console_ns.expect(console_ns.models[ParserDeleteModels.__name__], validate=True)
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
|
|
@ -181,113 +266,53 @@ class ModelProviderModelApi(Resource):
|
|||
def delete(self, provider: str):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
args = parser_delete_models.parse_args()
|
||||
args = ParserDeleteModels.model_validate(console_ns.payload)
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
model_provider_service.remove_model(
|
||||
tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"]
|
||||
tenant_id=tenant_id, provider=provider, model=args.model, model_type=args.model_type
|
||||
)
|
||||
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
parser_get_credentials = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("model", type=str, required=True, nullable=False, location="args")
|
||||
.add_argument(
|
||||
"model_type",
|
||||
type=str,
|
||||
required=True,
|
||||
nullable=False,
|
||||
choices=[mt.value for mt in ModelType],
|
||||
location="args",
|
||||
)
|
||||
.add_argument("config_from", type=str, required=False, nullable=True, location="args")
|
||||
.add_argument("credential_id", type=uuid_value, required=False, nullable=True, location="args")
|
||||
)
|
||||
|
||||
|
||||
parser_post_cred = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("model", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument(
|
||||
"model_type",
|
||||
type=str,
|
||||
required=True,
|
||||
nullable=False,
|
||||
choices=[mt.value for mt in ModelType],
|
||||
location="json",
|
||||
)
|
||||
.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
|
||||
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
)
|
||||
parser_put_cred = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("model", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument(
|
||||
"model_type",
|
||||
type=str,
|
||||
required=True,
|
||||
nullable=False,
|
||||
choices=[mt.value for mt in ModelType],
|
||||
location="json",
|
||||
)
|
||||
.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
|
||||
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
|
||||
)
|
||||
parser_delete_cred = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("model", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument(
|
||||
"model_type",
|
||||
type=str,
|
||||
required=True,
|
||||
nullable=False,
|
||||
choices=[mt.value for mt in ModelType],
|
||||
location="json",
|
||||
)
|
||||
.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/model-providers/<path:provider>/models/credentials")
|
||||
class ModelProviderModelCredentialApi(Resource):
|
||||
@console_ns.expect(parser_get_credentials)
|
||||
@console_ns.expect(console_ns.models[ParserGetCredentials.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider: str):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
args = parser_get_credentials.parse_args()
|
||||
args = ParserGetCredentials.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
current_credential = model_provider_service.get_model_credential(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
model_type=args["model_type"],
|
||||
model=args["model"],
|
||||
credential_id=args.get("credential_id"),
|
||||
model_type=args.model_type,
|
||||
model=args.model,
|
||||
credential_id=args.credential_id,
|
||||
)
|
||||
|
||||
model_load_balancing_service = ModelLoadBalancingService()
|
||||
is_load_balancing_enabled, load_balancing_configs = model_load_balancing_service.get_load_balancing_configs(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
model=args["model"],
|
||||
model_type=args["model_type"],
|
||||
config_from=args.get("config_from", ""),
|
||||
model=args.model,
|
||||
model_type=args.model_type,
|
||||
config_from=args.config_from or "",
|
||||
)
|
||||
|
||||
if args.get("config_from", "") == "predefined-model":
|
||||
if args.config_from == "predefined-model":
|
||||
available_credentials = model_provider_service.provider_manager.get_provider_available_credentials(
|
||||
tenant_id=tenant_id, provider_name=provider
|
||||
)
|
||||
else:
|
||||
model_type = ModelType.value_of(args["model_type"]).to_origin_model_type()
|
||||
model_type = args.model_type
|
||||
available_credentials = model_provider_service.provider_manager.get_provider_model_available_credentials(
|
||||
tenant_id=tenant_id, provider_name=provider, model_type=model_type, model_name=args["model"]
|
||||
tenant_id=tenant_id, provider_name=provider, model_type=model_type, model_name=args.model
|
||||
)
|
||||
|
||||
return jsonable_encoder(
|
||||
|
|
@ -304,7 +329,7 @@ class ModelProviderModelCredentialApi(Resource):
|
|||
}
|
||||
)
|
||||
|
||||
@console_ns.expect(parser_post_cred)
|
||||
@console_ns.expect(console_ns.models[ParserCreateCredential.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
|
|
@ -312,7 +337,7 @@ class ModelProviderModelCredentialApi(Resource):
|
|||
def post(self, provider: str):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
args = parser_post_cred.parse_args()
|
||||
args = ParserCreateCredential.model_validate(console_ns.payload)
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
|
||||
|
|
@ -320,30 +345,30 @@ class ModelProviderModelCredentialApi(Resource):
|
|||
model_provider_service.create_model_credential(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
model=args["model"],
|
||||
model_type=args["model_type"],
|
||||
credentials=args["credentials"],
|
||||
credential_name=args["name"],
|
||||
model=args.model,
|
||||
model_type=args.model_type,
|
||||
credentials=args.credentials,
|
||||
credential_name=args.name,
|
||||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
logger.exception(
|
||||
"Failed to save model credentials, tenant_id: %s, model: %s, model_type: %s",
|
||||
tenant_id,
|
||||
args.get("model"),
|
||||
args.get("model_type"),
|
||||
args.model,
|
||||
args.model_type,
|
||||
)
|
||||
raise ValueError(str(ex))
|
||||
|
||||
return {"result": "success"}, 201
|
||||
|
||||
@console_ns.expect(parser_put_cred)
|
||||
@console_ns.expect(console_ns.models[ParserUpdateCredential.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def put(self, provider: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
args = parser_put_cred.parse_args()
|
||||
args = ParserUpdateCredential.model_validate(console_ns.payload)
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
|
||||
|
|
@ -351,106 +376,87 @@ class ModelProviderModelCredentialApi(Resource):
|
|||
model_provider_service.update_model_credential(
|
||||
tenant_id=current_tenant_id,
|
||||
provider=provider,
|
||||
model_type=args["model_type"],
|
||||
model=args["model"],
|
||||
credentials=args["credentials"],
|
||||
credential_id=args["credential_id"],
|
||||
credential_name=args["name"],
|
||||
model_type=args.model_type,
|
||||
model=args.model,
|
||||
credentials=args.credentials,
|
||||
credential_id=args.credential_id,
|
||||
credential_name=args.name,
|
||||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
raise ValueError(str(ex))
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
@console_ns.expect(parser_delete_cred)
|
||||
@console_ns.expect(console_ns.models[ParserDeleteCredential.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def delete(self, provider: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
args = parser_delete_cred.parse_args()
|
||||
args = ParserDeleteCredential.model_validate(console_ns.payload)
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
model_provider_service.remove_model_credential(
|
||||
tenant_id=current_tenant_id,
|
||||
provider=provider,
|
||||
model_type=args["model_type"],
|
||||
model=args["model"],
|
||||
credential_id=args["credential_id"],
|
||||
model_type=args.model_type,
|
||||
model=args.model,
|
||||
credential_id=args.credential_id,
|
||||
)
|
||||
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
parser_switch = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("model", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument(
|
||||
"model_type",
|
||||
type=str,
|
||||
required=True,
|
||||
nullable=False,
|
||||
choices=[mt.value for mt in ModelType],
|
||||
location="json",
|
||||
)
|
||||
.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
|
||||
class ParserSwitch(BaseModel):
|
||||
model: str
|
||||
model_type: ModelType
|
||||
credential_id: str
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserSwitch.__name__, ParserSwitch.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/model-providers/<path:provider>/models/credentials/switch")
|
||||
class ModelProviderModelCredentialSwitchApi(Resource):
|
||||
@console_ns.expect(parser_switch)
|
||||
@console_ns.expect(console_ns.models[ParserSwitch.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
args = parser_switch.parse_args()
|
||||
args = ParserSwitch.model_validate(console_ns.payload)
|
||||
|
||||
service = ModelProviderService()
|
||||
service.add_model_credential_to_model_list(
|
||||
tenant_id=current_tenant_id,
|
||||
provider=provider,
|
||||
model_type=args["model_type"],
|
||||
model=args["model"],
|
||||
credential_id=args["credential_id"],
|
||||
model_type=args.model_type,
|
||||
model=args.model,
|
||||
credential_id=args.credential_id,
|
||||
)
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
parser_model_enable_disable = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("model", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument(
|
||||
"model_type",
|
||||
type=str,
|
||||
required=True,
|
||||
nullable=False,
|
||||
choices=[mt.value for mt in ModelType],
|
||||
location="json",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
"/workspaces/current/model-providers/<path:provider>/models/enable", endpoint="model-provider-model-enable"
|
||||
)
|
||||
class ModelProviderModelEnableApi(Resource):
|
||||
@console_ns.expect(parser_model_enable_disable)
|
||||
@console_ns.expect(console_ns.models[ParserDeleteModels.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def patch(self, provider: str):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
args = parser_model_enable_disable.parse_args()
|
||||
args = ParserDeleteModels.model_validate(console_ns.payload)
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
model_provider_service.enable_model(
|
||||
tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"]
|
||||
tenant_id=tenant_id, provider=provider, model=args.model, model_type=args.model_type
|
||||
)
|
||||
|
||||
return {"result": "success"}
|
||||
|
|
@ -460,48 +466,43 @@ class ModelProviderModelEnableApi(Resource):
|
|||
"/workspaces/current/model-providers/<path:provider>/models/disable", endpoint="model-provider-model-disable"
|
||||
)
|
||||
class ModelProviderModelDisableApi(Resource):
|
||||
@console_ns.expect(parser_model_enable_disable)
|
||||
@console_ns.expect(console_ns.models[ParserDeleteModels.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def patch(self, provider: str):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
args = parser_model_enable_disable.parse_args()
|
||||
args = ParserDeleteModels.model_validate(console_ns.payload)
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
model_provider_service.disable_model(
|
||||
tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"]
|
||||
tenant_id=tenant_id, provider=provider, model=args.model, model_type=args.model_type
|
||||
)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
parser_validate = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("model", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument(
|
||||
"model_type",
|
||||
type=str,
|
||||
required=True,
|
||||
nullable=False,
|
||||
choices=[mt.value for mt in ModelType],
|
||||
location="json",
|
||||
)
|
||||
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
class ParserValidate(BaseModel):
|
||||
model: str
|
||||
model_type: ModelType
|
||||
credentials: dict
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserValidate.__name__, ParserValidate.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/model-providers/<path:provider>/models/credentials/validate")
|
||||
class ModelProviderModelValidateApi(Resource):
|
||||
@console_ns.expect(parser_validate)
|
||||
@console_ns.expect(console_ns.models[ParserValidate.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
args = parser_validate.parse_args()
|
||||
args = ParserValidate.model_validate(console_ns.payload)
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
|
||||
|
|
@ -512,9 +513,9 @@ class ModelProviderModelValidateApi(Resource):
|
|||
model_provider_service.validate_model_credentials(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
model=args["model"],
|
||||
model_type=args["model_type"],
|
||||
credentials=args["credentials"],
|
||||
model=args.model,
|
||||
model_type=args.model_type,
|
||||
credentials=args.credentials,
|
||||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
result = False
|
||||
|
|
@ -528,24 +529,19 @@ class ModelProviderModelValidateApi(Resource):
|
|||
return response
|
||||
|
||||
|
||||
parser_parameter = reqparse.RequestParser().add_argument(
|
||||
"model", type=str, required=True, nullable=False, location="args"
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/model-providers/<path:provider>/models/parameter-rules")
|
||||
class ModelProviderModelParameterRuleApi(Resource):
|
||||
@console_ns.expect(parser_parameter)
|
||||
@console_ns.expect(console_ns.models[ParserParameter.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider: str):
|
||||
args = parser_parameter.parse_args()
|
||||
args = ParserParameter.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
parameter_rules = model_provider_service.get_model_parameter_rules(
|
||||
tenant_id=tenant_id, provider=provider, model=args["model"]
|
||||
tenant_id=tenant_id, provider=provider, model=args.model
|
||||
)
|
||||
|
||||
return jsonable_encoder({"data": parameter_rules})
|
||||
|
|
|
|||
|
|
@ -1,7 +1,9 @@
|
|||
import io
|
||||
from typing import Literal
|
||||
|
||||
from flask import request, send_file
|
||||
from flask_restx import Resource, reqparse
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from configs import dify_config
|
||||
|
|
@ -17,6 +19,8 @@ from services.plugin.plugin_parameter_service import PluginParameterService
|
|||
from services.plugin.plugin_permission_service import PluginPermissionService
|
||||
from services.plugin.plugin_service import PluginService
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/debugging-key")
|
||||
class PluginDebuggingKeyApi(Resource):
|
||||
|
|
@ -37,88 +41,251 @@ class PluginDebuggingKeyApi(Resource):
|
|||
raise ValueError(e)
|
||||
|
||||
|
||||
parser_list = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("page", type=int, required=False, location="args", default=1)
|
||||
.add_argument("page_size", type=int, required=False, location="args", default=256)
|
||||
class ParserList(BaseModel):
|
||||
page: int = Field(default=1)
|
||||
page_size: int = Field(default=256)
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserList.__name__, ParserList.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/list")
|
||||
class PluginListApi(Resource):
|
||||
@console_ns.expect(parser_list)
|
||||
@console_ns.expect(console_ns.models[ParserList.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
args = parser_list.parse_args()
|
||||
args = ParserList.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
try:
|
||||
plugins_with_total = PluginService.list_with_total(tenant_id, args["page"], args["page_size"])
|
||||
plugins_with_total = PluginService.list_with_total(tenant_id, args.page, args.page_size)
|
||||
except PluginDaemonClientSideError as e:
|
||||
raise ValueError(e)
|
||||
|
||||
return jsonable_encoder({"plugins": plugins_with_total.list, "total": plugins_with_total.total})
|
||||
|
||||
|
||||
parser_latest = reqparse.RequestParser().add_argument("plugin_ids", type=list, required=True, location="json")
|
||||
class ParserLatest(BaseModel):
|
||||
plugin_ids: list[str]
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserLatest.__name__, ParserLatest.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
|
||||
|
||||
class ParserIcon(BaseModel):
|
||||
tenant_id: str
|
||||
filename: str
|
||||
|
||||
|
||||
class ParserAsset(BaseModel):
|
||||
plugin_unique_identifier: str
|
||||
file_name: str
|
||||
|
||||
|
||||
class ParserGithubUpload(BaseModel):
|
||||
repo: str
|
||||
version: str
|
||||
package: str
|
||||
|
||||
|
||||
class ParserPluginIdentifiers(BaseModel):
|
||||
plugin_unique_identifiers: list[str]
|
||||
|
||||
|
||||
class ParserGithubInstall(BaseModel):
|
||||
plugin_unique_identifier: str
|
||||
repo: str
|
||||
version: str
|
||||
package: str
|
||||
|
||||
|
||||
class ParserPluginIdentifierQuery(BaseModel):
|
||||
plugin_unique_identifier: str
|
||||
|
||||
|
||||
class ParserTasks(BaseModel):
|
||||
page: int
|
||||
page_size: int
|
||||
|
||||
|
||||
class ParserMarketplaceUpgrade(BaseModel):
|
||||
original_plugin_unique_identifier: str
|
||||
new_plugin_unique_identifier: str
|
||||
|
||||
|
||||
class ParserGithubUpgrade(BaseModel):
|
||||
original_plugin_unique_identifier: str
|
||||
new_plugin_unique_identifier: str
|
||||
repo: str
|
||||
version: str
|
||||
package: str
|
||||
|
||||
|
||||
class ParserUninstall(BaseModel):
|
||||
plugin_installation_id: str
|
||||
|
||||
|
||||
class ParserPermissionChange(BaseModel):
|
||||
install_permission: TenantPluginPermission.InstallPermission
|
||||
debug_permission: TenantPluginPermission.DebugPermission
|
||||
|
||||
|
||||
class ParserDynamicOptions(BaseModel):
|
||||
plugin_id: str
|
||||
provider: str
|
||||
action: str
|
||||
parameter: str
|
||||
credential_id: str | None = None
|
||||
provider_type: Literal["tool", "trigger"]
|
||||
|
||||
|
||||
class PluginPermissionSettingsPayload(BaseModel):
|
||||
install_permission: TenantPluginPermission.InstallPermission = TenantPluginPermission.InstallPermission.EVERYONE
|
||||
debug_permission: TenantPluginPermission.DebugPermission = TenantPluginPermission.DebugPermission.EVERYONE
|
||||
|
||||
|
||||
class PluginAutoUpgradeSettingsPayload(BaseModel):
|
||||
strategy_setting: TenantPluginAutoUpgradeStrategy.StrategySetting = (
|
||||
TenantPluginAutoUpgradeStrategy.StrategySetting.FIX_ONLY
|
||||
)
|
||||
upgrade_time_of_day: int = 0
|
||||
upgrade_mode: TenantPluginAutoUpgradeStrategy.UpgradeMode = TenantPluginAutoUpgradeStrategy.UpgradeMode.EXCLUDE
|
||||
exclude_plugins: list[str] = Field(default_factory=list)
|
||||
include_plugins: list[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ParserPreferencesChange(BaseModel):
|
||||
permission: PluginPermissionSettingsPayload
|
||||
auto_upgrade: PluginAutoUpgradeSettingsPayload
|
||||
|
||||
|
||||
class ParserExcludePlugin(BaseModel):
|
||||
plugin_id: str
|
||||
|
||||
|
||||
class ParserReadme(BaseModel):
|
||||
plugin_unique_identifier: str
|
||||
language: str = Field(default="en-US")
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserIcon.__name__, ParserIcon.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserAsset.__name__, ParserAsset.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserGithubUpload.__name__, ParserGithubUpload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserPluginIdentifiers.__name__,
|
||||
ParserPluginIdentifiers.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserGithubInstall.__name__, ParserGithubInstall.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserPluginIdentifierQuery.__name__,
|
||||
ParserPluginIdentifierQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserTasks.__name__, ParserTasks.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserMarketplaceUpgrade.__name__,
|
||||
ParserMarketplaceUpgrade.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserGithubUpgrade.__name__, ParserGithubUpgrade.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserUninstall.__name__, ParserUninstall.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserPermissionChange.__name__,
|
||||
ParserPermissionChange.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserDynamicOptions.__name__,
|
||||
ParserDynamicOptions.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserPreferencesChange.__name__,
|
||||
ParserPreferencesChange.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserExcludePlugin.__name__,
|
||||
ParserExcludePlugin.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserReadme.__name__, ParserReadme.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/list/latest-versions")
|
||||
class PluginListLatestVersionsApi(Resource):
|
||||
@console_ns.expect(parser_latest)
|
||||
@console_ns.expect(console_ns.models[ParserLatest.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
args = parser_latest.parse_args()
|
||||
args = ParserLatest.model_validate(console_ns.payload)
|
||||
|
||||
try:
|
||||
versions = PluginService.list_latest_versions(args["plugin_ids"])
|
||||
versions = PluginService.list_latest_versions(args.plugin_ids)
|
||||
except PluginDaemonClientSideError as e:
|
||||
raise ValueError(e)
|
||||
|
||||
return jsonable_encoder({"versions": versions})
|
||||
|
||||
|
||||
parser_ids = reqparse.RequestParser().add_argument("plugin_ids", type=list, required=True, location="json")
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/list/installations/ids")
|
||||
class PluginListInstallationsFromIdsApi(Resource):
|
||||
@console_ns.expect(parser_ids)
|
||||
@console_ns.expect(console_ns.models[ParserLatest.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
args = parser_ids.parse_args()
|
||||
args = ParserLatest.model_validate(console_ns.payload)
|
||||
|
||||
try:
|
||||
plugins = PluginService.list_installations_from_ids(tenant_id, args["plugin_ids"])
|
||||
plugins = PluginService.list_installations_from_ids(tenant_id, args.plugin_ids)
|
||||
except PluginDaemonClientSideError as e:
|
||||
raise ValueError(e)
|
||||
|
||||
return jsonable_encoder({"plugins": plugins})
|
||||
|
||||
|
||||
parser_icon = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("tenant_id", type=str, required=True, location="args")
|
||||
.add_argument("filename", type=str, required=True, location="args")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/icon")
|
||||
class PluginIconApi(Resource):
|
||||
@console_ns.expect(parser_icon)
|
||||
@console_ns.expect(console_ns.models[ParserIcon.__name__])
|
||||
@setup_required
|
||||
def get(self):
|
||||
args = parser_icon.parse_args()
|
||||
args = ParserIcon.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
try:
|
||||
icon_bytes, mimetype = PluginService.get_asset(args["tenant_id"], args["filename"])
|
||||
icon_bytes, mimetype = PluginService.get_asset(args.tenant_id, args.filename)
|
||||
except PluginDaemonClientSideError as e:
|
||||
raise ValueError(e)
|
||||
|
||||
|
|
@ -128,20 +295,16 @@ class PluginIconApi(Resource):
|
|||
|
||||
@console_ns.route("/workspaces/current/plugin/asset")
|
||||
class PluginAssetApi(Resource):
|
||||
@console_ns.expect(console_ns.models[ParserAsset.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
req = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("plugin_unique_identifier", type=str, required=True, location="args")
|
||||
.add_argument("file_name", type=str, required=True, location="args")
|
||||
)
|
||||
args = req.parse_args()
|
||||
args = ParserAsset.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
try:
|
||||
binary = PluginService.extract_asset(tenant_id, args["plugin_unique_identifier"], args["file_name"])
|
||||
binary = PluginService.extract_asset(tenant_id, args.plugin_unique_identifier, args.file_name)
|
||||
return send_file(io.BytesIO(binary), mimetype="application/octet-stream")
|
||||
except PluginDaemonClientSideError as e:
|
||||
raise ValueError(e)
|
||||
|
|
@ -171,17 +334,9 @@ class PluginUploadFromPkgApi(Resource):
|
|||
return jsonable_encoder(response)
|
||||
|
||||
|
||||
parser_github = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("repo", type=str, required=True, location="json")
|
||||
.add_argument("version", type=str, required=True, location="json")
|
||||
.add_argument("package", type=str, required=True, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/upload/github")
|
||||
class PluginUploadFromGithubApi(Resource):
|
||||
@console_ns.expect(parser_github)
|
||||
@console_ns.expect(console_ns.models[ParserGithubUpload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -189,10 +344,10 @@ class PluginUploadFromGithubApi(Resource):
|
|||
def post(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
args = parser_github.parse_args()
|
||||
args = ParserGithubUpload.model_validate(console_ns.payload)
|
||||
|
||||
try:
|
||||
response = PluginService.upload_pkg_from_github(tenant_id, args["repo"], args["version"], args["package"])
|
||||
response = PluginService.upload_pkg_from_github(tenant_id, args.repo, args.version, args.package)
|
||||
except PluginDaemonClientSideError as e:
|
||||
raise ValueError(e)
|
||||
|
||||
|
|
@ -223,47 +378,28 @@ class PluginUploadFromBundleApi(Resource):
|
|||
return jsonable_encoder(response)
|
||||
|
||||
|
||||
parser_pkg = reqparse.RequestParser().add_argument(
|
||||
"plugin_unique_identifiers", type=list, required=True, location="json"
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/install/pkg")
|
||||
class PluginInstallFromPkgApi(Resource):
|
||||
@console_ns.expect(parser_pkg)
|
||||
@console_ns.expect(console_ns.models[ParserPluginIdentifiers.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@plugin_permission_required(install_required=True)
|
||||
def post(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
args = parser_pkg.parse_args()
|
||||
|
||||
# check if all plugin_unique_identifiers are valid string
|
||||
for plugin_unique_identifier in args["plugin_unique_identifiers"]:
|
||||
if not isinstance(plugin_unique_identifier, str):
|
||||
raise ValueError("Invalid plugin unique identifier")
|
||||
args = ParserPluginIdentifiers.model_validate(console_ns.payload)
|
||||
|
||||
try:
|
||||
response = PluginService.install_from_local_pkg(tenant_id, args["plugin_unique_identifiers"])
|
||||
response = PluginService.install_from_local_pkg(tenant_id, args.plugin_unique_identifiers)
|
||||
except PluginDaemonClientSideError as e:
|
||||
raise ValueError(e)
|
||||
|
||||
return jsonable_encoder(response)
|
||||
|
||||
|
||||
parser_githubapi = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("repo", type=str, required=True, location="json")
|
||||
.add_argument("version", type=str, required=True, location="json")
|
||||
.add_argument("package", type=str, required=True, location="json")
|
||||
.add_argument("plugin_unique_identifier", type=str, required=True, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/install/github")
|
||||
class PluginInstallFromGithubApi(Resource):
|
||||
@console_ns.expect(parser_githubapi)
|
||||
@console_ns.expect(console_ns.models[ParserGithubInstall.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -271,15 +407,15 @@ class PluginInstallFromGithubApi(Resource):
|
|||
def post(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
args = parser_githubapi.parse_args()
|
||||
args = ParserGithubInstall.model_validate(console_ns.payload)
|
||||
|
||||
try:
|
||||
response = PluginService.install_from_github(
|
||||
tenant_id,
|
||||
args["plugin_unique_identifier"],
|
||||
args["repo"],
|
||||
args["version"],
|
||||
args["package"],
|
||||
args.plugin_unique_identifier,
|
||||
args.repo,
|
||||
args.version,
|
||||
args.package,
|
||||
)
|
||||
except PluginDaemonClientSideError as e:
|
||||
raise ValueError(e)
|
||||
|
|
@ -287,14 +423,9 @@ class PluginInstallFromGithubApi(Resource):
|
|||
return jsonable_encoder(response)
|
||||
|
||||
|
||||
parser_marketplace = reqparse.RequestParser().add_argument(
|
||||
"plugin_unique_identifiers", type=list, required=True, location="json"
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/install/marketplace")
|
||||
class PluginInstallFromMarketplaceApi(Resource):
|
||||
@console_ns.expect(parser_marketplace)
|
||||
@console_ns.expect(console_ns.models[ParserPluginIdentifiers.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -302,43 +433,33 @@ class PluginInstallFromMarketplaceApi(Resource):
|
|||
def post(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
args = parser_marketplace.parse_args()
|
||||
|
||||
# check if all plugin_unique_identifiers are valid string
|
||||
for plugin_unique_identifier in args["plugin_unique_identifiers"]:
|
||||
if not isinstance(plugin_unique_identifier, str):
|
||||
raise ValueError("Invalid plugin unique identifier")
|
||||
args = ParserPluginIdentifiers.model_validate(console_ns.payload)
|
||||
|
||||
try:
|
||||
response = PluginService.install_from_marketplace_pkg(tenant_id, args["plugin_unique_identifiers"])
|
||||
response = PluginService.install_from_marketplace_pkg(tenant_id, args.plugin_unique_identifiers)
|
||||
except PluginDaemonClientSideError as e:
|
||||
raise ValueError(e)
|
||||
|
||||
return jsonable_encoder(response)
|
||||
|
||||
|
||||
parser_pkgapi = reqparse.RequestParser().add_argument(
|
||||
"plugin_unique_identifier", type=str, required=True, location="args"
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/marketplace/pkg")
|
||||
class PluginFetchMarketplacePkgApi(Resource):
|
||||
@console_ns.expect(parser_pkgapi)
|
||||
@console_ns.expect(console_ns.models[ParserPluginIdentifierQuery.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@plugin_permission_required(install_required=True)
|
||||
def get(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
args = parser_pkgapi.parse_args()
|
||||
args = ParserPluginIdentifierQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
try:
|
||||
return jsonable_encoder(
|
||||
{
|
||||
"manifest": PluginService.fetch_marketplace_pkg(
|
||||
tenant_id,
|
||||
args["plugin_unique_identifier"],
|
||||
args.plugin_unique_identifier,
|
||||
)
|
||||
}
|
||||
)
|
||||
|
|
@ -346,14 +467,9 @@ class PluginFetchMarketplacePkgApi(Resource):
|
|||
raise ValueError(e)
|
||||
|
||||
|
||||
parser_fetch = reqparse.RequestParser().add_argument(
|
||||
"plugin_unique_identifier", type=str, required=True, location="args"
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/fetch-manifest")
|
||||
class PluginFetchManifestApi(Resource):
|
||||
@console_ns.expect(parser_fetch)
|
||||
@console_ns.expect(console_ns.models[ParserPluginIdentifierQuery.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -361,30 +477,19 @@ class PluginFetchManifestApi(Resource):
|
|||
def get(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
args = parser_fetch.parse_args()
|
||||
args = ParserPluginIdentifierQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
try:
|
||||
return jsonable_encoder(
|
||||
{
|
||||
"manifest": PluginService.fetch_plugin_manifest(
|
||||
tenant_id, args["plugin_unique_identifier"]
|
||||
).model_dump()
|
||||
}
|
||||
{"manifest": PluginService.fetch_plugin_manifest(tenant_id, args.plugin_unique_identifier).model_dump()}
|
||||
)
|
||||
except PluginDaemonClientSideError as e:
|
||||
raise ValueError(e)
|
||||
|
||||
|
||||
parser_tasks = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("page", type=int, required=True, location="args")
|
||||
.add_argument("page_size", type=int, required=True, location="args")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/tasks")
|
||||
class PluginFetchInstallTasksApi(Resource):
|
||||
@console_ns.expect(parser_tasks)
|
||||
@console_ns.expect(console_ns.models[ParserTasks.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -392,12 +497,10 @@ class PluginFetchInstallTasksApi(Resource):
|
|||
def get(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
args = parser_tasks.parse_args()
|
||||
args = ParserTasks.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
try:
|
||||
return jsonable_encoder(
|
||||
{"tasks": PluginService.fetch_install_tasks(tenant_id, args["page"], args["page_size"])}
|
||||
)
|
||||
return jsonable_encoder({"tasks": PluginService.fetch_install_tasks(tenant_id, args.page, args.page_size)})
|
||||
except PluginDaemonClientSideError as e:
|
||||
raise ValueError(e)
|
||||
|
||||
|
|
@ -462,16 +565,9 @@ class PluginDeleteInstallTaskItemApi(Resource):
|
|||
raise ValueError(e)
|
||||
|
||||
|
||||
parser_marketplace_api = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("original_plugin_unique_identifier", type=str, required=True, location="json")
|
||||
.add_argument("new_plugin_unique_identifier", type=str, required=True, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/upgrade/marketplace")
|
||||
class PluginUpgradeFromMarketplaceApi(Resource):
|
||||
@console_ns.expect(parser_marketplace_api)
|
||||
@console_ns.expect(console_ns.models[ParserMarketplaceUpgrade.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -479,31 +575,21 @@ class PluginUpgradeFromMarketplaceApi(Resource):
|
|||
def post(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
args = parser_marketplace_api.parse_args()
|
||||
args = ParserMarketplaceUpgrade.model_validate(console_ns.payload)
|
||||
|
||||
try:
|
||||
return jsonable_encoder(
|
||||
PluginService.upgrade_plugin_with_marketplace(
|
||||
tenant_id, args["original_plugin_unique_identifier"], args["new_plugin_unique_identifier"]
|
||||
tenant_id, args.original_plugin_unique_identifier, args.new_plugin_unique_identifier
|
||||
)
|
||||
)
|
||||
except PluginDaemonClientSideError as e:
|
||||
raise ValueError(e)
|
||||
|
||||
|
||||
parser_github_post = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("original_plugin_unique_identifier", type=str, required=True, location="json")
|
||||
.add_argument("new_plugin_unique_identifier", type=str, required=True, location="json")
|
||||
.add_argument("repo", type=str, required=True, location="json")
|
||||
.add_argument("version", type=str, required=True, location="json")
|
||||
.add_argument("package", type=str, required=True, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/upgrade/github")
|
||||
class PluginUpgradeFromGithubApi(Resource):
|
||||
@console_ns.expect(parser_github_post)
|
||||
@console_ns.expect(console_ns.models[ParserGithubUpgrade.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -511,56 +597,44 @@ class PluginUpgradeFromGithubApi(Resource):
|
|||
def post(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
args = parser_github_post.parse_args()
|
||||
args = ParserGithubUpgrade.model_validate(console_ns.payload)
|
||||
|
||||
try:
|
||||
return jsonable_encoder(
|
||||
PluginService.upgrade_plugin_with_github(
|
||||
tenant_id,
|
||||
args["original_plugin_unique_identifier"],
|
||||
args["new_plugin_unique_identifier"],
|
||||
args["repo"],
|
||||
args["version"],
|
||||
args["package"],
|
||||
args.original_plugin_unique_identifier,
|
||||
args.new_plugin_unique_identifier,
|
||||
args.repo,
|
||||
args.version,
|
||||
args.package,
|
||||
)
|
||||
)
|
||||
except PluginDaemonClientSideError as e:
|
||||
raise ValueError(e)
|
||||
|
||||
|
||||
parser_uninstall = reqparse.RequestParser().add_argument(
|
||||
"plugin_installation_id", type=str, required=True, location="json"
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/uninstall")
|
||||
class PluginUninstallApi(Resource):
|
||||
@console_ns.expect(parser_uninstall)
|
||||
@console_ns.expect(console_ns.models[ParserUninstall.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@plugin_permission_required(install_required=True)
|
||||
def post(self):
|
||||
args = parser_uninstall.parse_args()
|
||||
args = ParserUninstall.model_validate(console_ns.payload)
|
||||
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
try:
|
||||
return {"success": PluginService.uninstall(tenant_id, args["plugin_installation_id"])}
|
||||
return {"success": PluginService.uninstall(tenant_id, args.plugin_installation_id)}
|
||||
except PluginDaemonClientSideError as e:
|
||||
raise ValueError(e)
|
||||
|
||||
|
||||
parser_change_post = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("install_permission", type=str, required=True, location="json")
|
||||
.add_argument("debug_permission", type=str, required=True, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/permission/change")
|
||||
class PluginChangePermissionApi(Resource):
|
||||
@console_ns.expect(parser_change_post)
|
||||
@console_ns.expect(console_ns.models[ParserPermissionChange.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -570,14 +644,15 @@ class PluginChangePermissionApi(Resource):
|
|||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
args = parser_change_post.parse_args()
|
||||
|
||||
install_permission = TenantPluginPermission.InstallPermission(args["install_permission"])
|
||||
debug_permission = TenantPluginPermission.DebugPermission(args["debug_permission"])
|
||||
args = ParserPermissionChange.model_validate(console_ns.payload)
|
||||
|
||||
tenant_id = current_tenant_id
|
||||
|
||||
return {"success": PluginPermissionService.change_permission(tenant_id, install_permission, debug_permission)}
|
||||
return {
|
||||
"success": PluginPermissionService.change_permission(
|
||||
tenant_id, args.install_permission, args.debug_permission
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/permission/fetch")
|
||||
|
|
@ -605,20 +680,9 @@ class PluginFetchPermissionApi(Resource):
|
|||
)
|
||||
|
||||
|
||||
parser_dynamic = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("plugin_id", type=str, required=True, location="args")
|
||||
.add_argument("provider", type=str, required=True, location="args")
|
||||
.add_argument("action", type=str, required=True, location="args")
|
||||
.add_argument("parameter", type=str, required=True, location="args")
|
||||
.add_argument("credential_id", type=str, required=False, location="args")
|
||||
.add_argument("provider_type", type=str, required=True, location="args")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/parameters/dynamic-options")
|
||||
class PluginFetchDynamicSelectOptionsApi(Resource):
|
||||
@console_ns.expect(parser_dynamic)
|
||||
@console_ns.expect(console_ns.models[ParserDynamicOptions.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
|
|
@ -627,18 +691,18 @@ class PluginFetchDynamicSelectOptionsApi(Resource):
|
|||
current_user, tenant_id = current_account_with_tenant()
|
||||
user_id = current_user.id
|
||||
|
||||
args = parser_dynamic.parse_args()
|
||||
args = ParserDynamicOptions.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
try:
|
||||
options = PluginParameterService.get_dynamic_select_options(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
plugin_id=args["plugin_id"],
|
||||
provider=args["provider"],
|
||||
action=args["action"],
|
||||
parameter=args["parameter"],
|
||||
credential_id=args["credential_id"],
|
||||
provider_type=args["provider_type"],
|
||||
plugin_id=args.plugin_id,
|
||||
provider=args.provider,
|
||||
action=args.action,
|
||||
parameter=args.parameter,
|
||||
credential_id=args.credential_id,
|
||||
provider_type=args.provider_type,
|
||||
)
|
||||
except PluginDaemonClientSideError as e:
|
||||
raise ValueError(e)
|
||||
|
|
@ -646,16 +710,9 @@ class PluginFetchDynamicSelectOptionsApi(Resource):
|
|||
return jsonable_encoder({"options": options})
|
||||
|
||||
|
||||
parser_change = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("permission", type=dict, required=True, location="json")
|
||||
.add_argument("auto_upgrade", type=dict, required=True, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/preferences/change")
|
||||
class PluginChangePreferencesApi(Resource):
|
||||
@console_ns.expect(parser_change)
|
||||
@console_ns.expect(console_ns.models[ParserPreferencesChange.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -664,22 +721,20 @@ class PluginChangePreferencesApi(Resource):
|
|||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
args = parser_change.parse_args()
|
||||
args = ParserPreferencesChange.model_validate(console_ns.payload)
|
||||
|
||||
permission = args["permission"]
|
||||
permission = args.permission
|
||||
|
||||
install_permission = TenantPluginPermission.InstallPermission(permission.get("install_permission", "everyone"))
|
||||
debug_permission = TenantPluginPermission.DebugPermission(permission.get("debug_permission", "everyone"))
|
||||
install_permission = permission.install_permission
|
||||
debug_permission = permission.debug_permission
|
||||
|
||||
auto_upgrade = args["auto_upgrade"]
|
||||
auto_upgrade = args.auto_upgrade
|
||||
|
||||
strategy_setting = TenantPluginAutoUpgradeStrategy.StrategySetting(
|
||||
auto_upgrade.get("strategy_setting", "fix_only")
|
||||
)
|
||||
upgrade_time_of_day = auto_upgrade.get("upgrade_time_of_day", 0)
|
||||
upgrade_mode = TenantPluginAutoUpgradeStrategy.UpgradeMode(auto_upgrade.get("upgrade_mode", "exclude"))
|
||||
exclude_plugins = auto_upgrade.get("exclude_plugins", [])
|
||||
include_plugins = auto_upgrade.get("include_plugins", [])
|
||||
strategy_setting = auto_upgrade.strategy_setting
|
||||
upgrade_time_of_day = auto_upgrade.upgrade_time_of_day
|
||||
upgrade_mode = auto_upgrade.upgrade_mode
|
||||
exclude_plugins = auto_upgrade.exclude_plugins
|
||||
include_plugins = auto_upgrade.include_plugins
|
||||
|
||||
# set permission
|
||||
set_permission_result = PluginPermissionService.change_permission(
|
||||
|
|
@ -744,12 +799,9 @@ class PluginFetchPreferencesApi(Resource):
|
|||
return jsonable_encoder({"permission": permission_dict, "auto_upgrade": auto_upgrade_dict})
|
||||
|
||||
|
||||
parser_exclude = reqparse.RequestParser().add_argument("plugin_id", type=str, required=True, location="json")
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/preferences/autoupgrade/exclude")
|
||||
class PluginAutoUpgradeExcludePluginApi(Resource):
|
||||
@console_ns.expect(parser_exclude)
|
||||
@console_ns.expect(console_ns.models[ParserExcludePlugin.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -757,28 +809,20 @@ class PluginAutoUpgradeExcludePluginApi(Resource):
|
|||
# exclude one single plugin
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
args = parser_exclude.parse_args()
|
||||
args = ParserExcludePlugin.model_validate(console_ns.payload)
|
||||
|
||||
return jsonable_encoder({"success": PluginAutoUpgradeService.exclude_plugin(tenant_id, args["plugin_id"])})
|
||||
return jsonable_encoder({"success": PluginAutoUpgradeService.exclude_plugin(tenant_id, args.plugin_id)})
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/readme")
|
||||
class PluginReadmeApi(Resource):
|
||||
@console_ns.expect(console_ns.models[ParserReadme.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("plugin_unique_identifier", type=str, required=True, location="args")
|
||||
.add_argument("language", type=str, required=False, location="args")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = ParserReadme.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
return jsonable_encoder(
|
||||
{
|
||||
"readme": PluginService.fetch_plugin_readme(
|
||||
tenant_id, args["plugin_unique_identifier"], args.get("language", "en-US")
|
||||
)
|
||||
}
|
||||
{"readme": PluginService.fetch_plugin_readme(tenant_id, args.plugin_unique_identifier, args.language)}
|
||||
)
|
||||
|
|
|
|||
|
|
@ -6,8 +6,6 @@ from sqlalchemy.orm import Session
|
|||
from werkzeug.exceptions import BadRequest, Forbidden
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required
|
||||
from controllers.web.error import NotFoundError
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
|
|
@ -23,9 +21,13 @@ from services.trigger.trigger_provider_service import TriggerProviderService
|
|||
from services.trigger.trigger_subscription_builder_service import TriggerSubscriptionBuilderService
|
||||
from services.trigger.trigger_subscription_operator_service import TriggerSubscriptionOperatorService
|
||||
|
||||
from .. import console_ns
|
||||
from ..wraps import account_initialization_required, is_admin_or_owner_required, setup_required
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/trigger-provider/<path:provider>/icon")
|
||||
class TriggerProviderIconApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -38,6 +40,7 @@ class TriggerProviderIconApi(Resource):
|
|||
return TriggerManager.get_trigger_plugin_icon(tenant_id=user.current_tenant_id, provider_id=provider)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/triggers")
|
||||
class TriggerProviderListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -50,6 +53,7 @@ class TriggerProviderListApi(Resource):
|
|||
return jsonable_encoder(TriggerProviderService.list_trigger_providers(user.current_tenant_id))
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/trigger-provider/<path:provider>/info")
|
||||
class TriggerProviderInfoApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -64,6 +68,7 @@ class TriggerProviderInfoApi(Resource):
|
|||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/trigger-provider/<path:provider>/subscriptions/list")
|
||||
class TriggerSubscriptionListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -87,7 +92,16 @@ class TriggerSubscriptionListApi(Resource):
|
|||
raise
|
||||
|
||||
|
||||
parser = reqparse.RequestParser().add_argument(
|
||||
"credential_type", type=str, required=False, nullable=True, location="json"
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/create",
|
||||
)
|
||||
class TriggerSubscriptionBuilderCreateApi(Resource):
|
||||
@console_ns.expect(parser)
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
|
|
@ -97,9 +111,6 @@ class TriggerSubscriptionBuilderCreateApi(Resource):
|
|||
user = current_user
|
||||
assert user.current_tenant_id is not None
|
||||
|
||||
parser = reqparse.RequestParser().add_argument(
|
||||
"credential_type", type=str, required=False, nullable=True, location="json"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
|
|
@ -116,6 +127,9 @@ class TriggerSubscriptionBuilderCreateApi(Resource):
|
|||
raise
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/<path:subscription_builder_id>",
|
||||
)
|
||||
class TriggerSubscriptionBuilderGetApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -127,7 +141,18 @@ class TriggerSubscriptionBuilderGetApi(Resource):
|
|||
)
|
||||
|
||||
|
||||
parser_api = (
|
||||
reqparse.RequestParser()
|
||||
# The credentials of the subscription builder
|
||||
.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/verify/<path:subscription_builder_id>",
|
||||
)
|
||||
class TriggerSubscriptionBuilderVerifyApi(Resource):
|
||||
@console_ns.expect(parser_api)
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
|
|
@ -136,12 +161,8 @@ class TriggerSubscriptionBuilderVerifyApi(Resource):
|
|||
"""Verify a subscription instance for a trigger provider"""
|
||||
user = current_user
|
||||
assert user.current_tenant_id is not None
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
# The credentials of the subscription builder
|
||||
.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
args = parser_api.parse_args()
|
||||
|
||||
try:
|
||||
# Use atomic update_and_verify to prevent race conditions
|
||||
|
|
@ -159,7 +180,24 @@ class TriggerSubscriptionBuilderVerifyApi(Resource):
|
|||
raise ValueError(str(e)) from e
|
||||
|
||||
|
||||
parser_update_api = (
|
||||
reqparse.RequestParser()
|
||||
# The name of the subscription builder
|
||||
.add_argument("name", type=str, required=False, nullable=True, location="json")
|
||||
# The parameters of the subscription builder
|
||||
.add_argument("parameters", type=dict, required=False, nullable=True, location="json")
|
||||
# The properties of the subscription builder
|
||||
.add_argument("properties", type=dict, required=False, nullable=True, location="json")
|
||||
# The credentials of the subscription builder
|
||||
.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/update/<path:subscription_builder_id>",
|
||||
)
|
||||
class TriggerSubscriptionBuilderUpdateApi(Resource):
|
||||
@console_ns.expect(parser_update_api)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -169,18 +207,7 @@ class TriggerSubscriptionBuilderUpdateApi(Resource):
|
|||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
# The name of the subscription builder
|
||||
.add_argument("name", type=str, required=False, nullable=True, location="json")
|
||||
# The parameters of the subscription builder
|
||||
.add_argument("parameters", type=dict, required=False, nullable=True, location="json")
|
||||
# The properties of the subscription builder
|
||||
.add_argument("properties", type=dict, required=False, nullable=True, location="json")
|
||||
# The credentials of the subscription builder
|
||||
.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_update_api.parse_args()
|
||||
try:
|
||||
return jsonable_encoder(
|
||||
TriggerSubscriptionBuilderService.update_trigger_subscription_builder(
|
||||
|
|
@ -200,6 +227,9 @@ class TriggerSubscriptionBuilderUpdateApi(Resource):
|
|||
raise
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/logs/<path:subscription_builder_id>",
|
||||
)
|
||||
class TriggerSubscriptionBuilderLogsApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -218,7 +248,11 @@ class TriggerSubscriptionBuilderLogsApi(Resource):
|
|||
raise
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/build/<path:subscription_builder_id>",
|
||||
)
|
||||
class TriggerSubscriptionBuilderBuildApi(Resource):
|
||||
@console_ns.expect(parser_update_api)
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
|
|
@ -227,18 +261,7 @@ class TriggerSubscriptionBuilderBuildApi(Resource):
|
|||
"""Build a subscription instance for a trigger provider"""
|
||||
user = current_user
|
||||
assert user.current_tenant_id is not None
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
# The name of the subscription builder
|
||||
.add_argument("name", type=str, required=False, nullable=True, location="json")
|
||||
# The parameters of the subscription builder
|
||||
.add_argument("parameters", type=dict, required=False, nullable=True, location="json")
|
||||
# The properties of the subscription builder
|
||||
.add_argument("properties", type=dict, required=False, nullable=True, location="json")
|
||||
# The credentials of the subscription builder
|
||||
.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_update_api.parse_args()
|
||||
try:
|
||||
# Use atomic update_and_build to prevent race conditions
|
||||
TriggerSubscriptionBuilderService.update_and_build_builder(
|
||||
|
|
@ -258,6 +281,9 @@ class TriggerSubscriptionBuilderBuildApi(Resource):
|
|||
raise ValueError(str(e)) from e
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
"/workspaces/current/trigger-provider/<path:subscription_id>/subscriptions/delete",
|
||||
)
|
||||
class TriggerSubscriptionDeleteApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -291,6 +317,7 @@ class TriggerSubscriptionDeleteApi(Resource):
|
|||
raise
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/trigger-provider/<path:provider>/subscriptions/oauth/authorize")
|
||||
class TriggerOAuthAuthorizeApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -374,6 +401,7 @@ class TriggerOAuthAuthorizeApi(Resource):
|
|||
raise
|
||||
|
||||
|
||||
@console_ns.route("/oauth/plugin/<path:provider>/trigger/callback")
|
||||
class TriggerOAuthCallbackApi(Resource):
|
||||
@setup_required
|
||||
def get(self, provider):
|
||||
|
|
@ -438,6 +466,14 @@ class TriggerOAuthCallbackApi(Resource):
|
|||
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")
|
||||
|
||||
|
||||
parser_oauth_client = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("client_params", type=dict, required=False, nullable=True, location="json")
|
||||
.add_argument("enabled", type=bool, required=False, nullable=True, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/trigger-provider/<path:provider>/oauth/client")
|
||||
class TriggerOAuthClientManageApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -484,6 +520,7 @@ class TriggerOAuthClientManageApi(Resource):
|
|||
logger.exception("Error getting OAuth client", exc_info=e)
|
||||
raise
|
||||
|
||||
@console_ns.expect(parser_oauth_client)
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
|
|
@ -493,12 +530,7 @@ class TriggerOAuthClientManageApi(Resource):
|
|||
user = current_user
|
||||
assert user.current_tenant_id is not None
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("client_params", type=dict, required=False, nullable=True, location="json")
|
||||
.add_argument("enabled", type=bool, required=False, nullable=True, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_oauth_client.parse_args()
|
||||
|
||||
try:
|
||||
provider_id = TriggerProviderID(provider)
|
||||
|
|
@ -536,52 +568,3 @@ class TriggerOAuthClientManageApi(Resource):
|
|||
except Exception as e:
|
||||
logger.exception("Error removing OAuth client", exc_info=e)
|
||||
raise
|
||||
|
||||
|
||||
# Trigger Subscription
|
||||
console_ns.add_resource(TriggerProviderIconApi, "/workspaces/current/trigger-provider/<path:provider>/icon")
|
||||
console_ns.add_resource(TriggerProviderListApi, "/workspaces/current/triggers")
|
||||
console_ns.add_resource(TriggerProviderInfoApi, "/workspaces/current/trigger-provider/<path:provider>/info")
|
||||
console_ns.add_resource(
|
||||
TriggerSubscriptionListApi, "/workspaces/current/trigger-provider/<path:provider>/subscriptions/list"
|
||||
)
|
||||
console_ns.add_resource(
|
||||
TriggerSubscriptionDeleteApi,
|
||||
"/workspaces/current/trigger-provider/<path:subscription_id>/subscriptions/delete",
|
||||
)
|
||||
|
||||
# Trigger Subscription Builder
|
||||
console_ns.add_resource(
|
||||
TriggerSubscriptionBuilderCreateApi,
|
||||
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/create",
|
||||
)
|
||||
console_ns.add_resource(
|
||||
TriggerSubscriptionBuilderGetApi,
|
||||
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/<path:subscription_builder_id>",
|
||||
)
|
||||
console_ns.add_resource(
|
||||
TriggerSubscriptionBuilderUpdateApi,
|
||||
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/update/<path:subscription_builder_id>",
|
||||
)
|
||||
console_ns.add_resource(
|
||||
TriggerSubscriptionBuilderVerifyApi,
|
||||
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/verify/<path:subscription_builder_id>",
|
||||
)
|
||||
console_ns.add_resource(
|
||||
TriggerSubscriptionBuilderBuildApi,
|
||||
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/build/<path:subscription_builder_id>",
|
||||
)
|
||||
console_ns.add_resource(
|
||||
TriggerSubscriptionBuilderLogsApi,
|
||||
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/logs/<path:subscription_builder_id>",
|
||||
)
|
||||
|
||||
|
||||
# OAuth
|
||||
console_ns.add_resource(
|
||||
TriggerOAuthAuthorizeApi, "/workspaces/current/trigger-provider/<path:provider>/subscriptions/oauth/authorize"
|
||||
)
|
||||
console_ns.add_resource(TriggerOAuthCallbackApi, "/oauth/plugin/<path:provider>/trigger/callback")
|
||||
console_ns.add_resource(
|
||||
TriggerOAuthClientManageApi, "/workspaces/current/trigger-provider/<path:provider>/oauth/client"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
import logging
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse
|
||||
from flask_restx import Resource, fields, marshal, marshal_with
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
|
|
@ -32,6 +33,45 @@ from services.file_service import FileService
|
|||
from services.workspace_service import WorkspaceService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class WorkspaceListQuery(BaseModel):
|
||||
page: int = Field(default=1, ge=1, le=99999)
|
||||
limit: int = Field(default=20, ge=1, le=100)
|
||||
|
||||
|
||||
class SwitchWorkspacePayload(BaseModel):
|
||||
tenant_id: str
|
||||
|
||||
|
||||
class WorkspaceCustomConfigPayload(BaseModel):
|
||||
remove_webapp_brand: bool | None = None
|
||||
replace_webapp_logo: str | None = None
|
||||
|
||||
|
||||
class WorkspaceInfoPayload(BaseModel):
|
||||
name: str
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
WorkspaceListQuery.__name__, WorkspaceListQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
|
||||
console_ns.schema_model(
|
||||
SwitchWorkspacePayload.__name__,
|
||||
SwitchWorkspacePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
console_ns.schema_model(
|
||||
WorkspaceCustomConfigPayload.__name__,
|
||||
WorkspaceCustomConfigPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
console_ns.schema_model(
|
||||
WorkspaceInfoPayload.__name__,
|
||||
WorkspaceInfoPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
|
||||
provider_fields = {
|
||||
|
|
@ -95,18 +135,15 @@ class TenantListApi(Resource):
|
|||
|
||||
@console_ns.route("/all-workspaces")
|
||||
class WorkspaceListApi(Resource):
|
||||
@console_ns.expect(console_ns.models[WorkspaceListQuery.__name__])
|
||||
@setup_required
|
||||
@admin_required
|
||||
def get(self):
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args")
|
||||
.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
payload = request.args.to_dict(flat=True) # type: ignore
|
||||
args = WorkspaceListQuery.model_validate(payload)
|
||||
|
||||
stmt = select(Tenant).order_by(Tenant.created_at.desc())
|
||||
tenants = db.paginate(select=stmt, page=args["page"], per_page=args["limit"], error_out=False)
|
||||
tenants = db.paginate(select=stmt, page=args.page, per_page=args.limit, error_out=False)
|
||||
has_more = False
|
||||
|
||||
if tenants.has_next:
|
||||
|
|
@ -115,8 +152,8 @@ class WorkspaceListApi(Resource):
|
|||
return {
|
||||
"data": marshal(tenants.items, workspace_fields),
|
||||
"has_more": has_more,
|
||||
"limit": args["limit"],
|
||||
"page": args["page"],
|
||||
"limit": args.limit,
|
||||
"page": args.page,
|
||||
"total": tenants.total,
|
||||
}, 200
|
||||
|
||||
|
|
@ -150,26 +187,24 @@ class TenantApi(Resource):
|
|||
return WorkspaceService.get_tenant_info(tenant), 200
|
||||
|
||||
|
||||
parser_switch = reqparse.RequestParser().add_argument("tenant_id", type=str, required=True, location="json")
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/switch")
|
||||
class SwitchWorkspaceApi(Resource):
|
||||
@console_ns.expect(parser_switch)
|
||||
@console_ns.expect(console_ns.models[SwitchWorkspacePayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
args = parser_switch.parse_args()
|
||||
payload = console_ns.payload or {}
|
||||
args = SwitchWorkspacePayload.model_validate(payload)
|
||||
|
||||
# check if tenant_id is valid, 403 if not
|
||||
try:
|
||||
TenantService.switch_tenant(current_user, args["tenant_id"])
|
||||
TenantService.switch_tenant(current_user, args.tenant_id)
|
||||
except Exception:
|
||||
raise AccountNotLinkTenantError("Account not link tenant")
|
||||
|
||||
new_tenant = db.session.query(Tenant).get(args["tenant_id"]) # Get new tenant
|
||||
new_tenant = db.session.query(Tenant).get(args.tenant_id) # Get new tenant
|
||||
if new_tenant is None:
|
||||
raise ValueError("Tenant not found")
|
||||
|
||||
|
|
@ -178,24 +213,21 @@ class SwitchWorkspaceApi(Resource):
|
|||
|
||||
@console_ns.route("/workspaces/custom-config")
|
||||
class CustomConfigWorkspaceApi(Resource):
|
||||
@console_ns.expect(console_ns.models[WorkspaceCustomConfigPayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("workspace_custom")
|
||||
def post(self):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("remove_webapp_brand", type=bool, location="json")
|
||||
.add_argument("replace_webapp_logo", type=str, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
payload = console_ns.payload or {}
|
||||
args = WorkspaceCustomConfigPayload.model_validate(payload)
|
||||
tenant = db.get_or_404(Tenant, current_tenant_id)
|
||||
|
||||
custom_config_dict = {
|
||||
"remove_webapp_brand": args["remove_webapp_brand"],
|
||||
"replace_webapp_logo": args["replace_webapp_logo"]
|
||||
if args["replace_webapp_logo"] is not None
|
||||
"remove_webapp_brand": args.remove_webapp_brand,
|
||||
"replace_webapp_logo": args.replace_webapp_logo
|
||||
if args.replace_webapp_logo is not None
|
||||
else tenant.custom_config_dict.get("replace_webapp_logo"),
|
||||
}
|
||||
|
||||
|
|
@ -245,24 +277,22 @@ class WebappLogoWorkspaceApi(Resource):
|
|||
return {"id": upload_file.id}, 201
|
||||
|
||||
|
||||
parser_info = reqparse.RequestParser().add_argument("name", type=str, required=True, location="json")
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/info")
|
||||
class WorkspaceInfoApi(Resource):
|
||||
@console_ns.expect(parser_info)
|
||||
@console_ns.expect(console_ns.models[WorkspaceInfoPayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
# Change workspace name
|
||||
def post(self):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
args = parser_info.parse_args()
|
||||
payload = console_ns.payload or {}
|
||||
args = WorkspaceInfoPayload.model_validate(payload)
|
||||
|
||||
if not current_tenant_id:
|
||||
raise ValueError("No current tenant")
|
||||
tenant = db.get_or_404(Tenant, current_tenant_id)
|
||||
tenant.name = args["name"]
|
||||
tenant.name = args.name
|
||||
db.session.commit()
|
||||
|
||||
return {"result": "success", "tenant": marshal(WorkspaceService.get_tenant_info(tenant), tenant_fields)}
|
||||
|
|
|
|||
|
|
@ -17,7 +17,6 @@ from controllers.service_api.app.error import (
|
|||
)
|
||||
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
||||
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.errors.error import (
|
||||
ModelCurrentlyNotSupportError,
|
||||
|
|
@ -30,6 +29,7 @@ from libs import helper
|
|||
from libs.helper import uuid_value
|
||||
from models.model import App, AppMode, EndUser
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.app_task_service import AppTaskService
|
||||
from services.errors.app import IsDraftWorkflowError, WorkflowIdFormatError, WorkflowNotFoundError
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
|
||||
|
|
@ -88,7 +88,7 @@ class CompletionApi(Resource):
|
|||
This endpoint generates a completion based on the provided inputs and query.
|
||||
Supports both blocking and streaming response modes.
|
||||
"""
|
||||
if app_model.mode != "completion":
|
||||
if app_model.mode != AppMode.COMPLETION:
|
||||
raise AppUnavailableError()
|
||||
|
||||
args = completion_parser.parse_args()
|
||||
|
|
@ -147,10 +147,15 @@ class CompletionStopApi(Resource):
|
|||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
|
||||
def post(self, app_model: App, end_user: EndUser, task_id: str):
|
||||
"""Stop a running completion task."""
|
||||
if app_model.mode != "completion":
|
||||
if app_model.mode != AppMode.COMPLETION:
|
||||
raise AppUnavailableError()
|
||||
|
||||
AppQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id)
|
||||
AppTaskService.stop_task(
|
||||
task_id=task_id,
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
user_id=end_user.id,
|
||||
app_mode=AppMode.value_of(app_model.mode),
|
||||
)
|
||||
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
|
@ -244,6 +249,11 @@ class ChatStopApi(Resource):
|
|||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
AppQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id)
|
||||
AppTaskService.stop_task(
|
||||
task_id=task_id,
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
user_id=end_user.id,
|
||||
app_mode=app_mode,
|
||||
)
|
||||
|
||||
return {"result": "success"}, 200
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import logging
|
||||
import time
|
||||
|
||||
from flask import jsonify
|
||||
from flask import jsonify, request
|
||||
from werkzeug.exceptions import NotFound, RequestEntityTooLarge
|
||||
|
||||
from controllers.trigger import bp
|
||||
|
|
@ -28,8 +28,14 @@ def _prepare_webhook_execution(webhook_id: str, is_debug: bool = False):
|
|||
webhook_data = WebhookService.extract_and_validate_webhook_data(webhook_trigger, node_config)
|
||||
return webhook_trigger, workflow, node_config, webhook_data, None
|
||||
except ValueError as e:
|
||||
# Fall back to raw extraction for error reporting
|
||||
webhook_data = WebhookService.extract_webhook_data(webhook_trigger)
|
||||
# Provide minimal context for error reporting without risking another parse failure
|
||||
webhook_data = {
|
||||
"method": request.method,
|
||||
"headers": dict(request.headers),
|
||||
"query_params": dict(request.args),
|
||||
"body": {},
|
||||
"files": {},
|
||||
}
|
||||
return webhook_trigger, workflow, node_config, webhook_data, str(e)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -17,7 +17,6 @@ from controllers.web.error import (
|
|||
)
|
||||
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
|
||||
from controllers.web.wraps import WebApiResource
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.errors.error import (
|
||||
ModelCurrentlyNotSupportError,
|
||||
|
|
@ -29,6 +28,7 @@ from libs import helper
|
|||
from libs.helper import uuid_value
|
||||
from models.model import AppMode
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.app_task_service import AppTaskService
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -64,7 +64,7 @@ class CompletionApi(WebApiResource):
|
|||
}
|
||||
)
|
||||
def post(self, app_model, end_user):
|
||||
if app_model.mode != "completion":
|
||||
if app_model.mode != AppMode.COMPLETION:
|
||||
raise NotCompletionAppError()
|
||||
|
||||
parser = (
|
||||
|
|
@ -125,10 +125,15 @@ class CompletionStopApi(WebApiResource):
|
|||
}
|
||||
)
|
||||
def post(self, app_model, end_user, task_id):
|
||||
if app_model.mode != "completion":
|
||||
if app_model.mode != AppMode.COMPLETION:
|
||||
raise NotCompletionAppError()
|
||||
|
||||
AppQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id)
|
||||
AppTaskService.stop_task(
|
||||
task_id=task_id,
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
user_id=end_user.id,
|
||||
app_mode=AppMode.value_of(app_model.mode),
|
||||
)
|
||||
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
|
@ -234,6 +239,11 @@ class ChatStopApi(WebApiResource):
|
|||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
AppQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id)
|
||||
AppTaskService.stop_task(
|
||||
task_id=task_id,
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
user_id=end_user.id,
|
||||
app_mode=app_mode,
|
||||
)
|
||||
|
||||
return {"result": "success"}, 200
|
||||
|
|
|
|||
|
|
@ -62,7 +62,8 @@ from core.app.task_pipeline.message_cycle_manager import MessageCycleManager
|
|||
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.ops.entities.trace_entity import TraceTaskName
|
||||
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
|
||||
from core.workflow.enums import WorkflowExecutionStatus
|
||||
from core.workflow.nodes import NodeType
|
||||
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
|
||||
|
|
@ -72,7 +73,7 @@ from extensions.ext_database import db
|
|||
from libs.datetime_utils import naive_utc_now
|
||||
from models import Account, Conversation, EndUser, Message, MessageFile
|
||||
from models.enums import CreatorUserRole
|
||||
from models.workflow import Workflow
|
||||
from models.workflow import Workflow, WorkflowNodeExecutionModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -580,7 +581,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
|||
|
||||
with self._database_session() as session:
|
||||
# Save message
|
||||
self._save_message(session=session, graph_runtime_state=resolved_state)
|
||||
self._save_message(session=session, graph_runtime_state=resolved_state, trace_manager=trace_manager)
|
||||
|
||||
yield workflow_finish_resp
|
||||
elif event.stopped_by in (
|
||||
|
|
@ -590,7 +591,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
|||
# When hitting input-moderation or annotation-reply, the workflow will not start
|
||||
with self._database_session() as session:
|
||||
# Save message
|
||||
self._save_message(session=session)
|
||||
self._save_message(session=session, trace_manager=trace_manager)
|
||||
|
||||
yield self._message_end_to_stream_response()
|
||||
|
||||
|
|
@ -599,6 +600,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
|||
event: QueueAdvancedChatMessageEndEvent,
|
||||
*,
|
||||
graph_runtime_state: GraphRuntimeState | None = None,
|
||||
trace_manager: TraceQueueManager | None = None,
|
||||
**kwargs,
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle advanced chat message end events."""
|
||||
|
|
@ -616,7 +618,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
|||
|
||||
# Save message
|
||||
with self._database_session() as session:
|
||||
self._save_message(session=session, graph_runtime_state=resolved_state)
|
||||
self._save_message(session=session, graph_runtime_state=resolved_state, trace_manager=trace_manager)
|
||||
|
||||
yield self._message_end_to_stream_response()
|
||||
|
||||
|
|
@ -770,7 +772,13 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
|||
if self._conversation_name_generate_thread:
|
||||
self._conversation_name_generate_thread.join()
|
||||
|
||||
def _save_message(self, *, session: Session, graph_runtime_state: GraphRuntimeState | None = None):
|
||||
def _save_message(
|
||||
self,
|
||||
*,
|
||||
session: Session,
|
||||
graph_runtime_state: GraphRuntimeState | None = None,
|
||||
trace_manager: TraceQueueManager | None = None,
|
||||
):
|
||||
message = self._get_message(session=session)
|
||||
|
||||
# If there are assistant files, remove markdown image links from answer
|
||||
|
|
@ -809,6 +817,14 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
|||
|
||||
metadata = self._task_state.metadata.model_dump()
|
||||
message.message_metadata = json.dumps(jsonable_encoder(metadata))
|
||||
|
||||
# Extract model provider and model_id from workflow node executions for tracing
|
||||
if message.workflow_run_id:
|
||||
model_info = self._extract_model_info_from_workflow(session, message.workflow_run_id)
|
||||
if model_info:
|
||||
message.model_provider = model_info.get("provider")
|
||||
message.model_id = model_info.get("model")
|
||||
|
||||
message_files = [
|
||||
MessageFile(
|
||||
message_id=message.id,
|
||||
|
|
@ -826,6 +842,68 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
|||
]
|
||||
session.add_all(message_files)
|
||||
|
||||
# Trigger MESSAGE_TRACE for tracing integrations
|
||||
if trace_manager:
|
||||
trace_manager.add_trace_task(
|
||||
TraceTask(
|
||||
TraceTaskName.MESSAGE_TRACE, conversation_id=self._conversation_id, message_id=self._message_id
|
||||
)
|
||||
)
|
||||
|
||||
def _extract_model_info_from_workflow(self, session: Session, workflow_run_id: str) -> dict[str, str] | None:
|
||||
"""
|
||||
Extract model provider and model_id from workflow node executions.
|
||||
Returns dict with 'provider' and 'model' keys, or None if not found.
|
||||
"""
|
||||
try:
|
||||
# Query workflow node executions for LLM or Agent nodes
|
||||
stmt = (
|
||||
select(WorkflowNodeExecutionModel)
|
||||
.where(WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id)
|
||||
.where(WorkflowNodeExecutionModel.node_type.in_(["llm", "agent"]))
|
||||
.order_by(WorkflowNodeExecutionModel.created_at.desc())
|
||||
.limit(1)
|
||||
)
|
||||
node_execution = session.scalar(stmt)
|
||||
|
||||
if not node_execution:
|
||||
return None
|
||||
|
||||
# Try to extract from execution_metadata for agent nodes
|
||||
if node_execution.execution_metadata:
|
||||
try:
|
||||
metadata = json.loads(node_execution.execution_metadata)
|
||||
agent_log = metadata.get("agent_log", [])
|
||||
# Look for the first agent thought with provider info
|
||||
for log_entry in agent_log:
|
||||
entry_metadata = log_entry.get("metadata", {})
|
||||
provider_str = entry_metadata.get("provider")
|
||||
if provider_str:
|
||||
# Parse format like "langgenius/deepseek/deepseek"
|
||||
parts = provider_str.split("/")
|
||||
if len(parts) >= 3:
|
||||
return {"provider": parts[1], "model": parts[2]}
|
||||
elif len(parts) == 2:
|
||||
return {"provider": parts[0], "model": parts[1]}
|
||||
except (json.JSONDecodeError, KeyError, AttributeError) as e:
|
||||
logger.debug("Failed to parse execution_metadata: %s", e)
|
||||
|
||||
# Try to extract from process_data for llm nodes
|
||||
if node_execution.process_data:
|
||||
try:
|
||||
process_data = json.loads(node_execution.process_data)
|
||||
provider = process_data.get("model_provider")
|
||||
model = process_data.get("model_name")
|
||||
if provider and model:
|
||||
return {"provider": provider, "model": model}
|
||||
except (json.JSONDecodeError, KeyError) as e:
|
||||
logger.debug("Failed to parse process_data: %s", e)
|
||||
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.warning("Failed to extract model info from workflow: %s", e)
|
||||
return None
|
||||
|
||||
def _seed_graph_runtime_state_from_queue_manager(self) -> None:
|
||||
"""Bootstrap the cached runtime state from the queue manager when present."""
|
||||
candidate = self._base_task_pipeline.queue_manager.graph_runtime_state
|
||||
|
|
|
|||
|
|
@ -155,8 +155,17 @@ class BaseAppGenerator:
|
|||
f"{variable_entity.variable} in input form must be less than {variable_entity.max_length} files"
|
||||
)
|
||||
case VariableEntityType.CHECKBOX:
|
||||
if not isinstance(value, bool):
|
||||
raise ValueError(f"{variable_entity.variable} in input form must be a valid boolean value")
|
||||
if isinstance(value, str):
|
||||
normalized_value = value.strip().lower()
|
||||
if normalized_value in {"true", "1", "yes", "on"}:
|
||||
value = True
|
||||
elif normalized_value in {"false", "0", "no", "off"}:
|
||||
value = False
|
||||
elif isinstance(value, (int, float)):
|
||||
if value == 1:
|
||||
value = True
|
||||
elif value == 0:
|
||||
value = False
|
||||
case _:
|
||||
raise AssertionError("this statement should be unreachable.")
|
||||
|
||||
|
|
|
|||
|
|
@ -40,6 +40,9 @@ class EasyUITaskState(TaskState):
|
|||
"""
|
||||
|
||||
llm_result: LLMResult
|
||||
first_token_time: float | None = None
|
||||
last_token_time: float | None = None
|
||||
is_streaming_response: bool = False
|
||||
|
||||
|
||||
class WorkflowTaskState(TaskState):
|
||||
|
|
|
|||
|
|
@ -118,6 +118,7 @@ class PauseStatePersistenceLayer(GraphEngineLayer):
|
|||
workflow_run_id=workflow_run_id,
|
||||
state_owner_user_id=self._state_owner_user_id,
|
||||
state=state.dumps(),
|
||||
pause_reasons=event.reasons,
|
||||
)
|
||||
|
||||
def on_graph_end(self, error: Exception | None) -> None:
|
||||
|
|
|
|||
|
|
@ -332,6 +332,12 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
|||
if not self._task_state.llm_result.prompt_messages:
|
||||
self._task_state.llm_result.prompt_messages = chunk.prompt_messages
|
||||
|
||||
# Track streaming response times
|
||||
if self._task_state.first_token_time is None:
|
||||
self._task_state.first_token_time = time.perf_counter()
|
||||
self._task_state.is_streaming_response = True
|
||||
self._task_state.last_token_time = time.perf_counter()
|
||||
|
||||
# handle output moderation chunk
|
||||
should_direct_answer = self._handle_output_moderation_chunk(cast(str, delta_text))
|
||||
if should_direct_answer:
|
||||
|
|
@ -398,6 +404,18 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
|||
message.total_price = usage.total_price
|
||||
message.currency = usage.currency
|
||||
self._task_state.llm_result.usage.latency = message.provider_response_latency
|
||||
|
||||
# Add streaming metrics to usage if available
|
||||
if self._task_state.is_streaming_response and self._task_state.first_token_time:
|
||||
start_time = self.start_at
|
||||
first_token_time = self._task_state.first_token_time
|
||||
last_token_time = self._task_state.last_token_time or first_token_time
|
||||
usage.time_to_first_token = round(first_token_time - start_time, 3)
|
||||
usage.time_to_generate = round(last_token_time - first_token_time, 3)
|
||||
|
||||
# Update metadata with the complete usage info
|
||||
self._task_state.metadata.usage = usage
|
||||
|
||||
message.message_metadata = self._task_state.metadata.model_dump_json()
|
||||
|
||||
if trace_manager:
|
||||
|
|
|
|||
|
|
@ -222,6 +222,59 @@ class TencentSpanBuilder:
|
|||
links=links,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def build_message_llm_span(
|
||||
trace_info: MessageTraceInfo, trace_id: int, parent_span_id: int, user_id: str
|
||||
) -> SpanData:
|
||||
"""Build LLM span for message traces with detailed LLM attributes."""
|
||||
status = Status(StatusCode.OK)
|
||||
if trace_info.error:
|
||||
status = Status(StatusCode.ERROR, trace_info.error)
|
||||
|
||||
# Extract model information from `metadata`` or `message_data`
|
||||
trace_metadata = trace_info.metadata or {}
|
||||
message_data = trace_info.message_data or {}
|
||||
|
||||
model_provider = trace_metadata.get("ls_provider") or (
|
||||
message_data.get("model_provider", "") if isinstance(message_data, dict) else ""
|
||||
)
|
||||
model_name = trace_metadata.get("ls_model_name") or (
|
||||
message_data.get("model_id", "") if isinstance(message_data, dict) else ""
|
||||
)
|
||||
|
||||
inputs_str = str(trace_info.inputs or "")
|
||||
outputs_str = str(trace_info.outputs or "")
|
||||
|
||||
attributes = {
|
||||
GEN_AI_SESSION_ID: trace_metadata.get("conversation_id", ""),
|
||||
GEN_AI_USER_ID: str(user_id),
|
||||
GEN_AI_SPAN_KIND: GenAISpanKind.GENERATION.value,
|
||||
GEN_AI_FRAMEWORK: "dify",
|
||||
GEN_AI_MODEL_NAME: str(model_name),
|
||||
GEN_AI_PROVIDER: str(model_provider),
|
||||
GEN_AI_USAGE_INPUT_TOKENS: str(trace_info.message_tokens or 0),
|
||||
GEN_AI_USAGE_OUTPUT_TOKENS: str(trace_info.answer_tokens or 0),
|
||||
GEN_AI_USAGE_TOTAL_TOKENS: str(trace_info.total_tokens or 0),
|
||||
GEN_AI_PROMPT: inputs_str,
|
||||
GEN_AI_COMPLETION: outputs_str,
|
||||
INPUT_VALUE: inputs_str,
|
||||
OUTPUT_VALUE: outputs_str,
|
||||
}
|
||||
|
||||
if trace_info.is_streaming_request:
|
||||
attributes[GEN_AI_IS_STREAMING_REQUEST] = "true"
|
||||
|
||||
return SpanData(
|
||||
trace_id=trace_id,
|
||||
parent_span_id=parent_span_id,
|
||||
span_id=TencentTraceUtils.convert_to_span_id(trace_info.message_id, "llm"),
|
||||
name="GENERATION",
|
||||
start_time=TencentSpanBuilder._get_time_nanoseconds(trace_info.start_time),
|
||||
end_time=TencentSpanBuilder._get_time_nanoseconds(trace_info.end_time),
|
||||
attributes=attributes,
|
||||
status=status,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def build_tool_span(trace_info: ToolTraceInfo, trace_id: int, parent_span_id: int) -> SpanData:
|
||||
"""Build tool span."""
|
||||
|
|
|
|||
|
|
@ -107,9 +107,13 @@ class TencentDataTrace(BaseTraceInstance):
|
|||
links.append(TencentTraceUtils.create_link(trace_info.trace_id))
|
||||
|
||||
message_span = TencentSpanBuilder.build_message_span(trace_info, trace_id, str(user_id), links)
|
||||
|
||||
self.trace_client.add_span(message_span)
|
||||
|
||||
# Add LLM child span with detailed attributes
|
||||
parent_span_id = TencentTraceUtils.convert_to_span_id(trace_info.message_id, "message")
|
||||
llm_span = TencentSpanBuilder.build_message_llm_span(trace_info, trace_id, parent_span_id, str(user_id))
|
||||
self.trace_client.add_span(llm_span)
|
||||
|
||||
self._record_message_llm_metrics(trace_info)
|
||||
|
||||
# Record trace duration for entry span
|
||||
|
|
|
|||
|
|
@ -1,20 +1,110 @@
|
|||
import re
|
||||
from operator import itemgetter
|
||||
from typing import cast
|
||||
|
||||
|
||||
class JiebaKeywordTableHandler:
|
||||
def __init__(self):
|
||||
from core.rag.datasource.keyword.jieba.stopwords import STOPWORDS
|
||||
|
||||
tfidf = self._load_tfidf_extractor()
|
||||
tfidf.stop_words = STOPWORDS # type: ignore[attr-defined]
|
||||
self._tfidf = tfidf
|
||||
|
||||
def _load_tfidf_extractor(self):
|
||||
"""
|
||||
Load jieba TFIDF extractor with fallback strategy.
|
||||
|
||||
Loading Flow:
|
||||
┌─────────────────────────────────────────────────────────────────────┐
|
||||
│ jieba.analyse.default_tfidf │
|
||||
│ exists? │
|
||||
└─────────────────────────────────────────────────────────────────────┘
|
||||
│ │
|
||||
YES NO
|
||||
│ │
|
||||
▼ ▼
|
||||
┌──────────────────┐ ┌──────────────────────────────────┐
|
||||
│ Return default │ │ jieba.analyse.TFIDF exists? │
|
||||
│ TFIDF │ └──────────────────────────────────┘
|
||||
└──────────────────┘ │ │
|
||||
YES NO
|
||||
│ │
|
||||
│ ▼
|
||||
│ ┌────────────────────────────┐
|
||||
│ │ Try import from │
|
||||
│ │ jieba.analyse.tfidf.TFIDF │
|
||||
│ └────────────────────────────┘
|
||||
│ │ │
|
||||
│ SUCCESS FAILED
|
||||
│ │ │
|
||||
▼ ▼ ▼
|
||||
┌────────────────────────┐ ┌─────────────────┐
|
||||
│ Instantiate TFIDF() │ │ Build fallback │
|
||||
│ & cache to default │ │ _SimpleTFIDF │
|
||||
└────────────────────────┘ └─────────────────┘
|
||||
"""
|
||||
import jieba.analyse # type: ignore
|
||||
|
||||
tfidf = getattr(jieba.analyse, "default_tfidf", None)
|
||||
if tfidf is not None:
|
||||
return tfidf
|
||||
|
||||
tfidf_class = getattr(jieba.analyse, "TFIDF", None)
|
||||
if tfidf_class is None:
|
||||
try:
|
||||
from jieba.analyse.tfidf import TFIDF # type: ignore
|
||||
|
||||
tfidf_class = TFIDF
|
||||
except Exception:
|
||||
tfidf_class = None
|
||||
|
||||
if tfidf_class is not None:
|
||||
tfidf = tfidf_class()
|
||||
jieba.analyse.default_tfidf = tfidf # type: ignore[attr-defined]
|
||||
return tfidf
|
||||
|
||||
return self._build_fallback_tfidf()
|
||||
|
||||
@staticmethod
|
||||
def _build_fallback_tfidf():
|
||||
"""Fallback lightweight TFIDF for environments missing jieba's TFIDF."""
|
||||
import jieba # type: ignore
|
||||
|
||||
from core.rag.datasource.keyword.jieba.stopwords import STOPWORDS
|
||||
|
||||
jieba.analyse.default_tfidf.stop_words = STOPWORDS # type: ignore
|
||||
class _SimpleTFIDF:
|
||||
def __init__(self):
|
||||
self.stop_words = STOPWORDS
|
||||
self._lcut = getattr(jieba, "lcut", None)
|
||||
|
||||
def extract_tags(self, sentence: str, top_k: int | None = 20, **kwargs):
|
||||
# Basic frequency-based keyword extraction as a fallback when TF-IDF is unavailable.
|
||||
top_k = kwargs.pop("topK", top_k)
|
||||
cut = getattr(jieba, "cut", None)
|
||||
if self._lcut:
|
||||
tokens = self._lcut(sentence)
|
||||
elif callable(cut):
|
||||
tokens = list(cut(sentence))
|
||||
else:
|
||||
tokens = re.findall(r"\w+", sentence)
|
||||
|
||||
words = [w for w in tokens if w and w not in self.stop_words]
|
||||
freq: dict[str, int] = {}
|
||||
for w in words:
|
||||
freq[w] = freq.get(w, 0) + 1
|
||||
|
||||
sorted_words = sorted(freq.items(), key=itemgetter(1), reverse=True)
|
||||
if top_k is not None:
|
||||
sorted_words = sorted_words[:top_k]
|
||||
|
||||
return [item[0] for item in sorted_words]
|
||||
|
||||
return _SimpleTFIDF()
|
||||
|
||||
def extract_keywords(self, text: str, max_keywords_per_chunk: int | None = 10) -> set[str]:
|
||||
"""Extract keywords with JIEBA tfidf."""
|
||||
import jieba.analyse # type: ignore
|
||||
|
||||
keywords = jieba.analyse.extract_tags(
|
||||
keywords = self._tfidf.extract_tags(
|
||||
sentence=text,
|
||||
topK=max_keywords_per_chunk,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -302,8 +302,7 @@ class OracleVector(BaseVector):
|
|||
nltk.data.find("tokenizers/punkt")
|
||||
nltk.data.find("corpora/stopwords")
|
||||
except LookupError:
|
||||
nltk.download("punkt")
|
||||
nltk.download("stopwords")
|
||||
raise LookupError("Unable to find the required NLTK data package: punkt and stopwords")
|
||||
e_str = re.sub(r"[^\w ]", "", query)
|
||||
all_tokens = nltk.word_tokenize(e_str)
|
||||
stop_words = stopwords.words("english")
|
||||
|
|
|
|||
|
|
@ -167,13 +167,18 @@ class WeaviateVector(BaseVector):
|
|||
|
||||
try:
|
||||
if not self._client.collections.exists(self._collection_name):
|
||||
tokenization = (
|
||||
wc.Tokenization(dify_config.WEAVIATE_TOKENIZATION)
|
||||
if dify_config.WEAVIATE_TOKENIZATION
|
||||
else wc.Tokenization.WORD
|
||||
)
|
||||
self._client.collections.create(
|
||||
name=self._collection_name,
|
||||
properties=[
|
||||
wc.Property(
|
||||
name=Field.TEXT_KEY.value,
|
||||
data_type=wc.DataType.TEXT,
|
||||
tokenization=wc.Tokenization.WORD,
|
||||
tokenization=tokenization,
|
||||
),
|
||||
wc.Property(name="document_id", data_type=wc.DataType.TEXT),
|
||||
wc.Property(name="doc_id", data_type=wc.DataType.TEXT),
|
||||
|
|
|
|||
|
|
@ -141,6 +141,7 @@ class WorkflowToolProviderController(ToolProviderController):
|
|||
form=parameter.form,
|
||||
llm_description=parameter.description,
|
||||
required=variable.required,
|
||||
default=variable.default,
|
||||
options=options,
|
||||
placeholder=I18nObject(en_US="", zh_Hans=""),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -71,6 +71,11 @@ class TriggerProviderIdentity(BaseModel):
|
|||
icon_dark: str | None = Field(default=None, description="The dark icon of the trigger provider")
|
||||
tags: list[str] = Field(default_factory=list, description="The tags of the trigger provider")
|
||||
|
||||
@field_validator("tags", mode="before")
|
||||
@classmethod
|
||||
def validate_tags(cls, v: list[str] | None) -> list[str]:
|
||||
return v or []
|
||||
|
||||
|
||||
class EventIdentity(BaseModel):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -1,17 +1,11 @@
|
|||
from ..runtime.graph_runtime_state import GraphRuntimeState
|
||||
from ..runtime.variable_pool import VariablePool
|
||||
from .agent import AgentNodeStrategyInit
|
||||
from .graph_init_params import GraphInitParams
|
||||
from .workflow_execution import WorkflowExecution
|
||||
from .workflow_node_execution import WorkflowNodeExecution
|
||||
from .workflow_pause import WorkflowPauseEntity
|
||||
|
||||
__all__ = [
|
||||
"AgentNodeStrategyInit",
|
||||
"GraphInitParams",
|
||||
"GraphRuntimeState",
|
||||
"VariablePool",
|
||||
"WorkflowExecution",
|
||||
"WorkflowNodeExecution",
|
||||
"WorkflowPauseEntity",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -1,49 +1,26 @@
|
|||
from enum import StrEnum, auto
|
||||
from typing import Annotated, Any, ClassVar, TypeAlias
|
||||
from typing import Annotated, Literal, TypeAlias
|
||||
|
||||
from pydantic import BaseModel, Discriminator, Tag
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class _PauseReasonType(StrEnum):
|
||||
class PauseReasonType(StrEnum):
|
||||
HUMAN_INPUT_REQUIRED = auto()
|
||||
SCHEDULED_PAUSE = auto()
|
||||
|
||||
|
||||
class _PauseReasonBase(BaseModel):
|
||||
TYPE: ClassVar[_PauseReasonType]
|
||||
class HumanInputRequired(BaseModel):
|
||||
TYPE: Literal[PauseReasonType.HUMAN_INPUT_REQUIRED] = PauseReasonType.HUMAN_INPUT_REQUIRED
|
||||
|
||||
form_id: str
|
||||
# The identifier of the human input node causing the pause.
|
||||
node_id: str
|
||||
|
||||
|
||||
class HumanInputRequired(_PauseReasonBase):
|
||||
TYPE = _PauseReasonType.HUMAN_INPUT_REQUIRED
|
||||
|
||||
|
||||
class SchedulingPause(_PauseReasonBase):
|
||||
TYPE = _PauseReasonType.SCHEDULED_PAUSE
|
||||
class SchedulingPause(BaseModel):
|
||||
TYPE: Literal[PauseReasonType.SCHEDULED_PAUSE] = PauseReasonType.SCHEDULED_PAUSE
|
||||
|
||||
message: str
|
||||
|
||||
|
||||
def _get_pause_reason_discriminator(v: Any) -> _PauseReasonType | None:
|
||||
if isinstance(v, _PauseReasonBase):
|
||||
return v.TYPE
|
||||
elif isinstance(v, dict):
|
||||
reason_type_str = v.get("TYPE")
|
||||
if reason_type_str is None:
|
||||
return None
|
||||
try:
|
||||
reason_type = _PauseReasonType(reason_type_str)
|
||||
except ValueError:
|
||||
return None
|
||||
return reason_type
|
||||
else:
|
||||
# return None if the discriminator value isn't found
|
||||
return None
|
||||
|
||||
|
||||
PauseReason: TypeAlias = Annotated[
|
||||
(
|
||||
Annotated[HumanInputRequired, Tag(_PauseReasonType.HUMAN_INPUT_REQUIRED)]
|
||||
| Annotated[SchedulingPause, Tag(_PauseReasonType.SCHEDULED_PAUSE)]
|
||||
),
|
||||
Discriminator(_get_pause_reason_discriminator),
|
||||
]
|
||||
PauseReason: TypeAlias = Annotated[HumanInputRequired | SchedulingPause, Field(discriminator="TYPE")]
|
||||
|
|
|
|||
|
|
@ -42,7 +42,7 @@ class GraphExecutionState(BaseModel):
|
|||
completed: bool = Field(default=False)
|
||||
aborted: bool = Field(default=False)
|
||||
paused: bool = Field(default=False)
|
||||
pause_reason: PauseReason | None = Field(default=None)
|
||||
pause_reasons: list[PauseReason] = Field(default_factory=list)
|
||||
error: GraphExecutionErrorState | None = Field(default=None)
|
||||
exceptions_count: int = Field(default=0)
|
||||
node_executions: list[NodeExecutionState] = Field(default_factory=list[NodeExecutionState])
|
||||
|
|
@ -107,7 +107,7 @@ class GraphExecution:
|
|||
completed: bool = False
|
||||
aborted: bool = False
|
||||
paused: bool = False
|
||||
pause_reason: PauseReason | None = None
|
||||
pause_reasons: list[PauseReason] = field(default_factory=list)
|
||||
error: Exception | None = None
|
||||
node_executions: dict[str, NodeExecution] = field(default_factory=dict[str, NodeExecution])
|
||||
exceptions_count: int = 0
|
||||
|
|
@ -137,10 +137,8 @@ class GraphExecution:
|
|||
raise RuntimeError("Cannot pause execution that has completed")
|
||||
if self.aborted:
|
||||
raise RuntimeError("Cannot pause execution that has been aborted")
|
||||
if self.paused:
|
||||
return
|
||||
self.paused = True
|
||||
self.pause_reason = reason
|
||||
self.pause_reasons.append(reason)
|
||||
|
||||
def fail(self, error: Exception) -> None:
|
||||
"""Mark the graph execution as failed."""
|
||||
|
|
@ -195,7 +193,7 @@ class GraphExecution:
|
|||
completed=self.completed,
|
||||
aborted=self.aborted,
|
||||
paused=self.paused,
|
||||
pause_reason=self.pause_reason,
|
||||
pause_reasons=self.pause_reasons,
|
||||
error=_serialize_error(self.error),
|
||||
exceptions_count=self.exceptions_count,
|
||||
node_executions=node_states,
|
||||
|
|
@ -221,7 +219,7 @@ class GraphExecution:
|
|||
self.completed = state.completed
|
||||
self.aborted = state.aborted
|
||||
self.paused = state.paused
|
||||
self.pause_reason = state.pause_reason
|
||||
self.pause_reasons = state.pause_reasons
|
||||
self.error = _deserialize_error(state.error)
|
||||
self.exceptions_count = state.exceptions_count
|
||||
self.node_executions = {
|
||||
|
|
|
|||
|
|
@ -110,7 +110,13 @@ class EventManager:
|
|||
"""
|
||||
with self._lock.write_lock():
|
||||
self._events.append(event)
|
||||
self._notify_layers(event)
|
||||
|
||||
# NOTE: `_notify_layers` is intentionally called outside the critical section
|
||||
# to minimize lock contention and avoid blocking other readers or writers.
|
||||
#
|
||||
# The public `notify_layers` method also does not use a write lock,
|
||||
# so protecting `_notify_layers` with a lock here is unnecessary.
|
||||
self._notify_layers(event)
|
||||
|
||||
def _get_new_events(self, start_index: int) -> list[GraphEngineEvent]:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -232,7 +232,7 @@ class GraphEngine:
|
|||
self._graph_execution.start()
|
||||
else:
|
||||
self._graph_execution.paused = False
|
||||
self._graph_execution.pause_reason = None
|
||||
self._graph_execution.pause_reasons = []
|
||||
|
||||
start_event = GraphRunStartedEvent()
|
||||
self._event_manager.notify_layers(start_event)
|
||||
|
|
@ -246,11 +246,11 @@ class GraphEngine:
|
|||
|
||||
# Handle completion
|
||||
if self._graph_execution.is_paused:
|
||||
pause_reason = self._graph_execution.pause_reason
|
||||
assert pause_reason is not None, "pause_reason should not be None when execution is paused."
|
||||
pause_reasons = self._graph_execution.pause_reasons
|
||||
assert pause_reasons, "pause_reasons should not be empty when execution is paused."
|
||||
# Ensure we have a valid PauseReason for the event
|
||||
paused_event = GraphRunPausedEvent(
|
||||
reason=pause_reason,
|
||||
reasons=pause_reasons,
|
||||
outputs=self._graph_runtime_state.outputs,
|
||||
)
|
||||
self._event_manager.notify_layers(paused_event)
|
||||
|
|
|
|||
|
|
@ -45,8 +45,7 @@ class GraphRunAbortedEvent(BaseGraphEvent):
|
|||
class GraphRunPausedEvent(BaseGraphEvent):
|
||||
"""Event emitted when a graph run is paused by user command."""
|
||||
|
||||
# reason: str | None = Field(default=None, description="reason for pause")
|
||||
reason: PauseReason = Field(..., description="reason for pause")
|
||||
reasons: list[PauseReason] = Field(description="reason for pause", default_factory=list)
|
||||
outputs: dict[str, object] = Field(
|
||||
default_factory=dict,
|
||||
description="Outputs available to the client while the run is paused.",
|
||||
|
|
|
|||
|
|
@ -65,7 +65,8 @@ class HumanInputNode(Node):
|
|||
return self._pause_generator()
|
||||
|
||||
def _pause_generator(self):
|
||||
yield PauseRequestedEvent(reason=HumanInputRequired())
|
||||
# TODO(QuantumGhost): yield a real form id.
|
||||
yield PauseRequestedEvent(reason=HumanInputRequired(form_id="test_form_id", node_id=self.id))
|
||||
|
||||
def _is_completion_ready(self) -> bool:
|
||||
"""Determine whether all required inputs are satisfied."""
|
||||
|
|
|
|||
|
|
@ -229,6 +229,8 @@ def _get_file_extract_string_func(*, key: str) -> Callable[[File], str]:
|
|||
return lambda x: x.transfer_method
|
||||
case "url":
|
||||
return lambda x: x.remote_url or ""
|
||||
case "related_id":
|
||||
return lambda x: x.related_id or ""
|
||||
case _:
|
||||
raise InvalidKeyError(f"Invalid key: {key}")
|
||||
|
||||
|
|
@ -299,7 +301,7 @@ def _get_boolean_filter_func(*, condition: FilterOperator, value: bool) -> Calla
|
|||
|
||||
def _get_file_filter_func(*, key: str, condition: str, value: str | Sequence[str]) -> Callable[[File], bool]:
|
||||
extract_func: Callable[[File], Any]
|
||||
if key in {"name", "extension", "mime_type", "url"} and isinstance(value, str):
|
||||
if key in {"name", "extension", "mime_type", "url", "related_id"} and isinstance(value, str):
|
||||
extract_func = _get_file_extract_string_func(key=key)
|
||||
return lambda x: _get_string_filter_func(condition=condition, value=value)(extract_func(x))
|
||||
if key in {"type", "transfer_method"}:
|
||||
|
|
@ -358,7 +360,7 @@ def _ge(value: int | float) -> Callable[[int | float], bool]:
|
|||
|
||||
def _order_file(*, order: Order, order_by: str = "", array: Sequence[File]):
|
||||
extract_func: Callable[[File], Any]
|
||||
if order_by in {"name", "type", "extension", "mime_type", "transfer_method", "url"}:
|
||||
if order_by in {"name", "type", "extension", "mime_type", "transfer_method", "url", "related_id"}:
|
||||
extract_func = _get_file_extract_string_func(key=order_by)
|
||||
return sorted(array, key=lambda x: extract_func(x), reverse=order == Order.DESC)
|
||||
elif order_by == "size":
|
||||
|
|
|
|||
|
|
@ -329,7 +329,15 @@ class ToolNode(Node):
|
|||
json.append(message.message.json_object)
|
||||
elif message.type == ToolInvokeMessage.MessageType.LINK:
|
||||
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
|
||||
stream_text = f"Link: {message.message.text}\n"
|
||||
|
||||
# Check if this LINK message is a file link
|
||||
file_obj = (message.meta or {}).get("file")
|
||||
if isinstance(file_obj, File):
|
||||
files.append(file_obj)
|
||||
stream_text = f"File: {message.message.text}\n"
|
||||
else:
|
||||
stream_text = f"Link: {message.message.text}\n"
|
||||
|
||||
text += stream_text
|
||||
yield StreamChunkEvent(
|
||||
selector=[node_id, "text"],
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ from typing import Any, Protocol
|
|||
from pydantic.json import pydantic_encoder
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.workflow.entities.pause_reason import PauseReason
|
||||
from core.workflow.runtime.variable_pool import VariablePool
|
||||
|
||||
|
||||
|
|
@ -46,7 +47,11 @@ class ReadyQueueProtocol(Protocol):
|
|||
|
||||
|
||||
class GraphExecutionProtocol(Protocol):
|
||||
"""Structural interface for graph execution aggregate."""
|
||||
"""Structural interface for graph execution aggregate.
|
||||
|
||||
Defines the minimal set of attributes and methods required from a GraphExecution entity
|
||||
for runtime orchestration and state management.
|
||||
"""
|
||||
|
||||
workflow_id: str
|
||||
started: bool
|
||||
|
|
@ -54,6 +59,7 @@ class GraphExecutionProtocol(Protocol):
|
|||
aborted: bool
|
||||
error: Exception | None
|
||||
exceptions_count: int
|
||||
pause_reasons: list[PauseReason]
|
||||
|
||||
def start(self) -> None:
|
||||
"""Transition execution into the running state."""
|
||||
|
|
|
|||
|
|
@ -112,7 +112,7 @@ class Storage:
|
|||
def exists(self, filename):
|
||||
return self.storage_runner.exists(filename)
|
||||
|
||||
def delete(self, filename):
|
||||
def delete(self, filename: str):
|
||||
return self.storage_runner.delete(filename)
|
||||
|
||||
def scan(self, path: str, files: bool = True, directories: bool = False) -> list[str]:
|
||||
|
|
|
|||
|
|
@ -0,0 +1,41 @@
|
|||
"""Add workflow_pauses_reasons table
|
||||
|
||||
Revision ID: 7bb281b7a422
|
||||
Revises: 09cfdda155d1
|
||||
Create Date: 2025-11-18 18:59:26.999572
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import models as models
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "7bb281b7a422"
|
||||
down_revision = "09cfdda155d1"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
op.create_table(
|
||||
"workflow_pause_reasons",
|
||||
sa.Column("id", models.types.StringUUID(), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
|
||||
|
||||
sa.Column("pause_id", models.types.StringUUID(), nullable=False),
|
||||
sa.Column("type_", sa.String(20), nullable=False),
|
||||
sa.Column("form_id", sa.String(length=36), nullable=False),
|
||||
sa.Column("node_id", sa.String(length=255), nullable=False),
|
||||
sa.Column("message", sa.String(length=255), nullable=False),
|
||||
|
||||
sa.PrimaryKeyConstraint("id", name=op.f("workflow_pause_reasons_pkey")),
|
||||
)
|
||||
with op.batch_alter_table("workflow_pause_reasons", schema=None) as batch_op:
|
||||
batch_op.create_index(batch_op.f("workflow_pause_reasons_pause_id_idx"), ["pause_id"], unique=False)
|
||||
|
||||
|
||||
def downgrade():
|
||||
op.drop_table("workflow_pause_reasons")
|
||||
|
|
@ -88,7 +88,9 @@ class Account(UserMixin, TypeBase):
|
|||
__tablename__ = "accounts"
|
||||
__table_args__ = (sa.PrimaryKeyConstraint("id", name="account_pkey"), sa.Index("account_email_idx", "email"))
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
|
||||
id: Mapped[str] = mapped_column(
|
||||
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
|
||||
)
|
||||
name: Mapped[str] = mapped_column(String(255))
|
||||
email: Mapped[str] = mapped_column(String(255))
|
||||
password: Mapped[str | None] = mapped_column(String(255), default=None)
|
||||
|
|
@ -235,7 +237,9 @@ class Tenant(TypeBase):
|
|||
__tablename__ = "tenants"
|
||||
__table_args__ = (sa.PrimaryKeyConstraint("id", name="tenant_pkey"),)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
|
||||
id: Mapped[str] = mapped_column(
|
||||
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
|
||||
)
|
||||
name: Mapped[str] = mapped_column(String(255))
|
||||
encrypt_public_key: Mapped[str | None] = mapped_column(LongText, default=None)
|
||||
plan: Mapped[str] = mapped_column(String(255), server_default=sa.text("'basic'"), default="basic")
|
||||
|
|
@ -275,7 +279,9 @@ class TenantAccountJoin(TypeBase):
|
|||
sa.UniqueConstraint("tenant_id", "account_id", name="unique_tenant_account_join"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
|
||||
id: Mapped[str] = mapped_column(
|
||||
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
|
||||
)
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID)
|
||||
account_id: Mapped[str] = mapped_column(StringUUID)
|
||||
current: Mapped[bool] = mapped_column(sa.Boolean, server_default=sa.text("false"), default=False)
|
||||
|
|
@ -297,7 +303,9 @@ class AccountIntegrate(TypeBase):
|
|||
sa.UniqueConstraint("provider", "open_id", name="unique_provider_open_id"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
|
||||
id: Mapped[str] = mapped_column(
|
||||
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
|
||||
)
|
||||
account_id: Mapped[str] = mapped_column(StringUUID)
|
||||
provider: Mapped[str] = mapped_column(String(16))
|
||||
open_id: Mapped[str] = mapped_column(String(255))
|
||||
|
|
@ -348,7 +356,9 @@ class TenantPluginPermission(TypeBase):
|
|||
sa.UniqueConstraint("tenant_id", name="unique_tenant_plugin"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
|
||||
id: Mapped[str] = mapped_column(
|
||||
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
|
||||
)
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
install_permission: Mapped[InstallPermission] = mapped_column(
|
||||
String(16), nullable=False, server_default="everyone", default=InstallPermission.EVERYONE
|
||||
|
|
@ -375,7 +385,9 @@ class TenantPluginAutoUpgradeStrategy(TypeBase):
|
|||
sa.UniqueConstraint("tenant_id", name="unique_tenant_plugin_auto_upgrade_strategy"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
|
||||
id: Mapped[str] = mapped_column(
|
||||
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
|
||||
)
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
strategy_setting: Mapped[StrategySetting] = mapped_column(
|
||||
String(16), nullable=False, server_default="fix_only", default=StrategySetting.FIX_ONLY
|
||||
|
|
|
|||
|
|
@ -24,7 +24,9 @@ class APIBasedExtension(TypeBase):
|
|||
sa.Index("api_based_extension_tenant_idx", "tenant_id"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
|
||||
id: Mapped[str] = mapped_column(
|
||||
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
|
||||
)
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
api_endpoint: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
|
|
|
|||
|
|
@ -920,7 +920,12 @@ class AppDatasetJoin(TypeBase):
|
|||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(
|
||||
StringUUID, primary_key=True, nullable=False, default=lambda: str(uuid4()), init=False
|
||||
StringUUID,
|
||||
primary_key=True,
|
||||
nullable=False,
|
||||
insert_default=lambda: str(uuid4()),
|
||||
default_factory=lambda: str(uuid4()),
|
||||
init=False,
|
||||
)
|
||||
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
dataset_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
|
|
@ -941,7 +946,12 @@ class DatasetQuery(TypeBase):
|
|||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(
|
||||
StringUUID, primary_key=True, nullable=False, default=lambda: str(uuid4()), init=False
|
||||
StringUUID,
|
||||
primary_key=True,
|
||||
nullable=False,
|
||||
insert_default=lambda: str(uuid4()),
|
||||
default_factory=lambda: str(uuid4()),
|
||||
init=False,
|
||||
)
|
||||
dataset_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
content: Mapped[str] = mapped_column(LongText, nullable=False)
|
||||
|
|
@ -961,7 +971,13 @@ class DatasetKeywordTable(TypeBase):
|
|||
sa.Index("dataset_keyword_table_dataset_id_idx", "dataset_id"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, primary_key=True, default=lambda: str(uuid4()), init=False)
|
||||
id: Mapped[str] = mapped_column(
|
||||
StringUUID,
|
||||
primary_key=True,
|
||||
insert_default=lambda: str(uuid4()),
|
||||
default_factory=lambda: str(uuid4()),
|
||||
init=False,
|
||||
)
|
||||
dataset_id: Mapped[str] = mapped_column(StringUUID, nullable=False, unique=True)
|
||||
keyword_table: Mapped[str] = mapped_column(LongText, nullable=False)
|
||||
data_source_type: Mapped[str] = mapped_column(
|
||||
|
|
@ -1012,7 +1028,13 @@ class Embedding(TypeBase):
|
|||
sa.Index("created_at_idx", "created_at"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, primary_key=True, default=lambda: str(uuid4()), init=False)
|
||||
id: Mapped[str] = mapped_column(
|
||||
StringUUID,
|
||||
primary_key=True,
|
||||
insert_default=lambda: str(uuid4()),
|
||||
default_factory=lambda: str(uuid4()),
|
||||
init=False,
|
||||
)
|
||||
model_name: Mapped[str] = mapped_column(
|
||||
String(255), nullable=False, server_default=sa.text("'text-embedding-ada-002'")
|
||||
)
|
||||
|
|
@ -1037,7 +1059,13 @@ class DatasetCollectionBinding(TypeBase):
|
|||
sa.Index("provider_model_name_idx", "provider_name", "model_name"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, primary_key=True, default=lambda: str(uuid4()), init=False)
|
||||
id: Mapped[str] = mapped_column(
|
||||
StringUUID,
|
||||
primary_key=True,
|
||||
insert_default=lambda: str(uuid4()),
|
||||
default_factory=lambda: str(uuid4()),
|
||||
init=False,
|
||||
)
|
||||
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
model_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
type: Mapped[str] = mapped_column(String(40), server_default=sa.text("'dataset'"), nullable=False)
|
||||
|
|
@ -1073,7 +1101,13 @@ class Whitelist(TypeBase):
|
|||
sa.PrimaryKeyConstraint("id", name="whitelists_pkey"),
|
||||
sa.Index("whitelists_tenant_idx", "tenant_id"),
|
||||
)
|
||||
id: Mapped[str] = mapped_column(StringUUID, primary_key=True, default=lambda: str(uuid4()), init=False)
|
||||
id: Mapped[str] = mapped_column(
|
||||
StringUUID,
|
||||
primary_key=True,
|
||||
insert_default=lambda: str(uuid4()),
|
||||
default_factory=lambda: str(uuid4()),
|
||||
init=False,
|
||||
)
|
||||
tenant_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
|
||||
category: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
|
|
@ -1090,7 +1124,13 @@ class DatasetPermission(TypeBase):
|
|||
sa.Index("idx_dataset_permissions_tenant_id", "tenant_id"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), primary_key=True, init=False)
|
||||
id: Mapped[str] = mapped_column(
|
||||
StringUUID,
|
||||
insert_default=lambda: str(uuid4()),
|
||||
default_factory=lambda: str(uuid4()),
|
||||
primary_key=True,
|
||||
init=False,
|
||||
)
|
||||
dataset_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
account_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
|
|
@ -1110,7 +1150,13 @@ class ExternalKnowledgeApis(TypeBase):
|
|||
sa.Index("external_knowledge_apis_name_idx", "name"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, nullable=False, default=lambda: str(uuid4()), init=False)
|
||||
id: Mapped[str] = mapped_column(
|
||||
StringUUID,
|
||||
nullable=False,
|
||||
insert_default=lambda: str(uuid4()),
|
||||
default_factory=lambda: str(uuid4()),
|
||||
init=False,
|
||||
)
|
||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
description: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
|
|
@ -1167,7 +1213,13 @@ class ExternalKnowledgeBindings(TypeBase):
|
|||
sa.Index("external_knowledge_bindings_external_knowledge_api_idx", "external_knowledge_api_id"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, nullable=False, default=lambda: str(uuid4()), init=False)
|
||||
id: Mapped[str] = mapped_column(
|
||||
StringUUID,
|
||||
nullable=False,
|
||||
insert_default=lambda: str(uuid4()),
|
||||
default_factory=lambda: str(uuid4()),
|
||||
init=False,
|
||||
)
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
external_knowledge_api_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
dataset_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
|
|
@ -1191,7 +1243,9 @@ class DatasetAutoDisableLog(TypeBase):
|
|||
sa.Index("dataset_auto_disable_log_created_atx", "created_at"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
|
||||
id: Mapped[str] = mapped_column(
|
||||
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
|
||||
)
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
dataset_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
document_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
|
|
@ -1209,7 +1263,9 @@ class RateLimitLog(TypeBase):
|
|||
sa.Index("rate_limit_log_operation_idx", "operation"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
|
||||
id: Mapped[str] = mapped_column(
|
||||
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
|
||||
)
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
subscription_plan: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
operation: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
|
|
@ -1226,7 +1282,9 @@ class DatasetMetadata(TypeBase):
|
|||
sa.Index("dataset_metadata_dataset_idx", "dataset_id"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
|
||||
id: Mapped[str] = mapped_column(
|
||||
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
|
||||
)
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
dataset_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
type: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
|
|
@ -1255,7 +1313,9 @@ class DatasetMetadataBinding(TypeBase):
|
|||
sa.Index("dataset_metadata_binding_document_idx", "document_id"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
|
||||
id: Mapped[str] = mapped_column(
|
||||
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
|
||||
)
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
dataset_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
metadata_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
|
|
@ -1270,7 +1330,9 @@ class PipelineBuiltInTemplate(TypeBase):
|
|||
__tablename__ = "pipeline_built_in_templates"
|
||||
__table_args__ = (sa.PrimaryKeyConstraint("id", name="pipeline_built_in_template_pkey"),)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7()), init=False)
|
||||
id: Mapped[str] = mapped_column(
|
||||
StringUUID, insert_default=lambda: str(uuidv7()), default_factory=lambda: str(uuidv7()), init=False
|
||||
)
|
||||
name: Mapped[str] = mapped_column(sa.String(255), nullable=False)
|
||||
description: Mapped[str] = mapped_column(LongText, nullable=False)
|
||||
chunk_structure: Mapped[str] = mapped_column(sa.String(255), nullable=False)
|
||||
|
|
@ -1300,7 +1362,9 @@ class PipelineCustomizedTemplate(TypeBase):
|
|||
sa.Index("pipeline_customized_template_tenant_idx", "tenant_id"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7()), init=False)
|
||||
id: Mapped[str] = mapped_column(
|
||||
StringUUID, insert_default=lambda: str(uuidv7()), default_factory=lambda: str(uuidv7()), init=False
|
||||
)
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
name: Mapped[str] = mapped_column(sa.String(255), nullable=False)
|
||||
description: Mapped[str] = mapped_column(LongText, nullable=False)
|
||||
|
|
@ -1335,7 +1399,9 @@ class Pipeline(TypeBase):
|
|||
__tablename__ = "pipelines"
|
||||
__table_args__ = (sa.PrimaryKeyConstraint("id", name="pipeline_pkey"),)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7()), init=False)
|
||||
id: Mapped[str] = mapped_column(
|
||||
StringUUID, insert_default=lambda: str(uuidv7()), default_factory=lambda: str(uuidv7()), init=False
|
||||
)
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
name: Mapped[str] = mapped_column(sa.String(255), nullable=False)
|
||||
description: Mapped[str] = mapped_column(LongText, nullable=False, default=sa.text("''"))
|
||||
|
|
@ -1368,7 +1434,9 @@ class DocumentPipelineExecutionLog(TypeBase):
|
|||
sa.Index("document_pipeline_execution_logs_document_id_idx", "document_id"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7()), init=False)
|
||||
id: Mapped[str] = mapped_column(
|
||||
StringUUID, insert_default=lambda: str(uuidv7()), default_factory=lambda: str(uuidv7()), init=False
|
||||
)
|
||||
pipeline_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
document_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
datasource_type: Mapped[str] = mapped_column(sa.String(255), nullable=False)
|
||||
|
|
@ -1385,7 +1453,9 @@ class PipelineRecommendedPlugin(TypeBase):
|
|||
__tablename__ = "pipeline_recommended_plugins"
|
||||
__table_args__ = (sa.PrimaryKeyConstraint("id", name="pipeline_recommended_plugin_pkey"),)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7()), init=False)
|
||||
id: Mapped[str] = mapped_column(
|
||||
StringUUID, insert_default=lambda: str(uuidv7()), default_factory=lambda: str(uuidv7()), init=False
|
||||
)
|
||||
plugin_id: Mapped[str] = mapped_column(LongText, nullable=False)
|
||||
provider_name: Mapped[str] = mapped_column(LongText, nullable=False)
|
||||
position: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0)
|
||||
|
|
|
|||
|
|
@ -572,7 +572,9 @@ class InstalledApp(TypeBase):
|
|||
sa.UniqueConstraint("tenant_id", "app_id", name="unique_tenant_app"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
|
||||
id: Mapped[str] = mapped_column(
|
||||
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
|
||||
)
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
app_owner_tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
|
|
@ -606,7 +608,9 @@ class OAuthProviderApp(TypeBase):
|
|||
sa.Index("oauth_provider_app_client_id_idx", "client_id"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7()), init=False)
|
||||
id: Mapped[str] = mapped_column(
|
||||
StringUUID, insert_default=lambda: str(uuidv7()), default_factory=lambda: str(uuidv7()), init=False
|
||||
)
|
||||
app_icon: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
client_id: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
client_secret: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
|
|
@ -1251,9 +1255,13 @@ class Message(Base):
|
|||
"id": self.id,
|
||||
"app_id": self.app_id,
|
||||
"conversation_id": self.conversation_id,
|
||||
"model_provider": self.model_provider,
|
||||
"model_id": self.model_id,
|
||||
"inputs": self.inputs,
|
||||
"query": self.query,
|
||||
"message_tokens": self.message_tokens,
|
||||
"answer_tokens": self.answer_tokens,
|
||||
"provider_response_latency": self.provider_response_latency,
|
||||
"total_price": self.total_price,
|
||||
"message": self.message,
|
||||
"answer": self.answer,
|
||||
|
|
@ -1275,8 +1283,12 @@ class Message(Base):
|
|||
id=data["id"],
|
||||
app_id=data["app_id"],
|
||||
conversation_id=data["conversation_id"],
|
||||
model_provider=data.get("model_provider"),
|
||||
model_id=data["model_id"],
|
||||
inputs=data["inputs"],
|
||||
message_tokens=data.get("message_tokens", 0),
|
||||
answer_tokens=data.get("answer_tokens", 0),
|
||||
provider_response_latency=data.get("provider_response_latency", 0.0),
|
||||
total_price=data["total_price"],
|
||||
query=data["query"],
|
||||
message=data["message"],
|
||||
|
|
@ -1303,7 +1315,9 @@ class MessageFeedback(TypeBase):
|
|||
sa.Index("message_feedback_conversation_idx", "conversation_id", "from_source", "rating"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
|
||||
id: Mapped[str] = mapped_column(
|
||||
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
|
||||
)
|
||||
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
conversation_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
message_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
|
|
@ -1352,7 +1366,9 @@ class MessageFile(TypeBase):
|
|||
sa.Index("message_file_created_by_idx", "created_by"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
|
||||
id: Mapped[str] = mapped_column(
|
||||
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
|
||||
)
|
||||
message_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
type: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
transfer_method: Mapped[FileTransferMethod] = mapped_column(String(255), nullable=False)
|
||||
|
|
@ -1444,7 +1460,9 @@ class AppAnnotationSetting(TypeBase):
|
|||
sa.Index("app_annotation_settings_app_idx", "app_id"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
|
||||
id: Mapped[str] = mapped_column(
|
||||
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
|
||||
)
|
||||
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
score_threshold: Mapped[float] = mapped_column(Float, nullable=False, server_default=sa.text("0"))
|
||||
collection_binding_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
|
|
@ -1480,7 +1498,9 @@ class OperationLog(TypeBase):
|
|||
sa.Index("operation_log_account_action_idx", "tenant_id", "account_id", "action"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
|
||||
id: Mapped[str] = mapped_column(
|
||||
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
|
||||
)
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
account_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
action: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
|
|
@ -1546,7 +1566,9 @@ class AppMCPServer(TypeBase):
|
|||
sa.UniqueConstraint("tenant_id", "app_id", name="unique_app_mcp_server_tenant_app_id"),
|
||||
sa.UniqueConstraint("server_code", name="unique_app_mcp_server_server_code"),
|
||||
)
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
|
||||
id: Mapped[str] = mapped_column(
|
||||
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
|
||||
)
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
|
|
@ -1756,7 +1778,9 @@ class ApiRequest(TypeBase):
|
|||
sa.Index("api_request_token_idx", "tenant_id", "api_token_id"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
|
||||
id: Mapped[str] = mapped_column(
|
||||
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
|
||||
)
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
api_token_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
path: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
|
|
@ -1775,7 +1799,9 @@ class MessageChain(TypeBase):
|
|||
sa.Index("message_chain_message_id_idx", "message_id"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
|
||||
id: Mapped[str] = mapped_column(
|
||||
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
|
||||
)
|
||||
message_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
type: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
input: Mapped[str | None] = mapped_column(LongText, nullable=True)
|
||||
|
|
@ -1906,7 +1932,9 @@ class DatasetRetrieverResource(TypeBase):
|
|||
sa.Index("dataset_retriever_resource_message_id_idx", "message_id"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
|
||||
id: Mapped[str] = mapped_column(
|
||||
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
|
||||
)
|
||||
message_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
position: Mapped[int] = mapped_column(sa.Integer, nullable=False)
|
||||
dataset_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
|
|
@ -1938,7 +1966,9 @@ class Tag(TypeBase):
|
|||
|
||||
TAG_TYPE_LIST = ["knowledge", "app"]
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
|
||||
id: Mapped[str] = mapped_column(
|
||||
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
|
||||
)
|
||||
tenant_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
|
||||
type: Mapped[str] = mapped_column(String(16), nullable=False)
|
||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
|
|
@ -1956,7 +1986,9 @@ class TagBinding(TypeBase):
|
|||
sa.Index("tag_bind_tag_id_idx", "tag_id"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
|
||||
id: Mapped[str] = mapped_column(
|
||||
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
|
||||
)
|
||||
tenant_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
|
||||
tag_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
|
||||
target_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
|
||||
|
|
@ -1973,7 +2005,9 @@ class TraceAppConfig(TypeBase):
|
|||
sa.Index("trace_app_config_app_id_idx", "app_id"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
|
||||
id: Mapped[str] = mapped_column(
|
||||
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
|
||||
)
|
||||
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
tracing_provider: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
tracing_config: Mapped[dict | None] = mapped_column(sa.JSON, nullable=True)
|
||||
|
|
|
|||
|
|
@ -17,7 +17,9 @@ class DatasourceOauthParamConfig(TypeBase):
|
|||
sa.UniqueConstraint("plugin_id", "provider", name="datasource_oauth_config_datasource_id_provider_idx"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7()), init=False)
|
||||
id: Mapped[str] = mapped_column(
|
||||
StringUUID, insert_default=lambda: str(uuidv7()), default_factory=lambda: str(uuidv7()), init=False
|
||||
)
|
||||
plugin_id: Mapped[str] = mapped_column(sa.String(255), nullable=False)
|
||||
provider: Mapped[str] = mapped_column(sa.String(255), nullable=False)
|
||||
system_credentials: Mapped[dict] = mapped_column(AdjustedJSON, nullable=False)
|
||||
|
|
@ -30,7 +32,9 @@ class DatasourceProvider(TypeBase):
|
|||
sa.UniqueConstraint("tenant_id", "plugin_id", "provider", "name", name="datasource_provider_unique_name"),
|
||||
sa.Index("datasource_provider_auth_type_provider_idx", "tenant_id", "plugin_id", "provider"),
|
||||
)
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7()), init=False)
|
||||
id: Mapped[str] = mapped_column(
|
||||
StringUUID, insert_default=lambda: str(uuidv7()), default_factory=lambda: str(uuidv7()), init=False
|
||||
)
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
name: Mapped[str] = mapped_column(sa.String(255), nullable=False)
|
||||
provider: Mapped[str] = mapped_column(sa.String(128), nullable=False)
|
||||
|
|
@ -60,7 +64,9 @@ class DatasourceOauthTenantParamConfig(TypeBase):
|
|||
sa.UniqueConstraint("tenant_id", "plugin_id", "provider", name="datasource_oauth_tenant_config_unique"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7()), init=False)
|
||||
id: Mapped[str] = mapped_column(
|
||||
StringUUID, insert_default=lambda: str(uuidv7()), default_factory=lambda: str(uuidv7()), init=False
|
||||
)
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
provider: Mapped[str] = mapped_column(sa.String(255), nullable=False)
|
||||
plugin_id: Mapped[str] = mapped_column(sa.String(255), nullable=False)
|
||||
|
|
|
|||
|
|
@ -58,7 +58,13 @@ class Provider(TypeBase):
|
|||
),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, primary_key=True, default=lambda: str(uuidv7()), init=False)
|
||||
id: Mapped[str] = mapped_column(
|
||||
StringUUID,
|
||||
primary_key=True,
|
||||
insert_default=lambda: str(uuidv7()),
|
||||
default_factory=lambda: str(uuidv7()),
|
||||
init=False,
|
||||
)
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
provider_type: Mapped[str] = mapped_column(
|
||||
|
|
@ -132,7 +138,9 @@ class ProviderModel(TypeBase):
|
|||
),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
|
||||
id: Mapped[str] = mapped_column(
|
||||
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
|
||||
)
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
model_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
|
|
@ -173,7 +181,9 @@ class TenantDefaultModel(TypeBase):
|
|||
sa.Index("tenant_default_model_tenant_id_provider_type_idx", "tenant_id", "provider_name", "model_type"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
|
||||
id: Mapped[str] = mapped_column(
|
||||
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
|
||||
)
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
model_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
|
|
@ -193,7 +203,9 @@ class TenantPreferredModelProvider(TypeBase):
|
|||
sa.Index("tenant_preferred_model_provider_tenant_provider_idx", "tenant_id", "provider_name"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
|
||||
id: Mapped[str] = mapped_column(
|
||||
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
|
||||
)
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
preferred_provider_type: Mapped[str] = mapped_column(String(40), nullable=False)
|
||||
|
|
@ -212,7 +224,9 @@ class ProviderOrder(TypeBase):
|
|||
sa.Index("provider_order_tenant_provider_idx", "tenant_id", "provider_name"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
|
||||
id: Mapped[str] = mapped_column(
|
||||
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
|
||||
)
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
account_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
|
|
@ -245,7 +259,9 @@ class ProviderModelSetting(TypeBase):
|
|||
sa.Index("provider_model_setting_tenant_provider_model_idx", "tenant_id", "provider_name", "model_type"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
|
||||
id: Mapped[str] = mapped_column(
|
||||
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
|
||||
)
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
model_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
|
|
@ -273,7 +289,9 @@ class LoadBalancingModelConfig(TypeBase):
|
|||
sa.Index("load_balancing_model_config_tenant_provider_model_idx", "tenant_id", "provider_name", "model_type"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
|
||||
id: Mapped[str] = mapped_column(
|
||||
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
|
||||
)
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
model_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
|
|
@ -302,7 +320,9 @@ class ProviderCredential(TypeBase):
|
|||
sa.Index("provider_credential_tenant_provider_idx", "tenant_id", "provider_name"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7()), init=False)
|
||||
id: Mapped[str] = mapped_column(
|
||||
StringUUID, insert_default=lambda: str(uuidv7()), default_factory=lambda: str(uuidv7()), init=False
|
||||
)
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
credential_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
|
|
@ -332,7 +352,9 @@ class ProviderModelCredential(TypeBase):
|
|||
),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7()), init=False)
|
||||
id: Mapped[str] = mapped_column(
|
||||
StringUUID, insert_default=lambda: str(uuidv7()), default_factory=lambda: str(uuidv7()), init=False
|
||||
)
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
model_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
|
|
|
|||
|
|
@ -18,7 +18,9 @@ class DataSourceOauthBinding(TypeBase):
|
|||
adjusted_json_index("source_info_idx", "source_info"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
|
||||
id: Mapped[str] = mapped_column(
|
||||
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
|
||||
)
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
access_token: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
provider: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
|
|
@ -44,7 +46,9 @@ class DataSourceApiKeyAuthBinding(TypeBase):
|
|||
sa.Index("data_source_api_key_auth_binding_provider_idx", "provider"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
|
||||
id: Mapped[str] = mapped_column(
|
||||
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
|
||||
)
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
category: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
provider: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
|
|
|
|||
|
|
@ -24,7 +24,8 @@ class CeleryTask(TypeBase):
|
|||
result: Mapped[bytes | None] = mapped_column(BinaryData, nullable=True, default=None)
|
||||
date_done: Mapped[datetime | None] = mapped_column(
|
||||
DateTime,
|
||||
default=naive_utc_now,
|
||||
insert_default=naive_utc_now,
|
||||
default=None,
|
||||
onupdate=naive_utc_now,
|
||||
nullable=True,
|
||||
)
|
||||
|
|
@ -47,4 +48,6 @@ class CeleryTaskSet(TypeBase):
|
|||
)
|
||||
taskset_id: Mapped[str] = mapped_column(String(155), unique=True)
|
||||
result: Mapped[bytes | None] = mapped_column(BinaryData, nullable=True, default=None)
|
||||
date_done: Mapped[datetime | None] = mapped_column(DateTime, default=naive_utc_now, nullable=True)
|
||||
date_done: Mapped[datetime | None] = mapped_column(
|
||||
DateTime, insert_default=naive_utc_now, default=None, nullable=True
|
||||
)
|
||||
|
|
|
|||
|
|
@ -30,7 +30,9 @@ class ToolOAuthSystemClient(TypeBase):
|
|||
sa.UniqueConstraint("plugin_id", "provider", name="tool_oauth_system_client_plugin_id_provider_idx"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
|
||||
id: Mapped[str] = mapped_column(
|
||||
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
|
||||
)
|
||||
plugin_id: Mapped[str] = mapped_column(String(512), nullable=False)
|
||||
provider: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
# oauth params of the tool provider
|
||||
|
|
@ -45,7 +47,9 @@ class ToolOAuthTenantClient(TypeBase):
|
|||
sa.UniqueConstraint("tenant_id", "plugin_id", "provider", name="unique_tool_oauth_tenant_client"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
|
||||
id: Mapped[str] = mapped_column(
|
||||
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
|
||||
)
|
||||
# tenant id
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
plugin_id: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
|
|
@ -71,7 +75,9 @@ class BuiltinToolProvider(TypeBase):
|
|||
)
|
||||
|
||||
# id of the tool provider
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
|
||||
id: Mapped[str] = mapped_column(
|
||||
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
|
||||
)
|
||||
name: Mapped[str] = mapped_column(
|
||||
String(256),
|
||||
nullable=False,
|
||||
|
|
@ -120,7 +126,9 @@ class ApiToolProvider(TypeBase):
|
|||
sa.UniqueConstraint("name", "tenant_id", name="unique_api_tool_provider"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
|
||||
id: Mapped[str] = mapped_column(
|
||||
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
|
||||
)
|
||||
# name of the api provider
|
||||
name: Mapped[str] = mapped_column(
|
||||
String(255),
|
||||
|
|
@ -192,7 +200,9 @@ class ToolLabelBinding(TypeBase):
|
|||
sa.UniqueConstraint("tool_id", "label_name", name="unique_tool_label_bind"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
|
||||
id: Mapped[str] = mapped_column(
|
||||
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
|
||||
)
|
||||
# tool id
|
||||
tool_id: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||
# tool type
|
||||
|
|
@ -213,7 +223,9 @@ class WorkflowToolProvider(TypeBase):
|
|||
sa.UniqueConstraint("tenant_id", "app_id", name="unique_workflow_tool_provider_app_id"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
|
||||
id: Mapped[str] = mapped_column(
|
||||
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
|
||||
)
|
||||
# name of the workflow provider
|
||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
# label of the workflow provider
|
||||
|
|
@ -279,7 +291,9 @@ class MCPToolProvider(TypeBase):
|
|||
sa.UniqueConstraint("tenant_id", "server_identifier", name="unique_mcp_provider_server_identifier"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
|
||||
id: Mapped[str] = mapped_column(
|
||||
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
|
||||
)
|
||||
# name of the mcp provider
|
||||
name: Mapped[str] = mapped_column(String(40), nullable=False)
|
||||
# server identifier of the mcp provider
|
||||
|
|
@ -360,7 +374,9 @@ class ToolModelInvoke(TypeBase):
|
|||
__tablename__ = "tool_model_invokes"
|
||||
__table_args__ = (sa.PrimaryKeyConstraint("id", name="tool_model_invoke_pkey"),)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
|
||||
id: Mapped[str] = mapped_column(
|
||||
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
|
||||
)
|
||||
# who invoke this tool
|
||||
user_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
# tenant id
|
||||
|
|
@ -413,7 +429,9 @@ class ToolConversationVariables(TypeBase):
|
|||
sa.Index("conversation_id_idx", "conversation_id"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
|
||||
id: Mapped[str] = mapped_column(
|
||||
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
|
||||
)
|
||||
# conversation user id
|
||||
user_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
# tenant id
|
||||
|
|
@ -450,7 +468,9 @@ class ToolFile(TypeBase):
|
|||
sa.Index("tool_file_conversation_id_idx", "conversation_id"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
|
||||
id: Mapped[str] = mapped_column(
|
||||
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
|
||||
)
|
||||
# conversation user id
|
||||
user_id: Mapped[str] = mapped_column(StringUUID)
|
||||
# tenant id
|
||||
|
|
@ -481,7 +501,9 @@ class DeprecatedPublishedAppTool(TypeBase):
|
|||
sa.UniqueConstraint("app_id", "user_id", name="unique_published_app_tool"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
|
||||
id: Mapped[str] = mapped_column(
|
||||
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
|
||||
)
|
||||
# id of the app
|
||||
app_id: Mapped[str] = mapped_column(StringUUID, ForeignKey("apps.id"), nullable=False)
|
||||
|
||||
|
|
|
|||
|
|
@ -41,7 +41,9 @@ class TriggerSubscription(TypeBase):
|
|||
UniqueConstraint("tenant_id", "provider_id", "name", name="unique_trigger_provider"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
|
||||
id: Mapped[str] = mapped_column(
|
||||
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
|
||||
)
|
||||
name: Mapped[str] = mapped_column(String(255), nullable=False, comment="Subscription instance name")
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
user_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
|
|
@ -111,7 +113,9 @@ class TriggerOAuthSystemClient(TypeBase):
|
|||
sa.UniqueConstraint("plugin_id", "provider", name="trigger_oauth_system_client_plugin_id_provider_idx"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
|
||||
id: Mapped[str] = mapped_column(
|
||||
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
|
||||
)
|
||||
plugin_id: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
provider: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
# oauth params of the trigger provider
|
||||
|
|
@ -136,7 +140,9 @@ class TriggerOAuthTenantClient(TypeBase):
|
|||
sa.UniqueConstraint("tenant_id", "plugin_id", "provider", name="unique_trigger_oauth_tenant_client"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
|
||||
id: Mapped[str] = mapped_column(
|
||||
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
|
||||
)
|
||||
# tenant id
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
plugin_id: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
|
|
@ -202,7 +208,9 @@ class WorkflowTriggerLog(TypeBase):
|
|||
sa.Index("workflow_trigger_log_workflow_id_idx", "workflow_id"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7()), init=False)
|
||||
id: Mapped[str] = mapped_column(
|
||||
StringUUID, insert_default=lambda: str(uuidv7()), default_factory=lambda: str(uuidv7()), init=False
|
||||
)
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
workflow_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
|
|
@ -294,7 +302,9 @@ class WorkflowWebhookTrigger(TypeBase):
|
|||
sa.UniqueConstraint("webhook_id", name="uniq_webhook_id"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7()), init=False)
|
||||
id: Mapped[str] = mapped_column(
|
||||
StringUUID, insert_default=lambda: str(uuidv7()), default_factory=lambda: str(uuidv7()), init=False
|
||||
)
|
||||
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
node_id: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
|
|
@ -351,7 +361,9 @@ class WorkflowPluginTrigger(TypeBase):
|
|||
sa.UniqueConstraint("app_id", "node_id", name="uniq_app_node_subscription"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
|
||||
id: Mapped[str] = mapped_column(
|
||||
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
|
||||
)
|
||||
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
node_id: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
|
|
@ -395,7 +407,9 @@ class AppTrigger(TypeBase):
|
|||
sa.Index("app_trigger_tenant_app_idx", "tenant_id", "app_id"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7()), init=False)
|
||||
id: Mapped[str] = mapped_column(
|
||||
StringUUID, insert_default=lambda: str(uuidv7()), default_factory=lambda: str(uuidv7()), init=False
|
||||
)
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
node_id: Mapped[str | None] = mapped_column(String(64), nullable=False)
|
||||
|
|
@ -443,7 +457,13 @@ class WorkflowSchedulePlan(TypeBase):
|
|||
sa.Index("workflow_schedule_plan_next_idx", "next_run_at"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, primary_key=True, default=lambda: str(uuidv7()), init=False)
|
||||
id: Mapped[str] = mapped_column(
|
||||
StringUUID,
|
||||
primary_key=True,
|
||||
insert_default=lambda: str(uuidv7()),
|
||||
default_factory=lambda: str(uuidv7()),
|
||||
init=False,
|
||||
)
|
||||
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
node_id: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
|
|
|
|||
|
|
@ -18,7 +18,9 @@ class SavedMessage(TypeBase):
|
|||
sa.Index("saved_message_message_idx", "app_id", "message_id", "created_by_role", "created_by"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
|
||||
id: Mapped[str] = mapped_column(
|
||||
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
|
||||
)
|
||||
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
message_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
created_by_role: Mapped[str] = mapped_column(String(255), nullable=False, server_default=sa.text("'end_user'"))
|
||||
|
|
@ -42,7 +44,9 @@ class PinnedConversation(TypeBase):
|
|||
sa.Index("pinned_conversation_conversation_idx", "app_id", "conversation_id", "created_by_role", "created_by"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
|
||||
id: Mapped[str] = mapped_column(
|
||||
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
|
||||
)
|
||||
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
conversation_id: Mapped[str] = mapped_column(StringUUID)
|
||||
created_by_role: Mapped[str] = mapped_column(
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ from core.workflow.constants import (
|
|||
CONVERSATION_VARIABLE_NODE_ID,
|
||||
SYSTEM_VARIABLE_NODE_ID,
|
||||
)
|
||||
from core.workflow.entities.pause_reason import HumanInputRequired, PauseReason, PauseReasonType, SchedulingPause
|
||||
from core.workflow.enums import NodeType
|
||||
from extensions.ext_storage import Storage
|
||||
from factories.variable_factory import TypeMismatchError, build_segment_with_type
|
||||
|
|
@ -1102,7 +1103,9 @@ class WorkflowAppLog(TypeBase):
|
|||
sa.Index("workflow_app_log_workflow_run_id_idx", "workflow_run_id"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
|
||||
id: Mapped[str] = mapped_column(
|
||||
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
|
||||
)
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID)
|
||||
app_id: Mapped[str] = mapped_column(StringUUID)
|
||||
workflow_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
|
|
@ -1728,3 +1731,68 @@ class WorkflowPause(DefaultFieldsMixin, Base):
|
|||
primaryjoin="WorkflowPause.workflow_run_id == WorkflowRun.id",
|
||||
back_populates="pause",
|
||||
)
|
||||
|
||||
|
||||
class WorkflowPauseReason(DefaultFieldsMixin, Base):
|
||||
__tablename__ = "workflow_pause_reasons"
|
||||
|
||||
# `pause_id` represents the identifier of the pause,
|
||||
# correspond to the `id` field of `WorkflowPause`.
|
||||
pause_id: Mapped[str] = mapped_column(StringUUID, nullable=False, index=True)
|
||||
|
||||
type_: Mapped[PauseReasonType] = mapped_column(EnumText(PauseReasonType), nullable=False)
|
||||
|
||||
# form_id is not empty if and if only type_ == PauseReasonType.HUMAN_INPUT_REQUIRED
|
||||
#
|
||||
form_id: Mapped[str] = mapped_column(
|
||||
String(36),
|
||||
nullable=False,
|
||||
default="",
|
||||
)
|
||||
|
||||
# message records the text description of this pause reason. For example,
|
||||
# "The workflow has been paused due to scheduling."
|
||||
#
|
||||
# Empty message means that this pause reason is not speified.
|
||||
message: Mapped[str] = mapped_column(
|
||||
String(255),
|
||||
nullable=False,
|
||||
default="",
|
||||
)
|
||||
|
||||
# `node_id` is the identifier of node causing the pasue, correspond to
|
||||
# `Node.id`. Empty `node_id` means that this pause reason is not caused by any specific node
|
||||
# (E.G. time slicing pauses.)
|
||||
node_id: Mapped[str] = mapped_column(
|
||||
String(255),
|
||||
nullable=False,
|
||||
default="",
|
||||
)
|
||||
|
||||
# Relationship to WorkflowPause
|
||||
pause: Mapped[WorkflowPause] = orm.relationship(
|
||||
foreign_keys=[pause_id],
|
||||
# require explicit preloading.
|
||||
lazy="raise",
|
||||
uselist=False,
|
||||
primaryjoin="WorkflowPauseReason.pause_id == WorkflowPause.id",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_entity(cls, pause_reason: PauseReason) -> "WorkflowPauseReason":
|
||||
if isinstance(pause_reason, HumanInputRequired):
|
||||
return cls(
|
||||
type_=PauseReasonType.HUMAN_INPUT_REQUIRED, form_id=pause_reason.form_id, node_id=pause_reason.node_id
|
||||
)
|
||||
elif isinstance(pause_reason, SchedulingPause):
|
||||
return cls(type_=PauseReasonType.SCHEDULED_PAUSE, message=pause_reason.message, node_id="")
|
||||
else:
|
||||
raise AssertionError(f"Unknown pause reason type: {pause_reason}")
|
||||
|
||||
def to_entity(self) -> PauseReason:
|
||||
if self.type_ == PauseReasonType.HUMAN_INPUT_REQUIRED:
|
||||
return HumanInputRequired(form_id=self.form_id, node_id=self.node_id)
|
||||
elif self.type_ == PauseReasonType.SCHEDULED_PAUSE:
|
||||
return SchedulingPause(message=self.message)
|
||||
else:
|
||||
raise AssertionError(f"Unknown pause reason type: {self.type_}")
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
[project]
|
||||
name = "dify-api"
|
||||
version = "1.10.0"
|
||||
version = "1.10.1"
|
||||
requires-python = ">=3.11,<3.13"
|
||||
|
||||
dependencies = [
|
||||
|
|
|
|||
|
|
@ -38,11 +38,12 @@ from collections.abc import Sequence
|
|||
from datetime import datetime
|
||||
from typing import Protocol
|
||||
|
||||
from core.workflow.entities.workflow_pause import WorkflowPauseEntity
|
||||
from core.workflow.entities.pause_reason import PauseReason
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
from models.workflow import WorkflowRun
|
||||
from repositories.entities.workflow_pause import WorkflowPauseEntity
|
||||
from repositories.types import (
|
||||
AverageInteractionStats,
|
||||
DailyRunsStats,
|
||||
|
|
@ -257,6 +258,7 @@ class APIWorkflowRunRepository(WorkflowExecutionRepository, Protocol):
|
|||
workflow_run_id: str,
|
||||
state_owner_user_id: str,
|
||||
state: str,
|
||||
pause_reasons: Sequence[PauseReason],
|
||||
) -> WorkflowPauseEntity:
|
||||
"""
|
||||
Create a new workflow pause state.
|
||||
|
|
|
|||
|
|
@ -7,8 +7,11 @@ and don't contain implementation details like tenant_id, app_id, etc.
|
|||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
|
||||
from core.workflow.entities.pause_reason import PauseReason
|
||||
|
||||
|
||||
class WorkflowPauseEntity(ABC):
|
||||
"""
|
||||
|
|
@ -59,3 +62,15 @@ class WorkflowPauseEntity(ABC):
|
|||
the pause is not resumed yet.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_pause_reasons(self) -> Sequence[PauseReason]:
|
||||
"""
|
||||
Retrieve detailed reasons for this pause.
|
||||
|
||||
Returns a sequence of `PauseReason` objects describing the specific nodes and
|
||||
reasons for which the workflow execution was paused.
|
||||
This information is related to, but distinct from, the `PauseReason` type
|
||||
defined in `api/core/workflow/entities/pause_reason.py`.
|
||||
"""
|
||||
...
|
||||
|
|
@ -31,7 +31,7 @@ from sqlalchemy import and_, delete, func, null, or_, select
|
|||
from sqlalchemy.engine import CursorResult
|
||||
from sqlalchemy.orm import Session, selectinload, sessionmaker
|
||||
|
||||
from core.workflow.entities.workflow_pause import WorkflowPauseEntity
|
||||
from core.workflow.entities.pause_reason import HumanInputRequired, PauseReason, SchedulingPause
|
||||
from core.workflow.enums import WorkflowExecutionStatus
|
||||
from extensions.ext_storage import storage
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
|
|
@ -41,8 +41,9 @@ from libs.time_parser import get_time_threshold
|
|||
from libs.uuid_utils import uuidv7
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
from models.workflow import WorkflowPause as WorkflowPauseModel
|
||||
from models.workflow import WorkflowRun
|
||||
from models.workflow import WorkflowPauseReason, WorkflowRun
|
||||
from repositories.api_workflow_run_repository import APIWorkflowRunRepository
|
||||
from repositories.entities.workflow_pause import WorkflowPauseEntity
|
||||
from repositories.types import (
|
||||
AverageInteractionStats,
|
||||
DailyRunsStats,
|
||||
|
|
@ -318,6 +319,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
|
|||
workflow_run_id: str,
|
||||
state_owner_user_id: str,
|
||||
state: str,
|
||||
pause_reasons: Sequence[PauseReason],
|
||||
) -> WorkflowPauseEntity:
|
||||
"""
|
||||
Create a new workflow pause state.
|
||||
|
|
@ -371,6 +373,25 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
|
|||
pause_model.workflow_run_id = workflow_run.id
|
||||
pause_model.state_object_key = state_obj_key
|
||||
pause_model.created_at = naive_utc_now()
|
||||
pause_reason_models = []
|
||||
for reason in pause_reasons:
|
||||
if isinstance(reason, HumanInputRequired):
|
||||
# TODO(QuantumGhost): record node_id for `WorkflowPauseReason`
|
||||
pause_reason_model = WorkflowPauseReason(
|
||||
pause_id=pause_model.id,
|
||||
type_=reason.TYPE,
|
||||
form_id=reason.form_id,
|
||||
)
|
||||
elif isinstance(reason, SchedulingPause):
|
||||
pause_reason_model = WorkflowPauseReason(
|
||||
pause_id=pause_model.id,
|
||||
type_=reason.TYPE,
|
||||
message=reason.message,
|
||||
)
|
||||
else:
|
||||
raise AssertionError(f"unkown reason type: {type(reason)}")
|
||||
|
||||
pause_reason_models.append(pause_reason_model)
|
||||
|
||||
# Update workflow run status
|
||||
workflow_run.status = WorkflowExecutionStatus.PAUSED
|
||||
|
|
@ -378,10 +399,16 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
|
|||
# Save everything in a transaction
|
||||
session.add(pause_model)
|
||||
session.add(workflow_run)
|
||||
session.add_all(pause_reason_models)
|
||||
|
||||
logger.info("Created workflow pause %s for workflow run %s", pause_model.id, workflow_run_id)
|
||||
|
||||
return _PrivateWorkflowPauseEntity.from_models(pause_model)
|
||||
return _PrivateWorkflowPauseEntity(pause_model=pause_model, reason_models=pause_reason_models)
|
||||
|
||||
def _get_reasons_by_pause_id(self, session: Session, pause_id: str):
|
||||
reason_stmt = select(WorkflowPauseReason).where(WorkflowPauseReason.pause_id == pause_id)
|
||||
pause_reason_models = session.scalars(reason_stmt).all()
|
||||
return pause_reason_models
|
||||
|
||||
def get_workflow_pause(
|
||||
self,
|
||||
|
|
@ -413,8 +440,16 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
|
|||
pause_model = workflow_run.pause
|
||||
if pause_model is None:
|
||||
return None
|
||||
pause_reason_models = self._get_reasons_by_pause_id(session, pause_model.id)
|
||||
|
||||
return _PrivateWorkflowPauseEntity.from_models(pause_model)
|
||||
human_input_form: list[Any] = []
|
||||
# TODO(QuantumGhost): query human_input_forms model and rebuild PauseReason
|
||||
|
||||
return _PrivateWorkflowPauseEntity(
|
||||
pause_model=pause_model,
|
||||
reason_models=pause_reason_models,
|
||||
human_input_form=human_input_form,
|
||||
)
|
||||
|
||||
def resume_workflow_pause(
|
||||
self,
|
||||
|
|
@ -466,6 +501,8 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
|
|||
if pause_model.resumed_at is not None:
|
||||
raise _WorkflowRunError(f"Cannot resume an already resumed pause, pause_id={pause_model.id}")
|
||||
|
||||
pause_reasons = self._get_reasons_by_pause_id(session, pause_model.id)
|
||||
|
||||
# Mark as resumed
|
||||
pause_model.resumed_at = naive_utc_now()
|
||||
workflow_run.pause_id = None # type: ignore
|
||||
|
|
@ -476,7 +513,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
|
|||
|
||||
logger.info("Resumed workflow pause %s for workflow run %s", pause_model.id, workflow_run_id)
|
||||
|
||||
return _PrivateWorkflowPauseEntity.from_models(pause_model)
|
||||
return _PrivateWorkflowPauseEntity(pause_model=pause_model, reason_models=pause_reasons)
|
||||
|
||||
def delete_workflow_pause(
|
||||
self,
|
||||
|
|
@ -815,26 +852,13 @@ class _PrivateWorkflowPauseEntity(WorkflowPauseEntity):
|
|||
self,
|
||||
*,
|
||||
pause_model: WorkflowPauseModel,
|
||||
reason_models: Sequence[WorkflowPauseReason],
|
||||
human_input_form: Sequence = (),
|
||||
) -> None:
|
||||
self._pause_model = pause_model
|
||||
self._reason_models = reason_models
|
||||
self._cached_state: bytes | None = None
|
||||
|
||||
@classmethod
|
||||
def from_models(cls, workflow_pause_model) -> "_PrivateWorkflowPauseEntity":
|
||||
"""
|
||||
Create a _PrivateWorkflowPauseEntity from database models.
|
||||
|
||||
Args:
|
||||
workflow_pause_model: The WorkflowPause database model
|
||||
upload_file_model: The UploadFile database model
|
||||
|
||||
Returns:
|
||||
_PrivateWorkflowPauseEntity: The constructed entity
|
||||
|
||||
Raises:
|
||||
ValueError: If required model attributes are missing
|
||||
"""
|
||||
return cls(pause_model=workflow_pause_model)
|
||||
self._human_input_form = human_input_form
|
||||
|
||||
@property
|
||||
def id(self) -> str:
|
||||
|
|
@ -867,3 +891,6 @@ class _PrivateWorkflowPauseEntity(WorkflowPauseEntity):
|
|||
@property
|
||||
def resumed_at(self) -> datetime | None:
|
||||
return self._pause_model.resumed_at
|
||||
|
||||
def get_pause_reasons(self) -> Sequence[PauseReason]:
|
||||
return [reason.to_entity() for reason in self._reason_models]
|
||||
|
|
|
|||
|
|
@ -1352,7 +1352,7 @@ class RegisterService:
|
|||
|
||||
@classmethod
|
||||
def invite_new_member(
|
||||
cls, tenant: Tenant, email: str, language: str, role: str = "normal", inviter: Account | None = None
|
||||
cls, tenant: Tenant, email: str, language: str | None, role: str = "normal", inviter: Account | None = None
|
||||
) -> str:
|
||||
if not inviter:
|
||||
raise ValueError("Inviter is required")
|
||||
|
|
|
|||
|
|
@ -550,7 +550,7 @@ class AppDslService:
|
|||
"app": {
|
||||
"name": app_model.name,
|
||||
"mode": app_model.mode,
|
||||
"icon": "🤖" if app_model.icon_type == "image" else app_model.icon,
|
||||
"icon": app_model.icon if app_model.icon_type == "image" else "🤖",
|
||||
"icon_background": "#FFEAD5" if app_model.icon_type == "image" else app_model.icon_background,
|
||||
"description": app_model.description,
|
||||
"use_icon_as_answer_icon": app_model.use_icon_as_answer_icon,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,45 @@
|
|||
"""Service for managing application task operations.
|
||||
|
||||
This service provides centralized logic for task control operations
|
||||
like stopping tasks, handling both legacy Redis flag mechanism and
|
||||
new GraphEngine command channel mechanism.
|
||||
"""
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.graph_engine.manager import GraphEngineManager
|
||||
from models.model import AppMode
|
||||
|
||||
|
||||
class AppTaskService:
|
||||
"""Service for managing application task operations."""
|
||||
|
||||
@staticmethod
|
||||
def stop_task(
|
||||
task_id: str,
|
||||
invoke_from: InvokeFrom,
|
||||
user_id: str,
|
||||
app_mode: AppMode,
|
||||
) -> None:
|
||||
"""Stop a running task.
|
||||
|
||||
This method handles stopping tasks using both mechanisms:
|
||||
1. Legacy Redis flag mechanism (for backward compatibility)
|
||||
2. New GraphEngine command channel (for workflow-based apps)
|
||||
|
||||
Args:
|
||||
task_id: The task ID to stop
|
||||
invoke_from: The source of the invoke (e.g., DEBUGGER, WEB_APP, SERVICE_API)
|
||||
user_id: The user ID requesting the stop
|
||||
app_mode: The application mode (CHAT, AGENT_CHAT, ADVANCED_CHAT, WORKFLOW, etc.)
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
# Legacy mechanism: Set stop flag in Redis
|
||||
AppQueueManager.set_stop_flag(task_id, invoke_from, user_id)
|
||||
|
||||
# New mechanism: Send stop command via GraphEngine for workflow-based apps
|
||||
# This ensures proper workflow status recording in the persistence layer
|
||||
if app_mode in (AppMode.ADVANCED_CHAT, AppMode.WORKFLOW):
|
||||
GraphEngineManager.send_stop_command(task_id)
|
||||
|
|
@ -1375,6 +1375,11 @@ class DocumentService:
|
|||
|
||||
document.name = name
|
||||
db.session.add(document)
|
||||
if document.data_source_info_dict:
|
||||
db.session.query(UploadFile).where(
|
||||
UploadFile.id == document.data_source_info_dict["upload_file_id"]
|
||||
).update({UploadFile.name: name})
|
||||
|
||||
db.session.commit()
|
||||
|
||||
return document
|
||||
|
|
|
|||
|
|
@ -0,0 +1,185 @@
|
|||
import csv
|
||||
import io
|
||||
import json
|
||||
from datetime import datetime
|
||||
|
||||
from flask import Response
|
||||
from sqlalchemy import or_
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models.model import Account, App, Conversation, Message, MessageFeedback
|
||||
|
||||
|
||||
class FeedbackService:
|
||||
@staticmethod
|
||||
def export_feedbacks(
|
||||
app_id: str,
|
||||
from_source: str | None = None,
|
||||
rating: str | None = None,
|
||||
has_comment: bool | None = None,
|
||||
start_date: str | None = None,
|
||||
end_date: str | None = None,
|
||||
format_type: str = "csv",
|
||||
):
|
||||
"""
|
||||
Export feedback data with message details for analysis
|
||||
|
||||
Args:
|
||||
app_id: Application ID
|
||||
from_source: Filter by feedback source ('user' or 'admin')
|
||||
rating: Filter by rating ('like' or 'dislike')
|
||||
has_comment: Only include feedback with comments
|
||||
start_date: Start date filter (YYYY-MM-DD)
|
||||
end_date: End date filter (YYYY-MM-DD)
|
||||
format_type: Export format ('csv' or 'json')
|
||||
"""
|
||||
|
||||
# Validate format early to avoid hitting DB when unnecessary
|
||||
fmt = (format_type or "csv").lower()
|
||||
if fmt not in {"csv", "json"}:
|
||||
raise ValueError(f"Unsupported format: {format_type}")
|
||||
|
||||
# Build base query
|
||||
query = (
|
||||
db.session.query(MessageFeedback, Message, Conversation, App, Account)
|
||||
.join(Message, MessageFeedback.message_id == Message.id)
|
||||
.join(Conversation, MessageFeedback.conversation_id == Conversation.id)
|
||||
.join(App, MessageFeedback.app_id == App.id)
|
||||
.outerjoin(Account, MessageFeedback.from_account_id == Account.id)
|
||||
.where(MessageFeedback.app_id == app_id)
|
||||
)
|
||||
|
||||
# Apply filters
|
||||
if from_source:
|
||||
query = query.filter(MessageFeedback.from_source == from_source)
|
||||
|
||||
if rating:
|
||||
query = query.filter(MessageFeedback.rating == rating)
|
||||
|
||||
if has_comment is not None:
|
||||
if has_comment:
|
||||
query = query.filter(MessageFeedback.content.isnot(None), MessageFeedback.content != "")
|
||||
else:
|
||||
query = query.filter(or_(MessageFeedback.content.is_(None), MessageFeedback.content == ""))
|
||||
|
||||
if start_date:
|
||||
try:
|
||||
start_dt = datetime.strptime(start_date, "%Y-%m-%d")
|
||||
query = query.filter(MessageFeedback.created_at >= start_dt)
|
||||
except ValueError:
|
||||
raise ValueError(f"Invalid start_date format: {start_date}. Use YYYY-MM-DD")
|
||||
|
||||
if end_date:
|
||||
try:
|
||||
end_dt = datetime.strptime(end_date, "%Y-%m-%d")
|
||||
query = query.filter(MessageFeedback.created_at <= end_dt)
|
||||
except ValueError:
|
||||
raise ValueError(f"Invalid end_date format: {end_date}. Use YYYY-MM-DD")
|
||||
|
||||
# Order by creation date (newest first)
|
||||
query = query.order_by(MessageFeedback.created_at.desc())
|
||||
|
||||
# Execute query
|
||||
results = query.all()
|
||||
|
||||
# Prepare data for export
|
||||
export_data = []
|
||||
for feedback, message, conversation, app, account in results:
|
||||
# Get the user query from the message
|
||||
user_query = message.query or message.inputs.get("query", "") if message.inputs else ""
|
||||
|
||||
# Format the feedback data
|
||||
feedback_record = {
|
||||
"feedback_id": str(feedback.id),
|
||||
"app_name": app.name,
|
||||
"app_id": str(app.id),
|
||||
"conversation_id": str(conversation.id),
|
||||
"conversation_name": conversation.name or "",
|
||||
"message_id": str(message.id),
|
||||
"user_query": user_query,
|
||||
"ai_response": message.answer[:500] + "..."
|
||||
if len(message.answer) > 500
|
||||
else message.answer, # Truncate long responses
|
||||
"feedback_rating": "👍" if feedback.rating == "like" else "👎",
|
||||
"feedback_rating_raw": feedback.rating,
|
||||
"feedback_comment": feedback.content or "",
|
||||
"feedback_source": feedback.from_source,
|
||||
"feedback_date": feedback.created_at.strftime("%Y-%m-%d %H:%M:%S"),
|
||||
"message_date": message.created_at.strftime("%Y-%m-%d %H:%M:%S"),
|
||||
"from_account_name": account.name if account else "",
|
||||
"from_end_user_id": str(feedback.from_end_user_id) if feedback.from_end_user_id else "",
|
||||
"has_comment": "Yes" if feedback.content and feedback.content.strip() else "No",
|
||||
}
|
||||
export_data.append(feedback_record)
|
||||
|
||||
# Export based on format
|
||||
if fmt == "csv":
|
||||
return FeedbackService._export_csv(export_data, app_id)
|
||||
else: # fmt == "json"
|
||||
return FeedbackService._export_json(export_data, app_id)
|
||||
|
||||
@staticmethod
|
||||
def _export_csv(data, app_id):
|
||||
"""Export data as CSV"""
|
||||
if not data:
|
||||
pass # allow empty CSV with headers only
|
||||
|
||||
# Create CSV in memory
|
||||
output = io.StringIO()
|
||||
|
||||
# Define headers
|
||||
headers = [
|
||||
"feedback_id",
|
||||
"app_name",
|
||||
"app_id",
|
||||
"conversation_id",
|
||||
"conversation_name",
|
||||
"message_id",
|
||||
"user_query",
|
||||
"ai_response",
|
||||
"feedback_rating",
|
||||
"feedback_rating_raw",
|
||||
"feedback_comment",
|
||||
"feedback_source",
|
||||
"feedback_date",
|
||||
"message_date",
|
||||
"from_account_name",
|
||||
"from_end_user_id",
|
||||
"has_comment",
|
||||
]
|
||||
|
||||
writer = csv.DictWriter(output, fieldnames=headers)
|
||||
writer.writeheader()
|
||||
writer.writerows(data)
|
||||
|
||||
# Create response without requiring app context
|
||||
response = Response(output.getvalue(), mimetype="text/csv; charset=utf-8-sig")
|
||||
response.headers["Content-Disposition"] = (
|
||||
f"attachment; filename=dify_feedback_export_{app_id}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv"
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
@staticmethod
|
||||
def _export_json(data, app_id):
|
||||
"""Export data as JSON"""
|
||||
response_data = {
|
||||
"export_info": {
|
||||
"app_id": app_id,
|
||||
"export_date": datetime.now().isoformat(),
|
||||
"total_records": len(data),
|
||||
"data_source": "dify_feedback_export",
|
||||
},
|
||||
"feedback_data": data,
|
||||
}
|
||||
|
||||
# Create response without requiring app context
|
||||
response = Response(
|
||||
json.dumps(response_data, ensure_ascii=False, indent=2),
|
||||
mimetype="application/json; charset=utf-8",
|
||||
)
|
||||
response.headers["Content-Disposition"] = (
|
||||
f"attachment; filename=dify_feedback_export_{app_id}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
|
||||
)
|
||||
|
||||
return response
|
||||
|
|
@ -3,8 +3,8 @@ import os
|
|||
import uuid
|
||||
from typing import Literal, Union
|
||||
|
||||
from sqlalchemy import Engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy import Engine, select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from configs import dify_config
|
||||
|
|
@ -29,7 +29,7 @@ PREVIEW_WORDS_LIMIT = 3000
|
|||
|
||||
|
||||
class FileService:
|
||||
_session_maker: sessionmaker
|
||||
_session_maker: sessionmaker[Session]
|
||||
|
||||
def __init__(self, session_factory: sessionmaker | Engine | None = None):
|
||||
if isinstance(session_factory, Engine):
|
||||
|
|
@ -236,11 +236,10 @@ class FileService:
|
|||
return content.decode("utf-8")
|
||||
|
||||
def delete_file(self, file_id: str):
|
||||
with self._session_maker(expire_on_commit=False) as session:
|
||||
upload_file: UploadFile | None = session.query(UploadFile).where(UploadFile.id == file_id).first()
|
||||
with self._session_maker() as session, session.begin():
|
||||
upload_file = session.scalar(select(UploadFile).where(UploadFile.id == file_id))
|
||||
|
||||
if not upload_file:
|
||||
return
|
||||
storage.delete(upload_file.key)
|
||||
session.delete(upload_file)
|
||||
session.commit()
|
||||
if not upload_file:
|
||||
return
|
||||
storage.delete(upload_file.key)
|
||||
session.delete(upload_file)
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ import secrets
|
|||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
import orjson
|
||||
from flask import request
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
|
|
@ -169,7 +170,7 @@ class WebhookService:
|
|||
- method: HTTP method
|
||||
- headers: Request headers
|
||||
- query_params: Query parameters as strings
|
||||
- body: Request body (varies by content type)
|
||||
- body: Request body (varies by content type; JSON parsing errors raise ValueError)
|
||||
- files: Uploaded files (if any)
|
||||
"""
|
||||
cls._validate_content_length()
|
||||
|
|
@ -255,14 +256,21 @@ class WebhookService:
|
|||
|
||||
Returns:
|
||||
tuple: (body_data, files_data) where:
|
||||
- body_data: Parsed JSON content or empty dict if parsing fails
|
||||
- body_data: Parsed JSON content
|
||||
- files_data: Empty dict (JSON requests don't contain files)
|
||||
|
||||
Raises:
|
||||
ValueError: If JSON parsing fails
|
||||
"""
|
||||
raw_body = request.get_data(cache=True)
|
||||
if not raw_body or raw_body.strip() == b"":
|
||||
return {}, {}
|
||||
|
||||
try:
|
||||
body = request.get_json() or {}
|
||||
except Exception:
|
||||
logger.warning("Failed to parse JSON body")
|
||||
body = {}
|
||||
body = orjson.loads(raw_body)
|
||||
except orjson.JSONDecodeError as exc:
|
||||
logger.warning("Failed to parse JSON body: %s", exc)
|
||||
raise ValueError(f"Invalid JSON body: {exc}") from exc
|
||||
return body, {}
|
||||
|
||||
@classmethod
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ from core.file import File
|
|||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from core.variables import Variable
|
||||
from core.variables.variables import VariableUnion
|
||||
from core.workflow.entities import VariablePool, WorkflowNodeExecution
|
||||
from core.workflow.entities import WorkflowNodeExecution
|
||||
from core.workflow.enums import ErrorStrategy, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from core.workflow.errors import WorkflowNodeRunFailedError
|
||||
from core.workflow.graph_events import GraphNodeEventBase, NodeRunFailedEvent, NodeRunSucceededEvent
|
||||
|
|
@ -24,6 +24,7 @@ from core.workflow.nodes import NodeType
|
|||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
|
||||
from core.workflow.nodes.start.entities import StartNodeData
|
||||
from core.workflow.runtime import VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from enums.cloud_plan import CloudPlan
|
||||
|
|
|
|||
|
|
@ -62,6 +62,7 @@ WEAVIATE_ENDPOINT=http://localhost:8080
|
|||
WEAVIATE_API_KEY=WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih
|
||||
WEAVIATE_GRPC_ENABLED=false
|
||||
WEAVIATE_BATCH_SIZE=100
|
||||
WEAVIATE_TOKENIZATION=word
|
||||
|
||||
|
||||
# Upload configuration
|
||||
|
|
|
|||
|
|
@ -0,0 +1,106 @@
|
|||
"""Basic integration tests for Feedback API endpoints."""
|
||||
|
||||
import uuid
|
||||
|
||||
from flask.testing import FlaskClient
|
||||
|
||||
|
||||
class TestFeedbackApiBasic:
|
||||
"""Basic tests for feedback API endpoints."""
|
||||
|
||||
def test_feedback_export_endpoint_exists(self, test_client: FlaskClient, auth_header):
|
||||
"""Test that feedback export endpoint exists and handles basic requests."""
|
||||
|
||||
app_id = str(uuid.uuid4())
|
||||
|
||||
# Test endpoint exists (even if it fails, it should return 500 or 403, not 404)
|
||||
response = test_client.get(
|
||||
f"/console/api/apps/{app_id}/feedbacks/export", headers=auth_header, query_string={"format": "csv"}
|
||||
)
|
||||
|
||||
# Should not return 404 (endpoint exists)
|
||||
assert response.status_code != 404
|
||||
|
||||
# Should return authentication or permission error
|
||||
assert response.status_code in [401, 403, 500] # 500 if app doesn't exist, 403 if no permission
|
||||
|
||||
def test_feedback_summary_endpoint_exists(self, test_client: FlaskClient, auth_header):
|
||||
"""Test that feedback summary endpoint exists and handles basic requests."""
|
||||
|
||||
app_id = str(uuid.uuid4())
|
||||
|
||||
# Test endpoint exists
|
||||
response = test_client.get(f"/console/api/apps/{app_id}/feedbacks/summary", headers=auth_header)
|
||||
|
||||
# Should not return 404 (endpoint exists)
|
||||
assert response.status_code != 404
|
||||
|
||||
# Should return authentication or permission error
|
||||
assert response.status_code in [401, 403, 500]
|
||||
|
||||
def test_feedback_export_invalid_format(self, test_client: FlaskClient, auth_header):
|
||||
"""Test feedback export endpoint with invalid format parameter."""
|
||||
|
||||
app_id = str(uuid.uuid4())
|
||||
|
||||
# Test with invalid format
|
||||
response = test_client.get(
|
||||
f"/console/api/apps/{app_id}/feedbacks/export",
|
||||
headers=auth_header,
|
||||
query_string={"format": "invalid_format"},
|
||||
)
|
||||
|
||||
# Should not return 404
|
||||
assert response.status_code != 404
|
||||
|
||||
def test_feedback_export_with_filters(self, test_client: FlaskClient, auth_header):
|
||||
"""Test feedback export endpoint with various filter parameters."""
|
||||
|
||||
app_id = str(uuid.uuid4())
|
||||
|
||||
# Test with various filter combinations
|
||||
filter_params = [
|
||||
{"from_source": "user"},
|
||||
{"rating": "like"},
|
||||
{"has_comment": True},
|
||||
{"start_date": "2024-01-01"},
|
||||
{"end_date": "2024-12-31"},
|
||||
{"format": "json"},
|
||||
{
|
||||
"from_source": "admin",
|
||||
"rating": "dislike",
|
||||
"has_comment": True,
|
||||
"start_date": "2024-01-01",
|
||||
"end_date": "2024-12-31",
|
||||
"format": "csv",
|
||||
},
|
||||
]
|
||||
|
||||
for params in filter_params:
|
||||
response = test_client.get(
|
||||
f"/console/api/apps/{app_id}/feedbacks/export", headers=auth_header, query_string=params
|
||||
)
|
||||
|
||||
# Should not return 404
|
||||
assert response.status_code != 404
|
||||
|
||||
def test_feedback_export_invalid_dates(self, test_client: FlaskClient, auth_header):
|
||||
"""Test feedback export endpoint with invalid date formats."""
|
||||
|
||||
app_id = str(uuid.uuid4())
|
||||
|
||||
# Test with invalid date formats
|
||||
invalid_dates = [
|
||||
{"start_date": "invalid-date"},
|
||||
{"end_date": "not-a-date"},
|
||||
{"start_date": "2024-13-01"}, # Invalid month
|
||||
{"end_date": "2024-12-32"}, # Invalid day
|
||||
]
|
||||
|
||||
for params in invalid_dates:
|
||||
response = test_client.get(
|
||||
f"/console/api/apps/{app_id}/feedbacks/export", headers=auth_header, query_string=params
|
||||
)
|
||||
|
||||
# Should not return 404
|
||||
assert response.status_code != 404
|
||||
|
|
@ -0,0 +1,334 @@
|
|||
"""Integration tests for Feedback Export API endpoints."""
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from types import SimpleNamespace
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
from flask.testing import FlaskClient
|
||||
|
||||
from controllers.console.app import message as message_api
|
||||
from controllers.console.app import wraps
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import App, Tenant
|
||||
from models.account import Account, TenantAccountJoin, TenantAccountRole
|
||||
from models.model import AppMode, MessageFeedback
|
||||
from services.feedback_service import FeedbackService
|
||||
|
||||
|
||||
class TestFeedbackExportApi:
|
||||
"""Test feedback export API endpoints."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_app_model(self):
|
||||
"""Create a mock App model for testing."""
|
||||
app = App()
|
||||
app.id = str(uuid.uuid4())
|
||||
app.mode = AppMode.CHAT
|
||||
app.tenant_id = str(uuid.uuid4())
|
||||
app.status = "normal"
|
||||
app.name = "Test App"
|
||||
return app
|
||||
|
||||
@pytest.fixture
|
||||
def mock_account(self, monkeypatch: pytest.MonkeyPatch):
|
||||
"""Create a mock Account for testing."""
|
||||
account = Account(
|
||||
name="Test User",
|
||||
email="test@example.com",
|
||||
)
|
||||
account.last_active_at = naive_utc_now()
|
||||
account.created_at = naive_utc_now()
|
||||
account.updated_at = naive_utc_now()
|
||||
account.id = str(uuid.uuid4())
|
||||
|
||||
# Create mock tenant
|
||||
tenant = Tenant(name="Test Tenant")
|
||||
tenant.id = str(uuid.uuid4())
|
||||
|
||||
mock_session_instance = mock.Mock()
|
||||
|
||||
mock_tenant_join = TenantAccountJoin(role=TenantAccountRole.OWNER)
|
||||
monkeypatch.setattr(mock_session_instance, "scalar", mock.Mock(return_value=mock_tenant_join))
|
||||
|
||||
mock_scalars_result = mock.Mock()
|
||||
mock_scalars_result.one.return_value = tenant
|
||||
monkeypatch.setattr(mock_session_instance, "scalars", mock.Mock(return_value=mock_scalars_result))
|
||||
|
||||
mock_session_context = mock.Mock()
|
||||
mock_session_context.__enter__.return_value = mock_session_instance
|
||||
monkeypatch.setattr("models.account.Session", lambda _, expire_on_commit: mock_session_context)
|
||||
|
||||
account.current_tenant = tenant
|
||||
return account
|
||||
|
||||
@pytest.fixture
|
||||
def sample_feedback_data(self):
|
||||
"""Create sample feedback data for testing."""
|
||||
app_id = str(uuid.uuid4())
|
||||
conversation_id = str(uuid.uuid4())
|
||||
message_id = str(uuid.uuid4())
|
||||
|
||||
# Mock feedback data
|
||||
user_feedback = MessageFeedback(
|
||||
id=str(uuid.uuid4()),
|
||||
app_id=app_id,
|
||||
conversation_id=conversation_id,
|
||||
message_id=message_id,
|
||||
rating="like",
|
||||
from_source="user",
|
||||
content=None,
|
||||
from_end_user_id=str(uuid.uuid4()),
|
||||
from_account_id=None,
|
||||
created_at=naive_utc_now(),
|
||||
)
|
||||
|
||||
admin_feedback = MessageFeedback(
|
||||
id=str(uuid.uuid4()),
|
||||
app_id=app_id,
|
||||
conversation_id=conversation_id,
|
||||
message_id=message_id,
|
||||
rating="dislike",
|
||||
from_source="admin",
|
||||
content="The response was not helpful",
|
||||
from_end_user_id=None,
|
||||
from_account_id=str(uuid.uuid4()),
|
||||
created_at=naive_utc_now(),
|
||||
)
|
||||
|
||||
# Mock message and conversation
|
||||
mock_message = SimpleNamespace(
|
||||
id=message_id,
|
||||
conversation_id=conversation_id,
|
||||
query="What is the weather today?",
|
||||
answer="It's sunny and 25 degrees outside.",
|
||||
inputs={"query": "What is the weather today?"},
|
||||
created_at=naive_utc_now(),
|
||||
)
|
||||
|
||||
mock_conversation = SimpleNamespace(id=conversation_id, name="Weather Conversation", app_id=app_id)
|
||||
|
||||
mock_app = SimpleNamespace(id=app_id, name="Weather App")
|
||||
|
||||
return {
|
||||
"user_feedback": user_feedback,
|
||||
"admin_feedback": admin_feedback,
|
||||
"message": mock_message,
|
||||
"conversation": mock_conversation,
|
||||
"app": mock_app,
|
||||
}
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("role", "status"),
|
||||
[
|
||||
(TenantAccountRole.OWNER, 200),
|
||||
(TenantAccountRole.ADMIN, 200),
|
||||
(TenantAccountRole.EDITOR, 200),
|
||||
(TenantAccountRole.NORMAL, 403),
|
||||
(TenantAccountRole.DATASET_OPERATOR, 403),
|
||||
],
|
||||
)
|
||||
def test_feedback_export_permissions(
|
||||
self,
|
||||
test_client: FlaskClient,
|
||||
auth_header,
|
||||
monkeypatch,
|
||||
mock_app_model,
|
||||
mock_account,
|
||||
role: TenantAccountRole,
|
||||
status: int,
|
||||
):
|
||||
"""Test feedback export endpoint permissions."""
|
||||
|
||||
# Setup mocks
|
||||
mock_load_app_model = mock.Mock(return_value=mock_app_model)
|
||||
monkeypatch.setattr(wraps, "_load_app_model", mock_load_app_model)
|
||||
|
||||
mock_export_feedbacks = mock.Mock(return_value="mock csv response")
|
||||
monkeypatch.setattr(FeedbackService, "export_feedbacks", mock_export_feedbacks)
|
||||
|
||||
monkeypatch.setattr(message_api, "current_user", mock_account)
|
||||
|
||||
# Set user role
|
||||
mock_account.role = role
|
||||
|
||||
response = test_client.get(
|
||||
f"/console/api/apps/{mock_app_model.id}/feedbacks/export",
|
||||
headers=auth_header,
|
||||
query_string={"format": "csv"},
|
||||
)
|
||||
|
||||
assert response.status_code == status
|
||||
|
||||
if status == 200:
|
||||
mock_export_feedbacks.assert_called_once()
|
||||
|
||||
def test_feedback_export_csv_format(
|
||||
self, test_client: FlaskClient, auth_header, monkeypatch, mock_app_model, mock_account, sample_feedback_data
|
||||
):
|
||||
"""Test feedback export in CSV format."""
|
||||
|
||||
# Setup mocks
|
||||
mock_load_app_model = mock.Mock(return_value=mock_app_model)
|
||||
monkeypatch.setattr(wraps, "_load_app_model", mock_load_app_model)
|
||||
|
||||
# Create mock CSV response
|
||||
mock_csv_content = (
|
||||
"feedback_id,app_name,conversation_id,user_query,ai_response,feedback_rating,feedback_comment\n"
|
||||
)
|
||||
mock_csv_content += f"{sample_feedback_data['user_feedback'].id},{sample_feedback_data['app'].name},"
|
||||
mock_csv_content += f"{sample_feedback_data['conversation'].id},{sample_feedback_data['message'].query},"
|
||||
mock_csv_content += f"{sample_feedback_data['message'].answer},👍,\n"
|
||||
|
||||
mock_response = mock.Mock()
|
||||
mock_response.headers = {"Content-Type": "text/csv; charset=utf-8-sig"}
|
||||
mock_response.data = mock_csv_content.encode("utf-8")
|
||||
|
||||
mock_export_feedbacks = mock.Mock(return_value=mock_response)
|
||||
monkeypatch.setattr(FeedbackService, "export_feedbacks", mock_export_feedbacks)
|
||||
|
||||
monkeypatch.setattr(message_api, "current_user", mock_account)
|
||||
|
||||
response = test_client.get(
|
||||
f"/console/api/apps/{mock_app_model.id}/feedbacks/export",
|
||||
headers=auth_header,
|
||||
query_string={"format": "csv", "from_source": "user"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert "text/csv" in response.content_type
|
||||
|
||||
def test_feedback_export_json_format(
|
||||
self, test_client: FlaskClient, auth_header, monkeypatch, mock_app_model, mock_account, sample_feedback_data
|
||||
):
|
||||
"""Test feedback export in JSON format."""
|
||||
|
||||
# Setup mocks
|
||||
mock_load_app_model = mock.Mock(return_value=mock_app_model)
|
||||
monkeypatch.setattr(wraps, "_load_app_model", mock_load_app_model)
|
||||
|
||||
mock_json_response = {
|
||||
"export_info": {
|
||||
"app_id": mock_app_model.id,
|
||||
"export_date": datetime.now().isoformat(),
|
||||
"total_records": 2,
|
||||
"data_source": "dify_feedback_export",
|
||||
},
|
||||
"feedback_data": [
|
||||
{
|
||||
"feedback_id": sample_feedback_data["user_feedback"].id,
|
||||
"feedback_rating": "👍",
|
||||
"feedback_rating_raw": "like",
|
||||
"feedback_comment": "",
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
mock_response = mock.Mock()
|
||||
mock_response.headers = {"Content-Type": "application/json; charset=utf-8"}
|
||||
mock_response.data = json.dumps(mock_json_response).encode("utf-8")
|
||||
|
||||
mock_export_feedbacks = mock.Mock(return_value=mock_response)
|
||||
monkeypatch.setattr(FeedbackService, "export_feedbacks", mock_export_feedbacks)
|
||||
|
||||
monkeypatch.setattr(message_api, "current_user", mock_account)
|
||||
|
||||
response = test_client.get(
|
||||
f"/console/api/apps/{mock_app_model.id}/feedbacks/export",
|
||||
headers=auth_header,
|
||||
query_string={"format": "json"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert "application/json" in response.content_type
|
||||
|
||||
def test_feedback_export_with_filters(
|
||||
self, test_client: FlaskClient, auth_header, monkeypatch, mock_app_model, mock_account
|
||||
):
|
||||
"""Test feedback export with various filters."""
|
||||
|
||||
# Setup mocks
|
||||
mock_load_app_model = mock.Mock(return_value=mock_app_model)
|
||||
monkeypatch.setattr(wraps, "_load_app_model", mock_load_app_model)
|
||||
|
||||
mock_export_feedbacks = mock.Mock(return_value="mock filtered response")
|
||||
monkeypatch.setattr(FeedbackService, "export_feedbacks", mock_export_feedbacks)
|
||||
|
||||
monkeypatch.setattr(message_api, "current_user", mock_account)
|
||||
|
||||
# Test with multiple filters
|
||||
response = test_client.get(
|
||||
f"/console/api/apps/{mock_app_model.id}/feedbacks/export",
|
||||
headers=auth_header,
|
||||
query_string={
|
||||
"from_source": "user",
|
||||
"rating": "dislike",
|
||||
"has_comment": True,
|
||||
"start_date": "2024-01-01",
|
||||
"end_date": "2024-12-31",
|
||||
"format": "csv",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
# Verify service was called with correct parameters
|
||||
mock_export_feedbacks.assert_called_once_with(
|
||||
app_id=mock_app_model.id,
|
||||
from_source="user",
|
||||
rating="dislike",
|
||||
has_comment=True,
|
||||
start_date="2024-01-01",
|
||||
end_date="2024-12-31",
|
||||
format_type="csv",
|
||||
)
|
||||
|
||||
def test_feedback_export_invalid_date_format(
|
||||
self, test_client: FlaskClient, auth_header, monkeypatch, mock_app_model, mock_account
|
||||
):
|
||||
"""Test feedback export with invalid date format."""
|
||||
|
||||
# Setup mocks
|
||||
mock_load_app_model = mock.Mock(return_value=mock_app_model)
|
||||
monkeypatch.setattr(wraps, "_load_app_model", mock_load_app_model)
|
||||
|
||||
# Mock the service to raise ValueError for invalid date
|
||||
mock_export_feedbacks = mock.Mock(side_effect=ValueError("Invalid date format"))
|
||||
monkeypatch.setattr(FeedbackService, "export_feedbacks", mock_export_feedbacks)
|
||||
|
||||
monkeypatch.setattr(message_api, "current_user", mock_account)
|
||||
|
||||
response = test_client.get(
|
||||
f"/console/api/apps/{mock_app_model.id}/feedbacks/export",
|
||||
headers=auth_header,
|
||||
query_string={"start_date": "invalid-date", "format": "csv"},
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
response_json = response.get_json()
|
||||
assert "Parameter validation error" in response_json["error"]
|
||||
|
||||
def test_feedback_export_server_error(
|
||||
self, test_client: FlaskClient, auth_header, monkeypatch, mock_app_model, mock_account
|
||||
):
|
||||
"""Test feedback export with server error."""
|
||||
|
||||
# Setup mocks
|
||||
mock_load_app_model = mock.Mock(return_value=mock_app_model)
|
||||
monkeypatch.setattr(wraps, "_load_app_model", mock_load_app_model)
|
||||
|
||||
# Mock the service to raise an exception
|
||||
mock_export_feedbacks = mock.Mock(side_effect=Exception("Database connection failed"))
|
||||
monkeypatch.setattr(FeedbackService, "export_feedbacks", mock_export_feedbacks)
|
||||
|
||||
monkeypatch.setattr(message_api, "current_user", mock_account)
|
||||
|
||||
response = test_client.get(
|
||||
f"/console/api/apps/{mock_app_model.id}/feedbacks/export",
|
||||
headers=auth_header,
|
||||
query_string={"format": "csv"},
|
||||
)
|
||||
|
||||
assert response.status_code == 500
|
||||
|
|
@ -319,7 +319,7 @@ class TestPauseStatePersistenceLayerTestContainers:
|
|||
|
||||
# Create pause event
|
||||
event = GraphRunPausedEvent(
|
||||
reason=SchedulingPause(message="test pause"),
|
||||
reasons=[SchedulingPause(message="test pause")],
|
||||
outputs={"intermediate": "result"},
|
||||
)
|
||||
|
||||
|
|
@ -381,7 +381,7 @@ class TestPauseStatePersistenceLayerTestContainers:
|
|||
command_channel = _TestCommandChannelImpl()
|
||||
layer.initialize(graph_runtime_state, command_channel)
|
||||
|
||||
event = GraphRunPausedEvent(reason=SchedulingPause(message="test pause"))
|
||||
event = GraphRunPausedEvent(reasons=[SchedulingPause(message="test pause")])
|
||||
|
||||
# Act - Save pause state
|
||||
layer.on_event(event)
|
||||
|
|
@ -390,6 +390,7 @@ class TestPauseStatePersistenceLayerTestContainers:
|
|||
pause_entity = self.workflow_run_service._workflow_run_repo.get_workflow_pause(self.test_workflow_run_id)
|
||||
assert pause_entity is not None
|
||||
assert pause_entity.workflow_execution_id == self.test_workflow_run_id
|
||||
assert pause_entity.get_pause_reasons() == event.reasons
|
||||
|
||||
state_bytes = pause_entity.get_state()
|
||||
resumption_context = WorkflowResumptionContext.loads(state_bytes.decode())
|
||||
|
|
@ -414,7 +415,7 @@ class TestPauseStatePersistenceLayerTestContainers:
|
|||
command_channel = _TestCommandChannelImpl()
|
||||
layer.initialize(graph_runtime_state, command_channel)
|
||||
|
||||
event = GraphRunPausedEvent(reason=SchedulingPause(message="test pause"))
|
||||
event = GraphRunPausedEvent(reasons=[SchedulingPause(message="test pause")])
|
||||
|
||||
# Act
|
||||
layer.on_event(event)
|
||||
|
|
@ -448,7 +449,7 @@ class TestPauseStatePersistenceLayerTestContainers:
|
|||
command_channel = _TestCommandChannelImpl()
|
||||
layer.initialize(graph_runtime_state, command_channel)
|
||||
|
||||
event = GraphRunPausedEvent(reason=SchedulingPause(message="test pause"))
|
||||
event = GraphRunPausedEvent(reasons=[SchedulingPause(message="test pause")])
|
||||
|
||||
# Act
|
||||
layer.on_event(event)
|
||||
|
|
@ -514,7 +515,7 @@ class TestPauseStatePersistenceLayerTestContainers:
|
|||
command_channel = _TestCommandChannelImpl()
|
||||
layer.initialize(graph_runtime_state, command_channel)
|
||||
|
||||
event = GraphRunPausedEvent(reason=SchedulingPause(message="test pause"))
|
||||
event = GraphRunPausedEvent(reasons=[SchedulingPause(message="test pause")])
|
||||
|
||||
# Act
|
||||
layer.on_event(event)
|
||||
|
|
@ -570,7 +571,7 @@ class TestPauseStatePersistenceLayerTestContainers:
|
|||
layer = self._create_pause_state_persistence_layer()
|
||||
# Don't initialize - graph_runtime_state should not be set
|
||||
|
||||
event = GraphRunPausedEvent(reason=SchedulingPause(message="test pause"))
|
||||
event = GraphRunPausedEvent(reasons=[SchedulingPause(message="test pause")])
|
||||
|
||||
# Act & Assert - Should raise AttributeError
|
||||
with pytest.raises(AttributeError):
|
||||
|
|
|
|||
|
|
@ -295,9 +295,13 @@ class TestAPIBasedExtensionService:
|
|||
original_name = created_extension.name
|
||||
original_endpoint = created_extension.api_endpoint
|
||||
|
||||
# Update the extension
|
||||
# Update the extension with guaranteed different values
|
||||
new_name = fake.company()
|
||||
# Ensure new endpoint is different from original
|
||||
new_endpoint = f"https://{fake.domain_name()}/api"
|
||||
# If by chance they're the same, generate a new one
|
||||
while new_endpoint == original_endpoint:
|
||||
new_endpoint = f"https://{fake.domain_name()}/api"
|
||||
new_api_key = fake.password(length=20)
|
||||
|
||||
created_extension.name = new_name
|
||||
|
|
|
|||
|
|
@ -0,0 +1,386 @@
|
|||
"""Unit tests for FeedbackService."""
|
||||
|
||||
import json
|
||||
from datetime import datetime
|
||||
from types import SimpleNamespace
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models.model import App, Conversation, Message
|
||||
from services.feedback_service import FeedbackService
|
||||
|
||||
|
||||
class TestFeedbackService:
|
||||
"""Test FeedbackService methods."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_session(self, monkeypatch):
|
||||
"""Mock database session."""
|
||||
mock_session = mock.Mock()
|
||||
monkeypatch.setattr(db, "session", mock_session)
|
||||
return mock_session
|
||||
|
||||
@pytest.fixture
|
||||
def sample_data(self):
|
||||
"""Create sample data for testing."""
|
||||
app_id = "test-app-id"
|
||||
|
||||
# Create mock models
|
||||
app = App(id=app_id, name="Test App")
|
||||
|
||||
conversation = Conversation(id="test-conversation-id", app_id=app_id, name="Test Conversation")
|
||||
|
||||
message = Message(
|
||||
id="test-message-id",
|
||||
conversation_id="test-conversation-id",
|
||||
query="What is AI?",
|
||||
answer="AI is artificial intelligence.",
|
||||
inputs={"query": "What is AI?"},
|
||||
created_at=datetime(2024, 1, 1, 10, 0, 0),
|
||||
)
|
||||
|
||||
# Use SimpleNamespace to avoid ORM model constructor issues
|
||||
user_feedback = SimpleNamespace(
|
||||
id="user-feedback-id",
|
||||
app_id=app_id,
|
||||
conversation_id="test-conversation-id",
|
||||
message_id="test-message-id",
|
||||
rating="like",
|
||||
from_source="user",
|
||||
content="Great answer!",
|
||||
from_end_user_id="user-123",
|
||||
from_account_id=None,
|
||||
from_account=None, # Mock account object
|
||||
created_at=datetime(2024, 1, 1, 10, 5, 0),
|
||||
)
|
||||
|
||||
admin_feedback = SimpleNamespace(
|
||||
id="admin-feedback-id",
|
||||
app_id=app_id,
|
||||
conversation_id="test-conversation-id",
|
||||
message_id="test-message-id",
|
||||
rating="dislike",
|
||||
from_source="admin",
|
||||
content="Could be more detailed",
|
||||
from_end_user_id=None,
|
||||
from_account_id="admin-456",
|
||||
from_account=SimpleNamespace(name="Admin User"), # Mock account object
|
||||
created_at=datetime(2024, 1, 1, 10, 10, 0),
|
||||
)
|
||||
|
||||
return {
|
||||
"app": app,
|
||||
"conversation": conversation,
|
||||
"message": message,
|
||||
"user_feedback": user_feedback,
|
||||
"admin_feedback": admin_feedback,
|
||||
}
|
||||
|
||||
def test_export_feedbacks_csv_format(self, mock_db_session, sample_data):
|
||||
"""Test exporting feedback data in CSV format."""
|
||||
|
||||
# Setup mock query result
|
||||
mock_query = mock.Mock()
|
||||
mock_query.join.return_value = mock_query
|
||||
mock_query.outerjoin.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.filter.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.all.return_value = [
|
||||
(
|
||||
sample_data["user_feedback"],
|
||||
sample_data["message"],
|
||||
sample_data["conversation"],
|
||||
sample_data["app"],
|
||||
sample_data["user_feedback"].from_account,
|
||||
)
|
||||
]
|
||||
|
||||
mock_db_session.query.return_value = mock_query
|
||||
|
||||
# Test CSV export
|
||||
result = FeedbackService.export_feedbacks(app_id=sample_data["app"].id, format_type="csv")
|
||||
|
||||
# Verify response structure
|
||||
assert hasattr(result, "headers")
|
||||
assert "text/csv" in result.headers["Content-Type"]
|
||||
assert "attachment" in result.headers["Content-Disposition"]
|
||||
|
||||
# Check CSV content
|
||||
csv_content = result.get_data(as_text=True)
|
||||
# Verify essential headers exist (order may include additional columns)
|
||||
assert "feedback_id" in csv_content
|
||||
assert "app_name" in csv_content
|
||||
assert "conversation_id" in csv_content
|
||||
assert sample_data["app"].name in csv_content
|
||||
assert sample_data["message"].query in csv_content
|
||||
|
||||
def test_export_feedbacks_json_format(self, mock_db_session, sample_data):
|
||||
"""Test exporting feedback data in JSON format."""
|
||||
|
||||
# Setup mock query result
|
||||
mock_query = mock.Mock()
|
||||
mock_query.join.return_value = mock_query
|
||||
mock_query.outerjoin.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.filter.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.all.return_value = [
|
||||
(
|
||||
sample_data["admin_feedback"],
|
||||
sample_data["message"],
|
||||
sample_data["conversation"],
|
||||
sample_data["app"],
|
||||
sample_data["admin_feedback"].from_account,
|
||||
)
|
||||
]
|
||||
|
||||
mock_db_session.query.return_value = mock_query
|
||||
|
||||
# Test JSON export
|
||||
result = FeedbackService.export_feedbacks(app_id=sample_data["app"].id, format_type="json")
|
||||
|
||||
# Verify response structure
|
||||
assert hasattr(result, "headers")
|
||||
assert "application/json" in result.headers["Content-Type"]
|
||||
assert "attachment" in result.headers["Content-Disposition"]
|
||||
|
||||
# Check JSON content
|
||||
json_content = json.loads(result.get_data(as_text=True))
|
||||
assert "export_info" in json_content
|
||||
assert "feedback_data" in json_content
|
||||
assert json_content["export_info"]["app_id"] == sample_data["app"].id
|
||||
assert json_content["export_info"]["total_records"] == 1
|
||||
|
||||
def test_export_feedbacks_with_filters(self, mock_db_session, sample_data):
|
||||
"""Test exporting feedback with various filters."""
|
||||
|
||||
# Setup mock query result
|
||||
mock_query = mock.Mock()
|
||||
mock_query.join.return_value = mock_query
|
||||
mock_query.outerjoin.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.filter.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.all.return_value = [
|
||||
(
|
||||
sample_data["admin_feedback"],
|
||||
sample_data["message"],
|
||||
sample_data["conversation"],
|
||||
sample_data["app"],
|
||||
sample_data["admin_feedback"].from_account,
|
||||
)
|
||||
]
|
||||
|
||||
mock_db_session.query.return_value = mock_query
|
||||
|
||||
# Test with filters
|
||||
result = FeedbackService.export_feedbacks(
|
||||
app_id=sample_data["app"].id,
|
||||
from_source="admin",
|
||||
rating="dislike",
|
||||
has_comment=True,
|
||||
start_date="2024-01-01",
|
||||
end_date="2024-12-31",
|
||||
format_type="csv",
|
||||
)
|
||||
|
||||
# Verify filters were applied
|
||||
assert mock_query.filter.called
|
||||
filter_calls = mock_query.filter.call_args_list
|
||||
# At least three filter invocations are expected (source, rating, comment)
|
||||
assert len(filter_calls) >= 3
|
||||
|
||||
def test_export_feedbacks_no_data(self, mock_db_session, sample_data):
|
||||
"""Test exporting feedback when no data exists."""
|
||||
|
||||
# Setup mock query result with no data
|
||||
mock_query = mock.Mock()
|
||||
mock_query.join.return_value = mock_query
|
||||
mock_query.outerjoin.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.filter.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.all.return_value = []
|
||||
|
||||
mock_db_session.query.return_value = mock_query
|
||||
|
||||
result = FeedbackService.export_feedbacks(app_id=sample_data["app"].id, format_type="csv")
|
||||
|
||||
# Should return an empty CSV with headers only
|
||||
assert hasattr(result, "headers")
|
||||
assert "text/csv" in result.headers["Content-Type"]
|
||||
csv_content = result.get_data(as_text=True)
|
||||
# Headers should exist (order can include additional columns)
|
||||
assert "feedback_id" in csv_content
|
||||
assert "app_name" in csv_content
|
||||
assert "conversation_id" in csv_content
|
||||
# No data rows expected
|
||||
assert len([line for line in csv_content.strip().splitlines() if line.strip()]) == 1
|
||||
|
||||
def test_export_feedbacks_invalid_date_format(self, mock_db_session, sample_data):
|
||||
"""Test exporting feedback with invalid date format."""
|
||||
|
||||
# Test with invalid start_date
|
||||
with pytest.raises(ValueError, match="Invalid start_date format"):
|
||||
FeedbackService.export_feedbacks(app_id=sample_data["app"].id, start_date="invalid-date-format")
|
||||
|
||||
# Test with invalid end_date
|
||||
with pytest.raises(ValueError, match="Invalid end_date format"):
|
||||
FeedbackService.export_feedbacks(app_id=sample_data["app"].id, end_date="invalid-date-format")
|
||||
|
||||
def test_export_feedbacks_invalid_format(self, mock_db_session, sample_data):
|
||||
"""Test exporting feedback with unsupported format."""
|
||||
|
||||
with pytest.raises(ValueError, match="Unsupported format"):
|
||||
FeedbackService.export_feedbacks(
|
||||
app_id=sample_data["app"].id,
|
||||
format_type="xml", # Unsupported format
|
||||
)
|
||||
|
||||
def test_export_feedbacks_long_response_truncation(self, mock_db_session, sample_data):
|
||||
"""Test that long AI responses are truncated in export."""
|
||||
|
||||
# Create message with long response
|
||||
long_message = Message(
|
||||
id="long-message-id",
|
||||
conversation_id="test-conversation-id",
|
||||
query="What is AI?",
|
||||
answer="A" * 600, # 600 character response
|
||||
inputs={"query": "What is AI?"},
|
||||
created_at=datetime(2024, 1, 1, 10, 0, 0),
|
||||
)
|
||||
|
||||
# Setup mock query result
|
||||
mock_query = mock.Mock()
|
||||
mock_query.join.return_value = mock_query
|
||||
mock_query.outerjoin.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.filter.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.all.return_value = [
|
||||
(
|
||||
sample_data["user_feedback"],
|
||||
long_message,
|
||||
sample_data["conversation"],
|
||||
sample_data["app"],
|
||||
sample_data["user_feedback"].from_account,
|
||||
)
|
||||
]
|
||||
|
||||
mock_db_session.query.return_value = mock_query
|
||||
|
||||
# Test export
|
||||
result = FeedbackService.export_feedbacks(app_id=sample_data["app"].id, format_type="json")
|
||||
|
||||
# Check JSON content
|
||||
json_content = json.loads(result.get_data(as_text=True))
|
||||
exported_answer = json_content["feedback_data"][0]["ai_response"]
|
||||
|
||||
# Should be truncated with ellipsis
|
||||
assert len(exported_answer) <= 503 # 500 + "..."
|
||||
assert exported_answer.endswith("...")
|
||||
assert len(exported_answer) > 500 # Should be close to limit
|
||||
|
||||
def test_export_feedbacks_unicode_content(self, mock_db_session, sample_data):
|
||||
"""Test exporting feedback with unicode content (Chinese characters)."""
|
||||
|
||||
# Create feedback with Chinese content (use SimpleNamespace to avoid ORM constructor constraints)
|
||||
chinese_feedback = SimpleNamespace(
|
||||
id="chinese-feedback-id",
|
||||
app_id=sample_data["app"].id,
|
||||
conversation_id="test-conversation-id",
|
||||
message_id="test-message-id",
|
||||
rating="dislike",
|
||||
from_source="user",
|
||||
content="回答不够详细,需要更多信息",
|
||||
from_end_user_id="user-123",
|
||||
from_account_id=None,
|
||||
created_at=datetime(2024, 1, 1, 10, 5, 0),
|
||||
)
|
||||
|
||||
# Create Chinese message
|
||||
chinese_message = Message(
|
||||
id="chinese-message-id",
|
||||
conversation_id="test-conversation-id",
|
||||
query="什么是人工智能?",
|
||||
answer="人工智能是模拟人类智能的技术。",
|
||||
inputs={"query": "什么是人工智能?"},
|
||||
created_at=datetime(2024, 1, 1, 10, 0, 0),
|
||||
)
|
||||
|
||||
# Setup mock query result
|
||||
mock_query = mock.Mock()
|
||||
mock_query.join.return_value = mock_query
|
||||
mock_query.outerjoin.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.filter.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.all.return_value = [
|
||||
(
|
||||
chinese_feedback,
|
||||
chinese_message,
|
||||
sample_data["conversation"],
|
||||
sample_data["app"],
|
||||
None, # No account for user feedback
|
||||
)
|
||||
]
|
||||
|
||||
mock_db_session.query.return_value = mock_query
|
||||
|
||||
# Test export
|
||||
result = FeedbackService.export_feedbacks(app_id=sample_data["app"].id, format_type="csv")
|
||||
|
||||
# Check that unicode content is preserved
|
||||
csv_content = result.get_data(as_text=True)
|
||||
assert "什么是人工智能?" in csv_content
|
||||
assert "回答不够详细,需要更多信息" in csv_content
|
||||
assert "人工智能是模拟人类智能的技术" in csv_content
|
||||
|
||||
def test_export_feedbacks_emoji_ratings(self, mock_db_session, sample_data):
|
||||
"""Test that rating emojis are properly formatted in export."""
|
||||
|
||||
# Setup mock query result with both like and dislike feedback
|
||||
mock_query = mock.Mock()
|
||||
mock_query.join.return_value = mock_query
|
||||
mock_query.outerjoin.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.filter.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.all.return_value = [
|
||||
(
|
||||
sample_data["user_feedback"],
|
||||
sample_data["message"],
|
||||
sample_data["conversation"],
|
||||
sample_data["app"],
|
||||
sample_data["user_feedback"].from_account,
|
||||
),
|
||||
(
|
||||
sample_data["admin_feedback"],
|
||||
sample_data["message"],
|
||||
sample_data["conversation"],
|
||||
sample_data["app"],
|
||||
sample_data["admin_feedback"].from_account,
|
||||
),
|
||||
]
|
||||
|
||||
mock_db_session.query.return_value = mock_query
|
||||
|
||||
# Test export
|
||||
result = FeedbackService.export_feedbacks(app_id=sample_data["app"].id, format_type="json")
|
||||
|
||||
# Check JSON content for emoji ratings
|
||||
json_content = json.loads(result.get_data(as_text=True))
|
||||
feedback_data = json_content["feedback_data"]
|
||||
|
||||
# Should have both feedback records
|
||||
assert len(feedback_data) == 2
|
||||
|
||||
# Check that emojis are properly set
|
||||
like_feedback = next(f for f in feedback_data if f["feedback_rating_raw"] == "like")
|
||||
dislike_feedback = next(f for f in feedback_data if f["feedback_rating_raw"] == "dislike")
|
||||
|
||||
assert like_feedback["feedback_rating"] == "👍"
|
||||
assert dislike_feedback["feedback_rating"] == "👎"
|
||||
|
|
@ -334,12 +334,14 @@ class TestWorkflowPauseIntegration:
|
|||
workflow_run_id=workflow_run.id,
|
||||
state_owner_user_id=self.test_user_id,
|
||||
state=test_state,
|
||||
pause_reasons=[],
|
||||
)
|
||||
|
||||
# Assert - Pause state created
|
||||
assert pause_entity is not None
|
||||
assert pause_entity.id is not None
|
||||
assert pause_entity.workflow_execution_id == workflow_run.id
|
||||
assert list(pause_entity.get_pause_reasons()) == []
|
||||
# Convert both to strings for comparison
|
||||
retrieved_state = pause_entity.get_state()
|
||||
if isinstance(retrieved_state, bytes):
|
||||
|
|
@ -366,6 +368,7 @@ class TestWorkflowPauseIntegration:
|
|||
if isinstance(retrieved_state, bytes):
|
||||
retrieved_state = retrieved_state.decode()
|
||||
assert retrieved_state == test_state
|
||||
assert list(retrieved_entity.get_pause_reasons()) == []
|
||||
|
||||
# Act - Resume workflow
|
||||
resumed_entity = repository.resume_workflow_pause(
|
||||
|
|
@ -402,6 +405,7 @@ class TestWorkflowPauseIntegration:
|
|||
workflow_run_id=workflow_run.id,
|
||||
state_owner_user_id=self.test_user_id,
|
||||
state=test_state,
|
||||
pause_reasons=[],
|
||||
)
|
||||
|
||||
assert pause_entity is not None
|
||||
|
|
@ -432,6 +436,7 @@ class TestWorkflowPauseIntegration:
|
|||
workflow_run_id=workflow_run.id,
|
||||
state_owner_user_id=self.test_user_id,
|
||||
state=test_state,
|
||||
pause_reasons=[],
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("test_case", resume_workflow_success_cases(), ids=lambda tc: tc.name)
|
||||
|
|
@ -449,6 +454,7 @@ class TestWorkflowPauseIntegration:
|
|||
workflow_run_id=workflow_run.id,
|
||||
state_owner_user_id=self.test_user_id,
|
||||
state=test_state,
|
||||
pause_reasons=[],
|
||||
)
|
||||
|
||||
self.session.refresh(workflow_run)
|
||||
|
|
@ -480,6 +486,7 @@ class TestWorkflowPauseIntegration:
|
|||
workflow_run_id=workflow_run.id,
|
||||
state_owner_user_id=self.test_user_id,
|
||||
state=test_state,
|
||||
pause_reasons=[],
|
||||
)
|
||||
|
||||
self.session.refresh(workflow_run)
|
||||
|
|
@ -503,6 +510,7 @@ class TestWorkflowPauseIntegration:
|
|||
workflow_run_id=workflow_run.id,
|
||||
state_owner_user_id=self.test_user_id,
|
||||
state=test_state,
|
||||
pause_reasons=[],
|
||||
)
|
||||
pause_model = self.session.get(WorkflowPauseModel, pause_entity.id)
|
||||
pause_model.resumed_at = naive_utc_now()
|
||||
|
|
@ -530,6 +538,7 @@ class TestWorkflowPauseIntegration:
|
|||
workflow_run_id=nonexistent_id,
|
||||
state_owner_user_id=self.test_user_id,
|
||||
state=test_state,
|
||||
pause_reasons=[],
|
||||
)
|
||||
|
||||
def test_resume_nonexistent_workflow_run(self):
|
||||
|
|
@ -543,6 +552,7 @@ class TestWorkflowPauseIntegration:
|
|||
workflow_run_id=workflow_run.id,
|
||||
state_owner_user_id=self.test_user_id,
|
||||
state=test_state,
|
||||
pause_reasons=[],
|
||||
)
|
||||
|
||||
nonexistent_id = str(uuid.uuid4())
|
||||
|
|
@ -570,6 +580,7 @@ class TestWorkflowPauseIntegration:
|
|||
workflow_run_id=workflow_run.id,
|
||||
state_owner_user_id=self.test_user_id,
|
||||
state=test_state,
|
||||
pause_reasons=[],
|
||||
)
|
||||
|
||||
# Manually adjust timestamps for testing
|
||||
|
|
@ -648,6 +659,7 @@ class TestWorkflowPauseIntegration:
|
|||
workflow_run_id=workflow_run.id,
|
||||
state_owner_user_id=self.test_user_id,
|
||||
state=test_state,
|
||||
pause_reasons=[],
|
||||
)
|
||||
pause_entities.append(pause_entity)
|
||||
|
||||
|
|
@ -750,6 +762,7 @@ class TestWorkflowPauseIntegration:
|
|||
workflow_run_id=workflow_run1.id,
|
||||
state_owner_user_id=self.test_user_id,
|
||||
state=test_state,
|
||||
pause_reasons=[],
|
||||
)
|
||||
|
||||
# Try to access pause from tenant 2 using tenant 1's repository
|
||||
|
|
@ -762,6 +775,7 @@ class TestWorkflowPauseIntegration:
|
|||
workflow_run_id=workflow_run2.id,
|
||||
state_owner_user_id=account2.id,
|
||||
state=test_state,
|
||||
pause_reasons=[],
|
||||
)
|
||||
|
||||
# Assert - Both pauses should exist and be separate
|
||||
|
|
@ -782,6 +796,7 @@ class TestWorkflowPauseIntegration:
|
|||
workflow_run_id=workflow_run.id,
|
||||
state_owner_user_id=self.test_user_id,
|
||||
state=test_state,
|
||||
pause_reasons=[],
|
||||
)
|
||||
|
||||
# Verify pause is properly scoped
|
||||
|
|
@ -802,6 +817,7 @@ class TestWorkflowPauseIntegration:
|
|||
workflow_run_id=workflow_run.id,
|
||||
state_owner_user_id=self.test_user_id,
|
||||
state=test_state,
|
||||
pause_reasons=[],
|
||||
)
|
||||
|
||||
# Assert - Verify file was uploaded to storage
|
||||
|
|
@ -828,9 +844,7 @@ class TestWorkflowPauseIntegration:
|
|||
repository = self._get_workflow_run_repository()
|
||||
|
||||
pause_entity = repository.create_workflow_pause(
|
||||
workflow_run_id=workflow_run.id,
|
||||
state_owner_user_id=self.test_user_id,
|
||||
state=test_state,
|
||||
workflow_run_id=workflow_run.id, state_owner_user_id=self.test_user_id, state=test_state, pause_reasons=[]
|
||||
)
|
||||
|
||||
# Get file info before deletion
|
||||
|
|
@ -868,6 +882,7 @@ class TestWorkflowPauseIntegration:
|
|||
workflow_run_id=workflow_run.id,
|
||||
state_owner_user_id=self.test_user_id,
|
||||
state=large_state_json,
|
||||
pause_reasons=[],
|
||||
)
|
||||
|
||||
# Assert
|
||||
|
|
@ -902,9 +917,7 @@ class TestWorkflowPauseIntegration:
|
|||
|
||||
# Pause
|
||||
pause_entity = repository.create_workflow_pause(
|
||||
workflow_run_id=workflow_run.id,
|
||||
state_owner_user_id=self.test_user_id,
|
||||
state=state,
|
||||
workflow_run_id=workflow_run.id, state_owner_user_id=self.test_user_id, state=state, pause_reasons=[]
|
||||
)
|
||||
assert pause_entity is not None
|
||||
|
||||
|
|
|
|||
|
|
@ -31,7 +31,7 @@ class TestDataFactory:
|
|||
|
||||
@staticmethod
|
||||
def create_graph_run_paused_event(outputs: dict[str, object] | None = None) -> GraphRunPausedEvent:
|
||||
return GraphRunPausedEvent(reason=SchedulingPause(message="test pause"), outputs=outputs or {})
|
||||
return GraphRunPausedEvent(reasons=[SchedulingPause(message="test pause")], outputs=outputs or {})
|
||||
|
||||
@staticmethod
|
||||
def create_graph_run_started_event() -> GraphRunStartedEvent:
|
||||
|
|
@ -255,15 +255,17 @@ class TestPauseStatePersistenceLayer:
|
|||
layer.on_event(event)
|
||||
|
||||
mock_factory.assert_called_once_with(session_factory)
|
||||
mock_repo.create_workflow_pause.assert_called_once_with(
|
||||
workflow_run_id="run-123",
|
||||
state_owner_user_id="owner-123",
|
||||
state=mock_repo.create_workflow_pause.call_args.kwargs["state"],
|
||||
)
|
||||
serialized_state = mock_repo.create_workflow_pause.call_args.kwargs["state"]
|
||||
assert mock_repo.create_workflow_pause.call_count == 1
|
||||
call_kwargs = mock_repo.create_workflow_pause.call_args.kwargs
|
||||
assert call_kwargs["workflow_run_id"] == "run-123"
|
||||
assert call_kwargs["state_owner_user_id"] == "owner-123"
|
||||
serialized_state = call_kwargs["state"]
|
||||
resumption_context = WorkflowResumptionContext.loads(serialized_state)
|
||||
assert resumption_context.serialized_graph_runtime_state == expected_state
|
||||
assert resumption_context.get_generate_entity().model_dump() == generate_entity.model_dump()
|
||||
pause_reasons = call_kwargs["pause_reasons"]
|
||||
|
||||
assert isinstance(pause_reasons, list)
|
||||
|
||||
def test_on_event_ignores_non_paused_events(self, monkeypatch: pytest.MonkeyPatch):
|
||||
session_factory = Mock(name="session_factory")
|
||||
|
|
|
|||
|
|
@ -19,38 +19,18 @@ class TestPrivateWorkflowPauseEntity:
|
|||
mock_pause_model.resumed_at = None
|
||||
|
||||
# Create entity
|
||||
entity = _PrivateWorkflowPauseEntity(
|
||||
pause_model=mock_pause_model,
|
||||
)
|
||||
entity = _PrivateWorkflowPauseEntity(pause_model=mock_pause_model, reason_models=[], human_input_form=[])
|
||||
|
||||
# Verify initialization
|
||||
assert entity._pause_model is mock_pause_model
|
||||
assert entity._cached_state is None
|
||||
|
||||
def test_from_models_classmethod(self):
|
||||
"""Test from_models class method."""
|
||||
# Create mock models
|
||||
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
|
||||
mock_pause_model.id = "pause-123"
|
||||
mock_pause_model.workflow_run_id = "execution-456"
|
||||
|
||||
# Create entity using from_models
|
||||
entity = _PrivateWorkflowPauseEntity.from_models(
|
||||
workflow_pause_model=mock_pause_model,
|
||||
)
|
||||
|
||||
# Verify entity creation
|
||||
assert isinstance(entity, _PrivateWorkflowPauseEntity)
|
||||
assert entity._pause_model is mock_pause_model
|
||||
|
||||
def test_id_property(self):
|
||||
"""Test id property returns pause model ID."""
|
||||
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
|
||||
mock_pause_model.id = "pause-123"
|
||||
|
||||
entity = _PrivateWorkflowPauseEntity(
|
||||
pause_model=mock_pause_model,
|
||||
)
|
||||
entity = _PrivateWorkflowPauseEntity(pause_model=mock_pause_model, reason_models=[], human_input_form=[])
|
||||
|
||||
assert entity.id == "pause-123"
|
||||
|
||||
|
|
@ -59,9 +39,7 @@ class TestPrivateWorkflowPauseEntity:
|
|||
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
|
||||
mock_pause_model.workflow_run_id = "execution-456"
|
||||
|
||||
entity = _PrivateWorkflowPauseEntity(
|
||||
pause_model=mock_pause_model,
|
||||
)
|
||||
entity = _PrivateWorkflowPauseEntity(pause_model=mock_pause_model, reason_models=[], human_input_form=[])
|
||||
|
||||
assert entity.workflow_execution_id == "execution-456"
|
||||
|
||||
|
|
@ -72,9 +50,7 @@ class TestPrivateWorkflowPauseEntity:
|
|||
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
|
||||
mock_pause_model.resumed_at = resumed_at
|
||||
|
||||
entity = _PrivateWorkflowPauseEntity(
|
||||
pause_model=mock_pause_model,
|
||||
)
|
||||
entity = _PrivateWorkflowPauseEntity(pause_model=mock_pause_model, reason_models=[], human_input_form=[])
|
||||
|
||||
assert entity.resumed_at == resumed_at
|
||||
|
||||
|
|
@ -83,9 +59,7 @@ class TestPrivateWorkflowPauseEntity:
|
|||
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
|
||||
mock_pause_model.resumed_at = None
|
||||
|
||||
entity = _PrivateWorkflowPauseEntity(
|
||||
pause_model=mock_pause_model,
|
||||
)
|
||||
entity = _PrivateWorkflowPauseEntity(pause_model=mock_pause_model, reason_models=[], human_input_form=[])
|
||||
|
||||
assert entity.resumed_at is None
|
||||
|
||||
|
|
@ -98,9 +72,7 @@ class TestPrivateWorkflowPauseEntity:
|
|||
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
|
||||
mock_pause_model.state_object_key = "test-state-key"
|
||||
|
||||
entity = _PrivateWorkflowPauseEntity(
|
||||
pause_model=mock_pause_model,
|
||||
)
|
||||
entity = _PrivateWorkflowPauseEntity(pause_model=mock_pause_model, reason_models=[], human_input_form=[])
|
||||
|
||||
# First call should load from storage
|
||||
result = entity.get_state()
|
||||
|
|
@ -118,9 +90,7 @@ class TestPrivateWorkflowPauseEntity:
|
|||
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
|
||||
mock_pause_model.state_object_key = "test-state-key"
|
||||
|
||||
entity = _PrivateWorkflowPauseEntity(
|
||||
pause_model=mock_pause_model,
|
||||
)
|
||||
entity = _PrivateWorkflowPauseEntity(pause_model=mock_pause_model, reason_models=[], human_input_form=[])
|
||||
|
||||
# First call
|
||||
result1 = entity.get_state()
|
||||
|
|
@ -139,9 +109,7 @@ class TestPrivateWorkflowPauseEntity:
|
|||
|
||||
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
|
||||
|
||||
entity = _PrivateWorkflowPauseEntity(
|
||||
pause_model=mock_pause_model,
|
||||
)
|
||||
entity = _PrivateWorkflowPauseEntity(pause_model=mock_pause_model, reason_models=[], human_input_form=[])
|
||||
|
||||
# Pre-cache data
|
||||
entity._cached_state = state_data
|
||||
|
|
@ -162,9 +130,7 @@ class TestPrivateWorkflowPauseEntity:
|
|||
|
||||
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
|
||||
|
||||
entity = _PrivateWorkflowPauseEntity(
|
||||
pause_model=mock_pause_model,
|
||||
)
|
||||
entity = _PrivateWorkflowPauseEntity(pause_model=mock_pause_model, reason_models=[], human_input_form=[])
|
||||
|
||||
result = entity.get_state()
|
||||
|
||||
|
|
|
|||
|
|
@ -8,12 +8,13 @@ from typing import Any
|
|||
import pytest
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph.validation import GraphValidationError
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from models.enums import UserFrom
|
||||
|
||||
|
|
|
|||
|
|
@ -178,8 +178,7 @@ def test_pause_command():
|
|||
assert any(isinstance(e, GraphRunStartedEvent) for e in events)
|
||||
pause_events = [e for e in events if isinstance(e, GraphRunPausedEvent)]
|
||||
assert len(pause_events) == 1
|
||||
assert pause_events[0].reason == SchedulingPause(message="User requested pause")
|
||||
assert pause_events[0].reasons == [SchedulingPause(message="User requested pause")]
|
||||
|
||||
graph_execution = engine.graph_runtime_state.graph_execution
|
||||
assert graph_execution.paused
|
||||
assert graph_execution.pause_reason == SchedulingPause(message="User requested pause")
|
||||
assert graph_execution.pause_reasons == [SchedulingPause(message="User requested pause")]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,488 @@
|
|||
from core.helper.code_executor.code_executor import CodeLanguage
|
||||
from core.variables.types import SegmentType
|
||||
from core.workflow.nodes.code.code_node import CodeNode
|
||||
from core.workflow.nodes.code.entities import CodeNodeData
|
||||
from core.workflow.nodes.code.exc import (
|
||||
CodeNodeError,
|
||||
DepthLimitError,
|
||||
OutputValidationError,
|
||||
)
|
||||
|
||||
|
||||
class TestCodeNodeExceptions:
|
||||
"""Test suite for code node exceptions."""
|
||||
|
||||
def test_code_node_error_is_value_error(self):
|
||||
"""Test CodeNodeError inherits from ValueError."""
|
||||
error = CodeNodeError("test error")
|
||||
|
||||
assert isinstance(error, ValueError)
|
||||
assert str(error) == "test error"
|
||||
|
||||
def test_output_validation_error_is_code_node_error(self):
|
||||
"""Test OutputValidationError inherits from CodeNodeError."""
|
||||
error = OutputValidationError("validation failed")
|
||||
|
||||
assert isinstance(error, CodeNodeError)
|
||||
assert isinstance(error, ValueError)
|
||||
assert str(error) == "validation failed"
|
||||
|
||||
def test_depth_limit_error_is_code_node_error(self):
|
||||
"""Test DepthLimitError inherits from CodeNodeError."""
|
||||
error = DepthLimitError("depth exceeded")
|
||||
|
||||
assert isinstance(error, CodeNodeError)
|
||||
assert isinstance(error, ValueError)
|
||||
assert str(error) == "depth exceeded"
|
||||
|
||||
def test_code_node_error_with_empty_message(self):
|
||||
"""Test CodeNodeError with empty message."""
|
||||
error = CodeNodeError("")
|
||||
|
||||
assert str(error) == ""
|
||||
|
||||
def test_output_validation_error_with_field_info(self):
|
||||
"""Test OutputValidationError with field information."""
|
||||
error = OutputValidationError("Output 'result' is not a valid type")
|
||||
|
||||
assert "result" in str(error)
|
||||
assert "not a valid type" in str(error)
|
||||
|
||||
def test_depth_limit_error_with_limit_info(self):
|
||||
"""Test DepthLimitError with limit information."""
|
||||
error = DepthLimitError("Depth limit 5 reached, object too deep")
|
||||
|
||||
assert "5" in str(error)
|
||||
assert "too deep" in str(error)
|
||||
|
||||
|
||||
class TestCodeNodeClassMethods:
|
||||
"""Test suite for CodeNode class methods."""
|
||||
|
||||
def test_code_node_version(self):
|
||||
"""Test CodeNode version method."""
|
||||
version = CodeNode.version()
|
||||
|
||||
assert version == "1"
|
||||
|
||||
def test_get_default_config_python3(self):
|
||||
"""Test get_default_config for Python3."""
|
||||
config = CodeNode.get_default_config(filters={"code_language": CodeLanguage.PYTHON3})
|
||||
|
||||
assert config is not None
|
||||
assert isinstance(config, dict)
|
||||
|
||||
def test_get_default_config_javascript(self):
|
||||
"""Test get_default_config for JavaScript."""
|
||||
config = CodeNode.get_default_config(filters={"code_language": CodeLanguage.JAVASCRIPT})
|
||||
|
||||
assert config is not None
|
||||
assert isinstance(config, dict)
|
||||
|
||||
def test_get_default_config_no_filters(self):
|
||||
"""Test get_default_config with no filters defaults to Python3."""
|
||||
config = CodeNode.get_default_config()
|
||||
|
||||
assert config is not None
|
||||
assert isinstance(config, dict)
|
||||
|
||||
def test_get_default_config_empty_filters(self):
|
||||
"""Test get_default_config with empty filters."""
|
||||
config = CodeNode.get_default_config(filters={})
|
||||
|
||||
assert config is not None
|
||||
|
||||
|
||||
class TestCodeNodeCheckMethods:
|
||||
"""Test suite for CodeNode check methods."""
|
||||
|
||||
def test_check_string_none_value(self):
|
||||
"""Test _check_string with None value."""
|
||||
node = CodeNode.__new__(CodeNode)
|
||||
result = node._check_string(None, "test_var")
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_check_string_removes_null_bytes(self):
|
||||
"""Test _check_string removes null bytes."""
|
||||
node = CodeNode.__new__(CodeNode)
|
||||
result = node._check_string("hello\x00world", "test_var")
|
||||
|
||||
assert result == "helloworld"
|
||||
assert "\x00" not in result
|
||||
|
||||
def test_check_string_valid_string(self):
|
||||
"""Test _check_string with valid string."""
|
||||
node = CodeNode.__new__(CodeNode)
|
||||
result = node._check_string("valid string", "test_var")
|
||||
|
||||
assert result == "valid string"
|
||||
|
||||
def test_check_string_empty_string(self):
|
||||
"""Test _check_string with empty string."""
|
||||
node = CodeNode.__new__(CodeNode)
|
||||
result = node._check_string("", "test_var")
|
||||
|
||||
assert result == ""
|
||||
|
||||
def test_check_string_with_unicode(self):
|
||||
"""Test _check_string with unicode characters."""
|
||||
node = CodeNode.__new__(CodeNode)
|
||||
result = node._check_string("你好世界🌍", "test_var")
|
||||
|
||||
assert result == "你好世界🌍"
|
||||
|
||||
def test_check_boolean_none_value(self):
|
||||
"""Test _check_boolean with None value."""
|
||||
node = CodeNode.__new__(CodeNode)
|
||||
result = node._check_boolean(None, "test_var")
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_check_boolean_true_value(self):
|
||||
"""Test _check_boolean with True value."""
|
||||
node = CodeNode.__new__(CodeNode)
|
||||
result = node._check_boolean(True, "test_var")
|
||||
|
||||
assert result is True
|
||||
|
||||
def test_check_boolean_false_value(self):
|
||||
"""Test _check_boolean with False value."""
|
||||
node = CodeNode.__new__(CodeNode)
|
||||
result = node._check_boolean(False, "test_var")
|
||||
|
||||
assert result is False
|
||||
|
||||
def test_check_number_none_value(self):
|
||||
"""Test _check_number with None value."""
|
||||
node = CodeNode.__new__(CodeNode)
|
||||
result = node._check_number(None, "test_var")
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_check_number_integer_value(self):
|
||||
"""Test _check_number with integer value."""
|
||||
node = CodeNode.__new__(CodeNode)
|
||||
result = node._check_number(42, "test_var")
|
||||
|
||||
assert result == 42
|
||||
|
||||
def test_check_number_float_value(self):
|
||||
"""Test _check_number with float value."""
|
||||
node = CodeNode.__new__(CodeNode)
|
||||
result = node._check_number(3.14, "test_var")
|
||||
|
||||
assert result == 3.14
|
||||
|
||||
def test_check_number_zero(self):
|
||||
"""Test _check_number with zero."""
|
||||
node = CodeNode.__new__(CodeNode)
|
||||
result = node._check_number(0, "test_var")
|
||||
|
||||
assert result == 0
|
||||
|
||||
def test_check_number_negative(self):
|
||||
"""Test _check_number with negative number."""
|
||||
node = CodeNode.__new__(CodeNode)
|
||||
result = node._check_number(-100, "test_var")
|
||||
|
||||
assert result == -100
|
||||
|
||||
def test_check_number_negative_float(self):
|
||||
"""Test _check_number with negative float."""
|
||||
node = CodeNode.__new__(CodeNode)
|
||||
result = node._check_number(-3.14159, "test_var")
|
||||
|
||||
assert result == -3.14159
|
||||
|
||||
|
||||
class TestCodeNodeConvertBooleanToInt:
|
||||
"""Test suite for _convert_boolean_to_int static method."""
|
||||
|
||||
def test_convert_none_returns_none(self):
|
||||
"""Test converting None returns None."""
|
||||
result = CodeNode._convert_boolean_to_int(None)
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_convert_true_returns_one(self):
|
||||
"""Test converting True returns 1."""
|
||||
result = CodeNode._convert_boolean_to_int(True)
|
||||
|
||||
assert result == 1
|
||||
assert isinstance(result, int)
|
||||
|
||||
def test_convert_false_returns_zero(self):
|
||||
"""Test converting False returns 0."""
|
||||
result = CodeNode._convert_boolean_to_int(False)
|
||||
|
||||
assert result == 0
|
||||
assert isinstance(result, int)
|
||||
|
||||
def test_convert_integer_returns_same(self):
|
||||
"""Test converting integer returns same value."""
|
||||
result = CodeNode._convert_boolean_to_int(42)
|
||||
|
||||
assert result == 42
|
||||
|
||||
def test_convert_float_returns_same(self):
|
||||
"""Test converting float returns same value."""
|
||||
result = CodeNode._convert_boolean_to_int(3.14)
|
||||
|
||||
assert result == 3.14
|
||||
|
||||
def test_convert_zero_returns_zero(self):
|
||||
"""Test converting zero returns zero."""
|
||||
result = CodeNode._convert_boolean_to_int(0)
|
||||
|
||||
assert result == 0
|
||||
|
||||
def test_convert_negative_returns_same(self):
|
||||
"""Test converting negative number returns same value."""
|
||||
result = CodeNode._convert_boolean_to_int(-100)
|
||||
|
||||
assert result == -100
|
||||
|
||||
|
||||
class TestCodeNodeExtractVariableSelector:
|
||||
"""Test suite for _extract_variable_selector_to_variable_mapping."""
|
||||
|
||||
def test_extract_empty_variables(self):
|
||||
"""Test extraction with no variables."""
|
||||
node_data = {
|
||||
"title": "Test",
|
||||
"variables": [],
|
||||
"code_language": "python3",
|
||||
"code": "def main(): return {}",
|
||||
"outputs": {},
|
||||
}
|
||||
|
||||
result = CodeNode._extract_variable_selector_to_variable_mapping(
|
||||
graph_config={},
|
||||
node_id="node_1",
|
||||
node_data=node_data,
|
||||
)
|
||||
|
||||
assert result == {}
|
||||
|
||||
def test_extract_single_variable(self):
|
||||
"""Test extraction with single variable."""
|
||||
node_data = {
|
||||
"title": "Test",
|
||||
"variables": [
|
||||
{"variable": "input_text", "value_selector": ["start", "text"]},
|
||||
],
|
||||
"code_language": "python3",
|
||||
"code": "def main(): return {}",
|
||||
"outputs": {},
|
||||
}
|
||||
|
||||
result = CodeNode._extract_variable_selector_to_variable_mapping(
|
||||
graph_config={},
|
||||
node_id="node_1",
|
||||
node_data=node_data,
|
||||
)
|
||||
|
||||
assert "node_1.input_text" in result
|
||||
assert result["node_1.input_text"] == ["start", "text"]
|
||||
|
||||
def test_extract_multiple_variables(self):
|
||||
"""Test extraction with multiple variables."""
|
||||
node_data = {
|
||||
"title": "Test",
|
||||
"variables": [
|
||||
{"variable": "var1", "value_selector": ["node_a", "output1"]},
|
||||
{"variable": "var2", "value_selector": ["node_b", "output2"]},
|
||||
{"variable": "var3", "value_selector": ["node_c", "output3"]},
|
||||
],
|
||||
"code_language": "python3",
|
||||
"code": "def main(): return {}",
|
||||
"outputs": {},
|
||||
}
|
||||
|
||||
result = CodeNode._extract_variable_selector_to_variable_mapping(
|
||||
graph_config={},
|
||||
node_id="code_node",
|
||||
node_data=node_data,
|
||||
)
|
||||
|
||||
assert len(result) == 3
|
||||
assert "code_node.var1" in result
|
||||
assert "code_node.var2" in result
|
||||
assert "code_node.var3" in result
|
||||
|
||||
def test_extract_with_nested_selector(self):
|
||||
"""Test extraction with nested value selector."""
|
||||
node_data = {
|
||||
"title": "Test",
|
||||
"variables": [
|
||||
{"variable": "deep_var", "value_selector": ["node", "obj", "nested", "value"]},
|
||||
],
|
||||
"code_language": "python3",
|
||||
"code": "def main(): return {}",
|
||||
"outputs": {},
|
||||
}
|
||||
|
||||
result = CodeNode._extract_variable_selector_to_variable_mapping(
|
||||
graph_config={},
|
||||
node_id="node_x",
|
||||
node_data=node_data,
|
||||
)
|
||||
|
||||
assert result["node_x.deep_var"] == ["node", "obj", "nested", "value"]
|
||||
|
||||
|
||||
class TestCodeNodeDataValidation:
|
||||
"""Test suite for CodeNodeData validation scenarios."""
|
||||
|
||||
def test_valid_python3_code_node_data(self):
|
||||
"""Test valid Python3 CodeNodeData."""
|
||||
data = CodeNodeData(
|
||||
title="Python Code",
|
||||
variables=[],
|
||||
code_language=CodeLanguage.PYTHON3,
|
||||
code="def main(): return {'result': 1}",
|
||||
outputs={"result": CodeNodeData.Output(type=SegmentType.NUMBER)},
|
||||
)
|
||||
|
||||
assert data.code_language == CodeLanguage.PYTHON3
|
||||
|
||||
def test_valid_javascript_code_node_data(self):
|
||||
"""Test valid JavaScript CodeNodeData."""
|
||||
data = CodeNodeData(
|
||||
title="JS Code",
|
||||
variables=[],
|
||||
code_language=CodeLanguage.JAVASCRIPT,
|
||||
code="function main() { return { result: 1 }; }",
|
||||
outputs={"result": CodeNodeData.Output(type=SegmentType.NUMBER)},
|
||||
)
|
||||
|
||||
assert data.code_language == CodeLanguage.JAVASCRIPT
|
||||
|
||||
def test_code_node_data_with_all_output_types(self):
|
||||
"""Test CodeNodeData with all valid output types."""
|
||||
data = CodeNodeData(
|
||||
title="All Types",
|
||||
variables=[],
|
||||
code_language=CodeLanguage.PYTHON3,
|
||||
code="def main(): return {}",
|
||||
outputs={
|
||||
"str_out": CodeNodeData.Output(type=SegmentType.STRING),
|
||||
"num_out": CodeNodeData.Output(type=SegmentType.NUMBER),
|
||||
"bool_out": CodeNodeData.Output(type=SegmentType.BOOLEAN),
|
||||
"obj_out": CodeNodeData.Output(type=SegmentType.OBJECT),
|
||||
"arr_str": CodeNodeData.Output(type=SegmentType.ARRAY_STRING),
|
||||
"arr_num": CodeNodeData.Output(type=SegmentType.ARRAY_NUMBER),
|
||||
"arr_bool": CodeNodeData.Output(type=SegmentType.ARRAY_BOOLEAN),
|
||||
"arr_obj": CodeNodeData.Output(type=SegmentType.ARRAY_OBJECT),
|
||||
},
|
||||
)
|
||||
|
||||
assert len(data.outputs) == 8
|
||||
|
||||
def test_code_node_data_complex_nested_output(self):
|
||||
"""Test CodeNodeData with complex nested output structure."""
|
||||
data = CodeNodeData(
|
||||
title="Complex Output",
|
||||
variables=[],
|
||||
code_language=CodeLanguage.PYTHON3,
|
||||
code="def main(): return {}",
|
||||
outputs={
|
||||
"response": CodeNodeData.Output(
|
||||
type=SegmentType.OBJECT,
|
||||
children={
|
||||
"data": CodeNodeData.Output(
|
||||
type=SegmentType.OBJECT,
|
||||
children={
|
||||
"items": CodeNodeData.Output(type=SegmentType.ARRAY_STRING),
|
||||
"count": CodeNodeData.Output(type=SegmentType.NUMBER),
|
||||
},
|
||||
),
|
||||
"status": CodeNodeData.Output(type=SegmentType.STRING),
|
||||
"success": CodeNodeData.Output(type=SegmentType.BOOLEAN),
|
||||
},
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
assert data.outputs["response"].type == SegmentType.OBJECT
|
||||
assert data.outputs["response"].children is not None
|
||||
assert "data" in data.outputs["response"].children
|
||||
assert data.outputs["response"].children["data"].children is not None
|
||||
|
||||
|
||||
class TestCodeNodeInitialization:
|
||||
"""Test suite for CodeNode initialization methods."""
|
||||
|
||||
def test_init_node_data_python3(self):
|
||||
"""Test init_node_data with Python3 configuration."""
|
||||
node = CodeNode.__new__(CodeNode)
|
||||
data = {
|
||||
"title": "Test Node",
|
||||
"variables": [],
|
||||
"code_language": "python3",
|
||||
"code": "def main(): return {'x': 1}",
|
||||
"outputs": {"x": {"type": "number"}},
|
||||
}
|
||||
|
||||
node.init_node_data(data)
|
||||
|
||||
assert node._node_data.title == "Test Node"
|
||||
assert node._node_data.code_language == CodeLanguage.PYTHON3
|
||||
|
||||
def test_init_node_data_javascript(self):
|
||||
"""Test init_node_data with JavaScript configuration."""
|
||||
node = CodeNode.__new__(CodeNode)
|
||||
data = {
|
||||
"title": "JS Node",
|
||||
"variables": [],
|
||||
"code_language": "javascript",
|
||||
"code": "function main() { return { x: 1 }; }",
|
||||
"outputs": {"x": {"type": "number"}},
|
||||
}
|
||||
|
||||
node.init_node_data(data)
|
||||
|
||||
assert node._node_data.code_language == CodeLanguage.JAVASCRIPT
|
||||
|
||||
def test_get_title(self):
|
||||
"""Test _get_title method."""
|
||||
node = CodeNode.__new__(CodeNode)
|
||||
node._node_data = CodeNodeData(
|
||||
title="My Code Node",
|
||||
variables=[],
|
||||
code_language=CodeLanguage.PYTHON3,
|
||||
code="",
|
||||
outputs={},
|
||||
)
|
||||
|
||||
assert node._get_title() == "My Code Node"
|
||||
|
||||
def test_get_description_none(self):
|
||||
"""Test _get_description returns None when not set."""
|
||||
node = CodeNode.__new__(CodeNode)
|
||||
node._node_data = CodeNodeData(
|
||||
title="Test",
|
||||
variables=[],
|
||||
code_language=CodeLanguage.PYTHON3,
|
||||
code="",
|
||||
outputs={},
|
||||
)
|
||||
|
||||
assert node._get_description() is None
|
||||
|
||||
def test_get_base_node_data(self):
|
||||
"""Test get_base_node_data returns node data."""
|
||||
node = CodeNode.__new__(CodeNode)
|
||||
node._node_data = CodeNodeData(
|
||||
title="Base Test",
|
||||
variables=[],
|
||||
code_language=CodeLanguage.PYTHON3,
|
||||
code="",
|
||||
outputs={},
|
||||
)
|
||||
|
||||
result = node.get_base_node_data()
|
||||
|
||||
assert result == node._node_data
|
||||
assert result.title == "Base Test"
|
||||
|
|
@ -0,0 +1,353 @@
|
|||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from core.helper.code_executor.code_executor import CodeLanguage
|
||||
from core.variables.types import SegmentType
|
||||
from core.workflow.nodes.code.entities import CodeNodeData
|
||||
|
||||
|
||||
class TestCodeNodeDataOutput:
|
||||
"""Test suite for CodeNodeData.Output model."""
|
||||
|
||||
def test_output_with_string_type(self):
|
||||
"""Test Output with STRING type."""
|
||||
output = CodeNodeData.Output(type=SegmentType.STRING)
|
||||
|
||||
assert output.type == SegmentType.STRING
|
||||
assert output.children is None
|
||||
|
||||
def test_output_with_number_type(self):
|
||||
"""Test Output with NUMBER type."""
|
||||
output = CodeNodeData.Output(type=SegmentType.NUMBER)
|
||||
|
||||
assert output.type == SegmentType.NUMBER
|
||||
assert output.children is None
|
||||
|
||||
def test_output_with_boolean_type(self):
|
||||
"""Test Output with BOOLEAN type."""
|
||||
output = CodeNodeData.Output(type=SegmentType.BOOLEAN)
|
||||
|
||||
assert output.type == SegmentType.BOOLEAN
|
||||
|
||||
def test_output_with_object_type(self):
|
||||
"""Test Output with OBJECT type."""
|
||||
output = CodeNodeData.Output(type=SegmentType.OBJECT)
|
||||
|
||||
assert output.type == SegmentType.OBJECT
|
||||
|
||||
def test_output_with_array_string_type(self):
|
||||
"""Test Output with ARRAY_STRING type."""
|
||||
output = CodeNodeData.Output(type=SegmentType.ARRAY_STRING)
|
||||
|
||||
assert output.type == SegmentType.ARRAY_STRING
|
||||
|
||||
def test_output_with_array_number_type(self):
|
||||
"""Test Output with ARRAY_NUMBER type."""
|
||||
output = CodeNodeData.Output(type=SegmentType.ARRAY_NUMBER)
|
||||
|
||||
assert output.type == SegmentType.ARRAY_NUMBER
|
||||
|
||||
def test_output_with_array_object_type(self):
|
||||
"""Test Output with ARRAY_OBJECT type."""
|
||||
output = CodeNodeData.Output(type=SegmentType.ARRAY_OBJECT)
|
||||
|
||||
assert output.type == SegmentType.ARRAY_OBJECT
|
||||
|
||||
def test_output_with_array_boolean_type(self):
|
||||
"""Test Output with ARRAY_BOOLEAN type."""
|
||||
output = CodeNodeData.Output(type=SegmentType.ARRAY_BOOLEAN)
|
||||
|
||||
assert output.type == SegmentType.ARRAY_BOOLEAN
|
||||
|
||||
def test_output_with_nested_children(self):
|
||||
"""Test Output with nested children for OBJECT type."""
|
||||
child_output = CodeNodeData.Output(type=SegmentType.STRING)
|
||||
parent_output = CodeNodeData.Output(
|
||||
type=SegmentType.OBJECT,
|
||||
children={"name": child_output},
|
||||
)
|
||||
|
||||
assert parent_output.type == SegmentType.OBJECT
|
||||
assert parent_output.children is not None
|
||||
assert "name" in parent_output.children
|
||||
assert parent_output.children["name"].type == SegmentType.STRING
|
||||
|
||||
def test_output_with_deeply_nested_children(self):
|
||||
"""Test Output with deeply nested children."""
|
||||
inner_child = CodeNodeData.Output(type=SegmentType.NUMBER)
|
||||
middle_child = CodeNodeData.Output(
|
||||
type=SegmentType.OBJECT,
|
||||
children={"value": inner_child},
|
||||
)
|
||||
outer_output = CodeNodeData.Output(
|
||||
type=SegmentType.OBJECT,
|
||||
children={"nested": middle_child},
|
||||
)
|
||||
|
||||
assert outer_output.children is not None
|
||||
assert outer_output.children["nested"].children is not None
|
||||
assert outer_output.children["nested"].children["value"].type == SegmentType.NUMBER
|
||||
|
||||
def test_output_with_multiple_children(self):
|
||||
"""Test Output with multiple children."""
|
||||
output = CodeNodeData.Output(
|
||||
type=SegmentType.OBJECT,
|
||||
children={
|
||||
"name": CodeNodeData.Output(type=SegmentType.STRING),
|
||||
"age": CodeNodeData.Output(type=SegmentType.NUMBER),
|
||||
"active": CodeNodeData.Output(type=SegmentType.BOOLEAN),
|
||||
},
|
||||
)
|
||||
|
||||
assert output.children is not None
|
||||
assert len(output.children) == 3
|
||||
assert output.children["name"].type == SegmentType.STRING
|
||||
assert output.children["age"].type == SegmentType.NUMBER
|
||||
assert output.children["active"].type == SegmentType.BOOLEAN
|
||||
|
||||
def test_output_rejects_invalid_type(self):
|
||||
"""Test Output rejects invalid segment types."""
|
||||
with pytest.raises(ValidationError):
|
||||
CodeNodeData.Output(type=SegmentType.FILE)
|
||||
|
||||
def test_output_rejects_array_file_type(self):
|
||||
"""Test Output rejects ARRAY_FILE type."""
|
||||
with pytest.raises(ValidationError):
|
||||
CodeNodeData.Output(type=SegmentType.ARRAY_FILE)
|
||||
|
||||
|
||||
class TestCodeNodeDataDependency:
|
||||
"""Test suite for CodeNodeData.Dependency model."""
|
||||
|
||||
def test_dependency_basic(self):
|
||||
"""Test Dependency with name and version."""
|
||||
dependency = CodeNodeData.Dependency(name="numpy", version="1.24.0")
|
||||
|
||||
assert dependency.name == "numpy"
|
||||
assert dependency.version == "1.24.0"
|
||||
|
||||
def test_dependency_with_complex_version(self):
|
||||
"""Test Dependency with complex version string."""
|
||||
dependency = CodeNodeData.Dependency(name="pandas", version=">=2.0.0,<3.0.0")
|
||||
|
||||
assert dependency.name == "pandas"
|
||||
assert dependency.version == ">=2.0.0,<3.0.0"
|
||||
|
||||
def test_dependency_with_empty_version(self):
|
||||
"""Test Dependency with empty version."""
|
||||
dependency = CodeNodeData.Dependency(name="requests", version="")
|
||||
|
||||
assert dependency.name == "requests"
|
||||
assert dependency.version == ""
|
||||
|
||||
|
||||
class TestCodeNodeData:
|
||||
"""Test suite for CodeNodeData model."""
|
||||
|
||||
def test_code_node_data_python3(self):
|
||||
"""Test CodeNodeData with Python3 language."""
|
||||
data = CodeNodeData(
|
||||
title="Test Code Node",
|
||||
variables=[],
|
||||
code_language=CodeLanguage.PYTHON3,
|
||||
code="def main(): return {'result': 42}",
|
||||
outputs={"result": CodeNodeData.Output(type=SegmentType.NUMBER)},
|
||||
)
|
||||
|
||||
assert data.title == "Test Code Node"
|
||||
assert data.code_language == CodeLanguage.PYTHON3
|
||||
assert data.code == "def main(): return {'result': 42}"
|
||||
assert "result" in data.outputs
|
||||
assert data.dependencies is None
|
||||
|
||||
def test_code_node_data_javascript(self):
|
||||
"""Test CodeNodeData with JavaScript language."""
|
||||
data = CodeNodeData(
|
||||
title="JS Code Node",
|
||||
variables=[],
|
||||
code_language=CodeLanguage.JAVASCRIPT,
|
||||
code="function main() { return { result: 'hello' }; }",
|
||||
outputs={"result": CodeNodeData.Output(type=SegmentType.STRING)},
|
||||
)
|
||||
|
||||
assert data.code_language == CodeLanguage.JAVASCRIPT
|
||||
assert "result" in data.outputs
|
||||
assert data.outputs["result"].type == SegmentType.STRING
|
||||
|
||||
def test_code_node_data_with_dependencies(self):
|
||||
"""Test CodeNodeData with dependencies."""
|
||||
data = CodeNodeData(
|
||||
title="Code with Deps",
|
||||
variables=[],
|
||||
code_language=CodeLanguage.PYTHON3,
|
||||
code="import numpy as np\ndef main(): return {'sum': 10}",
|
||||
outputs={"sum": CodeNodeData.Output(type=SegmentType.NUMBER)},
|
||||
dependencies=[
|
||||
CodeNodeData.Dependency(name="numpy", version="1.24.0"),
|
||||
CodeNodeData.Dependency(name="pandas", version="2.0.0"),
|
||||
],
|
||||
)
|
||||
|
||||
assert data.dependencies is not None
|
||||
assert len(data.dependencies) == 2
|
||||
assert data.dependencies[0].name == "numpy"
|
||||
assert data.dependencies[1].name == "pandas"
|
||||
|
||||
def test_code_node_data_with_multiple_outputs(self):
|
||||
"""Test CodeNodeData with multiple outputs."""
|
||||
data = CodeNodeData(
|
||||
title="Multi Output",
|
||||
variables=[],
|
||||
code_language=CodeLanguage.PYTHON3,
|
||||
code="def main(): return {'name': 'test', 'count': 5, 'items': ['a', 'b']}",
|
||||
outputs={
|
||||
"name": CodeNodeData.Output(type=SegmentType.STRING),
|
||||
"count": CodeNodeData.Output(type=SegmentType.NUMBER),
|
||||
"items": CodeNodeData.Output(type=SegmentType.ARRAY_STRING),
|
||||
},
|
||||
)
|
||||
|
||||
assert len(data.outputs) == 3
|
||||
assert data.outputs["name"].type == SegmentType.STRING
|
||||
assert data.outputs["count"].type == SegmentType.NUMBER
|
||||
assert data.outputs["items"].type == SegmentType.ARRAY_STRING
|
||||
|
||||
def test_code_node_data_with_object_output(self):
|
||||
"""Test CodeNodeData with nested object output."""
|
||||
data = CodeNodeData(
|
||||
title="Object Output",
|
||||
variables=[],
|
||||
code_language=CodeLanguage.PYTHON3,
|
||||
code="def main(): return {'user': {'name': 'John', 'age': 30}}",
|
||||
outputs={
|
||||
"user": CodeNodeData.Output(
|
||||
type=SegmentType.OBJECT,
|
||||
children={
|
||||
"name": CodeNodeData.Output(type=SegmentType.STRING),
|
||||
"age": CodeNodeData.Output(type=SegmentType.NUMBER),
|
||||
},
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
assert data.outputs["user"].type == SegmentType.OBJECT
|
||||
assert data.outputs["user"].children is not None
|
||||
assert len(data.outputs["user"].children) == 2
|
||||
|
||||
def test_code_node_data_with_array_object_output(self):
|
||||
"""Test CodeNodeData with array of objects output."""
|
||||
data = CodeNodeData(
|
||||
title="Array Object Output",
|
||||
variables=[],
|
||||
code_language=CodeLanguage.PYTHON3,
|
||||
code="def main(): return {'users': [{'name': 'A'}, {'name': 'B'}]}",
|
||||
outputs={
|
||||
"users": CodeNodeData.Output(
|
||||
type=SegmentType.ARRAY_OBJECT,
|
||||
children={
|
||||
"name": CodeNodeData.Output(type=SegmentType.STRING),
|
||||
},
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
assert data.outputs["users"].type == SegmentType.ARRAY_OBJECT
|
||||
assert data.outputs["users"].children is not None
|
||||
|
||||
def test_code_node_data_empty_code(self):
|
||||
"""Test CodeNodeData with empty code."""
|
||||
data = CodeNodeData(
|
||||
title="Empty Code",
|
||||
variables=[],
|
||||
code_language=CodeLanguage.PYTHON3,
|
||||
code="",
|
||||
outputs={},
|
||||
)
|
||||
|
||||
assert data.code == ""
|
||||
assert len(data.outputs) == 0
|
||||
|
||||
def test_code_node_data_multiline_code(self):
|
||||
"""Test CodeNodeData with multiline code."""
|
||||
multiline_code = """
|
||||
def main():
|
||||
result = 0
|
||||
for i in range(10):
|
||||
result += i
|
||||
return {'sum': result}
|
||||
"""
|
||||
data = CodeNodeData(
|
||||
title="Multiline Code",
|
||||
variables=[],
|
||||
code_language=CodeLanguage.PYTHON3,
|
||||
code=multiline_code,
|
||||
outputs={"sum": CodeNodeData.Output(type=SegmentType.NUMBER)},
|
||||
)
|
||||
|
||||
assert "for i in range(10)" in data.code
|
||||
assert "result += i" in data.code
|
||||
|
||||
def test_code_node_data_with_special_characters_in_code(self):
|
||||
"""Test CodeNodeData with special characters in code."""
|
||||
code_with_special = "def main(): return {'msg': 'Hello\\nWorld\\t!'}"
|
||||
data = CodeNodeData(
|
||||
title="Special Chars",
|
||||
variables=[],
|
||||
code_language=CodeLanguage.PYTHON3,
|
||||
code=code_with_special,
|
||||
outputs={"msg": CodeNodeData.Output(type=SegmentType.STRING)},
|
||||
)
|
||||
|
||||
assert "\\n" in data.code
|
||||
assert "\\t" in data.code
|
||||
|
||||
def test_code_node_data_with_unicode_in_code(self):
|
||||
"""Test CodeNodeData with unicode characters in code."""
|
||||
unicode_code = "def main(): return {'greeting': '你好世界'}"
|
||||
data = CodeNodeData(
|
||||
title="Unicode Code",
|
||||
variables=[],
|
||||
code_language=CodeLanguage.PYTHON3,
|
||||
code=unicode_code,
|
||||
outputs={"greeting": CodeNodeData.Output(type=SegmentType.STRING)},
|
||||
)
|
||||
|
||||
assert "你好世界" in data.code
|
||||
|
||||
def test_code_node_data_empty_dependencies_list(self):
|
||||
"""Test CodeNodeData with empty dependencies list."""
|
||||
data = CodeNodeData(
|
||||
title="No Deps",
|
||||
variables=[],
|
||||
code_language=CodeLanguage.PYTHON3,
|
||||
code="def main(): return {}",
|
||||
outputs={},
|
||||
dependencies=[],
|
||||
)
|
||||
|
||||
assert data.dependencies is not None
|
||||
assert len(data.dependencies) == 0
|
||||
|
||||
def test_code_node_data_with_boolean_array_output(self):
|
||||
"""Test CodeNodeData with boolean array output."""
|
||||
data = CodeNodeData(
|
||||
title="Boolean Array",
|
||||
variables=[],
|
||||
code_language=CodeLanguage.PYTHON3,
|
||||
code="def main(): return {'flags': [True, False, True]}",
|
||||
outputs={"flags": CodeNodeData.Output(type=SegmentType.ARRAY_BOOLEAN)},
|
||||
)
|
||||
|
||||
assert data.outputs["flags"].type == SegmentType.ARRAY_BOOLEAN
|
||||
|
||||
def test_code_node_data_with_number_array_output(self):
|
||||
"""Test CodeNodeData with number array output."""
|
||||
data = CodeNodeData(
|
||||
title="Number Array",
|
||||
variables=[],
|
||||
code_language=CodeLanguage.PYTHON3,
|
||||
code="def main(): return {'values': [1, 2, 3, 4, 5]}",
|
||||
outputs={"values": CodeNodeData.Output(type=SegmentType.ARRAY_NUMBER)},
|
||||
)
|
||||
|
||||
assert data.outputs["values"].type == SegmentType.ARRAY_NUMBER
|
||||
|
|
@ -0,0 +1,339 @@
|
|||
from core.workflow.nodes.iteration.entities import (
|
||||
ErrorHandleMode,
|
||||
IterationNodeData,
|
||||
IterationStartNodeData,
|
||||
IterationState,
|
||||
)
|
||||
|
||||
|
||||
class TestErrorHandleMode:
|
||||
"""Test suite for ErrorHandleMode enum."""
|
||||
|
||||
def test_terminated_value(self):
|
||||
"""Test TERMINATED enum value."""
|
||||
assert ErrorHandleMode.TERMINATED == "terminated"
|
||||
assert ErrorHandleMode.TERMINATED.value == "terminated"
|
||||
|
||||
def test_continue_on_error_value(self):
|
||||
"""Test CONTINUE_ON_ERROR enum value."""
|
||||
assert ErrorHandleMode.CONTINUE_ON_ERROR == "continue-on-error"
|
||||
assert ErrorHandleMode.CONTINUE_ON_ERROR.value == "continue-on-error"
|
||||
|
||||
def test_remove_abnormal_output_value(self):
|
||||
"""Test REMOVE_ABNORMAL_OUTPUT enum value."""
|
||||
assert ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT == "remove-abnormal-output"
|
||||
assert ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT.value == "remove-abnormal-output"
|
||||
|
||||
def test_error_handle_mode_is_str_enum(self):
|
||||
"""Test ErrorHandleMode is a string enum."""
|
||||
assert isinstance(ErrorHandleMode.TERMINATED, str)
|
||||
assert isinstance(ErrorHandleMode.CONTINUE_ON_ERROR, str)
|
||||
assert isinstance(ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT, str)
|
||||
|
||||
def test_error_handle_mode_comparison(self):
|
||||
"""Test ErrorHandleMode can be compared with strings."""
|
||||
assert ErrorHandleMode.TERMINATED == "terminated"
|
||||
assert ErrorHandleMode.CONTINUE_ON_ERROR == "continue-on-error"
|
||||
|
||||
def test_all_error_handle_modes(self):
|
||||
"""Test all ErrorHandleMode values are accessible."""
|
||||
modes = list(ErrorHandleMode)
|
||||
|
||||
assert len(modes) == 3
|
||||
assert ErrorHandleMode.TERMINATED in modes
|
||||
assert ErrorHandleMode.CONTINUE_ON_ERROR in modes
|
||||
assert ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT in modes
|
||||
|
||||
|
||||
class TestIterationNodeData:
|
||||
"""Test suite for IterationNodeData model."""
|
||||
|
||||
def test_iteration_node_data_basic(self):
|
||||
"""Test IterationNodeData with basic configuration."""
|
||||
data = IterationNodeData(
|
||||
title="Test Iteration",
|
||||
iterator_selector=["node1", "output"],
|
||||
output_selector=["iteration", "result"],
|
||||
)
|
||||
|
||||
assert data.title == "Test Iteration"
|
||||
assert data.iterator_selector == ["node1", "output"]
|
||||
assert data.output_selector == ["iteration", "result"]
|
||||
|
||||
def test_iteration_node_data_default_values(self):
|
||||
"""Test IterationNodeData default values."""
|
||||
data = IterationNodeData(
|
||||
title="Default Test",
|
||||
iterator_selector=["start", "items"],
|
||||
output_selector=["iter", "out"],
|
||||
)
|
||||
|
||||
assert data.parent_loop_id is None
|
||||
assert data.is_parallel is False
|
||||
assert data.parallel_nums == 10
|
||||
assert data.error_handle_mode == ErrorHandleMode.TERMINATED
|
||||
assert data.flatten_output is True
|
||||
|
||||
def test_iteration_node_data_parallel_mode(self):
|
||||
"""Test IterationNodeData with parallel mode enabled."""
|
||||
data = IterationNodeData(
|
||||
title="Parallel Iteration",
|
||||
iterator_selector=["node", "list"],
|
||||
output_selector=["iter", "output"],
|
||||
is_parallel=True,
|
||||
parallel_nums=5,
|
||||
)
|
||||
|
||||
assert data.is_parallel is True
|
||||
assert data.parallel_nums == 5
|
||||
|
||||
def test_iteration_node_data_custom_parallel_nums(self):
|
||||
"""Test IterationNodeData with custom parallel numbers."""
|
||||
data = IterationNodeData(
|
||||
title="Custom Parallel",
|
||||
iterator_selector=["a", "b"],
|
||||
output_selector=["c", "d"],
|
||||
parallel_nums=20,
|
||||
)
|
||||
|
||||
assert data.parallel_nums == 20
|
||||
|
||||
def test_iteration_node_data_continue_on_error(self):
|
||||
"""Test IterationNodeData with continue on error mode."""
|
||||
data = IterationNodeData(
|
||||
title="Continue Error",
|
||||
iterator_selector=["x", "y"],
|
||||
output_selector=["z", "w"],
|
||||
error_handle_mode=ErrorHandleMode.CONTINUE_ON_ERROR,
|
||||
)
|
||||
|
||||
assert data.error_handle_mode == ErrorHandleMode.CONTINUE_ON_ERROR
|
||||
|
||||
def test_iteration_node_data_remove_abnormal_output(self):
|
||||
"""Test IterationNodeData with remove abnormal output mode."""
|
||||
data = IterationNodeData(
|
||||
title="Remove Abnormal",
|
||||
iterator_selector=["input", "array"],
|
||||
output_selector=["output", "result"],
|
||||
error_handle_mode=ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT,
|
||||
)
|
||||
|
||||
assert data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT
|
||||
|
||||
def test_iteration_node_data_flatten_output_disabled(self):
|
||||
"""Test IterationNodeData with flatten output disabled."""
|
||||
data = IterationNodeData(
|
||||
title="No Flatten",
|
||||
iterator_selector=["a"],
|
||||
output_selector=["b"],
|
||||
flatten_output=False,
|
||||
)
|
||||
|
||||
assert data.flatten_output is False
|
||||
|
||||
def test_iteration_node_data_with_parent_loop_id(self):
|
||||
"""Test IterationNodeData with parent loop ID."""
|
||||
data = IterationNodeData(
|
||||
title="Nested Loop",
|
||||
iterator_selector=["parent", "items"],
|
||||
output_selector=["child", "output"],
|
||||
parent_loop_id="parent_loop_123",
|
||||
)
|
||||
|
||||
assert data.parent_loop_id == "parent_loop_123"
|
||||
|
||||
def test_iteration_node_data_complex_selectors(self):
|
||||
"""Test IterationNodeData with complex selectors."""
|
||||
data = IterationNodeData(
|
||||
title="Complex Selectors",
|
||||
iterator_selector=["node1", "output", "data", "items"],
|
||||
output_selector=["iteration", "result", "value"],
|
||||
)
|
||||
|
||||
assert len(data.iterator_selector) == 4
|
||||
assert len(data.output_selector) == 3
|
||||
|
||||
def test_iteration_node_data_all_options(self):
|
||||
"""Test IterationNodeData with all options configured."""
|
||||
data = IterationNodeData(
|
||||
title="Full Config",
|
||||
iterator_selector=["start", "list"],
|
||||
output_selector=["end", "result"],
|
||||
parent_loop_id="outer_loop",
|
||||
is_parallel=True,
|
||||
parallel_nums=15,
|
||||
error_handle_mode=ErrorHandleMode.CONTINUE_ON_ERROR,
|
||||
flatten_output=False,
|
||||
)
|
||||
|
||||
assert data.title == "Full Config"
|
||||
assert data.parent_loop_id == "outer_loop"
|
||||
assert data.is_parallel is True
|
||||
assert data.parallel_nums == 15
|
||||
assert data.error_handle_mode == ErrorHandleMode.CONTINUE_ON_ERROR
|
||||
assert data.flatten_output is False
|
||||
|
||||
|
||||
class TestIterationStartNodeData:
|
||||
"""Test suite for IterationStartNodeData model."""
|
||||
|
||||
def test_iteration_start_node_data_basic(self):
|
||||
"""Test IterationStartNodeData basic creation."""
|
||||
data = IterationStartNodeData(title="Iteration Start")
|
||||
|
||||
assert data.title == "Iteration Start"
|
||||
|
||||
def test_iteration_start_node_data_with_description(self):
|
||||
"""Test IterationStartNodeData with description."""
|
||||
data = IterationStartNodeData(
|
||||
title="Start Node",
|
||||
desc="This is the start of iteration",
|
||||
)
|
||||
|
||||
assert data.title == "Start Node"
|
||||
assert data.desc == "This is the start of iteration"
|
||||
|
||||
|
||||
class TestIterationState:
|
||||
"""Test suite for IterationState model."""
|
||||
|
||||
def test_iteration_state_default_values(self):
|
||||
"""Test IterationState default values."""
|
||||
state = IterationState()
|
||||
|
||||
assert state.outputs == []
|
||||
assert state.current_output is None
|
||||
|
||||
def test_iteration_state_with_outputs(self):
|
||||
"""Test IterationState with outputs."""
|
||||
state = IterationState(outputs=["result1", "result2", "result3"])
|
||||
|
||||
assert len(state.outputs) == 3
|
||||
assert state.outputs[0] == "result1"
|
||||
assert state.outputs[2] == "result3"
|
||||
|
||||
def test_iteration_state_with_current_output(self):
|
||||
"""Test IterationState with current output."""
|
||||
state = IterationState(current_output="current_value")
|
||||
|
||||
assert state.current_output == "current_value"
|
||||
|
||||
def test_iteration_state_get_last_output_with_outputs(self):
|
||||
"""Test get_last_output with outputs present."""
|
||||
state = IterationState(outputs=["first", "second", "last"])
|
||||
|
||||
result = state.get_last_output()
|
||||
|
||||
assert result == "last"
|
||||
|
||||
def test_iteration_state_get_last_output_empty(self):
|
||||
"""Test get_last_output with empty outputs."""
|
||||
state = IterationState(outputs=[])
|
||||
|
||||
result = state.get_last_output()
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_iteration_state_get_last_output_single(self):
|
||||
"""Test get_last_output with single output."""
|
||||
state = IterationState(outputs=["only_one"])
|
||||
|
||||
result = state.get_last_output()
|
||||
|
||||
assert result == "only_one"
|
||||
|
||||
def test_iteration_state_get_current_output(self):
|
||||
"""Test get_current_output method."""
|
||||
state = IterationState(current_output={"key": "value"})
|
||||
|
||||
result = state.get_current_output()
|
||||
|
||||
assert result == {"key": "value"}
|
||||
|
||||
def test_iteration_state_get_current_output_none(self):
|
||||
"""Test get_current_output when None."""
|
||||
state = IterationState()
|
||||
|
||||
result = state.get_current_output()
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_iteration_state_with_complex_outputs(self):
|
||||
"""Test IterationState with complex output types."""
|
||||
state = IterationState(
|
||||
outputs=[
|
||||
{"id": 1, "name": "first"},
|
||||
{"id": 2, "name": "second"},
|
||||
[1, 2, 3],
|
||||
"string_output",
|
||||
]
|
||||
)
|
||||
|
||||
assert len(state.outputs) == 4
|
||||
assert state.outputs[0] == {"id": 1, "name": "first"}
|
||||
assert state.outputs[2] == [1, 2, 3]
|
||||
|
||||
def test_iteration_state_with_none_outputs(self):
|
||||
"""Test IterationState with None values in outputs."""
|
||||
state = IterationState(outputs=["value1", None, "value3"])
|
||||
|
||||
assert len(state.outputs) == 3
|
||||
assert state.outputs[1] is None
|
||||
|
||||
def test_iteration_state_get_last_output_with_none(self):
|
||||
"""Test get_last_output when last output is None."""
|
||||
state = IterationState(outputs=["first", None])
|
||||
|
||||
result = state.get_last_output()
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_iteration_state_metadata_class(self):
|
||||
"""Test IterationState.MetaData class."""
|
||||
metadata = IterationState.MetaData(iterator_length=10)
|
||||
|
||||
assert metadata.iterator_length == 10
|
||||
|
||||
def test_iteration_state_metadata_different_lengths(self):
|
||||
"""Test IterationState.MetaData with different lengths."""
|
||||
metadata1 = IterationState.MetaData(iterator_length=0)
|
||||
metadata2 = IterationState.MetaData(iterator_length=100)
|
||||
metadata3 = IterationState.MetaData(iterator_length=1000000)
|
||||
|
||||
assert metadata1.iterator_length == 0
|
||||
assert metadata2.iterator_length == 100
|
||||
assert metadata3.iterator_length == 1000000
|
||||
|
||||
def test_iteration_state_outputs_modification(self):
|
||||
"""Test modifying IterationState outputs."""
|
||||
state = IterationState(outputs=[])
|
||||
|
||||
state.outputs.append("new_output")
|
||||
state.outputs.append("another_output")
|
||||
|
||||
assert len(state.outputs) == 2
|
||||
assert state.get_last_output() == "another_output"
|
||||
|
||||
def test_iteration_state_current_output_update(self):
|
||||
"""Test updating current_output."""
|
||||
state = IterationState()
|
||||
|
||||
state.current_output = "first_value"
|
||||
assert state.get_current_output() == "first_value"
|
||||
|
||||
state.current_output = "updated_value"
|
||||
assert state.get_current_output() == "updated_value"
|
||||
|
||||
def test_iteration_state_with_numeric_outputs(self):
|
||||
"""Test IterationState with numeric outputs."""
|
||||
state = IterationState(outputs=[1, 2, 3, 4, 5])
|
||||
|
||||
assert state.get_last_output() == 5
|
||||
assert len(state.outputs) == 5
|
||||
|
||||
def test_iteration_state_with_boolean_outputs(self):
|
||||
"""Test IterationState with boolean outputs."""
|
||||
state = IterationState(outputs=[True, False, True])
|
||||
|
||||
assert state.get_last_output() is True
|
||||
assert state.outputs[1] is False
|
||||
|
|
@ -0,0 +1,390 @@
|
|||
from core.workflow.enums import NodeType
|
||||
from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData
|
||||
from core.workflow.nodes.iteration.exc import (
|
||||
InvalidIteratorValueError,
|
||||
IterationGraphNotFoundError,
|
||||
IterationIndexNotFoundError,
|
||||
IterationNodeError,
|
||||
IteratorVariableNotFoundError,
|
||||
StartNodeIdNotFoundError,
|
||||
)
|
||||
from core.workflow.nodes.iteration.iteration_node import IterationNode
|
||||
|
||||
|
||||
class TestIterationNodeExceptions:
|
||||
"""Test suite for iteration node exceptions."""
|
||||
|
||||
def test_iteration_node_error_is_value_error(self):
|
||||
"""Test IterationNodeError inherits from ValueError."""
|
||||
error = IterationNodeError("test error")
|
||||
|
||||
assert isinstance(error, ValueError)
|
||||
assert str(error) == "test error"
|
||||
|
||||
def test_iterator_variable_not_found_error(self):
|
||||
"""Test IteratorVariableNotFoundError."""
|
||||
error = IteratorVariableNotFoundError("Iterator variable not found")
|
||||
|
||||
assert isinstance(error, IterationNodeError)
|
||||
assert isinstance(error, ValueError)
|
||||
assert "Iterator variable not found" in str(error)
|
||||
|
||||
def test_invalid_iterator_value_error(self):
|
||||
"""Test InvalidIteratorValueError."""
|
||||
error = InvalidIteratorValueError("Invalid iterator value")
|
||||
|
||||
assert isinstance(error, IterationNodeError)
|
||||
assert "Invalid iterator value" in str(error)
|
||||
|
||||
def test_start_node_id_not_found_error(self):
|
||||
"""Test StartNodeIdNotFoundError."""
|
||||
error = StartNodeIdNotFoundError("Start node ID not found")
|
||||
|
||||
assert isinstance(error, IterationNodeError)
|
||||
assert "Start node ID not found" in str(error)
|
||||
|
||||
def test_iteration_graph_not_found_error(self):
|
||||
"""Test IterationGraphNotFoundError."""
|
||||
error = IterationGraphNotFoundError("Iteration graph not found")
|
||||
|
||||
assert isinstance(error, IterationNodeError)
|
||||
assert "Iteration graph not found" in str(error)
|
||||
|
||||
def test_iteration_index_not_found_error(self):
|
||||
"""Test IterationIndexNotFoundError."""
|
||||
error = IterationIndexNotFoundError("Iteration index not found")
|
||||
|
||||
assert isinstance(error, IterationNodeError)
|
||||
assert "Iteration index not found" in str(error)
|
||||
|
||||
def test_exception_with_empty_message(self):
|
||||
"""Test exception with empty message."""
|
||||
error = IterationNodeError("")
|
||||
|
||||
assert str(error) == ""
|
||||
|
||||
def test_exception_with_detailed_message(self):
|
||||
"""Test exception with detailed message."""
|
||||
error = IteratorVariableNotFoundError("Variable 'items' not found in node 'start_node'")
|
||||
|
||||
assert "items" in str(error)
|
||||
assert "start_node" in str(error)
|
||||
|
||||
def test_all_exceptions_inherit_from_base(self):
|
||||
"""Test all exceptions inherit from IterationNodeError."""
|
||||
exceptions = [
|
||||
IteratorVariableNotFoundError("test"),
|
||||
InvalidIteratorValueError("test"),
|
||||
StartNodeIdNotFoundError("test"),
|
||||
IterationGraphNotFoundError("test"),
|
||||
IterationIndexNotFoundError("test"),
|
||||
]
|
||||
|
||||
for exc in exceptions:
|
||||
assert isinstance(exc, IterationNodeError)
|
||||
assert isinstance(exc, ValueError)
|
||||
|
||||
|
||||
class TestIterationNodeClassAttributes:
|
||||
"""Test suite for IterationNode class attributes."""
|
||||
|
||||
def test_node_type(self):
|
||||
"""Test IterationNode node_type attribute."""
|
||||
assert IterationNode.node_type == NodeType.ITERATION
|
||||
|
||||
def test_version(self):
|
||||
"""Test IterationNode version method."""
|
||||
version = IterationNode.version()
|
||||
|
||||
assert version == "1"
|
||||
|
||||
|
||||
class TestIterationNodeDefaultConfig:
|
||||
"""Test suite for IterationNode get_default_config."""
|
||||
|
||||
def test_get_default_config_returns_dict(self):
|
||||
"""Test get_default_config returns a dictionary."""
|
||||
config = IterationNode.get_default_config()
|
||||
|
||||
assert isinstance(config, dict)
|
||||
|
||||
def test_get_default_config_type(self):
|
||||
"""Test get_default_config includes type."""
|
||||
config = IterationNode.get_default_config()
|
||||
|
||||
assert config.get("type") == "iteration"
|
||||
|
||||
def test_get_default_config_has_config_section(self):
|
||||
"""Test get_default_config has config section."""
|
||||
config = IterationNode.get_default_config()
|
||||
|
||||
assert "config" in config
|
||||
assert isinstance(config["config"], dict)
|
||||
|
||||
def test_get_default_config_is_parallel_default(self):
|
||||
"""Test get_default_config is_parallel default value."""
|
||||
config = IterationNode.get_default_config()
|
||||
|
||||
assert config["config"]["is_parallel"] is False
|
||||
|
||||
def test_get_default_config_parallel_nums_default(self):
|
||||
"""Test get_default_config parallel_nums default value."""
|
||||
config = IterationNode.get_default_config()
|
||||
|
||||
assert config["config"]["parallel_nums"] == 10
|
||||
|
||||
def test_get_default_config_error_handle_mode_default(self):
|
||||
"""Test get_default_config error_handle_mode default value."""
|
||||
config = IterationNode.get_default_config()
|
||||
|
||||
assert config["config"]["error_handle_mode"] == ErrorHandleMode.TERMINATED
|
||||
|
||||
def test_get_default_config_flatten_output_default(self):
|
||||
"""Test get_default_config flatten_output default value."""
|
||||
config = IterationNode.get_default_config()
|
||||
|
||||
assert config["config"]["flatten_output"] is True
|
||||
|
||||
def test_get_default_config_with_none_filters(self):
|
||||
"""Test get_default_config with None filters."""
|
||||
config = IterationNode.get_default_config(filters=None)
|
||||
|
||||
assert config is not None
|
||||
assert "type" in config
|
||||
|
||||
def test_get_default_config_with_empty_filters(self):
|
||||
"""Test get_default_config with empty filters."""
|
||||
config = IterationNode.get_default_config(filters={})
|
||||
|
||||
assert config is not None
|
||||
|
||||
|
||||
class TestIterationNodeInitialization:
|
||||
"""Test suite for IterationNode initialization."""
|
||||
|
||||
def test_init_node_data_basic(self):
|
||||
"""Test init_node_data with basic configuration."""
|
||||
node = IterationNode.__new__(IterationNode)
|
||||
data = {
|
||||
"title": "Test Iteration",
|
||||
"iterator_selector": ["start", "items"],
|
||||
"output_selector": ["iteration", "result"],
|
||||
}
|
||||
|
||||
node.init_node_data(data)
|
||||
|
||||
assert node._node_data.title == "Test Iteration"
|
||||
assert node._node_data.iterator_selector == ["start", "items"]
|
||||
|
||||
def test_init_node_data_with_parallel(self):
|
||||
"""Test init_node_data with parallel configuration."""
|
||||
node = IterationNode.__new__(IterationNode)
|
||||
data = {
|
||||
"title": "Parallel Iteration",
|
||||
"iterator_selector": ["node", "list"],
|
||||
"output_selector": ["out", "result"],
|
||||
"is_parallel": True,
|
||||
"parallel_nums": 5,
|
||||
}
|
||||
|
||||
node.init_node_data(data)
|
||||
|
||||
assert node._node_data.is_parallel is True
|
||||
assert node._node_data.parallel_nums == 5
|
||||
|
||||
def test_init_node_data_with_error_handle_mode(self):
|
||||
"""Test init_node_data with error handle mode."""
|
||||
node = IterationNode.__new__(IterationNode)
|
||||
data = {
|
||||
"title": "Error Handle Test",
|
||||
"iterator_selector": ["a", "b"],
|
||||
"output_selector": ["c", "d"],
|
||||
"error_handle_mode": "continue-on-error",
|
||||
}
|
||||
|
||||
node.init_node_data(data)
|
||||
|
||||
assert node._node_data.error_handle_mode == ErrorHandleMode.CONTINUE_ON_ERROR
|
||||
|
||||
def test_get_title(self):
|
||||
"""Test _get_title method."""
|
||||
node = IterationNode.__new__(IterationNode)
|
||||
node._node_data = IterationNodeData(
|
||||
title="My Iteration",
|
||||
iterator_selector=["x"],
|
||||
output_selector=["y"],
|
||||
)
|
||||
|
||||
assert node._get_title() == "My Iteration"
|
||||
|
||||
def test_get_description_none(self):
|
||||
"""Test _get_description returns None when not set."""
|
||||
node = IterationNode.__new__(IterationNode)
|
||||
node._node_data = IterationNodeData(
|
||||
title="Test",
|
||||
iterator_selector=["a"],
|
||||
output_selector=["b"],
|
||||
)
|
||||
|
||||
assert node._get_description() is None
|
||||
|
||||
def test_get_description_with_value(self):
|
||||
"""Test _get_description with value."""
|
||||
node = IterationNode.__new__(IterationNode)
|
||||
node._node_data = IterationNodeData(
|
||||
title="Test",
|
||||
desc="This is a description",
|
||||
iterator_selector=["a"],
|
||||
output_selector=["b"],
|
||||
)
|
||||
|
||||
assert node._get_description() == "This is a description"
|
||||
|
||||
def test_get_base_node_data(self):
|
||||
"""Test get_base_node_data returns node data."""
|
||||
node = IterationNode.__new__(IterationNode)
|
||||
node._node_data = IterationNodeData(
|
||||
title="Base Test",
|
||||
iterator_selector=["x"],
|
||||
output_selector=["y"],
|
||||
)
|
||||
|
||||
result = node.get_base_node_data()
|
||||
|
||||
assert result == node._node_data
|
||||
|
||||
|
||||
class TestIterationNodeDataValidation:
|
||||
"""Test suite for IterationNodeData validation scenarios."""
|
||||
|
||||
def test_valid_iteration_node_data(self):
|
||||
"""Test valid IterationNodeData creation."""
|
||||
data = IterationNodeData(
|
||||
title="Valid Iteration",
|
||||
iterator_selector=["start", "items"],
|
||||
output_selector=["end", "result"],
|
||||
)
|
||||
|
||||
assert data.title == "Valid Iteration"
|
||||
|
||||
def test_iteration_node_data_with_all_error_modes(self):
|
||||
"""Test IterationNodeData with all error handle modes."""
|
||||
modes = [
|
||||
ErrorHandleMode.TERMINATED,
|
||||
ErrorHandleMode.CONTINUE_ON_ERROR,
|
||||
ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT,
|
||||
]
|
||||
|
||||
for mode in modes:
|
||||
data = IterationNodeData(
|
||||
title=f"Test {mode}",
|
||||
iterator_selector=["a"],
|
||||
output_selector=["b"],
|
||||
error_handle_mode=mode,
|
||||
)
|
||||
assert data.error_handle_mode == mode
|
||||
|
||||
def test_iteration_node_data_parallel_configuration(self):
|
||||
"""Test IterationNodeData parallel configuration combinations."""
|
||||
configs = [
|
||||
(False, 10),
|
||||
(True, 1),
|
||||
(True, 5),
|
||||
(True, 20),
|
||||
(True, 100),
|
||||
]
|
||||
|
||||
for is_parallel, parallel_nums in configs:
|
||||
data = IterationNodeData(
|
||||
title="Parallel Test",
|
||||
iterator_selector=["x"],
|
||||
output_selector=["y"],
|
||||
is_parallel=is_parallel,
|
||||
parallel_nums=parallel_nums,
|
||||
)
|
||||
assert data.is_parallel == is_parallel
|
||||
assert data.parallel_nums == parallel_nums
|
||||
|
||||
def test_iteration_node_data_flatten_output_options(self):
|
||||
"""Test IterationNodeData flatten_output options."""
|
||||
data_flatten = IterationNodeData(
|
||||
title="Flatten True",
|
||||
iterator_selector=["a"],
|
||||
output_selector=["b"],
|
||||
flatten_output=True,
|
||||
)
|
||||
|
||||
data_no_flatten = IterationNodeData(
|
||||
title="Flatten False",
|
||||
iterator_selector=["a"],
|
||||
output_selector=["b"],
|
||||
flatten_output=False,
|
||||
)
|
||||
|
||||
assert data_flatten.flatten_output is True
|
||||
assert data_no_flatten.flatten_output is False
|
||||
|
||||
def test_iteration_node_data_complex_selectors(self):
|
||||
"""Test IterationNodeData with complex selectors."""
|
||||
data = IterationNodeData(
|
||||
title="Complex",
|
||||
iterator_selector=["node1", "output", "data", "items", "list"],
|
||||
output_selector=["iteration", "result", "value", "final"],
|
||||
)
|
||||
|
||||
assert len(data.iterator_selector) == 5
|
||||
assert len(data.output_selector) == 4
|
||||
|
||||
def test_iteration_node_data_single_element_selectors(self):
|
||||
"""Test IterationNodeData with single element selectors."""
|
||||
data = IterationNodeData(
|
||||
title="Single",
|
||||
iterator_selector=["items"],
|
||||
output_selector=["result"],
|
||||
)
|
||||
|
||||
assert len(data.iterator_selector) == 1
|
||||
assert len(data.output_selector) == 1
|
||||
|
||||
|
||||
class TestIterationNodeErrorStrategies:
|
||||
"""Test suite for IterationNode error strategies."""
|
||||
|
||||
def test_get_error_strategy_default(self):
|
||||
"""Test _get_error_strategy with default value."""
|
||||
node = IterationNode.__new__(IterationNode)
|
||||
node._node_data = IterationNodeData(
|
||||
title="Test",
|
||||
iterator_selector=["a"],
|
||||
output_selector=["b"],
|
||||
)
|
||||
|
||||
result = node._get_error_strategy()
|
||||
|
||||
assert result is None or result == node._node_data.error_strategy
|
||||
|
||||
def test_get_retry_config(self):
|
||||
"""Test _get_retry_config method."""
|
||||
node = IterationNode.__new__(IterationNode)
|
||||
node._node_data = IterationNodeData(
|
||||
title="Test",
|
||||
iterator_selector=["a"],
|
||||
output_selector=["b"],
|
||||
)
|
||||
|
||||
result = node._get_retry_config()
|
||||
|
||||
assert result is not None
|
||||
|
||||
def test_get_default_value_dict(self):
|
||||
"""Test _get_default_value_dict method."""
|
||||
node = IterationNode.__new__(IterationNode)
|
||||
node._node_data = IterationNodeData(
|
||||
title="Test",
|
||||
iterator_selector=["a"],
|
||||
output_selector=["b"],
|
||||
)
|
||||
|
||||
result = node._get_default_value_dict()
|
||||
|
||||
assert isinstance(result, dict)
|
||||
|
|
@ -0,0 +1 @@
|
|||
|
||||
|
|
@ -0,0 +1,544 @@
|
|||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
|
||||
from core.variables import ArrayNumberSegment, ArrayStringSegment
|
||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.list_operator.node import ListOperatorNode
|
||||
from models.workflow import WorkflowType
|
||||
|
||||
|
||||
class TestListOperatorNode:
|
||||
"""Comprehensive tests for ListOperatorNode."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_graph_runtime_state(self):
|
||||
"""Create mock GraphRuntimeState."""
|
||||
mock_state = MagicMock(spec=GraphRuntimeState)
|
||||
mock_variable_pool = MagicMock()
|
||||
mock_state.variable_pool = mock_variable_pool
|
||||
return mock_state
|
||||
|
||||
@pytest.fixture
|
||||
def mock_graph(self):
|
||||
"""Create mock Graph."""
|
||||
return MagicMock(spec=Graph)
|
||||
|
||||
@pytest.fixture
|
||||
def graph_init_params(self):
|
||||
"""Create GraphInitParams fixture."""
|
||||
return GraphInitParams(
|
||||
tenant_id="test",
|
||||
app_id="test",
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
workflow_id="test",
|
||||
graph_config={},
|
||||
user_id="test",
|
||||
user_from="test",
|
||||
invoke_from="test",
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def list_operator_node_factory(self, graph_init_params, mock_graph, mock_graph_runtime_state):
|
||||
"""Factory fixture for creating ListOperatorNode instances."""
|
||||
|
||||
def _create_node(config, mock_variable):
|
||||
mock_graph_runtime_state.variable_pool.get.return_value = mock_variable
|
||||
return ListOperatorNode(
|
||||
id="test",
|
||||
config=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph=mock_graph,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
)
|
||||
|
||||
return _create_node
|
||||
|
||||
def test_node_initialization(self, mock_graph, mock_graph_runtime_state, graph_init_params):
|
||||
"""Test node initializes correctly."""
|
||||
config = {
|
||||
"title": "List Operator",
|
||||
"variable": ["sys", "list"],
|
||||
"filter_by": {"enabled": False},
|
||||
"order_by": {"enabled": False},
|
||||
"limit": {"enabled": False},
|
||||
}
|
||||
|
||||
node = ListOperatorNode(
|
||||
id="test",
|
||||
config=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph=mock_graph,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
)
|
||||
|
||||
assert node.node_type == NodeType.LIST_OPERATOR
|
||||
assert node._node_data.title == "List Operator"
|
||||
|
||||
def test_version(self):
|
||||
"""Test version returns correct value."""
|
||||
assert ListOperatorNode.version() == "1"
|
||||
|
||||
def test_run_with_string_array(self, list_operator_node_factory):
|
||||
"""Test with string array."""
|
||||
config = {
|
||||
"title": "Test",
|
||||
"variable": ["sys", "items"],
|
||||
"filter_by": {"enabled": False},
|
||||
"order_by": {"enabled": False},
|
||||
"limit": {"enabled": False},
|
||||
}
|
||||
|
||||
mock_var = ArrayStringSegment(value=["apple", "banana", "cherry"])
|
||||
node = list_operator_node_factory(config, mock_var)
|
||||
|
||||
result = node._run()
|
||||
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert result.outputs["result"].value == ["apple", "banana", "cherry"]
|
||||
|
||||
def test_run_with_empty_array(self, mock_graph, mock_graph_runtime_state, graph_init_params):
|
||||
"""Test with empty array."""
|
||||
config = {
|
||||
"title": "Test",
|
||||
"variable": ["sys", "items"],
|
||||
"filter_by": {"enabled": False},
|
||||
"order_by": {"enabled": False},
|
||||
"limit": {"enabled": False},
|
||||
}
|
||||
|
||||
mock_var = ArrayStringSegment(value=[])
|
||||
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
|
||||
|
||||
node = ListOperatorNode(
|
||||
id="test",
|
||||
config=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph=mock_graph,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
)
|
||||
|
||||
result = node._run()
|
||||
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert result.outputs["result"].value == []
|
||||
assert result.outputs["first_record"] is None
|
||||
assert result.outputs["last_record"] is None
|
||||
|
||||
def test_run_with_filter_contains(self, mock_graph, mock_graph_runtime_state, graph_init_params):
|
||||
"""Test filter with contains condition."""
|
||||
config = {
|
||||
"title": "Test",
|
||||
"variable": ["sys", "items"],
|
||||
"filter_by": {
|
||||
"enabled": True,
|
||||
"condition": "contains",
|
||||
"value": "app",
|
||||
},
|
||||
"order_by": {"enabled": False},
|
||||
"limit": {"enabled": False},
|
||||
}
|
||||
|
||||
mock_var = ArrayStringSegment(value=["apple", "banana", "pineapple", "cherry"])
|
||||
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
|
||||
|
||||
node = ListOperatorNode(
|
||||
id="test",
|
||||
config=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph=mock_graph,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
)
|
||||
|
||||
result = node._run()
|
||||
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert result.outputs["result"].value == ["apple", "pineapple"]
|
||||
|
||||
def test_run_with_filter_not_contains(self, mock_graph, mock_graph_runtime_state, graph_init_params):
|
||||
"""Test filter with not contains condition."""
|
||||
config = {
|
||||
"title": "Test",
|
||||
"variable": ["sys", "items"],
|
||||
"filter_by": {
|
||||
"enabled": True,
|
||||
"condition": "not contains",
|
||||
"value": "app",
|
||||
},
|
||||
"order_by": {"enabled": False},
|
||||
"limit": {"enabled": False},
|
||||
}
|
||||
|
||||
mock_var = ArrayStringSegment(value=["apple", "banana", "pineapple", "cherry"])
|
||||
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
|
||||
|
||||
node = ListOperatorNode(
|
||||
id="test",
|
||||
config=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph=mock_graph,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
)
|
||||
|
||||
result = node._run()
|
||||
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert result.outputs["result"].value == ["banana", "cherry"]
|
||||
|
||||
def test_run_with_number_filter_greater_than(self, mock_graph, mock_graph_runtime_state, graph_init_params):
|
||||
"""Test filter with greater than condition on numbers."""
|
||||
config = {
|
||||
"title": "Test",
|
||||
"variable": ["sys", "numbers"],
|
||||
"filter_by": {
|
||||
"enabled": True,
|
||||
"condition": ">",
|
||||
"value": "5",
|
||||
},
|
||||
"order_by": {"enabled": False},
|
||||
"limit": {"enabled": False},
|
||||
}
|
||||
|
||||
mock_var = ArrayNumberSegment(value=[1, 3, 5, 7, 9, 11])
|
||||
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
|
||||
|
||||
node = ListOperatorNode(
|
||||
id="test",
|
||||
config=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph=mock_graph,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
)
|
||||
|
||||
result = node._run()
|
||||
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert result.outputs["result"].value == [7, 9, 11]
|
||||
|
||||
def test_run_with_order_ascending(self, mock_graph, mock_graph_runtime_state, graph_init_params):
|
||||
"""Test ordering in ascending order."""
|
||||
config = {
|
||||
"title": "Test",
|
||||
"variable": ["sys", "items"],
|
||||
"filter_by": {"enabled": False},
|
||||
"order_by": {
|
||||
"enabled": True,
|
||||
"value": "asc",
|
||||
},
|
||||
"limit": {"enabled": False},
|
||||
}
|
||||
|
||||
mock_var = ArrayStringSegment(value=["cherry", "apple", "banana"])
|
||||
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
|
||||
|
||||
node = ListOperatorNode(
|
||||
id="test",
|
||||
config=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph=mock_graph,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
)
|
||||
|
||||
result = node._run()
|
||||
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert result.outputs["result"].value == ["apple", "banana", "cherry"]
|
||||
|
||||
def test_run_with_order_descending(self, mock_graph, mock_graph_runtime_state, graph_init_params):
|
||||
"""Test ordering in descending order."""
|
||||
config = {
|
||||
"title": "Test",
|
||||
"variable": ["sys", "items"],
|
||||
"filter_by": {"enabled": False},
|
||||
"order_by": {
|
||||
"enabled": True,
|
||||
"value": "desc",
|
||||
},
|
||||
"limit": {"enabled": False},
|
||||
}
|
||||
|
||||
mock_var = ArrayStringSegment(value=["cherry", "apple", "banana"])
|
||||
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
|
||||
|
||||
node = ListOperatorNode(
|
||||
id="test",
|
||||
config=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph=mock_graph,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
)
|
||||
|
||||
result = node._run()
|
||||
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert result.outputs["result"].value == ["cherry", "banana", "apple"]
|
||||
|
||||
def test_run_with_limit(self, mock_graph, mock_graph_runtime_state, graph_init_params):
|
||||
"""Test with limit enabled."""
|
||||
config = {
|
||||
"title": "Test",
|
||||
"variable": ["sys", "items"],
|
||||
"filter_by": {"enabled": False},
|
||||
"order_by": {"enabled": False},
|
||||
"limit": {
|
||||
"enabled": True,
|
||||
"size": 2,
|
||||
},
|
||||
}
|
||||
|
||||
mock_var = ArrayStringSegment(value=["apple", "banana", "cherry", "date"])
|
||||
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
|
||||
|
||||
node = ListOperatorNode(
|
||||
id="test",
|
||||
config=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph=mock_graph,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
)
|
||||
|
||||
result = node._run()
|
||||
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert result.outputs["result"].value == ["apple", "banana"]
|
||||
|
||||
def test_run_with_filter_order_and_limit(self, mock_graph, mock_graph_runtime_state, graph_init_params):
|
||||
"""Test with filter, order, and limit combined."""
|
||||
config = {
|
||||
"title": "Test",
|
||||
"variable": ["sys", "numbers"],
|
||||
"filter_by": {
|
||||
"enabled": True,
|
||||
"condition": ">",
|
||||
"value": "3",
|
||||
},
|
||||
"order_by": {
|
||||
"enabled": True,
|
||||
"value": "desc",
|
||||
},
|
||||
"limit": {
|
||||
"enabled": True,
|
||||
"size": 3,
|
||||
},
|
||||
}
|
||||
|
||||
mock_var = ArrayNumberSegment(value=[1, 2, 3, 4, 5, 6, 7, 8, 9])
|
||||
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
|
||||
|
||||
node = ListOperatorNode(
|
||||
id="test",
|
||||
config=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph=mock_graph,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
)
|
||||
|
||||
result = node._run()
|
||||
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert result.outputs["result"].value == [9, 8, 7]
|
||||
|
||||
def test_run_with_variable_not_found(self, mock_graph, mock_graph_runtime_state, graph_init_params):
|
||||
"""Test when variable is not found."""
|
||||
config = {
|
||||
"title": "Test",
|
||||
"variable": ["sys", "missing"],
|
||||
"filter_by": {"enabled": False},
|
||||
"order_by": {"enabled": False},
|
||||
"limit": {"enabled": False},
|
||||
}
|
||||
|
||||
mock_graph_runtime_state.variable_pool.get.return_value = None
|
||||
|
||||
node = ListOperatorNode(
|
||||
id="test",
|
||||
config=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph=mock_graph,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
)
|
||||
|
||||
result = node._run()
|
||||
|
||||
assert result.status == WorkflowNodeExecutionStatus.FAILED
|
||||
assert "Variable not found" in result.error
|
||||
|
||||
def test_run_with_first_and_last_record(self, mock_graph, mock_graph_runtime_state, graph_init_params):
|
||||
"""Test first_record and last_record outputs."""
|
||||
config = {
|
||||
"title": "Test",
|
||||
"variable": ["sys", "items"],
|
||||
"filter_by": {"enabled": False},
|
||||
"order_by": {"enabled": False},
|
||||
"limit": {"enabled": False},
|
||||
}
|
||||
|
||||
mock_var = ArrayStringSegment(value=["first", "middle", "last"])
|
||||
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
|
||||
|
||||
node = ListOperatorNode(
|
||||
id="test",
|
||||
config=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph=mock_graph,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
)
|
||||
|
||||
result = node._run()
|
||||
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert result.outputs["first_record"] == "first"
|
||||
assert result.outputs["last_record"] == "last"
|
||||
|
||||
def test_run_with_filter_startswith(self, mock_graph, mock_graph_runtime_state, graph_init_params):
|
||||
"""Test filter with startswith condition."""
|
||||
config = {
|
||||
"title": "Test",
|
||||
"variable": ["sys", "items"],
|
||||
"filter_by": {
|
||||
"enabled": True,
|
||||
"condition": "start with",
|
||||
"value": "app",
|
||||
},
|
||||
"order_by": {"enabled": False},
|
||||
"limit": {"enabled": False},
|
||||
}
|
||||
|
||||
mock_var = ArrayStringSegment(value=["apple", "application", "banana", "apricot"])
|
||||
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
|
||||
|
||||
node = ListOperatorNode(
|
||||
id="test",
|
||||
config=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph=mock_graph,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
)
|
||||
|
||||
result = node._run()
|
||||
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert result.outputs["result"].value == ["apple", "application"]
|
||||
|
||||
def test_run_with_filter_endswith(self, mock_graph, mock_graph_runtime_state, graph_init_params):
|
||||
"""Test filter with endswith condition."""
|
||||
config = {
|
||||
"title": "Test",
|
||||
"variable": ["sys", "items"],
|
||||
"filter_by": {
|
||||
"enabled": True,
|
||||
"condition": "end with",
|
||||
"value": "le",
|
||||
},
|
||||
"order_by": {"enabled": False},
|
||||
"limit": {"enabled": False},
|
||||
}
|
||||
|
||||
mock_var = ArrayStringSegment(value=["apple", "banana", "pineapple", "table"])
|
||||
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
|
||||
|
||||
node = ListOperatorNode(
|
||||
id="test",
|
||||
config=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph=mock_graph,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
)
|
||||
|
||||
result = node._run()
|
||||
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert result.outputs["result"].value == ["apple", "pineapple", "table"]
|
||||
|
||||
def test_run_with_number_filter_equals(self, mock_graph, mock_graph_runtime_state, graph_init_params):
|
||||
"""Test number filter with equals condition."""
|
||||
config = {
|
||||
"title": "Test",
|
||||
"variable": ["sys", "numbers"],
|
||||
"filter_by": {
|
||||
"enabled": True,
|
||||
"condition": "=",
|
||||
"value": "5",
|
||||
},
|
||||
"order_by": {"enabled": False},
|
||||
"limit": {"enabled": False},
|
||||
}
|
||||
|
||||
mock_var = ArrayNumberSegment(value=[1, 3, 5, 5, 7, 9])
|
||||
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
|
||||
|
||||
node = ListOperatorNode(
|
||||
id="test",
|
||||
config=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph=mock_graph,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
)
|
||||
|
||||
result = node._run()
|
||||
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert result.outputs["result"].value == [5, 5]
|
||||
|
||||
def test_run_with_number_filter_not_equals(self, mock_graph, mock_graph_runtime_state, graph_init_params):
|
||||
"""Test number filter with not equals condition."""
|
||||
config = {
|
||||
"title": "Test",
|
||||
"variable": ["sys", "numbers"],
|
||||
"filter_by": {
|
||||
"enabled": True,
|
||||
"condition": "≠",
|
||||
"value": "5",
|
||||
},
|
||||
"order_by": {"enabled": False},
|
||||
"limit": {"enabled": False},
|
||||
}
|
||||
|
||||
mock_var = ArrayNumberSegment(value=[1, 3, 5, 7, 9])
|
||||
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
|
||||
|
||||
node = ListOperatorNode(
|
||||
id="test",
|
||||
config=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph=mock_graph,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
)
|
||||
|
||||
result = node._run()
|
||||
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert result.outputs["result"].value == [1, 3, 7, 9]
|
||||
|
||||
def test_run_with_number_order_ascending(self, mock_graph, mock_graph_runtime_state, graph_init_params):
|
||||
"""Test number ordering in ascending order."""
|
||||
config = {
|
||||
"title": "Test",
|
||||
"variable": ["sys", "numbers"],
|
||||
"filter_by": {"enabled": False},
|
||||
"order_by": {
|
||||
"enabled": True,
|
||||
"value": "asc",
|
||||
},
|
||||
"limit": {"enabled": False},
|
||||
}
|
||||
|
||||
mock_var = ArrayNumberSegment(value=[9, 3, 7, 1, 5])
|
||||
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
|
||||
|
||||
node = ListOperatorNode(
|
||||
id="test",
|
||||
config=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph=mock_graph,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
)
|
||||
|
||||
result = node._run()
|
||||
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert result.outputs["result"].value == [1, 3, 5, 7, 9]
|
||||
|
|
@ -0,0 +1 @@
|
|||
|
||||
|
|
@ -0,0 +1,225 @@
|
|||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from core.workflow.enums import ErrorStrategy
|
||||
from core.workflow.nodes.template_transform.entities import TemplateTransformNodeData
|
||||
|
||||
|
||||
class TestTemplateTransformNodeData:
|
||||
"""Test suite for TemplateTransformNodeData entity."""
|
||||
|
||||
def test_valid_template_transform_node_data(self):
|
||||
"""Test creating valid TemplateTransformNodeData."""
|
||||
data = {
|
||||
"title": "Template Transform",
|
||||
"desc": "Transform data using Jinja2 template",
|
||||
"variables": [
|
||||
{"variable": "name", "value_selector": ["sys", "user_name"]},
|
||||
{"variable": "age", "value_selector": ["sys", "user_age"]},
|
||||
],
|
||||
"template": "Hello {{ name }}, you are {{ age }} years old!",
|
||||
}
|
||||
|
||||
node_data = TemplateTransformNodeData.model_validate(data)
|
||||
|
||||
assert node_data.title == "Template Transform"
|
||||
assert node_data.desc == "Transform data using Jinja2 template"
|
||||
assert len(node_data.variables) == 2
|
||||
assert node_data.variables[0].variable == "name"
|
||||
assert node_data.variables[0].value_selector == ["sys", "user_name"]
|
||||
assert node_data.variables[1].variable == "age"
|
||||
assert node_data.variables[1].value_selector == ["sys", "user_age"]
|
||||
assert node_data.template == "Hello {{ name }}, you are {{ age }} years old!"
|
||||
|
||||
def test_template_transform_node_data_with_empty_variables(self):
|
||||
"""Test TemplateTransformNodeData with no variables."""
|
||||
data = {
|
||||
"title": "Static Template",
|
||||
"variables": [],
|
||||
"template": "This is a static template with no variables.",
|
||||
}
|
||||
|
||||
node_data = TemplateTransformNodeData.model_validate(data)
|
||||
|
||||
assert node_data.title == "Static Template"
|
||||
assert len(node_data.variables) == 0
|
||||
assert node_data.template == "This is a static template with no variables."
|
||||
|
||||
def test_template_transform_node_data_with_complex_template(self):
|
||||
"""Test TemplateTransformNodeData with complex Jinja2 template."""
|
||||
data = {
|
||||
"title": "Complex Template",
|
||||
"variables": [
|
||||
{"variable": "items", "value_selector": ["sys", "item_list"]},
|
||||
{"variable": "total", "value_selector": ["sys", "total_count"]},
|
||||
],
|
||||
"template": (
|
||||
"{% for item in items %}{{ item }}{% if not loop.last %}, {% endif %}{% endfor %}. Total: {{ total }}"
|
||||
),
|
||||
}
|
||||
|
||||
node_data = TemplateTransformNodeData.model_validate(data)
|
||||
|
||||
assert node_data.title == "Complex Template"
|
||||
assert len(node_data.variables) == 2
|
||||
assert "{% for item in items %}" in node_data.template
|
||||
assert "{{ total }}" in node_data.template
|
||||
|
||||
def test_template_transform_node_data_with_error_strategy(self):
|
||||
"""Test TemplateTransformNodeData with error handling strategy."""
|
||||
data = {
|
||||
"title": "Template with Error Handling",
|
||||
"variables": [{"variable": "value", "value_selector": ["sys", "input"]}],
|
||||
"template": "{{ value }}",
|
||||
"error_strategy": "fail-branch",
|
||||
}
|
||||
|
||||
node_data = TemplateTransformNodeData.model_validate(data)
|
||||
|
||||
assert node_data.error_strategy == ErrorStrategy.FAIL_BRANCH
|
||||
|
||||
def test_template_transform_node_data_with_retry_config(self):
|
||||
"""Test TemplateTransformNodeData with retry configuration."""
|
||||
data = {
|
||||
"title": "Template with Retry",
|
||||
"variables": [{"variable": "data", "value_selector": ["sys", "data"]}],
|
||||
"template": "{{ data }}",
|
||||
"retry_config": {"enabled": True, "max_retries": 3, "retry_interval": 1000},
|
||||
}
|
||||
|
||||
node_data = TemplateTransformNodeData.model_validate(data)
|
||||
|
||||
assert node_data.retry_config.enabled is True
|
||||
assert node_data.retry_config.max_retries == 3
|
||||
assert node_data.retry_config.retry_interval == 1000
|
||||
|
||||
def test_template_transform_node_data_missing_required_fields(self):
|
||||
"""Test that missing required fields raises ValidationError."""
|
||||
data = {
|
||||
"title": "Incomplete Template",
|
||||
# Missing 'variables' and 'template'
|
||||
}
|
||||
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
TemplateTransformNodeData.model_validate(data)
|
||||
|
||||
errors = exc_info.value.errors()
|
||||
assert len(errors) >= 2
|
||||
error_fields = {error["loc"][0] for error in errors}
|
||||
assert "variables" in error_fields
|
||||
assert "template" in error_fields
|
||||
|
||||
def test_template_transform_node_data_invalid_variable_selector(self):
|
||||
"""Test that invalid variable selector format raises ValidationError."""
|
||||
data = {
|
||||
"title": "Invalid Variable",
|
||||
"variables": [
|
||||
{"variable": "name", "value_selector": "invalid_format"} # Should be list
|
||||
],
|
||||
"template": "{{ name }}",
|
||||
}
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
TemplateTransformNodeData.model_validate(data)
|
||||
|
||||
def test_template_transform_node_data_with_default_value_dict(self):
|
||||
"""Test TemplateTransformNodeData with default value dictionary."""
|
||||
data = {
|
||||
"title": "Template with Defaults",
|
||||
"variables": [
|
||||
{"variable": "name", "value_selector": ["sys", "user_name"]},
|
||||
{"variable": "greeting", "value_selector": ["sys", "greeting"]},
|
||||
],
|
||||
"template": "{{ greeting }} {{ name }}!",
|
||||
"default_value_dict": {"greeting": "Hello", "name": "Guest"},
|
||||
}
|
||||
|
||||
node_data = TemplateTransformNodeData.model_validate(data)
|
||||
|
||||
assert node_data.default_value_dict == {"greeting": "Hello", "name": "Guest"}
|
||||
|
||||
def test_template_transform_node_data_with_nested_selectors(self):
|
||||
"""Test TemplateTransformNodeData with nested variable selectors."""
|
||||
data = {
|
||||
"title": "Nested Selectors",
|
||||
"variables": [
|
||||
{"variable": "user_info", "value_selector": ["sys", "user", "profile", "name"]},
|
||||
{"variable": "settings", "value_selector": ["sys", "config", "app", "theme"]},
|
||||
],
|
||||
"template": "User: {{ user_info }}, Theme: {{ settings }}",
|
||||
}
|
||||
|
||||
node_data = TemplateTransformNodeData.model_validate(data)
|
||||
|
||||
assert len(node_data.variables) == 2
|
||||
assert node_data.variables[0].value_selector == ["sys", "user", "profile", "name"]
|
||||
assert node_data.variables[1].value_selector == ["sys", "config", "app", "theme"]
|
||||
|
||||
def test_template_transform_node_data_with_multiline_template(self):
|
||||
"""Test TemplateTransformNodeData with multiline template."""
|
||||
data = {
|
||||
"title": "Multiline Template",
|
||||
"variables": [
|
||||
{"variable": "title", "value_selector": ["sys", "title"]},
|
||||
{"variable": "content", "value_selector": ["sys", "content"]},
|
||||
],
|
||||
"template": """
|
||||
# {{ title }}
|
||||
|
||||
{{ content }}
|
||||
|
||||
---
|
||||
Generated by Template Transform Node
|
||||
""",
|
||||
}
|
||||
|
||||
node_data = TemplateTransformNodeData.model_validate(data)
|
||||
|
||||
assert "# {{ title }}" in node_data.template
|
||||
assert "{{ content }}" in node_data.template
|
||||
assert "Generated by Template Transform Node" in node_data.template
|
||||
|
||||
def test_template_transform_node_data_serialization(self):
|
||||
"""Test that TemplateTransformNodeData can be serialized and deserialized."""
|
||||
original_data = {
|
||||
"title": "Serialization Test",
|
||||
"desc": "Test serialization",
|
||||
"variables": [{"variable": "test", "value_selector": ["sys", "test"]}],
|
||||
"template": "{{ test }}",
|
||||
}
|
||||
|
||||
node_data = TemplateTransformNodeData.model_validate(original_data)
|
||||
serialized = node_data.model_dump()
|
||||
deserialized = TemplateTransformNodeData.model_validate(serialized)
|
||||
|
||||
assert deserialized.title == node_data.title
|
||||
assert deserialized.desc == node_data.desc
|
||||
assert len(deserialized.variables) == len(node_data.variables)
|
||||
assert deserialized.template == node_data.template
|
||||
|
||||
def test_template_transform_node_data_with_special_characters(self):
|
||||
"""Test TemplateTransformNodeData with special characters in template."""
|
||||
data = {
|
||||
"title": "Special Characters",
|
||||
"variables": [{"variable": "text", "value_selector": ["sys", "input"]}],
|
||||
"template": "Special: {{ text }} | Symbols: @#$%^&*() | Unicode: 你好 🎉",
|
||||
}
|
||||
|
||||
node_data = TemplateTransformNodeData.model_validate(data)
|
||||
|
||||
assert "@#$%^&*()" in node_data.template
|
||||
assert "你好" in node_data.template
|
||||
assert "🎉" in node_data.template
|
||||
|
||||
def test_template_transform_node_data_empty_template(self):
|
||||
"""Test TemplateTransformNodeData with empty template string."""
|
||||
data = {
|
||||
"title": "Empty Template",
|
||||
"variables": [],
|
||||
"template": "",
|
||||
}
|
||||
|
||||
node_data = TemplateTransformNodeData.model_validate(data)
|
||||
|
||||
assert node_data.template == ""
|
||||
assert len(node_data.variables) == 0
|
||||
|
|
@ -0,0 +1,414 @@
|
|||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
|
||||
from core.helper.code_executor.code_executor import CodeExecutionError
|
||||
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode
|
||||
from models.workflow import WorkflowType
|
||||
|
||||
|
||||
class TestTemplateTransformNode:
|
||||
"""Comprehensive test suite for TemplateTransformNode."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_graph_runtime_state(self):
|
||||
"""Create a mock GraphRuntimeState with variable pool."""
|
||||
mock_state = MagicMock(spec=GraphRuntimeState)
|
||||
mock_variable_pool = MagicMock()
|
||||
mock_state.variable_pool = mock_variable_pool
|
||||
return mock_state
|
||||
|
||||
@pytest.fixture
|
||||
def mock_graph(self):
|
||||
"""Create a mock Graph."""
|
||||
return MagicMock(spec=Graph)
|
||||
|
||||
@pytest.fixture
|
||||
def graph_init_params(self):
|
||||
"""Create a mock GraphInitParams."""
|
||||
return GraphInitParams(
|
||||
tenant_id="test_tenant",
|
||||
app_id="test_app",
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
workflow_id="test_workflow",
|
||||
graph_config={},
|
||||
user_id="test_user",
|
||||
user_from="test",
|
||||
invoke_from="test",
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def basic_node_data(self):
|
||||
"""Create basic node data for testing."""
|
||||
return {
|
||||
"title": "Template Transform",
|
||||
"desc": "Transform data using template",
|
||||
"variables": [
|
||||
{"variable": "name", "value_selector": ["sys", "user_name"]},
|
||||
{"variable": "age", "value_selector": ["sys", "user_age"]},
|
||||
],
|
||||
"template": "Hello {{ name }}, you are {{ age }} years old!",
|
||||
}
|
||||
|
||||
def test_node_initialization(self, basic_node_data, mock_graph, mock_graph_runtime_state, graph_init_params):
|
||||
"""Test that TemplateTransformNode initializes correctly."""
|
||||
node = TemplateTransformNode(
|
||||
id="test_node",
|
||||
config=basic_node_data,
|
||||
graph_init_params=graph_init_params,
|
||||
graph=mock_graph,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
)
|
||||
|
||||
assert node.node_type == NodeType.TEMPLATE_TRANSFORM
|
||||
assert node._node_data.title == "Template Transform"
|
||||
assert len(node._node_data.variables) == 2
|
||||
assert node._node_data.template == "Hello {{ name }}, you are {{ age }} years old!"
|
||||
|
||||
def test_get_title(self, basic_node_data, mock_graph, mock_graph_runtime_state, graph_init_params):
|
||||
"""Test _get_title method."""
|
||||
node = TemplateTransformNode(
|
||||
id="test_node",
|
||||
config=basic_node_data,
|
||||
graph_init_params=graph_init_params,
|
||||
graph=mock_graph,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
)
|
||||
|
||||
assert node._get_title() == "Template Transform"
|
||||
|
||||
def test_get_description(self, basic_node_data, mock_graph, mock_graph_runtime_state, graph_init_params):
|
||||
"""Test _get_description method."""
|
||||
node = TemplateTransformNode(
|
||||
id="test_node",
|
||||
config=basic_node_data,
|
||||
graph_init_params=graph_init_params,
|
||||
graph=mock_graph,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
)
|
||||
|
||||
assert node._get_description() == "Transform data using template"
|
||||
|
||||
def test_get_error_strategy(self, mock_graph, mock_graph_runtime_state, graph_init_params):
|
||||
"""Test _get_error_strategy method."""
|
||||
node_data = {
|
||||
"title": "Test",
|
||||
"variables": [],
|
||||
"template": "test",
|
||||
"error_strategy": "fail-branch",
|
||||
}
|
||||
|
||||
node = TemplateTransformNode(
|
||||
id="test_node",
|
||||
config=node_data,
|
||||
graph_init_params=graph_init_params,
|
||||
graph=mock_graph,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
)
|
||||
|
||||
assert node._get_error_strategy() == ErrorStrategy.FAIL_BRANCH
|
||||
|
||||
def test_get_default_config(self):
|
||||
"""Test get_default_config class method."""
|
||||
config = TemplateTransformNode.get_default_config()
|
||||
|
||||
assert config["type"] == "template-transform"
|
||||
assert "config" in config
|
||||
assert "variables" in config["config"]
|
||||
assert "template" in config["config"]
|
||||
assert config["config"]["template"] == "{{ arg1 }}"
|
||||
|
||||
def test_version(self):
|
||||
"""Test version class method."""
|
||||
assert TemplateTransformNode.version() == "1"
|
||||
|
||||
@patch("core.workflow.nodes.template_transform.template_transform_node.CodeExecutor.execute_workflow_code_template")
|
||||
def test_run_simple_template(
|
||||
self, mock_execute, basic_node_data, mock_graph, mock_graph_runtime_state, graph_init_params
|
||||
):
|
||||
"""Test _run with simple template transformation."""
|
||||
# Setup mock variable pool
|
||||
mock_name_value = MagicMock()
|
||||
mock_name_value.to_object.return_value = "Alice"
|
||||
mock_age_value = MagicMock()
|
||||
mock_age_value.to_object.return_value = 30
|
||||
|
||||
variable_map = {
|
||||
("sys", "user_name"): mock_name_value,
|
||||
("sys", "user_age"): mock_age_value,
|
||||
}
|
||||
mock_graph_runtime_state.variable_pool.get.side_effect = lambda selector: variable_map.get(tuple(selector))
|
||||
|
||||
# Setup mock executor
|
||||
mock_execute.return_value = {"result": "Hello Alice, you are 30 years old!"}
|
||||
|
||||
node = TemplateTransformNode(
|
||||
id="test_node",
|
||||
config=basic_node_data,
|
||||
graph_init_params=graph_init_params,
|
||||
graph=mock_graph,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
)
|
||||
|
||||
result = node._run()
|
||||
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert result.outputs["output"] == "Hello Alice, you are 30 years old!"
|
||||
assert result.inputs["name"] == "Alice"
|
||||
assert result.inputs["age"] == 30
|
||||
|
||||
@patch("core.workflow.nodes.template_transform.template_transform_node.CodeExecutor.execute_workflow_code_template")
|
||||
def test_run_with_none_values(self, mock_execute, mock_graph, mock_graph_runtime_state, graph_init_params):
|
||||
"""Test _run with None variable values."""
|
||||
node_data = {
|
||||
"title": "Test",
|
||||
"variables": [{"variable": "value", "value_selector": ["sys", "missing"]}],
|
||||
"template": "Value: {{ value }}",
|
||||
}
|
||||
|
||||
mock_graph_runtime_state.variable_pool.get.return_value = None
|
||||
mock_execute.return_value = {"result": "Value: "}
|
||||
|
||||
node = TemplateTransformNode(
|
||||
id="test_node",
|
||||
config=node_data,
|
||||
graph_init_params=graph_init_params,
|
||||
graph=mock_graph,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
)
|
||||
|
||||
result = node._run()
|
||||
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert result.inputs["value"] is None
|
||||
|
||||
@patch("core.workflow.nodes.template_transform.template_transform_node.CodeExecutor.execute_workflow_code_template")
|
||||
def test_run_with_code_execution_error(
|
||||
self, mock_execute, basic_node_data, mock_graph, mock_graph_runtime_state, graph_init_params
|
||||
):
|
||||
"""Test _run when code execution fails."""
|
||||
mock_graph_runtime_state.variable_pool.get.return_value = MagicMock()
|
||||
mock_execute.side_effect = CodeExecutionError("Template syntax error")
|
||||
|
||||
node = TemplateTransformNode(
|
||||
id="test_node",
|
||||
config=basic_node_data,
|
||||
graph_init_params=graph_init_params,
|
||||
graph=mock_graph,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
)
|
||||
|
||||
result = node._run()
|
||||
|
||||
assert result.status == WorkflowNodeExecutionStatus.FAILED
|
||||
assert "Template syntax error" in result.error
|
||||
|
||||
@patch("core.workflow.nodes.template_transform.template_transform_node.CodeExecutor.execute_workflow_code_template")
|
||||
@patch("core.workflow.nodes.template_transform.template_transform_node.MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH", 10)
|
||||
def test_run_output_length_exceeds_limit(
|
||||
self, mock_execute, basic_node_data, mock_graph, mock_graph_runtime_state, graph_init_params
|
||||
):
|
||||
"""Test _run when output exceeds maximum length."""
|
||||
mock_graph_runtime_state.variable_pool.get.return_value = MagicMock()
|
||||
mock_execute.return_value = {"result": "This is a very long output that exceeds the limit"}
|
||||
|
||||
node = TemplateTransformNode(
|
||||
id="test_node",
|
||||
config=basic_node_data,
|
||||
graph_init_params=graph_init_params,
|
||||
graph=mock_graph,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
)
|
||||
|
||||
result = node._run()
|
||||
|
||||
assert result.status == WorkflowNodeExecutionStatus.FAILED
|
||||
assert "Output length exceeds" in result.error
|
||||
|
||||
@patch("core.workflow.nodes.template_transform.template_transform_node.CodeExecutor.execute_workflow_code_template")
|
||||
def test_run_with_complex_jinja2_template(
|
||||
self, mock_execute, mock_graph, mock_graph_runtime_state, graph_init_params
|
||||
):
|
||||
"""Test _run with complex Jinja2 template including loops and conditions."""
|
||||
node_data = {
|
||||
"title": "Complex Template",
|
||||
"variables": [
|
||||
{"variable": "items", "value_selector": ["sys", "items"]},
|
||||
{"variable": "show_total", "value_selector": ["sys", "show_total"]},
|
||||
],
|
||||
"template": (
|
||||
"{% for item in items %}{{ item }}{% if not loop.last %}, {% endif %}{% endfor %}"
|
||||
"{% if show_total %} (Total: {{ items|length }}){% endif %}"
|
||||
),
|
||||
}
|
||||
|
||||
mock_items = MagicMock()
|
||||
mock_items.to_object.return_value = ["apple", "banana", "orange"]
|
||||
mock_show_total = MagicMock()
|
||||
mock_show_total.to_object.return_value = True
|
||||
|
||||
variable_map = {
|
||||
("sys", "items"): mock_items,
|
||||
("sys", "show_total"): mock_show_total,
|
||||
}
|
||||
mock_graph_runtime_state.variable_pool.get.side_effect = lambda selector: variable_map.get(tuple(selector))
|
||||
mock_execute.return_value = {"result": "apple, banana, orange (Total: 3)"}
|
||||
|
||||
node = TemplateTransformNode(
|
||||
id="test_node",
|
||||
config=node_data,
|
||||
graph_init_params=graph_init_params,
|
||||
graph=mock_graph,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
)
|
||||
|
||||
result = node._run()
|
||||
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert result.outputs["output"] == "apple, banana, orange (Total: 3)"
|
||||
|
||||
def test_extract_variable_selector_to_variable_mapping(self):
|
||||
"""Test _extract_variable_selector_to_variable_mapping class method."""
|
||||
node_data = {
|
||||
"title": "Test",
|
||||
"variables": [
|
||||
{"variable": "var1", "value_selector": ["sys", "input1"]},
|
||||
{"variable": "var2", "value_selector": ["sys", "input2"]},
|
||||
],
|
||||
"template": "{{ var1 }} {{ var2 }}",
|
||||
}
|
||||
|
||||
mapping = TemplateTransformNode._extract_variable_selector_to_variable_mapping(
|
||||
graph_config={}, node_id="node_123", node_data=node_data
|
||||
)
|
||||
|
||||
assert "node_123.var1" in mapping
|
||||
assert "node_123.var2" in mapping
|
||||
assert mapping["node_123.var1"] == ["sys", "input1"]
|
||||
assert mapping["node_123.var2"] == ["sys", "input2"]
|
||||
|
||||
@patch("core.workflow.nodes.template_transform.template_transform_node.CodeExecutor.execute_workflow_code_template")
|
||||
def test_run_with_empty_variables(self, mock_execute, mock_graph, mock_graph_runtime_state, graph_init_params):
|
||||
"""Test _run with no variables (static template)."""
|
||||
node_data = {
|
||||
"title": "Static Template",
|
||||
"variables": [],
|
||||
"template": "This is a static message.",
|
||||
}
|
||||
|
||||
mock_execute.return_value = {"result": "This is a static message."}
|
||||
|
||||
node = TemplateTransformNode(
|
||||
id="test_node",
|
||||
config=node_data,
|
||||
graph_init_params=graph_init_params,
|
||||
graph=mock_graph,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
)
|
||||
|
||||
result = node._run()
|
||||
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert result.outputs["output"] == "This is a static message."
|
||||
assert result.inputs == {}
|
||||
|
||||
@patch("core.workflow.nodes.template_transform.template_transform_node.CodeExecutor.execute_workflow_code_template")
|
||||
def test_run_with_numeric_values(self, mock_execute, mock_graph, mock_graph_runtime_state, graph_init_params):
|
||||
"""Test _run with numeric variable values."""
|
||||
node_data = {
|
||||
"title": "Numeric Template",
|
||||
"variables": [
|
||||
{"variable": "price", "value_selector": ["sys", "price"]},
|
||||
{"variable": "quantity", "value_selector": ["sys", "quantity"]},
|
||||
],
|
||||
"template": "Total: ${{ price * quantity }}",
|
||||
}
|
||||
|
||||
mock_price = MagicMock()
|
||||
mock_price.to_object.return_value = 10.5
|
||||
mock_quantity = MagicMock()
|
||||
mock_quantity.to_object.return_value = 3
|
||||
|
||||
variable_map = {
|
||||
("sys", "price"): mock_price,
|
||||
("sys", "quantity"): mock_quantity,
|
||||
}
|
||||
mock_graph_runtime_state.variable_pool.get.side_effect = lambda selector: variable_map.get(tuple(selector))
|
||||
mock_execute.return_value = {"result": "Total: $31.5"}
|
||||
|
||||
node = TemplateTransformNode(
|
||||
id="test_node",
|
||||
config=node_data,
|
||||
graph_init_params=graph_init_params,
|
||||
graph=mock_graph,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
)
|
||||
|
||||
result = node._run()
|
||||
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert result.outputs["output"] == "Total: $31.5"
|
||||
|
||||
@patch("core.workflow.nodes.template_transform.template_transform_node.CodeExecutor.execute_workflow_code_template")
|
||||
def test_run_with_dict_values(self, mock_execute, mock_graph, mock_graph_runtime_state, graph_init_params):
|
||||
"""Test _run with dictionary variable values."""
|
||||
node_data = {
|
||||
"title": "Dict Template",
|
||||
"variables": [{"variable": "user", "value_selector": ["sys", "user_data"]}],
|
||||
"template": "Name: {{ user.name }}, Email: {{ user.email }}",
|
||||
}
|
||||
|
||||
mock_user = MagicMock()
|
||||
mock_user.to_object.return_value = {"name": "John Doe", "email": "john@example.com"}
|
||||
|
||||
mock_graph_runtime_state.variable_pool.get.return_value = mock_user
|
||||
mock_execute.return_value = {"result": "Name: John Doe, Email: john@example.com"}
|
||||
|
||||
node = TemplateTransformNode(
|
||||
id="test_node",
|
||||
config=node_data,
|
||||
graph_init_params=graph_init_params,
|
||||
graph=mock_graph,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
)
|
||||
|
||||
result = node._run()
|
||||
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert "John Doe" in result.outputs["output"]
|
||||
assert "john@example.com" in result.outputs["output"]
|
||||
|
||||
@patch("core.workflow.nodes.template_transform.template_transform_node.CodeExecutor.execute_workflow_code_template")
|
||||
def test_run_with_list_values(self, mock_execute, mock_graph, mock_graph_runtime_state, graph_init_params):
|
||||
"""Test _run with list variable values."""
|
||||
node_data = {
|
||||
"title": "List Template",
|
||||
"variables": [{"variable": "tags", "value_selector": ["sys", "tags"]}],
|
||||
"template": "Tags: {% for tag in tags %}#{{ tag }} {% endfor %}",
|
||||
}
|
||||
|
||||
mock_tags = MagicMock()
|
||||
mock_tags.to_object.return_value = ["python", "ai", "workflow"]
|
||||
|
||||
mock_graph_runtime_state.variable_pool.get.return_value = mock_tags
|
||||
mock_execute.return_value = {"result": "Tags: #python #ai #workflow "}
|
||||
|
||||
node = TemplateTransformNode(
|
||||
id="test_node",
|
||||
config=node_data,
|
||||
graph_init_params=graph_init_params,
|
||||
graph=mock_graph,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
)
|
||||
|
||||
result = node._run()
|
||||
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert "#python" in result.outputs["output"]
|
||||
assert "#ai" in result.outputs["output"]
|
||||
assert "#workflow" in result.outputs["output"]
|
||||
|
|
@ -0,0 +1,160 @@
|
|||
import sys
|
||||
import types
|
||||
from collections.abc import Generator
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.file import File, FileTransferMethod, FileType
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.utils.message_transformer import ToolFileMessageTransformer
|
||||
from core.variables.segments import ArrayFileSegment
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.node_events import StreamChunkEvent, StreamCompletedEvent
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
|
||||
if TYPE_CHECKING: # pragma: no cover - imported for type checking only
|
||||
from core.workflow.nodes.tool.tool_node import ToolNode
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tool_node(monkeypatch) -> "ToolNode":
|
||||
module_name = "core.ops.ops_trace_manager"
|
||||
if module_name not in sys.modules:
|
||||
ops_stub = types.ModuleType(module_name)
|
||||
ops_stub.TraceQueueManager = object # pragma: no cover - stub attribute
|
||||
ops_stub.TraceTask = object # pragma: no cover - stub attribute
|
||||
monkeypatch.setitem(sys.modules, module_name, ops_stub)
|
||||
|
||||
from core.workflow.nodes.tool.tool_node import ToolNode
|
||||
|
||||
graph_config: dict[str, Any] = {
|
||||
"nodes": [
|
||||
{
|
||||
"id": "tool-node",
|
||||
"data": {
|
||||
"type": "tool",
|
||||
"title": "Tool",
|
||||
"desc": "",
|
||||
"provider_id": "provider",
|
||||
"provider_type": "builtin",
|
||||
"provider_name": "provider",
|
||||
"tool_name": "tool",
|
||||
"tool_label": "tool",
|
||||
"tool_configurations": {},
|
||||
"tool_parameters": {},
|
||||
},
|
||||
}
|
||||
],
|
||||
"edges": [],
|
||||
}
|
||||
|
||||
init_params = GraphInitParams(
|
||||
tenant_id="tenant-id",
|
||||
app_id="app-id",
|
||||
workflow_id="workflow-id",
|
||||
graph_config=graph_config,
|
||||
user_id="user-id",
|
||||
user_from="account",
|
||||
invoke_from="debugger",
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
variable_pool = VariablePool(system_variables=SystemVariable(user_id="user-id"))
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=0.0)
|
||||
|
||||
config = graph_config["nodes"][0]
|
||||
node = ToolNode(
|
||||
id="node-instance",
|
||||
config=config,
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
node.init_node_data(config["data"])
|
||||
return node
|
||||
|
||||
|
||||
def _collect_events(generator: Generator) -> tuple[list[Any], LLMUsage]:
|
||||
events: list[Any] = []
|
||||
try:
|
||||
while True:
|
||||
events.append(next(generator))
|
||||
except StopIteration as stop:
|
||||
return events, stop.value
|
||||
|
||||
|
||||
def _run_transform(tool_node: "ToolNode", message: ToolInvokeMessage) -> tuple[list[Any], LLMUsage]:
|
||||
def _identity_transform(messages, *_args, **_kwargs):
|
||||
return messages
|
||||
|
||||
tool_runtime = MagicMock()
|
||||
with patch.object(ToolFileMessageTransformer, "transform_tool_invoke_messages", side_effect=_identity_transform):
|
||||
generator = tool_node._transform_message(
|
||||
messages=iter([message]),
|
||||
tool_info={"provider_type": "builtin", "provider_id": "provider"},
|
||||
parameters_for_log={},
|
||||
user_id="user-id",
|
||||
tenant_id="tenant-id",
|
||||
node_id=tool_node._node_id,
|
||||
tool_runtime=tool_runtime,
|
||||
)
|
||||
return _collect_events(generator)
|
||||
|
||||
|
||||
def test_link_messages_with_file_populate_files_output(tool_node: "ToolNode"):
|
||||
file_obj = File(
|
||||
tenant_id="tenant-id",
|
||||
type=FileType.DOCUMENT,
|
||||
transfer_method=FileTransferMethod.TOOL_FILE,
|
||||
related_id="file-id",
|
||||
filename="demo.pdf",
|
||||
extension=".pdf",
|
||||
mime_type="application/pdf",
|
||||
size=123,
|
||||
storage_key="file-key",
|
||||
)
|
||||
message = ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.LINK,
|
||||
message=ToolInvokeMessage.TextMessage(text="/files/tools/file-id.pdf"),
|
||||
meta={"file": file_obj},
|
||||
)
|
||||
|
||||
events, usage = _run_transform(tool_node, message)
|
||||
|
||||
assert isinstance(usage, LLMUsage)
|
||||
|
||||
chunk_events = [event for event in events if isinstance(event, StreamChunkEvent)]
|
||||
assert chunk_events
|
||||
assert chunk_events[0].chunk == "File: /files/tools/file-id.pdf\n"
|
||||
|
||||
completed_events = [event for event in events if isinstance(event, StreamCompletedEvent)]
|
||||
assert len(completed_events) == 1
|
||||
outputs = completed_events[0].node_run_result.outputs
|
||||
assert outputs["text"] == "File: /files/tools/file-id.pdf\n"
|
||||
|
||||
files_segment = outputs["files"]
|
||||
assert isinstance(files_segment, ArrayFileSegment)
|
||||
assert files_segment.value == [file_obj]
|
||||
|
||||
|
||||
def test_plain_link_messages_remain_links(tool_node: "ToolNode"):
|
||||
message = ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.LINK,
|
||||
message=ToolInvokeMessage.TextMessage(text="https://dify.ai"),
|
||||
meta=None,
|
||||
)
|
||||
|
||||
events, _ = _run_transform(tool_node, message)
|
||||
|
||||
chunk_events = [event for event in events if isinstance(event, StreamChunkEvent)]
|
||||
assert chunk_events
|
||||
assert chunk_events[0].chunk == "Link: https://dify.ai\n"
|
||||
|
||||
completed_events = [event for event in events if isinstance(event, StreamCompletedEvent)]
|
||||
assert len(completed_events) == 1
|
||||
files_segment = completed_events[0].node_run_result.outputs["files"]
|
||||
assert isinstance(files_segment, ArrayFileSegment)
|
||||
assert files_segment.value == []
|
||||
|
|
@ -0,0 +1,966 @@
|
|||
"""
|
||||
Comprehensive unit tests for Tool models.
|
||||
|
||||
This test suite covers:
|
||||
- ToolProvider model validation (BuiltinToolProvider, ApiToolProvider)
|
||||
- BuiltinToolProvider relationships and credential management
|
||||
- ApiToolProvider credential storage and encryption
|
||||
- Tool OAuth client models
|
||||
- ToolLabelBinding relationships
|
||||
"""
|
||||
|
||||
import json
|
||||
from uuid import uuid4
|
||||
|
||||
from core.tools.entities.tool_entities import ApiProviderSchemaType
|
||||
from models.tools import (
|
||||
ApiToolProvider,
|
||||
BuiltinToolProvider,
|
||||
ToolLabelBinding,
|
||||
ToolOAuthSystemClient,
|
||||
ToolOAuthTenantClient,
|
||||
)
|
||||
|
||||
|
||||
class TestBuiltinToolProviderValidation:
|
||||
"""Test suite for BuiltinToolProvider model validation and operations."""
|
||||
|
||||
def test_builtin_tool_provider_creation_with_required_fields(self):
|
||||
"""Test creating a builtin tool provider with all required fields."""
|
||||
# Arrange
|
||||
tenant_id = str(uuid4())
|
||||
user_id = str(uuid4())
|
||||
provider_name = "google"
|
||||
credentials = {"api_key": "test_key_123"}
|
||||
|
||||
# Act
|
||||
builtin_provider = BuiltinToolProvider(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
provider=provider_name,
|
||||
encrypted_credentials=json.dumps(credentials),
|
||||
name="Google API Key 1",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert builtin_provider.tenant_id == tenant_id
|
||||
assert builtin_provider.user_id == user_id
|
||||
assert builtin_provider.provider == provider_name
|
||||
assert builtin_provider.name == "Google API Key 1"
|
||||
assert builtin_provider.encrypted_credentials == json.dumps(credentials)
|
||||
|
||||
def test_builtin_tool_provider_credentials_property(self):
|
||||
"""Test credentials property parses JSON correctly."""
|
||||
# Arrange
|
||||
credentials_data = {
|
||||
"api_key": "sk-test123",
|
||||
"auth_type": "api_key",
|
||||
"endpoint": "https://api.example.com",
|
||||
}
|
||||
builtin_provider = BuiltinToolProvider(
|
||||
tenant_id=str(uuid4()),
|
||||
user_id=str(uuid4()),
|
||||
provider="custom_provider",
|
||||
name="Custom Provider Key",
|
||||
encrypted_credentials=json.dumps(credentials_data),
|
||||
)
|
||||
|
||||
# Act
|
||||
result = builtin_provider.credentials
|
||||
|
||||
# Assert
|
||||
assert result == credentials_data
|
||||
assert result["api_key"] == "sk-test123"
|
||||
assert result["auth_type"] == "api_key"
|
||||
|
||||
def test_builtin_tool_provider_credentials_empty_when_none(self):
|
||||
"""Test credentials property returns empty dict when encrypted_credentials is None."""
|
||||
# Arrange
|
||||
builtin_provider = BuiltinToolProvider(
|
||||
tenant_id=str(uuid4()),
|
||||
user_id=str(uuid4()),
|
||||
provider="test_provider",
|
||||
name="Test Provider",
|
||||
encrypted_credentials=None,
|
||||
)
|
||||
|
||||
# Act
|
||||
result = builtin_provider.credentials
|
||||
|
||||
# Assert
|
||||
assert result == {}
|
||||
|
||||
def test_builtin_tool_provider_credentials_empty_when_empty_string(self):
|
||||
"""Test credentials property returns empty dict when encrypted_credentials is empty."""
|
||||
# Arrange
|
||||
builtin_provider = BuiltinToolProvider(
|
||||
tenant_id=str(uuid4()),
|
||||
user_id=str(uuid4()),
|
||||
provider="test_provider",
|
||||
name="Test Provider",
|
||||
encrypted_credentials="",
|
||||
)
|
||||
|
||||
# Act
|
||||
result = builtin_provider.credentials
|
||||
|
||||
# Assert
|
||||
assert result == {}
|
||||
|
||||
def test_builtin_tool_provider_default_values(self):
|
||||
"""Test builtin tool provider default values."""
|
||||
# Arrange & Act
|
||||
builtin_provider = BuiltinToolProvider(
|
||||
tenant_id=str(uuid4()),
|
||||
user_id=str(uuid4()),
|
||||
provider="test_provider",
|
||||
name="Test Provider",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert builtin_provider.is_default is False
|
||||
assert builtin_provider.credential_type == "api-key"
|
||||
assert builtin_provider.expires_at == -1
|
||||
|
||||
def test_builtin_tool_provider_with_oauth_credential_type(self):
|
||||
"""Test builtin tool provider with OAuth credential type."""
|
||||
# Arrange
|
||||
credentials = {
|
||||
"access_token": "oauth_token_123",
|
||||
"refresh_token": "refresh_token_456",
|
||||
"token_type": "Bearer",
|
||||
}
|
||||
|
||||
# Act
|
||||
builtin_provider = BuiltinToolProvider(
|
||||
tenant_id=str(uuid4()),
|
||||
user_id=str(uuid4()),
|
||||
provider="google",
|
||||
name="Google OAuth",
|
||||
encrypted_credentials=json.dumps(credentials),
|
||||
credential_type="oauth2",
|
||||
expires_at=1735689600,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert builtin_provider.credential_type == "oauth2"
|
||||
assert builtin_provider.expires_at == 1735689600
|
||||
assert builtin_provider.credentials["access_token"] == "oauth_token_123"
|
||||
|
||||
def test_builtin_tool_provider_is_default_flag(self):
|
||||
"""Test is_default flag for builtin tool provider."""
|
||||
# Arrange
|
||||
provider1 = BuiltinToolProvider(
|
||||
tenant_id=str(uuid4()),
|
||||
user_id=str(uuid4()),
|
||||
provider="google",
|
||||
name="Google Key 1",
|
||||
is_default=True,
|
||||
)
|
||||
provider2 = BuiltinToolProvider(
|
||||
tenant_id=str(uuid4()),
|
||||
user_id=str(uuid4()),
|
||||
provider="google",
|
||||
name="Google Key 2",
|
||||
is_default=False,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert provider1.is_default is True
|
||||
assert provider2.is_default is False
|
||||
|
||||
def test_builtin_tool_provider_unique_constraint_fields(self):
|
||||
"""Test unique constraint fields (tenant_id, provider, name)."""
|
||||
# Arrange
|
||||
tenant_id = str(uuid4())
|
||||
provider_name = "google"
|
||||
credential_name = "My Google Key"
|
||||
|
||||
# Act
|
||||
builtin_provider = BuiltinToolProvider(
|
||||
tenant_id=tenant_id,
|
||||
user_id=str(uuid4()),
|
||||
provider=provider_name,
|
||||
name=credential_name,
|
||||
)
|
||||
|
||||
# Assert - these fields form unique constraint
|
||||
assert builtin_provider.tenant_id == tenant_id
|
||||
assert builtin_provider.provider == provider_name
|
||||
assert builtin_provider.name == credential_name
|
||||
|
||||
def test_builtin_tool_provider_multiple_credentials_same_provider(self):
|
||||
"""Test multiple credential sets for the same provider."""
|
||||
# Arrange
|
||||
tenant_id = str(uuid4())
|
||||
user_id = str(uuid4())
|
||||
provider = "openai"
|
||||
|
||||
# Act - create multiple credentials for same provider
|
||||
provider1 = BuiltinToolProvider(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
provider=provider,
|
||||
name="OpenAI Key 1",
|
||||
encrypted_credentials=json.dumps({"api_key": "key1"}),
|
||||
)
|
||||
provider2 = BuiltinToolProvider(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
provider=provider,
|
||||
name="OpenAI Key 2",
|
||||
encrypted_credentials=json.dumps({"api_key": "key2"}),
|
||||
)
|
||||
|
||||
# Assert - different names allow multiple credentials
|
||||
assert provider1.provider == provider2.provider
|
||||
assert provider1.name != provider2.name
|
||||
assert provider1.credentials != provider2.credentials
|
||||
|
||||
|
||||
class TestApiToolProviderValidation:
|
||||
"""Test suite for ApiToolProvider model validation and operations."""
|
||||
|
||||
def test_api_tool_provider_creation_with_required_fields(self):
|
||||
"""Test creating an API tool provider with all required fields."""
|
||||
# Arrange
|
||||
tenant_id = str(uuid4())
|
||||
user_id = str(uuid4())
|
||||
provider_name = "Custom API"
|
||||
schema = '{"openapi": "3.0.0", "info": {"title": "Test API"}}'
|
||||
tools = [{"name": "test_tool", "description": "A test tool"}]
|
||||
credentials = {"auth_type": "api_key", "api_key_value": "test123"}
|
||||
|
||||
# Act
|
||||
api_provider = ApiToolProvider(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
name=provider_name,
|
||||
icon='{"type": "emoji", "value": "🔧"}',
|
||||
schema=schema,
|
||||
schema_type_str="openapi",
|
||||
description="Custom API for testing",
|
||||
tools_str=json.dumps(tools),
|
||||
credentials_str=json.dumps(credentials),
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert api_provider.tenant_id == tenant_id
|
||||
assert api_provider.user_id == user_id
|
||||
assert api_provider.name == provider_name
|
||||
assert api_provider.schema == schema
|
||||
assert api_provider.schema_type_str == "openapi"
|
||||
assert api_provider.description == "Custom API for testing"
|
||||
|
||||
def test_api_tool_provider_schema_type_property(self):
|
||||
"""Test schema_type property converts string to enum."""
|
||||
# Arrange
|
||||
api_provider = ApiToolProvider(
|
||||
tenant_id=str(uuid4()),
|
||||
user_id=str(uuid4()),
|
||||
name="Test API",
|
||||
icon="{}",
|
||||
schema="{}",
|
||||
schema_type_str="openapi",
|
||||
description="Test",
|
||||
tools_str="[]",
|
||||
credentials_str="{}",
|
||||
)
|
||||
|
||||
# Act
|
||||
result = api_provider.schema_type
|
||||
|
||||
# Assert
|
||||
assert result == ApiProviderSchemaType.OPENAPI
|
||||
|
||||
def test_api_tool_provider_tools_property(self):
|
||||
"""Test tools property parses JSON and returns ApiToolBundle list."""
|
||||
# Arrange
|
||||
tools_data = [
|
||||
{
|
||||
"author": "test",
|
||||
"server_url": "https://api.weather.com",
|
||||
"method": "get",
|
||||
"summary": "Get weather information",
|
||||
"operation_id": "getWeather",
|
||||
"parameters": [],
|
||||
"openapi": {
|
||||
"operation_id": "getWeather",
|
||||
"parameters": [],
|
||||
"method": "get",
|
||||
"path": "/weather",
|
||||
"server_url": "https://api.weather.com",
|
||||
},
|
||||
},
|
||||
{
|
||||
"author": "test",
|
||||
"server_url": "https://api.location.com",
|
||||
"method": "get",
|
||||
"summary": "Get location data",
|
||||
"operation_id": "getLocation",
|
||||
"parameters": [],
|
||||
"openapi": {
|
||||
"operation_id": "getLocation",
|
||||
"parameters": [],
|
||||
"method": "get",
|
||||
"path": "/location",
|
||||
"server_url": "https://api.location.com",
|
||||
},
|
||||
},
|
||||
]
|
||||
api_provider = ApiToolProvider(
|
||||
tenant_id=str(uuid4()),
|
||||
user_id=str(uuid4()),
|
||||
name="Weather API",
|
||||
icon="{}",
|
||||
schema="{}",
|
||||
schema_type_str="openapi",
|
||||
description="Weather API",
|
||||
tools_str=json.dumps(tools_data),
|
||||
credentials_str="{}",
|
||||
)
|
||||
|
||||
# Act
|
||||
result = api_provider.tools
|
||||
|
||||
# Assert
|
||||
assert len(result) == 2
|
||||
assert result[0].operation_id == "getWeather"
|
||||
assert result[1].operation_id == "getLocation"
|
||||
|
||||
def test_api_tool_provider_credentials_property(self):
|
||||
"""Test credentials property parses JSON correctly."""
|
||||
# Arrange
|
||||
credentials_data = {
|
||||
"auth_type": "api_key_header",
|
||||
"api_key_header": "Authorization",
|
||||
"api_key_value": "Bearer test_token",
|
||||
"api_key_header_prefix": "bearer",
|
||||
}
|
||||
api_provider = ApiToolProvider(
|
||||
tenant_id=str(uuid4()),
|
||||
user_id=str(uuid4()),
|
||||
name="Secure API",
|
||||
icon="{}",
|
||||
schema="{}",
|
||||
schema_type_str="openapi",
|
||||
description="Secure API",
|
||||
tools_str="[]",
|
||||
credentials_str=json.dumps(credentials_data),
|
||||
)
|
||||
|
||||
# Act
|
||||
result = api_provider.credentials
|
||||
|
||||
# Assert
|
||||
assert result["auth_type"] == "api_key_header"
|
||||
assert result["api_key_header"] == "Authorization"
|
||||
assert result["api_key_value"] == "Bearer test_token"
|
||||
|
||||
def test_api_tool_provider_with_privacy_policy(self):
|
||||
"""Test API tool provider with privacy policy."""
|
||||
# Arrange
|
||||
privacy_policy_url = "https://example.com/privacy"
|
||||
|
||||
# Act
|
||||
api_provider = ApiToolProvider(
|
||||
tenant_id=str(uuid4()),
|
||||
user_id=str(uuid4()),
|
||||
name="Privacy API",
|
||||
icon="{}",
|
||||
schema="{}",
|
||||
schema_type_str="openapi",
|
||||
description="API with privacy policy",
|
||||
tools_str="[]",
|
||||
credentials_str="{}",
|
||||
privacy_policy=privacy_policy_url,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert api_provider.privacy_policy == privacy_policy_url
|
||||
|
||||
def test_api_tool_provider_with_custom_disclaimer(self):
|
||||
"""Test API tool provider with custom disclaimer."""
|
||||
# Arrange
|
||||
disclaimer = "This API is provided as-is without warranty."
|
||||
|
||||
# Act
|
||||
api_provider = ApiToolProvider(
|
||||
tenant_id=str(uuid4()),
|
||||
user_id=str(uuid4()),
|
||||
name="Disclaimer API",
|
||||
icon="{}",
|
||||
schema="{}",
|
||||
schema_type_str="openapi",
|
||||
description="API with disclaimer",
|
||||
tools_str="[]",
|
||||
credentials_str="{}",
|
||||
custom_disclaimer=disclaimer,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert api_provider.custom_disclaimer == disclaimer
|
||||
|
||||
def test_api_tool_provider_default_custom_disclaimer(self):
|
||||
"""Test API tool provider default custom_disclaimer is empty string."""
|
||||
# Arrange & Act
|
||||
api_provider = ApiToolProvider(
|
||||
tenant_id=str(uuid4()),
|
||||
user_id=str(uuid4()),
|
||||
name="Default API",
|
||||
icon="{}",
|
||||
schema="{}",
|
||||
schema_type_str="openapi",
|
||||
description="API",
|
||||
tools_str="[]",
|
||||
credentials_str="{}",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert api_provider.custom_disclaimer == ""
|
||||
|
||||
def test_api_tool_provider_unique_constraint_fields(self):
|
||||
"""Test unique constraint fields (name, tenant_id)."""
|
||||
# Arrange
|
||||
tenant_id = str(uuid4())
|
||||
provider_name = "Unique API"
|
||||
|
||||
# Act
|
||||
api_provider = ApiToolProvider(
|
||||
tenant_id=tenant_id,
|
||||
user_id=str(uuid4()),
|
||||
name=provider_name,
|
||||
icon="{}",
|
||||
schema="{}",
|
||||
schema_type_str="openapi",
|
||||
description="Unique API",
|
||||
tools_str="[]",
|
||||
credentials_str="{}",
|
||||
)
|
||||
|
||||
# Assert - these fields form unique constraint
|
||||
assert api_provider.tenant_id == tenant_id
|
||||
assert api_provider.name == provider_name
|
||||
|
||||
def test_api_tool_provider_with_no_auth(self):
|
||||
"""Test API tool provider with no authentication."""
|
||||
# Arrange
|
||||
credentials = {"auth_type": "none"}
|
||||
|
||||
# Act
|
||||
api_provider = ApiToolProvider(
|
||||
tenant_id=str(uuid4()),
|
||||
user_id=str(uuid4()),
|
||||
name="Public API",
|
||||
icon="{}",
|
||||
schema="{}",
|
||||
schema_type_str="openapi",
|
||||
description="Public API with no auth",
|
||||
tools_str="[]",
|
||||
credentials_str=json.dumps(credentials),
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert api_provider.credentials["auth_type"] == "none"
|
||||
|
||||
def test_api_tool_provider_with_api_key_query_auth(self):
|
||||
"""Test API tool provider with API key in query parameter."""
|
||||
# Arrange
|
||||
credentials = {
|
||||
"auth_type": "api_key_query",
|
||||
"api_key_query_param": "apikey",
|
||||
"api_key_value": "my_secret_key",
|
||||
}
|
||||
|
||||
# Act
|
||||
api_provider = ApiToolProvider(
|
||||
tenant_id=str(uuid4()),
|
||||
user_id=str(uuid4()),
|
||||
name="Query Auth API",
|
||||
icon="{}",
|
||||
schema="{}",
|
||||
schema_type_str="openapi",
|
||||
description="API with query auth",
|
||||
tools_str="[]",
|
||||
credentials_str=json.dumps(credentials),
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert api_provider.credentials["auth_type"] == "api_key_query"
|
||||
assert api_provider.credentials["api_key_query_param"] == "apikey"
|
||||
|
||||
|
||||
class TestToolOAuthModels:
|
||||
"""Test suite for OAuth client models (system and tenant level)."""
|
||||
|
||||
def test_oauth_system_client_creation(self):
|
||||
"""Test creating a system-level OAuth client."""
|
||||
# Arrange
|
||||
plugin_id = "builtin.google"
|
||||
provider = "google"
|
||||
oauth_params = json.dumps(
|
||||
{"client_id": "system_client_id", "client_secret": "system_secret", "scope": "email profile"}
|
||||
)
|
||||
|
||||
# Act
|
||||
oauth_client = ToolOAuthSystemClient(
|
||||
plugin_id=plugin_id,
|
||||
provider=provider,
|
||||
encrypted_oauth_params=oauth_params,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert oauth_client.plugin_id == plugin_id
|
||||
assert oauth_client.provider == provider
|
||||
assert oauth_client.encrypted_oauth_params == oauth_params
|
||||
|
||||
def test_oauth_system_client_unique_constraint(self):
|
||||
"""Test unique constraint on plugin_id and provider."""
|
||||
# Arrange
|
||||
plugin_id = "builtin.github"
|
||||
provider = "github"
|
||||
|
||||
# Act
|
||||
oauth_client = ToolOAuthSystemClient(
|
||||
plugin_id=plugin_id,
|
||||
provider=provider,
|
||||
encrypted_oauth_params="{}",
|
||||
)
|
||||
|
||||
# Assert - these fields form unique constraint
|
||||
assert oauth_client.plugin_id == plugin_id
|
||||
assert oauth_client.provider == provider
|
||||
|
||||
def test_oauth_tenant_client_creation(self):
|
||||
"""Test creating a tenant-level OAuth client."""
|
||||
# Arrange
|
||||
tenant_id = str(uuid4())
|
||||
plugin_id = "builtin.google"
|
||||
provider = "google"
|
||||
|
||||
# Act
|
||||
oauth_client = ToolOAuthTenantClient(
|
||||
tenant_id=tenant_id,
|
||||
plugin_id=plugin_id,
|
||||
provider=provider,
|
||||
)
|
||||
# Set encrypted_oauth_params after creation (it has init=False)
|
||||
oauth_params = json.dumps({"client_id": "tenant_client_id", "client_secret": "tenant_secret"})
|
||||
oauth_client.encrypted_oauth_params = oauth_params
|
||||
|
||||
# Assert
|
||||
assert oauth_client.tenant_id == tenant_id
|
||||
assert oauth_client.plugin_id == plugin_id
|
||||
assert oauth_client.provider == provider
|
||||
|
||||
def test_oauth_tenant_client_enabled_default(self):
|
||||
"""Test OAuth tenant client enabled flag has init=False and uses server default."""
|
||||
# Arrange & Act
|
||||
oauth_client = ToolOAuthTenantClient(
|
||||
tenant_id=str(uuid4()),
|
||||
plugin_id="builtin.slack",
|
||||
provider="slack",
|
||||
)
|
||||
|
||||
# Assert - enabled has init=False, so it won't be set until saved to DB
|
||||
# We can manually set it to test the field exists
|
||||
oauth_client.enabled = True
|
||||
assert oauth_client.enabled is True
|
||||
|
||||
def test_oauth_tenant_client_oauth_params_property(self):
|
||||
"""Test oauth_params property parses JSON correctly."""
|
||||
# Arrange
|
||||
params_data = {
|
||||
"client_id": "test_client_123",
|
||||
"client_secret": "secret_456",
|
||||
"redirect_uri": "https://app.example.com/callback",
|
||||
}
|
||||
oauth_client = ToolOAuthTenantClient(
|
||||
tenant_id=str(uuid4()),
|
||||
plugin_id="builtin.dropbox",
|
||||
provider="dropbox",
|
||||
)
|
||||
# Set encrypted_oauth_params after creation (it has init=False)
|
||||
oauth_client.encrypted_oauth_params = json.dumps(params_data)
|
||||
|
||||
# Act
|
||||
result = oauth_client.oauth_params
|
||||
|
||||
# Assert
|
||||
assert result == params_data
|
||||
assert result["client_id"] == "test_client_123"
|
||||
assert result["redirect_uri"] == "https://app.example.com/callback"
|
||||
|
||||
def test_oauth_tenant_client_oauth_params_empty_when_none(self):
|
||||
"""Test oauth_params returns empty dict when encrypted_oauth_params is None."""
|
||||
# Arrange
|
||||
oauth_client = ToolOAuthTenantClient(
|
||||
tenant_id=str(uuid4()),
|
||||
plugin_id="builtin.test",
|
||||
provider="test",
|
||||
)
|
||||
# encrypted_oauth_params has init=False, set it to None
|
||||
oauth_client.encrypted_oauth_params = None
|
||||
|
||||
# Act
|
||||
result = oauth_client.oauth_params
|
||||
|
||||
# Assert
|
||||
assert result == {}
|
||||
|
||||
def test_oauth_tenant_client_disabled_state(self):
|
||||
"""Test OAuth tenant client can be disabled."""
|
||||
# Arrange
|
||||
oauth_client = ToolOAuthTenantClient(
|
||||
tenant_id=str(uuid4()),
|
||||
plugin_id="builtin.microsoft",
|
||||
provider="microsoft",
|
||||
)
|
||||
|
||||
# Act
|
||||
oauth_client.enabled = False
|
||||
|
||||
# Assert
|
||||
assert oauth_client.enabled is False
|
||||
|
||||
|
||||
class TestToolLabelBinding:
|
||||
"""Test suite for ToolLabelBinding model."""
|
||||
|
||||
def test_tool_label_binding_creation(self):
|
||||
"""Test creating a tool label binding."""
|
||||
# Arrange
|
||||
tool_id = "google.search"
|
||||
tool_type = "builtin"
|
||||
label_name = "search"
|
||||
|
||||
# Act
|
||||
label_binding = ToolLabelBinding(
|
||||
tool_id=tool_id,
|
||||
tool_type=tool_type,
|
||||
label_name=label_name,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert label_binding.tool_id == tool_id
|
||||
assert label_binding.tool_type == tool_type
|
||||
assert label_binding.label_name == label_name
|
||||
|
||||
def test_tool_label_binding_unique_constraint(self):
|
||||
"""Test unique constraint on tool_id and label_name."""
|
||||
# Arrange
|
||||
tool_id = "openai.text_generation"
|
||||
label_name = "text"
|
||||
|
||||
# Act
|
||||
label_binding = ToolLabelBinding(
|
||||
tool_id=tool_id,
|
||||
tool_type="builtin",
|
||||
label_name=label_name,
|
||||
)
|
||||
|
||||
# Assert - these fields form unique constraint
|
||||
assert label_binding.tool_id == tool_id
|
||||
assert label_binding.label_name == label_name
|
||||
|
||||
def test_tool_label_binding_multiple_labels_same_tool(self):
|
||||
"""Test multiple labels can be bound to the same tool."""
|
||||
# Arrange
|
||||
tool_id = "google.search"
|
||||
tool_type = "builtin"
|
||||
|
||||
# Act
|
||||
binding1 = ToolLabelBinding(
|
||||
tool_id=tool_id,
|
||||
tool_type=tool_type,
|
||||
label_name="search",
|
||||
)
|
||||
binding2 = ToolLabelBinding(
|
||||
tool_id=tool_id,
|
||||
tool_type=tool_type,
|
||||
label_name="productivity",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert binding1.tool_id == binding2.tool_id
|
||||
assert binding1.label_name != binding2.label_name
|
||||
|
||||
def test_tool_label_binding_different_tool_types(self):
|
||||
"""Test label bindings for different tool types."""
|
||||
# Arrange
|
||||
tool_types = ["builtin", "api", "workflow"]
|
||||
|
||||
# Act & Assert
|
||||
for tool_type in tool_types:
|
||||
binding = ToolLabelBinding(
|
||||
tool_id=f"test_tool_{tool_type}",
|
||||
tool_type=tool_type,
|
||||
label_name="test",
|
||||
)
|
||||
assert binding.tool_type == tool_type
|
||||
|
||||
|
||||
class TestCredentialStorage:
|
||||
"""Test suite for credential storage and encryption patterns."""
|
||||
|
||||
def test_builtin_provider_credential_storage_format(self):
|
||||
"""Test builtin provider stores credentials as JSON string."""
|
||||
# Arrange
|
||||
credentials = {
|
||||
"api_key": "sk-test123",
|
||||
"endpoint": "https://api.example.com",
|
||||
"timeout": 30,
|
||||
}
|
||||
|
||||
# Act
|
||||
provider = BuiltinToolProvider(
|
||||
tenant_id=str(uuid4()),
|
||||
user_id=str(uuid4()),
|
||||
provider="test",
|
||||
name="Test Provider",
|
||||
encrypted_credentials=json.dumps(credentials),
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(provider.encrypted_credentials, str)
|
||||
assert provider.credentials == credentials
|
||||
|
||||
def test_api_provider_credential_storage_format(self):
|
||||
"""Test API provider stores credentials as JSON string."""
|
||||
# Arrange
|
||||
credentials = {
|
||||
"auth_type": "api_key_header",
|
||||
"api_key_header": "X-API-Key",
|
||||
"api_key_value": "secret_key_789",
|
||||
}
|
||||
|
||||
# Act
|
||||
provider = ApiToolProvider(
|
||||
tenant_id=str(uuid4()),
|
||||
user_id=str(uuid4()),
|
||||
name="Test API",
|
||||
icon="{}",
|
||||
schema="{}",
|
||||
schema_type_str="openapi",
|
||||
description="Test",
|
||||
tools_str="[]",
|
||||
credentials_str=json.dumps(credentials),
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(provider.credentials_str, str)
|
||||
assert provider.credentials == credentials
|
||||
|
||||
def test_builtin_provider_complex_credential_structure(self):
|
||||
"""Test builtin provider with complex nested credential structure."""
|
||||
# Arrange
|
||||
credentials = {
|
||||
"auth_type": "oauth2",
|
||||
"oauth_config": {
|
||||
"access_token": "token123",
|
||||
"refresh_token": "refresh456",
|
||||
"expires_in": 3600,
|
||||
"token_type": "Bearer",
|
||||
},
|
||||
"additional_headers": {"X-Custom-Header": "value"},
|
||||
}
|
||||
|
||||
# Act
|
||||
provider = BuiltinToolProvider(
|
||||
tenant_id=str(uuid4()),
|
||||
user_id=str(uuid4()),
|
||||
provider="oauth_provider",
|
||||
name="OAuth Provider",
|
||||
encrypted_credentials=json.dumps(credentials),
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert provider.credentials["oauth_config"]["access_token"] == "token123"
|
||||
assert provider.credentials["additional_headers"]["X-Custom-Header"] == "value"
|
||||
|
||||
def test_api_provider_credential_update_pattern(self):
|
||||
"""Test pattern for updating API provider credentials."""
|
||||
# Arrange
|
||||
original_credentials = {"auth_type": "api_key_header", "api_key_value": "old_key"}
|
||||
provider = ApiToolProvider(
|
||||
tenant_id=str(uuid4()),
|
||||
user_id=str(uuid4()),
|
||||
name="Update Test",
|
||||
icon="{}",
|
||||
schema="{}",
|
||||
schema_type_str="openapi",
|
||||
description="Test",
|
||||
tools_str="[]",
|
||||
credentials_str=json.dumps(original_credentials),
|
||||
)
|
||||
|
||||
# Act - simulate credential update
|
||||
new_credentials = {"auth_type": "api_key_header", "api_key_value": "new_key"}
|
||||
provider.credentials_str = json.dumps(new_credentials)
|
||||
|
||||
# Assert
|
||||
assert provider.credentials["api_key_value"] == "new_key"
|
||||
|
||||
def test_builtin_provider_credential_expiration(self):
|
||||
"""Test builtin provider credential expiration tracking."""
|
||||
# Arrange
|
||||
future_timestamp = 1735689600 # Future date
|
||||
past_timestamp = 1609459200 # Past date
|
||||
|
||||
# Act
|
||||
active_provider = BuiltinToolProvider(
|
||||
tenant_id=str(uuid4()),
|
||||
user_id=str(uuid4()),
|
||||
provider="active",
|
||||
name="Active Provider",
|
||||
expires_at=future_timestamp,
|
||||
)
|
||||
expired_provider = BuiltinToolProvider(
|
||||
tenant_id=str(uuid4()),
|
||||
user_id=str(uuid4()),
|
||||
provider="expired",
|
||||
name="Expired Provider",
|
||||
expires_at=past_timestamp,
|
||||
)
|
||||
never_expires_provider = BuiltinToolProvider(
|
||||
tenant_id=str(uuid4()),
|
||||
user_id=str(uuid4()),
|
||||
provider="permanent",
|
||||
name="Permanent Provider",
|
||||
expires_at=-1,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert active_provider.expires_at == future_timestamp
|
||||
assert expired_provider.expires_at == past_timestamp
|
||||
assert never_expires_provider.expires_at == -1
|
||||
|
||||
def test_oauth_client_credential_storage(self):
|
||||
"""Test OAuth client credential storage pattern."""
|
||||
# Arrange
|
||||
oauth_credentials = {
|
||||
"client_id": "oauth_client_123",
|
||||
"client_secret": "oauth_secret_456",
|
||||
"authorization_url": "https://oauth.example.com/authorize",
|
||||
"token_url": "https://oauth.example.com/token",
|
||||
"scope": "read write",
|
||||
}
|
||||
|
||||
# Act
|
||||
system_client = ToolOAuthSystemClient(
|
||||
plugin_id="builtin.oauth_test",
|
||||
provider="oauth_test",
|
||||
encrypted_oauth_params=json.dumps(oauth_credentials),
|
||||
)
|
||||
|
||||
tenant_client = ToolOAuthTenantClient(
|
||||
tenant_id=str(uuid4()),
|
||||
plugin_id="builtin.oauth_test",
|
||||
provider="oauth_test",
|
||||
)
|
||||
# Set encrypted_oauth_params after creation (it has init=False)
|
||||
tenant_client.encrypted_oauth_params = json.dumps(oauth_credentials)
|
||||
|
||||
# Assert
|
||||
assert system_client.encrypted_oauth_params == json.dumps(oauth_credentials)
|
||||
assert tenant_client.oauth_params == oauth_credentials
|
||||
|
||||
|
||||
class TestToolProviderRelationships:
|
||||
"""Test suite for tool provider relationships and associations."""
|
||||
|
||||
def test_builtin_provider_tenant_relationship(self):
|
||||
"""Test builtin provider belongs to a tenant."""
|
||||
# Arrange
|
||||
tenant_id = str(uuid4())
|
||||
|
||||
# Act
|
||||
provider = BuiltinToolProvider(
|
||||
tenant_id=tenant_id,
|
||||
user_id=str(uuid4()),
|
||||
provider="test",
|
||||
name="Test Provider",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert provider.tenant_id == tenant_id
|
||||
|
||||
def test_api_provider_user_relationship(self):
|
||||
"""Test API provider belongs to a user."""
|
||||
# Arrange
|
||||
user_id = str(uuid4())
|
||||
|
||||
# Act
|
||||
provider = ApiToolProvider(
|
||||
tenant_id=str(uuid4()),
|
||||
user_id=user_id,
|
||||
name="User API",
|
||||
icon="{}",
|
||||
schema="{}",
|
||||
schema_type_str="openapi",
|
||||
description="Test",
|
||||
tools_str="[]",
|
||||
credentials_str="{}",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert provider.user_id == user_id
|
||||
|
||||
def test_multiple_providers_same_tenant(self):
|
||||
"""Test multiple providers can belong to the same tenant."""
|
||||
# Arrange
|
||||
tenant_id = str(uuid4())
|
||||
user_id = str(uuid4())
|
||||
|
||||
# Act
|
||||
builtin1 = BuiltinToolProvider(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
provider="google",
|
||||
name="Google Key 1",
|
||||
)
|
||||
builtin2 = BuiltinToolProvider(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
provider="openai",
|
||||
name="OpenAI Key 1",
|
||||
)
|
||||
api1 = ApiToolProvider(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
name="Custom API 1",
|
||||
icon="{}",
|
||||
schema="{}",
|
||||
schema_type_str="openapi",
|
||||
description="Test",
|
||||
tools_str="[]",
|
||||
credentials_str="{}",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert builtin1.tenant_id == tenant_id
|
||||
assert builtin2.tenant_id == tenant_id
|
||||
assert api1.tenant_id == tenant_id
|
||||
|
||||
def test_tool_label_bindings_for_provider_tools(self):
|
||||
"""Test tool label bindings can be associated with provider tools."""
|
||||
# Arrange
|
||||
provider_name = "google"
|
||||
tool_id = f"{provider_name}.search"
|
||||
|
||||
# Act
|
||||
binding1 = ToolLabelBinding(
|
||||
tool_id=tool_id,
|
||||
tool_type="builtin",
|
||||
label_name="search",
|
||||
)
|
||||
binding2 = ToolLabelBinding(
|
||||
tool_id=tool_id,
|
||||
tool_type="builtin",
|
||||
label_name="web",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert binding1.tool_id == tool_id
|
||||
assert binding2.tool_id == tool_id
|
||||
assert binding1.label_name != binding2.label_name
|
||||
|
|
@ -6,10 +6,10 @@ from unittest.mock import Mock, patch
|
|||
import pytest
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from core.workflow.entities.workflow_pause import WorkflowPauseEntity
|
||||
from core.workflow.enums import WorkflowExecutionStatus
|
||||
from models.workflow import WorkflowPause as WorkflowPauseModel
|
||||
from models.workflow import WorkflowRun
|
||||
from repositories.entities.workflow_pause import WorkflowPauseEntity
|
||||
from repositories.sqlalchemy_api_workflow_run_repository import (
|
||||
DifyAPISQLAlchemyWorkflowRunRepository,
|
||||
_PrivateWorkflowPauseEntity,
|
||||
|
|
@ -129,12 +129,14 @@ class TestCreateWorkflowPause(TestDifyAPISQLAlchemyWorkflowRunRepository):
|
|||
workflow_run_id=workflow_run_id,
|
||||
state_owner_user_id=state_owner_user_id,
|
||||
state=state,
|
||||
pause_reasons=[],
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, _PrivateWorkflowPauseEntity)
|
||||
assert result.id == "pause-123"
|
||||
assert result.workflow_execution_id == workflow_run_id
|
||||
assert result.get_pause_reasons() == []
|
||||
|
||||
# Verify database interactions
|
||||
mock_session.get.assert_called_once_with(WorkflowRun, workflow_run_id)
|
||||
|
|
@ -156,6 +158,7 @@ class TestCreateWorkflowPause(TestDifyAPISQLAlchemyWorkflowRunRepository):
|
|||
workflow_run_id="workflow-run-123",
|
||||
state_owner_user_id="user-123",
|
||||
state='{"test": "state"}',
|
||||
pause_reasons=[],
|
||||
)
|
||||
|
||||
mock_session.get.assert_called_once_with(WorkflowRun, "workflow-run-123")
|
||||
|
|
@ -174,6 +177,7 @@ class TestCreateWorkflowPause(TestDifyAPISQLAlchemyWorkflowRunRepository):
|
|||
workflow_run_id="workflow-run-123",
|
||||
state_owner_user_id="user-123",
|
||||
state='{"test": "state"}',
|
||||
pause_reasons=[],
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -316,19 +320,10 @@ class TestDeleteWorkflowPause(TestDifyAPISQLAlchemyWorkflowRunRepository):
|
|||
class TestPrivateWorkflowPauseEntity(TestDifyAPISQLAlchemyWorkflowRunRepository):
|
||||
"""Test _PrivateWorkflowPauseEntity class."""
|
||||
|
||||
def test_from_models(self, sample_workflow_pause: Mock):
|
||||
"""Test creating _PrivateWorkflowPauseEntity from models."""
|
||||
# Act
|
||||
entity = _PrivateWorkflowPauseEntity.from_models(sample_workflow_pause)
|
||||
|
||||
# Assert
|
||||
assert isinstance(entity, _PrivateWorkflowPauseEntity)
|
||||
assert entity._pause_model == sample_workflow_pause
|
||||
|
||||
def test_properties(self, sample_workflow_pause: Mock):
|
||||
"""Test entity properties."""
|
||||
# Arrange
|
||||
entity = _PrivateWorkflowPauseEntity.from_models(sample_workflow_pause)
|
||||
entity = _PrivateWorkflowPauseEntity(pause_model=sample_workflow_pause, reason_models=[], human_input_form=[])
|
||||
|
||||
# Act & Assert
|
||||
assert entity.id == sample_workflow_pause.id
|
||||
|
|
@ -338,7 +333,7 @@ class TestPrivateWorkflowPauseEntity(TestDifyAPISQLAlchemyWorkflowRunRepository)
|
|||
def test_get_state(self, sample_workflow_pause: Mock):
|
||||
"""Test getting state from storage."""
|
||||
# Arrange
|
||||
entity = _PrivateWorkflowPauseEntity.from_models(sample_workflow_pause)
|
||||
entity = _PrivateWorkflowPauseEntity(pause_model=sample_workflow_pause, reason_models=[], human_input_form=[])
|
||||
expected_state = b'{"test": "state"}'
|
||||
|
||||
with patch("repositories.sqlalchemy_api_workflow_run_repository.storage") as mock_storage:
|
||||
|
|
@ -354,7 +349,7 @@ class TestPrivateWorkflowPauseEntity(TestDifyAPISQLAlchemyWorkflowRunRepository)
|
|||
def test_get_state_caching(self, sample_workflow_pause: Mock):
|
||||
"""Test state caching in get_state method."""
|
||||
# Arrange
|
||||
entity = _PrivateWorkflowPauseEntity.from_models(sample_workflow_pause)
|
||||
entity = _PrivateWorkflowPauseEntity(pause_model=sample_workflow_pause, reason_models=[], human_input_form=[])
|
||||
expected_state = b'{"test": "state"}'
|
||||
|
||||
with patch("repositories.sqlalchemy_api_workflow_run_repository.storage") as mock_storage:
|
||||
|
|
|
|||
|
|
@ -0,0 +1,932 @@
|
|||
"""
|
||||
Comprehensive unit tests for DatasetCollectionBindingService.
|
||||
|
||||
This module contains extensive unit tests for the DatasetCollectionBindingService class,
|
||||
which handles dataset collection binding operations for vector database collections.
|
||||
|
||||
The DatasetCollectionBindingService provides methods for:
|
||||
- Retrieving or creating dataset collection bindings by provider, model, and type
|
||||
- Retrieving specific collection bindings by ID and type
|
||||
- Managing collection bindings for different collection types (dataset, etc.)
|
||||
|
||||
Collection bindings are used to map embedding models (provider + model name) to
|
||||
specific vector database collections, allowing datasets to share collections when
|
||||
they use the same embedding model configuration.
|
||||
|
||||
This test suite ensures:
|
||||
- Correct retrieval of existing bindings
|
||||
- Proper creation of new bindings when they don't exist
|
||||
- Accurate filtering by provider, model, and collection type
|
||||
- Proper error handling for missing bindings
|
||||
- Database transaction handling (add, commit)
|
||||
- Collection name generation using Dataset.gen_collection_name_by_id
|
||||
|
||||
================================================================================
|
||||
ARCHITECTURE OVERVIEW
|
||||
================================================================================
|
||||
|
||||
The DatasetCollectionBindingService is a critical component in the Dify platform's
|
||||
vector database management system. It serves as an abstraction layer between the
|
||||
application logic and the underlying vector database collections.
|
||||
|
||||
Key Concepts:
|
||||
1. Collection Binding: A mapping between an embedding model configuration
|
||||
(provider + model name) and a vector database collection name. This allows
|
||||
multiple datasets to share the same collection when they use identical
|
||||
embedding models, improving resource efficiency.
|
||||
|
||||
2. Collection Type: Different types of collections can exist (e.g., "dataset",
|
||||
"custom_type"). This allows for separation of collections based on their
|
||||
intended use case or data structure.
|
||||
|
||||
3. Provider and Model: The combination of provider_name (e.g., "openai",
|
||||
"cohere", "huggingface") and model_name (e.g., "text-embedding-ada-002")
|
||||
uniquely identifies an embedding model configuration.
|
||||
|
||||
4. Collection Name Generation: When a new binding is created, a unique collection
|
||||
name is generated using Dataset.gen_collection_name_by_id() with a UUID.
|
||||
This ensures each binding has a unique collection identifier.
|
||||
|
||||
================================================================================
|
||||
TESTING STRATEGY
|
||||
================================================================================
|
||||
|
||||
This test suite follows a comprehensive testing strategy that covers:
|
||||
|
||||
1. Happy Path Scenarios:
|
||||
- Successful retrieval of existing bindings
|
||||
- Successful creation of new bindings
|
||||
- Proper handling of default parameters
|
||||
|
||||
2. Edge Cases:
|
||||
- Different collection types
|
||||
- Various provider/model combinations
|
||||
- Default vs explicit parameter usage
|
||||
|
||||
3. Error Handling:
|
||||
- Missing bindings (for get_by_id_and_type)
|
||||
- Database query failures
|
||||
- Invalid parameter combinations
|
||||
|
||||
4. Database Interaction:
|
||||
- Query construction and execution
|
||||
- Transaction management (add, commit)
|
||||
- Query chaining (where, order_by, first)
|
||||
|
||||
5. Mocking Strategy:
|
||||
- Database session mocking
|
||||
- Query builder chain mocking
|
||||
- UUID generation mocking
|
||||
- Collection name generation mocking
|
||||
|
||||
================================================================================
|
||||
"""
|
||||
|
||||
"""
|
||||
Import statements for the test module.
|
||||
|
||||
This section imports all necessary dependencies for testing the
|
||||
DatasetCollectionBindingService, including:
|
||||
- unittest.mock for creating mock objects
|
||||
- pytest for test framework functionality
|
||||
- uuid for UUID generation (used in collection name generation)
|
||||
- Models and services from the application codebase
|
||||
"""
|
||||
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from models.dataset import Dataset, DatasetCollectionBinding
|
||||
from services.dataset_service import DatasetCollectionBindingService
|
||||
|
||||
# ============================================================================
|
||||
# Test Data Factory
|
||||
# ============================================================================
|
||||
# The Test Data Factory pattern is used here to centralize the creation of
|
||||
# test objects and mock instances. This approach provides several benefits:
|
||||
#
|
||||
# 1. Consistency: All test objects are created using the same factory methods,
|
||||
# ensuring consistent structure across all tests.
|
||||
#
|
||||
# 2. Maintainability: If the structure of DatasetCollectionBinding or Dataset
|
||||
# changes, we only need to update the factory methods rather than every
|
||||
# individual test.
|
||||
#
|
||||
# 3. Reusability: Factory methods can be reused across multiple test classes,
|
||||
# reducing code duplication.
|
||||
#
|
||||
# 4. Readability: Tests become more readable when they use descriptive factory
|
||||
# method calls instead of complex object construction logic.
|
||||
#
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class DatasetCollectionBindingTestDataFactory:
|
||||
"""
|
||||
Factory class for creating test data and mock objects for dataset collection binding tests.
|
||||
|
||||
This factory provides static methods to create mock objects for:
|
||||
- DatasetCollectionBinding instances
|
||||
- Database query results
|
||||
- Collection name generation results
|
||||
|
||||
The factory methods help maintain consistency across tests and reduce
|
||||
code duplication when setting up test scenarios.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def create_collection_binding_mock(
|
||||
binding_id: str = "binding-123",
|
||||
provider_name: str = "openai",
|
||||
model_name: str = "text-embedding-ada-002",
|
||||
collection_name: str = "collection-abc",
|
||||
collection_type: str = "dataset",
|
||||
created_at=None,
|
||||
**kwargs,
|
||||
) -> Mock:
|
||||
"""
|
||||
Create a mock DatasetCollectionBinding with specified attributes.
|
||||
|
||||
Args:
|
||||
binding_id: Unique identifier for the binding
|
||||
provider_name: Name of the embedding model provider (e.g., "openai", "cohere")
|
||||
model_name: Name of the embedding model (e.g., "text-embedding-ada-002")
|
||||
collection_name: Name of the vector database collection
|
||||
collection_type: Type of collection (default: "dataset")
|
||||
created_at: Optional datetime for creation timestamp
|
||||
**kwargs: Additional attributes to set on the mock
|
||||
|
||||
Returns:
|
||||
Mock object configured as a DatasetCollectionBinding instance
|
||||
"""
|
||||
binding = Mock(spec=DatasetCollectionBinding)
|
||||
binding.id = binding_id
|
||||
binding.provider_name = provider_name
|
||||
binding.model_name = model_name
|
||||
binding.collection_name = collection_name
|
||||
binding.type = collection_type
|
||||
binding.created_at = created_at
|
||||
for key, value in kwargs.items():
|
||||
setattr(binding, key, value)
|
||||
return binding
|
||||
|
||||
@staticmethod
|
||||
def create_dataset_mock(
|
||||
dataset_id: str = "dataset-123",
|
||||
**kwargs,
|
||||
) -> Mock:
|
||||
"""
|
||||
Create a mock Dataset for testing collection name generation.
|
||||
|
||||
Args:
|
||||
dataset_id: Unique identifier for the dataset
|
||||
**kwargs: Additional attributes to set on the mock
|
||||
|
||||
Returns:
|
||||
Mock object configured as a Dataset instance
|
||||
"""
|
||||
dataset = Mock(spec=Dataset)
|
||||
dataset.id = dataset_id
|
||||
for key, value in kwargs.items():
|
||||
setattr(dataset, key, value)
|
||||
return dataset
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Tests for get_dataset_collection_binding
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestDatasetCollectionBindingServiceGetBinding:
|
||||
"""
|
||||
Comprehensive unit tests for DatasetCollectionBindingService.get_dataset_collection_binding method.
|
||||
|
||||
This test class covers the main collection binding retrieval/creation functionality,
|
||||
including various provider/model combinations, collection types, and edge cases.
|
||||
|
||||
The get_dataset_collection_binding method:
|
||||
1. Queries for existing binding by provider_name, model_name, and collection_type
|
||||
2. Orders results by created_at (ascending) and takes the first match
|
||||
3. If no binding exists, creates a new one with:
|
||||
- The provided provider_name and model_name
|
||||
- A generated collection_name using Dataset.gen_collection_name_by_id
|
||||
- The provided collection_type
|
||||
4. Adds the new binding to the database session and commits
|
||||
5. Returns the binding (either existing or newly created)
|
||||
|
||||
Test scenarios include:
|
||||
- Retrieving existing bindings
|
||||
- Creating new bindings when none exist
|
||||
- Different collection types
|
||||
- Database transaction handling
|
||||
- Collection name generation
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_session(self):
|
||||
"""
|
||||
Mock database session for testing database operations.
|
||||
|
||||
Provides a mocked database session that can be used to verify:
|
||||
- Query construction and execution
|
||||
- Add operations for new bindings
|
||||
- Commit operations for transaction completion
|
||||
|
||||
The mock is configured to return a query builder that supports
|
||||
chaining operations like .where(), .order_by(), and .first().
|
||||
"""
|
||||
with patch("services.dataset_service.db.session") as mock_db:
|
||||
yield mock_db
|
||||
|
||||
def test_get_dataset_collection_binding_existing_binding_success(self, mock_db_session):
|
||||
"""
|
||||
Test successful retrieval of an existing collection binding.
|
||||
|
||||
Verifies that when a binding already exists in the database for the given
|
||||
provider, model, and collection type, the method returns the existing binding
|
||||
without creating a new one.
|
||||
|
||||
This test ensures:
|
||||
- The query is constructed correctly with all three filters
|
||||
- Results are ordered by created_at
|
||||
- The first matching binding is returned
|
||||
- No new binding is created (db.session.add is not called)
|
||||
- No commit is performed (db.session.commit is not called)
|
||||
"""
|
||||
# Arrange
|
||||
provider_name = "openai"
|
||||
model_name = "text-embedding-ada-002"
|
||||
collection_type = "dataset"
|
||||
|
||||
existing_binding = DatasetCollectionBindingTestDataFactory.create_collection_binding_mock(
|
||||
binding_id="binding-123",
|
||||
provider_name=provider_name,
|
||||
model_name=model_name,
|
||||
collection_type=collection_type,
|
||||
)
|
||||
|
||||
# Mock the query chain: query().where().order_by().first()
|
||||
mock_query = Mock()
|
||||
mock_where = Mock()
|
||||
mock_order_by = Mock()
|
||||
mock_query.where.return_value = mock_where
|
||||
mock_where.order_by.return_value = mock_order_by
|
||||
mock_order_by.first.return_value = existing_binding
|
||||
mock_db_session.query.return_value = mock_query
|
||||
|
||||
# Act
|
||||
result = DatasetCollectionBindingService.get_dataset_collection_binding(
|
||||
provider_name=provider_name, model_name=model_name, collection_type=collection_type
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == existing_binding
|
||||
assert result.id == "binding-123"
|
||||
assert result.provider_name == provider_name
|
||||
assert result.model_name == model_name
|
||||
assert result.type == collection_type
|
||||
|
||||
# Verify query was constructed correctly
|
||||
# The query should be constructed with DatasetCollectionBinding as the model
|
||||
mock_db_session.query.assert_called_once_with(DatasetCollectionBinding)
|
||||
|
||||
# Verify the where clause was applied to filter by provider, model, and type
|
||||
mock_query.where.assert_called_once()
|
||||
|
||||
# Verify the results were ordered by created_at (ascending)
|
||||
# This ensures we get the oldest binding if multiple exist
|
||||
mock_where.order_by.assert_called_once()
|
||||
|
||||
# Verify no new binding was created
|
||||
# Since an existing binding was found, we should not create a new one
|
||||
mock_db_session.add.assert_not_called()
|
||||
|
||||
# Verify no commit was performed
|
||||
# Since no new binding was created, no database transaction is needed
|
||||
mock_db_session.commit.assert_not_called()
|
||||
|
||||
def test_get_dataset_collection_binding_create_new_binding_success(self, mock_db_session):
|
||||
"""
|
||||
Test successful creation of a new collection binding when none exists.
|
||||
|
||||
Verifies that when no binding exists in the database for the given
|
||||
provider, model, and collection type, the method creates a new binding
|
||||
with a generated collection name and commits it to the database.
|
||||
|
||||
This test ensures:
|
||||
- The query returns None (no existing binding)
|
||||
- A new DatasetCollectionBinding is created with correct attributes
|
||||
- Dataset.gen_collection_name_by_id is called to generate collection name
|
||||
- The new binding is added to the database session
|
||||
- The transaction is committed
|
||||
- The newly created binding is returned
|
||||
"""
|
||||
# Arrange
|
||||
provider_name = "cohere"
|
||||
model_name = "embed-english-v3.0"
|
||||
collection_type = "dataset"
|
||||
generated_collection_name = "collection-generated-xyz"
|
||||
|
||||
# Mock the query chain to return None (no existing binding)
|
||||
mock_query = Mock()
|
||||
mock_where = Mock()
|
||||
mock_order_by = Mock()
|
||||
mock_query.where.return_value = mock_where
|
||||
mock_where.order_by.return_value = mock_order_by
|
||||
mock_order_by.first.return_value = None # No existing binding
|
||||
mock_db_session.query.return_value = mock_query
|
||||
|
||||
# Mock Dataset.gen_collection_name_by_id to return a generated name
|
||||
with patch("services.dataset_service.Dataset.gen_collection_name_by_id") as mock_gen_name:
|
||||
mock_gen_name.return_value = generated_collection_name
|
||||
|
||||
# Mock uuid.uuid4 for the collection name generation
|
||||
mock_uuid = "test-uuid-123"
|
||||
with patch("services.dataset_service.uuid.uuid4", return_value=mock_uuid):
|
||||
# Act
|
||||
result = DatasetCollectionBindingService.get_dataset_collection_binding(
|
||||
provider_name=provider_name, model_name=model_name, collection_type=collection_type
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert result.provider_name == provider_name
|
||||
assert result.model_name == model_name
|
||||
assert result.type == collection_type
|
||||
assert result.collection_name == generated_collection_name
|
||||
|
||||
# Verify Dataset.gen_collection_name_by_id was called with the generated UUID
|
||||
# This method generates a unique collection name based on the UUID
|
||||
# The UUID is converted to string before passing to the method
|
||||
mock_gen_name.assert_called_once_with(str(mock_uuid))
|
||||
|
||||
# Verify new binding was added to the database session
|
||||
# The add method should be called exactly once with the new binding instance
|
||||
mock_db_session.add.assert_called_once()
|
||||
|
||||
# Extract the binding that was added to verify its properties
|
||||
added_binding = mock_db_session.add.call_args[0][0]
|
||||
|
||||
# Verify the added binding is an instance of DatasetCollectionBinding
|
||||
# This ensures we're creating the correct type of object
|
||||
assert isinstance(added_binding, DatasetCollectionBinding)
|
||||
|
||||
# Verify all the binding properties are set correctly
|
||||
# These should match the input parameters to the method
|
||||
assert added_binding.provider_name == provider_name
|
||||
assert added_binding.model_name == model_name
|
||||
assert added_binding.type == collection_type
|
||||
|
||||
# Verify the collection name was set from the generated name
|
||||
# This ensures the binding has a valid collection identifier
|
||||
assert added_binding.collection_name == generated_collection_name
|
||||
|
||||
# Verify the transaction was committed
|
||||
# This ensures the new binding is persisted to the database
|
||||
mock_db_session.commit.assert_called_once()
|
||||
|
||||
def test_get_dataset_collection_binding_different_collection_type(self, mock_db_session):
|
||||
"""
|
||||
Test retrieval with a different collection type (not "dataset").
|
||||
|
||||
Verifies that the method correctly filters by collection_type, allowing
|
||||
different types of collections to coexist with the same provider/model
|
||||
combination.
|
||||
|
||||
This test ensures:
|
||||
- Collection type is properly used as a filter in the query
|
||||
- Different collection types can have separate bindings
|
||||
- The correct binding is returned based on type
|
||||
"""
|
||||
# Arrange
|
||||
provider_name = "openai"
|
||||
model_name = "text-embedding-ada-002"
|
||||
collection_type = "custom_type"
|
||||
|
||||
existing_binding = DatasetCollectionBindingTestDataFactory.create_collection_binding_mock(
|
||||
binding_id="binding-456",
|
||||
provider_name=provider_name,
|
||||
model_name=model_name,
|
||||
collection_type=collection_type,
|
||||
)
|
||||
|
||||
# Mock the query chain
|
||||
mock_query = Mock()
|
||||
mock_where = Mock()
|
||||
mock_order_by = Mock()
|
||||
mock_query.where.return_value = mock_where
|
||||
mock_where.order_by.return_value = mock_order_by
|
||||
mock_order_by.first.return_value = existing_binding
|
||||
mock_db_session.query.return_value = mock_query
|
||||
|
||||
# Act
|
||||
result = DatasetCollectionBindingService.get_dataset_collection_binding(
|
||||
provider_name=provider_name, model_name=model_name, collection_type=collection_type
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == existing_binding
|
||||
assert result.type == collection_type
|
||||
|
||||
# Verify query was constructed with the correct type filter
|
||||
mock_db_session.query.assert_called_once_with(DatasetCollectionBinding)
|
||||
mock_query.where.assert_called_once()
|
||||
|
||||
def test_get_dataset_collection_binding_default_collection_type(self, mock_db_session):
|
||||
"""
|
||||
Test retrieval with default collection type ("dataset").
|
||||
|
||||
Verifies that when collection_type is not provided, it defaults to "dataset"
|
||||
as specified in the method signature.
|
||||
|
||||
This test ensures:
|
||||
- The default value "dataset" is used when type is not specified
|
||||
- The query correctly filters by the default type
|
||||
"""
|
||||
# Arrange
|
||||
provider_name = "openai"
|
||||
model_name = "text-embedding-ada-002"
|
||||
# collection_type defaults to "dataset" in method signature
|
||||
|
||||
existing_binding = DatasetCollectionBindingTestDataFactory.create_collection_binding_mock(
|
||||
binding_id="binding-789",
|
||||
provider_name=provider_name,
|
||||
model_name=model_name,
|
||||
collection_type="dataset", # Default type
|
||||
)
|
||||
|
||||
# Mock the query chain
|
||||
mock_query = Mock()
|
||||
mock_where = Mock()
|
||||
mock_order_by = Mock()
|
||||
mock_query.where.return_value = mock_where
|
||||
mock_where.order_by.return_value = mock_order_by
|
||||
mock_order_by.first.return_value = existing_binding
|
||||
mock_db_session.query.return_value = mock_query
|
||||
|
||||
# Act - call without specifying collection_type (uses default)
|
||||
result = DatasetCollectionBindingService.get_dataset_collection_binding(
|
||||
provider_name=provider_name, model_name=model_name
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == existing_binding
|
||||
assert result.type == "dataset"
|
||||
|
||||
# Verify query was constructed correctly
|
||||
mock_db_session.query.assert_called_once_with(DatasetCollectionBinding)
|
||||
|
||||
def test_get_dataset_collection_binding_different_provider_model_combination(self, mock_db_session):
|
||||
"""
|
||||
Test retrieval with different provider/model combinations.
|
||||
|
||||
Verifies that bindings are correctly filtered by both provider_name and
|
||||
model_name, ensuring that different model combinations have separate bindings.
|
||||
|
||||
This test ensures:
|
||||
- Provider and model are both used as filters
|
||||
- Different combinations result in different bindings
|
||||
- The correct binding is returned for each combination
|
||||
"""
|
||||
# Arrange
|
||||
provider_name = "huggingface"
|
||||
model_name = "sentence-transformers/all-MiniLM-L6-v2"
|
||||
collection_type = "dataset"
|
||||
|
||||
existing_binding = DatasetCollectionBindingTestDataFactory.create_collection_binding_mock(
|
||||
binding_id="binding-hf-123",
|
||||
provider_name=provider_name,
|
||||
model_name=model_name,
|
||||
collection_type=collection_type,
|
||||
)
|
||||
|
||||
# Mock the query chain
|
||||
mock_query = Mock()
|
||||
mock_where = Mock()
|
||||
mock_order_by = Mock()
|
||||
mock_query.where.return_value = mock_where
|
||||
mock_where.order_by.return_value = mock_order_by
|
||||
mock_order_by.first.return_value = existing_binding
|
||||
mock_db_session.query.return_value = mock_query
|
||||
|
||||
# Act
|
||||
result = DatasetCollectionBindingService.get_dataset_collection_binding(
|
||||
provider_name=provider_name, model_name=model_name, collection_type=collection_type
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == existing_binding
|
||||
assert result.provider_name == provider_name
|
||||
assert result.model_name == model_name
|
||||
|
||||
# Verify query filters were applied correctly
|
||||
# The query should filter by both provider_name and model_name
|
||||
# This ensures different model combinations have separate bindings
|
||||
mock_db_session.query.assert_called_once_with(DatasetCollectionBinding)
|
||||
|
||||
# Verify the where clause was applied with all three filters:
|
||||
# - provider_name filter
|
||||
# - model_name filter
|
||||
# - collection_type filter
|
||||
mock_query.where.assert_called_once()
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Tests for get_dataset_collection_binding_by_id_and_type
|
||||
# ============================================================================
|
||||
# This section contains tests for the get_dataset_collection_binding_by_id_and_type
|
||||
# method, which retrieves a specific collection binding by its ID and type.
|
||||
#
|
||||
# Key differences from get_dataset_collection_binding:
|
||||
# 1. This method queries by ID and type, not by provider/model/type
|
||||
# 2. This method does NOT create a new binding if one doesn't exist
|
||||
# 3. This method raises ValueError if the binding is not found
|
||||
# 4. This method is typically used when you already know the binding ID
|
||||
#
|
||||
# Use cases:
|
||||
# - Retrieving a binding that was previously created
|
||||
# - Validating that a binding exists before using it
|
||||
# - Accessing binding metadata when you have the ID
|
||||
#
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestDatasetCollectionBindingServiceGetBindingByIdAndType:
|
||||
"""
|
||||
Comprehensive unit tests for DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type method.
|
||||
|
||||
This test class covers collection binding retrieval by ID and type,
|
||||
including success scenarios and error handling for missing bindings.
|
||||
|
||||
The get_dataset_collection_binding_by_id_and_type method:
|
||||
1. Queries for a binding by collection_binding_id and collection_type
|
||||
2. Orders results by created_at (ascending) and takes the first match
|
||||
3. If no binding exists, raises ValueError("Dataset collection binding not found")
|
||||
4. Returns the found binding
|
||||
|
||||
Unlike get_dataset_collection_binding, this method does NOT create a new
|
||||
binding if one doesn't exist - it only retrieves existing bindings.
|
||||
|
||||
Test scenarios include:
|
||||
- Successful retrieval of existing bindings
|
||||
- Error handling for missing bindings
|
||||
- Different collection types
|
||||
- Default collection type behavior
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_session(self):
|
||||
"""
|
||||
Mock database session for testing database operations.
|
||||
|
||||
Provides a mocked database session that can be used to verify:
|
||||
- Query construction with ID and type filters
|
||||
- Ordering by created_at
|
||||
- First result retrieval
|
||||
|
||||
The mock is configured to return a query builder that supports
|
||||
chaining operations like .where(), .order_by(), and .first().
|
||||
"""
|
||||
with patch("services.dataset_service.db.session") as mock_db:
|
||||
yield mock_db
|
||||
|
||||
def test_get_dataset_collection_binding_by_id_and_type_success(self, mock_db_session):
|
||||
"""
|
||||
Test successful retrieval of a collection binding by ID and type.
|
||||
|
||||
Verifies that when a binding exists in the database with the given
|
||||
ID and collection type, the method returns the binding.
|
||||
|
||||
This test ensures:
|
||||
- The query is constructed correctly with ID and type filters
|
||||
- Results are ordered by created_at
|
||||
- The first matching binding is returned
|
||||
- No error is raised
|
||||
"""
|
||||
# Arrange
|
||||
collection_binding_id = "binding-123"
|
||||
collection_type = "dataset"
|
||||
|
||||
existing_binding = DatasetCollectionBindingTestDataFactory.create_collection_binding_mock(
|
||||
binding_id=collection_binding_id,
|
||||
provider_name="openai",
|
||||
model_name="text-embedding-ada-002",
|
||||
collection_type=collection_type,
|
||||
)
|
||||
|
||||
# Mock the query chain: query().where().order_by().first()
|
||||
mock_query = Mock()
|
||||
mock_where = Mock()
|
||||
mock_order_by = Mock()
|
||||
mock_query.where.return_value = mock_where
|
||||
mock_where.order_by.return_value = mock_order_by
|
||||
mock_order_by.first.return_value = existing_binding
|
||||
mock_db_session.query.return_value = mock_query
|
||||
|
||||
# Act
|
||||
result = DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(
|
||||
collection_binding_id=collection_binding_id, collection_type=collection_type
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == existing_binding
|
||||
assert result.id == collection_binding_id
|
||||
assert result.type == collection_type
|
||||
|
||||
# Verify query was constructed correctly
|
||||
mock_db_session.query.assert_called_once_with(DatasetCollectionBinding)
|
||||
mock_query.where.assert_called_once()
|
||||
mock_where.order_by.assert_called_once()
|
||||
|
||||
def test_get_dataset_collection_binding_by_id_and_type_not_found_error(self, mock_db_session):
|
||||
"""
|
||||
Test error handling when binding is not found.
|
||||
|
||||
Verifies that when no binding exists in the database with the given
|
||||
ID and collection type, the method raises a ValueError with the
|
||||
message "Dataset collection binding not found".
|
||||
|
||||
This test ensures:
|
||||
- The query returns None (no existing binding)
|
||||
- ValueError is raised with the correct message
|
||||
- No binding is returned
|
||||
"""
|
||||
# Arrange
|
||||
collection_binding_id = "non-existent-binding"
|
||||
collection_type = "dataset"
|
||||
|
||||
# Mock the query chain to return None (no existing binding)
|
||||
mock_query = Mock()
|
||||
mock_where = Mock()
|
||||
mock_order_by = Mock()
|
||||
mock_query.where.return_value = mock_where
|
||||
mock_where.order_by.return_value = mock_order_by
|
||||
mock_order_by.first.return_value = None # No existing binding
|
||||
mock_db_session.query.return_value = mock_query
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="Dataset collection binding not found"):
|
||||
DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(
|
||||
collection_binding_id=collection_binding_id, collection_type=collection_type
|
||||
)
|
||||
|
||||
# Verify query was attempted
|
||||
mock_db_session.query.assert_called_once_with(DatasetCollectionBinding)
|
||||
mock_query.where.assert_called_once()
|
||||
|
||||
def test_get_dataset_collection_binding_by_id_and_type_different_collection_type(self, mock_db_session):
|
||||
"""
|
||||
Test retrieval with a different collection type.
|
||||
|
||||
Verifies that the method correctly filters by collection_type, ensuring
|
||||
that bindings with the same ID but different types are treated as
|
||||
separate entities.
|
||||
|
||||
This test ensures:
|
||||
- Collection type is properly used as a filter in the query
|
||||
- Different collection types can have separate bindings with same ID
|
||||
- The correct binding is returned based on type
|
||||
"""
|
||||
# Arrange
|
||||
collection_binding_id = "binding-456"
|
||||
collection_type = "custom_type"
|
||||
|
||||
existing_binding = DatasetCollectionBindingTestDataFactory.create_collection_binding_mock(
|
||||
binding_id=collection_binding_id,
|
||||
provider_name="cohere",
|
||||
model_name="embed-english-v3.0",
|
||||
collection_type=collection_type,
|
||||
)
|
||||
|
||||
# Mock the query chain
|
||||
mock_query = Mock()
|
||||
mock_where = Mock()
|
||||
mock_order_by = Mock()
|
||||
mock_query.where.return_value = mock_where
|
||||
mock_where.order_by.return_value = mock_order_by
|
||||
mock_order_by.first.return_value = existing_binding
|
||||
mock_db_session.query.return_value = mock_query
|
||||
|
||||
# Act
|
||||
result = DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(
|
||||
collection_binding_id=collection_binding_id, collection_type=collection_type
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == existing_binding
|
||||
assert result.id == collection_binding_id
|
||||
assert result.type == collection_type
|
||||
|
||||
# Verify query was constructed with the correct type filter
|
||||
mock_db_session.query.assert_called_once_with(DatasetCollectionBinding)
|
||||
mock_query.where.assert_called_once()
|
||||
|
||||
def test_get_dataset_collection_binding_by_id_and_type_default_collection_type(self, mock_db_session):
|
||||
"""
|
||||
Test retrieval with default collection type ("dataset").
|
||||
|
||||
Verifies that when collection_type is not provided, it defaults to "dataset"
|
||||
as specified in the method signature.
|
||||
|
||||
This test ensures:
|
||||
- The default value "dataset" is used when type is not specified
|
||||
- The query correctly filters by the default type
|
||||
- The correct binding is returned
|
||||
"""
|
||||
# Arrange
|
||||
collection_binding_id = "binding-789"
|
||||
# collection_type defaults to "dataset" in method signature
|
||||
|
||||
existing_binding = DatasetCollectionBindingTestDataFactory.create_collection_binding_mock(
|
||||
binding_id=collection_binding_id,
|
||||
provider_name="openai",
|
||||
model_name="text-embedding-ada-002",
|
||||
collection_type="dataset", # Default type
|
||||
)
|
||||
|
||||
# Mock the query chain
|
||||
mock_query = Mock()
|
||||
mock_where = Mock()
|
||||
mock_order_by = Mock()
|
||||
mock_query.where.return_value = mock_where
|
||||
mock_where.order_by.return_value = mock_order_by
|
||||
mock_order_by.first.return_value = existing_binding
|
||||
mock_db_session.query.return_value = mock_query
|
||||
|
||||
# Act - call without specifying collection_type (uses default)
|
||||
result = DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(
|
||||
collection_binding_id=collection_binding_id
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == existing_binding
|
||||
assert result.id == collection_binding_id
|
||||
assert result.type == "dataset"
|
||||
|
||||
# Verify query was constructed correctly
|
||||
mock_db_session.query.assert_called_once_with(DatasetCollectionBinding)
|
||||
mock_query.where.assert_called_once()
|
||||
|
||||
def test_get_dataset_collection_binding_by_id_and_type_wrong_type_error(self, mock_db_session):
|
||||
"""
|
||||
Test error handling when binding exists but with wrong collection type.
|
||||
|
||||
Verifies that when a binding exists with the given ID but a different
|
||||
collection type, the method raises a ValueError because the binding
|
||||
doesn't match both the ID and type criteria.
|
||||
|
||||
This test ensures:
|
||||
- The query correctly filters by both ID and type
|
||||
- Bindings with matching ID but different type are not returned
|
||||
- ValueError is raised when no matching binding is found
|
||||
"""
|
||||
# Arrange
|
||||
collection_binding_id = "binding-123"
|
||||
collection_type = "dataset"
|
||||
|
||||
# Mock the query chain to return None (binding exists but with different type)
|
||||
mock_query = Mock()
|
||||
mock_where = Mock()
|
||||
mock_order_by = Mock()
|
||||
mock_query.where.return_value = mock_where
|
||||
mock_where.order_by.return_value = mock_order_by
|
||||
mock_order_by.first.return_value = None # No matching binding
|
||||
mock_db_session.query.return_value = mock_query
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="Dataset collection binding not found"):
|
||||
DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(
|
||||
collection_binding_id=collection_binding_id, collection_type=collection_type
|
||||
)
|
||||
|
||||
# Verify query was attempted with both ID and type filters
|
||||
# The query should filter by both collection_binding_id and collection_type
|
||||
# This ensures we only get bindings that match both criteria
|
||||
mock_db_session.query.assert_called_once_with(DatasetCollectionBinding)
|
||||
|
||||
# Verify the where clause was applied with both filters:
|
||||
# - collection_binding_id filter (exact match)
|
||||
# - collection_type filter (exact match)
|
||||
mock_query.where.assert_called_once()
|
||||
|
||||
# Note: The order_by and first() calls are also part of the query chain,
|
||||
# but we don't need to verify them separately since they're part of the
|
||||
# standard query pattern used by both methods in this service.
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Additional Test Scenarios and Edge Cases
|
||||
# ============================================================================
|
||||
# The following section could contain additional test scenarios if needed:
|
||||
#
|
||||
# Potential additional tests:
|
||||
# 1. Test with multiple existing bindings (verify ordering by created_at)
|
||||
# 2. Test with very long provider/model names (boundary testing)
|
||||
# 3. Test with special characters in provider/model names
|
||||
# 4. Test concurrent binding creation (thread safety)
|
||||
# 5. Test database rollback scenarios
|
||||
# 6. Test with None values for optional parameters
|
||||
# 7. Test with empty strings for required parameters
|
||||
# 8. Test collection name generation uniqueness
|
||||
# 9. Test with different UUID formats
|
||||
# 10. Test query performance with large datasets
|
||||
#
|
||||
# These scenarios are not currently implemented but could be added if needed
|
||||
# based on real-world usage patterns or discovered edge cases.
|
||||
#
|
||||
# ============================================================================
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Integration Notes and Best Practices
|
||||
# ============================================================================
|
||||
#
|
||||
# When using DatasetCollectionBindingService in production code, consider:
|
||||
#
|
||||
# 1. Error Handling:
|
||||
# - Always handle ValueError exceptions when calling
|
||||
# get_dataset_collection_binding_by_id_and_type
|
||||
# - Check return values from get_dataset_collection_binding to ensure
|
||||
# bindings were created successfully
|
||||
#
|
||||
# 2. Performance Considerations:
|
||||
# - The service queries the database on every call, so consider caching
|
||||
# bindings if they're accessed frequently
|
||||
# - Collection bindings are typically long-lived, so caching is safe
|
||||
#
|
||||
# 3. Transaction Management:
|
||||
# - New bindings are automatically committed to the database
|
||||
# - If you need to rollback, ensure you're within a transaction context
|
||||
#
|
||||
# 4. Collection Type Usage:
|
||||
# - Use "dataset" for standard dataset collections
|
||||
# - Use custom types only when you need to separate collections by purpose
|
||||
# - Be consistent with collection type naming across your application
|
||||
#
|
||||
# 5. Provider and Model Naming:
|
||||
# - Use consistent provider names (e.g., "openai", not "OpenAI" or "OPENAI")
|
||||
# - Use exact model names as provided by the model provider
|
||||
# - These names are case-sensitive and must match exactly
|
||||
#
|
||||
# ============================================================================
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Database Schema Reference
|
||||
# ============================================================================
|
||||
#
|
||||
# The DatasetCollectionBinding model has the following structure:
|
||||
#
|
||||
# - id: StringUUID (primary key, auto-generated)
|
||||
# - provider_name: String(255) (required, e.g., "openai", "cohere")
|
||||
# - model_name: String(255) (required, e.g., "text-embedding-ada-002")
|
||||
# - type: String(40) (required, default: "dataset")
|
||||
# - collection_name: String(64) (required, unique collection identifier)
|
||||
# - created_at: DateTime (auto-generated timestamp)
|
||||
#
|
||||
# Indexes:
|
||||
# - Primary key on id
|
||||
# - Composite index on (provider_name, model_name) for efficient lookups
|
||||
#
|
||||
# Relationships:
|
||||
# - One binding can be referenced by multiple datasets
|
||||
# - Datasets reference bindings via collection_binding_id
|
||||
#
|
||||
# ============================================================================
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Mocking Strategy Documentation
|
||||
# ============================================================================
|
||||
#
|
||||
# This test suite uses extensive mocking to isolate the unit under test.
|
||||
# Here's how the mocking strategy works:
|
||||
#
|
||||
# 1. Database Session Mocking:
|
||||
# - db.session is patched to prevent actual database access
|
||||
# - Query chains are mocked to return predictable results
|
||||
# - Add and commit operations are tracked for verification
|
||||
#
|
||||
# 2. Query Chain Mocking:
|
||||
# - query() returns a mock query object
|
||||
# - where() returns a mock where object
|
||||
# - order_by() returns a mock order_by object
|
||||
# - first() returns the final result (binding or None)
|
||||
#
|
||||
# 3. UUID Generation Mocking:
|
||||
# - uuid.uuid4() is mocked to return predictable UUIDs
|
||||
# - This ensures collection names are generated consistently in tests
|
||||
#
|
||||
# 4. Collection Name Generation Mocking:
|
||||
# - Dataset.gen_collection_name_by_id() is mocked
|
||||
# - This allows us to verify the method is called correctly
|
||||
# - We can control the generated collection name for testing
|
||||
#
|
||||
# Benefits of this approach:
|
||||
# - Tests run quickly (no database I/O)
|
||||
# - Tests are deterministic (no random UUIDs)
|
||||
# - Tests are isolated (no side effects)
|
||||
# - Tests are maintainable (clear mock setup)
|
||||
#
|
||||
# ============================================================================
|
||||
|
|
@ -0,0 +1,920 @@
|
|||
"""
|
||||
Extensive unit tests for ``ExternalDatasetService``.
|
||||
|
||||
This module focuses on the *external dataset service* surface area, which is responsible
|
||||
for integrating with **external knowledge APIs** and wiring them into Dify datasets.
|
||||
|
||||
The goal of this test suite is twofold:
|
||||
|
||||
- Provide **high‑confidence regression coverage** for all public helpers on
|
||||
``ExternalDatasetService``.
|
||||
- Serve as **executable documentation** for how external API integration is expected
|
||||
to behave in different scenarios (happy paths, validation failures, and error codes).
|
||||
|
||||
The file intentionally contains **rich comments and generous spacing** in order to make
|
||||
each scenario easy to scan during reviews.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, cast
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from constants import HIDDEN_VALUE
|
||||
from models.dataset import Dataset, ExternalKnowledgeApis, ExternalKnowledgeBindings
|
||||
from services.entities.external_knowledge_entities.external_knowledge_entities import (
|
||||
Authorization,
|
||||
AuthorizationConfig,
|
||||
ExternalKnowledgeApiSetting,
|
||||
)
|
||||
from services.errors.dataset import DatasetNameDuplicateError
|
||||
from services.external_knowledge_service import ExternalDatasetService
|
||||
|
||||
|
||||
class ExternalDatasetTestDataFactory:
|
||||
"""
|
||||
Factory helpers for building *lightweight* mocks for external knowledge tests.
|
||||
|
||||
These helpers are intentionally small and explicit:
|
||||
|
||||
- They avoid pulling in unnecessary fixtures.
|
||||
- They reflect the minimal contract that the service under test cares about.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def create_external_api(
|
||||
api_id: str = "api-123",
|
||||
tenant_id: str = "tenant-1",
|
||||
name: str = "Test API",
|
||||
description: str = "Description",
|
||||
settings: dict | None = None,
|
||||
) -> ExternalKnowledgeApis:
|
||||
"""
|
||||
Create a concrete ``ExternalKnowledgeApis`` instance with minimal fields.
|
||||
|
||||
Using the real SQLAlchemy model (instead of a pure Mock) makes it easier to
|
||||
exercise ``settings_dict`` and other convenience properties if needed.
|
||||
"""
|
||||
|
||||
instance = ExternalKnowledgeApis(
|
||||
tenant_id=tenant_id,
|
||||
name=name,
|
||||
description=description,
|
||||
settings=None if settings is None else cast(str, pytest.approx), # type: ignore[assignment]
|
||||
)
|
||||
|
||||
# Overwrite generated id for determinism in assertions.
|
||||
instance.id = api_id
|
||||
return instance
|
||||
|
||||
@staticmethod
|
||||
def create_dataset(
|
||||
dataset_id: str = "ds-1",
|
||||
tenant_id: str = "tenant-1",
|
||||
name: str = "External Dataset",
|
||||
provider: str = "external",
|
||||
) -> Dataset:
|
||||
"""
|
||||
Build a small ``Dataset`` instance representing an external dataset.
|
||||
"""
|
||||
|
||||
dataset = Dataset(
|
||||
tenant_id=tenant_id,
|
||||
name=name,
|
||||
description="",
|
||||
provider=provider,
|
||||
created_by="user-1",
|
||||
)
|
||||
dataset.id = dataset_id
|
||||
return dataset
|
||||
|
||||
@staticmethod
|
||||
def create_external_binding(
|
||||
tenant_id: str = "tenant-1",
|
||||
dataset_id: str = "ds-1",
|
||||
api_id: str = "api-1",
|
||||
external_knowledge_id: str = "knowledge-1",
|
||||
) -> ExternalKnowledgeBindings:
|
||||
"""
|
||||
Small helper for a binding between dataset and external knowledge API.
|
||||
"""
|
||||
|
||||
binding = ExternalKnowledgeBindings(
|
||||
tenant_id=tenant_id,
|
||||
dataset_id=dataset_id,
|
||||
external_knowledge_api_id=api_id,
|
||||
external_knowledge_id=external_knowledge_id,
|
||||
created_by="user-1",
|
||||
)
|
||||
return binding
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_external_knowledge_apis
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestExternalDatasetServiceGetExternalKnowledgeApis:
|
||||
"""
|
||||
Tests for ``ExternalDatasetService.get_external_knowledge_apis``.
|
||||
|
||||
These tests focus on:
|
||||
|
||||
- Basic pagination wiring via ``db.paginate``.
|
||||
- Optional search keyword behaviour.
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_paginate(self):
|
||||
"""
|
||||
Patch ``db.paginate`` so we do not touch the real database layer.
|
||||
"""
|
||||
|
||||
with (
|
||||
patch("services.external_knowledge_service.db.paginate") as mock_paginate,
|
||||
patch("services.external_knowledge_service.select"),
|
||||
):
|
||||
yield mock_paginate
|
||||
|
||||
def test_get_external_knowledge_apis_basic_pagination(self, mock_db_paginate: MagicMock):
|
||||
"""
|
||||
It should return ``items`` and ``total`` coming from the paginate object.
|
||||
"""
|
||||
|
||||
# Arrange
|
||||
tenant_id = "tenant-1"
|
||||
page = 1
|
||||
per_page = 20
|
||||
|
||||
mock_items = [Mock(spec=ExternalKnowledgeApis), Mock(spec=ExternalKnowledgeApis)]
|
||||
mock_pagination = SimpleNamespace(items=mock_items, total=42)
|
||||
mock_db_paginate.return_value = mock_pagination
|
||||
|
||||
# Act
|
||||
items, total = ExternalDatasetService.get_external_knowledge_apis(page, per_page, tenant_id)
|
||||
|
||||
# Assert
|
||||
assert items is mock_items
|
||||
assert total == 42
|
||||
|
||||
mock_db_paginate.assert_called_once()
|
||||
call_kwargs = mock_db_paginate.call_args.kwargs
|
||||
assert call_kwargs["page"] == page
|
||||
assert call_kwargs["per_page"] == per_page
|
||||
assert call_kwargs["max_per_page"] == 100
|
||||
assert call_kwargs["error_out"] is False
|
||||
|
||||
def test_get_external_knowledge_apis_with_search_keyword(self, mock_db_paginate: MagicMock):
|
||||
"""
|
||||
When a search keyword is provided, the query should be adjusted
|
||||
(we simply assert that paginate is still called and does not explode).
|
||||
"""
|
||||
|
||||
# Arrange
|
||||
tenant_id = "tenant-1"
|
||||
page = 2
|
||||
per_page = 10
|
||||
search = "foo"
|
||||
|
||||
mock_pagination = SimpleNamespace(items=[], total=0)
|
||||
mock_db_paginate.return_value = mock_pagination
|
||||
|
||||
# Act
|
||||
items, total = ExternalDatasetService.get_external_knowledge_apis(page, per_page, tenant_id, search=search)
|
||||
|
||||
# Assert
|
||||
assert items == []
|
||||
assert total == 0
|
||||
mock_db_paginate.assert_called_once()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# validate_api_list
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestExternalDatasetServiceValidateApiList:
|
||||
"""
|
||||
Lightweight validation tests for ``validate_api_list``.
|
||||
"""
|
||||
|
||||
def test_validate_api_list_success(self):
|
||||
"""
|
||||
A minimal valid configuration (endpoint + api_key) should pass.
|
||||
"""
|
||||
|
||||
config = {"endpoint": "https://example.com", "api_key": "secret"}
|
||||
|
||||
# Act & Assert – no exception expected
|
||||
ExternalDatasetService.validate_api_list(config)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("config", "expected_message"),
|
||||
[
|
||||
({}, "api list is empty"),
|
||||
({"api_key": "k"}, "endpoint is required"),
|
||||
({"endpoint": "https://example.com"}, "api_key is required"),
|
||||
],
|
||||
)
|
||||
def test_validate_api_list_failures(self, config: dict, expected_message: str):
|
||||
"""
|
||||
Invalid configs should raise ``ValueError`` with a clear message.
|
||||
"""
|
||||
|
||||
with pytest.raises(ValueError, match=expected_message):
|
||||
ExternalDatasetService.validate_api_list(config)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# create_external_knowledge_api & get/update/delete
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestExternalDatasetServiceCrudExternalKnowledgeApi:
|
||||
"""
|
||||
CRUD tests for external knowledge API templates.
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_session(self):
|
||||
"""
|
||||
Patch ``db.session`` for all CRUD tests in this class.
|
||||
"""
|
||||
|
||||
with patch("services.external_knowledge_service.db.session") as mock_session:
|
||||
yield mock_session
|
||||
|
||||
def test_create_external_knowledge_api_success(self, mock_db_session: MagicMock):
|
||||
"""
|
||||
``create_external_knowledge_api`` should persist a new record
|
||||
when settings are present and valid.
|
||||
"""
|
||||
|
||||
tenant_id = "tenant-1"
|
||||
user_id = "user-1"
|
||||
args = {
|
||||
"name": "API",
|
||||
"description": "desc",
|
||||
"settings": {"endpoint": "https://api.example.com", "api_key": "secret"},
|
||||
}
|
||||
|
||||
# We do not want to actually call the remote endpoint here, so we patch the validator.
|
||||
with patch.object(ExternalDatasetService, "check_endpoint_and_api_key") as mock_check:
|
||||
result = ExternalDatasetService.create_external_knowledge_api(tenant_id, user_id, args)
|
||||
|
||||
assert isinstance(result, ExternalKnowledgeApis)
|
||||
mock_check.assert_called_once_with(args["settings"])
|
||||
mock_db_session.add.assert_called_once()
|
||||
mock_db_session.commit.assert_called_once()
|
||||
|
||||
def test_create_external_knowledge_api_missing_settings_raises(self, mock_db_session: MagicMock):
|
||||
"""
|
||||
Missing ``settings`` should result in a ``ValueError``.
|
||||
"""
|
||||
|
||||
tenant_id = "tenant-1"
|
||||
user_id = "user-1"
|
||||
args = {"name": "API", "description": "desc"}
|
||||
|
||||
with pytest.raises(ValueError, match="settings is required"):
|
||||
ExternalDatasetService.create_external_knowledge_api(tenant_id, user_id, args)
|
||||
|
||||
mock_db_session.add.assert_not_called()
|
||||
mock_db_session.commit.assert_not_called()
|
||||
|
||||
def test_get_external_knowledge_api_found(self, mock_db_session: MagicMock):
|
||||
"""
|
||||
``get_external_knowledge_api`` should return the first matching record.
|
||||
"""
|
||||
|
||||
api = Mock(spec=ExternalKnowledgeApis)
|
||||
mock_db_session.query.return_value.filter_by.return_value.first.return_value = api
|
||||
|
||||
result = ExternalDatasetService.get_external_knowledge_api("api-id")
|
||||
assert result is api
|
||||
|
||||
def test_get_external_knowledge_api_not_found_raises(self, mock_db_session: MagicMock):
|
||||
"""
|
||||
When the record is absent, a ``ValueError`` is raised.
|
||||
"""
|
||||
|
||||
mock_db_session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
|
||||
with pytest.raises(ValueError, match="api template not found"):
|
||||
ExternalDatasetService.get_external_knowledge_api("missing-id")
|
||||
|
||||
def test_update_external_knowledge_api_success_with_hidden_api_key(self, mock_db_session: MagicMock):
|
||||
"""
|
||||
Updating an API should keep the existing API key when the special hidden
|
||||
value placeholder is sent from the UI.
|
||||
"""
|
||||
|
||||
tenant_id = "tenant-1"
|
||||
user_id = "user-1"
|
||||
api_id = "api-1"
|
||||
|
||||
existing_api = Mock(spec=ExternalKnowledgeApis)
|
||||
existing_api.settings_dict = {"api_key": "stored-key"}
|
||||
existing_api.settings = '{"api_key":"stored-key"}'
|
||||
mock_db_session.query.return_value.filter_by.return_value.first.return_value = existing_api
|
||||
|
||||
args = {
|
||||
"name": "New Name",
|
||||
"description": "New Desc",
|
||||
"settings": {"endpoint": "https://api.example.com", "api_key": HIDDEN_VALUE},
|
||||
}
|
||||
|
||||
result = ExternalDatasetService.update_external_knowledge_api(tenant_id, user_id, api_id, args)
|
||||
|
||||
assert result is existing_api
|
||||
# The placeholder should be replaced with stored key.
|
||||
assert args["settings"]["api_key"] == "stored-key"
|
||||
mock_db_session.commit.assert_called_once()
|
||||
|
||||
def test_update_external_knowledge_api_not_found_raises(self, mock_db_session: MagicMock):
|
||||
"""
|
||||
Updating a non‑existent API template should raise ``ValueError``.
|
||||
"""
|
||||
|
||||
mock_db_session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
|
||||
with pytest.raises(ValueError, match="api template not found"):
|
||||
ExternalDatasetService.update_external_knowledge_api(
|
||||
tenant_id="tenant-1",
|
||||
user_id="user-1",
|
||||
external_knowledge_api_id="missing-id",
|
||||
args={"name": "n", "description": "d", "settings": {}},
|
||||
)
|
||||
|
||||
def test_delete_external_knowledge_api_success(self, mock_db_session: MagicMock):
|
||||
"""
|
||||
``delete_external_knowledge_api`` should delete and commit when found.
|
||||
"""
|
||||
|
||||
api = Mock(spec=ExternalKnowledgeApis)
|
||||
mock_db_session.query.return_value.filter_by.return_value.first.return_value = api
|
||||
|
||||
ExternalDatasetService.delete_external_knowledge_api("tenant-1", "api-1")
|
||||
|
||||
mock_db_session.delete.assert_called_once_with(api)
|
||||
mock_db_session.commit.assert_called_once()
|
||||
|
||||
def test_delete_external_knowledge_api_not_found_raises(self, mock_db_session: MagicMock):
|
||||
"""
|
||||
Deletion of a missing template should raise ``ValueError``.
|
||||
"""
|
||||
|
||||
mock_db_session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
|
||||
with pytest.raises(ValueError, match="api template not found"):
|
||||
ExternalDatasetService.delete_external_knowledge_api("tenant-1", "missing")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# external_knowledge_api_use_check & binding lookups
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestExternalDatasetServiceUsageAndBindings:
|
||||
"""
|
||||
Tests for usage checks and dataset binding retrieval.
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_session(self):
|
||||
with patch("services.external_knowledge_service.db.session") as mock_session:
|
||||
yield mock_session
|
||||
|
||||
def test_external_knowledge_api_use_check_in_use(self, mock_db_session: MagicMock):
|
||||
"""
|
||||
When there are bindings, ``external_knowledge_api_use_check`` returns True and count.
|
||||
"""
|
||||
|
||||
mock_db_session.query.return_value.filter_by.return_value.count.return_value = 3
|
||||
|
||||
in_use, count = ExternalDatasetService.external_knowledge_api_use_check("api-1")
|
||||
|
||||
assert in_use is True
|
||||
assert count == 3
|
||||
|
||||
def test_external_knowledge_api_use_check_not_in_use(self, mock_db_session: MagicMock):
|
||||
"""
|
||||
Zero bindings should return ``(False, 0)``.
|
||||
"""
|
||||
|
||||
mock_db_session.query.return_value.filter_by.return_value.count.return_value = 0
|
||||
|
||||
in_use, count = ExternalDatasetService.external_knowledge_api_use_check("api-1")
|
||||
|
||||
assert in_use is False
|
||||
assert count == 0
|
||||
|
||||
def test_get_external_knowledge_binding_with_dataset_id_found(self, mock_db_session: MagicMock):
|
||||
"""
|
||||
Binding lookup should return the first record when present.
|
||||
"""
|
||||
|
||||
binding = Mock(spec=ExternalKnowledgeBindings)
|
||||
mock_db_session.query.return_value.filter_by.return_value.first.return_value = binding
|
||||
|
||||
result = ExternalDatasetService.get_external_knowledge_binding_with_dataset_id("tenant-1", "ds-1")
|
||||
assert result is binding
|
||||
|
||||
def test_get_external_knowledge_binding_with_dataset_id_not_found_raises(self, mock_db_session: MagicMock):
|
||||
"""
|
||||
Missing binding should result in a ``ValueError``.
|
||||
"""
|
||||
|
||||
mock_db_session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
|
||||
with pytest.raises(ValueError, match="external knowledge binding not found"):
|
||||
ExternalDatasetService.get_external_knowledge_binding_with_dataset_id("tenant-1", "ds-1")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# document_create_args_validate
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestExternalDatasetServiceDocumentCreateArgsValidate:
|
||||
"""
|
||||
Tests for ``document_create_args_validate``.
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_session(self):
|
||||
with patch("services.external_knowledge_service.db.session") as mock_session:
|
||||
yield mock_session
|
||||
|
||||
def test_document_create_args_validate_success(self, mock_db_session: MagicMock):
|
||||
"""
|
||||
All required custom parameters present – validation should pass.
|
||||
"""
|
||||
|
||||
external_api = Mock(spec=ExternalKnowledgeApis)
|
||||
external_api.settings = json_settings = (
|
||||
'[{"document_process_setting":[{"name":"foo","required":true},{"name":"bar","required":false}]}]'
|
||||
)
|
||||
# Raw string; the service itself calls json.loads on it
|
||||
mock_db_session.query.return_value.filter_by.return_value.first.return_value = external_api
|
||||
|
||||
process_parameter = {"foo": "value", "bar": "optional"}
|
||||
|
||||
# Act & Assert – no exception
|
||||
ExternalDatasetService.document_create_args_validate("tenant-1", "api-1", process_parameter)
|
||||
|
||||
assert json_settings in external_api.settings # simple sanity check on our test data
|
||||
|
||||
def test_document_create_args_validate_missing_template_raises(self, mock_db_session: MagicMock):
|
||||
"""
|
||||
When the referenced API template is missing, a ``ValueError`` is raised.
|
||||
"""
|
||||
|
||||
mock_db_session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
|
||||
with pytest.raises(ValueError, match="api template not found"):
|
||||
ExternalDatasetService.document_create_args_validate("tenant-1", "missing", {})
|
||||
|
||||
def test_document_create_args_validate_missing_required_parameter_raises(self, mock_db_session: MagicMock):
|
||||
"""
|
||||
Required document process parameters must be supplied.
|
||||
"""
|
||||
|
||||
external_api = Mock(spec=ExternalKnowledgeApis)
|
||||
external_api.settings = (
|
||||
'[{"document_process_setting":[{"name":"foo","required":true},{"name":"bar","required":false}]}]'
|
||||
)
|
||||
mock_db_session.query.return_value.filter_by.return_value.first.return_value = external_api
|
||||
|
||||
process_parameter = {"bar": "present"} # missing "foo"
|
||||
|
||||
with pytest.raises(ValueError, match="foo is required"):
|
||||
ExternalDatasetService.document_create_args_validate("tenant-1", "api-1", process_parameter)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# process_external_api
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestExternalDatasetServiceProcessExternalApi:
|
||||
"""
|
||||
Tests focused on the HTTP request assembly and method mapping behaviour.
|
||||
"""
|
||||
|
||||
def test_process_external_api_valid_method_post(self):
|
||||
"""
|
||||
For a supported HTTP verb we should delegate to the correct ``ssrf_proxy`` function.
|
||||
"""
|
||||
|
||||
settings = ExternalKnowledgeApiSetting(
|
||||
url="https://example.com/path",
|
||||
request_method="POST",
|
||||
headers={"X-Test": "1"},
|
||||
params={"foo": "bar"},
|
||||
)
|
||||
|
||||
fake_response = httpx.Response(200)
|
||||
|
||||
with patch("services.external_knowledge_service.ssrf_proxy.post") as mock_post:
|
||||
mock_post.return_value = fake_response
|
||||
|
||||
result = ExternalDatasetService.process_external_api(settings, files=None)
|
||||
|
||||
assert result is fake_response
|
||||
mock_post.assert_called_once()
|
||||
kwargs = mock_post.call_args.kwargs
|
||||
assert kwargs["url"] == settings.url
|
||||
assert kwargs["headers"] == settings.headers
|
||||
assert kwargs["follow_redirects"] is True
|
||||
assert "data" in kwargs
|
||||
|
||||
def test_process_external_api_invalid_method_raises(self):
|
||||
"""
|
||||
An unsupported HTTP verb should raise ``InvalidHttpMethodError``.
|
||||
"""
|
||||
|
||||
settings = ExternalKnowledgeApiSetting(
|
||||
url="https://example.com",
|
||||
request_method="INVALID",
|
||||
headers=None,
|
||||
params={},
|
||||
)
|
||||
|
||||
from core.workflow.nodes.http_request.exc import InvalidHttpMethodError
|
||||
|
||||
with pytest.raises(InvalidHttpMethodError):
|
||||
ExternalDatasetService.process_external_api(settings, files=None)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# assembling_headers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestExternalDatasetServiceAssemblingHeaders:
|
||||
"""
|
||||
Tests for header assembly based on different authentication flavours.
|
||||
"""
|
||||
|
||||
def test_assembling_headers_bearer_token(self):
|
||||
"""
|
||||
For bearer auth we expect ``Authorization: Bearer <key>`` by default.
|
||||
"""
|
||||
|
||||
auth = Authorization(
|
||||
type="api-key",
|
||||
config=AuthorizationConfig(type="bearer", api_key="secret", header=None),
|
||||
)
|
||||
|
||||
headers = ExternalDatasetService.assembling_headers(auth)
|
||||
|
||||
assert headers["Authorization"] == "Bearer secret"
|
||||
|
||||
def test_assembling_headers_basic_token_with_custom_header(self):
|
||||
"""
|
||||
For basic auth we honour the configured header name.
|
||||
"""
|
||||
|
||||
auth = Authorization(
|
||||
type="api-key",
|
||||
config=AuthorizationConfig(type="basic", api_key="abc123", header="X-Auth"),
|
||||
)
|
||||
|
||||
headers = ExternalDatasetService.assembling_headers(auth, headers={"Existing": "1"})
|
||||
|
||||
assert headers["Existing"] == "1"
|
||||
assert headers["X-Auth"] == "Basic abc123"
|
||||
|
||||
def test_assembling_headers_custom_type(self):
|
||||
"""
|
||||
Custom auth type should inject the raw API key.
|
||||
"""
|
||||
|
||||
auth = Authorization(
|
||||
type="api-key",
|
||||
config=AuthorizationConfig(type="custom", api_key="raw-key", header="X-API-KEY"),
|
||||
)
|
||||
|
||||
headers = ExternalDatasetService.assembling_headers(auth, headers=None)
|
||||
|
||||
assert headers["X-API-KEY"] == "raw-key"
|
||||
|
||||
def test_assembling_headers_missing_config_raises(self):
|
||||
"""
|
||||
Missing config object should be rejected.
|
||||
"""
|
||||
|
||||
auth = Authorization(type="api-key", config=None)
|
||||
|
||||
with pytest.raises(ValueError, match="authorization config is required"):
|
||||
ExternalDatasetService.assembling_headers(auth)
|
||||
|
||||
def test_assembling_headers_missing_api_key_raises(self):
|
||||
"""
|
||||
``api_key`` is required when type is ``api-key``.
|
||||
"""
|
||||
|
||||
auth = Authorization(
|
||||
type="api-key",
|
||||
config=AuthorizationConfig(type="bearer", api_key=None, header="Authorization"),
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="api_key is required"):
|
||||
ExternalDatasetService.assembling_headers(auth)
|
||||
|
||||
def test_assembling_headers_no_auth_type_leaves_headers_unchanged(self):
|
||||
"""
|
||||
For ``no-auth`` we should not modify the headers mapping.
|
||||
"""
|
||||
|
||||
auth = Authorization(type="no-auth", config=None)
|
||||
|
||||
base_headers = {"X": "1"}
|
||||
result = ExternalDatasetService.assembling_headers(auth, headers=base_headers)
|
||||
|
||||
# A copy is returned, original is not mutated.
|
||||
assert result == base_headers
|
||||
assert result is not base_headers
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_external_knowledge_api_settings
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestExternalDatasetServiceGetExternalKnowledgeApiSettings:
|
||||
"""
|
||||
Simple shape test for ``get_external_knowledge_api_settings``.
|
||||
"""
|
||||
|
||||
def test_get_external_knowledge_api_settings(self):
|
||||
settings_dict: dict[str, Any] = {
|
||||
"url": "https://example.com/retrieval",
|
||||
"request_method": "post",
|
||||
"headers": {"Content-Type": "application/json"},
|
||||
"params": {"foo": "bar"},
|
||||
}
|
||||
|
||||
result = ExternalDatasetService.get_external_knowledge_api_settings(settings_dict)
|
||||
|
||||
assert isinstance(result, ExternalKnowledgeApiSetting)
|
||||
assert result.url == settings_dict["url"]
|
||||
assert result.request_method == settings_dict["request_method"]
|
||||
assert result.headers == settings_dict["headers"]
|
||||
assert result.params == settings_dict["params"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# create_external_dataset
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestExternalDatasetServiceCreateExternalDataset:
|
||||
"""
|
||||
Tests around creating the external dataset and its binding row.
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_session(self):
|
||||
with patch("services.external_knowledge_service.db.session") as mock_session:
|
||||
yield mock_session
|
||||
|
||||
def test_create_external_dataset_success(self, mock_db_session: MagicMock):
|
||||
"""
|
||||
A brand new dataset name with valid external knowledge references
|
||||
should create both the dataset and its binding.
|
||||
"""
|
||||
|
||||
tenant_id = "tenant-1"
|
||||
user_id = "user-1"
|
||||
|
||||
args = {
|
||||
"name": "My Dataset",
|
||||
"description": "desc",
|
||||
"external_knowledge_api_id": "api-1",
|
||||
"external_knowledge_id": "knowledge-1",
|
||||
"external_retrieval_model": {"top_k": 3},
|
||||
}
|
||||
|
||||
# No existing dataset with same name.
|
||||
mock_db_session.query.return_value.filter_by.return_value.first.side_effect = [
|
||||
None, # duplicate‑name check
|
||||
Mock(spec=ExternalKnowledgeApis), # external knowledge api
|
||||
]
|
||||
|
||||
dataset = ExternalDatasetService.create_external_dataset(tenant_id, user_id, args)
|
||||
|
||||
assert isinstance(dataset, Dataset)
|
||||
assert dataset.provider == "external"
|
||||
assert dataset.retrieval_model == args["external_retrieval_model"]
|
||||
|
||||
assert mock_db_session.add.call_count >= 2 # dataset + binding
|
||||
mock_db_session.flush.assert_called_once()
|
||||
mock_db_session.commit.assert_called_once()
|
||||
|
||||
def test_create_external_dataset_duplicate_name_raises(self, mock_db_session: MagicMock):
|
||||
"""
|
||||
When a dataset with the same name already exists,
|
||||
``DatasetNameDuplicateError`` is raised.
|
||||
"""
|
||||
|
||||
existing_dataset = Mock(spec=Dataset)
|
||||
mock_db_session.query.return_value.filter_by.return_value.first.return_value = existing_dataset
|
||||
|
||||
args = {
|
||||
"name": "Existing",
|
||||
"external_knowledge_api_id": "api-1",
|
||||
"external_knowledge_id": "knowledge-1",
|
||||
}
|
||||
|
||||
with pytest.raises(DatasetNameDuplicateError):
|
||||
ExternalDatasetService.create_external_dataset("tenant-1", "user-1", args)
|
||||
|
||||
mock_db_session.add.assert_not_called()
|
||||
mock_db_session.commit.assert_not_called()
|
||||
|
||||
def test_create_external_dataset_missing_api_template_raises(self, mock_db_session: MagicMock):
|
||||
"""
|
||||
If the referenced external knowledge API does not exist, a ``ValueError`` is raised.
|
||||
"""
|
||||
|
||||
# First call: duplicate name check – not found.
|
||||
mock_db_session.query.return_value.filter_by.return_value.first.side_effect = [
|
||||
None,
|
||||
None, # external knowledge api lookup
|
||||
]
|
||||
|
||||
args = {
|
||||
"name": "Dataset",
|
||||
"external_knowledge_api_id": "missing",
|
||||
"external_knowledge_id": "knowledge-1",
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError, match="api template not found"):
|
||||
ExternalDatasetService.create_external_dataset("tenant-1", "user-1", args)
|
||||
|
||||
def test_create_external_dataset_missing_required_ids_raise(self, mock_db_session: MagicMock):
|
||||
"""
|
||||
``external_knowledge_id`` and ``external_knowledge_api_id`` are mandatory.
|
||||
"""
|
||||
|
||||
# duplicate name check
|
||||
mock_db_session.query.return_value.filter_by.return_value.first.side_effect = [
|
||||
None,
|
||||
Mock(spec=ExternalKnowledgeApis),
|
||||
]
|
||||
|
||||
args_missing_knowledge_id = {
|
||||
"name": "Dataset",
|
||||
"external_knowledge_api_id": "api-1",
|
||||
"external_knowledge_id": None,
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError, match="external_knowledge_id is required"):
|
||||
ExternalDatasetService.create_external_dataset("tenant-1", "user-1", args_missing_knowledge_id)
|
||||
|
||||
args_missing_api_id = {
|
||||
"name": "Dataset",
|
||||
"external_knowledge_api_id": None,
|
||||
"external_knowledge_id": "k-1",
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError, match="external_knowledge_api_id is required"):
|
||||
ExternalDatasetService.create_external_dataset("tenant-1", "user-1", args_missing_api_id)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# fetch_external_knowledge_retrieval
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestExternalDatasetServiceFetchExternalKnowledgeRetrieval:
|
||||
"""
|
||||
Tests for ``fetch_external_knowledge_retrieval`` which orchestrates
|
||||
external retrieval requests and normalises the response payload.
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_session(self):
|
||||
with patch("services.external_knowledge_service.db.session") as mock_session:
|
||||
yield mock_session
|
||||
|
||||
def test_fetch_external_knowledge_retrieval_success(self, mock_db_session: MagicMock):
|
||||
"""
|
||||
With a valid binding and API template, records from the external
|
||||
service should be returned when the HTTP response is 200.
|
||||
"""
|
||||
|
||||
tenant_id = "tenant-1"
|
||||
dataset_id = "ds-1"
|
||||
query = "test query"
|
||||
external_retrieval_parameters = {"top_k": 3, "score_threshold_enabled": True, "score_threshold": 0.5}
|
||||
|
||||
binding = ExternalDatasetTestDataFactory.create_external_binding(
|
||||
tenant_id=tenant_id,
|
||||
dataset_id=dataset_id,
|
||||
api_id="api-1",
|
||||
external_knowledge_id="knowledge-1",
|
||||
)
|
||||
|
||||
api = Mock(spec=ExternalKnowledgeApis)
|
||||
api.settings = '{"endpoint":"https://example.com","api_key":"secret"}'
|
||||
|
||||
# First query: binding; second query: api.
|
||||
mock_db_session.query.return_value.filter_by.return_value.first.side_effect = [
|
||||
binding,
|
||||
api,
|
||||
]
|
||||
|
||||
fake_records = [{"content": "doc", "score": 0.9}]
|
||||
fake_response = Mock(spec=httpx.Response)
|
||||
fake_response.status_code = 200
|
||||
fake_response.json.return_value = {"records": fake_records}
|
||||
|
||||
metadata_condition = SimpleNamespace(model_dump=lambda: {"field": "value"})
|
||||
|
||||
with patch.object(ExternalDatasetService, "process_external_api", return_value=fake_response) as mock_process:
|
||||
result = ExternalDatasetService.fetch_external_knowledge_retrieval(
|
||||
tenant_id=tenant_id,
|
||||
dataset_id=dataset_id,
|
||||
query=query,
|
||||
external_retrieval_parameters=external_retrieval_parameters,
|
||||
metadata_condition=metadata_condition,
|
||||
)
|
||||
|
||||
assert result == fake_records
|
||||
|
||||
mock_process.assert_called_once()
|
||||
setting_arg = mock_process.call_args.args[0]
|
||||
assert isinstance(setting_arg, ExternalKnowledgeApiSetting)
|
||||
assert setting_arg.url.endswith("/retrieval")
|
||||
|
||||
def test_fetch_external_knowledge_retrieval_binding_not_found_raises(self, mock_db_session: MagicMock):
|
||||
"""
|
||||
Missing binding should raise ``ValueError``.
|
||||
"""
|
||||
|
||||
mock_db_session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
|
||||
with pytest.raises(ValueError, match="external knowledge binding not found"):
|
||||
ExternalDatasetService.fetch_external_knowledge_retrieval(
|
||||
tenant_id="tenant-1",
|
||||
dataset_id="missing",
|
||||
query="q",
|
||||
external_retrieval_parameters={},
|
||||
metadata_condition=None,
|
||||
)
|
||||
|
||||
def test_fetch_external_knowledge_retrieval_missing_api_template_raises(self, mock_db_session: MagicMock):
|
||||
"""
|
||||
When the API template is missing or has no settings, a ``ValueError`` is raised.
|
||||
"""
|
||||
|
||||
binding = ExternalDatasetTestDataFactory.create_external_binding()
|
||||
mock_db_session.query.return_value.filter_by.return_value.first.side_effect = [
|
||||
binding,
|
||||
None,
|
||||
]
|
||||
|
||||
with pytest.raises(ValueError, match="external api template not found"):
|
||||
ExternalDatasetService.fetch_external_knowledge_retrieval(
|
||||
tenant_id="tenant-1",
|
||||
dataset_id="ds-1",
|
||||
query="q",
|
||||
external_retrieval_parameters={},
|
||||
metadata_condition=None,
|
||||
)
|
||||
|
||||
def test_fetch_external_knowledge_retrieval_non_200_status_returns_empty_list(self, mock_db_session: MagicMock):
|
||||
"""
|
||||
Non‑200 responses should be treated as an empty result set.
|
||||
"""
|
||||
|
||||
binding = ExternalDatasetTestDataFactory.create_external_binding()
|
||||
api = Mock(spec=ExternalKnowledgeApis)
|
||||
api.settings = '{"endpoint":"https://example.com","api_key":"secret"}'
|
||||
|
||||
mock_db_session.query.return_value.filter_by.return_value.first.side_effect = [
|
||||
binding,
|
||||
api,
|
||||
]
|
||||
|
||||
fake_response = Mock(spec=httpx.Response)
|
||||
fake_response.status_code = 500
|
||||
fake_response.json.return_value = {}
|
||||
|
||||
with patch.object(ExternalDatasetService, "process_external_api", return_value=fake_response):
|
||||
result = ExternalDatasetService.fetch_external_knowledge_retrieval(
|
||||
tenant_id="tenant-1",
|
||||
dataset_id="ds-1",
|
||||
query="q",
|
||||
external_retrieval_parameters={},
|
||||
metadata_condition=None,
|
||||
)
|
||||
|
||||
assert result == []
|
||||
|
|
@ -0,0 +1,802 @@
|
|||
"""
|
||||
Unit tests for HitTestingService.
|
||||
|
||||
This module contains comprehensive unit tests for the HitTestingService class,
|
||||
which handles retrieval testing operations for datasets, including internal
|
||||
dataset retrieval and external knowledge base retrieval.
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.rag.models.document import Document
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from models import Account
|
||||
from models.dataset import Dataset
|
||||
from services.hit_testing_service import HitTestingService
|
||||
|
||||
|
||||
class HitTestingTestDataFactory:
|
||||
"""
|
||||
Factory class for creating test data and mock objects for hit testing service tests.
|
||||
|
||||
This factory provides static methods to create mock objects for datasets, users,
|
||||
documents, and retrieval records used in HitTestingService unit tests.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def create_dataset_mock(
|
||||
dataset_id: str = "dataset-123",
|
||||
tenant_id: str = "tenant-123",
|
||||
provider: str = "vendor",
|
||||
retrieval_model: dict | None = None,
|
||||
**kwargs,
|
||||
) -> Mock:
|
||||
"""
|
||||
Create a mock dataset with specified attributes.
|
||||
|
||||
Args:
|
||||
dataset_id: Unique identifier for the dataset
|
||||
tenant_id: Tenant identifier
|
||||
provider: Dataset provider (vendor, external, etc.)
|
||||
retrieval_model: Optional retrieval model configuration
|
||||
**kwargs: Additional attributes to set on the mock
|
||||
|
||||
Returns:
|
||||
Mock object configured as a Dataset instance
|
||||
"""
|
||||
dataset = Mock(spec=Dataset)
|
||||
dataset.id = dataset_id
|
||||
dataset.tenant_id = tenant_id
|
||||
dataset.provider = provider
|
||||
dataset.retrieval_model = retrieval_model
|
||||
for key, value in kwargs.items():
|
||||
setattr(dataset, key, value)
|
||||
return dataset
|
||||
|
||||
@staticmethod
|
||||
def create_user_mock(
|
||||
user_id: str = "user-789",
|
||||
tenant_id: str = "tenant-123",
|
||||
**kwargs,
|
||||
) -> Mock:
|
||||
"""
|
||||
Create a mock user (Account) with specified attributes.
|
||||
|
||||
Args:
|
||||
user_id: Unique identifier for the user
|
||||
tenant_id: Tenant identifier
|
||||
**kwargs: Additional attributes to set on the mock
|
||||
|
||||
Returns:
|
||||
Mock object configured as an Account instance
|
||||
"""
|
||||
user = Mock(spec=Account)
|
||||
user.id = user_id
|
||||
user.current_tenant_id = tenant_id
|
||||
user.name = "Test User"
|
||||
for key, value in kwargs.items():
|
||||
setattr(user, key, value)
|
||||
return user
|
||||
|
||||
@staticmethod
|
||||
def create_document_mock(
|
||||
content: str = "Test document content",
|
||||
metadata: dict | None = None,
|
||||
**kwargs,
|
||||
) -> Mock:
|
||||
"""
|
||||
Create a mock Document from core.rag.models.document.
|
||||
|
||||
Args:
|
||||
content: Document content/text
|
||||
metadata: Optional metadata dictionary
|
||||
**kwargs: Additional attributes to set on the mock
|
||||
|
||||
Returns:
|
||||
Mock object configured as a Document instance
|
||||
"""
|
||||
document = Mock(spec=Document)
|
||||
document.page_content = content
|
||||
document.metadata = metadata or {}
|
||||
for key, value in kwargs.items():
|
||||
setattr(document, key, value)
|
||||
return document
|
||||
|
||||
@staticmethod
|
||||
def create_retrieval_record_mock(
|
||||
content: str = "Test content",
|
||||
score: float = 0.95,
|
||||
**kwargs,
|
||||
) -> Mock:
|
||||
"""
|
||||
Create a mock retrieval record.
|
||||
|
||||
Args:
|
||||
content: Record content
|
||||
score: Retrieval score
|
||||
**kwargs: Additional fields for the record
|
||||
|
||||
Returns:
|
||||
Mock object with model_dump method returning record data
|
||||
"""
|
||||
record = Mock()
|
||||
record.model_dump.return_value = {
|
||||
"content": content,
|
||||
"score": score,
|
||||
**kwargs,
|
||||
}
|
||||
return record
|
||||
|
||||
|
||||
class TestHitTestingServiceRetrieve:
|
||||
"""
|
||||
Tests for HitTestingService.retrieve method (hit_testing).
|
||||
|
||||
This test class covers the main retrieval testing functionality, including
|
||||
various retrieval model configurations, metadata filtering, and query logging.
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_session(self):
|
||||
"""
|
||||
Mock database session.
|
||||
|
||||
Provides a mocked database session for testing database operations
|
||||
like adding and committing DatasetQuery records.
|
||||
"""
|
||||
with patch("services.hit_testing_service.db.session") as mock_db:
|
||||
yield mock_db
|
||||
|
||||
def test_retrieve_success_with_default_retrieval_model(self, mock_db_session):
|
||||
"""
|
||||
Test successful retrieval with default retrieval model.
|
||||
|
||||
Verifies that the retrieve method works correctly when no custom
|
||||
retrieval model is provided, using the default retrieval configuration.
|
||||
"""
|
||||
# Arrange
|
||||
dataset = HitTestingTestDataFactory.create_dataset_mock(retrieval_model=None)
|
||||
account = HitTestingTestDataFactory.create_user_mock()
|
||||
query = "test query"
|
||||
retrieval_model = None
|
||||
external_retrieval_model = {}
|
||||
|
||||
documents = [
|
||||
HitTestingTestDataFactory.create_document_mock(content="Doc 1"),
|
||||
HitTestingTestDataFactory.create_document_mock(content="Doc 2"),
|
||||
]
|
||||
|
||||
mock_records = [
|
||||
HitTestingTestDataFactory.create_retrieval_record_mock(content="Doc 1"),
|
||||
HitTestingTestDataFactory.create_retrieval_record_mock(content="Doc 2"),
|
||||
]
|
||||
|
||||
with (
|
||||
patch("services.hit_testing_service.RetrievalService.retrieve") as mock_retrieve,
|
||||
patch("services.hit_testing_service.RetrievalService.format_retrieval_documents") as mock_format,
|
||||
patch("services.hit_testing_service.time.perf_counter") as mock_perf_counter,
|
||||
):
|
||||
mock_perf_counter.side_effect = [0.0, 0.1] # start, end
|
||||
mock_retrieve.return_value = documents
|
||||
mock_format.return_value = mock_records
|
||||
|
||||
# Act
|
||||
result = HitTestingService.retrieve(dataset, query, account, retrieval_model, external_retrieval_model)
|
||||
|
||||
# Assert
|
||||
assert result["query"]["content"] == query
|
||||
assert len(result["records"]) == 2
|
||||
mock_retrieve.assert_called_once()
|
||||
mock_db_session.add.assert_called_once()
|
||||
mock_db_session.commit.assert_called_once()
|
||||
|
||||
def test_retrieve_success_with_custom_retrieval_model(self, mock_db_session):
|
||||
"""
|
||||
Test successful retrieval with custom retrieval model.
|
||||
|
||||
Verifies that custom retrieval model parameters (search method, reranking,
|
||||
score threshold, etc.) are properly passed to RetrievalService.
|
||||
"""
|
||||
# Arrange
|
||||
dataset = HitTestingTestDataFactory.create_dataset_mock()
|
||||
account = HitTestingTestDataFactory.create_user_mock()
|
||||
query = "test query"
|
||||
retrieval_model = {
|
||||
"search_method": RetrievalMethod.KEYWORD_SEARCH,
|
||||
"reranking_enable": True,
|
||||
"reranking_model": {"reranking_provider_name": "cohere", "reranking_model_name": "rerank-1"},
|
||||
"top_k": 5,
|
||||
"score_threshold_enabled": True,
|
||||
"score_threshold": 0.7,
|
||||
"weights": {"vector_setting": 0.5, "keyword_setting": 0.5},
|
||||
}
|
||||
external_retrieval_model = {}
|
||||
|
||||
documents = [HitTestingTestDataFactory.create_document_mock()]
|
||||
mock_records = [HitTestingTestDataFactory.create_retrieval_record_mock()]
|
||||
|
||||
with (
|
||||
patch("services.hit_testing_service.RetrievalService.retrieve") as mock_retrieve,
|
||||
patch("services.hit_testing_service.RetrievalService.format_retrieval_documents") as mock_format,
|
||||
patch("services.hit_testing_service.time.perf_counter") as mock_perf_counter,
|
||||
):
|
||||
mock_perf_counter.side_effect = [0.0, 0.1]
|
||||
mock_retrieve.return_value = documents
|
||||
mock_format.return_value = mock_records
|
||||
|
||||
# Act
|
||||
result = HitTestingService.retrieve(dataset, query, account, retrieval_model, external_retrieval_model)
|
||||
|
||||
# Assert
|
||||
assert result["query"]["content"] == query
|
||||
mock_retrieve.assert_called_once()
|
||||
call_kwargs = mock_retrieve.call_args[1]
|
||||
assert call_kwargs["retrieval_method"] == RetrievalMethod.KEYWORD_SEARCH
|
||||
assert call_kwargs["top_k"] == 5
|
||||
assert call_kwargs["score_threshold"] == 0.7
|
||||
assert call_kwargs["reranking_model"] == retrieval_model["reranking_model"]
|
||||
|
||||
def test_retrieve_with_metadata_filtering(self, mock_db_session):
|
||||
"""
|
||||
Test retrieval with metadata filtering conditions.
|
||||
|
||||
Verifies that metadata filtering conditions are properly processed
|
||||
and document ID filters are applied to the retrieval query.
|
||||
"""
|
||||
# Arrange
|
||||
dataset = HitTestingTestDataFactory.create_dataset_mock()
|
||||
account = HitTestingTestDataFactory.create_user_mock()
|
||||
query = "test query"
|
||||
retrieval_model = {
|
||||
"metadata_filtering_conditions": {
|
||||
"conditions": [
|
||||
{"field": "category", "operator": "is", "value": "test"},
|
||||
],
|
||||
},
|
||||
}
|
||||
external_retrieval_model = {}
|
||||
|
||||
mock_dataset_retrieval = MagicMock()
|
||||
mock_dataset_retrieval.get_metadata_filter_condition.return_value = (
|
||||
{dataset.id: ["doc-1", "doc-2"]},
|
||||
None,
|
||||
)
|
||||
|
||||
documents = [HitTestingTestDataFactory.create_document_mock()]
|
||||
mock_records = [HitTestingTestDataFactory.create_retrieval_record_mock()]
|
||||
|
||||
with (
|
||||
patch("services.hit_testing_service.RetrievalService.retrieve") as mock_retrieve,
|
||||
patch("services.hit_testing_service.RetrievalService.format_retrieval_documents") as mock_format,
|
||||
patch("services.hit_testing_service.DatasetRetrieval") as mock_dataset_retrieval_class,
|
||||
patch("services.hit_testing_service.time.perf_counter") as mock_perf_counter,
|
||||
):
|
||||
mock_perf_counter.side_effect = [0.0, 0.1]
|
||||
mock_dataset_retrieval_class.return_value = mock_dataset_retrieval
|
||||
mock_retrieve.return_value = documents
|
||||
mock_format.return_value = mock_records
|
||||
|
||||
# Act
|
||||
result = HitTestingService.retrieve(dataset, query, account, retrieval_model, external_retrieval_model)
|
||||
|
||||
# Assert
|
||||
assert result["query"]["content"] == query
|
||||
mock_dataset_retrieval.get_metadata_filter_condition.assert_called_once()
|
||||
call_kwargs = mock_retrieve.call_args[1]
|
||||
assert call_kwargs["document_ids_filter"] == ["doc-1", "doc-2"]
|
||||
|
||||
def test_retrieve_with_metadata_filtering_no_documents(self, mock_db_session):
|
||||
"""
|
||||
Test retrieval with metadata filtering that returns no documents.
|
||||
|
||||
Verifies that when metadata filtering results in no matching documents,
|
||||
an empty result is returned without calling RetrievalService.
|
||||
"""
|
||||
# Arrange
|
||||
dataset = HitTestingTestDataFactory.create_dataset_mock()
|
||||
account = HitTestingTestDataFactory.create_user_mock()
|
||||
query = "test query"
|
||||
retrieval_model = {
|
||||
"metadata_filtering_conditions": {
|
||||
"conditions": [
|
||||
{"field": "category", "operator": "is", "value": "test"},
|
||||
],
|
||||
},
|
||||
}
|
||||
external_retrieval_model = {}
|
||||
|
||||
mock_dataset_retrieval = MagicMock()
|
||||
mock_dataset_retrieval.get_metadata_filter_condition.return_value = ({}, True)
|
||||
|
||||
with (
|
||||
patch("services.hit_testing_service.DatasetRetrieval") as mock_dataset_retrieval_class,
|
||||
patch("services.hit_testing_service.RetrievalService.format_retrieval_documents") as mock_format,
|
||||
):
|
||||
mock_dataset_retrieval_class.return_value = mock_dataset_retrieval
|
||||
mock_format.return_value = []
|
||||
|
||||
# Act
|
||||
result = HitTestingService.retrieve(dataset, query, account, retrieval_model, external_retrieval_model)
|
||||
|
||||
# Assert
|
||||
assert result["query"]["content"] == query
|
||||
assert result["records"] == []
|
||||
|
||||
def test_retrieve_with_dataset_retrieval_model(self, mock_db_session):
|
||||
"""
|
||||
Test retrieval using dataset's retrieval model when not provided.
|
||||
|
||||
Verifies that when no retrieval model is provided, the dataset's
|
||||
retrieval model is used as a fallback.
|
||||
"""
|
||||
# Arrange
|
||||
dataset_retrieval_model = {
|
||||
"search_method": RetrievalMethod.HYBRID_SEARCH,
|
||||
"top_k": 3,
|
||||
}
|
||||
dataset = HitTestingTestDataFactory.create_dataset_mock(retrieval_model=dataset_retrieval_model)
|
||||
account = HitTestingTestDataFactory.create_user_mock()
|
||||
query = "test query"
|
||||
retrieval_model = None
|
||||
external_retrieval_model = {}
|
||||
|
||||
documents = [HitTestingTestDataFactory.create_document_mock()]
|
||||
mock_records = [HitTestingTestDataFactory.create_retrieval_record_mock()]
|
||||
|
||||
with (
|
||||
patch("services.hit_testing_service.RetrievalService.retrieve") as mock_retrieve,
|
||||
patch("services.hit_testing_service.RetrievalService.format_retrieval_documents") as mock_format,
|
||||
patch("services.hit_testing_service.time.perf_counter") as mock_perf_counter,
|
||||
):
|
||||
mock_perf_counter.side_effect = [0.0, 0.1]
|
||||
mock_retrieve.return_value = documents
|
||||
mock_format.return_value = mock_records
|
||||
|
||||
# Act
|
||||
result = HitTestingService.retrieve(dataset, query, account, retrieval_model, external_retrieval_model)
|
||||
|
||||
# Assert
|
||||
assert result["query"]["content"] == query
|
||||
call_kwargs = mock_retrieve.call_args[1]
|
||||
assert call_kwargs["retrieval_method"] == RetrievalMethod.HYBRID_SEARCH
|
||||
assert call_kwargs["top_k"] == 3
|
||||
|
||||
|
||||
class TestHitTestingServiceExternalRetrieve:
|
||||
"""
|
||||
Tests for HitTestingService.external_retrieve method.
|
||||
|
||||
This test class covers external knowledge base retrieval functionality,
|
||||
including query escaping, response formatting, and provider validation.
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_session(self):
|
||||
"""
|
||||
Mock database session.
|
||||
|
||||
Provides a mocked database session for testing database operations
|
||||
like adding and committing DatasetQuery records.
|
||||
"""
|
||||
with patch("services.hit_testing_service.db.session") as mock_db:
|
||||
yield mock_db
|
||||
|
||||
def test_external_retrieve_success(self, mock_db_session):
|
||||
"""
|
||||
Test successful external retrieval.
|
||||
|
||||
Verifies that external knowledge base retrieval works correctly,
|
||||
including query escaping, document formatting, and query logging.
|
||||
"""
|
||||
# Arrange
|
||||
dataset = HitTestingTestDataFactory.create_dataset_mock(provider="external")
|
||||
account = HitTestingTestDataFactory.create_user_mock()
|
||||
query = 'test query with "quotes"'
|
||||
external_retrieval_model = {"top_k": 5, "score_threshold": 0.8}
|
||||
metadata_filtering_conditions = {}
|
||||
|
||||
external_documents = [
|
||||
{"content": "External doc 1", "title": "Title 1", "score": 0.95, "metadata": {"key": "value"}},
|
||||
{"content": "External doc 2", "title": "Title 2", "score": 0.85, "metadata": {}},
|
||||
]
|
||||
|
||||
with (
|
||||
patch("services.hit_testing_service.RetrievalService.external_retrieve") as mock_external_retrieve,
|
||||
patch("services.hit_testing_service.time.perf_counter") as mock_perf_counter,
|
||||
):
|
||||
mock_perf_counter.side_effect = [0.0, 0.1]
|
||||
mock_external_retrieve.return_value = external_documents
|
||||
|
||||
# Act
|
||||
result = HitTestingService.external_retrieve(
|
||||
dataset, query, account, external_retrieval_model, metadata_filtering_conditions
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result["query"]["content"] == query
|
||||
assert len(result["records"]) == 2
|
||||
assert result["records"][0]["content"] == "External doc 1"
|
||||
assert result["records"][0]["title"] == "Title 1"
|
||||
assert result["records"][0]["score"] == 0.95
|
||||
mock_external_retrieve.assert_called_once()
|
||||
# Verify query was escaped
|
||||
assert mock_external_retrieve.call_args[1]["query"] == 'test query with \\"quotes\\"'
|
||||
mock_db_session.add.assert_called_once()
|
||||
mock_db_session.commit.assert_called_once()
|
||||
|
||||
def test_external_retrieve_non_external_provider(self, mock_db_session):
|
||||
"""
|
||||
Test external retrieval with non-external provider (should return empty).
|
||||
|
||||
Verifies that when the dataset provider is not "external", the method
|
||||
returns an empty result without performing retrieval or database operations.
|
||||
"""
|
||||
# Arrange
|
||||
dataset = HitTestingTestDataFactory.create_dataset_mock(provider="vendor")
|
||||
account = HitTestingTestDataFactory.create_user_mock()
|
||||
query = "test query"
|
||||
external_retrieval_model = {}
|
||||
metadata_filtering_conditions = {}
|
||||
|
||||
# Act
|
||||
result = HitTestingService.external_retrieve(
|
||||
dataset, query, account, external_retrieval_model, metadata_filtering_conditions
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result["query"]["content"] == query
|
||||
assert result["records"] == []
|
||||
mock_db_session.add.assert_not_called()
|
||||
|
||||
def test_external_retrieve_with_metadata_filtering(self, mock_db_session):
|
||||
"""
|
||||
Test external retrieval with metadata filtering conditions.
|
||||
|
||||
Verifies that metadata filtering conditions are properly passed
|
||||
to the external retrieval service.
|
||||
"""
|
||||
# Arrange
|
||||
dataset = HitTestingTestDataFactory.create_dataset_mock(provider="external")
|
||||
account = HitTestingTestDataFactory.create_user_mock()
|
||||
query = "test query"
|
||||
external_retrieval_model = {"top_k": 3}
|
||||
metadata_filtering_conditions = {"category": "test"}
|
||||
|
||||
external_documents = [{"content": "Doc 1", "title": "Title", "score": 0.9, "metadata": {}}]
|
||||
|
||||
with (
|
||||
patch("services.hit_testing_service.RetrievalService.external_retrieve") as mock_external_retrieve,
|
||||
patch("services.hit_testing_service.time.perf_counter") as mock_perf_counter,
|
||||
):
|
||||
mock_perf_counter.side_effect = [0.0, 0.1]
|
||||
mock_external_retrieve.return_value = external_documents
|
||||
|
||||
# Act
|
||||
result = HitTestingService.external_retrieve(
|
||||
dataset, query, account, external_retrieval_model, metadata_filtering_conditions
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result["query"]["content"] == query
|
||||
assert len(result["records"]) == 1
|
||||
call_kwargs = mock_external_retrieve.call_args[1]
|
||||
assert call_kwargs["metadata_filtering_conditions"] == metadata_filtering_conditions
|
||||
|
||||
def test_external_retrieve_empty_documents(self, mock_db_session):
|
||||
"""
|
||||
Test external retrieval with empty document list.
|
||||
|
||||
Verifies that when external retrieval returns no documents,
|
||||
an empty result is properly formatted and returned.
|
||||
"""
|
||||
# Arrange
|
||||
dataset = HitTestingTestDataFactory.create_dataset_mock(provider="external")
|
||||
account = HitTestingTestDataFactory.create_user_mock()
|
||||
query = "test query"
|
||||
external_retrieval_model = {}
|
||||
metadata_filtering_conditions = {}
|
||||
|
||||
with (
|
||||
patch("services.hit_testing_service.RetrievalService.external_retrieve") as mock_external_retrieve,
|
||||
patch("services.hit_testing_service.time.perf_counter") as mock_perf_counter,
|
||||
):
|
||||
mock_perf_counter.side_effect = [0.0, 0.1]
|
||||
mock_external_retrieve.return_value = []
|
||||
|
||||
# Act
|
||||
result = HitTestingService.external_retrieve(
|
||||
dataset, query, account, external_retrieval_model, metadata_filtering_conditions
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result["query"]["content"] == query
|
||||
assert result["records"] == []
|
||||
|
||||
|
||||
class TestHitTestingServiceCompactRetrieveResponse:
|
||||
"""
|
||||
Tests for HitTestingService.compact_retrieve_response method.
|
||||
|
||||
This test class covers response formatting for internal dataset retrieval,
|
||||
ensuring documents are properly formatted into retrieval records.
|
||||
"""
|
||||
|
||||
def test_compact_retrieve_response_success(self):
|
||||
"""
|
||||
Test successful response formatting.
|
||||
|
||||
Verifies that documents are properly formatted into retrieval records
|
||||
with correct structure and data.
|
||||
"""
|
||||
# Arrange
|
||||
query = "test query"
|
||||
documents = [
|
||||
HitTestingTestDataFactory.create_document_mock(content="Doc 1"),
|
||||
HitTestingTestDataFactory.create_document_mock(content="Doc 2"),
|
||||
]
|
||||
|
||||
mock_records = [
|
||||
HitTestingTestDataFactory.create_retrieval_record_mock(content="Doc 1", score=0.95),
|
||||
HitTestingTestDataFactory.create_retrieval_record_mock(content="Doc 2", score=0.85),
|
||||
]
|
||||
|
||||
with patch("services.hit_testing_service.RetrievalService.format_retrieval_documents") as mock_format:
|
||||
mock_format.return_value = mock_records
|
||||
|
||||
# Act
|
||||
result = HitTestingService.compact_retrieve_response(query, documents)
|
||||
|
||||
# Assert
|
||||
assert result["query"]["content"] == query
|
||||
assert len(result["records"]) == 2
|
||||
assert result["records"][0]["content"] == "Doc 1"
|
||||
assert result["records"][0]["score"] == 0.95
|
||||
mock_format.assert_called_once_with(documents)
|
||||
|
||||
def test_compact_retrieve_response_empty_documents(self):
|
||||
"""
|
||||
Test response formatting with empty document list.
|
||||
|
||||
Verifies that an empty document list results in an empty records array
|
||||
while maintaining the correct response structure.
|
||||
"""
|
||||
# Arrange
|
||||
query = "test query"
|
||||
documents = []
|
||||
|
||||
with patch("services.hit_testing_service.RetrievalService.format_retrieval_documents") as mock_format:
|
||||
mock_format.return_value = []
|
||||
|
||||
# Act
|
||||
result = HitTestingService.compact_retrieve_response(query, documents)
|
||||
|
||||
# Assert
|
||||
assert result["query"]["content"] == query
|
||||
assert result["records"] == []
|
||||
|
||||
|
||||
class TestHitTestingServiceCompactExternalRetrieveResponse:
|
||||
"""
|
||||
Tests for HitTestingService.compact_external_retrieve_response method.
|
||||
|
||||
This test class covers response formatting for external knowledge base
|
||||
retrieval, ensuring proper field extraction and provider validation.
|
||||
"""
|
||||
|
||||
def test_compact_external_retrieve_response_external_provider(self):
|
||||
"""
|
||||
Test external response formatting for external provider.
|
||||
|
||||
Verifies that external documents are properly formatted with all
|
||||
required fields (content, title, score, metadata).
|
||||
"""
|
||||
# Arrange
|
||||
dataset = HitTestingTestDataFactory.create_dataset_mock(provider="external")
|
||||
query = "test query"
|
||||
documents = [
|
||||
{"content": "Doc 1", "title": "Title 1", "score": 0.95, "metadata": {"key": "value"}},
|
||||
{"content": "Doc 2", "title": "Title 2", "score": 0.85, "metadata": {}},
|
||||
]
|
||||
|
||||
# Act
|
||||
result = HitTestingService.compact_external_retrieve_response(dataset, query, documents)
|
||||
|
||||
# Assert
|
||||
assert result["query"]["content"] == query
|
||||
assert len(result["records"]) == 2
|
||||
assert result["records"][0]["content"] == "Doc 1"
|
||||
assert result["records"][0]["title"] == "Title 1"
|
||||
assert result["records"][0]["score"] == 0.95
|
||||
assert result["records"][0]["metadata"] == {"key": "value"}
|
||||
|
||||
def test_compact_external_retrieve_response_non_external_provider(self):
|
||||
"""
|
||||
Test external response formatting for non-external provider.
|
||||
|
||||
Verifies that non-external providers return an empty records array
|
||||
regardless of input documents.
|
||||
"""
|
||||
# Arrange
|
||||
dataset = HitTestingTestDataFactory.create_dataset_mock(provider="vendor")
|
||||
query = "test query"
|
||||
documents = [{"content": "Doc 1"}]
|
||||
|
||||
# Act
|
||||
result = HitTestingService.compact_external_retrieve_response(dataset, query, documents)
|
||||
|
||||
# Assert
|
||||
assert result["query"]["content"] == query
|
||||
assert result["records"] == []
|
||||
|
||||
def test_compact_external_retrieve_response_missing_fields(self):
|
||||
"""
|
||||
Test external response formatting with missing optional fields.
|
||||
|
||||
Verifies that missing optional fields (title, score, metadata) are
|
||||
handled gracefully by setting them to None.
|
||||
"""
|
||||
# Arrange
|
||||
dataset = HitTestingTestDataFactory.create_dataset_mock(provider="external")
|
||||
query = "test query"
|
||||
documents = [
|
||||
{"content": "Doc 1"}, # Missing title, score, metadata
|
||||
{"content": "Doc 2", "title": "Title 2"}, # Missing score, metadata
|
||||
]
|
||||
|
||||
# Act
|
||||
result = HitTestingService.compact_external_retrieve_response(dataset, query, documents)
|
||||
|
||||
# Assert
|
||||
assert result["query"]["content"] == query
|
||||
assert len(result["records"]) == 2
|
||||
assert result["records"][0]["content"] == "Doc 1"
|
||||
assert result["records"][0]["title"] is None
|
||||
assert result["records"][0]["score"] is None
|
||||
assert result["records"][0]["metadata"] is None
|
||||
|
||||
|
||||
class TestHitTestingServiceHitTestingArgsCheck:
|
||||
"""
|
||||
Tests for HitTestingService.hit_testing_args_check method.
|
||||
|
||||
This test class covers query argument validation, ensuring queries
|
||||
meet the required criteria (non-empty, max 250 characters).
|
||||
"""
|
||||
|
||||
def test_hit_testing_args_check_success(self):
|
||||
"""
|
||||
Test successful argument validation.
|
||||
|
||||
Verifies that valid queries pass validation without raising errors.
|
||||
"""
|
||||
# Arrange
|
||||
args = {"query": "valid query"}
|
||||
|
||||
# Act & Assert (should not raise)
|
||||
HitTestingService.hit_testing_args_check(args)
|
||||
|
||||
def test_hit_testing_args_check_empty_query(self):
|
||||
"""
|
||||
Test validation fails with empty query.
|
||||
|
||||
Verifies that empty queries raise a ValueError with appropriate message.
|
||||
"""
|
||||
# Arrange
|
||||
args = {"query": ""}
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="Query is required and cannot exceed 250 characters"):
|
||||
HitTestingService.hit_testing_args_check(args)
|
||||
|
||||
def test_hit_testing_args_check_none_query(self):
|
||||
"""
|
||||
Test validation fails with None query.
|
||||
|
||||
Verifies that None queries raise a ValueError with appropriate message.
|
||||
"""
|
||||
# Arrange
|
||||
args = {"query": None}
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="Query is required and cannot exceed 250 characters"):
|
||||
HitTestingService.hit_testing_args_check(args)
|
||||
|
||||
def test_hit_testing_args_check_too_long_query(self):
|
||||
"""
|
||||
Test validation fails with query exceeding 250 characters.
|
||||
|
||||
Verifies that queries longer than 250 characters raise a ValueError.
|
||||
"""
|
||||
# Arrange
|
||||
args = {"query": "a" * 251}
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="Query is required and cannot exceed 250 characters"):
|
||||
HitTestingService.hit_testing_args_check(args)
|
||||
|
||||
def test_hit_testing_args_check_exactly_250_characters(self):
|
||||
"""
|
||||
Test validation succeeds with exactly 250 characters.
|
||||
|
||||
Verifies that queries with exactly 250 characters (the maximum)
|
||||
pass validation successfully.
|
||||
"""
|
||||
# Arrange
|
||||
args = {"query": "a" * 250}
|
||||
|
||||
# Act & Assert (should not raise)
|
||||
HitTestingService.hit_testing_args_check(args)
|
||||
|
||||
|
||||
class TestHitTestingServiceEscapeQueryForSearch:
|
||||
"""
|
||||
Tests for HitTestingService.escape_query_for_search method.
|
||||
|
||||
This test class covers query escaping functionality for external search,
|
||||
ensuring special characters are properly escaped.
|
||||
"""
|
||||
|
||||
def test_escape_query_for_search_with_quotes(self):
|
||||
"""
|
||||
Test escaping quotes in query.
|
||||
|
||||
Verifies that double quotes in queries are properly escaped with
|
||||
backslashes for external search compatibility.
|
||||
"""
|
||||
# Arrange
|
||||
query = 'test query with "quotes"'
|
||||
|
||||
# Act
|
||||
result = HitTestingService.escape_query_for_search(query)
|
||||
|
||||
# Assert
|
||||
assert result == 'test query with \\"quotes\\"'
|
||||
|
||||
def test_escape_query_for_search_without_quotes(self):
|
||||
"""
|
||||
Test query without quotes (no change).
|
||||
|
||||
Verifies that queries without quotes remain unchanged after escaping.
|
||||
"""
|
||||
# Arrange
|
||||
query = "test query without quotes"
|
||||
|
||||
# Act
|
||||
result = HitTestingService.escape_query_for_search(query)
|
||||
|
||||
# Assert
|
||||
assert result == query
|
||||
|
||||
def test_escape_query_for_search_multiple_quotes(self):
|
||||
"""
|
||||
Test escaping multiple quotes in query.
|
||||
|
||||
Verifies that all occurrences of double quotes in a query are
|
||||
properly escaped, not just the first one.
|
||||
"""
|
||||
# Arrange
|
||||
query = 'test "query" with "multiple" quotes'
|
||||
|
||||
# Act
|
||||
result = HitTestingService.escape_query_for_search(query)
|
||||
|
||||
# Assert
|
||||
assert result == 'test \\"query\\" with \\"multiple\\" quotes'
|
||||
|
||||
def test_escape_query_for_search_empty_string(self):
|
||||
"""
|
||||
Test escaping empty string.
|
||||
|
||||
Verifies that empty strings are handled correctly and remain empty
|
||||
after the escaping operation.
|
||||
"""
|
||||
# Arrange
|
||||
query = ""
|
||||
|
||||
# Act
|
||||
result = HitTestingService.escape_query_for_search(query)
|
||||
|
||||
# Assert
|
||||
assert result == ""
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue