Merge branch 'fix/create-app-from-dsl' into deploy/dev

This commit is contained in:
Yansong Zhang 2026-04-15 14:12:21 +08:00
commit b06ab95084
677 changed files with 8068 additions and 7119 deletions

View File

@ -62,7 +62,7 @@ jobs:
- name: Render coverage markdown from structured data
id: render
run: |
comment_body="$(uv run --directory api python api/libs/pyrefly_type_coverage.py \
comment_body="$(uv run --directory api python libs/pyrefly_type_coverage.py \
--base base_report.json \
< pr_report.json)"

View File

@ -57,6 +57,9 @@ REDIS_SSL_CERTFILE=
REDIS_SSL_KEYFILE=
# Path to client private key file for SSL authentication
REDIS_DB=0
# Optional global prefix for Redis keys, topics, streams, and Celery Redis transport artifacts.
# Leave empty to preserve current unprefixed behavior.
REDIS_KEY_PREFIX=
# redis Sentinel configuration.
REDIS_USE_SENTINEL=false

View File

@ -32,6 +32,11 @@ class RedisConfig(BaseSettings):
default=0,
)
REDIS_KEY_PREFIX: str = Field(
description="Optional global prefix for Redis keys, topics, and transport artifacts",
default="",
)
REDIS_USE_SSL: bool = Field(
description="Enable SSL/TLS for the Redis connection",
default=False,

View File

@ -6,10 +6,9 @@ from typing import Any, Literal
from flask import request
from flask_restx import Resource
from graphon.enums import WorkflowExecutionStatus
from graphon.file import helpers as file_helpers
from pydantic import AliasChoices, BaseModel, Field, computed_field, field_validator
from sqlalchemy import select
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm import Session
from werkzeug.exceptions import BadRequest
from controllers.common.helpers import FileInfo
@ -31,13 +30,14 @@ from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.trigger.constants import TRIGGER_NODE_TYPES
from extensions.ext_database import db
from fields.base import ResponseModel
from libs.helper import build_icon_url
from libs.login import current_account_with_tenant, login_required
from models import App, DatasetPermissionEnum, Workflow
from models.model import IconType
from services.app_dsl_service import AppDslService
from services.app_service import AppService
from services.enterprise.enterprise_service import EnterpriseService
from services.entities.dsl_entities import ImportMode
from services.entities.dsl_entities import ImportMode, ImportStatus
from services.entities.knowledge_entities.knowledge_entities import (
DataSource,
InfoList,
@ -161,15 +161,6 @@ def _to_timestamp(value: datetime | int | None) -> int | None:
return value
def _build_icon_url(icon_type: str | IconType | None, icon: str | None) -> str | None:
if icon is None or icon_type is None:
return None
icon_type_value = icon_type.value if isinstance(icon_type, IconType) else str(icon_type)
if icon_type_value.lower() != IconType.IMAGE:
return None
return file_helpers.get_signed_file_url(icon)
class Tag(ResponseModel):
id: str
name: str
@ -292,7 +283,7 @@ class Site(ResponseModel):
@computed_field(return_type=str | None) # type: ignore
@property
def icon_url(self) -> str | None:
return _build_icon_url(self.icon_type, self.icon)
return build_icon_url(self.icon_type, self.icon)
@field_validator("icon_type", mode="before")
@classmethod
@ -342,7 +333,7 @@ class AppPartial(ResponseModel):
@computed_field(return_type=str | None) # type: ignore
@property
def icon_url(self) -> str | None:
return _build_icon_url(self.icon_type, self.icon)
return build_icon_url(self.icon_type, self.icon)
@field_validator("created_at", "updated_at", mode="before")
@classmethod
@ -390,7 +381,7 @@ class AppDetailWithSite(AppDetail):
@computed_field(return_type=str | None) # type: ignore
@property
def icon_url(self) -> str | None:
return _build_icon_url(self.icon_type, self.icon)
return build_icon_url(self.icon_type, self.icon)
class AppPagination(ResponseModel):
@ -632,7 +623,7 @@ class AppCopyApi(Resource):
args = CopyAppPayload.model_validate(console_ns.payload or {})
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
with Session(db.engine, expire_on_commit=False) as session:
import_service = AppDslService(session)
yaml_content = import_service.export_dsl(app_model=app_model, include_secret=True)
result = import_service.import_app(
@ -645,6 +636,13 @@ class AppCopyApi(Resource):
icon=args.icon,
icon_background=args.icon_background,
)
if result.status == ImportStatus.FAILED:
session.rollback()
return result.model_dump(mode="json"), 400
if result.status == ImportStatus.PENDING:
session.rollback()
return result.model_dump(mode="json"), 202
session.commit()
# Inherit web app permission from original app
if result.app_id and FeatureService.get_system_features().webapp_auth.enabled:

View File

@ -1,6 +1,6 @@
from flask_restx import Resource
from pydantic import BaseModel, Field
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm import Session
from controllers.common.schema import register_schema_models
from controllers.console.app.wraps import get_app_model
@ -52,8 +52,9 @@ class AppImportApi(Resource):
current_user, _ = current_account_with_tenant()
args = AppImportPayload.model_validate(console_ns.payload)
# Create service with session
with sessionmaker(db.engine).begin() as session:
# AppDslService performs internal commits for some creation paths, so use a plain
# Session here instead of nesting it inside sessionmaker(...).begin().
with Session(db.engine, expire_on_commit=False) as session:
import_service = AppDslService(session)
# Import app
account = current_user
@ -69,6 +70,10 @@ class AppImportApi(Resource):
icon_background=args.icon_background,
app_id=args.app_id,
)
if result.status == ImportStatus.FAILED:
session.rollback()
else:
session.commit()
if result.app_id and FeatureService.get_system_features().webapp_auth.enabled:
# update web app setting as private
EnterpriseService.WebAppAuth.update_app_access_mode(result.app_id, "private")
@ -95,12 +100,15 @@ class AppImportConfirmApi(Resource):
# Check user role first
current_user, _ = current_account_with_tenant()
# Create service with session
with sessionmaker(db.engine).begin() as session:
with Session(db.engine, expire_on_commit=False) as session:
import_service = AppDslService(session)
# Confirm import
account = current_user
result = import_service.confirm_import(import_id=import_id, account=account)
if result.status == ImportStatus.FAILED:
session.rollback()
else:
session.commit()
# Return appropriate status code based on result
if result.status == ImportStatus.FAILED:
@ -117,7 +125,7 @@ class AppImportCheckDependenciesApi(Resource):
@account_initialization_required
@edit_permission_required
def get(self, app_model: App):
with sessionmaker(db.engine).begin() as session:
with Session(db.engine, expire_on_commit=False) as session:
import_service = AppDslService(session)
result = import_service.check_dependencies(app_model=app_model)

View File

@ -1,44 +1,86 @@
from __future__ import annotations
from datetime import datetime
from typing import Any
from flask import request
from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field
from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator
from sqlalchemy import select
from sqlalchemy.orm import sessionmaker
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from extensions.ext_database import db
from fields.conversation_variable_fields import (
conversation_variable_fields,
paginated_conversation_variable_fields,
)
from fields._value_type_serializer import serialize_value_type
from fields.base import ResponseModel
from libs.login import login_required
from models import ConversationVariable
from models.model import AppMode
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class ConversationVariablesQuery(BaseModel):
conversation_id: str = Field(..., description="Conversation ID to filter variables")
console_ns.schema_model(
ConversationVariablesQuery.__name__,
ConversationVariablesQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
def _to_timestamp(value: datetime | int | None) -> int | None:
if isinstance(value, datetime):
return int(value.timestamp())
return value
# Register models for flask_restx to avoid dict type issues in Swagger
# Register base model first
conversation_variable_model = console_ns.model("ConversationVariable", conversation_variable_fields)
# For nested models, need to replace nested dict with registered model
paginated_conversation_variable_fields_copy = paginated_conversation_variable_fields.copy()
paginated_conversation_variable_fields_copy["data"] = fields.List(
fields.Nested(conversation_variable_model), attribute="data"
)
paginated_conversation_variable_model = console_ns.model(
"PaginatedConversationVariable", paginated_conversation_variable_fields_copy
class ConversationVariableResponse(ResponseModel):
id: str
name: str
value_type: str
value: str | None = None
description: str | None = None
created_at: int | None = None
updated_at: int | None = None
@field_validator("value_type", mode="before")
@classmethod
def _normalize_value_type(cls, value: Any) -> str:
exposed_type = getattr(value, "exposed_type", None)
if callable(exposed_type):
return str(exposed_type().value)
if isinstance(value, str):
return value
try:
return serialize_value_type(value)
except Exception:
return serialize_value_type({"value_type": value})
@field_validator("value", mode="before")
@classmethod
def _normalize_value(cls, value: Any | None) -> str | None:
if value is None:
return None
if isinstance(value, str):
return value
return str(value)
@field_validator("created_at", "updated_at", mode="before")
@classmethod
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
return _to_timestamp(value)
class PaginatedConversationVariableResponse(ResponseModel):
page: int
limit: int
total: int
has_more: bool
data: list[ConversationVariableResponse]
register_schema_models(
console_ns,
ConversationVariablesQuery,
ConversationVariableResponse,
PaginatedConversationVariableResponse,
)
@ -48,12 +90,15 @@ class ConversationVariablesApi(Resource):
@console_ns.doc(description="Get conversation variables for an application")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[ConversationVariablesQuery.__name__])
@console_ns.response(200, "Conversation variables retrieved successfully", paginated_conversation_variable_model)
@console_ns.response(
200,
"Conversation variables retrieved successfully",
console_ns.models[PaginatedConversationVariableResponse.__name__],
)
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=AppMode.ADVANCED_CHAT)
@marshal_with(paginated_conversation_variable_model)
def get(self, app_model):
args = ConversationVariablesQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
@ -72,17 +117,22 @@ class ConversationVariablesApi(Resource):
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
rows = session.scalars(stmt).all()
return {
"page": page,
"limit": page_size,
"total": len(rows),
"has_more": False,
"data": [
{
"created_at": row.created_at,
"updated_at": row.updated_at,
**row.to_variable().model_dump(),
}
for row in rows
],
}
response = PaginatedConversationVariableResponse.model_validate(
{
"page": page,
"limit": page_size,
"total": len(rows),
"has_more": False,
"data": [
ConversationVariableResponse.model_validate(
{
"created_at": row.created_at,
"updated_at": row.updated_at,
**row.to_variable().model_dump(),
}
)
for row in rows
],
}
)
return response.model_dump(mode="json")

View File

@ -1,8 +1,9 @@
import logging
from datetime import datetime
from typing import Literal
from flask import request
from flask_restx import Resource, fields, marshal_with
from flask_restx import Resource
from graphon.model_runtime.errors.invoke import InvokeError
from pydantic import BaseModel, Field, field_validator
from sqlalchemy import exists, func, select
@ -25,10 +26,21 @@ from controllers.console.wraps import (
setup_required,
)
from core.app.entities.app_invoke_entities import InvokeFrom
from core.entities.execution_extra_content import ExecutionExtraContentDomainModel
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from extensions.ext_database import db
from fields.raws import FilesContainedField
from libs.helper import TimestampField, uuid_value
from fields.base import ResponseModel
from fields.conversation_fields import (
AgentThought,
ConversationAnnotation,
ConversationAnnotationHitHistory,
Feedback,
JSONValue,
MessageFile,
format_files_contained,
to_timestamp,
)
from libs.helper import uuid_value
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from libs.login import current_account_with_tenant, login_required
from models.enums import FeedbackFromSource, FeedbackRating
@ -98,6 +110,51 @@ class SuggestedQuestionsResponse(BaseModel):
data: list[str] = Field(description="Suggested question")
class MessageDetailResponse(ResponseModel):
id: str
conversation_id: str
inputs: dict[str, JSONValue]
query: str
message: JSONValue | None = None
message_tokens: int | None = None
answer: str = Field(validation_alias="re_sign_file_url_answer")
answer_tokens: int | None = None
provider_response_latency: float | None = None
from_source: str
from_end_user_id: str | None = None
from_account_id: str | None = None
feedbacks: list[Feedback] = Field(default_factory=list)
workflow_run_id: str | None = None
annotation: ConversationAnnotation | None = None
annotation_hit_history: ConversationAnnotationHitHistory | None = None
created_at: int | None = None
agent_thoughts: list[AgentThought] = Field(default_factory=list)
message_files: list[MessageFile] = Field(default_factory=list)
extra_contents: list[ExecutionExtraContentDomainModel] = Field(default_factory=list)
metadata: JSONValue | None = Field(default=None, validation_alias="message_metadata_dict")
status: str
error: str | None = None
parent_message_id: str | None = None
@field_validator("inputs", mode="before")
@classmethod
def _normalize_inputs(cls, value: JSONValue) -> JSONValue:
return format_files_contained(value)
@field_validator("created_at", mode="before")
@classmethod
def _normalize_created_at(cls, value: datetime | int | None) -> int | None:
if isinstance(value, datetime):
return to_timestamp(value)
return value
class MessageInfiniteScrollPaginationResponse(ResponseModel):
limit: int
has_more: bool
data: list[MessageDetailResponse]
register_schema_models(
console_ns,
ChatMessagesQuery,
@ -105,124 +162,8 @@ register_schema_models(
FeedbackExportQuery,
AnnotationCountResponse,
SuggestedQuestionsResponse,
)
# Register models for flask_restx to avoid dict type issues in Swagger
# Register in dependency order: base models first, then dependent models
# Base models
simple_account_model = console_ns.model(
"SimpleAccount",
{
"id": fields.String,
"name": fields.String,
"email": fields.String,
},
)
message_file_model = console_ns.model(
"MessageFile",
{
"id": fields.String,
"filename": fields.String,
"type": fields.String,
"url": fields.String,
"mime_type": fields.String,
"size": fields.Integer,
"transfer_method": fields.String,
"belongs_to": fields.String(default="user"),
"upload_file_id": fields.String(default=None),
},
)
agent_thought_model = console_ns.model(
"AgentThought",
{
"id": fields.String,
"chain_id": fields.String,
"message_id": fields.String,
"position": fields.Integer,
"thought": fields.String,
"tool": fields.String,
"tool_labels": fields.Raw,
"tool_input": fields.String,
"created_at": TimestampField,
"observation": fields.String,
"files": fields.List(fields.String),
},
)
# Models that depend on simple_account_model
feedback_model = console_ns.model(
"Feedback",
{
"rating": fields.String,
"content": fields.String,
"from_source": fields.String,
"from_end_user_id": fields.String,
"from_account": fields.Nested(simple_account_model, allow_null=True),
},
)
annotation_model = console_ns.model(
"Annotation",
{
"id": fields.String,
"question": fields.String,
"content": fields.String,
"account": fields.Nested(simple_account_model, allow_null=True),
"created_at": TimestampField,
},
)
annotation_hit_history_model = console_ns.model(
"AnnotationHitHistory",
{
"annotation_id": fields.String(attribute="id"),
"annotation_create_account": fields.Nested(simple_account_model, allow_null=True),
"created_at": TimestampField,
},
)
# Message detail model that depends on multiple models
message_detail_model = console_ns.model(
"MessageDetail",
{
"id": fields.String,
"conversation_id": fields.String,
"inputs": FilesContainedField,
"query": fields.String,
"message": fields.Raw,
"message_tokens": fields.Integer,
"answer": fields.String(attribute="re_sign_file_url_answer"),
"answer_tokens": fields.Integer,
"provider_response_latency": fields.Float,
"from_source": fields.String,
"from_end_user_id": fields.String,
"from_account_id": fields.String,
"feedbacks": fields.List(fields.Nested(feedback_model)),
"workflow_run_id": fields.String,
"annotation": fields.Nested(annotation_model, allow_null=True),
"annotation_hit_history": fields.Nested(annotation_hit_history_model, allow_null=True),
"created_at": TimestampField,
"agent_thoughts": fields.List(fields.Nested(agent_thought_model)),
"message_files": fields.List(fields.Nested(message_file_model)),
"extra_contents": fields.List(fields.Raw),
"metadata": fields.Raw(attribute="message_metadata_dict"),
"status": fields.String,
"error": fields.String,
"parent_message_id": fields.String,
},
)
# Message infinite scroll pagination model
message_infinite_scroll_pagination_model = console_ns.model(
"MessageInfiniteScrollPagination",
{
"limit": fields.Integer,
"has_more": fields.Boolean,
"data": fields.List(fields.Nested(message_detail_model)),
},
MessageDetailResponse,
MessageInfiniteScrollPaginationResponse,
)
@ -232,13 +173,12 @@ class ChatMessageListApi(Resource):
@console_ns.doc(description="Get chat messages for a conversation with pagination")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[ChatMessagesQuery.__name__])
@console_ns.response(200, "Success", message_infinite_scroll_pagination_model)
@console_ns.response(200, "Success", console_ns.models[MessageInfiniteScrollPaginationResponse.__name__])
@console_ns.response(404, "Conversation not found")
@login_required
@account_initialization_required
@setup_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
@marshal_with(message_infinite_scroll_pagination_model)
@edit_permission_required
def get(self, app_model):
args = ChatMessagesQuery.model_validate(request.args.to_dict())
@ -298,7 +238,10 @@ class ChatMessageListApi(Resource):
history_messages = list(reversed(history_messages))
attach_message_extra_contents(history_messages)
return InfiniteScrollPagination(data=history_messages, limit=args.limit, has_more=has_more)
return MessageInfiniteScrollPaginationResponse.model_validate(
InfiniteScrollPagination(data=history_messages, limit=args.limit, has_more=has_more),
from_attributes=True,
).model_dump(mode="json")
@console_ns.route("/apps/<uuid:app_id>/feedbacks")
@ -468,13 +411,12 @@ class MessageApi(Resource):
@console_ns.doc("get_message")
@console_ns.doc(description="Get message details by ID")
@console_ns.doc(params={"app_id": "Application ID", "message_id": "Message ID"})
@console_ns.response(200, "Message retrieved successfully", message_detail_model)
@console_ns.response(200, "Message retrieved successfully", console_ns.models[MessageDetailResponse.__name__])
@console_ns.response(404, "Message not found")
@get_app_model
@setup_required
@login_required
@account_initialization_required
@marshal_with(message_detail_model)
def get(self, app_model, message_id: str):
message_id = str(message_id)
@ -486,4 +428,4 @@ class MessageApi(Resource):
raise NotFound("Message Not Exists.")
attach_message_extra_contents([message])
return message
return MessageDetailResponse.model_validate(message, from_attributes=True).model_dump(mode="json")

View File

@ -4,7 +4,7 @@ from collections.abc import Sequence
from typing import Any
from flask import abort, request
from flask_restx import Resource, fields, marshal_with
from flask_restx import Resource, fields, marshal, marshal_with
from graphon.enums import NodeType
from graphon.file import File
from graphon.graph_engine.manager import GraphEngineManager
@ -942,7 +942,6 @@ class PublishedAllWorkflowApi(Resource):
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@marshal_with(workflow_pagination_model)
@edit_permission_required
def get(self, app_model: App):
"""
@ -970,9 +969,10 @@ class PublishedAllWorkflowApi(Resource):
user_id=user_id,
named_only=named_only,
)
serialized_workflows = marshal(workflows, workflow_fields_copy)
return {
"items": workflows,
"items": serialized_workflows,
"page": page,
"limit": limit,
"has_more": has_more,

View File

@ -1,27 +1,26 @@
from datetime import datetime
from typing import Any
from dateutil.parser import isoparse
from flask import request
from flask_restx import Resource, marshal_with
from flask_restx import Resource
from graphon.enums import WorkflowExecutionStatus
from pydantic import BaseModel, Field, field_validator
from sqlalchemy.orm import sessionmaker
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from extensions.ext_database import db
from fields.workflow_app_log_fields import (
build_workflow_app_log_pagination_model,
build_workflow_archived_log_pagination_model,
)
from fields.base import ResponseModel
from fields.end_user_fields import SimpleEndUser
from fields.member_fields import SimpleAccount
from libs.login import login_required
from models import App
from models.model import AppMode
from services.workflow_app_service import WorkflowAppService
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class WorkflowAppLogQuery(BaseModel):
keyword: str | None = Field(default=None, description="Search keyword for filtering logs")
@ -58,13 +57,113 @@ class WorkflowAppLogQuery(BaseModel):
raise ValueError("Invalid boolean value for detail")
console_ns.schema_model(
WorkflowAppLogQuery.__name__, WorkflowAppLogQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
class WorkflowRunForLogResponse(ResponseModel):
id: str
version: str | None = None
status: str | None = None
triggered_from: str | None = None
error: str | None = None
elapsed_time: float | None = None
total_tokens: int | None = None
total_steps: int | None = None
created_at: int | None = None
finished_at: int | None = None
exceptions_count: int | None = None
# Register model for flask_restx to avoid dict type issues in Swagger
workflow_app_log_pagination_model = build_workflow_app_log_pagination_model(console_ns)
workflow_archived_log_pagination_model = build_workflow_archived_log_pagination_model(console_ns)
@field_validator("status", mode="before")
@classmethod
def _normalize_status(cls, value: Any) -> str | None:
if value is None:
return None
if isinstance(value, str):
return value
return str(getattr(value, "value", value))
@field_validator("created_at", "finished_at", mode="before")
@classmethod
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
if isinstance(value, datetime):
return int(value.timestamp())
return value
class WorkflowRunForArchivedLogResponse(ResponseModel):
id: str
status: str | None = None
triggered_from: str | None = None
elapsed_time: float | None = None
total_tokens: int | None = None
@field_validator("status", mode="before")
@classmethod
def _normalize_status(cls, value: Any) -> str | None:
if value is None:
return None
if isinstance(value, str):
return value
return str(getattr(value, "value", value))
class WorkflowAppLogPartialResponse(ResponseModel):
id: str
workflow_run: WorkflowRunForLogResponse | None = None
details: Any = None
created_from: str | None = None
created_by_role: str | None = None
created_by_account: SimpleAccount | None = None
created_by_end_user: SimpleEndUser | None = None
created_at: int | None = None
@field_validator("created_at", mode="before")
@classmethod
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
if isinstance(value, datetime):
return int(value.timestamp())
return value
class WorkflowArchivedLogPartialResponse(ResponseModel):
id: str
workflow_run: WorkflowRunForArchivedLogResponse | None = None
trigger_metadata: Any = None
created_by_account: SimpleAccount | None = None
created_by_end_user: SimpleEndUser | None = None
created_at: int | None = None
@field_validator("created_at", mode="before")
@classmethod
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
if isinstance(value, datetime):
return int(value.timestamp())
return value
class WorkflowAppLogPaginationResponse(ResponseModel):
page: int
limit: int
total: int
has_more: bool
data: list[WorkflowAppLogPartialResponse]
class WorkflowArchivedLogPaginationResponse(ResponseModel):
page: int
limit: int
total: int
has_more: bool
data: list[WorkflowArchivedLogPartialResponse]
register_schema_models(
console_ns,
WorkflowAppLogQuery,
WorkflowRunForLogResponse,
WorkflowRunForArchivedLogResponse,
WorkflowAppLogPartialResponse,
WorkflowArchivedLogPartialResponse,
WorkflowAppLogPaginationResponse,
WorkflowArchivedLogPaginationResponse,
)
@console_ns.route("/apps/<uuid:app_id>/workflow-app-logs")
@ -73,12 +172,15 @@ class WorkflowAppLogApi(Resource):
@console_ns.doc(description="Get workflow application execution logs")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[WorkflowAppLogQuery.__name__])
@console_ns.response(200, "Workflow app logs retrieved successfully", workflow_app_log_pagination_model)
@console_ns.response(
200,
"Workflow app logs retrieved successfully",
console_ns.models[WorkflowAppLogPaginationResponse.__name__],
)
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.WORKFLOW])
@marshal_with(workflow_app_log_pagination_model)
def get(self, app_model: App):
"""
Get workflow app logs
@ -102,7 +204,9 @@ class WorkflowAppLogApi(Resource):
created_by_account=args.created_by_account,
)
return workflow_app_log_pagination
return WorkflowAppLogPaginationResponse.model_validate(
workflow_app_log_pagination, from_attributes=True
).model_dump(mode="json")
@console_ns.route("/apps/<uuid:app_id>/workflow-archived-logs")
@ -111,12 +215,15 @@ class WorkflowArchivedLogApi(Resource):
@console_ns.doc(description="Get workflow archived execution logs")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[WorkflowAppLogQuery.__name__])
@console_ns.response(200, "Workflow archived logs retrieved successfully", workflow_archived_log_pagination_model)
@console_ns.response(
200,
"Workflow archived logs retrieved successfully",
console_ns.models[WorkflowArchivedLogPaginationResponse.__name__],
)
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.WORKFLOW])
@marshal_with(workflow_archived_log_pagination_model)
def get(self, app_model: App):
"""
Get workflow archived logs
@ -132,4 +239,6 @@ class WorkflowArchivedLogApi(Resource):
limit=args.limit,
)
return workflow_app_log_pagination
return WorkflowArchivedLogPaginationResponse.model_validate(
workflow_app_log_pagination, from_attributes=True
).model_dump(mode="json")

View File

@ -1,16 +1,17 @@
import logging
from datetime import datetime
from flask import request
from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel
from flask_restx import Resource
from pydantic import BaseModel, field_validator
from sqlalchemy import select
from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import NotFound
from configs import dify_config
from controllers.common.schema import get_or_create_model
from controllers.common.schema import register_schema_models
from extensions.ext_database import db
from fields.workflow_trigger_fields import trigger_fields, triggers_list_fields, webhook_trigger_fields
from fields.base import ResponseModel
from libs.login import current_user, login_required
from models.enums import AppTriggerStatus
from models.model import Account, App, AppMode
@ -21,15 +22,6 @@ from ..app.wraps import get_app_model
from ..wraps import account_initialization_required, edit_permission_required, setup_required
logger = logging.getLogger(__name__)
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
trigger_model = get_or_create_model("WorkflowTrigger", trigger_fields)
triggers_list_fields_copy = triggers_list_fields.copy()
triggers_list_fields_copy["data"] = fields.List(fields.Nested(trigger_model))
triggers_list_model = get_or_create_model("WorkflowTriggerList", triggers_list_fields_copy)
webhook_trigger_model = get_or_create_model("WebhookTrigger", webhook_trigger_fields)
class Parser(BaseModel):
@ -41,10 +33,52 @@ class ParserEnable(BaseModel):
enable_trigger: bool
console_ns.schema_model(Parser.__name__, Parser.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
class WorkflowTriggerResponse(ResponseModel):
id: str
trigger_type: str
title: str
node_id: str
provider_name: str
icon: str
status: str
created_at: datetime | None = None
updated_at: datetime | None = None
console_ns.schema_model(
ParserEnable.__name__, ParserEnable.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
@field_validator("id", "trigger_type", "title", "node_id", "provider_name", "icon", "status", mode="before")
@classmethod
def _normalize_string_fields(cls, value: object) -> str:
if isinstance(value, str):
return value
return str(value)
class WorkflowTriggerListResponse(ResponseModel):
data: list[WorkflowTriggerResponse]
class WebhookTriggerResponse(ResponseModel):
id: str
webhook_id: str
webhook_url: str
webhook_debug_url: str
node_id: str
created_at: datetime | None = None
@field_validator("id", "webhook_id", "webhook_url", "webhook_debug_url", "node_id", mode="before")
@classmethod
def _normalize_string_fields(cls, value: object) -> str:
if isinstance(value, str):
return value
return str(value)
register_schema_models(
console_ns,
Parser,
ParserEnable,
WorkflowTriggerResponse,
WorkflowTriggerListResponse,
WebhookTriggerResponse,
)
@ -57,7 +91,7 @@ class WebhookTriggerApi(Resource):
@login_required
@account_initialization_required
@get_app_model(mode=AppMode.WORKFLOW)
@marshal_with(webhook_trigger_model)
@console_ns.response(200, "Success", console_ns.models[WebhookTriggerResponse.__name__])
def get(self, app_model: App):
"""Get webhook trigger for a node"""
args = Parser.model_validate(request.args.to_dict(flat=True)) # type: ignore
@ -78,7 +112,7 @@ class WebhookTriggerApi(Resource):
if not webhook_trigger:
raise NotFound("Webhook trigger not found for this node")
return webhook_trigger
return WebhookTriggerResponse.model_validate(webhook_trigger, from_attributes=True).model_dump(mode="json")
@console_ns.route("/apps/<uuid:app_id>/triggers")
@ -89,7 +123,7 @@ class AppTriggersApi(Resource):
@login_required
@account_initialization_required
@get_app_model(mode=AppMode.WORKFLOW)
@marshal_with(triggers_list_model)
@console_ns.response(200, "Success", console_ns.models[WorkflowTriggerListResponse.__name__])
def get(self, app_model: App):
"""Get app triggers list"""
assert isinstance(current_user, Account)
@ -118,7 +152,9 @@ class AppTriggersApi(Resource):
else:
trigger.icon = "" # type: ignore
return {"data": triggers}
return WorkflowTriggerListResponse.model_validate({"data": triggers}, from_attributes=True).model_dump(
mode="json"
)
@console_ns.route("/apps/<uuid:app_id>/trigger-enable")
@ -129,7 +165,7 @@ class AppTriggerEnableApi(Resource):
@account_initialization_required
@edit_permission_required
@get_app_model(mode=AppMode.WORKFLOW)
@marshal_with(trigger_model)
@console_ns.response(200, "Success", console_ns.models[WorkflowTriggerResponse.__name__])
def post(self, app_model: App):
"""Update app trigger (enable/disable)"""
args = ParserEnable.model_validate(console_ns.payload)
@ -160,4 +196,4 @@ class AppTriggerEnableApi(Resource):
else:
trigger.icon = "" # type: ignore
return trigger
return WorkflowTriggerResponse.model_validate(trigger, from_attributes=True).model_dump(mode="json")

View File

@ -1,3 +1,5 @@
from typing import Any
from flask import request
from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator
@ -40,7 +42,7 @@ class ActivatePayload(BaseModel):
class ActivationCheckResponse(BaseModel):
is_valid: bool = Field(description="Whether token is valid")
data: dict | None = Field(default=None, description="Activation data if valid")
data: dict[str, Any] | None = Field(default=None, description="Activation data if valid")
class ActivationResponse(BaseModel):

View File

@ -1026,7 +1026,7 @@ class DocumentMetadataApi(DocumentResource):
if not isinstance(doc_metadata, dict):
raise ValueError("doc_metadata must be a dictionary.")
metadata_schema: dict = cast(dict, DocumentService.DOCUMENT_METADATA_SCHEMA[doc_type])
metadata_schema: dict[str, Any] = cast(dict[str, Any], DocumentService.DOCUMENT_METADATA_SCHEMA[doc_type])
document.doc_metadata = {}
if doc_type == "others":

View File

@ -1,13 +1,13 @@
from flask_restx import Resource, fields
from __future__ import annotations
from controllers.common.schema import register_schema_model
from fields.hit_testing_fields import (
child_chunk_fields,
document_fields,
files_fields,
hit_testing_record_fields,
segment_fields,
)
from datetime import datetime
from typing import Any
from flask_restx import Resource
from pydantic import Field, field_validator
from controllers.common.schema import register_schema_models
from fields.base import ResponseModel
from libs.login import login_required
from .. import console_ns
@ -18,39 +18,92 @@ from ..wraps import (
setup_required,
)
register_schema_model(console_ns, HitTestingPayload)
def _to_timestamp(value: datetime | int | None) -> int | None:
if isinstance(value, datetime):
return int(value.timestamp())
return value
def _get_or_create_model(model_name: str, field_def):
"""Get or create a flask_restx model to avoid dict type issues in Swagger."""
existing = console_ns.models.get(model_name)
if existing is None:
existing = console_ns.model(model_name, field_def)
return existing
class HitTestingDocument(ResponseModel):
id: str | None = None
data_source_type: str | None = None
name: str | None = None
doc_type: str | None = None
doc_metadata: Any | None = None
# Register models for flask_restx to avoid dict type issues in Swagger
document_model = _get_or_create_model("HitTestingDocument", document_fields)
class HitTestingSegment(ResponseModel):
id: str | None = None
position: int | None = None
document_id: str | None = None
content: str | None = None
sign_content: str | None = None
answer: str | None = None
word_count: int | None = None
tokens: int | None = None
keywords: list[str] = Field(default_factory=list)
index_node_id: str | None = None
index_node_hash: str | None = None
hit_count: int | None = None
enabled: bool | None = None
disabled_at: int | None = None
disabled_by: str | None = None
status: str | None = None
created_by: str | None = None
created_at: int | None = None
indexing_at: int | None = None
completed_at: int | None = None
error: str | None = None
stopped_at: int | None = None
document: HitTestingDocument | None = None
segment_fields_copy = segment_fields.copy()
segment_fields_copy["document"] = fields.Nested(document_model)
segment_model = _get_or_create_model("HitTestingSegment", segment_fields_copy)
@field_validator("disabled_at", "created_at", "indexing_at", "completed_at", "stopped_at", mode="before")
@classmethod
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
return _to_timestamp(value)
child_chunk_model = _get_or_create_model("HitTestingChildChunk", child_chunk_fields)
files_model = _get_or_create_model("HitTestingFile", files_fields)
hit_testing_record_fields_copy = hit_testing_record_fields.copy()
hit_testing_record_fields_copy["segment"] = fields.Nested(segment_model)
hit_testing_record_fields_copy["child_chunks"] = fields.List(fields.Nested(child_chunk_model))
hit_testing_record_fields_copy["files"] = fields.List(fields.Nested(files_model))
hit_testing_record_model = _get_or_create_model("HitTestingRecord", hit_testing_record_fields_copy)
class HitTestingChildChunk(ResponseModel):
id: str | None = None
content: str | None = None
position: int | None = None
score: float | None = None
# Response model for hit testing API
hit_testing_response_fields = {
"query": fields.String,
"records": fields.List(fields.Nested(hit_testing_record_model)),
}
hit_testing_response_model = _get_or_create_model("HitTestingResponse", hit_testing_response_fields)
class HitTestingFile(ResponseModel):
id: str | None = None
name: str | None = None
size: int | None = None
extension: str | None = None
mime_type: str | None = None
source_url: str | None = None
class HitTestingRecord(ResponseModel):
segment: HitTestingSegment | None = None
child_chunks: list[HitTestingChildChunk] = Field(default_factory=list)
score: float | None = None
tsne_position: Any | None = None
files: list[HitTestingFile] = Field(default_factory=list)
summary: str | None = None
class HitTestingResponse(ResponseModel):
query: str
records: list[HitTestingRecord] = Field(default_factory=list)
register_schema_models(
console_ns,
HitTestingPayload,
HitTestingDocument,
HitTestingSegment,
HitTestingChildChunk,
HitTestingFile,
HitTestingRecord,
HitTestingResponse,
)
@console_ns.route("/datasets/<uuid:dataset_id>/hit-testing")
@ -59,7 +112,11 @@ class HitTestingApi(Resource, DatasetsHitTestingBase):
@console_ns.doc(description="Test dataset knowledge retrieval")
@console_ns.doc(params={"dataset_id": "Dataset ID"})
@console_ns.expect(console_ns.models[HitTestingPayload.__name__])
@console_ns.response(200, "Hit testing completed successfully", model=hit_testing_response_model)
@console_ns.response(
200,
"Hit testing completed successfully",
model=console_ns.models[HitTestingResponse.__name__],
)
@console_ns.response(404, "Dataset not found")
@console_ns.response(400, "Invalid parameters")
@setup_required
@ -74,4 +131,4 @@ class HitTestingApi(Resource, DatasetsHitTestingBase):
args = payload.model_dump(exclude_none=True)
self.hit_testing_args_check(args)
return self.perform_hit_testing(dataset, args)
return HitTestingResponse.model_validate(self.perform_hit_testing(dataset, args)).model_dump(mode="json")

View File

@ -1,21 +1,24 @@
import logging
from datetime import datetime
from typing import Any
from flask import request
from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field
from flask_restx import Resource
from graphon.file import helpers as file_helpers
from pydantic import BaseModel, Field, computed_field, field_validator
from sqlalchemy import and_, select
from werkzeug.exceptions import BadRequest, Forbidden, NotFound
from controllers.common.schema import get_or_create_model
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.explore.wraps import InstalledAppResource
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
from extensions.ext_database import db
from fields.installed_app_fields import app_fields, installed_app_fields, installed_app_list_fields
from fields.base import ResponseModel
from libs.datetime_utils import naive_utc_now
from libs.login import current_account_with_tenant, login_required
from models import App, InstalledApp, RecommendedApp
from models.model import IconType
from services.account_service import TenantService
from services.enterprise.enterprise_service import EnterpriseService
from services.feature_service import FeatureService
@ -36,22 +39,97 @@ class InstalledAppsListQuery(BaseModel):
logger = logging.getLogger(__name__)
app_model = get_or_create_model("InstalledAppInfo", app_fields)
def _build_icon_url(icon_type: str | IconType | None, icon: str | None) -> str | None:
if icon is None or icon_type is None:
return None
icon_type_value = icon_type.value if isinstance(icon_type, IconType) else str(icon_type)
if icon_type_value.lower() != IconType.IMAGE:
return None
return file_helpers.get_signed_file_url(icon)
installed_app_fields_copy = installed_app_fields.copy()
installed_app_fields_copy["app"] = fields.Nested(app_model)
installed_app_model = get_or_create_model("InstalledApp", installed_app_fields_copy)
installed_app_list_fields_copy = installed_app_list_fields.copy()
installed_app_list_fields_copy["installed_apps"] = fields.List(fields.Nested(installed_app_model))
installed_app_list_model = get_or_create_model("InstalledAppList", installed_app_list_fields_copy)
def _safe_primitive(value: Any) -> Any:
if value is None or isinstance(value, (str, int, float, bool, datetime)):
return value
return None
class InstalledAppInfoResponse(ResponseModel):
id: str
name: str | None = None
mode: str | None = None
icon_type: str | None = None
icon: str | None = None
icon_background: str | None = None
use_icon_as_answer_icon: bool | None = None
@field_validator("mode", "icon_type", mode="before")
@classmethod
def _normalize_enum_like(cls, value: Any) -> str | None:
if value is None:
return None
if isinstance(value, str):
return value
return str(getattr(value, "value", value))
@computed_field(return_type=str | None) # type: ignore[prop-decorator]
@property
def icon_url(self) -> str | None:
return _build_icon_url(self.icon_type, self.icon)
class InstalledAppResponse(ResponseModel):
id: str
app: InstalledAppInfoResponse
app_owner_tenant_id: str
is_pinned: bool
last_used_at: int | None = None
editable: bool
uninstallable: bool
@field_validator("app", mode="before")
@classmethod
def _normalize_app(cls, value: Any) -> Any:
if isinstance(value, dict):
return value
return {
"id": _safe_primitive(getattr(value, "id", "")) or "",
"name": _safe_primitive(getattr(value, "name", None)),
"mode": _safe_primitive(getattr(value, "mode", None)),
"icon_type": _safe_primitive(getattr(value, "icon_type", None)),
"icon": _safe_primitive(getattr(value, "icon", None)),
"icon_background": _safe_primitive(getattr(value, "icon_background", None)),
"use_icon_as_answer_icon": _safe_primitive(getattr(value, "use_icon_as_answer_icon", None)),
}
@field_validator("last_used_at", mode="before")
@classmethod
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
if isinstance(value, datetime):
return int(value.timestamp())
return value
class InstalledAppListResponse(ResponseModel):
installed_apps: list[InstalledAppResponse]
register_schema_models(
console_ns,
InstalledAppCreatePayload,
InstalledAppUpdatePayload,
InstalledAppsListQuery,
InstalledAppInfoResponse,
InstalledAppResponse,
InstalledAppListResponse,
)
@console_ns.route("/installed-apps")
class InstalledAppsListApi(Resource):
@login_required
@account_initialization_required
@marshal_with(installed_app_list_model)
@console_ns.response(200, "Success", console_ns.models[InstalledAppListResponse.__name__])
def get(self):
query = InstalledAppsListQuery.model_validate(request.args.to_dict())
current_user, current_tenant_id = current_account_with_tenant()
@ -125,7 +203,9 @@ class InstalledAppsListApi(Resource):
)
)
return {"installed_apps": installed_app_list}
return InstalledAppListResponse.model_validate(
{"installed_apps": installed_app_list}, from_attributes=True
).model_dump(mode="json")
@login_required
@account_initialization_required

View File

@ -1,66 +1,83 @@
from typing import Any
from flask import request
from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field
from flask_restx import Resource
from pydantic import BaseModel, Field, computed_field, field_validator
from constants.languages import languages
from controllers.common.schema import get_or_create_model
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required
from libs.helper import AppIconUrlField
from fields.base import ResponseModel
from libs.helper import build_icon_url
from libs.login import current_user, login_required
from services.recommended_app_service import RecommendedAppService
app_fields = {
"id": fields.String,
"name": fields.String,
"mode": fields.String,
"icon": fields.String,
"icon_type": fields.String,
"icon_url": AppIconUrlField,
"icon_background": fields.String,
}
app_model = get_or_create_model("RecommendedAppInfo", app_fields)
recommended_app_fields = {
"app": fields.Nested(app_model, attribute="app"),
"app_id": fields.String,
"description": fields.String(attribute="description"),
"copyright": fields.String,
"privacy_policy": fields.String,
"custom_disclaimer": fields.String,
"category": fields.String,
"position": fields.Integer,
"is_listed": fields.Boolean,
"can_trial": fields.Boolean,
}
recommended_app_model = get_or_create_model("RecommendedApp", recommended_app_fields)
recommended_app_list_fields = {
"recommended_apps": fields.List(fields.Nested(recommended_app_model)),
"categories": fields.List(fields.String),
}
recommended_app_list_model = get_or_create_model("RecommendedAppList", recommended_app_list_fields)
class RecommendedAppsQuery(BaseModel):
language: str | None = Field(default=None)
console_ns.schema_model(
RecommendedAppsQuery.__name__,
RecommendedAppsQuery.model_json_schema(ref_template="#/definitions/{model}"),
class RecommendedAppInfoResponse(ResponseModel):
id: str
name: str | None = None
mode: str | None = None
icon: str | None = None
icon_type: str | None = None
icon_background: str | None = None
@staticmethod
def _normalize_enum_like(value: Any) -> str | None:
if value is None:
return None
if isinstance(value, str):
return value
return str(getattr(value, "value", value))
@field_validator("mode", "icon_type", mode="before")
@classmethod
def _normalize_enum_fields(cls, value: Any) -> str | None:
return cls._normalize_enum_like(value)
@computed_field(return_type=str | None) # type: ignore[prop-decorator]
@property
def icon_url(self) -> str | None:
return build_icon_url(self.icon_type, self.icon)
class RecommendedAppResponse(ResponseModel):
app: RecommendedAppInfoResponse | None = None
app_id: str
description: str | None = None
copyright: str | None = None
privacy_policy: str | None = None
custom_disclaimer: str | None = None
category: str | None = None
position: int | None = None
is_listed: bool | None = None
can_trial: bool | None = None
class RecommendedAppListResponse(ResponseModel):
recommended_apps: list[RecommendedAppResponse]
categories: list[str]
register_schema_models(
console_ns,
RecommendedAppsQuery,
RecommendedAppInfoResponse,
RecommendedAppResponse,
RecommendedAppListResponse,
)
@console_ns.route("/explore/apps")
class RecommendedAppListApi(Resource):
@console_ns.expect(console_ns.models[RecommendedAppsQuery.__name__])
@console_ns.response(200, "Success", console_ns.models[RecommendedAppListResponse.__name__])
@login_required
@account_initialization_required
@marshal_with(recommended_app_list_model)
def get(self):
# language args
args = RecommendedAppsQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
@ -72,7 +89,10 @@ class RecommendedAppListApi(Resource):
else:
language_prefix = languages[0]
return RecommendedAppService.get_recommended_apps_and_categories(language_prefix)
return RecommendedAppListResponse.model_validate(
RecommendedAppService.get_recommended_apps_and_categories(language_prefix),
from_attributes=True,
).model_dump(mode="json")
@console_ns.route("/explore/apps/<uuid:app_id>")

View File

@ -1,15 +1,18 @@
from datetime import datetime
from typing import Any
from flask import request
from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field
from flask_restx import Resource
from pydantic import BaseModel, Field, TypeAdapter, field_validator
from constants import HIDDEN_VALUE
from fields.api_based_extension_fields import api_based_extension_fields
from fields.base import ResponseModel
from libs.login import current_account_with_tenant, login_required
from models.api_based_extension import APIBasedExtension
from services.api_based_extension_service import APIBasedExtensionService
from services.code_based_extension_service import CodeBasedExtensionService
from ..common.schema import register_schema_models
from ..common.schema import DEFAULT_REF_TEMPLATE_SWAGGER_2_0, register_schema_models
from . import console_ns
from .wraps import account_initialization_required, setup_required
@ -24,12 +27,52 @@ class APIBasedExtensionPayload(BaseModel):
api_key: str = Field(description="API key for authentication")
register_schema_models(console_ns, APIBasedExtensionPayload)
class CodeBasedExtensionResponse(ResponseModel):
module: str = Field(description="Module name")
data: Any = Field(description="Extension data")
api_based_extension_model = console_ns.model("ApiBasedExtensionModel", api_based_extension_fields)
def _mask_api_key(api_key: str) -> str:
if not api_key:
return api_key
if len(api_key) <= 8:
return api_key[0] + "******" + api_key[-1]
return api_key[:3] + "******" + api_key[-3:]
api_based_extension_list_model = fields.List(fields.Nested(api_based_extension_model))
def _to_timestamp(value: datetime | int | None) -> int | None:
if isinstance(value, datetime):
return int(value.timestamp())
return value
class APIBasedExtensionResponse(ResponseModel):
id: str
name: str
api_endpoint: str
api_key: str
created_at: int | None = None
@field_validator("api_key", mode="before")
@classmethod
def _normalize_api_key(cls, value: str) -> str:
return _mask_api_key(value)
@field_validator("created_at", mode="before")
@classmethod
def _normalize_created_at(cls, value: datetime | int | None) -> int | None:
return _to_timestamp(value)
register_schema_models(console_ns, APIBasedExtensionPayload, CodeBasedExtensionResponse, APIBasedExtensionResponse)
console_ns.schema_model(
"APIBasedExtensionListResponse",
TypeAdapter(list[APIBasedExtensionResponse]).json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
def _serialize_api_based_extension(extension: APIBasedExtension) -> dict[str, Any]:
return APIBasedExtensionResponse.model_validate(extension, from_attributes=True).model_dump(mode="json")
@console_ns.route("/code-based-extension")
@ -40,10 +83,7 @@ class CodeBasedExtensionAPI(Resource):
@console_ns.response(
200,
"Success",
console_ns.model(
"CodeBasedExtensionResponse",
{"module": fields.String(description="Module name"), "data": fields.Raw(description="Extension data")},
),
console_ns.models[CodeBasedExtensionResponse.__name__],
)
@setup_required
@login_required
@ -51,30 +91,34 @@ class CodeBasedExtensionAPI(Resource):
def get(self):
query = CodeBasedExtensionQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
return {"module": query.module, "data": CodeBasedExtensionService.get_code_based_extension(query.module)}
return CodeBasedExtensionResponse(
module=query.module,
data=CodeBasedExtensionService.get_code_based_extension(query.module),
).model_dump(mode="json")
@console_ns.route("/api-based-extension")
class APIBasedExtensionAPI(Resource):
@console_ns.doc("get_api_based_extensions")
@console_ns.doc(description="Get all API-based extensions for current tenant")
@console_ns.response(200, "Success", api_based_extension_list_model)
@console_ns.response(200, "Success", console_ns.models["APIBasedExtensionListResponse"])
@setup_required
@login_required
@account_initialization_required
@marshal_with(api_based_extension_model)
def get(self):
_, tenant_id = current_account_with_tenant()
return APIBasedExtensionService.get_all_by_tenant_id(tenant_id)
return [
_serialize_api_based_extension(extension)
for extension in APIBasedExtensionService.get_all_by_tenant_id(tenant_id)
]
@console_ns.doc("create_api_based_extension")
@console_ns.doc(description="Create a new API-based extension")
@console_ns.expect(console_ns.models[APIBasedExtensionPayload.__name__])
@console_ns.response(201, "Extension created successfully", api_based_extension_model)
@console_ns.response(201, "Extension created successfully", console_ns.models[APIBasedExtensionResponse.__name__])
@setup_required
@login_required
@account_initialization_required
@marshal_with(api_based_extension_model)
def post(self):
payload = APIBasedExtensionPayload.model_validate(console_ns.payload or {})
_, current_tenant_id = current_account_with_tenant()
@ -86,7 +130,7 @@ class APIBasedExtensionAPI(Resource):
api_key=payload.api_key,
)
return APIBasedExtensionService.save(extension_data)
return _serialize_api_based_extension(APIBasedExtensionService.save(extension_data))
@console_ns.route("/api-based-extension/<uuid:id>")
@ -94,26 +138,26 @@ class APIBasedExtensionDetailAPI(Resource):
@console_ns.doc("get_api_based_extension")
@console_ns.doc(description="Get API-based extension by ID")
@console_ns.doc(params={"id": "Extension ID"})
@console_ns.response(200, "Success", api_based_extension_model)
@console_ns.response(200, "Success", console_ns.models[APIBasedExtensionResponse.__name__])
@setup_required
@login_required
@account_initialization_required
@marshal_with(api_based_extension_model)
def get(self, id):
api_based_extension_id = str(id)
_, tenant_id = current_account_with_tenant()
return APIBasedExtensionService.get_with_tenant_id(tenant_id, api_based_extension_id)
return _serialize_api_based_extension(
APIBasedExtensionService.get_with_tenant_id(tenant_id, api_based_extension_id)
)
@console_ns.doc("update_api_based_extension")
@console_ns.doc(description="Update API-based extension")
@console_ns.doc(params={"id": "Extension ID"})
@console_ns.expect(console_ns.models[APIBasedExtensionPayload.__name__])
@console_ns.response(200, "Extension updated successfully", api_based_extension_model)
@console_ns.response(200, "Extension updated successfully", console_ns.models[APIBasedExtensionResponse.__name__])
@setup_required
@login_required
@account_initialization_required
@marshal_with(api_based_extension_model)
def post(self, id):
api_based_extension_id = str(id)
_, current_tenant_id = current_account_with_tenant()
@ -128,7 +172,7 @@ class APIBasedExtensionDetailAPI(Resource):
if payload.api_key != HIDDEN_VALUE:
extension_data_from_db.api_key = payload.api_key
return APIBasedExtensionService.save(extension_data_from_db)
return _serialize_api_based_extension(APIBasedExtensionService.save(extension_data_from_db))
@console_ns.doc("delete_api_based_extension")
@console_ns.doc(description="Delete API-based extension")

View File

@ -5,7 +5,7 @@ from typing import Any, Literal
import pytz
from flask import request
from flask_restx import Resource, fields, marshal_with
from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator, model_validator
from sqlalchemy import select
@ -37,9 +37,10 @@ from controllers.console.wraps import (
setup_required,
)
from extensions.ext_database import db
from fields.base import ResponseModel
from fields.member_fields import Account as AccountResponse
from libs.datetime_utils import naive_utc_now
from libs.helper import EmailStr, TimestampField, extract_remote_ip, timezone
from libs.helper import EmailStr, extract_remote_ip, timezone
from libs.login import current_account_with_tenant, login_required
from models import AccountIntegrate, InvitationCode
from models.account import AccountStatus, InvitationCodeStatus
@ -178,17 +179,57 @@ def _serialize_account(account) -> dict[str, Any]:
return AccountResponse.model_validate(account, from_attributes=True).model_dump(mode="json")
integrate_fields = {
"provider": fields.String,
"created_at": TimestampField,
"is_bound": fields.Boolean,
"link": fields.String,
}
def _to_timestamp(value: datetime | int | None) -> int | None:
if isinstance(value, datetime):
return int(value.timestamp())
return value
integrate_model = console_ns.model("AccountIntegrate", integrate_fields)
integrate_list_model = console_ns.model(
"AccountIntegrateList",
{"data": fields.List(fields.Nested(integrate_model))},
class AccountIntegrateResponse(ResponseModel):
provider: str
created_at: int | None = None
is_bound: bool
link: str | None = None
@field_validator("created_at", mode="before")
@classmethod
def _normalize_created_at(cls, value: datetime | int | None) -> int | None:
return _to_timestamp(value)
class AccountIntegrateListResponse(ResponseModel):
data: list[AccountIntegrateResponse]
class EducationVerifyResponse(ResponseModel):
token: str | None = None
class EducationStatusResponse(ResponseModel):
result: bool | None = None
is_student: bool | None = None
expire_at: int | None = None
allow_refresh: bool | None = None
@field_validator("expire_at", mode="before")
@classmethod
def _normalize_expire_at(cls, value: datetime | int | None) -> int | None:
return _to_timestamp(value)
class EducationAutocompleteResponse(ResponseModel):
data: list[str] = Field(default_factory=list)
curr_page: int | None = None
has_next: bool | None = None
register_schema_models(
console_ns,
AccountIntegrateResponse,
AccountIntegrateListResponse,
EducationVerifyResponse,
EducationStatusResponse,
EducationAutocompleteResponse,
)
@ -359,7 +400,7 @@ class AccountIntegrateApi(Resource):
@setup_required
@login_required
@account_initialization_required
@marshal_with(integrate_list_model)
@console_ns.response(200, "Success", console_ns.models[AccountIntegrateListResponse.__name__])
def get(self):
account, _ = current_account_with_tenant()
@ -395,7 +436,9 @@ class AccountIntegrateApi(Resource):
}
)
return {"data": integrate_data}
return AccountIntegrateListResponse(
data=[AccountIntegrateResponse.model_validate(item) for item in integrate_data]
).model_dump(mode="json")
@console_ns.route("/account/delete/verify")
@ -447,31 +490,22 @@ class AccountDeleteUpdateFeedbackApi(Resource):
@console_ns.route("/account/education/verify")
class EducationVerifyApi(Resource):
verify_fields = {
"token": fields.String,
}
@setup_required
@login_required
@account_initialization_required
@only_edition_cloud
@cloud_edition_billing_enabled
@marshal_with(verify_fields)
@console_ns.response(200, "Success", console_ns.models[EducationVerifyResponse.__name__])
def get(self):
account, _ = current_account_with_tenant()
return BillingService.EducationIdentity.verify(account.id, account.email)
return EducationVerifyResponse.model_validate(
BillingService.EducationIdentity.verify(account.id, account.email) or {}
).model_dump(mode="json")
@console_ns.route("/account/education")
class EducationApi(Resource):
status_fields = {
"result": fields.Boolean,
"is_student": fields.Boolean,
"expire_at": TimestampField,
"allow_refresh": fields.Boolean,
}
@console_ns.expect(console_ns.models[EducationActivatePayload.__name__])
@setup_required
@login_required
@ -491,37 +525,33 @@ class EducationApi(Resource):
@account_initialization_required
@only_edition_cloud
@cloud_edition_billing_enabled
@marshal_with(status_fields)
@console_ns.response(200, "Success", console_ns.models[EducationStatusResponse.__name__])
def get(self):
account, _ = current_account_with_tenant()
res = BillingService.EducationIdentity.status(account.id)
res = BillingService.EducationIdentity.status(account.id) or {}
# convert expire_at to UTC timestamp from isoformat
if res and "expire_at" in res:
res["expire_at"] = datetime.fromisoformat(res["expire_at"]).astimezone(pytz.utc)
return res
return EducationStatusResponse.model_validate(res).model_dump(mode="json")
@console_ns.route("/account/education/autocomplete")
class EducationAutoCompleteApi(Resource):
data_fields = {
"data": fields.List(fields.String),
"curr_page": fields.Integer,
"has_next": fields.Boolean,
}
@console_ns.expect(console_ns.models[EducationAutocompleteQuery.__name__])
@setup_required
@login_required
@account_initialization_required
@only_edition_cloud
@cloud_edition_billing_enabled
@marshal_with(data_fields)
@console_ns.response(200, "Success", console_ns.models[EducationAutocompleteResponse.__name__])
def get(self):
payload = request.args.to_dict(flat=True)
args = EducationAutocompleteQuery.model_validate(payload)
return BillingService.EducationIdentity.autocomplete(args.keywords, args.page, args.limit)
return EducationAutocompleteResponse.model_validate(
BillingService.EducationIdentity.autocomplete(args.keywords, args.page, args.limit) or {}
).model_dump(mode="json")
@console_ns.route("/account/change-email")

View File

@ -465,7 +465,7 @@ class ModelProviderModelDisableApi(Resource):
class ParserValidate(BaseModel):
model: str
model_type: ModelType
credentials: dict
credentials: dict[str, Any]
console_ns.schema_model(

View File

@ -1,8 +1,9 @@
import logging
from datetime import datetime
from flask import request
from flask_restx import Resource, fields, marshal, marshal_with
from pydantic import BaseModel, Field
from flask_restx import Resource, fields, marshal
from pydantic import BaseModel, Field, field_validator
from sqlalchemy import select
from werkzeug.exceptions import Unauthorized
@ -26,6 +27,7 @@ from controllers.console.wraps import (
)
from enums.cloud_plan import CloudPlan
from extensions.ext_database import db
from fields.base import ResponseModel
from libs.helper import TimestampField
from libs.login import current_account_with_tenant, login_required
from models.account import Tenant, TenantCustomConfigDict, TenantStatus
@ -58,6 +60,37 @@ class WorkspaceInfoPayload(BaseModel):
name: str
class TenantInfoResponse(ResponseModel):
id: str
name: str | None = None
plan: str | None = None
status: str | None = None
created_at: int | None = None
role: str | None = None
in_trial: bool | None = None
trial_end_reason: str | None = None
custom_config: dict | None = None
trial_credits: int | None = None
trial_credits_used: int | None = None
next_credit_reset_date: int | None = None
@field_validator("plan", "status", "trial_end_reason", mode="before")
@classmethod
def _normalize_enum_like(cls, value):
if value is None:
return None
if isinstance(value, str):
return value
return str(getattr(value, "value", value))
@field_validator("created_at", mode="before")
@classmethod
def _normalize_created_at(cls, value: datetime | int | None):
if isinstance(value, datetime):
return int(value.timestamp())
return value
def reg(cls: type[BaseModel]):
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
@ -66,6 +99,7 @@ reg(WorkspaceListQuery)
reg(SwitchWorkspacePayload)
reg(WorkspaceCustomConfigPayload)
reg(WorkspaceInfoPayload)
reg(TenantInfoResponse)
provider_fields = {
"provider_name": fields.String,
@ -180,7 +214,7 @@ class TenantApi(Resource):
@setup_required
@login_required
@account_initialization_required
@marshal_with(tenant_fields)
@console_ns.response(200, "Success", console_ns.models[TenantInfoResponse.__name__])
def post(self):
if request.path == "/info":
logger.warning("Deprecated URL /info was used.")
@ -200,7 +234,13 @@ class TenantApi(Resource):
else:
raise Unauthorized("workspace is archived")
return WorkspaceService.get_tenant_info(tenant), 200
return (
TenantInfoResponse.model_validate(
WorkspaceService.get_tenant_info(tenant),
from_attributes=True,
).model_dump(mode="json"),
200,
)
@console_ns.route("/workspaces/switch")

View File

@ -9,7 +9,7 @@ from flask import request
from flask_restx import Resource
from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm import Session
from controllers.common.schema import register_schema_model
from controllers.console.wraps import setup_required
@ -56,7 +56,7 @@ class EnterpriseAppDSLImport(Resource):
account.set_tenant_id(workspace_id)
with sessionmaker(db.engine).begin() as session:
with Session(db.engine, expire_on_commit=False) as session:
dsl_service = AppDslService(session)
result = dsl_service.import_app(
account=account,
@ -65,6 +65,10 @@ class EnterpriseAppDSLImport(Resource):
name=args.name,
description=args.description,
)
if result.status == ImportStatus.FAILED:
session.rollback()
else:
session.commit()
if result.status == ImportStatus.FAILED:
return result.model_dump(mode="json"), 400

View File

@ -1,7 +1,9 @@
from datetime import datetime
from typing import Any, Literal
from flask import request
from flask_restx import Resource
from graphon.variables.types import SegmentType
from pydantic import BaseModel, Field, TypeAdapter, field_validator
from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import BadRequest, NotFound
@ -14,14 +16,12 @@ from controllers.service_api.app.error import NotChatAppError
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db
from fields._value_type_serializer import serialize_value_type
from fields.base import ResponseModel
from fields.conversation_fields import (
ConversationInfiniteScrollPagination,
SimpleConversation,
)
from fields.conversation_variable_fields import (
build_conversation_variable_infinite_scroll_pagination_model,
build_conversation_variable_model,
)
from libs.helper import UUIDStrOrEmpty
from models.model import App, AppMode, EndUser
from services.conversation_service import ConversationService
@ -70,12 +70,70 @@ class ConversationVariableUpdatePayload(BaseModel):
value: Any
class ConversationVariableResponse(ResponseModel):
id: str
name: str
value_type: str
value: str | None = None
description: str | None = None
created_at: int | None = None
updated_at: int | None = None
@field_validator("value_type", mode="before")
@classmethod
def normalize_value_type(cls, value: Any) -> str:
exposed_type = getattr(value, "exposed_type", None)
if callable(exposed_type):
return str(exposed_type().value)
if isinstance(value, str):
try:
return str(SegmentType(value).exposed_type().value)
except ValueError:
return value
try:
return serialize_value_type(value)
except (AttributeError, TypeError, ValueError):
pass
try:
return serialize_value_type({"value_type": value})
except (AttributeError, TypeError, ValueError):
value_attr = getattr(value, "value", None)
if value_attr is not None:
return str(value_attr)
return str(value)
@field_validator("value", mode="before")
@classmethod
def normalize_value(cls, value: Any | None) -> str | None:
if value is None:
return None
if isinstance(value, str):
return value
return str(value)
@field_validator("created_at", "updated_at", mode="before")
@classmethod
def normalize_timestamp(cls, value: datetime | int | None) -> int | None:
if isinstance(value, datetime):
return int(value.timestamp())
return value
class ConversationVariableInfiniteScrollPaginationResponse(ResponseModel):
limit: int
has_more: bool
data: list[ConversationVariableResponse]
register_schema_models(
service_api_ns,
ConversationListQuery,
ConversationRenamePayload,
ConversationVariablesQuery,
ConversationVariableUpdatePayload,
ConversationVariableResponse,
ConversationVariableInfiniteScrollPaginationResponse,
)
@ -204,8 +262,12 @@ class ConversationVariablesApi(Resource):
404: "Conversation not found",
}
)
@service_api_ns.response(
200,
"Variables retrieved successfully",
service_api_ns.models[ConversationVariableInfiniteScrollPaginationResponse.__name__],
)
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
@service_api_ns.marshal_with(build_conversation_variable_infinite_scroll_pagination_model(service_api_ns))
def get(self, app_model: App, end_user: EndUser, c_id):
"""List all variables for a conversation.
@ -222,9 +284,12 @@ class ConversationVariablesApi(Resource):
last_id = str(query_args.last_id) if query_args.last_id else None
try:
return ConversationService.get_conversational_variable(
pagination = ConversationService.get_conversational_variable(
app_model, conversation_id, end_user, query_args.limit, last_id, query_args.variable_name
)
return ConversationVariableInfiniteScrollPaginationResponse.model_validate(
pagination, from_attributes=True
).model_dump(mode="json")
except services.errors.conversation.ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
@ -243,8 +308,12 @@ class ConversationVariableDetailApi(Resource):
404: "Conversation or variable not found",
}
)
@service_api_ns.response(
200,
"Variable updated successfully",
service_api_ns.models[ConversationVariableResponse.__name__],
)
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
@service_api_ns.marshal_with(build_conversation_variable_model(service_api_ns))
def put(self, app_model: App, end_user: EndUser, c_id, variable_id):
"""Update a conversation variable's value.
@ -261,9 +330,10 @@ class ConversationVariableDetailApi(Resource):
payload = ConversationVariableUpdatePayload.model_validate(service_api_ns.payload or {})
try:
return ConversationService.update_conversation_variable(
variable = ConversationService.update_conversation_variable(
app_model, conversation_id, variable_id, end_user, payload.value
)
return ConversationVariableResponse.model_validate(variable, from_attributes=True).model_dump(mode="json")
except services.errors.conversation.ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
except services.errors.conversation.ConversationVariableNotExistsError:

View File

@ -1,13 +1,15 @@
import logging
from collections.abc import Mapping
from datetime import datetime
from typing import Literal
from dateutil.parser import isoparse
from flask import request
from flask_restx import Namespace, Resource, fields
from flask_restx import Resource, fields
from graphon.enums import WorkflowExecutionStatus
from graphon.graph_engine.manager import GraphEngineManager
from graphon.model_runtime.errors.invoke import InvokeError
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, field_validator
from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
@ -33,9 +35,10 @@ from core.errors.error import (
from core.helper.trace_id_helper import get_external_trace_id
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from fields.workflow_app_log_fields import build_workflow_app_log_pagination_model
from fields.base import ResponseModel
from fields.end_user_fields import SimpleEndUser
from fields.member_fields import SimpleAccount
from libs import helper
from libs.helper import OptionalTimestampField, TimestampField
from models.model import App, AppMode, EndUser
from models.workflow import WorkflowRun
from repositories.factory import DifyAPIRepositoryFactory
@ -65,38 +68,142 @@ class WorkflowLogQuery(BaseModel):
register_schema_models(service_api_ns, WorkflowRunPayload, WorkflowLogQuery)
def _to_timestamp(value: datetime | int | None) -> int | None:
if isinstance(value, datetime):
return int(value.timestamp())
return value
def _enum_value(value):
return getattr(value, "value", value)
class WorkflowRunStatusField(fields.Raw):
def output(self, key, obj: WorkflowRun, **kwargs):
return obj.status.value
return _enum_value(obj.status)
class WorkflowRunOutputsField(fields.Raw):
def output(self, key, obj: WorkflowRun, **kwargs):
if obj.status == WorkflowExecutionStatus.PAUSED:
status = _enum_value(obj.status)
if status == WorkflowExecutionStatus.PAUSED.value:
return {}
outputs = obj.outputs_dict
return outputs or {}
workflow_run_fields = {
"id": fields.String,
"workflow_id": fields.String,
"status": WorkflowRunStatusField,
"inputs": fields.Raw,
"outputs": WorkflowRunOutputsField,
"error": fields.String,
"total_steps": fields.Integer,
"total_tokens": fields.Integer,
"created_at": TimestampField,
"finished_at": OptionalTimestampField,
"elapsed_time": fields.Float,
}
class WorkflowRunResponse(ResponseModel):
id: str
workflow_id: str
status: str
inputs: dict | list | str | int | float | bool | None = None
outputs: dict = Field(default_factory=dict)
error: str | None = None
total_steps: int | None = None
total_tokens: int | None = None
created_at: int | None = None
finished_at: int | None = None
elapsed_time: float | int | None = None
@field_validator("created_at", "finished_at", mode="before")
@classmethod
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
return _to_timestamp(value)
def build_workflow_run_model(api_or_ns: Namespace):
"""Build the workflow run model for the API or Namespace."""
return api_or_ns.model("WorkflowRun", workflow_run_fields)
class WorkflowRunForLogResponse(ResponseModel):
id: str
version: str | None = None
status: str | None = None
triggered_from: str | None = None
error: str | None = None
elapsed_time: float | int | None = None
total_tokens: int | None = None
total_steps: int | None = None
created_at: int | None = None
finished_at: int | None = None
exceptions_count: int | None = None
@field_validator("status", "triggered_from", mode="before")
@classmethod
def _normalize_enum(cls, value):
return _enum_value(value)
@field_validator("created_at", "finished_at", mode="before")
@classmethod
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
return _to_timestamp(value)
class WorkflowAppLogPartialResponse(ResponseModel):
id: str
workflow_run: WorkflowRunForLogResponse | None = None
details: dict | list | str | int | float | bool | None = None
created_from: str | None = None
created_by_role: str | None = None
created_by_account: SimpleAccount | None = None
created_by_end_user: SimpleEndUser | None = None
created_at: int | None = None
@field_validator("created_from", "created_by_role", mode="before")
@classmethod
def _normalize_enum(cls, value):
return _enum_value(value)
@field_validator("created_at", mode="before")
@classmethod
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
return _to_timestamp(value)
class WorkflowAppLogPaginationResponse(ResponseModel):
page: int
limit: int
total: int
has_more: bool
data: list[WorkflowAppLogPartialResponse]
register_schema_models(
service_api_ns,
WorkflowRunResponse,
WorkflowRunForLogResponse,
WorkflowAppLogPartialResponse,
WorkflowAppLogPaginationResponse,
)
def _serialize_workflow_run(workflow_run: WorkflowRun) -> dict:
status = _enum_value(workflow_run.status)
raw_outputs = workflow_run.outputs_dict
if status == WorkflowExecutionStatus.PAUSED.value or raw_outputs is None:
outputs: dict = {}
elif isinstance(raw_outputs, dict):
outputs = raw_outputs
elif isinstance(raw_outputs, Mapping):
outputs = dict(raw_outputs)
else:
outputs = {}
return WorkflowRunResponse.model_validate(
{
"id": workflow_run.id,
"workflow_id": workflow_run.workflow_id,
"status": status,
"inputs": workflow_run.inputs,
"outputs": outputs,
"error": workflow_run.error,
"total_steps": workflow_run.total_steps,
"total_tokens": workflow_run.total_tokens,
"created_at": workflow_run.created_at,
"finished_at": workflow_run.finished_at,
"elapsed_time": workflow_run.elapsed_time,
}
).model_dump(mode="json")
def _serialize_workflow_log_pagination(pagination) -> dict:
return WorkflowAppLogPaginationResponse.model_validate(pagination, from_attributes=True).model_dump(mode="json")
@service_api_ns.route("/workflows/run/<string:workflow_run_id>")
@ -112,7 +219,11 @@ class WorkflowRunDetailApi(Resource):
}
)
@validate_app_token
@service_api_ns.marshal_with(build_workflow_run_model(service_api_ns))
@service_api_ns.response(
200,
"Workflow run details retrieved successfully",
service_api_ns.models[WorkflowRunResponse.__name__],
)
def get(self, app_model: App, workflow_run_id: str):
"""Get a workflow task running detail.
@ -133,7 +244,7 @@ class WorkflowRunDetailApi(Resource):
)
if not workflow_run:
raise NotFound("Workflow run not found.")
return workflow_run
return _serialize_workflow_run(workflow_run)
@service_api_ns.route("/workflows/run")
@ -299,7 +410,11 @@ class WorkflowAppLogApi(Resource):
}
)
@validate_app_token
@service_api_ns.marshal_with(build_workflow_app_log_pagination_model(service_api_ns))
@service_api_ns.response(
200,
"Logs retrieved successfully",
service_api_ns.models[WorkflowAppLogPaginationResponse.__name__],
)
def get(self, app_model: App):
"""Get workflow app logs.
@ -327,4 +442,4 @@ class WorkflowAppLogApi(Resource):
created_by_account=args.created_by_account,
)
return workflow_app_log_pagination
return _serialize_workflow_log_pagination(workflow_app_log_pagination)

View File

@ -1,7 +1,7 @@
import json
import re
from collections.abc import Generator
from typing import Union
from typing import Any, Union
from graphon.model_runtime.entities.llm_entities import LLMResultChunk
@ -11,7 +11,7 @@ from core.agent.entities import AgentScratchpadUnit
class CotAgentOutputParser:
@classmethod
def handle_react_stream_output(
cls, llm_response: Generator[LLMResultChunk, None, None], usage_dict: dict
cls, llm_response: Generator[LLMResultChunk, None, None], usage_dict: dict[str, Any]
) -> Generator[Union[str, AgentScratchpadUnit.Action], None, None]:
def parse_action(action) -> Union[str, AgentScratchpadUnit.Action]:
action_name = None

View File

@ -57,7 +57,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_full_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, Any, None]:
) -> Generator[dict[str, Any] | str, Any, None]:
"""
Convert stream full response.
:param stream_response: stream response
@ -88,7 +88,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_simple_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, Any, None]:
) -> Generator[dict[str, Any] | str, Any, None]:
"""
Convert stream simple response.
:param stream_response: stream response

View File

@ -1,5 +1,5 @@
from collections.abc import Generator
from typing import cast
from typing import Any, cast
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
from core.app.entities.task_entities import (
@ -56,7 +56,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_full_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, None, None]:
) -> Generator[dict[str, Any] | str, None, None]:
"""
Convert stream full response.
:param stream_response: stream response
@ -87,7 +87,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_simple_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, None, None]:
) -> Generator[dict[str, Any] | str, None, None]:
"""
Convert stream simple response.
:param stream_response: stream response

View File

@ -24,7 +24,7 @@ class AppGenerateResponseConverter(ABC):
return cls.convert_blocking_full_response(response)
else:
def _generate_full_response() -> Generator[dict | str, Any, None]:
def _generate_full_response() -> Generator[dict[str, Any] | str, Any, None]:
yield from cls.convert_stream_full_response(response)
return _generate_full_response()
@ -33,7 +33,7 @@ class AppGenerateResponseConverter(ABC):
return cls.convert_blocking_simple_response(response)
else:
def _generate_simple_response() -> Generator[dict | str, Any, None]:
def _generate_simple_response() -> Generator[dict[str, Any] | str, Any, None]:
yield from cls.convert_stream_simple_response(response)
return _generate_simple_response()
@ -52,14 +52,14 @@ class AppGenerateResponseConverter(ABC):
@abstractmethod
def convert_stream_full_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, None, None]:
) -> Generator[dict[str, Any] | str, None, None]:
raise NotImplementedError
@classmethod
@abstractmethod
def convert_stream_simple_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, None, None]:
) -> Generator[dict[str, Any] | str, None, None]:
raise NotImplementedError
@classmethod

View File

@ -1,5 +1,5 @@
from collections.abc import Generator
from typing import cast
from typing import Any, cast
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
from core.app.entities.task_entities import (
@ -56,7 +56,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_full_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, None, None]:
) -> Generator[dict[str, Any] | str, None, None]:
"""
Convert stream full response.
:param stream_response: stream response
@ -87,7 +87,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_simple_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, None, None]:
) -> Generator[dict[str, Any] | str, None, None]:
"""
Convert stream simple response.
:param stream_response: stream response

View File

@ -1,5 +1,5 @@
from collections.abc import Generator
from typing import cast
from typing import Any, cast
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
from core.app.entities.task_entities import (
@ -55,7 +55,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_full_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, None, None]:
) -> Generator[dict[str, Any] | str, None, None]:
"""
Convert stream full response.
:param stream_response: stream response
@ -85,7 +85,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_simple_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, None, None]:
) -> Generator[dict[str, Any] | str, None, None]:
"""
Convert stream simple response.
:param stream_response: stream response

View File

@ -37,7 +37,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_full_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, None, None]:
) -> Generator[dict[str, Any] | str, None, None]:
"""
Convert stream full response.
:param stream_response: stream response
@ -66,7 +66,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_simple_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, None, None]:
) -> Generator[dict[str, Any] | str, None, None]:
"""
Convert stream simple response.
:param stream_response: stream response

View File

@ -1,5 +1,5 @@
from collections.abc import Generator
from typing import cast
from typing import Any, cast
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
from core.app.entities.task_entities import (
@ -37,7 +37,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_full_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, None, None]:
) -> Generator[dict[str, Any] | str, None, None]:
"""
Convert stream full response.
:param stream_response: stream response
@ -66,7 +66,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_simple_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, None, None]:
) -> Generator[dict[str, Any] | str, None, None]:
"""
Convert stream simple response.
:param stream_response: stream response

View File

@ -682,15 +682,16 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
def _save_workflow_app_log(self, *, session: Session, workflow_run_id: str | None):
invoke_from = self._application_generate_entity.invoke_from
if invoke_from == InvokeFrom.SERVICE_API:
created_from = WorkflowAppLogCreatedFrom.SERVICE_API
elif invoke_from == InvokeFrom.EXPLORE:
created_from = WorkflowAppLogCreatedFrom.INSTALLED_APP
elif invoke_from == InvokeFrom.WEB_APP:
created_from = WorkflowAppLogCreatedFrom.WEB_APP
else:
# not save log for debugging
return
match invoke_from:
case InvokeFrom.SERVICE_API:
created_from = WorkflowAppLogCreatedFrom.SERVICE_API
case InvokeFrom.EXPLORE:
created_from = WorkflowAppLogCreatedFrom.INSTALLED_APP
case InvokeFrom.WEB_APP:
created_from = WorkflowAppLogCreatedFrom.WEB_APP
case InvokeFrom.DEBUGGER | InvokeFrom.TRIGGER | InvokeFrom.PUBLISHED_PIPELINE | InvokeFrom.VALIDATION:
# not save log for debugging
return
if not workflow_run_id:
return

View File

@ -1,3 +1,5 @@
from typing import Any
from flask import Flask
from graphon.model_runtime.entities.model_entities import ModelType
from pydantic import BaseModel
@ -28,7 +30,7 @@ class FreeHostingQuota(HostingQuota):
class HostingProvider(BaseModel):
enabled: bool = False
credentials: dict | None = None
credentials: dict[str, Any] | None = None
quota_unit: QuotaUnit | None = None
quotas: list[HostingQuota] = []

View File

@ -1,6 +1,7 @@
import tempfile
from binascii import hexlify, unhexlify
from collections.abc import Generator
from typing import Any
from graphon.model_runtime.entities.llm_entities import (
LLMResult,
@ -226,7 +227,7 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
# invoke model
response = model_instance.invoke_tts(content_text=payload.content_text, voice=payload.voice)
def handle() -> Generator[dict, None, None]:
def handle() -> Generator[dict[str, Any], None, None]:
for chunk in response:
yield {"result": hexlify(chunk).decode("utf-8")}

View File

@ -95,9 +95,9 @@ class ExtractProcessor:
) -> list[Document]:
if extract_setting.datasource_type == DatasourceType.FILE:
with tempfile.TemporaryDirectory() as temp_dir:
upload_file = extract_setting.upload_file
if not file_path:
assert extract_setting.upload_file is not None, "upload_file is required"
upload_file: UploadFile = extract_setting.upload_file
assert upload_file is not None, "upload_file is required"
suffix = Path(upload_file.key).suffix
# FIXME mypy: Cannot determine type of 'tempfile._get_candidate_names' better not use it here
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" # type: ignore
@ -113,6 +113,7 @@ class ExtractProcessor:
if file_extension in {".xlsx", ".xls"}:
extractor = ExcelExtractor(file_path)
elif file_extension == ".pdf":
assert upload_file is not None
extractor = PdfExtractor(file_path, upload_file.tenant_id, upload_file.created_by)
elif file_extension in {".md", ".markdown", ".mdx"}:
extractor = (
@ -123,6 +124,7 @@ class ExtractProcessor:
elif file_extension in {".htm", ".html"}:
extractor = HtmlExtractor(file_path)
elif file_extension == ".docx":
assert upload_file is not None
extractor = WordExtractor(file_path, upload_file.tenant_id, upload_file.created_by)
elif file_extension == ".doc":
extractor = UnstructuredWordExtractor(file_path, unstructured_api_url, unstructured_api_key)
@ -149,12 +151,14 @@ class ExtractProcessor:
if file_extension in {".xlsx", ".xls"}:
extractor = ExcelExtractor(file_path)
elif file_extension == ".pdf":
assert upload_file is not None
extractor = PdfExtractor(file_path, upload_file.tenant_id, upload_file.created_by)
elif file_extension in {".md", ".markdown", ".mdx"}:
extractor = MarkdownExtractor(file_path, autodetect_encoding=True)
elif file_extension in {".htm", ".html"}:
extractor = HtmlExtractor(file_path)
elif file_extension == ".docx":
assert upload_file is not None
extractor = WordExtractor(file_path, upload_file.tenant_id, upload_file.created_by)
elif file_extension == ".csv":
extractor = CSVExtractor(file_path, autodetect_encoding=True)

View File

@ -174,21 +174,25 @@ class FirecrawlApp:
return f"{self.base_url.rstrip('/')}/{path.lstrip('/')}"
def _post_request(self, url, data, headers, retries=3, backoff_factor=0.5) -> httpx.Response:
response: httpx.Response | None = None
for attempt in range(retries):
response = httpx.post(url, headers=headers, json=data)
if response.status_code == 502:
time.sleep(backoff_factor * (2**attempt))
else:
return response
assert response is not None, "retries must be at least 1"
return response
def _get_request(self, url, headers, retries=3, backoff_factor=0.5) -> httpx.Response:
response: httpx.Response | None = None
for attempt in range(retries):
response = httpx.get(url, headers=headers)
if response.status_code == 502:
time.sleep(backoff_factor * (2**attempt))
else:
return response
assert response is not None, "retries must be at least 1"
return response
def _handle_error(self, response, action):

View File

@ -254,7 +254,7 @@ def resolve_dify_schema_refs(
return resolver.resolve(schema)
def _remove_metadata_fields(schema: dict) -> dict:
def _remove_metadata_fields(schema: dict[str, Any]) -> dict[str, Any]:
"""
Remove metadata fields from schema that shouldn't be included in resolved output

View File

@ -329,7 +329,7 @@ class EnterpriseMetricHandler:
return
include_content = exporter.include_content
attrs: dict = {
attrs: dict[str, Any] = {
"dify.message.id": payload.get("message_id"),
"dify.tenant_id": envelope.tenant_id,
"dify.event.id": envelope.event_id,

View File

@ -9,6 +9,7 @@ from typing_extensions import TypedDict
from configs import dify_config
from dify_app import DifyApp
from extensions.redis_names import normalize_redis_key_prefix
class _CelerySentinelKwargsDict(TypedDict):
@ -16,9 +17,10 @@ class _CelerySentinelKwargsDict(TypedDict):
password: str | None
class CelerySentinelTransportDict(TypedDict):
class CelerySentinelTransportDict(TypedDict, total=False):
master_name: str | None
sentinel_kwargs: _CelerySentinelKwargsDict
global_keyprefix: str
class CelerySSLOptionsDict(TypedDict):
@ -61,15 +63,31 @@ def get_celery_ssl_options() -> CelerySSLOptionsDict | None:
def get_celery_broker_transport_options() -> CelerySentinelTransportDict | dict[str, Any]:
"""Get broker transport options (e.g. Redis Sentinel) for Celery connections."""
transport_options: CelerySentinelTransportDict | dict[str, Any]
if dify_config.CELERY_USE_SENTINEL:
return CelerySentinelTransportDict(
transport_options = CelerySentinelTransportDict(
master_name=dify_config.CELERY_SENTINEL_MASTER_NAME,
sentinel_kwargs=_CelerySentinelKwargsDict(
socket_timeout=dify_config.CELERY_SENTINEL_SOCKET_TIMEOUT,
password=dify_config.CELERY_SENTINEL_PASSWORD,
),
)
return {}
else:
transport_options = {}
global_keyprefix = get_celery_redis_global_keyprefix()
if global_keyprefix:
transport_options["global_keyprefix"] = global_keyprefix
return transport_options
def get_celery_redis_global_keyprefix() -> str | None:
"""Return the Redis transport prefix for Celery when namespace isolation is enabled."""
normalized_prefix = normalize_redis_key_prefix(dify_config.REDIS_KEY_PREFIX)
if not normalized_prefix:
return None
return f"{normalized_prefix}:"
def init_app(app: DifyApp) -> Celery:

View File

@ -3,7 +3,7 @@ import logging
import ssl
from collections.abc import Callable
from datetime import timedelta
from typing import TYPE_CHECKING, Any, Union
from typing import Any, Union, cast
import redis
from redis import RedisError
@ -18,17 +18,26 @@ from typing_extensions import TypedDict
from configs import dify_config
from dify_app import DifyApp
from extensions.redis_names import (
normalize_redis_key_prefix,
serialize_redis_name,
serialize_redis_name_arg,
serialize_redis_name_args,
)
from libs.broadcast_channel.channel import BroadcastChannel as BroadcastChannelProtocol
from libs.broadcast_channel.redis.channel import BroadcastChannel as RedisBroadcastChannel
from libs.broadcast_channel.redis.sharded_channel import ShardedRedisBroadcastChannel
from libs.broadcast_channel.redis.streams_channel import StreamsBroadcastChannel
if TYPE_CHECKING:
from redis.lock import Lock
logger = logging.getLogger(__name__)
_normalize_redis_key_prefix = normalize_redis_key_prefix
_serialize_redis_name = serialize_redis_name
_serialize_redis_name_arg = serialize_redis_name_arg
_serialize_redis_name_args = serialize_redis_name_args
class RedisClientWrapper:
"""
A wrapper class for the Redis client that addresses the issue where the global
@ -59,68 +68,148 @@ class RedisClientWrapper:
if self._client is None:
self._client = client
if TYPE_CHECKING:
# Type hints for IDE support and static analysis
# These are not executed at runtime but provide type information
def get(self, name: str | bytes) -> Any: ...
def set(
self,
name: str | bytes,
value: Any,
ex: int | None = None,
px: int | None = None,
nx: bool = False,
xx: bool = False,
keepttl: bool = False,
get: bool = False,
exat: int | None = None,
pxat: int | None = None,
) -> Any: ...
def setex(self, name: str | bytes, time: int | timedelta, value: Any) -> Any: ...
def setnx(self, name: str | bytes, value: Any) -> Any: ...
def delete(self, *names: str | bytes) -> Any: ...
def incr(self, name: str | bytes, amount: int = 1) -> Any: ...
def expire(
self,
name: str | bytes,
time: int | timedelta,
nx: bool = False,
xx: bool = False,
gt: bool = False,
lt: bool = False,
) -> Any: ...
def lock(
self,
name: str,
timeout: float | None = None,
sleep: float = 0.1,
blocking: bool = True,
blocking_timeout: float | None = None,
thread_local: bool = True,
) -> Lock: ...
def zadd(
self,
name: str | bytes,
mapping: dict[str | bytes | int | float, float | int | str | bytes],
nx: bool = False,
xx: bool = False,
ch: bool = False,
incr: bool = False,
gt: bool = False,
lt: bool = False,
) -> Any: ...
def zremrangebyscore(self, name: str | bytes, min: float | str, max: float | str) -> Any: ...
def zcard(self, name: str | bytes) -> Any: ...
def getdel(self, name: str | bytes) -> Any: ...
def pubsub(self) -> PubSub: ...
def pipeline(self, transaction: bool = True, shard_hint: str | None = None) -> Any: ...
def __getattr__(self, item: str) -> Any:
def _require_client(self) -> redis.Redis | RedisCluster:
if self._client is None:
raise RuntimeError("Redis client is not initialized. Call init_app first.")
return getattr(self._client, item)
return self._client
def _get_prefix(self) -> str:
return dify_config.REDIS_KEY_PREFIX
def get(self, name: str | bytes) -> Any:
return self._require_client().get(_serialize_redis_name_arg(name, self._get_prefix()))
def set(
self,
name: str | bytes,
value: Any,
ex: int | None = None,
px: int | None = None,
nx: bool = False,
xx: bool = False,
keepttl: bool = False,
get: bool = False,
exat: int | None = None,
pxat: int | None = None,
) -> Any:
return self._require_client().set(
_serialize_redis_name_arg(name, self._get_prefix()),
value,
ex=ex,
px=px,
nx=nx,
xx=xx,
keepttl=keepttl,
get=get,
exat=exat,
pxat=pxat,
)
def setex(self, name: str | bytes, time: int | timedelta, value: Any) -> Any:
return self._require_client().setex(_serialize_redis_name_arg(name, self._get_prefix()), time, value)
def setnx(self, name: str | bytes, value: Any) -> Any:
return self._require_client().setnx(_serialize_redis_name_arg(name, self._get_prefix()), value)
def delete(self, *names: str | bytes) -> Any:
return self._require_client().delete(*_serialize_redis_name_args(names, self._get_prefix()))
def incr(self, name: str | bytes, amount: int = 1) -> Any:
return self._require_client().incr(_serialize_redis_name_arg(name, self._get_prefix()), amount)
def expire(
self,
name: str | bytes,
time: int | timedelta,
nx: bool = False,
xx: bool = False,
gt: bool = False,
lt: bool = False,
) -> Any:
return self._require_client().expire(
_serialize_redis_name_arg(name, self._get_prefix()),
time,
nx=nx,
xx=xx,
gt=gt,
lt=lt,
)
def exists(self, *names: str | bytes) -> Any:
return self._require_client().exists(*_serialize_redis_name_args(names, self._get_prefix()))
def ttl(self, name: str | bytes) -> Any:
return self._require_client().ttl(_serialize_redis_name_arg(name, self._get_prefix()))
def getdel(self, name: str | bytes) -> Any:
return self._require_client().getdel(_serialize_redis_name_arg(name, self._get_prefix()))
def lock(
self,
name: str,
timeout: float | None = None,
sleep: float = 0.1,
blocking: bool = True,
blocking_timeout: float | None = None,
thread_local: bool = True,
) -> Any:
return self._require_client().lock(
_serialize_redis_name(name, self._get_prefix()),
timeout=timeout,
sleep=sleep,
blocking=blocking,
blocking_timeout=blocking_timeout,
thread_local=thread_local,
)
def hset(self, name: str | bytes, *args: Any, **kwargs: Any) -> Any:
return self._require_client().hset(_serialize_redis_name_arg(name, self._get_prefix()), *args, **kwargs)
def hgetall(self, name: str | bytes) -> Any:
return self._require_client().hgetall(_serialize_redis_name_arg(name, self._get_prefix()))
def hdel(self, name: str | bytes, *keys: str | bytes) -> Any:
return self._require_client().hdel(_serialize_redis_name_arg(name, self._get_prefix()), *keys)
def hlen(self, name: str | bytes) -> Any:
return self._require_client().hlen(_serialize_redis_name_arg(name, self._get_prefix()))
def zadd(
self,
name: str | bytes,
mapping: dict[str | bytes | int | float, float | int | str | bytes],
nx: bool = False,
xx: bool = False,
ch: bool = False,
incr: bool = False,
gt: bool = False,
lt: bool = False,
) -> Any:
return self._require_client().zadd(
_serialize_redis_name_arg(name, self._get_prefix()),
cast(Any, mapping),
nx=nx,
xx=xx,
ch=ch,
incr=incr,
gt=gt,
lt=lt,
)
def zremrangebyscore(self, name: str | bytes, min: float | str, max: float | str) -> Any:
return self._require_client().zremrangebyscore(_serialize_redis_name_arg(name, self._get_prefix()), min, max)
def zcard(self, name: str | bytes) -> Any:
return self._require_client().zcard(_serialize_redis_name_arg(name, self._get_prefix()))
def pubsub(self) -> PubSub:
return self._require_client().pubsub()
def pipeline(self, transaction: bool = True, shard_hint: str | None = None) -> Any:
return self._require_client().pipeline(transaction=transaction, shard_hint=shard_hint)
def __getattr__(self, item: str) -> Any:
return getattr(self._require_client(), item)
redis_client: RedisClientWrapper = RedisClientWrapper()

View File

@ -0,0 +1,32 @@
from configs import dify_config
def normalize_redis_key_prefix(prefix: str | None) -> str:
"""Normalize the configured Redis key prefix for consistent runtime use."""
if prefix is None:
return ""
return prefix.strip()
def get_redis_key_prefix() -> str:
"""Read and normalize the current Redis key prefix from config."""
return normalize_redis_key_prefix(dify_config.REDIS_KEY_PREFIX)
def serialize_redis_name(name: str, prefix: str | None = None) -> str:
"""Convert a logical Redis name into the physical name used in Redis."""
normalized_prefix = get_redis_key_prefix() if prefix is None else normalize_redis_key_prefix(prefix)
if not normalized_prefix:
return name
return f"{normalized_prefix}:{name}"
def serialize_redis_name_arg(name: str | bytes, prefix: str | None = None) -> str | bytes:
"""Prefix string Redis names while preserving bytes inputs unchanged."""
if isinstance(name, bytes):
return name
return serialize_redis_name(name, prefix)
def serialize_redis_name_args(names: tuple[str | bytes, ...], prefix: str | None = None) -> tuple[str | bytes, ...]:
return tuple(serialize_redis_name_arg(name, prefix) for name in names)

View File

@ -2,6 +2,7 @@ import logging
import os
from collections.abc import Generator
from pathlib import Path
from typing import Any
import opendal
from dotenv import dotenv_values
@ -19,7 +20,7 @@ def _get_opendal_kwargs(*, scheme: str, env_file_path: str = ".env", prefix: str
if key.startswith(config_prefix):
kwargs[key[len(config_prefix) :].lower()] = value
file_env_vars: dict = dotenv_values(env_file_path) or {}
file_env_vars: dict[str, Any] = dotenv_values(env_file_path) or {}
for key, value in file_env_vars.items():
if key.startswith(config_prefix) and key[len(config_prefix) :].lower() not in kwargs and value:
kwargs[key[len(config_prefix) :].lower()] = value

View File

@ -2,6 +2,7 @@ from __future__ import annotations
from typing import Any
from extensions.redis_names import serialize_redis_name
from libs.broadcast_channel.channel import Producer, Subscriber, Subscription
from redis import Redis, RedisCluster
@ -32,12 +33,13 @@ class Topic:
def __init__(self, redis_client: Redis | RedisCluster, topic: str):
self._client = redis_client
self._topic = topic
self._redis_topic = serialize_redis_name(topic)
def as_producer(self) -> Producer:
return self
def publish(self, payload: bytes) -> None:
self._client.publish(self._topic, payload)
self._client.publish(self._redis_topic, payload)
def as_subscriber(self) -> Subscriber:
return self
@ -46,7 +48,7 @@ class Topic:
return _RedisSubscription(
client=self._client,
pubsub=self._client.pubsub(),
topic=self._topic,
topic=self._redis_topic,
)

View File

@ -2,6 +2,7 @@ from __future__ import annotations
from typing import Any
from extensions.redis_names import serialize_redis_name
from libs.broadcast_channel.channel import Producer, Subscriber, Subscription
from redis import Redis, RedisCluster
@ -30,12 +31,13 @@ class ShardedTopic:
def __init__(self, redis_client: Redis | RedisCluster, topic: str):
self._client = redis_client
self._topic = topic
self._redis_topic = serialize_redis_name(topic)
def as_producer(self) -> Producer:
return self
def publish(self, payload: bytes) -> None:
self._client.spublish(self._topic, payload) # type: ignore[attr-defined,union-attr]
self._client.spublish(self._redis_topic, payload) # type: ignore[attr-defined,union-attr]
def as_subscriber(self) -> Subscriber:
return self
@ -44,7 +46,7 @@ class ShardedTopic:
return _RedisShardedSubscription(
client=self._client,
pubsub=self._client.pubsub(),
topic=self._topic,
topic=self._redis_topic,
)

View File

@ -6,6 +6,7 @@ import threading
from collections.abc import Iterator
from typing import Self
from extensions.redis_names import serialize_redis_name
from libs.broadcast_channel.channel import Producer, Subscriber, Subscription
from libs.broadcast_channel.exc import SubscriptionClosedError
from redis import Redis, RedisCluster
@ -35,7 +36,7 @@ class StreamsTopic:
def __init__(self, redis_client: Redis | RedisCluster, topic: str, *, retention_seconds: int = 600):
self._client = redis_client
self._topic = topic
self._key = f"stream:{topic}"
self._key = serialize_redis_name(f"stream:{topic}")
self._retention_seconds = retention_seconds
self.max_length = 5000

View File

@ -103,7 +103,10 @@ class DbMigrationAutoRenewLock:
timeout=self._ttl_seconds,
thread_local=False,
)
acquired = bool(self._lock.acquire(*args, **kwargs))
lock = self._lock
if lock is None:
raise RuntimeError("Redis lock initialization failed.")
acquired = bool(lock.acquire(*args, **kwargs))
self._acquired = acquired
if acquired:
self._start_heartbeat()

View File

@ -120,10 +120,22 @@ class AppIconUrlField(fields.Raw):
obj = obj["app"]
if isinstance(obj, App | Site) and obj.icon_type == IconType.IMAGE:
return file_helpers.get_signed_file_url(obj.icon)
return build_icon_url(obj.icon_type, obj.icon)
return None
def build_icon_url(icon_type: Any, icon: str | None) -> str | None:
if icon is None or icon_type is None:
return None
from models.model import IconType
icon_type_value = icon_type.value if isinstance(icon_type, IconType) else str(icon_type)
if icon_type_value.lower() != IconType.IMAGE:
return None
return file_helpers.get_signed_file_url(icon)
class AvatarUrlField(fields.Raw):
def output(self, key, obj, **kwargs):
if obj is None:

View File

@ -1552,7 +1552,7 @@ class PipelineBuiltInTemplate(TypeBase):
name: Mapped[str] = mapped_column(sa.String(255), nullable=False)
description: Mapped[str] = mapped_column(LongText, nullable=False)
chunk_structure: Mapped[str] = mapped_column(sa.String(255), nullable=False)
icon: Mapped[dict] = mapped_column(sa.JSON, nullable=False)
icon: Mapped[dict[str, Any]] = mapped_column(sa.JSON, nullable=False)
yaml_content: Mapped[str] = mapped_column(LongText, nullable=False)
copyright: Mapped[str] = mapped_column(sa.String(255), nullable=False)
privacy_policy: Mapped[str] = mapped_column(sa.String(255), nullable=False)
@ -1585,7 +1585,7 @@ class PipelineCustomizedTemplate(TypeBase):
name: Mapped[str] = mapped_column(sa.String(255), nullable=False)
description: Mapped[str] = mapped_column(LongText, nullable=False)
chunk_structure: Mapped[str] = mapped_column(sa.String(255), nullable=False)
icon: Mapped[dict] = mapped_column(sa.JSON, nullable=False)
icon: Mapped[dict[str, Any]] = mapped_column(sa.JSON, nullable=False)
position: Mapped[int] = mapped_column(sa.Integer, nullable=False)
yaml_content: Mapped[str] = mapped_column(LongText, nullable=False)
install_count: Mapped[int] = mapped_column(sa.Integer, nullable=False)
@ -1658,7 +1658,7 @@ class DocumentPipelineExecutionLog(TypeBase):
datasource_type: Mapped[str] = mapped_column(sa.String(255), nullable=False)
datasource_info: Mapped[str] = mapped_column(LongText, nullable=False)
datasource_node_id: Mapped[str] = mapped_column(sa.String(255), nullable=False)
input_data: Mapped[dict] = mapped_column(sa.JSON, nullable=False)
input_data: Mapped[dict[str, Any]] = mapped_column(sa.JSON, nullable=False)
created_by: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
created_at: Mapped[datetime] = mapped_column(
sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False

View File

@ -1007,7 +1007,7 @@ class OAuthProviderApp(TypeBase):
app_icon: Mapped[str] = mapped_column(String(255), nullable=False)
client_id: Mapped[str] = mapped_column(String(255), nullable=False)
client_secret: Mapped[str] = mapped_column(String(255), nullable=False)
app_label: Mapped[dict] = mapped_column(sa.JSON, nullable=False, default_factory=dict)
app_label: Mapped[dict[str, Any]] = mapped_column(sa.JSON, nullable=False, default_factory=dict)
redirect_uris: Mapped[list] = mapped_column(sa.JSON, nullable=False, default_factory=list)
scope: Mapped[str] = mapped_column(
String(255),
@ -2495,7 +2495,7 @@ class TraceAppConfig(TypeBase):
)
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
tracing_provider: Mapped[str | None] = mapped_column(String(255), nullable=True)
tracing_config: Mapped[dict | None] = mapped_column(sa.JSON, nullable=True)
tracing_config: Mapped[dict[str, Any] | None] = mapped_column(sa.JSON, nullable=True)
created_at: Mapped[datetime] = mapped_column(
sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False
)

View File

@ -1,4 +1,5 @@
from datetime import datetime
from typing import Any
import sqlalchemy as sa
from sqlalchemy import func
@ -22,7 +23,7 @@ class DatasourceOauthParamConfig(TypeBase):
)
plugin_id: Mapped[str] = mapped_column(sa.String(255), nullable=False)
provider: Mapped[str] = mapped_column(sa.String(255), nullable=False)
system_credentials: Mapped[dict] = mapped_column(AdjustedJSON, nullable=False)
system_credentials: Mapped[dict[str, Any]] = mapped_column(AdjustedJSON, nullable=False)
class DatasourceProvider(TypeBase):
@ -40,7 +41,7 @@ class DatasourceProvider(TypeBase):
provider: Mapped[str] = mapped_column(sa.String(128), nullable=False)
plugin_id: Mapped[str] = mapped_column(sa.String(255), nullable=False)
auth_type: Mapped[str] = mapped_column(sa.String(255), nullable=False)
encrypted_credentials: Mapped[dict] = mapped_column(AdjustedJSON, nullable=False)
encrypted_credentials: Mapped[dict[str, Any]] = mapped_column(AdjustedJSON, nullable=False)
avatar_url: Mapped[str] = mapped_column(LongText, nullable=True, default="default")
is_default: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"), default=False)
expires_at: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default="-1", default=-1)
@ -70,7 +71,7 @@ class DatasourceOauthTenantParamConfig(TypeBase):
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
provider: Mapped[str] = mapped_column(sa.String(255), nullable=False)
plugin_id: Mapped[str] = mapped_column(sa.String(255), nullable=False)
client_params: Mapped[dict] = mapped_column(AdjustedJSON, nullable=False, default_factory=dict)
client_params: Mapped[dict[str, Any]] = mapped_column(AdjustedJSON, nullable=False, default_factory=dict)
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, default=False)
created_at: Mapped[datetime] = mapped_column(

View File

@ -25,7 +25,7 @@ class DataSourceOauthBinding(TypeBase):
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
access_token: Mapped[str] = mapped_column(String(255), nullable=False)
provider: Mapped[str] = mapped_column(String(255), nullable=False)
source_info: Mapped[dict] = mapped_column(AdjustedJSON, nullable=False)
source_info: Mapped[dict[str, Any]] = mapped_column(AdjustedJSON, nullable=False)
created_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=func.current_timestamp(), init=False
)

View File

@ -23,7 +23,7 @@ class ElasticSearchJaVector(ElasticSearchVector):
self,
embeddings: list[list[float]],
metadatas: list[dict[Any, Any]] | None = None,
index_params: dict | None = None,
index_params: dict[str, Any] | None = None,
):
lock_name = f"vector_indexing_lock_{self._collection_name}"
with redis_client.lock(lock_name, timeout=20):

View File

@ -1,3 +1,4 @@
from typing import Any
from uuid import UUID
from numpy import ndarray
@ -8,5 +9,5 @@ class CollectionORM(DeclarativeBase):
__tablename__: str
id: Mapped[UUID]
text: Mapped[str]
meta: Mapped[dict]
meta: Mapped[dict[str, Any]]
vector: Mapped[ndarray]

View File

@ -67,7 +67,7 @@ class PGVectoRS(BaseVector):
primary_key=True,
)
text: Mapped[str]
meta: Mapped[dict] = mapped_column(postgresql.JSONB)
meta: Mapped[dict[str, Any]] = mapped_column(postgresql.JSONB)
vector: Mapped[ndarray] = mapped_column(VECTOR(dim))
self._table = _Table

View File

@ -455,7 +455,7 @@ class AppDslService:
app.updated_by = account.id
self._session.add(app)
self._session.commit()
self._session.flush()
app_was_created.send(app, account=account)
# save dependencies

View File

@ -87,7 +87,7 @@ class RetrievalModel(BaseModel):
class MetaDataConfig(BaseModel):
doc_type: str
doc_metadata: dict
doc_metadata: dict[str, Any]
class KnowledgeConfig(BaseModel):

View File

@ -1,4 +1,5 @@
import logging
from typing import Any
from graphon.model_runtime.entities.model_entities import ModelType, ParameterRule
@ -168,7 +169,9 @@ class ModelProviderService:
model_name=model,
)
def get_provider_credential(self, tenant_id: str, provider: str, credential_id: str | None = None) -> dict | None:
def get_provider_credential(
self, tenant_id: str, provider: str, credential_id: str | None = None
) -> dict[str, Any] | None:
"""
get provider credentials.
@ -180,7 +183,7 @@ class ModelProviderService:
provider_configuration = self._get_provider_configuration(tenant_id, provider)
return provider_configuration.get_provider_credential(credential_id=credential_id)
def validate_provider_credentials(self, tenant_id: str, provider: str, credentials: dict):
def validate_provider_credentials(self, tenant_id: str, provider: str, credentials: dict[str, Any]):
"""
validate provider credentials before saving.
@ -192,7 +195,7 @@ class ModelProviderService:
provider_configuration.validate_provider_credentials(credentials)
def create_provider_credential(
self, tenant_id: str, provider: str, credentials: dict, credential_name: str | None
self, tenant_id: str, provider: str, credentials: dict[str, Any], credential_name: str | None
) -> None:
"""
Create and save new provider credentials.
@ -210,7 +213,7 @@ class ModelProviderService:
self,
tenant_id: str,
provider: str,
credentials: dict,
credentials: dict[str, Any],
credential_id: str,
credential_name: str | None,
) -> None:
@ -254,7 +257,7 @@ class ModelProviderService:
def get_model_credential(
self, tenant_id: str, provider: str, model_type: str, model: str, credential_id: str | None
) -> dict | None:
) -> dict[str, Any] | None:
"""
Retrieve model-specific credentials.
@ -270,7 +273,9 @@ class ModelProviderService:
model_type=ModelType.value_of(model_type), model=model, credential_id=credential_id
)
def validate_model_credentials(self, tenant_id: str, provider: str, model_type: str, model: str, credentials: dict):
def validate_model_credentials(
self, tenant_id: str, provider: str, model_type: str, model: str, credentials: dict[str, Any]
):
"""
validate model credentials.
@ -287,7 +292,13 @@ class ModelProviderService:
)
def create_model_credential(
self, tenant_id: str, provider: str, model_type: str, model: str, credentials: dict, credential_name: str | None
self,
tenant_id: str,
provider: str,
model_type: str,
model: str,
credentials: dict[str, Any],
credential_name: str | None,
) -> None:
"""
create and save model credentials.
@ -314,7 +325,7 @@ class ModelProviderService:
provider: str,
model_type: str,
model: str,
credentials: dict,
credentials: dict[str, Any],
credential_id: str,
credential_name: str | None,
) -> None:

View File

@ -104,7 +104,7 @@ class RagPipelineService:
self._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
@classmethod
def get_pipeline_templates(cls, type: str = "built-in", language: str = "en-US") -> dict:
def get_pipeline_templates(cls, type: str = "built-in", language: str = "en-US") -> dict[str, Any]:
if type == "built-in":
mode = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_MODE
retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)()
@ -120,7 +120,7 @@ class RagPipelineService:
return result
@classmethod
def get_pipeline_template_detail(cls, template_id: str, type: str = "built-in") -> dict | None:
def get_pipeline_template_detail(cls, template_id: str, type: str = "built-in") -> dict[str, Any] | None:
"""
Get pipeline template detail.
@ -131,7 +131,7 @@ class RagPipelineService:
if type == "built-in":
mode = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_MODE
retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)()
built_in_result: dict | None = retrieval_instance.get_pipeline_template_detail(template_id)
built_in_result: dict[str, Any] | None = retrieval_instance.get_pipeline_template_detail(template_id)
if built_in_result is None:
logger.warning(
"pipeline template retrieval returned empty result, template_id: %s, mode: %s",
@ -142,7 +142,7 @@ class RagPipelineService:
else:
mode = "customized"
retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)()
customized_result: dict | None = retrieval_instance.get_pipeline_template_detail(template_id)
customized_result: dict[str, Any] | None = retrieval_instance.get_pipeline_template_detail(template_id)
return customized_result
@classmethod
@ -297,7 +297,7 @@ class RagPipelineService:
self,
*,
pipeline: Pipeline,
graph: dict,
graph: dict[str, Any],
unique_hash: str | None,
account: Account,
environment_variables: Sequence[VariableBase],
@ -467,7 +467,9 @@ class RagPipelineService:
return default_block_configs
def get_default_block_config(self, node_type: str, filters: dict | None = None) -> Mapping[str, object] | None:
def get_default_block_config(
self, node_type: str, filters: dict[str, Any] | None = None
) -> Mapping[str, object] | None:
"""
Get default config of node.
:param node_type: node type
@ -500,7 +502,7 @@ class RagPipelineService:
return default_config
def run_draft_workflow_node(
self, pipeline: Pipeline, node_id: str, user_inputs: dict, account: Account
self, pipeline: Pipeline, node_id: str, user_inputs: dict[str, Any], account: Account
) -> WorkflowNodeExecutionModel | None:
"""
Run draft workflow node
@ -582,7 +584,7 @@ class RagPipelineService:
self,
pipeline: Pipeline,
node_id: str,
user_inputs: dict,
user_inputs: dict[str, Any],
account: Account,
datasource_type: str,
is_published: bool,
@ -749,7 +751,7 @@ class RagPipelineService:
self,
pipeline: Pipeline,
node_id: str,
user_inputs: dict,
user_inputs: dict[str, Any],
account: Account,
datasource_type: str,
is_published: bool,
@ -979,7 +981,7 @@ class RagPipelineService:
return workflow_node_execution
def update_workflow(
self, *, session: Session, workflow_id: str, tenant_id: str, account_id: str, data: dict
self, *, session: Session, workflow_id: str, tenant_id: str, account_id: str, data: dict[str, Any]
) -> Workflow | None:
"""
Update workflow attributes
@ -1099,7 +1101,9 @@ class RagPipelineService:
]
return datasource_provider_variables
def get_rag_pipeline_paginate_workflow_runs(self, pipeline: Pipeline, args: dict) -> InfiniteScrollPagination:
def get_rag_pipeline_paginate_workflow_runs(
self, pipeline: Pipeline, args: dict[str, Any]
) -> InfiniteScrollPagination:
"""
Get debug workflow run list
Only return triggered_from == debugging
@ -1169,7 +1173,7 @@ class RagPipelineService:
return list(node_executions)
@classmethod
def publish_customized_pipeline_template(cls, pipeline_id: str, args: dict):
def publish_customized_pipeline_template(cls, pipeline_id: str, args: dict[str, Any]):
"""
Publish customized pipeline template
"""
@ -1259,7 +1263,7 @@ class RagPipelineService:
)
return node_exec
def set_datasource_variables(self, pipeline: Pipeline, args: dict, current_user: Account):
def set_datasource_variables(self, pipeline: Pipeline, args: dict[str, Any], current_user: Account):
"""
Set datasource variables
"""
@ -1346,7 +1350,7 @@ class RagPipelineService:
)
return workflow_node_execution_db_model
def get_recommended_plugins(self, type: str) -> dict:
def get_recommended_plugins(self, type: str) -> dict[str, Any]:
# Query active recommended plugins
stmt = select(PipelineRecommendedPlugin).where(PipelineRecommendedPlugin.active == True)
if type and type != "all":

View File

@ -241,8 +241,8 @@ class WorkflowService:
self,
*,
app_model: App,
graph: dict,
features: dict,
graph: dict[str, Any],
features: dict[str, Any],
unique_hash: str | None,
account: Account,
environment_variables: Sequence[VariableBase],
@ -576,7 +576,7 @@ class WorkflowService:
except Exception as e:
raise ValueError(f"Failed to validate default credential for tool provider {provider}: {str(e)}")
def _validate_load_balancing_credentials(self, workflow: Workflow, node_data: dict, node_id: str) -> None:
def _validate_load_balancing_credentials(self, workflow: Workflow, node_data: dict[str, Any], node_id: str) -> None:
"""
Validate load balancing credentials for a workflow node.
@ -1214,7 +1214,7 @@ class WorkflowService:
return variable_pool
def run_free_workflow_node(
self, node_data: dict, tenant_id: str, user_id: str, node_id: str, user_inputs: dict[str, Any]
self, node_data: dict[str, Any], tenant_id: str, user_id: str, node_id: str, user_inputs: dict[str, Any]
) -> WorkflowNodeExecution:
"""
Run free workflow node
@ -1361,7 +1361,7 @@ class WorkflowService:
node_execution.status = WorkflowNodeExecutionStatus.FAILED
node_execution.error = error
def convert_to_workflow(self, app_model: App, account: Account, args: dict) -> App:
def convert_to_workflow(self, app_model: App, account: Account, args: dict[str, Any]) -> App:
"""
Basic mode of chatbot app(expert mode) to workflow
Completion App to Workflow App
@ -1421,7 +1421,7 @@ class WorkflowService:
if node_type == BuiltinNodeTypes.HUMAN_INPUT:
self._validate_human_input_node_data(node_data)
def validate_features_structure(self, app_model: App, features: dict):
def validate_features_structure(self, app_model: App, features: dict[str, Any]):
match app_model.mode:
case AppMode.ADVANCED_CHAT:
return AdvancedChatAppConfigManager.config_validate(
@ -1434,7 +1434,7 @@ class WorkflowService:
case _:
raise ValueError(f"Invalid app mode: {app_model.mode}")
def _validate_human_input_node_data(self, node_data: dict) -> None:
def _validate_human_input_node_data(self, node_data: dict[str, Any]) -> None:
"""
Validate HumanInput node data format.
@ -1452,7 +1452,7 @@ class WorkflowService:
raise ValueError(f"Invalid HumanInput node data: {str(e)}")
def update_workflow(
self, *, session: Session, workflow_id: str, tenant_id: str, account_id: str, data: dict
self, *, session: Session, workflow_id: str, tenant_id: str, account_id: str, data: dict[str, Any]
) -> Workflow | None:
"""
Update workflow attributes

View File

@ -33,6 +33,7 @@ REDIS_USERNAME=
REDIS_PASSWORD=difyai123456
REDIS_USE_SSL=false
REDIS_DB=0
REDIS_KEY_PREFIX=
# PostgreSQL database configuration
DB_USERNAME=postgres

View File

@ -1,11 +0,0 @@
import pytest
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor
CODE_LANGUAGE = "unsupported_language"
def test_unsupported_with_code_template():
with pytest.raises(CodeExecutionError) as e:
CodeExecutor.execute_workflow_code_template(language=CODE_LANGUAGE, code="", inputs={})
assert str(e.value) == f"Unsupported language {CODE_LANGUAGE}"

View File

@ -1,36 +0,0 @@
from textwrap import dedent
from core.helper.code_executor.code_executor import CodeExecutor, CodeLanguage
from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider
from core.helper.code_executor.python3.python3_transformer import Python3TemplateTransformer
CODE_LANGUAGE = CodeLanguage.PYTHON3
def test_python3_plain():
code = 'print("Hello World")'
result = CodeExecutor.execute_code(language=CODE_LANGUAGE, preload="", code=code)
assert result == "Hello World\n"
def test_python3_json():
code = dedent("""
import json
print(json.dumps({'Hello': 'World'}))
""")
result = CodeExecutor.execute_code(language=CODE_LANGUAGE, preload="", code=code)
assert result == '{"Hello": "World"}\n'
def test_python3_with_code_template():
result = CodeExecutor.execute_workflow_code_template(
language=CODE_LANGUAGE, code=Python3CodeProvider.get_default_code(), inputs={"arg1": "Hello", "arg2": "World"}
)
assert result == {"result": "HelloWorld"}
def test_python3_get_runner_script():
runner_script = Python3TemplateTransformer.get_runner_script()
assert runner_script.count(Python3TemplateTransformer._code_placeholder) == 1
assert runner_script.count(Python3TemplateTransformer._inputs_placeholder) == 1
assert runner_script.count(Python3TemplateTransformer._result_tag) == 2

View File

@ -432,7 +432,7 @@ class TestWorkflowAppLogEndpoints:
monkeypatch.setattr(workflow_app_log_module, "sessionmaker", DummySessionMaker)
def fake_get_paginate(self, **_kwargs):
return {"items": [], "total": 0}
return {"page": 1, "limit": 20, "total": 0, "has_more": False, "data": []}
monkeypatch.setattr(
workflow_app_log_module.WorkflowAppService,
@ -443,7 +443,7 @@ class TestWorkflowAppLogEndpoints:
with app.test_request_context("/?page=1&limit=20"):
result = method(app_model=SimpleNamespace(id="app-1"))
assert result == {"items": [], "total": 0}
assert result == {"page": 1, "limit": 20, "total": 0, "has_more": False, "data": []}
class TestWorkflowDraftVariableEndpoints:
@ -608,7 +608,8 @@ class TestWorkflowTriggerEndpoints:
with app.test_request_context("/?node_id=node-1"):
result = method(app_model=SimpleNamespace(id="app-1"))
assert result is trigger
assert isinstance(result, dict)
assert {"id", "webhook_id", "webhook_url", "webhook_debug_url", "node_id", "created_at"} <= set(result.keys())
class TestWrapsEndpoints:

View File

@ -96,6 +96,56 @@ class TestAppImportApi:
assert status == 200
assert response["status"] == ImportStatus.COMPLETED
def test_import_post_commits_session_on_success(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
api = app_import_module.AppImportApi()
method = _unwrap(api.post)
_install_features(monkeypatch, enabled=False)
monkeypatch.setattr(
app_import_module.AppDslService,
"import_app",
lambda *_args, **_kwargs: _Result(ImportStatus.COMPLETED, app_id="app-123"),
)
monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1"))
fake_session = MagicMock()
fake_session.__enter__.return_value = fake_session
fake_session.__exit__.return_value = None
monkeypatch.setattr(app_import_module, "Session", lambda *_args, **_kwargs: fake_session)
with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}):
response, status = method()
fake_session.commit.assert_called_once_with()
fake_session.rollback.assert_not_called()
assert status == 200
assert response["status"] == ImportStatus.COMPLETED
def test_import_post_rolls_back_session_on_failure(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
api = app_import_module.AppImportApi()
method = _unwrap(api.post)
_install_features(monkeypatch, enabled=False)
monkeypatch.setattr(
app_import_module.AppDslService,
"import_app",
lambda *_args, **_kwargs: _Result(ImportStatus.FAILED, app_id=None),
)
monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1"))
fake_session = MagicMock()
fake_session.__enter__.return_value = fake_session
fake_session.__exit__.return_value = None
monkeypatch.setattr(app_import_module, "Session", lambda *_args, **_kwargs: fake_session)
with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}):
response, status = method()
fake_session.rollback.assert_called_once_with()
fake_session.commit.assert_not_called()
assert status == 400
assert response["status"] == ImportStatus.FAILED
class TestAppImportConfirmApi:
@pytest.fixture

View File

@ -0,0 +1,73 @@
from datetime import datetime
from unittest.mock import patch
import pytest
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound
from controllers.console.app.conversation import _get_conversation
from models.enums import ConversationFromSource
from models.model import AppMode, Conversation
from tests.test_containers_integration_tests.controllers.console.helpers import (
create_console_account_and_tenant,
create_console_app,
)
def test_get_conversation_mark_read_keeps_updated_at_unchanged(
db_session_with_containers: Session,
):
account, tenant = create_console_account_and_tenant(db_session_with_containers)
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT)
original_updated_at = datetime(2026, 2, 8, 0, 0, 0)
conversation = Conversation(
app_id=app.id,
name="read timestamp test",
inputs={},
status="normal",
mode=AppMode.CHAT,
from_source=ConversationFromSource.CONSOLE,
from_account_id=account.id,
updated_at=original_updated_at,
)
db_session_with_containers.add(conversation)
db_session_with_containers.commit()
read_at = datetime(2026, 2, 9, 0, 0, 0)
with (
patch(
"controllers.console.app.conversation.current_account_with_tenant",
return_value=(account, tenant.id),
autospec=True,
),
patch(
"controllers.console.app.conversation.naive_utc_now",
return_value=read_at,
autospec=True,
),
):
loaded = _get_conversation(app, conversation.id)
db_session_with_containers.refresh(conversation)
assert loaded.id == conversation.id
assert conversation.read_at == read_at
assert conversation.read_account_id == account.id
assert conversation.updated_at == original_updated_at
def test_get_conversation_raises_not_found_for_missing_conversation(
db_session_with_containers: Session,
):
account, tenant = create_console_account_and_tenant(db_session_with_containers)
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT)
with patch(
"controllers.console.app.conversation.current_account_with_tenant",
return_value=(account, tenant.id),
autospec=True,
):
with pytest.raises(NotFound):
_get_conversation(app, "00000000-0000-0000-0000-000000000000")

View File

@ -0,0 +1,110 @@
"""
Testcontainers integration tests for Service API Site controller.
"""
from __future__ import annotations
import pytest
from flask import Flask
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden
from controllers.service_api.app.site import AppSiteApi
from models.account import Tenant, TenantStatus
from models.model import App, AppMode, Site
@pytest.fixture
def app(flask_app_with_containers) -> Flask:
return flask_app_with_containers
def _unwrap(method):
fn = method
while hasattr(fn, "__wrapped__"):
fn = fn.__wrapped__
return fn
def _create_tenant(db_session: Session, *, status: TenantStatus = TenantStatus.NORMAL) -> Tenant:
tenant = Tenant(name="service-api-site-tenant", status=status)
db_session.add(tenant)
db_session.commit()
return tenant
def _create_app(db_session: Session, tenant_id: str) -> App:
app_model = App(
tenant_id=tenant_id,
mode=AppMode.CHAT,
name="service-api-site-app",
enable_site=True,
enable_api=True,
status="normal",
)
db_session.add(app_model)
db_session.commit()
return app_model
def _create_site(db_session: Session, app_id: str) -> Site:
site = Site(
app_id=app_id,
title="Service API Site",
icon_type="emoji",
icon="robot",
icon_background="#ffffff",
description="Service API test site",
default_language="en-US",
prompt_public=True,
show_workflow_steps=True,
customize_token_strategy="not_allow",
use_icon_as_answer_icon=False,
chat_color_theme="light",
chat_color_theme_inverted=False,
)
db_session.add(site)
db_session.commit()
return site
class TestAppSiteApi:
def test_get_site_success(self, app: Flask, db_session_with_containers: Session) -> None:
tenant = _create_tenant(db_session_with_containers)
app_model = _create_app(db_session_with_containers, tenant.id)
_create_site(db_session_with_containers, app_model.id)
with app.test_request_context("/site", method="GET", headers={"Authorization": "Bearer test-token"}):
api = AppSiteApi()
response = _unwrap(api.get)(api, app_model=app_model)
assert response["title"] == "Service API Site"
assert response["icon"] == "robot"
assert response["description"] == "Service API test site"
def test_get_site_not_found(self, app: Flask, db_session_with_containers: Session) -> None:
tenant = _create_tenant(db_session_with_containers)
app_model = _create_app(db_session_with_containers, tenant.id)
with app.test_request_context("/site", method="GET", headers={"Authorization": "Bearer test-token"}):
api = AppSiteApi()
with pytest.raises(Forbidden):
_unwrap(api.get)(api, app_model=app_model)
def test_get_site_tenant_archived(self, app: Flask, db_session_with_containers: Session) -> None:
tenant = _create_tenant(db_session_with_containers)
app_model = _create_app(db_session_with_containers, tenant.id)
_create_site(db_session_with_containers, app_model.id)
archived_tenant = db_session_with_containers.get(Tenant, tenant.id)
assert archived_tenant is not None
archived_tenant.status = TenantStatus.ARCHIVE
db_session_with_containers.commit()
app_model = db_session_with_containers.get(App, app_model.id)
assert app_model is not None
with app.test_request_context("/site", method="GET", headers={"Authorization": "Bearer test-token"}):
api = AppSiteApi()
with pytest.raises(Forbidden):
_unwrap(api.get)(api, app_model=app_model)

View File

@ -1,28 +1,48 @@
"""Unit tests for controllers.web.site endpoints."""
"""Testcontainers integration tests for controllers.web.site endpoints."""
from __future__ import annotations
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
from unittest.mock import patch
import pytest
from flask import Flask
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden
from controllers.web.site import AppSiteApi, AppSiteInfo
from models import Tenant, TenantStatus
from models.model import App, AppMode, CustomizeTokenStrategy, Site
def _tenant(*, status: str = "normal") -> SimpleNamespace:
return SimpleNamespace(
id="tenant-1",
status=status,
plan="basic",
custom_config_dict={"remove_webapp_brand": False, "replace_webapp_logo": False},
@pytest.fixture
def app(flask_app_with_containers) -> Flask:
return flask_app_with_containers
def _create_tenant(db_session: Session, *, status: TenantStatus = TenantStatus.NORMAL) -> Tenant:
tenant = Tenant(name="test-tenant", status=status)
db_session.add(tenant)
db_session.commit()
return tenant
def _create_app(db_session: Session, tenant_id: str, *, enable_site: bool = True) -> App:
app_model = App(
tenant_id=tenant_id,
mode=AppMode.CHAT,
name="test-app",
enable_site=enable_site,
enable_api=True,
)
db_session.add(app_model)
db_session.commit()
return app_model
def _site() -> SimpleNamespace:
return SimpleNamespace(
def _create_site(db_session: Session, app_id: str) -> Site:
site = Site(
app_id=app_id,
title="Site",
icon_type="emoji",
icon="robot",
@ -31,77 +51,64 @@ def _site() -> SimpleNamespace:
default_language="en",
chat_color_theme="light",
chat_color_theme_inverted=False,
copyright=None,
privacy_policy=None,
custom_disclaimer=None,
customize_token_strategy=CustomizeTokenStrategy.NOT_ALLOW,
code=f"code-{app_id[-6:]}",
prompt_public=False,
show_workflow_steps=True,
use_icon_as_answer_icon=False,
)
db_session.add(site)
db_session.commit()
return site
# ---------------------------------------------------------------------------
# AppSiteApi
# ---------------------------------------------------------------------------
class TestAppSiteApi:
@patch("controllers.web.site.FeatureService.get_features")
@patch("controllers.web.site.db")
def test_happy_path(self, mock_db: MagicMock, mock_features: MagicMock, app: Flask) -> None:
def test_happy_path(self, mock_features, app: Flask, db_session_with_containers: Session) -> None:
app.config["RESTX_MASK_HEADER"] = "X-Fields"
mock_features.return_value = SimpleNamespace(can_replace_logo=False)
site_obj = _site()
mock_db.session.scalar.return_value = site_obj
tenant = _tenant()
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", tenant=tenant, enable_site=True)
tenant = _create_tenant(db_session_with_containers)
app_model = _create_app(db_session_with_containers, tenant.id)
_create_site(db_session_with_containers, app_model.id)
end_user = SimpleNamespace(id="eu-1")
mock_features.return_value = SimpleNamespace(can_replace_logo=False)
with app.test_request_context("/site"):
result = AppSiteApi().get(app_model, end_user)
# marshal_with serializes AppSiteInfo to a dict
assert result["app_id"] == "app-1"
assert result["app_id"] == app_model.id
assert result["plan"] == "basic"
assert result["enable_site"] is True
@patch("controllers.web.site.db")
def test_missing_site_raises_forbidden(self, mock_db: MagicMock, app: Flask) -> None:
def test_missing_site_raises_forbidden(self, app: Flask, db_session_with_containers: Session) -> None:
app.config["RESTX_MASK_HEADER"] = "X-Fields"
mock_db.session.scalar.return_value = None
tenant = _tenant()
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", tenant=tenant, enable_site=True)
tenant = _create_tenant(db_session_with_containers)
app_model = _create_app(db_session_with_containers, tenant.id)
end_user = SimpleNamespace(id="eu-1")
with app.test_request_context("/site"):
with pytest.raises(Forbidden):
AppSiteApi().get(app_model, end_user)
@patch("controllers.web.site.db")
def test_archived_tenant_raises_forbidden(self, mock_db: MagicMock, app: Flask) -> None:
@patch("controllers.web.site.FeatureService.get_features")
def test_archived_tenant_raises_forbidden(
self, mock_features, app: Flask, db_session_with_containers: Session
) -> None:
app.config["RESTX_MASK_HEADER"] = "X-Fields"
from models.account import TenantStatus
mock_db.session.scalar.return_value = _site()
tenant = SimpleNamespace(
id="tenant-1",
status=TenantStatus.ARCHIVE,
plan="basic",
custom_config_dict={},
)
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", tenant=tenant)
tenant = _create_tenant(db_session_with_containers, status=TenantStatus.ARCHIVE)
app_model = _create_app(db_session_with_containers, tenant.id)
_create_site(db_session_with_containers, app_model.id)
end_user = SimpleNamespace(id="eu-1")
mock_features.return_value = SimpleNamespace(can_replace_logo=False)
with app.test_request_context("/site"):
with pytest.raises(Forbidden):
AppSiteApi().get(app_model, end_user)
# ---------------------------------------------------------------------------
# AppSiteInfo
# ---------------------------------------------------------------------------
class TestAppSiteInfo:
def test_basic_fields(self) -> None:
tenant = _tenant()
site_obj = _site()
tenant = SimpleNamespace(id="tenant-1", plan="basic", custom_config_dict={})
site_obj = SimpleNamespace()
info = AppSiteInfo(tenant, SimpleNamespace(id="app-1", enable_site=True), site_obj, "eu-1", False)
assert info.app_id == "app-1"
@ -118,7 +125,7 @@ class TestAppSiteInfo:
plan="pro",
custom_config_dict={"remove_webapp_brand": True, "replace_webapp_logo": True},
)
site_obj = _site()
site_obj = SimpleNamespace()
info = AppSiteInfo(tenant, SimpleNamespace(id="app-1", enable_site=True), site_obj, "eu-1", True)
assert info.can_replace_logo is True

View File

@ -0,0 +1,613 @@
"""Testcontainers integration tests for DatasetService permission and lifecycle SQL paths."""
from datetime import datetime
from types import SimpleNamespace
from unittest.mock import patch
from uuid import uuid4
import pytest
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound
from core.rag.index_processor.constant.index_type import IndexTechniqueType
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import (
AppDatasetJoin,
Dataset,
DatasetAutoDisableLog,
DatasetCollectionBinding,
DatasetPermission,
DatasetPermissionEnum,
)
from models.enums import DataSourceType
from services.dataset_service import DatasetCollectionBindingService, DatasetPermissionService, DatasetService
from services.errors.account import NoPermissionError
class DatasetPermissionIntegrationFactory:
@staticmethod
def create_account_with_tenant(
db_session_with_containers: Session,
role: TenantAccountRole = TenantAccountRole.OWNER,
) -> tuple[Account, Tenant]:
account = Account(
email=f"{uuid4()}@example.com",
name=f"user-{uuid4()}",
interface_language="en-US",
status="active",
)
tenant = Tenant(name=f"tenant-{uuid4()}", status="normal")
db_session_with_containers.add_all([account, tenant])
db_session_with_containers.flush()
join = TenantAccountJoin(
tenant_id=tenant.id,
account_id=account.id,
role=role,
current=True,
)
db_session_with_containers.add(join)
db_session_with_containers.commit()
account.role = role
account._current_tenant = tenant
return account, tenant
@staticmethod
def create_account_in_tenant(
db_session_with_containers: Session,
tenant: Tenant,
role: TenantAccountRole = TenantAccountRole.EDITOR,
) -> Account:
account = Account(
email=f"{uuid4()}@example.com",
name=f"user-{uuid4()}",
interface_language="en-US",
status="active",
)
db_session_with_containers.add(account)
db_session_with_containers.flush()
join = TenantAccountJoin(
tenant_id=tenant.id,
account_id=account.id,
role=role,
current=True,
)
db_session_with_containers.add(join)
db_session_with_containers.commit()
account.role = role
account._current_tenant = tenant
return account
@staticmethod
def create_dataset(
db_session_with_containers: Session,
*,
tenant_id: str,
created_by: str,
name: str | None = None,
permission: DatasetPermissionEnum = DatasetPermissionEnum.ONLY_ME,
indexing_technique: str | None = IndexTechniqueType.HIGH_QUALITY,
enable_api: bool = True,
) -> Dataset:
dataset = Dataset(
tenant_id=tenant_id,
name=name or f"dataset-{uuid4()}",
description="desc",
data_source_type=DataSourceType.UPLOAD_FILE,
indexing_technique=indexing_technique,
created_by=created_by,
provider="vendor",
permission=permission,
retrieval_model={"top_k": 2},
)
dataset.enable_api = enable_api
db_session_with_containers.add(dataset)
db_session_with_containers.commit()
return dataset
@staticmethod
def create_dataset_permission(
db_session_with_containers: Session,
*,
dataset_id: str,
tenant_id: str,
account_id: str,
) -> DatasetPermission:
permission = DatasetPermission(
dataset_id=dataset_id,
tenant_id=tenant_id,
account_id=account_id,
has_permission=True,
)
db_session_with_containers.add(permission)
db_session_with_containers.commit()
return permission
@staticmethod
def create_app_dataset_join(
db_session_with_containers: Session,
*,
dataset_id: str,
) -> AppDatasetJoin:
join = AppDatasetJoin(
app_id=str(uuid4()),
dataset_id=dataset_id,
)
db_session_with_containers.add(join)
db_session_with_containers.commit()
return join
@staticmethod
def create_collection_binding(
db_session_with_containers: Session,
*,
provider_name: str,
model_name: str,
collection_type: str = "dataset",
) -> DatasetCollectionBinding:
binding = DatasetCollectionBinding(
provider_name=provider_name,
model_name=model_name,
collection_name=f"collection_{uuid4().hex}",
type=collection_type,
)
db_session_with_containers.add(binding)
db_session_with_containers.commit()
return binding
@staticmethod
def create_auto_disable_log(
db_session_with_containers: Session,
*,
tenant_id: str,
dataset_id: str,
document_id: str,
) -> DatasetAutoDisableLog:
log = DatasetAutoDisableLog(
tenant_id=tenant_id,
dataset_id=dataset_id,
document_id=document_id,
)
db_session_with_containers.add(log)
db_session_with_containers.commit()
return log
class TestDatasetServicePermissionsAndLifecycle:
def test_delete_dataset_returns_false_when_dataset_is_missing(self, db_session_with_containers: Session):
owner, _tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers)
result = DatasetService.delete_dataset(str(uuid4()), user=owner)
assert result is False
def test_delete_dataset_checks_permission_and_deletes_dataset(self, db_session_with_containers: Session):
owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers)
dataset = DatasetPermissionIntegrationFactory.create_dataset(
db_session_with_containers,
tenant_id=tenant.id,
created_by=owner.id,
)
with patch("services.dataset_service.dataset_was_deleted.send") as send_deleted_signal:
result = DatasetService.delete_dataset(dataset.id, user=owner)
assert result is True
assert db_session_with_containers.get(Dataset, dataset.id) is None
send_deleted_signal.assert_called_once_with(dataset)
def test_dataset_use_check_returns_true_when_join_exists(self, db_session_with_containers: Session):
owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers)
dataset = DatasetPermissionIntegrationFactory.create_dataset(
db_session_with_containers,
tenant_id=tenant.id,
created_by=owner.id,
)
DatasetPermissionIntegrationFactory.create_app_dataset_join(
db_session_with_containers,
dataset_id=dataset.id,
)
assert DatasetService.dataset_use_check(dataset.id) is True
def test_dataset_use_check_returns_false_when_join_missing(self, db_session_with_containers: Session):
owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers)
dataset = DatasetPermissionIntegrationFactory.create_dataset(
db_session_with_containers,
tenant_id=tenant.id,
created_by=owner.id,
)
assert DatasetService.dataset_use_check(dataset.id) is False
def test_check_dataset_permission_rejects_cross_tenant_access(self, db_session_with_containers: Session):
owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers)
outsider, _other_tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(
db_session_with_containers
)
dataset = DatasetPermissionIntegrationFactory.create_dataset(
db_session_with_containers,
tenant_id=tenant.id,
created_by=owner.id,
)
with pytest.raises(NoPermissionError, match="do not have permission"):
DatasetService.check_dataset_permission(dataset, outsider)
def test_check_dataset_permission_rejects_only_me_dataset_for_non_creator(
self, db_session_with_containers: Session
):
owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers)
member = DatasetPermissionIntegrationFactory.create_account_in_tenant(db_session_with_containers, tenant)
dataset = DatasetPermissionIntegrationFactory.create_dataset(
db_session_with_containers,
tenant_id=tenant.id,
created_by=owner.id,
permission=DatasetPermissionEnum.ONLY_ME,
)
with pytest.raises(NoPermissionError, match="do not have permission"):
DatasetService.check_dataset_permission(dataset, member)
def test_check_dataset_permission_rejects_partial_team_user_without_binding(
self, db_session_with_containers: Session
):
owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers)
member = DatasetPermissionIntegrationFactory.create_account_in_tenant(db_session_with_containers, tenant)
dataset = DatasetPermissionIntegrationFactory.create_dataset(
db_session_with_containers,
tenant_id=tenant.id,
created_by=owner.id,
permission=DatasetPermissionEnum.PARTIAL_TEAM,
)
with pytest.raises(NoPermissionError, match="do not have permission"):
DatasetService.check_dataset_permission(dataset, member)
def test_check_dataset_permission_allows_partial_team_creator(self, db_session_with_containers: Session):
creator, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(
db_session_with_containers,
role=TenantAccountRole.EDITOR,
)
dataset = DatasetPermissionIntegrationFactory.create_dataset(
db_session_with_containers,
tenant_id=tenant.id,
created_by=creator.id,
permission=DatasetPermissionEnum.PARTIAL_TEAM,
)
DatasetService.check_dataset_permission(dataset, creator)
def test_check_dataset_permission_allows_partial_team_member_with_binding(
self, db_session_with_containers: Session
):
owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers)
member = DatasetPermissionIntegrationFactory.create_account_in_tenant(db_session_with_containers, tenant)
dataset = DatasetPermissionIntegrationFactory.create_dataset(
db_session_with_containers,
tenant_id=tenant.id,
created_by=owner.id,
permission=DatasetPermissionEnum.PARTIAL_TEAM,
)
DatasetPermissionIntegrationFactory.create_dataset_permission(
db_session_with_containers,
dataset_id=dataset.id,
tenant_id=tenant.id,
account_id=member.id,
)
DatasetService.check_dataset_permission(dataset, member)
def test_check_dataset_operator_permission_rejects_only_me_for_non_creator(
self, db_session_with_containers: Session
):
owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers)
operator = DatasetPermissionIntegrationFactory.create_account_in_tenant(
db_session_with_containers,
tenant,
role=TenantAccountRole.EDITOR,
)
dataset = DatasetPermissionIntegrationFactory.create_dataset(
db_session_with_containers,
tenant_id=tenant.id,
created_by=owner.id,
permission=DatasetPermissionEnum.ONLY_ME,
)
with pytest.raises(NoPermissionError, match="do not have permission"):
DatasetService.check_dataset_operator_permission(user=operator, dataset=dataset)
def test_check_dataset_operator_permission_rejects_partial_team_without_binding(
self, db_session_with_containers: Session
):
owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers)
operator = DatasetPermissionIntegrationFactory.create_account_in_tenant(
db_session_with_containers,
tenant,
role=TenantAccountRole.EDITOR,
)
dataset = DatasetPermissionIntegrationFactory.create_dataset(
db_session_with_containers,
tenant_id=tenant.id,
created_by=owner.id,
permission=DatasetPermissionEnum.PARTIAL_TEAM,
)
with pytest.raises(NoPermissionError, match="do not have permission"):
DatasetService.check_dataset_operator_permission(user=operator, dataset=dataset)
def test_check_dataset_operator_permission_allows_partial_team_with_binding(
self, db_session_with_containers: Session
):
owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers)
operator = DatasetPermissionIntegrationFactory.create_account_in_tenant(
db_session_with_containers,
tenant,
role=TenantAccountRole.EDITOR,
)
dataset = DatasetPermissionIntegrationFactory.create_dataset(
db_session_with_containers,
tenant_id=tenant.id,
created_by=owner.id,
permission=DatasetPermissionEnum.PARTIAL_TEAM,
)
DatasetPermissionIntegrationFactory.create_dataset_permission(
db_session_with_containers,
dataset_id=dataset.id,
tenant_id=tenant.id,
account_id=operator.id,
)
DatasetService.check_dataset_operator_permission(user=operator, dataset=dataset)
def test_update_dataset_api_status_raises_not_found_for_missing_dataset(self, flask_app_with_containers):
with flask_app_with_containers.app_context():
with pytest.raises(NotFound, match="Dataset not found"):
DatasetService.update_dataset_api_status(str(uuid4()), True)
def test_update_dataset_api_status_requires_current_user_id(self, db_session_with_containers: Session):
owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers)
dataset = DatasetPermissionIntegrationFactory.create_dataset(
db_session_with_containers,
tenant_id=tenant.id,
created_by=owner.id,
enable_api=False,
)
with patch("services.dataset_service.current_user", SimpleNamespace(id=None)):
with pytest.raises(ValueError, match="Current user or current user id not found"):
DatasetService.update_dataset_api_status(dataset.id, True)
def test_update_dataset_api_status_updates_fields_and_commits(self, db_session_with_containers: Session):
owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers)
dataset = DatasetPermissionIntegrationFactory.create_dataset(
db_session_with_containers,
tenant_id=tenant.id,
created_by=owner.id,
enable_api=False,
)
now = datetime(2026, 4, 14, 18, 0, 0)
with (
patch("services.dataset_service.current_user", owner),
patch("services.dataset_service.naive_utc_now", return_value=now),
):
DatasetService.update_dataset_api_status(dataset.id, True)
db_session_with_containers.refresh(dataset)
assert dataset.enable_api is True
assert dataset.updated_by == owner.id
assert dataset.updated_at == now
def test_get_dataset_auto_disable_logs_returns_empty_when_billing_is_disabled(
self, db_session_with_containers: Session
):
owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers)
features = SimpleNamespace(
billing=SimpleNamespace(enabled=False, subscription=SimpleNamespace(plan="professional"))
)
with (
patch("services.dataset_service.current_user", owner),
patch("services.dataset_service.FeatureService.get_features", return_value=features),
):
result = DatasetService.get_dataset_auto_disable_logs(str(uuid4()))
assert result == {"document_ids": [], "count": 0}
def test_get_dataset_auto_disable_logs_returns_recent_document_ids(self, db_session_with_containers: Session):
owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers)
dataset = DatasetPermissionIntegrationFactory.create_dataset(
db_session_with_containers,
tenant_id=tenant.id,
created_by=owner.id,
)
DatasetPermissionIntegrationFactory.create_auto_disable_log(
db_session_with_containers,
tenant_id=tenant.id,
dataset_id=dataset.id,
document_id=str(uuid4()),
)
DatasetPermissionIntegrationFactory.create_auto_disable_log(
db_session_with_containers,
tenant_id=tenant.id,
dataset_id=dataset.id,
document_id=str(uuid4()),
)
features = SimpleNamespace(
billing=SimpleNamespace(enabled=True, subscription=SimpleNamespace(plan="professional"))
)
with (
patch("services.dataset_service.current_user", owner),
patch("services.dataset_service.FeatureService.get_features", return_value=features),
):
result = DatasetService.get_dataset_auto_disable_logs(dataset.id)
assert result["count"] == 2
assert len(result["document_ids"]) == 2
class TestDatasetCollectionBindingServiceIntegration:
def test_get_dataset_collection_binding_returns_existing_binding(self, db_session_with_containers: Session):
binding = DatasetPermissionIntegrationFactory.create_collection_binding(
db_session_with_containers,
provider_name="provider",
model_name="model",
)
result = DatasetCollectionBindingService.get_dataset_collection_binding("provider", "model")
assert result.id == binding.id
def test_get_dataset_collection_binding_creates_binding_when_missing(self, db_session_with_containers: Session):
result = DatasetCollectionBindingService.get_dataset_collection_binding("provider", "missing-model")
persisted = db_session_with_containers.get(DatasetCollectionBinding, result.id)
assert persisted is not None
assert persisted.provider_name == "provider"
assert persisted.model_name == "missing-model"
assert persisted.type == "dataset"
assert persisted.collection_name
def test_get_dataset_collection_binding_by_id_and_type_raises_when_missing(self, flask_app_with_containers):
with flask_app_with_containers.app_context():
with pytest.raises(ValueError, match="Dataset collection binding not found"):
DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(str(uuid4()))
def test_get_dataset_collection_binding_by_id_and_type_returns_binding(self, db_session_with_containers: Session):
binding = DatasetPermissionIntegrationFactory.create_collection_binding(
db_session_with_containers,
provider_name="provider",
model_name="model",
)
result = DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(binding.id)
assert result.id == binding.id
class TestDatasetPermissionServiceIntegration:
def test_get_dataset_partial_member_list_returns_scalar_results(self, db_session_with_containers: Session):
owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers)
member_a = DatasetPermissionIntegrationFactory.create_account_in_tenant(db_session_with_containers, tenant)
member_b = DatasetPermissionIntegrationFactory.create_account_in_tenant(db_session_with_containers, tenant)
dataset = DatasetPermissionIntegrationFactory.create_dataset(
db_session_with_containers,
tenant_id=tenant.id,
created_by=owner.id,
permission=DatasetPermissionEnum.PARTIAL_TEAM,
)
DatasetPermissionIntegrationFactory.create_dataset_permission(
db_session_with_containers,
dataset_id=dataset.id,
tenant_id=tenant.id,
account_id=member_a.id,
)
DatasetPermissionIntegrationFactory.create_dataset_permission(
db_session_with_containers,
dataset_id=dataset.id,
tenant_id=tenant.id,
account_id=member_b.id,
)
result = DatasetPermissionService.get_dataset_partial_member_list(dataset.id)
assert set(result) == {member_a.id, member_b.id}
def test_update_partial_member_list_replaces_permissions_and_commits(self, db_session_with_containers: Session):
owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers)
member_a = DatasetPermissionIntegrationFactory.create_account_in_tenant(db_session_with_containers, tenant)
member_b = DatasetPermissionIntegrationFactory.create_account_in_tenant(db_session_with_containers, tenant)
dataset = DatasetPermissionIntegrationFactory.create_dataset(
db_session_with_containers,
tenant_id=tenant.id,
created_by=owner.id,
permission=DatasetPermissionEnum.PARTIAL_TEAM,
)
stale_member = DatasetPermissionIntegrationFactory.create_account_in_tenant(db_session_with_containers, tenant)
DatasetPermissionIntegrationFactory.create_dataset_permission(
db_session_with_containers,
dataset_id=dataset.id,
tenant_id=tenant.id,
account_id=stale_member.id,
)
DatasetPermissionService.update_partial_member_list(
tenant.id,
dataset.id,
[{"user_id": member_a.id}, {"user_id": member_b.id}],
)
permissions = db_session_with_containers.query(DatasetPermission).filter_by(dataset_id=dataset.id).all()
assert {permission.account_id for permission in permissions} == {member_a.id, member_b.id}
def test_check_permission_requires_dataset_editor(self):
user = SimpleNamespace(is_dataset_editor=False, is_dataset_operator=False)
dataset = SimpleNamespace(id="dataset-1", permission=DatasetPermissionEnum.ALL_TEAM)
with pytest.raises(NoPermissionError, match="does not have permission"):
DatasetPermissionService.check_permission(user, dataset, DatasetPermissionEnum.ALL_TEAM, [])
def test_check_permission_prevents_dataset_operator_from_changing_permission_mode(self):
user = SimpleNamespace(is_dataset_editor=True, is_dataset_operator=True)
dataset = SimpleNamespace(id="dataset-1", permission=DatasetPermissionEnum.ALL_TEAM)
with pytest.raises(NoPermissionError, match="cannot change the dataset permissions"):
DatasetPermissionService.check_permission(user, dataset, DatasetPermissionEnum.ONLY_ME, [])
def test_check_permission_requires_partial_member_list_for_partial_members_mode(self):
user = SimpleNamespace(is_dataset_editor=True, is_dataset_operator=True)
dataset = SimpleNamespace(id="dataset-1", permission=DatasetPermissionEnum.PARTIAL_TEAM)
with pytest.raises(ValueError, match="Partial member list is required"):
DatasetPermissionService.check_permission(user, dataset, DatasetPermissionEnum.PARTIAL_TEAM, [])
def test_check_permission_rejects_dataset_operator_member_list_changes(self):
user = SimpleNamespace(is_dataset_editor=True, is_dataset_operator=True)
dataset = SimpleNamespace(id="dataset-1", permission=DatasetPermissionEnum.PARTIAL_TEAM)
with patch.object(DatasetPermissionService, "get_dataset_partial_member_list", return_value=["user-1"]):
with pytest.raises(ValueError, match="cannot change the dataset permissions"):
DatasetPermissionService.check_permission(
user,
dataset,
DatasetPermissionEnum.PARTIAL_TEAM,
[{"user_id": "user-2"}],
)
def test_check_permission_allows_dataset_operator_when_member_list_is_unchanged(self):
user = SimpleNamespace(is_dataset_editor=True, is_dataset_operator=True)
dataset = SimpleNamespace(id="dataset-1", permission=DatasetPermissionEnum.PARTIAL_TEAM)
with patch.object(DatasetPermissionService, "get_dataset_partial_member_list", return_value=["user-1"]):
DatasetPermissionService.check_permission(
user,
dataset,
DatasetPermissionEnum.PARTIAL_TEAM,
[{"user_id": "user-1"}],
)
def test_clear_partial_member_list_deletes_permissions_and_commits(self, db_session_with_containers: Session):
owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers)
member = DatasetPermissionIntegrationFactory.create_account_in_tenant(db_session_with_containers, tenant)
dataset = DatasetPermissionIntegrationFactory.create_dataset(
db_session_with_containers,
tenant_id=tenant.id,
created_by=owner.id,
permission=DatasetPermissionEnum.PARTIAL_TEAM,
)
DatasetPermissionIntegrationFactory.create_dataset_permission(
db_session_with_containers,
dataset_id=dataset.id,
tenant_id=tenant.id,
account_id=member.id,
)
DatasetPermissionService.clear_partial_member_list(dataset.id)
remaining = db_session_with_containers.query(DatasetPermission).filter_by(dataset_id=dataset.id).all()
assert remaining == []

View File

@ -0,0 +1,387 @@
"""Testcontainers integration tests for schedule service SQL-backed behavior."""
from datetime import datetime
from types import SimpleNamespace
from uuid import uuid4
import pytest
from sqlalchemy import delete, select
from sqlalchemy.orm import Session
from core.workflow.nodes.trigger_schedule.entities import ScheduleConfig, SchedulePlanUpdate
from core.workflow.nodes.trigger_schedule.exc import ScheduleNotFoundError
from events.event_handlers.sync_workflow_schedule_when_app_published import sync_schedule_from_workflow
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.trigger import WorkflowSchedulePlan
from services.errors.account import AccountNotFoundError
from services.trigger.schedule_service import ScheduleService
class ScheduleServiceIntegrationFactory:
@staticmethod
def create_account_with_tenant(
db_session_with_containers: Session,
role: TenantAccountRole = TenantAccountRole.OWNER,
) -> tuple[Account, Tenant]:
account = Account(
email=f"{uuid4()}@example.com",
name=f"user-{uuid4()}",
interface_language="en-US",
status="active",
)
tenant = Tenant(name=f"tenant-{uuid4()}", status="normal")
db_session_with_containers.add_all([account, tenant])
db_session_with_containers.flush()
join = TenantAccountJoin(
tenant_id=tenant.id,
account_id=account.id,
role=role,
current=True,
)
db_session_with_containers.add(join)
db_session_with_containers.commit()
account.current_tenant = tenant
return account, tenant
@staticmethod
def create_schedule_plan(
db_session_with_containers: Session,
*,
tenant_id: str,
app_id: str | None = None,
node_id: str = "start",
cron_expression: str = "30 10 * * *",
timezone: str = "UTC",
next_run_at: datetime | None = None,
) -> WorkflowSchedulePlan:
schedule = WorkflowSchedulePlan(
tenant_id=tenant_id,
app_id=app_id or str(uuid4()),
node_id=node_id,
cron_expression=cron_expression,
timezone=timezone,
next_run_at=next_run_at,
)
db_session_with_containers.add(schedule)
db_session_with_containers.commit()
return schedule
def _cron_workflow(
*,
node_id: str = "start",
cron_expression: str = "30 10 * * *",
timezone: str = "UTC",
):
return SimpleNamespace(
graph_dict={
"nodes": [
{
"id": node_id,
"data": {
"type": "trigger-schedule",
"mode": "cron",
"cron_expression": cron_expression,
"timezone": timezone,
},
}
]
}
)
def _no_schedule_workflow():
return SimpleNamespace(
graph_dict={
"nodes": [
{
"id": "node-1",
"data": {"type": "llm"},
}
]
}
)
class TestScheduleServiceIntegration:
def test_create_schedule_persists_schedule(self, db_session_with_containers: Session):
account, tenant = ScheduleServiceIntegrationFactory.create_account_with_tenant(db_session_with_containers)
expected_next_run = datetime(2026, 1, 1, 10, 30, 0)
config = ScheduleConfig(
node_id="start",
cron_expression="30 10 * * *",
timezone="UTC",
)
with pytest.MonkeyPatch.context() as monkeypatch:
monkeypatch.setattr(
"services.trigger.schedule_service.calculate_next_run_at",
lambda *_args, **_kwargs: expected_next_run,
)
schedule = ScheduleService.create_schedule(
session=db_session_with_containers,
tenant_id=tenant.id,
app_id=str(uuid4()),
config=config,
)
persisted = db_session_with_containers.get(WorkflowSchedulePlan, schedule.id)
assert persisted is not None
assert persisted.tenant_id == tenant.id
assert persisted.node_id == "start"
assert persisted.cron_expression == "30 10 * * *"
assert persisted.timezone == "UTC"
assert persisted.next_run_at == expected_next_run
def test_update_schedule_updates_fields_and_recomputes_next_run(self, db_session_with_containers: Session):
_account, tenant = ScheduleServiceIntegrationFactory.create_account_with_tenant(db_session_with_containers)
schedule = ScheduleServiceIntegrationFactory.create_schedule_plan(
db_session_with_containers,
tenant_id=tenant.id,
cron_expression="30 10 * * *",
timezone="UTC",
)
expected_next_run = datetime(2026, 1, 2, 12, 0, 0)
with pytest.MonkeyPatch.context() as monkeypatch:
monkeypatch.setattr(
"services.trigger.schedule_service.calculate_next_run_at",
lambda *_args, **_kwargs: expected_next_run,
)
updated = ScheduleService.update_schedule(
session=db_session_with_containers,
schedule_id=schedule.id,
updates=SchedulePlanUpdate(
cron_expression="0 12 * * *",
timezone="America/New_York",
),
)
db_session_with_containers.refresh(updated)
assert updated.cron_expression == "0 12 * * *"
assert updated.timezone == "America/New_York"
assert updated.next_run_at == expected_next_run
def test_update_schedule_updates_only_node_id_without_recomputing_time(self, db_session_with_containers: Session):
_account, tenant = ScheduleServiceIntegrationFactory.create_account_with_tenant(db_session_with_containers)
initial_next_run = datetime(2026, 1, 1, 10, 0, 0)
schedule = ScheduleServiceIntegrationFactory.create_schedule_plan(
db_session_with_containers,
tenant_id=tenant.id,
next_run_at=initial_next_run,
)
with pytest.MonkeyPatch.context() as monkeypatch:
calls: list[tuple] = []
def _track(*args, **kwargs):
calls.append((args, kwargs))
return datetime(2026, 1, 9, 10, 0, 0)
monkeypatch.setattr("services.trigger.schedule_service.calculate_next_run_at", _track)
updated = ScheduleService.update_schedule(
session=db_session_with_containers,
schedule_id=schedule.id,
updates=SchedulePlanUpdate(node_id="node-new"),
)
db_session_with_containers.refresh(updated)
assert updated.node_id == "node-new"
assert updated.next_run_at == initial_next_run
assert calls == []
def test_update_schedule_not_found_raises(self, db_session_with_containers: Session):
with pytest.raises(ScheduleNotFoundError, match="Schedule not found"):
ScheduleService.update_schedule(
session=db_session_with_containers,
schedule_id=str(uuid4()),
updates=SchedulePlanUpdate(node_id="node-new"),
)
def test_delete_schedule_removes_row(self, db_session_with_containers: Session):
_account, tenant = ScheduleServiceIntegrationFactory.create_account_with_tenant(db_session_with_containers)
schedule = ScheduleServiceIntegrationFactory.create_schedule_plan(
db_session_with_containers,
tenant_id=tenant.id,
)
ScheduleService.delete_schedule(
session=db_session_with_containers,
schedule_id=schedule.id,
)
db_session_with_containers.commit()
assert db_session_with_containers.get(WorkflowSchedulePlan, schedule.id) is None
def test_delete_schedule_not_found_raises(self, db_session_with_containers: Session):
with pytest.raises(ScheduleNotFoundError, match="Schedule not found"):
ScheduleService.delete_schedule(
session=db_session_with_containers,
schedule_id=str(uuid4()),
)
def test_get_tenant_owner_returns_owner_account(self, db_session_with_containers: Session):
owner, tenant = ScheduleServiceIntegrationFactory.create_account_with_tenant(
db_session_with_containers,
role=TenantAccountRole.OWNER,
)
result = ScheduleService.get_tenant_owner(
session=db_session_with_containers,
tenant_id=tenant.id,
)
assert result.id == owner.id
def test_get_tenant_owner_falls_back_to_admin(self, db_session_with_containers: Session):
admin, tenant = ScheduleServiceIntegrationFactory.create_account_with_tenant(
db_session_with_containers,
role=TenantAccountRole.ADMIN,
)
result = ScheduleService.get_tenant_owner(
session=db_session_with_containers,
tenant_id=tenant.id,
)
assert result.id == admin.id
def test_get_tenant_owner_raises_when_account_record_missing(self, db_session_with_containers: Session):
_account, tenant = ScheduleServiceIntegrationFactory.create_account_with_tenant(db_session_with_containers)
db_session_with_containers.execute(delete(TenantAccountJoin))
missing_account_id = str(uuid4())
join = TenantAccountJoin(
tenant_id=tenant.id,
account_id=missing_account_id,
role=TenantAccountRole.OWNER,
current=True,
)
db_session_with_containers.add(join)
db_session_with_containers.commit()
with pytest.raises(AccountNotFoundError, match=missing_account_id):
ScheduleService.get_tenant_owner(session=db_session_with_containers, tenant_id=tenant.id)
def test_get_tenant_owner_raises_when_no_owner_or_admin_found(self, db_session_with_containers: Session):
_account, tenant = ScheduleServiceIntegrationFactory.create_account_with_tenant(db_session_with_containers)
db_session_with_containers.execute(delete(TenantAccountJoin))
db_session_with_containers.commit()
with pytest.raises(AccountNotFoundError, match=tenant.id):
ScheduleService.get_tenant_owner(session=db_session_with_containers, tenant_id=tenant.id)
def test_update_next_run_at_updates_persisted_value(self, db_session_with_containers: Session):
_account, tenant = ScheduleServiceIntegrationFactory.create_account_with_tenant(db_session_with_containers)
schedule = ScheduleServiceIntegrationFactory.create_schedule_plan(
db_session_with_containers,
tenant_id=tenant.id,
)
expected_next_run = datetime(2026, 1, 3, 10, 30, 0)
with pytest.MonkeyPatch.context() as monkeypatch:
monkeypatch.setattr(
"services.trigger.schedule_service.calculate_next_run_at",
lambda *_args, **_kwargs: expected_next_run,
)
result = ScheduleService.update_next_run_at(
session=db_session_with_containers,
schedule_id=schedule.id,
)
db_session_with_containers.refresh(schedule)
assert result == expected_next_run
assert schedule.next_run_at == expected_next_run
def test_update_next_run_at_raises_when_schedule_not_found(self, db_session_with_containers: Session):
with pytest.raises(ScheduleNotFoundError, match="Schedule not found"):
ScheduleService.update_next_run_at(
session=db_session_with_containers,
schedule_id=str(uuid4()),
)
class TestSyncScheduleFromWorkflowIntegration:
def test_sync_schedule_create_new(self, db_session_with_containers: Session):
_account, tenant = ScheduleServiceIntegrationFactory.create_account_with_tenant(db_session_with_containers)
app_id = str(uuid4())
expected_next_run = datetime(2026, 1, 4, 10, 30, 0)
with pytest.MonkeyPatch.context() as monkeypatch:
monkeypatch.setattr(
"services.trigger.schedule_service.calculate_next_run_at",
lambda *_args, **_kwargs: expected_next_run,
)
result = sync_schedule_from_workflow(
tenant_id=tenant.id,
app_id=app_id,
workflow=_cron_workflow(),
)
assert result is not None
persisted = db_session_with_containers.execute(
select(WorkflowSchedulePlan).where(WorkflowSchedulePlan.app_id == app_id)
).scalar_one()
assert persisted.node_id == "start"
assert persisted.cron_expression == "30 10 * * *"
assert persisted.timezone == "UTC"
assert persisted.next_run_at == expected_next_run
def test_sync_schedule_update_existing(self, db_session_with_containers: Session):
_account, tenant = ScheduleServiceIntegrationFactory.create_account_with_tenant(db_session_with_containers)
app_id = str(uuid4())
existing = ScheduleServiceIntegrationFactory.create_schedule_plan(
db_session_with_containers,
tenant_id=tenant.id,
app_id=app_id,
node_id="old-start",
cron_expression="30 10 * * *",
timezone="UTC",
)
existing_id = existing.id
expected_next_run = datetime(2026, 1, 5, 12, 0, 0)
with pytest.MonkeyPatch.context() as monkeypatch:
monkeypatch.setattr(
"services.trigger.schedule_service.calculate_next_run_at",
lambda *_args, **_kwargs: expected_next_run,
)
result = sync_schedule_from_workflow(
tenant_id=tenant.id,
app_id=app_id,
workflow=_cron_workflow(
node_id="start",
cron_expression="0 12 * * *",
timezone="America/New_York",
),
)
assert result is not None
db_session_with_containers.expire_all()
persisted = db_session_with_containers.get(WorkflowSchedulePlan, existing_id)
assert persisted is not None
assert persisted.node_id == "start"
assert persisted.cron_expression == "0 12 * * *"
assert persisted.timezone == "America/New_York"
assert persisted.next_run_at == expected_next_run
def test_sync_schedule_remove_when_no_config(self, db_session_with_containers: Session):
_account, tenant = ScheduleServiceIntegrationFactory.create_account_with_tenant(db_session_with_containers)
app_id = str(uuid4())
existing = ScheduleServiceIntegrationFactory.create_schedule_plan(
db_session_with_containers,
tenant_id=tenant.id,
app_id=app_id,
)
existing_id = existing.id
result = sync_schedule_from_workflow(
tenant_id=tenant.id,
app_id=app_id,
workflow=_no_schedule_workflow(),
)
assert result is None
db_session_with_containers.expire_all()
assert db_session_with_containers.get(WorkflowSchedulePlan, existing_id) is None

View File

@ -2,6 +2,7 @@ from unittest.mock import MagicMock, patch
import pytest
from faker import Faker
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
@ -530,22 +531,18 @@ class TestAddDocumentToIndexTask:
redis_client.set(indexing_cache_key, "processing", ex=300)
# Verify logs exist before processing
existing_logs = (
db_session_with_containers.query(DatasetAutoDisableLog)
.where(DatasetAutoDisableLog.document_id == document.id)
.all()
)
existing_logs = db_session_with_containers.scalars(
select(DatasetAutoDisableLog).where(DatasetAutoDisableLog.document_id == document.id)
).all()
assert len(existing_logs) == 2
# Act: Execute the task
add_document_to_index_task(document.id)
# Assert: Verify auto disable logs were deleted
remaining_logs = (
db_session_with_containers.query(DatasetAutoDisableLog)
.where(DatasetAutoDisableLog.document_id == document.id)
.all()
)
remaining_logs = db_session_with_containers.scalars(
select(DatasetAutoDisableLog).where(DatasetAutoDisableLog.document_id == document.id)
).all()
assert len(remaining_logs) == 0
# Verify index processing occurred normally

View File

@ -11,6 +11,7 @@ from unittest.mock import Mock, patch
import pytest
from faker import Faker
from sqlalchemy import func, select
from sqlalchemy.orm import Session
from core.rag.index_processor.constant.index_type import IndexStructureType
@ -267,11 +268,13 @@ class TestBatchCleanDocumentTask:
db_session_with_containers.commit() # Ensure all changes are committed
# Check that segment is deleted
deleted_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first()
deleted_segment = db_session_with_containers.scalar(
select(DocumentSegment).where(DocumentSegment.id == segment_id).limit(1)
)
assert deleted_segment is None
# Check that upload file is deleted
deleted_file = db_session_with_containers.query(UploadFile).filter_by(id=file_id).first()
deleted_file = db_session_with_containers.scalar(select(UploadFile).where(UploadFile.id == file_id).limit(1))
assert deleted_file is None
def test_batch_clean_document_task_with_image_files(
@ -319,7 +322,9 @@ class TestBatchCleanDocumentTask:
db_session_with_containers.commit()
# Check that segment is deleted
deleted_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first()
deleted_segment = db_session_with_containers.scalar(
select(DocumentSegment).where(DocumentSegment.id == segment_id).limit(1)
)
assert deleted_segment is None
# Verify that the task completed successfully by checking the log output
@ -360,14 +365,14 @@ class TestBatchCleanDocumentTask:
db_session_with_containers.commit()
# Check that upload file is deleted
deleted_file = db_session_with_containers.query(UploadFile).filter_by(id=file_id).first()
deleted_file = db_session_with_containers.scalar(select(UploadFile).where(UploadFile.id == file_id).limit(1))
assert deleted_file is None
# Verify database cleanup
db_session_with_containers.commit()
# Check that upload file is deleted
deleted_file = db_session_with_containers.query(UploadFile).filter_by(id=file_id).first()
deleted_file = db_session_with_containers.scalar(select(UploadFile).where(UploadFile.id == file_id).limit(1))
assert deleted_file is None
def test_batch_clean_document_task_dataset_not_found(
@ -410,7 +415,9 @@ class TestBatchCleanDocumentTask:
db_session_with_containers.commit()
# Document should still exist since cleanup failed
existing_document = db_session_with_containers.query(Document).filter_by(id=document_id).first()
existing_document = db_session_with_containers.scalar(
select(Document).where(Document.id == document_id).limit(1)
)
assert existing_document is not None
def test_batch_clean_document_task_storage_cleanup_failure(
@ -453,11 +460,13 @@ class TestBatchCleanDocumentTask:
db_session_with_containers.commit()
# Check that segment is deleted from database
deleted_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first()
deleted_segment = db_session_with_containers.scalar(
select(DocumentSegment).where(DocumentSegment.id == segment_id).limit(1)
)
assert deleted_segment is None
# Check that upload file is deleted from database
deleted_file = db_session_with_containers.query(UploadFile).filter_by(id=file_id).first()
deleted_file = db_session_with_containers.scalar(select(UploadFile).where(UploadFile.id == file_id).limit(1))
assert deleted_file is None
def test_batch_clean_document_task_multiple_documents(
@ -510,12 +519,16 @@ class TestBatchCleanDocumentTask:
# Check that all segments are deleted
for segment_id in segment_ids:
deleted_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first()
deleted_segment = db_session_with_containers.scalar(
select(DocumentSegment).where(DocumentSegment.id == segment_id).limit(1)
)
assert deleted_segment is None
# Check that all upload files are deleted
for file_id in file_ids:
deleted_file = db_session_with_containers.query(UploadFile).filter_by(id=file_id).first()
deleted_file = db_session_with_containers.scalar(
select(UploadFile).where(UploadFile.id == file_id).limit(1)
)
assert deleted_file is None
def test_batch_clean_document_task_different_doc_forms(
@ -564,7 +577,9 @@ class TestBatchCleanDocumentTask:
db_session_with_containers.commit()
# Check that segment is deleted
deleted_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first()
deleted_segment = db_session_with_containers.scalar(
select(DocumentSegment).where(DocumentSegment.id == segment_id).limit(1)
)
assert deleted_segment is None
except Exception as e:
@ -574,7 +589,9 @@ class TestBatchCleanDocumentTask:
db_session_with_containers.commit()
# Check if the segment still exists (task may have failed before deletion)
existing_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first()
existing_segment = db_session_with_containers.scalar(
select(DocumentSegment).where(DocumentSegment.id == segment_id).limit(1)
)
if existing_segment is not None:
# If segment still exists, the task failed before deletion
# This is acceptable in test environments with external service issues
@ -645,12 +662,16 @@ class TestBatchCleanDocumentTask:
# Check that all segments are deleted
for segment_id in segment_ids:
deleted_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first()
deleted_segment = db_session_with_containers.scalar(
select(DocumentSegment).where(DocumentSegment.id == segment_id).limit(1)
)
assert deleted_segment is None
# Check that all upload files are deleted
for file_id in file_ids:
deleted_file = db_session_with_containers.query(UploadFile).filter_by(id=file_id).first()
deleted_file = db_session_with_containers.scalar(
select(UploadFile).where(UploadFile.id == file_id).limit(1)
)
assert deleted_file is None
def test_batch_clean_document_task_integration_with_real_database(
@ -699,8 +720,16 @@ class TestBatchCleanDocumentTask:
db_session_with_containers.commit()
# Verify initial state
assert db_session_with_containers.query(DocumentSegment).filter_by(document_id=document.id).count() == 3
assert db_session_with_containers.query(UploadFile).filter_by(id=upload_file.id).first() is not None
assert (
db_session_with_containers.scalar(
select(func.count()).select_from(DocumentSegment).where(DocumentSegment.document_id == document.id)
)
== 3
)
assert (
db_session_with_containers.scalar(select(UploadFile).where(UploadFile.id == upload_file.id).limit(1))
is not None
)
# Store original IDs for verification
document_id = document.id
@ -720,13 +749,20 @@ class TestBatchCleanDocumentTask:
# Check that all segments are deleted
for segment_id in segment_ids:
deleted_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first()
deleted_segment = db_session_with_containers.scalar(
select(DocumentSegment).where(DocumentSegment.id == segment_id).limit(1)
)
assert deleted_segment is None
# Check that upload file is deleted
deleted_file = db_session_with_containers.query(UploadFile).filter_by(id=file_id).first()
deleted_file = db_session_with_containers.scalar(select(UploadFile).where(UploadFile.id == file_id).limit(1))
assert deleted_file is None
# Verify final database state
assert db_session_with_containers.query(DocumentSegment).filter_by(document_id=document_id).count() == 0
assert db_session_with_containers.query(UploadFile).filter_by(id=file_id).first() is None
assert (
db_session_with_containers.scalar(
select(func.count()).select_from(DocumentSegment).where(DocumentSegment.document_id == document_id)
)
== 0
)
assert db_session_with_containers.scalar(select(UploadFile).where(UploadFile.id == file_id).limit(1)) is None

View File

@ -17,6 +17,7 @@ from unittest.mock import MagicMock, patch
import pytest
from faker import Faker
from sqlalchemy import delete, select
from sqlalchemy.orm import Session
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
@ -37,13 +38,13 @@ class TestBatchCreateSegmentToIndexTask:
from extensions.ext_redis import redis_client
# Clear all test data
db_session_with_containers.query(DocumentSegment).delete()
db_session_with_containers.query(Document).delete()
db_session_with_containers.query(Dataset).delete()
db_session_with_containers.query(UploadFile).delete()
db_session_with_containers.query(TenantAccountJoin).delete()
db_session_with_containers.query(Tenant).delete()
db_session_with_containers.query(Account).delete()
db_session_with_containers.execute(delete(DocumentSegment))
db_session_with_containers.execute(delete(Document))
db_session_with_containers.execute(delete(Dataset))
db_session_with_containers.execute(delete(UploadFile))
db_session_with_containers.execute(delete(TenantAccountJoin))
db_session_with_containers.execute(delete(Tenant))
db_session_with_containers.execute(delete(Account))
db_session_with_containers.commit()
# Clear Redis cache
@ -292,12 +293,9 @@ class TestBatchCreateSegmentToIndexTask:
# Verify results
# Check that segments were created
segments = (
db_session_with_containers.query(DocumentSegment)
.filter_by(document_id=document.id)
.order_by(DocumentSegment.position)
.all()
)
segments = db_session_with_containers.scalars(
select(DocumentSegment).where(DocumentSegment.document_id == document.id).order_by(DocumentSegment.position)
).all()
assert len(segments) == 3
# Verify segment content and metadata
@ -367,11 +365,11 @@ class TestBatchCreateSegmentToIndexTask:
# Verify no segments were created (since dataset doesn't exist)
segments = db_session_with_containers.query(DocumentSegment).all()
segments = db_session_with_containers.scalars(select(DocumentSegment)).all()
assert len(segments) == 0
# Verify no documents were modified
documents = db_session_with_containers.query(Document).all()
documents = db_session_with_containers.scalars(select(Document)).all()
assert len(documents) == 0
def test_batch_create_segment_to_index_task_document_not_found(
@ -415,12 +413,14 @@ class TestBatchCreateSegmentToIndexTask:
# Verify no segments were created
segments = db_session_with_containers.query(DocumentSegment).all()
segments = db_session_with_containers.scalars(select(DocumentSegment)).all()
assert len(segments) == 0
# Verify dataset remains unchanged (no segments were added to the dataset)
db_session_with_containers.refresh(dataset)
segments_for_dataset = db_session_with_containers.query(DocumentSegment).filter_by(dataset_id=dataset.id).all()
segments_for_dataset = db_session_with_containers.scalars(
select(DocumentSegment).where(DocumentSegment.dataset_id == dataset.id)
).all()
assert len(segments_for_dataset) == 0
def test_batch_create_segment_to_index_task_document_not_available(
@ -516,7 +516,9 @@ class TestBatchCreateSegmentToIndexTask:
assert cache_value == b"error"
# Verify no segments were created
segments = db_session_with_containers.query(DocumentSegment).filter_by(document_id=document.id).all()
segments = db_session_with_containers.scalars(
select(DocumentSegment).where(DocumentSegment.document_id == document.id)
).all()
assert len(segments) == 0
def test_batch_create_segment_to_index_task_upload_file_not_found(
@ -560,7 +562,7 @@ class TestBatchCreateSegmentToIndexTask:
# Verify no segments were created
segments = db_session_with_containers.query(DocumentSegment).all()
segments = db_session_with_containers.scalars(select(DocumentSegment)).all()
assert len(segments) == 0
# Verify document remains unchanged
@ -611,7 +613,7 @@ class TestBatchCreateSegmentToIndexTask:
# Verify error handling
# Since exception was raised, no segments should be created
segments = db_session_with_containers.query(DocumentSegment).all()
segments = db_session_with_containers.scalars(select(DocumentSegment)).all()
assert len(segments) == 0
# Verify document remains unchanged
@ -682,12 +684,9 @@ class TestBatchCreateSegmentToIndexTask:
# Verify results
# Check that new segments were created with correct positions
all_segments = (
db_session_with_containers.query(DocumentSegment)
.filter_by(document_id=document.id)
.order_by(DocumentSegment.position)
.all()
)
all_segments = db_session_with_containers.scalars(
select(DocumentSegment).where(DocumentSegment.document_id == document.id).order_by(DocumentSegment.position)
).all()
assert len(all_segments) == 6 # 3 existing + 3 new
# Verify position ordering

View File

@ -16,6 +16,7 @@ from unittest.mock import MagicMock, patch
import pytest
from faker import Faker
from sqlalchemy import delete, select
from sqlalchemy.orm import Session
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
@ -52,18 +53,18 @@ class TestCleanDatasetTask:
from extensions.ext_redis import redis_client
# Clear all test data using the provided session fixture
db_session_with_containers.query(DatasetMetadataBinding).delete()
db_session_with_containers.query(DatasetMetadata).delete()
db_session_with_containers.query(AppDatasetJoin).delete()
db_session_with_containers.query(DatasetQuery).delete()
db_session_with_containers.query(DatasetProcessRule).delete()
db_session_with_containers.query(DocumentSegment).delete()
db_session_with_containers.query(Document).delete()
db_session_with_containers.query(Dataset).delete()
db_session_with_containers.query(UploadFile).delete()
db_session_with_containers.query(TenantAccountJoin).delete()
db_session_with_containers.query(Tenant).delete()
db_session_with_containers.query(Account).delete()
db_session_with_containers.execute(delete(DatasetMetadataBinding))
db_session_with_containers.execute(delete(DatasetMetadata))
db_session_with_containers.execute(delete(AppDatasetJoin))
db_session_with_containers.execute(delete(DatasetQuery))
db_session_with_containers.execute(delete(DatasetProcessRule))
db_session_with_containers.execute(delete(DocumentSegment))
db_session_with_containers.execute(delete(Document))
db_session_with_containers.execute(delete(Dataset))
db_session_with_containers.execute(delete(UploadFile))
db_session_with_containers.execute(delete(TenantAccountJoin))
db_session_with_containers.execute(delete(Tenant))
db_session_with_containers.execute(delete(Account))
db_session_with_containers.commit()
# Clear Redis cache
@ -302,28 +303,40 @@ class TestCleanDatasetTask:
# Verify results
# Check that dataset-related data was cleaned up
documents = db_session_with_containers.query(Document).filter_by(dataset_id=dataset.id).all()
documents = db_session_with_containers.scalars(select(Document).where(Document.dataset_id == dataset.id)).all()
assert len(documents) == 0
segments = db_session_with_containers.query(DocumentSegment).filter_by(dataset_id=dataset.id).all()
segments = db_session_with_containers.scalars(
select(DocumentSegment).where(DocumentSegment.dataset_id == dataset.id)
).all()
assert len(segments) == 0
# Check that metadata and bindings were cleaned up
metadata = db_session_with_containers.query(DatasetMetadata).filter_by(dataset_id=dataset.id).all()
metadata = db_session_with_containers.scalars(
select(DatasetMetadata).where(DatasetMetadata.dataset_id == dataset.id)
).all()
assert len(metadata) == 0
bindings = db_session_with_containers.query(DatasetMetadataBinding).filter_by(dataset_id=dataset.id).all()
bindings = db_session_with_containers.scalars(
select(DatasetMetadataBinding).where(DatasetMetadataBinding.dataset_id == dataset.id)
).all()
assert len(bindings) == 0
# Check that process rules and queries were cleaned up
process_rules = db_session_with_containers.query(DatasetProcessRule).filter_by(dataset_id=dataset.id).all()
process_rules = db_session_with_containers.scalars(
select(DatasetProcessRule).where(DatasetProcessRule.dataset_id == dataset.id)
).all()
assert len(process_rules) == 0
queries = db_session_with_containers.query(DatasetQuery).filter_by(dataset_id=dataset.id).all()
queries = db_session_with_containers.scalars(
select(DatasetQuery).where(DatasetQuery.dataset_id == dataset.id)
).all()
assert len(queries) == 0
# Check that app dataset joins were cleaned up
app_joins = db_session_with_containers.query(AppDatasetJoin).filter_by(dataset_id=dataset.id).all()
app_joins = db_session_with_containers.scalars(
select(AppDatasetJoin).where(AppDatasetJoin.dataset_id == dataset.id)
).all()
assert len(app_joins) == 0
# Verify index processor was called
@ -414,24 +427,32 @@ class TestCleanDatasetTask:
# Verify results
# Check that all documents were deleted
remaining_documents = db_session_with_containers.query(Document).filter_by(dataset_id=dataset.id).all()
remaining_documents = db_session_with_containers.scalars(
select(Document).where(Document.dataset_id == dataset.id)
).all()
assert len(remaining_documents) == 0
# Check that all segments were deleted
remaining_segments = db_session_with_containers.query(DocumentSegment).filter_by(dataset_id=dataset.id).all()
remaining_segments = db_session_with_containers.scalars(
select(DocumentSegment).where(DocumentSegment.dataset_id == dataset.id)
).all()
assert len(remaining_segments) == 0
# Check that all upload files were deleted
remaining_files = db_session_with_containers.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).all()
remaining_files = db_session_with_containers.scalars(
select(UploadFile).where(UploadFile.id.in_(upload_file_ids))
).all()
assert len(remaining_files) == 0
# Check that metadata and bindings were cleaned up
remaining_metadata = db_session_with_containers.query(DatasetMetadata).filter_by(dataset_id=dataset.id).all()
remaining_metadata = db_session_with_containers.scalars(
select(DatasetMetadata).where(DatasetMetadata.dataset_id == dataset.id)
).all()
assert len(remaining_metadata) == 0
remaining_bindings = (
db_session_with_containers.query(DatasetMetadataBinding).filter_by(dataset_id=dataset.id).all()
)
remaining_bindings = db_session_with_containers.scalars(
select(DatasetMetadataBinding).where(DatasetMetadataBinding.dataset_id == dataset.id)
).all()
assert len(remaining_bindings) == 0
# Verify index processor was called
@ -485,12 +506,14 @@ class TestCleanDatasetTask:
# Check that all data was cleaned up
remaining_documents = db_session_with_containers.query(Document).filter_by(dataset_id=dataset.id).all()
remaining_documents = db_session_with_containers.scalars(
select(Document).where(Document.dataset_id == dataset.id)
).all()
assert len(remaining_documents) == 0
remaining_segments = (
db_session_with_containers.query(DocumentSegment).filter_by(dataset_id=dataset.id).all()
)
remaining_segments = db_session_with_containers.scalars(
select(DocumentSegment).where(DocumentSegment.dataset_id == dataset.id)
).all()
assert len(remaining_segments) == 0
# Recreate data for next test case
@ -538,11 +561,15 @@ class TestCleanDatasetTask:
# Verify results - even with vector cleanup failure, documents and segments should be deleted
# Check that documents were still deleted despite vector cleanup failure
remaining_documents = db_session_with_containers.query(Document).filter_by(dataset_id=dataset.id).all()
remaining_documents = db_session_with_containers.scalars(
select(Document).where(Document.dataset_id == dataset.id)
).all()
assert len(remaining_documents) == 0
# Check that segments were still deleted despite vector cleanup failure
remaining_segments = db_session_with_containers.query(DocumentSegment).filter_by(dataset_id=dataset.id).all()
remaining_segments = db_session_with_containers.scalars(
select(DocumentSegment).where(DocumentSegment.dataset_id == dataset.id)
).all()
assert len(remaining_segments) == 0
# Verify that index processor was called and failed
@ -622,18 +649,22 @@ class TestCleanDatasetTask:
# Verify results
# Check that all documents were deleted
remaining_documents = db_session_with_containers.query(Document).filter_by(dataset_id=dataset.id).all()
remaining_documents = db_session_with_containers.scalars(
select(Document).where(Document.dataset_id == dataset.id)
).all()
assert len(remaining_documents) == 0
# Check that all segments were deleted
remaining_segments = db_session_with_containers.query(DocumentSegment).filter_by(dataset_id=dataset.id).all()
remaining_segments = db_session_with_containers.scalars(
select(DocumentSegment).where(DocumentSegment.dataset_id == dataset.id)
).all()
assert len(remaining_segments) == 0
# Check that all image files were deleted from database
image_file_ids = [f.id for f in image_files]
remaining_image_files = (
db_session_with_containers.query(UploadFile).where(UploadFile.id.in_(image_file_ids)).all()
)
remaining_image_files = db_session_with_containers.scalars(
select(UploadFile).where(UploadFile.id.in_(image_file_ids))
).all()
assert len(remaining_image_files) == 0
# Verify that storage.delete was called for each image file
@ -738,24 +769,32 @@ class TestCleanDatasetTask:
# Verify results
# Check that all documents were deleted
remaining_documents = db_session_with_containers.query(Document).filter_by(dataset_id=dataset.id).all()
remaining_documents = db_session_with_containers.scalars(
select(Document).where(Document.dataset_id == dataset.id)
).all()
assert len(remaining_documents) == 0
# Check that all segments were deleted
remaining_segments = db_session_with_containers.query(DocumentSegment).filter_by(dataset_id=dataset.id).all()
remaining_segments = db_session_with_containers.scalars(
select(DocumentSegment).where(DocumentSegment.dataset_id == dataset.id)
).all()
assert len(remaining_segments) == 0
# Check that all upload files were deleted
remaining_files = db_session_with_containers.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).all()
remaining_files = db_session_with_containers.scalars(
select(UploadFile).where(UploadFile.id.in_(upload_file_ids))
).all()
assert len(remaining_files) == 0
# Check that all metadata and bindings were deleted
remaining_metadata = db_session_with_containers.query(DatasetMetadata).filter_by(dataset_id=dataset.id).all()
remaining_metadata = db_session_with_containers.scalars(
select(DatasetMetadata).where(DatasetMetadata.dataset_id == dataset.id)
).all()
assert len(remaining_metadata) == 0
remaining_bindings = (
db_session_with_containers.query(DatasetMetadataBinding).filter_by(dataset_id=dataset.id).all()
)
remaining_bindings = db_session_with_containers.scalars(
select(DatasetMetadataBinding).where(DatasetMetadataBinding.dataset_id == dataset.id)
).all()
assert len(remaining_bindings) == 0
# Verify performance expectations
@ -826,7 +865,9 @@ class TestCleanDatasetTask:
# Check that upload file was still deleted from database despite storage failure
# Note: When storage operations fail, the upload file may not be deleted
# This demonstrates that the cleanup process continues even with storage errors
remaining_files = db_session_with_containers.query(UploadFile).filter_by(id=upload_file.id).all()
remaining_files = db_session_with_containers.scalars(
select(UploadFile).where(UploadFile.id == upload_file.id)
).all()
# The upload file should still be deleted from the database even if storage cleanup fails
# However, this depends on the specific implementation of clean_dataset_task
if len(remaining_files) > 0:
@ -976,19 +1017,27 @@ class TestCleanDatasetTask:
# Verify results
# Check that all documents were deleted
remaining_documents = db_session_with_containers.query(Document).filter_by(dataset_id=dataset.id).all()
remaining_documents = db_session_with_containers.scalars(
select(Document).where(Document.dataset_id == dataset.id)
).all()
assert len(remaining_documents) == 0
# Check that all segments were deleted
remaining_segments = db_session_with_containers.query(DocumentSegment).filter_by(dataset_id=dataset.id).all()
remaining_segments = db_session_with_containers.scalars(
select(DocumentSegment).where(DocumentSegment.dataset_id == dataset.id)
).all()
assert len(remaining_segments) == 0
# Check that all upload files were deleted
remaining_files = db_session_with_containers.query(UploadFile).filter_by(id=upload_file_id).all()
remaining_files = db_session_with_containers.scalars(
select(UploadFile).where(UploadFile.id == upload_file_id)
).all()
assert len(remaining_files) == 0
# Check that all metadata was deleted
remaining_metadata = db_session_with_containers.query(DatasetMetadata).filter_by(dataset_id=dataset.id).all()
remaining_metadata = db_session_with_containers.scalars(
select(DatasetMetadata).where(DatasetMetadata.dataset_id == dataset.id)
).all()
assert len(remaining_metadata) == 0
# Verify that storage.delete was called

View File

@ -11,6 +11,7 @@ from unittest.mock import Mock, patch
import pytest
from faker import Faker
from sqlalchemy import func, select
from core.rag.index_processor.constant.index_type import IndexStructureType
from models.dataset import Dataset, Document, DocumentSegment
@ -145,11 +146,16 @@ class TestCleanNotionDocumentTask:
db_session_with_containers.commit()
# Verify data exists before cleanup
assert db_session_with_containers.query(Document).filter(Document.id.in_(document_ids)).count() == 3
assert (
db_session_with_containers.query(DocumentSegment)
.filter(DocumentSegment.document_id.in_(document_ids))
.count()
db_session_with_containers.scalar(
select(func.count()).select_from(Document).where(Document.id.in_(document_ids))
)
== 3
)
assert (
db_session_with_containers.scalar(
select(func.count()).select_from(DocumentSegment).where(DocumentSegment.document_id.in_(document_ids))
)
== 6
)
@ -158,9 +164,9 @@ class TestCleanNotionDocumentTask:
# Verify segments are deleted
assert (
db_session_with_containers.query(DocumentSegment)
.filter(DocumentSegment.document_id.in_(document_ids))
.count()
db_session_with_containers.scalar(
select(func.count()).select_from(DocumentSegment).where(DocumentSegment.document_id.in_(document_ids))
)
== 0
)
@ -323,9 +329,9 @@ class TestCleanNotionDocumentTask:
# Verify segments are deleted
assert (
db_session_with_containers.query(DocumentSegment)
.filter(DocumentSegment.document_id == document.id)
.count()
db_session_with_containers.scalar(
select(func.count()).select_from(DocumentSegment).where(DocumentSegment.document_id == document.id)
)
== 0
)
@ -411,7 +417,9 @@ class TestCleanNotionDocumentTask:
# Verify segments are deleted
assert (
db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.document_id == document.id).count()
db_session_with_containers.scalar(
select(func.count()).select_from(DocumentSegment).where(DocumentSegment.document_id == document.id)
)
== 0
)
@ -499,9 +507,16 @@ class TestCleanNotionDocumentTask:
db_session_with_containers.commit()
# Verify all data exists before cleanup
assert db_session_with_containers.query(Document).filter(Document.dataset_id == dataset.id).count() == 5
assert (
db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset.id).count()
db_session_with_containers.scalar(
select(func.count()).select_from(Document).where(Document.dataset_id == dataset.id)
)
== 5
)
assert (
db_session_with_containers.scalar(
select(func.count()).select_from(DocumentSegment).where(DocumentSegment.dataset_id == dataset.id)
)
== 10
)
@ -514,19 +529,26 @@ class TestCleanNotionDocumentTask:
# Verify only specified documents' segments are deleted
assert (
db_session_with_containers.query(DocumentSegment)
.filter(DocumentSegment.document_id.in_(documents_to_clean))
.count()
db_session_with_containers.scalar(
select(func.count())
.select_from(DocumentSegment)
.where(DocumentSegment.document_id.in_(documents_to_clean))
)
== 0
)
# Verify remaining documents and segments are intact
remaining_docs = [doc.id for doc in documents[3:]]
assert db_session_with_containers.query(Document).filter(Document.id.in_(remaining_docs)).count() == 2
assert (
db_session_with_containers.query(DocumentSegment)
.filter(DocumentSegment.document_id.in_(remaining_docs))
.count()
db_session_with_containers.scalar(
select(func.count()).select_from(Document).where(Document.id.in_(remaining_docs))
)
== 2
)
assert (
db_session_with_containers.scalar(
select(func.count()).select_from(DocumentSegment).where(DocumentSegment.document_id.in_(remaining_docs))
)
== 4
)
@ -613,7 +635,9 @@ class TestCleanNotionDocumentTask:
# Verify all segments exist before cleanup
assert (
db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.document_id == document.id).count()
db_session_with_containers.scalar(
select(func.count()).select_from(DocumentSegment).where(DocumentSegment.document_id == document.id)
)
== 4
)
@ -622,7 +646,9 @@ class TestCleanNotionDocumentTask:
# Verify all segments are deleted regardless of status
assert (
db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.document_id == document.id).count()
db_session_with_containers.scalar(
select(func.count()).select_from(DocumentSegment).where(DocumentSegment.document_id == document.id)
)
== 0
)
@ -795,11 +821,15 @@ class TestCleanNotionDocumentTask:
# Verify all data exists before cleanup
assert (
db_session_with_containers.query(Document).filter(Document.dataset_id == dataset.id).count()
db_session_with_containers.scalar(
select(func.count()).select_from(Document).where(Document.dataset_id == dataset.id)
)
== num_documents
)
assert (
db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset.id).count()
db_session_with_containers.scalar(
select(func.count()).select_from(DocumentSegment).where(DocumentSegment.dataset_id == dataset.id)
)
== num_documents * num_segments_per_doc
)
@ -809,7 +839,9 @@ class TestCleanNotionDocumentTask:
# Verify all segments are deleted
assert (
db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset.id).count()
db_session_with_containers.scalar(
select(func.count()).select_from(DocumentSegment).where(DocumentSegment.dataset_id == dataset.id)
)
== 0
)
@ -906,8 +938,8 @@ class TestCleanNotionDocumentTask:
# Verify all data exists before cleanup
# Note: There may be documents from previous tests, so we check for at least 3
assert db_session_with_containers.query(Document).count() >= 3
assert db_session_with_containers.query(DocumentSegment).count() >= 9
assert db_session_with_containers.scalar(select(func.count()).select_from(Document)) >= 3
assert db_session_with_containers.scalar(select(func.count()).select_from(DocumentSegment)) >= 9
# Clean up documents from only the first dataset
target_dataset = datasets[0]
@ -919,19 +951,26 @@ class TestCleanNotionDocumentTask:
# Verify only documents' segments from target dataset are deleted
assert (
db_session_with_containers.query(DocumentSegment)
.filter(DocumentSegment.document_id == target_document.id)
.count()
db_session_with_containers.scalar(
select(func.count())
.select_from(DocumentSegment)
.where(DocumentSegment.document_id == target_document.id)
)
== 0
)
# Verify documents from other datasets remain intact
remaining_docs = [doc.id for doc in all_documents[1:]]
assert db_session_with_containers.query(Document).filter(Document.id.in_(remaining_docs)).count() == 2
assert (
db_session_with_containers.query(DocumentSegment)
.filter(DocumentSegment.document_id.in_(remaining_docs))
.count()
db_session_with_containers.scalar(
select(func.count()).select_from(Document).where(Document.id.in_(remaining_docs))
)
== 2
)
assert (
db_session_with_containers.scalar(
select(func.count()).select_from(DocumentSegment).where(DocumentSegment.document_id.in_(remaining_docs))
)
== 6
)
@ -1028,11 +1067,13 @@ class TestCleanNotionDocumentTask:
db_session_with_containers.commit()
# Verify all data exists before cleanup
assert db_session_with_containers.query(Document).filter(Document.dataset_id == dataset.id).count() == len(
document_statuses
)
assert db_session_with_containers.scalar(
select(func.count()).select_from(Document).where(Document.dataset_id == dataset.id)
) == len(document_statuses)
assert (
db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset.id).count()
db_session_with_containers.scalar(
select(func.count()).select_from(DocumentSegment).where(DocumentSegment.dataset_id == dataset.id)
)
== len(document_statuses) * 2
)
@ -1042,7 +1083,9 @@ class TestCleanNotionDocumentTask:
# Verify all segments are deleted regardless of status
assert (
db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset.id).count()
db_session_with_containers.scalar(
select(func.count()).select_from(DocumentSegment).where(DocumentSegment.dataset_id == dataset.id)
)
== 0
)
@ -1142,9 +1185,16 @@ class TestCleanNotionDocumentTask:
db_session_with_containers.commit()
# Verify data exists before cleanup
assert db_session_with_containers.query(Document).filter(Document.id == document.id).count() == 1
assert (
db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.document_id == document.id).count()
db_session_with_containers.scalar(
select(func.count()).select_from(Document).where(Document.id == document.id)
)
== 1
)
assert (
db_session_with_containers.scalar(
select(func.count()).select_from(DocumentSegment).where(DocumentSegment.document_id == document.id)
)
== 3
)
@ -1153,7 +1203,9 @@ class TestCleanNotionDocumentTask:
# Verify segments are deleted
assert (
db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.document_id == document.id).count()
db_session_with_containers.scalar(
select(func.count()).select_from(DocumentSegment).where(DocumentSegment.document_id == document.id)
)
== 0
)

View File

@ -6,6 +6,7 @@ from unittest.mock import MagicMock, patch
import pytest
from faker import Faker
from sqlalchemy import select
from core.indexing_runner import DocumentIsPausedError
from core.rag.index_processor.constant.index_type import IndexTechniqueType
@ -175,7 +176,7 @@ class TestDatasetIndexingTaskIntegration:
def _query_document(self, db_session_with_containers, document_id: str) -> Document | None:
"""Return the latest persisted document state."""
return db_session_with_containers.query(Document).where(Document.id == document_id).first()
return db_session_with_containers.scalar(select(Document).where(Document.id == document_id).limit(1))
def _assert_documents_parsing(self, db_session_with_containers, document_ids: Sequence[str]) -> None:
"""Assert all target documents are persisted in parsing status."""

View File

@ -9,6 +9,7 @@ The task is responsible for removing document segments from the search index whe
from unittest.mock import MagicMock, patch
from faker import Faker
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
@ -471,9 +472,9 @@ class TestDisableSegmentsFromIndexTask:
db_session_with_containers.refresh(segments[1])
# Check that segments are re-enabled after error
updated_segments = (
db_session_with_containers.query(DocumentSegment).where(DocumentSegment.id.in_(segment_ids)).all()
)
updated_segments = db_session_with_containers.scalars(
select(DocumentSegment).where(DocumentSegment.id.in_(segment_ids))
).all()
for segment in updated_segments:
assert segment.enabled is True

View File

@ -12,6 +12,7 @@ from unittest.mock import Mock, patch
from uuid import uuid4
import pytest
from sqlalchemy import delete, func, select, update
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
@ -254,8 +255,8 @@ class TestDocumentIndexingSyncTask:
"""Test that task raises error when data_source_info is empty."""
# Arrange
context = self._create_notion_sync_context(db_session_with_containers, data_source_info=None)
db_session_with_containers.query(Document).where(Document.id == context["document"].id).update(
{"data_source_info": None}
db_session_with_containers.execute(
update(Document).where(Document.id == context["document"].id).values(data_source_info=None)
)
db_session_with_containers.commit()
@ -274,8 +275,8 @@ class TestDocumentIndexingSyncTask:
# Assert
db_session_with_containers.expire_all()
updated_document = (
db_session_with_containers.query(Document).where(Document.id == context["document"].id).first()
updated_document = db_session_with_containers.scalar(
select(Document).where(Document.id == context["document"].id).limit(1)
)
assert updated_document is not None
assert updated_document.indexing_status == IndexingStatus.ERROR
@ -294,13 +295,13 @@ class TestDocumentIndexingSyncTask:
# Assert
db_session_with_containers.expire_all()
updated_document = (
db_session_with_containers.query(Document).where(Document.id == context["document"].id).first()
updated_document = db_session_with_containers.scalar(
select(Document).where(Document.id == context["document"].id).limit(1)
)
remaining_segments = (
db_session_with_containers.query(DocumentSegment)
remaining_segments = db_session_with_containers.scalar(
select(func.count())
.select_from(DocumentSegment)
.where(DocumentSegment.document_id == context["document"].id)
.count()
)
assert updated_document is not None
assert updated_document.indexing_status == IndexingStatus.COMPLETED
@ -319,13 +320,13 @@ class TestDocumentIndexingSyncTask:
# Assert
db_session_with_containers.expire_all()
updated_document = (
db_session_with_containers.query(Document).where(Document.id == context["document"].id).first()
updated_document = db_session_with_containers.scalar(
select(Document).where(Document.id == context["document"].id).limit(1)
)
remaining_segments = (
db_session_with_containers.query(DocumentSegment)
remaining_segments = db_session_with_containers.scalar(
select(func.count())
.select_from(DocumentSegment)
.where(DocumentSegment.document_id == context["document"].id)
.count()
)
assert updated_document is not None
@ -354,7 +355,7 @@ class TestDocumentIndexingSyncTask:
context = self._create_notion_sync_context(db_session_with_containers)
def _delete_dataset_before_clean() -> str:
db_session_with_containers.query(Dataset).where(Dataset.id == context["dataset"].id).delete()
db_session_with_containers.execute(delete(Dataset).where(Dataset.id == context["dataset"].id))
db_session_with_containers.commit()
return "2024-01-02T00:00:00Z"
@ -367,8 +368,8 @@ class TestDocumentIndexingSyncTask:
# Assert
db_session_with_containers.expire_all()
updated_document = (
db_session_with_containers.query(Document).where(Document.id == context["document"].id).first()
updated_document = db_session_with_containers.scalar(
select(Document).where(Document.id == context["document"].id).limit(1)
)
assert updated_document is not None
assert updated_document.indexing_status == IndexingStatus.PARSING
@ -386,13 +387,13 @@ class TestDocumentIndexingSyncTask:
# Assert
db_session_with_containers.expire_all()
updated_document = (
db_session_with_containers.query(Document).where(Document.id == context["document"].id).first()
updated_document = db_session_with_containers.scalar(
select(Document).where(Document.id == context["document"].id).limit(1)
)
remaining_segments = (
db_session_with_containers.query(DocumentSegment)
remaining_segments = db_session_with_containers.scalar(
select(func.count())
.select_from(DocumentSegment)
.where(DocumentSegment.document_id == context["document"].id)
.count()
)
assert updated_document is not None
assert updated_document.indexing_status == IndexingStatus.PARSING
@ -410,8 +411,8 @@ class TestDocumentIndexingSyncTask:
# Assert
db_session_with_containers.expire_all()
updated_document = (
db_session_with_containers.query(Document).where(Document.id == context["document"].id).first()
updated_document = db_session_with_containers.scalar(
select(Document).where(Document.id == context["document"].id).limit(1)
)
assert updated_document is not None
assert updated_document.indexing_status == IndexingStatus.PARSING
@ -428,8 +429,8 @@ class TestDocumentIndexingSyncTask:
# Assert
db_session_with_containers.expire_all()
updated_document = (
db_session_with_containers.query(Document).where(Document.id == context["document"].id).first()
updated_document = db_session_with_containers.scalar(
select(Document).where(Document.id == context["document"].id).limit(1)
)
assert updated_document is not None
assert updated_document.indexing_status == IndexingStatus.ERROR

View File

@ -2,6 +2,7 @@ from unittest.mock import MagicMock, patch
import pytest
from faker import Faker
from sqlalchemy import func, select
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
@ -123,13 +124,13 @@ class TestDocumentIndexingUpdateTask:
db_session_with_containers.expire_all()
# Assert document status updated before reindex
updated = db_session_with_containers.query(Document).where(Document.id == document.id).first()
updated = db_session_with_containers.scalar(select(Document).where(Document.id == document.id).limit(1))
assert updated.indexing_status == IndexingStatus.PARSING
assert updated.processing_started_at is not None
# Segments should be deleted
remaining = (
db_session_with_containers.query(DocumentSegment).where(DocumentSegment.document_id == document.id).count()
remaining = db_session_with_containers.scalar(
select(func.count()).select_from(DocumentSegment).where(DocumentSegment.document_id == document.id)
)
assert remaining == 0
@ -167,8 +168,8 @@ class TestDocumentIndexingUpdateTask:
mock_external_dependencies["runner_instance"].run.assert_called_once()
# Segments should remain (since clean failed before DB delete)
remaining = (
db_session_with_containers.query(DocumentSegment).where(DocumentSegment.document_id == document.id).count()
remaining = db_session_with_containers.scalar(
select(func.count()).select_from(DocumentSegment).where(DocumentSegment.document_id == document.id)
)
assert remaining > 0

View File

@ -236,6 +236,41 @@ def test_pubsub_redis_url_required_when_default_unavailable(monkeypatch: pytest.
_ = DifyConfig().normalized_pubsub_redis_url
def test_dify_config_exposes_redis_key_prefix_default(monkeypatch: pytest.MonkeyPatch):
os.environ.clear()
monkeypatch.setenv("CONSOLE_API_URL", "https://example.com")
monkeypatch.setenv("CONSOLE_WEB_URL", "https://example.com")
monkeypatch.setenv("DB_TYPE", "postgresql")
monkeypatch.setenv("DB_USERNAME", "postgres")
monkeypatch.setenv("DB_PASSWORD", "postgres")
monkeypatch.setenv("DB_HOST", "localhost")
monkeypatch.setenv("DB_PORT", "5432")
monkeypatch.setenv("DB_DATABASE", "dify")
config = DifyConfig(_env_file=None)
assert config.REDIS_KEY_PREFIX == ""
def test_dify_config_reads_redis_key_prefix_from_env(monkeypatch: pytest.MonkeyPatch):
os.environ.clear()
monkeypatch.setenv("CONSOLE_API_URL", "https://example.com")
monkeypatch.setenv("CONSOLE_WEB_URL", "https://example.com")
monkeypatch.setenv("DB_TYPE", "postgresql")
monkeypatch.setenv("DB_USERNAME", "postgres")
monkeypatch.setenv("DB_PASSWORD", "postgres")
monkeypatch.setenv("DB_HOST", "localhost")
monkeypatch.setenv("DB_PORT", "5432")
monkeypatch.setenv("DB_DATABASE", "dify")
monkeypatch.setenv("REDIS_KEY_PREFIX", "enterprise-a")
config = DifyConfig(_env_file=None)
assert config.REDIS_KEY_PREFIX == "enterprise-a"
@pytest.mark.parametrize(
("broker_url", "expected_host", "expected_port", "expected_username", "expected_password", "expected_db"),
[

View File

@ -0,0 +1,139 @@
"""Unit tests for console app import endpoints."""
from __future__ import annotations
from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
from controllers.console.app import app_import as app_import_module
from services.app_dsl_service import ImportStatus
def _unwrap(func):
bound_self = getattr(func, "__self__", None)
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
if bound_self is not None:
return func.__get__(bound_self, bound_self.__class__)
return func
class _Result:
def __init__(self, status: ImportStatus, app_id: str | None = "app-1"):
self.status = status
self.app_id = app_id
def model_dump(self, mode: str = "json"):
return {"status": self.status, "app_id": self.app_id}
def _install_features(monkeypatch: pytest.MonkeyPatch, enabled: bool) -> None:
features = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=enabled))
monkeypatch.setattr(app_import_module.FeatureService, "get_system_features", lambda: features)
def _mock_session(monkeypatch: pytest.MonkeyPatch) -> MagicMock:
fake_session = MagicMock()
fake_session.__enter__.return_value = fake_session
fake_session.__exit__.return_value = None
monkeypatch.setattr(app_import_module, "db", SimpleNamespace(engine=object()))
monkeypatch.setattr(app_import_module, "Session", lambda *_args, **_kwargs: fake_session)
return fake_session
class TestAppImportApi:
@pytest.fixture
def api(self):
return app_import_module.AppImportApi()
def test_import_post_returns_failed_status_and_rolls_back(self, api, app, monkeypatch: pytest.MonkeyPatch) -> None:
method = _unwrap(api.post)
_install_features(monkeypatch, enabled=False)
session = _mock_session(monkeypatch)
monkeypatch.setattr(
app_import_module.AppDslService,
"import_app",
lambda *_args, **_kwargs: _Result(ImportStatus.FAILED, app_id=None),
)
monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1"))
with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}):
response, status = method()
session.rollback.assert_called_once_with()
session.commit.assert_not_called()
assert status == 400
assert response["status"] == ImportStatus.FAILED
def test_import_post_returns_pending_status_and_commits(self, api, app, monkeypatch: pytest.MonkeyPatch) -> None:
method = _unwrap(api.post)
_install_features(monkeypatch, enabled=False)
session = _mock_session(monkeypatch)
monkeypatch.setattr(
app_import_module.AppDslService,
"import_app",
lambda *_args, **_kwargs: _Result(ImportStatus.PENDING),
)
monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1"))
with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}):
response, status = method()
session.commit.assert_called_once_with()
session.rollback.assert_not_called()
assert status == 202
assert response["status"] == ImportStatus.PENDING
def test_import_post_updates_webapp_auth_when_enabled(self, api, app, monkeypatch: pytest.MonkeyPatch) -> None:
method = _unwrap(api.post)
_install_features(monkeypatch, enabled=True)
session = _mock_session(monkeypatch)
monkeypatch.setattr(
app_import_module.AppDslService,
"import_app",
lambda *_args, **_kwargs: _Result(ImportStatus.COMPLETED, app_id="app-123"),
)
update_access = MagicMock()
monkeypatch.setattr(app_import_module.EnterpriseService.WebAppAuth, "update_app_access_mode", update_access)
monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1"))
with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}):
response, status = method()
session.commit.assert_called_once_with()
session.rollback.assert_not_called()
update_access.assert_called_once_with("app-123", "private")
assert status == 200
assert response["status"] == ImportStatus.COMPLETED
class TestAppImportConfirmApi:
@pytest.fixture
def api(self):
return app_import_module.AppImportConfirmApi()
def test_import_confirm_returns_failed_status_and_rolls_back(
self, api, app, monkeypatch: pytest.MonkeyPatch
) -> None:
method = _unwrap(api.post)
session = _mock_session(monkeypatch)
monkeypatch.setattr(
app_import_module.AppDslService,
"confirm_import",
lambda *_args, **_kwargs: _Result(ImportStatus.FAILED),
)
monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1"))
with app.test_request_context("/console/api/apps/imports/import-1/confirm", method="POST"):
response, status = method(import_id="import-1")
session.rollback.assert_called_once_with()
session.commit.assert_not_called()
assert status == 400
assert response["status"] == ImportStatus.FAILED

View File

@ -138,12 +138,15 @@ def app_models(app_module):
def patch_signed_url(monkeypatch, app_module):
"""Ensure icon URL generation uses a deterministic helper for tests."""
def _fake_signed_url(key: str | None) -> str | None:
if not key:
def _fake_build_icon_url(_icon_type, key: str | None) -> str | None:
if key is None:
return None
icon_type = str(_icon_type).lower()
if icon_type != "image":
return None
return f"signed:{key}"
monkeypatch.setattr(app_module.file_helpers, "get_signed_file_url", _fake_signed_url)
monkeypatch.setattr(app_module, "build_icon_url", _fake_build_icon_url)
def _ts(hour: int = 12) -> datetime:

View File

@ -1,42 +0,0 @@
from datetime import datetime
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
from controllers.console.app.conversation import _get_conversation
def test_get_conversation_mark_read_keeps_updated_at_unchanged():
app_model = SimpleNamespace(id="app-id")
account = SimpleNamespace(id="account-id")
conversation = MagicMock()
conversation.id = "conversation-id"
with (
patch(
"controllers.console.app.conversation.current_account_with_tenant",
return_value=(account, None),
autospec=True,
),
patch(
"controllers.console.app.conversation.naive_utc_now",
return_value=datetime(2026, 2, 9, 0, 0, 0),
autospec=True,
),
patch("controllers.console.app.conversation.db.session", autospec=True) as mock_session,
):
mock_session.scalar.return_value = conversation
_get_conversation(app_model, "conversation-id")
statement = mock_session.execute.call_args[0][0]
compiled = statement.compile()
sql_text = str(compiled).lower()
compact_sql_text = sql_text.replace(" ", "")
params = compiled.params
assert "updated_at=current_timestamp" not in compact_sql_text
assert "updated_at=conversations.updated_at" in compact_sql_text
assert "read_at=:read_at" in compact_sql_text
assert "read_account_id=:read_account_id" in compact_sql_text
assert params["read_at"] == datetime(2026, 2, 9, 0, 0, 0)
assert params["read_account_id"] == "account-id"

View File

@ -0,0 +1,108 @@
from __future__ import annotations
from contextlib import nullcontext
from datetime import UTC, datetime
from types import SimpleNamespace
import pytest
from graphon.variables.types import SegmentType
from pydantic import ValidationError
from controllers.console.app import conversation_variables as conversation_variables_module
def _unwrap(func):
bound_self = getattr(func, "__self__", None)
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
if bound_self is not None:
return func.__get__(bound_self, bound_self.__class__)
return func
def test_get_conversation_variables_returns_paginated_response(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = conversation_variables_module.ConversationVariablesApi()
method = _unwrap(api.get)
created_at = datetime(2026, 1, 1, tzinfo=UTC)
updated_at = datetime(2026, 1, 2, tzinfo=UTC)
row = SimpleNamespace(
created_at=created_at,
updated_at=updated_at,
to_variable=lambda: SimpleNamespace(
model_dump=lambda: {
"id": "var-1",
"name": "my_var",
"value_type": "string",
"value": "value",
"description": "desc",
}
),
)
session = SimpleNamespace(scalars=lambda _stmt: SimpleNamespace(all=lambda: [row]))
monkeypatch.setattr(conversation_variables_module, "db", SimpleNamespace(engine=object()))
monkeypatch.setattr(
conversation_variables_module,
"sessionmaker",
lambda *_args, **_kwargs: SimpleNamespace(begin=lambda: nullcontext(session)),
)
with app.test_request_context(
"/console/api/apps/app-1/conversation-variables",
method="GET",
query_string={"conversation_id": "conv-1"},
):
response = method(app_model=SimpleNamespace(id="app-1"))
assert response["page"] == 1
assert response["limit"] == 100
assert response["total"] == 1
assert response["has_more"] is False
assert response["data"][0]["id"] == "var-1"
assert response["data"][0]["created_at"] == int(created_at.timestamp())
assert response["data"][0]["updated_at"] == int(updated_at.timestamp())
def test_get_conversation_variables_normalizes_value_type_and_value(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = conversation_variables_module.ConversationVariablesApi()
method = _unwrap(api.get)
row = SimpleNamespace(
created_at=None,
updated_at=None,
to_variable=lambda: SimpleNamespace(
model_dump=lambda: {
"id": "var-2",
"name": "my_var_2",
"value_type": SegmentType.INTEGER,
"value": 42,
"description": None,
}
),
)
session = SimpleNamespace(scalars=lambda _stmt: SimpleNamespace(all=lambda: [row]))
monkeypatch.setattr(conversation_variables_module, "db", SimpleNamespace(engine=object()))
monkeypatch.setattr(
conversation_variables_module,
"sessionmaker",
lambda *_args, **_kwargs: SimpleNamespace(begin=lambda: nullcontext(session)),
)
with app.test_request_context(
"/console/api/apps/app-1/conversation-variables",
method="GET",
query_string={"conversation_id": "conv-1"},
):
response = method(app_model=SimpleNamespace(id="app-1"))
assert response["data"][0]["value_type"] == "number"
assert response["data"][0]["value"] == "42"
def test_get_conversation_variables_requires_conversation_id(app) -> None:
api = conversation_variables_module.ConversationVariablesApi()
method = _unwrap(api.get)
with app.test_request_context("/console/api/apps/app-1/conversation-variables", method="GET"):
with pytest.raises(ValidationError):
method(app_model=SimpleNamespace(id="app-1"))

View File

@ -1,5 +1,7 @@
from __future__ import annotations
from datetime import UTC, datetime
import pytest
from controllers.console.app import message as message_module
@ -120,3 +122,24 @@ def test_suggested_questions_response(app, monkeypatch: pytest.MonkeyPatch) -> N
response = message_module.SuggestedQuestionsResponse(data=["What is AI?", "How does ML work?"])
assert len(response.data) == 2
assert response.data[0] == "What is AI?"
def test_message_detail_response_normalizes_aliases_and_timestamp(app, monkeypatch: pytest.MonkeyPatch) -> None:
"""Test MessageDetailResponse normalizes alias fields and datetime timestamps."""
created_at = datetime(2026, 1, 2, 3, 4, 5, tzinfo=UTC)
response = message_module.MessageDetailResponse.model_validate(
{
"id": "550e8400-e29b-41d4-a716-446655440000",
"conversation_id": "550e8400-e29b-41d4-a716-446655440001",
"inputs": {"foo": "bar"},
"query": "hello",
"re_sign_file_url_answer": "world",
"from_source": "user",
"status": "normal",
"created_at": created_at,
"message_metadata_dict": {"token_usage": 3},
}
)
assert response.answer == "world"
assert response.metadata == {"token_usage": 3}
assert response.created_at == int(created_at.timestamp())

View File

@ -258,6 +258,63 @@ def test_restore_published_workflow_to_draft_returns_400_for_invalid_structure(
assert exc.value.description == "invalid workflow graph"
def test_get_published_workflows_marshals_items_before_session_closes(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = workflow_module.PublishedAllWorkflowApi()
handler = _unwrap(api.get)
session_state = {"open": False}
class _SessionContext:
def __enter__(self):
session_state["open"] = True
return object()
def __exit__(self, exc_type, exc, tb):
session_state["open"] = False
return False
class _SessionMaker:
def begin(self):
return _SessionContext()
class _Workflow:
@property
def id(self):
assert session_state["open"] is True
return "w1"
monkeypatch.setattr(workflow_module, "db", SimpleNamespace(engine=object()))
monkeypatch.setattr(workflow_module, "sessionmaker", lambda *_args, **_kwargs: _SessionMaker())
monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1"))
monkeypatch.setattr(
workflow_module,
"WorkflowService",
lambda: SimpleNamespace(
get_all_published_workflow=lambda **_kwargs: ([_Workflow()], False),
),
)
def _fake_marshal(items, fields):
assert session_state["open"] is True
return [{"id": item.id} for item in items]
monkeypatch.setattr(workflow_module, "marshal", _fake_marshal)
with app.test_request_context(
"/apps/app/workflows",
method="GET",
query_string={"page": 1, "limit": 10, "user_id": "", "named_only": "false"},
):
response = handler(api, app_model=SimpleNamespace(id="app", workflow_id="wf-1"))
assert response == {
"items": [{"id": "w1"}],
"page": 1,
"limit": 10,
"has_more": False,
}
def test_draft_workflow_get_not_found(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
workflow_module, "WorkflowService", lambda: SimpleNamespace(get_draft_workflow=lambda **_k: None)

View File

@ -0,0 +1,85 @@
from __future__ import annotations
from datetime import UTC, datetime
from graphon.enums import WorkflowExecutionStatus
from controllers.console.app import workflow_app_log as workflow_app_log_module
def test_workflow_app_log_query_parses_bool_and_datetime():
query = workflow_app_log_module.WorkflowAppLogQuery.model_validate(
{
"detail": "true",
"created_at__before": "2026-01-02T03:04:05Z",
"page": "2",
"limit": "10",
}
)
assert query.detail is True
assert query.created_at__before == datetime(2026, 1, 2, 3, 4, 5, tzinfo=UTC)
assert query.page == 2
assert query.limit == 10
def test_workflow_app_log_pagination_response_normalizes_nested_fields():
created_at = datetime(2026, 1, 2, 3, 4, 5, tzinfo=UTC)
response = workflow_app_log_module.WorkflowAppLogPaginationResponse.model_validate(
{
"page": 1,
"limit": 20,
"total": 1,
"has_more": False,
"data": [
{
"id": "log-1",
"workflow_run": {
"id": "run-1",
"status": WorkflowExecutionStatus.SUCCEEDED,
"created_at": created_at,
"finished_at": created_at,
},
"details": {"trigger_metadata": {}},
"created_by_account": {"id": "acc-1", "name": "acc", "email": "acc@example.com"},
"created_at": created_at,
}
],
}
).model_dump(mode="json")
assert response["data"][0]["workflow_run"]["status"] == "succeeded"
assert response["data"][0]["workflow_run"]["created_at"] == int(created_at.timestamp())
assert response["data"][0]["created_at"] == int(created_at.timestamp())
def test_workflow_archived_log_pagination_response_normalizes_nested_fields():
created_at = datetime(2026, 1, 2, 3, 4, 5, tzinfo=UTC)
response = workflow_app_log_module.WorkflowArchivedLogPaginationResponse.model_validate(
{
"page": 1,
"limit": 20,
"total": 1,
"has_more": False,
"data": [
{
"id": "archived-1",
"workflow_run": {
"id": "run-1",
"status": WorkflowExecutionStatus.FAILED,
},
"trigger_metadata": {"type": "trigger-plugin"},
"created_by_end_user": {
"id": "eu-1",
"type": "anonymous",
"is_anonymous": True,
"session_id": "session-1",
},
"created_at": created_at,
}
],
}
).model_dump(mode="json")
assert response["data"][0]["workflow_run"]["status"] == "failed"
assert response["data"][0]["created_at"] == int(created_at.timestamp())

View File

@ -0,0 +1,54 @@
from __future__ import annotations
from datetime import UTC, datetime
from types import SimpleNamespace
from controllers.console.app import workflow_trigger as workflow_trigger_module
def test_parser_models_validate():
parser = workflow_trigger_module.Parser(node_id="node-1")
enable_parser = workflow_trigger_module.ParserEnable(
trigger_id="550e8400-e29b-41d4-a716-446655440000", enable_trigger=True
)
assert parser.node_id == "node-1"
assert enable_parser.enable_trigger is True
def test_workflow_trigger_response_serializes_datetime():
created_at = datetime(2026, 1, 2, 3, 4, 5, tzinfo=UTC)
trigger = SimpleNamespace(
id="trigger-1",
trigger_type="trigger-plugin",
title="Trigger",
node_id="node-1",
provider_name="provider",
icon="https://example.com/icon",
status="enabled",
created_at=created_at,
updated_at=created_at,
)
payload = workflow_trigger_module.WorkflowTriggerResponse.model_validate(trigger, from_attributes=True).model_dump(
mode="json"
)
assert payload["id"] == "trigger-1"
assert payload["created_at"] == "2026-01-02T03:04:05Z"
assert payload["updated_at"] == "2026-01-02T03:04:05Z"
def test_webhook_trigger_response_serializes_datetime():
created_at = datetime(2026, 1, 2, 3, 4, 5, tzinfo=UTC)
webhook = {
"id": "webhook-1",
"webhook_id": "whk-1",
"webhook_url": "https://example.com/hook",
"webhook_debug_url": "https://example.com/hook/debug",
"node_id": "node-1",
"created_at": created_at,
}
payload = workflow_trigger_module.WebhookTriggerResponse.model_validate(webhook).model_dump(mode="json")
assert payload["webhook_id"] == "whk-1"
assert payload["created_at"] == "2026-01-02T03:04:05Z"

View File

@ -99,6 +99,57 @@ class TestHitTestingApi:
assert "records" in result
assert result["records"] == []
def test_hit_testing_success_with_optional_record_fields(self, app, dataset, dataset_id):
api = HitTestingApi()
method = unwrap(api.post)
payload = {
"query": "what is vector search",
}
records = [
{
"segment": None,
"child_chunks": [],
"score": None,
"tsne_position": None,
"files": [],
"summary": None,
}
]
with (
app.test_request_context("/"),
patch.object(
type(console_ns),
"payload",
new_callable=PropertyMock,
return_value=payload,
),
patch.object(
HitTestingPayload,
"model_validate",
return_value=MagicMock(model_dump=lambda **_: payload),
),
patch.object(
HitTestingApi,
"get_and_validate_dataset",
return_value=dataset,
),
patch.object(
HitTestingApi,
"hit_testing_args_check",
),
patch.object(
HitTestingApi,
"perform_hit_testing",
return_value={"query": payload["query"], "records": records},
),
):
result = method(api, dataset_id)
assert result["query"] == payload["query"]
assert result["records"] == records
def test_hit_testing_dataset_not_found(self, app, dataset_id):
api = HitTestingApi()
method = unwrap(api.post)

View File

@ -1,6 +1,7 @@
from unittest.mock import MagicMock, patch
import controllers.console.explore.recommended_app as module
from models.model import AppMode, IconType
def unwrap(func):
@ -90,3 +91,48 @@ class TestRecommendedAppApi:
service_mock.assert_called_once_with("11111111-1111-1111-1111-111111111111")
assert result == result_data
class TestRecommendedAppResponseModels:
def test_recommended_app_info_response_computes_icon_url(self):
with patch.object(module, "build_icon_url", return_value="https://signed/icon.png"):
payload = module.RecommendedAppInfoResponse.model_validate(
{
"id": "app-1",
"name": "App",
"mode": AppMode.CHAT,
"icon": "icon.png",
"icon_type": IconType.IMAGE,
"icon_background": "#fff",
}
).model_dump(mode="json")
assert payload["icon_url"] == "https://signed/icon.png"
def test_recommended_app_list_response_serialization(self):
response = module.RecommendedAppListResponse.model_validate(
{
"recommended_apps": [
{
"app": {
"id": "app-1",
"name": "App",
"mode": "chat",
"icon": "icon.png",
"icon_type": "emoji",
"icon_background": "#fff",
},
"app_id": "app-1",
"description": "desc",
"category": "cat",
"position": 1,
"is_listed": True,
"can_trial": False,
}
],
"categories": ["cat"],
}
).model_dump(mode="json")
assert response["recommended_apps"][0]["app_id"] == "app-1"
assert response["categories"] == ["cat"]

View File

@ -18,6 +18,7 @@ from controllers.console.workspace.workspace import (
CustomConfigWorkspaceApi,
SwitchWorkspaceApi,
TenantApi,
TenantInfoResponse,
TenantListApi,
WebappLogoWorkspaceApi,
WorkspaceInfoApi,
@ -435,6 +436,23 @@ class TestTenantApi:
assert status == 200
class TestTenantInfoResponse:
def test_tenant_info_response_normalizes_enum_and_datetime(self):
created_at = naive_utc_now()
payload = TenantInfoResponse.model_validate(
{
"id": "t1",
"status": TenantStatus.NORMAL,
"plan": CloudPlan.TEAM,
"created_at": created_at,
}
).model_dump(mode="json")
assert payload["status"] == "normal"
assert payload["plan"] == "team"
assert payload["created_at"] == int(created_at.timestamp())
class TestSwitchWorkspaceApi:
def test_switch_success(self, app):
api = SwitchWorkspaceApi()

View File

@ -102,16 +102,16 @@ class TestEnterpriseAppDSLImport:
@pytest.fixture
def _mock_import_deps(self):
"""Patch db, sessionmaker, and AppDslService for import handler tests."""
mock_session_ctx = MagicMock()
mock_session_ctx.__enter__ = MagicMock(return_value=MagicMock())
mock_session_ctx.__exit__ = MagicMock(return_value=False)
mock_sessionmaker = MagicMock(return_value=MagicMock(begin=MagicMock(return_value=mock_session_ctx)))
"""Patch db, Session, and AppDslService for import handler tests."""
mock_session = MagicMock()
mock_session.__enter__ = MagicMock(return_value=mock_session)
mock_session.__exit__ = MagicMock(return_value=False)
with (
patch("controllers.inner_api.app.dsl.db"),
patch("controllers.inner_api.app.dsl.sessionmaker", mock_sessionmaker),
patch("controllers.inner_api.app.dsl.Session", return_value=mock_session),
patch("controllers.inner_api.app.dsl.AppDslService") as mock_dsl_cls,
):
self._mock_session = mock_session
self._mock_dsl = MagicMock()
mock_dsl_cls.return_value = self._mock_dsl
yield
@ -147,6 +147,8 @@ class TestEnterpriseAppDSLImport:
assert status_code == 200
assert body["status"] == "completed"
mock_account.set_tenant_id.assert_called_once_with("ws-123")
self._mock_session.commit.assert_called_once_with()
self._mock_session.rollback.assert_not_called()
@pytest.mark.usefixtures("_mock_import_deps")
@patch("controllers.inner_api.app.dsl._get_active_account")
@ -162,6 +164,8 @@ class TestEnterpriseAppDSLImport:
assert status_code == 202
assert body["status"] == "pending"
self._mock_session.commit.assert_called_once_with()
self._mock_session.rollback.assert_not_called()
@pytest.mark.usefixtures("_mock_import_deps")
@patch("controllers.inner_api.app.dsl._get_active_account")
@ -177,6 +181,8 @@ class TestEnterpriseAppDSLImport:
assert status_code == 400
assert body["status"] == "failed"
self._mock_session.rollback.assert_called_once_with()
self._mock_session.commit.assert_not_called()
@patch("controllers.inner_api.app.dsl._get_active_account")
def test_import_account_not_found_returns_404(self, mock_get_account, api_instance, app: Flask):

View File

@ -2,6 +2,7 @@
Unit tests for inner_api plugin decorators
"""
from typing import Any
from unittest.mock import MagicMock, patch
import pytest
@ -232,11 +233,11 @@ class TestGetUserTenant:
class PluginTestPayload:
"""Simple test payload class"""
def __init__(self, data: dict):
def __init__(self, data: dict[str, Any]):
self.value = data.get("value")
@classmethod
def model_validate(cls, data: dict):
def model_validate(cls, data: dict[str, Any]):
return cls(data)
@ -277,7 +278,7 @@ class TestPluginData:
# Arrange
class InvalidPayload:
@classmethod
def model_validate(cls, data: dict):
def model_validate(cls, data: dict[str, Any]):
raise Exception("Validation failed")
@plugin_data(payload_type=InvalidPayload)

View File

@ -15,10 +15,12 @@ Focus on:
import sys
import uuid
from datetime import UTC, datetime
from types import SimpleNamespace
from unittest.mock import Mock, patch
import pytest
from graphon.variables.types import SegmentType
from werkzeug.exceptions import BadRequest, NotFound
import services
@ -29,6 +31,8 @@ from controllers.service_api.app.conversation import (
ConversationRenameApi,
ConversationRenamePayload,
ConversationVariableDetailApi,
ConversationVariableInfiniteScrollPaginationResponse,
ConversationVariableResponse,
ConversationVariablesApi,
ConversationVariablesQuery,
ConversationVariableUpdatePayload,
@ -261,6 +265,46 @@ class TestConversationVariableUpdatePayload:
assert payload.value == nested
class TestConversationVariableResponseModels:
def test_variable_response_normalizes_value_type_and_timestamps(self):
created_at = datetime(2026, 1, 2, 3, 4, 5, tzinfo=UTC)
response = ConversationVariableResponse.model_validate(
{
"id": "550e8400-e29b-41d4-a716-446655440000",
"name": "foo",
"value_type": SegmentType.INTEGER,
"value": 1,
"description": "desc",
"created_at": created_at,
"updated_at": created_at,
}
)
assert response.value_type == "number"
assert response.value == "1"
assert response.created_at == int(created_at.timestamp())
assert response.updated_at == int(created_at.timestamp())
def test_variable_pagination_response(self):
response = ConversationVariableInfiniteScrollPaginationResponse.model_validate(
{
"limit": 1,
"has_more": False,
"data": [
{
"id": "550e8400-e29b-41d4-a716-446655440000",
"name": "foo",
"value_type": "string",
"value": "bar",
}
],
}
)
assert response.limit == 1
assert response.has_more is False
assert len(response.data) == 1
assert response.data[0].name == "foo"
class TestConversationAppModeValidation:
"""Test app mode validation for conversation endpoints."""
@ -549,6 +593,44 @@ class TestConversationVariablesApiController:
with pytest.raises(NotFound):
handler(api, app_model=app_model, end_user=end_user, c_id="00000000-0000-0000-0000-000000000001")
def test_success_serializes_response(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
created_at = datetime(2026, 1, 2, 3, 4, 5, tzinfo=UTC)
monkeypatch.setattr(
ConversationService,
"get_conversational_variable",
lambda *_args, **_kwargs: SimpleNamespace(
limit=1,
has_more=False,
data=[
{
"id": "550e8400-e29b-41d4-a716-446655440000",
"name": "foo",
"value_type": SegmentType.INTEGER,
"value": 1,
"created_at": created_at,
"updated_at": created_at,
}
],
),
)
api = ConversationVariablesApi()
handler = _unwrap(api.get)
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
end_user = SimpleNamespace()
with app.test_request_context(
"/conversations/1/variables?limit=20",
method="GET",
):
result = handler(api, app_model=app_model, end_user=end_user, c_id="00000000-0000-0000-0000-000000000001")
assert result["limit"] == 1
assert result["has_more"] is False
assert result["data"][0]["value_type"] == "number"
assert result["data"][0]["value"] == "1"
assert result["data"][0]["created_at"] == int(created_at.timestamp())
class TestConversationVariableDetailApiController:
def test_update_type_mismatch(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
@ -602,3 +684,41 @@ class TestConversationVariableDetailApiController:
c_id="00000000-0000-0000-0000-000000000001",
variable_id="00000000-0000-0000-0000-000000000002",
)
def test_update_success_serializes_response(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
created_at = datetime(2026, 1, 2, 3, 4, 5, tzinfo=UTC)
monkeypatch.setattr(
ConversationService,
"update_conversation_variable",
lambda *_args, **_kwargs: {
"id": "550e8400-e29b-41d4-a716-446655440000",
"name": "foo",
"value_type": SegmentType.INTEGER,
"value": 1,
"created_at": created_at,
"updated_at": created_at,
},
)
api = ConversationVariableDetailApi()
handler = _unwrap(api.put)
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
end_user = SimpleNamespace()
with app.test_request_context(
"/conversations/1/variables/2",
method="PUT",
json={"value": 1},
):
result = handler(
api,
app_model=app_model,
end_user=end_user,
c_id="00000000-0000-0000-0000-000000000001",
variable_id="00000000-0000-0000-0000-000000000002",
)
assert result["id"] == "550e8400-e29b-41d4-a716-446655440000"
assert result["value_type"] == "number"
assert result["value"] == "1"
assert result["created_at"] == int(created_at.timestamp())

View File

@ -15,6 +15,7 @@ Focus on:
import sys
import uuid
from datetime import UTC, datetime
from types import SimpleNamespace
from unittest.mock import Mock, patch
@ -43,6 +44,22 @@ from services.errors.llm import InvokeRateLimitError
from services.workflow_app_service import WorkflowAppService
def _make_mock_workflow_run(run_id: str = "run-1"):
run = Mock()
run.id = run_id
run.workflow_id = "wf-1"
run.status = WorkflowExecutionStatus.SUCCEEDED
run.inputs = {"input": "value"}
run.outputs_dict = {"output": "value"}
run.error = None
run.total_steps = 1
run.total_tokens = 10
run.created_at = datetime(2026, 1, 1, tzinfo=UTC)
run.finished_at = datetime(2026, 1, 1, tzinfo=UTC)
run.elapsed_time = 0.1
return run
class TestWorkflowRunPayload:
"""Test suite for WorkflowRunPayload Pydantic model."""
@ -359,7 +376,7 @@ class TestWorkflowRunDetailApi:
handler(api, app_model=app_model, workflow_run_id="run")
def test_success(self, monkeypatch: pytest.MonkeyPatch) -> None:
run = SimpleNamespace(id="run")
run = _make_mock_workflow_run(run_id="run")
repo = SimpleNamespace(get_workflow_run_by_id=lambda **_kwargs: run)
workflow_module = sys.modules["controllers.service_api.app.workflow"]
monkeypatch.setattr(workflow_module, "db", SimpleNamespace(engine=object()))
@ -373,7 +390,10 @@ class TestWorkflowRunDetailApi:
handler = _unwrap(api.get)
app_model = SimpleNamespace(mode=AppMode.WORKFLOW.value, tenant_id="t1", id="a1")
assert handler(api, app_model=app_model, workflow_run_id="run") == run
result = handler(api, app_model=app_model, workflow_run_id="run")
assert result["id"] == "run"
assert result["workflow_id"] == "wf-1"
assert result["status"] == "succeeded"
class TestWorkflowRunApi:
@ -490,7 +510,7 @@ class TestWorkflowAppLogApi:
monkeypatch.setattr(
WorkflowAppService,
"get_paginate_workflow_app_logs",
lambda *_args, **_kwargs: {"items": [], "total": 0},
lambda *_args, **_kwargs: {"page": 1, "limit": 20, "total": 0, "has_more": False, "data": []},
)
api = WorkflowAppLogApi()
@ -500,7 +520,7 @@ class TestWorkflowAppLogApi:
with app.test_request_context("/workflows/logs", method="GET"):
response = handler(api, app_model=app_model)
assert response == {"items": [], "total": 0}
assert response == {"page": 1, "limit": 20, "total": 0, "has_more": False, "data": []}
# =============================================================================
@ -527,9 +547,8 @@ def mock_workflow_app():
class TestWorkflowRunDetailApiGet:
"""Test suite for WorkflowRunDetailApi.get() endpoint.
``get`` is wrapped by ``@validate_app_token`` (preserves ``__wrapped__``)
and ``@service_api_ns.marshal_with``. We call the unwrapped method
directly; ``marshal_with`` is a no-op when calling directly.
``get`` is wrapped by ``@validate_app_token`` (preserves ``__wrapped__``),
and we call the unwrapped method directly in tests.
"""
@patch("controllers.service_api.app.workflow.DifyAPIRepositoryFactory")
@ -542,9 +561,7 @@ class TestWorkflowRunDetailApiGet:
mock_workflow_app,
):
"""Test successful workflow run detail retrieval."""
mock_run = Mock()
mock_run.id = "run-1"
mock_run.status = "succeeded"
mock_run = _make_mock_workflow_run(run_id="run-1")
mock_repo = Mock()
mock_repo.get_workflow_run_by_id.return_value = mock_run
mock_repo_factory.create_api_workflow_run_repository.return_value = mock_repo
@ -558,7 +575,8 @@ class TestWorkflowRunDetailApiGet:
api = WorkflowRunDetailApi()
result = _unwrap(api.get)(api, app_model=mock_workflow_app, workflow_run_id=mock_run.id)
assert result == mock_run
assert result["id"] == mock_run.id
assert result["status"] == "succeeded"
@patch("controllers.service_api.app.workflow.db")
def test_get_workflow_run_wrong_app_mode(self, mock_db, app):
@ -622,8 +640,7 @@ class TestWorkflowTaskStopApiPost:
class TestWorkflowAppLogApiGet:
"""Test suite for WorkflowAppLogApi.get() endpoint.
``get`` is wrapped by ``@validate_app_token`` and
``@service_api_ns.marshal_with``.
``get`` is wrapped by ``@validate_app_token``.
"""
@patch("controllers.service_api.app.workflow.WorkflowAppService")
@ -637,6 +654,10 @@ class TestWorkflowAppLogApiGet:
):
"""Test successful workflow log retrieval."""
mock_pagination = Mock()
mock_pagination.page = 1
mock_pagination.limit = 20
mock_pagination.total = 0
mock_pagination.has_more = False
mock_pagination.data = []
mock_svc_instance = Mock()
mock_svc_instance.get_paginate_workflow_app_logs.return_value = mock_pagination
@ -661,4 +682,4 @@ class TestWorkflowAppLogApiGet:
api = WorkflowAppLogApi()
result = _unwrap(api.get)(api, app_model=mock_workflow_app)
assert result == mock_pagination
assert result == {"page": 1, "limit": 20, "total": 0, "has_more": False, "data": []}

View File

@ -1,270 +0,0 @@
"""
Unit tests for Service API Site controller
"""
import uuid
from unittest.mock import Mock, patch
import pytest
from werkzeug.exceptions import Forbidden
from controllers.service_api.app.site import AppSiteApi
from models.account import TenantStatus
from models.model import App, Site
from tests.unit_tests.conftest import setup_mock_tenant_account_query
class TestAppSiteApi:
"""Test suite for AppSiteApi"""
@pytest.fixture
def mock_app_model(self):
"""Create a mock App model with tenant."""
app = Mock(spec=App)
app.id = str(uuid.uuid4())
app.tenant_id = str(uuid.uuid4())
app.status = "normal"
app.enable_api = True
mock_tenant = Mock()
mock_tenant.id = app.tenant_id
mock_tenant.status = TenantStatus.NORMAL
app.tenant = mock_tenant
return app
@pytest.fixture
def mock_site(self):
"""Create a mock Site model."""
site = Mock(spec=Site)
site.id = str(uuid.uuid4())
site.app_id = str(uuid.uuid4())
site.title = "Test Site"
site.icon = "icon-url"
site.icon_background = "#ffffff"
site.description = "Site description"
site.copyright = "Copyright 2024"
site.privacy_policy = "Privacy policy text"
site.custom_disclaimer = "Custom disclaimer"
site.default_language = "en-US"
site.prompt_public = True
site.show_workflow_steps = True
site.use_icon_as_answer_icon = False
site.chat_color_theme = "light"
site.chat_color_theme_inverted = False
site.icon_type = "image"
site.created_at = "2024-01-01T00:00:00"
site.updated_at = "2024-01-01T00:00:00"
return site
@patch("controllers.service_api.wraps.user_logged_in")
@patch("controllers.service_api.app.site.db")
@patch("controllers.service_api.wraps.current_app")
@patch("controllers.service_api.wraps.validate_and_get_api_token")
@patch("controllers.service_api.wraps.db")
def test_get_site_success(
self,
mock_wraps_db,
mock_validate_token,
mock_current_app,
mock_db,
mock_user_logged_in,
app,
mock_app_model,
mock_site,
):
"""Test successful retrieval of site configuration."""
# Arrange
mock_current_app.login_manager = Mock()
# Mock authentication
mock_api_token = Mock()
mock_api_token.app_id = mock_app_model.id
mock_api_token.tenant_id = mock_app_model.tenant_id
mock_validate_token.return_value = mock_api_token
mock_tenant = Mock()
mock_tenant.status = TenantStatus.NORMAL
mock_app_model.tenant = mock_tenant
# Mock wraps.db for authentication
mock_wraps_db.session.get.side_effect = [
mock_app_model,
mock_tenant,
]
mock_account = Mock()
mock_account.current_tenant = mock_tenant
setup_mock_tenant_account_query(mock_wraps_db, mock_tenant, mock_account)
# Mock site.db for site query
mock_db.session.scalar.return_value = mock_site
# Act
with app.test_request_context("/site", method="GET", headers={"Authorization": "Bearer test_token"}):
api = AppSiteApi()
response = api.get()
# Assert
assert response["title"] == "Test Site"
assert response["icon"] == "icon-url"
assert response["description"] == "Site description"
mock_db.session.scalar.assert_called_once()
@patch("controllers.service_api.wraps.user_logged_in")
@patch("controllers.service_api.app.site.db")
@patch("controllers.service_api.wraps.current_app")
@patch("controllers.service_api.wraps.validate_and_get_api_token")
@patch("controllers.service_api.wraps.db")
def test_get_site_not_found(
self,
mock_wraps_db,
mock_validate_token,
mock_current_app,
mock_db,
mock_user_logged_in,
app,
mock_app_model,
):
"""Test that Forbidden is raised when site is not found."""
# Arrange
mock_current_app.login_manager = Mock()
# Mock authentication
mock_api_token = Mock()
mock_api_token.app_id = mock_app_model.id
mock_api_token.tenant_id = mock_app_model.tenant_id
mock_validate_token.return_value = mock_api_token
mock_tenant = Mock()
mock_tenant.status = TenantStatus.NORMAL
mock_app_model.tenant = mock_tenant
mock_wraps_db.session.get.side_effect = [
mock_app_model,
mock_tenant,
]
mock_account = Mock()
mock_account.current_tenant = mock_tenant
setup_mock_tenant_account_query(mock_wraps_db, mock_tenant, mock_account)
# Mock site query to return None
mock_db.session.scalar.return_value = None
# Act & Assert
with app.test_request_context("/site", method="GET", headers={"Authorization": "Bearer test_token"}):
api = AppSiteApi()
with pytest.raises(Forbidden):
api.get()
@patch("controllers.service_api.wraps.user_logged_in")
@patch("controllers.service_api.app.site.db")
@patch("controllers.service_api.wraps.current_app")
@patch("controllers.service_api.wraps.validate_and_get_api_token")
@patch("controllers.service_api.wraps.db")
def test_get_site_tenant_archived(
self,
mock_wraps_db,
mock_validate_token,
mock_current_app,
mock_db,
mock_user_logged_in,
app,
mock_app_model,
mock_site,
):
"""Test that Forbidden is raised when tenant is archived."""
# Arrange
mock_current_app.login_manager = Mock()
# Mock authentication
mock_api_token = Mock()
mock_api_token.app_id = mock_app_model.id
mock_api_token.tenant_id = mock_app_model.tenant_id
mock_validate_token.return_value = mock_api_token
mock_tenant = Mock()
mock_tenant.status = TenantStatus.NORMAL
mock_wraps_db.session.get.side_effect = [
mock_app_model,
mock_tenant,
]
mock_account = Mock()
mock_account.current_tenant = mock_tenant
setup_mock_tenant_account_query(mock_wraps_db, mock_tenant, mock_account)
# Mock site query
mock_db.session.scalar.return_value = mock_site
# Set tenant status to archived AFTER authentication
mock_app_model.tenant.status = TenantStatus.ARCHIVE
# Act & Assert
with app.test_request_context("/site", method="GET", headers={"Authorization": "Bearer test_token"}):
api = AppSiteApi()
with pytest.raises(Forbidden):
api.get()
@patch("controllers.service_api.wraps.user_logged_in")
@patch("controllers.service_api.app.site.db")
@patch("controllers.service_api.wraps.current_app")
@patch("controllers.service_api.wraps.validate_and_get_api_token")
@patch("controllers.service_api.wraps.db")
def test_get_site_queries_by_app_id(
self, mock_wraps_db, mock_validate_token, mock_current_app, mock_db, mock_user_logged_in, app, mock_app_model
):
"""Test that site is queried using the app model's id."""
# Arrange
mock_current_app.login_manager = Mock()
# Mock authentication
mock_api_token = Mock()
mock_api_token.app_id = mock_app_model.id
mock_api_token.tenant_id = mock_app_model.tenant_id
mock_validate_token.return_value = mock_api_token
mock_tenant = Mock()
mock_tenant.status = TenantStatus.NORMAL
mock_app_model.tenant = mock_tenant
mock_wraps_db.session.get.side_effect = [
mock_app_model,
mock_tenant,
]
mock_account = Mock()
mock_account.current_tenant = mock_tenant
setup_mock_tenant_account_query(mock_wraps_db, mock_tenant, mock_account)
mock_site = Mock(spec=Site)
mock_site.id = str(uuid.uuid4())
mock_site.app_id = mock_app_model.id
mock_site.title = "Test Site"
mock_site.icon = "icon-url"
mock_site.icon_background = "#ffffff"
mock_site.description = "Site description"
mock_site.copyright = "Copyright 2024"
mock_site.privacy_policy = "Privacy policy text"
mock_site.custom_disclaimer = "Custom disclaimer"
mock_site.default_language = "en-US"
mock_site.prompt_public = True
mock_site.show_workflow_steps = True
mock_site.use_icon_as_answer_icon = False
mock_site.chat_color_theme = "light"
mock_site.chat_color_theme_inverted = False
mock_site.icon_type = "image"
mock_site.created_at = "2024-01-01T00:00:00"
mock_site.updated_at = "2024-01-01T00:00:00"
mock_db.session.scalar.return_value = mock_site
# Act
with app.test_request_context("/site", method="GET", headers={"Authorization": "Bearer test_token"}):
api = AppSiteApi()
api.get()
# Assert
# The query was executed successfully (site returned), which validates the correct query was made
mock_db.session.scalar.assert_called_once()

View File

@ -1,3 +1,5 @@
from typing import Any
import pytest
from core.extension.extensible import ExtensionModule
@ -12,10 +14,10 @@ class TestExternalDataTool:
# Create a concrete subclass to test init
class ConcreteTool(ExternalDataTool):
@classmethod
def validate_config(cls, tenant_id: str, config: dict):
def validate_config(cls, tenant_id: str, config: dict[str, Any]):
return super().validate_config(tenant_id, config)
def query(self, inputs: dict, query: str | None = None) -> str:
def query(self, inputs: dict[str, Any], query: str | None = None) -> str:
return super().query(inputs, query)
tool = ConcreteTool(tenant_id="tenant_1", app_id="app_1", variable="var_1", config={"key": "value"})
@ -28,10 +30,10 @@ class TestExternalDataTool:
# Create a concrete subclass to test init
class ConcreteTool(ExternalDataTool):
@classmethod
def validate_config(cls, tenant_id: str, config: dict):
def validate_config(cls, tenant_id: str, config: dict[str, Any]):
pass
def query(self, inputs: dict, query: str | None = None) -> str:
def query(self, inputs: dict[str, Any], query: str | None = None) -> str:
return ""
tool = ConcreteTool(tenant_id="tenant_1", app_id="app_1", variable="var_1")
@ -43,10 +45,10 @@ class TestExternalDataTool:
def test_validate_config_raises_not_implemented(self):
class ConcreteTool(ExternalDataTool):
@classmethod
def validate_config(cls, tenant_id: str, config: dict):
def validate_config(cls, tenant_id: str, config: dict[str, Any]):
return super().validate_config(tenant_id, config)
def query(self, inputs: dict, query: str | None = None) -> str:
def query(self, inputs: dict[str, Any], query: str | None = None) -> str:
return ""
with pytest.raises(NotImplementedError):
@ -55,10 +57,10 @@ class TestExternalDataTool:
def test_query_raises_not_implemented(self):
class ConcreteTool(ExternalDataTool):
@classmethod
def validate_config(cls, tenant_id: str, config: dict):
def validate_config(cls, tenant_id: str, config: dict[str, Any]):
pass
def query(self, inputs: dict, query: str | None = None) -> str:
def query(self, inputs: dict[str, Any], query: str | None = None) -> str:
return super().query(inputs, query)
tool = ConcreteTool(tenant_id="tenant_1", app_id="app_1", variable="var_1")

View File

@ -10,6 +10,7 @@ This module tests all aspects of the content moderation system including:
- Configuration validation
"""
from typing import Any
from unittest.mock import MagicMock, Mock, patch
import pytest
@ -28,7 +29,7 @@ class TestKeywordsModeration:
"""Test suite for custom keyword-based content moderation."""
@pytest.fixture
def keywords_config(self) -> dict:
def keywords_config(self) -> dict[str, Any]:
"""
Fixture providing a standard keywords moderation configuration.
@ -48,7 +49,7 @@ class TestKeywordsModeration:
}
@pytest.fixture
def keywords_moderation(self, keywords_config: dict) -> KeywordsModeration:
def keywords_moderation(self, keywords_config: dict[str, Any]) -> KeywordsModeration:
"""
Fixture providing a KeywordsModeration instance.
@ -64,7 +65,7 @@ class TestKeywordsModeration:
config=keywords_config,
)
def test_validate_config_success(self, keywords_config: dict):
def test_validate_config_success(self, keywords_config: dict[str, Any]):
"""Test successful validation of keywords moderation configuration."""
# Should not raise any exception
KeywordsModeration.validate_config("test-tenant", keywords_config)
@ -274,7 +275,7 @@ class TestOpenAIModeration:
"""Test suite for OpenAI-based content moderation."""
@pytest.fixture
def openai_config(self) -> dict:
def openai_config(self) -> dict[str, Any]:
"""
Fixture providing OpenAI moderation configuration.
@ -293,7 +294,7 @@ class TestOpenAIModeration:
}
@pytest.fixture
def openai_moderation(self, openai_config: dict) -> OpenAIModeration:
def openai_moderation(self, openai_config: dict[str, Any]) -> OpenAIModeration:
"""
Fixture providing an OpenAIModeration instance.
@ -309,7 +310,7 @@ class TestOpenAIModeration:
config=openai_config,
)
def test_validate_config_success(self, openai_config: dict):
def test_validate_config_success(self, openai_config: dict[str, Any]):
"""Test successful validation of OpenAI moderation configuration."""
# Should not raise any exception
OpenAIModeration.validate_config("test-tenant", openai_config)

View File

@ -1,5 +1,6 @@
import json
from types import SimpleNamespace
from typing import Any
from unittest.mock import MagicMock
import pytest
@ -57,7 +58,7 @@ class _FakeSelect:
return self
def _dataset_keyword_table(data_source_type: str = "database", keyword_table_dict: dict | None = None):
def _dataset_keyword_table(data_source_type: str = "database", keyword_table_dict: dict[str, Any] | None = None):
return SimpleNamespace(
data_source_type=data_source_type,
keyword_table_dict=keyword_table_dict,

View File

@ -1,4 +1,5 @@
from types import SimpleNamespace
from typing import Any
from unittest.mock import MagicMock, Mock, call, patch
from uuid import uuid4
@ -20,7 +21,7 @@ def create_mock_document(
doc_id: str,
score: float = 0.8,
provider: str = "dify",
additional_metadata: dict | None = None,
additional_metadata: dict[str, Any] | None = None,
) -> Document:
"""
Create a mock Document object for testing.

View File

@ -1,4 +1,5 @@
from types import SimpleNamespace
from typing import Any
from unittest.mock import Mock, patch
import pytest
@ -71,7 +72,9 @@ class TestParagraphIndexProcessor:
with pytest.raises(ValueError, match="No rules found in process rule"):
processor.transform([Document(page_content="text", metadata={})], process_rule={"mode": "custom"})
def test_transform_validates_segmentation(self, processor: ParagraphIndexProcessor, process_rule: dict) -> None:
def test_transform_validates_segmentation(
self, processor: ParagraphIndexProcessor, process_rule: dict[str, Any]
) -> None:
rules_without_segmentation = SimpleNamespace(segmentation=None)
with patch(
@ -84,7 +87,9 @@ class TestParagraphIndexProcessor:
process_rule={"mode": "custom", "rules": {"enabled": True}},
)
def test_transform_builds_split_documents(self, processor: ParagraphIndexProcessor, process_rule: dict) -> None:
def test_transform_builds_split_documents(
self, processor: ParagraphIndexProcessor, process_rule: dict[str, Any]
) -> None:
source_document = Document(page_content="source", metadata={"dataset_id": "dataset-1", "document_id": "doc-1"})
splitter = Mock()
splitter.split_documents.return_value = [

View File

@ -1,4 +1,5 @@
from types import SimpleNamespace
from typing import Any
from unittest.mock import MagicMock, Mock, patch
import pandas as pd
@ -77,7 +78,7 @@ class TestQAIndexProcessor:
processor.transform([Document(page_content="text", metadata={})], process_rule={"mode": "custom"})
def test_transform_preview_calls_formatter_once(
self, processor: QAIndexProcessor, process_rule: dict, fake_flask_app
self, processor: QAIndexProcessor, process_rule: dict[str, Any], fake_flask_app
) -> None:
document = Document(page_content="raw text", metadata={"dataset_id": "dataset-1", "document_id": "doc-1"})
split_node = Document(page_content=".question", metadata={})
@ -119,7 +120,7 @@ class TestQAIndexProcessor:
mock_format.assert_called_once()
def test_transform_non_preview_uses_thread_batches(
self, processor: QAIndexProcessor, process_rule: dict, fake_flask_app
self, processor: QAIndexProcessor, process_rule: dict[str, Any], fake_flask_app
) -> None:
documents = [
Document(page_content="doc-1", metadata={"document_id": "doc-1", "dataset_id": "dataset-1"}),

Some files were not shown because too many files have changed in this diff Show More