mirror of
https://github.com/langgenius/dify.git
synced 2026-05-10 05:56:31 +08:00
Fix: Persian conversation titles robust detection, retry & translation fallback; precompile regex; move langdetect import; robust JSON parsing; lower LLM temperature; add tests; resolve Copilot comments (#29745)
This commit is contained in:
parent
5dc6fed97c
commit
a177097228
@ -43,6 +43,64 @@ _PERSIAN_HEURISTIC = re.compile(
|
||||
flags=re.IGNORECASE,
|
||||
)
|
||||
|
||||
# Precompiled regex for Persian-specific characters (including Persian ye U+06CC)
|
||||
_PERSIAN_CHARS_RE = re.compile(r"[پچژگک\u06CC]")
|
||||
|
||||
# Optional langdetect import — import once at module import time to avoid repeated lookups
|
||||
_LANGDETECT_AVAILABLE = False
|
||||
try:
|
||||
from langdetect import DetectorFactory, detect # type: ignore
|
||||
|
||||
DetectorFactory.seed = 0
|
||||
_LANGDETECT_AVAILABLE = True
|
||||
except Exception:
|
||||
detect = None
|
||||
DetectorFactory = None
|
||||
_LANGDETECT_AVAILABLE = False
|
||||
|
||||
|
||||
def _contains_persian(text: str) -> bool:
|
||||
"""Return True if text appears to be Persian (Farsi).
|
||||
|
||||
Detection is multi-layered: quick character check, word heuristics, and
|
||||
an optional langdetect fallback when available.
|
||||
"""
|
||||
text = text or ""
|
||||
|
||||
# 1) Quick check: Persian-specific letters
|
||||
if _PERSIAN_CHARS_RE.search(text):
|
||||
return True
|
||||
|
||||
# 2) Heuristic check for common Persian words (fast, precompiled)
|
||||
if _PERSIAN_HEURISTIC.search(text):
|
||||
return True
|
||||
|
||||
# 3) Fallback: language detection (more expensive) — only run if langdetect is available
|
||||
if _LANGDETECT_AVAILABLE and detect is not None:
|
||||
try:
|
||||
return detect(text) == "fa"
|
||||
except Exception as exc:
|
||||
# langdetect can fail for very short/ambiguous texts; log and continue
|
||||
logger.debug("langdetect detection failed: %s", exc)
|
||||
|
||||
return False
|
||||
|
||||
|
||||
# Precompiled regex for Persian-specific characters (including Persian ye U+06CC)
|
||||
_PERSIAN_CHARS_RE = re.compile(r"[پچژگک\u06CC]")
|
||||
|
||||
# Optional langdetect import — import once at module import time to avoid repeated lookups
|
||||
_LANGDETECT_AVAILABLE = False
|
||||
try:
|
||||
from langdetect import DetectorFactory, detect # type: ignore
|
||||
|
||||
DetectorFactory.seed = 0
|
||||
_LANGDETECT_AVAILABLE = True
|
||||
except Exception:
|
||||
detect = None
|
||||
DetectorFactory = None
|
||||
_LANGDETECT_AVAILABLE = False
|
||||
|
||||
|
||||
class WorkflowServiceInterface(Protocol):
|
||||
def get_draft_workflow(self, app_model: App, workflow_id: str | None = None) -> Workflow | None:
|
||||
@ -59,35 +117,7 @@ class LLMGenerator:
|
||||
):
|
||||
prompt = CONVERSATION_TITLE_PROMPT
|
||||
|
||||
def _contains_persian(text: str) -> bool:
|
||||
# Normalize input once
|
||||
text = text or ""
|
||||
|
||||
# 1) Quick check: Persian-specific letters (پ چ ژ گ ک and persian ye U+06CC)
|
||||
if bool(re.search(r"[پچژگک\u06CC]", text)):
|
||||
return True
|
||||
|
||||
# 2) Heuristic check for common Persian words (fast, precompiled)
|
||||
if _PERSIAN_HEURISTIC.search(text):
|
||||
return True
|
||||
|
||||
# 3) Fallback: language detection (more expensive) — only run if langdetect is available
|
||||
try:
|
||||
import importlib
|
||||
|
||||
if importlib.util.find_spec("langdetect") is not None:
|
||||
langdetect = importlib.import_module("langdetect")
|
||||
DetectorFactory = langdetect.DetectorFactory
|
||||
detect = langdetect.detect
|
||||
|
||||
DetectorFactory.seed = 0
|
||||
if detect(text) == "fa":
|
||||
return True
|
||||
except Exception as exc:
|
||||
# langdetect may fail on short/ambiguous texts; log debug and continue
|
||||
logger.debug("langdetect detection failed: %s", exc)
|
||||
|
||||
return False
|
||||
# _contains_persian is implemented at module scope for reuse and testability
|
||||
|
||||
if len(query) > 2000:
|
||||
query = query[:300] + "...[TRUNCATED]..." + query[-300:]
|
||||
@ -129,27 +159,35 @@ class LLMGenerator:
|
||||
model_parameters={"max_tokens": 500, "temperature": 0.2},
|
||||
stream=False,
|
||||
)
|
||||
except Exception:
|
||||
except (InvokeError, InvokeAuthorizationError):
|
||||
logger.exception("Failed to invoke LLM for conversation name generation")
|
||||
break
|
||||
|
||||
answer = cast(str, response.message.content)
|
||||
cleaned_answer = re.sub(r"^.*(\{.*\}).*$", r"\1", answer, flags=re.DOTALL)
|
||||
if cleaned_answer is None:
|
||||
continue
|
||||
|
||||
# Parse JSON, try to repair malformed JSON if necessary
|
||||
candidate = ""
|
||||
result_dict = None
|
||||
try:
|
||||
result_dict = json.loads(cleaned_answer)
|
||||
except json.JSONDecodeError:
|
||||
def _extract_and_parse_json(raw_text: str) -> dict | None:
|
||||
if not raw_text:
|
||||
return None
|
||||
# Try to extract JSON object by braces
|
||||
first_brace = raw_text.find("{")
|
||||
last_brace = raw_text.rfind("}")
|
||||
if first_brace != -1 and last_brace != -1 and last_brace > first_brace:
|
||||
candidate_json = raw_text[first_brace : last_brace + 1]
|
||||
else:
|
||||
candidate_json = raw_text
|
||||
|
||||
# Try normal json loads, then attempt to repair malformed JSON
|
||||
try:
|
||||
result_dict = json_repair.loads(cleaned_answer)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to parse LLM JSON when generating conversation name; using raw query as fallback"
|
||||
)
|
||||
return json.loads(candidate_json)
|
||||
except json.JSONDecodeError:
|
||||
try:
|
||||
repaired = json_repair.repair(candidate_json)
|
||||
return json.loads(repaired)
|
||||
except Exception as exc:
|
||||
logger.debug("JSON parse/repair failed: %s", exc)
|
||||
return None
|
||||
|
||||
result_dict = _extract_and_parse_json(answer)
|
||||
|
||||
if not isinstance(result_dict, dict):
|
||||
candidate = query
|
||||
@ -201,10 +239,8 @@ class LLMGenerator:
|
||||
translation = cast(str, translate_response.message.content).strip()
|
||||
if _contains_persian(translation):
|
||||
name = translation
|
||||
except InvokeError:
|
||||
except (InvokeError, InvokeAuthorizationError):
|
||||
logger.exception("Failed to obtain Persian translation for the conversation title")
|
||||
except Exception:
|
||||
logger.exception("Unexpected error obtaining Persian translation for the conversation title")
|
||||
|
||||
if len(name) > 75:
|
||||
name = name[:75] + "..."
|
||||
|
||||
@ -11,7 +11,7 @@ Automatically identify the language of the user’s input (e.g. English, Chinese
|
||||
- The title must be natural, friendly, and in the same language as the input.
|
||||
- If the input is a direct question to the model, you may add an emoji at the end.
|
||||
|
||||
- Special Note for Persian (Farsi): If the input is Persian (Farsi), ALWAYS generate the title in Persian (Farsi). Use Persian characters (for example: پ، چ، ژ، گ، ک، ی) and ensure the "Language Type" field is "Persian" or "Farsi". Do NOT use Arabic or any other language or script when the input is Persian.
|
||||
- Special Note for Persian (Farsi): If the input is Persian (Farsi), ALWAYS generate the title in Persian (Farsi). Prefer using distinctly Persian characters (for example: پ، چ، ژ، گ). You may also use ک and ی, but prefer the Persian form (e.g., U+06CC for "ye"). Ensure the "Language Type" field is "Persian" or "Farsi". Do NOT use Arabic or any other language or script when the input is Persian.
|
||||
|
||||
3. Output Format
|
||||
Return **only** a valid JSON object with these exact keys and no additional text:
|
||||
|
||||
@ -63,3 +63,48 @@ def test_generate_conversation_name_translation_fallback(mock_get_model):
|
||||
# Final name should contain Persian character 'پ' from translation fallback
|
||||
assert "پ" in name
|
||||
assert model_instance.invoke_llm.call_count >= 3
|
||||
|
||||
|
||||
@patch("core.llm_generator.llm_generator.ModelManager.get_default_model_instance")
|
||||
def test_generate_conversation_name_enforces_persian_retry_prompt(mock_get_model):
|
||||
# A Persian input containing Persian-specific character 'پ'
|
||||
persian_query = "سلام، چطوری؟ پ"
|
||||
|
||||
# First model response: misdetected as Arabic and returns Arabic title
|
||||
first_resp = DummyResponse(make_json_response("Arabic", "مرحبا"))
|
||||
# Second response (after retry): returns a Persian title with Persian-specific chars
|
||||
second_resp = DummyResponse(make_json_response("Persian", "عنوان پِرس"))
|
||||
|
||||
model_instance = MagicMock()
|
||||
model_instance.invoke_llm.side_effect = [first_resp, second_resp]
|
||||
|
||||
mock_get_model.return_value = model_instance
|
||||
|
||||
name = LLMGenerator.generate_conversation_name("tenant1", persian_query)
|
||||
|
||||
# The final name should come from the Persian response (contains Persian-specific char 'پ')
|
||||
assert "پ" in name
|
||||
|
||||
# Ensure the retry prompt included a stronger Persian-only instruction
|
||||
assert model_instance.invoke_llm.call_count >= 2
|
||||
second_call_kwargs = model_instance.invoke_llm.call_args_list[1][1]
|
||||
prompt_msg = second_call_kwargs["prompt_messages"][0]
|
||||
assert "CRITICAL: You must output the title in Persian" in prompt_msg.content
|
||||
|
||||
|
||||
@patch("core.llm_generator.llm_generator.ModelManager.get_default_model_instance")
|
||||
def test_generate_conversation_name_handles_invoke_error(mock_get_model):
|
||||
# If LLM invocation raises InvokeError, ensure fallback/translation is attempted and no exception bubbles
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
|
||||
persian_query = "سلام، پ"
|
||||
|
||||
model_instance = MagicMock()
|
||||
# First invocation raises InvokeError; translation attempt returns Persian translation
|
||||
model_instance.invoke_llm.side_effect = [InvokeError("boom"), DummyResponse("عنوان ترجمه شده پ")]
|
||||
|
||||
mock_get_model.return_value = model_instance
|
||||
|
||||
name = LLMGenerator.generate_conversation_name("tenant1", persian_query)
|
||||
|
||||
assert "پ" in name
|
||||
@ -1,4 +1,6 @@
|
||||
import sys, types, json
|
||||
import sys
|
||||
import types
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
# Ensure the repo `api/` directory is importable so tests can import `core.*` without external env setup
|
||||
@ -228,3 +230,28 @@ def test_generate_conversation_name_persian(monkeypatch):
|
||||
|
||||
# Assert: title should be the Persian string we returned
|
||||
assert name == "عنوان تستی"
|
||||
|
||||
|
||||
def test_contains_persian_character_and_heuristics(monkeypatch):
|
||||
from core.llm_generator.llm_generator import _contains_persian, _PERSIAN_CHARS_RE, _PERSIAN_HEURISTIC
|
||||
|
||||
# By single Persian-specific character
|
||||
assert _contains_persian("این یک تست پ") is True
|
||||
|
||||
# By heuristic Persian word
|
||||
assert _contains_persian("سلام دوست") is True
|
||||
|
||||
|
||||
def test_contains_persian_langdetect_fallback(monkeypatch):
|
||||
import core.llm_generator.llm_generator as lg
|
||||
|
||||
# Simulate langdetect being available and detecting Persian
|
||||
monkeypatch.setattr(lg, "_LANGDETECT_AVAILABLE", True)
|
||||
monkeypatch.setattr(lg, "detect", lambda text: "fa")
|
||||
|
||||
assert lg._contains_persian("short ambiguous text") is True
|
||||
|
||||
# Reset monkeypatch
|
||||
monkeypatch.setattr(lg, "_LANGDETECT_AVAILABLE", False)
|
||||
monkeypatch.setattr(lg, "detect", None)
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user