mirror of https://github.com/langgenius/dify.git
Merge 97f750fa22 into 1e86535c4a
This commit is contained in:
commit
b11bbdeed8
|
|
@ -275,14 +275,12 @@ class RagPipelineGenerateEntity(WorkflowAppGenerateEntity):
|
|||
start_node_id: str | None = None
|
||||
|
||||
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
|
||||
AppGenerateEntity.model_rebuild()
|
||||
EasyUIBasedAppGenerateEntity.model_rebuild()
|
||||
ConversationAppGenerateEntity.model_rebuild()
|
||||
ChatAppGenerateEntity.model_rebuild()
|
||||
CompletionAppGenerateEntity.model_rebuild()
|
||||
AgentChatAppGenerateEntity.model_rebuild()
|
||||
AdvancedChatAppGenerateEntity.model_rebuild()
|
||||
WorkflowAppGenerateEntity.model_rebuild()
|
||||
RagPipelineGenerateEntity.model_rebuild()
|
||||
AppGenerateEntity.model_rebuild(raise_errors=False)
|
||||
EasyUIBasedAppGenerateEntity.model_rebuild(raise_errors=False)
|
||||
ConversationAppGenerateEntity.model_rebuild(raise_errors=False)
|
||||
ChatAppGenerateEntity.model_rebuild(raise_errors=False)
|
||||
CompletionAppGenerateEntity.model_rebuild(raise_errors=False)
|
||||
AgentChatAppGenerateEntity.model_rebuild(raise_errors=False)
|
||||
AdvancedChatAppGenerateEntity.model_rebuild(raise_errors=False)
|
||||
WorkflowAppGenerateEntity.model_rebuild(raise_errors=False)
|
||||
RagPipelineGenerateEntity.model_rebuild(raise_errors=False)
|
||||
|
|
|
|||
|
|
@ -26,7 +26,6 @@ from core.model_runtime.entities.message_entities import PromptMessage, SystemPr
|
|||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||
from core.ops.entities.trace_entity import TraceTaskName
|
||||
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
|
||||
from core.ops.utils import measure_time
|
||||
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
|
||||
|
|
@ -37,6 +36,55 @@ from models.workflow import Workflow
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Precompiled heuristic to detect common Persian (Farsi) words in short inputs.
|
||||
# Using a compiled regex avoids repeated recompilation on every call.
|
||||
_PERSIAN_HEURISTIC = re.compile(
|
||||
r"\b(سلام|متشکرم|ممنون|خوب|چطور|سپاس)\b",
|
||||
flags=re.IGNORECASE,
|
||||
)
|
||||
|
||||
# Precompiled regex for Persian-specific characters (including Persian ye U+06CC)
|
||||
_persian_chars_re = re.compile(r"[پچژگک\u06CC]")
|
||||
|
||||
# Optional langdetect import — import once at module import time to avoid repeated lookups
|
||||
_langdetect_available = False
|
||||
try:
|
||||
from langdetect import DetectorFactory, detect # type: ignore
|
||||
|
||||
DetectorFactory.seed = 0
|
||||
_langdetect_available = True
|
||||
except Exception:
|
||||
detect = None
|
||||
DetectorFactory = None
|
||||
_langdetect_available = False
|
||||
|
||||
|
||||
def _contains_persian(text: str) -> bool:
|
||||
"""Return True if text appears to be Persian (Farsi).
|
||||
|
||||
Detection is multi-layered: quick character check, word heuristics, and
|
||||
an optional langdetect fallback when available.
|
||||
"""
|
||||
text = text or ""
|
||||
|
||||
# 1) Quick check: Persian-specific letters
|
||||
if _persian_chars_re.search(text):
|
||||
return True
|
||||
|
||||
# 2) Heuristic check for common Persian words (fast, precompiled)
|
||||
if _PERSIAN_HEURISTIC.search(text):
|
||||
return True
|
||||
|
||||
# 3) Fallback: language detection (more expensive) — only run if langdetect is available
|
||||
if _langdetect_available and detect is not None:
|
||||
try:
|
||||
return detect(text) == "fa"
|
||||
except Exception as exc:
|
||||
# langdetect can fail for very short/ambiguous texts; log and continue
|
||||
logger.debug("langdetect detection failed: %s", exc)
|
||||
|
||||
return False
|
||||
|
||||
|
||||
class WorkflowServiceInterface(Protocol):
|
||||
def get_draft_workflow(self, app_model: App, workflow_id: str | None = None) -> Workflow | None:
|
||||
|
|
@ -53,6 +101,8 @@ class LLMGenerator:
|
|||
):
|
||||
prompt = CONVERSATION_TITLE_PROMPT
|
||||
|
||||
# _contains_persian is implemented at module scope for reuse and testability
|
||||
|
||||
if len(query) > 2000:
|
||||
query = query[:300] + "...[TRUNCATED]..." + query[-300:]
|
||||
|
||||
|
|
@ -65,35 +115,155 @@ class LLMGenerator:
|
|||
tenant_id=tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
)
|
||||
|
||||
# If the input contains Persian characters, add explicit instruction to produce Persian title
|
||||
is_persian_input = _contains_persian(query)
|
||||
|
||||
if is_persian_input:
|
||||
prompt += (
|
||||
"\nIMPORTANT: The user input is Persian (Farsi). "
|
||||
"Only output the final title in Persian (Farsi), use Persian characters "
|
||||
"(پ, چ, ژ, گ, ک, ی) and do NOT use Arabic or any other language.\n"
|
||||
)
|
||||
|
||||
prompts = [UserPromptMessage(content=prompt)]
|
||||
|
||||
with measure_time() as timer:
|
||||
response: LLMResult = model_instance.invoke_llm(
|
||||
prompt_messages=list(prompts), model_parameters={"max_tokens": 500, "temperature": 1}, stream=False
|
||||
)
|
||||
answer = cast(str, response.message.content)
|
||||
if answer is None:
|
||||
return ""
|
||||
try:
|
||||
result_dict = json.loads(answer)
|
||||
except json.JSONDecodeError:
|
||||
result_dict = json_repair.loads(answer)
|
||||
# Try generation with up to 2 attempts.
|
||||
# If Persian required but not produced, retry with stronger instruction.
|
||||
attempts = 0
|
||||
max_attempts = 2
|
||||
generated_output = None
|
||||
|
||||
if not isinstance(result_dict, dict):
|
||||
answer = query
|
||||
else:
|
||||
output = result_dict.get("Your Output")
|
||||
if isinstance(output, str) and output.strip():
|
||||
answer = output.strip()
|
||||
while attempts < max_attempts:
|
||||
attempts += 1
|
||||
try:
|
||||
response: LLMResult = model_instance.invoke_llm(
|
||||
prompt_messages=list(prompts),
|
||||
model_parameters={"max_tokens": 500, "temperature": 0.2},
|
||||
stream=False,
|
||||
)
|
||||
except (InvokeError, InvokeAuthorizationError):
|
||||
logger.exception("Failed to invoke LLM for conversation name generation")
|
||||
break
|
||||
|
||||
answer = cast(str, response.message.content)
|
||||
|
||||
def _extract_and_parse_json(raw_text: str) -> dict | None:
|
||||
if not raw_text:
|
||||
return None
|
||||
# Try to extract JSON object by braces
|
||||
first_brace = raw_text.find("{")
|
||||
last_brace = raw_text.rfind("}")
|
||||
if first_brace != -1 and last_brace != -1 and last_brace > first_brace:
|
||||
candidate_json = raw_text[first_brace : last_brace + 1]
|
||||
else:
|
||||
candidate_json = raw_text
|
||||
|
||||
# Try normal json loads, then attempt to repair malformed JSON
|
||||
try:
|
||||
parsed = json.loads(candidate_json)
|
||||
# Only accept dict results for structured conversation title parsing
|
||||
return parsed if isinstance(parsed, dict) else None
|
||||
except json.JSONDecodeError:
|
||||
# Prefer a json_repair.loads implementation if available
|
||||
json_repair_loads = getattr(json_repair, "loads", None)
|
||||
if callable(json_repair_loads):
|
||||
try:
|
||||
repaired_parsed = json_repair_loads(candidate_json)
|
||||
if isinstance(repaired_parsed, dict):
|
||||
return repaired_parsed
|
||||
# If the repair function returns a string, try parsing it
|
||||
if isinstance(repaired_parsed, str):
|
||||
try:
|
||||
parsed2 = json.loads(repaired_parsed)
|
||||
return parsed2 if isinstance(parsed2, dict) else None
|
||||
except Exception:
|
||||
return None
|
||||
return None
|
||||
except Exception as exc:
|
||||
logger.debug("json_repair.loads failed: %s", exc)
|
||||
return None
|
||||
|
||||
# Otherwise try to call a 'repair' function if present and parse result
|
||||
json_repair_repair = getattr(json_repair, "repair", None)
|
||||
if callable(json_repair_repair):
|
||||
try:
|
||||
repaired = json_repair_repair(candidate_json)
|
||||
if isinstance(repaired, (dict, list)):
|
||||
return repaired if isinstance(repaired, dict) else None
|
||||
if isinstance(repaired, str):
|
||||
parsed = json.loads(repaired)
|
||||
return parsed if isinstance(parsed, dict) else None
|
||||
return None
|
||||
except Exception as exc:
|
||||
logger.debug("json_repair.repair failed: %s", exc)
|
||||
return None
|
||||
|
||||
logger.debug("No suitable json_repair function available to repair JSON")
|
||||
return None
|
||||
|
||||
result_dict = _extract_and_parse_json(answer)
|
||||
|
||||
if not isinstance(result_dict, dict):
|
||||
candidate = query
|
||||
else:
|
||||
candidate = result_dict.get("Your Output", "")
|
||||
|
||||
# If input is Persian, ensure candidate contains Persian-specific characters.
|
||||
# Otherwise retry with stronger instruction.
|
||||
if is_persian_input and not _contains_persian(candidate):
|
||||
logger.info("Generated title doesn't appear to be Persian; retrying with stricter instruction")
|
||||
prompts = [
|
||||
UserPromptMessage(
|
||||
content=(
|
||||
prompt + "\nCRITICAL: You must output the title in Persian (Farsi) "
|
||||
"using Persian-specific letters (پ, چ, ژ, گ, ک, ی). "
|
||||
"Output only the JSON as specified earlier."
|
||||
)
|
||||
)
|
||||
]
|
||||
continue
|
||||
|
||||
generated_output = candidate.strip()
|
||||
break
|
||||
|
||||
if generated_output:
|
||||
name = generated_output
|
||||
else:
|
||||
answer = query
|
||||
# Use the last non-Persian candidate (if any) so that the translation fallback
|
||||
# can translate the generated candidate into Persian. Otherwise fall back to
|
||||
# the original query.
|
||||
last_candidate = locals().get("candidate", None)
|
||||
name = last_candidate.strip() if isinstance(last_candidate, str) and last_candidate else (query or "")
|
||||
|
||||
name = answer.strip()
|
||||
if is_persian_input and not _contains_persian(name):
|
||||
# As a last resort, ask the model to translate the title into Persian directly
|
||||
try:
|
||||
translate_prompt = UserPromptMessage(
|
||||
content=(
|
||||
"Translate the following short chat title into Persian (Farsi) ONLY. "
|
||||
"Output the Persian translation only (no JSON):\n\n"
|
||||
f"{name}"
|
||||
)
|
||||
)
|
||||
translate_response: LLMResult = model_instance.invoke_llm(
|
||||
prompt_messages=[translate_prompt],
|
||||
model_parameters={"max_tokens": 200, "temperature": 0},
|
||||
stream=False,
|
||||
)
|
||||
translation = cast(str, translate_response.message.content).strip()
|
||||
if _contains_persian(translation):
|
||||
name = translation
|
||||
except (InvokeError, InvokeAuthorizationError):
|
||||
logger.exception("Failed to obtain Persian translation for the conversation title")
|
||||
|
||||
if len(name) > 75:
|
||||
name = name[:75] + "..."
|
||||
|
||||
# get tracing instance
|
||||
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
|
||||
|
||||
trace_manager = TraceQueueManager(app_id=app_id)
|
||||
trace_manager.add_trace_task(
|
||||
TraceTask(
|
||||
|
|
@ -380,10 +550,46 @@ class LLMGenerator:
|
|||
if not isinstance(raw_content, str):
|
||||
raise ValueError(f"LLM response content must be a string, got: {type(raw_content)}")
|
||||
|
||||
# Initialize parsed_content to ensure the variable is always bound for type-checkers
|
||||
parsed_content: dict | list | None = None
|
||||
try:
|
||||
parsed_content = json.loads(raw_content)
|
||||
except json.JSONDecodeError:
|
||||
parsed_content = json_repair.loads(raw_content)
|
||||
# Prefer a json_repair.loads implementation if available
|
||||
json_repair_loads = getattr(json_repair, "loads", None)
|
||||
if callable(json_repair_loads):
|
||||
try:
|
||||
parsed_candidate = json_repair_loads(raw_content)
|
||||
# Accept dict or list directly
|
||||
if isinstance(parsed_candidate, (dict, list)):
|
||||
parsed_content = parsed_candidate
|
||||
elif isinstance(parsed_candidate, str):
|
||||
try:
|
||||
parsed2 = json.loads(parsed_candidate)
|
||||
parsed_content = parsed2 if isinstance(parsed2, (dict, list)) else None
|
||||
except Exception as exc:
|
||||
logger.debug("json_repair.loads returned a string that failed to parse: %s", exc)
|
||||
parsed_content = None
|
||||
else:
|
||||
parsed_content = None
|
||||
except Exception as exc:
|
||||
logger.debug("json_repair.loads failed: %s", exc)
|
||||
parsed_content = None
|
||||
else:
|
||||
# As a fallback, use a 'repair' function followed by json.loads
|
||||
json_repair_repair = getattr(json_repair, "repair", None)
|
||||
if callable(json_repair_repair):
|
||||
try:
|
||||
repaired = json_repair_repair(raw_content)
|
||||
if isinstance(repaired, (dict, list)):
|
||||
parsed_content = repaired
|
||||
elif isinstance(repaired, str):
|
||||
parsed_content = json.loads(repaired)
|
||||
else:
|
||||
parsed_content = None
|
||||
except Exception as exc:
|
||||
logger.debug("json_repair.repair failed: %s", exc)
|
||||
parsed_content = None
|
||||
|
||||
if not isinstance(parsed_content, dict | list):
|
||||
raise ValueError(f"Failed to parse structured output from llm: {raw_content}")
|
||||
|
|
@ -561,16 +767,11 @@ class LLMGenerator:
|
|||
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
|
||||
)
|
||||
|
||||
generated_raw = response.message.get_text_content()
|
||||
generated_raw = cast(str, response.message.content)
|
||||
first_brace = generated_raw.find("{")
|
||||
last_brace = generated_raw.rfind("}")
|
||||
if first_brace == -1 or last_brace == -1 or last_brace < first_brace:
|
||||
raise ValueError(f"Could not find a valid JSON object in response: {generated_raw}")
|
||||
json_str = generated_raw[first_brace : last_brace + 1]
|
||||
data = json_repair.loads(json_str)
|
||||
if not isinstance(data, dict):
|
||||
raise TypeError(f"Expected a JSON object, but got {type(data).__name__}")
|
||||
return data
|
||||
return {**json.loads(generated_raw[first_brace : last_brace + 1])}
|
||||
|
||||
except InvokeError as e:
|
||||
error = str(e)
|
||||
return {"error": f"Failed to generate code. Error: {error}"}
|
||||
|
|
|
|||
|
|
@ -11,6 +11,8 @@ Automatically identify the language of the user’s input (e.g. English, Chinese
|
|||
- The title must be natural, friendly, and in the same language as the input.
|
||||
- If the input is a direct question to the model, you may add an emoji at the end.
|
||||
|
||||
- Special Note for Persian (Farsi): If the input is Persian (Farsi), ALWAYS generate the title in Persian (Farsi). Prefer using distinctly Persian characters (for example: پ، چ، ژ، گ). You may also use ک and ی, but prefer the Persian form (e.g., U+06CC for "ye"). Ensure the "Language Type" field is "Persian" or "Farsi". Do NOT use Arabic or any other language or script when the input is Persian.
|
||||
|
||||
3. Output Format
|
||||
Return **only** a valid JSON object with these exact keys and no additional text:
|
||||
{
|
||||
|
|
|
|||
|
|
@ -35,6 +35,7 @@ dependencies = [
|
|||
"jsonschema>=4.25.1",
|
||||
"langfuse~=2.51.3",
|
||||
"langsmith~=0.1.77",
|
||||
"langdetect~=1.0.9",
|
||||
"markdown~=3.5.1",
|
||||
"mlflow-skinny>=3.0.0",
|
||||
"numpy~=1.26.4",
|
||||
|
|
|
|||
|
|
@ -0,0 +1,110 @@
|
|||
import json
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from core.llm_generator.llm_generator import LLMGenerator
|
||||
|
||||
|
||||
class DummyMessage:
|
||||
def __init__(self, content):
|
||||
self.content = content
|
||||
|
||||
|
||||
class DummyResponse:
|
||||
def __init__(self, content):
|
||||
self.message = DummyMessage(content)
|
||||
|
||||
|
||||
def make_json_response(language, output):
|
||||
return json.dumps({"Language Type": language, "Your Reasoning": "...", "Your Output": output})
|
||||
|
||||
|
||||
@patch("core.llm_generator.llm_generator.ModelManager.get_default_model_instance")
|
||||
def test_generate_conversation_name_enforces_persian(mock_get_model):
|
||||
# A Persian input containing Persian-specific character 'پ'
|
||||
persian_query = "سلام، چطوری؟ پ" # contains 'پ'
|
||||
|
||||
# First model response: misdetected as Arabic and returns Arabic title
|
||||
first_resp = DummyResponse(make_json_response("Arabic", "مرحبا"))
|
||||
# Second response (after retry): returns a Persian title with Persian-specific chars
|
||||
second_resp = DummyResponse(make_json_response("Persian", "عنوان پِرس"))
|
||||
|
||||
model_instance = MagicMock()
|
||||
model_instance.invoke_llm.side_effect = [first_resp, second_resp]
|
||||
|
||||
mock_get_model.return_value = model_instance
|
||||
|
||||
name = LLMGenerator.generate_conversation_name("tenant1", persian_query)
|
||||
|
||||
# The final name should come from the Persian response (contains Persian-specific char 'پ')
|
||||
assert "پ" in name
|
||||
# Ensure the model was invoked at least twice (retry occurred)
|
||||
assert model_instance.invoke_llm.call_count >= 2
|
||||
|
||||
|
||||
@patch("core.llm_generator.llm_generator.ModelManager.get_default_model_instance")
|
||||
def test_generate_conversation_name_translation_fallback(mock_get_model):
|
||||
# Persian query
|
||||
persian_query = "این یک تست است پ"
|
||||
|
||||
# Model returns non-Persian outputs consistently
|
||||
non_persian_resp = DummyResponse(make_json_response("Arabic", "مرحبا"))
|
||||
|
||||
# Translate response (last call) returns Persian translation
|
||||
translate_resp = DummyResponse("عنوان ترجمه شده پ")
|
||||
|
||||
model_instance = MagicMock()
|
||||
# First two calls return non-persian results; third call is translation
|
||||
model_instance.invoke_llm.side_effect = [non_persian_resp, non_persian_resp, translate_resp]
|
||||
|
||||
mock_get_model.return_value = model_instance
|
||||
|
||||
name = LLMGenerator.generate_conversation_name("tenant1", persian_query)
|
||||
|
||||
# Final name should contain Persian character 'پ' from translation fallback
|
||||
assert "پ" in name
|
||||
assert model_instance.invoke_llm.call_count >= 3
|
||||
|
||||
|
||||
@patch("core.llm_generator.llm_generator.ModelManager.get_default_model_instance")
|
||||
def test_generate_conversation_name_enforces_persian_retry_prompt(mock_get_model):
|
||||
# A Persian input containing Persian-specific character 'پ'
|
||||
persian_query = "سلام، چطوری؟ پ"
|
||||
|
||||
# First model response: misdetected as Arabic and returns Arabic title
|
||||
first_resp = DummyResponse(make_json_response("Arabic", "مرحبا"))
|
||||
# Second response (after retry): returns a Persian title with Persian-specific chars
|
||||
second_resp = DummyResponse(make_json_response("Persian", "عنوان پِرس"))
|
||||
|
||||
model_instance = MagicMock()
|
||||
model_instance.invoke_llm.side_effect = [first_resp, second_resp]
|
||||
|
||||
mock_get_model.return_value = model_instance
|
||||
|
||||
name = LLMGenerator.generate_conversation_name("tenant1", persian_query)
|
||||
|
||||
# The final name should come from the Persian response (contains Persian-specific char 'پ')
|
||||
assert "پ" in name
|
||||
|
||||
# Ensure the retry prompt included a stronger Persian-only instruction
|
||||
assert model_instance.invoke_llm.call_count >= 2
|
||||
second_call_kwargs = model_instance.invoke_llm.call_args_list[1][1]
|
||||
prompt_msg = second_call_kwargs["prompt_messages"][0]
|
||||
assert "CRITICAL: You must output the title in Persian" in prompt_msg.content
|
||||
|
||||
|
||||
@patch("core.llm_generator.llm_generator.ModelManager.get_default_model_instance")
|
||||
def test_generate_conversation_name_handles_invoke_error(mock_get_model):
|
||||
# If LLM invocation raises InvokeError, ensure fallback/translation is attempted and no exception bubbles
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
|
||||
persian_query = "سلام، پ"
|
||||
|
||||
model_instance = MagicMock()
|
||||
# First invocation raises InvokeError; translation attempt returns Persian translation
|
||||
model_instance.invoke_llm.side_effect = [InvokeError("boom"), DummyResponse("عنوان ترجمه شده پ")]
|
||||
|
||||
mock_get_model.return_value = model_instance
|
||||
|
||||
name = LLMGenerator.generate_conversation_name("tenant1", persian_query)
|
||||
|
||||
assert "پ" in name
|
||||
|
|
@ -1404,6 +1404,7 @@ dependencies = [
|
|||
{ name = "jieba" },
|
||||
{ name = "json-repair" },
|
||||
{ name = "jsonschema" },
|
||||
{ name = "langdetect" },
|
||||
{ name = "langfuse" },
|
||||
{ name = "langsmith" },
|
||||
{ name = "litellm" },
|
||||
|
|
@ -1602,6 +1603,7 @@ requires-dist = [
|
|||
{ name = "jieba", specifier = "==0.42.1" },
|
||||
{ name = "json-repair", specifier = ">=0.41.1" },
|
||||
{ name = "jsonschema", specifier = ">=4.25.1" },
|
||||
{ name = "langdetect", specifier = "~=1.0.9" },
|
||||
{ name = "langfuse", specifier = "~=2.51.3" },
|
||||
{ name = "langsmith", specifier = "~=0.1.77" },
|
||||
{ name = "litellm", specifier = "==1.77.1" },
|
||||
|
|
|
|||
|
|
@ -176,6 +176,12 @@ services:
|
|||
THIRD_PARTY_SIGNATURE_VERIFICATION_ENABLED: true
|
||||
THIRD_PARTY_SIGNATURE_VERIFICATION_PUBLIC_KEYS: /app/keys/publickey.pem
|
||||
FORCE_VERIFYING_SIGNATURE: false
|
||||
|
||||
HTTP_PROXY: ${HTTP_PROXY:-http://ssrf_proxy:3128}
|
||||
HTTPS_PROXY: ${HTTPS_PROXY:-http://ssrf_proxy:3128}
|
||||
PLUGIN_PYTHON_ENV_INIT_TIMEOUT: ${PLUGIN_PYTHON_ENV_INIT_TIMEOUT:-120}
|
||||
extra_hosts:
|
||||
- "host.docker.internal:host-gateway"
|
||||
ports:
|
||||
- "${EXPOSE_PLUGIN_DAEMON_PORT:-5002}:${PLUGIN_DAEMON_PORT:-5002}"
|
||||
- "${EXPOSE_PLUGIN_DEBUGGING_PORT:-5003}:${PLUGIN_DEBUGGING_PORT:-5003}"
|
||||
|
|
|
|||
|
|
@ -0,0 +1,256 @@
|
|||
import sys
|
||||
import types
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
# Ensure the repo `api/` directory is importable so tests can import `core.*` without external env setup
|
||||
ROOT = Path(__file__).resolve().parents[3]
|
||||
sys.path.insert(0, str(ROOT / "api"))
|
||||
|
||||
# Lightweight stubs to avoid importing heavy application modules during unit tests
|
||||
m = types.ModuleType("core.model_manager")
|
||||
|
||||
|
||||
class ModelManager:
|
||||
def get_default_model_instance(self, tenant_id, model_type):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_model_instance(self, tenant_id, model_type, provider=None, model=None):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
m.ModelManager = ModelManager
|
||||
sys.modules["core.model_manager"] = m
|
||||
|
||||
m2 = types.ModuleType("core.ops.ops_trace_manager")
|
||||
|
||||
|
||||
class TraceTask:
|
||||
def __init__(self, *args, **kwargs):
|
||||
# store attributes for potential inspection in tests
|
||||
for k, v in kwargs.items():
|
||||
setattr(self, k, v)
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
|
||||
|
||||
class TraceQueueManager:
|
||||
def __init__(self, *a, **k):
|
||||
pass
|
||||
|
||||
def add_trace_task(self, *a, **k):
|
||||
pass
|
||||
|
||||
|
||||
m2.TraceTask = TraceTask
|
||||
m2.TraceQueueManager = TraceQueueManager
|
||||
sys.modules["core.ops.ops_trace_manager"] = m2
|
||||
|
||||
# Stub core.ops.utils to avoid importing heavy dependencies (db, models) during tests
|
||||
m_ops = types.ModuleType("core.ops.utils")
|
||||
from contextlib import contextmanager
|
||||
|
||||
|
||||
@contextmanager
|
||||
def measure_time():
|
||||
class Timer:
|
||||
pass
|
||||
|
||||
t = Timer()
|
||||
yield t
|
||||
|
||||
|
||||
m_ops.measure_time = measure_time
|
||||
sys.modules["core.ops.utils"] = m_ops
|
||||
|
||||
m3 = types.ModuleType("core.model_runtime.entities.llm_entities")
|
||||
|
||||
|
||||
class LLMUsage:
|
||||
@classmethod
|
||||
def empty_usage(cls):
|
||||
return cls()
|
||||
|
||||
|
||||
class LLMResult:
|
||||
def __init__(self, model=None, prompt_messages=None, message=None, usage=None):
|
||||
self.model = model
|
||||
self.prompt_messages = prompt_messages
|
||||
self.message = message
|
||||
self.usage = usage
|
||||
|
||||
|
||||
m3.LLMUsage = LLMUsage
|
||||
m3.LLMResult = LLMResult
|
||||
sys.modules["core.model_runtime.entities.llm_entities"] = m3
|
||||
|
||||
m4 = types.ModuleType("core.model_runtime.entities.message_entities")
|
||||
|
||||
|
||||
class PromptMessage:
|
||||
def __init__(self, content=None):
|
||||
self.content = content
|
||||
|
||||
def get_text_content(self):
|
||||
return str(self.content) if self.content is not None else ""
|
||||
|
||||
|
||||
class TextPromptMessageContent:
|
||||
def __init__(self, data):
|
||||
self.data = data
|
||||
|
||||
|
||||
class ImagePromptMessageContent:
|
||||
def __init__(self, url=None, base64_data=None, mime_type=None, filename=None):
|
||||
self.url = url
|
||||
self.base64_data = base64_data
|
||||
self.mime_type = mime_type
|
||||
self.filename = filename
|
||||
|
||||
|
||||
class DocumentPromptMessageContent:
|
||||
def __init__(self, url=None):
|
||||
self.url = url
|
||||
|
||||
|
||||
class AudioPromptMessageContent(DocumentPromptMessageContent):
|
||||
pass
|
||||
|
||||
|
||||
class VideoPromptMessageContent(DocumentPromptMessageContent):
|
||||
pass
|
||||
|
||||
|
||||
class AssistantPromptMessage(PromptMessage):
|
||||
def __init__(self, content):
|
||||
super().__init__(content)
|
||||
|
||||
|
||||
class UserPromptMessage(PromptMessage):
|
||||
def __init__(self, content):
|
||||
super().__init__(content)
|
||||
|
||||
|
||||
class SystemPromptMessage(PromptMessage):
|
||||
def __init__(self, content=None):
|
||||
super().__init__(content)
|
||||
|
||||
|
||||
m4.PromptMessage = PromptMessage
|
||||
m4.AssistantPromptMessage = AssistantPromptMessage
|
||||
m4.UserPromptMessage = UserPromptMessage
|
||||
m4.SystemPromptMessage = SystemPromptMessage
|
||||
m4.TextPromptMessageContent = TextPromptMessageContent
|
||||
m4.ImagePromptMessageContent = ImagePromptMessageContent
|
||||
m4.DocumentPromptMessageContent = DocumentPromptMessageContent
|
||||
m4.AudioPromptMessageContent = AudioPromptMessageContent
|
||||
m4.VideoPromptMessageContent = VideoPromptMessageContent
|
||||
sys.modules["core.model_runtime.entities.message_entities"] = m4
|
||||
|
||||
m5 = types.ModuleType("core.model_runtime.entities.model_entities")
|
||||
|
||||
|
||||
class ModelType:
|
||||
LLM = None
|
||||
|
||||
|
||||
m5.ModelType = ModelType
|
||||
sys.modules["core.model_runtime.entities.model_entities"] = m5
|
||||
|
||||
# Stub minimal 'extensions' and 'models' packages to avoid importing heavy application code during tests
|
||||
ext_db = types.ModuleType("extensions.ext_database")
|
||||
ext_db.db = None
|
||||
sys.modules["extensions.ext_database"] = ext_db
|
||||
ext_storage = types.ModuleType("extensions.ext_storage")
|
||||
ext_storage.storage = None
|
||||
sys.modules["extensions.ext_storage"] = ext_storage
|
||||
|
||||
models_m = types.ModuleType("models")
|
||||
|
||||
|
||||
class App:
|
||||
pass
|
||||
|
||||
|
||||
class Message:
|
||||
pass
|
||||
|
||||
|
||||
class WorkflowNodeExecutionModel:
|
||||
pass
|
||||
|
||||
|
||||
models_m.App = App
|
||||
models_m.Message = Message
|
||||
models_m.WorkflowNodeExecutionModel = WorkflowNodeExecutionModel
|
||||
sys.modules["models"] = models_m
|
||||
|
||||
models_workflow = types.ModuleType("models.workflow")
|
||||
|
||||
|
||||
class Workflow:
|
||||
pass
|
||||
|
||||
|
||||
models_workflow.Workflow = Workflow
|
||||
sys.modules["models.workflow"] = models_workflow
|
||||
|
||||
from core.llm_generator.llm_generator import LLMGenerator
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
||||
from core.model_runtime.entities.message_entities import AssistantPromptMessage
|
||||
from core.model_manager import ModelManager
|
||||
|
||||
|
||||
class DummyModelInstance:
|
||||
def __init__(self, content):
|
||||
self._content = content
|
||||
|
||||
def invoke_llm(self, prompt_messages=None, model_parameters=None, stream=False):
|
||||
# Return an LLMResult-like object with the message content we expect
|
||||
return LLMResult(
|
||||
model="dummy",
|
||||
prompt_messages=[],
|
||||
message=AssistantPromptMessage(content=self._content),
|
||||
usage=LLMUsage.empty_usage(),
|
||||
)
|
||||
|
||||
|
||||
def test_generate_conversation_name_persian(monkeypatch):
|
||||
# Arrange: Persian input that doesn't necessarily include Persian-specific letters
|
||||
query = "سلام دوست من، میخواهم درباره تنظیمات حساب صحبت کنم"
|
||||
|
||||
# Mock the default model instance to return a Persian title in JSON format
|
||||
fake_output = json.dumps({"Your Output": "عنوان تستی"})
|
||||
dummy = DummyModelInstance(fake_output)
|
||||
|
||||
monkeypatch.setattr(ModelManager, "get_default_model_instance", lambda self, tenant_id, model_type: dummy)
|
||||
|
||||
# Act
|
||||
name = LLMGenerator.generate_conversation_name("tenant1", query)
|
||||
|
||||
# Assert: title should be the Persian string we returned
|
||||
assert name == "عنوان تستی"
|
||||
|
||||
|
||||
def test_contains_persian_character_and_heuristics(monkeypatch):
|
||||
from core.llm_generator.llm_generator import _contains_persian, _persian_chars_re, _PERSIAN_HEURISTIC
|
||||
|
||||
# By single Persian-specific character
|
||||
assert _contains_persian("این یک تست پ") is True
|
||||
|
||||
# By heuristic Persian word
|
||||
assert _contains_persian("سلام دوست") is True
|
||||
|
||||
|
||||
def test_contains_persian_langdetect_fallback(monkeypatch):
|
||||
import core.llm_generator.llm_generator as lg
|
||||
|
||||
# Simulate langdetect being available and detecting Persian
|
||||
monkeypatch.setattr(lg, "_langdetect_available", True)
|
||||
monkeypatch.setattr(lg, "detect", lambda text: "fa")
|
||||
|
||||
assert lg._contains_persian("short ambiguous text") is True
|
||||
|
||||
# Reset monkeypatch
|
||||
monkeypatch.setattr(lg, "_langdetect_available", False)
|
||||
monkeypatch.setattr(lg, "detect", None)
|
||||
Loading…
Reference in New Issue