From 6facd9360c94244a76d906b0681d464fdb137329 Mon Sep 17 00:00:00 2001 From: Asuka Minato Date: Tue, 12 May 2026 23:45:58 +0900 Subject: [PATCH] chore: some match case (#36080) Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- api/core/prompt/advanced_prompt_transform.py | 62 +++++++++++--------- api/core/rag/extractor/extract_processor.py | 59 ++++++++++--------- 2 files changed, 63 insertions(+), 58 deletions(-) diff --git a/api/core/prompt/advanced_prompt_transform.py b/api/core/prompt/advanced_prompt_transform.py index 24e05ef865..5a9914e6e4 100644 --- a/api/core/prompt/advanced_prompt_transform.py +++ b/api/core/prompt/advanced_prompt_transform.py @@ -161,35 +161,39 @@ class AdvancedPromptTransform(PromptTransform): prompt_messages: list[PromptMessage] = [] for prompt_item in prompt_template: raw_prompt = prompt_item.text - - if prompt_item.edition_type == "basic" or not prompt_item.edition_type: - if self.with_variable_tmpl: - vp = VariablePool.empty() - for k, v in inputs.items(): - if k.startswith("#"): - vp.add(k[1:-1].split("."), v) - raw_prompt = raw_prompt.replace("{{#context#}}", context or "") - prompt = vp.convert_template(raw_prompt).text - else: - parser = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl) - prompt_inputs: Mapping[str, str] = {k: inputs[k] for k in parser.variable_keys if k in inputs} - prompt_inputs = self._set_context_variable( - context=context, parser=parser, prompt_inputs=prompt_inputs - ) - prompt = parser.format(prompt_inputs) - elif prompt_item.edition_type == "jinja2": - prompt = raw_prompt - prompt_inputs = inputs - prompt = Jinja2Formatter.format(template=prompt, inputs=prompt_inputs) - else: - raise ValueError(f"Invalid edition type: {prompt_item.edition_type}") - - if prompt_item.role == PromptMessageRole.USER: - prompt_messages.append(UserPromptMessage(content=prompt)) - elif prompt_item.role == PromptMessageRole.SYSTEM and prompt: - prompt_messages.append(SystemPromptMessage(content=prompt)) - elif prompt_item.role == PromptMessageRole.ASSISTANT: - prompt_messages.append(AssistantPromptMessage(content=prompt)) + edition_type = prompt_item.edition_type or "basic" + match edition_type: + case "basic": + if self.with_variable_tmpl: + vp = VariablePool.empty() + for k, v in inputs.items(): + if k.startswith("#"): + vp.add(k[1:-1].split("."), v) + raw_prompt = raw_prompt.replace("{{#context#}}", context or "") + prompt = vp.convert_template(raw_prompt).text + else: + parser = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl) + prompt_inputs: Mapping[str, str] = {k: inputs[k] for k in parser.variable_keys if k in inputs} + prompt_inputs = self._set_context_variable( + context=context, parser=parser, prompt_inputs=prompt_inputs + ) + prompt = parser.format(prompt_inputs) + case "jinja2": + prompt = raw_prompt + prompt_inputs = inputs + prompt = Jinja2Formatter.format(template=prompt, inputs=prompt_inputs) + case _: + raise ValueError(f"Invalid edition type: {prompt_item.edition_type}") + match prompt_item.role: + case PromptMessageRole.USER: + prompt_messages.append(UserPromptMessage(content=prompt)) + case PromptMessageRole.SYSTEM: + if prompt: + prompt_messages.append(SystemPromptMessage(content=prompt)) + case PromptMessageRole.ASSISTANT: + prompt_messages.append(AssistantPromptMessage(content=prompt)) + case PromptMessageRole.TOOL: + pass if query and memory_config and memory_config.query_prompt_template: parser = PromptTemplateParser( diff --git a/api/core/rag/extractor/extract_processor.py b/api/core/rag/extractor/extract_processor.py index b679edab36..e49e814149 100644 --- a/api/core/rag/extractor/extract_processor.py +++ b/api/core/rag/extractor/extract_processor.py @@ -183,34 +183,35 @@ class ExtractProcessor: return extractor.extract() elif extract_setting.datasource_type == DatasourceType.WEBSITE: assert extract_setting.website_info is not None, "website_info is required" - if extract_setting.website_info.provider == "firecrawl": - extractor = FirecrawlWebExtractor( - url=extract_setting.website_info.url, - job_id=extract_setting.website_info.job_id, - tenant_id=extract_setting.website_info.tenant_id, - mode=extract_setting.website_info.mode, - only_main_content=extract_setting.website_info.only_main_content, - ) - return extractor.extract() - elif extract_setting.website_info.provider == "watercrawl": - extractor = WaterCrawlWebExtractor( - url=extract_setting.website_info.url, - job_id=extract_setting.website_info.job_id, - tenant_id=extract_setting.website_info.tenant_id, - mode=extract_setting.website_info.mode, - only_main_content=extract_setting.website_info.only_main_content, - ) - return extractor.extract() - elif extract_setting.website_info.provider == "jinareader": - extractor = JinaReaderWebExtractor( - url=extract_setting.website_info.url, - job_id=extract_setting.website_info.job_id, - tenant_id=extract_setting.website_info.tenant_id, - mode=extract_setting.website_info.mode, - only_main_content=extract_setting.website_info.only_main_content, - ) - return extractor.extract() - else: - raise ValueError(f"Unsupported website provider: {extract_setting.website_info.provider}") + match extract_setting.website_info.provider: + case "firecrawl": + extractor = FirecrawlWebExtractor( + url=extract_setting.website_info.url, + job_id=extract_setting.website_info.job_id, + tenant_id=extract_setting.website_info.tenant_id, + mode=extract_setting.website_info.mode, + only_main_content=extract_setting.website_info.only_main_content, + ) + return extractor.extract() + case "watercrawl": + extractor = WaterCrawlWebExtractor( + url=extract_setting.website_info.url, + job_id=extract_setting.website_info.job_id, + tenant_id=extract_setting.website_info.tenant_id, + mode=extract_setting.website_info.mode, + only_main_content=extract_setting.website_info.only_main_content, + ) + return extractor.extract() + case "jinareader": + extractor = JinaReaderWebExtractor( + url=extract_setting.website_info.url, + job_id=extract_setting.website_info.job_id, + tenant_id=extract_setting.website_info.tenant_id, + mode=extract_setting.website_info.mode, + only_main_content=extract_setting.website_info.only_main_content, + ) + return extractor.extract() + case _: + raise ValueError(f"Unsupported website provider: {extract_setting.website_info.provider}") else: raise ValueError(f"Unsupported datasource type: {extract_setting.datasource_type}")