mirror of
https://github.com/langgenius/dify.git
synced 2026-05-10 14:14:17 +08:00
chore: unify api && clean some type ignore (#35984)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
e48d7bb097
commit
c67ce6f66d
@ -21,8 +21,6 @@ from libs.token import extract_access_token
|
||||
from models.model import App, ExporleBanner, InstalledApp, RecommendedApp, TrialApp
|
||||
from services.billing_service import BillingService, LangContentDict
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class InsertExploreAppPayload(BaseModel):
|
||||
app_id: str = Field(...)
|
||||
@ -59,15 +57,7 @@ class InsertExploreBannerPayload(BaseModel):
|
||||
model_config = {"populate_by_name": True}
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
InsertExploreAppPayload.__name__,
|
||||
InsertExploreAppPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
console_ns.schema_model(
|
||||
InsertExploreBannerPayload.__name__,
|
||||
InsertExploreBannerPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
register_schema_models(console_ns, InsertExploreAppPayload, InsertExploreBannerPayload)
|
||||
|
||||
|
||||
def admin_required[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
|
||||
@ -34,7 +34,7 @@ class AdvancedPromptTemplateList(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
args = AdvancedPromptTemplateQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = AdvancedPromptTemplateQuery.model_validate(request.args.to_dict(flat=True))
|
||||
prompt_args: AdvancedPromptTemplateArgs = {
|
||||
"app_mode": args.app_mode,
|
||||
"model_mode": args.model_mode,
|
||||
|
||||
@ -2,6 +2,7 @@ from flask import request
|
||||
from flask_restx import Resource, fields
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
@ -10,8 +11,6 @@ from libs.login import login_required
|
||||
from models.model import AppMode
|
||||
from services.agent_service import AgentService
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class AgentLogQuery(BaseModel):
|
||||
message_id: str = Field(..., description="Message UUID")
|
||||
@ -23,9 +22,7 @@ class AgentLogQuery(BaseModel):
|
||||
return uuid_value(value)
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
AgentLogQuery.__name__, AgentLogQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
register_schema_models(console_ns, AgentLogQuery)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/agent/logs")
|
||||
@ -44,6 +41,6 @@ class AgentLogApi(Resource):
|
||||
@get_app_model(mode=[AppMode.AGENT_CHAT])
|
||||
def get(self, app_model):
|
||||
"""Get agent logs"""
|
||||
args = AgentLogQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = AgentLogQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
return AgentService.get_agent_logs(app_model, args.conversation_id, args.message_id)
|
||||
|
||||
@ -33,8 +33,6 @@ from services.annotation_service import (
|
||||
UpsertAnnotationArgs,
|
||||
)
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class AnnotationReplyPayload(BaseModel):
|
||||
score_threshold: float = Field(..., description="Score threshold for annotation matching")
|
||||
@ -87,17 +85,6 @@ class AnnotationFilePayload(BaseModel):
|
||||
return uuid_value(value)
|
||||
|
||||
|
||||
def reg(model: type[BaseModel]) -> None:
|
||||
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
|
||||
reg(AnnotationReplyPayload)
|
||||
reg(AnnotationSettingUpdatePayload)
|
||||
reg(AnnotationListQuery)
|
||||
reg(CreateAnnotationPayload)
|
||||
reg(UpdateAnnotationPayload)
|
||||
reg(AnnotationReplyStatusQuery)
|
||||
reg(AnnotationFilePayload)
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
Annotation,
|
||||
@ -105,6 +92,13 @@ register_schema_models(
|
||||
AnnotationExportList,
|
||||
AnnotationHitHistory,
|
||||
AnnotationHitHistoryList,
|
||||
AnnotationReplyPayload,
|
||||
AnnotationSettingUpdatePayload,
|
||||
AnnotationListQuery,
|
||||
CreateAnnotationPayload,
|
||||
UpdateAnnotationPayload,
|
||||
AnnotationReplyStatusQuery,
|
||||
AnnotationFilePayload,
|
||||
)
|
||||
|
||||
|
||||
@ -218,7 +212,7 @@ class AnnotationApi(Resource):
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def get(self, app_id):
|
||||
args = AnnotationListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = AnnotationListQuery.model_validate(request.args.to_dict(flat=True))
|
||||
page = args.page
|
||||
limit = args.limit
|
||||
keyword = args.keyword
|
||||
|
||||
@ -701,7 +701,7 @@ class AppExportApi(Resource):
|
||||
@edit_permission_required
|
||||
def get(self, app_model):
|
||||
"""Export app"""
|
||||
args = AppExportQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = AppExportQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
payload = AppExportResponse(
|
||||
data=AppDslService.export_dsl(
|
||||
|
||||
@ -173,7 +173,7 @@ class TextModesApi(Resource):
|
||||
@account_initialization_required
|
||||
def get(self, app_model):
|
||||
try:
|
||||
args = TextToSpeechVoiceQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = TextToSpeechVoiceQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
response = AudioService.transcript_tts_voices(
|
||||
tenant_id=app_model.tenant_id,
|
||||
|
||||
@ -7,6 +7,7 @@ from pydantic import BaseModel, Field, field_validator
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import (
|
||||
AppUnavailableError,
|
||||
@ -37,7 +38,6 @@ from services.app_task_service import AppTaskService
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class BaseMessagePayload(BaseModel):
|
||||
@ -65,13 +65,7 @@ class ChatMessagePayload(BaseMessagePayload):
|
||||
return uuid_value(value)
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
CompletionMessagePayload.__name__,
|
||||
CompletionMessagePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
console_ns.schema_model(
|
||||
ChatMessagePayload.__name__, ChatMessagePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
register_schema_models(console_ns, CompletionMessagePayload, ChatMessagePayload)
|
||||
|
||||
|
||||
# define completion message api for user
|
||||
|
||||
@ -39,8 +39,6 @@ from models.model import AppMode
|
||||
from services.conversation_service import ConversationService
|
||||
from services.errors.conversation import ConversationNotExistsError
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class BaseConversationQuery(BaseModel):
|
||||
keyword: str | None = Field(default=None, description="Search keyword")
|
||||
@ -70,15 +68,6 @@ class ChatConversationQuery(BaseConversationQuery):
|
||||
)
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
CompletionConversationQuery.__name__,
|
||||
CompletionConversationQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
console_ns.schema_model(
|
||||
ChatConversationQuery.__name__,
|
||||
ChatConversationQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
CompletionConversationQuery,
|
||||
@ -89,6 +78,8 @@ register_schema_models(
|
||||
ConversationWithSummaryPaginationResponse,
|
||||
ConversationDetailResponse,
|
||||
ResultResponse,
|
||||
CompletionConversationQuery,
|
||||
ChatConversationQuery,
|
||||
)
|
||||
|
||||
|
||||
@ -107,7 +98,7 @@ class CompletionConversationApi(Resource):
|
||||
@edit_permission_required
|
||||
def get(self, app_model):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
args = CompletionConversationQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = CompletionConversationQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
query = sa.select(Conversation).where(
|
||||
Conversation.app_id == app_model.id, Conversation.mode == "completion", Conversation.is_deleted.is_(False)
|
||||
@ -221,7 +212,7 @@ class ChatConversationApi(Resource):
|
||||
@edit_permission_required
|
||||
def get(self, app_model):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
args = ChatConversationQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = ChatConversationQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
subquery = (
|
||||
sa.select(Conversation.id.label("conversation_id"), EndUser.session_id.label("from_end_user_session_id"))
|
||||
|
||||
@ -100,7 +100,7 @@ class ConversationVariablesApi(Resource):
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.ADVANCED_CHAT)
|
||||
def get(self, app_model):
|
||||
args = ConversationVariablesQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = ConversationVariablesQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
stmt = (
|
||||
select(ConversationVariable)
|
||||
|
||||
@ -5,14 +5,13 @@ from flask_restx import Resource, fields
|
||||
from pydantic import BaseModel, Field
|
||||
from werkzeug.exceptions import BadRequest
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import TracingConfigCheckError, TracingConfigIsExist, TracingConfigNotExist
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from libs.login import login_required
|
||||
from services.ops_service import OpsService
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class TraceProviderQuery(BaseModel):
|
||||
tracing_provider: str = Field(..., description="Tracing provider name")
|
||||
@ -23,13 +22,7 @@ class TraceConfigPayload(BaseModel):
|
||||
tracing_config: dict[str, Any] = Field(..., description="Tracing configuration data")
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
TraceProviderQuery.__name__,
|
||||
TraceProviderQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
console_ns.schema_model(
|
||||
TraceConfigPayload.__name__, TraceConfigPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
register_schema_models(console_ns, TraceProviderQuery, TraceConfigPayload)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/trace-config")
|
||||
@ -50,7 +43,7 @@ class TraceAppConfigApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_id):
|
||||
args = TraceProviderQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = TraceProviderQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
try:
|
||||
trace_config = OpsService.get_tracing_app_config(app_id=app_id, tracing_provider=args.tracing_provider)
|
||||
@ -121,7 +114,7 @@ class TraceAppConfigApi(Resource):
|
||||
@account_initialization_required
|
||||
def delete(self, app_id):
|
||||
"""Delete an existing trace app configuration"""
|
||||
args = TraceProviderQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = TraceProviderQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
try:
|
||||
result = OpsService.delete_tracing_app_config(app_id=app_id, tracing_provider=args.tracing_provider)
|
||||
|
||||
@ -5,6 +5,7 @@ from flask import abort, jsonify, request
|
||||
from flask_restx import Resource, fields
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
@ -15,8 +16,6 @@ from libs.helper import convert_datetime_to_date
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import AppMode
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class StatisticTimeRangeQuery(BaseModel):
|
||||
start: str | None = Field(default=None, description="Start date (YYYY-MM-DD HH:MM)")
|
||||
@ -30,10 +29,7 @@ class StatisticTimeRangeQuery(BaseModel):
|
||||
return value
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
StatisticTimeRangeQuery.__name__,
|
||||
StatisticTimeRangeQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
register_schema_models(console_ns, StatisticTimeRangeQuery)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/statistics/daily-messages")
|
||||
@ -54,7 +50,7 @@ class DailyMessageStatistic(Resource):
|
||||
def get(self, app_model):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
converted_created_at = convert_datetime_to_date("created_at")
|
||||
sql_query = f"""SELECT
|
||||
@ -111,7 +107,7 @@ class DailyConversationStatistic(Resource):
|
||||
def get(self, app_model):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
converted_created_at = convert_datetime_to_date("created_at")
|
||||
sql_query = f"""SELECT
|
||||
@ -167,7 +163,7 @@ class DailyTerminalsStatistic(Resource):
|
||||
def get(self, app_model):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
converted_created_at = convert_datetime_to_date("created_at")
|
||||
sql_query = f"""SELECT
|
||||
@ -224,7 +220,7 @@ class DailyTokenCostStatistic(Resource):
|
||||
def get(self, app_model):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
converted_created_at = convert_datetime_to_date("created_at")
|
||||
sql_query = f"""SELECT
|
||||
@ -284,7 +280,7 @@ class AverageSessionInteractionStatistic(Resource):
|
||||
def get(self, app_model):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
converted_created_at = convert_datetime_to_date("c.created_at")
|
||||
sql_query = f"""SELECT
|
||||
@ -360,7 +356,7 @@ class UserSatisfactionRateStatistic(Resource):
|
||||
def get(self, app_model):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
converted_created_at = convert_datetime_to_date("m.created_at")
|
||||
sql_query = f"""SELECT
|
||||
@ -426,7 +422,7 @@ class AverageResponseTimeStatistic(Resource):
|
||||
def get(self, app_model):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
converted_created_at = convert_datetime_to_date("created_at")
|
||||
sql_query = f"""SELECT
|
||||
@ -482,7 +478,7 @@ class TokensPerSecondStatistic(Resource):
|
||||
@account_initialization_required
|
||||
def get(self, app_model):
|
||||
account, _ = current_account_with_tenant()
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
converted_created_at = convert_datetime_to_date("created_at")
|
||||
sql_query = f"""SELECT
|
||||
|
||||
@ -60,7 +60,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
_file_access_controller = DatabaseFileAccessController()
|
||||
LISTENING_RETRY_IN = 2000
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
RESTORE_SOURCE_WORKFLOW_MUST_BE_PUBLISHED_MESSAGE = "source workflow must be published"
|
||||
MAX_WORKFLOW_ONLINE_USERS_REQUEST_IDS = 1000
|
||||
WORKFLOW_ONLINE_USERS_REDIS_BATCH_SIZE = 50
|
||||
@ -912,7 +912,7 @@ class DefaultBlockConfigApi(Resource):
|
||||
"""
|
||||
Get default block config
|
||||
"""
|
||||
args = DefaultBlockConfigQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = DefaultBlockConfigQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
filters = None
|
||||
if args.q:
|
||||
@ -1005,7 +1005,7 @@ class PublishedAllWorkflowApi(Resource):
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
args = WorkflowListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = WorkflowListQuery.model_validate(request.args.to_dict(flat=True))
|
||||
page = args.page
|
||||
limit = args.limit
|
||||
user_id = args.user_id
|
||||
|
||||
@ -185,7 +185,7 @@ class WorkflowAppLogApi(Resource):
|
||||
"""
|
||||
Get workflow app logs
|
||||
"""
|
||||
args = WorkflowAppLogQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = WorkflowAppLogQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
# get paginate workflow app logs
|
||||
workflow_app_service = WorkflowAppService()
|
||||
@ -228,7 +228,7 @@ class WorkflowArchivedLogApi(Resource):
|
||||
"""
|
||||
Get workflow archived logs
|
||||
"""
|
||||
args = WorkflowAppLogQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = WorkflowAppLogQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
workflow_app_service = WorkflowAppService()
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
|
||||
@ -23,7 +23,6 @@ from services.account_service import TenantService
|
||||
from services.workflow_comment_service import WorkflowCommentService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class WorkflowCommentCreatePayload(BaseModel):
|
||||
@ -52,13 +51,14 @@ class WorkflowCommentMentionUsersPayload(BaseModel):
|
||||
users: list[AccountWithRole]
|
||||
|
||||
|
||||
for model in (
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
AccountWithRole,
|
||||
WorkflowCommentMentionUsersPayload,
|
||||
WorkflowCommentCreatePayload,
|
||||
WorkflowCommentUpdatePayload,
|
||||
WorkflowCommentReplyPayload,
|
||||
):
|
||||
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
register_schema_models(console_ns, AccountWithRole, WorkflowCommentMentionUsersPayload)
|
||||
)
|
||||
|
||||
workflow_comment_basic_model = console_ns.model("WorkflowCommentBasic", workflow_comment_basic_fields)
|
||||
workflow_comment_detail_model = console_ns.model("WorkflowCommentDetail", workflow_comment_detail_fields)
|
||||
|
||||
@ -8,6 +8,7 @@ from flask_restx import Resource, fields, marshal, marshal_with
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import (
|
||||
DraftWorkflowNotExist,
|
||||
@ -33,7 +34,6 @@ from services.workflow_service import WorkflowService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
_file_access_controller = DatabaseFileAccessController()
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class WorkflowDraftVariableListQuery(BaseModel):
|
||||
@ -56,21 +56,12 @@ class EnvironmentVariableUpdatePayload(BaseModel):
|
||||
environment_variables: list[dict[str, Any]] = Field(..., description="Environment variables for the draft workflow")
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
WorkflowDraftVariableListQuery.__name__,
|
||||
WorkflowDraftVariableListQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
console_ns.schema_model(
|
||||
WorkflowDraftVariableUpdatePayload.__name__,
|
||||
WorkflowDraftVariableUpdatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
console_ns.schema_model(
|
||||
ConversationVariableUpdatePayload.__name__,
|
||||
ConversationVariableUpdatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
console_ns.schema_model(
|
||||
EnvironmentVariableUpdatePayload.__name__,
|
||||
EnvironmentVariableUpdatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
WorkflowDraftVariableListQuery,
|
||||
WorkflowDraftVariableUpdatePayload,
|
||||
ConversationVariableUpdatePayload,
|
||||
EnvironmentVariableUpdatePayload,
|
||||
)
|
||||
|
||||
|
||||
@ -260,7 +251,7 @@ class WorkflowVariableCollectionApi(Resource):
|
||||
"""
|
||||
Get draft workflow
|
||||
"""
|
||||
args = WorkflowDraftVariableListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = WorkflowDraftVariableListQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
# fetch draft workflow by app_model
|
||||
workflow_service = WorkflowService()
|
||||
|
||||
@ -154,7 +154,7 @@ class AdvancedChatAppWorkflowRunListApi(Resource):
|
||||
"""
|
||||
Get advanced chat app workflow run list
|
||||
"""
|
||||
args_model = WorkflowRunListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args_model = WorkflowRunListQuery.model_validate(request.args.to_dict(flat=True))
|
||||
args: WorkflowRunListArgs = {"limit": args_model.limit}
|
||||
if args_model.last_id is not None:
|
||||
args["last_id"] = args_model.last_id
|
||||
@ -250,7 +250,7 @@ class AdvancedChatAppWorkflowRunCountApi(Resource):
|
||||
"""
|
||||
Get advanced chat workflow runs count statistics
|
||||
"""
|
||||
args_model = WorkflowRunCountQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args_model = WorkflowRunCountQuery.model_validate(request.args.to_dict(flat=True))
|
||||
args = args_model.model_dump(exclude_none=True)
|
||||
|
||||
# Default to DEBUGGING if not specified
|
||||
@ -290,7 +290,7 @@ class WorkflowRunListApi(Resource):
|
||||
"""
|
||||
Get workflow run list
|
||||
"""
|
||||
args_model = WorkflowRunListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args_model = WorkflowRunListQuery.model_validate(request.args.to_dict(flat=True))
|
||||
args: WorkflowRunListArgs = {"limit": args_model.limit}
|
||||
if args_model.last_id is not None:
|
||||
args["last_id"] = args_model.last_id
|
||||
@ -331,7 +331,7 @@ class WorkflowRunCountApi(Resource):
|
||||
"""
|
||||
Get workflow runs count statistics
|
||||
"""
|
||||
args_model = WorkflowRunCountQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args_model = WorkflowRunCountQuery.model_validate(request.args.to_dict(flat=True))
|
||||
args = args_model.model_dump(exclude_none=True)
|
||||
|
||||
# Default to DEBUGGING for workflow if not specified (backward compatibility)
|
||||
|
||||
@ -3,6 +3,7 @@ from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
@ -13,8 +14,6 @@ from models.enums import WorkflowRunTriggeredFrom
|
||||
from models.model import AppMode
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class WorkflowStatisticQuery(BaseModel):
|
||||
start: str | None = Field(default=None, description="Start date and time (YYYY-MM-DD HH:MM)")
|
||||
@ -28,10 +27,7 @@ class WorkflowStatisticQuery(BaseModel):
|
||||
return value
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
WorkflowStatisticQuery.__name__,
|
||||
WorkflowStatisticQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
register_schema_models(console_ns, WorkflowStatisticQuery)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow/statistics/daily-conversations")
|
||||
@ -53,7 +49,7 @@ class WorkflowDailyRunsStatistic(Resource):
|
||||
def get(self, app_model):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
assert account.timezone is not None
|
||||
|
||||
@ -93,7 +89,7 @@ class WorkflowDailyTerminalsStatistic(Resource):
|
||||
def get(self, app_model):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
assert account.timezone is not None
|
||||
|
||||
@ -133,7 +129,7 @@ class WorkflowDailyTokenCostStatistic(Resource):
|
||||
def get(self, app_model):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
assert account.timezone is not None
|
||||
|
||||
@ -173,7 +169,7 @@ class WorkflowAverageAppInteractionStatistic(Resource):
|
||||
def get(self, app_model):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
assert account.timezone is not None
|
||||
|
||||
|
||||
@ -94,7 +94,7 @@ class WebhookTriggerApi(Resource):
|
||||
@console_ns.response(200, "Success", console_ns.models[WebhookTriggerResponse.__name__])
|
||||
def get(self, app_model: App):
|
||||
"""Get webhook trigger for a node"""
|
||||
args = Parser.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = Parser.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
node_id = args.node_id
|
||||
|
||||
|
||||
@ -63,7 +63,7 @@ class ActivateCheckApi(Resource):
|
||||
console_ns.models[ActivationCheckResponse.__name__],
|
||||
)
|
||||
def get(self):
|
||||
args = ActivateCheckQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = ActivateCheckQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
workspaceId = args.workspace_id
|
||||
token = args.token
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from services.auth.api_key_auth_service import ApiKeyAuthService
|
||||
|
||||
@ -8,8 +9,6 @@ from .. import console_ns
|
||||
from ..auth.error import ApiKeyAuthFailedError
|
||||
from ..wraps import account_initialization_required, is_admin_or_owner_required, setup_required
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class ApiKeyAuthBindingPayload(BaseModel):
|
||||
category: str = Field(...)
|
||||
@ -17,10 +16,7 @@ class ApiKeyAuthBindingPayload(BaseModel):
|
||||
credentials: dict = Field(...)
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
ApiKeyAuthBindingPayload.__name__,
|
||||
ApiKeyAuthBindingPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
register_schema_models(console_ns, ApiKeyAuthBindingPayload)
|
||||
|
||||
|
||||
@console_ns.route("/api-key-auth/data-source")
|
||||
|
||||
@ -4,6 +4,7 @@ from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from configs import dify_config
|
||||
from constants.languages import languages
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.auth.error import (
|
||||
EmailAlreadyInUseError,
|
||||
@ -23,8 +24,6 @@ from services.errors.account import AccountNotFoundError, AccountRegisterError
|
||||
from ..error import AccountInFreezeError, EmailSendIpLimitError
|
||||
from ..wraps import email_password_login_enabled, email_register_enabled, setup_required
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class EmailRegisterSendPayload(BaseModel):
|
||||
email: EmailStr = Field(..., description="Email address")
|
||||
@ -48,8 +47,7 @@ class EmailRegisterResetPayload(BaseModel):
|
||||
return valid_password(value)
|
||||
|
||||
|
||||
for model in (EmailRegisterSendPayload, EmailRegisterValidityPayload, EmailRegisterResetPayload):
|
||||
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
register_schema_models(console_ns, EmailRegisterSendPayload, EmailRegisterValidityPayload, EmailRegisterResetPayload)
|
||||
|
||||
|
||||
@console_ns.route("/email-register/send-email")
|
||||
|
||||
@ -28,8 +28,6 @@ from services.entities.auth_entities import (
|
||||
)
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class ForgotPasswordEmailResponse(BaseModel):
|
||||
result: str = Field(description="Operation result")
|
||||
|
||||
@ -9,6 +9,7 @@ from werkzeug.exceptions import Unauthorized
|
||||
import services
|
||||
from configs import dify_config
|
||||
from constants.languages import get_valid_language
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.auth.error import (
|
||||
AuthenticationFailedError,
|
||||
@ -50,7 +51,6 @@ from services.errors.account import AccountRegisterError
|
||||
from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkspacesLimitExceededError
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -71,13 +71,7 @@ class EmailCodeLoginPayload(BaseModel):
|
||||
language: str | None = Field(default=None)
|
||||
|
||||
|
||||
def reg(cls: type[BaseModel]):
|
||||
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
|
||||
reg(LoginPayload)
|
||||
reg(EmailPayload)
|
||||
reg(EmailCodeLoginPayload)
|
||||
register_schema_models(console_ns, LoginPayload, EmailPayload, EmailCodeLoginPayload)
|
||||
|
||||
|
||||
@console_ns.route("/login")
|
||||
|
||||
@ -4,6 +4,7 @@ from flask_restx import ( # type: ignore
|
||||
from pydantic import BaseModel
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.datasets.wraps import get_rag_pipeline
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
@ -12,8 +13,6 @@ from models import Account
|
||||
from models.dataset import Pipeline
|
||||
from services.rag_pipeline.rag_pipeline import RagPipelineService
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class Parser(BaseModel):
|
||||
inputs: dict
|
||||
@ -21,7 +20,7 @@ class Parser(BaseModel):
|
||||
credential_id: str | None = None
|
||||
|
||||
|
||||
console_ns.schema_model(Parser.__name__, Parser.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
register_schema_models(console_ns, Parser)
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/datasource/nodes/<string:node_id>/preview")
|
||||
|
||||
@ -80,7 +80,7 @@ class RecommendedAppListApi(Resource):
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
# language args
|
||||
args = RecommendedAppsQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = RecommendedAppsQuery.model_validate(request.args.to_dict(flat=True))
|
||||
language = args.language
|
||||
if language and language in languages:
|
||||
language_prefix = language
|
||||
|
||||
@ -10,7 +10,7 @@ from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||
import services
|
||||
from controllers.common.fields import Parameters as ParametersResponse
|
||||
from controllers.common.fields import Site as SiteResponse
|
||||
from controllers.common.schema import get_or_create_model
|
||||
from controllers.common.schema import get_or_create_model, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import (
|
||||
AppUnavailableError,
|
||||
@ -120,10 +120,6 @@ workflow_fields_copy["rag_pipeline_variables"] = fields.List(fields.Nested(pipel
|
||||
workflow_model = get_or_create_model("TrialWorkflow", workflow_fields_copy)
|
||||
|
||||
|
||||
# Pydantic models for request validation
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class WorkflowRunRequest(BaseModel):
|
||||
inputs: dict
|
||||
files: list | None = None
|
||||
@ -153,19 +149,7 @@ class CompletionRequest(BaseModel):
|
||||
retriever_from: str = "explore_app"
|
||||
|
||||
|
||||
# Register schemas for Swagger documentation
|
||||
console_ns.schema_model(
|
||||
WorkflowRunRequest.__name__, WorkflowRunRequest.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
console_ns.schema_model(
|
||||
ChatRequest.__name__, ChatRequest.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
console_ns.schema_model(
|
||||
TextToSpeechRequest.__name__, TextToSpeechRequest.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
console_ns.schema_model(
|
||||
CompletionRequest.__name__, CompletionRequest.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
register_schema_models(console_ns, WorkflowRunRequest, ChatRequest, TextToSpeechRequest, CompletionRequest)
|
||||
|
||||
|
||||
class TrialAppWorkflowRunApi(TrialAppResource):
|
||||
|
||||
@ -89,7 +89,7 @@ class CodeBasedExtensionAPI(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
query = CodeBasedExtensionQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
query = CodeBasedExtensionQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
return CodeBasedExtensionResponse(
|
||||
module=query.module,
|
||||
|
||||
@ -52,8 +52,6 @@ from services.account_service import AccountService
|
||||
from services.billing_service import BillingService
|
||||
from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class AccountInitPayload(BaseModel):
|
||||
interface_language: str
|
||||
@ -161,27 +159,26 @@ class CheckEmailUniquePayload(BaseModel):
|
||||
email: EmailStr
|
||||
|
||||
|
||||
def reg(cls: type[BaseModel]):
|
||||
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
|
||||
reg(AccountInitPayload)
|
||||
reg(AccountNamePayload)
|
||||
reg(AccountAvatarPayload)
|
||||
reg(AccountAvatarQuery)
|
||||
reg(AccountInterfaceLanguagePayload)
|
||||
reg(AccountInterfaceThemePayload)
|
||||
reg(AccountTimezonePayload)
|
||||
reg(AccountPasswordPayload)
|
||||
reg(AccountDeletePayload)
|
||||
reg(AccountDeletionFeedbackPayload)
|
||||
reg(EducationActivatePayload)
|
||||
reg(EducationAutocompleteQuery)
|
||||
reg(ChangeEmailSendPayload)
|
||||
reg(ChangeEmailValidityPayload)
|
||||
reg(ChangeEmailResetPayload)
|
||||
reg(CheckEmailUniquePayload)
|
||||
register_schema_models(console_ns, AccountResponse)
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
AccountResponse,
|
||||
AccountInitPayload,
|
||||
AccountNamePayload,
|
||||
AccountAvatarPayload,
|
||||
AccountAvatarQuery,
|
||||
AccountInterfaceLanguagePayload,
|
||||
AccountInterfaceThemePayload,
|
||||
AccountTimezonePayload,
|
||||
AccountPasswordPayload,
|
||||
AccountDeletePayload,
|
||||
AccountDeletionFeedbackPayload,
|
||||
EducationActivatePayload,
|
||||
EducationAutocompleteQuery,
|
||||
ChangeEmailSendPayload,
|
||||
ChangeEmailValidityPayload,
|
||||
ChangeEmailResetPayload,
|
||||
CheckEmailUniquePayload,
|
||||
)
|
||||
|
||||
|
||||
def _serialize_account(account) -> dict[str, Any]:
|
||||
@ -326,7 +323,7 @@ class AccountAvatarApi(Resource):
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
args = AccountAvatarQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = AccountAvatarQuery.model_validate(request.args.to_dict(flat=True))
|
||||
avatar = args.avatar
|
||||
|
||||
if avatar.startswith(("http://", "https://")):
|
||||
|
||||
@ -20,8 +20,6 @@ from graphon.model_runtime.utils.encoders import jsonable_encoder
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from services.plugin.endpoint_service import EndpointService
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class EndpointCreatePayload(BaseModel):
|
||||
plugin_unique_identifier: str
|
||||
@ -80,10 +78,6 @@ class EndpointDisableResponse(BaseModel):
|
||||
success: bool = Field(description="Operation success")
|
||||
|
||||
|
||||
def reg(cls: type[BaseModel]):
|
||||
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
EndpointCreatePayload,
|
||||
@ -215,7 +209,7 @@ class EndpointListApi(Resource):
|
||||
def get(self):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
args = EndpointListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = EndpointListQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
page = args.page
|
||||
page_size = args.page_size
|
||||
@ -248,7 +242,7 @@ class EndpointListForSinglePluginApi(Resource):
|
||||
def get(self):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
args = EndpointListForPluginQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = EndpointListForPluginQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
page = args.page
|
||||
page_size = args.page_size
|
||||
|
||||
@ -33,8 +33,6 @@ from services.account_service import AccountService, RegisterService, TenantServ
|
||||
from services.errors.account import AccountAlreadyInTenantError
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class MemberInvitePayload(BaseModel):
|
||||
emails: list[str] = Field(default_factory=list)
|
||||
@ -59,17 +57,17 @@ class OwnerTransferPayload(BaseModel):
|
||||
token: str
|
||||
|
||||
|
||||
def reg(cls: type[BaseModel]):
|
||||
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
|
||||
reg(MemberInvitePayload)
|
||||
reg(MemberRoleUpdatePayload)
|
||||
reg(OwnerTransferEmailPayload)
|
||||
reg(OwnerTransferCheckPayload)
|
||||
reg(OwnerTransferPayload)
|
||||
register_enum_models(console_ns, TenantAccountRole)
|
||||
register_schema_models(console_ns, AccountWithRole, AccountWithRoleList)
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
AccountWithRole,
|
||||
AccountWithRoleList,
|
||||
MemberInvitePayload,
|
||||
MemberRoleUpdatePayload,
|
||||
OwnerTransferEmailPayload,
|
||||
OwnerTransferCheckPayload,
|
||||
OwnerTransferPayload,
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/members")
|
||||
|
||||
@ -5,6 +5,7 @@ from flask import request, send_file
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
@ -15,8 +16,6 @@ from libs.login import current_account_with_tenant, login_required
|
||||
from services.billing_service import BillingService
|
||||
from services.model_provider_service import ModelProviderService
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class ParserModelList(BaseModel):
|
||||
model_type: ModelType | None = None
|
||||
@ -75,18 +74,17 @@ class ParserPreferredProviderType(BaseModel):
|
||||
preferred_provider_type: Literal["system", "custom"]
|
||||
|
||||
|
||||
def reg(cls: type[BaseModel]):
|
||||
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
|
||||
reg(ParserModelList)
|
||||
reg(ParserCredentialId)
|
||||
reg(ParserCredentialCreate)
|
||||
reg(ParserCredentialUpdate)
|
||||
reg(ParserCredentialDelete)
|
||||
reg(ParserCredentialSwitch)
|
||||
reg(ParserCredentialValidate)
|
||||
reg(ParserPreferredProviderType)
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
ParserModelList,
|
||||
ParserCredentialId,
|
||||
ParserCredentialCreate,
|
||||
ParserCredentialUpdate,
|
||||
ParserCredentialDelete,
|
||||
ParserCredentialSwitch,
|
||||
ParserCredentialValidate,
|
||||
ParserPreferredProviderType,
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/model-providers")
|
||||
|
||||
@ -17,7 +17,6 @@ from services.model_load_balancing_service import ModelLoadBalancingService
|
||||
from services.model_provider_service import ModelProviderService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class ParserGetDefault(BaseModel):
|
||||
@ -107,6 +106,12 @@ class ParserParameter(BaseModel):
|
||||
model: str
|
||||
|
||||
|
||||
class ParserSwitch(BaseModel):
|
||||
model: str
|
||||
model_type: ModelType
|
||||
credential_id: str
|
||||
|
||||
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
ParserGetDefault,
|
||||
@ -119,6 +124,7 @@ register_schema_models(
|
||||
ParserDeleteCredential,
|
||||
ParserParameter,
|
||||
Inner,
|
||||
ParserSwitch,
|
||||
)
|
||||
|
||||
register_enum_models(console_ns, ModelType)
|
||||
@ -133,7 +139,7 @@ class DefaultModelApi(Resource):
|
||||
def get(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
args = ParserGetDefault.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = ParserGetDefault.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
default_model_entity = model_provider_service.get_default_model_of_model_type(
|
||||
@ -261,7 +267,7 @@ class ModelProviderModelCredentialApi(Resource):
|
||||
def get(self, provider: str):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
args = ParserGetCredentials.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = ParserGetCredentials.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
current_credential = model_provider_service.get_model_credential(
|
||||
@ -387,17 +393,6 @@ class ModelProviderModelCredentialApi(Resource):
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
class ParserSwitch(BaseModel):
|
||||
model: str
|
||||
model_type: ModelType
|
||||
credential_id: str
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserSwitch.__name__, ParserSwitch.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/model-providers/<path:provider>/models/credentials/switch")
|
||||
class ModelProviderModelCredentialSwitchApi(Resource):
|
||||
@console_ns.expect(console_ns.models[ParserSwitch.__name__])
|
||||
@ -468,9 +463,7 @@ class ParserValidate(BaseModel):
|
||||
credentials: dict[str, Any]
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserValidate.__name__, ParserValidate.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
register_schema_models(console_ns, ParserSwitch, ParserValidate)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/model-providers/<path:provider>/models/credentials/validate")
|
||||
@ -515,7 +508,7 @@ class ModelProviderModelParameterRuleApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider: str):
|
||||
args = ParserParameter.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = ParserParameter.model_validate(request.args.to_dict(flat=True))
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
|
||||
@ -211,7 +211,7 @@ class PluginListApi(Resource):
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
args = ParserList.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = ParserList.model_validate(request.args.to_dict(flat=True))
|
||||
try:
|
||||
plugins_with_total = PluginService.list_with_total(tenant_id, args.page, args.page_size)
|
||||
except PluginDaemonClientSideError as e:
|
||||
@ -261,7 +261,7 @@ class PluginIconApi(Resource):
|
||||
@console_ns.expect(console_ns.models[ParserIcon.__name__])
|
||||
@setup_required
|
||||
def get(self):
|
||||
args = ParserIcon.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = ParserIcon.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
try:
|
||||
icon_bytes, mimetype = PluginService.get_asset(args.tenant_id, args.filename)
|
||||
@ -279,7 +279,7 @@ class PluginAssetApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
args = ParserAsset.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = ParserAsset.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
try:
|
||||
@ -421,7 +421,7 @@ class PluginFetchMarketplacePkgApi(Resource):
|
||||
@plugin_permission_required(install_required=True)
|
||||
def get(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
args = ParserPluginIdentifierQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = ParserPluginIdentifierQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
try:
|
||||
return jsonable_encoder(
|
||||
@ -446,7 +446,7 @@ class PluginFetchManifestApi(Resource):
|
||||
def get(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
args = ParserPluginIdentifierQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = ParserPluginIdentifierQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
try:
|
||||
return jsonable_encoder(
|
||||
@ -466,7 +466,7 @@ class PluginFetchInstallTasksApi(Resource):
|
||||
def get(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
args = ParserTasks.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = ParserTasks.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
try:
|
||||
return jsonable_encoder({"tasks": PluginService.fetch_install_tasks(tenant_id, args.page, args.page_size)})
|
||||
@ -660,7 +660,7 @@ class PluginFetchDynamicSelectOptionsApi(Resource):
|
||||
current_user, tenant_id = current_account_with_tenant()
|
||||
user_id = current_user.id
|
||||
|
||||
args = ParserDynamicOptions.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = ParserDynamicOptions.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
try:
|
||||
options = PluginParameterService.get_dynamic_select_options(
|
||||
@ -822,7 +822,7 @@ class PluginReadmeApi(Resource):
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
args = ParserReadme.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = ParserReadme.model_validate(request.args.to_dict(flat=True))
|
||||
return jsonable_encoder(
|
||||
{"readme": PluginService.fetch_plugin_readme(tenant_id, args.plugin_unique_identifier, args.language)}
|
||||
)
|
||||
|
||||
@ -16,6 +16,7 @@ from controllers.common.errors import (
|
||||
TooManyFilesError,
|
||||
UnsupportedFileTypeError,
|
||||
)
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.admin import admin_required
|
||||
from controllers.console.error import AccountNotLinkTenantError
|
||||
@ -39,7 +40,6 @@ from services.file_service import FileService
|
||||
from services.workspace_service import WorkspaceService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class WorkspaceListQuery(BaseModel):
|
||||
@ -91,15 +91,14 @@ class TenantInfoResponse(ResponseModel):
|
||||
return value
|
||||
|
||||
|
||||
def reg(cls: type[BaseModel]):
|
||||
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
|
||||
reg(WorkspaceListQuery)
|
||||
reg(SwitchWorkspacePayload)
|
||||
reg(WorkspaceCustomConfigPayload)
|
||||
reg(WorkspaceInfoPayload)
|
||||
reg(TenantInfoResponse)
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
WorkspaceListQuery,
|
||||
SwitchWorkspacePayload,
|
||||
WorkspaceCustomConfigPayload,
|
||||
WorkspaceInfoPayload,
|
||||
TenantInfoResponse,
|
||||
)
|
||||
|
||||
provider_fields = {
|
||||
"provider_name": fields.String,
|
||||
|
||||
@ -8,13 +8,12 @@ from werkzeug.exceptions import NotFound
|
||||
import services
|
||||
from controllers.common.errors import UnsupportedFileTypeError
|
||||
from controllers.common.file_response import enforce_download_for_html
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.files import files_ns
|
||||
from extensions.ext_database import db
|
||||
from services.account_service import TenantService
|
||||
from services.file_service import FileService
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class FileSignatureQuery(BaseModel):
|
||||
timestamp: str = Field(..., description="Unix timestamp used in the signature")
|
||||
@ -26,12 +25,7 @@ class FilePreviewQuery(FileSignatureQuery):
|
||||
as_attachment: bool = Field(default=False, description="Whether to download as attachment")
|
||||
|
||||
|
||||
files_ns.schema_model(
|
||||
FileSignatureQuery.__name__, FileSignatureQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
files_ns.schema_model(
|
||||
FilePreviewQuery.__name__, FilePreviewQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
register_schema_models(files_ns, FileSignatureQuery, FilePreviewQuery)
|
||||
|
||||
|
||||
@files_ns.route("/<uuid:file_id>/image-preview")
|
||||
@ -58,7 +52,7 @@ class ImagePreviewApi(Resource):
|
||||
def get(self, file_id):
|
||||
file_id = str(file_id)
|
||||
|
||||
args = FileSignatureQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = FileSignatureQuery.model_validate(request.args.to_dict(flat=True))
|
||||
timestamp = args.timestamp
|
||||
nonce = args.nonce
|
||||
sign = args.sign
|
||||
@ -100,7 +94,7 @@ class FilePreviewApi(Resource):
|
||||
def get(self, file_id):
|
||||
file_id = str(file_id)
|
||||
|
||||
args = FilePreviewQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = FilePreviewQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
try:
|
||||
generator, upload_file = FileService(db.engine).get_file_generator_by_file_id(
|
||||
|
||||
@ -7,12 +7,11 @@ from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
from controllers.common.errors import UnsupportedFileTypeError
|
||||
from controllers.common.file_response import enforce_download_for_html
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.files import files_ns
|
||||
from core.tools.signature import verify_tool_file_signature
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class ToolFileQuery(BaseModel):
|
||||
timestamp: str = Field(..., description="Unix timestamp")
|
||||
@ -21,9 +20,7 @@ class ToolFileQuery(BaseModel):
|
||||
as_attachment: bool = Field(default=False, description="Download as attachment")
|
||||
|
||||
|
||||
files_ns.schema_model(
|
||||
ToolFileQuery.__name__, ToolFileQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
register_schema_models(files_ns, ToolFileQuery)
|
||||
|
||||
|
||||
@files_ns.route("/tools/<uuid:file_id>.<string:extension>")
|
||||
|
||||
@ -20,8 +20,6 @@ from ..console.wraps import setup_required
|
||||
from ..files import files_ns
|
||||
from ..inner_api.plugin.wraps import get_user
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class PluginUploadQuery(BaseModel):
|
||||
timestamp: str = Field(..., description="Unix timestamp for signature verification")
|
||||
@ -31,9 +29,8 @@ class PluginUploadQuery(BaseModel):
|
||||
user_id: str | None = Field(default=None, description="User identifier")
|
||||
|
||||
|
||||
files_ns.schema_model(
|
||||
PluginUploadQuery.__name__, PluginUploadQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
register_schema_models(files_ns, PluginUploadQuery)
|
||||
|
||||
|
||||
register_schema_models(files_ns, FileResponse)
|
||||
|
||||
@ -69,7 +66,7 @@ class PluginUploadFileApi(Resource):
|
||||
FileTooLargeError: File exceeds size limit
|
||||
UnsupportedFileTypeError: File type not supported
|
||||
"""
|
||||
args = PluginUploadQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = PluginUploadQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
file = request.files.get("file")
|
||||
if file is None:
|
||||
|
||||
@ -6,7 +6,7 @@ from pydantic import BaseModel, Field, TypeAdapter, field_validator, model_valid
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
import services
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.common.schema import register_enum_models, register_schema_models
|
||||
from controllers.console.wraps import edit_permission_required
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.dataset.error import DatasetInUseError, DatasetNameDuplicateError, InvalidActionError
|
||||
@ -34,13 +34,7 @@ from services.tag_service import (
|
||||
UpdateTagPayload,
|
||||
)
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
service_api_ns.schema_model(
|
||||
DatasetPermissionEnum.__name__,
|
||||
TypeAdapter(DatasetPermissionEnum).json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
register_enum_models(service_api_ns, DatasetPermissionEnum)
|
||||
|
||||
|
||||
class DatasetCreatePayload(BaseModel):
|
||||
|
||||
@ -77,9 +77,6 @@ class DocumentTextCreatePayload(BaseModel):
|
||||
return value
|
||||
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class DocumentTextUpdate(BaseModel):
|
||||
name: str | None = None
|
||||
text: str | None = None
|
||||
|
||||
Loading…
Reference in New Issue
Block a user