mirror of https://github.com/langgenius/dify.git
Fix typing errors in dataset API (#26424)
Co-authored-by: google-labs-jules[bot] <161369871+google-labs-jules[bot]@users.noreply.github.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
1a7898dff1
commit
d77c2e4d17
|
|
@ -1,10 +1,10 @@
|
|||
from typing import Literal
|
||||
from typing import Any, Literal, cast
|
||||
|
||||
from flask import request
|
||||
from flask_restx import marshal, reqparse
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
import services.dataset_service
|
||||
import services
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.dataset.error import DatasetInUseError, DatasetNameDuplicateError, InvalidActionError
|
||||
from controllers.service_api.wraps import (
|
||||
|
|
@ -254,19 +254,21 @@ class DatasetListApi(DatasetApiResource):
|
|||
"""Resource for creating datasets."""
|
||||
args = dataset_create_parser.parse_args()
|
||||
|
||||
if args.get("embedding_model_provider"):
|
||||
DatasetService.check_embedding_model_setting(
|
||||
tenant_id, args.get("embedding_model_provider"), args.get("embedding_model")
|
||||
)
|
||||
embedding_model_provider = args.get("embedding_model_provider")
|
||||
embedding_model = args.get("embedding_model")
|
||||
if embedding_model_provider and embedding_model:
|
||||
DatasetService.check_embedding_model_setting(tenant_id, embedding_model_provider, embedding_model)
|
||||
|
||||
retrieval_model = args.get("retrieval_model")
|
||||
if (
|
||||
args.get("retrieval_model")
|
||||
and args.get("retrieval_model").get("reranking_model")
|
||||
and args.get("retrieval_model").get("reranking_model").get("reranking_provider_name")
|
||||
retrieval_model
|
||||
and retrieval_model.get("reranking_model")
|
||||
and retrieval_model.get("reranking_model").get("reranking_provider_name")
|
||||
):
|
||||
DatasetService.check_reranking_model_setting(
|
||||
tenant_id,
|
||||
args.get("retrieval_model").get("reranking_model").get("reranking_provider_name"),
|
||||
args.get("retrieval_model").get("reranking_model").get("reranking_model_name"),
|
||||
retrieval_model.get("reranking_model").get("reranking_provider_name"),
|
||||
retrieval_model.get("reranking_model").get("reranking_model_name"),
|
||||
)
|
||||
|
||||
try:
|
||||
|
|
@ -317,7 +319,7 @@ class DatasetApi(DatasetApiResource):
|
|||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
data = marshal(dataset, dataset_detail_fields)
|
||||
data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields))
|
||||
# check embedding setting
|
||||
provider_manager = ProviderManager()
|
||||
assert isinstance(current_user, Account)
|
||||
|
|
@ -331,8 +333,8 @@ class DatasetApi(DatasetApiResource):
|
|||
for embedding_model in embedding_models:
|
||||
model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}")
|
||||
|
||||
if data["indexing_technique"] == "high_quality":
|
||||
item_model = f"{data['embedding_model']}:{data['embedding_model_provider']}"
|
||||
if data.get("indexing_technique") == "high_quality":
|
||||
item_model = f"{data.get('embedding_model')}:{data.get('embedding_model_provider')}"
|
||||
if item_model in model_names:
|
||||
data["embedding_available"] = True
|
||||
else:
|
||||
|
|
@ -341,7 +343,9 @@ class DatasetApi(DatasetApiResource):
|
|||
data["embedding_available"] = True
|
||||
|
||||
# force update search method to keyword_search if indexing_technique is economic
|
||||
data["retrieval_model_dict"]["search_method"] = "keyword_search"
|
||||
retrieval_model_dict = data.get("retrieval_model_dict")
|
||||
if retrieval_model_dict:
|
||||
retrieval_model_dict["search_method"] = "keyword_search"
|
||||
|
||||
if data.get("permission") == "partial_members":
|
||||
part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
|
||||
|
|
@ -372,19 +376,24 @@ class DatasetApi(DatasetApiResource):
|
|||
data = request.get_json()
|
||||
|
||||
# check embedding model setting
|
||||
if data.get("indexing_technique") == "high_quality" or data.get("embedding_model_provider"):
|
||||
embedding_model_provider = data.get("embedding_model_provider")
|
||||
embedding_model = data.get("embedding_model")
|
||||
if data.get("indexing_technique") == "high_quality" or embedding_model_provider:
|
||||
if embedding_model_provider and embedding_model:
|
||||
DatasetService.check_embedding_model_setting(
|
||||
dataset.tenant_id, data.get("embedding_model_provider"), data.get("embedding_model")
|
||||
dataset.tenant_id, embedding_model_provider, embedding_model
|
||||
)
|
||||
|
||||
retrieval_model = data.get("retrieval_model")
|
||||
if (
|
||||
data.get("retrieval_model")
|
||||
and data.get("retrieval_model").get("reranking_model")
|
||||
and data.get("retrieval_model").get("reranking_model").get("reranking_provider_name")
|
||||
retrieval_model
|
||||
and retrieval_model.get("reranking_model")
|
||||
and retrieval_model.get("reranking_model").get("reranking_provider_name")
|
||||
):
|
||||
DatasetService.check_reranking_model_setting(
|
||||
dataset.tenant_id,
|
||||
data.get("retrieval_model").get("reranking_model").get("reranking_provider_name"),
|
||||
data.get("retrieval_model").get("reranking_model").get("reranking_model_name"),
|
||||
retrieval_model.get("reranking_model").get("reranking_provider_name"),
|
||||
retrieval_model.get("reranking_model").get("reranking_model_name"),
|
||||
)
|
||||
|
||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||
|
|
@ -397,7 +406,7 @@ class DatasetApi(DatasetApiResource):
|
|||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
result_data = marshal(dataset, dataset_detail_fields)
|
||||
result_data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields))
|
||||
assert isinstance(current_user, Account)
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
|
|
@ -591,9 +600,10 @@ class DatasetTagsApi(DatasetApiResource):
|
|||
|
||||
args = tag_update_parser.parse_args()
|
||||
args["type"] = "knowledge"
|
||||
tag = TagService.update_tags(args, args.get("tag_id"))
|
||||
tag_id = args["tag_id"]
|
||||
tag = TagService.update_tags(args, tag_id)
|
||||
|
||||
binding_count = TagService.get_tag_binding_count(args.get("tag_id"))
|
||||
binding_count = TagService.get_tag_binding_count(tag_id)
|
||||
|
||||
response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count}
|
||||
|
||||
|
|
@ -616,7 +626,7 @@ class DatasetTagsApi(DatasetApiResource):
|
|||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
args = tag_delete_parser.parse_args()
|
||||
TagService.delete_tag(args.get("tag_id"))
|
||||
TagService.delete_tag(args["tag_id"])
|
||||
|
||||
return 204
|
||||
|
||||
|
|
|
|||
|
|
@ -108,19 +108,21 @@ class DocumentAddByTextApi(DatasetApiResource):
|
|||
if text is None or name is None:
|
||||
raise ValueError("Both 'text' and 'name' must be non-null values.")
|
||||
|
||||
if args.get("embedding_model_provider"):
|
||||
DatasetService.check_embedding_model_setting(
|
||||
tenant_id, args.get("embedding_model_provider"), args.get("embedding_model")
|
||||
)
|
||||
embedding_model_provider = args.get("embedding_model_provider")
|
||||
embedding_model = args.get("embedding_model")
|
||||
if embedding_model_provider and embedding_model:
|
||||
DatasetService.check_embedding_model_setting(tenant_id, embedding_model_provider, embedding_model)
|
||||
|
||||
retrieval_model = args.get("retrieval_model")
|
||||
if (
|
||||
args.get("retrieval_model")
|
||||
and args.get("retrieval_model").get("reranking_model")
|
||||
and args.get("retrieval_model").get("reranking_model").get("reranking_provider_name")
|
||||
retrieval_model
|
||||
and retrieval_model.get("reranking_model")
|
||||
and retrieval_model.get("reranking_model").get("reranking_provider_name")
|
||||
):
|
||||
DatasetService.check_reranking_model_setting(
|
||||
tenant_id,
|
||||
args.get("retrieval_model").get("reranking_model").get("reranking_provider_name"),
|
||||
args.get("retrieval_model").get("reranking_model").get("reranking_model_name"),
|
||||
retrieval_model.get("reranking_model").get("reranking_provider_name"),
|
||||
retrieval_model.get("reranking_model").get("reranking_model_name"),
|
||||
)
|
||||
|
||||
if not current_user:
|
||||
|
|
@ -187,15 +189,16 @@ class DocumentUpdateByTextApi(DatasetApiResource):
|
|||
if not dataset:
|
||||
raise ValueError("Dataset does not exist.")
|
||||
|
||||
retrieval_model = args.get("retrieval_model")
|
||||
if (
|
||||
args.get("retrieval_model")
|
||||
and args.get("retrieval_model").get("reranking_model")
|
||||
and args.get("retrieval_model").get("reranking_model").get("reranking_provider_name")
|
||||
retrieval_model
|
||||
and retrieval_model.get("reranking_model")
|
||||
and retrieval_model.get("reranking_model").get("reranking_provider_name")
|
||||
):
|
||||
DatasetService.check_reranking_model_setting(
|
||||
tenant_id,
|
||||
args.get("retrieval_model").get("reranking_model").get("reranking_provider_name"),
|
||||
args.get("retrieval_model").get("reranking_model").get("reranking_model_name"),
|
||||
retrieval_model.get("reranking_model").get("reranking_provider_name"),
|
||||
retrieval_model.get("reranking_model").get("reranking_model_name"),
|
||||
)
|
||||
|
||||
# indexing_technique is already set in dataset since this is an update
|
||||
|
|
|
|||
|
|
@ -106,7 +106,7 @@ class DatasetMetadataServiceApi(DatasetApiResource):
|
|||
raise NotFound("Dataset not found.")
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
|
||||
metadata = MetadataService.update_metadata_name(dataset_id_str, metadata_id_str, args.get("name"))
|
||||
metadata = MetadataService.update_metadata_name(dataset_id_str, metadata_id_str, args["name"])
|
||||
return marshal(metadata, dataset_metadata_fields), 200
|
||||
|
||||
@service_api_ns.doc("delete_dataset_metadata")
|
||||
|
|
|
|||
|
|
@ -8,7 +8,6 @@
|
|||
"extensions",
|
||||
"libs",
|
||||
"controllers/console/datasets",
|
||||
"controllers/service_api/dataset",
|
||||
"core/ops",
|
||||
"core/tools",
|
||||
"core/model_runtime",
|
||||
|
|
|
|||
Loading…
Reference in New Issue