diff --git a/api/core/app/entities/app_invoke_entities.py b/api/core/app/entities/app_invoke_entities.py index 0cb573cb86..5a2941d1ea 100644 --- a/api/core/app/entities/app_invoke_entities.py +++ b/api/core/app/entities/app_invoke_entities.py @@ -275,14 +275,12 @@ class RagPipelineGenerateEntity(WorkflowAppGenerateEntity): start_node_id: str | None = None -from core.ops.ops_trace_manager import TraceQueueManager - -AppGenerateEntity.model_rebuild() -EasyUIBasedAppGenerateEntity.model_rebuild() -ConversationAppGenerateEntity.model_rebuild() -ChatAppGenerateEntity.model_rebuild() -CompletionAppGenerateEntity.model_rebuild() -AgentChatAppGenerateEntity.model_rebuild() -AdvancedChatAppGenerateEntity.model_rebuild() -WorkflowAppGenerateEntity.model_rebuild() -RagPipelineGenerateEntity.model_rebuild() +AppGenerateEntity.model_rebuild(raise_errors=False) +EasyUIBasedAppGenerateEntity.model_rebuild(raise_errors=False) +ConversationAppGenerateEntity.model_rebuild(raise_errors=False) +ChatAppGenerateEntity.model_rebuild(raise_errors=False) +CompletionAppGenerateEntity.model_rebuild(raise_errors=False) +AgentChatAppGenerateEntity.model_rebuild(raise_errors=False) +AdvancedChatAppGenerateEntity.model_rebuild(raise_errors=False) +WorkflowAppGenerateEntity.model_rebuild(raise_errors=False) +RagPipelineGenerateEntity.model_rebuild(raise_errors=False) diff --git a/api/core/llm_generator/llm_generator.py b/api/core/llm_generator/llm_generator.py index b4c3ec1caf..c3c1829e96 100644 --- a/api/core/llm_generator/llm_generator.py +++ b/api/core/llm_generator/llm_generator.py @@ -26,7 +26,6 @@ from core.model_runtime.entities.message_entities import PromptMessage, SystemPr from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.ops.entities.trace_entity import TraceTaskName -from core.ops.ops_trace_manager import TraceQueueManager, TraceTask from core.ops.utils import measure_time from core.prompt.utils.prompt_template_parser import PromptTemplateParser from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey @@ -37,6 +36,55 @@ from models.workflow import Workflow logger = logging.getLogger(__name__) +# Precompiled heuristic to detect common Persian (Farsi) words in short inputs. +# Using a compiled regex avoids repeated recompilation on every call. +_PERSIAN_HEURISTIC = re.compile( + r"\b(سلام|متشکرم|ممنون|خوب|چطور|سپاس)\b", + flags=re.IGNORECASE, +) + +# Precompiled regex for Persian-specific characters (including Persian ye U+06CC) +_persian_chars_re = re.compile(r"[پچژگک\u06CC]") + +# Optional langdetect import — import once at module import time to avoid repeated lookups +_langdetect_available = False +try: + from langdetect import DetectorFactory, detect # type: ignore + + DetectorFactory.seed = 0 + _langdetect_available = True +except Exception: + detect = None + DetectorFactory = None + _langdetect_available = False + + +def _contains_persian(text: str) -> bool: + """Return True if text appears to be Persian (Farsi). + + Detection is multi-layered: quick character check, word heuristics, and + an optional langdetect fallback when available. + """ + text = text or "" + + # 1) Quick check: Persian-specific letters + if _persian_chars_re.search(text): + return True + + # 2) Heuristic check for common Persian words (fast, precompiled) + if _PERSIAN_HEURISTIC.search(text): + return True + + # 3) Fallback: language detection (more expensive) — only run if langdetect is available + if _langdetect_available and detect is not None: + try: + return detect(text) == "fa" + except Exception as exc: + # langdetect can fail for very short/ambiguous texts; log and continue + logger.debug("langdetect detection failed: %s", exc) + + return False + class WorkflowServiceInterface(Protocol): def get_draft_workflow(self, app_model: App, workflow_id: str | None = None) -> Workflow | None: @@ -53,6 +101,8 @@ class LLMGenerator: ): prompt = CONVERSATION_TITLE_PROMPT + # _contains_persian is implemented at module scope for reuse and testability + if len(query) > 2000: query = query[:300] + "...[TRUNCATED]..." + query[-300:] @@ -65,35 +115,155 @@ class LLMGenerator: tenant_id=tenant_id, model_type=ModelType.LLM, ) + + # If the input contains Persian characters, add explicit instruction to produce Persian title + is_persian_input = _contains_persian(query) + + if is_persian_input: + prompt += ( + "\nIMPORTANT: The user input is Persian (Farsi). " + "Only output the final title in Persian (Farsi), use Persian characters " + "(پ, چ, ژ, گ, ک, ی) and do NOT use Arabic or any other language.\n" + ) + prompts = [UserPromptMessage(content=prompt)] with measure_time() as timer: - response: LLMResult = model_instance.invoke_llm( - prompt_messages=list(prompts), model_parameters={"max_tokens": 500, "temperature": 1}, stream=False - ) - answer = cast(str, response.message.content) - if answer is None: - return "" - try: - result_dict = json.loads(answer) - except json.JSONDecodeError: - result_dict = json_repair.loads(answer) + # Try generation with up to 2 attempts. + # If Persian required but not produced, retry with stronger instruction. + attempts = 0 + max_attempts = 2 + generated_output = None - if not isinstance(result_dict, dict): - answer = query - else: - output = result_dict.get("Your Output") - if isinstance(output, str) and output.strip(): - answer = output.strip() + while attempts < max_attempts: + attempts += 1 + try: + response: LLMResult = model_instance.invoke_llm( + prompt_messages=list(prompts), + model_parameters={"max_tokens": 500, "temperature": 0.2}, + stream=False, + ) + except (InvokeError, InvokeAuthorizationError): + logger.exception("Failed to invoke LLM for conversation name generation") + break + + answer = cast(str, response.message.content) + + def _extract_and_parse_json(raw_text: str) -> dict | None: + if not raw_text: + return None + # Try to extract JSON object by braces + first_brace = raw_text.find("{") + last_brace = raw_text.rfind("}") + if first_brace != -1 and last_brace != -1 and last_brace > first_brace: + candidate_json = raw_text[first_brace : last_brace + 1] + else: + candidate_json = raw_text + + # Try normal json loads, then attempt to repair malformed JSON + try: + parsed = json.loads(candidate_json) + # Only accept dict results for structured conversation title parsing + return parsed if isinstance(parsed, dict) else None + except json.JSONDecodeError: + # Prefer a json_repair.loads implementation if available + json_repair_loads = getattr(json_repair, "loads", None) + if callable(json_repair_loads): + try: + repaired_parsed = json_repair_loads(candidate_json) + if isinstance(repaired_parsed, dict): + return repaired_parsed + # If the repair function returns a string, try parsing it + if isinstance(repaired_parsed, str): + try: + parsed2 = json.loads(repaired_parsed) + return parsed2 if isinstance(parsed2, dict) else None + except Exception: + return None + return None + except Exception as exc: + logger.debug("json_repair.loads failed: %s", exc) + return None + + # Otherwise try to call a 'repair' function if present and parse result + json_repair_repair = getattr(json_repair, "repair", None) + if callable(json_repair_repair): + try: + repaired = json_repair_repair(candidate_json) + if isinstance(repaired, (dict, list)): + return repaired if isinstance(repaired, dict) else None + if isinstance(repaired, str): + parsed = json.loads(repaired) + return parsed if isinstance(parsed, dict) else None + return None + except Exception as exc: + logger.debug("json_repair.repair failed: %s", exc) + return None + + logger.debug("No suitable json_repair function available to repair JSON") + return None + + result_dict = _extract_and_parse_json(answer) + + if not isinstance(result_dict, dict): + candidate = query + else: + candidate = result_dict.get("Your Output", "") + + # If input is Persian, ensure candidate contains Persian-specific characters. + # Otherwise retry with stronger instruction. + if is_persian_input and not _contains_persian(candidate): + logger.info("Generated title doesn't appear to be Persian; retrying with stricter instruction") + prompts = [ + UserPromptMessage( + content=( + prompt + "\nCRITICAL: You must output the title in Persian (Farsi) " + "using Persian-specific letters (پ, چ, ژ, گ, ک, ی). " + "Output only the JSON as specified earlier." + ) + ) + ] + continue + + generated_output = candidate.strip() + break + + if generated_output: + name = generated_output else: - answer = query + # Use the last non-Persian candidate (if any) so that the translation fallback + # can translate the generated candidate into Persian. Otherwise fall back to + # the original query. + last_candidate = locals().get("candidate", None) + name = last_candidate.strip() if isinstance(last_candidate, str) and last_candidate else (query or "") - name = answer.strip() + if is_persian_input and not _contains_persian(name): + # As a last resort, ask the model to translate the title into Persian directly + try: + translate_prompt = UserPromptMessage( + content=( + "Translate the following short chat title into Persian (Farsi) ONLY. " + "Output the Persian translation only (no JSON):\n\n" + f"{name}" + ) + ) + translate_response: LLMResult = model_instance.invoke_llm( + prompt_messages=[translate_prompt], + model_parameters={"max_tokens": 200, "temperature": 0}, + stream=False, + ) + translation = cast(str, translate_response.message.content).strip() + if _contains_persian(translation): + name = translation + except (InvokeError, InvokeAuthorizationError): + logger.exception("Failed to obtain Persian translation for the conversation title") if len(name) > 75: name = name[:75] + "..." # get tracing instance + from core.ops.ops_trace_manager import TraceQueueManager, TraceTask + trace_manager = TraceQueueManager(app_id=app_id) trace_manager.add_trace_task( TraceTask( @@ -380,10 +550,46 @@ class LLMGenerator: if not isinstance(raw_content, str): raise ValueError(f"LLM response content must be a string, got: {type(raw_content)}") + # Initialize parsed_content to ensure the variable is always bound for type-checkers + parsed_content: dict | list | None = None try: parsed_content = json.loads(raw_content) except json.JSONDecodeError: - parsed_content = json_repair.loads(raw_content) + # Prefer a json_repair.loads implementation if available + json_repair_loads = getattr(json_repair, "loads", None) + if callable(json_repair_loads): + try: + parsed_candidate = json_repair_loads(raw_content) + # Accept dict or list directly + if isinstance(parsed_candidate, (dict, list)): + parsed_content = parsed_candidate + elif isinstance(parsed_candidate, str): + try: + parsed2 = json.loads(parsed_candidate) + parsed_content = parsed2 if isinstance(parsed2, (dict, list)) else None + except Exception as exc: + logger.debug("json_repair.loads returned a string that failed to parse: %s", exc) + parsed_content = None + else: + parsed_content = None + except Exception as exc: + logger.debug("json_repair.loads failed: %s", exc) + parsed_content = None + else: + # As a fallback, use a 'repair' function followed by json.loads + json_repair_repair = getattr(json_repair, "repair", None) + if callable(json_repair_repair): + try: + repaired = json_repair_repair(raw_content) + if isinstance(repaired, (dict, list)): + parsed_content = repaired + elif isinstance(repaired, str): + parsed_content = json.loads(repaired) + else: + parsed_content = None + except Exception as exc: + logger.debug("json_repair.repair failed: %s", exc) + parsed_content = None if not isinstance(parsed_content, dict | list): raise ValueError(f"Failed to parse structured output from llm: {raw_content}") @@ -561,16 +767,11 @@ class LLMGenerator: prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False ) - generated_raw = response.message.get_text_content() + generated_raw = cast(str, response.message.content) first_brace = generated_raw.find("{") last_brace = generated_raw.rfind("}") - if first_brace == -1 or last_brace == -1 or last_brace < first_brace: - raise ValueError(f"Could not find a valid JSON object in response: {generated_raw}") - json_str = generated_raw[first_brace : last_brace + 1] - data = json_repair.loads(json_str) - if not isinstance(data, dict): - raise TypeError(f"Expected a JSON object, but got {type(data).__name__}") - return data + return {**json.loads(generated_raw[first_brace : last_brace + 1])} + except InvokeError as e: error = str(e) return {"error": f"Failed to generate code. Error: {error}"} diff --git a/api/core/llm_generator/prompts.py b/api/core/llm_generator/prompts.py index ec2b7f2d44..35fc1b4bfd 100644 --- a/api/core/llm_generator/prompts.py +++ b/api/core/llm_generator/prompts.py @@ -11,6 +11,8 @@ Automatically identify the language of the user’s input (e.g. English, Chinese - The title must be natural, friendly, and in the same language as the input. - If the input is a direct question to the model, you may add an emoji at the end. +- Special Note for Persian (Farsi): If the input is Persian (Farsi), ALWAYS generate the title in Persian (Farsi). Prefer using distinctly Persian characters (for example: پ، چ، ژ، گ). You may also use ک and ی, but prefer the Persian form (e.g., U+06CC for "ye"). Ensure the "Language Type" field is "Persian" or "Farsi". Do NOT use Arabic or any other language or script when the input is Persian. + 3. Output Format Return **only** a valid JSON object with these exact keys and no additional text: { diff --git a/api/pyproject.toml b/api/pyproject.toml index dbc6a2eb83..0fc4073d48 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -35,6 +35,7 @@ dependencies = [ "jsonschema>=4.25.1", "langfuse~=2.51.3", "langsmith~=0.1.77", + "langdetect~=1.0.9", "markdown~=3.5.1", "mlflow-skinny>=3.0.0", "numpy~=1.26.4", diff --git a/api/tests/unit_tests/core/test_llm_generator_persian.py b/api/tests/unit_tests/core/test_llm_generator_persian.py new file mode 100644 index 0000000000..e982467994 --- /dev/null +++ b/api/tests/unit_tests/core/test_llm_generator_persian.py @@ -0,0 +1,110 @@ +import json +from unittest.mock import MagicMock, patch + +from core.llm_generator.llm_generator import LLMGenerator + + +class DummyMessage: + def __init__(self, content): + self.content = content + + +class DummyResponse: + def __init__(self, content): + self.message = DummyMessage(content) + + +def make_json_response(language, output): + return json.dumps({"Language Type": language, "Your Reasoning": "...", "Your Output": output}) + + +@patch("core.llm_generator.llm_generator.ModelManager.get_default_model_instance") +def test_generate_conversation_name_enforces_persian(mock_get_model): + # A Persian input containing Persian-specific character 'پ' + persian_query = "سلام، چطوری؟ پ" # contains 'پ' + + # First model response: misdetected as Arabic and returns Arabic title + first_resp = DummyResponse(make_json_response("Arabic", "مرحبا")) + # Second response (after retry): returns a Persian title with Persian-specific chars + second_resp = DummyResponse(make_json_response("Persian", "عنوان پِرس")) + + model_instance = MagicMock() + model_instance.invoke_llm.side_effect = [first_resp, second_resp] + + mock_get_model.return_value = model_instance + + name = LLMGenerator.generate_conversation_name("tenant1", persian_query) + + # The final name should come from the Persian response (contains Persian-specific char 'پ') + assert "پ" in name + # Ensure the model was invoked at least twice (retry occurred) + assert model_instance.invoke_llm.call_count >= 2 + + +@patch("core.llm_generator.llm_generator.ModelManager.get_default_model_instance") +def test_generate_conversation_name_translation_fallback(mock_get_model): + # Persian query + persian_query = "این یک تست است پ" + + # Model returns non-Persian outputs consistently + non_persian_resp = DummyResponse(make_json_response("Arabic", "مرحبا")) + + # Translate response (last call) returns Persian translation + translate_resp = DummyResponse("عنوان ترجمه شده پ") + + model_instance = MagicMock() + # First two calls return non-persian results; third call is translation + model_instance.invoke_llm.side_effect = [non_persian_resp, non_persian_resp, translate_resp] + + mock_get_model.return_value = model_instance + + name = LLMGenerator.generate_conversation_name("tenant1", persian_query) + + # Final name should contain Persian character 'پ' from translation fallback + assert "پ" in name + assert model_instance.invoke_llm.call_count >= 3 + + +@patch("core.llm_generator.llm_generator.ModelManager.get_default_model_instance") +def test_generate_conversation_name_enforces_persian_retry_prompt(mock_get_model): + # A Persian input containing Persian-specific character 'پ' + persian_query = "سلام، چطوری؟ پ" + + # First model response: misdetected as Arabic and returns Arabic title + first_resp = DummyResponse(make_json_response("Arabic", "مرحبا")) + # Second response (after retry): returns a Persian title with Persian-specific chars + second_resp = DummyResponse(make_json_response("Persian", "عنوان پِرس")) + + model_instance = MagicMock() + model_instance.invoke_llm.side_effect = [first_resp, second_resp] + + mock_get_model.return_value = model_instance + + name = LLMGenerator.generate_conversation_name("tenant1", persian_query) + + # The final name should come from the Persian response (contains Persian-specific char 'پ') + assert "پ" in name + + # Ensure the retry prompt included a stronger Persian-only instruction + assert model_instance.invoke_llm.call_count >= 2 + second_call_kwargs = model_instance.invoke_llm.call_args_list[1][1] + prompt_msg = second_call_kwargs["prompt_messages"][0] + assert "CRITICAL: You must output the title in Persian" in prompt_msg.content + + +@patch("core.llm_generator.llm_generator.ModelManager.get_default_model_instance") +def test_generate_conversation_name_handles_invoke_error(mock_get_model): + # If LLM invocation raises InvokeError, ensure fallback/translation is attempted and no exception bubbles + from core.model_runtime.errors.invoke import InvokeError + + persian_query = "سلام، پ" + + model_instance = MagicMock() + # First invocation raises InvokeError; translation attempt returns Persian translation + model_instance.invoke_llm.side_effect = [InvokeError("boom"), DummyResponse("عنوان ترجمه شده پ")] + + mock_get_model.return_value = model_instance + + name = LLMGenerator.generate_conversation_name("tenant1", persian_query) + + assert "پ" in name diff --git a/api/uv.lock b/api/uv.lock index 4ccd229eec..d0bff50df1 100644 --- a/api/uv.lock +++ b/api/uv.lock @@ -1404,6 +1404,7 @@ dependencies = [ { name = "jieba" }, { name = "json-repair" }, { name = "jsonschema" }, + { name = "langdetect" }, { name = "langfuse" }, { name = "langsmith" }, { name = "litellm" }, @@ -1602,6 +1603,7 @@ requires-dist = [ { name = "jieba", specifier = "==0.42.1" }, { name = "json-repair", specifier = ">=0.41.1" }, { name = "jsonschema", specifier = ">=4.25.1" }, + { name = "langdetect", specifier = "~=1.0.9" }, { name = "langfuse", specifier = "~=2.51.3" }, { name = "langsmith", specifier = "~=0.1.77" }, { name = "litellm", specifier = "==1.77.1" }, diff --git a/docker/docker-compose.middleware.yaml b/docker/docker-compose.middleware.yaml index dba61d1816..a9a26939c9 100644 --- a/docker/docker-compose.middleware.yaml +++ b/docker/docker-compose.middleware.yaml @@ -176,6 +176,12 @@ services: THIRD_PARTY_SIGNATURE_VERIFICATION_ENABLED: true THIRD_PARTY_SIGNATURE_VERIFICATION_PUBLIC_KEYS: /app/keys/publickey.pem FORCE_VERIFYING_SIGNATURE: false + + HTTP_PROXY: ${HTTP_PROXY:-http://ssrf_proxy:3128} + HTTPS_PROXY: ${HTTPS_PROXY:-http://ssrf_proxy:3128} + PLUGIN_PYTHON_ENV_INIT_TIMEOUT: ${PLUGIN_PYTHON_ENV_INIT_TIMEOUT:-120} + extra_hosts: + - "host.docker.internal:host-gateway" ports: - "${EXPOSE_PLUGIN_DAEMON_PORT:-5002}:${PLUGIN_DAEMON_PORT:-5002}" - "${EXPOSE_PLUGIN_DEBUGGING_PORT:-5003}:${PLUGIN_DEBUGGING_PORT:-5003}" diff --git a/tests/unit_tests/core/llm_generator/test_llm_generator_persian.py b/tests/unit_tests/core/llm_generator/test_llm_generator_persian.py new file mode 100644 index 0000000000..7968f28d75 --- /dev/null +++ b/tests/unit_tests/core/llm_generator/test_llm_generator_persian.py @@ -0,0 +1,256 @@ +import sys +import types +import json +from pathlib import Path + +# Ensure the repo `api/` directory is importable so tests can import `core.*` without external env setup +ROOT = Path(__file__).resolve().parents[3] +sys.path.insert(0, str(ROOT / "api")) + +# Lightweight stubs to avoid importing heavy application modules during unit tests +m = types.ModuleType("core.model_manager") + + +class ModelManager: + def get_default_model_instance(self, tenant_id, model_type): + raise NotImplementedError + + def get_model_instance(self, tenant_id, model_type, provider=None, model=None): + raise NotImplementedError + + +m.ModelManager = ModelManager +sys.modules["core.model_manager"] = m + +m2 = types.ModuleType("core.ops.ops_trace_manager") + + +class TraceTask: + def __init__(self, *args, **kwargs): + # store attributes for potential inspection in tests + for k, v in kwargs.items(): + setattr(self, k, v) + self.args = args + self.kwargs = kwargs + + +class TraceQueueManager: + def __init__(self, *a, **k): + pass + + def add_trace_task(self, *a, **k): + pass + + +m2.TraceTask = TraceTask +m2.TraceQueueManager = TraceQueueManager +sys.modules["core.ops.ops_trace_manager"] = m2 + +# Stub core.ops.utils to avoid importing heavy dependencies (db, models) during tests +m_ops = types.ModuleType("core.ops.utils") +from contextlib import contextmanager + + +@contextmanager +def measure_time(): + class Timer: + pass + + t = Timer() + yield t + + +m_ops.measure_time = measure_time +sys.modules["core.ops.utils"] = m_ops + +m3 = types.ModuleType("core.model_runtime.entities.llm_entities") + + +class LLMUsage: + @classmethod + def empty_usage(cls): + return cls() + + +class LLMResult: + def __init__(self, model=None, prompt_messages=None, message=None, usage=None): + self.model = model + self.prompt_messages = prompt_messages + self.message = message + self.usage = usage + + +m3.LLMUsage = LLMUsage +m3.LLMResult = LLMResult +sys.modules["core.model_runtime.entities.llm_entities"] = m3 + +m4 = types.ModuleType("core.model_runtime.entities.message_entities") + + +class PromptMessage: + def __init__(self, content=None): + self.content = content + + def get_text_content(self): + return str(self.content) if self.content is not None else "" + + +class TextPromptMessageContent: + def __init__(self, data): + self.data = data + + +class ImagePromptMessageContent: + def __init__(self, url=None, base64_data=None, mime_type=None, filename=None): + self.url = url + self.base64_data = base64_data + self.mime_type = mime_type + self.filename = filename + + +class DocumentPromptMessageContent: + def __init__(self, url=None): + self.url = url + + +class AudioPromptMessageContent(DocumentPromptMessageContent): + pass + + +class VideoPromptMessageContent(DocumentPromptMessageContent): + pass + + +class AssistantPromptMessage(PromptMessage): + def __init__(self, content): + super().__init__(content) + + +class UserPromptMessage(PromptMessage): + def __init__(self, content): + super().__init__(content) + + +class SystemPromptMessage(PromptMessage): + def __init__(self, content=None): + super().__init__(content) + + +m4.PromptMessage = PromptMessage +m4.AssistantPromptMessage = AssistantPromptMessage +m4.UserPromptMessage = UserPromptMessage +m4.SystemPromptMessage = SystemPromptMessage +m4.TextPromptMessageContent = TextPromptMessageContent +m4.ImagePromptMessageContent = ImagePromptMessageContent +m4.DocumentPromptMessageContent = DocumentPromptMessageContent +m4.AudioPromptMessageContent = AudioPromptMessageContent +m4.VideoPromptMessageContent = VideoPromptMessageContent +sys.modules["core.model_runtime.entities.message_entities"] = m4 + +m5 = types.ModuleType("core.model_runtime.entities.model_entities") + + +class ModelType: + LLM = None + + +m5.ModelType = ModelType +sys.modules["core.model_runtime.entities.model_entities"] = m5 + +# Stub minimal 'extensions' and 'models' packages to avoid importing heavy application code during tests +ext_db = types.ModuleType("extensions.ext_database") +ext_db.db = None +sys.modules["extensions.ext_database"] = ext_db +ext_storage = types.ModuleType("extensions.ext_storage") +ext_storage.storage = None +sys.modules["extensions.ext_storage"] = ext_storage + +models_m = types.ModuleType("models") + + +class App: + pass + + +class Message: + pass + + +class WorkflowNodeExecutionModel: + pass + + +models_m.App = App +models_m.Message = Message +models_m.WorkflowNodeExecutionModel = WorkflowNodeExecutionModel +sys.modules["models"] = models_m + +models_workflow = types.ModuleType("models.workflow") + + +class Workflow: + pass + + +models_workflow.Workflow = Workflow +sys.modules["models.workflow"] = models_workflow + +from core.llm_generator.llm_generator import LLMGenerator +from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage +from core.model_runtime.entities.message_entities import AssistantPromptMessage +from core.model_manager import ModelManager + + +class DummyModelInstance: + def __init__(self, content): + self._content = content + + def invoke_llm(self, prompt_messages=None, model_parameters=None, stream=False): + # Return an LLMResult-like object with the message content we expect + return LLMResult( + model="dummy", + prompt_messages=[], + message=AssistantPromptMessage(content=self._content), + usage=LLMUsage.empty_usage(), + ) + + +def test_generate_conversation_name_persian(monkeypatch): + # Arrange: Persian input that doesn't necessarily include Persian-specific letters + query = "سلام دوست من، میخواهم درباره تنظیمات حساب صحبت کنم" + + # Mock the default model instance to return a Persian title in JSON format + fake_output = json.dumps({"Your Output": "عنوان تستی"}) + dummy = DummyModelInstance(fake_output) + + monkeypatch.setattr(ModelManager, "get_default_model_instance", lambda self, tenant_id, model_type: dummy) + + # Act + name = LLMGenerator.generate_conversation_name("tenant1", query) + + # Assert: title should be the Persian string we returned + assert name == "عنوان تستی" + + +def test_contains_persian_character_and_heuristics(monkeypatch): + from core.llm_generator.llm_generator import _contains_persian, _persian_chars_re, _PERSIAN_HEURISTIC + + # By single Persian-specific character + assert _contains_persian("این یک تست پ") is True + + # By heuristic Persian word + assert _contains_persian("سلام دوست") is True + + +def test_contains_persian_langdetect_fallback(monkeypatch): + import core.llm_generator.llm_generator as lg + + # Simulate langdetect being available and detecting Persian + monkeypatch.setattr(lg, "_langdetect_available", True) + monkeypatch.setattr(lg, "detect", lambda text: "fa") + + assert lg._contains_persian("short ambiguous text") is True + + # Reset monkeypatch + monkeypatch.setattr(lg, "_langdetect_available", False) + monkeypatch.setattr(lg, "detect", None)