mirror of https://github.com/langgenius/dify.git
refactor: replace request.args.get with Pydantic BaseModel validation (#31104)
Co-authored-by: GlobalStar117 <GlobalStar117@users.noreply.github.com> Co-authored-by: Asuka Minato <i@asukaminato.eu.org> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
parent
360f3bb32f
commit
f6be9cd90d
|
|
@ -36,6 +36,16 @@ class NotionEstimatePayload(BaseModel):
|
|||
doc_language: str = Field(default="English")
|
||||
|
||||
|
||||
class DataSourceNotionListQuery(BaseModel):
|
||||
dataset_id: str | None = Field(default=None, description="Dataset ID")
|
||||
credential_id: str = Field(..., description="Credential ID", min_length=1)
|
||||
datasource_parameters: dict[str, Any] | None = Field(default=None, description="Datasource parameters JSON string")
|
||||
|
||||
|
||||
class DataSourceNotionPreviewQuery(BaseModel):
|
||||
credential_id: str = Field(..., description="Credential ID", min_length=1)
|
||||
|
||||
|
||||
register_schema_model(console_ns, NotionEstimatePayload)
|
||||
|
||||
|
||||
|
|
@ -136,26 +146,15 @@ class DataSourceNotionListApi(Resource):
|
|||
def get(self):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
dataset_id = request.args.get("dataset_id", default=None, type=str)
|
||||
credential_id = request.args.get("credential_id", default=None, type=str)
|
||||
if not credential_id:
|
||||
raise ValueError("Credential id is required.")
|
||||
query = DataSourceNotionListQuery.model_validate(request.args.to_dict())
|
||||
|
||||
# Get datasource_parameters from query string (optional, for GitHub and other datasources)
|
||||
datasource_parameters_str = request.args.get("datasource_parameters", default=None, type=str)
|
||||
datasource_parameters = {}
|
||||
if datasource_parameters_str:
|
||||
try:
|
||||
datasource_parameters = json.loads(datasource_parameters_str)
|
||||
if not isinstance(datasource_parameters, dict):
|
||||
raise ValueError("datasource_parameters must be a JSON object.")
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError("Invalid datasource_parameters JSON format.")
|
||||
datasource_parameters = query.datasource_parameters or {}
|
||||
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
credential = datasource_provider_service.get_datasource_credentials(
|
||||
tenant_id=current_tenant_id,
|
||||
credential_id=credential_id,
|
||||
credential_id=query.credential_id,
|
||||
provider="notion_datasource",
|
||||
plugin_id="langgenius/notion_datasource",
|
||||
)
|
||||
|
|
@ -164,8 +163,8 @@ class DataSourceNotionListApi(Resource):
|
|||
exist_page_ids = []
|
||||
with Session(db.engine) as session:
|
||||
# import notion in the exist dataset
|
||||
if dataset_id:
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if query.dataset_id:
|
||||
dataset = DatasetService.get_dataset(query.dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
if dataset.data_source_type != "notion_import":
|
||||
|
|
@ -173,7 +172,7 @@ class DataSourceNotionListApi(Resource):
|
|||
|
||||
documents = session.scalars(
|
||||
select(Document).filter_by(
|
||||
dataset_id=dataset_id,
|
||||
dataset_id=query.dataset_id,
|
||||
tenant_id=current_tenant_id,
|
||||
data_source_type="notion_import",
|
||||
enabled=True,
|
||||
|
|
@ -240,13 +239,12 @@ class DataSourceNotionApi(Resource):
|
|||
def get(self, page_id, page_type):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
credential_id = request.args.get("credential_id", default=None, type=str)
|
||||
if not credential_id:
|
||||
raise ValueError("Credential id is required.")
|
||||
query = DataSourceNotionPreviewQuery.model_validate(request.args.to_dict())
|
||||
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
credential = datasource_provider_service.get_datasource_credentials(
|
||||
tenant_id=current_tenant_id,
|
||||
credential_id=credential_id,
|
||||
credential_id=query.credential_id,
|
||||
provider="notion_datasource",
|
||||
plugin_id="langgenius/notion_datasource",
|
||||
)
|
||||
|
|
|
|||
|
|
@ -176,7 +176,18 @@ class IndexingEstimatePayload(BaseModel):
|
|||
return result
|
||||
|
||||
|
||||
register_schema_models(console_ns, DatasetCreatePayload, DatasetUpdatePayload, IndexingEstimatePayload)
|
||||
class ConsoleDatasetListQuery(BaseModel):
|
||||
page: int = Field(default=1, description="Page number")
|
||||
limit: int = Field(default=20, description="Number of items per page")
|
||||
keyword: str | None = Field(default=None, description="Search keyword")
|
||||
include_all: bool = Field(default=False, description="Include all datasets")
|
||||
ids: list[str] = Field(default_factory=list, description="Filter by dataset IDs")
|
||||
tag_ids: list[str] = Field(default_factory=list, description="Filter by tag IDs")
|
||||
|
||||
|
||||
register_schema_models(
|
||||
console_ns, DatasetCreatePayload, DatasetUpdatePayload, IndexingEstimatePayload, ConsoleDatasetListQuery
|
||||
)
|
||||
|
||||
|
||||
def _get_retrieval_methods_by_vector_type(vector_type: str | None, is_mock: bool = False) -> dict[str, list[str]]:
|
||||
|
|
@ -275,18 +286,19 @@ class DatasetListApi(Resource):
|
|||
@enterprise_license_required
|
||||
def get(self):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
page = request.args.get("page", default=1, type=int)
|
||||
limit = request.args.get("limit", default=20, type=int)
|
||||
ids = request.args.getlist("ids")
|
||||
query = ConsoleDatasetListQuery.model_validate(request.args.to_dict(flat=False))
|
||||
# provider = request.args.get("provider", default="vendor")
|
||||
search = request.args.get("keyword", default=None, type=str)
|
||||
tag_ids = request.args.getlist("tag_ids")
|
||||
include_all = request.args.get("include_all", default="false").lower() == "true"
|
||||
if ids:
|
||||
datasets, total = DatasetService.get_datasets_by_ids(ids, current_tenant_id)
|
||||
if query.ids:
|
||||
datasets, total = DatasetService.get_datasets_by_ids(query.ids, current_tenant_id)
|
||||
else:
|
||||
datasets, total = DatasetService.get_datasets(
|
||||
page, limit, current_tenant_id, current_user, search, tag_ids, include_all
|
||||
query.page,
|
||||
query.limit,
|
||||
current_tenant_id,
|
||||
current_user,
|
||||
query.keyword,
|
||||
query.tag_ids,
|
||||
query.include_all,
|
||||
)
|
||||
|
||||
# check embedding setting
|
||||
|
|
@ -318,7 +330,13 @@ class DatasetListApi(Resource):
|
|||
else:
|
||||
item.update({"partial_member_list": []})
|
||||
|
||||
response = {"data": data, "has_more": len(datasets) == limit, "limit": limit, "total": total, "page": page}
|
||||
response = {
|
||||
"data": data,
|
||||
"has_more": len(datasets) == query.limit,
|
||||
"limit": query.limit,
|
||||
"total": total,
|
||||
"page": query.page,
|
||||
}
|
||||
return response, 200
|
||||
|
||||
@console_ns.doc("create_dataset")
|
||||
|
|
|
|||
|
|
@ -98,12 +98,19 @@ class BedrockRetrievalPayload(BaseModel):
|
|||
knowledge_id: str
|
||||
|
||||
|
||||
class ExternalApiTemplateListQuery(BaseModel):
|
||||
page: int = Field(default=1, description="Page number")
|
||||
limit: int = Field(default=20, description="Number of items per page")
|
||||
keyword: str | None = Field(default=None, description="Search keyword")
|
||||
|
||||
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
ExternalKnowledgeApiPayload,
|
||||
ExternalDatasetCreatePayload,
|
||||
ExternalHitTestingPayload,
|
||||
BedrockRetrievalPayload,
|
||||
ExternalApiTemplateListQuery,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -124,19 +131,17 @@ class ExternalApiTemplateListApi(Resource):
|
|||
@account_initialization_required
|
||||
def get(self):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
page = request.args.get("page", default=1, type=int)
|
||||
limit = request.args.get("limit", default=20, type=int)
|
||||
search = request.args.get("keyword", default=None, type=str)
|
||||
query = ExternalApiTemplateListQuery.model_validate(request.args.to_dict())
|
||||
|
||||
external_knowledge_apis, total = ExternalDatasetService.get_external_knowledge_apis(
|
||||
page, limit, current_tenant_id, search
|
||||
query.page, query.limit, current_tenant_id, query.keyword
|
||||
)
|
||||
response = {
|
||||
"data": [item.to_dict() for item in external_knowledge_apis],
|
||||
"has_more": len(external_knowledge_apis) == limit,
|
||||
"limit": limit,
|
||||
"has_more": len(external_knowledge_apis) == query.limit,
|
||||
"limit": query.limit,
|
||||
"total": total,
|
||||
"page": page,
|
||||
"page": query.page,
|
||||
}
|
||||
return response, 200
|
||||
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ from typing import Any
|
|||
|
||||
from flask import request
|
||||
from flask_restx import Resource, marshal_with
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import and_, select
|
||||
from werkzeug.exceptions import BadRequest, Forbidden, NotFound
|
||||
|
||||
|
|
@ -28,6 +28,10 @@ class InstalledAppUpdatePayload(BaseModel):
|
|||
is_pinned: bool | None = None
|
||||
|
||||
|
||||
class InstalledAppsListQuery(BaseModel):
|
||||
app_id: str | None = Field(default=None, description="App ID to filter by")
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
|
@ -37,13 +41,13 @@ class InstalledAppsListApi(Resource):
|
|||
@account_initialization_required
|
||||
@marshal_with(installed_app_list_fields)
|
||||
def get(self):
|
||||
app_id = request.args.get("app_id", default=None, type=str)
|
||||
query = InstalledAppsListQuery.model_validate(request.args.to_dict())
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
if app_id:
|
||||
if query.app_id:
|
||||
installed_apps = db.session.scalars(
|
||||
select(InstalledApp).where(
|
||||
and_(InstalledApp.tenant_id == current_tenant_id, InstalledApp.app_id == app_id)
|
||||
and_(InstalledApp.tenant_id == current_tenant_id, InstalledApp.app_id == query.app_id)
|
||||
)
|
||||
).all()
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -40,6 +40,7 @@ register_schema_models(
|
|||
TagBasePayload,
|
||||
TagBindingPayload,
|
||||
TagBindingRemovePayload,
|
||||
TagListQueryParam,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -87,6 +87,14 @@ class TagUnbindingPayload(BaseModel):
|
|||
target_id: str
|
||||
|
||||
|
||||
class DatasetListQuery(BaseModel):
|
||||
page: int = Field(default=1, description="Page number")
|
||||
limit: int = Field(default=20, description="Number of items per page")
|
||||
keyword: str | None = Field(default=None, description="Search keyword")
|
||||
include_all: bool = Field(default=False, description="Include all datasets")
|
||||
tag_ids: list[str] = Field(default_factory=list, description="Filter by tag IDs")
|
||||
|
||||
|
||||
register_schema_models(
|
||||
service_api_ns,
|
||||
DatasetCreatePayload,
|
||||
|
|
@ -96,6 +104,7 @@ register_schema_models(
|
|||
TagDeletePayload,
|
||||
TagBindingPayload,
|
||||
TagUnbindingPayload,
|
||||
DatasetListQuery,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -113,15 +122,11 @@ class DatasetListApi(DatasetApiResource):
|
|||
)
|
||||
def get(self, tenant_id):
|
||||
"""Resource for getting datasets."""
|
||||
page = request.args.get("page", default=1, type=int)
|
||||
limit = request.args.get("limit", default=20, type=int)
|
||||
query = DatasetListQuery.model_validate(request.args.to_dict(flat=False))
|
||||
# provider = request.args.get("provider", default="vendor")
|
||||
search = request.args.get("keyword", default=None, type=str)
|
||||
tag_ids = request.args.getlist("tag_ids")
|
||||
include_all = request.args.get("include_all", default="false").lower() == "true"
|
||||
|
||||
datasets, total = DatasetService.get_datasets(
|
||||
page, limit, tenant_id, current_user, search, tag_ids, include_all
|
||||
query.page, query.limit, tenant_id, current_user, query.keyword, query.tag_ids, query.include_all
|
||||
)
|
||||
# check embedding setting
|
||||
provider_manager = ProviderManager()
|
||||
|
|
@ -147,7 +152,13 @@ class DatasetListApi(DatasetApiResource):
|
|||
item["embedding_available"] = False
|
||||
else:
|
||||
item["embedding_available"] = True
|
||||
response = {"data": data, "has_more": len(datasets) == limit, "limit": limit, "total": total, "page": page}
|
||||
response = {
|
||||
"data": data,
|
||||
"has_more": len(datasets) == query.limit,
|
||||
"limit": query.limit,
|
||||
"total": total,
|
||||
"page": query.page,
|
||||
}
|
||||
return response, 200
|
||||
|
||||
@service_api_ns.expect(service_api_ns.models[DatasetCreatePayload.__name__])
|
||||
|
|
|
|||
|
|
@ -69,7 +69,14 @@ class DocumentTextUpdate(BaseModel):
|
|||
return self
|
||||
|
||||
|
||||
for m in [ProcessRule, RetrievalModel, DocumentTextCreatePayload, DocumentTextUpdate]:
|
||||
class DocumentListQuery(BaseModel):
|
||||
page: int = Field(default=1, description="Page number")
|
||||
limit: int = Field(default=20, description="Number of items per page")
|
||||
keyword: str | None = Field(default=None, description="Search keyword")
|
||||
status: str | None = Field(default=None, description="Document status filter")
|
||||
|
||||
|
||||
for m in [ProcessRule, RetrievalModel, DocumentTextCreatePayload, DocumentTextUpdate, DocumentListQuery]:
|
||||
service_api_ns.schema_model(m.__name__, m.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) # type: ignore
|
||||
|
||||
|
||||
|
|
@ -460,34 +467,33 @@ class DocumentListApi(DatasetApiResource):
|
|||
def get(self, tenant_id, dataset_id):
|
||||
dataset_id = str(dataset_id)
|
||||
tenant_id = str(tenant_id)
|
||||
page = request.args.get("page", default=1, type=int)
|
||||
limit = request.args.get("limit", default=20, type=int)
|
||||
search = request.args.get("keyword", default=None, type=str)
|
||||
status = request.args.get("status", default=None, type=str)
|
||||
query_params = DocumentListQuery.model_validate(request.args.to_dict())
|
||||
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
query = select(Document).filter_by(dataset_id=str(dataset_id), tenant_id=tenant_id)
|
||||
|
||||
if status:
|
||||
query = DocumentService.apply_display_status_filter(query, status)
|
||||
if query_params.status:
|
||||
query = DocumentService.apply_display_status_filter(query, query_params.status)
|
||||
|
||||
if search:
|
||||
search = f"%{search}%"
|
||||
if query_params.keyword:
|
||||
search = f"%{query_params.keyword}%"
|
||||
query = query.where(Document.name.like(search))
|
||||
|
||||
query = query.order_by(desc(Document.created_at), desc(Document.position))
|
||||
|
||||
paginated_documents = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
|
||||
paginated_documents = db.paginate(
|
||||
select=query, page=query_params.page, per_page=query_params.limit, max_per_page=100, error_out=False
|
||||
)
|
||||
documents = paginated_documents.items
|
||||
|
||||
response = {
|
||||
"data": marshal(documents, document_fields),
|
||||
"has_more": len(documents) == limit,
|
||||
"limit": limit,
|
||||
"has_more": len(documents) == query_params.limit,
|
||||
"limit": query_params.limit,
|
||||
"total": paginated_documents.total,
|
||||
"page": page,
|
||||
"page": query_params.page,
|
||||
}
|
||||
|
||||
return response
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from collections.abc import Generator, Mapping
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
from core.datasource.__base.datasource_plugin import DatasourcePlugin
|
||||
|
|
@ -34,7 +34,7 @@ class OnlineDocumentDatasourcePlugin(DatasourcePlugin):
|
|||
def get_online_document_pages(
|
||||
self,
|
||||
user_id: str,
|
||||
datasource_parameters: Mapping[str, Any],
|
||||
datasource_parameters: dict[str, Any],
|
||||
provider_type: str,
|
||||
) -> Generator[OnlineDocumentPagesMessage, None, None]:
|
||||
manager = PluginDatasourceManager()
|
||||
|
|
|
|||
Loading…
Reference in New Issue