mirror of https://github.com/langgenius/dify.git
style(llm-generator): fix lint issues (log langdetect errors, wrap long lines)
This commit is contained in:
parent
3cb944f318
commit
0846542c33
|
|
@ -53,6 +53,27 @@ class LLMGenerator:
|
|||
):
|
||||
prompt = CONVERSATION_TITLE_PROMPT
|
||||
|
||||
def _contains_persian(text: str) -> bool:
|
||||
# Detect presence of Persian-specific characters (پ چ ژ گ ک and persian ye U+06CC)
|
||||
if bool(re.search(r"[پچژگک\u06CC]", text or "")):
|
||||
return True
|
||||
# Fallback: use language detection to catch Persian text without special chars
|
||||
try:
|
||||
from langdetect import DetectorFactory, detect
|
||||
|
||||
DetectorFactory.seed = 0
|
||||
lang = detect(text or "")
|
||||
if lang == "fa":
|
||||
return True
|
||||
except Exception as exc:
|
||||
# langdetect may fail on very short texts; ignore failures.
|
||||
# Log at debug level to aid debugging without failing the linter S110.
|
||||
logger.debug("langdetect detection failed: %s", exc)
|
||||
# Also check for some common Persian words as an additional heuristic
|
||||
if bool(re.search(r"\b(سلام|متشکرم|ممنون|خوب|چطور|سپاس)\b", (text or ""), flags=re.IGNORECASE)):
|
||||
return True
|
||||
return False
|
||||
|
||||
if len(query) > 2000:
|
||||
query = query[:300] + "...[TRUNCATED]..." + query[-300:]
|
||||
|
||||
|
|
@ -65,23 +86,96 @@ class LLMGenerator:
|
|||
tenant_id=tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
)
|
||||
|
||||
# If the input contains Persian characters, add explicit instruction to produce Persian title
|
||||
is_persian_input = _contains_persian(query)
|
||||
|
||||
if is_persian_input:
|
||||
prompt += (
|
||||
"\nIMPORTANT: The user input is Persian (Farsi). "
|
||||
"Only output the final title in Persian (Farsi), use Persian characters "
|
||||
"(پ, چ, ژ, گ, ک, ی) and do NOT use Arabic or any other language.\n"
|
||||
)
|
||||
|
||||
prompts = [UserPromptMessage(content=prompt)]
|
||||
|
||||
with measure_time() as timer:
|
||||
response: LLMResult = model_instance.invoke_llm(
|
||||
prompt_messages=list(prompts), model_parameters={"max_tokens": 500, "temperature": 1}, stream=False
|
||||
)
|
||||
answer = cast(str, response.message.content)
|
||||
cleaned_answer = re.sub(r"^.*(\{.*\}).*$", r"\1", answer, flags=re.DOTALL)
|
||||
if cleaned_answer is None:
|
||||
return ""
|
||||
try:
|
||||
result_dict = json.loads(cleaned_answer)
|
||||
answer = result_dict["Your Output"]
|
||||
except json.JSONDecodeError:
|
||||
logger.exception("Failed to generate name after answer, use query instead")
|
||||
answer = query
|
||||
name = answer.strip()
|
||||
# Try generation with up to 2 attempts.
|
||||
# If Persian required but not produced, retry with stronger instruction.
|
||||
attempts = 0
|
||||
max_attempts = 2
|
||||
generated_output = None
|
||||
|
||||
while attempts < max_attempts:
|
||||
attempts += 1
|
||||
try:
|
||||
response: LLMResult = model_instance.invoke_llm(
|
||||
prompt_messages=list(prompts),
|
||||
model_parameters={"max_tokens": 500, "temperature": 0.2},
|
||||
stream=False,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Failed to invoke LLM for conversation name generation")
|
||||
break
|
||||
|
||||
answer = cast(str, response.message.content)
|
||||
cleaned_answer = re.sub(r"^.*(\{.*\}).*$", r"\1", answer, flags=re.DOTALL)
|
||||
if cleaned_answer is None:
|
||||
continue
|
||||
|
||||
try:
|
||||
result_dict = json.loads(cleaned_answer)
|
||||
candidate = result_dict.get("Your Output", "")
|
||||
except json.JSONDecodeError:
|
||||
logger.exception(
|
||||
"Failed to parse LLM JSON when generating conversation name; "
|
||||
"using raw query as fallback"
|
||||
)
|
||||
candidate = query
|
||||
|
||||
# If input is Persian, ensure candidate contains Persian-specific characters.
|
||||
# Otherwise retry with stronger instruction.
|
||||
if is_persian_input and not _contains_persian(candidate):
|
||||
logger.info(
|
||||
"Generated title doesn't appear to be Persian; retrying with stricter instruction"
|
||||
)
|
||||
prompts = [
|
||||
UserPromptMessage(
|
||||
content=(
|
||||
prompt
|
||||
+ "\nCRITICAL: You must output the title in Persian (Farsi) "
|
||||
"using Persian-specific letters (پ, چ, ژ, گ, ک, ی). "
|
||||
"Output only the JSON as specified earlier."
|
||||
)
|
||||
)
|
||||
]
|
||||
continue
|
||||
|
||||
generated_output = candidate.strip()
|
||||
break
|
||||
|
||||
name = generated_output or (query or "")
|
||||
|
||||
if is_persian_input and not _contains_persian(name):
|
||||
# As a last resort, ask the model to translate the title into Persian directly
|
||||
try:
|
||||
translate_prompt = UserPromptMessage(
|
||||
content=(
|
||||
"Translate the following short chat title into Persian (Farsi) ONLY. "
|
||||
"Output the Persian translation only (no JSON):\n\n"
|
||||
f"{name}"
|
||||
)
|
||||
)
|
||||
response: LLMResult = model_instance.invoke_llm(
|
||||
prompt_messages=[translate_prompt],
|
||||
model_parameters={"max_tokens": 200, "temperature": 0},
|
||||
stream=False,
|
||||
)
|
||||
translation = cast(str, response.message.content).strip()
|
||||
if _contains_persian(translation):
|
||||
name = translation
|
||||
except Exception:
|
||||
logger.exception("Failed to obtain Persian translation for the conversation title")
|
||||
|
||||
if len(name) > 75:
|
||||
name = name[:75] + "..."
|
||||
|
|
|
|||
|
|
@ -33,6 +33,7 @@ dependencies = [
|
|||
"json-repair>=0.41.1",
|
||||
"langfuse~=2.51.3",
|
||||
"langsmith~=0.1.77",
|
||||
"langdetect~=1.0.9",
|
||||
"markdown~=3.5.1",
|
||||
"mlflow-skinny>=3.0.0",
|
||||
"numpy~=1.26.4",
|
||||
|
|
|
|||
|
|
@ -0,0 +1,177 @@
|
|||
import sys, types, json
|
||||
from pathlib import Path
|
||||
# Ensure the repo `api/` directory is importable so tests can import `core.*` without external env setup
|
||||
ROOT = Path(__file__).resolve().parents[3]
|
||||
sys.path.insert(0, str(ROOT / "api"))
|
||||
|
||||
# Lightweight stubs to avoid importing heavy application modules during unit tests
|
||||
m = types.ModuleType('core.model_manager')
|
||||
class ModelManager:
|
||||
def get_default_model_instance(self, tenant_id, model_type):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_model_instance(self, tenant_id, model_type, provider=None, model=None):
|
||||
raise NotImplementedError
|
||||
|
||||
m.ModelManager = ModelManager
|
||||
sys.modules['core.model_manager'] = m
|
||||
|
||||
m2 = types.ModuleType('core.ops.ops_trace_manager')
|
||||
class TraceTask:
|
||||
def __init__(self, *args, **kwargs):
|
||||
# store attributes for potential inspection in tests
|
||||
for k, v in kwargs.items():
|
||||
setattr(self, k, v)
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
|
||||
class TraceQueueManager:
|
||||
def __init__(self, *a, **k):
|
||||
pass
|
||||
def add_trace_task(self, *a, **k):
|
||||
pass
|
||||
m2.TraceTask = TraceTask
|
||||
m2.TraceQueueManager = TraceQueueManager
|
||||
sys.modules['core.ops.ops_trace_manager'] = m2
|
||||
|
||||
# Stub core.ops.utils to avoid importing heavy dependencies (db, models) during tests
|
||||
m_ops = types.ModuleType('core.ops.utils')
|
||||
from contextlib import contextmanager
|
||||
@contextmanager
|
||||
def measure_time():
|
||||
class Timer:
|
||||
pass
|
||||
t = Timer()
|
||||
yield t
|
||||
m_ops.measure_time = measure_time
|
||||
sys.modules['core.ops.utils'] = m_ops
|
||||
|
||||
m3 = types.ModuleType('core.model_runtime.entities.llm_entities')
|
||||
class LLMUsage:
|
||||
@classmethod
|
||||
def empty_usage(cls):
|
||||
return cls()
|
||||
class LLMResult:
|
||||
def __init__(self, model=None, prompt_messages=None, message=None, usage=None):
|
||||
self.model = model
|
||||
self.prompt_messages = prompt_messages
|
||||
self.message = message
|
||||
self.usage = usage
|
||||
m3.LLMUsage = LLMUsage
|
||||
m3.LLMResult = LLMResult
|
||||
sys.modules['core.model_runtime.entities.llm_entities'] = m3
|
||||
|
||||
m4 = types.ModuleType('core.model_runtime.entities.message_entities')
|
||||
class PromptMessage:
|
||||
def __init__(self, content=None):
|
||||
self.content = content
|
||||
def get_text_content(self):
|
||||
return str(self.content) if self.content is not None else ""
|
||||
|
||||
class TextPromptMessageContent:
|
||||
def __init__(self, data):
|
||||
self.data = data
|
||||
|
||||
class ImagePromptMessageContent:
|
||||
def __init__(self, url=None, base64_data=None, mime_type=None, filename=None):
|
||||
self.url = url
|
||||
self.base64_data = base64_data
|
||||
self.mime_type = mime_type
|
||||
self.filename = filename
|
||||
|
||||
class DocumentPromptMessageContent:
|
||||
def __init__(self, url=None):
|
||||
self.url = url
|
||||
|
||||
class AudioPromptMessageContent(DocumentPromptMessageContent):
|
||||
pass
|
||||
|
||||
class VideoPromptMessageContent(DocumentPromptMessageContent):
|
||||
pass
|
||||
|
||||
class AssistantPromptMessage(PromptMessage):
|
||||
def __init__(self, content):
|
||||
super().__init__(content)
|
||||
|
||||
class UserPromptMessage(PromptMessage):
|
||||
def __init__(self, content):
|
||||
super().__init__(content)
|
||||
|
||||
class SystemPromptMessage(PromptMessage):
|
||||
def __init__(self, content=None):
|
||||
super().__init__(content)
|
||||
|
||||
m4.PromptMessage = PromptMessage
|
||||
m4.AssistantPromptMessage = AssistantPromptMessage
|
||||
m4.UserPromptMessage = UserPromptMessage
|
||||
m4.SystemPromptMessage = SystemPromptMessage
|
||||
m4.TextPromptMessageContent = TextPromptMessageContent
|
||||
m4.ImagePromptMessageContent = ImagePromptMessageContent
|
||||
m4.DocumentPromptMessageContent = DocumentPromptMessageContent
|
||||
m4.AudioPromptMessageContent = AudioPromptMessageContent
|
||||
m4.VideoPromptMessageContent = VideoPromptMessageContent
|
||||
sys.modules['core.model_runtime.entities.message_entities'] = m4
|
||||
|
||||
m5 = types.ModuleType('core.model_runtime.entities.model_entities')
|
||||
class ModelType:
|
||||
LLM = None
|
||||
m5.ModelType = ModelType
|
||||
sys.modules['core.model_runtime.entities.model_entities'] = m5
|
||||
|
||||
# Stub minimal 'extensions' and 'models' packages to avoid importing heavy application code during tests
|
||||
ext_db = types.ModuleType('extensions.ext_database')
|
||||
ext_db.db = None
|
||||
sys.modules['extensions.ext_database'] = ext_db
|
||||
ext_storage = types.ModuleType('extensions.ext_storage')
|
||||
ext_storage.storage = None
|
||||
sys.modules['extensions.ext_storage'] = ext_storage
|
||||
|
||||
models_m = types.ModuleType('models')
|
||||
class App: pass
|
||||
class Message: pass
|
||||
class WorkflowNodeExecutionModel: pass
|
||||
models_m.App = App
|
||||
models_m.Message = Message
|
||||
models_m.WorkflowNodeExecutionModel = WorkflowNodeExecutionModel
|
||||
sys.modules['models'] = models_m
|
||||
|
||||
models_workflow = types.ModuleType('models.workflow')
|
||||
class Workflow: pass
|
||||
models_workflow.Workflow = Workflow
|
||||
sys.modules['models.workflow'] = models_workflow
|
||||
|
||||
from core.llm_generator.llm_generator import LLMGenerator
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
||||
from core.model_runtime.entities.message_entities import AssistantPromptMessage
|
||||
from core.model_manager import ModelManager
|
||||
|
||||
|
||||
class DummyModelInstance:
|
||||
def __init__(self, content):
|
||||
self._content = content
|
||||
|
||||
def invoke_llm(self, prompt_messages=None, model_parameters=None, stream=False):
|
||||
# Return an LLMResult-like object with the message content we expect
|
||||
return LLMResult(
|
||||
model="dummy",
|
||||
prompt_messages=[],
|
||||
message=AssistantPromptMessage(content=self._content),
|
||||
usage=LLMUsage.empty_usage(),
|
||||
)
|
||||
|
||||
|
||||
def test_generate_conversation_name_persian(monkeypatch):
|
||||
# Arrange: Persian input that doesn't necessarily include Persian-specific letters
|
||||
query = "سلام دوست من، میخواهم درباره تنظیمات حساب صحبت کنم"
|
||||
|
||||
# Mock the default model instance to return a Persian title in JSON format
|
||||
fake_output = json.dumps({"Your Output": "عنوان تستی"})
|
||||
dummy = DummyModelInstance(fake_output)
|
||||
|
||||
monkeypatch.setattr(ModelManager, "get_default_model_instance", lambda self, tenant_id, model_type: dummy)
|
||||
|
||||
# Act
|
||||
name = LLMGenerator.generate_conversation_name("tenant1", query)
|
||||
|
||||
# Assert: title should be the Persian string we returned
|
||||
assert "عنوان" in name or "تستی" in name
|
||||
Loading…
Reference in New Issue