mirror of https://github.com/langgenius/dify.git
Fix json in md when use quesion classifier node (#26992)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
This commit is contained in:
parent
5937a66e22
commit
830f891a74
|
|
@ -1,4 +1,5 @@
|
||||||
import json
|
import json
|
||||||
|
import re
|
||||||
from collections.abc import Mapping, Sequence
|
from collections.abc import Mapping, Sequence
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
|
@ -194,6 +195,8 @@ class QuestionClassifierNode(Node):
|
||||||
|
|
||||||
category_name = node_data.classes[0].name
|
category_name = node_data.classes[0].name
|
||||||
category_id = node_data.classes[0].id
|
category_id = node_data.classes[0].id
|
||||||
|
if "<think>" in result_text:
|
||||||
|
result_text = re.sub(r"<think[^>]*>[\s\S]*?</think>", "", result_text, flags=re.IGNORECASE)
|
||||||
result_text_json = parse_and_check_json_markdown(result_text, [])
|
result_text_json = parse_and_check_json_markdown(result_text, [])
|
||||||
# result_text_json = json.loads(result_text.strip('```JSON\n'))
|
# result_text_json = json.loads(result_text.strip('```JSON\n'))
|
||||||
if "category_name" in result_text_json and "category_id" in result_text_json:
|
if "category_name" in result_text_json and "category_id" in result_text_json:
|
||||||
|
|
|
||||||
|
|
@ -6,22 +6,22 @@ from core.llm_generator.output_parser.errors import OutputParserError
|
||||||
def parse_json_markdown(json_string: str):
|
def parse_json_markdown(json_string: str):
|
||||||
# Get json from the backticks/braces
|
# Get json from the backticks/braces
|
||||||
json_string = json_string.strip()
|
json_string = json_string.strip()
|
||||||
starts = ["```json", "```", "``", "`", "{"]
|
starts = ["```json", "```", "``", "`", "{", "["]
|
||||||
ends = ["```", "``", "`", "}"]
|
ends = ["```", "``", "`", "}", "]"]
|
||||||
end_index = -1
|
end_index = -1
|
||||||
start_index = 0
|
start_index = 0
|
||||||
parsed: dict = {}
|
parsed: dict = {}
|
||||||
for s in starts:
|
for s in starts:
|
||||||
start_index = json_string.find(s)
|
start_index = json_string.find(s)
|
||||||
if start_index != -1:
|
if start_index != -1:
|
||||||
if json_string[start_index] != "{":
|
if json_string[start_index] not in ("{", "["):
|
||||||
start_index += len(s)
|
start_index += len(s)
|
||||||
break
|
break
|
||||||
if start_index != -1:
|
if start_index != -1:
|
||||||
for e in ends:
|
for e in ends:
|
||||||
end_index = json_string.rfind(e, start_index)
|
end_index = json_string.rfind(e, start_index)
|
||||||
if end_index != -1:
|
if end_index != -1:
|
||||||
if json_string[end_index] == "}":
|
if json_string[end_index] in ("}", "]"):
|
||||||
end_index += 1
|
end_index += 1
|
||||||
break
|
break
|
||||||
if start_index != -1 and end_index != -1 and start_index < end_index:
|
if start_index != -1 and end_index != -1 and start_index < end_index:
|
||||||
|
|
@ -38,6 +38,12 @@ def parse_and_check_json_markdown(text: str, expected_keys: list[str]):
|
||||||
json_obj = parse_json_markdown(text)
|
json_obj = parse_json_markdown(text)
|
||||||
except json.JSONDecodeError as e:
|
except json.JSONDecodeError as e:
|
||||||
raise OutputParserError(f"got invalid json object. error: {e}")
|
raise OutputParserError(f"got invalid json object. error: {e}")
|
||||||
|
|
||||||
|
if isinstance(json_obj, list):
|
||||||
|
if len(json_obj) == 1 and isinstance(json_obj[0], dict):
|
||||||
|
json_obj = json_obj[0]
|
||||||
|
else:
|
||||||
|
raise OutputParserError(f"got invalid return object. obj:{json_obj}")
|
||||||
for key in expected_keys:
|
for key in expected_keys:
|
||||||
if key not in json_obj:
|
if key not in json_obj:
|
||||||
raise OutputParserError(
|
raise OutputParserError(
|
||||||
|
|
|
||||||
|
|
@ -86,3 +86,24 @@ def test_parse_and_check_json_markdown_multiple_blocks_fails():
|
||||||
# opening fence to the last closing fence, causing JSON decode failure.
|
# opening fence to the last closing fence, causing JSON decode failure.
|
||||||
with pytest.raises(OutputParserError):
|
with pytest.raises(OutputParserError):
|
||||||
parse_and_check_json_markdown(src, [])
|
parse_and_check_json_markdown(src, [])
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_and_check_json_markdown_handles_think_fenced_and_raw_variants():
|
||||||
|
expected = {"keywords": ["2"], "category_id": "2", "category_name": "2"}
|
||||||
|
cases = [
|
||||||
|
"""
|
||||||
|
```json
|
||||||
|
[{"keywords": ["2"], "category_id": "2", "category_name": "2"}]
|
||||||
|
```, error: Expecting value: line 1 column 1 (char 0)
|
||||||
|
""",
|
||||||
|
"""
|
||||||
|
```json
|
||||||
|
{"keywords": ["2"], "category_id": "2", "category_name": "2"}
|
||||||
|
```, error: Extra data: line 2 column 5 (char 66)
|
||||||
|
""",
|
||||||
|
'{"keywords": ["2"], "category_id": "2", "category_name": "2"}',
|
||||||
|
'[{"keywords": ["2"], "category_id": "2", "category_name": "2"}]',
|
||||||
|
]
|
||||||
|
for src in cases:
|
||||||
|
obj = parse_and_check_json_markdown(src, ["keywords", "category_id", "category_name"])
|
||||||
|
assert obj == expected
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue