mirror of
https://github.com/langgenius/dify.git
synced 2026-05-09 21:28:25 +08:00
Merge branch 'main' into sandboxed-agent-rebase
Made-with: Cursor # Conflicts: # api/tests/unit_tests/controllers/console/app/test_message.py # api/tests/unit_tests/controllers/console/app/test_statistic.py # api/tests/unit_tests/controllers/console/app/test_workflow_draft_variable.py # api/tests/unit_tests/controllers/console/auth/test_data_source_bearer_auth.py # api/tests/unit_tests/controllers/console/auth/test_data_source_oauth.py # api/tests/unit_tests/controllers/console/auth/test_oauth_server.py # web/app/components/header/account-setting/data-source-page/data-source-notion/operate/index.tsx # web/app/components/header/account-setting/data-source-page/data-source-website/config-firecrawl-modal.tsx # web/app/components/header/account-setting/data-source-page/data-source-website/config-jina-reader-modal.tsx # web/app/components/header/account-setting/data-source-page/data-source-website/config-watercrawl-modal.tsx # web/app/components/header/account-setting/data-source-page/panel/config-item.tsx # web/app/components/header/account-setting/data-source-page/panel/index.tsx # web/app/components/workflow/nodes/knowledge-retrieval/node.tsx # web/package.json # web/pnpm-lock.yaml
This commit is contained in:
commit
8775003732
@ -356,6 +356,9 @@ BAIDU_VECTOR_DB_SHARD=1
|
||||
BAIDU_VECTOR_DB_REPLICAS=3
|
||||
BAIDU_VECTOR_DB_INVERTED_INDEX_ANALYZER=DEFAULT_ANALYZER
|
||||
BAIDU_VECTOR_DB_INVERTED_INDEX_PARSER_MODE=COARSE_MODE
|
||||
BAIDU_VECTOR_DB_AUTO_BUILD_ROW_COUNT_INCREMENT=500
|
||||
BAIDU_VECTOR_DB_AUTO_BUILD_ROW_COUNT_INCREMENT_RATIO=0.05
|
||||
BAIDU_VECTOR_DB_REBUILD_INDEX_TIMEOUT_IN_SECONDS=300
|
||||
|
||||
# Upstash configuration
|
||||
UPSTASH_VECTOR_URL=your-server-url
|
||||
|
||||
@ -51,3 +51,18 @@ class BaiduVectorDBConfig(BaseSettings):
|
||||
description="Parser mode for inverted index in Baidu Vector Database (default is COARSE_MODE)",
|
||||
default="COARSE_MODE",
|
||||
)
|
||||
|
||||
BAIDU_VECTOR_DB_AUTO_BUILD_ROW_COUNT_INCREMENT: int = Field(
|
||||
description="Auto build row count increment threshold (default is 500)",
|
||||
default=500,
|
||||
)
|
||||
|
||||
BAIDU_VECTOR_DB_AUTO_BUILD_ROW_COUNT_INCREMENT_RATIO: float = Field(
|
||||
description="Auto build row count increment ratio threshold (default is 0.05)",
|
||||
default=0.05,
|
||||
)
|
||||
|
||||
BAIDU_VECTOR_DB_REBUILD_INDEX_TIMEOUT_IN_SECONDS: int = Field(
|
||||
description="Timeout in seconds for rebuilding the index in Baidu Vector Database (default is 3600 seconds)",
|
||||
default=300,
|
||||
)
|
||||
|
||||
@ -9,6 +9,7 @@ from extensions.ext_database import db
|
||||
from libs.helper import TimestampField
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.dataset import Dataset
|
||||
from models.enums import ApiTokenType
|
||||
from models.model import ApiToken, App
|
||||
from services.api_token_service import ApiTokenCache
|
||||
|
||||
@ -47,7 +48,7 @@ def _get_resource(resource_id, tenant_id, resource_model):
|
||||
class BaseApiKeyListResource(Resource):
|
||||
method_decorators = [account_initialization_required, login_required, setup_required]
|
||||
|
||||
resource_type: str | None = None
|
||||
resource_type: ApiTokenType | None = None
|
||||
resource_model: type | None = None
|
||||
resource_id_field: str | None = None
|
||||
token_prefix: str | None = None
|
||||
@ -91,6 +92,7 @@ class BaseApiKeyListResource(Resource):
|
||||
)
|
||||
|
||||
key = ApiToken.generate_api_key(self.token_prefix or "", 24)
|
||||
assert self.resource_type is not None, "resource_type must be set"
|
||||
api_token = ApiToken()
|
||||
setattr(api_token, self.resource_id_field, resource_id)
|
||||
api_token.tenant_id = current_tenant_id
|
||||
@ -104,7 +106,7 @@ class BaseApiKeyListResource(Resource):
|
||||
class BaseApiKeyResource(Resource):
|
||||
method_decorators = [account_initialization_required, login_required, setup_required]
|
||||
|
||||
resource_type: str | None = None
|
||||
resource_type: ApiTokenType | None = None
|
||||
resource_model: type | None = None
|
||||
resource_id_field: str | None = None
|
||||
|
||||
@ -159,7 +161,7 @@ class AppApiKeyListResource(BaseApiKeyListResource):
|
||||
"""Create a new API key for an app"""
|
||||
return super().post(resource_id)
|
||||
|
||||
resource_type = "app"
|
||||
resource_type = ApiTokenType.APP
|
||||
resource_model = App
|
||||
resource_id_field = "app_id"
|
||||
token_prefix = "app-"
|
||||
@ -175,7 +177,7 @@ class AppApiKeyResource(BaseApiKeyResource):
|
||||
"""Delete an API key for an app"""
|
||||
return super().delete(resource_id, api_key_id)
|
||||
|
||||
resource_type = "app"
|
||||
resource_type = ApiTokenType.APP
|
||||
resource_model = App
|
||||
resource_id_field = "app_id"
|
||||
|
||||
@ -199,7 +201,7 @@ class DatasetApiKeyListResource(BaseApiKeyListResource):
|
||||
"""Create a new API key for a dataset"""
|
||||
return super().post(resource_id)
|
||||
|
||||
resource_type = "dataset"
|
||||
resource_type = ApiTokenType.DATASET
|
||||
resource_model = Dataset
|
||||
resource_id_field = "dataset_id"
|
||||
token_prefix = "ds-"
|
||||
@ -215,6 +217,6 @@ class DatasetApiKeyResource(BaseApiKeyResource):
|
||||
"""Delete an API key for a dataset"""
|
||||
return super().delete(resource_id, api_key_id)
|
||||
|
||||
resource_type = "dataset"
|
||||
resource_type = ApiTokenType.DATASET
|
||||
resource_model = Dataset
|
||||
resource_id_field = "dataset_id"
|
||||
|
||||
@ -458,9 +458,7 @@ class ChatConversationApi(Resource):
|
||||
args = ChatConversationQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
subquery = (
|
||||
db.session.query(
|
||||
Conversation.id.label("conversation_id"), EndUser.session_id.label("from_end_user_session_id")
|
||||
)
|
||||
sa.select(Conversation.id.label("conversation_id"), EndUser.session_id.label("from_end_user_session_id"))
|
||||
.outerjoin(EndUser, Conversation.from_end_user_id == EndUser.id)
|
||||
.subquery()
|
||||
)
|
||||
@ -595,10 +593,8 @@ class ChatConversationDetailApi(Resource):
|
||||
|
||||
def _get_conversation(app_model, conversation_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
conversation = (
|
||||
db.session.query(Conversation)
|
||||
.where(Conversation.id == conversation_id, Conversation.app_id == app_model.id)
|
||||
.first()
|
||||
conversation = db.session.scalar(
|
||||
sa.select(Conversation).where(Conversation.id == conversation_id, Conversation.app_id == app_model.id).limit(1)
|
||||
)
|
||||
|
||||
if not conversation:
|
||||
|
||||
@ -204,7 +204,7 @@ class InstructionGenerateApi(Resource):
|
||||
try:
|
||||
# Generate from nothing for a workflow node
|
||||
if (args.current in (code_template, "")) and args.node_id != "":
|
||||
app = db.session.query(App).where(App.id == args.flow_id).first()
|
||||
app = db.session.get(App, args.flow_id)
|
||||
if not app:
|
||||
return {"error": f"app {args.flow_id} not found"}, 400
|
||||
workflow = WorkflowService().get_draft_workflow(app_model=app)
|
||||
|
||||
@ -2,6 +2,7 @@ import json
|
||||
|
||||
from flask_restx import Resource, marshal_with
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.console import console_ns
|
||||
@ -47,7 +48,7 @@ class AppMCPServerController(Resource):
|
||||
@get_app_model
|
||||
@marshal_with(app_server_model)
|
||||
def get(self, app_model):
|
||||
server = db.session.query(AppMCPServer).where(AppMCPServer.app_id == app_model.id).first()
|
||||
server = db.session.scalar(select(AppMCPServer).where(AppMCPServer.app_id == app_model.id).limit(1))
|
||||
return server
|
||||
|
||||
@console_ns.doc("create_app_mcp_server")
|
||||
@ -98,7 +99,7 @@ class AppMCPServerController(Resource):
|
||||
@edit_permission_required
|
||||
def put(self, app_model):
|
||||
payload = MCPServerUpdatePayload.model_validate(console_ns.payload or {})
|
||||
server = db.session.query(AppMCPServer).where(AppMCPServer.id == payload.id).first()
|
||||
server = db.session.get(AppMCPServer, payload.id)
|
||||
if not server:
|
||||
raise NotFound()
|
||||
|
||||
@ -135,11 +136,10 @@ class AppMCPServerRefreshController(Resource):
|
||||
@edit_permission_required
|
||||
def get(self, server_id):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
server = (
|
||||
db.session.query(AppMCPServer)
|
||||
.where(AppMCPServer.id == server_id)
|
||||
.where(AppMCPServer.tenant_id == current_tenant_id)
|
||||
.first()
|
||||
server = db.session.scalar(
|
||||
select(AppMCPServer)
|
||||
.where(AppMCPServer.id == server_id, AppMCPServer.tenant_id == current_tenant_id)
|
||||
.limit(1)
|
||||
)
|
||||
if not server:
|
||||
raise NotFound()
|
||||
|
||||
@ -4,7 +4,7 @@ from typing import Literal
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy import exists, select
|
||||
from sqlalchemy import exists, func, select
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
@ -245,27 +245,25 @@ class ChatMessageListApi(Resource):
|
||||
def get(self, app_model):
|
||||
args = ChatMessagesQuery.model_validate(request.args.to_dict())
|
||||
|
||||
conversation = (
|
||||
db.session.query(Conversation)
|
||||
conversation = db.session.scalar(
|
||||
select(Conversation)
|
||||
.where(Conversation.id == args.conversation_id, Conversation.app_id == app_model.id)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if not conversation:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
||||
if args.first_id:
|
||||
first_message = (
|
||||
db.session.query(Message)
|
||||
.where(Message.conversation_id == conversation.id, Message.id == args.first_id)
|
||||
.first()
|
||||
first_message = db.session.scalar(
|
||||
select(Message).where(Message.conversation_id == conversation.id, Message.id == args.first_id).limit(1)
|
||||
)
|
||||
|
||||
if not first_message:
|
||||
raise NotFound("First message not found")
|
||||
|
||||
history_messages = (
|
||||
db.session.query(Message)
|
||||
history_messages = db.session.scalars(
|
||||
select(Message)
|
||||
.where(
|
||||
Message.conversation_id == conversation.id,
|
||||
Message.created_at < first_message.created_at,
|
||||
@ -273,16 +271,14 @@ class ChatMessageListApi(Resource):
|
||||
)
|
||||
.order_by(Message.created_at.desc())
|
||||
.limit(args.limit)
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
else:
|
||||
history_messages = (
|
||||
db.session.query(Message)
|
||||
history_messages = db.session.scalars(
|
||||
select(Message)
|
||||
.where(Message.conversation_id == conversation.id)
|
||||
.order_by(Message.created_at.desc())
|
||||
.limit(args.limit)
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
|
||||
# Initialize has_more based on whether we have a full page
|
||||
if len(history_messages) == args.limit:
|
||||
@ -327,7 +323,9 @@ class MessageFeedbackApi(Resource):
|
||||
|
||||
message_id = str(args.message_id)
|
||||
|
||||
message = db.session.query(Message).where(Message.id == message_id, Message.app_id == app_model.id).first()
|
||||
message = db.session.scalar(
|
||||
select(Message).where(Message.id == message_id, Message.app_id == app_model.id).limit(1)
|
||||
)
|
||||
|
||||
if not message:
|
||||
raise NotFound("Message Not Exists.")
|
||||
@ -376,7 +374,9 @@ class MessageAnnotationCountApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_model):
|
||||
count = db.session.query(MessageAnnotation).where(MessageAnnotation.app_id == app_model.id).count()
|
||||
count = db.session.scalar(
|
||||
select(func.count(MessageAnnotation.id)).where(MessageAnnotation.app_id == app_model.id)
|
||||
)
|
||||
|
||||
return {"count": count}
|
||||
|
||||
@ -480,7 +480,9 @@ class MessageApi(Resource):
|
||||
def get(self, app_model, message_id: str):
|
||||
message_id = str(message_id)
|
||||
|
||||
message = db.session.query(Message).where(Message.id == message_id, Message.app_id == app_model.id).first()
|
||||
message = db.session.scalar(
|
||||
select(Message).where(Message.id == message_id, Message.app_id == app_model.id).limit(1)
|
||||
)
|
||||
|
||||
if not message:
|
||||
raise NotFound("Message Not Exists.")
|
||||
|
||||
@ -69,9 +69,7 @@ class ModelConfigResource(Resource):
|
||||
|
||||
if app_model.mode == AppMode.AGENT_CHAT or app_model.is_agent:
|
||||
# get original app model config
|
||||
original_app_model_config = (
|
||||
db.session.query(AppModelConfig).where(AppModelConfig.id == app_model.app_model_config_id).first()
|
||||
)
|
||||
original_app_model_config = db.session.get(AppModelConfig, app_model.app_model_config_id)
|
||||
if original_app_model_config is None:
|
||||
raise ValueError("Original app model config not found")
|
||||
agent_mode = original_app_model_config.agent_mode_dict
|
||||
|
||||
@ -2,6 +2,7 @@ from typing import Literal
|
||||
|
||||
from flask_restx import Resource, marshal_with
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy import select
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from constants.languages import supported_language
|
||||
@ -75,7 +76,7 @@ class AppSite(Resource):
|
||||
def post(self, app_model):
|
||||
args = AppSiteUpdatePayload.model_validate(console_ns.payload or {})
|
||||
current_user, _ = current_account_with_tenant()
|
||||
site = db.session.query(Site).where(Site.app_id == app_model.id).first()
|
||||
site = db.session.scalar(select(Site).where(Site.app_id == app_model.id).limit(1))
|
||||
if not site:
|
||||
raise NotFound
|
||||
|
||||
@ -124,7 +125,7 @@ class AppSiteAccessTokenReset(Resource):
|
||||
@marshal_with(app_site_model)
|
||||
def post(self, app_model):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
site = db.session.query(Site).where(Site.app_id == app_model.id).first()
|
||||
site = db.session.scalar(select(Site).where(Site.app_id == app_model.id).limit(1))
|
||||
|
||||
if not site:
|
||||
raise NotFound
|
||||
|
||||
@ -2,6 +2,8 @@ from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import ParamSpec, TypeVar, Union
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from controllers.console.app.error import AppNotFoundError
|
||||
from extensions.ext_database import db
|
||||
from libs.login import current_account_with_tenant
|
||||
@ -15,16 +17,14 @@ R1 = TypeVar("R1")
|
||||
|
||||
def _load_app_model(app_id: str) -> App | None:
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
app_model = (
|
||||
db.session.query(App)
|
||||
.where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
|
||||
.first()
|
||||
app_model = db.session.scalar(
|
||||
select(App).where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal").limit(1)
|
||||
)
|
||||
return app_model
|
||||
|
||||
|
||||
def _load_app_model_with_trial(app_id: str) -> App | None:
|
||||
app_model = db.session.query(App).where(App.id == app_id, App.status == "normal").first()
|
||||
app_model = db.session.scalar(select(App).where(App.id == app_id, App.status == "normal").limit(1))
|
||||
return app_model
|
||||
|
||||
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import logging
|
||||
import urllib.parse
|
||||
|
||||
import httpx
|
||||
from flask import current_app, redirect, request
|
||||
@ -112,6 +113,9 @@ class OAuthCallback(Resource):
|
||||
error_text = e.response.text
|
||||
logger.exception("An error occurred during the OAuth process with %s: %s", provider, error_text)
|
||||
return {"error": "OAuth process failed"}, 400
|
||||
except ValueError as e:
|
||||
logger.warning("OAuth error with %s", provider, exc_info=True)
|
||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message={urllib.parse.quote(str(e))}")
|
||||
|
||||
if invite_token and RegisterService.is_valid_invite_token(invite_token):
|
||||
invitation = RegisterService.get_invitation_by_token(token=invite_token)
|
||||
|
||||
@ -54,7 +54,7 @@ from fields.document_fields import document_status_fields
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile
|
||||
from models.dataset import DatasetPermission, DatasetPermissionEnum
|
||||
from models.enums import SegmentStatus
|
||||
from models.enums import ApiTokenType, SegmentStatus
|
||||
from models.provider_ids import ModelProviderID
|
||||
from services.api_token_service import ApiTokenCache
|
||||
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
|
||||
@ -777,7 +777,7 @@ class DatasetIndexingStatusApi(Resource):
|
||||
class DatasetApiKeyApi(Resource):
|
||||
max_keys = 10
|
||||
token_prefix = "dataset-"
|
||||
resource_type = "dataset"
|
||||
resource_type = ApiTokenType.DATASET
|
||||
|
||||
@console_ns.doc("get_dataset_api_keys")
|
||||
@console_ns.doc(description="Get dataset API keys")
|
||||
@ -826,7 +826,7 @@ class DatasetApiKeyApi(Resource):
|
||||
|
||||
@console_ns.route("/datasets/api-keys/<uuid:api_key_id>")
|
||||
class DatasetApiDeleteApi(Resource):
|
||||
resource_type = "dataset"
|
||||
resource_type = ApiTokenType.DATASET
|
||||
|
||||
@console_ns.doc("delete_dataset_api_key")
|
||||
@console_ns.doc(description="Delete dataset API key")
|
||||
|
||||
@ -735,7 +735,7 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
app_id=self._application_generate_entity.app_config.app_id,
|
||||
workflow_id=self._workflow.id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
created_from=created_from.value,
|
||||
created_from=created_from,
|
||||
created_by_role=self._created_by_role,
|
||||
created_by=self._user_id,
|
||||
)
|
||||
|
||||
@ -112,8 +112,7 @@ class LLMQuotaLayer(GraphEngineLayer):
|
||||
if instance is not None:
|
||||
return instance
|
||||
logger.warning(
|
||||
"LLMQuotaLayer skipped quota deduction because node does not expose a model instance,"
|
||||
" node_id=%s",
|
||||
"LLMQuotaLayer skipped quota deduction because node does not expose a model instance, node_id=%s",
|
||||
node.id,
|
||||
)
|
||||
return None
|
||||
|
||||
@ -181,10 +181,6 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
arize_phoenix_config: ArizeConfig | PhoenixConfig,
|
||||
):
|
||||
super().__init__(arize_phoenix_config)
|
||||
import logging
|
||||
|
||||
logging.basicConfig()
|
||||
logging.getLogger().setLevel(logging.DEBUG)
|
||||
self.arize_phoenix_config = arize_phoenix_config
|
||||
self.tracer, self.processor = setup_tracer(arize_phoenix_config)
|
||||
self.project = arize_phoenix_config.project
|
||||
|
||||
@ -918,11 +918,11 @@ class ProviderManager:
|
||||
|
||||
trail_pool = CreditPoolService.get_pool(
|
||||
tenant_id=tenant_id,
|
||||
pool_type=ProviderQuotaType.TRIAL.value,
|
||||
pool_type=ProviderQuotaType.TRIAL,
|
||||
)
|
||||
paid_pool = CreditPoolService.get_pool(
|
||||
tenant_id=tenant_id,
|
||||
pool_type=ProviderQuotaType.PAID.value,
|
||||
pool_type=ProviderQuotaType.PAID,
|
||||
)
|
||||
else:
|
||||
trail_pool = None
|
||||
|
||||
@ -13,6 +13,7 @@ from pymochow.exception import ServerError # type: ignore
|
||||
from pymochow.model.database import Database
|
||||
from pymochow.model.enum import FieldType, IndexState, IndexType, MetricType, ServerErrCode, TableState # type: ignore
|
||||
from pymochow.model.schema import (
|
||||
AutoBuildRowCountIncrement,
|
||||
Field,
|
||||
FilteringIndex,
|
||||
HNSWParams,
|
||||
@ -51,6 +52,9 @@ class BaiduConfig(BaseModel):
|
||||
replicas: int = 3
|
||||
inverted_index_analyzer: str = "DEFAULT_ANALYZER"
|
||||
inverted_index_parser_mode: str = "COARSE_MODE"
|
||||
auto_build_row_count_increment: int = 500
|
||||
auto_build_row_count_increment_ratio: float = 0.05
|
||||
rebuild_index_timeout_in_seconds: int = 300
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
@ -107,18 +111,6 @@ class BaiduVector(BaseVector):
|
||||
rows.append(row)
|
||||
table.upsert(rows=rows)
|
||||
|
||||
# rebuild vector index after upsert finished
|
||||
table.rebuild_index(self.vector_index)
|
||||
timeout = 3600 # 1 hour timeout
|
||||
start_time = time.time()
|
||||
while True:
|
||||
time.sleep(1)
|
||||
index = table.describe_index(self.vector_index)
|
||||
if index.state == IndexState.NORMAL:
|
||||
break
|
||||
if time.time() - start_time > timeout:
|
||||
raise TimeoutError(f"Index rebuild timeout after {timeout} seconds")
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
res = self._db.table(self._collection_name).query(primary_key={VDBField.PRIMARY_KEY: id})
|
||||
if res and res.code == 0:
|
||||
@ -232,8 +224,14 @@ class BaiduVector(BaseVector):
|
||||
return self._client.database(self._client_config.database)
|
||||
|
||||
def _table_existed(self) -> bool:
|
||||
tables = self._db.list_table()
|
||||
return any(table.table_name == self._collection_name for table in tables)
|
||||
try:
|
||||
table = self._db.table(self._collection_name)
|
||||
except ServerError as e:
|
||||
if e.code == ServerErrCode.TABLE_NOT_EXIST:
|
||||
return False
|
||||
else:
|
||||
raise
|
||||
return True
|
||||
|
||||
def _create_table(self, dimension: int):
|
||||
# Try to grab distributed lock and create table
|
||||
@ -287,6 +285,11 @@ class BaiduVector(BaseVector):
|
||||
field=VDBField.VECTOR,
|
||||
metric_type=metric_type,
|
||||
params=HNSWParams(m=16, efconstruction=200),
|
||||
auto_build=True,
|
||||
auto_build_index_policy=AutoBuildRowCountIncrement(
|
||||
row_count_increment=self._client_config.auto_build_row_count_increment,
|
||||
row_count_increment_ratio=self._client_config.auto_build_row_count_increment_ratio,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
@ -335,7 +338,7 @@ class BaiduVector(BaseVector):
|
||||
)
|
||||
|
||||
# Wait for table created
|
||||
timeout = 300 # 5 minutes timeout
|
||||
timeout = self._client_config.rebuild_index_timeout_in_seconds # default 5 minutes timeout
|
||||
start_time = time.time()
|
||||
while True:
|
||||
time.sleep(1)
|
||||
@ -345,6 +348,20 @@ class BaiduVector(BaseVector):
|
||||
if time.time() - start_time > timeout:
|
||||
raise TimeoutError(f"Table creation timeout after {timeout} seconds")
|
||||
redis_client.set(table_exist_cache_key, 1, ex=3600)
|
||||
# rebuild vector index immediately after table created, make sure index is ready
|
||||
table.rebuild_index(self.vector_index)
|
||||
timeout = 3600 # 1 hour timeout
|
||||
self._wait_for_index_ready(table, timeout)
|
||||
|
||||
def _wait_for_index_ready(self, table, timeout: int = 3600):
|
||||
start_time = time.time()
|
||||
while True:
|
||||
time.sleep(1)
|
||||
index = table.describe_index(self.vector_index)
|
||||
if index.state == IndexState.NORMAL:
|
||||
break
|
||||
if time.time() - start_time > timeout:
|
||||
raise TimeoutError(f"Index rebuild timeout after {timeout} seconds")
|
||||
|
||||
|
||||
class BaiduVectorFactory(AbstractVectorFactory):
|
||||
@ -369,5 +386,8 @@ class BaiduVectorFactory(AbstractVectorFactory):
|
||||
replicas=dify_config.BAIDU_VECTOR_DB_REPLICAS,
|
||||
inverted_index_analyzer=dify_config.BAIDU_VECTOR_DB_INVERTED_INDEX_ANALYZER,
|
||||
inverted_index_parser_mode=dify_config.BAIDU_VECTOR_DB_INVERTED_INDEX_PARSER_MODE,
|
||||
auto_build_row_count_increment=dify_config.BAIDU_VECTOR_DB_AUTO_BUILD_ROW_COUNT_INCREMENT,
|
||||
auto_build_row_count_increment_ratio=dify_config.BAIDU_VECTOR_DB_AUTO_BUILD_ROW_COUNT_INCREMENT_RATIO,
|
||||
rebuild_index_timeout_in_seconds=dify_config.BAIDU_VECTOR_DB_REBUILD_INDEX_TIMEOUT_IN_SECONDS,
|
||||
),
|
||||
)
|
||||
|
||||
@ -33,6 +33,7 @@ from core.rag.models.document import Document
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.dataset import Dataset, TidbAuthBinding
|
||||
from models.enums import TidbAuthBindingStatus
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from qdrant_client import grpc # noqa
|
||||
@ -452,7 +453,7 @@ class TidbOnQdrantVectorFactory(AbstractVectorFactory):
|
||||
password=new_cluster["password"],
|
||||
tenant_id=dataset.tenant_id,
|
||||
active=True,
|
||||
status="ACTIVE",
|
||||
status=TidbAuthBindingStatus.ACTIVE,
|
||||
)
|
||||
db.session.add(new_tidb_auth_binding)
|
||||
db.session.commit()
|
||||
|
||||
@ -9,6 +9,7 @@ from configs import dify_config
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.dataset import TidbAuthBinding
|
||||
from models.enums import TidbAuthBindingStatus
|
||||
|
||||
|
||||
class TidbService:
|
||||
@ -170,7 +171,7 @@ class TidbService:
|
||||
userPrefix = item["userPrefix"]
|
||||
if state == "ACTIVE" and len(userPrefix) > 0:
|
||||
cluster_info = tidb_serverless_list_map[item["clusterId"]]
|
||||
cluster_info.status = "ACTIVE"
|
||||
cluster_info.status = TidbAuthBindingStatus.ACTIVE
|
||||
cluster_info.account = f"{userPrefix}.root"
|
||||
db.session.add(cluster_info)
|
||||
db.session.commit()
|
||||
|
||||
@ -95,15 +95,11 @@ class FirecrawlApp:
|
||||
if response.status_code == 200:
|
||||
crawl_status_response = response.json()
|
||||
if crawl_status_response.get("status") == "completed":
|
||||
total = crawl_status_response.get("total", 0)
|
||||
if total == 0:
|
||||
# Normalize to avoid None bypassing the zero-guard when the API returns null.
|
||||
total = crawl_status_response.get("total") or 0
|
||||
if total <= 0:
|
||||
raise Exception("Failed to check crawl status. Error: No page found")
|
||||
data = crawl_status_response.get("data", [])
|
||||
url_data_list: list[FirecrawlDocumentData] = []
|
||||
for item in data:
|
||||
if isinstance(item, dict) and "metadata" in item and "markdown" in item:
|
||||
url_data = self._extract_common_fields(item)
|
||||
url_data_list.append(url_data)
|
||||
url_data_list = self._collect_all_crawl_pages(crawl_status_response, headers)
|
||||
if url_data_list:
|
||||
file_key = "website_files/" + job_id + ".txt"
|
||||
try:
|
||||
@ -120,6 +116,36 @@ class FirecrawlApp:
|
||||
self._handle_error(response, "check crawl status")
|
||||
raise RuntimeError("unreachable: _handle_error always raises")
|
||||
|
||||
def _collect_all_crawl_pages(
|
||||
self, first_page: dict[str, Any], headers: dict[str, str]
|
||||
) -> list[FirecrawlDocumentData]:
|
||||
"""Collect all crawl result pages by following pagination links.
|
||||
|
||||
Raises an exception if any paginated request fails, to avoid returning
|
||||
partial data that is inconsistent with the reported total.
|
||||
|
||||
The number of pages processed is capped at ``total`` (the
|
||||
server-reported page count) to guard against infinite loops caused by
|
||||
a misbehaving server that keeps returning a ``next`` URL.
|
||||
"""
|
||||
total: int = first_page.get("total") or 0
|
||||
url_data_list: list[FirecrawlDocumentData] = []
|
||||
current_page = first_page
|
||||
pages_processed = 0
|
||||
while True:
|
||||
for item in current_page.get("data", []):
|
||||
if isinstance(item, dict) and "metadata" in item and "markdown" in item:
|
||||
url_data_list.append(self._extract_common_fields(item))
|
||||
next_url: str | None = current_page.get("next")
|
||||
pages_processed += 1
|
||||
if not next_url or pages_processed >= total:
|
||||
break
|
||||
response = self._get_request(next_url, headers)
|
||||
if response.status_code != 200:
|
||||
self._handle_error(response, "fetch next crawl page")
|
||||
current_page = response.json()
|
||||
return url_data_list
|
||||
|
||||
def _format_crawl_status_response(
|
||||
self,
|
||||
status: str,
|
||||
|
||||
@ -50,7 +50,7 @@ class BuiltinTool(Tool):
|
||||
return ModelInvocationUtils.invoke(
|
||||
user_id=user_id,
|
||||
tenant_id=self.runtime.tenant_id or "",
|
||||
tool_type="builtin",
|
||||
tool_type=ToolProviderType.BUILT_IN,
|
||||
tool_name=self.entity.identity.name,
|
||||
prompt_messages=prompt_messages,
|
||||
)
|
||||
|
||||
@ -38,7 +38,7 @@ class ToolLabelManager:
|
||||
db.session.add(
|
||||
ToolLabelBinding(
|
||||
tool_id=provider_id,
|
||||
tool_type=controller.provider_type.value,
|
||||
tool_type=controller.provider_type,
|
||||
label_name=label,
|
||||
)
|
||||
)
|
||||
@ -58,7 +58,7 @@ class ToolLabelManager:
|
||||
raise ValueError("Unsupported tool type")
|
||||
stmt = select(ToolLabelBinding.label_name).where(
|
||||
ToolLabelBinding.tool_id == provider_id,
|
||||
ToolLabelBinding.tool_type == controller.provider_type.value,
|
||||
ToolLabelBinding.tool_type == controller.provider_type,
|
||||
)
|
||||
labels = db.session.scalars(stmt).all()
|
||||
|
||||
|
||||
@ -9,6 +9,7 @@ from decimal import Decimal
|
||||
from typing import cast
|
||||
|
||||
from core.model_manager import ModelManager
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMResult
|
||||
from dify_graph.model_runtime.entities.message_entities import PromptMessage
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
|
||||
@ -78,7 +79,7 @@ class ModelInvocationUtils:
|
||||
|
||||
@staticmethod
|
||||
def invoke(
|
||||
user_id: str, tenant_id: str, tool_type: str, tool_name: str, prompt_messages: list[PromptMessage]
|
||||
user_id: str, tenant_id: str, tool_type: ToolProviderType, tool_name: str, prompt_messages: list[PromptMessage]
|
||||
) -> LLMResult:
|
||||
"""
|
||||
invoke model with parameters in user's own context
|
||||
|
||||
@ -1,16 +1,19 @@
|
||||
import logging
|
||||
import sys
|
||||
import urllib.parse
|
||||
from dataclasses import dataclass
|
||||
from typing import NotRequired
|
||||
|
||||
import httpx
|
||||
from pydantic import TypeAdapter
|
||||
from pydantic import TypeAdapter, ValidationError
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
from typing import TypedDict
|
||||
else:
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
JsonObject = dict[str, object]
|
||||
JsonObjectList = list[JsonObject]
|
||||
|
||||
@ -30,8 +33,8 @@ class GitHubEmailRecord(TypedDict, total=False):
|
||||
class GitHubRawUserInfo(TypedDict):
|
||||
id: int | str
|
||||
login: str
|
||||
name: NotRequired[str]
|
||||
email: NotRequired[str]
|
||||
name: NotRequired[str | None]
|
||||
email: NotRequired[str | None]
|
||||
|
||||
|
||||
class GoogleRawUserInfo(TypedDict):
|
||||
@ -127,9 +130,14 @@ class GitHubOAuth(OAuth):
|
||||
response.raise_for_status()
|
||||
user_info = GITHUB_RAW_USER_INFO_ADAPTER.validate_python(_json_object(response))
|
||||
|
||||
email_response = httpx.get(self._EMAIL_INFO_URL, headers=headers)
|
||||
email_info = GITHUB_EMAIL_RECORDS_ADAPTER.validate_python(_json_list(email_response))
|
||||
primary_email = next((email for email in email_info if email.get("primary") is True), None)
|
||||
try:
|
||||
email_response = httpx.get(self._EMAIL_INFO_URL, headers=headers)
|
||||
email_response.raise_for_status()
|
||||
email_info = GITHUB_EMAIL_RECORDS_ADAPTER.validate_python(_json_list(email_response))
|
||||
primary_email = next((email for email in email_info if email.get("primary") is True), None)
|
||||
except (httpx.HTTPStatusError, ValidationError):
|
||||
logger.warning("Failed to retrieve email from GitHub /user/emails endpoint", exc_info=True)
|
||||
primary_email = None
|
||||
|
||||
return {**user_info, "email": primary_email.get("email", "") if primary_email else ""}
|
||||
|
||||
@ -137,8 +145,11 @@ class GitHubOAuth(OAuth):
|
||||
payload = GITHUB_RAW_USER_INFO_ADAPTER.validate_python(raw_info)
|
||||
email = payload.get("email")
|
||||
if not email:
|
||||
email = f"{payload['id']}+{payload['login']}@users.noreply.github.com"
|
||||
return OAuthUserInfo(id=str(payload["id"]), name=str(payload.get("name", "")), email=email)
|
||||
raise ValueError(
|
||||
'Dify currently not supports the "Keep my email addresses private" feature,'
|
||||
" please disable it and login again"
|
||||
)
|
||||
return OAuthUserInfo(id=str(payload["id"]), name=str(payload.get("name") or ""), email=email)
|
||||
|
||||
|
||||
class GoogleOAuth(OAuth):
|
||||
|
||||
@ -43,7 +43,9 @@ from .enums import (
|
||||
IndexingStatus,
|
||||
ProcessRuleMode,
|
||||
SegmentStatus,
|
||||
SegmentType,
|
||||
SummaryStatus,
|
||||
TidbAuthBindingStatus,
|
||||
)
|
||||
from .model import App, Tag, TagBinding, UploadFile
|
||||
from .types import AdjustedJSON, BinaryData, EnumText, LongText, StringUUID, adjusted_json_index
|
||||
@ -998,7 +1000,9 @@ class ChildChunk(Base):
|
||||
# indexing fields
|
||||
index_node_id = mapped_column(String(255), nullable=True)
|
||||
index_node_hash = mapped_column(String(255), nullable=True)
|
||||
type = mapped_column(String(255), nullable=False, server_default=sa.text("'automatic'"))
|
||||
type: Mapped[SegmentType] = mapped_column(
|
||||
EnumText(SegmentType, length=255), nullable=False, server_default=sa.text("'automatic'")
|
||||
)
|
||||
created_by = mapped_column(StringUUID, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=sa.func.current_timestamp())
|
||||
updated_by = mapped_column(StringUUID, nullable=True)
|
||||
@ -1239,7 +1243,9 @@ class TidbAuthBinding(TypeBase):
|
||||
cluster_id: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
cluster_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
active: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
|
||||
status: Mapped[str] = mapped_column(sa.String(255), nullable=False, server_default=sa.text("'CREATING'"))
|
||||
status: Mapped[TidbAuthBindingStatus] = mapped_column(
|
||||
EnumText(TidbAuthBindingStatus, length=255), nullable=False, server_default=sa.text("'CREATING'")
|
||||
)
|
||||
account: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
password: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
|
||||
@ -222,6 +222,13 @@ class DatasetMetadataType(StrEnum):
|
||||
TIME = "time"
|
||||
|
||||
|
||||
class SegmentType(StrEnum):
|
||||
"""Document segment type"""
|
||||
|
||||
AUTOMATIC = "automatic"
|
||||
CUSTOMIZED = "customized"
|
||||
|
||||
|
||||
class SegmentStatus(StrEnum):
|
||||
"""Document segment status"""
|
||||
|
||||
@ -323,3 +330,10 @@ class ProviderQuotaType(StrEnum):
|
||||
if member.value == value:
|
||||
return member
|
||||
raise ValueError(f"No matching enum found for value '{value}'")
|
||||
|
||||
|
||||
class ApiTokenType(StrEnum):
|
||||
"""API Token type"""
|
||||
|
||||
APP = "app"
|
||||
DATASET = "dataset"
|
||||
|
||||
@ -66,8 +66,8 @@ class HumanInputContent(ExecutionExtraContent):
|
||||
form_id: Mapped[str] = mapped_column(StringUUID, nullable=True)
|
||||
|
||||
@classmethod
|
||||
def new(cls, form_id: str, message_id: str | None) -> "HumanInputContent":
|
||||
return cls(form_id=form_id, message_id=message_id)
|
||||
def new(cls, *, workflow_run_id: str, form_id: str, message_id: str | None) -> "HumanInputContent":
|
||||
return cls(workflow_run_id=workflow_run_id, form_id=form_id, message_id=message_id)
|
||||
|
||||
form: Mapped["HumanInputForm"] = relationship(
|
||||
"HumanInputForm",
|
||||
|
||||
@ -21,7 +21,7 @@ from configs import dify_config
|
||||
from constants import DEFAULT_FILE_NUMBER_LIMITS
|
||||
from core.tools.signature import sign_tool_file
|
||||
from dify_graph.enums import WorkflowExecutionStatus
|
||||
from dify_graph.file import FILE_MODEL_IDENTITY, File, FileTransferMethod
|
||||
from dify_graph.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType
|
||||
from dify_graph.file import helpers as file_helpers
|
||||
from extensions.storage.storage_type import StorageType
|
||||
from libs.helper import generate_string # type: ignore[import-not-found]
|
||||
@ -31,6 +31,7 @@ from .account import Account, Tenant
|
||||
from .base import Base, TypeBase, gen_uuidv4_string
|
||||
from .engine import db
|
||||
from .enums import (
|
||||
ApiTokenType,
|
||||
AppMCPServerStatus,
|
||||
AppStatus,
|
||||
BannerStatus,
|
||||
@ -43,6 +44,8 @@ from .enums import (
|
||||
MessageChainType,
|
||||
MessageFileBelongsTo,
|
||||
MessageStatus,
|
||||
ProviderQuotaType,
|
||||
TagType,
|
||||
)
|
||||
from .provider_ids import GenericProviderID
|
||||
from .types import EnumText, LongText, StringUUID
|
||||
@ -1796,7 +1799,7 @@ class MessageFile(TypeBase):
|
||||
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
|
||||
)
|
||||
message_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
type: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
type: Mapped[FileType] = mapped_column(EnumText(FileType, length=255), nullable=False)
|
||||
transfer_method: Mapped[FileTransferMethod] = mapped_column(
|
||||
EnumText(FileTransferMethod, length=255), nullable=False
|
||||
)
|
||||
@ -2108,7 +2111,7 @@ class ApiToken(Base): # bug: this uses setattr so idk the field.
|
||||
id = mapped_column(StringUUID, default=lambda: str(uuid4()))
|
||||
app_id = mapped_column(StringUUID, nullable=True)
|
||||
tenant_id = mapped_column(StringUUID, nullable=True)
|
||||
type = mapped_column(String(16), nullable=False)
|
||||
type: Mapped[ApiTokenType] = mapped_column(EnumText(ApiTokenType, length=16), nullable=False)
|
||||
token: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
last_used_at = mapped_column(sa.DateTime, nullable=True)
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
@ -2418,7 +2421,7 @@ class Tag(TypeBase):
|
||||
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
|
||||
)
|
||||
tenant_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
|
||||
type: Mapped[str] = mapped_column(String(16), nullable=False)
|
||||
type: Mapped[TagType] = mapped_column(EnumText(TagType, length=16), nullable=False)
|
||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
@ -2587,7 +2590,9 @@ class TenantCreditPool(TypeBase):
|
||||
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
|
||||
)
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
pool_type: Mapped[str] = mapped_column(String(40), nullable=False, default="trial", server_default="trial")
|
||||
pool_type: Mapped[ProviderQuotaType] = mapped_column(
|
||||
EnumText(ProviderQuotaType, length=40), nullable=False, default=ProviderQuotaType.TRIAL, server_default="trial"
|
||||
)
|
||||
quota_limit: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0)
|
||||
quota_used: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
|
||||
@ -13,12 +13,16 @@ from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_bundle import ApiToolBundle
|
||||
from core.tools.entities.tool_entities import ApiProviderSchemaType, WorkflowToolParameterConfiguration
|
||||
from core.tools.entities.tool_entities import (
|
||||
ApiProviderSchemaType,
|
||||
ToolProviderType,
|
||||
WorkflowToolParameterConfiguration,
|
||||
)
|
||||
|
||||
from .base import TypeBase
|
||||
from .engine import db
|
||||
from .model import Account, App, Tenant
|
||||
from .types import LongText, StringUUID
|
||||
from .types import EnumText, LongText, StringUUID
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.entities.mcp_provider import MCPProviderEntity
|
||||
@ -208,7 +212,7 @@ class ToolLabelBinding(TypeBase):
|
||||
# tool id
|
||||
tool_id: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||
# tool type
|
||||
tool_type: Mapped[str] = mapped_column(String(40), nullable=False)
|
||||
tool_type: Mapped[ToolProviderType] = mapped_column(EnumText(ToolProviderType, length=40), nullable=False)
|
||||
# label name
|
||||
label_name: Mapped[str] = mapped_column(String(40), nullable=False)
|
||||
|
||||
@ -386,7 +390,7 @@ class ToolModelInvoke(TypeBase):
|
||||
# provider
|
||||
provider: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
# type
|
||||
tool_type: Mapped[str] = mapped_column(String(40), nullable=False)
|
||||
tool_type: Mapped[ToolProviderType] = mapped_column(EnumText(ToolProviderType, length=40), nullable=False)
|
||||
# tool name
|
||||
tool_name: Mapped[str] = mapped_column(String(128), nullable=False)
|
||||
# invoke parameters
|
||||
|
||||
@ -1260,7 +1260,9 @@ class WorkflowAppLog(TypeBase):
|
||||
app_id: Mapped[str] = mapped_column(StringUUID)
|
||||
workflow_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
workflow_run_id: Mapped[str] = mapped_column(StringUUID)
|
||||
created_from: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
created_from: Mapped[WorkflowAppLogCreatedFrom] = mapped_column(
|
||||
EnumText(WorkflowAppLogCreatedFrom, length=255), nullable=False
|
||||
)
|
||||
created_by_role: Mapped[CreatorUserRole] = mapped_column(EnumText(CreatorUserRole, length=255), nullable=False)
|
||||
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
@ -1340,10 +1342,14 @@ class WorkflowArchiveLog(TypeBase):
|
||||
|
||||
log_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
|
||||
log_created_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
||||
log_created_from: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
log_created_from: Mapped[WorkflowAppLogCreatedFrom | None] = mapped_column(
|
||||
EnumText(WorkflowAppLogCreatedFrom, length=255), nullable=True
|
||||
)
|
||||
|
||||
run_version: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
run_status: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
run_status: Mapped[WorkflowExecutionStatus] = mapped_column(
|
||||
EnumText(WorkflowExecutionStatus, length=255), nullable=False
|
||||
)
|
||||
run_triggered_from: Mapped[WorkflowRunTriggeredFrom] = mapped_column(
|
||||
EnumText(WorkflowRunTriggeredFrom, length=255), nullable=False
|
||||
)
|
||||
|
||||
@ -8,7 +8,7 @@ dependencies = [
|
||||
"arize-phoenix-otel~=0.15.0",
|
||||
"azure-identity==1.25.3",
|
||||
"beautifulsoup4==4.14.3",
|
||||
"boto3==1.42.68",
|
||||
"boto3==1.42.73",
|
||||
"bs4~=0.0.1",
|
||||
"cachetools~=5.3.0",
|
||||
"celery~=5.6.2",
|
||||
@ -23,7 +23,7 @@ dependencies = [
|
||||
"gevent~=25.9.1",
|
||||
"gmpy2~=2.3.0",
|
||||
"google-api-core>=2.19.1",
|
||||
"google-api-python-client==2.192.0",
|
||||
"google-api-python-client==2.193.0",
|
||||
"google-auth>=2.47.0",
|
||||
"google-auth-httplib2==0.3.0",
|
||||
"google-cloud-aiplatform>=1.123.0",
|
||||
@ -40,7 +40,7 @@ dependencies = [
|
||||
"numpy~=1.26.4",
|
||||
"openpyxl~=3.1.5",
|
||||
"opik~=1.10.37",
|
||||
"litellm==1.82.2", # Pinned to avoid madoka dependency issue
|
||||
"litellm==1.82.6", # Pinned to avoid madoka dependency issue
|
||||
"opentelemetry-api==1.28.0",
|
||||
"opentelemetry-distro==0.49b0",
|
||||
"opentelemetry-exporter-otlp==1.28.0",
|
||||
@ -72,13 +72,14 @@ dependencies = [
|
||||
"pyyaml~=6.0.1",
|
||||
"readabilipy~=0.3.0",
|
||||
"redis[hiredis]~=7.3.0",
|
||||
"resend~=2.23.0",
|
||||
"sentry-sdk[flask]~=2.54.0",
|
||||
"resend~=2.26.0",
|
||||
"sentry-sdk[flask]~=2.55.0",
|
||||
"sqlalchemy~=2.0.29",
|
||||
"starlette==0.52.1",
|
||||
"starlette==1.0.0",
|
||||
"tiktoken~=0.12.0",
|
||||
"transformers~=5.3.0",
|
||||
"unstructured[docx,epub,md,ppt,pptx]~=0.21.5",
|
||||
"pypandoc~=1.13",
|
||||
"yarl~=1.23.0",
|
||||
"webvtt-py~=0.5.1",
|
||||
"sseclient-py~=1.9.0",
|
||||
@ -91,7 +92,7 @@ dependencies = [
|
||||
"apscheduler>=3.11.0",
|
||||
"weave>=0.52.16",
|
||||
"fastopenapi[flask]>=0.7.0",
|
||||
"bleach~=6.2.0",
|
||||
"bleach~=6.3.0",
|
||||
]
|
||||
# Before adding new dependency, consider place it in
|
||||
# alphabet order (a-z) and suitable group.
|
||||
@ -118,7 +119,7 @@ dev = [
|
||||
"ruff~=0.15.5",
|
||||
"pytest~=9.0.2",
|
||||
"pytest-benchmark~=5.2.3",
|
||||
"pytest-cov~=7.0.0",
|
||||
"pytest-cov~=7.1.0",
|
||||
"pytest-env~=1.6.0",
|
||||
"pytest-mock~=3.15.1",
|
||||
"testcontainers~=4.14.1",
|
||||
@ -202,7 +203,7 @@ tools = ["cloudscraper~=1.2.71", "nltk~=3.9.1"]
|
||||
# Required by vector store clients
|
||||
############################################################
|
||||
vdb = [
|
||||
"alibabacloud_gpdb20160503~=3.8.0",
|
||||
"alibabacloud_gpdb20160503~=5.1.0",
|
||||
"alibabacloud_tea_openapi~=0.4.3",
|
||||
"chromadb==0.5.20",
|
||||
"clickhouse-connect~=0.14.1",
|
||||
|
||||
@ -8,6 +8,7 @@ from configs import dify_config
|
||||
from core.rag.datasource.vdb.tidb_on_qdrant.tidb_service import TidbService
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import TidbAuthBinding
|
||||
from models.enums import TidbAuthBindingStatus
|
||||
|
||||
|
||||
@app.celery.task(queue="dataset")
|
||||
@ -57,7 +58,7 @@ def create_clusters(batch_size):
|
||||
account=new_cluster["account"],
|
||||
password=new_cluster["password"],
|
||||
active=False,
|
||||
status="CREATING",
|
||||
status=TidbAuthBindingStatus.CREATING,
|
||||
)
|
||||
db.session.add(tidb_auth_binding)
|
||||
db.session.commit()
|
||||
|
||||
@ -9,6 +9,7 @@ from configs import dify_config
|
||||
from core.rag.datasource.vdb.tidb_on_qdrant.tidb_service import TidbService
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import TidbAuthBinding
|
||||
from models.enums import TidbAuthBindingStatus
|
||||
|
||||
|
||||
@app.celery.task(queue="dataset")
|
||||
@ -18,7 +19,10 @@ def update_tidb_serverless_status_task():
|
||||
try:
|
||||
# check the number of idle tidb serverless
|
||||
tidb_serverless_list = db.session.scalars(
|
||||
select(TidbAuthBinding).where(TidbAuthBinding.active == False, TidbAuthBinding.status == "CREATING")
|
||||
select(TidbAuthBinding).where(
|
||||
TidbAuthBinding.active == False,
|
||||
TidbAuthBinding.status == TidbAuthBindingStatus.CREATING,
|
||||
)
|
||||
).all()
|
||||
if len(tidb_serverless_list) == 0:
|
||||
return
|
||||
|
||||
@ -7,6 +7,7 @@ from configs import dify_config
|
||||
from core.errors.error import QuotaExceededError
|
||||
from extensions.ext_database import db
|
||||
from models import TenantCreditPool
|
||||
from models.enums import ProviderQuotaType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -16,7 +17,10 @@ class CreditPoolService:
|
||||
def create_default_pool(cls, tenant_id: str) -> TenantCreditPool:
|
||||
"""create default credit pool for new tenant"""
|
||||
credit_pool = TenantCreditPool(
|
||||
tenant_id=tenant_id, quota_limit=dify_config.HOSTED_POOL_CREDITS, quota_used=0, pool_type="trial"
|
||||
tenant_id=tenant_id,
|
||||
quota_limit=dify_config.HOSTED_POOL_CREDITS,
|
||||
quota_used=0,
|
||||
pool_type=ProviderQuotaType.TRIAL,
|
||||
)
|
||||
db.session.add(credit_pool)
|
||||
db.session.commit()
|
||||
|
||||
@ -58,6 +58,7 @@ from models.enums import (
|
||||
IndexingStatus,
|
||||
ProcessRuleMode,
|
||||
SegmentStatus,
|
||||
SegmentType,
|
||||
)
|
||||
from models.model import UploadFile
|
||||
from models.provider_ids import ModelProviderID
|
||||
@ -3786,7 +3787,7 @@ class SegmentService:
|
||||
child_chunk.word_count = len(child_chunk.content)
|
||||
child_chunk.updated_by = current_user.id
|
||||
child_chunk.updated_at = naive_utc_now()
|
||||
child_chunk.type = "customized"
|
||||
child_chunk.type = SegmentType.CUSTOMIZED
|
||||
update_child_chunks.append(child_chunk)
|
||||
else:
|
||||
new_child_chunks_args.append(child_chunk_update_args)
|
||||
@ -3845,7 +3846,7 @@ class SegmentService:
|
||||
child_chunk.word_count = len(content)
|
||||
child_chunk.updated_by = current_user.id
|
||||
child_chunk.updated_at = naive_utc_now()
|
||||
child_chunk.type = "customized"
|
||||
child_chunk.type = SegmentType.CUSTOMIZED
|
||||
db.session.add(child_chunk)
|
||||
VectorService.update_child_chunk_vector([], [child_chunk], [], dataset)
|
||||
db.session.commit()
|
||||
|
||||
@ -7,6 +7,7 @@ from werkzeug.exceptions import NotFound
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset
|
||||
from models.enums import TagType
|
||||
from models.model import App, Tag, TagBinding
|
||||
|
||||
|
||||
@ -83,7 +84,7 @@ class TagService:
|
||||
raise ValueError("Tag name already exists")
|
||||
tag = Tag(
|
||||
name=args["name"],
|
||||
type=args["type"],
|
||||
type=TagType(args["type"]),
|
||||
created_by=current_user.id,
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
)
|
||||
|
||||
@ -179,7 +179,7 @@ def _record_trigger_failure_log(
|
||||
app_id=workflow.app_id,
|
||||
workflow_id=workflow.id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
created_from=WorkflowAppLogCreatedFrom.SERVICE_API.value,
|
||||
created_from=WorkflowAppLogCreatedFrom.SERVICE_API,
|
||||
created_by_role=created_by_role,
|
||||
created_by=created_by,
|
||||
)
|
||||
|
||||
@ -13,6 +13,7 @@ from unittest.mock import patch
|
||||
import pytest
|
||||
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.enums import ApiTokenType
|
||||
from models.model import ApiToken
|
||||
from services.api_token_service import ApiTokenCache, CachedApiToken
|
||||
|
||||
@ -279,7 +280,7 @@ class TestEndToEndCacheFlow:
|
||||
test_token = ApiToken()
|
||||
test_token.id = "test-e2e-id"
|
||||
test_token.token = test_token_value
|
||||
test_token.type = test_scope
|
||||
test_token.type = ApiTokenType.APP
|
||||
test_token.app_id = "test-app"
|
||||
test_token.tenant_id = "test-tenant"
|
||||
test_token.last_used_at = None
|
||||
|
||||
@ -0,0 +1,342 @@
|
||||
"""Authenticated controller integration tests for console message APIs."""
|
||||
|
||||
from datetime import timedelta
|
||||
from decimal import Decimal
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from flask.testing import FlaskClient
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from controllers.console.app.message import ChatMessagesQuery, FeedbackExportQuery, MessageFeedbackPayload
|
||||
from controllers.console.app.message import attach_message_extra_contents as _attach_message_extra_contents
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.enums import ConversationFromSource, FeedbackRating
|
||||
from models.model import AppMode, Conversation, Message, MessageAnnotation, MessageFeedback
|
||||
from services.errors.conversation import ConversationNotExistsError
|
||||
from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError
|
||||
from tests.test_containers_integration_tests.controllers.console.helpers import (
|
||||
authenticate_console_client,
|
||||
create_console_account_and_tenant,
|
||||
create_console_app,
|
||||
)
|
||||
|
||||
|
||||
def _create_conversation(db_session: Session, app_id: str, account_id: str, mode: AppMode) -> Conversation:
|
||||
conversation = Conversation(
|
||||
app_id=app_id,
|
||||
app_model_config_id=None,
|
||||
model_provider=None,
|
||||
model_id="",
|
||||
override_model_configs=None,
|
||||
mode=mode,
|
||||
name="Test Conversation",
|
||||
inputs={},
|
||||
introduction="",
|
||||
system_instruction="",
|
||||
system_instruction_tokens=0,
|
||||
status="normal",
|
||||
from_source=ConversationFromSource.CONSOLE,
|
||||
from_account_id=account_id,
|
||||
)
|
||||
db_session.add(conversation)
|
||||
db_session.commit()
|
||||
return conversation
|
||||
|
||||
|
||||
def _create_message(
|
||||
db_session: Session,
|
||||
app_id: str,
|
||||
conversation_id: str,
|
||||
account_id: str,
|
||||
*,
|
||||
created_at_offset_seconds: int = 0,
|
||||
) -> Message:
|
||||
created_at = naive_utc_now() + timedelta(seconds=created_at_offset_seconds)
|
||||
message = Message(
|
||||
app_id=app_id,
|
||||
model_provider=None,
|
||||
model_id="",
|
||||
override_model_configs=None,
|
||||
conversation_id=conversation_id,
|
||||
inputs={},
|
||||
query="Hello",
|
||||
message={"type": "text", "content": "Hello"},
|
||||
message_tokens=1,
|
||||
message_unit_price=Decimal("0.0001"),
|
||||
message_price_unit=Decimal("0.001"),
|
||||
answer="Hi there",
|
||||
answer_tokens=1,
|
||||
answer_unit_price=Decimal("0.0001"),
|
||||
answer_price_unit=Decimal("0.001"),
|
||||
parent_message_id=None,
|
||||
provider_response_latency=0,
|
||||
total_price=Decimal("0.0002"),
|
||||
currency="USD",
|
||||
from_source=ConversationFromSource.CONSOLE,
|
||||
from_account_id=account_id,
|
||||
created_at=created_at,
|
||||
updated_at=created_at,
|
||||
app_mode=AppMode.CHAT,
|
||||
)
|
||||
db_session.add(message)
|
||||
db_session.commit()
|
||||
return message
|
||||
|
||||
|
||||
class TestMessageValidators:
|
||||
def test_chat_messages_query_validators(self) -> None:
|
||||
assert ChatMessagesQuery.empty_to_none("") is None
|
||||
assert ChatMessagesQuery.empty_to_none("val") == "val"
|
||||
assert ChatMessagesQuery.validate_uuid(None) is None
|
||||
assert (
|
||||
ChatMessagesQuery.validate_uuid("123e4567-e89b-12d3-a456-426614174000")
|
||||
== "123e4567-e89b-12d3-a456-426614174000"
|
||||
)
|
||||
|
||||
def test_message_feedback_validators(self) -> None:
|
||||
assert (
|
||||
MessageFeedbackPayload.validate_message_id("123e4567-e89b-12d3-a456-426614174000")
|
||||
== "123e4567-e89b-12d3-a456-426614174000"
|
||||
)
|
||||
|
||||
def test_feedback_export_validators(self) -> None:
|
||||
assert FeedbackExportQuery.parse_bool(None) is None
|
||||
assert FeedbackExportQuery.parse_bool(True) is True
|
||||
assert FeedbackExportQuery.parse_bool("1") is True
|
||||
assert FeedbackExportQuery.parse_bool("0") is False
|
||||
assert FeedbackExportQuery.parse_bool("off") is False
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
FeedbackExportQuery.parse_bool("invalid")
|
||||
|
||||
|
||||
def test_chat_message_list_not_found(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
account, tenant = create_console_account_and_tenant(db_session_with_containers)
|
||||
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT)
|
||||
|
||||
response = test_client_with_containers.get(
|
||||
f"/console/api/apps/{app.id}/chat-messages",
|
||||
query_string={"conversation_id": str(uuid4())},
|
||||
headers=authenticate_console_client(test_client_with_containers, account),
|
||||
)
|
||||
|
||||
assert response.status_code == 404
|
||||
payload = response.get_json()
|
||||
assert payload is not None
|
||||
assert payload["code"] == "not_found"
|
||||
|
||||
|
||||
def test_chat_message_list_success(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
account, tenant = create_console_account_and_tenant(db_session_with_containers)
|
||||
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT)
|
||||
conversation = _create_conversation(db_session_with_containers, app.id, account.id, app.mode)
|
||||
_create_message(db_session_with_containers, app.id, conversation.id, account.id, created_at_offset_seconds=0)
|
||||
second = _create_message(
|
||||
db_session_with_containers,
|
||||
app.id,
|
||||
conversation.id,
|
||||
account.id,
|
||||
created_at_offset_seconds=1,
|
||||
)
|
||||
|
||||
with patch(
|
||||
"controllers.console.app.message.attach_message_extra_contents",
|
||||
side_effect=_attach_message_extra_contents,
|
||||
):
|
||||
response = test_client_with_containers.get(
|
||||
f"/console/api/apps/{app.id}/chat-messages",
|
||||
query_string={"conversation_id": conversation.id, "limit": 1},
|
||||
headers=authenticate_console_client(test_client_with_containers, account),
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
payload = response.get_json()
|
||||
assert payload is not None
|
||||
assert payload["limit"] == 1
|
||||
assert payload["has_more"] is True
|
||||
assert len(payload["data"]) == 1
|
||||
assert payload["data"][0]["id"] == second.id
|
||||
|
||||
|
||||
def test_message_feedback_not_found(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
account, tenant = create_console_account_and_tenant(db_session_with_containers)
|
||||
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT)
|
||||
|
||||
response = test_client_with_containers.post(
|
||||
f"/console/api/apps/{app.id}/feedbacks",
|
||||
json={"message_id": str(uuid4()), "rating": "like"},
|
||||
headers=authenticate_console_client(test_client_with_containers, account),
|
||||
)
|
||||
|
||||
assert response.status_code == 404
|
||||
payload = response.get_json()
|
||||
assert payload is not None
|
||||
assert payload["code"] == "not_found"
|
||||
|
||||
|
||||
def test_message_feedback_success(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
account, tenant = create_console_account_and_tenant(db_session_with_containers)
|
||||
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT)
|
||||
conversation = _create_conversation(db_session_with_containers, app.id, account.id, app.mode)
|
||||
message = _create_message(db_session_with_containers, app.id, conversation.id, account.id)
|
||||
|
||||
response = test_client_with_containers.post(
|
||||
f"/console/api/apps/{app.id}/feedbacks",
|
||||
json={"message_id": message.id, "rating": "like"},
|
||||
headers=authenticate_console_client(test_client_with_containers, account),
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.get_json() == {"result": "success"}
|
||||
|
||||
feedback = db_session_with_containers.scalar(
|
||||
select(MessageFeedback).where(MessageFeedback.message_id == message.id)
|
||||
)
|
||||
assert feedback is not None
|
||||
assert feedback.rating == FeedbackRating.LIKE
|
||||
assert feedback.from_account_id == account.id
|
||||
|
||||
|
||||
def test_message_annotation_count(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
account, tenant = create_console_account_and_tenant(db_session_with_containers)
|
||||
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT)
|
||||
conversation = _create_conversation(db_session_with_containers, app.id, account.id, app.mode)
|
||||
message = _create_message(db_session_with_containers, app.id, conversation.id, account.id)
|
||||
db_session_with_containers.add(
|
||||
MessageAnnotation(
|
||||
app_id=app.id,
|
||||
conversation_id=conversation.id,
|
||||
message_id=message.id,
|
||||
question="Q",
|
||||
content="A",
|
||||
account_id=account.id,
|
||||
)
|
||||
)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
response = test_client_with_containers.get(
|
||||
f"/console/api/apps/{app.id}/annotations/count",
|
||||
headers=authenticate_console_client(test_client_with_containers, account),
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.get_json() == {"count": 1}
|
||||
|
||||
|
||||
def test_message_suggested_questions_success(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
account, tenant = create_console_account_and_tenant(db_session_with_containers)
|
||||
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT)
|
||||
message_id = str(uuid4())
|
||||
|
||||
with patch(
|
||||
"controllers.console.app.message.MessageService.get_suggested_questions_after_answer",
|
||||
return_value=["q1", "q2"],
|
||||
):
|
||||
response = test_client_with_containers.get(
|
||||
f"/console/api/apps/{app.id}/chat-messages/{message_id}/suggested-questions",
|
||||
headers=authenticate_console_client(test_client_with_containers, account),
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.get_json() == {"data": ["q1", "q2"]}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("exc", "expected_status", "expected_code"),
|
||||
[
|
||||
(MessageNotExistsError(), 404, "not_found"),
|
||||
(ConversationNotExistsError(), 404, "not_found"),
|
||||
(ProviderTokenNotInitError(), 400, "provider_not_initialize"),
|
||||
(QuotaExceededError(), 400, "provider_quota_exceeded"),
|
||||
(ModelCurrentlyNotSupportError(), 400, "model_currently_not_support"),
|
||||
(SuggestedQuestionsAfterAnswerDisabledError(), 403, "app_suggested_questions_after_answer_disabled"),
|
||||
(Exception(), 500, "internal_server_error"),
|
||||
],
|
||||
)
|
||||
def test_message_suggested_questions_errors(
|
||||
exc: Exception,
|
||||
expected_status: int,
|
||||
expected_code: str,
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
account, tenant = create_console_account_and_tenant(db_session_with_containers)
|
||||
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT)
|
||||
message_id = str(uuid4())
|
||||
|
||||
with patch(
|
||||
"controllers.console.app.message.MessageService.get_suggested_questions_after_answer",
|
||||
side_effect=exc,
|
||||
):
|
||||
response = test_client_with_containers.get(
|
||||
f"/console/api/apps/{app.id}/chat-messages/{message_id}/suggested-questions",
|
||||
headers=authenticate_console_client(test_client_with_containers, account),
|
||||
)
|
||||
|
||||
assert response.status_code == expected_status
|
||||
payload = response.get_json()
|
||||
assert payload is not None
|
||||
assert payload["code"] == expected_code
|
||||
|
||||
|
||||
def test_message_feedback_export_success(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
account, tenant = create_console_account_and_tenant(db_session_with_containers)
|
||||
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT)
|
||||
|
||||
with patch("services.feedback_service.FeedbackService.export_feedbacks", return_value={"exported": True}):
|
||||
response = test_client_with_containers.get(
|
||||
f"/console/api/apps/{app.id}/feedbacks/export",
|
||||
headers=authenticate_console_client(test_client_with_containers, account),
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.get_json() == {"exported": True}
|
||||
|
||||
|
||||
def test_message_api_get_success(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
account, tenant = create_console_account_and_tenant(db_session_with_containers)
|
||||
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT)
|
||||
conversation = _create_conversation(db_session_with_containers, app.id, account.id, app.mode)
|
||||
message = _create_message(db_session_with_containers, app.id, conversation.id, account.id)
|
||||
|
||||
with patch(
|
||||
"controllers.console.app.message.attach_message_extra_contents",
|
||||
side_effect=_attach_message_extra_contents,
|
||||
):
|
||||
response = test_client_with_containers.get(
|
||||
f"/console/api/apps/{app.id}/messages/{message.id}",
|
||||
headers=authenticate_console_client(test_client_with_containers, account),
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
payload = response.get_json()
|
||||
assert payload is not None
|
||||
assert payload["id"] == message.id
|
||||
@ -0,0 +1,334 @@
|
||||
"""Controller integration tests for console statistic routes."""
|
||||
|
||||
from datetime import timedelta
|
||||
from decimal import Decimal
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
from flask.testing import FlaskClient
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.enums import ConversationFromSource, FeedbackFromSource, FeedbackRating
|
||||
from models.model import AppMode, Conversation, Message, MessageFeedback
|
||||
from tests.test_containers_integration_tests.controllers.console.helpers import (
|
||||
authenticate_console_client,
|
||||
create_console_account_and_tenant,
|
||||
create_console_app,
|
||||
)
|
||||
|
||||
|
||||
def _create_conversation(
|
||||
db_session: Session,
|
||||
app_id: str,
|
||||
account_id: str,
|
||||
*,
|
||||
mode: AppMode,
|
||||
created_at_offset_days: int = 0,
|
||||
) -> Conversation:
|
||||
created_at = naive_utc_now() + timedelta(days=created_at_offset_days)
|
||||
conversation = Conversation(
|
||||
app_id=app_id,
|
||||
app_model_config_id=None,
|
||||
model_provider=None,
|
||||
model_id="",
|
||||
override_model_configs=None,
|
||||
mode=mode,
|
||||
name="Stats Conversation",
|
||||
inputs={},
|
||||
introduction="",
|
||||
system_instruction="",
|
||||
system_instruction_tokens=0,
|
||||
status="normal",
|
||||
from_source=ConversationFromSource.CONSOLE,
|
||||
from_account_id=account_id,
|
||||
created_at=created_at,
|
||||
updated_at=created_at,
|
||||
)
|
||||
db_session.add(conversation)
|
||||
db_session.commit()
|
||||
return conversation
|
||||
|
||||
|
||||
def _create_message(
|
||||
db_session: Session,
|
||||
app_id: str,
|
||||
conversation_id: str,
|
||||
*,
|
||||
from_account_id: str | None,
|
||||
from_end_user_id: str | None = None,
|
||||
message_tokens: int = 1,
|
||||
answer_tokens: int = 1,
|
||||
total_price: Decimal = Decimal("0.01"),
|
||||
provider_response_latency: float = 1.0,
|
||||
created_at_offset_days: int = 0,
|
||||
) -> Message:
|
||||
created_at = naive_utc_now() + timedelta(days=created_at_offset_days)
|
||||
message = Message(
|
||||
app_id=app_id,
|
||||
model_provider=None,
|
||||
model_id="",
|
||||
override_model_configs=None,
|
||||
conversation_id=conversation_id,
|
||||
inputs={},
|
||||
query="Hello",
|
||||
message={"type": "text", "content": "Hello"},
|
||||
message_tokens=message_tokens,
|
||||
message_unit_price=Decimal("0.001"),
|
||||
message_price_unit=Decimal("0.001"),
|
||||
answer="Hi there",
|
||||
answer_tokens=answer_tokens,
|
||||
answer_unit_price=Decimal("0.001"),
|
||||
answer_price_unit=Decimal("0.001"),
|
||||
parent_message_id=None,
|
||||
provider_response_latency=provider_response_latency,
|
||||
total_price=total_price,
|
||||
currency="USD",
|
||||
invoke_from=InvokeFrom.EXPLORE,
|
||||
from_source=ConversationFromSource.CONSOLE,
|
||||
from_end_user_id=from_end_user_id,
|
||||
from_account_id=from_account_id,
|
||||
created_at=created_at,
|
||||
updated_at=created_at,
|
||||
app_mode=AppMode.CHAT,
|
||||
)
|
||||
db_session.add(message)
|
||||
db_session.commit()
|
||||
return message
|
||||
|
||||
|
||||
def _create_like_feedback(
|
||||
db_session: Session,
|
||||
app_id: str,
|
||||
conversation_id: str,
|
||||
message_id: str,
|
||||
account_id: str,
|
||||
) -> None:
|
||||
db_session.add(
|
||||
MessageFeedback(
|
||||
app_id=app_id,
|
||||
conversation_id=conversation_id,
|
||||
message_id=message_id,
|
||||
rating=FeedbackRating.LIKE,
|
||||
from_source=FeedbackFromSource.ADMIN,
|
||||
from_account_id=account_id,
|
||||
)
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def test_daily_message_statistic(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
account, tenant = create_console_account_and_tenant(db_session_with_containers)
|
||||
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT)
|
||||
conversation = _create_conversation(db_session_with_containers, app.id, account.id, mode=app.mode)
|
||||
_create_message(db_session_with_containers, app.id, conversation.id, from_account_id=account.id)
|
||||
|
||||
response = test_client_with_containers.get(
|
||||
f"/console/api/apps/{app.id}/statistics/daily-messages",
|
||||
headers=authenticate_console_client(test_client_with_containers, account),
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.get_json()["data"][0]["message_count"] == 1
|
||||
|
||||
|
||||
def test_daily_conversation_statistic(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
account, tenant = create_console_account_and_tenant(db_session_with_containers)
|
||||
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT)
|
||||
conversation = _create_conversation(db_session_with_containers, app.id, account.id, mode=app.mode)
|
||||
_create_message(db_session_with_containers, app.id, conversation.id, from_account_id=account.id)
|
||||
_create_message(db_session_with_containers, app.id, conversation.id, from_account_id=account.id)
|
||||
|
||||
response = test_client_with_containers.get(
|
||||
f"/console/api/apps/{app.id}/statistics/daily-conversations",
|
||||
headers=authenticate_console_client(test_client_with_containers, account),
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.get_json()["data"][0]["conversation_count"] == 1
|
||||
|
||||
|
||||
def test_daily_terminals_statistic(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
account, tenant = create_console_account_and_tenant(db_session_with_containers)
|
||||
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT)
|
||||
conversation = _create_conversation(db_session_with_containers, app.id, account.id, mode=app.mode)
|
||||
_create_message(
|
||||
db_session_with_containers,
|
||||
app.id,
|
||||
conversation.id,
|
||||
from_account_id=None,
|
||||
from_end_user_id=str(uuid4()),
|
||||
)
|
||||
|
||||
response = test_client_with_containers.get(
|
||||
f"/console/api/apps/{app.id}/statistics/daily-end-users",
|
||||
headers=authenticate_console_client(test_client_with_containers, account),
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.get_json()["data"][0]["terminal_count"] == 1
|
||||
|
||||
|
||||
def test_daily_token_cost_statistic(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
account, tenant = create_console_account_and_tenant(db_session_with_containers)
|
||||
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT)
|
||||
conversation = _create_conversation(db_session_with_containers, app.id, account.id, mode=app.mode)
|
||||
_create_message(
|
||||
db_session_with_containers,
|
||||
app.id,
|
||||
conversation.id,
|
||||
from_account_id=account.id,
|
||||
message_tokens=40,
|
||||
answer_tokens=60,
|
||||
total_price=Decimal("0.02"),
|
||||
)
|
||||
|
||||
response = test_client_with_containers.get(
|
||||
f"/console/api/apps/{app.id}/statistics/token-costs",
|
||||
headers=authenticate_console_client(test_client_with_containers, account),
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
payload = response.get_json()
|
||||
assert payload["data"][0]["token_count"] == 100
|
||||
assert Decimal(payload["data"][0]["total_price"]) == Decimal("0.02")
|
||||
|
||||
|
||||
def test_average_session_interaction_statistic(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
account, tenant = create_console_account_and_tenant(db_session_with_containers)
|
||||
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT)
|
||||
conversation = _create_conversation(db_session_with_containers, app.id, account.id, mode=app.mode)
|
||||
_create_message(db_session_with_containers, app.id, conversation.id, from_account_id=account.id)
|
||||
_create_message(db_session_with_containers, app.id, conversation.id, from_account_id=account.id)
|
||||
|
||||
response = test_client_with_containers.get(
|
||||
f"/console/api/apps/{app.id}/statistics/average-session-interactions",
|
||||
headers=authenticate_console_client(test_client_with_containers, account),
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.get_json()["data"][0]["interactions"] == 2.0
|
||||
|
||||
|
||||
def test_user_satisfaction_rate_statistic(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
account, tenant = create_console_account_and_tenant(db_session_with_containers)
|
||||
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT)
|
||||
conversation = _create_conversation(db_session_with_containers, app.id, account.id, mode=app.mode)
|
||||
first = _create_message(db_session_with_containers, app.id, conversation.id, from_account_id=account.id)
|
||||
for _ in range(9):
|
||||
_create_message(db_session_with_containers, app.id, conversation.id, from_account_id=account.id)
|
||||
_create_like_feedback(db_session_with_containers, app.id, conversation.id, first.id, account.id)
|
||||
|
||||
response = test_client_with_containers.get(
|
||||
f"/console/api/apps/{app.id}/statistics/user-satisfaction-rate",
|
||||
headers=authenticate_console_client(test_client_with_containers, account),
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.get_json()["data"][0]["rate"] == 100.0
|
||||
|
||||
|
||||
def test_average_response_time_statistic(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
account, tenant = create_console_account_and_tenant(db_session_with_containers)
|
||||
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.COMPLETION)
|
||||
conversation = _create_conversation(db_session_with_containers, app.id, account.id, mode=app.mode)
|
||||
_create_message(
|
||||
db_session_with_containers,
|
||||
app.id,
|
||||
conversation.id,
|
||||
from_account_id=account.id,
|
||||
provider_response_latency=1.234,
|
||||
)
|
||||
|
||||
response = test_client_with_containers.get(
|
||||
f"/console/api/apps/{app.id}/statistics/average-response-time",
|
||||
headers=authenticate_console_client(test_client_with_containers, account),
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.get_json()["data"][0]["latency"] == 1234.0
|
||||
|
||||
|
||||
def test_tokens_per_second_statistic(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
account, tenant = create_console_account_and_tenant(db_session_with_containers)
|
||||
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT)
|
||||
conversation = _create_conversation(db_session_with_containers, app.id, account.id, mode=app.mode)
|
||||
_create_message(
|
||||
db_session_with_containers,
|
||||
app.id,
|
||||
conversation.id,
|
||||
from_account_id=account.id,
|
||||
answer_tokens=31,
|
||||
provider_response_latency=2.0,
|
||||
)
|
||||
|
||||
response = test_client_with_containers.get(
|
||||
f"/console/api/apps/{app.id}/statistics/tokens-per-second",
|
||||
headers=authenticate_console_client(test_client_with_containers, account),
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.get_json()["data"][0]["tps"] == 15.5
|
||||
|
||||
|
||||
def test_invalid_time_range(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
account, tenant = create_console_account_and_tenant(db_session_with_containers)
|
||||
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT)
|
||||
|
||||
with patch("controllers.console.app.statistic.parse_time_range", side_effect=ValueError("Invalid time")):
|
||||
response = test_client_with_containers.get(
|
||||
f"/console/api/apps/{app.id}/statistics/daily-messages?start=invalid&end=invalid",
|
||||
headers=authenticate_console_client(test_client_with_containers, account),
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert response.get_json()["message"] == "Invalid time"
|
||||
|
||||
|
||||
def test_time_range_params_passed(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
import datetime
|
||||
|
||||
account, tenant = create_console_account_and_tenant(db_session_with_containers)
|
||||
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT)
|
||||
start = datetime.datetime.now()
|
||||
end = datetime.datetime.now()
|
||||
|
||||
with patch("controllers.console.app.statistic.parse_time_range", return_value=(start, end)) as mock_parse:
|
||||
response = test_client_with_containers.get(
|
||||
f"/console/api/apps/{app.id}/statistics/daily-messages?start=something&end=something",
|
||||
headers=authenticate_console_client(test_client_with_containers, account),
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
mock_parse.assert_called_once_with("something", "something", "UTC")
|
||||
@ -0,0 +1,415 @@
|
||||
"""Authenticated controller integration tests for workflow draft variable APIs."""
|
||||
|
||||
import uuid
|
||||
|
||||
from flask.testing import FlaskClient
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID
|
||||
from dify_graph.variables.segments import StringSegment
|
||||
from factories.variable_factory import segment_to_variable
|
||||
from models import Workflow
|
||||
from models.model import AppMode
|
||||
from models.workflow import WorkflowDraftVariable
|
||||
from tests.test_containers_integration_tests.controllers.console.helpers import (
|
||||
authenticate_console_client,
|
||||
create_console_account_and_tenant,
|
||||
create_console_app,
|
||||
)
|
||||
|
||||
|
||||
def _create_draft_workflow(
|
||||
db_session: Session,
|
||||
app_id: str,
|
||||
tenant_id: str,
|
||||
account_id: str,
|
||||
*,
|
||||
environment_variables: list | None = None,
|
||||
conversation_variables: list | None = None,
|
||||
) -> Workflow:
|
||||
workflow = Workflow.new(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
type="workflow",
|
||||
version=Workflow.VERSION_DRAFT,
|
||||
graph='{"nodes": [], "edges": []}',
|
||||
features="{}",
|
||||
created_by=account_id,
|
||||
environment_variables=environment_variables or [],
|
||||
conversation_variables=conversation_variables or [],
|
||||
rag_pipeline_variables=[],
|
||||
)
|
||||
db_session.add(workflow)
|
||||
db_session.commit()
|
||||
return workflow
|
||||
|
||||
|
||||
def _create_node_variable(
|
||||
db_session: Session,
|
||||
app_id: str,
|
||||
user_id: str,
|
||||
*,
|
||||
node_id: str = "node_1",
|
||||
name: str = "test_var",
|
||||
) -> WorkflowDraftVariable:
|
||||
variable = WorkflowDraftVariable.new_node_variable(
|
||||
app_id=app_id,
|
||||
user_id=user_id,
|
||||
node_id=node_id,
|
||||
name=name,
|
||||
value=StringSegment(value="test_value"),
|
||||
node_execution_id=str(uuid.uuid4()),
|
||||
visible=True,
|
||||
editable=True,
|
||||
)
|
||||
db_session.add(variable)
|
||||
db_session.commit()
|
||||
return variable
|
||||
|
||||
|
||||
def _create_system_variable(
|
||||
db_session: Session, app_id: str, user_id: str, name: str = "query"
|
||||
) -> WorkflowDraftVariable:
|
||||
variable = WorkflowDraftVariable.new_sys_variable(
|
||||
app_id=app_id,
|
||||
user_id=user_id,
|
||||
name=name,
|
||||
value=StringSegment(value="system-value"),
|
||||
node_execution_id=str(uuid.uuid4()),
|
||||
editable=True,
|
||||
)
|
||||
db_session.add(variable)
|
||||
db_session.commit()
|
||||
return variable
|
||||
|
||||
|
||||
def _build_environment_variable(name: str, value: str):
|
||||
return segment_to_variable(
|
||||
segment=StringSegment(value=value),
|
||||
selector=[ENVIRONMENT_VARIABLE_NODE_ID, name],
|
||||
name=name,
|
||||
description=f"Environment variable {name}",
|
||||
)
|
||||
|
||||
|
||||
def _build_conversation_variable(name: str, value: str):
|
||||
return segment_to_variable(
|
||||
segment=StringSegment(value=value),
|
||||
selector=[CONVERSATION_VARIABLE_NODE_ID, name],
|
||||
name=name,
|
||||
description=f"Conversation variable {name}",
|
||||
)
|
||||
|
||||
|
||||
def test_workflow_variable_collection_get_success(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
account, tenant = create_console_account_and_tenant(db_session_with_containers)
|
||||
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW)
|
||||
_create_draft_workflow(db_session_with_containers, app.id, tenant.id, account.id)
|
||||
|
||||
response = test_client_with_containers.get(
|
||||
f"/console/api/apps/{app.id}/workflows/draft/variables?page=1&limit=20",
|
||||
headers=authenticate_console_client(test_client_with_containers, account),
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.get_json() == {"items": [], "total": 0}
|
||||
|
||||
|
||||
def test_workflow_variable_collection_get_not_exist(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
account, tenant = create_console_account_and_tenant(db_session_with_containers)
|
||||
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW)
|
||||
|
||||
response = test_client_with_containers.get(
|
||||
f"/console/api/apps/{app.id}/workflows/draft/variables",
|
||||
headers=authenticate_console_client(test_client_with_containers, account),
|
||||
)
|
||||
|
||||
assert response.status_code == 404
|
||||
payload = response.get_json()
|
||||
assert payload is not None
|
||||
assert payload["code"] == "draft_workflow_not_exist"
|
||||
|
||||
|
||||
def test_workflow_variable_collection_delete(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
account, tenant = create_console_account_and_tenant(db_session_with_containers)
|
||||
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW)
|
||||
_create_node_variable(db_session_with_containers, app.id, account.id)
|
||||
_create_node_variable(db_session_with_containers, app.id, account.id, node_id="node_2", name="other_var")
|
||||
|
||||
response = test_client_with_containers.delete(
|
||||
f"/console/api/apps/{app.id}/workflows/draft/variables",
|
||||
headers=authenticate_console_client(test_client_with_containers, account),
|
||||
)
|
||||
|
||||
assert response.status_code == 204
|
||||
remaining = db_session_with_containers.scalars(
|
||||
select(WorkflowDraftVariable).where(
|
||||
WorkflowDraftVariable.app_id == app.id,
|
||||
WorkflowDraftVariable.user_id == account.id,
|
||||
)
|
||||
).all()
|
||||
assert remaining == []
|
||||
|
||||
|
||||
def test_node_variable_collection_get_success(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
account, tenant = create_console_account_and_tenant(db_session_with_containers)
|
||||
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW)
|
||||
node_variable = _create_node_variable(db_session_with_containers, app.id, account.id, node_id="node_123")
|
||||
_create_node_variable(db_session_with_containers, app.id, account.id, node_id="node_456", name="other")
|
||||
|
||||
response = test_client_with_containers.get(
|
||||
f"/console/api/apps/{app.id}/workflows/draft/nodes/node_123/variables",
|
||||
headers=authenticate_console_client(test_client_with_containers, account),
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
payload = response.get_json()
|
||||
assert payload is not None
|
||||
assert [item["id"] for item in payload["items"]] == [node_variable.id]
|
||||
|
||||
|
||||
def test_node_variable_collection_get_invalid_node_id(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
account, tenant = create_console_account_and_tenant(db_session_with_containers)
|
||||
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW)
|
||||
|
||||
response = test_client_with_containers.get(
|
||||
f"/console/api/apps/{app.id}/workflows/draft/nodes/sys/variables",
|
||||
headers=authenticate_console_client(test_client_with_containers, account),
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
payload = response.get_json()
|
||||
assert payload is not None
|
||||
assert payload["code"] == "invalid_param"
|
||||
|
||||
|
||||
def test_node_variable_collection_delete(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
account, tenant = create_console_account_and_tenant(db_session_with_containers)
|
||||
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW)
|
||||
target = _create_node_variable(db_session_with_containers, app.id, account.id, node_id="node_123")
|
||||
untouched = _create_node_variable(db_session_with_containers, app.id, account.id, node_id="node_456")
|
||||
target_id = target.id
|
||||
untouched_id = untouched.id
|
||||
|
||||
response = test_client_with_containers.delete(
|
||||
f"/console/api/apps/{app.id}/workflows/draft/nodes/node_123/variables",
|
||||
headers=authenticate_console_client(test_client_with_containers, account),
|
||||
)
|
||||
|
||||
assert response.status_code == 204
|
||||
assert (
|
||||
db_session_with_containers.scalar(select(WorkflowDraftVariable).where(WorkflowDraftVariable.id == target_id))
|
||||
is None
|
||||
)
|
||||
assert (
|
||||
db_session_with_containers.scalar(select(WorkflowDraftVariable).where(WorkflowDraftVariable.id == untouched_id))
|
||||
is not None
|
||||
)
|
||||
|
||||
|
||||
def test_variable_api_get_success(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
account, tenant = create_console_account_and_tenant(db_session_with_containers)
|
||||
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW)
|
||||
_create_draft_workflow(db_session_with_containers, app.id, tenant.id, account.id)
|
||||
variable = _create_node_variable(db_session_with_containers, app.id, account.id)
|
||||
|
||||
response = test_client_with_containers.get(
|
||||
f"/console/api/apps/{app.id}/workflows/draft/variables/{variable.id}",
|
||||
headers=authenticate_console_client(test_client_with_containers, account),
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
payload = response.get_json()
|
||||
assert payload is not None
|
||||
assert payload["id"] == variable.id
|
||||
assert payload["name"] == "test_var"
|
||||
|
||||
|
||||
def test_variable_api_get_not_found(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
account, tenant = create_console_account_and_tenant(db_session_with_containers)
|
||||
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW)
|
||||
_create_draft_workflow(db_session_with_containers, app.id, tenant.id, account.id)
|
||||
|
||||
response = test_client_with_containers.get(
|
||||
f"/console/api/apps/{app.id}/workflows/draft/variables/{uuid.uuid4()}",
|
||||
headers=authenticate_console_client(test_client_with_containers, account),
|
||||
)
|
||||
|
||||
assert response.status_code == 404
|
||||
payload = response.get_json()
|
||||
assert payload is not None
|
||||
assert payload["code"] == "not_found"
|
||||
|
||||
|
||||
def test_variable_api_patch_success(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
account, tenant = create_console_account_and_tenant(db_session_with_containers)
|
||||
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW)
|
||||
_create_draft_workflow(db_session_with_containers, app.id, tenant.id, account.id)
|
||||
variable = _create_node_variable(db_session_with_containers, app.id, account.id)
|
||||
|
||||
response = test_client_with_containers.patch(
|
||||
f"/console/api/apps/{app.id}/workflows/draft/variables/{variable.id}",
|
||||
headers=authenticate_console_client(test_client_with_containers, account),
|
||||
json={"name": "renamed_var"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
payload = response.get_json()
|
||||
assert payload is not None
|
||||
assert payload["id"] == variable.id
|
||||
assert payload["name"] == "renamed_var"
|
||||
|
||||
refreshed = db_session_with_containers.scalar(
|
||||
select(WorkflowDraftVariable).where(WorkflowDraftVariable.id == variable.id)
|
||||
)
|
||||
assert refreshed is not None
|
||||
assert refreshed.name == "renamed_var"
|
||||
|
||||
|
||||
def test_variable_api_delete_success(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
account, tenant = create_console_account_and_tenant(db_session_with_containers)
|
||||
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW)
|
||||
_create_draft_workflow(db_session_with_containers, app.id, tenant.id, account.id)
|
||||
variable = _create_node_variable(db_session_with_containers, app.id, account.id)
|
||||
|
||||
response = test_client_with_containers.delete(
|
||||
f"/console/api/apps/{app.id}/workflows/draft/variables/{variable.id}",
|
||||
headers=authenticate_console_client(test_client_with_containers, account),
|
||||
)
|
||||
|
||||
assert response.status_code == 204
|
||||
assert (
|
||||
db_session_with_containers.scalar(select(WorkflowDraftVariable).where(WorkflowDraftVariable.id == variable.id))
|
||||
is None
|
||||
)
|
||||
|
||||
|
||||
def test_variable_reset_api_put_success_returns_no_content_without_execution(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
account, tenant = create_console_account_and_tenant(db_session_with_containers)
|
||||
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW)
|
||||
_create_draft_workflow(db_session_with_containers, app.id, tenant.id, account.id)
|
||||
variable = _create_node_variable(db_session_with_containers, app.id, account.id)
|
||||
|
||||
response = test_client_with_containers.put(
|
||||
f"/console/api/apps/{app.id}/workflows/draft/variables/{variable.id}/reset",
|
||||
headers=authenticate_console_client(test_client_with_containers, account),
|
||||
)
|
||||
|
||||
assert response.status_code == 204
|
||||
assert (
|
||||
db_session_with_containers.scalar(select(WorkflowDraftVariable).where(WorkflowDraftVariable.id == variable.id))
|
||||
is None
|
||||
)
|
||||
|
||||
|
||||
def test_conversation_variable_collection_get(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
account, tenant = create_console_account_and_tenant(db_session_with_containers)
|
||||
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW)
|
||||
_create_draft_workflow(
|
||||
db_session_with_containers,
|
||||
app.id,
|
||||
tenant.id,
|
||||
account.id,
|
||||
conversation_variables=[_build_conversation_variable("session_name", "Alice")],
|
||||
)
|
||||
|
||||
response = test_client_with_containers.get(
|
||||
f"/console/api/apps/{app.id}/workflows/draft/conversation-variables",
|
||||
headers=authenticate_console_client(test_client_with_containers, account),
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
payload = response.get_json()
|
||||
assert payload is not None
|
||||
assert [item["name"] for item in payload["items"]] == ["session_name"]
|
||||
|
||||
created = db_session_with_containers.scalars(
|
||||
select(WorkflowDraftVariable).where(
|
||||
WorkflowDraftVariable.app_id == app.id,
|
||||
WorkflowDraftVariable.user_id == account.id,
|
||||
WorkflowDraftVariable.node_id == CONVERSATION_VARIABLE_NODE_ID,
|
||||
)
|
||||
).all()
|
||||
assert len(created) == 1
|
||||
|
||||
|
||||
def test_system_variable_collection_get(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
account, tenant = create_console_account_and_tenant(db_session_with_containers)
|
||||
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW)
|
||||
variable = _create_system_variable(db_session_with_containers, app.id, account.id)
|
||||
|
||||
response = test_client_with_containers.get(
|
||||
f"/console/api/apps/{app.id}/workflows/draft/system-variables",
|
||||
headers=authenticate_console_client(test_client_with_containers, account),
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
payload = response.get_json()
|
||||
assert payload is not None
|
||||
assert [item["id"] for item in payload["items"]] == [variable.id]
|
||||
|
||||
|
||||
def test_environment_variable_collection_get(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
account, tenant = create_console_account_and_tenant(db_session_with_containers)
|
||||
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW)
|
||||
_create_draft_workflow(
|
||||
db_session_with_containers,
|
||||
app.id,
|
||||
tenant.id,
|
||||
account.id,
|
||||
environment_variables=[_build_environment_variable("api_key", "secret-value")],
|
||||
)
|
||||
|
||||
response = test_client_with_containers.get(
|
||||
f"/console/api/apps/{app.id}/workflows/draft/environment-variables",
|
||||
headers=authenticate_console_client(test_client_with_containers, account),
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
payload = response.get_json()
|
||||
assert payload is not None
|
||||
assert payload["items"][0]["name"] == "api_key"
|
||||
assert payload["items"][0]["value"] == "secret-value"
|
||||
@ -0,0 +1,131 @@
|
||||
"""Controller integration tests for API key data source auth routes."""
|
||||
|
||||
import json
|
||||
from unittest.mock import patch
|
||||
|
||||
from flask.testing import FlaskClient
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from models.source import DataSourceApiKeyAuthBinding
|
||||
from tests.test_containers_integration_tests.controllers.console.helpers import (
|
||||
authenticate_console_client,
|
||||
create_console_account_and_tenant,
|
||||
)
|
||||
|
||||
|
||||
def test_get_api_key_auth_data_source(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
account, tenant = create_console_account_and_tenant(db_session_with_containers)
|
||||
binding = DataSourceApiKeyAuthBinding(
|
||||
tenant_id=tenant.id,
|
||||
category="api_key",
|
||||
provider="custom_provider",
|
||||
credentials=json.dumps({"auth_type": "api_key", "config": {"api_key": "encrypted"}}),
|
||||
disabled=False,
|
||||
)
|
||||
db_session_with_containers.add(binding)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
response = test_client_with_containers.get(
|
||||
"/console/api/api-key-auth/data-source",
|
||||
headers=authenticate_console_client(test_client_with_containers, account),
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
payload = response.get_json()
|
||||
assert payload is not None
|
||||
assert len(payload["sources"]) == 1
|
||||
assert payload["sources"][0]["provider"] == "custom_provider"
|
||||
|
||||
|
||||
def test_get_api_key_auth_data_source_empty(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
account, _tenant = create_console_account_and_tenant(db_session_with_containers)
|
||||
|
||||
response = test_client_with_containers.get(
|
||||
"/console/api/api-key-auth/data-source",
|
||||
headers=authenticate_console_client(test_client_with_containers, account),
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.get_json() == {"sources": []}
|
||||
|
||||
|
||||
def test_create_binding_successful(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
account, _tenant = create_console_account_and_tenant(db_session_with_containers)
|
||||
|
||||
with (
|
||||
patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.validate_api_key_auth_args"),
|
||||
patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.create_provider_auth"),
|
||||
):
|
||||
response = test_client_with_containers.post(
|
||||
"/console/api/api-key-auth/data-source/binding",
|
||||
json={"category": "api_key", "provider": "custom", "credentials": {"key": "value"}},
|
||||
headers=authenticate_console_client(test_client_with_containers, account),
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.get_json() == {"result": "success"}
|
||||
|
||||
|
||||
def test_create_binding_failure(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
account, _tenant = create_console_account_and_tenant(db_session_with_containers)
|
||||
|
||||
with (
|
||||
patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.validate_api_key_auth_args"),
|
||||
patch(
|
||||
"controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.create_provider_auth",
|
||||
side_effect=ValueError("Invalid structure"),
|
||||
),
|
||||
):
|
||||
response = test_client_with_containers.post(
|
||||
"/console/api/api-key-auth/data-source/binding",
|
||||
json={"category": "api_key", "provider": "custom", "credentials": {"key": "value"}},
|
||||
headers=authenticate_console_client(test_client_with_containers, account),
|
||||
)
|
||||
|
||||
assert response.status_code == 500
|
||||
payload = response.get_json()
|
||||
assert payload is not None
|
||||
assert payload["code"] == "auth_failed"
|
||||
assert payload["message"] == "Invalid structure"
|
||||
|
||||
|
||||
def test_delete_binding_successful(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
account, tenant = create_console_account_and_tenant(db_session_with_containers)
|
||||
binding = DataSourceApiKeyAuthBinding(
|
||||
tenant_id=tenant.id,
|
||||
category="api_key",
|
||||
provider="custom_provider",
|
||||
credentials=json.dumps({"auth_type": "api_key", "config": {"api_key": "encrypted"}}),
|
||||
disabled=False,
|
||||
)
|
||||
db_session_with_containers.add(binding)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
response = test_client_with_containers.delete(
|
||||
f"/console/api/api-key-auth/data-source/{binding.id}",
|
||||
headers=authenticate_console_client(test_client_with_containers, account),
|
||||
)
|
||||
|
||||
assert response.status_code == 204
|
||||
assert (
|
||||
db_session_with_containers.scalar(
|
||||
select(DataSourceApiKeyAuthBinding).where(DataSourceApiKeyAuthBinding.id == binding.id)
|
||||
)
|
||||
is None
|
||||
)
|
||||
@ -0,0 +1,120 @@
|
||||
"""Controller integration tests for console OAuth data source routes."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from flask.testing import FlaskClient
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from models.source import DataSourceOauthBinding
|
||||
from tests.test_containers_integration_tests.controllers.console.helpers import (
|
||||
authenticate_console_client,
|
||||
create_console_account_and_tenant,
|
||||
)
|
||||
|
||||
|
||||
def test_get_oauth_url_successful(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
account, tenant = create_console_account_and_tenant(db_session_with_containers)
|
||||
provider = MagicMock()
|
||||
provider.get_authorization_url.return_value = "http://oauth.provider/auth"
|
||||
|
||||
with (
|
||||
patch("controllers.console.auth.data_source_oauth.get_oauth_providers", return_value={"notion": provider}),
|
||||
patch("controllers.console.auth.data_source_oauth.dify_config.NOTION_INTEGRATION_TYPE", None),
|
||||
):
|
||||
response = test_client_with_containers.get(
|
||||
"/console/api/oauth/data-source/notion",
|
||||
headers=authenticate_console_client(test_client_with_containers, account),
|
||||
)
|
||||
|
||||
assert tenant.id == account.current_tenant_id
|
||||
assert response.status_code == 200
|
||||
assert response.get_json() == {"data": "http://oauth.provider/auth"}
|
||||
provider.get_authorization_url.assert_called_once()
|
||||
|
||||
|
||||
def test_get_oauth_url_invalid_provider(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
account, _tenant = create_console_account_and_tenant(db_session_with_containers)
|
||||
|
||||
with patch("controllers.console.auth.data_source_oauth.get_oauth_providers", return_value={"notion": MagicMock()}):
|
||||
response = test_client_with_containers.get(
|
||||
"/console/api/oauth/data-source/unknown_provider",
|
||||
headers=authenticate_console_client(test_client_with_containers, account),
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert response.get_json() == {"error": "Invalid provider"}
|
||||
|
||||
|
||||
def test_oauth_callback_successful(test_client_with_containers: FlaskClient) -> None:
|
||||
with patch("controllers.console.auth.data_source_oauth.get_oauth_providers", return_value={"notion": MagicMock()}):
|
||||
response = test_client_with_containers.get("/console/api/oauth/data-source/callback/notion?code=mock_code")
|
||||
|
||||
assert response.status_code == 302
|
||||
assert "code=mock_code" in response.location
|
||||
|
||||
|
||||
def test_oauth_callback_missing_code(test_client_with_containers: FlaskClient) -> None:
|
||||
with patch("controllers.console.auth.data_source_oauth.get_oauth_providers", return_value={"notion": MagicMock()}):
|
||||
response = test_client_with_containers.get("/console/api/oauth/data-source/callback/notion")
|
||||
|
||||
assert response.status_code == 302
|
||||
assert "error=Access%20denied" in response.location
|
||||
|
||||
|
||||
def test_oauth_callback_invalid_provider(test_client_with_containers: FlaskClient) -> None:
|
||||
with patch("controllers.console.auth.data_source_oauth.get_oauth_providers", return_value={"notion": MagicMock()}):
|
||||
response = test_client_with_containers.get("/console/api/oauth/data-source/callback/invalid?code=mock_code")
|
||||
|
||||
assert response.status_code == 400
|
||||
assert response.get_json() == {"error": "Invalid provider"}
|
||||
|
||||
|
||||
def test_get_binding_successful(test_client_with_containers: FlaskClient) -> None:
|
||||
provider = MagicMock()
|
||||
with patch("controllers.console.auth.data_source_oauth.get_oauth_providers", return_value={"notion": provider}):
|
||||
response = test_client_with_containers.get("/console/api/oauth/data-source/binding/notion?code=auth_code_123")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.get_json() == {"result": "success"}
|
||||
provider.get_access_token.assert_called_once_with("auth_code_123")
|
||||
|
||||
|
||||
def test_get_binding_missing_code(test_client_with_containers: FlaskClient) -> None:
|
||||
with patch("controllers.console.auth.data_source_oauth.get_oauth_providers", return_value={"notion": MagicMock()}):
|
||||
response = test_client_with_containers.get("/console/api/oauth/data-source/binding/notion?code=")
|
||||
|
||||
assert response.status_code == 400
|
||||
assert response.get_json() == {"error": "Invalid code"}
|
||||
|
||||
|
||||
def test_sync_successful(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
account, tenant = create_console_account_and_tenant(db_session_with_containers)
|
||||
binding = DataSourceOauthBinding(
|
||||
tenant_id=tenant.id,
|
||||
access_token="test-access-token",
|
||||
provider="notion",
|
||||
source_info={"workspace_name": "Workspace", "workspace_icon": None, "workspace_id": tenant.id, "pages": []},
|
||||
disabled=False,
|
||||
)
|
||||
db_session_with_containers.add(binding)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
provider = MagicMock()
|
||||
with patch("controllers.console.auth.data_source_oauth.get_oauth_providers", return_value={"notion": provider}):
|
||||
response = test_client_with_containers.get(
|
||||
f"/console/api/oauth/data-source/notion/{binding.id}/sync",
|
||||
headers=authenticate_console_client(test_client_with_containers, account),
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.get_json() == {"result": "success"}
|
||||
provider.sync_data_source.assert_called_once_with(binding.id)
|
||||
@ -0,0 +1,365 @@
|
||||
"""Controller integration tests for console OAuth server routes."""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
from flask.testing import FlaskClient
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from models.model import OAuthProviderApp
|
||||
from services.oauth_server import OAUTH_ACCESS_TOKEN_EXPIRES_IN
|
||||
from tests.test_containers_integration_tests.controllers.console.helpers import (
|
||||
authenticate_console_client,
|
||||
create_console_account_and_tenant,
|
||||
ensure_dify_setup,
|
||||
)
|
||||
|
||||
|
||||
def _build_oauth_provider_app() -> OAuthProviderApp:
|
||||
return OAuthProviderApp(
|
||||
app_icon="icon_url",
|
||||
client_id="test_client_id",
|
||||
client_secret="test_secret",
|
||||
app_label={"en-US": "Test App"},
|
||||
redirect_uris=["http://localhost/callback"],
|
||||
scope="read,write",
|
||||
)
|
||||
|
||||
|
||||
def test_oauth_provider_successful_post(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
ensure_dify_setup(db_session_with_containers)
|
||||
|
||||
with patch(
|
||||
"controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app",
|
||||
return_value=_build_oauth_provider_app(),
|
||||
):
|
||||
response = test_client_with_containers.post(
|
||||
"/console/api/oauth/provider",
|
||||
json={"client_id": "test_client_id", "redirect_uri": "http://localhost/callback"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
payload = response.get_json()
|
||||
assert payload is not None
|
||||
assert payload["app_icon"] == "icon_url"
|
||||
assert payload["app_label"] == {"en-US": "Test App"}
|
||||
assert payload["scope"] == "read,write"
|
||||
|
||||
|
||||
def test_oauth_provider_invalid_redirect_uri(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
ensure_dify_setup(db_session_with_containers)
|
||||
|
||||
with patch(
|
||||
"controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app",
|
||||
return_value=_build_oauth_provider_app(),
|
||||
):
|
||||
response = test_client_with_containers.post(
|
||||
"/console/api/oauth/provider",
|
||||
json={"client_id": "test_client_id", "redirect_uri": "http://invalid/callback"},
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
payload = response.get_json()
|
||||
assert payload is not None
|
||||
assert "redirect_uri is invalid" in payload["message"]
|
||||
|
||||
|
||||
def test_oauth_provider_invalid_client_id(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
ensure_dify_setup(db_session_with_containers)
|
||||
|
||||
response = test_client_with_containers.post(
|
||||
"/console/api/oauth/provider",
|
||||
json={"client_id": "test_invalid_client_id", "redirect_uri": "http://localhost/callback"},
|
||||
)
|
||||
|
||||
assert response.status_code == 404
|
||||
payload = response.get_json()
|
||||
assert payload is not None
|
||||
assert "client_id is invalid" in payload["message"]
|
||||
|
||||
|
||||
def test_oauth_authorize_successful(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
account, _tenant = create_console_account_and_tenant(db_session_with_containers)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app",
|
||||
return_value=_build_oauth_provider_app(),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.auth.oauth_server.OAuthServerService.sign_oauth_authorization_code",
|
||||
return_value="auth_code_123",
|
||||
) as mock_sign,
|
||||
):
|
||||
response = test_client_with_containers.post(
|
||||
"/console/api/oauth/provider/authorize",
|
||||
json={"client_id": "test_client_id"},
|
||||
headers=authenticate_console_client(test_client_with_containers, account),
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.get_json() == {"code": "auth_code_123"}
|
||||
mock_sign.assert_called_once_with("test_client_id", account.id)
|
||||
|
||||
|
||||
def test_oauth_token_authorization_code_grant(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
ensure_dify_setup(db_session_with_containers)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app",
|
||||
return_value=_build_oauth_provider_app(),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.auth.oauth_server.OAuthServerService.sign_oauth_access_token",
|
||||
return_value=("access_123", "refresh_123"),
|
||||
),
|
||||
):
|
||||
response = test_client_with_containers.post(
|
||||
"/console/api/oauth/provider/token",
|
||||
json={
|
||||
"client_id": "test_client_id",
|
||||
"grant_type": "authorization_code",
|
||||
"code": "auth_code",
|
||||
"client_secret": "test_secret",
|
||||
"redirect_uri": "http://localhost/callback",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.get_json() == {
|
||||
"access_token": "access_123",
|
||||
"token_type": "Bearer",
|
||||
"expires_in": OAUTH_ACCESS_TOKEN_EXPIRES_IN,
|
||||
"refresh_token": "refresh_123",
|
||||
}
|
||||
|
||||
|
||||
def test_oauth_token_authorization_code_grant_missing_code(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
ensure_dify_setup(db_session_with_containers)
|
||||
|
||||
with patch(
|
||||
"controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app",
|
||||
return_value=_build_oauth_provider_app(),
|
||||
):
|
||||
response = test_client_with_containers.post(
|
||||
"/console/api/oauth/provider/token",
|
||||
json={
|
||||
"client_id": "test_client_id",
|
||||
"grant_type": "authorization_code",
|
||||
"client_secret": "test_secret",
|
||||
"redirect_uri": "http://localhost/callback",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert response.get_json()["message"] == "code is required"
|
||||
|
||||
|
||||
def test_oauth_token_authorization_code_grant_invalid_secret(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
ensure_dify_setup(db_session_with_containers)
|
||||
|
||||
with patch(
|
||||
"controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app",
|
||||
return_value=_build_oauth_provider_app(),
|
||||
):
|
||||
response = test_client_with_containers.post(
|
||||
"/console/api/oauth/provider/token",
|
||||
json={
|
||||
"client_id": "test_client_id",
|
||||
"grant_type": "authorization_code",
|
||||
"code": "auth_code",
|
||||
"client_secret": "invalid_secret",
|
||||
"redirect_uri": "http://localhost/callback",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert response.get_json()["message"] == "client_secret is invalid"
|
||||
|
||||
|
||||
def test_oauth_token_authorization_code_grant_invalid_redirect_uri(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
ensure_dify_setup(db_session_with_containers)
|
||||
|
||||
with patch(
|
||||
"controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app",
|
||||
return_value=_build_oauth_provider_app(),
|
||||
):
|
||||
response = test_client_with_containers.post(
|
||||
"/console/api/oauth/provider/token",
|
||||
json={
|
||||
"client_id": "test_client_id",
|
||||
"grant_type": "authorization_code",
|
||||
"code": "auth_code",
|
||||
"client_secret": "test_secret",
|
||||
"redirect_uri": "http://invalid/callback",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert response.get_json()["message"] == "redirect_uri is invalid"
|
||||
|
||||
|
||||
def test_oauth_token_refresh_token_grant(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
ensure_dify_setup(db_session_with_containers)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app",
|
||||
return_value=_build_oauth_provider_app(),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.auth.oauth_server.OAuthServerService.sign_oauth_access_token",
|
||||
return_value=("new_access", "new_refresh"),
|
||||
),
|
||||
):
|
||||
response = test_client_with_containers.post(
|
||||
"/console/api/oauth/provider/token",
|
||||
json={"client_id": "test_client_id", "grant_type": "refresh_token", "refresh_token": "refresh_123"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.get_json() == {
|
||||
"access_token": "new_access",
|
||||
"token_type": "Bearer",
|
||||
"expires_in": OAUTH_ACCESS_TOKEN_EXPIRES_IN,
|
||||
"refresh_token": "new_refresh",
|
||||
}
|
||||
|
||||
|
||||
def test_oauth_token_refresh_token_grant_missing_token(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
ensure_dify_setup(db_session_with_containers)
|
||||
|
||||
with patch(
|
||||
"controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app",
|
||||
return_value=_build_oauth_provider_app(),
|
||||
):
|
||||
response = test_client_with_containers.post(
|
||||
"/console/api/oauth/provider/token",
|
||||
json={"client_id": "test_client_id", "grant_type": "refresh_token"},
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert response.get_json()["message"] == "refresh_token is required"
|
||||
|
||||
|
||||
def test_oauth_token_invalid_grant_type(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
ensure_dify_setup(db_session_with_containers)
|
||||
|
||||
with patch(
|
||||
"controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app",
|
||||
return_value=_build_oauth_provider_app(),
|
||||
):
|
||||
response = test_client_with_containers.post(
|
||||
"/console/api/oauth/provider/token",
|
||||
json={"client_id": "test_client_id", "grant_type": "invalid_grant"},
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert response.get_json()["message"] == "invalid grant_type"
|
||||
|
||||
|
||||
def test_oauth_account_successful_retrieval(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
ensure_dify_setup(db_session_with_containers)
|
||||
account, _tenant = create_console_account_and_tenant(db_session_with_containers)
|
||||
account.avatar = "avatar_url"
|
||||
db_session_with_containers.commit()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app",
|
||||
return_value=_build_oauth_provider_app(),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.auth.oauth_server.OAuthServerService.validate_oauth_access_token",
|
||||
return_value=account,
|
||||
),
|
||||
):
|
||||
response = test_client_with_containers.post(
|
||||
"/console/api/oauth/provider/account",
|
||||
json={"client_id": "test_client_id"},
|
||||
headers={"Authorization": "Bearer valid_access_token"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.get_json() == {
|
||||
"name": "Test User",
|
||||
"email": account.email,
|
||||
"avatar": "avatar_url",
|
||||
"interface_language": "en-US",
|
||||
"timezone": "UTC",
|
||||
}
|
||||
|
||||
|
||||
def test_oauth_account_missing_authorization_header(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
ensure_dify_setup(db_session_with_containers)
|
||||
|
||||
with patch(
|
||||
"controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app",
|
||||
return_value=_build_oauth_provider_app(),
|
||||
):
|
||||
response = test_client_with_containers.post(
|
||||
"/console/api/oauth/provider/account",
|
||||
json={"client_id": "test_client_id"},
|
||||
)
|
||||
|
||||
assert response.status_code == 401
|
||||
assert response.get_json() == {"error": "Authorization header is required"}
|
||||
|
||||
|
||||
def test_oauth_account_invalid_authorization_header_format(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
ensure_dify_setup(db_session_with_containers)
|
||||
|
||||
with patch(
|
||||
"controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app",
|
||||
return_value=_build_oauth_provider_app(),
|
||||
):
|
||||
response = test_client_with_containers.post(
|
||||
"/console/api/oauth/provider/account",
|
||||
json={"client_id": "test_client_id"},
|
||||
headers={"Authorization": "InvalidFormat"},
|
||||
)
|
||||
|
||||
assert response.status_code == 401
|
||||
assert response.get_json() == {"error": "Invalid Authorization header format"}
|
||||
@ -1,17 +1,10 @@
|
||||
"""
|
||||
Test suite for password reset authentication flows.
|
||||
"""Testcontainers integration tests for password reset authentication flows."""
|
||||
|
||||
This module tests the password reset mechanism including:
|
||||
- Password reset email sending
|
||||
- Verification code validation
|
||||
- Password reset with token
|
||||
- Rate limiting and security checks
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from controllers.console.auth.error import (
|
||||
EmailCodeError,
|
||||
@ -28,31 +21,12 @@ from controllers.console.auth.forgot_password import (
|
||||
from controllers.console.error import AccountNotFound, EmailSendIpLimitError
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _mock_forgot_password_session():
|
||||
with patch("controllers.console.auth.forgot_password.Session") as mock_session_cls:
|
||||
mock_session = MagicMock()
|
||||
mock_session_cls.return_value.__enter__.return_value = mock_session
|
||||
mock_session_cls.return_value.__exit__.return_value = None
|
||||
yield mock_session
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _mock_forgot_password_db():
|
||||
with patch("controllers.console.auth.forgot_password.db") as mock_db:
|
||||
mock_db.engine = MagicMock()
|
||||
yield mock_db
|
||||
|
||||
|
||||
class TestForgotPasswordSendEmailApi:
|
||||
"""Test cases for sending password reset emails."""
|
||||
|
||||
@pytest.fixture
|
||||
def app(self):
|
||||
"""Create Flask test application."""
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
return app
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
@pytest.fixture
|
||||
def mock_account(self):
|
||||
@ -62,7 +36,6 @@ class TestForgotPasswordSendEmailApi:
|
||||
account.name = "Test User"
|
||||
return account
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.is_email_send_ip_limit")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.send_reset_password_email")
|
||||
@ -73,20 +46,10 @@ class TestForgotPasswordSendEmailApi:
|
||||
mock_send_email,
|
||||
mock_get_account,
|
||||
mock_is_ip_limit,
|
||||
mock_wraps_db,
|
||||
app,
|
||||
mock_account,
|
||||
):
|
||||
"""
|
||||
Test successful password reset email sending.
|
||||
|
||||
Verifies that:
|
||||
- Email is sent to valid account
|
||||
- Reset token is generated and returned
|
||||
- IP rate limiting is checked
|
||||
"""
|
||||
# Arrange
|
||||
mock_wraps_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_is_ip_limit.return_value = False
|
||||
mock_get_account.return_value = mock_account
|
||||
mock_send_email.return_value = "reset_token_123"
|
||||
@ -104,9 +67,8 @@ class TestForgotPasswordSendEmailApi:
|
||||
assert response["data"] == "reset_token_123"
|
||||
mock_send_email.assert_called_once()
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.is_email_send_ip_limit")
|
||||
def test_send_reset_email_ip_rate_limited(self, mock_is_ip_limit, mock_db, app):
|
||||
def test_send_reset_email_ip_rate_limited(self, mock_is_ip_limit, app):
|
||||
"""
|
||||
Test password reset email blocked by IP rate limit.
|
||||
|
||||
@ -115,7 +77,6 @@ class TestForgotPasswordSendEmailApi:
|
||||
- No email is sent when rate limited
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_is_ip_limit.return_value = True
|
||||
|
||||
# Act & Assert
|
||||
@ -133,7 +94,6 @@ class TestForgotPasswordSendEmailApi:
|
||||
(None, "en-US"), # Defaults to en-US when not provided
|
||||
],
|
||||
)
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.is_email_send_ip_limit")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.send_reset_password_email")
|
||||
@ -144,7 +104,6 @@ class TestForgotPasswordSendEmailApi:
|
||||
mock_send_email,
|
||||
mock_get_account,
|
||||
mock_is_ip_limit,
|
||||
mock_wraps_db,
|
||||
app,
|
||||
mock_account,
|
||||
language_input,
|
||||
@ -158,7 +117,6 @@ class TestForgotPasswordSendEmailApi:
|
||||
- Unsupported languages default to en-US
|
||||
"""
|
||||
# Arrange
|
||||
mock_wraps_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_is_ip_limit.return_value = False
|
||||
mock_get_account.return_value = mock_account
|
||||
mock_send_email.return_value = "token"
|
||||
@ -180,13 +138,9 @@ class TestForgotPasswordCheckApi:
|
||||
"""Test cases for verifying password reset codes."""
|
||||
|
||||
@pytest.fixture
|
||||
def app(self):
|
||||
"""Create Flask test application."""
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
return app
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token")
|
||||
@ -199,7 +153,6 @@ class TestForgotPasswordCheckApi:
|
||||
mock_revoke_token,
|
||||
mock_get_data,
|
||||
mock_is_rate_limit,
|
||||
mock_db,
|
||||
app,
|
||||
):
|
||||
"""
|
||||
@ -212,7 +165,6 @@ class TestForgotPasswordCheckApi:
|
||||
- Rate limit is reset on success
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_is_rate_limit.return_value = False
|
||||
mock_get_data.return_value = {"email": "test@example.com", "code": "123456"}
|
||||
mock_generate_token.return_value = (None, "new_token")
|
||||
@ -236,7 +188,6 @@ class TestForgotPasswordCheckApi:
|
||||
)
|
||||
mock_reset_rate_limit.assert_called_once_with("test@example.com")
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token")
|
||||
@ -249,10 +200,8 @@ class TestForgotPasswordCheckApi:
|
||||
mock_revoke_token,
|
||||
mock_get_data,
|
||||
mock_is_rate_limit,
|
||||
mock_db,
|
||||
app,
|
||||
):
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_is_rate_limit.return_value = False
|
||||
mock_get_data.return_value = {"email": "User@Example.com", "code": "999888"}
|
||||
mock_generate_token.return_value = (None, "fresh-token")
|
||||
@ -271,9 +220,8 @@ class TestForgotPasswordCheckApi:
|
||||
mock_revoke_token.assert_called_once_with("upper_token")
|
||||
mock_reset_rate_limit.assert_called_once_with("user@example.com")
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit")
|
||||
def test_verify_code_rate_limited(self, mock_is_rate_limit, mock_db, app):
|
||||
def test_verify_code_rate_limited(self, mock_is_rate_limit, app):
|
||||
"""
|
||||
Test code verification blocked by rate limit.
|
||||
|
||||
@ -282,7 +230,6 @@ class TestForgotPasswordCheckApi:
|
||||
- Prevents brute force attacks on verification codes
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_is_rate_limit.return_value = True
|
||||
|
||||
# Act & Assert
|
||||
@ -295,10 +242,9 @@ class TestForgotPasswordCheckApi:
|
||||
with pytest.raises(EmailPasswordResetLimitError):
|
||||
api.post()
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
|
||||
def test_verify_code_invalid_token(self, mock_get_data, mock_is_rate_limit, mock_db, app):
|
||||
def test_verify_code_invalid_token(self, mock_get_data, mock_is_rate_limit, app):
|
||||
"""
|
||||
Test code verification with invalid token.
|
||||
|
||||
@ -306,7 +252,6 @@ class TestForgotPasswordCheckApi:
|
||||
- InvalidTokenError is raised for invalid/expired tokens
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_is_rate_limit.return_value = False
|
||||
mock_get_data.return_value = None
|
||||
|
||||
@ -320,10 +265,9 @@ class TestForgotPasswordCheckApi:
|
||||
with pytest.raises(InvalidTokenError):
|
||||
api.post()
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
|
||||
def test_verify_code_email_mismatch(self, mock_get_data, mock_is_rate_limit, mock_db, app):
|
||||
def test_verify_code_email_mismatch(self, mock_get_data, mock_is_rate_limit, app):
|
||||
"""
|
||||
Test code verification with mismatched email.
|
||||
|
||||
@ -332,7 +276,6 @@ class TestForgotPasswordCheckApi:
|
||||
- Prevents token abuse
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_is_rate_limit.return_value = False
|
||||
mock_get_data.return_value = {"email": "original@example.com", "code": "123456"}
|
||||
|
||||
@ -346,11 +289,10 @@ class TestForgotPasswordCheckApi:
|
||||
with pytest.raises(InvalidEmailError):
|
||||
api.post()
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.add_forgot_password_error_rate_limit")
|
||||
def test_verify_code_wrong_code(self, mock_add_rate_limit, mock_get_data, mock_is_rate_limit, mock_db, app):
|
||||
def test_verify_code_wrong_code(self, mock_add_rate_limit, mock_get_data, mock_is_rate_limit, app):
|
||||
"""
|
||||
Test code verification with incorrect code.
|
||||
|
||||
@ -359,7 +301,6 @@ class TestForgotPasswordCheckApi:
|
||||
- Rate limit counter is incremented
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_is_rate_limit.return_value = False
|
||||
mock_get_data.return_value = {"email": "test@example.com", "code": "123456"}
|
||||
|
||||
@ -380,11 +321,8 @@ class TestForgotPasswordResetApi:
|
||||
"""Test cases for resetting password with verified token."""
|
||||
|
||||
@pytest.fixture
|
||||
def app(self):
|
||||
"""Create Flask test application."""
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
return app
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
@pytest.fixture
|
||||
def mock_account(self):
|
||||
@ -394,7 +332,6 @@ class TestForgotPasswordResetApi:
|
||||
account.name = "Test User"
|
||||
return account
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback")
|
||||
@ -405,7 +342,6 @@ class TestForgotPasswordResetApi:
|
||||
mock_get_account,
|
||||
mock_revoke_token,
|
||||
mock_get_data,
|
||||
mock_wraps_db,
|
||||
app,
|
||||
mock_account,
|
||||
):
|
||||
@ -418,7 +354,6 @@ class TestForgotPasswordResetApi:
|
||||
- Success response is returned
|
||||
"""
|
||||
# Arrange
|
||||
mock_wraps_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_data.return_value = {"email": "test@example.com", "phase": "reset"}
|
||||
mock_get_account.return_value = mock_account
|
||||
mock_get_tenants.return_value = [MagicMock()]
|
||||
@ -436,9 +371,8 @@ class TestForgotPasswordResetApi:
|
||||
assert response["result"] == "success"
|
||||
mock_revoke_token.assert_called_once_with("valid_token")
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
|
||||
def test_reset_password_mismatch(self, mock_get_data, mock_db, app):
|
||||
def test_reset_password_mismatch(self, mock_get_data, app):
|
||||
"""
|
||||
Test password reset with mismatched passwords.
|
||||
|
||||
@ -447,7 +381,6 @@ class TestForgotPasswordResetApi:
|
||||
- No password update occurs
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_data.return_value = {"email": "test@example.com", "phase": "reset"}
|
||||
|
||||
# Act & Assert
|
||||
@ -460,9 +393,8 @@ class TestForgotPasswordResetApi:
|
||||
with pytest.raises(PasswordMismatchError):
|
||||
api.post()
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
|
||||
def test_reset_password_invalid_token(self, mock_get_data, mock_db, app):
|
||||
def test_reset_password_invalid_token(self, mock_get_data, app):
|
||||
"""
|
||||
Test password reset with invalid token.
|
||||
|
||||
@ -470,7 +402,6 @@ class TestForgotPasswordResetApi:
|
||||
- InvalidTokenError is raised for invalid/expired tokens
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_data.return_value = None
|
||||
|
||||
# Act & Assert
|
||||
@ -483,9 +414,8 @@ class TestForgotPasswordResetApi:
|
||||
with pytest.raises(InvalidTokenError):
|
||||
api.post()
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
|
||||
def test_reset_password_wrong_phase(self, mock_get_data, mock_db, app):
|
||||
def test_reset_password_wrong_phase(self, mock_get_data, app):
|
||||
"""
|
||||
Test password reset with token not in reset phase.
|
||||
|
||||
@ -494,7 +424,6 @@ class TestForgotPasswordResetApi:
|
||||
- Prevents use of verification-phase tokens for reset
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_data.return_value = {"email": "test@example.com", "phase": "verify"}
|
||||
|
||||
# Act & Assert
|
||||
@ -507,13 +436,10 @@ class TestForgotPasswordResetApi:
|
||||
with pytest.raises(InvalidTokenError):
|
||||
api.post()
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback")
|
||||
def test_reset_password_account_not_found(
|
||||
self, mock_get_account, mock_revoke_token, mock_get_data, mock_wraps_db, app
|
||||
):
|
||||
def test_reset_password_account_not_found(self, mock_get_account, mock_revoke_token, mock_get_data, app):
|
||||
"""
|
||||
Test password reset for non-existent account.
|
||||
|
||||
@ -521,7 +447,6 @@ class TestForgotPasswordResetApi:
|
||||
- AccountNotFound is raised when account doesn't exist
|
||||
"""
|
||||
# Arrange
|
||||
mock_wraps_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_data.return_value = {"email": "nonexistent@example.com", "phase": "reset"}
|
||||
mock_get_account.return_value = None
|
||||
|
||||
@ -0,0 +1,85 @@
|
||||
"""Shared helpers for authenticated console controller integration tests."""
|
||||
|
||||
import uuid
|
||||
|
||||
from flask.testing import FlaskClient
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
from constants import HEADER_NAME_CSRF_TOKEN
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.token import _real_cookie_name, generate_csrf_token
|
||||
from models import Account, DifySetup, Tenant, TenantAccountJoin
|
||||
from models.account import AccountStatus, TenantAccountRole
|
||||
from models.model import App, AppMode
|
||||
from services.account_service import AccountService
|
||||
|
||||
|
||||
def ensure_dify_setup(db_session: Session) -> None:
|
||||
"""Create a setup marker once so setup-protected console routes can be exercised."""
|
||||
if db_session.scalar(select(DifySetup).limit(1)) is not None:
|
||||
return
|
||||
|
||||
db_session.add(DifySetup(version=dify_config.project.version))
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def create_console_account_and_tenant(db_session: Session) -> tuple[Account, Tenant]:
|
||||
"""Create an initialized owner account with a current tenant."""
|
||||
account = Account(
|
||||
email=f"test-{uuid.uuid4()}@example.com",
|
||||
name="Test User",
|
||||
interface_language="en-US",
|
||||
status=AccountStatus.ACTIVE,
|
||||
)
|
||||
account.initialized_at = naive_utc_now()
|
||||
db_session.add(account)
|
||||
db_session.commit()
|
||||
|
||||
tenant = Tenant(name="Test Tenant", status="normal")
|
||||
db_session.add(tenant)
|
||||
db_session.commit()
|
||||
|
||||
db_session.add(
|
||||
TenantAccountJoin(
|
||||
tenant_id=tenant.id,
|
||||
account_id=account.id,
|
||||
role=TenantAccountRole.OWNER,
|
||||
current=True,
|
||||
)
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
account.set_tenant_id(tenant.id)
|
||||
account.timezone = "UTC"
|
||||
db_session.commit()
|
||||
|
||||
ensure_dify_setup(db_session)
|
||||
return account, tenant
|
||||
|
||||
|
||||
def create_console_app(db_session: Session, tenant_id: str, account_id: str, mode: AppMode) -> App:
|
||||
"""Create a minimal app row that can be loaded by get_app_model."""
|
||||
app = App(
|
||||
tenant_id=tenant_id,
|
||||
name="Test App",
|
||||
mode=mode,
|
||||
enable_site=True,
|
||||
enable_api=True,
|
||||
created_by=account_id,
|
||||
)
|
||||
db_session.add(app)
|
||||
db_session.commit()
|
||||
return app
|
||||
|
||||
|
||||
def authenticate_console_client(test_client: FlaskClient, account: Account) -> dict[str, str]:
|
||||
"""Attach console auth cookies/headers for endpoints guarded by login_required."""
|
||||
access_token = AccountService.get_account_jwt_token(account)
|
||||
csrf_token = generate_csrf_token(account.id)
|
||||
test_client.set_cookie(_real_cookie_name("csrf_token"), csrf_token, domain="localhost")
|
||||
return {
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
HEADER_NAME_CSRF_TOKEN: csrf_token,
|
||||
}
|
||||
@ -1,27 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from extensions.ext_database import db
|
||||
from repositories.sqlalchemy_execution_extra_content_repository import SQLAlchemyExecutionExtraContentRepository
|
||||
from tests.test_containers_integration_tests.helpers.execution_extra_content import (
|
||||
create_human_input_message_fixture,
|
||||
)
|
||||
|
||||
|
||||
def test_get_by_message_ids_returns_human_input_content(db_session_with_containers):
|
||||
fixture = create_human_input_message_fixture(db_session_with_containers)
|
||||
repository = SQLAlchemyExecutionExtraContentRepository(
|
||||
session_maker=sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
)
|
||||
|
||||
results = repository.get_by_message_ids([fixture.message.id])
|
||||
|
||||
assert len(results) == 1
|
||||
assert len(results[0]) == 1
|
||||
content = results[0][0]
|
||||
assert content.submitted is True
|
||||
assert content.form_submission_data is not None
|
||||
assert content.form_submission_data.action_id == fixture.action_id
|
||||
assert content.form_submission_data.action_text == fixture.action_text
|
||||
assert content.form_submission_data.rendered_content == fixture.form.rendered_content
|
||||
@ -27,7 +27,7 @@ from models.human_input import (
|
||||
HumanInputFormRecipient,
|
||||
RecipientType,
|
||||
)
|
||||
from models.workflow import WorkflowAppLog, WorkflowPause, WorkflowPauseReason, WorkflowRun
|
||||
from models.workflow import WorkflowAppLog, WorkflowAppLogCreatedFrom, WorkflowPause, WorkflowPauseReason, WorkflowRun
|
||||
from repositories.entities.workflow_pause import WorkflowPauseEntity
|
||||
from repositories.sqlalchemy_api_workflow_run_repository import (
|
||||
DifyAPISQLAlchemyWorkflowRunRepository,
|
||||
@ -218,7 +218,7 @@ class TestDeleteRunsWithRelated:
|
||||
app_id=test_scope.app_id,
|
||||
workflow_id=test_scope.workflow_id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
created_from="service-api",
|
||||
created_from=WorkflowAppLogCreatedFrom.SERVICE_API,
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=test_scope.user_id,
|
||||
)
|
||||
@ -278,7 +278,7 @@ class TestCountRunsWithRelated:
|
||||
app_id=test_scope.app_id,
|
||||
workflow_id=test_scope.workflow_id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
created_from="service-api",
|
||||
created_from=WorkflowAppLogCreatedFrom.SERVICE_API,
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=test_scope.user_id,
|
||||
)
|
||||
|
||||
@ -0,0 +1,407 @@
|
||||
"""Integration tests for SQLAlchemyExecutionExtraContentRepository using Testcontainers.
|
||||
|
||||
Part of #32454 — replaces the mock-based unit tests with real database interactions.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta
|
||||
from decimal import Decimal
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import Engine, delete, select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from dify_graph.nodes.human_input.entities import FormDefinition, UserAction
|
||||
from dify_graph.nodes.human_input.enums import HumanInputFormStatus
|
||||
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.enums import ConversationFromSource, InvokeFrom
|
||||
from models.execution_extra_content import ExecutionExtraContent, HumanInputContent
|
||||
from models.human_input import (
|
||||
ConsoleRecipientPayload,
|
||||
HumanInputDelivery,
|
||||
HumanInputForm,
|
||||
HumanInputFormRecipient,
|
||||
RecipientType,
|
||||
)
|
||||
from models.model import App, Conversation, Message
|
||||
from repositories.sqlalchemy_execution_extra_content_repository import SQLAlchemyExecutionExtraContentRepository
|
||||
|
||||
|
||||
@dataclass
|
||||
class _TestScope:
|
||||
"""Per-test data scope used to isolate DB rows.
|
||||
|
||||
IDs are populated after flushing the base entities to the database.
|
||||
"""
|
||||
|
||||
tenant_id: str = ""
|
||||
app_id: str = ""
|
||||
user_id: str = ""
|
||||
|
||||
|
||||
def _cleanup_scope_data(session: Session, scope: _TestScope) -> None:
|
||||
"""Remove test-created DB rows for a test scope."""
|
||||
form_ids_subquery = select(HumanInputForm.id).where(
|
||||
HumanInputForm.tenant_id == scope.tenant_id,
|
||||
)
|
||||
session.execute(delete(HumanInputFormRecipient).where(HumanInputFormRecipient.form_id.in_(form_ids_subquery)))
|
||||
session.execute(delete(HumanInputDelivery).where(HumanInputDelivery.form_id.in_(form_ids_subquery)))
|
||||
session.execute(
|
||||
delete(ExecutionExtraContent).where(
|
||||
ExecutionExtraContent.workflow_run_id.in_(
|
||||
select(HumanInputForm.workflow_run_id).where(HumanInputForm.tenant_id == scope.tenant_id)
|
||||
)
|
||||
)
|
||||
)
|
||||
session.execute(delete(HumanInputForm).where(HumanInputForm.tenant_id == scope.tenant_id))
|
||||
session.execute(delete(Message).where(Message.app_id == scope.app_id))
|
||||
session.execute(delete(Conversation).where(Conversation.app_id == scope.app_id))
|
||||
session.execute(delete(App).where(App.id == scope.app_id))
|
||||
session.execute(delete(TenantAccountJoin).where(TenantAccountJoin.tenant_id == scope.tenant_id))
|
||||
session.execute(delete(Account).where(Account.id == scope.user_id))
|
||||
session.execute(delete(Tenant).where(Tenant.id == scope.tenant_id))
|
||||
session.commit()
|
||||
|
||||
|
||||
def _seed_base_entities(session: Session, scope: _TestScope) -> None:
|
||||
"""Create the base tenant, account, and app needed by tests."""
|
||||
tenant = Tenant(name="Test Tenant")
|
||||
session.add(tenant)
|
||||
session.flush()
|
||||
scope.tenant_id = tenant.id
|
||||
|
||||
account = Account(
|
||||
name="Test Account",
|
||||
email=f"test_{uuid4()}@example.com",
|
||||
password="hashed-password",
|
||||
password_salt="salt",
|
||||
interface_language="en-US",
|
||||
timezone="UTC",
|
||||
)
|
||||
session.add(account)
|
||||
session.flush()
|
||||
scope.user_id = account.id
|
||||
|
||||
tenant_join = TenantAccountJoin(
|
||||
tenant_id=scope.tenant_id,
|
||||
account_id=scope.user_id,
|
||||
role=TenantAccountRole.OWNER,
|
||||
current=True,
|
||||
)
|
||||
session.add(tenant_join)
|
||||
|
||||
app = App(
|
||||
tenant_id=scope.tenant_id,
|
||||
name="Test App",
|
||||
description="",
|
||||
mode="chat",
|
||||
icon_type="emoji",
|
||||
icon="bot",
|
||||
icon_background="#FFFFFF",
|
||||
enable_site=False,
|
||||
enable_api=True,
|
||||
api_rpm=100,
|
||||
api_rph=100,
|
||||
is_demo=False,
|
||||
is_public=False,
|
||||
is_universal=False,
|
||||
created_by=scope.user_id,
|
||||
updated_by=scope.user_id,
|
||||
)
|
||||
session.add(app)
|
||||
session.flush()
|
||||
scope.app_id = app.id
|
||||
|
||||
|
||||
def _create_conversation(session: Session, scope: _TestScope) -> Conversation:
|
||||
conversation = Conversation(
|
||||
app_id=scope.app_id,
|
||||
mode="chat",
|
||||
name="Test Conversation",
|
||||
summary="",
|
||||
introduction="",
|
||||
system_instruction="",
|
||||
status="normal",
|
||||
invoke_from=InvokeFrom.EXPLORE,
|
||||
from_source=ConversationFromSource.CONSOLE,
|
||||
from_account_id=scope.user_id,
|
||||
from_end_user_id=None,
|
||||
)
|
||||
conversation.inputs = {}
|
||||
session.add(conversation)
|
||||
session.flush()
|
||||
return conversation
|
||||
|
||||
|
||||
def _create_message(
|
||||
session: Session,
|
||||
scope: _TestScope,
|
||||
conversation_id: str,
|
||||
workflow_run_id: str,
|
||||
) -> Message:
|
||||
message = Message(
|
||||
app_id=scope.app_id,
|
||||
conversation_id=conversation_id,
|
||||
inputs={},
|
||||
query="test query",
|
||||
message={"messages": []},
|
||||
answer="test answer",
|
||||
message_tokens=50,
|
||||
message_unit_price=Decimal("0.001"),
|
||||
answer_tokens=80,
|
||||
answer_unit_price=Decimal("0.001"),
|
||||
provider_response_latency=0.5,
|
||||
currency="USD",
|
||||
from_source=ConversationFromSource.CONSOLE,
|
||||
from_account_id=scope.user_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
)
|
||||
session.add(message)
|
||||
session.flush()
|
||||
return message
|
||||
|
||||
|
||||
def _create_submitted_form(
|
||||
session: Session,
|
||||
scope: _TestScope,
|
||||
*,
|
||||
workflow_run_id: str,
|
||||
action_id: str = "approve",
|
||||
action_title: str = "Approve",
|
||||
node_title: str = "Approval",
|
||||
) -> HumanInputForm:
|
||||
expiration_time = datetime.utcnow() + timedelta(days=1)
|
||||
form_definition = FormDefinition(
|
||||
form_content="content",
|
||||
inputs=[],
|
||||
user_actions=[UserAction(id=action_id, title=action_title)],
|
||||
rendered_content="rendered",
|
||||
expiration_time=expiration_time,
|
||||
node_title=node_title,
|
||||
display_in_ui=True,
|
||||
)
|
||||
form = HumanInputForm(
|
||||
tenant_id=scope.tenant_id,
|
||||
app_id=scope.app_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
node_id="node-id",
|
||||
form_definition=form_definition.model_dump_json(),
|
||||
rendered_content=f"Rendered {action_title}",
|
||||
status=HumanInputFormStatus.SUBMITTED,
|
||||
expiration_time=expiration_time,
|
||||
selected_action_id=action_id,
|
||||
)
|
||||
session.add(form)
|
||||
session.flush()
|
||||
return form
|
||||
|
||||
|
||||
def _create_waiting_form(
|
||||
session: Session,
|
||||
scope: _TestScope,
|
||||
*,
|
||||
workflow_run_id: str,
|
||||
default_values: dict | None = None,
|
||||
) -> HumanInputForm:
|
||||
expiration_time = datetime.utcnow() + timedelta(days=1)
|
||||
form_definition = FormDefinition(
|
||||
form_content="content",
|
||||
inputs=[],
|
||||
user_actions=[UserAction(id="approve", title="Approve")],
|
||||
rendered_content="rendered",
|
||||
expiration_time=expiration_time,
|
||||
default_values=default_values or {"name": "John"},
|
||||
node_title="Approval",
|
||||
display_in_ui=True,
|
||||
)
|
||||
form = HumanInputForm(
|
||||
tenant_id=scope.tenant_id,
|
||||
app_id=scope.app_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
node_id="node-id",
|
||||
form_definition=form_definition.model_dump_json(),
|
||||
rendered_content="Rendered block",
|
||||
status=HumanInputFormStatus.WAITING,
|
||||
expiration_time=expiration_time,
|
||||
)
|
||||
session.add(form)
|
||||
session.flush()
|
||||
return form
|
||||
|
||||
|
||||
def _create_human_input_content(
|
||||
session: Session,
|
||||
*,
|
||||
workflow_run_id: str,
|
||||
message_id: str,
|
||||
form_id: str,
|
||||
) -> HumanInputContent:
|
||||
content = HumanInputContent.new(
|
||||
workflow_run_id=workflow_run_id,
|
||||
message_id=message_id,
|
||||
form_id=form_id,
|
||||
)
|
||||
session.add(content)
|
||||
return content
|
||||
|
||||
|
||||
def _create_recipient(
|
||||
session: Session,
|
||||
*,
|
||||
form_id: str,
|
||||
delivery_id: str,
|
||||
recipient_type: RecipientType = RecipientType.CONSOLE,
|
||||
access_token: str = "token-1",
|
||||
) -> HumanInputFormRecipient:
|
||||
payload = ConsoleRecipientPayload(account_id=None)
|
||||
recipient = HumanInputFormRecipient(
|
||||
form_id=form_id,
|
||||
delivery_id=delivery_id,
|
||||
recipient_type=recipient_type,
|
||||
recipient_payload=payload.model_dump_json(),
|
||||
access_token=access_token,
|
||||
)
|
||||
session.add(recipient)
|
||||
return recipient
|
||||
|
||||
|
||||
def _create_delivery(session: Session, *, form_id: str) -> HumanInputDelivery:
|
||||
from dify_graph.nodes.human_input.enums import DeliveryMethodType
|
||||
from models.human_input import ConsoleDeliveryPayload
|
||||
|
||||
delivery = HumanInputDelivery(
|
||||
form_id=form_id,
|
||||
delivery_method_type=DeliveryMethodType.WEBAPP,
|
||||
channel_payload=ConsoleDeliveryPayload().model_dump_json(),
|
||||
)
|
||||
session.add(delivery)
|
||||
session.flush()
|
||||
return delivery
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def repository(db_session_with_containers: Session) -> SQLAlchemyExecutionExtraContentRepository:
|
||||
"""Build a repository backed by the testcontainers database engine."""
|
||||
engine = db_session_with_containers.get_bind()
|
||||
assert isinstance(engine, Engine)
|
||||
return SQLAlchemyExecutionExtraContentRepository(sessionmaker(bind=engine, expire_on_commit=False))
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_scope(db_session_with_containers: Session) -> Generator[_TestScope]:
|
||||
"""Provide an isolated scope and clean related data after each test."""
|
||||
scope = _TestScope()
|
||||
_seed_base_entities(db_session_with_containers, scope)
|
||||
db_session_with_containers.commit()
|
||||
yield scope
|
||||
_cleanup_scope_data(db_session_with_containers, scope)
|
||||
|
||||
|
||||
class TestGetByMessageIds:
|
||||
"""Tests for SQLAlchemyExecutionExtraContentRepository.get_by_message_ids."""
|
||||
|
||||
def test_groups_contents_by_message(
|
||||
self,
|
||||
db_session_with_containers: Session,
|
||||
repository: SQLAlchemyExecutionExtraContentRepository,
|
||||
test_scope: _TestScope,
|
||||
) -> None:
|
||||
"""Submitted forms are correctly mapped and grouped by message ID."""
|
||||
workflow_run_id = str(uuid4())
|
||||
conversation = _create_conversation(db_session_with_containers, test_scope)
|
||||
msg1 = _create_message(db_session_with_containers, test_scope, conversation.id, workflow_run_id)
|
||||
msg2 = _create_message(db_session_with_containers, test_scope, conversation.id, workflow_run_id)
|
||||
|
||||
form = _create_submitted_form(
|
||||
db_session_with_containers,
|
||||
test_scope,
|
||||
workflow_run_id=workflow_run_id,
|
||||
action_id="approve",
|
||||
action_title="Approve",
|
||||
)
|
||||
_create_human_input_content(
|
||||
db_session_with_containers,
|
||||
workflow_run_id=workflow_run_id,
|
||||
message_id=msg1.id,
|
||||
form_id=form.id,
|
||||
)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
result = repository.get_by_message_ids([msg1.id, msg2.id])
|
||||
|
||||
assert len(result) == 2
|
||||
# msg1 has one submitted content
|
||||
assert len(result[0]) == 1
|
||||
content = result[0][0]
|
||||
assert content.submitted is True
|
||||
assert content.workflow_run_id == workflow_run_id
|
||||
assert content.form_submission_data is not None
|
||||
assert content.form_submission_data.action_id == "approve"
|
||||
assert content.form_submission_data.action_text == "Approve"
|
||||
assert content.form_submission_data.rendered_content == "Rendered Approve"
|
||||
assert content.form_submission_data.node_id == "node-id"
|
||||
assert content.form_submission_data.node_title == "Approval"
|
||||
# msg2 has no content
|
||||
assert result[1] == []
|
||||
|
||||
def test_returns_unsubmitted_form_definition(
|
||||
self,
|
||||
db_session_with_containers: Session,
|
||||
repository: SQLAlchemyExecutionExtraContentRepository,
|
||||
test_scope: _TestScope,
|
||||
) -> None:
|
||||
"""Waiting forms return full form_definition with resolved token and defaults."""
|
||||
workflow_run_id = str(uuid4())
|
||||
conversation = _create_conversation(db_session_with_containers, test_scope)
|
||||
msg = _create_message(db_session_with_containers, test_scope, conversation.id, workflow_run_id)
|
||||
|
||||
form = _create_waiting_form(
|
||||
db_session_with_containers,
|
||||
test_scope,
|
||||
workflow_run_id=workflow_run_id,
|
||||
default_values={"name": "John"},
|
||||
)
|
||||
delivery = _create_delivery(db_session_with_containers, form_id=form.id)
|
||||
_create_recipient(
|
||||
db_session_with_containers,
|
||||
form_id=form.id,
|
||||
delivery_id=delivery.id,
|
||||
access_token="token-1",
|
||||
)
|
||||
_create_human_input_content(
|
||||
db_session_with_containers,
|
||||
workflow_run_id=workflow_run_id,
|
||||
message_id=msg.id,
|
||||
form_id=form.id,
|
||||
)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
result = repository.get_by_message_ids([msg.id])
|
||||
|
||||
assert len(result) == 1
|
||||
assert len(result[0]) == 1
|
||||
domain_content = result[0][0]
|
||||
assert domain_content.submitted is False
|
||||
assert domain_content.workflow_run_id == workflow_run_id
|
||||
assert domain_content.form_definition is not None
|
||||
form_def = domain_content.form_definition
|
||||
assert form_def.form_id == form.id
|
||||
assert form_def.node_id == "node-id"
|
||||
assert form_def.node_title == "Approval"
|
||||
assert form_def.form_content == "Rendered block"
|
||||
assert form_def.display_in_ui is True
|
||||
assert form_def.form_token == "token-1"
|
||||
assert form_def.resolved_default_values == {"name": "John"}
|
||||
assert form_def.expiration_time == int(form.expiration_time.timestamp())
|
||||
|
||||
def test_empty_message_ids_returns_empty_list(
|
||||
self,
|
||||
repository: SQLAlchemyExecutionExtraContentRepository,
|
||||
) -> None:
|
||||
"""Passing no message IDs returns an empty list without hitting the DB."""
|
||||
result = repository.get_by_message_ids([])
|
||||
assert result == []
|
||||
@ -525,3 +525,147 @@ class TestAPIBasedExtensionService:
|
||||
# Try to get extension with wrong tenant ID
|
||||
with pytest.raises(ValueError, match="API based extension is not found"):
|
||||
APIBasedExtensionService.get_with_tenant_id(tenant2.id, created_extension.id)
|
||||
|
||||
def test_save_extension_api_key_exactly_four_chars_rejected(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""API key with exactly 4 characters should be rejected (boundary)."""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
assert tenant is not None
|
||||
|
||||
extension_data = APIBasedExtension(
|
||||
tenant_id=tenant.id,
|
||||
name=fake.company(),
|
||||
api_endpoint=f"https://{fake.domain_name()}/api",
|
||||
api_key="1234",
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="api_key must be at least 5 characters"):
|
||||
APIBasedExtensionService.save(extension_data)
|
||||
|
||||
def test_save_extension_api_key_exactly_five_chars_accepted(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""API key with exactly 5 characters should be accepted (boundary)."""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
assert tenant is not None
|
||||
|
||||
extension_data = APIBasedExtension(
|
||||
tenant_id=tenant.id,
|
||||
name=fake.company(),
|
||||
api_endpoint=f"https://{fake.domain_name()}/api",
|
||||
api_key="12345",
|
||||
)
|
||||
|
||||
saved = APIBasedExtensionService.save(extension_data)
|
||||
assert saved.id is not None
|
||||
|
||||
def test_save_extension_requestor_constructor_error(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""Exception raised by requestor constructor is wrapped in ValueError."""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
assert tenant is not None
|
||||
|
||||
mock_external_service_dependencies["requestor"].side_effect = RuntimeError("bad config")
|
||||
|
||||
extension_data = APIBasedExtension(
|
||||
tenant_id=tenant.id,
|
||||
name=fake.company(),
|
||||
api_endpoint=f"https://{fake.domain_name()}/api",
|
||||
api_key=fake.password(length=20),
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="connection error: bad config"):
|
||||
APIBasedExtensionService.save(extension_data)
|
||||
|
||||
def test_save_extension_network_exception(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""Network exceptions during ping are wrapped in ValueError."""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
assert tenant is not None
|
||||
|
||||
mock_external_service_dependencies["requestor_instance"].request.side_effect = ConnectionError(
|
||||
"network failure"
|
||||
)
|
||||
|
||||
extension_data = APIBasedExtension(
|
||||
tenant_id=tenant.id,
|
||||
name=fake.company(),
|
||||
api_endpoint=f"https://{fake.domain_name()}/api",
|
||||
api_key=fake.password(length=20),
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="connection error: network failure"):
|
||||
APIBasedExtensionService.save(extension_data)
|
||||
|
||||
def test_save_extension_update_duplicate_name_rejected(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""Updating an existing extension to use another extension's name should fail."""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
assert tenant is not None
|
||||
|
||||
ext1 = APIBasedExtensionService.save(
|
||||
APIBasedExtension(
|
||||
tenant_id=tenant.id,
|
||||
name="Extension Alpha",
|
||||
api_endpoint=f"https://{fake.domain_name()}/api",
|
||||
api_key=fake.password(length=20),
|
||||
)
|
||||
)
|
||||
ext2 = APIBasedExtensionService.save(
|
||||
APIBasedExtension(
|
||||
tenant_id=tenant.id,
|
||||
name="Extension Beta",
|
||||
api_endpoint=f"https://{fake.domain_name()}/api",
|
||||
api_key=fake.password(length=20),
|
||||
)
|
||||
)
|
||||
|
||||
# Try to rename ext2 to ext1's name
|
||||
ext2.name = "Extension Alpha"
|
||||
with pytest.raises(ValueError, match="name must be unique, it is already existed"):
|
||||
APIBasedExtensionService.save(ext2)
|
||||
|
||||
def test_get_all_returns_empty_for_different_tenant(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""Extensions from one tenant should not be visible to another."""
|
||||
fake = Faker()
|
||||
_, tenant1 = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
_, tenant2 = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
assert tenant1 is not None
|
||||
|
||||
APIBasedExtensionService.save(
|
||||
APIBasedExtension(
|
||||
tenant_id=tenant1.id,
|
||||
name=fake.company(),
|
||||
api_endpoint=f"https://{fake.domain_name()}/api",
|
||||
api_key=fake.password(length=20),
|
||||
)
|
||||
)
|
||||
|
||||
assert tenant2 is not None
|
||||
result = APIBasedExtensionService.get_all_by_tenant_id(tenant2.id)
|
||||
assert result == []
|
||||
|
||||
@ -0,0 +1,80 @@
|
||||
"""Testcontainers integration tests for AttachmentService."""
|
||||
|
||||
import base64
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
import services.attachment_service as attachment_service_module
|
||||
from extensions.ext_database import db
|
||||
from extensions.storage.storage_type import StorageType
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import UploadFile
|
||||
from services.attachment_service import AttachmentService
|
||||
|
||||
|
||||
class TestAttachmentService:
|
||||
def _create_upload_file(self, db_session_with_containers, *, tenant_id: str | None = None) -> UploadFile:
|
||||
upload_file = UploadFile(
|
||||
tenant_id=tenant_id or str(uuid4()),
|
||||
storage_type=StorageType.OPENDAL,
|
||||
key=f"upload/{uuid4()}.txt",
|
||||
name="test-file.txt",
|
||||
size=100,
|
||||
extension="txt",
|
||||
mime_type="text/plain",
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=str(uuid4()),
|
||||
created_at=datetime.now(UTC),
|
||||
used=False,
|
||||
)
|
||||
db_session_with_containers.add(upload_file)
|
||||
db_session_with_containers.commit()
|
||||
return upload_file
|
||||
|
||||
def test_should_initialize_with_sessionmaker(self):
|
||||
session_factory = sessionmaker()
|
||||
|
||||
service = AttachmentService(session_factory=session_factory)
|
||||
|
||||
assert service._session_maker is session_factory
|
||||
|
||||
def test_should_initialize_with_engine(self):
|
||||
engine = create_engine("sqlite:///:memory:")
|
||||
|
||||
service = AttachmentService(session_factory=engine)
|
||||
session = service._session_maker()
|
||||
try:
|
||||
assert session.bind == engine
|
||||
finally:
|
||||
session.close()
|
||||
engine.dispose()
|
||||
|
||||
@pytest.mark.parametrize("invalid_session_factory", [None, "not-a-session-factory", 1])
|
||||
def test_should_raise_assertion_error_for_invalid_session_factory(self, invalid_session_factory):
|
||||
with pytest.raises(AssertionError, match="must be a sessionmaker or an Engine."):
|
||||
AttachmentService(session_factory=invalid_session_factory)
|
||||
|
||||
def test_should_return_base64_when_file_exists(self, db_session_with_containers):
|
||||
upload_file = self._create_upload_file(db_session_with_containers)
|
||||
service = AttachmentService(session_factory=sessionmaker(bind=db.engine))
|
||||
|
||||
with patch.object(attachment_service_module.storage, "load_once", return_value=b"binary-content") as mock_load:
|
||||
result = service.get_file_base64(upload_file.id)
|
||||
|
||||
assert result == base64.b64encode(b"binary-content").decode()
|
||||
mock_load.assert_called_once_with(upload_file.key)
|
||||
|
||||
def test_should_raise_not_found_when_file_missing(self, db_session_with_containers):
|
||||
service = AttachmentService(session_factory=sessionmaker(bind=db.engine))
|
||||
|
||||
with patch.object(attachment_service_module.storage, "load_once") as mock_load:
|
||||
with pytest.raises(NotFound, match="File not found"):
|
||||
service.get_file_base64(str(uuid4()))
|
||||
|
||||
mock_load.assert_not_called()
|
||||
@ -0,0 +1,58 @@
|
||||
"""Testcontainers integration tests for ConversationVariableUpdater."""
|
||||
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from dify_graph.variables import StringVariable
|
||||
from extensions.ext_database import db
|
||||
from models.workflow import ConversationVariable
|
||||
from services.conversation_variable_updater import ConversationVariableNotFoundError, ConversationVariableUpdater
|
||||
|
||||
|
||||
class TestConversationVariableUpdater:
|
||||
def _create_conversation_variable(
|
||||
self, db_session_with_containers, *, conversation_id: str, variable: StringVariable, app_id: str | None = None
|
||||
) -> ConversationVariable:
|
||||
row = ConversationVariable(
|
||||
id=variable.id,
|
||||
conversation_id=conversation_id,
|
||||
app_id=app_id or str(uuid4()),
|
||||
data=variable.model_dump_json(),
|
||||
)
|
||||
db_session_with_containers.add(row)
|
||||
db_session_with_containers.commit()
|
||||
return row
|
||||
|
||||
def test_should_update_conversation_variable_data_and_commit(self, db_session_with_containers):
|
||||
conversation_id = str(uuid4())
|
||||
variable = StringVariable(id=str(uuid4()), name="topic", value="old value")
|
||||
self._create_conversation_variable(
|
||||
db_session_with_containers, conversation_id=conversation_id, variable=variable
|
||||
)
|
||||
|
||||
updated_variable = StringVariable(id=variable.id, name="topic", value="new value")
|
||||
updater = ConversationVariableUpdater(sessionmaker(bind=db.engine))
|
||||
|
||||
updater.update(conversation_id=conversation_id, variable=updated_variable)
|
||||
|
||||
db_session_with_containers.expire_all()
|
||||
row = db_session_with_containers.get(ConversationVariable, (variable.id, conversation_id))
|
||||
assert row is not None
|
||||
assert row.data == updated_variable.model_dump_json()
|
||||
|
||||
def test_should_raise_not_found_when_variable_missing(self, db_session_with_containers):
|
||||
conversation_id = str(uuid4())
|
||||
variable = StringVariable(id=str(uuid4()), name="topic", value="value")
|
||||
updater = ConversationVariableUpdater(sessionmaker(bind=db.engine))
|
||||
|
||||
with pytest.raises(ConversationVariableNotFoundError, match="conversation variable not found in the database"):
|
||||
updater.update(conversation_id=conversation_id, variable=variable)
|
||||
|
||||
def test_should_do_nothing_when_flush_is_called(self, db_session_with_containers):
|
||||
updater = ConversationVariableUpdater(sessionmaker(bind=db.engine))
|
||||
|
||||
result = updater.flush()
|
||||
|
||||
assert result is None
|
||||
@ -0,0 +1,104 @@
|
||||
"""Testcontainers integration tests for CreditPoolService."""
|
||||
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from core.errors.error import QuotaExceededError
|
||||
from models import TenantCreditPool
|
||||
from models.enums import ProviderQuotaType
|
||||
from services.credit_pool_service import CreditPoolService
|
||||
|
||||
|
||||
class TestCreditPoolService:
|
||||
def _create_tenant_id(self) -> str:
|
||||
return str(uuid4())
|
||||
|
||||
def test_create_default_pool(self, db_session_with_containers):
|
||||
tenant_id = self._create_tenant_id()
|
||||
|
||||
pool = CreditPoolService.create_default_pool(tenant_id)
|
||||
|
||||
assert isinstance(pool, TenantCreditPool)
|
||||
assert pool.tenant_id == tenant_id
|
||||
assert pool.pool_type == ProviderQuotaType.TRIAL
|
||||
assert pool.quota_used == 0
|
||||
assert pool.quota_limit > 0
|
||||
|
||||
def test_get_pool_returns_pool_when_exists(self, db_session_with_containers):
|
||||
tenant_id = self._create_tenant_id()
|
||||
CreditPoolService.create_default_pool(tenant_id)
|
||||
|
||||
result = CreditPoolService.get_pool(tenant_id=tenant_id, pool_type=ProviderQuotaType.TRIAL)
|
||||
|
||||
assert result is not None
|
||||
assert result.tenant_id == tenant_id
|
||||
assert result.pool_type == ProviderQuotaType.TRIAL
|
||||
|
||||
def test_get_pool_returns_none_when_not_exists(self, db_session_with_containers):
|
||||
result = CreditPoolService.get_pool(tenant_id=self._create_tenant_id(), pool_type=ProviderQuotaType.TRIAL)
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_check_credits_available_returns_false_when_no_pool(self, db_session_with_containers):
|
||||
result = CreditPoolService.check_credits_available(tenant_id=self._create_tenant_id(), credits_required=10)
|
||||
|
||||
assert result is False
|
||||
|
||||
def test_check_credits_available_returns_true_when_sufficient(self, db_session_with_containers):
|
||||
tenant_id = self._create_tenant_id()
|
||||
CreditPoolService.create_default_pool(tenant_id)
|
||||
|
||||
result = CreditPoolService.check_credits_available(tenant_id=tenant_id, credits_required=10)
|
||||
|
||||
assert result is True
|
||||
|
||||
def test_check_credits_available_returns_false_when_insufficient(self, db_session_with_containers):
|
||||
tenant_id = self._create_tenant_id()
|
||||
pool = CreditPoolService.create_default_pool(tenant_id)
|
||||
# Exhaust credits
|
||||
pool.quota_used = pool.quota_limit
|
||||
db_session_with_containers.commit()
|
||||
|
||||
result = CreditPoolService.check_credits_available(tenant_id=tenant_id, credits_required=1)
|
||||
|
||||
assert result is False
|
||||
|
||||
def test_check_and_deduct_credits_raises_when_no_pool(self, db_session_with_containers):
|
||||
with pytest.raises(QuotaExceededError, match="Credit pool not found"):
|
||||
CreditPoolService.check_and_deduct_credits(tenant_id=self._create_tenant_id(), credits_required=10)
|
||||
|
||||
def test_check_and_deduct_credits_raises_when_no_remaining(self, db_session_with_containers):
|
||||
tenant_id = self._create_tenant_id()
|
||||
pool = CreditPoolService.create_default_pool(tenant_id)
|
||||
pool.quota_used = pool.quota_limit
|
||||
db_session_with_containers.commit()
|
||||
|
||||
with pytest.raises(QuotaExceededError, match="No credits remaining"):
|
||||
CreditPoolService.check_and_deduct_credits(tenant_id=tenant_id, credits_required=10)
|
||||
|
||||
def test_check_and_deduct_credits_deducts_required_amount(self, db_session_with_containers):
|
||||
tenant_id = self._create_tenant_id()
|
||||
CreditPoolService.create_default_pool(tenant_id)
|
||||
credits_required = 10
|
||||
|
||||
result = CreditPoolService.check_and_deduct_credits(tenant_id=tenant_id, credits_required=credits_required)
|
||||
|
||||
assert result == credits_required
|
||||
db_session_with_containers.expire_all()
|
||||
pool = CreditPoolService.get_pool(tenant_id=tenant_id)
|
||||
assert pool.quota_used == credits_required
|
||||
|
||||
def test_check_and_deduct_credits_caps_at_remaining(self, db_session_with_containers):
|
||||
tenant_id = self._create_tenant_id()
|
||||
pool = CreditPoolService.create_default_pool(tenant_id)
|
||||
remaining = 5
|
||||
pool.quota_used = pool.quota_limit - remaining
|
||||
db_session_with_containers.commit()
|
||||
|
||||
result = CreditPoolService.check_and_deduct_credits(tenant_id=tenant_id, credits_required=200)
|
||||
|
||||
assert result == remaining
|
||||
db_session_with_containers.expire_all()
|
||||
updated_pool = CreditPoolService.get_pool(tenant_id=tenant_id)
|
||||
assert updated_pool.quota_used == pool.quota_limit
|
||||
@ -397,6 +397,68 @@ class TestDatasetPermissionServiceClearPartialMemberList:
|
||||
class TestDatasetServiceCheckDatasetPermission:
|
||||
"""Verify dataset access checks against persisted partial-member permissions."""
|
||||
|
||||
def test_check_dataset_permission_different_tenant_should_fail(self, db_session_with_containers):
|
||||
"""Test that users from different tenants cannot access dataset."""
|
||||
owner, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER)
|
||||
other_user, _ = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.EDITOR)
|
||||
|
||||
dataset = DatasetPermissionTestDataFactory.create_dataset(
|
||||
tenant.id, owner.id, permission=DatasetPermissionEnum.ALL_TEAM
|
||||
)
|
||||
|
||||
with pytest.raises(NoPermissionError):
|
||||
DatasetService.check_dataset_permission(dataset, other_user)
|
||||
|
||||
def test_check_dataset_permission_owner_can_access_any_dataset(self, db_session_with_containers):
|
||||
"""Test that tenant owners can access any dataset regardless of permission level."""
|
||||
owner, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER)
|
||||
creator, _ = DatasetPermissionTestDataFactory.create_account_with_tenant(
|
||||
role=TenantAccountRole.NORMAL, tenant=tenant
|
||||
)
|
||||
|
||||
dataset = DatasetPermissionTestDataFactory.create_dataset(
|
||||
tenant.id, creator.id, permission=DatasetPermissionEnum.ONLY_ME
|
||||
)
|
||||
|
||||
DatasetService.check_dataset_permission(dataset, owner)
|
||||
|
||||
def test_check_dataset_permission_only_me_creator_can_access(self, db_session_with_containers):
|
||||
"""Test ONLY_ME permission allows only the dataset creator to access."""
|
||||
creator, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.EDITOR)
|
||||
|
||||
dataset = DatasetPermissionTestDataFactory.create_dataset(
|
||||
tenant.id, creator.id, permission=DatasetPermissionEnum.ONLY_ME
|
||||
)
|
||||
|
||||
DatasetService.check_dataset_permission(dataset, creator)
|
||||
|
||||
def test_check_dataset_permission_only_me_others_cannot_access(self, db_session_with_containers):
|
||||
"""Test ONLY_ME permission denies access to non-creators."""
|
||||
creator, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.NORMAL)
|
||||
other, _ = DatasetPermissionTestDataFactory.create_account_with_tenant(
|
||||
role=TenantAccountRole.NORMAL, tenant=tenant
|
||||
)
|
||||
|
||||
dataset = DatasetPermissionTestDataFactory.create_dataset(
|
||||
tenant.id, creator.id, permission=DatasetPermissionEnum.ONLY_ME
|
||||
)
|
||||
|
||||
with pytest.raises(NoPermissionError):
|
||||
DatasetService.check_dataset_permission(dataset, other)
|
||||
|
||||
def test_check_dataset_permission_all_team_allows_access(self, db_session_with_containers):
|
||||
"""Test ALL_TEAM permission allows any team member to access the dataset."""
|
||||
creator, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.NORMAL)
|
||||
member, _ = DatasetPermissionTestDataFactory.create_account_with_tenant(
|
||||
role=TenantAccountRole.NORMAL, tenant=tenant
|
||||
)
|
||||
|
||||
dataset = DatasetPermissionTestDataFactory.create_dataset(
|
||||
tenant.id, creator.id, permission=DatasetPermissionEnum.ALL_TEAM
|
||||
)
|
||||
|
||||
DatasetService.check_dataset_permission(dataset, member)
|
||||
|
||||
def test_check_dataset_permission_partial_members_with_permission_success(self, db_session_with_containers):
|
||||
"""
|
||||
Test that user with explicit permission can access partial_members dataset.
|
||||
@ -443,6 +505,16 @@ class TestDatasetServiceCheckDatasetPermission:
|
||||
with pytest.raises(NoPermissionError, match="You do not have permission to access this dataset"):
|
||||
DatasetService.check_dataset_permission(dataset, user)
|
||||
|
||||
def test_check_dataset_permission_partial_team_creator_can_access(self, db_session_with_containers):
|
||||
"""Test PARTIAL_TEAM permission allows creator to access without explicit permission."""
|
||||
creator, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.EDITOR)
|
||||
|
||||
dataset = DatasetPermissionTestDataFactory.create_dataset(
|
||||
tenant.id, creator.id, permission=DatasetPermissionEnum.PARTIAL_TEAM
|
||||
)
|
||||
|
||||
DatasetService.check_dataset_permission(dataset, creator)
|
||||
|
||||
|
||||
class TestDatasetServiceCheckDatasetOperatorPermission:
|
||||
"""Verify operator permission checks against persisted partial-member permissions."""
|
||||
|
||||
@ -694,3 +694,19 @@ class TestDatasetServiceBatchUpdateDocumentStatus:
|
||||
|
||||
patched_dependencies["redis_client"].setex.assert_called_once_with(f"document_{doc1.id}_indexing", 600, 1)
|
||||
patched_dependencies["add_task"].delay.assert_called_once_with(doc1.id)
|
||||
|
||||
def test_batch_update_invalid_action_raises_value_error(
|
||||
self, db_session_with_containers: Session, patched_dependencies
|
||||
):
|
||||
"""Test that an invalid action raises ValueError."""
|
||||
factory = DocumentBatchUpdateIntegrationDataFactory
|
||||
dataset = factory.create_dataset(db_session_with_containers)
|
||||
doc = factory.create_document(db_session_with_containers, dataset)
|
||||
user = UserDouble(id=str(uuid4()))
|
||||
|
||||
patched_dependencies["redis_client"].get.return_value = None
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid action"):
|
||||
DocumentService.batch_update_document_status(
|
||||
dataset=dataset, document_ids=[doc.id], action="invalid_action", user=user
|
||||
)
|
||||
|
||||
@ -0,0 +1,60 @@
|
||||
"""Testcontainers integration tests for DatasetService.create_empty_rag_pipeline_dataset."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import Mock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from models.account import Account, Tenant, TenantAccountJoin
|
||||
from services.dataset_service import DatasetService
|
||||
from services.entities.knowledge_entities.rag_pipeline_entities import IconInfo, RagPipelineDatasetCreateEntity
|
||||
|
||||
|
||||
class TestDatasetServiceCreateRagPipelineDataset:
|
||||
def _create_tenant_and_account(self, db_session_with_containers) -> tuple[Tenant, Account]:
|
||||
tenant = Tenant(name=f"Tenant {uuid4()}")
|
||||
db_session_with_containers.add(tenant)
|
||||
db_session_with_containers.flush()
|
||||
|
||||
account = Account(
|
||||
name=f"Account {uuid4()}",
|
||||
email=f"ds_create_{uuid4()}@example.com",
|
||||
password="hashed",
|
||||
password_salt="salt",
|
||||
interface_language="en-US",
|
||||
timezone="UTC",
|
||||
)
|
||||
db_session_with_containers.add(account)
|
||||
db_session_with_containers.flush()
|
||||
|
||||
join = TenantAccountJoin(
|
||||
tenant_id=tenant.id,
|
||||
account_id=account.id,
|
||||
role="owner",
|
||||
current=True,
|
||||
)
|
||||
db_session_with_containers.add(join)
|
||||
db_session_with_containers.commit()
|
||||
return tenant, account
|
||||
|
||||
def _build_entity(self, name: str = "Test Dataset") -> RagPipelineDatasetCreateEntity:
|
||||
icon_info = IconInfo(icon="\U0001f4d9", icon_background="#FFF4ED", icon_type="emoji")
|
||||
return RagPipelineDatasetCreateEntity(
|
||||
name=name,
|
||||
description="",
|
||||
icon_info=icon_info,
|
||||
permission="only_me",
|
||||
)
|
||||
|
||||
def test_create_rag_pipeline_dataset_raises_when_current_user_id_is_none(self, db_session_with_containers):
|
||||
tenant, _ = self._create_tenant_and_account(db_session_with_containers)
|
||||
|
||||
mock_user = Mock(id=None)
|
||||
with patch("services.dataset_service.current_user", mock_user):
|
||||
with pytest.raises(ValueError, match="Current user or current user id not found"):
|
||||
DatasetService.create_empty_rag_pipeline_dataset(
|
||||
tenant_id=tenant.id,
|
||||
rag_pipeline_dataset_create_entity=self._build_entity(),
|
||||
)
|
||||
@ -142,3 +142,11 @@ def test_apply_display_status_filter_returns_same_when_invalid(db_session_with_c
|
||||
|
||||
rows = db_session_with_containers.scalars(filtered).all()
|
||||
assert {row.id for row in rows} == {doc1.id, doc2.id}
|
||||
|
||||
|
||||
def test_normalize_display_status_alias_mapping():
|
||||
"""Test that normalize_display_status maps aliases correctly."""
|
||||
assert DocumentService.normalize_display_status("ACTIVE") == "available"
|
||||
assert DocumentService.normalize_display_status("enabled") == "available"
|
||||
assert DocumentService.normalize_display_status("archived") == "archived"
|
||||
assert DocumentService.normalize_display_status("unknown") is None
|
||||
|
||||
@ -414,3 +414,144 @@ class TestEndUserServiceGetEndUserById:
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestEndUserServiceCreateBatch:
|
||||
"""Integration tests for EndUserService.create_end_user_batch."""
|
||||
|
||||
@pytest.fixture
|
||||
def factory(self):
|
||||
return TestEndUserServiceFactory()
|
||||
|
||||
def _create_multiple_apps(self, db_session_with_containers, factory, count: int = 3):
|
||||
"""Create multiple apps under the same tenant."""
|
||||
first_app = factory.create_app_and_account(db_session_with_containers)
|
||||
tenant_id = first_app.tenant_id
|
||||
apps = [first_app]
|
||||
for _ in range(count - 1):
|
||||
app = App(
|
||||
tenant_id=tenant_id,
|
||||
name=f"App {uuid4()}",
|
||||
description="",
|
||||
mode="chat",
|
||||
icon_type="emoji",
|
||||
icon="bot",
|
||||
icon_background="#FFFFFF",
|
||||
enable_site=False,
|
||||
enable_api=True,
|
||||
api_rpm=100,
|
||||
api_rph=100,
|
||||
is_demo=False,
|
||||
is_public=False,
|
||||
is_universal=False,
|
||||
created_by=first_app.created_by,
|
||||
updated_by=first_app.updated_by,
|
||||
)
|
||||
db_session_with_containers.add(app)
|
||||
db_session_with_containers.commit()
|
||||
all_apps = db_session_with_containers.query(App).filter(App.tenant_id == tenant_id).all()
|
||||
return tenant_id, all_apps
|
||||
|
||||
def test_create_batch_empty_app_ids(self, db_session_with_containers):
|
||||
result = EndUserService.create_end_user_batch(
|
||||
type=InvokeFrom.SERVICE_API, tenant_id=str(uuid4()), app_ids=[], user_id="user-1"
|
||||
)
|
||||
assert result == {}
|
||||
|
||||
def test_create_batch_creates_users_for_all_apps(self, db_session_with_containers, factory):
|
||||
tenant_id, apps = self._create_multiple_apps(db_session_with_containers, factory, count=3)
|
||||
app_ids = [a.id for a in apps]
|
||||
user_id = f"user-{uuid4()}"
|
||||
|
||||
result = EndUserService.create_end_user_batch(
|
||||
type=InvokeFrom.SERVICE_API, tenant_id=tenant_id, app_ids=app_ids, user_id=user_id
|
||||
)
|
||||
|
||||
assert len(result) == 3
|
||||
for app_id in app_ids:
|
||||
assert app_id in result
|
||||
assert result[app_id].session_id == user_id
|
||||
assert result[app_id].type == InvokeFrom.SERVICE_API
|
||||
|
||||
def test_create_batch_default_session_id(self, db_session_with_containers, factory):
|
||||
tenant_id, apps = self._create_multiple_apps(db_session_with_containers, factory, count=2)
|
||||
app_ids = [a.id for a in apps]
|
||||
|
||||
result = EndUserService.create_end_user_batch(
|
||||
type=InvokeFrom.SERVICE_API, tenant_id=tenant_id, app_ids=app_ids, user_id=""
|
||||
)
|
||||
|
||||
assert len(result) == 2
|
||||
for end_user in result.values():
|
||||
assert end_user.session_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID
|
||||
assert end_user._is_anonymous is True
|
||||
|
||||
def test_create_batch_deduplicate_app_ids(self, db_session_with_containers, factory):
|
||||
tenant_id, apps = self._create_multiple_apps(db_session_with_containers, factory, count=2)
|
||||
app_ids = [apps[0].id, apps[1].id, apps[0].id, apps[1].id]
|
||||
user_id = f"user-{uuid4()}"
|
||||
|
||||
result = EndUserService.create_end_user_batch(
|
||||
type=InvokeFrom.SERVICE_API, tenant_id=tenant_id, app_ids=app_ids, user_id=user_id
|
||||
)
|
||||
|
||||
assert len(result) == 2
|
||||
|
||||
def test_create_batch_returns_existing_users(self, db_session_with_containers, factory):
|
||||
tenant_id, apps = self._create_multiple_apps(db_session_with_containers, factory, count=2)
|
||||
app_ids = [a.id for a in apps]
|
||||
user_id = f"user-{uuid4()}"
|
||||
|
||||
# Create batch first time
|
||||
first_result = EndUserService.create_end_user_batch(
|
||||
type=InvokeFrom.SERVICE_API, tenant_id=tenant_id, app_ids=app_ids, user_id=user_id
|
||||
)
|
||||
|
||||
# Create batch second time — should return existing users
|
||||
second_result = EndUserService.create_end_user_batch(
|
||||
type=InvokeFrom.SERVICE_API, tenant_id=tenant_id, app_ids=app_ids, user_id=user_id
|
||||
)
|
||||
|
||||
assert len(second_result) == 2
|
||||
for app_id in app_ids:
|
||||
assert first_result[app_id].id == second_result[app_id].id
|
||||
|
||||
def test_create_batch_partial_existing_users(self, db_session_with_containers, factory):
|
||||
tenant_id, apps = self._create_multiple_apps(db_session_with_containers, factory, count=3)
|
||||
user_id = f"user-{uuid4()}"
|
||||
|
||||
# Create for first 2 apps
|
||||
first_result = EndUserService.create_end_user_batch(
|
||||
type=InvokeFrom.SERVICE_API,
|
||||
tenant_id=tenant_id,
|
||||
app_ids=[apps[0].id, apps[1].id],
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
# Create for all 3 apps — should reuse first 2, create 3rd
|
||||
all_result = EndUserService.create_end_user_batch(
|
||||
type=InvokeFrom.SERVICE_API,
|
||||
tenant_id=tenant_id,
|
||||
app_ids=[a.id for a in apps],
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
assert len(all_result) == 3
|
||||
assert all_result[apps[0].id].id == first_result[apps[0].id].id
|
||||
assert all_result[apps[1].id].id == first_result[apps[1].id].id
|
||||
assert all_result[apps[2].id].session_id == user_id
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"invoke_type",
|
||||
[InvokeFrom.SERVICE_API, InvokeFrom.WEB_APP, InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER],
|
||||
)
|
||||
def test_create_batch_all_invoke_types(self, db_session_with_containers, invoke_type, factory):
|
||||
tenant_id, apps = self._create_multiple_apps(db_session_with_containers, factory, count=1)
|
||||
user_id = f"user-{uuid4()}"
|
||||
|
||||
result = EndUserService.create_end_user_batch(
|
||||
type=invoke_type, tenant_id=tenant_id, app_ids=[apps[0].id], user_id=user_id
|
||||
)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[apps[0].id].type == invoke_type
|
||||
|
||||
@ -0,0 +1,96 @@
|
||||
"""
|
||||
Testcontainers integration tests for FileService helpers.
|
||||
|
||||
Covers:
|
||||
- ZIP tempfile building (sanitization + deduplication + content writes)
|
||||
- tenant-scoped batch lookup behavior (get_upload_files_by_ids)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from types import SimpleNamespace
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
from zipfile import ZipFile
|
||||
|
||||
import pytest
|
||||
|
||||
import services.file_service as file_service_module
|
||||
from extensions.storage.storage_type import StorageType
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import UploadFile
|
||||
from services.file_service import FileService
|
||||
|
||||
|
||||
def _create_upload_file(db_session, *, tenant_id: str, key: str, name: str) -> UploadFile:
|
||||
upload_file = UploadFile(
|
||||
tenant_id=tenant_id,
|
||||
storage_type=StorageType.OPENDAL,
|
||||
key=key,
|
||||
name=name,
|
||||
size=100,
|
||||
extension="txt",
|
||||
mime_type="text/plain",
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=str(uuid4()),
|
||||
created_at=datetime.now(UTC),
|
||||
used=False,
|
||||
)
|
||||
db_session.add(upload_file)
|
||||
db_session.commit()
|
||||
return upload_file
|
||||
|
||||
|
||||
def test_build_upload_files_zip_tempfile_sanitizes_and_dedupes_names(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Ensure ZIP entry names are safe and unique while preserving extensions."""
|
||||
upload_files: list[Any] = [
|
||||
SimpleNamespace(name="a/b.txt", key="k1"),
|
||||
SimpleNamespace(name="c/b.txt", key="k2"),
|
||||
SimpleNamespace(name="../b.txt", key="k3"),
|
||||
]
|
||||
|
||||
data_by_key: dict[str, list[bytes]] = {"k1": [b"one"], "k2": [b"two"], "k3": [b"three"]}
|
||||
|
||||
def _load(key: str, stream: bool = True) -> list[bytes]:
|
||||
assert stream is True
|
||||
return data_by_key[key]
|
||||
|
||||
monkeypatch.setattr(file_service_module.storage, "load", _load)
|
||||
|
||||
with FileService.build_upload_files_zip_tempfile(upload_files=upload_files) as tmp:
|
||||
with ZipFile(tmp, mode="r") as zf:
|
||||
assert zf.namelist() == ["b.txt", "b (1).txt", "b (2).txt"]
|
||||
assert zf.read("b.txt") == b"one"
|
||||
assert zf.read("b (1).txt") == b"two"
|
||||
assert zf.read("b (2).txt") == b"three"
|
||||
|
||||
|
||||
def test_get_upload_files_by_ids_returns_empty_when_no_ids(db_session_with_containers) -> None:
|
||||
"""Ensure empty input returns an empty mapping without hitting the database."""
|
||||
assert FileService.get_upload_files_by_ids(str(uuid4()), []) == {}
|
||||
|
||||
|
||||
def test_get_upload_files_by_ids_returns_id_keyed_mapping(db_session_with_containers) -> None:
|
||||
"""Ensure batch lookup returns a dict keyed by stringified UploadFile ids."""
|
||||
tenant_id = str(uuid4())
|
||||
file1 = _create_upload_file(db_session_with_containers, tenant_id=tenant_id, key="k1", name="file1.txt")
|
||||
file2 = _create_upload_file(db_session_with_containers, tenant_id=tenant_id, key="k2", name="file2.txt")
|
||||
|
||||
result = FileService.get_upload_files_by_ids(tenant_id, [file1.id, file1.id, file2.id])
|
||||
|
||||
assert set(result.keys()) == {file1.id, file2.id}
|
||||
assert result[file1.id].id == file1.id
|
||||
assert result[file2.id].id == file2.id
|
||||
|
||||
|
||||
def test_get_upload_files_by_ids_filters_by_tenant(db_session_with_containers) -> None:
|
||||
"""Ensure files from other tenants are not returned."""
|
||||
tenant_a = str(uuid4())
|
||||
tenant_b = str(uuid4())
|
||||
file_a = _create_upload_file(db_session_with_containers, tenant_id=tenant_a, key="ka", name="a.txt")
|
||||
_create_upload_file(db_session_with_containers, tenant_id=tenant_b, key="kb", name="b.txt")
|
||||
|
||||
result = FileService.get_upload_files_by_ids(tenant_a, [file_a.id])
|
||||
|
||||
assert set(result.keys()) == {file_a.id}
|
||||
@ -8,6 +8,7 @@ import pytest
|
||||
from faker import Faker
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from dify_graph.file.enums import FileType
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
@ -253,7 +254,7 @@ class TestMessagesCleanServiceIntegration:
|
||||
# MessageFile
|
||||
file = MessageFile(
|
||||
message_id=message.id,
|
||||
type="image",
|
||||
type=FileType.IMAGE,
|
||||
transfer_method="local_file",
|
||||
url="http://example.com/test.jpg",
|
||||
belongs_to=MessageFileBelongsTo.USER,
|
||||
|
||||
@ -0,0 +1,174 @@
|
||||
"""Testcontainers integration tests for OAuthServerService."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from typing import cast
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import BadRequest
|
||||
|
||||
from models.model import OAuthProviderApp
|
||||
from services.oauth_server import (
|
||||
OAUTH_ACCESS_TOKEN_EXPIRES_IN,
|
||||
OAUTH_ACCESS_TOKEN_REDIS_KEY,
|
||||
OAUTH_AUTHORIZATION_CODE_REDIS_KEY,
|
||||
OAUTH_REFRESH_TOKEN_EXPIRES_IN,
|
||||
OAUTH_REFRESH_TOKEN_REDIS_KEY,
|
||||
OAuthGrantType,
|
||||
OAuthServerService,
|
||||
)
|
||||
|
||||
|
||||
class TestOAuthServerServiceGetProviderApp:
|
||||
"""DB-backed tests for get_oauth_provider_app."""
|
||||
|
||||
def _create_oauth_provider_app(self, db_session_with_containers, *, client_id: str) -> OAuthProviderApp:
|
||||
app = OAuthProviderApp(
|
||||
app_icon="icon.png",
|
||||
client_id=client_id,
|
||||
client_secret=str(uuid4()),
|
||||
app_label={"en-US": "Test OAuth App"},
|
||||
redirect_uris=["https://example.com/callback"],
|
||||
scope="read",
|
||||
)
|
||||
db_session_with_containers.add(app)
|
||||
db_session_with_containers.commit()
|
||||
return app
|
||||
|
||||
def test_get_oauth_provider_app_returns_app_when_exists(self, db_session_with_containers):
|
||||
client_id = f"client-{uuid4()}"
|
||||
created = self._create_oauth_provider_app(db_session_with_containers, client_id=client_id)
|
||||
|
||||
result = OAuthServerService.get_oauth_provider_app(client_id)
|
||||
|
||||
assert result is not None
|
||||
assert result.client_id == client_id
|
||||
assert result.id == created.id
|
||||
|
||||
def test_get_oauth_provider_app_returns_none_when_not_exists(self, db_session_with_containers):
|
||||
result = OAuthServerService.get_oauth_provider_app(f"nonexistent-{uuid4()}")
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestOAuthServerServiceTokenOperations:
|
||||
"""Redis-backed tests for token sign/validate operations."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_redis(self):
|
||||
with patch("services.oauth_server.redis_client") as mock:
|
||||
yield mock
|
||||
|
||||
def test_sign_authorization_code_stores_and_returns_code(self, mock_redis):
|
||||
deterministic_uuid = uuid.UUID("00000000-0000-0000-0000-000000000111")
|
||||
with patch("services.oauth_server.uuid.uuid4", return_value=deterministic_uuid):
|
||||
code = OAuthServerService.sign_oauth_authorization_code("client-1", "user-1")
|
||||
|
||||
assert code == str(deterministic_uuid)
|
||||
mock_redis.set.assert_called_once_with(
|
||||
OAUTH_AUTHORIZATION_CODE_REDIS_KEY.format(client_id="client-1", code=code),
|
||||
"user-1",
|
||||
ex=600,
|
||||
)
|
||||
|
||||
def test_sign_access_token_raises_bad_request_for_invalid_code(self, mock_redis):
|
||||
mock_redis.get.return_value = None
|
||||
|
||||
with pytest.raises(BadRequest, match="invalid code"):
|
||||
OAuthServerService.sign_oauth_access_token(
|
||||
grant_type=OAuthGrantType.AUTHORIZATION_CODE,
|
||||
code="bad-code",
|
||||
client_id="client-1",
|
||||
)
|
||||
|
||||
def test_sign_access_token_issues_tokens_for_valid_code(self, mock_redis):
|
||||
token_uuids = [
|
||||
uuid.UUID("00000000-0000-0000-0000-000000000201"),
|
||||
uuid.UUID("00000000-0000-0000-0000-000000000202"),
|
||||
]
|
||||
with patch("services.oauth_server.uuid.uuid4", side_effect=token_uuids):
|
||||
mock_redis.get.return_value = b"user-1"
|
||||
|
||||
access_token, refresh_token = OAuthServerService.sign_oauth_access_token(
|
||||
grant_type=OAuthGrantType.AUTHORIZATION_CODE,
|
||||
code="code-1",
|
||||
client_id="client-1",
|
||||
)
|
||||
|
||||
assert access_token == str(token_uuids[0])
|
||||
assert refresh_token == str(token_uuids[1])
|
||||
code_key = OAUTH_AUTHORIZATION_CODE_REDIS_KEY.format(client_id="client-1", code="code-1")
|
||||
mock_redis.delete.assert_called_once_with(code_key)
|
||||
mock_redis.set.assert_any_call(
|
||||
OAUTH_ACCESS_TOKEN_REDIS_KEY.format(client_id="client-1", token=access_token),
|
||||
b"user-1",
|
||||
ex=OAUTH_ACCESS_TOKEN_EXPIRES_IN,
|
||||
)
|
||||
mock_redis.set.assert_any_call(
|
||||
OAUTH_REFRESH_TOKEN_REDIS_KEY.format(client_id="client-1", token=refresh_token),
|
||||
b"user-1",
|
||||
ex=OAUTH_REFRESH_TOKEN_EXPIRES_IN,
|
||||
)
|
||||
|
||||
def test_sign_access_token_raises_bad_request_for_invalid_refresh_token(self, mock_redis):
|
||||
mock_redis.get.return_value = None
|
||||
|
||||
with pytest.raises(BadRequest, match="invalid refresh token"):
|
||||
OAuthServerService.sign_oauth_access_token(
|
||||
grant_type=OAuthGrantType.REFRESH_TOKEN,
|
||||
refresh_token="stale-token",
|
||||
client_id="client-1",
|
||||
)
|
||||
|
||||
def test_sign_access_token_issues_new_token_for_valid_refresh(self, mock_redis):
|
||||
deterministic_uuid = uuid.UUID("00000000-0000-0000-0000-000000000301")
|
||||
with patch("services.oauth_server.uuid.uuid4", return_value=deterministic_uuid):
|
||||
mock_redis.get.return_value = b"user-1"
|
||||
|
||||
access_token, returned_refresh = OAuthServerService.sign_oauth_access_token(
|
||||
grant_type=OAuthGrantType.REFRESH_TOKEN,
|
||||
refresh_token="refresh-1",
|
||||
client_id="client-1",
|
||||
)
|
||||
|
||||
assert access_token == str(deterministic_uuid)
|
||||
assert returned_refresh == "refresh-1"
|
||||
|
||||
def test_sign_access_token_returns_none_for_unknown_grant_type(self, mock_redis):
|
||||
grant_type = cast(OAuthGrantType, "invalid-grant-type")
|
||||
|
||||
result = OAuthServerService.sign_oauth_access_token(grant_type=grant_type, client_id="client-1")
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_sign_refresh_token_stores_with_expected_expiry(self, mock_redis):
|
||||
deterministic_uuid = uuid.UUID("00000000-0000-0000-0000-000000000401")
|
||||
with patch("services.oauth_server.uuid.uuid4", return_value=deterministic_uuid):
|
||||
refresh_token = OAuthServerService._sign_oauth_refresh_token("client-2", "user-2")
|
||||
|
||||
assert refresh_token == str(deterministic_uuid)
|
||||
mock_redis.set.assert_called_once_with(
|
||||
OAUTH_REFRESH_TOKEN_REDIS_KEY.format(client_id="client-2", token=refresh_token),
|
||||
"user-2",
|
||||
ex=OAUTH_REFRESH_TOKEN_EXPIRES_IN,
|
||||
)
|
||||
|
||||
def test_validate_access_token_returns_none_when_not_found(self, mock_redis):
|
||||
mock_redis.get.return_value = None
|
||||
|
||||
result = OAuthServerService.validate_oauth_access_token("client-1", "missing-token")
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_validate_access_token_loads_user_when_exists(self, mock_redis):
|
||||
mock_redis.get.return_value = b"user-88"
|
||||
expected_user = MagicMock()
|
||||
|
||||
with patch("services.oauth_server.AccountService.load_user", return_value=expected_user) as mock_load:
|
||||
result = OAuthServerService.validate_oauth_access_token("client-1", "access-token")
|
||||
|
||||
assert result is expected_user
|
||||
mock_load.assert_called_once_with("user-88")
|
||||
@ -396,11 +396,6 @@ class TestSavedMessageService:
|
||||
|
||||
assert "User is required" in str(exc_info.value)
|
||||
|
||||
# Verify no database operations were performed
|
||||
|
||||
saved_messages = db_session_with_containers.query(SavedMessage).all()
|
||||
assert len(saved_messages) == 0
|
||||
|
||||
def test_save_error_no_user(self, db_session_with_containers: Session, mock_external_service_dependencies):
|
||||
"""
|
||||
Test error handling when saving message with no user.
|
||||
@ -497,124 +492,140 @@ class TestSavedMessageService:
|
||||
# The message should still exist, only the saved_message should be deleted
|
||||
assert db_session_with_containers.query(Message).where(Message.id == message.id).first() is not None
|
||||
|
||||
def test_pagination_by_last_id_error_no_user(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test error handling when no user is provided.
|
||||
|
||||
This test verifies:
|
||||
- Proper error handling for missing user
|
||||
- ValueError is raised when user is None
|
||||
- No database operations are performed
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
fake = Faker()
|
||||
def test_save_for_end_user(self, db_session_with_containers: Session, mock_external_service_dependencies):
|
||||
"""Test saving a message for an EndUser."""
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
end_user = self._create_test_end_user(db_session_with_containers, app)
|
||||
message = self._create_test_message(db_session_with_containers, app, end_user)
|
||||
|
||||
# Act & Assert: Verify proper error handling
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
SavedMessageService.pagination_by_last_id(app_model=app, user=None, last_id=None, limit=10)
|
||||
mock_external_service_dependencies["message_service"].get_message.return_value = message
|
||||
|
||||
assert "User is required" in str(exc_info.value)
|
||||
SavedMessageService.save(app_model=app, user=end_user, message_id=message.id)
|
||||
|
||||
# Verify no database operations were performed for this specific test
|
||||
# Note: We don't check total count as other tests may have created data
|
||||
# Instead, we verify that the error was properly raised
|
||||
pass
|
||||
|
||||
def test_save_error_no_user(self, db_session_with_containers: Session, mock_external_service_dependencies):
|
||||
"""
|
||||
Test error handling when saving message with no user.
|
||||
|
||||
This test verifies:
|
||||
- Method returns early when user is None
|
||||
- No database operations are performed
|
||||
- No exceptions are raised
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
message = self._create_test_message(db_session_with_containers, app, account)
|
||||
|
||||
# Act: Execute the method under test with None user
|
||||
result = SavedMessageService.save(app_model=app, user=None, message_id=message.id)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is None
|
||||
|
||||
# Verify no saved message was created
|
||||
|
||||
saved_message = (
|
||||
saved = (
|
||||
db_session_with_containers.query(SavedMessage)
|
||||
.where(
|
||||
SavedMessage.app_id == app.id,
|
||||
SavedMessage.message_id == message.id,
|
||||
)
|
||||
.where(SavedMessage.app_id == app.id, SavedMessage.message_id == message.id)
|
||||
.first()
|
||||
)
|
||||
assert saved is not None
|
||||
assert saved.created_by == end_user.id
|
||||
assert saved.created_by_role == "end_user"
|
||||
|
||||
assert saved_message is None
|
||||
|
||||
def test_delete_success_existing_message(
|
||||
def test_save_duplicate_is_idempotent(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful deletion of an existing saved message.
|
||||
|
||||
This test verifies:
|
||||
- Proper deletion of existing saved message
|
||||
- Correct database state after deletion
|
||||
- No errors during deletion process
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
fake = Faker()
|
||||
"""Test that saving an already-saved message does not create a duplicate."""
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
message = self._create_test_message(db_session_with_containers, app, account)
|
||||
|
||||
# Create a saved message first
|
||||
saved_message = SavedMessage(
|
||||
app_id=app.id,
|
||||
message_id=message.id,
|
||||
created_by_role="account",
|
||||
created_by=account.id,
|
||||
)
|
||||
mock_external_service_dependencies["message_service"].get_message.return_value = message
|
||||
|
||||
db_session_with_containers.add(saved_message)
|
||||
# Save once
|
||||
SavedMessageService.save(app_model=app, user=account, message_id=message.id)
|
||||
# Save again
|
||||
SavedMessageService.save(app_model=app, user=account, message_id=message.id)
|
||||
|
||||
count = (
|
||||
db_session_with_containers.query(SavedMessage)
|
||||
.where(SavedMessage.app_id == app.id, SavedMessage.message_id == message.id)
|
||||
.count()
|
||||
)
|
||||
assert count == 1
|
||||
|
||||
def test_delete_without_user_does_nothing(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""Test that deleting without a user is a no-op."""
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
message = self._create_test_message(db_session_with_containers, app, account)
|
||||
|
||||
# Pre-create a saved message
|
||||
saved = SavedMessage(app_id=app.id, message_id=message.id, created_by_role="account", created_by=account.id)
|
||||
db_session_with_containers.add(saved)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Verify saved message exists
|
||||
SavedMessageService.delete(app_model=app, user=None, message_id=message.id)
|
||||
|
||||
# Should still exist
|
||||
assert (
|
||||
db_session_with_containers.query(SavedMessage)
|
||||
.where(SavedMessage.app_id == app.id, SavedMessage.message_id == message.id)
|
||||
.first()
|
||||
is not None
|
||||
)
|
||||
|
||||
def test_delete_non_existent_does_nothing(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""Test that deleting a non-existent saved message is a no-op."""
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Should not raise — use a valid UUID that doesn't exist in DB
|
||||
from uuid import uuid4
|
||||
|
||||
SavedMessageService.delete(app_model=app, user=account, message_id=str(uuid4()))
|
||||
|
||||
def test_delete_for_end_user(self, db_session_with_containers: Session, mock_external_service_dependencies):
|
||||
"""Test deleting a saved message for an EndUser."""
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
end_user = self._create_test_end_user(db_session_with_containers, app)
|
||||
message = self._create_test_message(db_session_with_containers, app, end_user)
|
||||
|
||||
saved = SavedMessage(app_id=app.id, message_id=message.id, created_by_role="end_user", created_by=end_user.id)
|
||||
db_session_with_containers.add(saved)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
SavedMessageService.delete(app_model=app, user=end_user, message_id=message.id)
|
||||
|
||||
assert (
|
||||
db_session_with_containers.query(SavedMessage)
|
||||
.where(SavedMessage.app_id == app.id, SavedMessage.message_id == message.id)
|
||||
.first()
|
||||
is None
|
||||
)
|
||||
|
||||
def test_delete_only_affects_own_saved_messages(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""Test that delete only removes the requesting user's saved message."""
|
||||
app, account1 = self._create_test_app_and_account(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
end_user = self._create_test_end_user(db_session_with_containers, app)
|
||||
message = self._create_test_message(db_session_with_containers, app, account1)
|
||||
|
||||
# Both users save the same message
|
||||
saved_account = SavedMessage(
|
||||
app_id=app.id, message_id=message.id, created_by_role="account", created_by=account1.id
|
||||
)
|
||||
saved_end_user = SavedMessage(
|
||||
app_id=app.id, message_id=message.id, created_by_role="end_user", created_by=end_user.id
|
||||
)
|
||||
db_session_with_containers.add_all([saved_account, saved_end_user])
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Delete only account1's saved message
|
||||
SavedMessageService.delete(app_model=app, user=account1, message_id=message.id)
|
||||
|
||||
# Account's saved message should be gone
|
||||
assert (
|
||||
db_session_with_containers.query(SavedMessage)
|
||||
.where(
|
||||
SavedMessage.app_id == app.id,
|
||||
SavedMessage.message_id == message.id,
|
||||
SavedMessage.created_by_role == "account",
|
||||
SavedMessage.created_by == account.id,
|
||||
SavedMessage.created_by == account1.id,
|
||||
)
|
||||
.first()
|
||||
is not None
|
||||
is None
|
||||
)
|
||||
|
||||
# Act: Execute the method under test
|
||||
SavedMessageService.delete(app_model=app, user=account, message_id=message.id)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
# Check if saved message was deleted from database
|
||||
deleted_saved_message = (
|
||||
# End user's saved message should still exist
|
||||
assert (
|
||||
db_session_with_containers.query(SavedMessage)
|
||||
.where(
|
||||
SavedMessage.app_id == app.id,
|
||||
SavedMessage.message_id == message.id,
|
||||
SavedMessage.created_by_role == "account",
|
||||
SavedMessage.created_by == account.id,
|
||||
SavedMessage.created_by == end_user.id,
|
||||
)
|
||||
.first()
|
||||
is not None
|
||||
)
|
||||
|
||||
assert deleted_saved_message is None
|
||||
|
||||
# Verify database state
|
||||
db_session_with_containers.commit()
|
||||
# The message should still exist, only the saved_message should be deleted
|
||||
assert db_session_with_containers.query(Message).where(Message.id == message.id).first() is not None
|
||||
|
||||
@ -9,7 +9,7 @@ from werkzeug.exceptions import NotFound
|
||||
|
||||
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.dataset import Dataset
|
||||
from models.enums import DataSourceType
|
||||
from models.enums import DataSourceType, TagType
|
||||
from models.model import App, Tag, TagBinding
|
||||
from services.tag_service import TagService
|
||||
|
||||
@ -547,7 +547,7 @@ class TestTagService:
|
||||
assert result is not None
|
||||
assert len(result) == 1
|
||||
assert result[0].name == "python_tag"
|
||||
assert result[0].type == "app"
|
||||
assert result[0].type == TagType.APP
|
||||
assert result[0].tenant_id == tenant.id
|
||||
|
||||
def test_get_tag_by_tag_name_no_matches(
|
||||
@ -638,7 +638,7 @@ class TestTagService:
|
||||
|
||||
# Verify all tags are returned
|
||||
for tag in result:
|
||||
assert tag.type == "app"
|
||||
assert tag.type == TagType.APP
|
||||
assert tag.tenant_id == tenant.id
|
||||
assert tag.id in [t.id for t in tags]
|
||||
|
||||
|
||||
@ -10,6 +10,7 @@ from sqlalchemy.orm import Session
|
||||
from dify_graph.entities.workflow_execution import WorkflowExecutionStatus
|
||||
from models import EndUser, Workflow, WorkflowAppLog, WorkflowRun
|
||||
from models.enums import CreatorUserRole
|
||||
from models.workflow import WorkflowAppLogCreatedFrom
|
||||
from services.account_service import AccountService, TenantService
|
||||
|
||||
# Delay import of AppService to avoid circular dependency
|
||||
@ -221,7 +222,7 @@ class TestWorkflowAppService:
|
||||
app_id=app.id,
|
||||
workflow_id=workflow.id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
created_from="service-api",
|
||||
created_from=WorkflowAppLogCreatedFrom.SERVICE_API,
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=account.id,
|
||||
)
|
||||
@ -357,7 +358,7 @@ class TestWorkflowAppService:
|
||||
app_id=app.id,
|
||||
workflow_id=workflow.id,
|
||||
workflow_run_id=workflow_run_1.id,
|
||||
created_from="service-api",
|
||||
created_from=WorkflowAppLogCreatedFrom.SERVICE_API,
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=account.id,
|
||||
)
|
||||
@ -399,7 +400,7 @@ class TestWorkflowAppService:
|
||||
app_id=app.id,
|
||||
workflow_id=workflow.id,
|
||||
workflow_run_id=workflow_run_2.id,
|
||||
created_from="service-api",
|
||||
created_from=WorkflowAppLogCreatedFrom.SERVICE_API,
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=account.id,
|
||||
)
|
||||
@ -441,7 +442,7 @@ class TestWorkflowAppService:
|
||||
app_id=app.id,
|
||||
workflow_id=workflow.id,
|
||||
workflow_run_id=workflow_run_4.id,
|
||||
created_from="service-api",
|
||||
created_from=WorkflowAppLogCreatedFrom.SERVICE_API,
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=account.id,
|
||||
)
|
||||
@ -521,7 +522,7 @@ class TestWorkflowAppService:
|
||||
app_id=app.id,
|
||||
workflow_id=workflow.id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
created_from="service-api",
|
||||
created_from=WorkflowAppLogCreatedFrom.SERVICE_API,
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=account.id,
|
||||
)
|
||||
@ -627,7 +628,7 @@ class TestWorkflowAppService:
|
||||
app_id=app.id,
|
||||
workflow_id=workflow.id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
created_from="service-api",
|
||||
created_from=WorkflowAppLogCreatedFrom.SERVICE_API,
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=account.id,
|
||||
)
|
||||
@ -732,7 +733,7 @@ class TestWorkflowAppService:
|
||||
app_id=app.id,
|
||||
workflow_id=workflow.id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
created_from="service-api",
|
||||
created_from=WorkflowAppLogCreatedFrom.SERVICE_API,
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=account.id,
|
||||
)
|
||||
@ -860,7 +861,7 @@ class TestWorkflowAppService:
|
||||
app_id=app.id,
|
||||
workflow_id=workflow.id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
created_from="service-api",
|
||||
created_from=WorkflowAppLogCreatedFrom.SERVICE_API,
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=account.id,
|
||||
)
|
||||
@ -902,7 +903,7 @@ class TestWorkflowAppService:
|
||||
app_id=app.id,
|
||||
workflow_id=workflow.id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
created_from="web-app",
|
||||
created_from=WorkflowAppLogCreatedFrom.WEB_APP,
|
||||
created_by_role=CreatorUserRole.END_USER,
|
||||
created_by=end_user.id,
|
||||
)
|
||||
@ -1037,7 +1038,7 @@ class TestWorkflowAppService:
|
||||
app_id=app.id,
|
||||
workflow_id=workflow.id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
created_from="service-api",
|
||||
created_from=WorkflowAppLogCreatedFrom.SERVICE_API,
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=account.id,
|
||||
)
|
||||
@ -1125,7 +1126,7 @@ class TestWorkflowAppService:
|
||||
app_id=app.id,
|
||||
workflow_id=workflow.id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
created_from="service-api",
|
||||
created_from=WorkflowAppLogCreatedFrom.SERVICE_API,
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=account.id,
|
||||
)
|
||||
@ -1279,7 +1280,7 @@ class TestWorkflowAppService:
|
||||
app_id=app.id,
|
||||
workflow_id=workflow.id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
created_from="service-api",
|
||||
created_from=WorkflowAppLogCreatedFrom.SERVICE_API,
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=account.id,
|
||||
)
|
||||
@ -1379,7 +1380,7 @@ class TestWorkflowAppService:
|
||||
app_id=app.id,
|
||||
workflow_id=workflow.id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
created_from="service-api",
|
||||
created_from=WorkflowAppLogCreatedFrom.SERVICE_API,
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=account.id,
|
||||
)
|
||||
@ -1481,7 +1482,7 @@ class TestWorkflowAppService:
|
||||
app_id=app.id,
|
||||
workflow_id=workflow.id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
created_from="service-api",
|
||||
created_from=WorkflowAppLogCreatedFrom.SERVICE_API,
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=account.id,
|
||||
)
|
||||
|
||||
@ -536,3 +536,151 @@ class TestApiToolManageService:
|
||||
# Verify mock interactions
|
||||
mock_external_service_dependencies["encrypter"].assert_called_once()
|
||||
mock_external_service_dependencies["provider_controller"].from_db.assert_called_once()
|
||||
|
||||
def test_delete_api_tool_provider_success(
|
||||
self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""Test successful deletion of an API tool provider."""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
schema = self._create_test_openapi_schema()
|
||||
provider_name = fake.unique.word()
|
||||
|
||||
ApiToolManageService.create_api_tool_provider(
|
||||
user_id=account.id,
|
||||
tenant_id=tenant.id,
|
||||
provider_name=provider_name,
|
||||
icon={"content": "🔧", "background": "#FFF"},
|
||||
credentials={"auth_type": "none"},
|
||||
schema_type=ApiProviderSchemaType.OPENAPI,
|
||||
schema=schema,
|
||||
privacy_policy="",
|
||||
custom_disclaimer="",
|
||||
labels=[],
|
||||
)
|
||||
|
||||
provider = (
|
||||
db_session_with_containers.query(ApiToolProvider)
|
||||
.filter(ApiToolProvider.tenant_id == tenant.id, ApiToolProvider.name == provider_name)
|
||||
.first()
|
||||
)
|
||||
assert provider is not None
|
||||
|
||||
result = ApiToolManageService.delete_api_tool_provider(account.id, tenant.id, provider_name)
|
||||
|
||||
assert result == {"result": "success"}
|
||||
deleted = (
|
||||
db_session_with_containers.query(ApiToolProvider)
|
||||
.filter(ApiToolProvider.tenant_id == tenant.id, ApiToolProvider.name == provider_name)
|
||||
.first()
|
||||
)
|
||||
assert deleted is None
|
||||
|
||||
def test_delete_api_tool_provider_not_found(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""Test deletion raises ValueError when provider not found."""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="you have not added provider"):
|
||||
ApiToolManageService.delete_api_tool_provider(account.id, tenant.id, "nonexistent")
|
||||
|
||||
def test_update_api_tool_provider_not_found(
|
||||
self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""Test update raises ValueError when original provider not found."""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="does not exists"):
|
||||
ApiToolManageService.update_api_tool_provider(
|
||||
user_id=account.id,
|
||||
tenant_id=tenant.id,
|
||||
provider_name="new-name",
|
||||
original_provider="nonexistent",
|
||||
icon={},
|
||||
credentials={"auth_type": "none"},
|
||||
_schema_type=ApiProviderSchemaType.OPENAPI,
|
||||
schema=self._create_test_openapi_schema(),
|
||||
privacy_policy=None,
|
||||
custom_disclaimer="",
|
||||
labels=[],
|
||||
)
|
||||
|
||||
def test_update_api_tool_provider_missing_auth_type(
|
||||
self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""Test update raises ValueError when auth_type is missing from credentials."""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
schema = self._create_test_openapi_schema()
|
||||
provider_name = fake.unique.word()
|
||||
|
||||
ApiToolManageService.create_api_tool_provider(
|
||||
user_id=account.id,
|
||||
tenant_id=tenant.id,
|
||||
provider_name=provider_name,
|
||||
icon={"content": "🔧", "background": "#FFF"},
|
||||
credentials={"auth_type": "none"},
|
||||
schema_type=ApiProviderSchemaType.OPENAPI,
|
||||
schema=schema,
|
||||
privacy_policy="",
|
||||
custom_disclaimer="",
|
||||
labels=[],
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="auth_type is required"):
|
||||
ApiToolManageService.update_api_tool_provider(
|
||||
user_id=account.id,
|
||||
tenant_id=tenant.id,
|
||||
provider_name=provider_name,
|
||||
original_provider=provider_name,
|
||||
icon={},
|
||||
credentials={},
|
||||
_schema_type=ApiProviderSchemaType.OPENAPI,
|
||||
schema=schema,
|
||||
privacy_policy=None,
|
||||
custom_disclaimer="",
|
||||
labels=[],
|
||||
)
|
||||
|
||||
def test_list_api_tool_provider_tools_not_found(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""Test listing tools raises ValueError when provider not found."""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="you have not added provider"):
|
||||
ApiToolManageService.list_api_tool_provider_tools(account.id, tenant.id, "nonexistent")
|
||||
|
||||
def test_test_api_tool_preview_invalid_schema_type(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""Test preview raises ValueError for invalid schema type."""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="invalid schema type"):
|
||||
ApiToolManageService.test_api_tool_preview(
|
||||
tenant_id=tenant.id,
|
||||
provider_name="provider-a",
|
||||
tool_name="tool-a",
|
||||
credentials={"auth_type": "none"},
|
||||
parameters={},
|
||||
schema_type="bad-schema-type",
|
||||
schema="schema",
|
||||
)
|
||||
|
||||
@ -1043,3 +1043,112 @@ class TestWorkflowToolManageService:
|
||||
# After the fix, this should always be 0
|
||||
# For now, we document that the record may exist, demonstrating the bug
|
||||
# assert tool_count == 0 # Expected after fix
|
||||
|
||||
def test_delete_workflow_tool_success(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""Test successful deletion of a workflow tool."""
|
||||
fake = Faker()
|
||||
app, account, workflow = self._create_test_app_and_account(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
tool_name = fake.unique.word()
|
||||
|
||||
WorkflowToolManageService.create_workflow_tool(
|
||||
user_id=account.id,
|
||||
tenant_id=account.current_tenant.id,
|
||||
workflow_app_id=app.id,
|
||||
name=tool_name,
|
||||
label=fake.word(),
|
||||
icon={"type": "emoji", "emoji": "🔧"},
|
||||
description=fake.text(max_nb_chars=200),
|
||||
parameters=self._create_test_workflow_tool_parameters(),
|
||||
)
|
||||
|
||||
tool = (
|
||||
db_session_with_containers.query(WorkflowToolProvider)
|
||||
.where(WorkflowToolProvider.tenant_id == account.current_tenant.id, WorkflowToolProvider.name == tool_name)
|
||||
.first()
|
||||
)
|
||||
assert tool is not None
|
||||
|
||||
result = WorkflowToolManageService.delete_workflow_tool(account.id, account.current_tenant.id, tool.id)
|
||||
|
||||
assert result == {"result": "success"}
|
||||
deleted = (
|
||||
db_session_with_containers.query(WorkflowToolProvider).where(WorkflowToolProvider.id == tool.id).first()
|
||||
)
|
||||
assert deleted is None
|
||||
|
||||
def test_list_tenant_workflow_tools_empty(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""Test listing workflow tools when none exist returns empty list."""
|
||||
fake = Faker()
|
||||
app, account, workflow = self._create_test_app_and_account(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
result = WorkflowToolManageService.list_tenant_workflow_tools(account.id, account.current_tenant.id)
|
||||
|
||||
assert result == []
|
||||
|
||||
def test_get_workflow_tool_by_tool_id_not_found(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""Test that get_workflow_tool_by_tool_id raises ValueError when tool not found."""
|
||||
fake = Faker()
|
||||
app, account, workflow = self._create_test_app_and_account(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Tool not found"):
|
||||
WorkflowToolManageService.get_workflow_tool_by_tool_id(account.id, account.current_tenant.id, fake.uuid4())
|
||||
|
||||
def test_get_workflow_tool_by_app_id_not_found(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""Test that get_workflow_tool_by_app_id raises ValueError when tool not found."""
|
||||
fake = Faker()
|
||||
app, account, workflow = self._create_test_app_and_account(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Tool not found"):
|
||||
WorkflowToolManageService.get_workflow_tool_by_app_id(account.id, account.current_tenant.id, fake.uuid4())
|
||||
|
||||
def test_list_single_workflow_tools_not_found(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""Test that list_single_workflow_tools raises ValueError when tool not found."""
|
||||
fake = Faker()
|
||||
app, account, workflow = self._create_test_app_and_account(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
WorkflowToolManageService.list_single_workflow_tools(account.id, account.current_tenant.id, fake.uuid4())
|
||||
|
||||
def test_create_workflow_tool_with_labels(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""Test that labels are forwarded to ToolLabelManager when provided."""
|
||||
fake = Faker()
|
||||
app, account, workflow = self._create_test_app_and_account(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
result = WorkflowToolManageService.create_workflow_tool(
|
||||
user_id=account.id,
|
||||
tenant_id=account.current_tenant.id,
|
||||
workflow_app_id=app.id,
|
||||
name=fake.unique.word(),
|
||||
label=fake.word(),
|
||||
icon={"type": "emoji", "emoji": "🔧"},
|
||||
description=fake.text(max_nb_chars=200),
|
||||
parameters=self._create_test_workflow_tool_parameters(),
|
||||
labels=["label-1", "label-2"],
|
||||
)
|
||||
|
||||
assert result == {"result": "success"}
|
||||
mock_external_service_dependencies["tool_label_manager"].update_tool_labels.assert_called_once()
|
||||
|
||||
@ -0,0 +1,158 @@
|
||||
"""Testcontainers integration tests for WorkflowService.delete_workflow."""
|
||||
|
||||
import json
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account, Tenant, TenantAccountJoin
|
||||
from models.model import App
|
||||
from models.tools import WorkflowToolProvider
|
||||
from models.workflow import Workflow
|
||||
from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError, WorkflowService
|
||||
|
||||
|
||||
class TestWorkflowDeletion:
|
||||
def _create_tenant_and_account(self, session: Session) -> tuple[Tenant, Account]:
|
||||
tenant = Tenant(name=f"Tenant {uuid4()}")
|
||||
session.add(tenant)
|
||||
session.flush()
|
||||
|
||||
account = Account(
|
||||
name=f"Account {uuid4()}",
|
||||
email=f"wf_del_{uuid4()}@example.com",
|
||||
password="hashed",
|
||||
password_salt="salt",
|
||||
interface_language="en-US",
|
||||
timezone="UTC",
|
||||
)
|
||||
session.add(account)
|
||||
session.flush()
|
||||
|
||||
join = TenantAccountJoin(
|
||||
tenant_id=tenant.id,
|
||||
account_id=account.id,
|
||||
role="owner",
|
||||
current=True,
|
||||
)
|
||||
session.add(join)
|
||||
session.flush()
|
||||
return tenant, account
|
||||
|
||||
def _create_app(self, session: Session, *, tenant: Tenant, account: Account, workflow_id: str | None = None) -> App:
|
||||
app = App(
|
||||
tenant_id=tenant.id,
|
||||
name=f"App {uuid4()}",
|
||||
description="",
|
||||
mode="workflow",
|
||||
icon_type="emoji",
|
||||
icon="bot",
|
||||
icon_background="#FFFFFF",
|
||||
enable_site=False,
|
||||
enable_api=True,
|
||||
api_rpm=100,
|
||||
api_rph=100,
|
||||
is_demo=False,
|
||||
is_public=False,
|
||||
is_universal=False,
|
||||
created_by=account.id,
|
||||
updated_by=account.id,
|
||||
workflow_id=workflow_id,
|
||||
)
|
||||
session.add(app)
|
||||
session.flush()
|
||||
return app
|
||||
|
||||
def _create_workflow(
|
||||
self, session: Session, *, tenant: Tenant, app: App, account: Account, version: str = "1.0"
|
||||
) -> Workflow:
|
||||
workflow = Workflow(
|
||||
id=str(uuid4()),
|
||||
tenant_id=tenant.id,
|
||||
app_id=app.id,
|
||||
type="workflow",
|
||||
version=version,
|
||||
graph=json.dumps({"nodes": [], "edges": []}),
|
||||
_features=json.dumps({}),
|
||||
created_by=account.id,
|
||||
updated_by=account.id,
|
||||
)
|
||||
session.add(workflow)
|
||||
session.flush()
|
||||
return workflow
|
||||
|
||||
def _create_tool_provider(
|
||||
self, session: Session, *, tenant: Tenant, app: App, account: Account, version: str
|
||||
) -> WorkflowToolProvider:
|
||||
provider = WorkflowToolProvider(
|
||||
name=f"tool-{uuid4()}",
|
||||
label=f"Tool {uuid4()}",
|
||||
icon="wrench",
|
||||
app_id=app.id,
|
||||
version=version,
|
||||
user_id=account.id,
|
||||
tenant_id=tenant.id,
|
||||
description="test tool provider",
|
||||
)
|
||||
session.add(provider)
|
||||
session.flush()
|
||||
return provider
|
||||
|
||||
def test_delete_workflow_success(self, db_session_with_containers):
|
||||
tenant, account = self._create_tenant_and_account(db_session_with_containers)
|
||||
app = self._create_app(db_session_with_containers, tenant=tenant, account=account)
|
||||
workflow = self._create_workflow(
|
||||
db_session_with_containers, tenant=tenant, app=app, account=account, version="1.0"
|
||||
)
|
||||
db_session_with_containers.commit()
|
||||
workflow_id = workflow.id
|
||||
|
||||
service = WorkflowService(sessionmaker(bind=db.engine))
|
||||
result = service.delete_workflow(
|
||||
session=db_session_with_containers, workflow_id=workflow_id, tenant_id=tenant.id
|
||||
)
|
||||
|
||||
assert result is True
|
||||
db_session_with_containers.expire_all()
|
||||
assert db_session_with_containers.get(Workflow, workflow_id) is None
|
||||
|
||||
def test_delete_draft_workflow_raises_error(self, db_session_with_containers):
|
||||
tenant, account = self._create_tenant_and_account(db_session_with_containers)
|
||||
app = self._create_app(db_session_with_containers, tenant=tenant, account=account)
|
||||
workflow = self._create_workflow(
|
||||
db_session_with_containers, tenant=tenant, app=app, account=account, version="draft"
|
||||
)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
service = WorkflowService(sessionmaker(bind=db.engine))
|
||||
with pytest.raises(DraftWorkflowDeletionError):
|
||||
service.delete_workflow(session=db_session_with_containers, workflow_id=workflow.id, tenant_id=tenant.id)
|
||||
|
||||
def test_delete_workflow_in_use_by_app_raises_error(self, db_session_with_containers):
|
||||
tenant, account = self._create_tenant_and_account(db_session_with_containers)
|
||||
app = self._create_app(db_session_with_containers, tenant=tenant, account=account)
|
||||
workflow = self._create_workflow(
|
||||
db_session_with_containers, tenant=tenant, app=app, account=account, version="1.0"
|
||||
)
|
||||
# Point app to this workflow
|
||||
app.workflow_id = workflow.id
|
||||
db_session_with_containers.commit()
|
||||
|
||||
service = WorkflowService(sessionmaker(bind=db.engine))
|
||||
with pytest.raises(WorkflowInUseError, match="currently in use by app"):
|
||||
service.delete_workflow(session=db_session_with_containers, workflow_id=workflow.id, tenant_id=tenant.id)
|
||||
|
||||
def test_delete_workflow_published_as_tool_raises_error(self, db_session_with_containers):
|
||||
tenant, account = self._create_tenant_and_account(db_session_with_containers)
|
||||
app = self._create_app(db_session_with_containers, tenant=tenant, account=account)
|
||||
workflow = self._create_workflow(
|
||||
db_session_with_containers, tenant=tenant, app=app, account=account, version="1.0"
|
||||
)
|
||||
self._create_tool_provider(db_session_with_containers, tenant=tenant, app=app, account=account, version="1.0")
|
||||
db_session_with_containers.commit()
|
||||
|
||||
service = WorkflowService(sessionmaker(bind=db.engine))
|
||||
with pytest.raises(WorkflowInUseError, match="published as a tool"):
|
||||
service.delete_workflow(session=db_session_with_containers, workflow_id=workflow.id, tenant_id=tenant.id)
|
||||
@ -281,12 +281,10 @@ class TestSiteEndpoints:
|
||||
method = _unwrap(api.post)
|
||||
|
||||
site = MagicMock()
|
||||
query = MagicMock()
|
||||
query.where.return_value.first.return_value = site
|
||||
monkeypatch.setattr(
|
||||
site_module.db,
|
||||
"session",
|
||||
MagicMock(query=lambda *_args, **_kwargs: query, commit=lambda: None),
|
||||
MagicMock(scalar=lambda *_args, **_kwargs: site, commit=lambda: None),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
site_module,
|
||||
@ -305,12 +303,10 @@ class TestSiteEndpoints:
|
||||
method = _unwrap(api.post)
|
||||
|
||||
site = MagicMock()
|
||||
query = MagicMock()
|
||||
query.where.return_value.first.return_value = site
|
||||
monkeypatch.setattr(
|
||||
site_module.db,
|
||||
"session",
|
||||
MagicMock(query=lambda *_args, **_kwargs: query, commit=lambda: None),
|
||||
MagicMock(scalar=lambda *_args, **_kwargs: site, commit=lambda: None),
|
||||
)
|
||||
monkeypatch.setattr(site_module.Site, "generate_code", lambda *_args, **_kwargs: "code")
|
||||
monkeypatch.setattr(
|
||||
|
||||
@ -82,12 +82,8 @@ def test_chat_conversation_list_advanced_chat_calls_paginate(app, monkeypatch: p
|
||||
def test_get_conversation_updates_read_at(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
conversation = SimpleNamespace(id="c1", app_id="app-1")
|
||||
|
||||
query = MagicMock()
|
||||
query.where.return_value = query
|
||||
query.first.return_value = conversation
|
||||
|
||||
session = MagicMock()
|
||||
session.query.return_value = query
|
||||
session.scalar.return_value = conversation
|
||||
|
||||
monkeypatch.setattr(conversation_module, "current_account_with_tenant", lambda: (_make_account(), "t1"))
|
||||
monkeypatch.setattr(conversation_module.db, "session", session)
|
||||
@ -101,12 +97,8 @@ def test_get_conversation_updates_read_at(monkeypatch: pytest.MonkeyPatch) -> No
|
||||
|
||||
|
||||
def test_get_conversation_missing_raises_not_found(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
query = MagicMock()
|
||||
query.where.return_value = query
|
||||
query.first.return_value = None
|
||||
|
||||
session = MagicMock()
|
||||
session.query.return_value = query
|
||||
session.scalar.return_value = None
|
||||
|
||||
monkeypatch.setattr(conversation_module, "current_account_with_tenant", lambda: (_make_account(), "t1"))
|
||||
monkeypatch.setattr(conversation_module.db, "session", session)
|
||||
|
||||
@ -24,7 +24,7 @@ def test_get_conversation_mark_read_keeps_updated_at_unchanged():
|
||||
),
|
||||
patch("controllers.console.app.conversation.db.session", autospec=True) as mock_session,
|
||||
):
|
||||
mock_session.query.return_value.where.return_value.first.return_value = conversation
|
||||
mock_session.scalar.return_value = conversation
|
||||
|
||||
_get_conversation(app_model, "conversation-id")
|
||||
|
||||
|
||||
@ -73,8 +73,7 @@ def test_instruction_generate_app_not_found(app, monkeypatch: pytest.MonkeyPatch
|
||||
|
||||
monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1"))
|
||||
|
||||
query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: None)
|
||||
monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query))
|
||||
monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(get=lambda *_args, **_kwargs: None))
|
||||
|
||||
with app.test_request_context(
|
||||
"/console/api/instruction-generate",
|
||||
@ -99,8 +98,7 @@ def test_instruction_generate_workflow_not_found(app, monkeypatch: pytest.Monkey
|
||||
monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1"))
|
||||
|
||||
app_model = SimpleNamespace(id="app-1")
|
||||
query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: app_model)
|
||||
monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query))
|
||||
monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(get=lambda *_args, **_kwargs: app_model))
|
||||
_install_workflow_service(monkeypatch, workflow=None)
|
||||
|
||||
with app.test_request_context(
|
||||
@ -126,8 +124,7 @@ def test_instruction_generate_node_missing(app, monkeypatch: pytest.MonkeyPatch)
|
||||
monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1"))
|
||||
|
||||
app_model = SimpleNamespace(id="app-1")
|
||||
query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: app_model)
|
||||
monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query))
|
||||
monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(get=lambda *_args, **_kwargs: app_model))
|
||||
|
||||
workflow = SimpleNamespace(graph_dict={"nodes": []})
|
||||
_install_workflow_service(monkeypatch, workflow=workflow)
|
||||
@ -155,8 +152,7 @@ def test_instruction_generate_code_node(app, monkeypatch: pytest.MonkeyPatch) ->
|
||||
monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1"))
|
||||
|
||||
app_model = SimpleNamespace(id="app-1")
|
||||
query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: app_model)
|
||||
monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query))
|
||||
monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(get=lambda *_args, **_kwargs: app_model))
|
||||
|
||||
workflow = SimpleNamespace(
|
||||
graph_dict={
|
||||
|
||||
@ -1,324 +0,0 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask, g, request
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
from werkzeug.local import LocalProxy
|
||||
|
||||
from controllers.console.app.error import (
|
||||
ProviderModelCurrentlyNotSupportError,
|
||||
ProviderNotInitializeError,
|
||||
ProviderQuotaExceededError,
|
||||
)
|
||||
from controllers.console.app.message import (
|
||||
ChatMessageListApi,
|
||||
ChatMessagesQuery,
|
||||
FeedbackExportQuery,
|
||||
MessageAnnotationCountApi,
|
||||
MessageApi,
|
||||
MessageFeedbackApi,
|
||||
MessageFeedbackExportApi,
|
||||
MessageFeedbackPayload,
|
||||
MessageSuggestedQuestionApi,
|
||||
)
|
||||
from controllers.console.explore.error import AppSuggestedQuestionsAfterAnswerDisabledError
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from models import App, AppMode
|
||||
from services.errors.conversation import ConversationNotExistsError
|
||||
from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
flask_app = Flask(__name__)
|
||||
flask_app.config["TESTING"] = True
|
||||
flask_app.config["RESTX_MASK_HEADER"] = "X-Fields"
|
||||
mock_lm = MagicMock()
|
||||
mock_lm._load_user = lambda: setattr(__import__("flask").g, "_login_user", MagicMock())
|
||||
flask_app.login_manager = mock_lm
|
||||
return flask_app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_account():
|
||||
from models.account import Account, AccountStatus
|
||||
|
||||
account = MagicMock(spec=Account)
|
||||
account.id = "user_123"
|
||||
account.timezone = "UTC"
|
||||
account.status = AccountStatus.ACTIVE
|
||||
account.is_admin_or_owner = True
|
||||
account.current_tenant.current_role = "owner"
|
||||
account.has_edit_permission = True
|
||||
return account
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_app_model():
|
||||
app_model = MagicMock(spec=App)
|
||||
app_model.id = "app_123"
|
||||
app_model.mode = AppMode.CHAT
|
||||
app_model.tenant_id = "tenant_123"
|
||||
return app_model
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_csrf():
|
||||
with patch("libs.login.check_csrf_token") as mock:
|
||||
yield mock
|
||||
|
||||
|
||||
import contextlib
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def setup_test_context(
|
||||
test_app, endpoint_class, route_path, method, mock_account, mock_app_model, payload=None, qs=None
|
||||
):
|
||||
with (
|
||||
patch("extensions.ext_database.db") as mock_db,
|
||||
patch("controllers.console.app.wraps.db", mock_db),
|
||||
patch("controllers.console.wraps.db", mock_db),
|
||||
patch("controllers.console.app.message.db", mock_db),
|
||||
patch("controllers.console.app.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")),
|
||||
patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")),
|
||||
patch("controllers.console.app.message.current_account_with_tenant", return_value=(mock_account, "tenant_123")),
|
||||
):
|
||||
# Set up a generic query mock that usually returns mock_app_model when getting app
|
||||
app_query_mock = MagicMock()
|
||||
app_query_mock.filter.return_value.first.return_value = mock_app_model
|
||||
app_query_mock.filter.return_value.filter.return_value.first.return_value = mock_app_model
|
||||
app_query_mock.where.return_value.first.return_value = mock_app_model
|
||||
app_query_mock.where.return_value.where.return_value.first.return_value = mock_app_model
|
||||
|
||||
data_query_mock = MagicMock()
|
||||
|
||||
def query_side_effect(*args, **kwargs):
|
||||
if args and hasattr(args[0], "__name__") and args[0].__name__ == "App":
|
||||
return app_query_mock
|
||||
return data_query_mock
|
||||
|
||||
mock_db.session.query.side_effect = query_side_effect
|
||||
mock_db.data_query = data_query_mock
|
||||
|
||||
# Let the caller override the stat db query logic
|
||||
proxy_mock = LocalProxy(lambda: mock_account)
|
||||
|
||||
query_string = "&".join([f"{k}={v}" for k, v in (qs or {}).items()])
|
||||
full_path = f"{route_path}?{query_string}" if qs else route_path
|
||||
|
||||
with (
|
||||
patch("libs.login.current_user", proxy_mock),
|
||||
patch("flask_login.current_user", proxy_mock),
|
||||
patch("controllers.console.app.message.attach_message_extra_contents", return_value=None),
|
||||
):
|
||||
with test_app.test_request_context(full_path, method=method, json=payload):
|
||||
g._login_user = mock_account
|
||||
request.view_args = {"app_id": "app_123"}
|
||||
|
||||
if "suggested-questions" in route_path:
|
||||
# simplistic extraction for message_id
|
||||
parts = route_path.split("chat-messages/")
|
||||
if len(parts) > 1:
|
||||
request.view_args["message_id"] = parts[1].split("/")[0]
|
||||
elif "messages/" in route_path and "chat-messages" not in route_path:
|
||||
parts = route_path.split("messages/")
|
||||
if len(parts) > 1:
|
||||
request.view_args["message_id"] = parts[1].split("/")[0]
|
||||
|
||||
api_instance = endpoint_class()
|
||||
|
||||
# Check if it has a dispatch_request or method
|
||||
if hasattr(api_instance, method.lower()):
|
||||
yield api_instance, mock_db, request.view_args
|
||||
|
||||
|
||||
class TestMessageValidators:
|
||||
def test_chat_messages_query_validators(self):
|
||||
# Test empty_to_none
|
||||
assert ChatMessagesQuery.empty_to_none("") is None
|
||||
assert ChatMessagesQuery.empty_to_none("val") == "val"
|
||||
|
||||
# Test validate_uuid
|
||||
assert ChatMessagesQuery.validate_uuid(None) is None
|
||||
assert (
|
||||
ChatMessagesQuery.validate_uuid("123e4567-e89b-12d3-a456-426614174000")
|
||||
== "123e4567-e89b-12d3-a456-426614174000"
|
||||
)
|
||||
|
||||
def test_message_feedback_validators(self):
|
||||
assert (
|
||||
MessageFeedbackPayload.validate_message_id("123e4567-e89b-12d3-a456-426614174000")
|
||||
== "123e4567-e89b-12d3-a456-426614174000"
|
||||
)
|
||||
|
||||
def test_feedback_export_validators(self):
|
||||
assert FeedbackExportQuery.parse_bool(None) is None
|
||||
assert FeedbackExportQuery.parse_bool(True) is True
|
||||
assert FeedbackExportQuery.parse_bool("1") is True
|
||||
assert FeedbackExportQuery.parse_bool("0") is False
|
||||
assert FeedbackExportQuery.parse_bool("off") is False
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
FeedbackExportQuery.parse_bool("invalid")
|
||||
|
||||
|
||||
class TestMessageEndpoints:
|
||||
def test_chat_message_list_not_found(self, app, mock_account, mock_app_model):
|
||||
with setup_test_context(
|
||||
app,
|
||||
ChatMessageListApi,
|
||||
"/apps/app_123/chat-messages",
|
||||
"GET",
|
||||
mock_account,
|
||||
mock_app_model,
|
||||
qs={"conversation_id": "123e4567-e89b-12d3-a456-426614174000"},
|
||||
) as (api, mock_db, v_args):
|
||||
mock_db.data_query.where.return_value.first.return_value = None
|
||||
|
||||
with pytest.raises(NotFound):
|
||||
api.get(**v_args)
|
||||
|
||||
def test_chat_message_list_success(self, app, mock_account, mock_app_model):
|
||||
with setup_test_context(
|
||||
app,
|
||||
ChatMessageListApi,
|
||||
"/apps/app_123/chat-messages",
|
||||
"GET",
|
||||
mock_account,
|
||||
mock_app_model,
|
||||
qs={"conversation_id": "123e4567-e89b-12d3-a456-426614174000", "limit": 1},
|
||||
) as (api, mock_db, v_args):
|
||||
mock_conv = MagicMock()
|
||||
mock_conv.id = "123e4567-e89b-12d3-a456-426614174000"
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.id = "msg_123"
|
||||
mock_msg.feedbacks = []
|
||||
mock_msg.annotation = None
|
||||
mock_msg.annotation_hit_history = None
|
||||
mock_msg.agent_thoughts = []
|
||||
mock_msg.message_files = []
|
||||
mock_msg.extra_contents = []
|
||||
mock_msg.message = {}
|
||||
mock_msg.message_metadata_dict = {}
|
||||
|
||||
# mock returns
|
||||
q_mock = mock_db.data_query
|
||||
q_mock.where.return_value.first.side_effect = [mock_conv]
|
||||
q_mock.where.return_value.order_by.return_value.limit.return_value.all.return_value = [mock_msg]
|
||||
mock_db.session.scalar.side_effect = [MagicMock(), False]
|
||||
|
||||
resp = api.get(**v_args)
|
||||
assert resp["limit"] == 1
|
||||
assert resp["has_more"] is False
|
||||
assert len(resp["data"]) == 1
|
||||
|
||||
def test_message_feedback_not_found(self, app, mock_account, mock_app_model):
|
||||
with setup_test_context(
|
||||
app,
|
||||
MessageFeedbackApi,
|
||||
"/apps/app_123/feedbacks",
|
||||
"POST",
|
||||
mock_account,
|
||||
mock_app_model,
|
||||
payload={"message_id": "123e4567-e89b-12d3-a456-426614174000"},
|
||||
) as (api, mock_db, v_args):
|
||||
mock_db.data_query.where.return_value.first.return_value = None
|
||||
|
||||
with pytest.raises(NotFound):
|
||||
api.post(**v_args)
|
||||
|
||||
def test_message_feedback_success(self, app, mock_account, mock_app_model):
|
||||
payload = {"message_id": "123e4567-e89b-12d3-a456-426614174000", "rating": "like"}
|
||||
with setup_test_context(
|
||||
app, MessageFeedbackApi, "/apps/app_123/feedbacks", "POST", mock_account, mock_app_model, payload=payload
|
||||
) as (api, mock_db, v_args):
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.admin_feedback = None
|
||||
mock_db.data_query.where.return_value.first.return_value = mock_msg
|
||||
|
||||
resp = api.post(**v_args)
|
||||
assert resp == {"result": "success"}
|
||||
|
||||
def test_message_annotation_count(self, app, mock_account, mock_app_model):
|
||||
with setup_test_context(
|
||||
app, MessageAnnotationCountApi, "/apps/app_123/annotations/count", "GET", mock_account, mock_app_model
|
||||
) as (api, mock_db, v_args):
|
||||
mock_db.data_query.where.return_value.count.return_value = 5
|
||||
|
||||
resp = api.get(**v_args)
|
||||
assert resp == {"count": 5}
|
||||
|
||||
@patch("controllers.console.app.message.MessageService")
|
||||
def test_message_suggested_questions_success(self, mock_msg_srv, app, mock_account, mock_app_model):
|
||||
mock_msg_srv.get_suggested_questions_after_answer.return_value = ["q1", "q2"]
|
||||
|
||||
with setup_test_context(
|
||||
app,
|
||||
MessageSuggestedQuestionApi,
|
||||
"/apps/app_123/chat-messages/msg_123/suggested-questions",
|
||||
"GET",
|
||||
mock_account,
|
||||
mock_app_model,
|
||||
) as (api, mock_db, v_args):
|
||||
resp = api.get(**v_args)
|
||||
assert resp == {"data": ["q1", "q2"]}
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("exc", "expected_exc"),
|
||||
[
|
||||
(MessageNotExistsError, NotFound),
|
||||
(ConversationNotExistsError, NotFound),
|
||||
(ProviderTokenNotInitError, ProviderNotInitializeError),
|
||||
(QuotaExceededError, ProviderQuotaExceededError),
|
||||
(ModelCurrentlyNotSupportError, ProviderModelCurrentlyNotSupportError),
|
||||
(SuggestedQuestionsAfterAnswerDisabledError, AppSuggestedQuestionsAfterAnswerDisabledError),
|
||||
(Exception, InternalServerError),
|
||||
],
|
||||
)
|
||||
@patch("controllers.console.app.message.MessageService")
|
||||
def test_message_suggested_questions_errors(
|
||||
self, mock_msg_srv, exc, expected_exc, app, mock_account, mock_app_model
|
||||
):
|
||||
mock_msg_srv.get_suggested_questions_after_answer.side_effect = exc()
|
||||
|
||||
with setup_test_context(
|
||||
app,
|
||||
MessageSuggestedQuestionApi,
|
||||
"/apps/app_123/chat-messages/msg_123/suggested-questions",
|
||||
"GET",
|
||||
mock_account,
|
||||
mock_app_model,
|
||||
) as (api, mock_db, v_args):
|
||||
with pytest.raises(expected_exc):
|
||||
api.get(**v_args)
|
||||
|
||||
@patch("services.feedback_service.FeedbackService.export_feedbacks")
|
||||
def test_message_feedback_export_success(self, mock_export, app, mock_account, mock_app_model):
|
||||
mock_export.return_value = {"exported": True}
|
||||
|
||||
with setup_test_context(
|
||||
app, MessageFeedbackExportApi, "/apps/app_123/feedbacks/export", "GET", mock_account, mock_app_model
|
||||
) as (api, mock_db, v_args):
|
||||
resp = api.get(**v_args)
|
||||
assert resp == {"exported": True}
|
||||
|
||||
def test_message_api_get_success(self, app, mock_account, mock_app_model):
|
||||
with setup_test_context(
|
||||
app, MessageApi, "/apps/app_123/messages/msg_123", "GET", mock_account, mock_app_model
|
||||
) as (api, mock_db, v_args):
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.id = "msg_123"
|
||||
mock_msg.feedbacks = []
|
||||
mock_msg.annotation = None
|
||||
mock_msg.annotation_hit_history = None
|
||||
mock_msg.agent_thoughts = []
|
||||
mock_msg.message_files = []
|
||||
mock_msg.extra_contents = []
|
||||
mock_msg.message = {}
|
||||
mock_msg.message_metadata_dict = {}
|
||||
|
||||
mock_db.data_query.where.return_value.first.return_value = mock_msg
|
||||
|
||||
resp = api.get(**v_args)
|
||||
assert resp["id"] == "msg_123"
|
||||
@ -92,10 +92,7 @@ def test_post_encrypts_agent_tool_parameters(app, monkeypatch: pytest.MonkeyPatc
|
||||
)
|
||||
|
||||
session = MagicMock()
|
||||
query = MagicMock()
|
||||
query.where.return_value = query
|
||||
query.first.return_value = original_config
|
||||
session.query.return_value = query
|
||||
session.get.return_value = original_config
|
||||
monkeypatch.setattr(model_config_module.db, "session", session)
|
||||
|
||||
monkeypatch.setattr(
|
||||
|
||||
@ -1,279 +0,0 @@
|
||||
from decimal import Decimal
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask, g, request
|
||||
from werkzeug.local import LocalProxy
|
||||
|
||||
from controllers.console.app.statistic import (
|
||||
AverageResponseTimeStatistic,
|
||||
AverageSessionInteractionStatistic,
|
||||
DailyConversationStatistic,
|
||||
DailyMessageStatistic,
|
||||
DailyTerminalsStatistic,
|
||||
DailyTokenCostStatistic,
|
||||
TokensPerSecondStatistic,
|
||||
UserSatisfactionRateStatistic,
|
||||
)
|
||||
from models import App, AppMode
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
flask_app = Flask(__name__)
|
||||
flask_app.config["TESTING"] = True
|
||||
mock_lm = MagicMock()
|
||||
mock_lm._load_user = lambda: setattr(__import__("flask").g, "_login_user", MagicMock())
|
||||
flask_app.login_manager = mock_lm
|
||||
return flask_app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_account():
|
||||
from models.account import Account, AccountStatus
|
||||
|
||||
account = MagicMock(spec=Account)
|
||||
account.id = "user_123"
|
||||
account.timezone = "UTC"
|
||||
account.status = AccountStatus.ACTIVE
|
||||
account.is_admin_or_owner = True
|
||||
account.current_tenant.current_role = "owner"
|
||||
account.has_edit_permission = True
|
||||
return account
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_app_model():
|
||||
app_model = MagicMock(spec=App)
|
||||
app_model.id = "app_123"
|
||||
app_model.mode = AppMode.CHAT
|
||||
app_model.tenant_id = "tenant_123"
|
||||
return app_model
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_csrf():
|
||||
with patch("libs.login.check_csrf_token") as mock:
|
||||
yield mock
|
||||
|
||||
|
||||
def setup_test_context(
|
||||
test_app, endpoint_class, route_path, mock_account, mock_app_model, mock_rs, mock_parse_ret=(None, None)
|
||||
):
|
||||
with (
|
||||
patch("controllers.console.app.statistic.db") as mock_db_stat,
|
||||
patch("controllers.console.app.wraps.db") as mock_db_wraps,
|
||||
patch("controllers.console.wraps.db", mock_db_wraps),
|
||||
patch(
|
||||
"controllers.console.app.statistic.current_account_with_tenant", return_value=(mock_account, "tenant_123")
|
||||
),
|
||||
patch("controllers.console.app.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")),
|
||||
patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")),
|
||||
):
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.execute.return_value = mock_rs
|
||||
|
||||
mock_begin = MagicMock()
|
||||
mock_begin.__enter__.return_value = mock_conn
|
||||
mock_db_stat.engine.begin.return_value = mock_begin
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_query.filter.return_value.first.return_value = mock_app_model
|
||||
mock_query.filter.return_value.filter.return_value.first.return_value = mock_app_model
|
||||
mock_query.where.return_value.first.return_value = mock_app_model
|
||||
mock_query.where.return_value.where.return_value.first.return_value = mock_app_model
|
||||
mock_db_wraps.session.query.return_value = mock_query
|
||||
|
||||
proxy_mock = LocalProxy(lambda: mock_account)
|
||||
|
||||
with patch("libs.login.current_user", proxy_mock), patch("flask_login.current_user", proxy_mock):
|
||||
with test_app.test_request_context(route_path, method="GET"):
|
||||
g._login_user = mock_account
|
||||
request.view_args = {"app_id": "app_123"}
|
||||
api_instance = endpoint_class()
|
||||
response = api_instance.get(app_id="app_123")
|
||||
return response
|
||||
|
||||
|
||||
class TestStatisticEndpoints:
|
||||
def test_daily_message_statistic(self, app, mock_account, mock_app_model):
|
||||
mock_row = MagicMock()
|
||||
mock_row.date = "2023-01-01"
|
||||
mock_row.message_count = 10
|
||||
mock_row.interactions = Decimal(0)
|
||||
|
||||
with patch("controllers.console.app.statistic.parse_time_range", return_value=(None, None)):
|
||||
response = setup_test_context(
|
||||
app,
|
||||
DailyMessageStatistic,
|
||||
"/apps/app_123/statistics/daily-messages?start=2023-01-01 00:00&end=2023-01-02 00:00",
|
||||
mock_account,
|
||||
mock_app_model,
|
||||
[mock_row],
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json["data"][0]["message_count"] == 10
|
||||
|
||||
def test_daily_conversation_statistic(self, app, mock_account, mock_app_model):
|
||||
mock_row = MagicMock()
|
||||
mock_row.date = "2023-01-01"
|
||||
mock_row.conversation_count = 5
|
||||
mock_row.interactions = Decimal(0)
|
||||
|
||||
with patch("controllers.console.app.statistic.parse_time_range", return_value=(None, None)):
|
||||
response = setup_test_context(
|
||||
app,
|
||||
DailyConversationStatistic,
|
||||
"/apps/app_123/statistics/daily-conversations",
|
||||
mock_account,
|
||||
mock_app_model,
|
||||
[mock_row],
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json["data"][0]["conversation_count"] == 5
|
||||
|
||||
def test_daily_terminals_statistic(self, app, mock_account, mock_app_model):
|
||||
mock_row = MagicMock()
|
||||
mock_row.date = "2023-01-01"
|
||||
mock_row.terminal_count = 2
|
||||
mock_row.interactions = Decimal(0)
|
||||
|
||||
with patch("controllers.console.app.statistic.parse_time_range", return_value=(None, None)):
|
||||
response = setup_test_context(
|
||||
app,
|
||||
DailyTerminalsStatistic,
|
||||
"/apps/app_123/statistics/daily-end-users",
|
||||
mock_account,
|
||||
mock_app_model,
|
||||
[mock_row],
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json["data"][0]["terminal_count"] == 2
|
||||
|
||||
def test_daily_token_cost_statistic(self, app, mock_account, mock_app_model):
|
||||
mock_row = MagicMock()
|
||||
mock_row.date = "2023-01-01"
|
||||
mock_row.token_count = 100
|
||||
mock_row.total_price = Decimal("0.02")
|
||||
mock_row.interactions = Decimal(0)
|
||||
|
||||
with patch("controllers.console.app.statistic.parse_time_range", return_value=(None, None)):
|
||||
response = setup_test_context(
|
||||
app,
|
||||
DailyTokenCostStatistic,
|
||||
"/apps/app_123/statistics/token-costs",
|
||||
mock_account,
|
||||
mock_app_model,
|
||||
[mock_row],
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json["data"][0]["token_count"] == 100
|
||||
assert response.json["data"][0]["total_price"] == "0.02"
|
||||
|
||||
def test_average_session_interaction_statistic(self, app, mock_account, mock_app_model):
|
||||
mock_row = MagicMock()
|
||||
mock_row.date = "2023-01-01"
|
||||
mock_row.interactions = Decimal("3.523")
|
||||
|
||||
with patch("controllers.console.app.statistic.parse_time_range", return_value=(None, None)):
|
||||
response = setup_test_context(
|
||||
app,
|
||||
AverageSessionInteractionStatistic,
|
||||
"/apps/app_123/statistics/average-session-interactions",
|
||||
mock_account,
|
||||
mock_app_model,
|
||||
[mock_row],
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json["data"][0]["interactions"] == 3.52
|
||||
|
||||
def test_user_satisfaction_rate_statistic(self, app, mock_account, mock_app_model):
|
||||
mock_row = MagicMock()
|
||||
mock_row.date = "2023-01-01"
|
||||
mock_row.message_count = 100
|
||||
mock_row.feedback_count = 10
|
||||
mock_row.interactions = Decimal(0)
|
||||
|
||||
with patch("controllers.console.app.statistic.parse_time_range", return_value=(None, None)):
|
||||
response = setup_test_context(
|
||||
app,
|
||||
UserSatisfactionRateStatistic,
|
||||
"/apps/app_123/statistics/user-satisfaction-rate",
|
||||
mock_account,
|
||||
mock_app_model,
|
||||
[mock_row],
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json["data"][0]["rate"] == 100.0
|
||||
|
||||
def test_average_response_time_statistic(self, app, mock_account, mock_app_model):
|
||||
mock_app_model.mode = AppMode.COMPLETION
|
||||
mock_row = MagicMock()
|
||||
mock_row.date = "2023-01-01"
|
||||
mock_row.latency = 1.234
|
||||
mock_row.interactions = Decimal(0)
|
||||
|
||||
with patch("controllers.console.app.statistic.parse_time_range", return_value=(None, None)):
|
||||
response = setup_test_context(
|
||||
app,
|
||||
AverageResponseTimeStatistic,
|
||||
"/apps/app_123/statistics/average-response-time",
|
||||
mock_account,
|
||||
mock_app_model,
|
||||
[mock_row],
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json["data"][0]["latency"] == 1234.0
|
||||
|
||||
def test_tokens_per_second_statistic(self, app, mock_account, mock_app_model):
|
||||
mock_row = MagicMock()
|
||||
mock_row.date = "2023-01-01"
|
||||
mock_row.tokens_per_second = 15.5
|
||||
mock_row.interactions = Decimal(0)
|
||||
|
||||
with patch("controllers.console.app.statistic.parse_time_range", return_value=(None, None)):
|
||||
response = setup_test_context(
|
||||
app,
|
||||
TokensPerSecondStatistic,
|
||||
"/apps/app_123/statistics/tokens-per-second",
|
||||
mock_account,
|
||||
mock_app_model,
|
||||
[mock_row],
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json["data"][0]["tps"] == 15.5
|
||||
|
||||
@patch("controllers.console.app.statistic.parse_time_range")
|
||||
def test_invalid_time_range(self, mock_parse, app, mock_account, mock_app_model):
|
||||
mock_parse.side_effect = ValueError("Invalid time")
|
||||
|
||||
from werkzeug.exceptions import BadRequest
|
||||
|
||||
with pytest.raises(BadRequest):
|
||||
setup_test_context(
|
||||
app,
|
||||
DailyMessageStatistic,
|
||||
"/apps/app_123/statistics/daily-messages?start=invalid&end=invalid",
|
||||
mock_account,
|
||||
mock_app_model,
|
||||
[],
|
||||
)
|
||||
|
||||
@patch("controllers.console.app.statistic.parse_time_range")
|
||||
def test_time_range_params_passed(self, mock_parse, app, mock_account, mock_app_model):
|
||||
import datetime
|
||||
|
||||
start = datetime.datetime.now()
|
||||
end = datetime.datetime.now()
|
||||
mock_parse.return_value = (start, end)
|
||||
|
||||
response = setup_test_context(
|
||||
app,
|
||||
DailyMessageStatistic,
|
||||
"/apps/app_123/statistics/daily-messages?start=something&end=something",
|
||||
mock_account,
|
||||
mock_app_model,
|
||||
[],
|
||||
)
|
||||
assert response.status_code == 200
|
||||
mock_parse.assert_called_once()
|
||||
@ -1,321 +0,0 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask, g, request
|
||||
from werkzeug.local import LocalProxy
|
||||
|
||||
from controllers.console.app.error import DraftWorkflowNotExist
|
||||
from controllers.console.app.workflow_draft_variable import (
|
||||
ConversationVariableCollectionApi,
|
||||
EnvironmentVariableCollectionApi,
|
||||
NodeVariableCollectionApi,
|
||||
SystemVariableCollectionApi,
|
||||
VariableApi,
|
||||
VariableResetApi,
|
||||
WorkflowVariableCollectionApi,
|
||||
)
|
||||
from controllers.web.error import InvalidArgumentError, NotFoundError
|
||||
from models import App, AppMode
|
||||
from models.enums import DraftVariableType
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
flask_app = Flask(__name__)
|
||||
flask_app.config["TESTING"] = True
|
||||
flask_app.config["RESTX_MASK_HEADER"] = "X-Fields"
|
||||
mock_lm = MagicMock()
|
||||
mock_lm._load_user = lambda: setattr(__import__("flask").g, "_login_user", MagicMock())
|
||||
flask_app.login_manager = mock_lm
|
||||
return flask_app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_account():
|
||||
from models.account import Account, AccountStatus
|
||||
|
||||
account = MagicMock(spec=Account)
|
||||
account.id = "user_123"
|
||||
account.timezone = "UTC"
|
||||
account.status = AccountStatus.ACTIVE
|
||||
account.is_admin_or_owner = True
|
||||
account.current_tenant.current_role = "owner"
|
||||
account.has_edit_permission = True
|
||||
return account
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_app_model():
|
||||
app_model = MagicMock(spec=App)
|
||||
app_model.id = "app_123"
|
||||
app_model.mode = AppMode.WORKFLOW
|
||||
app_model.tenant_id = "tenant_123"
|
||||
return app_model
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_csrf():
|
||||
with patch("libs.login.check_csrf_token") as mock:
|
||||
yield mock
|
||||
|
||||
|
||||
def setup_test_context(test_app, endpoint_class, route_path, method, mock_account, mock_app_model, payload=None):
|
||||
with (
|
||||
patch("controllers.console.app.wraps.db") as mock_db_wraps,
|
||||
patch("controllers.console.wraps.db", mock_db_wraps),
|
||||
patch("controllers.console.app.workflow_draft_variable.db"),
|
||||
patch("controllers.console.app.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")),
|
||||
patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")),
|
||||
):
|
||||
mock_query = MagicMock()
|
||||
mock_query.filter.return_value.first.return_value = mock_app_model
|
||||
mock_query.filter.return_value.filter.return_value.first.return_value = mock_app_model
|
||||
mock_query.where.return_value.first.return_value = mock_app_model
|
||||
mock_query.where.return_value.where.return_value.first.return_value = mock_app_model
|
||||
mock_db_wraps.session.query.return_value = mock_query
|
||||
|
||||
proxy_mock = LocalProxy(lambda: mock_account)
|
||||
|
||||
with patch("libs.login.current_user", proxy_mock), patch("flask_login.current_user", proxy_mock):
|
||||
with test_app.test_request_context(route_path, method=method, json=payload):
|
||||
g._login_user = mock_account
|
||||
request.view_args = {"app_id": "app_123"}
|
||||
# extract node_id or variable_id from path manually since view_args overrides
|
||||
if "nodes/" in route_path:
|
||||
request.view_args["node_id"] = route_path.split("nodes/")[1].split("/")[0]
|
||||
if "variables/" in route_path:
|
||||
# simplistic extraction
|
||||
parts = route_path.split("variables/")
|
||||
if len(parts) > 1 and parts[1] and parts[1] != "reset":
|
||||
request.view_args["variable_id"] = parts[1].split("/")[0]
|
||||
|
||||
api_instance = endpoint_class()
|
||||
# we just call dispatch_request to avoid manual argument passing
|
||||
if hasattr(api_instance, method.lower()):
|
||||
func = getattr(api_instance, method.lower())
|
||||
return func(**request.view_args)
|
||||
|
||||
|
||||
class TestWorkflowDraftVariableEndpoints:
|
||||
@staticmethod
|
||||
def _mock_workflow_variable(variable_type: DraftVariableType = DraftVariableType.NODE) -> MagicMock:
|
||||
class DummyValueType:
|
||||
def exposed_type(self):
|
||||
return DraftVariableType.NODE
|
||||
|
||||
mock_var = MagicMock()
|
||||
mock_var.app_id = "app_123"
|
||||
mock_var.id = "var_123"
|
||||
mock_var.user_id = "user_123"
|
||||
mock_var.name = "test_var"
|
||||
mock_var.description = ""
|
||||
mock_var.get_variable_type.return_value = variable_type
|
||||
mock_var.get_selector.return_value = []
|
||||
mock_var.value_type = DummyValueType()
|
||||
mock_var.edited = False
|
||||
mock_var.visible = True
|
||||
mock_var.file_id = None
|
||||
mock_var.variable_file = None
|
||||
mock_var.is_truncated.return_value = False
|
||||
mock_var.get_value.return_value.model_copy.return_value.value = "test_value"
|
||||
return mock_var
|
||||
|
||||
@patch("controllers.console.app.workflow_draft_variable.WorkflowService")
|
||||
@patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService")
|
||||
def test_workflow_variable_collection_get_success(
|
||||
self, mock_draft_srv, mock_wf_srv, app, mock_account, mock_app_model
|
||||
):
|
||||
mock_wf_srv.return_value.is_workflow_exist.return_value = True
|
||||
from services.workflow_draft_variable_service import WorkflowDraftVariableList
|
||||
|
||||
mock_draft_srv.return_value.list_variables_without_values.return_value = WorkflowDraftVariableList(
|
||||
variables=[], total=0
|
||||
)
|
||||
|
||||
resp = setup_test_context(
|
||||
app,
|
||||
WorkflowVariableCollectionApi,
|
||||
"/apps/app_123/workflows/draft/variables?page=1&limit=20",
|
||||
"GET",
|
||||
mock_account,
|
||||
mock_app_model,
|
||||
)
|
||||
assert resp == {"items": [], "total": 0}
|
||||
|
||||
@patch("controllers.console.app.workflow_draft_variable.WorkflowService")
|
||||
def test_workflow_variable_collection_get_not_exist(self, mock_wf_srv, app, mock_account, mock_app_model):
|
||||
mock_wf_srv.return_value.is_workflow_exist.return_value = False
|
||||
|
||||
with pytest.raises(DraftWorkflowNotExist):
|
||||
setup_test_context(
|
||||
app,
|
||||
WorkflowVariableCollectionApi,
|
||||
"/apps/app_123/workflows/draft/variables",
|
||||
"GET",
|
||||
mock_account,
|
||||
mock_app_model,
|
||||
)
|
||||
|
||||
@patch("controllers.console.app.workflow_draft_variable.SandboxService")
|
||||
@patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService")
|
||||
def test_workflow_variable_collection_delete(
|
||||
self, mock_draft_srv, mock_sandbox_srv, app, mock_account, mock_app_model,
|
||||
):
|
||||
resp = setup_test_context(
|
||||
app,
|
||||
WorkflowVariableCollectionApi,
|
||||
"/apps/app_123/workflows/draft/variables",
|
||||
"DELETE",
|
||||
mock_account,
|
||||
mock_app_model,
|
||||
)
|
||||
assert resp.status_code == 204
|
||||
|
||||
@patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService")
|
||||
def test_node_variable_collection_get_success(self, mock_draft_srv, app, mock_account, mock_app_model):
|
||||
from services.workflow_draft_variable_service import WorkflowDraftVariableList
|
||||
|
||||
mock_draft_srv.return_value.list_node_variables.return_value = WorkflowDraftVariableList(variables=[])
|
||||
resp = setup_test_context(
|
||||
app,
|
||||
NodeVariableCollectionApi,
|
||||
"/apps/app_123/workflows/draft/nodes/node_123/variables",
|
||||
"GET",
|
||||
mock_account,
|
||||
mock_app_model,
|
||||
)
|
||||
assert resp == {"items": []}
|
||||
|
||||
def test_node_variable_collection_get_invalid_node_id(self, app, mock_account, mock_app_model):
|
||||
with pytest.raises(InvalidArgumentError):
|
||||
setup_test_context(
|
||||
app,
|
||||
NodeVariableCollectionApi,
|
||||
"/apps/app_123/workflows/draft/nodes/sys/variables",
|
||||
"GET",
|
||||
mock_account,
|
||||
mock_app_model,
|
||||
)
|
||||
|
||||
@patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService")
|
||||
def test_node_variable_collection_delete(self, mock_draft_srv, app, mock_account, mock_app_model):
|
||||
resp = setup_test_context(
|
||||
app,
|
||||
NodeVariableCollectionApi,
|
||||
"/apps/app_123/workflows/draft/nodes/node_123/variables",
|
||||
"DELETE",
|
||||
mock_account,
|
||||
mock_app_model,
|
||||
)
|
||||
assert resp.status_code == 204
|
||||
|
||||
@patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService")
|
||||
def test_variable_api_get_success(self, mock_draft_srv, app, mock_account, mock_app_model):
|
||||
mock_draft_srv.return_value.get_variable.return_value = self._mock_workflow_variable()
|
||||
|
||||
resp = setup_test_context(
|
||||
app, VariableApi, "/apps/app_123/workflows/draft/variables/var_123", "GET", mock_account, mock_app_model
|
||||
)
|
||||
assert resp["id"] == "var_123"
|
||||
|
||||
@patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService")
|
||||
def test_variable_api_get_not_found(self, mock_draft_srv, app, mock_account, mock_app_model):
|
||||
mock_draft_srv.return_value.get_variable.return_value = None
|
||||
|
||||
with pytest.raises(NotFoundError):
|
||||
setup_test_context(
|
||||
app, VariableApi, "/apps/app_123/workflows/draft/variables/var_123", "GET", mock_account, mock_app_model
|
||||
)
|
||||
|
||||
@patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService")
|
||||
def test_variable_api_patch_success(self, mock_draft_srv, app, mock_account, mock_app_model):
|
||||
mock_draft_srv.return_value.get_variable.return_value = self._mock_workflow_variable()
|
||||
|
||||
resp = setup_test_context(
|
||||
app,
|
||||
VariableApi,
|
||||
"/apps/app_123/workflows/draft/variables/var_123",
|
||||
"PATCH",
|
||||
mock_account,
|
||||
mock_app_model,
|
||||
payload={"name": "new_name"},
|
||||
)
|
||||
assert resp["id"] == "var_123"
|
||||
mock_draft_srv.return_value.update_variable.assert_called_once()
|
||||
|
||||
@patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService")
|
||||
def test_variable_api_delete_success(self, mock_draft_srv, app, mock_account, mock_app_model):
|
||||
mock_draft_srv.return_value.get_variable.return_value = self._mock_workflow_variable()
|
||||
|
||||
resp = setup_test_context(
|
||||
app, VariableApi, "/apps/app_123/workflows/draft/variables/var_123", "DELETE", mock_account, mock_app_model
|
||||
)
|
||||
assert resp.status_code == 204
|
||||
mock_draft_srv.return_value.delete_variable.assert_called_once()
|
||||
|
||||
@patch("controllers.console.app.workflow_draft_variable.WorkflowService")
|
||||
@patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService")
|
||||
def test_variable_reset_api_put_success(self, mock_draft_srv, mock_wf_srv, app, mock_account, mock_app_model):
|
||||
mock_wf_srv.return_value.get_draft_workflow.return_value = MagicMock()
|
||||
mock_draft_srv.return_value.get_variable.return_value = self._mock_workflow_variable()
|
||||
mock_draft_srv.return_value.reset_variable.return_value = None # means no content
|
||||
|
||||
resp = setup_test_context(
|
||||
app,
|
||||
VariableResetApi,
|
||||
"/apps/app_123/workflows/draft/variables/var_123/reset",
|
||||
"PUT",
|
||||
mock_account,
|
||||
mock_app_model,
|
||||
)
|
||||
assert resp.status_code == 204
|
||||
|
||||
@patch("controllers.console.app.workflow_draft_variable.WorkflowService")
|
||||
@patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService")
|
||||
def test_conversation_variable_collection_get(self, mock_draft_srv, mock_wf_srv, app, mock_account, mock_app_model):
|
||||
mock_wf_srv.return_value.get_draft_workflow.return_value = MagicMock()
|
||||
from services.workflow_draft_variable_service import WorkflowDraftVariableList
|
||||
|
||||
mock_draft_srv.return_value.list_conversation_variables.return_value = WorkflowDraftVariableList(variables=[])
|
||||
|
||||
resp = setup_test_context(
|
||||
app,
|
||||
ConversationVariableCollectionApi,
|
||||
"/apps/app_123/workflows/draft/conversation-variables",
|
||||
"GET",
|
||||
mock_account,
|
||||
mock_app_model,
|
||||
)
|
||||
assert resp == {"items": []}
|
||||
|
||||
@patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService")
|
||||
def test_system_variable_collection_get(self, mock_draft_srv, app, mock_account, mock_app_model):
|
||||
from services.workflow_draft_variable_service import WorkflowDraftVariableList
|
||||
|
||||
mock_draft_srv.return_value.list_system_variables.return_value = WorkflowDraftVariableList(variables=[])
|
||||
|
||||
resp = setup_test_context(
|
||||
app,
|
||||
SystemVariableCollectionApi,
|
||||
"/apps/app_123/workflows/draft/system-variables",
|
||||
"GET",
|
||||
mock_account,
|
||||
mock_app_model,
|
||||
)
|
||||
assert resp == {"items": []}
|
||||
|
||||
@patch("controllers.console.app.workflow_draft_variable.WorkflowService")
|
||||
def test_environment_variable_collection_get(self, mock_wf_srv, app, mock_account, mock_app_model):
|
||||
mock_wf = MagicMock()
|
||||
mock_wf.environment_variables = []
|
||||
mock_wf_srv.return_value.get_draft_workflow.return_value = mock_wf
|
||||
|
||||
resp = setup_test_context(
|
||||
app,
|
||||
EnvironmentVariableCollectionApi,
|
||||
"/apps/app_123/workflows/draft/environment-variables",
|
||||
"GET",
|
||||
mock_account,
|
||||
mock_app_model,
|
||||
)
|
||||
assert resp == {"items": []}
|
||||
@ -11,10 +11,8 @@ from models.model import AppMode
|
||||
|
||||
def test_get_app_model_injects_model(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
app_model = SimpleNamespace(id="app-1", mode=AppMode.CHAT.value, status="normal", tenant_id="t1")
|
||||
query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: app_model)
|
||||
|
||||
monkeypatch.setattr(wraps_module, "current_account_with_tenant", lambda: (None, "t1"))
|
||||
monkeypatch.setattr(wraps_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query))
|
||||
monkeypatch.setattr(wraps_module.db, "session", SimpleNamespace(scalar=lambda *_args, **_kwargs: app_model))
|
||||
|
||||
@wraps_module.get_app_model
|
||||
def handler(app_model):
|
||||
@ -25,10 +23,8 @@ def test_get_app_model_injects_model(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
|
||||
def test_get_app_model_rejects_wrong_mode(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
app_model = SimpleNamespace(id="app-1", mode=AppMode.CHAT.value, status="normal", tenant_id="t1")
|
||||
query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: app_model)
|
||||
|
||||
monkeypatch.setattr(wraps_module, "current_account_with_tenant", lambda: (None, "t1"))
|
||||
monkeypatch.setattr(wraps_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query))
|
||||
monkeypatch.setattr(wraps_module.db, "session", SimpleNamespace(scalar=lambda *_args, **_kwargs: app_model))
|
||||
|
||||
@wraps_module.get_app_model(mode=[AppMode.COMPLETION])
|
||||
def handler(app_model):
|
||||
|
||||
@ -1,223 +0,0 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask, g
|
||||
|
||||
from controllers.console.auth.data_source_bearer_auth import (
|
||||
ApiKeyAuthDataSource,
|
||||
ApiKeyAuthDataSourceBinding,
|
||||
ApiKeyAuthDataSourceBindingDelete,
|
||||
)
|
||||
from controllers.console.auth.error import ApiKeyAuthFailedError
|
||||
|
||||
|
||||
class TestApiKeyAuthDataSource:
|
||||
@pytest.fixture
|
||||
def app(self):
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
app.config["WTF_CSRF_ENABLED"] = False
|
||||
mock_lm = MagicMock()
|
||||
mock_lm._load_user = lambda: setattr(__import__("flask").g, "_login_user", MagicMock())
|
||||
app.login_manager = mock_lm
|
||||
return app
|
||||
|
||||
@patch("libs.login.check_csrf_token")
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.get_provider_auth_list")
|
||||
def test_get_api_key_auth_data_source(self, mock_get_list, mock_db, mock_csrf, app):
|
||||
from models.account import Account, AccountStatus
|
||||
|
||||
mock_account = MagicMock(spec=Account)
|
||||
mock_account.id = "user_123"
|
||||
mock_account.status = AccountStatus.ACTIVE
|
||||
mock_account.is_admin_or_owner = True
|
||||
mock_account.current_tenant.current_role = "owner"
|
||||
|
||||
mock_binding = MagicMock()
|
||||
mock_binding.id = "bind_123"
|
||||
mock_binding.category = "api_key"
|
||||
mock_binding.provider = "custom_provider"
|
||||
mock_binding.disabled = False
|
||||
mock_binding.created_at.timestamp.return_value = 1620000000
|
||||
mock_binding.updated_at.timestamp.return_value = 1620000001
|
||||
|
||||
mock_get_list.return_value = [mock_binding]
|
||||
|
||||
with (
|
||||
patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")),
|
||||
patch(
|
||||
"controllers.console.auth.data_source_bearer_auth.current_account_with_tenant",
|
||||
return_value=(mock_account, "tenant_123"),
|
||||
),
|
||||
):
|
||||
with app.test_request_context("/console/api/api-key-auth/data-source", method="GET"):
|
||||
g._login_user = mock_account
|
||||
proxy_mock = MagicMock()
|
||||
proxy_mock._get_current_object.return_value = mock_account
|
||||
with patch("libs.login.current_user", proxy_mock):
|
||||
api_instance = ApiKeyAuthDataSource()
|
||||
response = api_instance.get()
|
||||
|
||||
assert "sources" in response
|
||||
assert len(response["sources"]) == 1
|
||||
assert response["sources"][0]["provider"] == "custom_provider"
|
||||
|
||||
@patch("libs.login.check_csrf_token")
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.get_provider_auth_list")
|
||||
def test_get_api_key_auth_data_source_empty(self, mock_get_list, mock_db, mock_csrf, app):
|
||||
from models.account import Account, AccountStatus
|
||||
|
||||
mock_account = MagicMock(spec=Account)
|
||||
mock_account.id = "user_123"
|
||||
mock_account.status = AccountStatus.ACTIVE
|
||||
mock_account.is_admin_or_owner = True
|
||||
mock_account.current_tenant.current_role = "owner"
|
||||
|
||||
mock_get_list.return_value = None
|
||||
|
||||
with (
|
||||
patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")),
|
||||
patch(
|
||||
"controllers.console.auth.data_source_bearer_auth.current_account_with_tenant",
|
||||
return_value=(mock_account, "tenant_123"),
|
||||
),
|
||||
):
|
||||
with app.test_request_context("/console/api/api-key-auth/data-source", method="GET"):
|
||||
g._login_user = mock_account
|
||||
proxy_mock = MagicMock()
|
||||
proxy_mock._get_current_object.return_value = mock_account
|
||||
with patch("libs.login.current_user", proxy_mock):
|
||||
api_instance = ApiKeyAuthDataSource()
|
||||
response = api_instance.get()
|
||||
|
||||
assert "sources" in response
|
||||
assert len(response["sources"]) == 0
|
||||
|
||||
|
||||
class TestApiKeyAuthDataSourceBinding:
|
||||
@pytest.fixture
|
||||
def app(self):
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
app.config["WTF_CSRF_ENABLED"] = False
|
||||
mock_lm = MagicMock()
|
||||
mock_lm._load_user = lambda: setattr(__import__("flask").g, "_login_user", MagicMock())
|
||||
app.login_manager = mock_lm
|
||||
return app
|
||||
|
||||
@patch("libs.login.check_csrf_token")
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.create_provider_auth")
|
||||
@patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.validate_api_key_auth_args")
|
||||
def test_create_binding_successful(self, mock_validate, mock_create, mock_db, mock_csrf, app):
|
||||
from models.account import Account, AccountStatus
|
||||
|
||||
mock_account = MagicMock(spec=Account)
|
||||
mock_account.id = "user_123"
|
||||
mock_account.status = AccountStatus.ACTIVE
|
||||
mock_account.is_admin_or_owner = True
|
||||
mock_account.current_tenant.current_role = "owner"
|
||||
|
||||
with (
|
||||
patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")),
|
||||
patch(
|
||||
"controllers.console.auth.data_source_bearer_auth.current_account_with_tenant",
|
||||
return_value=(mock_account, "tenant_123"),
|
||||
),
|
||||
):
|
||||
with app.test_request_context(
|
||||
"/console/api/api-key-auth/data-source/binding",
|
||||
method="POST",
|
||||
json={"category": "api_key", "provider": "custom", "credentials": {"key": "value"}},
|
||||
):
|
||||
g._login_user = mock_account
|
||||
proxy_mock = MagicMock()
|
||||
proxy_mock._get_current_object.return_value = mock_account
|
||||
with patch("libs.login.current_user", proxy_mock), patch("flask_login.current_user", proxy_mock):
|
||||
api_instance = ApiKeyAuthDataSourceBinding()
|
||||
response = api_instance.post()
|
||||
|
||||
assert response[0]["result"] == "success"
|
||||
assert response[1] == 200
|
||||
mock_validate.assert_called_once()
|
||||
mock_create.assert_called_once()
|
||||
|
||||
@patch("libs.login.check_csrf_token")
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.create_provider_auth")
|
||||
@patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.validate_api_key_auth_args")
|
||||
def test_create_binding_failure(self, mock_validate, mock_create, mock_db, mock_csrf, app):
|
||||
from models.account import Account, AccountStatus
|
||||
|
||||
mock_account = MagicMock(spec=Account)
|
||||
mock_account.id = "user_123"
|
||||
mock_account.status = AccountStatus.ACTIVE
|
||||
mock_account.is_admin_or_owner = True
|
||||
mock_account.current_tenant.current_role = "owner"
|
||||
|
||||
mock_create.side_effect = ValueError("Invalid structure")
|
||||
|
||||
with (
|
||||
patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")),
|
||||
patch(
|
||||
"controllers.console.auth.data_source_bearer_auth.current_account_with_tenant",
|
||||
return_value=(mock_account, "tenant_123"),
|
||||
),
|
||||
):
|
||||
with app.test_request_context(
|
||||
"/console/api/api-key-auth/data-source/binding",
|
||||
method="POST",
|
||||
json={"category": "api_key", "provider": "custom", "credentials": {"key": "value"}},
|
||||
):
|
||||
g._login_user = mock_account
|
||||
proxy_mock = MagicMock()
|
||||
proxy_mock._get_current_object.return_value = mock_account
|
||||
with patch("libs.login.current_user", proxy_mock), patch("flask_login.current_user", proxy_mock):
|
||||
api_instance = ApiKeyAuthDataSourceBinding()
|
||||
with pytest.raises(ApiKeyAuthFailedError, match="Invalid structure"):
|
||||
api_instance.post()
|
||||
|
||||
|
||||
class TestApiKeyAuthDataSourceBindingDelete:
|
||||
@pytest.fixture
|
||||
def app(self):
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
app.config["WTF_CSRF_ENABLED"] = False
|
||||
mock_lm = MagicMock()
|
||||
mock_lm._load_user = lambda: setattr(__import__("flask").g, "_login_user", MagicMock())
|
||||
app.login_manager = mock_lm
|
||||
return app
|
||||
|
||||
@patch("libs.login.check_csrf_token")
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.delete_provider_auth")
|
||||
def test_delete_binding_successful(self, mock_delete, mock_db, mock_csrf, app):
|
||||
from models.account import Account, AccountStatus
|
||||
|
||||
mock_account = MagicMock(spec=Account)
|
||||
mock_account.id = "user_123"
|
||||
mock_account.status = AccountStatus.ACTIVE
|
||||
mock_account.is_admin_or_owner = True
|
||||
mock_account.current_tenant.current_role = "owner"
|
||||
|
||||
with (
|
||||
patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")),
|
||||
patch(
|
||||
"controllers.console.auth.data_source_bearer_auth.current_account_with_tenant",
|
||||
return_value=(mock_account, "tenant_123"),
|
||||
),
|
||||
):
|
||||
with app.test_request_context("/console/api/api-key-auth/data-source/binding_123", method="DELETE"):
|
||||
g._login_user = mock_account
|
||||
proxy_mock = MagicMock()
|
||||
proxy_mock._get_current_object.return_value = mock_account
|
||||
with patch("libs.login.current_user", proxy_mock), patch("flask_login.current_user", proxy_mock):
|
||||
api_instance = ApiKeyAuthDataSourceBindingDelete()
|
||||
response = api_instance.delete("binding_123")
|
||||
|
||||
assert response[0]["result"] == "success"
|
||||
assert response[1] == 204
|
||||
mock_delete.assert_called_once_with("tenant_123", "binding_123")
|
||||
@ -1,205 +0,0 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask, g
|
||||
from werkzeug.local import LocalProxy
|
||||
|
||||
from controllers.console.auth.data_source_oauth import (
|
||||
OAuthDataSource,
|
||||
OAuthDataSourceBinding,
|
||||
OAuthDataSourceCallback,
|
||||
OAuthDataSourceSync,
|
||||
)
|
||||
|
||||
|
||||
class TestOAuthDataSource:
|
||||
@pytest.fixture
|
||||
def app(self):
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
mock_lm = MagicMock()
|
||||
mock_lm._load_user = lambda: setattr(__import__("flask").g, "_login_user", MagicMock())
|
||||
app.login_manager = mock_lm
|
||||
return app
|
||||
|
||||
@patch("controllers.console.auth.data_source_oauth.get_oauth_providers")
|
||||
@patch("flask_login.current_user")
|
||||
@patch("libs.login.current_user")
|
||||
@patch("libs.login.check_csrf_token")
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.data_source_oauth.dify_config.NOTION_INTEGRATION_TYPE", None)
|
||||
def test_get_oauth_url_successful(
|
||||
self, mock_db, mock_csrf, mock_libs_user, mock_flask_user, mock_get_providers, app
|
||||
):
|
||||
mock_oauth_provider = MagicMock()
|
||||
mock_oauth_provider.get_authorization_url.return_value = "http://oauth.provider/auth"
|
||||
mock_get_providers.return_value = {"notion": mock_oauth_provider}
|
||||
|
||||
from models.account import Account, AccountStatus
|
||||
|
||||
mock_account = MagicMock(spec=Account)
|
||||
mock_account.id = "user_123"
|
||||
mock_account.status = AccountStatus.ACTIVE
|
||||
mock_account.is_admin_or_owner = True
|
||||
mock_account.current_tenant.current_role = "owner"
|
||||
mock_libs_user.return_value = mock_account
|
||||
mock_flask_user.return_value = mock_account
|
||||
|
||||
# also patch current_account_with_tenant
|
||||
with patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, MagicMock())):
|
||||
with app.test_request_context("/console/api/oauth/data-source/notion", method="GET"):
|
||||
proxy_mock = LocalProxy(lambda: mock_account)
|
||||
with patch("libs.login.current_user", proxy_mock):
|
||||
api_instance = OAuthDataSource()
|
||||
response = api_instance.get("notion")
|
||||
|
||||
assert response[0]["data"] == "http://oauth.provider/auth"
|
||||
assert response[1] == 200
|
||||
mock_oauth_provider.get_authorization_url.assert_called_once()
|
||||
|
||||
@patch("controllers.console.auth.data_source_oauth.get_oauth_providers")
|
||||
@patch("flask_login.current_user")
|
||||
@patch("libs.login.check_csrf_token")
|
||||
@patch("controllers.console.wraps.db")
|
||||
def test_get_oauth_url_invalid_provider(self, mock_db, mock_csrf, mock_flask_user, mock_get_providers, app):
|
||||
mock_get_providers.return_value = {"notion": MagicMock()}
|
||||
|
||||
from models.account import Account, AccountStatus
|
||||
|
||||
mock_account = MagicMock(spec=Account)
|
||||
mock_account.id = "user_123"
|
||||
mock_account.status = AccountStatus.ACTIVE
|
||||
mock_account.is_admin_or_owner = True
|
||||
mock_account.current_tenant.current_role = "owner"
|
||||
|
||||
with patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, MagicMock())):
|
||||
with app.test_request_context("/console/api/oauth/data-source/unknown_provider", method="GET"):
|
||||
proxy_mock = LocalProxy(lambda: mock_account)
|
||||
with patch("libs.login.current_user", proxy_mock):
|
||||
api_instance = OAuthDataSource()
|
||||
response = api_instance.get("unknown_provider")
|
||||
|
||||
assert response[0]["error"] == "Invalid provider"
|
||||
assert response[1] == 400
|
||||
|
||||
|
||||
class TestOAuthDataSourceCallback:
|
||||
@pytest.fixture
|
||||
def app(self):
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
mock_lm = MagicMock()
|
||||
mock_lm._load_user = lambda: setattr(__import__("flask").g, "_login_user", MagicMock())
|
||||
app.login_manager = mock_lm
|
||||
return app
|
||||
|
||||
@patch("controllers.console.auth.data_source_oauth.get_oauth_providers")
|
||||
def test_oauth_callback_successful(self, mock_get_providers, app):
|
||||
provider_mock = MagicMock()
|
||||
mock_get_providers.return_value = {"notion": provider_mock}
|
||||
|
||||
with app.test_request_context("/console/api/oauth/data-source/notion/callback?code=mock_code", method="GET"):
|
||||
api_instance = OAuthDataSourceCallback()
|
||||
response = api_instance.get("notion")
|
||||
|
||||
assert response.status_code == 302
|
||||
assert "code=mock_code" in response.location
|
||||
|
||||
@patch("controllers.console.auth.data_source_oauth.get_oauth_providers")
|
||||
def test_oauth_callback_missing_code(self, mock_get_providers, app):
|
||||
provider_mock = MagicMock()
|
||||
mock_get_providers.return_value = {"notion": provider_mock}
|
||||
|
||||
with app.test_request_context("/console/api/oauth/data-source/notion/callback", method="GET"):
|
||||
api_instance = OAuthDataSourceCallback()
|
||||
response = api_instance.get("notion")
|
||||
|
||||
assert response.status_code == 302
|
||||
assert "error=Access denied" in response.location
|
||||
|
||||
@patch("controllers.console.auth.data_source_oauth.get_oauth_providers")
|
||||
def test_oauth_callback_invalid_provider(self, mock_get_providers, app):
|
||||
mock_get_providers.return_value = {"notion": MagicMock()}
|
||||
|
||||
with app.test_request_context("/console/api/oauth/data-source/invalid/callback?code=mock_code", method="GET"):
|
||||
api_instance = OAuthDataSourceCallback()
|
||||
response = api_instance.get("invalid")
|
||||
|
||||
assert response[0]["error"] == "Invalid provider"
|
||||
assert response[1] == 400
|
||||
|
||||
|
||||
class TestOAuthDataSourceBinding:
|
||||
@pytest.fixture
|
||||
def app(self):
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
mock_lm = MagicMock()
|
||||
mock_lm._load_user = lambda: setattr(__import__("flask").g, "_login_user", MagicMock())
|
||||
app.login_manager = mock_lm
|
||||
return app
|
||||
|
||||
@patch("controllers.console.auth.data_source_oauth.get_oauth_providers")
|
||||
def test_get_binding_successful(self, mock_get_providers, app):
|
||||
mock_provider = MagicMock()
|
||||
mock_provider.get_access_token.return_value = None
|
||||
mock_get_providers.return_value = {"notion": mock_provider}
|
||||
|
||||
with app.test_request_context("/console/api/oauth/data-source/notion/binding?code=auth_code_123", method="GET"):
|
||||
api_instance = OAuthDataSourceBinding()
|
||||
response = api_instance.get("notion")
|
||||
|
||||
assert response[0]["result"] == "success"
|
||||
assert response[1] == 200
|
||||
mock_provider.get_access_token.assert_called_once_with("auth_code_123")
|
||||
|
||||
@patch("controllers.console.auth.data_source_oauth.get_oauth_providers")
|
||||
def test_get_binding_missing_code(self, mock_get_providers, app):
|
||||
mock_get_providers.return_value = {"notion": MagicMock()}
|
||||
|
||||
with app.test_request_context("/console/api/oauth/data-source/notion/binding?code=", method="GET"):
|
||||
api_instance = OAuthDataSourceBinding()
|
||||
response = api_instance.get("notion")
|
||||
|
||||
assert response[0]["error"] == "Invalid code"
|
||||
assert response[1] == 400
|
||||
|
||||
|
||||
class TestOAuthDataSourceSync:
|
||||
@pytest.fixture
|
||||
def app(self):
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
mock_lm = MagicMock()
|
||||
mock_lm._load_user = lambda: setattr(__import__("flask").g, "_login_user", MagicMock())
|
||||
app.login_manager = mock_lm
|
||||
return app
|
||||
|
||||
@patch("controllers.console.auth.data_source_oauth.get_oauth_providers")
|
||||
@patch("libs.login.check_csrf_token")
|
||||
@patch("controllers.console.wraps.db")
|
||||
def test_sync_successful(self, mock_db, mock_csrf, mock_get_providers, app):
|
||||
mock_provider = MagicMock()
|
||||
mock_provider.sync_data_source.return_value = None
|
||||
mock_get_providers.return_value = {"notion": mock_provider}
|
||||
|
||||
from models.account import Account, AccountStatus
|
||||
|
||||
mock_account = MagicMock(spec=Account)
|
||||
mock_account.id = "user_123"
|
||||
mock_account.status = AccountStatus.ACTIVE
|
||||
mock_account.is_admin_or_owner = True
|
||||
mock_account.current_tenant.current_role = "owner"
|
||||
|
||||
with patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, MagicMock())):
|
||||
with app.test_request_context("/console/api/oauth/data-source/notion/binding_123/sync", method="GET"):
|
||||
g._login_user = mock_account
|
||||
proxy_mock = LocalProxy(lambda: mock_account)
|
||||
with patch("libs.login.current_user", proxy_mock):
|
||||
api_instance = OAuthDataSourceSync()
|
||||
# The route pattern uses <uuid:binding_id>, so we just pass a string for unit testing
|
||||
response = api_instance.get("notion", "binding_123")
|
||||
|
||||
assert response[0]["result"] == "success"
|
||||
assert response[1] == 200
|
||||
mock_provider.sync_data_source.assert_called_once_with("binding_123")
|
||||
@ -1,430 +0,0 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask, g
|
||||
from werkzeug.exceptions import BadRequest, NotFound
|
||||
|
||||
from controllers.console.auth.oauth_server import (
|
||||
OAuthServerAppApi,
|
||||
OAuthServerUserAccountApi,
|
||||
OAuthServerUserAuthorizeApi,
|
||||
OAuthServerUserTokenApi,
|
||||
)
|
||||
|
||||
|
||||
class TestOAuthServerAppApi:
|
||||
@pytest.fixture
|
||||
def app(self):
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
mock_lm = MagicMock()
|
||||
mock_lm._load_user = lambda: setattr(__import__("flask").g, "_login_user", MagicMock())
|
||||
app.login_manager = mock_lm
|
||||
return app
|
||||
|
||||
@pytest.fixture
|
||||
def mock_oauth_provider_app(self):
|
||||
from models.model import OAuthProviderApp
|
||||
|
||||
oauth_app = MagicMock(spec=OAuthProviderApp)
|
||||
oauth_app.client_id = "test_client_id"
|
||||
oauth_app.redirect_uris = ["http://localhost/callback"]
|
||||
oauth_app.app_icon = "icon_url"
|
||||
oauth_app.app_label = "Test App"
|
||||
oauth_app.scope = "read,write"
|
||||
return oauth_app
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
|
||||
def test_successful_post(self, mock_get_app, mock_db, app, mock_oauth_provider_app):
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_app.return_value = mock_oauth_provider_app
|
||||
|
||||
with app.test_request_context(
|
||||
"/oauth/provider",
|
||||
method="POST",
|
||||
json={"client_id": "test_client_id", "redirect_uri": "http://localhost/callback"},
|
||||
):
|
||||
api_instance = OAuthServerAppApi()
|
||||
response = api_instance.post()
|
||||
|
||||
assert response["app_icon"] == "icon_url"
|
||||
assert response["app_label"] == "Test App"
|
||||
assert response["scope"] == "read,write"
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
|
||||
def test_invalid_redirect_uri(self, mock_get_app, mock_db, app, mock_oauth_provider_app):
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_app.return_value = mock_oauth_provider_app
|
||||
|
||||
with app.test_request_context(
|
||||
"/oauth/provider",
|
||||
method="POST",
|
||||
json={"client_id": "test_client_id", "redirect_uri": "http://invalid/callback"},
|
||||
):
|
||||
api_instance = OAuthServerAppApi()
|
||||
with pytest.raises(BadRequest, match="redirect_uri is invalid"):
|
||||
api_instance.post()
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
|
||||
def test_invalid_client_id(self, mock_get_app, mock_db, app):
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_app.return_value = None
|
||||
|
||||
with app.test_request_context(
|
||||
"/oauth/provider",
|
||||
method="POST",
|
||||
json={"client_id": "test_invalid_client_id", "redirect_uri": "http://localhost/callback"},
|
||||
):
|
||||
api_instance = OAuthServerAppApi()
|
||||
with pytest.raises(NotFound, match="client_id is invalid"):
|
||||
api_instance.post()
|
||||
|
||||
|
||||
class TestOAuthServerUserAuthorizeApi:
|
||||
@pytest.fixture
|
||||
def app(self):
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
mock_lm = MagicMock()
|
||||
mock_lm._load_user = lambda: setattr(__import__("flask").g, "_login_user", MagicMock())
|
||||
app.login_manager = mock_lm
|
||||
return app
|
||||
|
||||
@pytest.fixture
|
||||
def mock_oauth_provider_app(self):
|
||||
oauth_app = MagicMock()
|
||||
oauth_app.client_id = "test_client_id"
|
||||
return oauth_app
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
|
||||
@patch("controllers.console.auth.oauth_server.current_account_with_tenant")
|
||||
@patch("controllers.console.wraps.current_account_with_tenant")
|
||||
@patch("controllers.console.auth.oauth_server.OAuthServerService.sign_oauth_authorization_code")
|
||||
@patch("libs.login.check_csrf_token")
|
||||
def test_successful_authorize(
|
||||
self, mock_csrf, mock_sign, mock_wrap_current, mock_current, mock_get_app, mock_db, app, mock_oauth_provider_app
|
||||
):
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_app.return_value = mock_oauth_provider_app
|
||||
|
||||
mock_account = MagicMock()
|
||||
mock_account.id = "user_123"
|
||||
from models.account import AccountStatus
|
||||
|
||||
mock_account.status = AccountStatus.ACTIVE
|
||||
|
||||
mock_current.return_value = (mock_account, MagicMock())
|
||||
mock_wrap_current.return_value = (mock_account, MagicMock())
|
||||
|
||||
mock_sign.return_value = "auth_code_123"
|
||||
|
||||
with app.test_request_context("/oauth/provider/authorize", method="POST", json={"client_id": "test_client_id"}):
|
||||
g._login_user = mock_account
|
||||
with patch("libs.login.current_user", mock_account):
|
||||
api_instance = OAuthServerUserAuthorizeApi()
|
||||
response = api_instance.post()
|
||||
|
||||
assert response["code"] == "auth_code_123"
|
||||
mock_sign.assert_called_once_with("test_client_id", "user_123")
|
||||
|
||||
|
||||
class TestOAuthServerUserTokenApi:
|
||||
@pytest.fixture
|
||||
def app(self):
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
mock_lm = MagicMock()
|
||||
mock_lm._load_user = lambda: setattr(__import__("flask").g, "_login_user", MagicMock())
|
||||
app.login_manager = mock_lm
|
||||
return app
|
||||
|
||||
@pytest.fixture
|
||||
def mock_oauth_provider_app(self):
|
||||
from models.model import OAuthProviderApp
|
||||
|
||||
oauth_app = MagicMock(spec=OAuthProviderApp)
|
||||
oauth_app.client_id = "test_client_id"
|
||||
oauth_app.client_secret = "test_secret"
|
||||
oauth_app.redirect_uris = ["http://localhost/callback"]
|
||||
return oauth_app
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
|
||||
@patch("controllers.console.auth.oauth_server.OAuthServerService.sign_oauth_access_token")
|
||||
def test_authorization_code_grant(self, mock_sign, mock_get_app, mock_db, app, mock_oauth_provider_app):
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_app.return_value = mock_oauth_provider_app
|
||||
mock_sign.return_value = ("access_123", "refresh_123")
|
||||
|
||||
with app.test_request_context(
|
||||
"/oauth/provider/token",
|
||||
method="POST",
|
||||
json={
|
||||
"client_id": "test_client_id",
|
||||
"grant_type": "authorization_code",
|
||||
"code": "auth_code",
|
||||
"client_secret": "test_secret",
|
||||
"redirect_uri": "http://localhost/callback",
|
||||
},
|
||||
):
|
||||
api_instance = OAuthServerUserTokenApi()
|
||||
response = api_instance.post()
|
||||
|
||||
assert response["access_token"] == "access_123"
|
||||
assert response["refresh_token"] == "refresh_123"
|
||||
assert response["token_type"] == "Bearer"
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
|
||||
def test_authorization_code_grant_missing_code(self, mock_get_app, mock_db, app, mock_oauth_provider_app):
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_app.return_value = mock_oauth_provider_app
|
||||
|
||||
with app.test_request_context(
|
||||
"/oauth/provider/token",
|
||||
method="POST",
|
||||
json={
|
||||
"client_id": "test_client_id",
|
||||
"grant_type": "authorization_code",
|
||||
"client_secret": "test_secret",
|
||||
"redirect_uri": "http://localhost/callback",
|
||||
},
|
||||
):
|
||||
api_instance = OAuthServerUserTokenApi()
|
||||
with pytest.raises(BadRequest, match="code is required"):
|
||||
api_instance.post()
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
|
||||
def test_authorization_code_grant_invalid_secret(self, mock_get_app, mock_db, app, mock_oauth_provider_app):
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_app.return_value = mock_oauth_provider_app
|
||||
|
||||
with app.test_request_context(
|
||||
"/oauth/provider/token",
|
||||
method="POST",
|
||||
json={
|
||||
"client_id": "test_client_id",
|
||||
"grant_type": "authorization_code",
|
||||
"code": "auth_code",
|
||||
"client_secret": "invalid_secret",
|
||||
"redirect_uri": "http://localhost/callback",
|
||||
},
|
||||
):
|
||||
api_instance = OAuthServerUserTokenApi()
|
||||
with pytest.raises(BadRequest, match="client_secret is invalid"):
|
||||
api_instance.post()
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
|
||||
def test_authorization_code_grant_invalid_redirect_uri(self, mock_get_app, mock_db, app, mock_oauth_provider_app):
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_app.return_value = mock_oauth_provider_app
|
||||
|
||||
with app.test_request_context(
|
||||
"/oauth/provider/token",
|
||||
method="POST",
|
||||
json={
|
||||
"client_id": "test_client_id",
|
||||
"grant_type": "authorization_code",
|
||||
"code": "auth_code",
|
||||
"client_secret": "test_secret",
|
||||
"redirect_uri": "http://invalid/callback",
|
||||
},
|
||||
):
|
||||
api_instance = OAuthServerUserTokenApi()
|
||||
with pytest.raises(BadRequest, match="redirect_uri is invalid"):
|
||||
api_instance.post()
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
|
||||
@patch("controllers.console.auth.oauth_server.OAuthServerService.sign_oauth_access_token")
|
||||
def test_refresh_token_grant(self, mock_sign, mock_get_app, mock_db, app, mock_oauth_provider_app):
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_app.return_value = mock_oauth_provider_app
|
||||
mock_sign.return_value = ("new_access", "new_refresh")
|
||||
|
||||
with app.test_request_context(
|
||||
"/oauth/provider/token",
|
||||
method="POST",
|
||||
json={"client_id": "test_client_id", "grant_type": "refresh_token", "refresh_token": "refresh_123"},
|
||||
):
|
||||
api_instance = OAuthServerUserTokenApi()
|
||||
response = api_instance.post()
|
||||
|
||||
assert response["access_token"] == "new_access"
|
||||
assert response["refresh_token"] == "new_refresh"
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
|
||||
def test_refresh_token_grant_missing_token(self, mock_get_app, mock_db, app, mock_oauth_provider_app):
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_app.return_value = mock_oauth_provider_app
|
||||
|
||||
with app.test_request_context(
|
||||
"/oauth/provider/token",
|
||||
method="POST",
|
||||
json={
|
||||
"client_id": "test_client_id",
|
||||
"grant_type": "refresh_token",
|
||||
},
|
||||
):
|
||||
api_instance = OAuthServerUserTokenApi()
|
||||
with pytest.raises(BadRequest, match="refresh_token is required"):
|
||||
api_instance.post()
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
|
||||
def test_invalid_grant_type(self, mock_get_app, mock_db, app, mock_oauth_provider_app):
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_app.return_value = mock_oauth_provider_app
|
||||
|
||||
with app.test_request_context(
|
||||
"/oauth/provider/token",
|
||||
method="POST",
|
||||
json={
|
||||
"client_id": "test_client_id",
|
||||
"grant_type": "invalid_grant",
|
||||
},
|
||||
):
|
||||
api_instance = OAuthServerUserTokenApi()
|
||||
with pytest.raises(BadRequest, match="invalid grant_type"):
|
||||
api_instance.post()
|
||||
|
||||
|
||||
class TestOAuthServerUserAccountApi:
|
||||
@pytest.fixture
|
||||
def app(self):
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
mock_lm = MagicMock()
|
||||
mock_lm._load_user = lambda: setattr(__import__("flask").g, "_login_user", MagicMock())
|
||||
app.login_manager = mock_lm
|
||||
return app
|
||||
|
||||
@pytest.fixture
|
||||
def mock_oauth_provider_app(self):
|
||||
from models.model import OAuthProviderApp
|
||||
|
||||
oauth_app = MagicMock(spec=OAuthProviderApp)
|
||||
oauth_app.client_id = "test_client_id"
|
||||
return oauth_app
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
|
||||
@patch("controllers.console.auth.oauth_server.OAuthServerService.validate_oauth_access_token")
|
||||
def test_successful_account_retrieval(self, mock_validate, mock_get_app, mock_db, app, mock_oauth_provider_app):
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_app.return_value = mock_oauth_provider_app
|
||||
|
||||
mock_account = MagicMock()
|
||||
mock_account.name = "Test User"
|
||||
mock_account.email = "test@example.com"
|
||||
mock_account.avatar = "avatar_url"
|
||||
mock_account.interface_language = "en-US"
|
||||
mock_account.timezone = "UTC"
|
||||
mock_validate.return_value = mock_account
|
||||
|
||||
with app.test_request_context(
|
||||
"/oauth/provider/account",
|
||||
method="POST",
|
||||
json={"client_id": "test_client_id"},
|
||||
headers={"Authorization": "Bearer valid_access_token"},
|
||||
):
|
||||
api_instance = OAuthServerUserAccountApi()
|
||||
response = api_instance.post()
|
||||
|
||||
assert response["name"] == "Test User"
|
||||
assert response["email"] == "test@example.com"
|
||||
assert response["avatar"] == "avatar_url"
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
|
||||
def test_missing_authorization_header(self, mock_get_app, mock_db, app, mock_oauth_provider_app):
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_app.return_value = mock_oauth_provider_app
|
||||
|
||||
with app.test_request_context("/oauth/provider/account", method="POST", json={"client_id": "test_client_id"}):
|
||||
api_instance = OAuthServerUserAccountApi()
|
||||
response = api_instance.post()
|
||||
|
||||
assert response.status_code == 401
|
||||
assert response.json["error"] == "Authorization header is required"
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
|
||||
def test_invalid_authorization_header_format(self, mock_get_app, mock_db, app, mock_oauth_provider_app):
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_app.return_value = mock_oauth_provider_app
|
||||
|
||||
with app.test_request_context(
|
||||
"/oauth/provider/account",
|
||||
method="POST",
|
||||
json={"client_id": "test_client_id"},
|
||||
headers={"Authorization": "InvalidFormat"},
|
||||
):
|
||||
api_instance = OAuthServerUserAccountApi()
|
||||
response = api_instance.post()
|
||||
|
||||
assert response.status_code == 401
|
||||
assert response.json["error"] == "Invalid Authorization header format"
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
|
||||
def test_invalid_token_type(self, mock_get_app, mock_db, app, mock_oauth_provider_app):
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_app.return_value = mock_oauth_provider_app
|
||||
|
||||
with app.test_request_context(
|
||||
"/oauth/provider/account",
|
||||
method="POST",
|
||||
json={"client_id": "test_client_id"},
|
||||
headers={"Authorization": "Basic something"},
|
||||
):
|
||||
api_instance = OAuthServerUserAccountApi()
|
||||
response = api_instance.post()
|
||||
|
||||
assert response.status_code == 401
|
||||
assert response.json["error"] == "token_type is invalid"
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
|
||||
def test_missing_access_token(self, mock_get_app, mock_db, app, mock_oauth_provider_app):
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_app.return_value = mock_oauth_provider_app
|
||||
|
||||
with app.test_request_context(
|
||||
"/oauth/provider/account",
|
||||
method="POST",
|
||||
json={"client_id": "test_client_id"},
|
||||
headers={"Authorization": "Bearer "},
|
||||
):
|
||||
api_instance = OAuthServerUserAccountApi()
|
||||
response = api_instance.post()
|
||||
|
||||
assert response.status_code == 401
|
||||
assert response.json["error"] == "Invalid Authorization header format"
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
|
||||
@patch("controllers.console.auth.oauth_server.OAuthServerService.validate_oauth_access_token")
|
||||
def test_invalid_access_token(self, mock_validate, mock_get_app, mock_db, app, mock_oauth_provider_app):
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_app.return_value = mock_oauth_provider_app
|
||||
mock_validate.return_value = None
|
||||
|
||||
with app.test_request_context(
|
||||
"/oauth/provider/account",
|
||||
method="POST",
|
||||
json={"client_id": "test_client_id"},
|
||||
headers={"Authorization": "Bearer invalid_token"},
|
||||
):
|
||||
api_instance = OAuthServerUserAccountApi()
|
||||
response = api_instance.post()
|
||||
|
||||
assert response.status_code == 401
|
||||
assert response.json["error"] == "access_token or client_id is invalid"
|
||||
@ -11,6 +11,7 @@ from controllers.console.tag.tags import (
|
||||
TagListApi,
|
||||
TagUpdateDeleteApi,
|
||||
)
|
||||
from models.enums import TagType
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
@ -52,7 +53,7 @@ def tag():
|
||||
tag = MagicMock()
|
||||
tag.id = "tag-1"
|
||||
tag.name = "test-tag"
|
||||
tag.type = "knowledge"
|
||||
tag.type = TagType.KNOWLEDGE
|
||||
return tag
|
||||
|
||||
|
||||
|
||||
@ -8,6 +8,7 @@ from controllers.console.apikey import (
|
||||
BaseApiKeyResource,
|
||||
_get_resource,
|
||||
)
|
||||
from models.enums import ApiTokenType
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -45,14 +46,14 @@ def bypass_permissions():
|
||||
|
||||
|
||||
class DummyApiKeyListResource(BaseApiKeyListResource):
|
||||
resource_type = "app"
|
||||
resource_type = ApiTokenType.APP
|
||||
resource_model = MagicMock()
|
||||
resource_id_field = "app_id"
|
||||
token_prefix = "app-"
|
||||
|
||||
|
||||
class DummyApiKeyResource(BaseApiKeyResource):
|
||||
resource_type = "app"
|
||||
resource_type = ApiTokenType.APP
|
||||
resource_model = MagicMock()
|
||||
resource_id_field = "app_id"
|
||||
|
||||
|
||||
@ -35,6 +35,7 @@ from controllers.service_api.dataset.dataset import (
|
||||
from controllers.service_api.dataset.error import DatasetInUseError, DatasetNameDuplicateError, InvalidActionError
|
||||
from models.account import Account
|
||||
from models.dataset import DatasetPermissionEnum
|
||||
from models.enums import TagType
|
||||
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
|
||||
from services.tag_service import TagService
|
||||
|
||||
@ -277,7 +278,7 @@ class TestDatasetTagsApi:
|
||||
mock_tag = Mock()
|
||||
mock_tag.id = "tag_1"
|
||||
mock_tag.name = "Test Tag"
|
||||
mock_tag.type = "knowledge"
|
||||
mock_tag.type = TagType.KNOWLEDGE
|
||||
mock_tag.binding_count = "0" # Required for Pydantic validation - must be string
|
||||
mock_tag_service.get_tags.return_value = [mock_tag]
|
||||
|
||||
@ -316,7 +317,7 @@ class TestDatasetTagsApi:
|
||||
mock_tag = Mock()
|
||||
mock_tag.id = "new_tag_1"
|
||||
mock_tag.name = "New Tag"
|
||||
mock_tag.type = "knowledge"
|
||||
mock_tag.type = TagType.KNOWLEDGE
|
||||
mock_tag_service.save_tags.return_value = mock_tag
|
||||
mock_service_api_ns.payload = {"name": "New Tag"}
|
||||
|
||||
@ -378,7 +379,7 @@ class TestDatasetTagsApi:
|
||||
mock_tag = Mock()
|
||||
mock_tag.id = "tag_1"
|
||||
mock_tag.name = "Updated Tag"
|
||||
mock_tag.type = "knowledge"
|
||||
mock_tag.type = TagType.KNOWLEDGE
|
||||
mock_tag.binding_count = "5"
|
||||
mock_tag_service.update_tags.return_value = mock_tag
|
||||
mock_tag_service.get_tag_binding_count.return_value = 5
|
||||
@ -866,7 +867,7 @@ class TestTagService:
|
||||
mock_tag = Mock()
|
||||
mock_tag.id = str(uuid.uuid4())
|
||||
mock_tag.name = "New Tag"
|
||||
mock_tag.type = "knowledge"
|
||||
mock_tag.type = TagType.KNOWLEDGE
|
||||
mock_save.return_value = mock_tag
|
||||
|
||||
result = TagService.save_tags({"name": "New Tag", "type": "knowledge"})
|
||||
|
||||
@ -21,7 +21,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.entities.task_entities import MessageEndStreamResponse
|
||||
from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline
|
||||
from dify_graph.file.enums import FileTransferMethod
|
||||
from dify_graph.file.enums import FileTransferMethod, FileType
|
||||
from models.model import MessageFile, UploadFile
|
||||
|
||||
|
||||
@ -51,7 +51,7 @@ class TestMessageEndStreamResponseFiles:
|
||||
message_file.transfer_method = FileTransferMethod.LOCAL_FILE
|
||||
message_file.upload_file_id = str(uuid.uuid4())
|
||||
message_file.url = None
|
||||
message_file.type = "image"
|
||||
message_file.type = FileType.IMAGE
|
||||
return message_file
|
||||
|
||||
@pytest.fixture
|
||||
@ -63,7 +63,7 @@ class TestMessageEndStreamResponseFiles:
|
||||
message_file.transfer_method = FileTransferMethod.REMOTE_URL
|
||||
message_file.upload_file_id = None
|
||||
message_file.url = "https://example.com/image.jpg"
|
||||
message_file.type = "image"
|
||||
message_file.type = FileType.IMAGE
|
||||
return message_file
|
||||
|
||||
@pytest.fixture
|
||||
@ -75,7 +75,7 @@ class TestMessageEndStreamResponseFiles:
|
||||
message_file.transfer_method = FileTransferMethod.TOOL_FILE
|
||||
message_file.upload_file_id = None
|
||||
message_file.url = "tool_file_123.png"
|
||||
message_file.type = "image"
|
||||
message_file.type = FileType.IMAGE
|
||||
return message_file
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
@ -164,6 +164,13 @@ class TestFirecrawlApp:
|
||||
with pytest.raises(Exception, match="No page found"):
|
||||
app.check_crawl_status("job-1")
|
||||
|
||||
def test_check_crawl_status_completed_with_null_total_raises(self, mocker: MockerFixture):
|
||||
app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev")
|
||||
mocker.patch("httpx.get", return_value=_response(200, {"status": "completed", "total": None, "data": []}))
|
||||
|
||||
with pytest.raises(Exception, match="No page found"):
|
||||
app.check_crawl_status("job-1")
|
||||
|
||||
def test_check_crawl_status_non_completed(self, mocker: MockerFixture):
|
||||
app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev")
|
||||
payload = {"status": "processing", "total": 5, "completed": 1, "data": []}
|
||||
@ -203,6 +210,77 @@ class TestFirecrawlApp:
|
||||
with pytest.raises(Exception, match="Error saving crawl data"):
|
||||
app.check_crawl_status("job-err")
|
||||
|
||||
def test_check_crawl_status_follows_pagination(self, mocker: MockerFixture):
|
||||
"""When status is completed and next is present, follow pagination to collect all pages."""
|
||||
app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev")
|
||||
page1 = {
|
||||
"status": "completed",
|
||||
"total": 3,
|
||||
"completed": 3,
|
||||
"next": "https://custom.firecrawl.dev/v2/crawl/job-42?skip=1",
|
||||
"data": [{"metadata": {"title": "p1", "description": "", "sourceURL": "https://p1"}, "markdown": "m1"}],
|
||||
}
|
||||
page2 = {
|
||||
"status": "completed",
|
||||
"total": 3,
|
||||
"completed": 3,
|
||||
"next": "https://custom.firecrawl.dev/v2/crawl/job-42?skip=2",
|
||||
"data": [{"metadata": {"title": "p2", "description": "", "sourceURL": "https://p2"}, "markdown": "m2"}],
|
||||
}
|
||||
page3 = {
|
||||
"status": "completed",
|
||||
"total": 3,
|
||||
"completed": 3,
|
||||
"data": [{"metadata": {"title": "p3", "description": "", "sourceURL": "https://p3"}, "markdown": "m3"}],
|
||||
}
|
||||
mocker.patch("httpx.get", side_effect=[_response(200, page1), _response(200, page2), _response(200, page3)])
|
||||
mock_storage = MagicMock()
|
||||
mock_storage.exists.return_value = False
|
||||
mocker.patch.object(firecrawl_module, "storage", mock_storage)
|
||||
|
||||
result = app.check_crawl_status("job-42")
|
||||
|
||||
assert result["status"] == "completed"
|
||||
assert result["total"] == 3
|
||||
assert len(result["data"]) == 3
|
||||
assert [d["title"] for d in result["data"]] == ["p1", "p2", "p3"]
|
||||
|
||||
def test_check_crawl_status_pagination_error_raises(self, mocker: MockerFixture):
|
||||
"""An error while fetching a paginated page raises an exception; no partial data is returned."""
|
||||
app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev")
|
||||
page1 = {
|
||||
"status": "completed",
|
||||
"total": 2,
|
||||
"completed": 2,
|
||||
"next": "https://custom.firecrawl.dev/v2/crawl/job-99?skip=1",
|
||||
"data": [{"metadata": {"title": "p1", "description": "", "sourceURL": "https://p1"}, "markdown": "m1"}],
|
||||
}
|
||||
mocker.patch("httpx.get", side_effect=[_response(200, page1), _response(500, {"error": "server error"})])
|
||||
|
||||
with pytest.raises(Exception, match="fetch next crawl page"):
|
||||
app.check_crawl_status("job-99")
|
||||
|
||||
def test_check_crawl_status_pagination_capped_at_total(self, mocker: MockerFixture):
|
||||
"""Pagination stops once pages_processed reaches total, even if next is present."""
|
||||
app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev")
|
||||
# total=1: only the first page should be processed; next must not be followed
|
||||
page1 = {
|
||||
"status": "completed",
|
||||
"total": 1,
|
||||
"completed": 1,
|
||||
"next": "https://custom.firecrawl.dev/v2/crawl/job-cap?skip=1",
|
||||
"data": [{"metadata": {"title": "p1", "description": "", "sourceURL": "https://p1"}, "markdown": "m1"}],
|
||||
}
|
||||
mock_get = mocker.patch("httpx.get", return_value=_response(200, page1))
|
||||
mock_storage = MagicMock()
|
||||
mock_storage.exists.return_value = False
|
||||
mocker.patch.object(firecrawl_module, "storage", mock_storage)
|
||||
|
||||
result = app.check_crawl_status("job-cap")
|
||||
|
||||
assert len(result["data"]) == 1
|
||||
mock_get.assert_called_once() # initial fetch only; next URL is not followed due to cap
|
||||
|
||||
def test_extract_common_fields_and_status_formatter(self):
|
||||
app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev")
|
||||
|
||||
|
||||
@ -95,13 +95,11 @@ class TestGitHubOAuth(BaseOAuthTest):
|
||||
],
|
||||
"primary@example.com",
|
||||
),
|
||||
# User with no emails - fallback to noreply
|
||||
({"id": 12345, "login": "testuser", "name": "Test User"}, [], "12345+testuser@users.noreply.github.com"),
|
||||
# User with only secondary email - fallback to noreply
|
||||
# User with private email (null email and name from API)
|
||||
(
|
||||
{"id": 12345, "login": "testuser", "name": "Test User"},
|
||||
[{"email": "secondary@example.com", "primary": False}],
|
||||
"12345+testuser@users.noreply.github.com",
|
||||
{"id": 12345, "login": "testuser", "name": None, "email": None},
|
||||
[{"email": "primary@example.com", "primary": True}],
|
||||
"primary@example.com",
|
||||
),
|
||||
],
|
||||
)
|
||||
@ -118,9 +116,54 @@ class TestGitHubOAuth(BaseOAuthTest):
|
||||
user_info = oauth.get_user_info("test_token")
|
||||
|
||||
assert user_info.id == str(user_data["id"])
|
||||
assert user_info.name == user_data["name"]
|
||||
assert user_info.name == (user_data["name"] or "")
|
||||
assert user_info.email == expected_email
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("user_data", "email_data"),
|
||||
[
|
||||
# User with no emails
|
||||
({"id": 12345, "login": "testuser", "name": "Test User"}, []),
|
||||
# User with only secondary email
|
||||
(
|
||||
{"id": 12345, "login": "testuser", "name": "Test User"},
|
||||
[{"email": "secondary@example.com", "primary": False}],
|
||||
),
|
||||
# User with private email and no primary in emails endpoint
|
||||
(
|
||||
{"id": 12345, "login": "testuser", "name": None, "email": None},
|
||||
[],
|
||||
),
|
||||
],
|
||||
)
|
||||
@patch("httpx.get", autospec=True)
|
||||
def test_should_raise_error_when_no_primary_email(self, mock_get, oauth, user_data, email_data):
|
||||
user_response = MagicMock()
|
||||
user_response.json.return_value = user_data
|
||||
|
||||
email_response = MagicMock()
|
||||
email_response.json.return_value = email_data
|
||||
|
||||
mock_get.side_effect = [user_response, email_response]
|
||||
|
||||
with pytest.raises(ValueError, match="Keep my email addresses private"):
|
||||
oauth.get_user_info("test_token")
|
||||
|
||||
@patch("httpx.get", autospec=True)
|
||||
def test_should_raise_error_when_email_endpoint_fails(self, mock_get, oauth):
|
||||
user_response = MagicMock()
|
||||
user_response.json.return_value = {"id": 12345, "login": "testuser", "name": "Test User"}
|
||||
|
||||
email_response = MagicMock()
|
||||
email_response.raise_for_status.side_effect = httpx.HTTPStatusError(
|
||||
"Forbidden", request=MagicMock(), response=MagicMock()
|
||||
)
|
||||
|
||||
mock_get.side_effect = [user_response, email_response]
|
||||
|
||||
with pytest.raises(ValueError, match="Keep my email addresses private"):
|
||||
oauth.get_user_info("test_token")
|
||||
|
||||
@patch("httpx.get", autospec=True)
|
||||
def test_should_handle_network_errors(self, mock_get, oauth):
|
||||
mock_get.side_effect = httpx.RequestError("Network error")
|
||||
|
||||
@ -12,7 +12,7 @@ This test suite covers:
|
||||
import json
|
||||
from uuid import uuid4
|
||||
|
||||
from core.tools.entities.tool_entities import ApiProviderSchemaType
|
||||
from core.tools.entities.tool_entities import ApiProviderSchemaType, ToolProviderType
|
||||
from models.tools import (
|
||||
ApiToolProvider,
|
||||
BuiltinToolProvider,
|
||||
@ -631,7 +631,7 @@ class TestToolLabelBinding:
|
||||
"""Test creating a tool label binding."""
|
||||
# Arrange
|
||||
tool_id = "google.search"
|
||||
tool_type = "builtin"
|
||||
tool_type = ToolProviderType.BUILT_IN
|
||||
label_name = "search"
|
||||
|
||||
# Act
|
||||
@ -655,7 +655,7 @@ class TestToolLabelBinding:
|
||||
# Act
|
||||
label_binding = ToolLabelBinding(
|
||||
tool_id=tool_id,
|
||||
tool_type="builtin",
|
||||
tool_type=ToolProviderType.BUILT_IN,
|
||||
label_name=label_name,
|
||||
)
|
||||
|
||||
@ -667,7 +667,7 @@ class TestToolLabelBinding:
|
||||
"""Test multiple labels can be bound to the same tool."""
|
||||
# Arrange
|
||||
tool_id = "google.search"
|
||||
tool_type = "builtin"
|
||||
tool_type = ToolProviderType.BUILT_IN
|
||||
|
||||
# Act
|
||||
binding1 = ToolLabelBinding(
|
||||
@ -688,7 +688,7 @@ class TestToolLabelBinding:
|
||||
def test_tool_label_binding_different_tool_types(self):
|
||||
"""Test label bindings for different tool types."""
|
||||
# Arrange
|
||||
tool_types = ["builtin", "api", "workflow"]
|
||||
tool_types = [ToolProviderType.BUILT_IN, ToolProviderType.API, ToolProviderType.WORKFLOW]
|
||||
|
||||
# Act & Assert
|
||||
for tool_type in tool_types:
|
||||
@ -951,12 +951,12 @@ class TestToolProviderRelationships:
|
||||
# Act
|
||||
binding1 = ToolLabelBinding(
|
||||
tool_id=tool_id,
|
||||
tool_type="builtin",
|
||||
tool_type=ToolProviderType.BUILT_IN,
|
||||
label_name="search",
|
||||
)
|
||||
binding2 = ToolLabelBinding(
|
||||
tool_id=tool_id,
|
||||
tool_type="builtin",
|
||||
tool_type=ToolProviderType.BUILT_IN,
|
||||
label_name="web",
|
||||
)
|
||||
|
||||
|
||||
@ -1,180 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
from core.entities.execution_extra_content import HumanInputContent as HumanInputContentDomain
|
||||
from core.entities.execution_extra_content import HumanInputFormSubmissionData
|
||||
from dify_graph.nodes.human_input.entities import (
|
||||
FormDefinition,
|
||||
UserAction,
|
||||
)
|
||||
from dify_graph.nodes.human_input.enums import HumanInputFormStatus
|
||||
from models.execution_extra_content import HumanInputContent as HumanInputContentModel
|
||||
from models.human_input import ConsoleRecipientPayload, HumanInputForm, HumanInputFormRecipient, RecipientType
|
||||
from repositories.sqlalchemy_execution_extra_content_repository import SQLAlchemyExecutionExtraContentRepository
|
||||
|
||||
|
||||
class _FakeScalarResult:
|
||||
def __init__(self, values: Sequence[HumanInputContentModel]):
|
||||
self._values = list(values)
|
||||
|
||||
def all(self) -> list[HumanInputContentModel]:
|
||||
return list(self._values)
|
||||
|
||||
|
||||
class _FakeSession:
|
||||
def __init__(self, values: Sequence[Sequence[object]]):
|
||||
self._values = list(values)
|
||||
|
||||
def scalars(self, _stmt):
|
||||
if not self._values:
|
||||
return _FakeScalarResult([])
|
||||
return _FakeScalarResult(self._values.pop(0))
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
|
||||
@dataclass
|
||||
class _FakeSessionMaker:
|
||||
session: _FakeSession
|
||||
|
||||
def __call__(self) -> _FakeSession:
|
||||
return self.session
|
||||
|
||||
|
||||
def _build_form(action_id: str, action_title: str, rendered_content: str) -> HumanInputForm:
|
||||
expiration_time = datetime.now(UTC) + timedelta(days=1)
|
||||
definition = FormDefinition(
|
||||
form_content="content",
|
||||
inputs=[],
|
||||
user_actions=[UserAction(id=action_id, title=action_title)],
|
||||
rendered_content="rendered",
|
||||
expiration_time=expiration_time,
|
||||
node_title="Approval",
|
||||
display_in_ui=True,
|
||||
)
|
||||
form = HumanInputForm(
|
||||
id=f"form-{action_id}",
|
||||
tenant_id="tenant-id",
|
||||
app_id="app-id",
|
||||
workflow_run_id="workflow-run",
|
||||
node_id="node-id",
|
||||
form_definition=definition.model_dump_json(),
|
||||
rendered_content=rendered_content,
|
||||
status=HumanInputFormStatus.SUBMITTED,
|
||||
expiration_time=expiration_time,
|
||||
)
|
||||
form.selected_action_id = action_id
|
||||
return form
|
||||
|
||||
|
||||
def _build_content(message_id: str, action_id: str, action_title: str) -> HumanInputContentModel:
|
||||
form = _build_form(
|
||||
action_id=action_id,
|
||||
action_title=action_title,
|
||||
rendered_content=f"Rendered {action_title}",
|
||||
)
|
||||
content = HumanInputContentModel(
|
||||
id=f"content-{message_id}",
|
||||
form_id=form.id,
|
||||
message_id=message_id,
|
||||
workflow_run_id=form.workflow_run_id,
|
||||
)
|
||||
content.form = form
|
||||
return content
|
||||
|
||||
|
||||
def test_get_by_message_ids_groups_contents_by_message() -> None:
|
||||
message_ids = ["msg-1", "msg-2"]
|
||||
contents = [_build_content("msg-1", "approve", "Approve")]
|
||||
repository = SQLAlchemyExecutionExtraContentRepository(
|
||||
session_maker=_FakeSessionMaker(session=_FakeSession(values=[contents, []]))
|
||||
)
|
||||
|
||||
result = repository.get_by_message_ids(message_ids)
|
||||
|
||||
assert len(result) == 2
|
||||
assert [content.model_dump(mode="json", exclude_none=True) for content in result[0]] == [
|
||||
HumanInputContentDomain(
|
||||
workflow_run_id="workflow-run",
|
||||
submitted=True,
|
||||
form_submission_data=HumanInputFormSubmissionData(
|
||||
node_id="node-id",
|
||||
node_title="Approval",
|
||||
rendered_content="Rendered Approve",
|
||||
action_id="approve",
|
||||
action_text="Approve",
|
||||
),
|
||||
).model_dump(mode="json", exclude_none=True)
|
||||
]
|
||||
assert result[1] == []
|
||||
|
||||
|
||||
def test_get_by_message_ids_returns_unsubmitted_form_definition() -> None:
|
||||
expiration_time = datetime.now(UTC) + timedelta(days=1)
|
||||
definition = FormDefinition(
|
||||
form_content="content",
|
||||
inputs=[],
|
||||
user_actions=[UserAction(id="approve", title="Approve")],
|
||||
rendered_content="rendered",
|
||||
expiration_time=expiration_time,
|
||||
default_values={"name": "John"},
|
||||
node_title="Approval",
|
||||
display_in_ui=True,
|
||||
)
|
||||
form = HumanInputForm(
|
||||
id="form-1",
|
||||
tenant_id="tenant-id",
|
||||
app_id="app-id",
|
||||
workflow_run_id="workflow-run",
|
||||
node_id="node-id",
|
||||
form_definition=definition.model_dump_json(),
|
||||
rendered_content="Rendered block",
|
||||
status=HumanInputFormStatus.WAITING,
|
||||
expiration_time=expiration_time,
|
||||
)
|
||||
content = HumanInputContentModel(
|
||||
id="content-msg-1",
|
||||
form_id=form.id,
|
||||
message_id="msg-1",
|
||||
workflow_run_id=form.workflow_run_id,
|
||||
)
|
||||
content.form = form
|
||||
|
||||
recipient = HumanInputFormRecipient(
|
||||
form_id=form.id,
|
||||
delivery_id="delivery-1",
|
||||
recipient_type=RecipientType.CONSOLE,
|
||||
recipient_payload=ConsoleRecipientPayload(account_id=None).model_dump_json(),
|
||||
access_token="token-1",
|
||||
)
|
||||
|
||||
repository = SQLAlchemyExecutionExtraContentRepository(
|
||||
session_maker=_FakeSessionMaker(session=_FakeSession(values=[[content], [recipient]]))
|
||||
)
|
||||
|
||||
result = repository.get_by_message_ids(["msg-1"])
|
||||
|
||||
assert len(result) == 1
|
||||
assert len(result[0]) == 1
|
||||
domain_content = result[0][0]
|
||||
assert domain_content.submitted is False
|
||||
assert domain_content.workflow_run_id == "workflow-run"
|
||||
assert domain_content.form_definition is not None
|
||||
assert domain_content.form_definition.expiration_time == int(form.expiration_time.timestamp())
|
||||
assert domain_content.form_definition is not None
|
||||
form_definition = domain_content.form_definition
|
||||
assert form_definition.form_id == "form-1"
|
||||
assert form_definition.node_id == "node-id"
|
||||
assert form_definition.node_title == "Approval"
|
||||
assert form_definition.form_content == "Rendered block"
|
||||
assert form_definition.display_in_ui is True
|
||||
assert form_definition.form_token == "token-1"
|
||||
assert form_definition.resolved_default_values == {"name": "John"}
|
||||
assert form_definition.expiration_time == int(form.expiration_time.timestamp())
|
||||
@ -4,6 +4,7 @@ import pytest
|
||||
|
||||
from models.account import Account
|
||||
from models.dataset import ChildChunk, Dataset, Document, DocumentSegment
|
||||
from models.enums import SegmentType
|
||||
from services.dataset_service import SegmentService
|
||||
from services.entities.knowledge_entities.knowledge_entities import SegmentUpdateArgs
|
||||
from services.errors.chunk import ChildChunkDeleteIndexError, ChildChunkIndexingError
|
||||
@ -77,7 +78,7 @@ class SegmentTestDataFactory:
|
||||
chunk.word_count = word_count
|
||||
chunk.index_node_id = f"node-{chunk_id}"
|
||||
chunk.index_node_hash = "hash-123"
|
||||
chunk.type = "automatic"
|
||||
chunk.type = SegmentType.AUTOMATIC
|
||||
chunk.created_by = "user-123"
|
||||
chunk.updated_by = None
|
||||
chunk.updated_at = None
|
||||
|
||||
@ -1,421 +0,0 @@
|
||||
"""
|
||||
Comprehensive unit tests for services/api_based_extension_service.py
|
||||
|
||||
Covers:
|
||||
- APIBasedExtensionService.get_all_by_tenant_id
|
||||
- APIBasedExtensionService.save
|
||||
- APIBasedExtensionService.delete
|
||||
- APIBasedExtensionService.get_with_tenant_id
|
||||
- APIBasedExtensionService._validation (new record & existing record branches)
|
||||
- APIBasedExtensionService._ping_connection (pong success, wrong response, exception)
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from services.api_based_extension_service import APIBasedExtensionService
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_extension(
|
||||
*,
|
||||
id_: str | None = None,
|
||||
tenant_id: str = "tenant-001",
|
||||
name: str = "my-ext",
|
||||
api_endpoint: str = "https://example.com/hook",
|
||||
api_key: str = "secret-key-123",
|
||||
) -> MagicMock:
|
||||
"""Return a lightweight mock that mimics APIBasedExtension."""
|
||||
ext = MagicMock()
|
||||
ext.id = id_
|
||||
ext.tenant_id = tenant_id
|
||||
ext.name = name
|
||||
ext.api_endpoint = api_endpoint
|
||||
ext.api_key = api_key
|
||||
return ext
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: get_all_by_tenant_id
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGetAllByTenantId:
|
||||
"""Tests for APIBasedExtensionService.get_all_by_tenant_id."""
|
||||
|
||||
@patch("services.api_based_extension_service.decrypt_token", return_value="decrypted-key")
|
||||
@patch("services.api_based_extension_service.db")
|
||||
def test_returns_extensions_with_decrypted_keys(self, mock_db, mock_decrypt):
|
||||
"""Each api_key is decrypted and the list is returned."""
|
||||
ext1 = _make_extension(id_="id-1", api_key="enc-key-1")
|
||||
ext2 = _make_extension(id_="id-2", api_key="enc-key-2")
|
||||
|
||||
mock_db.session.query.return_value.filter_by.return_value.order_by.return_value.all.return_value = [
|
||||
ext1,
|
||||
ext2,
|
||||
]
|
||||
|
||||
result = APIBasedExtensionService.get_all_by_tenant_id("tenant-001")
|
||||
|
||||
assert result == [ext1, ext2]
|
||||
assert ext1.api_key == "decrypted-key"
|
||||
assert ext2.api_key == "decrypted-key"
|
||||
assert mock_decrypt.call_count == 2
|
||||
|
||||
@patch("services.api_based_extension_service.decrypt_token", return_value="decrypted-key")
|
||||
@patch("services.api_based_extension_service.db")
|
||||
def test_returns_empty_list_when_no_extensions(self, mock_db, mock_decrypt):
|
||||
"""Returns an empty list gracefully when no records exist."""
|
||||
mock_db.session.query.return_value.filter_by.return_value.order_by.return_value.all.return_value = []
|
||||
|
||||
result = APIBasedExtensionService.get_all_by_tenant_id("tenant-001")
|
||||
|
||||
assert result == []
|
||||
mock_decrypt.assert_not_called()
|
||||
|
||||
@patch("services.api_based_extension_service.decrypt_token", return_value="decrypted-key")
|
||||
@patch("services.api_based_extension_service.db")
|
||||
def test_calls_query_with_correct_tenant_id(self, mock_db, mock_decrypt):
|
||||
"""Verifies the DB is queried with the supplied tenant_id."""
|
||||
mock_db.session.query.return_value.filter_by.return_value.order_by.return_value.all.return_value = []
|
||||
|
||||
APIBasedExtensionService.get_all_by_tenant_id("tenant-xyz")
|
||||
|
||||
mock_db.session.query.return_value.filter_by.assert_called_once_with(tenant_id="tenant-xyz")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: save
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSave:
|
||||
"""Tests for APIBasedExtensionService.save."""
|
||||
|
||||
@patch("services.api_based_extension_service.encrypt_token", return_value="encrypted-key")
|
||||
@patch("services.api_based_extension_service.db")
|
||||
@patch.object(APIBasedExtensionService, "_validation")
|
||||
def test_save_new_record_encrypts_key_and_commits(self, mock_validation, mock_db, mock_encrypt):
|
||||
"""Happy path: validation passes, key is encrypted, record is added and committed."""
|
||||
ext = _make_extension(id_=None, api_key="plain-key-123")
|
||||
|
||||
result = APIBasedExtensionService.save(ext)
|
||||
|
||||
mock_validation.assert_called_once_with(ext)
|
||||
mock_encrypt.assert_called_once_with(ext.tenant_id, "plain-key-123")
|
||||
assert ext.api_key == "encrypted-key"
|
||||
mock_db.session.add.assert_called_once_with(ext)
|
||||
mock_db.session.commit.assert_called_once()
|
||||
assert result is ext
|
||||
|
||||
@patch("services.api_based_extension_service.encrypt_token", return_value="encrypted-key")
|
||||
@patch("services.api_based_extension_service.db")
|
||||
@patch.object(APIBasedExtensionService, "_validation", side_effect=ValueError("name must not be empty"))
|
||||
def test_save_raises_when_validation_fails(self, mock_validation, mock_db, mock_encrypt):
|
||||
"""If _validation raises, save should propagate the error without touching the DB."""
|
||||
ext = _make_extension(name="")
|
||||
|
||||
with pytest.raises(ValueError, match="name must not be empty"):
|
||||
APIBasedExtensionService.save(ext)
|
||||
|
||||
mock_db.session.add.assert_not_called()
|
||||
mock_db.session.commit.assert_not_called()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: delete
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestDelete:
|
||||
"""Tests for APIBasedExtensionService.delete."""
|
||||
|
||||
@patch("services.api_based_extension_service.db")
|
||||
def test_delete_removes_record_and_commits(self, mock_db):
|
||||
"""delete() must call session.delete with the extension and then commit."""
|
||||
ext = _make_extension(id_="delete-me")
|
||||
|
||||
APIBasedExtensionService.delete(ext)
|
||||
|
||||
mock_db.session.delete.assert_called_once_with(ext)
|
||||
mock_db.session.commit.assert_called_once()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: get_with_tenant_id
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGetWithTenantId:
|
||||
"""Tests for APIBasedExtensionService.get_with_tenant_id."""
|
||||
|
||||
@patch("services.api_based_extension_service.decrypt_token", return_value="decrypted-key")
|
||||
@patch("services.api_based_extension_service.db")
|
||||
def test_returns_extension_with_decrypted_key(self, mock_db, mock_decrypt):
|
||||
"""Found extension has its api_key decrypted before being returned."""
|
||||
ext = _make_extension(id_="ext-123", api_key="enc-key")
|
||||
|
||||
(mock_db.session.query.return_value.filter_by.return_value.filter_by.return_value.first.return_value) = ext
|
||||
|
||||
result = APIBasedExtensionService.get_with_tenant_id("tenant-001", "ext-123")
|
||||
|
||||
assert result is ext
|
||||
assert ext.api_key == "decrypted-key"
|
||||
mock_decrypt.assert_called_once_with(ext.tenant_id, "enc-key")
|
||||
|
||||
@patch("services.api_based_extension_service.db")
|
||||
def test_raises_value_error_when_not_found(self, mock_db):
|
||||
"""Raises ValueError when no matching extension exists."""
|
||||
(mock_db.session.query.return_value.filter_by.return_value.filter_by.return_value.first.return_value) = None
|
||||
|
||||
with pytest.raises(ValueError, match="API based extension is not found"):
|
||||
APIBasedExtensionService.get_with_tenant_id("tenant-001", "non-existent")
|
||||
|
||||
@patch("services.api_based_extension_service.decrypt_token", return_value="decrypted-key")
|
||||
@patch("services.api_based_extension_service.db")
|
||||
def test_queries_with_correct_tenant_and_extension_id(self, mock_db, mock_decrypt):
|
||||
"""Verifies both tenant_id and extension id are used in the query."""
|
||||
ext = _make_extension(id_="ext-abc")
|
||||
chain = mock_db.session.query.return_value
|
||||
chain.filter_by.return_value.filter_by.return_value.first.return_value = ext
|
||||
|
||||
APIBasedExtensionService.get_with_tenant_id("tenant-002", "ext-abc")
|
||||
|
||||
# First filter_by call uses tenant_id
|
||||
chain.filter_by.assert_called_once_with(tenant_id="tenant-002")
|
||||
# Second filter_by call uses id
|
||||
chain.filter_by.return_value.filter_by.assert_called_once_with(id="ext-abc")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: _validation (new record — id is falsy)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestValidationNewRecord:
|
||||
"""Tests for _validation() with a brand-new record (no id)."""
|
||||
|
||||
def _build_mock_db(self, name_exists: bool = False):
|
||||
mock_db = MagicMock()
|
||||
mock_db.session.query.return_value.filter_by.return_value.filter_by.return_value.first.return_value = (
|
||||
MagicMock() if name_exists else None
|
||||
)
|
||||
return mock_db
|
||||
|
||||
@patch.object(APIBasedExtensionService, "_ping_connection")
|
||||
@patch("services.api_based_extension_service.db")
|
||||
def test_valid_new_extension_passes(self, mock_db, mock_ping):
|
||||
"""A new record with all valid fields should pass without exceptions."""
|
||||
mock_db.session.query.return_value.filter_by.return_value.filter_by.return_value.first.return_value = None
|
||||
ext = _make_extension(id_=None, name="valid-ext", api_key="longenoughkey")
|
||||
|
||||
# Should not raise
|
||||
APIBasedExtensionService._validation(ext)
|
||||
mock_ping.assert_called_once_with(ext)
|
||||
|
||||
@patch("services.api_based_extension_service.db")
|
||||
def test_raises_if_name_is_empty(self, mock_db):
|
||||
"""Empty name raises ValueError."""
|
||||
ext = _make_extension(id_=None, name="")
|
||||
with pytest.raises(ValueError, match="name must not be empty"):
|
||||
APIBasedExtensionService._validation(ext)
|
||||
|
||||
@patch("services.api_based_extension_service.db")
|
||||
def test_raises_if_name_is_none(self, mock_db):
|
||||
"""None name raises ValueError."""
|
||||
ext = _make_extension(id_=None, name=None)
|
||||
with pytest.raises(ValueError, match="name must not be empty"):
|
||||
APIBasedExtensionService._validation(ext)
|
||||
|
||||
@patch("services.api_based_extension_service.db")
|
||||
def test_raises_if_name_already_exists_for_new_record(self, mock_db):
|
||||
"""A new record whose name already exists raises ValueError."""
|
||||
mock_db.session.query.return_value.filter_by.return_value.filter_by.return_value.first.return_value = (
|
||||
MagicMock()
|
||||
)
|
||||
ext = _make_extension(id_=None, name="duplicate-name")
|
||||
|
||||
with pytest.raises(ValueError, match="name must be unique, it is already existed"):
|
||||
APIBasedExtensionService._validation(ext)
|
||||
|
||||
@patch("services.api_based_extension_service.db")
|
||||
def test_raises_if_api_endpoint_is_empty(self, mock_db):
|
||||
"""Empty api_endpoint raises ValueError."""
|
||||
mock_db.session.query.return_value.filter_by.return_value.filter_by.return_value.first.return_value = None
|
||||
ext = _make_extension(id_=None, api_endpoint="")
|
||||
|
||||
with pytest.raises(ValueError, match="api_endpoint must not be empty"):
|
||||
APIBasedExtensionService._validation(ext)
|
||||
|
||||
@patch("services.api_based_extension_service.db")
|
||||
def test_raises_if_api_endpoint_is_none(self, mock_db):
|
||||
"""None api_endpoint raises ValueError."""
|
||||
mock_db.session.query.return_value.filter_by.return_value.filter_by.return_value.first.return_value = None
|
||||
ext = _make_extension(id_=None, api_endpoint=None)
|
||||
|
||||
with pytest.raises(ValueError, match="api_endpoint must not be empty"):
|
||||
APIBasedExtensionService._validation(ext)
|
||||
|
||||
@patch("services.api_based_extension_service.db")
|
||||
def test_raises_if_api_key_is_empty(self, mock_db):
|
||||
"""Empty api_key raises ValueError."""
|
||||
mock_db.session.query.return_value.filter_by.return_value.filter_by.return_value.first.return_value = None
|
||||
ext = _make_extension(id_=None, api_key="")
|
||||
|
||||
with pytest.raises(ValueError, match="api_key must not be empty"):
|
||||
APIBasedExtensionService._validation(ext)
|
||||
|
||||
@patch("services.api_based_extension_service.db")
|
||||
def test_raises_if_api_key_is_none(self, mock_db):
|
||||
"""None api_key raises ValueError."""
|
||||
mock_db.session.query.return_value.filter_by.return_value.filter_by.return_value.first.return_value = None
|
||||
ext = _make_extension(id_=None, api_key=None)
|
||||
|
||||
with pytest.raises(ValueError, match="api_key must not be empty"):
|
||||
APIBasedExtensionService._validation(ext)
|
||||
|
||||
@patch("services.api_based_extension_service.db")
|
||||
def test_raises_if_api_key_too_short(self, mock_db):
|
||||
"""api_key shorter than 5 characters raises ValueError."""
|
||||
mock_db.session.query.return_value.filter_by.return_value.filter_by.return_value.first.return_value = None
|
||||
ext = _make_extension(id_=None, api_key="abc")
|
||||
|
||||
with pytest.raises(ValueError, match="api_key must be at least 5 characters"):
|
||||
APIBasedExtensionService._validation(ext)
|
||||
|
||||
@patch("services.api_based_extension_service.db")
|
||||
def test_raises_if_api_key_exactly_four_chars(self, mock_db):
|
||||
"""api_key with exactly 4 characters raises ValueError (boundary condition)."""
|
||||
mock_db.session.query.return_value.filter_by.return_value.filter_by.return_value.first.return_value = None
|
||||
ext = _make_extension(id_=None, api_key="1234")
|
||||
|
||||
with pytest.raises(ValueError, match="api_key must be at least 5 characters"):
|
||||
APIBasedExtensionService._validation(ext)
|
||||
|
||||
@patch.object(APIBasedExtensionService, "_ping_connection")
|
||||
@patch("services.api_based_extension_service.db")
|
||||
def test_api_key_exactly_five_chars_is_accepted(self, mock_db, mock_ping):
|
||||
"""api_key with exactly 5 characters should pass (boundary condition)."""
|
||||
mock_db.session.query.return_value.filter_by.return_value.filter_by.return_value.first.return_value = None
|
||||
ext = _make_extension(id_=None, api_key="12345")
|
||||
|
||||
# Should not raise
|
||||
APIBasedExtensionService._validation(ext)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: _validation (existing record — id is truthy)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestValidationExistingRecord:
|
||||
"""Tests for _validation() with an existing record (id is set)."""
|
||||
|
||||
@patch.object(APIBasedExtensionService, "_ping_connection")
|
||||
@patch("services.api_based_extension_service.db")
|
||||
def test_valid_existing_extension_passes(self, mock_db, mock_ping):
|
||||
"""An existing record whose name is unique (excluding self) should pass."""
|
||||
# .where(...).first() → None means no *other* record has that name
|
||||
(
|
||||
mock_db.session.query.return_value.filter_by.return_value.filter_by.return_value.where.return_value.first.return_value
|
||||
) = None
|
||||
ext = _make_extension(id_="existing-id", name="unique-name", api_key="longenoughkey")
|
||||
|
||||
# Should not raise
|
||||
APIBasedExtensionService._validation(ext)
|
||||
mock_ping.assert_called_once_with(ext)
|
||||
|
||||
@patch("services.api_based_extension_service.db")
|
||||
def test_raises_if_existing_record_name_conflicts_with_another(self, mock_db):
|
||||
"""Existing record cannot use a name already owned by a different record."""
|
||||
(
|
||||
mock_db.session.query.return_value.filter_by.return_value.filter_by.return_value.where.return_value.first.return_value
|
||||
) = MagicMock()
|
||||
ext = _make_extension(id_="existing-id", name="taken-name")
|
||||
|
||||
with pytest.raises(ValueError, match="name must be unique, it is already existed"):
|
||||
APIBasedExtensionService._validation(ext)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: _ping_connection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestPingConnection:
|
||||
"""Tests for APIBasedExtensionService._ping_connection."""
|
||||
|
||||
@patch("services.api_based_extension_service.APIBasedExtensionRequestor")
|
||||
def test_successful_ping_returns_pong(self, mock_requestor_class):
|
||||
"""When the endpoint returns {"result": "pong"}, no exception is raised."""
|
||||
mock_client = MagicMock()
|
||||
mock_client.request.return_value = {"result": "pong"}
|
||||
mock_requestor_class.return_value = mock_client
|
||||
|
||||
ext = _make_extension(api_endpoint="https://ok.example.com", api_key="secret-key")
|
||||
# Should not raise
|
||||
APIBasedExtensionService._ping_connection(ext)
|
||||
|
||||
mock_requestor_class.assert_called_once_with(ext.api_endpoint, ext.api_key)
|
||||
|
||||
@patch("services.api_based_extension_service.APIBasedExtensionRequestor")
|
||||
def test_wrong_ping_response_raises_value_error(self, mock_requestor_class):
|
||||
"""When the response is not {"result": "pong"}, a ValueError is raised."""
|
||||
mock_client = MagicMock()
|
||||
mock_client.request.return_value = {"result": "error"}
|
||||
mock_requestor_class.return_value = mock_client
|
||||
|
||||
ext = _make_extension()
|
||||
with pytest.raises(ValueError, match="connection error"):
|
||||
APIBasedExtensionService._ping_connection(ext)
|
||||
|
||||
@patch("services.api_based_extension_service.APIBasedExtensionRequestor")
|
||||
def test_network_exception_wraps_in_value_error(self, mock_requestor_class):
|
||||
"""Any exception raised during request is wrapped in a ValueError."""
|
||||
mock_client = MagicMock()
|
||||
mock_client.request.side_effect = ConnectionError("network failure")
|
||||
mock_requestor_class.return_value = mock_client
|
||||
|
||||
ext = _make_extension()
|
||||
with pytest.raises(ValueError, match="connection error: network failure"):
|
||||
APIBasedExtensionService._ping_connection(ext)
|
||||
|
||||
@patch("services.api_based_extension_service.APIBasedExtensionRequestor")
|
||||
def test_requestor_constructor_exception_wraps_in_value_error(self, mock_requestor_class):
|
||||
"""Exception raised by the requestor constructor itself is wrapped."""
|
||||
mock_requestor_class.side_effect = RuntimeError("bad config")
|
||||
|
||||
ext = _make_extension()
|
||||
with pytest.raises(ValueError, match="connection error: bad config"):
|
||||
APIBasedExtensionService._ping_connection(ext)
|
||||
|
||||
@patch("services.api_based_extension_service.APIBasedExtensionRequestor")
|
||||
def test_missing_result_key_raises_value_error(self, mock_requestor_class):
|
||||
"""A response dict without a 'result' key does not equal 'pong' → raises."""
|
||||
mock_client = MagicMock()
|
||||
mock_client.request.return_value = {} # no 'result' key
|
||||
mock_requestor_class.return_value = mock_client
|
||||
|
||||
ext = _make_extension()
|
||||
with pytest.raises(ValueError, match="connection error"):
|
||||
APIBasedExtensionService._ping_connection(ext)
|
||||
|
||||
@patch("services.api_based_extension_service.APIBasedExtensionRequestor")
|
||||
def test_uses_ping_extension_point(self, mock_requestor_class):
|
||||
"""The PING extension point is passed to the client.request call."""
|
||||
from models.api_based_extension import APIBasedExtensionPoint
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.request.return_value = {"result": "pong"}
|
||||
mock_requestor_class.return_value = mock_client
|
||||
|
||||
ext = _make_extension()
|
||||
APIBasedExtensionService._ping_connection(ext)
|
||||
|
||||
call_kwargs = mock_client.request.call_args
|
||||
assert call_kwargs.kwargs["point"] == APIBasedExtensionPoint.PING
|
||||
assert call_kwargs.kwargs["params"] == {}
|
||||
@ -1,73 +0,0 @@
|
||||
import base64
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
import services.attachment_service as attachment_service_module
|
||||
from models.model import UploadFile
|
||||
from services.attachment_service import AttachmentService
|
||||
|
||||
|
||||
class TestAttachmentService:
|
||||
def test_should_initialize_with_sessionmaker_when_sessionmaker_is_provided(self):
|
||||
"""Test that AttachmentService keeps the provided sessionmaker instance."""
|
||||
session_factory = sessionmaker()
|
||||
|
||||
service = AttachmentService(session_factory=session_factory)
|
||||
|
||||
assert service._session_maker is session_factory
|
||||
|
||||
def test_should_initialize_with_bound_sessionmaker_when_engine_is_provided(self):
|
||||
"""Test that AttachmentService builds a sessionmaker bound to the provided engine."""
|
||||
engine = create_engine("sqlite:///:memory:")
|
||||
|
||||
service = AttachmentService(session_factory=engine)
|
||||
session = service._session_maker()
|
||||
try:
|
||||
assert session.bind == engine
|
||||
finally:
|
||||
session.close()
|
||||
engine.dispose()
|
||||
|
||||
@pytest.mark.parametrize("invalid_session_factory", [None, "not-a-session-factory", 1])
|
||||
def test_should_raise_assertion_error_when_session_factory_type_is_invalid(self, invalid_session_factory):
|
||||
"""Test that invalid session_factory types are rejected."""
|
||||
with pytest.raises(AssertionError, match="must be a sessionmaker or an Engine."):
|
||||
AttachmentService(session_factory=invalid_session_factory)
|
||||
|
||||
def test_should_return_base64_encoded_blob_when_file_exists(self):
|
||||
"""Test that existing files are loaded from storage and returned as base64."""
|
||||
service = AttachmentService(session_factory=sessionmaker())
|
||||
upload_file = MagicMock(spec=UploadFile)
|
||||
upload_file.key = "upload-file-key"
|
||||
|
||||
session = MagicMock()
|
||||
session.query.return_value.where.return_value.first.return_value = upload_file
|
||||
service._session_maker = MagicMock(return_value=session)
|
||||
|
||||
with patch.object(attachment_service_module.storage, "load_once", return_value=b"binary-content") as mock_load:
|
||||
result = service.get_file_base64("file-123")
|
||||
|
||||
assert result == base64.b64encode(b"binary-content").decode()
|
||||
service._session_maker.assert_called_once_with(expire_on_commit=False)
|
||||
session.query.assert_called_once_with(UploadFile)
|
||||
mock_load.assert_called_once_with("upload-file-key")
|
||||
|
||||
def test_should_raise_not_found_when_file_does_not_exist(self):
|
||||
"""Test that missing files raise NotFound and never call storage."""
|
||||
service = AttachmentService(session_factory=sessionmaker())
|
||||
|
||||
session = MagicMock()
|
||||
session.query.return_value.where.return_value.first.return_value = None
|
||||
service._session_maker = MagicMock(return_value=session)
|
||||
|
||||
with patch.object(attachment_service_module.storage, "load_once") as mock_load:
|
||||
with pytest.raises(NotFound, match="File not found"):
|
||||
service.get_file_base64("missing-file")
|
||||
|
||||
service._session_maker.assert_called_once_with(expire_on_commit=False)
|
||||
session.query.assert_called_once_with(UploadFile)
|
||||
mock_load.assert_not_called()
|
||||
@ -1,75 +0,0 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from dify_graph.variables import StringVariable
|
||||
from services.conversation_variable_updater import ConversationVariableNotFoundError, ConversationVariableUpdater
|
||||
|
||||
|
||||
class TestConversationVariableUpdater:
|
||||
def test_should_update_conversation_variable_data_and_commit(self):
|
||||
"""Test update persists serialized variable data when the row exists."""
|
||||
conversation_id = "conv-123"
|
||||
variable = StringVariable(
|
||||
id="var-123",
|
||||
name="topic",
|
||||
value="new value",
|
||||
)
|
||||
expected_json = variable.model_dump_json()
|
||||
|
||||
row = SimpleNamespace(data="old value")
|
||||
session = MagicMock()
|
||||
session.scalar.return_value = row
|
||||
|
||||
session_context = MagicMock()
|
||||
session_context.__enter__.return_value = session
|
||||
session_context.__exit__.return_value = None
|
||||
|
||||
session_maker = MagicMock(return_value=session_context)
|
||||
updater = ConversationVariableUpdater(session_maker)
|
||||
|
||||
updater.update(conversation_id=conversation_id, variable=variable)
|
||||
|
||||
session_maker.assert_called_once_with()
|
||||
session.scalar.assert_called_once()
|
||||
stmt = session.scalar.call_args.args[0]
|
||||
compiled_params = stmt.compile().params
|
||||
assert variable.id in compiled_params.values()
|
||||
assert conversation_id in compiled_params.values()
|
||||
assert row.data == expected_json
|
||||
session.commit.assert_called_once()
|
||||
|
||||
def test_should_raise_not_found_error_when_conversation_variable_missing(self):
|
||||
"""Test update raises ConversationVariableNotFoundError when no matching row exists."""
|
||||
conversation_id = "conv-404"
|
||||
variable = StringVariable(
|
||||
id="var-404",
|
||||
name="topic",
|
||||
value="value",
|
||||
)
|
||||
|
||||
session = MagicMock()
|
||||
session.scalar.return_value = None
|
||||
|
||||
session_context = MagicMock()
|
||||
session_context.__enter__.return_value = session
|
||||
session_context.__exit__.return_value = None
|
||||
|
||||
session_maker = MagicMock(return_value=session_context)
|
||||
updater = ConversationVariableUpdater(session_maker)
|
||||
|
||||
with pytest.raises(ConversationVariableNotFoundError, match="conversation variable not found in the database"):
|
||||
updater.update(conversation_id=conversation_id, variable=variable)
|
||||
|
||||
session.commit.assert_not_called()
|
||||
|
||||
def test_should_do_nothing_when_flush_is_called(self):
|
||||
"""Test flush currently behaves as a no-op and returns None."""
|
||||
session_maker = MagicMock()
|
||||
updater = ConversationVariableUpdater(session_maker)
|
||||
|
||||
result = updater.flush()
|
||||
|
||||
assert result is None
|
||||
session_maker.assert_not_called()
|
||||
@ -1,157 +0,0 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
import services.credit_pool_service as credit_pool_service_module
|
||||
from core.errors.error import QuotaExceededError
|
||||
from models import TenantCreditPool
|
||||
from services.credit_pool_service import CreditPoolService
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_credit_deduction_setup():
|
||||
"""Fixture providing common setup for credit deduction tests."""
|
||||
pool = SimpleNamespace(remaining_credits=50)
|
||||
fake_engine = MagicMock()
|
||||
session = MagicMock()
|
||||
session_context = MagicMock()
|
||||
session_context.__enter__.return_value = session
|
||||
session_context.__exit__.return_value = None
|
||||
|
||||
mock_get_pool = patch.object(CreditPoolService, "get_pool", return_value=pool)
|
||||
mock_db = patch.object(credit_pool_service_module, "db", new=SimpleNamespace(engine=fake_engine))
|
||||
mock_session = patch.object(credit_pool_service_module, "Session", return_value=session_context)
|
||||
|
||||
return {
|
||||
"pool": pool,
|
||||
"fake_engine": fake_engine,
|
||||
"session": session,
|
||||
"session_context": session_context,
|
||||
"patches": (mock_get_pool, mock_db, mock_session),
|
||||
}
|
||||
|
||||
|
||||
class TestCreditPoolService:
|
||||
def test_should_create_default_pool_with_trial_type_and_configured_quota(self):
|
||||
"""Test create_default_pool persists a trial pool using configured hosted credits."""
|
||||
tenant_id = "tenant-123"
|
||||
hosted_pool_credits = 5000
|
||||
|
||||
with (
|
||||
patch.object(credit_pool_service_module.dify_config, "HOSTED_POOL_CREDITS", hosted_pool_credits),
|
||||
patch.object(credit_pool_service_module, "db") as mock_db,
|
||||
):
|
||||
pool = CreditPoolService.create_default_pool(tenant_id)
|
||||
|
||||
assert isinstance(pool, TenantCreditPool)
|
||||
assert pool.tenant_id == tenant_id
|
||||
assert pool.pool_type == "trial"
|
||||
assert pool.quota_limit == hosted_pool_credits
|
||||
assert pool.quota_used == 0
|
||||
mock_db.session.add.assert_called_once_with(pool)
|
||||
mock_db.session.commit.assert_called_once()
|
||||
|
||||
def test_should_return_first_pool_from_query_when_get_pool_called(self):
|
||||
"""Test get_pool queries by tenant and pool_type and returns first result."""
|
||||
tenant_id = "tenant-123"
|
||||
pool_type = "enterprise"
|
||||
expected_pool = MagicMock(spec=TenantCreditPool)
|
||||
|
||||
with patch.object(credit_pool_service_module, "db") as mock_db:
|
||||
query = mock_db.session.query.return_value
|
||||
filtered_query = query.filter_by.return_value
|
||||
filtered_query.first.return_value = expected_pool
|
||||
|
||||
result = CreditPoolService.get_pool(tenant_id=tenant_id, pool_type=pool_type)
|
||||
|
||||
assert result == expected_pool
|
||||
mock_db.session.query.assert_called_once_with(TenantCreditPool)
|
||||
query.filter_by.assert_called_once_with(tenant_id=tenant_id, pool_type=pool_type)
|
||||
filtered_query.first.assert_called_once()
|
||||
|
||||
def test_should_return_false_when_pool_not_found_in_check_credits_available(self):
|
||||
"""Test check_credits_available returns False when tenant has no pool."""
|
||||
with patch.object(CreditPoolService, "get_pool", return_value=None) as mock_get_pool:
|
||||
result = CreditPoolService.check_credits_available(tenant_id="tenant-123", credits_required=10)
|
||||
|
||||
assert result is False
|
||||
mock_get_pool.assert_called_once_with("tenant-123", "trial")
|
||||
|
||||
def test_should_return_true_when_remaining_credits_cover_required_amount(self):
|
||||
"""Test check_credits_available returns True when remaining credits are sufficient."""
|
||||
pool = SimpleNamespace(remaining_credits=100)
|
||||
|
||||
with patch.object(CreditPoolService, "get_pool", return_value=pool) as mock_get_pool:
|
||||
result = CreditPoolService.check_credits_available(tenant_id="tenant-123", credits_required=60)
|
||||
|
||||
assert result is True
|
||||
mock_get_pool.assert_called_once_with("tenant-123", "trial")
|
||||
|
||||
def test_should_return_false_when_remaining_credits_are_insufficient(self):
|
||||
"""Test check_credits_available returns False when required credits exceed remaining credits."""
|
||||
pool = SimpleNamespace(remaining_credits=30)
|
||||
|
||||
with patch.object(CreditPoolService, "get_pool", return_value=pool):
|
||||
result = CreditPoolService.check_credits_available(tenant_id="tenant-123", credits_required=60)
|
||||
|
||||
assert result is False
|
||||
|
||||
def test_should_raise_quota_exceeded_when_pool_not_found_in_check_and_deduct(self):
|
||||
"""Test check_and_deduct_credits raises when tenant credit pool does not exist."""
|
||||
with patch.object(CreditPoolService, "get_pool", return_value=None):
|
||||
with pytest.raises(QuotaExceededError, match="Credit pool not found"):
|
||||
CreditPoolService.check_and_deduct_credits(tenant_id="tenant-123", credits_required=10)
|
||||
|
||||
def test_should_raise_quota_exceeded_when_pool_has_no_remaining_credits(self):
|
||||
"""Test check_and_deduct_credits raises when remaining credits are zero or negative."""
|
||||
pool = SimpleNamespace(remaining_credits=0)
|
||||
|
||||
with patch.object(CreditPoolService, "get_pool", return_value=pool):
|
||||
with pytest.raises(QuotaExceededError, match="No credits remaining"):
|
||||
CreditPoolService.check_and_deduct_credits(tenant_id="tenant-123", credits_required=10)
|
||||
|
||||
def test_should_deduct_minimum_of_required_and_remaining_credits(self, mock_credit_deduction_setup):
|
||||
"""Test check_and_deduct_credits updates quota_used by the actual deducted amount."""
|
||||
tenant_id = "tenant-123"
|
||||
pool_type = "trial"
|
||||
credits_required = 200
|
||||
remaining_credits = 120
|
||||
expected_deducted_credits = 120
|
||||
|
||||
mock_credit_deduction_setup["pool"].remaining_credits = remaining_credits
|
||||
patches = mock_credit_deduction_setup["patches"]
|
||||
session = mock_credit_deduction_setup["session"]
|
||||
|
||||
with patches[0], patches[1], patches[2]:
|
||||
result = CreditPoolService.check_and_deduct_credits(
|
||||
tenant_id=tenant_id,
|
||||
credits_required=credits_required,
|
||||
pool_type=pool_type,
|
||||
)
|
||||
|
||||
assert result == expected_deducted_credits
|
||||
session.execute.assert_called_once()
|
||||
session.commit.assert_called_once()
|
||||
|
||||
stmt = session.execute.call_args.args[0]
|
||||
compiled_params = stmt.compile().params
|
||||
assert tenant_id in compiled_params.values()
|
||||
assert pool_type in compiled_params.values()
|
||||
assert expected_deducted_credits in compiled_params.values()
|
||||
|
||||
def test_should_raise_quota_exceeded_when_deduction_update_fails(self, mock_credit_deduction_setup):
|
||||
"""Test check_and_deduct_credits translates DB update failures to QuotaExceededError."""
|
||||
mock_credit_deduction_setup["pool"].remaining_credits = 50
|
||||
mock_credit_deduction_setup["session"].execute.side_effect = Exception("db failure")
|
||||
session = mock_credit_deduction_setup["session"]
|
||||
|
||||
patches = mock_credit_deduction_setup["patches"]
|
||||
mock_logger = patch.object(credit_pool_service_module, "logger")
|
||||
|
||||
with patches[0], patches[1], patches[2], mock_logger as mock_logger_obj:
|
||||
with pytest.raises(QuotaExceededError, match="Failed to deduct credits"):
|
||||
CreditPoolService.check_and_deduct_credits(tenant_id="tenant-123", credits_required=10)
|
||||
|
||||
session.commit.assert_not_called()
|
||||
mock_logger_obj.exception.assert_called_once()
|
||||
@ -1,305 +0,0 @@
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from models.account import Account, TenantAccountRole
|
||||
from models.dataset import Dataset, DatasetPermission, DatasetPermissionEnum
|
||||
from services.dataset_service import DatasetService
|
||||
from services.errors.account import NoPermissionError
|
||||
|
||||
|
||||
class DatasetPermissionTestDataFactory:
|
||||
"""Factory class for creating test data and mock objects for dataset permission tests."""
|
||||
|
||||
@staticmethod
|
||||
def create_dataset_mock(
|
||||
dataset_id: str = "dataset-123",
|
||||
tenant_id: str = "test-tenant-123",
|
||||
created_by: str = "creator-456",
|
||||
permission: DatasetPermissionEnum = DatasetPermissionEnum.ONLY_ME,
|
||||
**kwargs,
|
||||
) -> Mock:
|
||||
"""Create a mock dataset with specified attributes."""
|
||||
dataset = Mock(spec=Dataset)
|
||||
dataset.id = dataset_id
|
||||
dataset.tenant_id = tenant_id
|
||||
dataset.created_by = created_by
|
||||
dataset.permission = permission
|
||||
for key, value in kwargs.items():
|
||||
setattr(dataset, key, value)
|
||||
return dataset
|
||||
|
||||
@staticmethod
|
||||
def create_user_mock(
|
||||
user_id: str = "user-789",
|
||||
tenant_id: str = "test-tenant-123",
|
||||
role: TenantAccountRole = TenantAccountRole.NORMAL,
|
||||
**kwargs,
|
||||
) -> Mock:
|
||||
"""Create a mock user with specified attributes."""
|
||||
user = Mock(spec=Account)
|
||||
user.id = user_id
|
||||
user.current_tenant_id = tenant_id
|
||||
user.current_role = role
|
||||
for key, value in kwargs.items():
|
||||
setattr(user, key, value)
|
||||
return user
|
||||
|
||||
@staticmethod
|
||||
def create_dataset_permission_mock(
|
||||
dataset_id: str = "dataset-123",
|
||||
account_id: str = "user-789",
|
||||
**kwargs,
|
||||
) -> Mock:
|
||||
"""Create a mock dataset permission record."""
|
||||
permission = Mock(spec=DatasetPermission)
|
||||
permission.dataset_id = dataset_id
|
||||
permission.account_id = account_id
|
||||
for key, value in kwargs.items():
|
||||
setattr(permission, key, value)
|
||||
return permission
|
||||
|
||||
|
||||
class TestDatasetPermissionService:
|
||||
"""
|
||||
Comprehensive unit tests for DatasetService.check_dataset_permission method.
|
||||
|
||||
This test suite covers all permission scenarios including:
|
||||
- Cross-tenant access restrictions
|
||||
- Owner privilege checks
|
||||
- Different permission levels (ONLY_ME, ALL_TEAM, PARTIAL_TEAM)
|
||||
- Explicit permission checks for PARTIAL_TEAM
|
||||
- Error conditions and logging
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_dataset_service_dependencies(self):
|
||||
"""Common mock setup for dataset service dependencies."""
|
||||
with patch("services.dataset_service.db.session") as mock_session:
|
||||
yield {
|
||||
"db_session": mock_session,
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def mock_logging_dependencies(self):
|
||||
"""Mock setup for logging tests."""
|
||||
with patch("services.dataset_service.logger") as mock_logging:
|
||||
yield {
|
||||
"logging": mock_logging,
|
||||
}
|
||||
|
||||
def _assert_permission_check_passes(self, dataset: Mock, user: Mock):
|
||||
"""Helper method to verify that permission check passes without raising exceptions."""
|
||||
# Should not raise any exception
|
||||
DatasetService.check_dataset_permission(dataset, user)
|
||||
|
||||
def _assert_permission_check_fails(
|
||||
self, dataset: Mock, user: Mock, expected_message: str = "You do not have permission to access this dataset."
|
||||
):
|
||||
"""Helper method to verify that permission check fails with expected error."""
|
||||
with pytest.raises(NoPermissionError, match=expected_message):
|
||||
DatasetService.check_dataset_permission(dataset, user)
|
||||
|
||||
def _assert_database_query_called(self, mock_session: Mock, dataset_id: str, account_id: str):
|
||||
"""Helper method to verify database query calls for permission checks."""
|
||||
mock_session.query().filter_by.assert_called_with(dataset_id=dataset_id, account_id=account_id)
|
||||
|
||||
def _assert_database_query_not_called(self, mock_session: Mock):
|
||||
"""Helper method to verify that database query was not called."""
|
||||
mock_session.query.assert_not_called()
|
||||
|
||||
# ==================== Cross-Tenant Access Tests ====================
|
||||
|
||||
def test_permission_check_different_tenant_should_fail(self):
|
||||
"""Test that users from different tenants cannot access dataset regardless of other permissions."""
|
||||
# Create dataset and user from different tenants
|
||||
dataset = DatasetPermissionTestDataFactory.create_dataset_mock(
|
||||
tenant_id="tenant-123", permission=DatasetPermissionEnum.ALL_TEAM
|
||||
)
|
||||
user = DatasetPermissionTestDataFactory.create_user_mock(
|
||||
user_id="user-789", tenant_id="different-tenant-456", role=TenantAccountRole.EDITOR
|
||||
)
|
||||
|
||||
# Should fail due to different tenant
|
||||
self._assert_permission_check_fails(dataset, user)
|
||||
|
||||
# ==================== Owner Privilege Tests ====================
|
||||
|
||||
def test_owner_can_access_any_dataset(self):
|
||||
"""Test that tenant owners can access any dataset regardless of permission level."""
|
||||
# Create dataset with restrictive permission
|
||||
dataset = DatasetPermissionTestDataFactory.create_dataset_mock(permission=DatasetPermissionEnum.ONLY_ME)
|
||||
|
||||
# Create owner user
|
||||
owner_user = DatasetPermissionTestDataFactory.create_user_mock(
|
||||
user_id="owner-999", role=TenantAccountRole.OWNER
|
||||
)
|
||||
|
||||
# Owner should have access regardless of dataset permission
|
||||
self._assert_permission_check_passes(dataset, owner_user)
|
||||
|
||||
# ==================== ONLY_ME Permission Tests ====================
|
||||
|
||||
def test_only_me_permission_creator_can_access(self):
|
||||
"""Test ONLY_ME permission allows only the dataset creator to access."""
|
||||
# Create dataset with ONLY_ME permission
|
||||
dataset = DatasetPermissionTestDataFactory.create_dataset_mock(
|
||||
created_by="creator-456", permission=DatasetPermissionEnum.ONLY_ME
|
||||
)
|
||||
|
||||
# Create creator user
|
||||
creator_user = DatasetPermissionTestDataFactory.create_user_mock(
|
||||
user_id="creator-456", role=TenantAccountRole.EDITOR
|
||||
)
|
||||
|
||||
# Creator should be able to access
|
||||
self._assert_permission_check_passes(dataset, creator_user)
|
||||
|
||||
def test_only_me_permission_others_cannot_access(self):
|
||||
"""Test ONLY_ME permission denies access to non-creators."""
|
||||
# Create dataset with ONLY_ME permission
|
||||
dataset = DatasetPermissionTestDataFactory.create_dataset_mock(
|
||||
created_by="creator-456", permission=DatasetPermissionEnum.ONLY_ME
|
||||
)
|
||||
|
||||
# Create normal user (not the creator)
|
||||
normal_user = DatasetPermissionTestDataFactory.create_user_mock(
|
||||
user_id="normal-789", role=TenantAccountRole.NORMAL
|
||||
)
|
||||
|
||||
# Non-creator should be denied access
|
||||
self._assert_permission_check_fails(dataset, normal_user)
|
||||
|
||||
# ==================== ALL_TEAM Permission Tests ====================
|
||||
|
||||
def test_all_team_permission_allows_access(self):
|
||||
"""Test ALL_TEAM permission allows any team member to access the dataset."""
|
||||
# Create dataset with ALL_TEAM permission
|
||||
dataset = DatasetPermissionTestDataFactory.create_dataset_mock(permission=DatasetPermissionEnum.ALL_TEAM)
|
||||
|
||||
# Create different types of team members
|
||||
normal_user = DatasetPermissionTestDataFactory.create_user_mock(
|
||||
user_id="normal-789", role=TenantAccountRole.NORMAL
|
||||
)
|
||||
editor_user = DatasetPermissionTestDataFactory.create_user_mock(
|
||||
user_id="editor-456", role=TenantAccountRole.EDITOR
|
||||
)
|
||||
|
||||
# All team members should have access
|
||||
self._assert_permission_check_passes(dataset, normal_user)
|
||||
self._assert_permission_check_passes(dataset, editor_user)
|
||||
|
||||
# ==================== PARTIAL_TEAM Permission Tests ====================
|
||||
|
||||
def test_partial_team_permission_creator_can_access(self, mock_dataset_service_dependencies):
|
||||
"""Test PARTIAL_TEAM permission allows creator to access without database query."""
|
||||
# Create dataset with PARTIAL_TEAM permission
|
||||
dataset = DatasetPermissionTestDataFactory.create_dataset_mock(
|
||||
created_by="creator-456", permission=DatasetPermissionEnum.PARTIAL_TEAM
|
||||
)
|
||||
|
||||
# Create creator user
|
||||
creator_user = DatasetPermissionTestDataFactory.create_user_mock(
|
||||
user_id="creator-456", role=TenantAccountRole.EDITOR
|
||||
)
|
||||
|
||||
# Creator should have access without database query
|
||||
self._assert_permission_check_passes(dataset, creator_user)
|
||||
self._assert_database_query_not_called(mock_dataset_service_dependencies["db_session"])
|
||||
|
||||
def test_partial_team_permission_with_explicit_permission(self, mock_dataset_service_dependencies):
|
||||
"""Test PARTIAL_TEAM permission allows users with explicit permission records."""
|
||||
# Create dataset with PARTIAL_TEAM permission
|
||||
dataset = DatasetPermissionTestDataFactory.create_dataset_mock(permission=DatasetPermissionEnum.PARTIAL_TEAM)
|
||||
|
||||
# Create normal user (not the creator)
|
||||
normal_user = DatasetPermissionTestDataFactory.create_user_mock(
|
||||
user_id="normal-789", role=TenantAccountRole.NORMAL
|
||||
)
|
||||
|
||||
# Mock database query to return a permission record
|
||||
mock_permission = DatasetPermissionTestDataFactory.create_dataset_permission_mock(
|
||||
dataset_id=dataset.id, account_id=normal_user.id
|
||||
)
|
||||
mock_dataset_service_dependencies["db_session"].query().filter_by().first.return_value = mock_permission
|
||||
|
||||
# User with explicit permission should have access
|
||||
self._assert_permission_check_passes(dataset, normal_user)
|
||||
self._assert_database_query_called(mock_dataset_service_dependencies["db_session"], dataset.id, normal_user.id)
|
||||
|
||||
def test_partial_team_permission_without_explicit_permission(self, mock_dataset_service_dependencies):
|
||||
"""Test PARTIAL_TEAM permission denies users without explicit permission records."""
|
||||
# Create dataset with PARTIAL_TEAM permission
|
||||
dataset = DatasetPermissionTestDataFactory.create_dataset_mock(permission=DatasetPermissionEnum.PARTIAL_TEAM)
|
||||
|
||||
# Create normal user (not the creator)
|
||||
normal_user = DatasetPermissionTestDataFactory.create_user_mock(
|
||||
user_id="normal-789", role=TenantAccountRole.NORMAL
|
||||
)
|
||||
|
||||
# Mock database query to return None (no permission record)
|
||||
mock_dataset_service_dependencies["db_session"].query().filter_by().first.return_value = None
|
||||
|
||||
# User without explicit permission should be denied access
|
||||
self._assert_permission_check_fails(dataset, normal_user)
|
||||
self._assert_database_query_called(mock_dataset_service_dependencies["db_session"], dataset.id, normal_user.id)
|
||||
|
||||
def test_partial_team_permission_non_creator_without_permission_fails(self, mock_dataset_service_dependencies):
|
||||
"""Test that non-creators without explicit permission are denied access to PARTIAL_TEAM datasets."""
|
||||
# Create dataset with PARTIAL_TEAM permission
|
||||
dataset = DatasetPermissionTestDataFactory.create_dataset_mock(
|
||||
created_by="creator-456", permission=DatasetPermissionEnum.PARTIAL_TEAM
|
||||
)
|
||||
|
||||
# Create a different user (not the creator)
|
||||
other_user = DatasetPermissionTestDataFactory.create_user_mock(
|
||||
user_id="other-user-123", role=TenantAccountRole.NORMAL
|
||||
)
|
||||
|
||||
# Mock database query to return None (no permission record)
|
||||
mock_dataset_service_dependencies["db_session"].query().filter_by().first.return_value = None
|
||||
|
||||
# Non-creator without explicit permission should be denied access
|
||||
self._assert_permission_check_fails(dataset, other_user)
|
||||
self._assert_database_query_called(mock_dataset_service_dependencies["db_session"], dataset.id, other_user.id)
|
||||
|
||||
# ==================== Enum Usage Tests ====================
|
||||
|
||||
def test_partial_team_permission_uses_correct_enum(self):
|
||||
"""Test that the method correctly uses DatasetPermissionEnum.PARTIAL_TEAM instead of string literals."""
|
||||
# Create dataset with PARTIAL_TEAM permission using enum
|
||||
dataset = DatasetPermissionTestDataFactory.create_dataset_mock(
|
||||
created_by="creator-456", permission=DatasetPermissionEnum.PARTIAL_TEAM
|
||||
)
|
||||
|
||||
# Create creator user
|
||||
creator_user = DatasetPermissionTestDataFactory.create_user_mock(
|
||||
user_id="creator-456", role=TenantAccountRole.EDITOR
|
||||
)
|
||||
|
||||
# Creator should always have access regardless of permission level
|
||||
self._assert_permission_check_passes(dataset, creator_user)
|
||||
|
||||
# ==================== Logging Tests ====================
|
||||
|
||||
def test_permission_denied_logs_debug_message(self, mock_dataset_service_dependencies, mock_logging_dependencies):
|
||||
"""Test that permission denied events are properly logged for debugging purposes."""
|
||||
# Create dataset with PARTIAL_TEAM permission
|
||||
dataset = DatasetPermissionTestDataFactory.create_dataset_mock(permission=DatasetPermissionEnum.PARTIAL_TEAM)
|
||||
|
||||
# Create normal user (not the creator)
|
||||
normal_user = DatasetPermissionTestDataFactory.create_user_mock(
|
||||
user_id="normal-789", role=TenantAccountRole.NORMAL
|
||||
)
|
||||
|
||||
# Mock database query to return None (no permission record)
|
||||
mock_dataset_service_dependencies["db_session"].query().filter_by().first.return_value = None
|
||||
|
||||
# Attempt permission check (should fail)
|
||||
with pytest.raises(NoPermissionError):
|
||||
DatasetService.check_dataset_permission(dataset, normal_user)
|
||||
|
||||
# Verify debug message was logged with correct user and dataset information
|
||||
mock_logging_dependencies["logging"].debug.assert_called_with(
|
||||
"User %s does not have permission to access dataset %s", normal_user.id, dataset.id
|
||||
)
|
||||
@ -1,100 +0,0 @@
|
||||
import datetime
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from models.dataset import Dataset, Document
|
||||
from services.dataset_service import DocumentService
|
||||
from tests.unit_tests.conftest import redis_mock
|
||||
|
||||
|
||||
class DocumentBatchUpdateTestDataFactory:
|
||||
"""Factory class for creating test data and mock objects for document batch update tests."""
|
||||
|
||||
@staticmethod
|
||||
def create_dataset_mock(dataset_id: str = "dataset-123", tenant_id: str = "tenant-456") -> Mock:
|
||||
"""Create a mock dataset with specified attributes."""
|
||||
dataset = Mock(spec=Dataset)
|
||||
dataset.id = dataset_id
|
||||
dataset.tenant_id = tenant_id
|
||||
return dataset
|
||||
|
||||
@staticmethod
|
||||
def create_user_mock(user_id: str = "user-789") -> Mock:
|
||||
"""Create a mock user."""
|
||||
user = Mock()
|
||||
user.id = user_id
|
||||
return user
|
||||
|
||||
@staticmethod
|
||||
def create_document_mock(
|
||||
document_id: str = "doc-1",
|
||||
name: str = "test_document.pdf",
|
||||
enabled: bool = True,
|
||||
archived: bool = False,
|
||||
indexing_status: str = "completed",
|
||||
completed_at: datetime.datetime | None = None,
|
||||
**kwargs,
|
||||
) -> Mock:
|
||||
"""Create a mock document with specified attributes."""
|
||||
document = Mock(spec=Document)
|
||||
document.id = document_id
|
||||
document.name = name
|
||||
document.enabled = enabled
|
||||
document.archived = archived
|
||||
document.indexing_status = indexing_status
|
||||
document.completed_at = completed_at or datetime.datetime.now()
|
||||
|
||||
document.disabled_at = None
|
||||
document.disabled_by = None
|
||||
document.archived_at = None
|
||||
document.archived_by = None
|
||||
document.updated_at = None
|
||||
|
||||
for key, value in kwargs.items():
|
||||
setattr(document, key, value)
|
||||
return document
|
||||
|
||||
|
||||
class TestDatasetServiceBatchUpdateDocumentStatus:
|
||||
"""Unit tests for non-SQL path in DocumentService.batch_update_document_status."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_document_service_dependencies(self):
|
||||
"""Common mock setup for document service dependencies."""
|
||||
with (
|
||||
patch("services.dataset_service.DocumentService.get_document") as mock_get_doc,
|
||||
patch("extensions.ext_database.db.session") as mock_db,
|
||||
patch("services.dataset_service.naive_utc_now") as mock_naive_utc_now,
|
||||
):
|
||||
current_time = datetime.datetime(2023, 1, 1, 12, 0, 0)
|
||||
mock_naive_utc_now.return_value = current_time
|
||||
|
||||
yield {
|
||||
"get_document": mock_get_doc,
|
||||
"db_session": mock_db,
|
||||
"naive_utc_now": mock_naive_utc_now,
|
||||
"current_time": current_time,
|
||||
}
|
||||
|
||||
def test_batch_update_invalid_action_error(self, mock_document_service_dependencies):
|
||||
"""Test that ValueError is raised when an invalid action is provided."""
|
||||
dataset = DocumentBatchUpdateTestDataFactory.create_dataset_mock()
|
||||
user = DocumentBatchUpdateTestDataFactory.create_user_mock()
|
||||
|
||||
doc = DocumentBatchUpdateTestDataFactory.create_document_mock(enabled=True)
|
||||
mock_document_service_dependencies["get_document"].return_value = doc
|
||||
|
||||
redis_mock.reset_mock()
|
||||
redis_mock.get.return_value = None
|
||||
|
||||
invalid_action = "invalid_action"
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
DocumentService.batch_update_document_status(
|
||||
dataset=dataset, document_ids=["doc-1"], action=invalid_action, user=user
|
||||
)
|
||||
|
||||
assert invalid_action in str(exc_info.value)
|
||||
assert "Invalid action" in str(exc_info.value)
|
||||
|
||||
redis_mock.setex.assert_not_called()
|
||||
@ -1,50 +0,0 @@
|
||||
"""Unit tests for non-SQL validation paths in DatasetService dataset creation."""
|
||||
|
||||
from unittest.mock import Mock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from services.dataset_service import DatasetService
|
||||
from services.entities.knowledge_entities.rag_pipeline_entities import IconInfo, RagPipelineDatasetCreateEntity
|
||||
|
||||
|
||||
class TestDatasetServiceCreateRagPipelineDatasetNonSQL:
|
||||
"""Unit coverage for non-SQL validation in create_empty_rag_pipeline_dataset."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_rag_pipeline_dependencies(self):
|
||||
"""Patch database session and current_user for validation-only unit coverage."""
|
||||
with (
|
||||
patch("services.dataset_service.db.session") as mock_db,
|
||||
patch("services.dataset_service.current_user") as mock_current_user,
|
||||
):
|
||||
yield {
|
||||
"db_session": mock_db,
|
||||
"current_user_mock": mock_current_user,
|
||||
}
|
||||
|
||||
def test_create_rag_pipeline_dataset_missing_current_user_error(self, mock_rag_pipeline_dependencies):
|
||||
"""Raise ValueError when current_user.id is unavailable before SQL persistence."""
|
||||
# Arrange
|
||||
tenant_id = str(uuid4())
|
||||
mock_rag_pipeline_dependencies["current_user_mock"].id = None
|
||||
|
||||
mock_query = Mock()
|
||||
mock_query.filter_by.return_value.first.return_value = None
|
||||
mock_rag_pipeline_dependencies["db_session"].query.return_value = mock_query
|
||||
|
||||
icon_info = IconInfo(icon="📙", icon_background="#FFF4ED", icon_type="emoji")
|
||||
entity = RagPipelineDatasetCreateEntity(
|
||||
name="Test Dataset",
|
||||
description="",
|
||||
icon_info=icon_info,
|
||||
permission="only_me",
|
||||
)
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(ValueError, match="Current user or current user id not found"):
|
||||
DatasetService.create_empty_rag_pipeline_dataset(
|
||||
tenant_id=tenant_id,
|
||||
rag_pipeline_dataset_create_entity=entity,
|
||||
)
|
||||
@ -1,57 +0,0 @@
|
||||
"""
|
||||
Unit tests for archived workflow run deletion service.
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
|
||||
class TestArchivedWorkflowRunDeletion:
|
||||
def test_delete_by_run_id_calls_delete_run(self):
|
||||
from services.retention.workflow_run.delete_archived_workflow_run import ArchivedWorkflowRunDeletion
|
||||
|
||||
deleter = ArchivedWorkflowRunDeletion()
|
||||
repo = MagicMock()
|
||||
repo.get_archived_run_ids.return_value = {"run-1"}
|
||||
run = MagicMock()
|
||||
run.id = "run-1"
|
||||
run.tenant_id = "tenant-1"
|
||||
|
||||
session = MagicMock()
|
||||
session.get.return_value = run
|
||||
|
||||
session_maker = MagicMock()
|
||||
session_maker.return_value.__enter__.return_value = session
|
||||
session_maker.return_value.__exit__.return_value = None
|
||||
mock_db = MagicMock()
|
||||
mock_db.engine = MagicMock()
|
||||
|
||||
with (
|
||||
patch("services.retention.workflow_run.delete_archived_workflow_run.db", mock_db),
|
||||
patch(
|
||||
"services.retention.workflow_run.delete_archived_workflow_run.sessionmaker",
|
||||
return_value=session_maker,
|
||||
autospec=True,
|
||||
),
|
||||
patch.object(deleter, "_get_workflow_run_repo", return_value=repo, autospec=True),
|
||||
patch.object(
|
||||
deleter, "_delete_run", return_value=MagicMock(success=True), autospec=True
|
||||
) as mock_delete_run,
|
||||
):
|
||||
result = deleter.delete_by_run_id("run-1")
|
||||
|
||||
assert result.success is True
|
||||
mock_delete_run.assert_called_once_with(run)
|
||||
|
||||
def test_delete_run_dry_run(self):
|
||||
from services.retention.workflow_run.delete_archived_workflow_run import ArchivedWorkflowRunDeletion
|
||||
|
||||
deleter = ArchivedWorkflowRunDeletion(dry_run=True)
|
||||
run = MagicMock()
|
||||
run.id = "run-1"
|
||||
run.tenant_id = "tenant-1"
|
||||
|
||||
with patch.object(deleter, "_get_workflow_run_repo", autospec=True) as mock_get_repo:
|
||||
result = deleter._delete_run(run)
|
||||
|
||||
assert result.success is True
|
||||
mock_get_repo.assert_not_called()
|
||||
@ -1,8 +0,0 @@
|
||||
from services.dataset_service import DocumentService
|
||||
|
||||
|
||||
def test_normalize_display_status_alias_mapping():
|
||||
assert DocumentService.normalize_display_status("ACTIVE") == "available"
|
||||
assert DocumentService.normalize_display_status("enabled") == "available"
|
||||
assert DocumentService.normalize_display_status("archived") == "archived"
|
||||
assert DocumentService.normalize_display_status("unknown") is None
|
||||
@ -1,841 +0,0 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from models.model import App, DefaultEndUserSessionID, EndUser
|
||||
from services.end_user_service import EndUserService
|
||||
|
||||
|
||||
class TestEndUserServiceFactory:
|
||||
"""Factory class for creating test data and mock objects for end user service tests."""
|
||||
|
||||
@staticmethod
|
||||
def create_app_mock(
|
||||
app_id: str = "app-123",
|
||||
tenant_id: str = "tenant-456",
|
||||
name: str = "Test App",
|
||||
) -> MagicMock:
|
||||
"""Create a mock App object."""
|
||||
app = MagicMock(spec=App)
|
||||
app.id = app_id
|
||||
app.tenant_id = tenant_id
|
||||
app.name = name
|
||||
return app
|
||||
|
||||
@staticmethod
|
||||
def create_end_user_mock(
|
||||
user_id: str = "user-789",
|
||||
tenant_id: str = "tenant-456",
|
||||
app_id: str = "app-123",
|
||||
session_id: str = "session-001",
|
||||
type: InvokeFrom = InvokeFrom.SERVICE_API,
|
||||
is_anonymous: bool = False,
|
||||
) -> MagicMock:
|
||||
"""Create a mock EndUser object."""
|
||||
end_user = MagicMock(spec=EndUser)
|
||||
end_user.id = user_id
|
||||
end_user.tenant_id = tenant_id
|
||||
end_user.app_id = app_id
|
||||
end_user.session_id = session_id
|
||||
end_user.type = type
|
||||
end_user.is_anonymous = is_anonymous
|
||||
end_user.external_user_id = session_id
|
||||
return end_user
|
||||
|
||||
|
||||
class TestEndUserServiceGetEndUserById:
|
||||
"""Unit tests for EndUserService.get_end_user_by_id method."""
|
||||
|
||||
@pytest.fixture
|
||||
def factory(self):
|
||||
"""Provide test data factory."""
|
||||
return TestEndUserServiceFactory()
|
||||
|
||||
@patch("services.end_user_service.Session")
|
||||
@patch("services.end_user_service.db")
|
||||
def test_get_end_user_by_id_success(self, mock_db, mock_session_class, factory):
|
||||
"""Test successful retrieval of end user by ID."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-123"
|
||||
app_id = "app-456"
|
||||
end_user_id = "user-789"
|
||||
|
||||
mock_end_user = factory.create_end_user_mock(user_id=end_user_id, tenant_id=tenant_id, app_id=app_id)
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_context = MagicMock()
|
||||
mock_context.__enter__.return_value = mock_session
|
||||
mock_session_class.return_value = mock_context
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.first.return_value = mock_end_user
|
||||
|
||||
# Act
|
||||
result = EndUserService.get_end_user_by_id(tenant_id=tenant_id, app_id=app_id, end_user_id=end_user_id)
|
||||
|
||||
# Assert
|
||||
assert result == mock_end_user
|
||||
mock_session.query.assert_called_once_with(EndUser)
|
||||
mock_query.where.assert_called_once()
|
||||
mock_query.first.assert_called_once()
|
||||
mock_context.__enter__.assert_called_once()
|
||||
mock_context.__exit__.assert_called_once()
|
||||
|
||||
@patch("services.end_user_service.Session")
|
||||
@patch("services.end_user_service.db")
|
||||
def test_get_end_user_by_id_not_found(self, mock_db, mock_session_class):
|
||||
"""Test retrieval of non-existent end user returns None."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-123"
|
||||
app_id = "app-456"
|
||||
end_user_id = "user-789"
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_context = MagicMock()
|
||||
mock_context.__enter__.return_value = mock_session
|
||||
mock_session_class.return_value = mock_context
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.first.return_value = None
|
||||
|
||||
# Act
|
||||
result = EndUserService.get_end_user_by_id(tenant_id=tenant_id, app_id=app_id, end_user_id=end_user_id)
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
@patch("services.end_user_service.Session")
|
||||
@patch("services.end_user_service.db")
|
||||
def test_get_end_user_by_id_query_parameters(self, mock_db, mock_session_class):
|
||||
"""Test that query parameters are correctly applied."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-123"
|
||||
app_id = "app-456"
|
||||
end_user_id = "user-789"
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_context = MagicMock()
|
||||
mock_context.__enter__.return_value = mock_session
|
||||
mock_session_class.return_value = mock_context
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.first.return_value = None
|
||||
|
||||
# Act
|
||||
EndUserService.get_end_user_by_id(tenant_id=tenant_id, app_id=app_id, end_user_id=end_user_id)
|
||||
|
||||
# Assert
|
||||
# Verify the where clause was called with the correct conditions
|
||||
call_args = mock_query.where.call_args[0]
|
||||
assert len(call_args) == 3
|
||||
# Check that the conditions match the expected filters
|
||||
# (We can't easily test the exact conditions without importing SQLAlchemy)
|
||||
|
||||
|
||||
class TestEndUserServiceGetOrCreateEndUser:
|
||||
"""Unit tests for EndUserService.get_or_create_end_user method."""
|
||||
|
||||
@pytest.fixture
|
||||
def factory(self):
|
||||
"""Provide test data factory."""
|
||||
return TestEndUserServiceFactory()
|
||||
|
||||
@patch("services.end_user_service.EndUserService.get_or_create_end_user_by_type")
|
||||
def test_get_or_create_end_user_with_user_id(self, mock_get_or_create_by_type, factory):
|
||||
"""Test get_or_create_end_user with specific user_id."""
|
||||
# Arrange
|
||||
app_mock = factory.create_app_mock()
|
||||
user_id = "user-123"
|
||||
expected_end_user = factory.create_end_user_mock()
|
||||
mock_get_or_create_by_type.return_value = expected_end_user
|
||||
|
||||
# Act
|
||||
result = EndUserService.get_or_create_end_user(app_mock, user_id)
|
||||
|
||||
# Assert
|
||||
assert result == expected_end_user
|
||||
mock_get_or_create_by_type.assert_called_once_with(
|
||||
InvokeFrom.SERVICE_API, app_mock.tenant_id, app_mock.id, user_id
|
||||
)
|
||||
|
||||
@patch("services.end_user_service.EndUserService.get_or_create_end_user_by_type")
|
||||
def test_get_or_create_end_user_without_user_id(self, mock_get_or_create_by_type, factory):
|
||||
"""Test get_or_create_end_user without user_id (None)."""
|
||||
# Arrange
|
||||
app_mock = factory.create_app_mock()
|
||||
expected_end_user = factory.create_end_user_mock()
|
||||
mock_get_or_create_by_type.return_value = expected_end_user
|
||||
|
||||
# Act
|
||||
result = EndUserService.get_or_create_end_user(app_mock, None)
|
||||
|
||||
# Assert
|
||||
assert result == expected_end_user
|
||||
mock_get_or_create_by_type.assert_called_once_with(
|
||||
InvokeFrom.SERVICE_API, app_mock.tenant_id, app_mock.id, None
|
||||
)
|
||||
|
||||
|
||||
class TestEndUserServiceGetOrCreateEndUserByType:
|
||||
"""
|
||||
Unit tests for EndUserService.get_or_create_end_user_by_type method.
|
||||
|
||||
This test suite covers:
|
||||
- Creating end users with different InvokeFrom types
|
||||
- Type migration for legacy users
|
||||
- Query ordering and prioritization
|
||||
- Session management
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def factory(self):
|
||||
"""Provide test data factory."""
|
||||
return TestEndUserServiceFactory()
|
||||
|
||||
@patch("services.end_user_service.Session")
|
||||
@patch("services.end_user_service.db")
|
||||
def test_create_new_end_user_with_user_id(self, mock_db, mock_session_class, factory):
|
||||
"""Test creating a new end user with specific user_id."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-123"
|
||||
app_id = "app-456"
|
||||
user_id = "user-789"
|
||||
type_enum = InvokeFrom.SERVICE_API
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_context = MagicMock()
|
||||
mock_context.__enter__.return_value = mock_session
|
||||
mock_session_class.return_value = mock_context
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.first.return_value = None # No existing user
|
||||
|
||||
# Act
|
||||
result = EndUserService.get_or_create_end_user_by_type(
|
||||
type=type_enum, tenant_id=tenant_id, app_id=app_id, user_id=user_id
|
||||
)
|
||||
|
||||
# Assert
|
||||
# Verify new EndUser was created with correct parameters
|
||||
mock_session.add.assert_called_once()
|
||||
mock_session.commit.assert_called_once()
|
||||
added_user = mock_session.add.call_args[0][0]
|
||||
assert added_user.tenant_id == tenant_id
|
||||
assert added_user.app_id == app_id
|
||||
assert added_user.type == type_enum
|
||||
assert added_user.session_id == user_id
|
||||
assert added_user.external_user_id == user_id
|
||||
assert added_user._is_anonymous is False
|
||||
|
||||
@patch("services.end_user_service.Session")
|
||||
@patch("services.end_user_service.db")
|
||||
def test_create_new_end_user_default_session(self, mock_db, mock_session_class, factory):
|
||||
"""Test creating a new end user with default session ID."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-123"
|
||||
app_id = "app-456"
|
||||
user_id = None
|
||||
type_enum = InvokeFrom.WEB_APP
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_context = MagicMock()
|
||||
mock_context.__enter__.return_value = mock_session
|
||||
mock_session_class.return_value = mock_context
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.first.return_value = None # No existing user
|
||||
|
||||
# Act
|
||||
result = EndUserService.get_or_create_end_user_by_type(
|
||||
type=type_enum, tenant_id=tenant_id, app_id=app_id, user_id=user_id
|
||||
)
|
||||
|
||||
# Assert
|
||||
added_user = mock_session.add.call_args[0][0]
|
||||
assert added_user.session_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID
|
||||
assert added_user.external_user_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID
|
||||
assert added_user._is_anonymous is True
|
||||
|
||||
@patch("services.end_user_service.Session")
|
||||
@patch("services.end_user_service.db")
|
||||
@patch("services.end_user_service.logger")
|
||||
def test_existing_user_same_type(self, mock_logger, mock_db, mock_session_class, factory):
|
||||
"""Test retrieving existing user with same type."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-123"
|
||||
app_id = "app-456"
|
||||
user_id = "user-789"
|
||||
type_enum = InvokeFrom.SERVICE_API
|
||||
|
||||
existing_user = factory.create_end_user_mock(
|
||||
tenant_id=tenant_id, app_id=app_id, session_id=user_id, type=type_enum
|
||||
)
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_context = MagicMock()
|
||||
mock_context.__enter__.return_value = mock_session
|
||||
mock_session_class.return_value = mock_context
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.first.return_value = existing_user
|
||||
|
||||
# Act
|
||||
result = EndUserService.get_or_create_end_user_by_type(
|
||||
type=type_enum, tenant_id=tenant_id, app_id=app_id, user_id=user_id
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == existing_user
|
||||
mock_session.add.assert_not_called()
|
||||
mock_session.commit.assert_not_called()
|
||||
mock_logger.info.assert_not_called()
|
||||
|
||||
@patch("services.end_user_service.Session")
|
||||
@patch("services.end_user_service.db")
|
||||
@patch("services.end_user_service.logger")
|
||||
def test_existing_user_different_type_upgrade(self, mock_logger, mock_db, mock_session_class, factory):
|
||||
"""Test upgrading existing user with different type."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-123"
|
||||
app_id = "app-456"
|
||||
user_id = "user-789"
|
||||
old_type = InvokeFrom.WEB_APP
|
||||
new_type = InvokeFrom.SERVICE_API
|
||||
|
||||
existing_user = factory.create_end_user_mock(
|
||||
tenant_id=tenant_id, app_id=app_id, session_id=user_id, type=old_type
|
||||
)
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_context = MagicMock()
|
||||
mock_context.__enter__.return_value = mock_session
|
||||
mock_session_class.return_value = mock_context
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.first.return_value = existing_user
|
||||
|
||||
# Act
|
||||
result = EndUserService.get_or_create_end_user_by_type(
|
||||
type=new_type, tenant_id=tenant_id, app_id=app_id, user_id=user_id
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == existing_user
|
||||
assert existing_user.type == new_type
|
||||
mock_session.commit.assert_called_once()
|
||||
mock_logger.info.assert_called_once()
|
||||
logger_call_args = mock_logger.info.call_args[0]
|
||||
assert "Upgrading legacy EndUser" in logger_call_args[0]
|
||||
# The old and new types are passed as separate arguments
|
||||
assert mock_logger.info.call_args[0][1] == existing_user.id
|
||||
assert mock_logger.info.call_args[0][2] == old_type
|
||||
assert mock_logger.info.call_args[0][3] == new_type
|
||||
assert mock_logger.info.call_args[0][4] == user_id
|
||||
|
||||
@patch("services.end_user_service.Session")
|
||||
@patch("services.end_user_service.db")
|
||||
def test_query_ordering_prioritizes_exact_type_match(self, mock_db, mock_session_class, factory):
|
||||
"""Test that query ordering prioritizes exact type matches."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-123"
|
||||
app_id = "app-456"
|
||||
user_id = "user-789"
|
||||
target_type = InvokeFrom.SERVICE_API
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_context = MagicMock()
|
||||
mock_context.__enter__.return_value = mock_session
|
||||
mock_session_class.return_value = mock_context
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.first.return_value = None
|
||||
|
||||
# Act
|
||||
EndUserService.get_or_create_end_user_by_type(
|
||||
type=target_type, tenant_id=tenant_id, app_id=app_id, user_id=user_id
|
||||
)
|
||||
|
||||
# Assert
|
||||
mock_query.order_by.assert_called_once()
|
||||
# Verify that case statement is used for ordering
|
||||
order_by_call = mock_query.order_by.call_args[0][0]
|
||||
# The exact structure depends on SQLAlchemy's case implementation
|
||||
# but we can verify it was called
|
||||
|
||||
# Test 10: Session context manager properly closes
|
||||
@patch("services.end_user_service.Session")
|
||||
@patch("services.end_user_service.db")
|
||||
def test_session_context_manager_closes(self, mock_db, mock_session_class, factory):
|
||||
"""Test that Session context manager is properly used."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-123"
|
||||
app_id = "app-456"
|
||||
user_id = "user-789"
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_context = MagicMock()
|
||||
mock_context.__enter__.return_value = mock_session
|
||||
mock_session_class.return_value = mock_context
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.first.return_value = None
|
||||
|
||||
# Act
|
||||
EndUserService.get_or_create_end_user_by_type(
|
||||
type=InvokeFrom.SERVICE_API,
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
# Assert
|
||||
# Verify context manager was entered and exited
|
||||
mock_context.__enter__.assert_called_once()
|
||||
mock_context.__exit__.assert_called_once()
|
||||
|
||||
@patch("services.end_user_service.Session")
|
||||
@patch("services.end_user_service.db")
|
||||
def test_all_invokefrom_types_supported(self, mock_db, mock_session_class):
|
||||
"""Test that all InvokeFrom enum values are supported."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-123"
|
||||
app_id = "app-456"
|
||||
user_id = "user-789"
|
||||
|
||||
for invoke_type in InvokeFrom:
|
||||
with patch("services.end_user_service.Session") as mock_session_class:
|
||||
mock_session = MagicMock()
|
||||
mock_context = MagicMock()
|
||||
mock_context.__enter__.return_value = mock_session
|
||||
mock_session_class.return_value = mock_context
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.first.return_value = None
|
||||
|
||||
# Act
|
||||
result = EndUserService.get_or_create_end_user_by_type(
|
||||
type=invoke_type, tenant_id=tenant_id, app_id=app_id, user_id=user_id
|
||||
)
|
||||
|
||||
# Assert
|
||||
added_user = mock_session.add.call_args[0][0]
|
||||
assert added_user.type == invoke_type
|
||||
|
||||
|
||||
class TestEndUserServiceCreateEndUserBatch:
|
||||
"""Unit tests for EndUserService.create_end_user_batch method."""
|
||||
|
||||
@pytest.fixture
|
||||
def factory(self):
|
||||
"""Provide test data factory."""
|
||||
return TestEndUserServiceFactory()
|
||||
|
||||
@patch("services.end_user_service.Session")
|
||||
@patch("services.end_user_service.db")
|
||||
def test_create_batch_empty_app_ids(self, mock_db, mock_session_class):
|
||||
"""Test batch creation with empty app_ids list."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-123"
|
||||
app_ids: list[str] = []
|
||||
user_id = "user-789"
|
||||
type_enum = InvokeFrom.SERVICE_API
|
||||
|
||||
# Act
|
||||
result = EndUserService.create_end_user_batch(
|
||||
type=type_enum, tenant_id=tenant_id, app_ids=app_ids, user_id=user_id
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == {}
|
||||
mock_session_class.assert_not_called()
|
||||
|
||||
@patch("services.end_user_service.Session")
|
||||
@patch("services.end_user_service.db")
|
||||
def test_create_batch_default_session_id(self, mock_db, mock_session_class):
|
||||
"""Test batch creation with empty user_id (uses default session)."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-123"
|
||||
app_ids = ["app-456", "app-789"]
|
||||
user_id = ""
|
||||
type_enum = InvokeFrom.SERVICE_API
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_context = MagicMock()
|
||||
mock_context.__enter__.return_value = mock_session
|
||||
mock_session_class.return_value = mock_context
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.all.return_value = [] # No existing users
|
||||
|
||||
# Act
|
||||
result = EndUserService.create_end_user_batch(
|
||||
type=type_enum, tenant_id=tenant_id, app_ids=app_ids, user_id=user_id
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(result) == 2
|
||||
for app_id, end_user in result.items():
|
||||
assert end_user.session_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID
|
||||
assert end_user.external_user_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID
|
||||
assert end_user._is_anonymous is True
|
||||
|
||||
@patch("services.end_user_service.Session")
|
||||
@patch("services.end_user_service.db")
|
||||
def test_create_batch_deduplicate_app_ids(self, mock_db, mock_session_class):
|
||||
"""Test that duplicate app_ids are deduplicated while preserving order."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-123"
|
||||
app_ids = ["app-456", "app-789", "app-456", "app-123", "app-789"]
|
||||
user_id = "user-789"
|
||||
type_enum = InvokeFrom.SERVICE_API
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_context = MagicMock()
|
||||
mock_context.__enter__.return_value = mock_session
|
||||
mock_session_class.return_value = mock_context
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.all.return_value = [] # No existing users
|
||||
|
||||
# Act
|
||||
result = EndUserService.create_end_user_batch(
|
||||
type=type_enum, tenant_id=tenant_id, app_ids=app_ids, user_id=user_id
|
||||
)
|
||||
|
||||
# Assert
|
||||
# Should have 3 unique app_ids in original order
|
||||
assert len(result) == 3
|
||||
assert "app-456" in result
|
||||
assert "app-789" in result
|
||||
assert "app-123" in result
|
||||
|
||||
# Verify the order is preserved
|
||||
added_users = mock_session.add_all.call_args[0][0]
|
||||
assert len(added_users) == 3
|
||||
assert added_users[0].app_id == "app-456"
|
||||
assert added_users[1].app_id == "app-789"
|
||||
assert added_users[2].app_id == "app-123"
|
||||
|
||||
@patch("services.end_user_service.Session")
|
||||
@patch("services.end_user_service.db")
|
||||
def test_create_batch_all_existing_users(self, mock_db, mock_session_class, factory):
|
||||
"""Test batch creation when all users already exist."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-123"
|
||||
app_ids = ["app-456", "app-789"]
|
||||
user_id = "user-789"
|
||||
type_enum = InvokeFrom.SERVICE_API
|
||||
|
||||
existing_user1 = factory.create_end_user_mock(
|
||||
tenant_id=tenant_id, app_id="app-456", session_id=user_id, type=type_enum
|
||||
)
|
||||
existing_user2 = factory.create_end_user_mock(
|
||||
tenant_id=tenant_id, app_id="app-789", session_id=user_id, type=type_enum
|
||||
)
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_context = MagicMock()
|
||||
mock_context.__enter__.return_value = mock_session
|
||||
mock_session_class.return_value = mock_context
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.all.return_value = [existing_user1, existing_user2]
|
||||
|
||||
# Act
|
||||
result = EndUserService.create_end_user_batch(
|
||||
type=type_enum, tenant_id=tenant_id, app_ids=app_ids, user_id=user_id
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(result) == 2
|
||||
assert result["app-456"] == existing_user1
|
||||
assert result["app-789"] == existing_user2
|
||||
mock_session.add_all.assert_not_called()
|
||||
mock_session.commit.assert_not_called()
|
||||
|
||||
@patch("services.end_user_service.Session")
|
||||
@patch("services.end_user_service.db")
|
||||
def test_create_batch_partial_existing_users(self, mock_db, mock_session_class, factory):
|
||||
"""Test batch creation with some existing and some new users."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-123"
|
||||
app_ids = ["app-456", "app-789", "app-123"]
|
||||
user_id = "user-789"
|
||||
type_enum = InvokeFrom.SERVICE_API
|
||||
|
||||
existing_user1 = factory.create_end_user_mock(
|
||||
tenant_id=tenant_id, app_id="app-456", session_id=user_id, type=type_enum
|
||||
)
|
||||
# app-789 and app-123 don't exist
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_context = MagicMock()
|
||||
mock_context.__enter__.return_value = mock_session
|
||||
mock_session_class.return_value = mock_context
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.all.return_value = [existing_user1]
|
||||
|
||||
# Act
|
||||
result = EndUserService.create_end_user_batch(
|
||||
type=type_enum, tenant_id=tenant_id, app_ids=app_ids, user_id=user_id
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(result) == 3
|
||||
assert result["app-456"] == existing_user1
|
||||
assert "app-789" in result
|
||||
assert "app-123" in result
|
||||
|
||||
# Should create 2 new users
|
||||
mock_session.add_all.assert_called_once()
|
||||
added_users = mock_session.add_all.call_args[0][0]
|
||||
assert len(added_users) == 2
|
||||
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
@patch("services.end_user_service.Session")
|
||||
@patch("services.end_user_service.db")
|
||||
def test_create_batch_handles_duplicates_in_existing(self, mock_db, mock_session_class, factory):
|
||||
"""Test batch creation handles duplicates in existing users gracefully."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-123"
|
||||
app_ids = ["app-456"]
|
||||
user_id = "user-789"
|
||||
type_enum = InvokeFrom.SERVICE_API
|
||||
|
||||
# Simulate duplicate records in database
|
||||
existing_user1 = factory.create_end_user_mock(
|
||||
user_id="user-1", tenant_id=tenant_id, app_id="app-456", session_id=user_id, type=type_enum
|
||||
)
|
||||
existing_user2 = factory.create_end_user_mock(
|
||||
user_id="user-2", tenant_id=tenant_id, app_id="app-456", session_id=user_id, type=type_enum
|
||||
)
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_context = MagicMock()
|
||||
mock_context.__enter__.return_value = mock_session
|
||||
mock_session_class.return_value = mock_context
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.all.return_value = [existing_user1, existing_user2]
|
||||
|
||||
# Act
|
||||
result = EndUserService.create_end_user_batch(
|
||||
type=type_enum, tenant_id=tenant_id, app_ids=app_ids, user_id=user_id
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(result) == 1
|
||||
# Should prefer the first one found
|
||||
assert result["app-456"] == existing_user1
|
||||
|
||||
@patch("services.end_user_service.Session")
|
||||
@patch("services.end_user_service.db")
|
||||
def test_create_batch_all_invokefrom_types(self, mock_db, mock_session_class):
|
||||
"""Test batch creation with all InvokeFrom types."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-123"
|
||||
app_ids = ["app-456"]
|
||||
user_id = "user-789"
|
||||
|
||||
for invoke_type in InvokeFrom:
|
||||
with patch("services.end_user_service.Session") as mock_session_class:
|
||||
mock_session = MagicMock()
|
||||
mock_context = MagicMock()
|
||||
mock_context.__enter__.return_value = mock_session
|
||||
mock_session_class.return_value = mock_context
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.all.return_value = [] # No existing users
|
||||
|
||||
# Act
|
||||
result = EndUserService.create_end_user_batch(
|
||||
type=invoke_type, tenant_id=tenant_id, app_ids=app_ids, user_id=user_id
|
||||
)
|
||||
|
||||
# Assert
|
||||
added_user = mock_session.add_all.call_args[0][0][0]
|
||||
assert added_user.type == invoke_type
|
||||
|
||||
@patch("services.end_user_service.Session")
|
||||
@patch("services.end_user_service.db")
|
||||
def test_create_batch_single_app_id(self, mock_db, mock_session_class, factory):
|
||||
"""Test batch creation with single app_id."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-123"
|
||||
app_ids = ["app-456"]
|
||||
user_id = "user-789"
|
||||
type_enum = InvokeFrom.SERVICE_API
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_context = MagicMock()
|
||||
mock_context.__enter__.return_value = mock_session
|
||||
mock_session_class.return_value = mock_context
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.all.return_value = [] # No existing users
|
||||
|
||||
# Act
|
||||
result = EndUserService.create_end_user_batch(
|
||||
type=type_enum, tenant_id=tenant_id, app_ids=app_ids, user_id=user_id
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(result) == 1
|
||||
assert "app-456" in result
|
||||
mock_session.add_all.assert_called_once()
|
||||
added_users = mock_session.add_all.call_args[0][0]
|
||||
assert len(added_users) == 1
|
||||
assert added_users[0].app_id == "app-456"
|
||||
|
||||
@patch("services.end_user_service.Session")
|
||||
@patch("services.end_user_service.db")
|
||||
def test_create_batch_anonymous_vs_authenticated(self, mock_db, mock_session_class):
|
||||
"""Test batch creation correctly sets anonymous flag."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-123"
|
||||
app_ids = ["app-456", "app-789"]
|
||||
|
||||
# Test with regular user ID
|
||||
mock_session = MagicMock()
|
||||
mock_context = MagicMock()
|
||||
mock_context.__enter__.return_value = mock_session
|
||||
mock_session_class.return_value = mock_context
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.all.return_value = [] # No existing users
|
||||
|
||||
# Act - authenticated user
|
||||
result = EndUserService.create_end_user_batch(
|
||||
type=InvokeFrom.SERVICE_API, tenant_id=tenant_id, app_ids=app_ids, user_id="user-789"
|
||||
)
|
||||
|
||||
# Assert
|
||||
added_users = mock_session.add_all.call_args[0][0]
|
||||
for user in added_users:
|
||||
assert user._is_anonymous is False
|
||||
|
||||
# Test with default session ID
|
||||
mock_session.reset_mock()
|
||||
mock_query.reset_mock()
|
||||
mock_query.all.return_value = []
|
||||
|
||||
# Act - anonymous user
|
||||
result = EndUserService.create_end_user_batch(
|
||||
type=InvokeFrom.SERVICE_API,
|
||||
tenant_id=tenant_id,
|
||||
app_ids=app_ids,
|
||||
user_id=DefaultEndUserSessionID.DEFAULT_SESSION_ID,
|
||||
)
|
||||
|
||||
# Assert
|
||||
added_users = mock_session.add_all.call_args[0][0]
|
||||
for user in added_users:
|
||||
assert user._is_anonymous is True
|
||||
|
||||
@patch("services.end_user_service.Session")
|
||||
@patch("services.end_user_service.db")
|
||||
def test_create_batch_efficient_single_query(self, mock_db, mock_session_class):
|
||||
"""Test that batch creation uses efficient single query for existing users."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-123"
|
||||
app_ids = ["app-456", "app-789", "app-123"]
|
||||
user_id = "user-789"
|
||||
type_enum = InvokeFrom.SERVICE_API
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_context = MagicMock()
|
||||
mock_context.__enter__.return_value = mock_session
|
||||
mock_session_class.return_value = mock_context
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.all.return_value = [] # No existing users
|
||||
|
||||
# Act
|
||||
EndUserService.create_end_user_batch(type=type_enum, tenant_id=tenant_id, app_ids=app_ids, user_id=user_id)
|
||||
|
||||
# Assert
|
||||
# Should make exactly one query to check for existing users
|
||||
mock_session.query.assert_called_once_with(EndUser)
|
||||
mock_query.where.assert_called_once()
|
||||
mock_query.all.assert_called_once()
|
||||
|
||||
# Verify the where clause uses .in_() for app_ids
|
||||
where_call = mock_query.where.call_args[0]
|
||||
# The exact structure depends on SQLAlchemy implementation
|
||||
# but we can verify it was called with the right parameters
|
||||
|
||||
@patch("services.end_user_service.Session")
|
||||
@patch("services.end_user_service.db")
|
||||
def test_create_batch_session_context_manager(self, mock_db, mock_session_class):
|
||||
"""Test that batch creation properly uses session context manager."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-123"
|
||||
app_ids = ["app-456"]
|
||||
user_id = "user-789"
|
||||
type_enum = InvokeFrom.SERVICE_API
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_context = MagicMock()
|
||||
mock_context.__enter__.return_value = mock_session
|
||||
mock_session_class.return_value = mock_context
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.all.return_value = [] # No existing users
|
||||
|
||||
# Act
|
||||
EndUserService.create_end_user_batch(type=type_enum, tenant_id=tenant_id, app_ids=app_ids, user_id=user_id)
|
||||
|
||||
# Assert
|
||||
mock_context.__enter__.assert_called_once()
|
||||
mock_context.__exit__.assert_called_once()
|
||||
mock_session.commit.assert_called_once()
|
||||
@ -1,99 +0,0 @@
|
||||
"""
|
||||
Unit tests for `services.file_service.FileService` helpers.
|
||||
|
||||
We keep these tests focused on:
|
||||
- ZIP tempfile building (sanitization + deduplication + content writes)
|
||||
- tenant-scoped batch lookup behavior (`get_upload_files_by_ids`)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from typing import Any
|
||||
from zipfile import ZipFile
|
||||
|
||||
import pytest
|
||||
|
||||
import services.file_service as file_service_module
|
||||
from services.file_service import FileService
|
||||
|
||||
|
||||
def test_build_upload_files_zip_tempfile_sanitizes_and_dedupes_names(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Ensure ZIP entry names are safe and unique while preserving extensions."""
|
||||
|
||||
# Arrange: three upload files that all sanitize down to the same basename ("b.txt").
|
||||
upload_files: list[Any] = [
|
||||
SimpleNamespace(name="a/b.txt", key="k1"),
|
||||
SimpleNamespace(name="c/b.txt", key="k2"),
|
||||
SimpleNamespace(name="../b.txt", key="k3"),
|
||||
]
|
||||
|
||||
# Stream distinct bytes per key so we can verify content is written to the right entry.
|
||||
data_by_key: dict[str, list[bytes]] = {"k1": [b"one"], "k2": [b"two"], "k3": [b"three"]}
|
||||
|
||||
def _load(key: str, stream: bool = True) -> list[bytes]:
|
||||
# Return the corresponding chunks for this key (the production code iterates chunks).
|
||||
assert stream is True
|
||||
return data_by_key[key]
|
||||
|
||||
monkeypatch.setattr(file_service_module.storage, "load", _load)
|
||||
|
||||
# Act: build zip in a tempfile.
|
||||
with FileService.build_upload_files_zip_tempfile(upload_files=upload_files) as tmp:
|
||||
with ZipFile(tmp, mode="r") as zf:
|
||||
# Assert: names are sanitized (no directory components) and deduped with suffixes.
|
||||
assert zf.namelist() == ["b.txt", "b (1).txt", "b (2).txt"]
|
||||
|
||||
# Assert: each entry contains the correct bytes from storage.
|
||||
assert zf.read("b.txt") == b"one"
|
||||
assert zf.read("b (1).txt") == b"two"
|
||||
assert zf.read("b (2).txt") == b"three"
|
||||
|
||||
|
||||
def test_get_upload_files_by_ids_returns_empty_when_no_ids(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Ensure empty input returns an empty mapping without hitting the database."""
|
||||
|
||||
class _Session:
|
||||
def scalars(self, _stmt): # type: ignore[no-untyped-def]
|
||||
raise AssertionError("db.session.scalars should not be called for empty id lists")
|
||||
|
||||
monkeypatch.setattr(file_service_module, "db", SimpleNamespace(session=_Session()))
|
||||
|
||||
assert FileService.get_upload_files_by_ids("tenant-1", []) == {}
|
||||
|
||||
|
||||
def test_get_upload_files_by_ids_returns_id_keyed_mapping(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Ensure batch lookup returns a dict keyed by stringified UploadFile ids."""
|
||||
|
||||
upload_files: list[Any] = [
|
||||
SimpleNamespace(id="file-1", tenant_id="tenant-1"),
|
||||
SimpleNamespace(id="file-2", tenant_id="tenant-1"),
|
||||
]
|
||||
|
||||
class _ScalarResult:
|
||||
def __init__(self, items: list[Any]) -> None:
|
||||
self._items = items
|
||||
|
||||
def all(self) -> list[Any]:
|
||||
return self._items
|
||||
|
||||
class _Session:
|
||||
def __init__(self, items: list[Any]) -> None:
|
||||
self._items = items
|
||||
self.calls: list[object] = []
|
||||
|
||||
def scalars(self, stmt): # type: ignore[no-untyped-def]
|
||||
# Capture the statement so we can at least assert the query path is taken.
|
||||
self.calls.append(stmt)
|
||||
return _ScalarResult(self._items)
|
||||
|
||||
session = _Session(upload_files)
|
||||
monkeypatch.setattr(file_service_module, "db", SimpleNamespace(session=session))
|
||||
|
||||
# Provide duplicates to ensure callers can safely pass repeated ids.
|
||||
result = FileService.get_upload_files_by_ids("tenant-1", ["file-1", "file-1", "file-2"])
|
||||
|
||||
assert set(result.keys()) == {"file-1", "file-2"}
|
||||
assert result["file-1"].id == "file-1"
|
||||
assert result["file-2"].id == "file-2"
|
||||
assert len(session.calls) == 1
|
||||
@ -1,224 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from types import SimpleNamespace
|
||||
from typing import cast
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
from werkzeug.exceptions import BadRequest
|
||||
|
||||
from services.oauth_server import (
|
||||
OAUTH_ACCESS_TOKEN_EXPIRES_IN,
|
||||
OAUTH_ACCESS_TOKEN_REDIS_KEY,
|
||||
OAUTH_AUTHORIZATION_CODE_REDIS_KEY,
|
||||
OAUTH_REFRESH_TOKEN_EXPIRES_IN,
|
||||
OAUTH_REFRESH_TOKEN_REDIS_KEY,
|
||||
OAuthGrantType,
|
||||
OAuthServerService,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_redis_client(mocker: MockerFixture) -> MagicMock:
|
||||
return mocker.patch("services.oauth_server.redis_client")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session(mocker: MockerFixture) -> MagicMock:
|
||||
"""Mock the OAuth server Session context manager."""
|
||||
mocker.patch("services.oauth_server.db", SimpleNamespace(engine=object()))
|
||||
session = MagicMock()
|
||||
session_cm = MagicMock()
|
||||
session_cm.__enter__.return_value = session
|
||||
mocker.patch("services.oauth_server.Session", return_value=session_cm)
|
||||
return session
|
||||
|
||||
|
||||
def test_get_oauth_provider_app_should_return_app_when_record_exists(mock_session: MagicMock) -> None:
|
||||
# Arrange
|
||||
mock_execute_result = MagicMock()
|
||||
expected_app = MagicMock()
|
||||
mock_execute_result.scalar_one_or_none.return_value = expected_app
|
||||
mock_session.execute.return_value = mock_execute_result
|
||||
|
||||
# Act
|
||||
result = OAuthServerService.get_oauth_provider_app("client-1")
|
||||
|
||||
# Assert
|
||||
assert result is expected_app
|
||||
mock_session.execute.assert_called_once()
|
||||
mock_execute_result.scalar_one_or_none.assert_called_once()
|
||||
|
||||
|
||||
def test_sign_oauth_authorization_code_should_store_code_and_return_value(
|
||||
mocker: MockerFixture, mock_redis_client: MagicMock
|
||||
) -> None:
|
||||
# Arrange
|
||||
deterministic_uuid = uuid.UUID("00000000-0000-0000-0000-000000000111")
|
||||
mocker.patch("services.oauth_server.uuid.uuid4", return_value=deterministic_uuid)
|
||||
|
||||
# Act
|
||||
code = OAuthServerService.sign_oauth_authorization_code("client-1", "user-1")
|
||||
|
||||
# Assert
|
||||
expected_code = str(deterministic_uuid)
|
||||
assert code == expected_code
|
||||
mock_redis_client.set.assert_called_once_with(
|
||||
OAUTH_AUTHORIZATION_CODE_REDIS_KEY.format(client_id="client-1", code=expected_code),
|
||||
"user-1",
|
||||
ex=600,
|
||||
)
|
||||
|
||||
|
||||
def test_sign_oauth_access_token_should_raise_bad_request_when_authorization_code_is_invalid(
|
||||
mock_redis_client: MagicMock,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mock_redis_client.get.return_value = None
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(BadRequest, match="invalid code"):
|
||||
OAuthServerService.sign_oauth_access_token(
|
||||
grant_type=OAuthGrantType.AUTHORIZATION_CODE,
|
||||
code="bad-code",
|
||||
client_id="client-1",
|
||||
)
|
||||
|
||||
|
||||
def test_sign_oauth_access_token_should_issue_access_and_refresh_token_when_authorization_code_is_valid(
|
||||
mocker: MockerFixture, mock_redis_client: MagicMock
|
||||
) -> None:
|
||||
# Arrange
|
||||
token_uuids = [
|
||||
uuid.UUID("00000000-0000-0000-0000-000000000201"),
|
||||
uuid.UUID("00000000-0000-0000-0000-000000000202"),
|
||||
]
|
||||
mocker.patch("services.oauth_server.uuid.uuid4", side_effect=token_uuids)
|
||||
mock_redis_client.get.return_value = b"user-1"
|
||||
code_key = OAUTH_AUTHORIZATION_CODE_REDIS_KEY.format(client_id="client-1", code="code-1")
|
||||
|
||||
# Act
|
||||
access_token, refresh_token = OAuthServerService.sign_oauth_access_token(
|
||||
grant_type=OAuthGrantType.AUTHORIZATION_CODE,
|
||||
code="code-1",
|
||||
client_id="client-1",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert access_token == str(token_uuids[0])
|
||||
assert refresh_token == str(token_uuids[1])
|
||||
mock_redis_client.delete.assert_called_once_with(code_key)
|
||||
mock_redis_client.set.assert_any_call(
|
||||
OAUTH_ACCESS_TOKEN_REDIS_KEY.format(client_id="client-1", token=access_token),
|
||||
b"user-1",
|
||||
ex=OAUTH_ACCESS_TOKEN_EXPIRES_IN,
|
||||
)
|
||||
mock_redis_client.set.assert_any_call(
|
||||
OAUTH_REFRESH_TOKEN_REDIS_KEY.format(client_id="client-1", token=refresh_token),
|
||||
b"user-1",
|
||||
ex=OAUTH_REFRESH_TOKEN_EXPIRES_IN,
|
||||
)
|
||||
|
||||
|
||||
def test_sign_oauth_access_token_should_raise_bad_request_when_refresh_token_is_invalid(
|
||||
mock_redis_client: MagicMock,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mock_redis_client.get.return_value = None
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(BadRequest, match="invalid refresh token"):
|
||||
OAuthServerService.sign_oauth_access_token(
|
||||
grant_type=OAuthGrantType.REFRESH_TOKEN,
|
||||
refresh_token="stale-token",
|
||||
client_id="client-1",
|
||||
)
|
||||
|
||||
|
||||
def test_sign_oauth_access_token_should_issue_new_access_token_when_refresh_token_is_valid(
|
||||
mocker: MockerFixture, mock_redis_client: MagicMock
|
||||
) -> None:
|
||||
# Arrange
|
||||
deterministic_uuid = uuid.UUID("00000000-0000-0000-0000-000000000301")
|
||||
mocker.patch("services.oauth_server.uuid.uuid4", return_value=deterministic_uuid)
|
||||
mock_redis_client.get.return_value = b"user-1"
|
||||
|
||||
# Act
|
||||
access_token, returned_refresh_token = OAuthServerService.sign_oauth_access_token(
|
||||
grant_type=OAuthGrantType.REFRESH_TOKEN,
|
||||
refresh_token="refresh-1",
|
||||
client_id="client-1",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert access_token == str(deterministic_uuid)
|
||||
assert returned_refresh_token == "refresh-1"
|
||||
mock_redis_client.set.assert_called_once_with(
|
||||
OAUTH_ACCESS_TOKEN_REDIS_KEY.format(client_id="client-1", token=access_token),
|
||||
b"user-1",
|
||||
ex=OAUTH_ACCESS_TOKEN_EXPIRES_IN,
|
||||
)
|
||||
|
||||
|
||||
def test_sign_oauth_access_token_with_unknown_grant_type_should_return_none() -> None:
|
||||
# Arrange
|
||||
grant_type = cast(OAuthGrantType, "invalid-grant-type")
|
||||
|
||||
# Act
|
||||
result = OAuthServerService.sign_oauth_access_token(
|
||||
grant_type=grant_type,
|
||||
client_id="client-1",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_sign_oauth_refresh_token_should_store_token_with_expected_expiry(
|
||||
mocker: MockerFixture, mock_redis_client: MagicMock
|
||||
) -> None:
|
||||
# Arrange
|
||||
deterministic_uuid = uuid.UUID("00000000-0000-0000-0000-000000000401")
|
||||
mocker.patch("services.oauth_server.uuid.uuid4", return_value=deterministic_uuid)
|
||||
|
||||
# Act
|
||||
refresh_token = OAuthServerService._sign_oauth_refresh_token("client-2", "user-2")
|
||||
|
||||
# Assert
|
||||
assert refresh_token == str(deterministic_uuid)
|
||||
mock_redis_client.set.assert_called_once_with(
|
||||
OAUTH_REFRESH_TOKEN_REDIS_KEY.format(client_id="client-2", token=refresh_token),
|
||||
"user-2",
|
||||
ex=OAUTH_REFRESH_TOKEN_EXPIRES_IN,
|
||||
)
|
||||
|
||||
|
||||
def test_validate_oauth_access_token_should_return_none_when_token_not_found(
|
||||
mock_redis_client: MagicMock,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mock_redis_client.get.return_value = None
|
||||
|
||||
# Act
|
||||
result = OAuthServerService.validate_oauth_access_token("client-1", "missing-token")
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_validate_oauth_access_token_should_load_user_when_token_exists(
|
||||
mocker: MockerFixture, mock_redis_client: MagicMock
|
||||
) -> None:
|
||||
# Arrange
|
||||
mock_redis_client.get.return_value = b"user-88"
|
||||
expected_user = MagicMock()
|
||||
mock_load_user = mocker.patch("services.oauth_server.AccountService.load_user", return_value=expected_user)
|
||||
|
||||
# Act
|
||||
result = OAuthServerService.validate_oauth_access_token("client-1", "access-token")
|
||||
|
||||
# Assert
|
||||
assert result is expected_user
|
||||
mock_load_user.assert_called_once_with("user-88")
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user