mirror of
https://github.com/langgenius/dify.git
synced 2026-05-12 07:37:09 +08:00
refactor: port some if else to match (#31841)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
1082f488a1
commit
d625ac0bf1
@ -606,63 +606,63 @@ class DatasetIndexingEstimateApi(Resource):
|
||||
# validate args
|
||||
DocumentService.estimate_args_validate(args)
|
||||
extract_settings = []
|
||||
if args["info_list"]["data_source_type"] == "upload_file":
|
||||
file_ids = args["info_list"]["file_info_list"]["file_ids"]
|
||||
file_details = db.session.scalars(
|
||||
select(UploadFile).where(UploadFile.tenant_id == current_tenant_id, UploadFile.id.in_(file_ids))
|
||||
).all()
|
||||
match args["info_list"]["data_source_type"]:
|
||||
case "upload_file":
|
||||
file_ids = args["info_list"]["file_info_list"]["file_ids"]
|
||||
file_details = db.session.scalars(
|
||||
select(UploadFile).where(UploadFile.tenant_id == current_tenant_id, UploadFile.id.in_(file_ids))
|
||||
).all()
|
||||
if file_details is None:
|
||||
raise NotFound("File not found.")
|
||||
|
||||
if file_details is None:
|
||||
raise NotFound("File not found.")
|
||||
|
||||
if file_details:
|
||||
for file_detail in file_details:
|
||||
if file_details:
|
||||
for file_detail in file_details:
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type=DatasourceType.FILE,
|
||||
upload_file=file_detail,
|
||||
document_model=args["doc_form"],
|
||||
)
|
||||
extract_settings.append(extract_setting)
|
||||
case "notion_import":
|
||||
notion_info_list = args["info_list"]["notion_info_list"]
|
||||
for notion_info in notion_info_list:
|
||||
workspace_id = notion_info["workspace_id"]
|
||||
credential_id = notion_info.get("credential_id")
|
||||
for page in notion_info["pages"]:
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type=DatasourceType.NOTION,
|
||||
notion_info=NotionInfo.model_validate(
|
||||
{
|
||||
"credential_id": credential_id,
|
||||
"notion_workspace_id": workspace_id,
|
||||
"notion_obj_id": page["page_id"],
|
||||
"notion_page_type": page["type"],
|
||||
"tenant_id": current_tenant_id,
|
||||
}
|
||||
),
|
||||
document_model=args["doc_form"],
|
||||
)
|
||||
extract_settings.append(extract_setting)
|
||||
case "website_crawl":
|
||||
website_info_list = args["info_list"]["website_info_list"]
|
||||
for url in website_info_list["urls"]:
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type=DatasourceType.FILE,
|
||||
upload_file=file_detail,
|
||||
document_model=args["doc_form"],
|
||||
)
|
||||
extract_settings.append(extract_setting)
|
||||
elif args["info_list"]["data_source_type"] == "notion_import":
|
||||
notion_info_list = args["info_list"]["notion_info_list"]
|
||||
for notion_info in notion_info_list:
|
||||
workspace_id = notion_info["workspace_id"]
|
||||
credential_id = notion_info.get("credential_id")
|
||||
for page in notion_info["pages"]:
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type=DatasourceType.NOTION,
|
||||
notion_info=NotionInfo.model_validate(
|
||||
datasource_type=DatasourceType.WEBSITE,
|
||||
website_info=WebsiteInfo.model_validate(
|
||||
{
|
||||
"credential_id": credential_id,
|
||||
"notion_workspace_id": workspace_id,
|
||||
"notion_obj_id": page["page_id"],
|
||||
"notion_page_type": page["type"],
|
||||
"provider": website_info_list["provider"],
|
||||
"job_id": website_info_list["job_id"],
|
||||
"url": url,
|
||||
"tenant_id": current_tenant_id,
|
||||
"mode": "crawl",
|
||||
"only_main_content": website_info_list["only_main_content"],
|
||||
}
|
||||
),
|
||||
document_model=args["doc_form"],
|
||||
)
|
||||
extract_settings.append(extract_setting)
|
||||
elif args["info_list"]["data_source_type"] == "website_crawl":
|
||||
website_info_list = args["info_list"]["website_info_list"]
|
||||
for url in website_info_list["urls"]:
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type=DatasourceType.WEBSITE,
|
||||
website_info=WebsiteInfo.model_validate(
|
||||
{
|
||||
"provider": website_info_list["provider"],
|
||||
"job_id": website_info_list["job_id"],
|
||||
"url": url,
|
||||
"tenant_id": current_tenant_id,
|
||||
"mode": "crawl",
|
||||
"only_main_content": website_info_list["only_main_content"],
|
||||
}
|
||||
),
|
||||
document_model=args["doc_form"],
|
||||
)
|
||||
extract_settings.append(extract_setting)
|
||||
else:
|
||||
raise ValueError("Data source type not support")
|
||||
case _:
|
||||
raise ValueError("Data source type not support")
|
||||
indexing_runner = IndexingRunner()
|
||||
try:
|
||||
response = indexing_runner.indexing_estimate(
|
||||
|
||||
@ -369,28 +369,31 @@ class DatasetDocumentListApi(Resource):
|
||||
else:
|
||||
sort_logic = asc
|
||||
|
||||
if sort == "hit_count":
|
||||
sub_query = (
|
||||
sa.select(DocumentSegment.document_id, sa.func.sum(DocumentSegment.hit_count).label("total_hit_count"))
|
||||
.where(DocumentSegment.dataset_id == str(dataset_id))
|
||||
.group_by(DocumentSegment.document_id)
|
||||
.subquery()
|
||||
)
|
||||
match sort:
|
||||
case "hit_count":
|
||||
sub_query = (
|
||||
sa.select(
|
||||
DocumentSegment.document_id, sa.func.sum(DocumentSegment.hit_count).label("total_hit_count")
|
||||
)
|
||||
.where(DocumentSegment.dataset_id == str(dataset_id))
|
||||
.group_by(DocumentSegment.document_id)
|
||||
.subquery()
|
||||
)
|
||||
|
||||
query = query.outerjoin(sub_query, sub_query.c.document_id == Document.id).order_by(
|
||||
sort_logic(sa.func.coalesce(sub_query.c.total_hit_count, 0)),
|
||||
sort_logic(Document.position),
|
||||
)
|
||||
elif sort == "created_at":
|
||||
query = query.order_by(
|
||||
sort_logic(Document.created_at),
|
||||
sort_logic(Document.position),
|
||||
)
|
||||
else:
|
||||
query = query.order_by(
|
||||
desc(Document.created_at),
|
||||
desc(Document.position),
|
||||
)
|
||||
query = query.outerjoin(sub_query, sub_query.c.document_id == Document.id).order_by(
|
||||
sort_logic(sa.func.coalesce(sub_query.c.total_hit_count, 0)),
|
||||
sort_logic(Document.position),
|
||||
)
|
||||
case "created_at":
|
||||
query = query.order_by(
|
||||
sort_logic(Document.created_at),
|
||||
sort_logic(Document.position),
|
||||
)
|
||||
case _:
|
||||
query = query.order_by(
|
||||
desc(Document.created_at),
|
||||
desc(Document.position),
|
||||
)
|
||||
|
||||
paginated_documents = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
|
||||
documents = paginated_documents.items
|
||||
|
||||
@ -123,12 +123,15 @@ class SimplePromptTransform(PromptTransform):
|
||||
|
||||
for v in special_variable_keys:
|
||||
# support #context#, #query# and #histories#
|
||||
if v == "#context#":
|
||||
variables["#context#"] = context or ""
|
||||
elif v == "#query#":
|
||||
variables["#query#"] = query or ""
|
||||
elif v == "#histories#":
|
||||
variables["#histories#"] = histories or ""
|
||||
match v:
|
||||
case "#context#":
|
||||
variables["#context#"] = context or ""
|
||||
case "#query#":
|
||||
variables["#query#"] = query or ""
|
||||
case "#histories#":
|
||||
variables["#histories#"] = histories or ""
|
||||
case _:
|
||||
pass
|
||||
|
||||
prompt_template = prompt_template_config["prompt_template"]
|
||||
if not isinstance(prompt_template, PromptTemplateParser):
|
||||
|
||||
@ -65,35 +65,18 @@ class LangSmithRunModel(LangSmithTokenUsage, LangSmithMultiModel):
|
||||
}
|
||||
file_list = values.get("file_list", [])
|
||||
if isinstance(v, str):
|
||||
if field_name == "inputs":
|
||||
return {
|
||||
"messages": {
|
||||
"role": "user",
|
||||
"content": v,
|
||||
"usage_metadata": usage_metadata,
|
||||
"file_list": file_list,
|
||||
},
|
||||
}
|
||||
elif field_name == "outputs":
|
||||
return {
|
||||
"choices": {
|
||||
"role": "ai",
|
||||
"content": v,
|
||||
"usage_metadata": usage_metadata,
|
||||
"file_list": file_list,
|
||||
},
|
||||
}
|
||||
elif isinstance(v, list):
|
||||
data = {}
|
||||
if len(v) > 0 and isinstance(v[0], dict):
|
||||
# rename text to content
|
||||
v = replace_text_with_content(data=v)
|
||||
if field_name == "inputs":
|
||||
data = {
|
||||
"messages": v,
|
||||
match field_name:
|
||||
case "inputs":
|
||||
return {
|
||||
"messages": {
|
||||
"role": "user",
|
||||
"content": v,
|
||||
"usage_metadata": usage_metadata,
|
||||
"file_list": file_list,
|
||||
},
|
||||
}
|
||||
elif field_name == "outputs":
|
||||
data = {
|
||||
case "outputs":
|
||||
return {
|
||||
"choices": {
|
||||
"role": "ai",
|
||||
"content": v,
|
||||
@ -101,6 +84,29 @@ class LangSmithRunModel(LangSmithTokenUsage, LangSmithMultiModel):
|
||||
"file_list": file_list,
|
||||
},
|
||||
}
|
||||
case _:
|
||||
pass
|
||||
elif isinstance(v, list):
|
||||
data = {}
|
||||
if len(v) > 0 and isinstance(v[0], dict):
|
||||
# rename text to content
|
||||
v = replace_text_with_content(data=v)
|
||||
match field_name:
|
||||
case "inputs":
|
||||
data = {
|
||||
"messages": v,
|
||||
}
|
||||
case "outputs":
|
||||
data = {
|
||||
"choices": {
|
||||
"role": "ai",
|
||||
"content": v,
|
||||
"usage_metadata": usage_metadata,
|
||||
"file_list": file_list,
|
||||
},
|
||||
}
|
||||
case _:
|
||||
pass
|
||||
return data
|
||||
else:
|
||||
return {
|
||||
|
||||
@ -81,14 +81,15 @@ class OpenSearchConfig(BaseModel):
|
||||
pool_maxsize=20,
|
||||
)
|
||||
|
||||
if self.auth_method == "basic":
|
||||
logger.info("Using basic authentication for OpenSearch Vector DB")
|
||||
match self.auth_method:
|
||||
case AuthMethod.BASIC:
|
||||
logger.info("Using basic authentication for OpenSearch Vector DB")
|
||||
|
||||
params["http_auth"] = (self.user, self.password)
|
||||
elif self.auth_method == "aws_managed_iam":
|
||||
logger.info("Using AWS managed IAM role for OpenSearch Vector DB")
|
||||
params["http_auth"] = (self.user, self.password)
|
||||
case AuthMethod.AWS_MANAGED_IAM:
|
||||
logger.info("Using AWS managed IAM role for OpenSearch Vector DB")
|
||||
|
||||
params["http_auth"] = self.create_aws_managed_iam_auth()
|
||||
params["http_auth"] = self.create_aws_managed_iam_auth()
|
||||
|
||||
return params
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user