From 3216b67bfa93417814dbdf05fd8d0e174df3c627 Mon Sep 17 00:00:00 2001 From: Asuka Minato Date: Sun, 1 Feb 2026 19:25:54 +0900 Subject: [PATCH 01/43] refactor: examples of use match case (#31312) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- api/controllers/console/app/annotation.py | 9 +-- api/controllers/console/auth/oauth_server.py | 66 ++++++++++---------- 2 files changed, 38 insertions(+), 37 deletions(-) diff --git a/api/controllers/console/app/annotation.py b/api/controllers/console/app/annotation.py index 6a4c1528b0..a07145ce9f 100644 --- a/api/controllers/console/app/annotation.py +++ b/api/controllers/console/app/annotation.py @@ -107,10 +107,11 @@ class AnnotationReplyActionApi(Resource): def post(self, app_id, action: Literal["enable", "disable"]): app_id = str(app_id) args = AnnotationReplyPayload.model_validate(console_ns.payload) - if action == "enable": - result = AppAnnotationService.enable_app_annotation(args.model_dump(), app_id) - elif action == "disable": - result = AppAnnotationService.disable_app_annotation(app_id) + match action: + case "enable": + result = AppAnnotationService.enable_app_annotation(args.model_dump(), app_id) + case "disable": + result = AppAnnotationService.disable_app_annotation(app_id) return result, 200 diff --git a/api/controllers/console/auth/oauth_server.py b/api/controllers/console/auth/oauth_server.py index 6162d88a0b..38ea5d2dae 100644 --- a/api/controllers/console/auth/oauth_server.py +++ b/api/controllers/console/auth/oauth_server.py @@ -155,43 +155,43 @@ class OAuthServerUserTokenApi(Resource): grant_type = OAuthGrantType(payload.grant_type) except ValueError: raise BadRequest("invalid grant_type") + match grant_type: + case OAuthGrantType.AUTHORIZATION_CODE: + if not payload.code: + raise BadRequest("code is required") - if grant_type == OAuthGrantType.AUTHORIZATION_CODE: - if not payload.code: - raise BadRequest("code is required") + if payload.client_secret != oauth_provider_app.client_secret: + raise BadRequest("client_secret is invalid") - if payload.client_secret != oauth_provider_app.client_secret: - raise BadRequest("client_secret is invalid") + if payload.redirect_uri not in oauth_provider_app.redirect_uris: + raise BadRequest("redirect_uri is invalid") - if payload.redirect_uri not in oauth_provider_app.redirect_uris: - raise BadRequest("redirect_uri is invalid") + access_token, refresh_token = OAuthServerService.sign_oauth_access_token( + grant_type, code=payload.code, client_id=oauth_provider_app.client_id + ) + return jsonable_encoder( + { + "access_token": access_token, + "token_type": "Bearer", + "expires_in": OAUTH_ACCESS_TOKEN_EXPIRES_IN, + "refresh_token": refresh_token, + } + ) + case OAuthGrantType.REFRESH_TOKEN: + if not payload.refresh_token: + raise BadRequest("refresh_token is required") - access_token, refresh_token = OAuthServerService.sign_oauth_access_token( - grant_type, code=payload.code, client_id=oauth_provider_app.client_id - ) - return jsonable_encoder( - { - "access_token": access_token, - "token_type": "Bearer", - "expires_in": OAUTH_ACCESS_TOKEN_EXPIRES_IN, - "refresh_token": refresh_token, - } - ) - elif grant_type == OAuthGrantType.REFRESH_TOKEN: - if not payload.refresh_token: - raise BadRequest("refresh_token is required") - - access_token, refresh_token = OAuthServerService.sign_oauth_access_token( - grant_type, refresh_token=payload.refresh_token, client_id=oauth_provider_app.client_id - ) - return jsonable_encoder( - { - "access_token": access_token, - "token_type": "Bearer", - "expires_in": OAUTH_ACCESS_TOKEN_EXPIRES_IN, - "refresh_token": refresh_token, - } - ) + access_token, refresh_token = OAuthServerService.sign_oauth_access_token( + grant_type, refresh_token=payload.refresh_token, client_id=oauth_provider_app.client_id + ) + return jsonable_encoder( + { + "access_token": access_token, + "token_type": "Bearer", + "expires_in": OAUTH_ACCESS_TOKEN_EXPIRES_IN, + "refresh_token": refresh_token, + } + ) @console_ns.route("/oauth/provider/account") From 4f826b4641f44a8e5c1185ee09455dcd5bff4042 Mon Sep 17 00:00:00 2001 From: yyh <92089059+lyzno1@users.noreply.github.com> Date: Mon, 2 Feb 2026 09:41:34 +0800 Subject: [PATCH 02/43] refactor(typing): use enum types for workflow status fields (#31792) --- .../app/apps/common/workflow_response_converter.py | 10 +++++----- api/core/app/entities/task_entities.py | 10 +++++----- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/api/core/app/apps/common/workflow_response_converter.py b/api/core/app/apps/common/workflow_response_converter.py index 38ecec5d30..cefff7be92 100644 --- a/api/core/app/apps/common/workflow_response_converter.py +++ b/api/core/app/apps/common/workflow_response_converter.py @@ -250,7 +250,7 @@ class WorkflowResponseConverter: data=WorkflowFinishStreamResponse.Data( id=run_id, workflow_id=workflow_id, - status=status.value, + status=status, outputs=encoded_outputs, error=error, elapsed_time=elapsed_time, @@ -340,13 +340,13 @@ class WorkflowResponseConverter: metadata = self._merge_metadata(event.execution_metadata, snapshot) if isinstance(event, QueueNodeSucceededEvent): - status = WorkflowNodeExecutionStatus.SUCCEEDED.value + status = WorkflowNodeExecutionStatus.SUCCEEDED error_message = event.error elif isinstance(event, QueueNodeFailedEvent): - status = WorkflowNodeExecutionStatus.FAILED.value + status = WorkflowNodeExecutionStatus.FAILED error_message = event.error else: - status = WorkflowNodeExecutionStatus.EXCEPTION.value + status = WorkflowNodeExecutionStatus.EXCEPTION error_message = event.error return NodeFinishStreamResponse( @@ -413,7 +413,7 @@ class WorkflowResponseConverter: process_data_truncated=process_data_truncated, outputs=outputs, outputs_truncated=outputs_truncated, - status=WorkflowNodeExecutionStatus.RETRY.value, + status=WorkflowNodeExecutionStatus.RETRY, error=event.error, elapsed_time=elapsed_time, execution_metadata=metadata, diff --git a/api/core/app/entities/task_entities.py b/api/core/app/entities/task_entities.py index 79a5e657b3..26fb17ccef 100644 --- a/api/core/app/entities/task_entities.py +++ b/api/core/app/entities/task_entities.py @@ -7,7 +7,7 @@ from pydantic import BaseModel, ConfigDict, Field from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage from core.rag.entities.citation_metadata import RetrievalSourceMetadata from core.workflow.entities import AgentNodeStrategyInit -from core.workflow.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus +from core.workflow.enums import WorkflowExecutionStatus, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus class AnnotationReplyAccount(BaseModel): @@ -223,7 +223,7 @@ class WorkflowFinishStreamResponse(StreamResponse): id: str workflow_id: str - status: str + status: WorkflowExecutionStatus outputs: Mapping[str, Any] | None = None error: str | None = None elapsed_time: float @@ -311,7 +311,7 @@ class NodeFinishStreamResponse(StreamResponse): process_data_truncated: bool = False outputs: Mapping[str, Any] | None = None outputs_truncated: bool = True - status: str + status: WorkflowNodeExecutionStatus error: str | None = None elapsed_time: float execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None @@ -375,7 +375,7 @@ class NodeRetryStreamResponse(StreamResponse): process_data_truncated: bool = False outputs: Mapping[str, Any] | None = None outputs_truncated: bool = False - status: str + status: WorkflowNodeExecutionStatus error: str | None = None elapsed_time: float execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None @@ -719,7 +719,7 @@ class WorkflowAppBlockingResponse(AppBlockingResponse): id: str workflow_id: str - status: str + status: WorkflowExecutionStatus outputs: Mapping[str, Any] | None = None error: str | None = None elapsed_time: float From 41177757e64a174abcdf577079e45936294e19d4 Mon Sep 17 00:00:00 2001 From: FFXN <31929997+FFXN@users.noreply.github.com> Date: Mon, 2 Feb 2026 09:45:17 +0800 Subject: [PATCH 03/43] fix: summary index bug (#31810) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: Jyong <76649700+JohnJyong@users.noreply.github.com> Co-authored-by: zxhlyh Co-authored-by: Yansong Zhang <916125788@qq.com> Co-authored-by: hj24 Co-authored-by: CodingOnStar Co-authored-by: CodingOnStar Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- .../console/datasets/datasets_document.py | 12 ++++++ api/core/indexing_runner.py | 4 +- api/core/llm_generator/prompts.py | 4 +- .../index_processor/index_processor_base.py | 12 +++++- .../processor/paragraph_index_processor.py | 31 ++++++++++++-- .../processor/parent_child_index_processor.py | 8 +++- .../processor/qa_index_processor.py | 6 ++- .../knowledge_index/knowledge_index_node.py | 13 ++++++ api/services/dataset_service.py | 41 +++++++++++++++++++ .../rag_pipeline_transform_service.py | 4 ++ api/services/summary_index_service.py | 11 ++++- 11 files changed, 137 insertions(+), 9 deletions(-) diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index 6e3c0db8a3..6a0c9e5f77 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -1339,6 +1339,18 @@ class DocumentGenerateSummaryApi(Resource): missing_ids = set(document_list) - found_ids raise NotFound(f"Some documents not found: {list(missing_ids)}") + # Update need_summary to True for documents that don't have it set + # This handles the case where documents were created when summary_index_setting was disabled + documents_to_update = [doc for doc in documents if not doc.need_summary and doc.doc_form != "qa_model"] + + if documents_to_update: + document_ids_to_update = [str(doc.id) for doc in documents_to_update] + DocumentService.update_documents_need_summary( + dataset_id=dataset_id, + document_ids=document_ids_to_update, + need_summary=True, + ) + # Dispatch async tasks for each document for document in documents: # Skip qa_model documents as they don't generate summaries diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index e172e88298..61f168a26f 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -369,7 +369,9 @@ class IndexingRunner: # Generate summary preview summary_index_setting = tmp_processing_rule.get("summary_index_setting") if summary_index_setting and summary_index_setting.get("enable") and preview_texts: - preview_texts = index_processor.generate_summary_preview(tenant_id, preview_texts, summary_index_setting) + preview_texts = index_processor.generate_summary_preview( + tenant_id, preview_texts, summary_index_setting, doc_language + ) return IndexingEstimate(total_segments=total_segments, preview=preview_texts) diff --git a/api/core/llm_generator/prompts.py b/api/core/llm_generator/prompts.py index d46cf049dd..ee9a016c95 100644 --- a/api/core/llm_generator/prompts.py +++ b/api/core/llm_generator/prompts.py @@ -441,11 +441,13 @@ DEFAULT_GENERATOR_SUMMARY_PROMPT = ( Requirements: 1. Write a concise summary in plain text -2. Use the same language as the input content +2. You must write in {language}. No language other than {language} should be used. 3. Focus on important facts, concepts, and details 4. If images are included, describe their key information 5. Do not use words like "好的", "ok", "I understand", "This text discusses", "The content mentions" 6. Write directly without extra words +7. If there is not enough content to generate a meaningful summary, + return an empty string without any explanation or prompt Output only the summary text. Start summarizing now: diff --git a/api/core/rag/index_processor/index_processor_base.py b/api/core/rag/index_processor/index_processor_base.py index 151a3de7d9..6e76321ea0 100644 --- a/api/core/rag/index_processor/index_processor_base.py +++ b/api/core/rag/index_processor/index_processor_base.py @@ -48,12 +48,22 @@ class BaseIndexProcessor(ABC): @abstractmethod def generate_summary_preview( - self, tenant_id: str, preview_texts: list[PreviewDetail], summary_index_setting: dict + self, + tenant_id: str, + preview_texts: list[PreviewDetail], + summary_index_setting: dict, + doc_language: str | None = None, ) -> list[PreviewDetail]: """ For each segment in preview_texts, generate a summary using LLM and attach it to the segment. The summary can be stored in a new attribute, e.g., summary. This method should be implemented by subclasses. + + Args: + tenant_id: Tenant ID + preview_texts: List of preview details to generate summaries for + summary_index_setting: Summary index configuration + doc_language: Optional document language to ensure summary is generated in the correct language """ raise NotImplementedError diff --git a/api/core/rag/index_processor/processor/paragraph_index_processor.py b/api/core/rag/index_processor/processor/paragraph_index_processor.py index ab91e29145..41d7656f8a 100644 --- a/api/core/rag/index_processor/processor/paragraph_index_processor.py +++ b/api/core/rag/index_processor/processor/paragraph_index_processor.py @@ -275,7 +275,11 @@ class ParagraphIndexProcessor(BaseIndexProcessor): raise ValueError("Chunks is not a list") def generate_summary_preview( - self, tenant_id: str, preview_texts: list[PreviewDetail], summary_index_setting: dict + self, + tenant_id: str, + preview_texts: list[PreviewDetail], + summary_index_setting: dict, + doc_language: str | None = None, ) -> list[PreviewDetail]: """ For each segment, concurrently call generate_summary to generate a summary @@ -298,11 +302,15 @@ class ParagraphIndexProcessor(BaseIndexProcessor): if flask_app: # Ensure Flask app context in worker thread with flask_app.app_context(): - summary, _ = self.generate_summary(tenant_id, preview.content, summary_index_setting) + summary, _ = self.generate_summary( + tenant_id, preview.content, summary_index_setting, document_language=doc_language + ) preview.summary = summary else: # Fallback: try without app context (may fail) - summary, _ = self.generate_summary(tenant_id, preview.content, summary_index_setting) + summary, _ = self.generate_summary( + tenant_id, preview.content, summary_index_setting, document_language=doc_language + ) preview.summary = summary # Generate summaries concurrently using ThreadPoolExecutor @@ -356,6 +364,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor): text: str, summary_index_setting: dict | None = None, segment_id: str | None = None, + document_language: str | None = None, ) -> tuple[str, LLMUsage]: """ Generate summary for the given text using ModelInstance.invoke_llm and the default or custom summary prompt, @@ -366,6 +375,8 @@ class ParagraphIndexProcessor(BaseIndexProcessor): text: Text content to summarize summary_index_setting: Summary index configuration segment_id: Optional segment ID to fetch attachments from SegmentAttachmentBinding table + document_language: Optional document language (e.g., "Chinese", "English") + to ensure summary is generated in the correct language Returns: Tuple of (summary_content, llm_usage) where llm_usage is LLMUsage object @@ -381,8 +392,22 @@ class ParagraphIndexProcessor(BaseIndexProcessor): raise ValueError("model_name and model_provider_name are required in summary_index_setting") # Import default summary prompt + is_default_prompt = False if not summary_prompt: summary_prompt = DEFAULT_GENERATOR_SUMMARY_PROMPT + is_default_prompt = True + + # Format prompt with document language only for default prompt + # Custom prompts are used as-is to avoid interfering with user-defined templates + # If document_language is provided, use it; otherwise, use "the same language as the input content" + # This is especially important for image-only chunks where text is empty or minimal + if is_default_prompt: + language_for_prompt = document_language or "the same language as the input content" + try: + summary_prompt = summary_prompt.format(language=language_for_prompt) + except KeyError: + # If default prompt doesn't have {language} placeholder, use it as-is + pass provider_manager = ProviderManager() provider_model_bundle = provider_manager.get_provider_model_bundle( diff --git a/api/core/rag/index_processor/processor/parent_child_index_processor.py b/api/core/rag/index_processor/processor/parent_child_index_processor.py index 961df2e50c..0ea77405ed 100644 --- a/api/core/rag/index_processor/processor/parent_child_index_processor.py +++ b/api/core/rag/index_processor/processor/parent_child_index_processor.py @@ -358,7 +358,11 @@ class ParentChildIndexProcessor(BaseIndexProcessor): } def generate_summary_preview( - self, tenant_id: str, preview_texts: list[PreviewDetail], summary_index_setting: dict + self, + tenant_id: str, + preview_texts: list[PreviewDetail], + summary_index_setting: dict, + doc_language: str | None = None, ) -> list[PreviewDetail]: """ For each parent chunk in preview_texts, concurrently call generate_summary to generate a summary @@ -389,6 +393,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor): tenant_id=tenant_id, text=preview.content, summary_index_setting=summary_index_setting, + document_language=doc_language, ) preview.summary = summary else: @@ -397,6 +402,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor): tenant_id=tenant_id, text=preview.content, summary_index_setting=summary_index_setting, + document_language=doc_language, ) preview.summary = summary diff --git a/api/core/rag/index_processor/processor/qa_index_processor.py b/api/core/rag/index_processor/processor/qa_index_processor.py index 272d2ed351..40d9caaa69 100644 --- a/api/core/rag/index_processor/processor/qa_index_processor.py +++ b/api/core/rag/index_processor/processor/qa_index_processor.py @@ -241,7 +241,11 @@ class QAIndexProcessor(BaseIndexProcessor): } def generate_summary_preview( - self, tenant_id: str, preview_texts: list[PreviewDetail], summary_index_setting: dict + self, + tenant_id: str, + preview_texts: list[PreviewDetail], + summary_index_setting: dict, + doc_language: str | None = None, ) -> list[PreviewDetail]: """ QA model doesn't generate summaries, so this method returns preview_texts unchanged. diff --git a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py index b88c2d510f..2aff953bc6 100644 --- a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py +++ b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py @@ -78,12 +78,21 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]): indexing_technique = node_data.indexing_technique or dataset.indexing_technique summary_index_setting = node_data.summary_index_setting or dataset.summary_index_setting + # Try to get document language if document_id is available + doc_language = None + document_id = variable_pool.get(["sys", SystemVariableKey.DOCUMENT_ID]) + if document_id: + document = db.session.query(Document).filter_by(id=document_id.value).first() + if document and document.doc_language: + doc_language = document.doc_language + outputs = self._get_preview_output_with_summaries( node_data.chunk_structure, chunks, dataset=dataset, indexing_technique=indexing_technique, summary_index_setting=summary_index_setting, + doc_language=doc_language, ) return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, @@ -315,6 +324,7 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]): dataset: Dataset, indexing_technique: str | None = None, summary_index_setting: dict | None = None, + doc_language: str | None = None, ) -> Mapping[str, Any]: """ Generate preview output with summaries for chunks in preview mode. @@ -326,6 +336,7 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]): dataset: Dataset object (for tenant_id) indexing_technique: Indexing technique from node config or dataset summary_index_setting: Summary index setting from node config or dataset + doc_language: Optional document language to ensure summary is generated in the correct language """ index_processor = IndexProcessorFactory(chunk_structure).init_index_processor() preview_output = index_processor.format_preview(chunks) @@ -365,6 +376,7 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]): tenant_id=dataset.tenant_id, text=preview_item["content"], summary_index_setting=summary_index_setting, + document_language=doc_language, ) if summary: preview_item["summary"] = summary @@ -374,6 +386,7 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]): tenant_id=dataset.tenant_id, text=preview_item["content"], summary_index_setting=summary_index_setting, + document_language=doc_language, ) if summary: preview_item["summary"] = summary diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 0b3fcbe4ae..16945fca6a 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -16,6 +16,7 @@ from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden, NotFound from configs import dify_config +from core.db.session_factory import session_factory from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.file import helpers as file_helpers from core.helper.name_generator import generate_incremental_name @@ -1388,6 +1389,46 @@ class DocumentService: ).all() return documents + @staticmethod + def update_documents_need_summary(dataset_id: str, document_ids: Sequence[str], need_summary: bool = True) -> int: + """ + Update need_summary field for multiple documents. + + This method handles the case where documents were created when summary_index_setting was disabled, + and need to be updated when summary_index_setting is later enabled. + + Args: + dataset_id: Dataset ID + document_ids: List of document IDs to update + need_summary: Value to set for need_summary field (default: True) + + Returns: + Number of documents updated + """ + if not document_ids: + return 0 + + document_id_list: list[str] = [str(document_id) for document_id in document_ids] + + with session_factory.create_session() as session: + updated_count = ( + session.query(Document) + .filter( + Document.id.in_(document_id_list), + Document.dataset_id == dataset_id, + Document.doc_form != "qa_model", # Skip qa_model documents + ) + .update({Document.need_summary: need_summary}, synchronize_session=False) + ) + session.commit() + logger.info( + "Updated need_summary to %s for %d documents in dataset %s", + need_summary, + updated_count, + dataset_id, + ) + return updated_count + @staticmethod def get_document_download_url(document: Document) -> str: """ diff --git a/api/services/rag_pipeline/rag_pipeline_transform_service.py b/api/services/rag_pipeline/rag_pipeline_transform_service.py index 8ea365e907..d0dfbc1070 100644 --- a/api/services/rag_pipeline/rag_pipeline_transform_service.py +++ b/api/services/rag_pipeline/rag_pipeline_transform_service.py @@ -174,6 +174,10 @@ class RagPipelineTransformService: else: dataset.retrieval_model = knowledge_configuration.retrieval_model.model_dump() + # Copy summary_index_setting from dataset to knowledge_index node configuration + if dataset.summary_index_setting: + knowledge_configuration.summary_index_setting = dataset.summary_index_setting + knowledge_configuration_dict.update(knowledge_configuration.model_dump()) node["data"] = knowledge_configuration_dict return node diff --git a/api/services/summary_index_service.py b/api/services/summary_index_service.py index b8e1f8bc3f..7c03ceed5b 100644 --- a/api/services/summary_index_service.py +++ b/api/services/summary_index_service.py @@ -49,11 +49,18 @@ class SummaryIndexService: # Use lazy import to avoid circular import from core.rag.index_processor.processor.paragraph_index_processor import ParagraphIndexProcessor + # Get document language to ensure summary is generated in the correct language + # This is especially important for image-only chunks where text is empty or minimal + document_language = None + if segment.document and segment.document.doc_language: + document_language = segment.document.doc_language + summary_content, usage = ParagraphIndexProcessor.generate_summary( tenant_id=dataset.tenant_id, text=segment.content, summary_index_setting=summary_index_setting, segment_id=segment.id, + document_language=document_language, ) if not summary_content: @@ -558,6 +565,9 @@ class SummaryIndexService: ) session.add(summary_record) + # Commit the batch created records + session.commit() + @staticmethod def update_summary_record_error( segment: DocumentSegment, @@ -762,7 +772,6 @@ class SummaryIndexService: dataset=dataset, status="not_started", ) - session.commit() # Commit initial records summary_records = [] From 603a896c496295042235dad1e94fbad74b21d927 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Mon, 2 Feb 2026 11:12:04 +0800 Subject: [PATCH 04/43] chore(CODEOWNERS): assign `.agents/skills` to @hyoban (#31816) Signed-off-by: -LAN- --- .github/CODEOWNERS | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 106c26bbed..36fa39b5d7 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -9,6 +9,9 @@ # CODEOWNERS file /.github/CODEOWNERS @laipz8200 @crazywoola +# Agents +/.agents/skills/ @hyoban + # Docs /docs/ @crazywoola From 9fb72c151cadf84d5c7353baf7076b43f2f5a952 Mon Sep 17 00:00:00 2001 From: QuantumGhost Date: Mon, 2 Feb 2026 11:18:18 +0800 Subject: [PATCH 05/43] refactor: "chore: update version to 1.12.0" (#31817) --- api/pyproject.toml | 2 +- api/uv.lock | 2 +- docker/docker-compose-template.yaml | 8 ++++---- docker/docker-compose.yaml | 8 ++++---- web/package.json | 2 +- 5 files changed, 11 insertions(+), 11 deletions(-) diff --git a/api/pyproject.toml b/api/pyproject.toml index 97e6c83ed6..02d1aea21d 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "dify-api" -version = "1.12.0" +version = "1.11.4" requires-python = ">=3.11,<3.13" dependencies = [ diff --git a/api/uv.lock b/api/uv.lock index 04d9a7c021..ad84b35212 100644 --- a/api/uv.lock +++ b/api/uv.lock @@ -1368,7 +1368,7 @@ wheels = [ [[package]] name = "dify-api" -version = "1.12.0" +version = "1.11.4" source = { virtual = "." } dependencies = [ { name = "aliyun-log-python-sdk" }, diff --git a/docker/docker-compose-template.yaml b/docker/docker-compose-template.yaml index e27b51bcc0..eb8c2b53c5 100644 --- a/docker/docker-compose-template.yaml +++ b/docker/docker-compose-template.yaml @@ -21,7 +21,7 @@ services: # API service api: - image: langgenius/dify-api:1.12.0 + image: langgenius/dify-api:1.11.4 restart: always environment: # Use the shared environment variables. @@ -63,7 +63,7 @@ services: # worker service # The Celery worker for processing all queues (dataset, workflow, mail, etc.) worker: - image: langgenius/dify-api:1.12.0 + image: langgenius/dify-api:1.11.4 restart: always environment: # Use the shared environment variables. @@ -102,7 +102,7 @@ services: # worker_beat service # Celery beat for scheduling periodic tasks. worker_beat: - image: langgenius/dify-api:1.12.0 + image: langgenius/dify-api:1.11.4 restart: always environment: # Use the shared environment variables. @@ -132,7 +132,7 @@ services: # Frontend web application. web: - image: langgenius/dify-web:1.12.0 + image: langgenius/dify-web:1.11.4 restart: always environment: CONSOLE_API_URL: ${CONSOLE_API_URL:-} diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index a0a755f570..02b8146aa9 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -707,7 +707,7 @@ services: # API service api: - image: langgenius/dify-api:1.12.0 + image: langgenius/dify-api:1.11.4 restart: always environment: # Use the shared environment variables. @@ -749,7 +749,7 @@ services: # worker service # The Celery worker for processing all queues (dataset, workflow, mail, etc.) worker: - image: langgenius/dify-api:1.12.0 + image: langgenius/dify-api:1.11.4 restart: always environment: # Use the shared environment variables. @@ -788,7 +788,7 @@ services: # worker_beat service # Celery beat for scheduling periodic tasks. worker_beat: - image: langgenius/dify-api:1.12.0 + image: langgenius/dify-api:1.11.4 restart: always environment: # Use the shared environment variables. @@ -818,7 +818,7 @@ services: # Frontend web application. web: - image: langgenius/dify-web:1.12.0 + image: langgenius/dify-web:1.11.4 restart: always environment: CONSOLE_API_URL: ${CONSOLE_API_URL:-} diff --git a/web/package.json b/web/package.json index 954366fc89..83a4f98dee 100644 --- a/web/package.json +++ b/web/package.json @@ -1,7 +1,7 @@ { "name": "dify-web", "type": "module", - "version": "1.12.0", + "version": "1.11.4", "private": true, "packageManager": "pnpm@10.27.0+sha512.72d699da16b1179c14ba9e64dc71c9a40988cbdc65c264cb0e489db7de917f20dcf4d64d8723625f2969ba52d4b7e2a1170682d9ac2a5dcaeaab732b7e16f04a", "imports": { From 840a975fef42b965700699928e9a02fe3e2383b4 Mon Sep 17 00:00:00 2001 From: Asuka Minato Date: Mon, 2 Feb 2026 14:54:16 +0900 Subject: [PATCH 06/43] =?UTF-8?q?refactor:=20add=20test=20for=20api/contro?= =?UTF-8?q?llers/console/workspace/tool=5Fpr=E2=80=A6=20(#29886)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- .../console/workspace/test_tool_providers.py | 364 ++++++++++++++++++ 1 file changed, 364 insertions(+) create mode 100644 api/tests/unit_tests/controllers/console/workspace/test_tool_providers.py diff --git a/api/tests/unit_tests/controllers/console/workspace/test_tool_providers.py b/api/tests/unit_tests/controllers/console/workspace/test_tool_providers.py new file mode 100644 index 0000000000..94c3019d5e --- /dev/null +++ b/api/tests/unit_tests/controllers/console/workspace/test_tool_providers.py @@ -0,0 +1,364 @@ +"""Endpoint tests for controllers.console.workspace.tool_providers.""" + +from __future__ import annotations + +import builtins +import importlib +from contextlib import contextmanager +from types import ModuleType, SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from flask import Flask +from flask.views import MethodView + +if not hasattr(builtins, "MethodView"): + builtins.MethodView = MethodView # type: ignore[attr-defined] + + +_CONTROLLER_MODULE: ModuleType | None = None +_WRAPS_MODULE: ModuleType | None = None +_CONTROLLER_PATCHERS: list[patch] = [] + + +@contextmanager +def _mock_db(): + mock_session = SimpleNamespace(query=lambda *args, **kwargs: SimpleNamespace(first=lambda: True)) + with patch("extensions.ext_database.db.session", mock_session): + yield + + +@pytest.fixture +def app() -> Flask: + flask_app = Flask(__name__) + flask_app.config["TESTING"] = True + return flask_app + + +@pytest.fixture +def controller_module(monkeypatch: pytest.MonkeyPatch): + module_name = "controllers.console.workspace.tool_providers" + global _CONTROLLER_MODULE + if _CONTROLLER_MODULE is None: + + def _noop(func): + return func + + patch_targets = [ + ("libs.login.login_required", _noop), + ("controllers.console.wraps.setup_required", _noop), + ("controllers.console.wraps.account_initialization_required", _noop), + ("controllers.console.wraps.is_admin_or_owner_required", _noop), + ("controllers.console.wraps.enterprise_license_required", _noop), + ] + for target, value in patch_targets: + patcher = patch(target, value) + patcher.start() + _CONTROLLER_PATCHERS.append(patcher) + monkeypatch.setenv("DIFY_SETUP_READY", "true") + with _mock_db(): + _CONTROLLER_MODULE = importlib.import_module(module_name) + + module = _CONTROLLER_MODULE + monkeypatch.setattr(module, "jsonable_encoder", lambda payload: payload) + + # Ensure decorators that consult deployment edition do not reach the database. + global _WRAPS_MODULE + wraps_module = importlib.import_module("controllers.console.wraps") + _WRAPS_MODULE = wraps_module + monkeypatch.setattr(module.dify_config, "EDITION", "CLOUD") + monkeypatch.setattr(wraps_module.dify_config, "EDITION", "CLOUD") + + login_module = importlib.import_module("libs.login") + monkeypatch.setattr(login_module, "check_csrf_token", lambda *args, **kwargs: None) + return module + + +def _mock_account(user_id: str = "user-123") -> SimpleNamespace: + return SimpleNamespace(id=user_id, status="active", is_authenticated=True, current_tenant_id=None) + + +def _set_current_account( + monkeypatch: pytest.MonkeyPatch, + controller_module: ModuleType, + user: SimpleNamespace, + tenant_id: str, +) -> None: + def _getter(): + return user, tenant_id + + user.current_tenant_id = tenant_id + + monkeypatch.setattr(controller_module, "current_account_with_tenant", _getter) + if _WRAPS_MODULE is not None: + monkeypatch.setattr(_WRAPS_MODULE, "current_account_with_tenant", _getter) + + login_module = importlib.import_module("libs.login") + monkeypatch.setattr(login_module, "_get_user", lambda: user) + + +def test_tool_provider_list_calls_service_with_query( + app: Flask, controller_module: ModuleType, monkeypatch: pytest.MonkeyPatch +): + user = _mock_account() + _set_current_account(monkeypatch, controller_module, user, "tenant-456") + + service_mock = MagicMock(return_value=[{"provider": "builtin"}]) + monkeypatch.setattr(controller_module.ToolCommonService, "list_tool_providers", service_mock) + + with app.test_request_context("/workspaces/current/tool-providers?type=builtin"): + response = controller_module.ToolProviderListApi().get() + + assert response == [{"provider": "builtin"}] + service_mock.assert_called_once_with(user.id, "tenant-456", "builtin") + + +def test_builtin_provider_add_passes_payload( + app: Flask, controller_module: ModuleType, monkeypatch: pytest.MonkeyPatch +): + user = _mock_account() + _set_current_account(monkeypatch, controller_module, user, "tenant-456") + + service_mock = MagicMock(return_value={"status": "ok"}) + monkeypatch.setattr(controller_module.BuiltinToolManageService, "add_builtin_tool_provider", service_mock) + + payload = { + "credentials": {"api_key": "sk-test"}, + "name": "MyTool", + "type": controller_module.CredentialType.API_KEY, + } + + with app.test_request_context( + "/workspaces/current/tool-provider/builtin/openai/add", + method="POST", + json=payload, + ): + response = controller_module.ToolBuiltinProviderAddApi().post(provider="openai") + + assert response == {"status": "ok"} + service_mock.assert_called_once_with( + user_id="user-123", + tenant_id="tenant-456", + provider="openai", + credentials={"api_key": "sk-test"}, + name="MyTool", + api_type=controller_module.CredentialType.API_KEY, + ) + + +def test_builtin_provider_tools_get(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch): + user = _mock_account("user-tenant-789") + _set_current_account(monkeypatch, controller_module, user, "tenant-789") + + service_mock = MagicMock(return_value=[{"name": "tool-a"}]) + monkeypatch.setattr(controller_module.BuiltinToolManageService, "list_builtin_tool_provider_tools", service_mock) + monkeypatch.setattr(controller_module, "jsonable_encoder", lambda payload: payload) + + with app.test_request_context( + "/workspaces/current/tool-provider/builtin/my-provider/tools", + method="GET", + ): + response = controller_module.ToolBuiltinProviderListToolsApi().get(provider="my-provider") + + assert response == [{"name": "tool-a"}] + service_mock.assert_called_once_with("tenant-789", "my-provider") + + +def test_builtin_provider_info_get(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch): + user = _mock_account("user-tenant-9") + _set_current_account(monkeypatch, controller_module, user, "tenant-9") + service_mock = MagicMock(return_value={"info": True}) + monkeypatch.setattr(controller_module.BuiltinToolManageService, "get_builtin_tool_provider_info", service_mock) + + with app.test_request_context("/info", method="GET"): + resp = controller_module.ToolBuiltinProviderInfoApi().get(provider="demo") + + assert resp == {"info": True} + service_mock.assert_called_once_with("tenant-9", "demo") + + +def test_builtin_provider_credentials_get(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch): + user = _mock_account("user-tenant-cred") + _set_current_account(monkeypatch, controller_module, user, "tenant-cred") + service_mock = MagicMock(return_value=[{"cred": 1}]) + monkeypatch.setattr( + controller_module.BuiltinToolManageService, + "get_builtin_tool_provider_credentials", + service_mock, + ) + + with app.test_request_context("/creds", method="GET"): + resp = controller_module.ToolBuiltinProviderGetCredentialsApi().get(provider="demo") + + assert resp == [{"cred": 1}] + service_mock.assert_called_once_with(tenant_id="tenant-cred", provider_name="demo") + + +def test_api_provider_remote_schema_get(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch): + user = _mock_account() + _set_current_account(monkeypatch, controller_module, user, "tenant-10") + service_mock = MagicMock(return_value={"schema": "ok"}) + monkeypatch.setattr(controller_module.ApiToolManageService, "get_api_tool_provider_remote_schema", service_mock) + + with app.test_request_context("/remote?url=https://example.com/"): + resp = controller_module.ToolApiProviderGetRemoteSchemaApi().get() + + assert resp == {"schema": "ok"} + service_mock.assert_called_once_with(user.id, "tenant-10", "https://example.com/") + + +def test_api_provider_list_tools_get(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch): + user = _mock_account() + _set_current_account(monkeypatch, controller_module, user, "tenant-11") + service_mock = MagicMock(return_value=[{"tool": "t"}]) + monkeypatch.setattr(controller_module.ApiToolManageService, "list_api_tool_provider_tools", service_mock) + + with app.test_request_context("/tools?provider=foo"): + resp = controller_module.ToolApiProviderListToolsApi().get() + + assert resp == [{"tool": "t"}] + service_mock.assert_called_once_with(user.id, "tenant-11", "foo") + + +def test_api_provider_get(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch): + user = _mock_account() + _set_current_account(monkeypatch, controller_module, user, "tenant-12") + service_mock = MagicMock(return_value={"provider": "foo"}) + monkeypatch.setattr(controller_module.ApiToolManageService, "get_api_tool_provider", service_mock) + + with app.test_request_context("/get?provider=foo"): + resp = controller_module.ToolApiProviderGetApi().get() + + assert resp == {"provider": "foo"} + service_mock.assert_called_once_with(user.id, "tenant-12", "foo") + + +def test_builtin_provider_credentials_schema_get(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch): + user = _mock_account("user-tenant-13") + _set_current_account(monkeypatch, controller_module, user, "tenant-13") + service_mock = MagicMock(return_value={"schema": True}) + monkeypatch.setattr( + controller_module.BuiltinToolManageService, + "list_builtin_provider_credentials_schema", + service_mock, + ) + + with app.test_request_context("/schema", method="GET"): + resp = controller_module.ToolBuiltinProviderCredentialsSchemaApi().get( + provider="demo", credential_type="api-key" + ) + + assert resp == {"schema": True} + service_mock.assert_called_once() + + +def test_workflow_provider_get_by_tool(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch): + user = _mock_account() + _set_current_account(monkeypatch, controller_module, user, "tenant-wf") + tool_service = MagicMock(return_value={"wf": 1}) + monkeypatch.setattr( + controller_module.WorkflowToolManageService, + "get_workflow_tool_by_tool_id", + tool_service, + ) + + tool_id = "00000000-0000-0000-0000-000000000001" + with app.test_request_context(f"/workflow?workflow_tool_id={tool_id}"): + resp = controller_module.ToolWorkflowProviderGetApi().get() + + assert resp == {"wf": 1} + tool_service.assert_called_once_with(user.id, "tenant-wf", tool_id) + + +def test_workflow_provider_get_by_app(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch): + user = _mock_account() + _set_current_account(monkeypatch, controller_module, user, "tenant-wf2") + service_mock = MagicMock(return_value={"app": 1}) + monkeypatch.setattr( + controller_module.WorkflowToolManageService, + "get_workflow_tool_by_app_id", + service_mock, + ) + + app_id = "00000000-0000-0000-0000-000000000002" + with app.test_request_context(f"/workflow?workflow_app_id={app_id}"): + resp = controller_module.ToolWorkflowProviderGetApi().get() + + assert resp == {"app": 1} + service_mock.assert_called_once_with(user.id, "tenant-wf2", app_id) + + +def test_workflow_provider_list_tools(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch): + user = _mock_account() + _set_current_account(monkeypatch, controller_module, user, "tenant-wf3") + service_mock = MagicMock(return_value=[{"id": 1}]) + monkeypatch.setattr(controller_module.WorkflowToolManageService, "list_single_workflow_tools", service_mock) + + tool_id = "00000000-0000-0000-0000-000000000003" + with app.test_request_context(f"/workflow/tools?workflow_tool_id={tool_id}"): + resp = controller_module.ToolWorkflowProviderListToolApi().get() + + assert resp == [{"id": 1}] + service_mock.assert_called_once_with(user.id, "tenant-wf3", tool_id) + + +def test_builtin_tools_list(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch): + user = _mock_account() + _set_current_account(monkeypatch, controller_module, user, "tenant-bt") + + provider = SimpleNamespace(to_dict=lambda: {"name": "builtin"}) + monkeypatch.setattr( + controller_module.BuiltinToolManageService, + "list_builtin_tools", + MagicMock(return_value=[provider]), + ) + + with app.test_request_context("/tools/builtin"): + resp = controller_module.ToolBuiltinListApi().get() + + assert resp == [{"name": "builtin"}] + + +def test_api_tools_list(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch): + user = _mock_account("user-tenant-api") + _set_current_account(monkeypatch, controller_module, user, "tenant-api") + + provider = SimpleNamespace(to_dict=lambda: {"name": "api"}) + monkeypatch.setattr( + controller_module.ApiToolManageService, + "list_api_tools", + MagicMock(return_value=[provider]), + ) + + with app.test_request_context("/tools/api"): + resp = controller_module.ToolApiListApi().get() + + assert resp == [{"name": "api"}] + + +def test_workflow_tools_list(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch): + user = _mock_account() + _set_current_account(monkeypatch, controller_module, user, "tenant-wf4") + + provider = SimpleNamespace(to_dict=lambda: {"name": "wf"}) + monkeypatch.setattr( + controller_module.WorkflowToolManageService, + "list_tenant_workflow_tools", + MagicMock(return_value=[provider]), + ) + + with app.test_request_context("/tools/workflow"): + resp = controller_module.ToolWorkflowListApi().get() + + assert resp == [{"name": "wf"}] + + +def test_tool_labels_list(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch): + user = _mock_account("user-label") + _set_current_account(monkeypatch, controller_module, user, "tenant-labels") + monkeypatch.setattr(controller_module.ToolLabelsService, "list_tool_labels", lambda: ["a", "b"]) + + with app.test_request_context("/tool-labels"): + resp = controller_module.ToolLabelsApi().get() + + assert resp == ["a", "b"] From ac222a4dd4f030e06a0e0b47daa7c11d0514f0d1 Mon Sep 17 00:00:00 2001 From: Asuka Minato Date: Mon, 2 Feb 2026 18:03:07 +0900 Subject: [PATCH 07/43] refactor: port api/controllers/console/app/audio.py api/controllers/console/app/message.py api/controllers/console/auth/data_source_oauth.py api/controllers/console/auth/forgot_password.py api/controllers/console/workspace/endpoint.py (#30680) --- api/controllers/console/app/audio.py | 16 ++--- api/controllers/console/app/message.py | 31 +++++---- .../console/auth/data_source_oauth.py | 33 +++++++-- .../console/auth/forgot_password.py | 50 ++++++++------ api/controllers/console/workspace/endpoint.py | 69 ++++++++++++++----- .../clickzetta_volume_storage.py | 3 +- 6 files changed, 135 insertions(+), 67 deletions(-) diff --git a/api/controllers/console/app/audio.py b/api/controllers/console/app/audio.py index d344ede466..941db325bf 100644 --- a/api/controllers/console/app/audio.py +++ b/api/controllers/console/app/audio.py @@ -6,6 +6,7 @@ from pydantic import BaseModel, Field from werkzeug.exceptions import InternalServerError import services +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.app.error import ( AppUnavailableError, @@ -33,7 +34,6 @@ from services.errors.audio import ( ) logger = logging.getLogger(__name__) -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" class TextToSpeechPayload(BaseModel): @@ -47,13 +47,11 @@ class TextToSpeechVoiceQuery(BaseModel): language: str = Field(..., description="Language code") -console_ns.schema_model( - TextToSpeechPayload.__name__, TextToSpeechPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) -) -console_ns.schema_model( - TextToSpeechVoiceQuery.__name__, - TextToSpeechVoiceQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), -) +class AudioTranscriptResponse(BaseModel): + text: str = Field(description="Transcribed text from audio") + + +register_schema_models(console_ns, AudioTranscriptResponse, TextToSpeechPayload, TextToSpeechVoiceQuery) @console_ns.route("/apps//audio-to-text") @@ -64,7 +62,7 @@ class ChatMessageAudioApi(Resource): @console_ns.response( 200, "Audio transcription successful", - console_ns.model("AudioTranscriptResponse", {"text": fields.String(description="Transcribed text from audio")}), + console_ns.models[AudioTranscriptResponse.__name__], ) @console_ns.response(400, "Bad request - No audio uploaded or unsupported type") @console_ns.response(413, "Audio file too large") diff --git a/api/controllers/console/app/message.py b/api/controllers/console/app/message.py index 12ada8b798..0be3e0ec49 100644 --- a/api/controllers/console/app/message.py +++ b/api/controllers/console/app/message.py @@ -7,6 +7,7 @@ from pydantic import BaseModel, Field, field_validator from sqlalchemy import exists, select from werkzeug.exceptions import InternalServerError, NotFound +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.app.error import ( CompletionRequestError, @@ -35,7 +36,6 @@ from services.errors.message import MessageNotExistsError, SuggestedQuestionsAft from services.message_service import MessageService logger = logging.getLogger(__name__) -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" class ChatMessagesQuery(BaseModel): @@ -90,13 +90,22 @@ class FeedbackExportQuery(BaseModel): raise ValueError("has_comment must be a boolean value") -def reg(cls: type[BaseModel]): - console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) +class AnnotationCountResponse(BaseModel): + count: int = Field(description="Number of annotations") -reg(ChatMessagesQuery) -reg(MessageFeedbackPayload) -reg(FeedbackExportQuery) +class SuggestedQuestionsResponse(BaseModel): + data: list[str] = Field(description="Suggested question") + + +register_schema_models( + console_ns, + ChatMessagesQuery, + MessageFeedbackPayload, + FeedbackExportQuery, + AnnotationCountResponse, + SuggestedQuestionsResponse, +) # Register models for flask_restx to avoid dict type issues in Swagger # Register in dependency order: base models first, then dependent models @@ -231,7 +240,7 @@ class ChatMessageListApi(Resource): @marshal_with(message_infinite_scroll_pagination_model) @edit_permission_required def get(self, app_model): - args = ChatMessagesQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = ChatMessagesQuery.model_validate(request.args.to_dict()) conversation = ( db.session.query(Conversation) @@ -356,7 +365,7 @@ class MessageAnnotationCountApi(Resource): @console_ns.response( 200, "Annotation count retrieved successfully", - console_ns.model("AnnotationCountResponse", {"count": fields.Integer(description="Number of annotations")}), + console_ns.models[AnnotationCountResponse.__name__], ) @get_app_model @setup_required @@ -376,9 +385,7 @@ class MessageSuggestedQuestionApi(Resource): @console_ns.response( 200, "Suggested questions retrieved successfully", - console_ns.model( - "SuggestedQuestionsResponse", {"data": fields.List(fields.String(description="Suggested question"))} - ), + console_ns.models[SuggestedQuestionsResponse.__name__], ) @console_ns.response(404, "Message or conversation not found") @setup_required @@ -428,7 +435,7 @@ class MessageFeedbackExportApi(Resource): @login_required @account_initialization_required def get(self, app_model): - args = FeedbackExportQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = FeedbackExportQuery.model_validate(request.args.to_dict()) # Import the service function from services.feedback_service import FeedbackService diff --git a/api/controllers/console/auth/data_source_oauth.py b/api/controllers/console/auth/data_source_oauth.py index 0dd7d33ae9..3a3278ec9d 100644 --- a/api/controllers/console/auth/data_source_oauth.py +++ b/api/controllers/console/auth/data_source_oauth.py @@ -2,9 +2,11 @@ import logging import httpx from flask import current_app, redirect, request -from flask_restx import Resource, fields +from flask_restx import Resource +from pydantic import BaseModel, Field from configs import dify_config +from controllers.common.schema import register_schema_models from libs.login import login_required from libs.oauth_data_source import NotionOAuth @@ -14,6 +16,26 @@ from ..wraps import account_initialization_required, is_admin_or_owner_required, logger = logging.getLogger(__name__) +class OAuthDataSourceResponse(BaseModel): + data: str = Field(description="Authorization URL or 'internal' for internal setup") + + +class OAuthDataSourceBindingResponse(BaseModel): + result: str = Field(description="Operation result") + + +class OAuthDataSourceSyncResponse(BaseModel): + result: str = Field(description="Operation result") + + +register_schema_models( + console_ns, + OAuthDataSourceResponse, + OAuthDataSourceBindingResponse, + OAuthDataSourceSyncResponse, +) + + def get_oauth_providers(): with current_app.app_context(): notion_oauth = NotionOAuth( @@ -34,10 +56,7 @@ class OAuthDataSource(Resource): @console_ns.response( 200, "Authorization URL or internal setup success", - console_ns.model( - "OAuthDataSourceResponse", - {"data": fields.Raw(description="Authorization URL or 'internal' for internal setup")}, - ), + console_ns.models[OAuthDataSourceResponse.__name__], ) @console_ns.response(400, "Invalid provider") @console_ns.response(403, "Admin privileges required") @@ -101,7 +120,7 @@ class OAuthDataSourceBinding(Resource): @console_ns.response( 200, "Data source binding success", - console_ns.model("OAuthDataSourceBindingResponse", {"result": fields.String(description="Operation result")}), + console_ns.models[OAuthDataSourceBindingResponse.__name__], ) @console_ns.response(400, "Invalid provider or code") def get(self, provider: str): @@ -133,7 +152,7 @@ class OAuthDataSourceSync(Resource): @console_ns.response( 200, "Data source sync success", - console_ns.model("OAuthDataSourceSyncResponse", {"result": fields.String(description="Operation result")}), + console_ns.models[OAuthDataSourceSyncResponse.__name__], ) @console_ns.response(400, "Invalid provider or sync failed") @setup_required diff --git a/api/controllers/console/auth/forgot_password.py b/api/controllers/console/auth/forgot_password.py index 394f205d93..1ed931b0d7 100644 --- a/api/controllers/console/auth/forgot_password.py +++ b/api/controllers/console/auth/forgot_password.py @@ -2,10 +2,11 @@ import base64 import secrets from flask import request -from flask_restx import Resource, fields +from flask_restx import Resource from pydantic import BaseModel, Field, field_validator from sqlalchemy.orm import Session +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.auth.error import ( EmailCodeError, @@ -48,8 +49,31 @@ class ForgotPasswordResetPayload(BaseModel): return valid_password(value) -for model in (ForgotPasswordSendPayload, ForgotPasswordCheckPayload, ForgotPasswordResetPayload): - console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) +class ForgotPasswordEmailResponse(BaseModel): + result: str = Field(description="Operation result") + data: str | None = Field(default=None, description="Reset token") + code: str | None = Field(default=None, description="Error code if account not found") + + +class ForgotPasswordCheckResponse(BaseModel): + is_valid: bool = Field(description="Whether code is valid") + email: EmailStr = Field(description="Email address") + token: str = Field(description="New reset token") + + +class ForgotPasswordResetResponse(BaseModel): + result: str = Field(description="Operation result") + + +register_schema_models( + console_ns, + ForgotPasswordSendPayload, + ForgotPasswordCheckPayload, + ForgotPasswordResetPayload, + ForgotPasswordEmailResponse, + ForgotPasswordCheckResponse, + ForgotPasswordResetResponse, +) @console_ns.route("/forgot-password") @@ -60,14 +84,7 @@ class ForgotPasswordSendEmailApi(Resource): @console_ns.response( 200, "Email sent successfully", - console_ns.model( - "ForgotPasswordEmailResponse", - { - "result": fields.String(description="Operation result"), - "data": fields.String(description="Reset token"), - "code": fields.String(description="Error code if account not found"), - }, - ), + console_ns.models[ForgotPasswordEmailResponse.__name__], ) @console_ns.response(400, "Invalid email or rate limit exceeded") @setup_required @@ -106,14 +123,7 @@ class ForgotPasswordCheckApi(Resource): @console_ns.response( 200, "Code verified successfully", - console_ns.model( - "ForgotPasswordCheckResponse", - { - "is_valid": fields.Boolean(description="Whether code is valid"), - "email": fields.String(description="Email address"), - "token": fields.String(description="New reset token"), - }, - ), + console_ns.models[ForgotPasswordCheckResponse.__name__], ) @console_ns.response(400, "Invalid code or token") @setup_required @@ -163,7 +173,7 @@ class ForgotPasswordResetApi(Resource): @console_ns.response( 200, "Password reset successfully", - console_ns.model("ForgotPasswordResetResponse", {"result": fields.String(description="Operation result")}), + console_ns.models[ForgotPasswordResetResponse.__name__], ) @console_ns.response(400, "Invalid token or password mismatch") @setup_required diff --git a/api/controllers/console/workspace/endpoint.py b/api/controllers/console/workspace/endpoint.py index bfd9fc6c29..1897cbdca7 100644 --- a/api/controllers/console/workspace/endpoint.py +++ b/api/controllers/console/workspace/endpoint.py @@ -1,9 +1,10 @@ from typing import Any from flask import request -from flask_restx import Resource, fields +from flask_restx import Resource from pydantic import BaseModel, Field +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required from core.model_runtime.utils.encoders import jsonable_encoder @@ -38,15 +39,53 @@ class EndpointListForPluginQuery(EndpointListQuery): plugin_id: str +class EndpointCreateResponse(BaseModel): + success: bool = Field(description="Operation success") + + +class EndpointListResponse(BaseModel): + endpoints: list[dict[str, Any]] = Field(description="Endpoint information") + + +class PluginEndpointListResponse(BaseModel): + endpoints: list[dict[str, Any]] = Field(description="Endpoint information") + + +class EndpointDeleteResponse(BaseModel): + success: bool = Field(description="Operation success") + + +class EndpointUpdateResponse(BaseModel): + success: bool = Field(description="Operation success") + + +class EndpointEnableResponse(BaseModel): + success: bool = Field(description="Operation success") + + +class EndpointDisableResponse(BaseModel): + success: bool = Field(description="Operation success") + + def reg(cls: type[BaseModel]): console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) -reg(EndpointCreatePayload) -reg(EndpointIdPayload) -reg(EndpointUpdatePayload) -reg(EndpointListQuery) -reg(EndpointListForPluginQuery) +register_schema_models( + console_ns, + EndpointCreatePayload, + EndpointIdPayload, + EndpointUpdatePayload, + EndpointListQuery, + EndpointListForPluginQuery, + EndpointCreateResponse, + EndpointListResponse, + PluginEndpointListResponse, + EndpointDeleteResponse, + EndpointUpdateResponse, + EndpointEnableResponse, + EndpointDisableResponse, +) @console_ns.route("/workspaces/current/endpoints/create") @@ -57,7 +96,7 @@ class EndpointCreateApi(Resource): @console_ns.response( 200, "Endpoint created successfully", - console_ns.model("EndpointCreateResponse", {"success": fields.Boolean(description="Operation success")}), + console_ns.models[EndpointCreateResponse.__name__], ) @console_ns.response(403, "Admin privileges required") @setup_required @@ -91,9 +130,7 @@ class EndpointListApi(Resource): @console_ns.response( 200, "Success", - console_ns.model( - "EndpointListResponse", {"endpoints": fields.List(fields.Raw(description="Endpoint information"))} - ), + console_ns.models[EndpointListResponse.__name__], ) @setup_required @login_required @@ -126,9 +163,7 @@ class EndpointListForSinglePluginApi(Resource): @console_ns.response( 200, "Success", - console_ns.model( - "PluginEndpointListResponse", {"endpoints": fields.List(fields.Raw(description="Endpoint information"))} - ), + console_ns.models[PluginEndpointListResponse.__name__], ) @setup_required @login_required @@ -163,7 +198,7 @@ class EndpointDeleteApi(Resource): @console_ns.response( 200, "Endpoint deleted successfully", - console_ns.model("EndpointDeleteResponse", {"success": fields.Boolean(description="Operation success")}), + console_ns.models[EndpointDeleteResponse.__name__], ) @console_ns.response(403, "Admin privileges required") @setup_required @@ -190,7 +225,7 @@ class EndpointUpdateApi(Resource): @console_ns.response( 200, "Endpoint updated successfully", - console_ns.model("EndpointUpdateResponse", {"success": fields.Boolean(description="Operation success")}), + console_ns.models[EndpointUpdateResponse.__name__], ) @console_ns.response(403, "Admin privileges required") @setup_required @@ -221,7 +256,7 @@ class EndpointEnableApi(Resource): @console_ns.response( 200, "Endpoint enabled successfully", - console_ns.model("EndpointEnableResponse", {"success": fields.Boolean(description="Operation success")}), + console_ns.models[EndpointEnableResponse.__name__], ) @console_ns.response(403, "Admin privileges required") @setup_required @@ -248,7 +283,7 @@ class EndpointDisableApi(Resource): @console_ns.response( 200, "Endpoint disabled successfully", - console_ns.model("EndpointDisableResponse", {"success": fields.Boolean(description="Operation success")}), + console_ns.models[EndpointDisableResponse.__name__], ) @console_ns.response(403, "Admin privileges required") @setup_required diff --git a/api/extensions/storage/clickzetta_volume/clickzetta_volume_storage.py b/api/extensions/storage/clickzetta_volume/clickzetta_volume_storage.py index c1608f58a5..18eed4e481 100644 --- a/api/extensions/storage/clickzetta_volume/clickzetta_volume_storage.py +++ b/api/extensions/storage/clickzetta_volume/clickzetta_volume_storage.py @@ -390,8 +390,7 @@ class ClickZettaVolumeStorage(BaseStorage): """ content = self.load_once(filename) - with Path(target_filepath).open("wb") as f: - f.write(content) + Path(target_filepath).write_bytes(content) logger.debug("File %s downloaded from ClickZetta Volume to %s", filename, target_filepath) From 920db69ef2d52034df50a4f1821a7c63d003a544 Mon Sep 17 00:00:00 2001 From: Asuka Minato Date: Mon, 2 Feb 2026 18:12:03 +0900 Subject: [PATCH 08/43] refactor: if to match (#31799) --- api/commands.py | 203 +++++++++--------- api/controllers/console/app/conversation.py | 23 +- .../console/datasets/datasets_document.py | 107 +++++---- api/controllers/service_api/wraps.py | 16 +- 4 files changed, 179 insertions(+), 170 deletions(-) diff --git a/api/commands.py b/api/commands.py index 4b811fb1e6..c4f2c9edbb 100644 --- a/api/commands.py +++ b/api/commands.py @@ -1450,54 +1450,58 @@ def clear_orphaned_file_records(force: bool): all_ids_in_tables = [] for ids_table in ids_tables: query = "" - if ids_table["type"] == "uuid": - click.echo( - click.style( - f"- Listing file ids in column {ids_table['column']} in table {ids_table['table']}", fg="white" + match ids_table["type"]: + case "uuid": + click.echo( + click.style( + f"- Listing file ids in column {ids_table['column']} in table {ids_table['table']}", + fg="white", + ) ) - ) - query = ( - f"SELECT {ids_table['column']} FROM {ids_table['table']} WHERE {ids_table['column']} IS NOT NULL" - ) - with db.engine.begin() as conn: - rs = conn.execute(sa.text(query)) - for i in rs: - all_ids_in_tables.append({"table": ids_table["table"], "id": str(i[0])}) - elif ids_table["type"] == "text": - click.echo( - click.style( - f"- Listing file-id-like strings in column {ids_table['column']} in table {ids_table['table']}", - fg="white", + c = ids_table["column"] + query = f"SELECT {c} FROM {ids_table['table']} WHERE {c} IS NOT NULL" + with db.engine.begin() as conn: + rs = conn.execute(sa.text(query)) + for i in rs: + all_ids_in_tables.append({"table": ids_table["table"], "id": str(i[0])}) + case "text": + t = ids_table["table"] + click.echo( + click.style( + f"- Listing file-id-like strings in column {ids_table['column']} in table {t}", + fg="white", + ) ) - ) - query = ( - f"SELECT regexp_matches({ids_table['column']}, '{guid_regexp}', 'g') AS extracted_id " - f"FROM {ids_table['table']}" - ) - with db.engine.begin() as conn: - rs = conn.execute(sa.text(query)) - for i in rs: - for j in i[0]: - all_ids_in_tables.append({"table": ids_table["table"], "id": j}) - elif ids_table["type"] == "json": - click.echo( - click.style( - ( - f"- Listing file-id-like JSON string in column {ids_table['column']} " - f"in table {ids_table['table']}" - ), - fg="white", + query = ( + f"SELECT regexp_matches({ids_table['column']}, '{guid_regexp}', 'g') AS extracted_id " + f"FROM {ids_table['table']}" ) - ) - query = ( - f"SELECT regexp_matches({ids_table['column']}::text, '{guid_regexp}', 'g') AS extracted_id " - f"FROM {ids_table['table']}" - ) - with db.engine.begin() as conn: - rs = conn.execute(sa.text(query)) - for i in rs: - for j in i[0]: - all_ids_in_tables.append({"table": ids_table["table"], "id": j}) + with db.engine.begin() as conn: + rs = conn.execute(sa.text(query)) + for i in rs: + for j in i[0]: + all_ids_in_tables.append({"table": ids_table["table"], "id": j}) + case "json": + click.echo( + click.style( + ( + f"- Listing file-id-like JSON string in column {ids_table['column']} " + f"in table {ids_table['table']}" + ), + fg="white", + ) + ) + query = ( + f"SELECT regexp_matches({ids_table['column']}::text, '{guid_regexp}', 'g') AS extracted_id " + f"FROM {ids_table['table']}" + ) + with db.engine.begin() as conn: + rs = conn.execute(sa.text(query)) + for i in rs: + for j in i[0]: + all_ids_in_tables.append({"table": ids_table["table"], "id": j}) + case _: + pass click.echo(click.style(f"Found {len(all_ids_in_tables)} file ids in tables.", fg="white")) except Exception as e: @@ -1737,59 +1741,18 @@ def file_usage( if src_filter != src: continue - if ids_table["type"] == "uuid": - # Direct UUID match - query = ( - f"SELECT {ids_table['pk_column']}, {ids_table['column']} " - f"FROM {ids_table['table']} WHERE {ids_table['column']} IS NOT NULL" - ) - with db.engine.begin() as conn: - rs = conn.execute(sa.text(query)) - for row in rs: - record_id = str(row[0]) - ref_file_id = str(row[1]) - if ref_file_id not in file_key_map: - continue - storage_key = file_key_map[ref_file_id] - - # Apply filters - if file_id and ref_file_id != file_id: - continue - if key and not storage_key.endswith(key): - continue - - # Only collect items within the requested page range - if offset <= total_count < offset + limit: - paginated_usages.append( - { - "src": f"{ids_table['table']}.{ids_table['column']}", - "record_id": record_id, - "file_id": ref_file_id, - "key": storage_key, - } - ) - total_count += 1 - - elif ids_table["type"] in ("text", "json"): - # Extract UUIDs from text/json content - column_cast = f"{ids_table['column']}::text" if ids_table["type"] == "json" else ids_table["column"] - query = ( - f"SELECT {ids_table['pk_column']}, {column_cast} " - f"FROM {ids_table['table']} WHERE {ids_table['column']} IS NOT NULL" - ) - with db.engine.begin() as conn: - rs = conn.execute(sa.text(query)) - for row in rs: - record_id = str(row[0]) - content = str(row[1]) - - # Find all UUIDs in the content - import re - - uuid_pattern = re.compile(guid_regexp, re.IGNORECASE) - matches = uuid_pattern.findall(content) - - for ref_file_id in matches: + match ids_table["type"]: + case "uuid": + # Direct UUID match + query = ( + f"SELECT {ids_table['pk_column']}, {ids_table['column']} " + f"FROM {ids_table['table']} WHERE {ids_table['column']} IS NOT NULL" + ) + with db.engine.begin() as conn: + rs = conn.execute(sa.text(query)) + for row in rs: + record_id = str(row[0]) + ref_file_id = str(row[1]) if ref_file_id not in file_key_map: continue storage_key = file_key_map[ref_file_id] @@ -1812,6 +1775,50 @@ def file_usage( ) total_count += 1 + case "text" | "json": + # Extract UUIDs from text/json content + column_cast = f"{ids_table['column']}::text" if ids_table["type"] == "json" else ids_table["column"] + query = ( + f"SELECT {ids_table['pk_column']}, {column_cast} " + f"FROM {ids_table['table']} WHERE {ids_table['column']} IS NOT NULL" + ) + with db.engine.begin() as conn: + rs = conn.execute(sa.text(query)) + for row in rs: + record_id = str(row[0]) + content = str(row[1]) + + # Find all UUIDs in the content + import re + + uuid_pattern = re.compile(guid_regexp, re.IGNORECASE) + matches = uuid_pattern.findall(content) + + for ref_file_id in matches: + if ref_file_id not in file_key_map: + continue + storage_key = file_key_map[ref_file_id] + + # Apply filters + if file_id and ref_file_id != file_id: + continue + if key and not storage_key.endswith(key): + continue + + # Only collect items within the requested page range + if offset <= total_count < offset + limit: + paginated_usages.append( + { + "src": f"{ids_table['table']}.{ids_table['column']}", + "record_id": record_id, + "file_id": ref_file_id, + "key": storage_key, + } + ) + total_count += 1 + case _: + pass + # Output results if output_json: result = { diff --git a/api/controllers/console/app/conversation.py b/api/controllers/console/app/conversation.py index 55fdcb51e4..82cc957d04 100644 --- a/api/controllers/console/app/conversation.py +++ b/api/controllers/console/app/conversation.py @@ -508,16 +508,19 @@ class ChatConversationApi(Resource): case "created_at" | "-created_at" | _: query = query.where(Conversation.created_at <= end_datetime_utc) - if args.annotation_status == "annotated": - query = query.options(joinedload(Conversation.message_annotations)).join( # type: ignore - MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id - ) - elif args.annotation_status == "not_annotated": - query = ( - query.outerjoin(MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id) - .group_by(Conversation.id) - .having(func.count(MessageAnnotation.id) == 0) - ) + match args.annotation_status: + case "annotated": + query = query.options(joinedload(Conversation.message_annotations)).join( # type: ignore + MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id + ) + case "not_annotated": + query = ( + query.outerjoin(MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id) + .group_by(Conversation.id) + .having(func.count(MessageAnnotation.id) == 0) + ) + case "all": + pass if app_model.mode == AppMode.ADVANCED_CHAT: query = query.where(Conversation.invoke_from != InvokeFrom.DEBUGGER) diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index 6a0c9e5f77..e8b8f2ec6d 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -576,63 +576,62 @@ class DocumentBatchIndexingEstimateApi(DocumentResource): if document.indexing_status in {"completed", "error"}: raise DocumentAlreadyFinishedError() data_source_info = document.data_source_info_dict + match document.data_source_type: + case "upload_file": + if not data_source_info: + continue + file_id = data_source_info["upload_file_id"] + file_detail = ( + db.session.query(UploadFile) + .where(UploadFile.tenant_id == current_tenant_id, UploadFile.id == file_id) + .first() + ) - if document.data_source_type == "upload_file": - if not data_source_info: - continue - file_id = data_source_info["upload_file_id"] - file_detail = ( - db.session.query(UploadFile) - .where(UploadFile.tenant_id == current_tenant_id, UploadFile.id == file_id) - .first() - ) + if file_detail is None: + raise NotFound("File not found.") - if file_detail is None: - raise NotFound("File not found.") + extract_setting = ExtractSetting( + datasource_type=DatasourceType.FILE, upload_file=file_detail, document_model=document.doc_form + ) + extract_settings.append(extract_setting) + case "notion_import": + if not data_source_info: + continue + extract_setting = ExtractSetting( + datasource_type=DatasourceType.NOTION, + notion_info=NotionInfo.model_validate( + { + "credential_id": data_source_info.get("credential_id"), + "notion_workspace_id": data_source_info["notion_workspace_id"], + "notion_obj_id": data_source_info["notion_page_id"], + "notion_page_type": data_source_info["type"], + "tenant_id": current_tenant_id, + } + ), + document_model=document.doc_form, + ) + extract_settings.append(extract_setting) + case "website_crawl": + if not data_source_info: + continue + extract_setting = ExtractSetting( + datasource_type=DatasourceType.WEBSITE, + website_info=WebsiteInfo.model_validate( + { + "provider": data_source_info["provider"], + "job_id": data_source_info["job_id"], + "url": data_source_info["url"], + "tenant_id": current_tenant_id, + "mode": data_source_info["mode"], + "only_main_content": data_source_info["only_main_content"], + } + ), + document_model=document.doc_form, + ) + extract_settings.append(extract_setting) - extract_setting = ExtractSetting( - datasource_type=DatasourceType.FILE, upload_file=file_detail, document_model=document.doc_form - ) - extract_settings.append(extract_setting) - - elif document.data_source_type == "notion_import": - if not data_source_info: - continue - extract_setting = ExtractSetting( - datasource_type=DatasourceType.NOTION, - notion_info=NotionInfo.model_validate( - { - "credential_id": data_source_info.get("credential_id"), - "notion_workspace_id": data_source_info["notion_workspace_id"], - "notion_obj_id": data_source_info["notion_page_id"], - "notion_page_type": data_source_info["type"], - "tenant_id": current_tenant_id, - } - ), - document_model=document.doc_form, - ) - extract_settings.append(extract_setting) - elif document.data_source_type == "website_crawl": - if not data_source_info: - continue - extract_setting = ExtractSetting( - datasource_type=DatasourceType.WEBSITE, - website_info=WebsiteInfo.model_validate( - { - "provider": data_source_info["provider"], - "job_id": data_source_info["job_id"], - "url": data_source_info["url"], - "tenant_id": current_tenant_id, - "mode": data_source_info["mode"], - "only_main_content": data_source_info["only_main_content"], - } - ), - document_model=document.doc_form, - ) - extract_settings.append(extract_setting) - - else: - raise ValueError("Data source type not support") + case _: + raise ValueError("Data source type not support") indexing_runner = IndexingRunner() try: response = indexing_runner.indexing_estimate( diff --git a/api/controllers/service_api/wraps.py b/api/controllers/service_api/wraps.py index 24acced0d1..e597a72fc0 100644 --- a/api/controllers/service_api/wraps.py +++ b/api/controllers/service_api/wraps.py @@ -73,14 +73,14 @@ def validate_app_token(view: Callable[P, R] | None = None, *, fetch_user_arg: Fe # If caller needs end-user context, attach EndUser to current_user if fetch_user_arg: - if fetch_user_arg.fetch_from == WhereisUserArg.QUERY: - user_id = request.args.get("user") - elif fetch_user_arg.fetch_from == WhereisUserArg.JSON: - user_id = request.get_json().get("user") - elif fetch_user_arg.fetch_from == WhereisUserArg.FORM: - user_id = request.form.get("user") - else: - user_id = None + user_id = None + match fetch_user_arg.fetch_from: + case WhereisUserArg.QUERY: + user_id = request.args.get("user") + case WhereisUserArg.JSON: + user_id = request.get_json().get("user") + case WhereisUserArg.FORM: + user_id = request.form.get("user") if not user_id and fetch_user_arg.required: raise ValueError("Arg user must be provided.") From ce2c41bbf5662f2334e5b178f1d807c62b44c20e Mon Sep 17 00:00:00 2001 From: Asuka Minato Date: Mon, 2 Feb 2026 19:07:30 +0900 Subject: [PATCH 09/43] refactor: port api/controllers/console/datasets/datasets_document.py api/controllers/service_api/app/annotation.py api/core/app/app_config/easy_ui_based_app/agent/manager.py api/core/app/apps/pipeline/pipeline_generator.py api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py to match case (#31832) --- .../console/datasets/datasets_document.py | 29 +-- api/controllers/service_api/app/annotation.py | 9 +- .../easy_ui_based_app/agent/manager.py | 17 +- .../app/apps/pipeline/pipeline_generator.py | 30 +-- api/core/indexing_runner.py | 113 ++++++----- .../knowledge_retrieval_node.py | 180 +++++++++--------- api/core/workflow/nodes/tool/tool_node.py | 21 +- 7 files changed, 202 insertions(+), 197 deletions(-) diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index e8b8f2ec6d..bf097d374a 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -953,23 +953,24 @@ class DocumentProcessingApi(DocumentResource): if not current_user.is_dataset_editor: raise Forbidden() - if action == "pause": - if document.indexing_status != "indexing": - raise InvalidActionError("Document not in indexing state.") + match action: + case "pause": + if document.indexing_status != "indexing": + raise InvalidActionError("Document not in indexing state.") - document.paused_by = current_user.id - document.paused_at = naive_utc_now() - document.is_paused = True - db.session.commit() + document.paused_by = current_user.id + document.paused_at = naive_utc_now() + document.is_paused = True + db.session.commit() - elif action == "resume": - if document.indexing_status not in {"paused", "error"}: - raise InvalidActionError("Document not in paused or error state.") + case "resume": + if document.indexing_status not in {"paused", "error"}: + raise InvalidActionError("Document not in paused or error state.") - document.paused_by = None - document.paused_at = None - document.is_paused = False - db.session.commit() + document.paused_by = None + document.paused_at = None + document.is_paused = False + db.session.commit() return {"result": "success"}, 200 diff --git a/api/controllers/service_api/app/annotation.py b/api/controllers/service_api/app/annotation.py index 85ac9336d6..5be146a13e 100644 --- a/api/controllers/service_api/app/annotation.py +++ b/api/controllers/service_api/app/annotation.py @@ -45,10 +45,11 @@ class AnnotationReplyActionApi(Resource): def post(self, app_model: App, action: Literal["enable", "disable"]): """Enable or disable annotation reply feature.""" args = AnnotationReplyActionPayload.model_validate(service_api_ns.payload or {}).model_dump() - if action == "enable": - result = AppAnnotationService.enable_app_annotation(args, app_model.id) - elif action == "disable": - result = AppAnnotationService.disable_app_annotation(app_model.id) + match action: + case "enable": + result = AppAnnotationService.enable_app_annotation(args, app_model.id) + case "disable": + result = AppAnnotationService.disable_app_annotation(app_model.id) return result, 200 diff --git a/api/core/app/app_config/easy_ui_based_app/agent/manager.py b/api/core/app/app_config/easy_ui_based_app/agent/manager.py index c1f336fdde..9b981dfc09 100644 --- a/api/core/app/app_config/easy_ui_based_app/agent/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/agent/manager.py @@ -14,16 +14,17 @@ class AgentConfigManager: agent_dict = config.get("agent_mode", {}) agent_strategy = agent_dict.get("strategy", "cot") - if agent_strategy == "function_call": - strategy = AgentEntity.Strategy.FUNCTION_CALLING - elif agent_strategy in {"cot", "react"}: - strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT - else: - # old configs, try to detect default strategy - if config["model"]["provider"] == "openai": + match agent_strategy: + case "function_call": strategy = AgentEntity.Strategy.FUNCTION_CALLING - else: + case "cot" | "react": strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT + case _: + # old configs, try to detect default strategy + if config["model"]["provider"] == "openai": + strategy = AgentEntity.Strategy.FUNCTION_CALLING + else: + strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT agent_tools = [] for tool in agent_dict.get("tools", []): diff --git a/api/core/app/apps/pipeline/pipeline_generator.py b/api/core/app/apps/pipeline/pipeline_generator.py index ea4441b5d8..eca96cb074 100644 --- a/api/core/app/apps/pipeline/pipeline_generator.py +++ b/api/core/app/apps/pipeline/pipeline_generator.py @@ -120,7 +120,7 @@ class PipelineGenerator(BaseAppGenerator): raise ValueError("Pipeline dataset is required") inputs: Mapping[str, Any] = args["inputs"] start_node_id: str = args["start_node_id"] - datasource_type: str = args["datasource_type"] + datasource_type = DatasourceProviderType(args["datasource_type"]) datasource_info_list: list[Mapping[str, Any]] = self._format_datasource_info_list( datasource_type, args["datasource_info_list"], pipeline, workflow, start_node_id, user ) @@ -660,7 +660,7 @@ class PipelineGenerator(BaseAppGenerator): tenant_id: str, dataset_id: str, built_in_field_enabled: bool, - datasource_type: str, + datasource_type: DatasourceProviderType, datasource_info: Mapping[str, Any], created_from: str, position: int, @@ -668,17 +668,17 @@ class PipelineGenerator(BaseAppGenerator): batch: str, document_form: str, ): - if datasource_type == "local_file": - name = datasource_info.get("name", "untitled") - elif datasource_type == "online_document": - name = datasource_info.get("page", {}).get("page_name", "untitled") - elif datasource_type == "website_crawl": - name = datasource_info.get("title", "untitled") - elif datasource_type == "online_drive": - name = datasource_info.get("name", "untitled") - else: - raise ValueError(f"Unsupported datasource type: {datasource_type}") - + match datasource_type: + case DatasourceProviderType.LOCAL_FILE: + name = datasource_info.get("name", "untitled") + case DatasourceProviderType.ONLINE_DOCUMENT: + name = datasource_info.get("page", {}).get("page_name", "untitled") + case DatasourceProviderType.WEBSITE_CRAWL: + name = datasource_info.get("title", "untitled") + case DatasourceProviderType.ONLINE_DRIVE: + name = datasource_info.get("name", "untitled") + case _: + raise ValueError(f"Unsupported datasource type: {datasource_type}") document = Document( tenant_id=tenant_id, dataset_id=dataset_id, @@ -706,7 +706,7 @@ class PipelineGenerator(BaseAppGenerator): def _format_datasource_info_list( self, - datasource_type: str, + datasource_type: DatasourceProviderType, datasource_info_list: list[Mapping[str, Any]], pipeline: Pipeline, workflow: Workflow, @@ -716,7 +716,7 @@ class PipelineGenerator(BaseAppGenerator): """ Format datasource info list. """ - if datasource_type == "online_drive": + if datasource_type == DatasourceProviderType.ONLINE_DRIVE: all_files: list[Mapping[str, Any]] = [] datasource_node_data = None datasource_nodes = workflow.graph_dict.get("nodes", []) diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index 61f168a26f..4e3ad7bb75 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -378,70 +378,69 @@ class IndexingRunner: def _extract( self, index_processor: BaseIndexProcessor, dataset_document: DatasetDocument, process_rule: dict ) -> list[Document]: - # load file - if dataset_document.data_source_type not in {"upload_file", "notion_import", "website_crawl"}: - return [] - data_source_info = dataset_document.data_source_info_dict text_docs = [] - if dataset_document.data_source_type == "upload_file": - if not data_source_info or "upload_file_id" not in data_source_info: - raise ValueError("no upload file found") - stmt = select(UploadFile).where(UploadFile.id == data_source_info["upload_file_id"]) - file_detail = db.session.scalars(stmt).one_or_none() + match dataset_document.data_source_type: + case "upload_file": + if not data_source_info or "upload_file_id" not in data_source_info: + raise ValueError("no upload file found") + stmt = select(UploadFile).where(UploadFile.id == data_source_info["upload_file_id"]) + file_detail = db.session.scalars(stmt).one_or_none() - if file_detail: + if file_detail: + extract_setting = ExtractSetting( + datasource_type=DatasourceType.FILE, + upload_file=file_detail, + document_model=dataset_document.doc_form, + ) + text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"]) + case "notion_import": + if ( + not data_source_info + or "notion_workspace_id" not in data_source_info + or "notion_page_id" not in data_source_info + ): + raise ValueError("no notion import info found") extract_setting = ExtractSetting( - datasource_type=DatasourceType.FILE, - upload_file=file_detail, + datasource_type=DatasourceType.NOTION, + notion_info=NotionInfo.model_validate( + { + "credential_id": data_source_info.get("credential_id"), + "notion_workspace_id": data_source_info["notion_workspace_id"], + "notion_obj_id": data_source_info["notion_page_id"], + "notion_page_type": data_source_info["type"], + "document": dataset_document, + "tenant_id": dataset_document.tenant_id, + } + ), document_model=dataset_document.doc_form, ) text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"]) - elif dataset_document.data_source_type == "notion_import": - if ( - not data_source_info - or "notion_workspace_id" not in data_source_info - or "notion_page_id" not in data_source_info - ): - raise ValueError("no notion import info found") - extract_setting = ExtractSetting( - datasource_type=DatasourceType.NOTION, - notion_info=NotionInfo.model_validate( - { - "credential_id": data_source_info.get("credential_id"), - "notion_workspace_id": data_source_info["notion_workspace_id"], - "notion_obj_id": data_source_info["notion_page_id"], - "notion_page_type": data_source_info["type"], - "document": dataset_document, - "tenant_id": dataset_document.tenant_id, - } - ), - document_model=dataset_document.doc_form, - ) - text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"]) - elif dataset_document.data_source_type == "website_crawl": - if ( - not data_source_info - or "provider" not in data_source_info - or "url" not in data_source_info - or "job_id" not in data_source_info - ): - raise ValueError("no website import info found") - extract_setting = ExtractSetting( - datasource_type=DatasourceType.WEBSITE, - website_info=WebsiteInfo.model_validate( - { - "provider": data_source_info["provider"], - "job_id": data_source_info["job_id"], - "tenant_id": dataset_document.tenant_id, - "url": data_source_info["url"], - "mode": data_source_info["mode"], - "only_main_content": data_source_info["only_main_content"], - } - ), - document_model=dataset_document.doc_form, - ) - text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"]) + case "website_crawl": + if ( + not data_source_info + or "provider" not in data_source_info + or "url" not in data_source_info + or "job_id" not in data_source_info + ): + raise ValueError("no website import info found") + extract_setting = ExtractSetting( + datasource_type=DatasourceType.WEBSITE, + website_info=WebsiteInfo.model_validate( + { + "provider": data_source_info["provider"], + "job_id": data_source_info["job_id"], + "tenant_id": dataset_document.tenant_id, + "url": data_source_info["url"], + "mode": data_source_info["mode"], + "only_main_content": data_source_info["only_main_content"], + } + ), + document_model=dataset_document.doc_form, + ) + text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"]) + case _: + return [] # update document status to splitting self._update_document_index_status( document_id=dataset_document.id, diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index 3c4850ebac..0827494a48 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -303,33 +303,34 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD elif str(node_data.retrieval_mode) == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE: if node_data.multiple_retrieval_config is None: raise ValueError("multiple_retrieval_config is required") - if node_data.multiple_retrieval_config.reranking_mode == "reranking_model": - if node_data.multiple_retrieval_config.reranking_model: - reranking_model = { - "reranking_provider_name": node_data.multiple_retrieval_config.reranking_model.provider, - "reranking_model_name": node_data.multiple_retrieval_config.reranking_model.model, - } - else: + match node_data.multiple_retrieval_config.reranking_mode: + case "reranking_model": + if node_data.multiple_retrieval_config.reranking_model: + reranking_model = { + "reranking_provider_name": node_data.multiple_retrieval_config.reranking_model.provider, + "reranking_model_name": node_data.multiple_retrieval_config.reranking_model.model, + } + else: + reranking_model = None + weights = None + case "weighted_score": + if node_data.multiple_retrieval_config.weights is None: + raise ValueError("weights is required") reranking_model = None - weights = None - elif node_data.multiple_retrieval_config.reranking_mode == "weighted_score": - if node_data.multiple_retrieval_config.weights is None: - raise ValueError("weights is required") - reranking_model = None - vector_setting = node_data.multiple_retrieval_config.weights.vector_setting - weights = { - "vector_setting": { - "vector_weight": vector_setting.vector_weight, - "embedding_provider_name": vector_setting.embedding_provider_name, - "embedding_model_name": vector_setting.embedding_model_name, - }, - "keyword_setting": { - "keyword_weight": node_data.multiple_retrieval_config.weights.keyword_setting.keyword_weight - }, - } - else: - reranking_model = None - weights = None + vector_setting = node_data.multiple_retrieval_config.weights.vector_setting + weights = { + "vector_setting": { + "vector_weight": vector_setting.vector_weight, + "embedding_provider_name": vector_setting.embedding_provider_name, + "embedding_model_name": vector_setting.embedding_model_name, + }, + "keyword_setting": { + "keyword_weight": node_data.multiple_retrieval_config.weights.keyword_setting.keyword_weight + }, + } + case _: + reranking_model = None + weights = None all_documents = dataset_retrieval.multiple_retrieve( app_id=self.app_id, tenant_id=self.tenant_id, @@ -453,73 +454,74 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD ) filters: list[Any] = [] metadata_condition = None - if node_data.metadata_filtering_mode == "disabled": - return None, None, usage - elif node_data.metadata_filtering_mode == "automatic": - automatic_metadata_filters, automatic_usage = self._automatic_metadata_filter_func( - dataset_ids, query, node_data - ) - usage = self._merge_usage(usage, automatic_usage) - if automatic_metadata_filters: - conditions = [] - for sequence, filter in enumerate(automatic_metadata_filters): - DatasetRetrieval.process_metadata_filter_func( - sequence, - filter.get("condition", ""), - filter.get("metadata_name", ""), - filter.get("value"), - filters, - ) - conditions.append( - Condition( - name=filter.get("metadata_name"), # type: ignore - comparison_operator=filter.get("condition"), # type: ignore - value=filter.get("value"), - ) - ) - metadata_condition = MetadataCondition( - logical_operator=node_data.metadata_filtering_conditions.logical_operator - if node_data.metadata_filtering_conditions - else "or", - conditions=conditions, + match node_data.metadata_filtering_mode: + case "disabled": + return None, None, usage + case "automatic": + automatic_metadata_filters, automatic_usage = self._automatic_metadata_filter_func( + dataset_ids, query, node_data ) - elif node_data.metadata_filtering_mode == "manual": - if node_data.metadata_filtering_conditions: - conditions = [] - for sequence, condition in enumerate(node_data.metadata_filtering_conditions.conditions): # type: ignore - metadata_name = condition.name - expected_value = condition.value - if expected_value is not None and condition.comparison_operator not in ("empty", "not empty"): - if isinstance(expected_value, str): - expected_value = self.graph_runtime_state.variable_pool.convert_template( - expected_value - ).value[0] - if expected_value.value_type in {"number", "integer", "float"}: - expected_value = expected_value.value - elif expected_value.value_type == "string": - expected_value = re.sub(r"[\r\n\t]+", " ", expected_value.text).strip() - else: - raise ValueError("Invalid expected metadata value type") - conditions.append( - Condition( - name=metadata_name, - comparison_operator=condition.comparison_operator, - value=expected_value, + usage = self._merge_usage(usage, automatic_usage) + if automatic_metadata_filters: + conditions = [] + for sequence, filter in enumerate(automatic_metadata_filters): + DatasetRetrieval.process_metadata_filter_func( + sequence, + filter.get("condition", ""), + filter.get("metadata_name", ""), + filter.get("value"), + filters, ) + conditions.append( + Condition( + name=filter.get("metadata_name"), # type: ignore + comparison_operator=filter.get("condition"), # type: ignore + value=filter.get("value"), + ) + ) + metadata_condition = MetadataCondition( + logical_operator=node_data.metadata_filtering_conditions.logical_operator + if node_data.metadata_filtering_conditions + else "or", + conditions=conditions, ) - filters = DatasetRetrieval.process_metadata_filter_func( - sequence, - condition.comparison_operator, - metadata_name, - expected_value, - filters, + case "manual": + if node_data.metadata_filtering_conditions: + conditions = [] + for sequence, condition in enumerate(node_data.metadata_filtering_conditions.conditions): # type: ignore + metadata_name = condition.name + expected_value = condition.value + if expected_value is not None and condition.comparison_operator not in ("empty", "not empty"): + if isinstance(expected_value, str): + expected_value = self.graph_runtime_state.variable_pool.convert_template( + expected_value + ).value[0] + if expected_value.value_type in {"number", "integer", "float"}: + expected_value = expected_value.value + elif expected_value.value_type == "string": + expected_value = re.sub(r"[\r\n\t]+", " ", expected_value.text).strip() + else: + raise ValueError("Invalid expected metadata value type") + conditions.append( + Condition( + name=metadata_name, + comparison_operator=condition.comparison_operator, + value=expected_value, + ) + ) + filters = DatasetRetrieval.process_metadata_filter_func( + sequence, + condition.comparison_operator, + metadata_name, + expected_value, + filters, + ) + metadata_condition = MetadataCondition( + logical_operator=node_data.metadata_filtering_conditions.logical_operator, + conditions=conditions, ) - metadata_condition = MetadataCondition( - logical_operator=node_data.metadata_filtering_conditions.logical_operator, - conditions=conditions, - ) - else: - raise ValueError("Invalid metadata filtering mode") + case _: + raise ValueError("Invalid metadata filtering mode") if filters: if ( node_data.metadata_filtering_conditions diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index 68ac60e4f6..60d76db9b6 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -482,16 +482,17 @@ class ToolNode(Node[ToolNodeData]): result = {} for parameter_name in typed_node_data.tool_parameters: input = typed_node_data.tool_parameters[parameter_name] - if input.type == "mixed": - assert isinstance(input.value, str) - selectors = VariableTemplateParser(input.value).extract_variable_selectors() - for selector in selectors: - result[selector.variable] = selector.value_selector - elif input.type == "variable": - selector_key = ".".join(input.value) - result[f"#{selector_key}#"] = input.value - elif input.type == "constant": - pass + match input.type: + case "mixed": + assert isinstance(input.value, str) + selectors = VariableTemplateParser(input.value).extract_variable_selectors() + for selector in selectors: + result[selector.variable] = selector.value_selector + case "variable": + selector_key = ".".join(input.value) + result[f"#{selector_key}#"] = input.value + case "constant": + pass result = {node_id + "." + key: value for key, value in result.items()} From 491fa9923b8d1fd3820a5df40eda4ebf22affdf5 Mon Sep 17 00:00:00 2001 From: Asuka Minato Date: Mon, 2 Feb 2026 21:03:16 +0900 Subject: [PATCH 10/43] refactor: port api/controllers/console/datasets/data_source.py /datasets/metadata.py /service_api/dataset/metadata.py /nodes/agent/agent_node.py api/core/workflow/nodes/datasource/datasource_node.py api/services/dataset_service.py to match case (#31836) --- .../console/datasets/data_source.py | 40 ++-- api/controllers/console/datasets/metadata.py | 9 +- .../service_api/dataset/metadata.py | 9 +- api/core/workflow/nodes/agent/agent_node.py | 64 +++--- .../nodes/datasource/datasource_node.py | 193 +++++++++--------- api/services/dataset_service.py | 114 ++++++----- 6 files changed, 223 insertions(+), 206 deletions(-) diff --git a/api/controllers/console/datasets/data_source.py b/api/controllers/console/datasets/data_source.py index 01e9bf77c0..daef4e005a 100644 --- a/api/controllers/console/datasets/data_source.py +++ b/api/controllers/console/datasets/data_source.py @@ -1,6 +1,6 @@ import json from collections.abc import Generator -from typing import Any, cast +from typing import Any, Literal, cast from flask import request from flask_restx import Resource, fields, marshal_with @@ -157,9 +157,8 @@ class DataSourceApi(Resource): @setup_required @login_required @account_initialization_required - def patch(self, binding_id, action): + def patch(self, binding_id, action: Literal["enable", "disable"]): binding_id = str(binding_id) - action = str(action) with Session(db.engine) as session: data_source_binding = session.execute( select(DataSourceOauthBinding).filter_by(id=binding_id) @@ -167,23 +166,24 @@ class DataSourceApi(Resource): if data_source_binding is None: raise NotFound("Data source binding not found.") # enable binding - if action == "enable": - if data_source_binding.disabled: - data_source_binding.disabled = False - data_source_binding.updated_at = naive_utc_now() - db.session.add(data_source_binding) - db.session.commit() - else: - raise ValueError("Data source is not disabled.") - # disable binding - if action == "disable": - if not data_source_binding.disabled: - data_source_binding.disabled = True - data_source_binding.updated_at = naive_utc_now() - db.session.add(data_source_binding) - db.session.commit() - else: - raise ValueError("Data source is disabled.") + match action: + case "enable": + if data_source_binding.disabled: + data_source_binding.disabled = False + data_source_binding.updated_at = naive_utc_now() + db.session.add(data_source_binding) + db.session.commit() + else: + raise ValueError("Data source is not disabled.") + # disable binding + case "disable": + if not data_source_binding.disabled: + data_source_binding.disabled = True + data_source_binding.updated_at = naive_utc_now() + db.session.add(data_source_binding) + db.session.commit() + else: + raise ValueError("Data source is disabled.") return {"result": "success"}, 200 diff --git a/api/controllers/console/datasets/metadata.py b/api/controllers/console/datasets/metadata.py index 05fc4cd714..2e69ddc5ab 100644 --- a/api/controllers/console/datasets/metadata.py +++ b/api/controllers/console/datasets/metadata.py @@ -126,10 +126,11 @@ class DatasetMetadataBuiltInFieldActionApi(Resource): raise NotFound("Dataset not found.") DatasetService.check_dataset_permission(dataset, current_user) - if action == "enable": - MetadataService.enable_built_in_field(dataset) - elif action == "disable": - MetadataService.disable_built_in_field(dataset) + match action: + case "enable": + MetadataService.enable_built_in_field(dataset) + case "disable": + MetadataService.disable_built_in_field(dataset) return {"result": "success"}, 200 diff --git a/api/controllers/service_api/dataset/metadata.py b/api/controllers/service_api/dataset/metadata.py index b8d9508004..692342a38a 100644 --- a/api/controllers/service_api/dataset/metadata.py +++ b/api/controllers/service_api/dataset/metadata.py @@ -168,10 +168,11 @@ class DatasetMetadataBuiltInFieldActionServiceApi(DatasetApiResource): raise NotFound("Dataset not found.") DatasetService.check_dataset_permission(dataset, current_user) - if action == "enable": - MetadataService.enable_built_in_field(dataset) - elif action == "disable": - MetadataService.disable_built_in_field(dataset) + match action: + case "enable": + MetadataService.enable_built_in_field(dataset) + case "disable": + MetadataService.disable_built_in_field(dataset) return {"result": "success"}, 200 diff --git a/api/core/workflow/nodes/agent/agent_node.py b/api/core/workflow/nodes/agent/agent_node.py index 5a365f769d..e195aebe6d 100644 --- a/api/core/workflow/nodes/agent/agent_node.py +++ b/api/core/workflow/nodes/agent/agent_node.py @@ -192,32 +192,33 @@ class AgentNode(Node[AgentNodeData]): result[parameter_name] = None continue agent_input = node_data.agent_parameters[parameter_name] - if agent_input.type == "variable": - variable = variable_pool.get(agent_input.value) # type: ignore - if variable is None: - raise AgentVariableNotFoundError(str(agent_input.value)) - parameter_value = variable.value - elif agent_input.type in {"mixed", "constant"}: - # variable_pool.convert_template expects a string template, - # but if passing a dict, convert to JSON string first before rendering - try: - if not isinstance(agent_input.value, str): - parameter_value = json.dumps(agent_input.value, ensure_ascii=False) - else: + match agent_input.type: + case "variable": + variable = variable_pool.get(agent_input.value) # type: ignore + if variable is None: + raise AgentVariableNotFoundError(str(agent_input.value)) + parameter_value = variable.value + case "mixed" | "constant": + # variable_pool.convert_template expects a string template, + # but if passing a dict, convert to JSON string first before rendering + try: + if not isinstance(agent_input.value, str): + parameter_value = json.dumps(agent_input.value, ensure_ascii=False) + else: + parameter_value = str(agent_input.value) + except TypeError: parameter_value = str(agent_input.value) - except TypeError: - parameter_value = str(agent_input.value) - segment_group = variable_pool.convert_template(parameter_value) - parameter_value = segment_group.log if for_log else segment_group.text - # variable_pool.convert_template returns a string, - # so we need to convert it back to a dictionary - try: - if not isinstance(agent_input.value, str): - parameter_value = json.loads(parameter_value) - except json.JSONDecodeError: - parameter_value = parameter_value - else: - raise AgentInputTypeError(agent_input.type) + segment_group = variable_pool.convert_template(parameter_value) + parameter_value = segment_group.log if for_log else segment_group.text + # variable_pool.convert_template returns a string, + # so we need to convert it back to a dictionary + try: + if not isinstance(agent_input.value, str): + parameter_value = json.loads(parameter_value) + except json.JSONDecodeError: + parameter_value = parameter_value + case _: + raise AgentInputTypeError(agent_input.type) value = parameter_value if parameter.type == "array[tools]": value = cast(list[dict[str, Any]], value) @@ -374,12 +375,13 @@ class AgentNode(Node[AgentNodeData]): result: dict[str, Any] = {} for parameter_name in typed_node_data.agent_parameters: input = typed_node_data.agent_parameters[parameter_name] - if input.type in ["mixed", "constant"]: - selectors = VariableTemplateParser(str(input.value)).extract_variable_selectors() - for selector in selectors: - result[selector.variable] = selector.value_selector - elif input.type == "variable": - result[parameter_name] = input.value + match input.type: + case "mixed" | "constant": + selectors = VariableTemplateParser(str(input.value)).extract_variable_selectors() + for selector in selectors: + result[selector.variable] = selector.value_selector + case "variable": + result[parameter_name] = input.value result = {node_id + "." + key: value for key, value in result.items()} diff --git a/api/core/workflow/nodes/datasource/datasource_node.py b/api/core/workflow/nodes/datasource/datasource_node.py index fd71d610b4..a732a70417 100644 --- a/api/core/workflow/nodes/datasource/datasource_node.py +++ b/api/core/workflow/nodes/datasource/datasource_node.py @@ -270,15 +270,18 @@ class DatasourceNode(Node[DatasourceNodeData]): if typed_node_data.datasource_parameters: for parameter_name in typed_node_data.datasource_parameters: input = typed_node_data.datasource_parameters[parameter_name] - if input.type == "mixed": - assert isinstance(input.value, str) - selectors = VariableTemplateParser(input.value).extract_variable_selectors() - for selector in selectors: - result[selector.variable] = selector.value_selector - elif input.type == "variable": - result[parameter_name] = input.value - elif input.type == "constant": - pass + match input.type: + case "mixed": + assert isinstance(input.value, str) + selectors = VariableTemplateParser(input.value).extract_variable_selectors() + for selector in selectors: + result[selector.variable] = selector.value_selector + case "variable": + result[parameter_name] = input.value + case "constant": + pass + case None: + pass result = {node_id + "." + key: value for key, value in result.items()} @@ -308,99 +311,107 @@ class DatasourceNode(Node[DatasourceNodeData]): variables: dict[str, Any] = {} for message in message_stream: - if message.type in { - DatasourceMessage.MessageType.IMAGE_LINK, - DatasourceMessage.MessageType.BINARY_LINK, - DatasourceMessage.MessageType.IMAGE, - }: - assert isinstance(message.message, DatasourceMessage.TextMessage) + match message.type: + case ( + DatasourceMessage.MessageType.IMAGE_LINK + | DatasourceMessage.MessageType.BINARY_LINK + | DatasourceMessage.MessageType.IMAGE + ): + assert isinstance(message.message, DatasourceMessage.TextMessage) - url = message.message.text - transfer_method = FileTransferMethod.TOOL_FILE + url = message.message.text + transfer_method = FileTransferMethod.TOOL_FILE - datasource_file_id = str(url).split("/")[-1].split(".")[0] + datasource_file_id = str(url).split("/")[-1].split(".")[0] - with Session(db.engine) as session: - stmt = select(ToolFile).where(ToolFile.id == datasource_file_id) - datasource_file = session.scalar(stmt) - if datasource_file is None: - raise ToolFileError(f"Tool file {datasource_file_id} does not exist") + with Session(db.engine) as session: + stmt = select(ToolFile).where(ToolFile.id == datasource_file_id) + datasource_file = session.scalar(stmt) + if datasource_file is None: + raise ToolFileError(f"Tool file {datasource_file_id} does not exist") - mapping = { - "tool_file_id": datasource_file_id, - "type": file_factory.get_file_type_by_mime_type(datasource_file.mimetype), - "transfer_method": transfer_method, - "url": url, - } - file = file_factory.build_from_mapping( - mapping=mapping, - tenant_id=self.tenant_id, - ) - files.append(file) - elif message.type == DatasourceMessage.MessageType.BLOB: - # get tool file id - assert isinstance(message.message, DatasourceMessage.TextMessage) - assert message.meta - - datasource_file_id = message.message.text.split("/")[-1].split(".")[0] - with Session(db.engine) as session: - stmt = select(ToolFile).where(ToolFile.id == datasource_file_id) - datasource_file = session.scalar(stmt) - if datasource_file is None: - raise ToolFileError(f"datasource file {datasource_file_id} not exists") - - mapping = { - "tool_file_id": datasource_file_id, - "transfer_method": FileTransferMethod.TOOL_FILE, - } - - files.append( - file_factory.build_from_mapping( + mapping = { + "tool_file_id": datasource_file_id, + "type": file_factory.get_file_type_by_mime_type(datasource_file.mimetype), + "transfer_method": transfer_method, + "url": url, + } + file = file_factory.build_from_mapping( mapping=mapping, tenant_id=self.tenant_id, ) - ) - elif message.type == DatasourceMessage.MessageType.TEXT: - assert isinstance(message.message, DatasourceMessage.TextMessage) - text += message.message.text - yield StreamChunkEvent( - selector=[self._node_id, "text"], - chunk=message.message.text, - is_final=False, - ) - elif message.type == DatasourceMessage.MessageType.JSON: - assert isinstance(message.message, DatasourceMessage.JsonMessage) - json.append(message.message.json_object) - elif message.type == DatasourceMessage.MessageType.LINK: - assert isinstance(message.message, DatasourceMessage.TextMessage) - stream_text = f"Link: {message.message.text}\n" - text += stream_text - yield StreamChunkEvent( - selector=[self._node_id, "text"], - chunk=stream_text, - is_final=False, - ) - elif message.type == DatasourceMessage.MessageType.VARIABLE: - assert isinstance(message.message, DatasourceMessage.VariableMessage) - variable_name = message.message.variable_name - variable_value = message.message.variable_value - if message.message.stream: - if not isinstance(variable_value, str): - raise ValueError("When 'stream' is True, 'variable_value' must be a string.") - if variable_name not in variables: - variables[variable_name] = "" - variables[variable_name] += variable_value + files.append(file) + case DatasourceMessage.MessageType.BLOB: + # get tool file id + assert isinstance(message.message, DatasourceMessage.TextMessage) + assert message.meta + datasource_file_id = message.message.text.split("/")[-1].split(".")[0] + with Session(db.engine) as session: + stmt = select(ToolFile).where(ToolFile.id == datasource_file_id) + datasource_file = session.scalar(stmt) + if datasource_file is None: + raise ToolFileError(f"datasource file {datasource_file_id} not exists") + + mapping = { + "tool_file_id": datasource_file_id, + "transfer_method": FileTransferMethod.TOOL_FILE, + } + + files.append( + file_factory.build_from_mapping( + mapping=mapping, + tenant_id=self.tenant_id, + ) + ) + case DatasourceMessage.MessageType.TEXT: + assert isinstance(message.message, DatasourceMessage.TextMessage) + text += message.message.text yield StreamChunkEvent( - selector=[self._node_id, variable_name], - chunk=variable_value, + selector=[self._node_id, "text"], + chunk=message.message.text, is_final=False, ) - else: - variables[variable_name] = variable_value - elif message.type == DatasourceMessage.MessageType.FILE: - assert message.meta is not None - files.append(message.meta["file"]) + case DatasourceMessage.MessageType.JSON: + assert isinstance(message.message, DatasourceMessage.JsonMessage) + json.append(message.message.json_object) + case DatasourceMessage.MessageType.LINK: + assert isinstance(message.message, DatasourceMessage.TextMessage) + stream_text = f"Link: {message.message.text}\n" + text += stream_text + yield StreamChunkEvent( + selector=[self._node_id, "text"], + chunk=stream_text, + is_final=False, + ) + case DatasourceMessage.MessageType.VARIABLE: + assert isinstance(message.message, DatasourceMessage.VariableMessage) + variable_name = message.message.variable_name + variable_value = message.message.variable_value + if message.message.stream: + if not isinstance(variable_value, str): + raise ValueError("When 'stream' is True, 'variable_value' must be a string.") + if variable_name not in variables: + variables[variable_name] = "" + variables[variable_name] += variable_value + + yield StreamChunkEvent( + selector=[self._node_id, variable_name], + chunk=variable_value, + is_final=False, + ) + else: + variables[variable_name] = variable_value + case DatasourceMessage.MessageType.FILE: + assert message.meta is not None + files.append(message.meta["file"]) + case ( + DatasourceMessage.MessageType.BLOB_CHUNK + | DatasourceMessage.MessageType.LOG + | DatasourceMessage.MessageType.RETRIEVER_RESOURCES + ): + pass + # mark the end of the stream yield StreamChunkEvent( selector=[self._node_id, "text"], diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 16945fca6a..1ea6c4e1c3 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -2978,14 +2978,15 @@ class DocumentService: """ now = naive_utc_now() - if action == "enable": - return DocumentService._prepare_enable_update(document, now) - elif action == "disable": - return DocumentService._prepare_disable_update(document, user, now) - elif action == "archive": - return DocumentService._prepare_archive_update(document, user, now) - elif action == "un_archive": - return DocumentService._prepare_unarchive_update(document, now) + match action: + case "enable": + return DocumentService._prepare_enable_update(document, now) + case "disable": + return DocumentService._prepare_disable_update(document, user, now) + case "archive": + return DocumentService._prepare_archive_update(document, user, now) + case "un_archive": + return DocumentService._prepare_unarchive_update(document, now) return None @@ -3622,56 +3623,57 @@ class SegmentService: # Check if segment_ids is not empty to avoid WHERE false condition if not segment_ids or len(segment_ids) == 0: return - if action == "enable": - segments = db.session.scalars( - select(DocumentSegment).where( - DocumentSegment.id.in_(segment_ids), - DocumentSegment.dataset_id == dataset.id, - DocumentSegment.document_id == document.id, - DocumentSegment.enabled == False, - ) - ).all() - if not segments: - return - real_deal_segment_ids = [] - for segment in segments: - indexing_cache_key = f"segment_{segment.id}_indexing" - cache_result = redis_client.get(indexing_cache_key) - if cache_result is not None: - continue - segment.enabled = True - segment.disabled_at = None - segment.disabled_by = None - db.session.add(segment) - real_deal_segment_ids.append(segment.id) - db.session.commit() + match action: + case "enable": + segments = db.session.scalars( + select(DocumentSegment).where( + DocumentSegment.id.in_(segment_ids), + DocumentSegment.dataset_id == dataset.id, + DocumentSegment.document_id == document.id, + DocumentSegment.enabled == False, + ) + ).all() + if not segments: + return + real_deal_segment_ids = [] + for segment in segments: + indexing_cache_key = f"segment_{segment.id}_indexing" + cache_result = redis_client.get(indexing_cache_key) + if cache_result is not None: + continue + segment.enabled = True + segment.disabled_at = None + segment.disabled_by = None + db.session.add(segment) + real_deal_segment_ids.append(segment.id) + db.session.commit() - enable_segments_to_index_task.delay(real_deal_segment_ids, dataset.id, document.id) - elif action == "disable": - segments = db.session.scalars( - select(DocumentSegment).where( - DocumentSegment.id.in_(segment_ids), - DocumentSegment.dataset_id == dataset.id, - DocumentSegment.document_id == document.id, - DocumentSegment.enabled == True, - ) - ).all() - if not segments: - return - real_deal_segment_ids = [] - for segment in segments: - indexing_cache_key = f"segment_{segment.id}_indexing" - cache_result = redis_client.get(indexing_cache_key) - if cache_result is not None: - continue - segment.enabled = False - segment.disabled_at = naive_utc_now() - segment.disabled_by = current_user.id - db.session.add(segment) - real_deal_segment_ids.append(segment.id) - db.session.commit() + enable_segments_to_index_task.delay(real_deal_segment_ids, dataset.id, document.id) + case "disable": + segments = db.session.scalars( + select(DocumentSegment).where( + DocumentSegment.id.in_(segment_ids), + DocumentSegment.dataset_id == dataset.id, + DocumentSegment.document_id == document.id, + DocumentSegment.enabled == True, + ) + ).all() + if not segments: + return + real_deal_segment_ids = [] + for segment in segments: + indexing_cache_key = f"segment_{segment.id}_indexing" + cache_result = redis_client.get(indexing_cache_key) + if cache_result is not None: + continue + segment.enabled = False + segment.disabled_at = naive_utc_now() + segment.disabled_by = current_user.id + db.session.add(segment) + real_deal_segment_ids.append(segment.id) + db.session.commit() - disable_segments_from_index_task.delay(real_deal_segment_ids, dataset.id, document.id) + disable_segments_from_index_task.delay(real_deal_segment_ids, dataset.id, document.id) @classmethod def create_child_chunk( From 47f8de3f8ec8c89450ed8c0e92f2347fc1a83765 Mon Sep 17 00:00:00 2001 From: Asuka Minato Date: Tue, 3 Feb 2026 10:59:00 +0900 Subject: [PATCH 11/43] refactor: port api/controllers/console/app/annotation.py api/controllers/console/explore/trial.py api/controllers/console/workspace/account.py api/controllers/console/workspace/members.py api/controllers/service_api/app/annotation.py to basemodel (#31833) Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- api/controllers/console/app/annotation.py | 81 +++++++------ api/controllers/console/explore/trial.py | 6 +- api/controllers/console/workspace/account.py | 41 ++++--- api/controllers/console/workspace/members.py | 27 +++-- api/controllers/service_api/app/annotation.py | 67 +++++------ .../service_api/dataset/dataset.py | 19 +-- api/fields/annotation_fields.py | 89 +++++++++----- api/fields/end_user_fields.py | 22 +++- api/fields/member_fields.py | 109 +++++++++++++----- api/fields/tag_fields.py | 26 +++-- api/fields/workflow_app_log_fields.py | 20 +--- api/services/annotation_service.py | 4 +- 12 files changed, 307 insertions(+), 204 deletions(-) diff --git a/api/controllers/console/app/annotation.py b/api/controllers/console/app/annotation.py index a07145ce9f..9931bb5dd7 100644 --- a/api/controllers/console/app/annotation.py +++ b/api/controllers/console/app/annotation.py @@ -1,10 +1,11 @@ from typing import Any, Literal from flask import abort, make_response, request -from flask_restx import Resource, fields, marshal, marshal_with -from pydantic import BaseModel, Field, field_validator +from flask_restx import Resource +from pydantic import BaseModel, Field, TypeAdapter, field_validator from controllers.common.errors import NoFileUploadedError, TooManyFilesError +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.wraps import ( account_initialization_required, @@ -16,9 +17,11 @@ from controllers.console.wraps import ( ) from extensions.ext_redis import redis_client from fields.annotation_fields import ( - annotation_fields, - annotation_hit_history_fields, - build_annotation_model, + Annotation, + AnnotationExportList, + AnnotationHitHistory, + AnnotationHitHistoryList, + AnnotationList, ) from libs.helper import uuid_value from libs.login import login_required @@ -89,6 +92,14 @@ reg(CreateAnnotationPayload) reg(UpdateAnnotationPayload) reg(AnnotationReplyStatusQuery) reg(AnnotationFilePayload) +register_schema_models( + console_ns, + Annotation, + AnnotationList, + AnnotationExportList, + AnnotationHitHistory, + AnnotationHitHistoryList, +) @console_ns.route("/apps//annotation-reply/") @@ -202,33 +213,33 @@ class AnnotationApi(Resource): app_id = str(app_id) annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(app_id, page, limit, keyword) - response = { - "data": marshal(annotation_list, annotation_fields), - "has_more": len(annotation_list) == limit, - "limit": limit, - "total": total, - "page": page, - } - return response, 200 + annotation_models = TypeAdapter(list[Annotation]).validate_python(annotation_list, from_attributes=True) + response = AnnotationList( + data=annotation_models, + has_more=len(annotation_list) == limit, + limit=limit, + total=total, + page=page, + ) + return response.model_dump(mode="json"), 200 @console_ns.doc("create_annotation") @console_ns.doc(description="Create a new annotation for an app") @console_ns.doc(params={"app_id": "Application ID"}) @console_ns.expect(console_ns.models[CreateAnnotationPayload.__name__]) - @console_ns.response(201, "Annotation created successfully", build_annotation_model(console_ns)) + @console_ns.response(201, "Annotation created successfully", console_ns.models[Annotation.__name__]) @console_ns.response(403, "Insufficient permissions") @setup_required @login_required @account_initialization_required @cloud_edition_billing_resource_check("annotation") - @marshal_with(annotation_fields) @edit_permission_required def post(self, app_id): app_id = str(app_id) args = CreateAnnotationPayload.model_validate(console_ns.payload) data = args.model_dump(exclude_none=True) annotation = AppAnnotationService.up_insert_app_annotation_from_message(data, app_id) - return annotation + return Annotation.model_validate(annotation, from_attributes=True).model_dump(mode="json") @setup_required @login_required @@ -265,7 +276,7 @@ class AnnotationExportApi(Resource): @console_ns.response( 200, "Annotations exported successfully", - console_ns.model("AnnotationList", {"data": fields.List(fields.Nested(build_annotation_model(console_ns)))}), + console_ns.models[AnnotationExportList.__name__], ) @console_ns.response(403, "Insufficient permissions") @setup_required @@ -275,7 +286,8 @@ class AnnotationExportApi(Resource): def get(self, app_id): app_id = str(app_id) annotation_list = AppAnnotationService.export_annotation_list_by_app_id(app_id) - response_data = {"data": marshal(annotation_list, annotation_fields)} + annotation_models = TypeAdapter(list[Annotation]).validate_python(annotation_list, from_attributes=True) + response_data = AnnotationExportList(data=annotation_models).model_dump(mode="json") # Create response with secure headers for CSV export response = make_response(response_data, 200) @@ -290,7 +302,7 @@ class AnnotationUpdateDeleteApi(Resource): @console_ns.doc("update_delete_annotation") @console_ns.doc(description="Update or delete an annotation") @console_ns.doc(params={"app_id": "Application ID", "annotation_id": "Annotation ID"}) - @console_ns.response(200, "Annotation updated successfully", build_annotation_model(console_ns)) + @console_ns.response(200, "Annotation updated successfully", console_ns.models[Annotation.__name__]) @console_ns.response(204, "Annotation deleted successfully") @console_ns.response(403, "Insufficient permissions") @console_ns.expect(console_ns.models[UpdateAnnotationPayload.__name__]) @@ -299,7 +311,6 @@ class AnnotationUpdateDeleteApi(Resource): @account_initialization_required @cloud_edition_billing_resource_check("annotation") @edit_permission_required - @marshal_with(annotation_fields) def post(self, app_id, annotation_id): app_id = str(app_id) annotation_id = str(annotation_id) @@ -307,7 +318,7 @@ class AnnotationUpdateDeleteApi(Resource): annotation = AppAnnotationService.update_app_annotation_directly( args.model_dump(exclude_none=True), app_id, annotation_id ) - return annotation + return Annotation.model_validate(annotation, from_attributes=True).model_dump(mode="json") @setup_required @login_required @@ -415,14 +426,7 @@ class AnnotationHitHistoryListApi(Resource): @console_ns.response( 200, "Hit histories retrieved successfully", - console_ns.model( - "AnnotationHitHistoryList", - { - "data": fields.List( - fields.Nested(console_ns.model("AnnotationHitHistoryItem", annotation_hit_history_fields)) - ) - }, - ), + console_ns.models[AnnotationHitHistoryList.__name__], ) @console_ns.response(403, "Insufficient permissions") @setup_required @@ -437,11 +441,14 @@ class AnnotationHitHistoryListApi(Resource): annotation_hit_history_list, total = AppAnnotationService.get_annotation_hit_histories( app_id, annotation_id, page, limit ) - response = { - "data": marshal(annotation_hit_history_list, annotation_hit_history_fields), - "has_more": len(annotation_hit_history_list) == limit, - "limit": limit, - "total": total, - "page": page, - } - return response + history_models = TypeAdapter(list[AnnotationHitHistory]).validate_python( + annotation_hit_history_list, from_attributes=True + ) + response = AnnotationHitHistoryList( + data=history_models, + has_more=len(annotation_hit_history_list) == limit, + limit=limit, + total=total, + page=page, + ) + return response.model_dump(mode="json") diff --git a/api/controllers/console/explore/trial.py b/api/controllers/console/explore/trial.py index 1eb0cdb019..cd523b481c 100644 --- a/api/controllers/console/explore/trial.py +++ b/api/controllers/console/explore/trial.py @@ -9,7 +9,7 @@ import services from controllers.common.fields import Parameters as ParametersResponse from controllers.common.fields import Site as SiteResponse from controllers.common.schema import get_or_create_model -from controllers.console import api, console_ns +from controllers.console import api from controllers.console.app.error import ( AppUnavailableError, AudioTooLargeError, @@ -51,7 +51,7 @@ from fields.app_fields import ( tag_fields, ) from fields.dataset_fields import dataset_fields -from fields.member_fields import build_simple_account_model +from fields.member_fields import simple_account_fields from fields.workflow_fields import ( conversation_variable_fields, pipeline_variable_fields, @@ -103,7 +103,7 @@ app_detail_fields_with_site_copy["tags"] = fields.List(fields.Nested(tag_model)) app_detail_fields_with_site_copy["site"] = fields.Nested(site_model) app_detail_with_site_model = get_or_create_model("TrialAppDetailWithSite", app_detail_fields_with_site_copy) -simple_account_model = build_simple_account_model(console_ns) +simple_account_model = get_or_create_model("SimpleAccount", simple_account_fields) conversation_variable_model = get_or_create_model("TrialConversationVariable", conversation_variable_fields) pipeline_variable_model = get_or_create_model("TrialPipelineVariable", pipeline_variable_fields) diff --git a/api/controllers/console/workspace/account.py b/api/controllers/console/workspace/account.py index 38c66525b3..708df62642 100644 --- a/api/controllers/console/workspace/account.py +++ b/api/controllers/console/workspace/account.py @@ -12,6 +12,7 @@ from sqlalchemy.orm import Session from configs import dify_config from constants.languages import supported_language +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.auth.error import ( EmailAlreadyInUseError, @@ -37,7 +38,7 @@ from controllers.console.wraps import ( setup_required, ) from extensions.ext_database import db -from fields.member_fields import account_fields +from fields.member_fields import Account as AccountResponse from libs.datetime_utils import naive_utc_now from libs.helper import EmailStr, TimestampField, extract_remote_ip, timezone from libs.login import current_account_with_tenant, login_required @@ -170,6 +171,12 @@ reg(ChangeEmailSendPayload) reg(ChangeEmailValidityPayload) reg(ChangeEmailResetPayload) reg(CheckEmailUniquePayload) +register_schema_models(console_ns, AccountResponse) + + +def _serialize_account(account) -> dict: + return AccountResponse.model_validate(account, from_attributes=True).model_dump(mode="json") + integrate_fields = { "provider": fields.String, @@ -236,11 +243,11 @@ class AccountProfileApi(Resource): @setup_required @login_required @account_initialization_required - @marshal_with(account_fields) + @console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__]) @enterprise_license_required def get(self): current_user, _ = current_account_with_tenant() - return current_user + return _serialize_account(current_user) @console_ns.route("/account/name") @@ -249,14 +256,14 @@ class AccountNameApi(Resource): @setup_required @login_required @account_initialization_required - @marshal_with(account_fields) + @console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__]) def post(self): current_user, _ = current_account_with_tenant() payload = console_ns.payload or {} args = AccountNamePayload.model_validate(payload) updated_account = AccountService.update_account(current_user, name=args.name) - return updated_account + return _serialize_account(updated_account) @console_ns.route("/account/avatar") @@ -265,7 +272,7 @@ class AccountAvatarApi(Resource): @setup_required @login_required @account_initialization_required - @marshal_with(account_fields) + @console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__]) def post(self): current_user, _ = current_account_with_tenant() payload = console_ns.payload or {} @@ -273,7 +280,7 @@ class AccountAvatarApi(Resource): updated_account = AccountService.update_account(current_user, avatar=args.avatar) - return updated_account + return _serialize_account(updated_account) @console_ns.route("/account/interface-language") @@ -282,7 +289,7 @@ class AccountInterfaceLanguageApi(Resource): @setup_required @login_required @account_initialization_required - @marshal_with(account_fields) + @console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__]) def post(self): current_user, _ = current_account_with_tenant() payload = console_ns.payload or {} @@ -290,7 +297,7 @@ class AccountInterfaceLanguageApi(Resource): updated_account = AccountService.update_account(current_user, interface_language=args.interface_language) - return updated_account + return _serialize_account(updated_account) @console_ns.route("/account/interface-theme") @@ -299,7 +306,7 @@ class AccountInterfaceThemeApi(Resource): @setup_required @login_required @account_initialization_required - @marshal_with(account_fields) + @console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__]) def post(self): current_user, _ = current_account_with_tenant() payload = console_ns.payload or {} @@ -307,7 +314,7 @@ class AccountInterfaceThemeApi(Resource): updated_account = AccountService.update_account(current_user, interface_theme=args.interface_theme) - return updated_account + return _serialize_account(updated_account) @console_ns.route("/account/timezone") @@ -316,7 +323,7 @@ class AccountTimezoneApi(Resource): @setup_required @login_required @account_initialization_required - @marshal_with(account_fields) + @console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__]) def post(self): current_user, _ = current_account_with_tenant() payload = console_ns.payload or {} @@ -324,7 +331,7 @@ class AccountTimezoneApi(Resource): updated_account = AccountService.update_account(current_user, timezone=args.timezone) - return updated_account + return _serialize_account(updated_account) @console_ns.route("/account/password") @@ -333,7 +340,7 @@ class AccountPasswordApi(Resource): @setup_required @login_required @account_initialization_required - @marshal_with(account_fields) + @console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__]) def post(self): current_user, _ = current_account_with_tenant() payload = console_ns.payload or {} @@ -344,7 +351,7 @@ class AccountPasswordApi(Resource): except ServiceCurrentPasswordIncorrectError: raise CurrentPasswordIncorrectError() - return {"result": "success"} + return _serialize_account(current_user) @console_ns.route("/account/integrates") @@ -620,7 +627,7 @@ class ChangeEmailResetApi(Resource): @setup_required @login_required @account_initialization_required - @marshal_with(account_fields) + @console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__]) def post(self): payload = console_ns.payload or {} args = ChangeEmailResetPayload.model_validate(payload) @@ -649,7 +656,7 @@ class ChangeEmailResetApi(Resource): email=normalized_new_email, ) - return updated_account + return _serialize_account(updated_account) @console_ns.route("/account/change-email/check-email-unique") diff --git a/api/controllers/console/workspace/members.py b/api/controllers/console/workspace/members.py index 271cdce3c3..dd302b90d6 100644 --- a/api/controllers/console/workspace/members.py +++ b/api/controllers/console/workspace/members.py @@ -1,12 +1,12 @@ from urllib import parse from flask import abort, request -from flask_restx import Resource, fields, marshal_with -from pydantic import BaseModel, Field +from flask_restx import Resource +from pydantic import BaseModel, Field, TypeAdapter import services from configs import dify_config -from controllers.common.schema import get_or_create_model, register_enum_models +from controllers.common.schema import register_enum_models, register_schema_models from controllers.console import console_ns from controllers.console.auth.error import ( CannotTransferOwnerToSelfError, @@ -25,7 +25,7 @@ from controllers.console.wraps import ( setup_required, ) from extensions.ext_database import db -from fields.member_fields import account_with_role_fields, account_with_role_list_fields +from fields.member_fields import AccountWithRole, AccountWithRoleList from libs.helper import extract_remote_ip from libs.login import current_account_with_tenant, login_required from models.account import Account, TenantAccountRole @@ -69,12 +69,7 @@ reg(OwnerTransferEmailPayload) reg(OwnerTransferCheckPayload) reg(OwnerTransferPayload) register_enum_models(console_ns, TenantAccountRole) - -account_with_role_model = get_or_create_model("AccountWithRole", account_with_role_fields) - -account_with_role_list_fields_copy = account_with_role_list_fields.copy() -account_with_role_list_fields_copy["accounts"] = fields.List(fields.Nested(account_with_role_model)) -account_with_role_list_model = get_or_create_model("AccountWithRoleList", account_with_role_list_fields_copy) +register_schema_models(console_ns, AccountWithRole, AccountWithRoleList) @console_ns.route("/workspaces/current/members") @@ -84,13 +79,15 @@ class MemberListApi(Resource): @setup_required @login_required @account_initialization_required - @marshal_with(account_with_role_list_model) + @console_ns.response(200, "Success", console_ns.models[AccountWithRoleList.__name__]) def get(self): current_user, _ = current_account_with_tenant() if not current_user.current_tenant: raise ValueError("No current tenant") members = TenantService.get_tenant_members(current_user.current_tenant) - return {"result": "success", "accounts": members}, 200 + member_models = TypeAdapter(list[AccountWithRole]).validate_python(members, from_attributes=True) + response = AccountWithRoleList(accounts=member_models) + return response.model_dump(mode="json"), 200 @console_ns.route("/workspaces/current/members/invite-email") @@ -235,13 +232,15 @@ class DatasetOperatorMemberListApi(Resource): @setup_required @login_required @account_initialization_required - @marshal_with(account_with_role_list_model) + @console_ns.response(200, "Success", console_ns.models[AccountWithRoleList.__name__]) def get(self): current_user, _ = current_account_with_tenant() if not current_user.current_tenant: raise ValueError("No current tenant") members = TenantService.get_dataset_operator_members(current_user.current_tenant) - return {"result": "success", "accounts": members}, 200 + member_models = TypeAdapter(list[AccountWithRole]).validate_python(members, from_attributes=True) + response = AccountWithRoleList(accounts=member_models) + return response.model_dump(mode="json"), 200 @console_ns.route("/workspaces/current/members/send-owner-transfer-confirm-email") diff --git a/api/controllers/service_api/app/annotation.py b/api/controllers/service_api/app/annotation.py index 5be146a13e..ef254ca357 100644 --- a/api/controllers/service_api/app/annotation.py +++ b/api/controllers/service_api/app/annotation.py @@ -1,16 +1,16 @@ from typing import Literal from flask import request -from flask_restx import Namespace, Resource, fields +from flask_restx import Resource from flask_restx.api import HTTPStatus -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, TypeAdapter from controllers.common.schema import register_schema_models from controllers.console.wraps import edit_permission_required from controllers.service_api import service_api_ns from controllers.service_api.wraps import validate_app_token from extensions.ext_redis import redis_client -from fields.annotation_fields import annotation_fields, build_annotation_model +from fields.annotation_fields import Annotation, AnnotationList from models.model import App from services.annotation_service import AppAnnotationService @@ -26,7 +26,9 @@ class AnnotationReplyActionPayload(BaseModel): embedding_model_name: str = Field(description="Embedding model name") -register_schema_models(service_api_ns, AnnotationCreatePayload, AnnotationReplyActionPayload) +register_schema_models( + service_api_ns, AnnotationCreatePayload, AnnotationReplyActionPayload, Annotation, AnnotationList +) @service_api_ns.route("/apps/annotation-reply/") @@ -83,23 +85,6 @@ class AnnotationReplyActionStatusApi(Resource): return {"job_id": job_id, "job_status": job_status, "error_msg": error_msg}, 200 -# Define annotation list response model -annotation_list_fields = { - "data": fields.List(fields.Nested(annotation_fields)), - "has_more": fields.Boolean, - "limit": fields.Integer, - "total": fields.Integer, - "page": fields.Integer, -} - - -def build_annotation_list_model(api_or_ns: Namespace): - """Build the annotation list model for the API or Namespace.""" - copied_annotation_list_fields = annotation_list_fields.copy() - copied_annotation_list_fields["data"] = fields.List(fields.Nested(build_annotation_model(api_or_ns))) - return api_or_ns.model("AnnotationList", copied_annotation_list_fields) - - @service_api_ns.route("/apps/annotations") class AnnotationListApi(Resource): @service_api_ns.doc("list_annotations") @@ -110,8 +95,12 @@ class AnnotationListApi(Resource): 401: "Unauthorized - invalid API token", } ) + @service_api_ns.response( + 200, + "Annotations retrieved successfully", + service_api_ns.models[AnnotationList.__name__], + ) @validate_app_token - @service_api_ns.marshal_with(build_annotation_list_model(service_api_ns)) def get(self, app_model: App): """List annotations for the application.""" page = request.args.get("page", default=1, type=int) @@ -119,13 +108,15 @@ class AnnotationListApi(Resource): keyword = request.args.get("keyword", default="", type=str) annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(app_model.id, page, limit, keyword) - return { - "data": annotation_list, - "has_more": len(annotation_list) == limit, - "limit": limit, - "total": total, - "page": page, - } + annotation_models = TypeAdapter(list[Annotation]).validate_python(annotation_list, from_attributes=True) + response = AnnotationList( + data=annotation_models, + has_more=len(annotation_list) == limit, + limit=limit, + total=total, + page=page, + ) + return response.model_dump(mode="json") @service_api_ns.expect(service_api_ns.models[AnnotationCreatePayload.__name__]) @service_api_ns.doc("create_annotation") @@ -136,13 +127,18 @@ class AnnotationListApi(Resource): 401: "Unauthorized - invalid API token", } ) + @service_api_ns.response( + HTTPStatus.CREATED, + "Annotation created successfully", + service_api_ns.models[Annotation.__name__], + ) @validate_app_token - @service_api_ns.marshal_with(build_annotation_model(service_api_ns), code=HTTPStatus.CREATED) def post(self, app_model: App): """Create a new annotation.""" args = AnnotationCreatePayload.model_validate(service_api_ns.payload or {}).model_dump() annotation = AppAnnotationService.insert_app_annotation_directly(args, app_model.id) - return annotation, 201 + response = Annotation.model_validate(annotation, from_attributes=True) + return response.model_dump(mode="json"), HTTPStatus.CREATED @service_api_ns.route("/apps/annotations/") @@ -159,14 +155,19 @@ class AnnotationUpdateDeleteApi(Resource): 404: "Annotation not found", } ) + @service_api_ns.response( + 200, + "Annotation updated successfully", + service_api_ns.models[Annotation.__name__], + ) @validate_app_token @edit_permission_required - @service_api_ns.marshal_with(build_annotation_model(service_api_ns)) def put(self, app_model: App, annotation_id: str): """Update an existing annotation.""" args = AnnotationCreatePayload.model_validate(service_api_ns.payload or {}).model_dump() annotation = AppAnnotationService.update_app_annotation_directly(args, app_model.id, annotation_id) - return annotation + response = Annotation.model_validate(annotation, from_attributes=True) + return response.model_dump(mode="json") @service_api_ns.doc("delete_annotation") @service_api_ns.doc(description="Delete an annotation") diff --git a/api/controllers/service_api/dataset/dataset.py b/api/controllers/service_api/dataset/dataset.py index c11f64585a..db5cabe8aa 100644 --- a/api/controllers/service_api/dataset/dataset.py +++ b/api/controllers/service_api/dataset/dataset.py @@ -17,7 +17,7 @@ from controllers.service_api.wraps import ( from core.model_runtime.entities.model_entities import ModelType from core.provider_manager import ProviderManager from fields.dataset_fields import dataset_detail_fields -from fields.tag_fields import build_dataset_tag_fields +from fields.tag_fields import DataSetTag from libs.login import current_user from models.account import Account from models.dataset import DatasetPermissionEnum @@ -114,6 +114,7 @@ register_schema_models( TagBindingPayload, TagUnbindingPayload, DatasetListQuery, + DataSetTag, ) @@ -480,15 +481,14 @@ class DatasetTagsApi(DatasetApiResource): 401: "Unauthorized - invalid API token", } ) - @service_api_ns.marshal_with(build_dataset_tag_fields(service_api_ns)) def get(self, _): """Get all knowledge type tags.""" assert isinstance(current_user, Account) cid = current_user.current_tenant_id assert cid is not None tags = TagService.get_tags("knowledge", cid) - - return tags, 200 + tag_models = TypeAdapter(list[DataSetTag]).validate_python(tags, from_attributes=True) + return [tag.model_dump(mode="json") for tag in tag_models], 200 @service_api_ns.expect(service_api_ns.models[TagCreatePayload.__name__]) @service_api_ns.doc("create_dataset_tag") @@ -500,7 +500,6 @@ class DatasetTagsApi(DatasetApiResource): 403: "Forbidden - insufficient permissions", } ) - @service_api_ns.marshal_with(build_dataset_tag_fields(service_api_ns)) def post(self, _): """Add a knowledge type tag.""" assert isinstance(current_user, Account) @@ -510,7 +509,9 @@ class DatasetTagsApi(DatasetApiResource): payload = TagCreatePayload.model_validate(service_api_ns.payload or {}) tag = TagService.save_tags({"name": payload.name, "type": "knowledge"}) - response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0} + response = DataSetTag.model_validate( + {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0} + ).model_dump(mode="json") return response, 200 @service_api_ns.expect(service_api_ns.models[TagUpdatePayload.__name__]) @@ -523,7 +524,6 @@ class DatasetTagsApi(DatasetApiResource): 403: "Forbidden - insufficient permissions", } ) - @service_api_ns.marshal_with(build_dataset_tag_fields(service_api_ns)) def patch(self, _): assert isinstance(current_user, Account) if not (current_user.has_edit_permission or current_user.is_dataset_editor): @@ -536,8 +536,9 @@ class DatasetTagsApi(DatasetApiResource): binding_count = TagService.get_tag_binding_count(tag_id) - response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count} - + response = DataSetTag.model_validate( + {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count} + ).model_dump(mode="json") return response, 200 @service_api_ns.expect(service_api_ns.models[TagDeletePayload.__name__]) diff --git a/api/fields/annotation_fields.py b/api/fields/annotation_fields.py index e69306dcb2..a646950722 100644 --- a/api/fields/annotation_fields.py +++ b/api/fields/annotation_fields.py @@ -1,36 +1,69 @@ -from flask_restx import Namespace, fields +from __future__ import annotations -from libs.helper import TimestampField +from datetime import datetime -annotation_fields = { - "id": fields.String, - "question": fields.String, - "answer": fields.Raw(attribute="content"), - "hit_count": fields.Integer, - "created_at": TimestampField, - # 'account': fields.Nested(simple_account_fields, allow_null=True) -} +from pydantic import BaseModel, ConfigDict, Field, field_validator -def build_annotation_model(api_or_ns: Namespace): - """Build the annotation model for the API or Namespace.""" - return api_or_ns.model("Annotation", annotation_fields) +def _to_timestamp(value: datetime | int | None) -> int | None: + if isinstance(value, datetime): + return int(value.timestamp()) + return value -annotation_list_fields = { - "data": fields.List(fields.Nested(annotation_fields)), -} +class ResponseModel(BaseModel): + model_config = ConfigDict( + from_attributes=True, + extra="ignore", + populate_by_name=True, + serialize_by_alias=True, + protected_namespaces=(), + ) -annotation_hit_history_fields = { - "id": fields.String, - "source": fields.String, - "score": fields.Float, - "question": fields.String, - "created_at": TimestampField, - "match": fields.String(attribute="annotation_question"), - "response": fields.String(attribute="annotation_content"), -} -annotation_hit_history_list_fields = { - "data": fields.List(fields.Nested(annotation_hit_history_fields)), -} +class Annotation(ResponseModel): + id: str + question: str | None = None + answer: str | None = Field(default=None, validation_alias="content") + hit_count: int | None = None + created_at: int | None = None + + @field_validator("created_at", mode="before") + @classmethod + def _normalize_created_at(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) + + +class AnnotationList(ResponseModel): + data: list[Annotation] + has_more: bool + limit: int + total: int + page: int + + +class AnnotationExportList(ResponseModel): + data: list[Annotation] + + +class AnnotationHitHistory(ResponseModel): + id: str + source: str | None = None + score: float | None = None + question: str | None = None + created_at: int | None = None + match: str | None = Field(default=None, validation_alias="annotation_question") + response: str | None = Field(default=None, validation_alias="annotation_content") + + @field_validator("created_at", mode="before") + @classmethod + def _normalize_created_at(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) + + +class AnnotationHitHistoryList(ResponseModel): + data: list[AnnotationHitHistory] + has_more: bool + limit: int + total: int + page: int diff --git a/api/fields/end_user_fields.py b/api/fields/end_user_fields.py index 5389b0213a..effe7bfb20 100644 --- a/api/fields/end_user_fields.py +++ b/api/fields/end_user_fields.py @@ -1,4 +1,7 @@ -from flask_restx import Namespace, fields +from __future__ import annotations + +from flask_restx import fields +from pydantic import BaseModel, ConfigDict simple_end_user_fields = { "id": fields.String, @@ -8,5 +11,18 @@ simple_end_user_fields = { } -def build_simple_end_user_model(api_or_ns: Namespace): - return api_or_ns.model("SimpleEndUser", simple_end_user_fields) +class ResponseModel(BaseModel): + model_config = ConfigDict( + from_attributes=True, + extra="ignore", + populate_by_name=True, + serialize_by_alias=True, + protected_namespaces=(), + ) + + +class SimpleEndUser(ResponseModel): + id: str + type: str + is_anonymous: bool + session_id: str | None = None diff --git a/api/fields/member_fields.py b/api/fields/member_fields.py index 25160927e6..11d9a1a2fc 100644 --- a/api/fields/member_fields.py +++ b/api/fields/member_fields.py @@ -1,6 +1,11 @@ -from flask_restx import Namespace, fields +from __future__ import annotations -from libs.helper import AvatarUrlField, TimestampField +from datetime import datetime + +from flask_restx import fields +from pydantic import BaseModel, ConfigDict, computed_field, field_validator + +from core.file import helpers as file_helpers simple_account_fields = { "id": fields.String, @@ -9,36 +14,78 @@ simple_account_fields = { } -def build_simple_account_model(api_or_ns: Namespace): - return api_or_ns.model("SimpleAccount", simple_account_fields) +def _to_timestamp(value: datetime | int | None) -> int | None: + if isinstance(value, datetime): + return int(value.timestamp()) + return value -account_fields = { - "id": fields.String, - "name": fields.String, - "avatar": fields.String, - "avatar_url": AvatarUrlField, - "email": fields.String, - "is_password_set": fields.Boolean, - "interface_language": fields.String, - "interface_theme": fields.String, - "timezone": fields.String, - "last_login_at": TimestampField, - "last_login_ip": fields.String, - "created_at": TimestampField, -} +def _build_avatar_url(avatar: str | None) -> str | None: + if avatar is None: + return None + if avatar.startswith(("http://", "https://")): + return avatar + return file_helpers.get_signed_file_url(avatar) -account_with_role_fields = { - "id": fields.String, - "name": fields.String, - "avatar": fields.String, - "avatar_url": AvatarUrlField, - "email": fields.String, - "last_login_at": TimestampField, - "last_active_at": TimestampField, - "created_at": TimestampField, - "role": fields.String, - "status": fields.String, -} -account_with_role_list_fields = {"accounts": fields.List(fields.Nested(account_with_role_fields))} +class ResponseModel(BaseModel): + model_config = ConfigDict( + from_attributes=True, + extra="ignore", + populate_by_name=True, + serialize_by_alias=True, + protected_namespaces=(), + ) + + +class SimpleAccount(ResponseModel): + id: str + name: str + email: str + + +class _AccountAvatar(ResponseModel): + avatar: str | None = None + + @computed_field(return_type=str | None) # type: ignore[prop-decorator] + @property + def avatar_url(self) -> str | None: + return _build_avatar_url(self.avatar) + + +class Account(_AccountAvatar): + id: str + name: str + email: str + is_password_set: bool + interface_language: str | None = None + interface_theme: str | None = None + timezone: str | None = None + last_login_at: int | None = None + last_login_ip: str | None = None + created_at: int | None = None + + @field_validator("last_login_at", "created_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) + + +class AccountWithRole(_AccountAvatar): + id: str + name: str + email: str + last_login_at: int | None = None + last_active_at: int | None = None + created_at: int | None = None + role: str + status: str + + @field_validator("last_login_at", "last_active_at", "created_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) + + +class AccountWithRoleList(ResponseModel): + accounts: list[AccountWithRole] diff --git a/api/fields/tag_fields.py b/api/fields/tag_fields.py index e359a4408c..7cb64e5ca8 100644 --- a/api/fields/tag_fields.py +++ b/api/fields/tag_fields.py @@ -1,12 +1,20 @@ -from flask_restx import Namespace, fields +from __future__ import annotations -dataset_tag_fields = { - "id": fields.String, - "name": fields.String, - "type": fields.String, - "binding_count": fields.String, -} +from pydantic import BaseModel, ConfigDict -def build_dataset_tag_fields(api_or_ns: Namespace): - return api_or_ns.model("DataSetTag", dataset_tag_fields) +class ResponseModel(BaseModel): + model_config = ConfigDict( + from_attributes=True, + extra="ignore", + populate_by_name=True, + serialize_by_alias=True, + protected_namespaces=(), + ) + + +class DataSetTag(ResponseModel): + id: str + name: str + type: str + binding_count: str | None = None diff --git a/api/fields/workflow_app_log_fields.py b/api/fields/workflow_app_log_fields.py index ae70356322..d0e762f62b 100644 --- a/api/fields/workflow_app_log_fields.py +++ b/api/fields/workflow_app_log_fields.py @@ -1,7 +1,7 @@ from flask_restx import Namespace, fields -from fields.end_user_fields import build_simple_end_user_model, simple_end_user_fields -from fields.member_fields import build_simple_account_model, simple_account_fields +from fields.end_user_fields import simple_end_user_fields +from fields.member_fields import simple_account_fields from fields.workflow_run_fields import ( build_workflow_run_for_archived_log_model, build_workflow_run_for_log_model, @@ -25,17 +25,9 @@ workflow_app_log_partial_fields = { def build_workflow_app_log_partial_model(api_or_ns: Namespace): """Build the workflow app log partial model for the API or Namespace.""" workflow_run_model = build_workflow_run_for_log_model(api_or_ns) - simple_account_model = build_simple_account_model(api_or_ns) - simple_end_user_model = build_simple_end_user_model(api_or_ns) copied_fields = workflow_app_log_partial_fields.copy() copied_fields["workflow_run"] = fields.Nested(workflow_run_model, attribute="workflow_run", allow_null=True) - copied_fields["created_by_account"] = fields.Nested( - simple_account_model, attribute="created_by_account", allow_null=True - ) - copied_fields["created_by_end_user"] = fields.Nested( - simple_end_user_model, attribute="created_by_end_user", allow_null=True - ) return api_or_ns.model("WorkflowAppLogPartial", copied_fields) @@ -52,17 +44,9 @@ workflow_archived_log_partial_fields = { def build_workflow_archived_log_partial_model(api_or_ns: Namespace): """Build the workflow archived log partial model for the API or Namespace.""" workflow_run_model = build_workflow_run_for_archived_log_model(api_or_ns) - simple_account_model = build_simple_account_model(api_or_ns) - simple_end_user_model = build_simple_end_user_model(api_or_ns) copied_fields = workflow_archived_log_partial_fields.copy() copied_fields["workflow_run"] = fields.Nested(workflow_run_model, allow_null=True) - copied_fields["created_by_account"] = fields.Nested( - simple_account_model, attribute="created_by_account", allow_null=True - ) - copied_fields["created_by_end_user"] = fields.Nested( - simple_end_user_model, attribute="created_by_end_user", allow_null=True - ) return api_or_ns.model("WorkflowArchivedLogPartial", copied_fields) diff --git a/api/services/annotation_service.py b/api/services/annotation_service.py index 56e9cc6a00..8ebc87a670 100644 --- a/api/services/annotation_service.py +++ b/api/services/annotation_service.py @@ -158,7 +158,7 @@ class AppAnnotationService: .order_by(MessageAnnotation.created_at.desc(), MessageAnnotation.id.desc()) ) annotations = db.paginate(select=stmt, page=page, per_page=limit, max_per_page=100, error_out=False) - return annotations.items, annotations.total + return annotations.items, annotations.total or 0 @classmethod def export_annotation_list_by_app_id(cls, app_id: str): @@ -524,7 +524,7 @@ class AppAnnotationService: annotation_hit_histories = db.paginate( select=stmt, page=page, per_page=limit, max_per_page=100, error_out=False ) - return annotation_hit_histories.items, annotation_hit_histories.total + return annotation_hit_histories.items, annotation_hit_histories.total or 0 @classmethod def get_annotation_by_id(cls, annotation_id: str) -> MessageAnnotation | None: From 8b50c0d9205857800d925288c6bdc0d582045d19 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 3 Feb 2026 09:59:29 +0800 Subject: [PATCH 12/43] chore(deps-dev): bump types-psutil from 7.0.0.20251116 to 7.2.2.20260130 in /api (#31814) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- api/pyproject.toml | 2 +- api/uv.lock | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/api/pyproject.toml b/api/pyproject.toml index 02d1aea21d..ab1f523267 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -145,7 +145,7 @@ dev = [ "types-openpyxl~=3.1.5", "types-pexpect~=4.9.0", "types-protobuf~=5.29.1", - "types-psutil~=7.0.0", + "types-psutil~=7.2.2", "types-psycopg2~=2.9.21", "types-pygments~=2.19.0", "types-pymysql~=1.1.0", diff --git a/api/uv.lock b/api/uv.lock index ad84b35212..f253976cc1 100644 --- a/api/uv.lock +++ b/api/uv.lock @@ -1707,7 +1707,7 @@ dev = [ { name = "types-openpyxl", specifier = "~=3.1.5" }, { name = "types-pexpect", specifier = "~=4.9.0" }, { name = "types-protobuf", specifier = "~=5.29.1" }, - { name = "types-psutil", specifier = "~=7.0.0" }, + { name = "types-psutil", specifier = "~=7.2.2" }, { name = "types-psycopg2", specifier = "~=2.9.21" }, { name = "types-pygments", specifier = "~=2.19.0" }, { name = "types-pymysql", specifier = "~=1.1.0" }, @@ -6508,11 +6508,11 @@ wheels = [ [[package]] name = "types-psutil" -version = "7.0.0.20251116" +version = "7.2.2.20260130" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/47/ec/c1e9308b91582cad1d7e7d3007fd003ef45a62c2500f8219313df5fc3bba/types_psutil-7.0.0.20251116.tar.gz", hash = "sha256:92b5c78962e55ce1ed7b0189901a4409ece36ab9fd50c3029cca7e681c606c8a", size = 22192, upload-time = "2025-11-16T03:10:32.859Z" } +sdist = { url = "https://files.pythonhosted.org/packages/69/14/fc5fb0a6ddfadf68c27e254a02ececd4d5c7fdb0efcb7e7e917a183497fb/types_psutil-7.2.2.20260130.tar.gz", hash = "sha256:15b0ab69c52841cf9ce3c383e8480c620a4d13d6a8e22b16978ebddac5590950", size = 26535, upload-time = "2026-01-30T03:58:14.116Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/c3/0e/11ba08a5375c21039ed5f8e6bba41e9452fb69f0e2f7ee05ed5cca2a2cdf/types_psutil-7.0.0.20251116-py3-none-any.whl", hash = "sha256:74c052de077c2024b85cd435e2cba971165fe92a5eace79cbeb821e776dbc047", size = 25376, upload-time = "2025-11-16T03:10:31.813Z" }, + { url = "https://files.pythonhosted.org/packages/17/d7/60974b7e31545d3768d1770c5fe6e093182c3bfd819429b33133ba6b3e89/types_psutil-7.2.2.20260130-py3-none-any.whl", hash = "sha256:15523a3caa7b3ff03ac7f9b78a6470a59f88f48df1d74a39e70e06d2a99107da", size = 32876, upload-time = "2026-01-30T03:58:13.172Z" }, ] [[package]] From b55c0ec4de805e49218040f9b928379172a9d948 Mon Sep 17 00:00:00 2001 From: Stephen Zhou <38493346+hyoban@users.noreply.github.com> Date: Tue, 3 Feb 2026 12:26:47 +0800 Subject: [PATCH 13/43] fix: revert "refactor: api/controllers/console/feature.py (test)" (#31850) --- api/controllers/console/feature.py | 94 +++--- .../console/test_fastopenapi_feature.py | 291 ------------------ 2 files changed, 48 insertions(+), 337 deletions(-) delete mode 100644 api/tests/unit_tests/controllers/console/test_fastopenapi_feature.py diff --git a/api/controllers/console/feature.py b/api/controllers/console/feature.py index 1e98d622fe..d3811e2d1b 100644 --- a/api/controllers/console/feature.py +++ b/api/controllers/console/feature.py @@ -1,58 +1,60 @@ -from pydantic import BaseModel, Field +from flask_restx import Resource, fields from werkzeug.exceptions import Unauthorized -from controllers.fastopenapi import console_router from libs.login import current_account_with_tenant, current_user, login_required -from services.feature_service import FeatureModel, FeatureService, SystemFeatureModel +from services.feature_service import FeatureService +from . import console_ns from .wraps import account_initialization_required, cloud_utm_record, setup_required -class FeatureResponse(BaseModel): - features: FeatureModel = Field(description="Feature configuration object") +@console_ns.route("/features") +class FeatureApi(Resource): + @console_ns.doc("get_tenant_features") + @console_ns.doc(description="Get feature configuration for current tenant") + @console_ns.response( + 200, + "Success", + console_ns.model("FeatureResponse", {"features": fields.Raw(description="Feature configuration object")}), + ) + @setup_required + @login_required + @account_initialization_required + @cloud_utm_record + def get(self): + """Get feature configuration for current tenant""" + _, current_tenant_id = current_account_with_tenant() + + return FeatureService.get_features(current_tenant_id).model_dump() -class SystemFeatureResponse(BaseModel): - features: SystemFeatureModel = Field(description="System feature configuration object") +@console_ns.route("/system-features") +class SystemFeatureApi(Resource): + @console_ns.doc("get_system_features") + @console_ns.doc(description="Get system-wide feature configuration") + @console_ns.response( + 200, + "Success", + console_ns.model( + "SystemFeatureResponse", {"features": fields.Raw(description="System feature configuration object")} + ), + ) + def get(self): + """Get system-wide feature configuration + NOTE: This endpoint is unauthenticated by design, as it provides system features + data required for dashboard initialization. -@console_router.get( - "/features", - response_model=FeatureResponse, - tags=["console"], -) -@setup_required -@login_required -@account_initialization_required -@cloud_utm_record -def get_tenant_features() -> FeatureResponse: - """Get feature configuration for current tenant.""" - _, current_tenant_id = current_account_with_tenant() + Authentication would create circular dependency (can't login without dashboard loading). - return FeatureResponse(features=FeatureService.get_features(current_tenant_id)) - - -@console_router.get( - "/system-features", - response_model=SystemFeatureResponse, - tags=["console"], -) -def get_system_features() -> SystemFeatureResponse: - """Get system-wide feature configuration - - NOTE: This endpoint is unauthenticated by design, as it provides system features - data required for dashboard initialization. - - Authentication would create circular dependency (can't login without dashboard loading). - - Only non-sensitive configuration data should be returned by this endpoint. - """ - # NOTE(QuantumGhost): ideally we should access `current_user.is_authenticated` - # without a try-catch. However, due to the implementation of user loader (the `load_user_from_request` - # in api/extensions/ext_login.py), accessing `current_user.is_authenticated` will - # raise `Unauthorized` exception if authentication token is not provided. - try: - is_authenticated = current_user.is_authenticated - except Unauthorized: - is_authenticated = False - return SystemFeatureResponse(features=FeatureService.get_system_features(is_authenticated=is_authenticated)) + Only non-sensitive configuration data should be returned by this endpoint. + """ + # NOTE(QuantumGhost): ideally we should access `current_user.is_authenticated` + # without a try-catch. However, due to the implementation of user loader (the `load_user_from_request` + # in api/extensions/ext_login.py), accessing `current_user.is_authenticated` will + # raise `Unauthorized` exception if authentication token is not provided. + try: + is_authenticated = current_user.is_authenticated + except Unauthorized: + is_authenticated = False + return FeatureService.get_system_features(is_authenticated=is_authenticated).model_dump() diff --git a/api/tests/unit_tests/controllers/console/test_fastopenapi_feature.py b/api/tests/unit_tests/controllers/console/test_fastopenapi_feature.py deleted file mode 100644 index 68495dd979..0000000000 --- a/api/tests/unit_tests/controllers/console/test_fastopenapi_feature.py +++ /dev/null @@ -1,291 +0,0 @@ -import builtins -import contextlib -import importlib -import sys -from unittest.mock import MagicMock, PropertyMock, patch - -import pytest -from flask import Flask -from flask.views import MethodView -from werkzeug.exceptions import Unauthorized - -from extensions import ext_fastopenapi -from extensions.ext_database import db -from services.feature_service import FeatureModel, SystemFeatureModel - - -@pytest.fixture -def app(): - """ - Creates a Flask application instance configured for testing. - """ - app = Flask(__name__) - app.config["TESTING"] = True - app.config["SECRET_KEY"] = "test-secret" - app.config["SQLALCHEMY_DATABASE_URI"] = "sqlite:///:memory:" - - # Initialize the database with the app - db.init_app(app) - - return app - - -@pytest.fixture(autouse=True) -def fix_method_view_issue(monkeypatch): - """ - Automatic fixture to patch 'builtins.MethodView'. - - Why this is needed: - The official legacy codebase contains a global patch in its initialization logic: - if not hasattr(builtins, "MethodView"): - builtins.MethodView = MethodView - - Some dependencies (like ext_fastopenapi or older Flask extensions) might implicitly - rely on 'MethodView' being available in the global builtins namespace. - - Refactoring Note: - While patching builtins is generally discouraged due to global side effects, - this fixture reproduces the production environment's state to ensure tests are realistic. - We use 'monkeypatch' to ensure that this change is undone after the test finishes, - keeping other tests isolated. - """ - if not hasattr(builtins, "MethodView"): - # 'raising=False' allows us to set an attribute that doesn't exist yet - monkeypatch.setattr(builtins, "MethodView", MethodView, raising=False) - - -# ------------------------------------------------------------------------------ -# Helper Functions for Fixture Complexity Reduction -# ------------------------------------------------------------------------------ - - -def _create_isolated_router(): - """ - Creates a fresh, isolated router instance to prevent route pollution. - """ - import controllers.fastopenapi - - # Dynamically get the class type (e.g., FlaskRouter) to avoid hardcoding dependencies - RouterClass = type(controllers.fastopenapi.console_router) - return RouterClass() - - -@contextlib.contextmanager -def _patch_auth_and_router(temp_router): - """ - Context manager that applies all necessary patches for: - 1. The console_router (redirecting to our isolated temp_router) - 2. Authentication decorators (disabling them with no-ops) - 3. User/Account loaders (mocking authenticated state) - """ - - def noop(f): - return f - - # We patch the SOURCE of the decorators/functions, not the destination module. - # This ensures that when 'controllers.console.feature' imports them, it gets the mocks. - with ( - patch("controllers.fastopenapi.console_router", temp_router), - patch("extensions.ext_fastopenapi.console_router", temp_router), - patch("controllers.console.wraps.setup_required", side_effect=noop), - patch("libs.login.login_required", side_effect=noop), - patch("controllers.console.wraps.account_initialization_required", side_effect=noop), - patch("controllers.console.wraps.cloud_utm_record", side_effect=noop), - patch("libs.login.current_account_with_tenant", return_value=(MagicMock(), "tenant-id")), - patch("libs.login.current_user", MagicMock(is_authenticated=True)), - ): - # Explicitly reload ext_fastopenapi to ensure it uses the patched console_router - import extensions.ext_fastopenapi - - importlib.reload(extensions.ext_fastopenapi) - - yield - - -def _force_reload_module(target_module: str, alias_module: str): - """ - Forces a reload of the specified module and handles sys.modules aliasing. - - Why reload? - Python decorators (like @route, @login_required) run at IMPORT time. - To apply our patches (mocks/no-ops) to these decorators, we must re-import - the module while the patches are active. - - Why alias? - If 'ext_fastopenapi' imports the controller as 'api.controllers...', but we import - it as 'controllers...', Python treats them as two separate modules. This causes: - 1. Double execution of decorators (registering routes twice -> AssertionError). - 2. Type mismatch errors (Class A from module X is not Class A from module Y). - - This function ensures both names point to the SAME loaded module instance. - """ - # 1. Clean existing entries to force re-import - if target_module in sys.modules: - del sys.modules[target_module] - if alias_module in sys.modules: - del sys.modules[alias_module] - - # 2. Import the module (triggering decorators with active patches) - module = importlib.import_module(target_module) - - # 3. Alias the module in sys.modules to prevent double loading - sys.modules[alias_module] = sys.modules[target_module] - - return module - - -def _cleanup_modules(target_module: str, alias_module: str): - """ - Removes the module and its alias from sys.modules to prevent side effects - on other tests. - """ - if target_module in sys.modules: - del sys.modules[target_module] - if alias_module in sys.modules: - del sys.modules[alias_module] - - -@pytest.fixture -def mock_feature_module_env(): - """ - Sets up a mocked environment for the feature module. - - This fixture orchestrates: - 1. Creating an isolated router. - 2. Patching authentication and global dependencies. - 3. Reloading the controller module to apply patches to decorators. - 4. cleaning up sys.modules afterwards. - """ - target_module = "controllers.console.feature" - alias_module = "api.controllers.console.feature" - - # 1. Prepare isolated router - temp_router = _create_isolated_router() - - # 2. Apply patches - try: - with _patch_auth_and_router(temp_router): - # 3. Reload module to register routes on the temp_router - feature_module = _force_reload_module(target_module, alias_module) - - yield feature_module - - finally: - # 4. Teardown: Clean up sys.modules - _cleanup_modules(target_module, alias_module) - - -# ------------------------------------------------------------------------------ -# Test Cases -# ------------------------------------------------------------------------------ - - -@pytest.mark.parametrize( - ("url", "service_mock_path", "mock_model_instance", "json_key"), - [ - ( - "/console/api/features", - "controllers.console.feature.FeatureService.get_features", - FeatureModel(can_replace_logo=True), - "features", - ), - ( - "/console/api/system-features", - "controllers.console.feature.FeatureService.get_system_features", - SystemFeatureModel(enable_marketplace=True), - "features", - ), - ], -) -def test_console_features_success(app, mock_feature_module_env, url, service_mock_path, mock_model_instance, json_key): - """ - Tests that the feature APIs return a 200 OK status and correct JSON structure. - """ - # Patch the service layer to return our mock model instance - with patch(service_mock_path, return_value=mock_model_instance): - # Initialize the API extension - ext_fastopenapi.init_app(app) - - client = app.test_client() - response = client.get(url) - - # Assertions - assert response.status_code == 200, f"Request failed with status {response.status_code}: {response.text}" - - # Verify the JSON response matches the Pydantic model dump - expected_data = mock_model_instance.model_dump(mode="json") - assert response.get_json() == {json_key: expected_data} - - -@pytest.mark.parametrize( - ("url", "service_mock_path"), - [ - ("/console/api/features", "controllers.console.feature.FeatureService.get_features"), - ("/console/api/system-features", "controllers.console.feature.FeatureService.get_system_features"), - ], -) -def test_console_features_service_error(app, mock_feature_module_env, url, service_mock_path): - """ - Tests how the application handles Service layer errors. - - Note: When an exception occurs in the view, it is typically caught by the framework - (Flask or the OpenAPI wrapper) and converted to a 500 error response. - This test verifies that the application returns a 500 status code. - """ - # Simulate a service failure - with patch(service_mock_path, side_effect=ValueError("Service Failure")): - ext_fastopenapi.init_app(app) - client = app.test_client() - - # When an exception occurs in the view, it is typically caught by the framework - # (Flask or the OpenAPI wrapper) and converted to a 500 error response. - response = client.get(url) - - assert response.status_code == 500 - # Check if the error details are exposed in the response (depends on error handler config) - # We accept either generic 500 or the specific error message - assert "Service Failure" in response.text or "Internal Server Error" in response.text - - -def test_system_features_unauthenticated(app, mock_feature_module_env): - """ - Tests that /console/api/system-features endpoint works without authentication. - - This test verifies the try-except block in get_system_features that handles - unauthenticated requests by passing is_authenticated=False to the service layer. - """ - feature_module = mock_feature_module_env - - # Override the behavior of the current_user mock - # The fixture patched 'libs.login.current_user', so 'controllers.console.feature.current_user' - # refers to that same Mock object. - mock_user = feature_module.current_user - - # Simulate property access raising Unauthorized - # Note: We must reset side_effect if it was set, or set it here. - # The fixture initialized it as MagicMock(is_authenticated=True). - # We want type(mock_user).is_authenticated to raise Unauthorized. - type(mock_user).is_authenticated = PropertyMock(side_effect=Unauthorized) - - # Patch the service layer for this specific test - with patch("controllers.console.feature.FeatureService.get_system_features") as mock_service: - # Setup mock service return value - mock_model = SystemFeatureModel(enable_marketplace=True) - mock_service.return_value = mock_model - - # Initialize app - ext_fastopenapi.init_app(app) - client = app.test_client() - - # Act - response = client.get("/console/api/system-features") - - # Assert - assert response.status_code == 200, f"Request failed: {response.text}" - - # Verify service was called with is_authenticated=False - mock_service.assert_called_once_with(is_authenticated=False) - - # Verify response body - expected_data = mock_model.model_dump(mode="json") - assert response.get_json() == {"features": expected_data} From aa7fe42615b7d0fd4a8fe638e8d57a5959a7f7a8 Mon Sep 17 00:00:00 2001 From: Coding On Star <447357187@qq.com> Date: Tue, 3 Feb 2026 13:47:30 +0800 Subject: [PATCH 14/43] test: enhance CommandSelector and GotoAnythingProvider tests (#31743) Co-authored-by: CodingOnStar --- .../app/create-app-modal/index.spec.tsx | 4 +- .../explore/create-app-modal/index.spec.tsx | 32 +- .../goto-anything/command-selector.spec.tsx | 201 ++++++ .../components/empty-state.spec.tsx | 157 +++++ .../goto-anything/components/empty-state.tsx | 105 ++++ .../goto-anything/components/footer.spec.tsx | 273 ++++++++ .../goto-anything/components/footer.tsx | 90 +++ .../goto-anything/components/index.ts | 14 + .../goto-anything/components/result-item.tsx | 38 ++ .../goto-anything/components/result-list.tsx | 49 ++ .../components/search-input.spec.tsx | 206 ++++++ .../goto-anything/components/search-input.tsx | 62 ++ .../components/goto-anything/context.spec.tsx | 77 ++- .../components/goto-anything/hooks/index.ts | 11 + .../hooks/use-goto-anything-modal.spec.ts | 291 +++++++++ .../hooks/use-goto-anything-modal.ts | 59 ++ .../use-goto-anything-navigation.spec.ts | 391 ++++++++++++ .../hooks/use-goto-anything-navigation.ts | 96 +++ .../hooks/use-goto-anything-results.spec.ts | 354 +++++++++++ .../hooks/use-goto-anything-results.ts | 115 ++++ .../hooks/use-goto-anything-search.spec.ts | 301 +++++++++ .../hooks/use-goto-anything-search.ts | 77 +++ .../components/goto-anything/index.spec.tsx | 581 +++++++++++++++-- web/app/components/goto-anything/index.tsx | 585 +++++------------- .../workflow-onboarding-modal/index.spec.tsx | 4 +- web/eslint-suppressions.json | 10 - 26 files changed, 3666 insertions(+), 517 deletions(-) create mode 100644 web/app/components/goto-anything/components/empty-state.spec.tsx create mode 100644 web/app/components/goto-anything/components/empty-state.tsx create mode 100644 web/app/components/goto-anything/components/footer.spec.tsx create mode 100644 web/app/components/goto-anything/components/footer.tsx create mode 100644 web/app/components/goto-anything/components/index.ts create mode 100644 web/app/components/goto-anything/components/result-item.tsx create mode 100644 web/app/components/goto-anything/components/result-list.tsx create mode 100644 web/app/components/goto-anything/components/search-input.spec.tsx create mode 100644 web/app/components/goto-anything/components/search-input.tsx create mode 100644 web/app/components/goto-anything/hooks/index.ts create mode 100644 web/app/components/goto-anything/hooks/use-goto-anything-modal.spec.ts create mode 100644 web/app/components/goto-anything/hooks/use-goto-anything-modal.ts create mode 100644 web/app/components/goto-anything/hooks/use-goto-anything-navigation.spec.ts create mode 100644 web/app/components/goto-anything/hooks/use-goto-anything-navigation.ts create mode 100644 web/app/components/goto-anything/hooks/use-goto-anything-results.spec.ts create mode 100644 web/app/components/goto-anything/hooks/use-goto-anything-results.ts create mode 100644 web/app/components/goto-anything/hooks/use-goto-anything-search.spec.ts create mode 100644 web/app/components/goto-anything/hooks/use-goto-anything-search.ts diff --git a/web/app/components/app/create-app-modal/index.spec.tsx b/web/app/components/app/create-app-modal/index.spec.tsx index cb8f4db67f..d26a581fda 100644 --- a/web/app/components/app/create-app-modal/index.spec.tsx +++ b/web/app/components/app/create-app-modal/index.spec.tsx @@ -124,7 +124,7 @@ describe('CreateAppModal', () => { const nameInput = screen.getByPlaceholderText('app.newApp.appNamePlaceholder') fireEvent.change(nameInput, { target: { value: 'My App' } }) - fireEvent.click(screen.getByRole('button', { name: 'app.newApp.Create' })) + fireEvent.click(screen.getByRole('button', { name: /app\.newApp\.Create/ })) await waitFor(() => expect(mockCreateApp).toHaveBeenCalledWith({ name: 'My App', @@ -152,7 +152,7 @@ describe('CreateAppModal', () => { const nameInput = screen.getByPlaceholderText('app.newApp.appNamePlaceholder') fireEvent.change(nameInput, { target: { value: 'My App' } }) - fireEvent.click(screen.getByRole('button', { name: 'app.newApp.Create' })) + fireEvent.click(screen.getByRole('button', { name: /app\.newApp\.Create/ })) await waitFor(() => expect(mockCreateApp).toHaveBeenCalled()) expect(mockNotify).toHaveBeenCalledWith({ type: 'error', message: 'boom' }) diff --git a/web/app/components/explore/create-app-modal/index.spec.tsx b/web/app/components/explore/create-app-modal/index.spec.tsx index 7ddb5a9082..65ec0e6096 100644 --- a/web/app/components/explore/create-app-modal/index.spec.tsx +++ b/web/app/components/explore/create-app-modal/index.spec.tsx @@ -138,7 +138,7 @@ describe('CreateAppModal', () => { setup({ appName: 'My App', isEditModal: false }) expect(screen.getByText('explore.appCustomize.title:{"name":"My App"}')).toBeInTheDocument() - expect(screen.getByRole('button', { name: 'common.operation.create' })).toBeInTheDocument() + expect(screen.getByRole('button', { name: /common\.operation\.create/ })).toBeInTheDocument() expect(screen.getByRole('button', { name: 'common.operation.cancel' })).toBeInTheDocument() }) @@ -146,7 +146,7 @@ describe('CreateAppModal', () => { setup({ isEditModal: true, appMode: AppModeEnum.CHAT, max_active_requests: 5 }) expect(screen.getByText('app.editAppTitle')).toBeInTheDocument() - expect(screen.getByRole('button', { name: 'common.operation.save' })).toBeInTheDocument() + expect(screen.getByRole('button', { name: /common\.operation\.save/ })).toBeInTheDocument() expect(screen.getByRole('switch')).toBeInTheDocument() expect((screen.getByRole('spinbutton') as HTMLInputElement).value).toBe('5') }) @@ -166,7 +166,7 @@ describe('CreateAppModal', () => { it('should not render modal content when hidden', () => { setup({ show: false }) - expect(screen.queryByRole('button', { name: 'common.operation.create' })).not.toBeInTheDocument() + expect(screen.queryByRole('button', { name: /common\.operation\.create/ })).not.toBeInTheDocument() }) }) @@ -175,13 +175,13 @@ describe('CreateAppModal', () => { it('should disable confirm action when confirmDisabled is true', () => { setup({ confirmDisabled: true }) - expect(screen.getByRole('button', { name: 'common.operation.create' })).toBeDisabled() + expect(screen.getByRole('button', { name: /common\.operation\.create/ })).toBeDisabled() }) it('should disable confirm action when appName is empty', () => { setup({ appName: ' ' }) - expect(screen.getByRole('button', { name: 'common.operation.create' })).toBeDisabled() + expect(screen.getByRole('button', { name: /common\.operation\.create/ })).toBeDisabled() }) }) @@ -245,7 +245,7 @@ describe('CreateAppModal', () => { setup({ isEditModal: false }) expect(screen.getByText('billing.apps.fullTip2')).toBeInTheDocument() - expect(screen.getByRole('button', { name: 'common.operation.create' })).toBeDisabled() + expect(screen.getByRole('button', { name: /common\.operation\.create/ })).toBeDisabled() }) it('should allow saving when apps quota is reached in edit mode', () => { @@ -257,7 +257,7 @@ describe('CreateAppModal', () => { setup({ isEditModal: true }) expect(screen.queryByText('billing.apps.fullTip2')).not.toBeInTheDocument() - expect(screen.getByRole('button', { name: 'common.operation.save' })).toBeEnabled() + expect(screen.getByRole('button', { name: /common\.operation\.save/ })).toBeEnabled() }) }) @@ -384,7 +384,7 @@ describe('CreateAppModal', () => { fireEvent.click(screen.getByRole('button', { name: 'app.iconPicker.ok' })) - fireEvent.click(screen.getByRole('button', { name: 'common.operation.create' })) + fireEvent.click(screen.getByRole('button', { name: /common\.operation\.create/ })) act(() => { vi.advanceTimersByTime(300) }) @@ -433,7 +433,7 @@ describe('CreateAppModal', () => { expect(screen.queryByRole('button', { name: 'app.iconPicker.cancel' })).not.toBeInTheDocument() // Submit and verify the payload uses the original icon (cancel reverts to props) - fireEvent.click(screen.getByRole('button', { name: 'common.operation.create' })) + fireEvent.click(screen.getByRole('button', { name: /common\.operation\.create/ })) act(() => { vi.advanceTimersByTime(300) }) @@ -471,7 +471,7 @@ describe('CreateAppModal', () => { appIconBackground: '#000000', }) - fireEvent.click(screen.getByRole('button', { name: 'common.operation.create' })) + fireEvent.click(screen.getByRole('button', { name: /common\.operation\.create/ })) act(() => { vi.advanceTimersByTime(300) }) @@ -495,7 +495,7 @@ describe('CreateAppModal', () => { const { onConfirm } = setup({ appDescription: 'Old description' }) fireEvent.change(screen.getByPlaceholderText('app.newApp.appDescriptionPlaceholder'), { target: { value: 'Updated description' } }) - fireEvent.click(screen.getByRole('button', { name: 'common.operation.create' })) + fireEvent.click(screen.getByRole('button', { name: /common\.operation\.create/ })) act(() => { vi.advanceTimersByTime(300) }) @@ -512,7 +512,7 @@ describe('CreateAppModal', () => { appIconBackground: null, }) - fireEvent.click(screen.getByRole('button', { name: 'common.operation.create' })) + fireEvent.click(screen.getByRole('button', { name: /common\.operation\.create/ })) act(() => { vi.advanceTimersByTime(300) }) @@ -536,7 +536,7 @@ describe('CreateAppModal', () => { fireEvent.click(screen.getByRole('switch')) fireEvent.change(screen.getByRole('spinbutton'), { target: { value: '12' } }) - fireEvent.click(screen.getByRole('button', { name: 'common.operation.save' })) + fireEvent.click(screen.getByRole('button', { name: /common\.operation\.save/ })) act(() => { vi.advanceTimersByTime(300) }) @@ -551,7 +551,7 @@ describe('CreateAppModal', () => { it('should omit max_active_requests when input is empty', () => { const { onConfirm } = setup({ isEditModal: true, max_active_requests: null }) - fireEvent.click(screen.getByRole('button', { name: 'common.operation.save' })) + fireEvent.click(screen.getByRole('button', { name: /common\.operation\.save/ })) act(() => { vi.advanceTimersByTime(300) }) @@ -564,7 +564,7 @@ describe('CreateAppModal', () => { const { onConfirm } = setup({ isEditModal: true, max_active_requests: null }) fireEvent.change(screen.getByRole('spinbutton'), { target: { value: 'abc' } }) - fireEvent.click(screen.getByRole('button', { name: 'common.operation.save' })) + fireEvent.click(screen.getByRole('button', { name: /common\.operation\.save/ })) act(() => { vi.advanceTimersByTime(300) }) @@ -576,7 +576,7 @@ describe('CreateAppModal', () => { it('should show toast error and not submit when name becomes empty before debounced submit runs', () => { const { onConfirm, onHide } = setup({ appName: 'My App' }) - fireEvent.click(screen.getByRole('button', { name: 'common.operation.create' })) + fireEvent.click(screen.getByRole('button', { name: /common\.operation\.create/ })) fireEvent.change(screen.getByPlaceholderText('app.newApp.appNamePlaceholder'), { target: { value: ' ' } }) act(() => { diff --git a/web/app/components/goto-anything/command-selector.spec.tsx b/web/app/components/goto-anything/command-selector.spec.tsx index 0ee2086058..0712a1afd6 100644 --- a/web/app/components/goto-anything/command-selector.spec.tsx +++ b/web/app/components/goto-anything/command-selector.spec.tsx @@ -81,4 +81,205 @@ describe('CommandSelector', () => { expect(onSelect).toHaveBeenCalledWith('/zen') }) + + it('should show all slash commands when no filter provided', () => { + const actions = createActions() + const onSelect = vi.fn() + + render( + + + , + ) + + // Should show the zen command from mock + expect(screen.getByText('/zen')).toBeInTheDocument() + }) + + it('should exclude slash action when in @ mode', () => { + const actions = { + ...createActions(), + slash: { + key: '/', + shortcut: '/', + title: 'Slash', + search: vi.fn(), + description: '', + } as ActionItem, + } + const onSelect = vi.fn() + + render( + + + , + ) + + // Should show @ commands but not / + expect(screen.getByText('@app')).toBeInTheDocument() + expect(screen.queryByText('/')).not.toBeInTheDocument() + }) + + it('should show all actions when no filter in @ mode', () => { + const actions = createActions() + const onSelect = vi.fn() + + render( + + + , + ) + + expect(screen.getByText('@app')).toBeInTheDocument() + expect(screen.getByText('@plugin')).toBeInTheDocument() + }) + + it('should set default command value when items exist but value does not', () => { + const actions = createActions() + const onSelect = vi.fn() + const onCommandValueChange = vi.fn() + + render( + + + , + ) + + expect(onCommandValueChange).toHaveBeenCalledWith('@app') + }) + + it('should NOT set command value when value already exists in items', () => { + const actions = createActions() + const onSelect = vi.fn() + const onCommandValueChange = vi.fn() + + render( + + + , + ) + + expect(onCommandValueChange).not.toHaveBeenCalled() + }) + + it('should show no matching commands message when filter has no results', () => { + const actions = createActions() + const onSelect = vi.fn() + + render( + + + , + ) + + expect(screen.getByText('app.gotoAnything.noMatchingCommands')).toBeInTheDocument() + expect(screen.getByText('app.gotoAnything.tryDifferentSearch')).toBeInTheDocument() + }) + + it('should show no matching commands for slash mode with no results', () => { + const actions = createActions() + const onSelect = vi.fn() + + render( + + + , + ) + + expect(screen.getByText('app.gotoAnything.noMatchingCommands')).toBeInTheDocument() + }) + + it('should render description for @ commands', () => { + const actions = createActions() + const onSelect = vi.fn() + + render( + + + , + ) + + expect(screen.getByText('app.gotoAnything.actions.searchApplicationsDesc')).toBeInTheDocument() + expect(screen.getByText('app.gotoAnything.actions.searchPluginsDesc')).toBeInTheDocument() + }) + + it('should render group header for @ mode', () => { + const actions = createActions() + const onSelect = vi.fn() + + render( + + + , + ) + + expect(screen.getByText('app.gotoAnything.selectSearchType')).toBeInTheDocument() + }) + + it('should render group header for slash mode', () => { + const actions = createActions() + const onSelect = vi.fn() + + render( + + + , + ) + + expect(screen.getByText('app.gotoAnything.groups.commands')).toBeInTheDocument() + }) }) diff --git a/web/app/components/goto-anything/components/empty-state.spec.tsx b/web/app/components/goto-anything/components/empty-state.spec.tsx new file mode 100644 index 0000000000..e1e5e0dc89 --- /dev/null +++ b/web/app/components/goto-anything/components/empty-state.spec.tsx @@ -0,0 +1,157 @@ +import { render, screen } from '@testing-library/react' +import EmptyState from './empty-state' + +vi.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string, options?: { ns?: string, shortcuts?: string }) => { + if (options?.shortcuts !== undefined) + return `${key}:${options.shortcuts}` + return `${options?.ns || 'common'}.${key}` + }, + }), +})) + +describe('EmptyState', () => { + describe('loading variant', () => { + it('should render loading spinner', () => { + render() + + expect(screen.getByText('app.gotoAnything.searching')).toBeInTheDocument() + }) + + it('should have spinner animation class', () => { + const { container } = render() + + const spinner = container.querySelector('.animate-spin') + expect(spinner).toBeInTheDocument() + }) + }) + + describe('error variant', () => { + it('should render error message when error has message', () => { + const error = new Error('Connection failed') + render() + + expect(screen.getByText('app.gotoAnything.searchFailed')).toBeInTheDocument() + expect(screen.getByText('Connection failed')).toBeInTheDocument() + }) + + it('should render generic error when error has no message', () => { + render() + + expect(screen.getByText('app.gotoAnything.searchTemporarilyUnavailable')).toBeInTheDocument() + expect(screen.getByText('app.gotoAnything.servicesUnavailableMessage')).toBeInTheDocument() + }) + + it('should render generic error when error is undefined', () => { + render() + + expect(screen.getByText('app.gotoAnything.searchTemporarilyUnavailable')).toBeInTheDocument() + }) + + it('should have red error text styling', () => { + const error = new Error('Test error') + const { container } = render() + + const errorText = container.querySelector('.text-red-500') + expect(errorText).toBeInTheDocument() + }) + }) + + describe('default variant', () => { + it('should render search title', () => { + render() + + expect(screen.getByText('app.gotoAnything.searchTitle')).toBeInTheDocument() + }) + + it('should render all hint messages', () => { + render() + + expect(screen.getByText('app.gotoAnything.searchHint')).toBeInTheDocument() + expect(screen.getByText('app.gotoAnything.commandHint')).toBeInTheDocument() + expect(screen.getByText('app.gotoAnything.slashHint')).toBeInTheDocument() + }) + }) + + describe('no-results variant', () => { + describe('general search mode', () => { + it('should render generic no results message', () => { + render() + + expect(screen.getByText('app.gotoAnything.noResults')).toBeInTheDocument() + }) + + it('should show specific search hint with shortcuts', () => { + const Actions = { + app: { key: '@app', shortcut: '@app' }, + plugin: { key: '@plugin', shortcut: '@plugin' }, + } as unknown as Record + render() + + expect(screen.getByText('gotoAnything.emptyState.trySpecificSearch:@app, @plugin')).toBeInTheDocument() + }) + }) + + describe('app search mode', () => { + it('should render no apps found message', () => { + render() + + expect(screen.getByText('app.gotoAnything.emptyState.noAppsFound')).toBeInTheDocument() + }) + + it('should show try different term hint', () => { + render() + + expect(screen.getByText('app.gotoAnything.emptyState.tryDifferentTerm')).toBeInTheDocument() + }) + }) + + describe('plugin search mode', () => { + it('should render no plugins found message', () => { + render() + + expect(screen.getByText('app.gotoAnything.emptyState.noPluginsFound')).toBeInTheDocument() + }) + }) + + describe('knowledge search mode', () => { + it('should render no knowledge bases found message', () => { + render() + + expect(screen.getByText('app.gotoAnything.emptyState.noKnowledgeBasesFound')).toBeInTheDocument() + }) + }) + + describe('node search mode', () => { + it('should render no workflow nodes found message', () => { + render() + + expect(screen.getByText('app.gotoAnything.emptyState.noWorkflowNodesFound')).toBeInTheDocument() + }) + }) + + describe('unknown search mode', () => { + it('should fallback to generic no results message', () => { + render() + + expect(screen.getByText('app.gotoAnything.noResults')).toBeInTheDocument() + }) + }) + }) + + describe('default props', () => { + it('should use general as default searchMode', () => { + render() + + expect(screen.getByText('app.gotoAnything.noResults')).toBeInTheDocument() + }) + + it('should use empty object as default Actions', () => { + render() + + // Should show empty shortcuts + expect(screen.getByText('gotoAnything.emptyState.trySpecificSearch:')).toBeInTheDocument() + }) + }) +}) diff --git a/web/app/components/goto-anything/components/empty-state.tsx b/web/app/components/goto-anything/components/empty-state.tsx new file mode 100644 index 0000000000..a07bc1d45a --- /dev/null +++ b/web/app/components/goto-anything/components/empty-state.tsx @@ -0,0 +1,105 @@ +'use client' + +import type { FC } from 'react' +import type { ActionItem } from '../actions/types' +import { useTranslation } from 'react-i18next' + +export type EmptyStateVariant = 'no-results' | 'error' | 'default' | 'loading' + +export type EmptyStateProps = { + variant: EmptyStateVariant + searchMode?: string + error?: Error | null + Actions?: Record +} + +const EmptyState: FC = ({ + variant, + searchMode = 'general', + error, + Actions = {}, +}) => { + const { t } = useTranslation() + + if (variant === 'loading') { + return ( +
+
+
+ {t('gotoAnything.searching', { ns: 'app' })} +
+
+ ) + } + + if (variant === 'error') { + return ( +
+
+
+ {error?.message + ? t('gotoAnything.searchFailed', { ns: 'app' }) + : t('gotoAnything.searchTemporarilyUnavailable', { ns: 'app' })} +
+
+ {error?.message || t('gotoAnything.servicesUnavailableMessage', { ns: 'app' })} +
+
+
+ ) + } + + if (variant === 'default') { + return ( +
+
+
{t('gotoAnything.searchTitle', { ns: 'app' })}
+
+
{t('gotoAnything.searchHint', { ns: 'app' })}
+
{t('gotoAnything.commandHint', { ns: 'app' })}
+
{t('gotoAnything.slashHint', { ns: 'app' })}
+
+
+
+ ) + } + + // variant === 'no-results' + const isCommandSearch = searchMode !== 'general' + const commandType = isCommandSearch ? searchMode.replace('@', '') : '' + + const getNoResultsMessage = () => { + if (!isCommandSearch) { + return t('gotoAnything.noResults', { ns: 'app' }) + } + + const keyMap = { + app: 'gotoAnything.emptyState.noAppsFound', + plugin: 'gotoAnything.emptyState.noPluginsFound', + knowledge: 'gotoAnything.emptyState.noKnowledgeBasesFound', + node: 'gotoAnything.emptyState.noWorkflowNodesFound', + } as const + + return t(keyMap[commandType as keyof typeof keyMap] || 'gotoAnything.noResults', { ns: 'app' }) + } + + const getHintMessage = () => { + if (isCommandSearch) { + return t('gotoAnything.emptyState.tryDifferentTerm', { ns: 'app' }) + } + + const shortcuts = Object.values(Actions).map(action => action.shortcut).join(', ') + return t('gotoAnything.emptyState.trySpecificSearch', { ns: 'app', shortcuts }) + } + + return ( +
+
+
{getNoResultsMessage()}
+
{getHintMessage()}
+
+
+ ) +} + +export default EmptyState diff --git a/web/app/components/goto-anything/components/footer.spec.tsx b/web/app/components/goto-anything/components/footer.spec.tsx new file mode 100644 index 0000000000..3dfac5f71c --- /dev/null +++ b/web/app/components/goto-anything/components/footer.spec.tsx @@ -0,0 +1,273 @@ +import { render, screen } from '@testing-library/react' +import Footer from './footer' + +vi.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string, options?: { ns?: string, count?: number, scope?: string }) => { + if (options?.count !== undefined) + return `${key}:${options.count}` + if (options?.scope) + return `${key}:${options.scope}` + return `${options?.ns || 'common'}.${key}` + }, + }), +})) + +describe('Footer', () => { + describe('left content', () => { + describe('when there are results', () => { + it('should show result count', () => { + render( +