diff --git a/api/core/llm_generator/llm_generator.py b/api/core/llm_generator/llm_generator.py index 6b168fd4e8..4161443aea 100644 --- a/api/core/llm_generator/llm_generator.py +++ b/api/core/llm_generator/llm_generator.py @@ -53,6 +53,27 @@ class LLMGenerator: ): prompt = CONVERSATION_TITLE_PROMPT + def _contains_persian(text: str) -> bool: + # Detect presence of Persian-specific characters (پ چ ژ گ ک and persian ye U+06CC) + if bool(re.search(r"[پچژگک\u06CC]", text or "")): + return True + # Fallback: use language detection to catch Persian text without special chars + try: + from langdetect import DetectorFactory, detect + + DetectorFactory.seed = 0 + lang = detect(text or "") + if lang == "fa": + return True + except Exception as exc: + # langdetect may fail on very short texts; ignore failures. + # Log at debug level to aid debugging without failing the linter S110. + logger.debug("langdetect detection failed: %s", exc) + # Also check for some common Persian words as an additional heuristic + if bool(re.search(r"\b(سلام|متشکرم|ممنون|خوب|چطور|سپاس)\b", (text or ""), flags=re.IGNORECASE)): + return True + return False + if len(query) > 2000: query = query[:300] + "...[TRUNCATED]..." + query[-300:] @@ -65,23 +86,96 @@ class LLMGenerator: tenant_id=tenant_id, model_type=ModelType.LLM, ) + + # If the input contains Persian characters, add explicit instruction to produce Persian title + is_persian_input = _contains_persian(query) + + if is_persian_input: + prompt += ( + "\nIMPORTANT: The user input is Persian (Farsi). " + "Only output the final title in Persian (Farsi), use Persian characters " + "(پ, چ, ژ, گ, ک, ی) and do NOT use Arabic or any other language.\n" + ) + prompts = [UserPromptMessage(content=prompt)] with measure_time() as timer: - response: LLMResult = model_instance.invoke_llm( - prompt_messages=list(prompts), model_parameters={"max_tokens": 500, "temperature": 1}, stream=False - ) - answer = cast(str, response.message.content) - cleaned_answer = re.sub(r"^.*(\{.*\}).*$", r"\1", answer, flags=re.DOTALL) - if cleaned_answer is None: - return "" - try: - result_dict = json.loads(cleaned_answer) - answer = result_dict["Your Output"] - except json.JSONDecodeError: - logger.exception("Failed to generate name after answer, use query instead") - answer = query - name = answer.strip() + # Try generation with up to 2 attempts. + # If Persian required but not produced, retry with stronger instruction. + attempts = 0 + max_attempts = 2 + generated_output = None + + while attempts < max_attempts: + attempts += 1 + try: + response: LLMResult = model_instance.invoke_llm( + prompt_messages=list(prompts), + model_parameters={"max_tokens": 500, "temperature": 0.2}, + stream=False, + ) + except Exception: + logger.exception("Failed to invoke LLM for conversation name generation") + break + + answer = cast(str, response.message.content) + cleaned_answer = re.sub(r"^.*(\{.*\}).*$", r"\1", answer, flags=re.DOTALL) + if cleaned_answer is None: + continue + + try: + result_dict = json.loads(cleaned_answer) + candidate = result_dict.get("Your Output", "") + except json.JSONDecodeError: + logger.exception( + "Failed to parse LLM JSON when generating conversation name; " + "using raw query as fallback" + ) + candidate = query + + # If input is Persian, ensure candidate contains Persian-specific characters. + # Otherwise retry with stronger instruction. + if is_persian_input and not _contains_persian(candidate): + logger.info( + "Generated title doesn't appear to be Persian; retrying with stricter instruction" + ) + prompts = [ + UserPromptMessage( + content=( + prompt + + "\nCRITICAL: You must output the title in Persian (Farsi) " + "using Persian-specific letters (پ, چ, ژ, گ, ک, ی). " + "Output only the JSON as specified earlier." + ) + ) + ] + continue + + generated_output = candidate.strip() + break + + name = generated_output or (query or "") + + if is_persian_input and not _contains_persian(name): + # As a last resort, ask the model to translate the title into Persian directly + try: + translate_prompt = UserPromptMessage( + content=( + "Translate the following short chat title into Persian (Farsi) ONLY. " + "Output the Persian translation only (no JSON):\n\n" + f"{name}" + ) + ) + response: LLMResult = model_instance.invoke_llm( + prompt_messages=[translate_prompt], + model_parameters={"max_tokens": 200, "temperature": 0}, + stream=False, + ) + translation = cast(str, response.message.content).strip() + if _contains_persian(translation): + name = translation + except Exception: + logger.exception("Failed to obtain Persian translation for the conversation title") if len(name) > 75: name = name[:75] + "..." diff --git a/api/pyproject.toml b/api/pyproject.toml index 4f400129c1..2e7c96699f 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -33,6 +33,7 @@ dependencies = [ "json-repair>=0.41.1", "langfuse~=2.51.3", "langsmith~=0.1.77", + "langdetect~=1.0.9", "markdown~=3.5.1", "mlflow-skinny>=3.0.0", "numpy~=1.26.4", diff --git a/tests/unit_tests/core/llm_generator/test_llm_generator_persian.py b/tests/unit_tests/core/llm_generator/test_llm_generator_persian.py new file mode 100644 index 0000000000..d825b9c3ef --- /dev/null +++ b/tests/unit_tests/core/llm_generator/test_llm_generator_persian.py @@ -0,0 +1,177 @@ +import sys, types, json +from pathlib import Path +# Ensure the repo `api/` directory is importable so tests can import `core.*` without external env setup +ROOT = Path(__file__).resolve().parents[3] +sys.path.insert(0, str(ROOT / "api")) + +# Lightweight stubs to avoid importing heavy application modules during unit tests +m = types.ModuleType('core.model_manager') +class ModelManager: + def get_default_model_instance(self, tenant_id, model_type): + raise NotImplementedError + + def get_model_instance(self, tenant_id, model_type, provider=None, model=None): + raise NotImplementedError + +m.ModelManager = ModelManager +sys.modules['core.model_manager'] = m + +m2 = types.ModuleType('core.ops.ops_trace_manager') +class TraceTask: + def __init__(self, *args, **kwargs): + # store attributes for potential inspection in tests + for k, v in kwargs.items(): + setattr(self, k, v) + self.args = args + self.kwargs = kwargs + +class TraceQueueManager: + def __init__(self, *a, **k): + pass + def add_trace_task(self, *a, **k): + pass +m2.TraceTask = TraceTask +m2.TraceQueueManager = TraceQueueManager +sys.modules['core.ops.ops_trace_manager'] = m2 + +# Stub core.ops.utils to avoid importing heavy dependencies (db, models) during tests +m_ops = types.ModuleType('core.ops.utils') +from contextlib import contextmanager +@contextmanager +def measure_time(): + class Timer: + pass + t = Timer() + yield t +m_ops.measure_time = measure_time +sys.modules['core.ops.utils'] = m_ops + +m3 = types.ModuleType('core.model_runtime.entities.llm_entities') +class LLMUsage: + @classmethod + def empty_usage(cls): + return cls() +class LLMResult: + def __init__(self, model=None, prompt_messages=None, message=None, usage=None): + self.model = model + self.prompt_messages = prompt_messages + self.message = message + self.usage = usage +m3.LLMUsage = LLMUsage +m3.LLMResult = LLMResult +sys.modules['core.model_runtime.entities.llm_entities'] = m3 + +m4 = types.ModuleType('core.model_runtime.entities.message_entities') +class PromptMessage: + def __init__(self, content=None): + self.content = content + def get_text_content(self): + return str(self.content) if self.content is not None else "" + +class TextPromptMessageContent: + def __init__(self, data): + self.data = data + +class ImagePromptMessageContent: + def __init__(self, url=None, base64_data=None, mime_type=None, filename=None): + self.url = url + self.base64_data = base64_data + self.mime_type = mime_type + self.filename = filename + +class DocumentPromptMessageContent: + def __init__(self, url=None): + self.url = url + +class AudioPromptMessageContent(DocumentPromptMessageContent): + pass + +class VideoPromptMessageContent(DocumentPromptMessageContent): + pass + +class AssistantPromptMessage(PromptMessage): + def __init__(self, content): + super().__init__(content) + +class UserPromptMessage(PromptMessage): + def __init__(self, content): + super().__init__(content) + +class SystemPromptMessage(PromptMessage): + def __init__(self, content=None): + super().__init__(content) + +m4.PromptMessage = PromptMessage +m4.AssistantPromptMessage = AssistantPromptMessage +m4.UserPromptMessage = UserPromptMessage +m4.SystemPromptMessage = SystemPromptMessage +m4.TextPromptMessageContent = TextPromptMessageContent +m4.ImagePromptMessageContent = ImagePromptMessageContent +m4.DocumentPromptMessageContent = DocumentPromptMessageContent +m4.AudioPromptMessageContent = AudioPromptMessageContent +m4.VideoPromptMessageContent = VideoPromptMessageContent +sys.modules['core.model_runtime.entities.message_entities'] = m4 + +m5 = types.ModuleType('core.model_runtime.entities.model_entities') +class ModelType: + LLM = None +m5.ModelType = ModelType +sys.modules['core.model_runtime.entities.model_entities'] = m5 + +# Stub minimal 'extensions' and 'models' packages to avoid importing heavy application code during tests +ext_db = types.ModuleType('extensions.ext_database') +ext_db.db = None +sys.modules['extensions.ext_database'] = ext_db +ext_storage = types.ModuleType('extensions.ext_storage') +ext_storage.storage = None +sys.modules['extensions.ext_storage'] = ext_storage + +models_m = types.ModuleType('models') +class App: pass +class Message: pass +class WorkflowNodeExecutionModel: pass +models_m.App = App +models_m.Message = Message +models_m.WorkflowNodeExecutionModel = WorkflowNodeExecutionModel +sys.modules['models'] = models_m + +models_workflow = types.ModuleType('models.workflow') +class Workflow: pass +models_workflow.Workflow = Workflow +sys.modules['models.workflow'] = models_workflow + +from core.llm_generator.llm_generator import LLMGenerator +from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage +from core.model_runtime.entities.message_entities import AssistantPromptMessage +from core.model_manager import ModelManager + + +class DummyModelInstance: + def __init__(self, content): + self._content = content + + def invoke_llm(self, prompt_messages=None, model_parameters=None, stream=False): + # Return an LLMResult-like object with the message content we expect + return LLMResult( + model="dummy", + prompt_messages=[], + message=AssistantPromptMessage(content=self._content), + usage=LLMUsage.empty_usage(), + ) + + +def test_generate_conversation_name_persian(monkeypatch): + # Arrange: Persian input that doesn't necessarily include Persian-specific letters + query = "سلام دوست من، میخواهم درباره تنظیمات حساب صحبت کنم" + + # Mock the default model instance to return a Persian title in JSON format + fake_output = json.dumps({"Your Output": "عنوان تستی"}) + dummy = DummyModelInstance(fake_output) + + monkeypatch.setattr(ModelManager, "get_default_model_instance", lambda self, tenant_id, model_type: dummy) + + # Act + name = LLMGenerator.generate_conversation_name("tenant1", query) + + # Assert: title should be the Persian string we returned + assert "عنوان" in name or "تستی" in name