diff --git a/api/configs/middleware/__init__.py b/api/configs/middleware/__init__.py index d872e8201b..816d0e442f 100644 --- a/api/configs/middleware/__init__.py +++ b/api/configs/middleware/__init__.py @@ -145,7 +145,7 @@ class DatabaseConfig(BaseSettings): default="postgresql", ) - @computed_field # type: ignore[misc] + @computed_field # type: ignore[prop-decorator] @property def SQLALCHEMY_DATABASE_URI(self) -> str: db_extras = ( @@ -198,7 +198,7 @@ class DatabaseConfig(BaseSettings): default=os.cpu_count() or 1, ) - @computed_field # type: ignore[misc] + @computed_field # type: ignore[prop-decorator] @property def SQLALCHEMY_ENGINE_OPTIONS(self) -> dict[str, Any]: # Parse DB_EXTRAS for 'options' diff --git a/api/controllers/common/helpers.py b/api/controllers/common/helpers.py index 6a5197635e..ef89e66980 100644 --- a/api/controllers/common/helpers.py +++ b/api/controllers/common/helpers.py @@ -24,7 +24,7 @@ except ImportError: ) else: warnings.warn("To use python-magic guess MIMETYPE, you need to install `libmagic`", stacklevel=2) - magic = None # type: ignore + magic = None # type: ignore[assignment] from pydantic import BaseModel diff --git a/api/core/app/apps/agent_chat/app_generator.py b/api/core/app/apps/agent_chat/app_generator.py index c6d98374c1..7bd3b8a56e 100644 --- a/api/core/app/apps/agent_chat/app_generator.py +++ b/api/core/app/apps/agent_chat/app_generator.py @@ -211,8 +211,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator): user=user, stream=streaming, ) - # FIXME: Type hinting issue here, ignore it for now, will fix it later - return AgentChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) # type: ignore + return AgentChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) def _generate_worker( self, diff --git a/api/core/app/apps/workflow/generate_response_converter.py b/api/core/app/apps/workflow/generate_response_converter.py index 01ecf0298f..c64f44a603 100644 --- a/api/core/app/apps/workflow/generate_response_converter.py +++ b/api/core/app/apps/workflow/generate_response_converter.py @@ -89,7 +89,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter): data = cls._error_to_stream_response(sub_stream_response.err) response_chunk.update(data) elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse): - response_chunk.update(sub_stream_response.to_ignore_detail_dict()) # ty: ignore [unresolved-attribute] + response_chunk.update(sub_stream_response.to_ignore_detail_dict()) else: response_chunk.update(sub_stream_response.model_dump(mode="json")) yield response_chunk diff --git a/api/core/app/features/rate_limiting/rate_limit.py b/api/core/app/features/rate_limiting/rate_limit.py index ffa10cd43c..565905be0d 100644 --- a/api/core/app/features/rate_limiting/rate_limit.py +++ b/api/core/app/features/rate_limiting/rate_limit.py @@ -98,7 +98,7 @@ class RateLimit: else: return RateLimitGenerator( rate_limit=self, - generator=generator, # ty: ignore [invalid-argument-type] + generator=generator, request_id=request_id, ) diff --git a/api/core/app/task_pipeline/based_generate_task_pipeline.py b/api/core/app/task_pipeline/based_generate_task_pipeline.py index 45e3c0006b..26c7e60a4c 100644 --- a/api/core/app/task_pipeline/based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/based_generate_task_pipeline.py @@ -49,7 +49,7 @@ class BasedGenerateTaskPipeline: if isinstance(e, InvokeAuthorizationError): err = InvokeAuthorizationError("Incorrect API key provided") elif isinstance(e, InvokeError | ValueError): - err = e # ty: ignore [invalid-assignment] + err = e else: description = getattr(e, "description", None) err = Exception(description if description is not None else str(e)) diff --git a/api/core/entities/provider_configuration.py b/api/core/entities/provider_configuration.py index c4be429219..b10838f8c9 100644 --- a/api/core/entities/provider_configuration.py +++ b/api/core/entities/provider_configuration.py @@ -1868,7 +1868,7 @@ class ProviderConfigurations(BaseModel): if "/" not in key: key = str(ModelProviderID(key)) - return self.configurations.get(key, default) # type: ignore + return self.configurations.get(key, default) class ProviderModelBundle(BaseModel): diff --git a/api/core/helper/module_import_helper.py b/api/core/helper/module_import_helper.py index 6a2f27b8ba..2bada85582 100644 --- a/api/core/helper/module_import_helper.py +++ b/api/core/helper/module_import_helper.py @@ -20,7 +20,7 @@ def import_module_from_source(*, module_name: str, py_file_path: AnyStr, use_laz else: # Refer to: https://docs.python.org/3/library/importlib.html#importing-a-source-file-directly # FIXME: mypy does not support the type of spec.loader - spec = importlib.util.spec_from_file_location(module_name, py_file_path) # type: ignore + spec = importlib.util.spec_from_file_location(module_name, py_file_path) # type: ignore[assignment] if not spec or not spec.loader: raise Exception(f"Failed to load module {module_name} from {py_file_path!r}") if use_lazy_loader: diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index 7822ed4268..c430fba0b9 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -49,62 +49,80 @@ class IndexingRunner: self.storage = storage self.model_manager = ModelManager() + def _handle_indexing_error(self, document_id: str, error: Exception) -> None: + """Handle indexing errors by updating document status.""" + logger.exception("consume document failed") + document = db.session.get(DatasetDocument, document_id) + if document: + document.indexing_status = "error" + error_message = getattr(error, "description", str(error)) + document.error = str(error_message) + document.stopped_at = naive_utc_now() + db.session.commit() + def run(self, dataset_documents: list[DatasetDocument]): """Run the indexing process.""" for dataset_document in dataset_documents: + document_id = dataset_document.id try: + # Re-query the document to ensure it's bound to the current session + requeried_document = db.session.get(DatasetDocument, document_id) + if not requeried_document: + logger.warning("Document not found, skipping document id: %s", document_id) + continue + # get dataset - dataset = db.session.query(Dataset).filter_by(id=dataset_document.dataset_id).first() + dataset = db.session.query(Dataset).filter_by(id=requeried_document.dataset_id).first() if not dataset: raise ValueError("no dataset found") # get the process rule stmt = select(DatasetProcessRule).where( - DatasetProcessRule.id == dataset_document.dataset_process_rule_id + DatasetProcessRule.id == requeried_document.dataset_process_rule_id ) processing_rule = db.session.scalar(stmt) if not processing_rule: raise ValueError("no process rule found") - index_type = dataset_document.doc_form + index_type = requeried_document.doc_form index_processor = IndexProcessorFactory(index_type).init_index_processor() # extract - text_docs = self._extract(index_processor, dataset_document, processing_rule.to_dict()) + text_docs = self._extract(index_processor, requeried_document, processing_rule.to_dict()) # transform documents = self._transform( - index_processor, dataset, text_docs, dataset_document.doc_language, processing_rule.to_dict() + index_processor, dataset, text_docs, requeried_document.doc_language, processing_rule.to_dict() ) # save segment - self._load_segments(dataset, dataset_document, documents) + self._load_segments(dataset, requeried_document, documents) # load self._load( index_processor=index_processor, dataset=dataset, - dataset_document=dataset_document, + dataset_document=requeried_document, documents=documents, ) except DocumentIsPausedError: - raise DocumentIsPausedError(f"Document paused, document id: {dataset_document.id}") + raise DocumentIsPausedError(f"Document paused, document id: {document_id}") except ProviderTokenNotInitError as e: - dataset_document.indexing_status = "error" - dataset_document.error = str(e.description) - dataset_document.stopped_at = naive_utc_now() - db.session.commit() + self._handle_indexing_error(document_id, e) except ObjectDeletedError: - logger.warning("Document deleted, document id: %s", dataset_document.id) + logger.warning("Document deleted, document id: %s", document_id) except Exception as e: - logger.exception("consume document failed") - dataset_document.indexing_status = "error" - dataset_document.error = str(e) - dataset_document.stopped_at = naive_utc_now() - db.session.commit() + self._handle_indexing_error(document_id, e) def run_in_splitting_status(self, dataset_document: DatasetDocument): """Run the indexing process when the index_status is splitting.""" + document_id = dataset_document.id try: + # Re-query the document to ensure it's bound to the current session + requeried_document = db.session.get(DatasetDocument, document_id) + if not requeried_document: + logger.warning("Document not found: %s", document_id) + return + # get dataset - dataset = db.session.query(Dataset).filter_by(id=dataset_document.dataset_id).first() + dataset = db.session.query(Dataset).filter_by(id=requeried_document.dataset_id).first() if not dataset: raise ValueError("no dataset found") @@ -112,57 +130,60 @@ class IndexingRunner: # get exist document_segment list and delete document_segments = ( db.session.query(DocumentSegment) - .filter_by(dataset_id=dataset.id, document_id=dataset_document.id) + .filter_by(dataset_id=dataset.id, document_id=requeried_document.id) .all() ) for document_segment in document_segments: db.session.delete(document_segment) - if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: + if requeried_document.doc_form == IndexType.PARENT_CHILD_INDEX: # delete child chunks db.session.query(ChildChunk).where(ChildChunk.segment_id == document_segment.id).delete() db.session.commit() # get the process rule - stmt = select(DatasetProcessRule).where(DatasetProcessRule.id == dataset_document.dataset_process_rule_id) + stmt = select(DatasetProcessRule).where(DatasetProcessRule.id == requeried_document.dataset_process_rule_id) processing_rule = db.session.scalar(stmt) if not processing_rule: raise ValueError("no process rule found") - index_type = dataset_document.doc_form + index_type = requeried_document.doc_form index_processor = IndexProcessorFactory(index_type).init_index_processor() # extract - text_docs = self._extract(index_processor, dataset_document, processing_rule.to_dict()) + text_docs = self._extract(index_processor, requeried_document, processing_rule.to_dict()) # transform documents = self._transform( - index_processor, dataset, text_docs, dataset_document.doc_language, processing_rule.to_dict() + index_processor, dataset, text_docs, requeried_document.doc_language, processing_rule.to_dict() ) # save segment - self._load_segments(dataset, dataset_document, documents) + self._load_segments(dataset, requeried_document, documents) # load self._load( - index_processor=index_processor, dataset=dataset, dataset_document=dataset_document, documents=documents + index_processor=index_processor, + dataset=dataset, + dataset_document=requeried_document, + documents=documents, ) except DocumentIsPausedError: - raise DocumentIsPausedError(f"Document paused, document id: {dataset_document.id}") + raise DocumentIsPausedError(f"Document paused, document id: {document_id}") except ProviderTokenNotInitError as e: - dataset_document.indexing_status = "error" - dataset_document.error = str(e.description) - dataset_document.stopped_at = naive_utc_now() - db.session.commit() + self._handle_indexing_error(document_id, e) except Exception as e: - logger.exception("consume document failed") - dataset_document.indexing_status = "error" - dataset_document.error = str(e) - dataset_document.stopped_at = naive_utc_now() - db.session.commit() + self._handle_indexing_error(document_id, e) def run_in_indexing_status(self, dataset_document: DatasetDocument): """Run the indexing process when the index_status is indexing.""" + document_id = dataset_document.id try: + # Re-query the document to ensure it's bound to the current session + requeried_document = db.session.get(DatasetDocument, document_id) + if not requeried_document: + logger.warning("Document not found: %s", document_id) + return + # get dataset - dataset = db.session.query(Dataset).filter_by(id=dataset_document.dataset_id).first() + dataset = db.session.query(Dataset).filter_by(id=requeried_document.dataset_id).first() if not dataset: raise ValueError("no dataset found") @@ -170,7 +191,7 @@ class IndexingRunner: # get exist document_segment list and delete document_segments = ( db.session.query(DocumentSegment) - .filter_by(dataset_id=dataset.id, document_id=dataset_document.id) + .filter_by(dataset_id=dataset.id, document_id=requeried_document.id) .all() ) @@ -188,7 +209,7 @@ class IndexingRunner: "dataset_id": document_segment.dataset_id, }, ) - if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: + if requeried_document.doc_form == IndexType.PARENT_CHILD_INDEX: child_chunks = document_segment.get_child_chunks() if child_chunks: child_documents = [] @@ -206,24 +227,20 @@ class IndexingRunner: document.children = child_documents documents.append(document) # build index - index_type = dataset_document.doc_form + index_type = requeried_document.doc_form index_processor = IndexProcessorFactory(index_type).init_index_processor() self._load( - index_processor=index_processor, dataset=dataset, dataset_document=dataset_document, documents=documents + index_processor=index_processor, + dataset=dataset, + dataset_document=requeried_document, + documents=documents, ) except DocumentIsPausedError: - raise DocumentIsPausedError(f"Document paused, document id: {dataset_document.id}") + raise DocumentIsPausedError(f"Document paused, document id: {document_id}") except ProviderTokenNotInitError as e: - dataset_document.indexing_status = "error" - dataset_document.error = str(e.description) - dataset_document.stopped_at = naive_utc_now() - db.session.commit() + self._handle_indexing_error(document_id, e) except Exception as e: - logger.exception("consume document failed") - dataset_document.indexing_status = "error" - dataset_document.error = str(e) - dataset_document.stopped_at = naive_utc_now() - db.session.commit() + self._handle_indexing_error(document_id, e) def indexing_estimate( self, diff --git a/api/core/ops/langfuse_trace/langfuse_trace.py b/api/core/ops/langfuse_trace/langfuse_trace.py index 92e6b8ea60..4de4f403ce 100644 --- a/api/core/ops/langfuse_trace/langfuse_trace.py +++ b/api/core/ops/langfuse_trace/langfuse_trace.py @@ -2,7 +2,7 @@ import logging import os from datetime import datetime, timedelta -from langfuse import Langfuse # type: ignore +from langfuse import Langfuse from sqlalchemy.orm import sessionmaker from core.ops.base_trace_instance import BaseTraceInstance diff --git a/api/core/plugin/impl/base.py b/api/core/plugin/impl/base.py index 5095b46432..e9dc58eec8 100644 --- a/api/core/plugin/impl/base.py +++ b/api/core/plugin/impl/base.py @@ -180,7 +180,7 @@ class BasePluginClient: Make a request to the plugin daemon inner API and return the response as a model. """ response = self._request(method, path, headers, data, params, files) - return type_(**response.json()) # type: ignore + return type_(**response.json()) # type: ignore[return-value] def _request_with_plugin_daemon_response( self, diff --git a/api/core/repositories/celery_workflow_execution_repository.py b/api/core/repositories/celery_workflow_execution_repository.py index 460bb75722..c7f5942f5f 100644 --- a/api/core/repositories/celery_workflow_execution_repository.py +++ b/api/core/repositories/celery_workflow_execution_repository.py @@ -74,7 +74,7 @@ class CeleryWorkflowExecutionRepository(WorkflowExecutionRepository): tenant_id = extract_tenant_id(user) if not tenant_id: raise ValueError("User must have a tenant_id or current_tenant_id") - self._tenant_id = tenant_id # type: ignore[assignment] # We've already checked tenant_id is not None + self._tenant_id = tenant_id # Store app context self._app_id = app_id diff --git a/api/core/repositories/celery_workflow_node_execution_repository.py b/api/core/repositories/celery_workflow_node_execution_repository.py index 21a0b7eefe..9b8e45b1eb 100644 --- a/api/core/repositories/celery_workflow_node_execution_repository.py +++ b/api/core/repositories/celery_workflow_node_execution_repository.py @@ -81,7 +81,7 @@ class CeleryWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository): tenant_id = extract_tenant_id(user) if not tenant_id: raise ValueError("User must have a tenant_id or current_tenant_id") - self._tenant_id = tenant_id # type: ignore[assignment] # We've already checked tenant_id is not None + self._tenant_id = tenant_id # Store app context self._app_id = app_id diff --git a/api/core/repositories/factory.py b/api/core/repositories/factory.py index 854c122331..02fcabab5d 100644 --- a/api/core/repositories/factory.py +++ b/api/core/repositories/factory.py @@ -60,7 +60,7 @@ class DifyCoreRepositoryFactory: try: repository_class = import_string(class_path) - return repository_class( # type: ignore[no-any-return] + return repository_class( session_factory=session_factory, user=user, app_id=app_id, @@ -96,7 +96,7 @@ class DifyCoreRepositoryFactory: try: repository_class = import_string(class_path) - return repository_class( # type: ignore[no-any-return] + return repository_class( session_factory=session_factory, user=user, app_id=app_id, diff --git a/api/core/tools/builtin_tool/provider.py b/api/core/tools/builtin_tool/provider.py index 2e94907f30..a391136a5c 100644 --- a/api/core/tools/builtin_tool/provider.py +++ b/api/core/tools/builtin_tool/provider.py @@ -157,7 +157,7 @@ class BuiltinToolProviderController(ToolProviderController): """ returns the tool that the provider can provide """ - return next(filter(lambda x: x.entity.identity.name == tool_name, self.get_tools()), None) # type: ignore + return next(filter(lambda x: x.entity.identity.name == tool_name, self.get_tools()), None) @property def need_credentials(self) -> bool: diff --git a/api/core/tools/builtin_tool/providers/audio/tools/tts.py b/api/core/tools/builtin_tool/providers/audio/tools/tts.py index 8bc159bb85..5009f7ac21 100644 --- a/api/core/tools/builtin_tool/providers/audio/tools/tts.py +++ b/api/core/tools/builtin_tool/providers/audio/tools/tts.py @@ -43,7 +43,7 @@ class TTSTool(BuiltinTool): content_text=tool_parameters.get("text"), # type: ignore user=user_id, tenant_id=self.runtime.tenant_id, - voice=voice, # type: ignore + voice=voice, ) buffer = io.BytesIO() for chunk in tts: diff --git a/api/core/tools/builtin_tool/providers/time/tools/localtime_to_timestamp.py b/api/core/tools/builtin_tool/providers/time/tools/localtime_to_timestamp.py index 197b062e44..d0a41b940f 100644 --- a/api/core/tools/builtin_tool/providers/time/tools/localtime_to_timestamp.py +++ b/api/core/tools/builtin_tool/providers/time/tools/localtime_to_timestamp.py @@ -34,6 +34,7 @@ class LocaltimeToTimestampTool(BuiltinTool): yield self.create_text_message(f"{timestamp}") + # TODO: this method's type is messy @staticmethod def localtime_to_timestamp(localtime: str, time_format: str, local_tz=None) -> int | None: try: diff --git a/api/core/tools/builtin_tool/providers/time/tools/timezone_conversion.py b/api/core/tools/builtin_tool/providers/time/tools/timezone_conversion.py index babfa9bcd9..e23ae3b001 100644 --- a/api/core/tools/builtin_tool/providers/time/tools/timezone_conversion.py +++ b/api/core/tools/builtin_tool/providers/time/tools/timezone_conversion.py @@ -48,6 +48,6 @@ class TimezoneConversionTool(BuiltinTool): datetime_with_tz = input_timezone.localize(local_time) # timezone convert converted_datetime = datetime_with_tz.astimezone(output_timezone) - return converted_datetime.strftime(format=time_format) # type: ignore + return converted_datetime.strftime(time_format) except Exception as e: raise ToolInvokeError(str(e)) diff --git a/api/core/tools/mcp_tool/provider.py b/api/core/tools/mcp_tool/provider.py index 0c2870727e..f0e4dba9c3 100644 --- a/api/core/tools/mcp_tool/provider.py +++ b/api/core/tools/mcp_tool/provider.py @@ -105,7 +105,7 @@ class MCPToolProviderController(ToolProviderController): """ pass - def get_tool(self, tool_name: str) -> MCPTool: # type: ignore + def get_tool(self, tool_name: str) -> MCPTool: """ return tool with given name """ @@ -128,7 +128,7 @@ class MCPToolProviderController(ToolProviderController): sse_read_timeout=self.sse_read_timeout, ) - def get_tools(self) -> list[MCPTool]: # type: ignore + def get_tools(self) -> list[MCPTool]: """ get all tools """ diff --git a/api/core/tools/tool_label_manager.py b/api/core/tools/tool_label_manager.py index 39646b7fc8..90d5a647e9 100644 --- a/api/core/tools/tool_label_manager.py +++ b/api/core/tools/tool_label_manager.py @@ -26,7 +26,7 @@ class ToolLabelManager: labels = cls.filter_tool_labels(labels) if isinstance(controller, ApiToolProviderController | WorkflowToolProviderController): - provider_id = controller.provider_id # ty: ignore [unresolved-attribute] + provider_id = controller.provider_id else: raise ValueError("Unsupported tool type") @@ -51,7 +51,7 @@ class ToolLabelManager: Get tool labels """ if isinstance(controller, ApiToolProviderController | WorkflowToolProviderController): - provider_id = controller.provider_id # ty: ignore [unresolved-attribute] + provider_id = controller.provider_id elif isinstance(controller, BuiltinToolProviderController): return controller.tool_labels else: @@ -85,7 +85,7 @@ class ToolLabelManager: provider_ids = [] for controller in tool_providers: assert isinstance(controller, ApiToolProviderController | WorkflowToolProviderController) - provider_ids.append(controller.provider_id) # ty: ignore [unresolved-attribute] + provider_ids.append(controller.provider_id) labels = db.session.scalars(select(ToolLabelBinding).where(ToolLabelBinding.tool_id.in_(provider_ids))).all() diff --git a/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py b/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py index 915a22dd0f..f96510fb45 100644 --- a/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py +++ b/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py @@ -193,18 +193,18 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): DatasetDocument.enabled == True, DatasetDocument.archived == False, ) - document = db.session.scalar(dataset_document_stmt) # type: ignore + document = db.session.scalar(dataset_document_stmt) if dataset and document: source = RetrievalSourceMetadata( dataset_id=dataset.id, dataset_name=dataset.name, - document_id=document.id, # type: ignore - document_name=document.name, # type: ignore - data_source_type=document.data_source_type, # type: ignore + document_id=document.id, + document_name=document.name, + data_source_type=document.data_source_type, segment_id=segment.id, retriever_from=self.retriever_from, score=record.score or 0.0, - doc_metadata=document.doc_metadata, # type: ignore + doc_metadata=document.doc_metadata, ) if self.retriever_from == "dev": diff --git a/api/core/tools/utils/web_reader_tool.py b/api/core/tools/utils/web_reader_tool.py index 52c16c34a0..ef6913d0bd 100644 --- a/api/core/tools/utils/web_reader_tool.py +++ b/api/core/tools/utils/web_reader_tool.py @@ -6,8 +6,8 @@ from typing import Any, cast from urllib.parse import unquote import chardet -import cloudscraper # type: ignore -from readabilipy import simple_json_from_html_string # type: ignore +import cloudscraper +from readabilipy import simple_json_from_html_string from core.helper import ssrf_proxy from core.rag.extractor import extract_processor @@ -63,8 +63,8 @@ def get_url(url: str, user_agent: str | None = None) -> str: response = ssrf_proxy.get(url, headers=headers, follow_redirects=True, timeout=(120, 300)) elif response.status_code == 403: scraper = cloudscraper.create_scraper() - scraper.perform_request = ssrf_proxy.make_request # type: ignore - response = scraper.get(url, headers=headers, follow_redirects=True, timeout=(120, 300)) # type: ignore + scraper.perform_request = ssrf_proxy.make_request + response = scraper.get(url, headers=headers, timeout=(120, 300)) if response.status_code != 200: return f"URL returned status code {response.status_code}." diff --git a/api/core/tools/utils/yaml_utils.py b/api/core/tools/utils/yaml_utils.py index e9b5dab7d3..071154ee71 100644 --- a/api/core/tools/utils/yaml_utils.py +++ b/api/core/tools/utils/yaml_utils.py @@ -3,7 +3,7 @@ from functools import lru_cache from pathlib import Path from typing import Any -import yaml # type: ignore +import yaml from yaml import YAMLError logger = logging.getLogger(__name__) diff --git a/api/core/tools/workflow_as_tool/provider.py b/api/core/tools/workflow_as_tool/provider.py index 4d9c8895fc..e514c8c57b 100644 --- a/api/core/tools/workflow_as_tool/provider.py +++ b/api/core/tools/workflow_as_tool/provider.py @@ -99,7 +99,7 @@ class WorkflowToolProviderController(ToolProviderController): variables = WorkflowToolConfigurationUtils.get_workflow_graph_variables(graph) def fetch_workflow_variable(variable_name: str) -> VariableEntity | None: - return next(filter(lambda x: x.variable == variable_name, variables), None) # type: ignore + return next(filter(lambda x: x.variable == variable_name, variables), None) user = db_provider.user diff --git a/api/core/variables/segment_group.py b/api/core/variables/segment_group.py index 0a41b64228..b363255b2c 100644 --- a/api/core/variables/segment_group.py +++ b/api/core/variables/segment_group.py @@ -4,7 +4,7 @@ from .types import SegmentType class SegmentGroup(Segment): value_type: SegmentType = SegmentType.GROUP - value: list[Segment] = None # type: ignore + value: list[Segment] @property def text(self): diff --git a/api/core/variables/segments.py b/api/core/variables/segments.py index 6c9e6d726e..406b4e6f93 100644 --- a/api/core/variables/segments.py +++ b/api/core/variables/segments.py @@ -19,7 +19,7 @@ class Segment(BaseModel): model_config = ConfigDict(frozen=True) value_type: SegmentType - value: Any = None + value: Any @field_validator("value_type") @classmethod @@ -74,12 +74,12 @@ class NoneSegment(Segment): class StringSegment(Segment): value_type: SegmentType = SegmentType.STRING - value: str = None # type: ignore + value: str class FloatSegment(Segment): value_type: SegmentType = SegmentType.FLOAT - value: float = None # type: ignore + value: float # NOTE(QuantumGhost): seems that the equality for FloatSegment with `NaN` value has some problems. # The following tests cannot pass. # @@ -98,12 +98,12 @@ class FloatSegment(Segment): class IntegerSegment(Segment): value_type: SegmentType = SegmentType.INTEGER - value: int = None # type: ignore + value: int class ObjectSegment(Segment): value_type: SegmentType = SegmentType.OBJECT - value: Mapping[str, Any] = None # type: ignore + value: Mapping[str, Any] @property def text(self) -> str: @@ -136,7 +136,7 @@ class ArraySegment(Segment): class FileSegment(Segment): value_type: SegmentType = SegmentType.FILE - value: File = None # type: ignore + value: File @property def markdown(self) -> str: @@ -153,17 +153,17 @@ class FileSegment(Segment): class BooleanSegment(Segment): value_type: SegmentType = SegmentType.BOOLEAN - value: bool = None # type: ignore + value: bool class ArrayAnySegment(ArraySegment): value_type: SegmentType = SegmentType.ARRAY_ANY - value: Sequence[Any] = None # type: ignore + value: Sequence[Any] class ArrayStringSegment(ArraySegment): value_type: SegmentType = SegmentType.ARRAY_STRING - value: Sequence[str] = None # type: ignore + value: Sequence[str] @property def text(self) -> str: @@ -175,17 +175,17 @@ class ArrayStringSegment(ArraySegment): class ArrayNumberSegment(ArraySegment): value_type: SegmentType = SegmentType.ARRAY_NUMBER - value: Sequence[float | int] = None # type: ignore + value: Sequence[float | int] class ArrayObjectSegment(ArraySegment): value_type: SegmentType = SegmentType.ARRAY_OBJECT - value: Sequence[Mapping[str, Any]] = None # type: ignore + value: Sequence[Mapping[str, Any]] class ArrayFileSegment(ArraySegment): value_type: SegmentType = SegmentType.ARRAY_FILE - value: Sequence[File] = None # type: ignore + value: Sequence[File] @property def markdown(self) -> str: @@ -205,7 +205,7 @@ class ArrayFileSegment(ArraySegment): class ArrayBooleanSegment(ArraySegment): value_type: SegmentType = SegmentType.ARRAY_BOOLEAN - value: Sequence[bool] = None # type: ignore + value: Sequence[bool] def get_segment_discriminator(v: Any) -> SegmentType | None: diff --git a/api/core/workflow/entities/__init__.py b/api/core/workflow/entities/__init__.py index be70e467a0..185f0ad620 100644 --- a/api/core/workflow/entities/__init__.py +++ b/api/core/workflow/entities/__init__.py @@ -1,3 +1,5 @@ +from ..runtime.graph_runtime_state import GraphRuntimeState +from ..runtime.variable_pool import VariablePool from .agent import AgentNodeStrategyInit from .graph_init_params import GraphInitParams from .workflow_execution import WorkflowExecution @@ -6,6 +8,8 @@ from .workflow_node_execution import WorkflowNodeExecution __all__ = [ "AgentNodeStrategyInit", "GraphInitParams", + "GraphRuntimeState", + "VariablePool", "WorkflowExecution", "WorkflowNodeExecution", ] diff --git a/api/core/workflow/graph/graph.py b/api/core/workflow/graph/graph.py index 20b5193875..d04724425c 100644 --- a/api/core/workflow/graph/graph.py +++ b/api/core/workflow/graph/graph.py @@ -3,11 +3,12 @@ from collections import defaultdict from collections.abc import Mapping, Sequence from typing import Protocol, cast, final -from core.workflow.enums import NodeExecutionType, NodeState, NodeType +from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeState, NodeType from core.workflow.nodes.base.node import Node from libs.typing import is_str, is_str_dict from .edge import Edge +from .validation import get_graph_validator logger = logging.getLogger(__name__) @@ -201,6 +202,17 @@ class Graph: return GraphBuilder(graph_cls=cls) + @classmethod + def _promote_fail_branch_nodes(cls, nodes: dict[str, Node]) -> None: + """ + Promote nodes configured with FAIL_BRANCH error strategy to branch execution type. + + :param nodes: mapping of node ID to node instance + """ + for node in nodes.values(): + if node.error_strategy == ErrorStrategy.FAIL_BRANCH: + node.execution_type = NodeExecutionType.BRANCH + @classmethod def _mark_inactive_root_branches( cls, @@ -307,6 +319,9 @@ class Graph: # Create node instances nodes = cls._create_node_instances(node_configs_map, node_factory) + # Promote fail-branch nodes to branch execution type at graph level + cls._promote_fail_branch_nodes(nodes) + # Get root node instance root_node = nodes[root_node_id] @@ -314,7 +329,7 @@ class Graph: cls._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, root_node_id) # Create and return the graph - return cls( + graph = cls( nodes=nodes, edges=edges, in_edges=in_edges, @@ -322,6 +337,11 @@ class Graph: root_node=root_node, ) + # Validate the graph structure using built-in validators + get_graph_validator().validate(graph) + + return graph + @property def node_ids(self) -> list[str]: """ diff --git a/api/core/workflow/graph/validation.py b/api/core/workflow/graph/validation.py new file mode 100644 index 0000000000..87aa7db2e4 --- /dev/null +++ b/api/core/workflow/graph/validation.py @@ -0,0 +1,125 @@ +from __future__ import annotations + +from collections.abc import Sequence +from dataclasses import dataclass +from typing import TYPE_CHECKING, Protocol + +from core.workflow.enums import NodeExecutionType, NodeType + +if TYPE_CHECKING: + from .graph import Graph + + +@dataclass(frozen=True, slots=True) +class GraphValidationIssue: + """Immutable value object describing a single validation issue.""" + + code: str + message: str + node_id: str | None = None + + +class GraphValidationError(ValueError): + """Raised when graph validation fails.""" + + def __init__(self, issues: Sequence[GraphValidationIssue]) -> None: + if not issues: + raise ValueError("GraphValidationError requires at least one issue.") + self.issues: tuple[GraphValidationIssue, ...] = tuple(issues) + message = "; ".join(f"[{issue.code}] {issue.message}" for issue in self.issues) + super().__init__(message) + + +class GraphValidationRule(Protocol): + """Protocol that individual validation rules must satisfy.""" + + def validate(self, graph: Graph) -> Sequence[GraphValidationIssue]: + """Validate the provided graph and return any discovered issues.""" + ... + + +@dataclass(frozen=True, slots=True) +class _EdgeEndpointValidator: + """Ensures all edges reference existing nodes.""" + + missing_node_code: str = "MISSING_NODE" + + def validate(self, graph: Graph) -> Sequence[GraphValidationIssue]: + issues: list[GraphValidationIssue] = [] + for edge in graph.edges.values(): + if edge.tail not in graph.nodes: + issues.append( + GraphValidationIssue( + code=self.missing_node_code, + message=f"Edge {edge.id} references unknown source node '{edge.tail}'.", + node_id=edge.tail, + ) + ) + if edge.head not in graph.nodes: + issues.append( + GraphValidationIssue( + code=self.missing_node_code, + message=f"Edge {edge.id} references unknown target node '{edge.head}'.", + node_id=edge.head, + ) + ) + return issues + + +@dataclass(frozen=True, slots=True) +class _RootNodeValidator: + """Validates root node invariants.""" + + invalid_root_code: str = "INVALID_ROOT" + container_entry_types: tuple[NodeType, ...] = (NodeType.ITERATION_START, NodeType.LOOP_START) + + def validate(self, graph: Graph) -> Sequence[GraphValidationIssue]: + root_node = graph.root_node + issues: list[GraphValidationIssue] = [] + if root_node.id not in graph.nodes: + issues.append( + GraphValidationIssue( + code=self.invalid_root_code, + message=f"Root node '{root_node.id}' is missing from the node registry.", + node_id=root_node.id, + ) + ) + return issues + + node_type = getattr(root_node, "node_type", None) + if root_node.execution_type != NodeExecutionType.ROOT and node_type not in self.container_entry_types: + issues.append( + GraphValidationIssue( + code=self.invalid_root_code, + message=f"Root node '{root_node.id}' must declare execution type 'root'.", + node_id=root_node.id, + ) + ) + return issues + + +@dataclass(frozen=True, slots=True) +class GraphValidator: + """Coordinates execution of graph validation rules.""" + + rules: tuple[GraphValidationRule, ...] + + def validate(self, graph: Graph) -> None: + """Validate the graph against all configured rules.""" + issues: list[GraphValidationIssue] = [] + for rule in self.rules: + issues.extend(rule.validate(graph)) + + if issues: + raise GraphValidationError(issues) + + +_DEFAULT_RULES: tuple[GraphValidationRule, ...] = ( + _EdgeEndpointValidator(), + _RootNodeValidator(), +) + + +def get_graph_validator() -> GraphValidator: + """Construct the validator composed of default rules.""" + return GraphValidator(_DEFAULT_RULES) diff --git a/api/core/workflow/nodes/base/entities.py b/api/core/workflow/nodes/base/entities.py index 5aef9d79cf..94b0d1d8bc 100644 --- a/api/core/workflow/nodes/base/entities.py +++ b/api/core/workflow/nodes/base/entities.py @@ -1,5 +1,6 @@ import json from abc import ABC +from builtins import type as type_ from collections.abc import Sequence from enum import StrEnum from typing import Any, Union @@ -58,10 +59,9 @@ class DefaultValue(BaseModel): raise DefaultValueTypeError(f"Invalid JSON format for value: {value}") @staticmethod - def _validate_array(value: Any, element_type: DefaultValueType) -> bool: + def _validate_array(value: Any, element_type: type_ | tuple[type_, ...]) -> bool: """Unified array type validation""" - # FIXME, type ignore here for do not find the reason mypy complain, if find the root cause, please fix it - return isinstance(value, list) and all(isinstance(x, element_type) for x in value) # type: ignore + return isinstance(value, list) and all(isinstance(x, element_type) for x in value) @staticmethod def _convert_number(value: str) -> float: diff --git a/api/core/workflow/nodes/document_extractor/node.py b/api/core/workflow/nodes/document_extractor/node.py index ae1061d72c..cd5f50aaab 100644 --- a/api/core/workflow/nodes/document_extractor/node.py +++ b/api/core/workflow/nodes/document_extractor/node.py @@ -10,10 +10,10 @@ from typing import Any import chardet import docx import pandas as pd -import pypandoc # type: ignore -import pypdfium2 # type: ignore -import webvtt # type: ignore -import yaml # type: ignore +import pypandoc +import pypdfium2 +import webvtt +import yaml from docx.document import Document from docx.oxml.table import CT_Tbl from docx.oxml.text.paragraph import CT_P diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index 2dc3cb9320..ba5134f9e6 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -141,7 +141,7 @@ class KnowledgeRetrievalNode(Node): def version(cls): return "1" - def _run(self) -> NodeRunResult: # type: ignore + def _run(self) -> NodeRunResult: # extract variables variable = self.graph_runtime_state.variable_pool.get(self._node_data.query_variable_selector) if not isinstance(variable, StringSegment): @@ -443,7 +443,7 @@ class KnowledgeRetrievalNode(Node): metadata_condition = MetadataCondition( logical_operator=node_data.metadata_filtering_conditions.logical_operator if node_data.metadata_filtering_conditions - else "or", # type: ignore + else "or", conditions=conditions, ) elif node_data.metadata_filtering_mode == "manual": @@ -457,10 +457,10 @@ class KnowledgeRetrievalNode(Node): expected_value = self.graph_runtime_state.variable_pool.convert_template( expected_value ).value[0] - if expected_value.value_type in {"number", "integer", "float"}: # type: ignore - expected_value = expected_value.value # type: ignore - elif expected_value.value_type == "string": # type: ignore - expected_value = re.sub(r"[\r\n\t]+", " ", expected_value.text).strip() # type: ignore + if expected_value.value_type in {"number", "integer", "float"}: + expected_value = expected_value.value + elif expected_value.value_type == "string": + expected_value = re.sub(r"[\r\n\t]+", " ", expected_value.text).strip() else: raise ValueError("Invalid expected metadata value type") conditions.append( @@ -487,7 +487,7 @@ class KnowledgeRetrievalNode(Node): if ( node_data.metadata_filtering_conditions and node_data.metadata_filtering_conditions.logical_operator == "and" - ): # type: ignore + ): document_query = document_query.where(and_(*filters)) else: document_query = document_query.where(or_(*filters)) diff --git a/api/core/workflow/nodes/node_factory.py b/api/core/workflow/nodes/node_factory.py index 87d1b8c435..84f63d57eb 100644 --- a/api/core/workflow/nodes/node_factory.py +++ b/api/core/workflow/nodes/node_factory.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, final from typing_extensions import override -from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType +from core.workflow.enums import NodeType from core.workflow.graph import NodeFactory from core.workflow.nodes.base.node import Node from libs.typing import is_str, is_str_dict @@ -82,8 +82,4 @@ class DifyNodeFactory(NodeFactory): raise ValueError(f"Node {node_id} missing data information") node_instance.init_node_data(node_data) - # If node has fail branch, change execution type to branch - if node_instance.error_strategy == ErrorStrategy.FAIL_BRANCH: - node_instance.execution_type = NodeExecutionType.BRANCH - return node_instance diff --git a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py index 2b65cc30b6..e250650fef 100644 --- a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py +++ b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py @@ -747,7 +747,7 @@ class ParameterExtractorNode(Node): if model_mode == ModelMode.CHAT: system_prompt_messages = ChatModelMessage( role=PromptMessageRole.SYSTEM, - text=CHAT_GENERATE_JSON_PROMPT.format(histories=memory_str).replace("{{instructions}}", instruction), + text=CHAT_GENERATE_JSON_PROMPT.format(histories=memory_str, instructions=instruction), ) user_prompt_message = ChatModelMessage(role=PromptMessageRole.USER, text=input_text) return [system_prompt_messages, user_prompt_message] diff --git a/api/core/workflow/nodes/parameter_extractor/prompts.py b/api/core/workflow/nodes/parameter_extractor/prompts.py index b74be8f206..1b29be4418 100644 --- a/api/core/workflow/nodes/parameter_extractor/prompts.py +++ b/api/core/workflow/nodes/parameter_extractor/prompts.py @@ -135,7 +135,7 @@ Here are the chat histories between human and assistant, inside -{{instructions}} +{instructions} """ diff --git a/api/core/workflow/runtime/variable_pool.py b/api/core/workflow/runtime/variable_pool.py index 5fd6e894f1..d41a20dfd7 100644 --- a/api/core/workflow/runtime/variable_pool.py +++ b/api/core/workflow/runtime/variable_pool.py @@ -260,7 +260,7 @@ class VariablePool(BaseModel): # This ensures that we can keep the id of the system variables intact. if self._has(selector): continue - self.add(selector, value) # type: ignore + self.add(selector, value) @classmethod def empty(cls) -> "VariablePool": diff --git a/api/extensions/ext_blueprints.py b/api/extensions/ext_blueprints.py index 52fef4929f..82f0542b35 100644 --- a/api/extensions/ext_blueprints.py +++ b/api/extensions/ext_blueprints.py @@ -1,7 +1,12 @@ from configs import dify_config -from constants import HEADER_NAME_APP_CODE, HEADER_NAME_CSRF_TOKEN +from constants import HEADER_NAME_APP_CODE, HEADER_NAME_CSRF_TOKEN, HEADER_NAME_PASSPORT from dify_app import DifyApp +BASE_CORS_HEADERS: tuple[str, ...] = ("Content-Type", HEADER_NAME_APP_CODE, HEADER_NAME_PASSPORT) +SERVICE_API_HEADERS: tuple[str, ...] = (*BASE_CORS_HEADERS, "Authorization") +AUTHENTICATED_HEADERS: tuple[str, ...] = (*SERVICE_API_HEADERS, HEADER_NAME_CSRF_TOKEN) +FILES_HEADERS: tuple[str, ...] = (*BASE_CORS_HEADERS, HEADER_NAME_CSRF_TOKEN) + def init_app(app: DifyApp): # register blueprint routers @@ -17,7 +22,7 @@ def init_app(app: DifyApp): CORS( service_api_bp, - allow_headers=["Content-Type", "Authorization", HEADER_NAME_APP_CODE], + allow_headers=list(SERVICE_API_HEADERS), methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"], ) app.register_blueprint(service_api_bp) @@ -26,7 +31,7 @@ def init_app(app: DifyApp): web_bp, resources={r"/*": {"origins": dify_config.WEB_API_CORS_ALLOW_ORIGINS}}, supports_credentials=True, - allow_headers=["Content-Type", "Authorization", HEADER_NAME_APP_CODE, HEADER_NAME_CSRF_TOKEN], + allow_headers=list(AUTHENTICATED_HEADERS), methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"], expose_headers=["X-Version", "X-Env"], ) @@ -36,7 +41,7 @@ def init_app(app: DifyApp): console_app_bp, resources={r"/*": {"origins": dify_config.CONSOLE_CORS_ALLOW_ORIGINS}}, supports_credentials=True, - allow_headers=["Content-Type", "Authorization", HEADER_NAME_CSRF_TOKEN], + allow_headers=list(AUTHENTICATED_HEADERS), methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"], expose_headers=["X-Version", "X-Env"], ) @@ -44,7 +49,7 @@ def init_app(app: DifyApp): CORS( files_bp, - allow_headers=["Content-Type", HEADER_NAME_CSRF_TOKEN], + allow_headers=list(FILES_HEADERS), methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"], ) app.register_blueprint(files_bp) diff --git a/api/extensions/ext_compress.py b/api/extensions/ext_compress.py index 26ff6427be..9c3a663af4 100644 --- a/api/extensions/ext_compress.py +++ b/api/extensions/ext_compress.py @@ -7,7 +7,7 @@ def is_enabled() -> bool: def init_app(app: DifyApp): - from flask_compress import Compress # type: ignore + from flask_compress import Compress compress = Compress() compress.init_app(app) diff --git a/api/extensions/ext_login.py b/api/extensions/ext_login.py index e7816a2e88..ed4fe332c1 100644 --- a/api/extensions/ext_login.py +++ b/api/extensions/ext_login.py @@ -1,6 +1,6 @@ import json -import flask_login # type: ignore +import flask_login from flask import Response, request from flask_login import user_loaded_from_request, user_logged_in from werkzeug.exceptions import NotFound, Unauthorized diff --git a/api/extensions/ext_migrate.py b/api/extensions/ext_migrate.py index 5f862181fa..6d8f35c30d 100644 --- a/api/extensions/ext_migrate.py +++ b/api/extensions/ext_migrate.py @@ -2,7 +2,7 @@ from dify_app import DifyApp def init_app(app: DifyApp): - import flask_migrate # type: ignore + import flask_migrate from extensions.ext_database import db diff --git a/api/extensions/ext_otel.py b/api/extensions/ext_otel.py index cb6e4849a9..20ac2503a2 100644 --- a/api/extensions/ext_otel.py +++ b/api/extensions/ext_otel.py @@ -103,7 +103,7 @@ def init_app(app: DifyApp): def shutdown_tracer(): provider = trace.get_tracer_provider() if hasattr(provider, "force_flush"): - provider.force_flush() # ty: ignore [call-non-callable] + provider.force_flush() class ExceptionLoggingHandler(logging.Handler): """Custom logging handler that creates spans for logging.exception() calls""" diff --git a/api/extensions/ext_proxy_fix.py b/api/extensions/ext_proxy_fix.py index c085aed986..fe6685f633 100644 --- a/api/extensions/ext_proxy_fix.py +++ b/api/extensions/ext_proxy_fix.py @@ -6,4 +6,4 @@ def init_app(app: DifyApp): if dify_config.RESPECT_XFORWARD_HEADERS_ENABLED: from werkzeug.middleware.proxy_fix import ProxyFix - app.wsgi_app = ProxyFix(app.wsgi_app, x_port=1) # type: ignore + app.wsgi_app = ProxyFix(app.wsgi_app, x_port=1) # type: ignore[method-assign] diff --git a/api/extensions/ext_sentry.py b/api/extensions/ext_sentry.py index 5ed7840211..c3aa8edf80 100644 --- a/api/extensions/ext_sentry.py +++ b/api/extensions/ext_sentry.py @@ -5,7 +5,7 @@ from dify_app import DifyApp def init_app(app: DifyApp): if dify_config.SENTRY_DSN: import sentry_sdk - from langfuse import parse_error # type: ignore + from langfuse import parse_error from sentry_sdk.integrations.celery import CeleryIntegration from sentry_sdk.integrations.flask import FlaskIntegration from werkzeug.exceptions import HTTPException diff --git a/api/extensions/storage/aliyun_oss_storage.py b/api/extensions/storage/aliyun_oss_storage.py index 5da4737138..2283581f62 100644 --- a/api/extensions/storage/aliyun_oss_storage.py +++ b/api/extensions/storage/aliyun_oss_storage.py @@ -1,7 +1,7 @@ import posixpath from collections.abc import Generator -import oss2 as aliyun_s3 # type: ignore +import oss2 as aliyun_s3 from configs import dify_config from extensions.storage.base_storage import BaseStorage diff --git a/api/extensions/storage/baidu_obs_storage.py b/api/extensions/storage/baidu_obs_storage.py index b94efa08be..0bb4648c0a 100644 --- a/api/extensions/storage/baidu_obs_storage.py +++ b/api/extensions/storage/baidu_obs_storage.py @@ -2,9 +2,9 @@ import base64 import hashlib from collections.abc import Generator -from baidubce.auth.bce_credentials import BceCredentials # type: ignore -from baidubce.bce_client_configuration import BceClientConfiguration # type: ignore -from baidubce.services.bos.bos_client import BosClient # type: ignore +from baidubce.auth.bce_credentials import BceCredentials +from baidubce.bce_client_configuration import BceClientConfiguration +from baidubce.services.bos.bos_client import BosClient from configs import dify_config from extensions.storage.base_storage import BaseStorage diff --git a/api/extensions/storage/clickzetta_volume/clickzetta_volume_storage.py b/api/extensions/storage/clickzetta_volume/clickzetta_volume_storage.py index 06c528ca41..1cabc57e74 100644 --- a/api/extensions/storage/clickzetta_volume/clickzetta_volume_storage.py +++ b/api/extensions/storage/clickzetta_volume/clickzetta_volume_storage.py @@ -11,7 +11,7 @@ from collections.abc import Generator from io import BytesIO from pathlib import Path -import clickzetta # type: ignore[import] +import clickzetta from pydantic import BaseModel, model_validator from extensions.storage.base_storage import BaseStorage diff --git a/api/extensions/storage/clickzetta_volume/volume_permissions.py b/api/extensions/storage/clickzetta_volume/volume_permissions.py index 6dcf800abb..9d4ca689d8 100644 --- a/api/extensions/storage/clickzetta_volume/volume_permissions.py +++ b/api/extensions/storage/clickzetta_volume/volume_permissions.py @@ -34,7 +34,7 @@ class VolumePermissionManager: # Support two initialization methods: connection object or configuration dictionary if isinstance(connection_or_config, dict): # Create connection from configuration dictionary - import clickzetta # type: ignore[import-untyped] + import clickzetta config = connection_or_config self._connection = clickzetta.connect( diff --git a/api/extensions/storage/google_cloud_storage.py b/api/extensions/storage/google_cloud_storage.py index 7f59252f2f..d352996518 100644 --- a/api/extensions/storage/google_cloud_storage.py +++ b/api/extensions/storage/google_cloud_storage.py @@ -3,7 +3,7 @@ import io import json from collections.abc import Generator -from google.cloud import storage as google_cloud_storage # type: ignore +from google.cloud import storage as google_cloud_storage from configs import dify_config from extensions.storage.base_storage import BaseStorage diff --git a/api/extensions/storage/huawei_obs_storage.py b/api/extensions/storage/huawei_obs_storage.py index 3e75ecb7a9..74fed26f65 100644 --- a/api/extensions/storage/huawei_obs_storage.py +++ b/api/extensions/storage/huawei_obs_storage.py @@ -1,6 +1,6 @@ from collections.abc import Generator -from obs import ObsClient # type: ignore +from obs import ObsClient from configs import dify_config from extensions.storage.base_storage import BaseStorage diff --git a/api/extensions/storage/oracle_oci_storage.py b/api/extensions/storage/oracle_oci_storage.py index acc00cbd6b..c032803045 100644 --- a/api/extensions/storage/oracle_oci_storage.py +++ b/api/extensions/storage/oracle_oci_storage.py @@ -1,7 +1,7 @@ from collections.abc import Generator -import boto3 # type: ignore -from botocore.exceptions import ClientError # type: ignore +import boto3 +from botocore.exceptions import ClientError from configs import dify_config from extensions.storage.base_storage import BaseStorage diff --git a/api/extensions/storage/tencent_cos_storage.py b/api/extensions/storage/tencent_cos_storage.py index 9cdd3e67f7..ea5d982efc 100644 --- a/api/extensions/storage/tencent_cos_storage.py +++ b/api/extensions/storage/tencent_cos_storage.py @@ -1,6 +1,6 @@ from collections.abc import Generator -from qcloud_cos import CosConfig, CosS3Client # type: ignore +from qcloud_cos import CosConfig, CosS3Client from configs import dify_config from extensions.storage.base_storage import BaseStorage diff --git a/api/extensions/storage/volcengine_tos_storage.py b/api/extensions/storage/volcengine_tos_storage.py index 8ed8e4c170..a44959221f 100644 --- a/api/extensions/storage/volcengine_tos_storage.py +++ b/api/extensions/storage/volcengine_tos_storage.py @@ -1,6 +1,6 @@ from collections.abc import Generator -import tos # type: ignore +import tos from configs import dify_config from extensions.storage.base_storage import BaseStorage diff --git a/api/libs/external_api.py b/api/libs/external_api.py index f3ebcc4306..1a4fde960c 100644 --- a/api/libs/external_api.py +++ b/api/libs/external_api.py @@ -146,6 +146,6 @@ class ExternalApi(Api): kwargs["doc"] = dify_config.SWAGGER_UI_PATH if dify_config.SWAGGER_UI_ENABLED else False # manual separate call on construction and init_app to ensure configs in kwargs effective - super().__init__(app=None, *args, **kwargs) # type: ignore + super().__init__(app=None, *args, **kwargs) self.init_app(app, **kwargs) register_external_error_handlers(self) diff --git a/api/libs/gmpy2_pkcs10aep_cipher.py b/api/libs/gmpy2_pkcs10aep_cipher.py index fc38d51005..23eb8dca05 100644 --- a/api/libs/gmpy2_pkcs10aep_cipher.py +++ b/api/libs/gmpy2_pkcs10aep_cipher.py @@ -23,7 +23,7 @@ from hashlib import sha1 import Crypto.Hash.SHA1 import Crypto.Util.number -import gmpy2 # type: ignore +import gmpy2 from Crypto import Random from Crypto.Signature.pss import MGF1 from Crypto.Util.number import bytes_to_long, ceil_div, long_to_bytes @@ -136,7 +136,7 @@ class PKCS1OAepCipher: # Step 3a (OS2IP) em_int = bytes_to_long(em) # Step 3b (RSAEP) - m_int = gmpy2.powmod(em_int, self._key.e, self._key.n) # ty: ignore [unresolved-attribute] + m_int = gmpy2.powmod(em_int, self._key.e, self._key.n) # Step 3c (I2OSP) c = long_to_bytes(m_int, k) return c @@ -169,7 +169,7 @@ class PKCS1OAepCipher: ct_int = bytes_to_long(ciphertext) # Step 2b (RSADP) # m_int = self._key._decrypt(ct_int) - m_int = gmpy2.powmod(ct_int, self._key.d, self._key.n) # ty: ignore [unresolved-attribute] + m_int = gmpy2.powmod(ct_int, self._key.d, self._key.n) # Complete step 2c (I2OSP) em = long_to_bytes(m_int, k) # Step 3a @@ -191,12 +191,12 @@ class PKCS1OAepCipher: # Step 3g one_pos = hLen + db[hLen:].find(b"\x01") lHash1 = db[:hLen] - invalid = bord(y) | int(one_pos < hLen) # type: ignore + invalid = bord(y) | int(one_pos < hLen) # type: ignore[arg-type] hash_compare = strxor(lHash1, lHash) for x in hash_compare: - invalid |= bord(x) # type: ignore + invalid |= bord(x) # type: ignore[arg-type] for x in db[hLen:one_pos]: - invalid |= bord(x) # type: ignore + invalid |= bord(x) # type: ignore[arg-type] if invalid != 0: raise ValueError("Incorrect decryption.") # Step 4 diff --git a/api/libs/login.py b/api/libs/login.py index 5ed4bfae8f..4b8ee2d1f8 100644 --- a/api/libs/login.py +++ b/api/libs/login.py @@ -3,7 +3,7 @@ from functools import wraps from typing import Any from flask import current_app, g, has_request_context, request -from flask_login.config import EXEMPT_METHODS # type: ignore +from flask_login.config import EXEMPT_METHODS from werkzeug.local import LocalProxy from configs import dify_config @@ -87,7 +87,7 @@ def _get_user() -> EndUser | Account | None: if "_login_user" not in g: current_app.login_manager._load_user() # type: ignore - return g._login_user # type: ignore + return g._login_user return None diff --git a/api/libs/sendgrid.py b/api/libs/sendgrid.py index a270fa70fa..c047c54d06 100644 --- a/api/libs/sendgrid.py +++ b/api/libs/sendgrid.py @@ -1,8 +1,8 @@ import logging -import sendgrid # type: ignore +import sendgrid from python_http_client.exceptions import ForbiddenError, UnauthorizedError -from sendgrid.helpers.mail import Content, Email, Mail, To # type: ignore +from sendgrid.helpers.mail import Content, Email, Mail, To logger = logging.getLogger(__name__) diff --git a/api/models/account.py b/api/models/account.py index 86cd9e41b5..400a2c6362 100644 --- a/api/models/account.py +++ b/api/models/account.py @@ -5,7 +5,7 @@ from datetime import datetime from typing import Any, Optional import sqlalchemy as sa -from flask_login import UserMixin # type: ignore[import-untyped] +from flask_login import UserMixin from sqlalchemy import DateTime, String, func, select from sqlalchemy.orm import Mapped, Session, mapped_column from typing_extensions import deprecated diff --git a/api/models/model.py b/api/models/model.py index af22ab9538..8a8574e2fe 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -8,7 +8,7 @@ from typing import TYPE_CHECKING, Any, Literal, Optional, cast import sqlalchemy as sa from flask import request -from flask_login import UserMixin # type: ignore[import-untyped] +from flask_login import UserMixin from sqlalchemy import Float, Index, PrimaryKeyConstraint, String, exists, func, select, text from sqlalchemy.orm import Mapped, Session, mapped_column diff --git a/api/pyrightconfig.json b/api/pyrightconfig.json index bf4ec2314e..6a689b96df 100644 --- a/api/pyrightconfig.json +++ b/api/pyrightconfig.json @@ -16,7 +16,25 @@ "opentelemetry.instrumentation.requests", "opentelemetry.instrumentation.sqlalchemy", "opentelemetry.instrumentation.redis", - "opentelemetry.instrumentation.httpx" + "langfuse", + "cloudscraper", + "readabilipy", + "pypandoc", + "pypdfium2", + "webvtt", + "flask_compress", + "oss2", + "baidubce.auth.bce_credentials", + "baidubce.bce_client_configuration", + "baidubce.services.bos.bos_client", + "clickzetta", + "google.cloud", + "obs", + "qcloud_cos", + "tos", + "gmpy2", + "sendgrid", + "sendgrid.helpers.mail" ], "reportUnknownMemberType": "hint", "reportUnknownParameterType": "hint", @@ -28,7 +46,7 @@ "reportUnnecessaryComparison": "hint", "reportUnnecessaryIsInstance": "hint", "reportUntypedFunctionDecorator": "hint", - + "reportUnnecessaryTypeIgnoreComment": "hint", "reportAttributeAccessIssue": "hint", "pythonVersion": "3.11", "pythonPlatform": "All" diff --git a/api/repositories/factory.py b/api/repositories/factory.py index 0be9c8908c..96f9f886a4 100644 --- a/api/repositories/factory.py +++ b/api/repositories/factory.py @@ -48,7 +48,7 @@ class DifyAPIRepositoryFactory(DifyCoreRepositoryFactory): try: repository_class = import_string(class_path) - return repository_class(session_maker=session_maker) # type: ignore[no-any-return] + return repository_class(session_maker=session_maker) except (ImportError, Exception) as e: raise RepositoryImportError( f"Failed to create DifyAPIWorkflowNodeExecutionRepository from '{class_path}': {e}" @@ -77,6 +77,6 @@ class DifyAPIRepositoryFactory(DifyCoreRepositoryFactory): try: repository_class = import_string(class_path) - return repository_class(session_maker=session_maker) # type: ignore[no-any-return] + return repository_class(session_maker=session_maker) except (ImportError, Exception) as e: raise RepositoryImportError(f"Failed to create APIWorkflowRunRepository from '{class_path}': {e}") from e diff --git a/api/services/app_dsl_service.py b/api/services/app_dsl_service.py index e2915ebfbb..edb18a845a 100644 --- a/api/services/app_dsl_service.py +++ b/api/services/app_dsl_service.py @@ -7,7 +7,7 @@ from enum import StrEnum from urllib.parse import urlparse from uuid import uuid4 -import yaml # type: ignore +import yaml from Crypto.Cipher import AES from Crypto.Util.Padding import pad, unpad from packaging import version @@ -563,7 +563,7 @@ class AppDslService: else: cls._append_model_config_export_data(export_data, app_model) - return yaml.dump(export_data, allow_unicode=True) # type: ignore + return yaml.dump(export_data, allow_unicode=True) @classmethod def _append_workflow_export_data( diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index f4047da6b8..c97d419545 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -241,9 +241,9 @@ class DatasetService: dataset.created_by = account.id dataset.updated_by = account.id dataset.tenant_id = tenant_id - dataset.embedding_model_provider = embedding_model.provider if embedding_model else None # type: ignore - dataset.embedding_model = embedding_model.model if embedding_model else None # type: ignore - dataset.retrieval_model = retrieval_model.model_dump() if retrieval_model else None # type: ignore + dataset.embedding_model_provider = embedding_model.provider if embedding_model else None + dataset.embedding_model = embedding_model.model if embedding_model else None + dataset.retrieval_model = retrieval_model.model_dump() if retrieval_model else None dataset.permission = permission or DatasetPermissionEnum.ONLY_ME dataset.provider = provider db.session.add(dataset) @@ -1416,6 +1416,8 @@ class DocumentService: # check document limit assert isinstance(current_user, Account) assert current_user.current_tenant_id is not None + assert knowledge_config.data_source + assert knowledge_config.data_source.info_list.file_info_list features = FeatureService.get_features(current_user.current_tenant_id) @@ -1424,15 +1426,16 @@ class DocumentService: count = 0 if knowledge_config.data_source: if knowledge_config.data_source.info_list.data_source_type == "upload_file": - upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids # type: ignore + upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids count = len(upload_file_list) elif knowledge_config.data_source.info_list.data_source_type == "notion_import": - notion_info_list = knowledge_config.data_source.info_list.notion_info_list - for notion_info in notion_info_list: # type: ignore + notion_info_list = knowledge_config.data_source.info_list.notion_info_list or [] + for notion_info in notion_info_list: count = count + len(notion_info.pages) elif knowledge_config.data_source.info_list.data_source_type == "website_crawl": website_info = knowledge_config.data_source.info_list.website_info_list - count = len(website_info.urls) # type: ignore + assert website_info + count = len(website_info.urls) batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT) if features.billing.subscription.plan == "sandbox" and count > 1: @@ -1444,7 +1447,7 @@ class DocumentService: # if dataset is empty, update dataset data_source_type if not dataset.data_source_type: - dataset.data_source_type = knowledge_config.data_source.info_list.data_source_type # type: ignore + dataset.data_source_type = knowledge_config.data_source.info_list.data_source_type if not dataset.indexing_technique: if knowledge_config.indexing_technique not in Dataset.INDEXING_TECHNIQUE_LIST: @@ -1481,7 +1484,7 @@ class DocumentService: knowledge_config.retrieval_model.model_dump() if knowledge_config.retrieval_model else default_retrieval_model - ) # type: ignore + ) documents = [] if knowledge_config.original_document_id: @@ -1523,11 +1526,12 @@ class DocumentService: db.session.flush() lock_name = f"add_document_lock_dataset_id_{dataset.id}" with redis_client.lock(lock_name, timeout=600): + assert dataset_process_rule position = DocumentService.get_documents_position(dataset.id) document_ids = [] duplicate_document_ids = [] - if knowledge_config.data_source.info_list.data_source_type == "upload_file": # type: ignore - upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids # type: ignore + if knowledge_config.data_source.info_list.data_source_type == "upload_file": + upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids for file_id in upload_file_list: file = ( db.session.query(UploadFile) @@ -1540,7 +1544,7 @@ class DocumentService: raise FileNotExistsError() file_name = file.name - data_source_info = { + data_source_info: dict[str, str | bool] = { "upload_file_id": file_id, } # check duplicate @@ -1557,7 +1561,7 @@ class DocumentService: .first() ) if document: - document.dataset_process_rule_id = dataset_process_rule.id # type: ignore + document.dataset_process_rule_id = dataset_process_rule.id document.updated_at = naive_utc_now() document.created_from = created_from document.doc_form = knowledge_config.doc_form @@ -1571,8 +1575,8 @@ class DocumentService: continue document = DocumentService.build_document( dataset, - dataset_process_rule.id, # type: ignore - knowledge_config.data_source.info_list.data_source_type, # type: ignore + dataset_process_rule.id, + knowledge_config.data_source.info_list.data_source_type, knowledge_config.doc_form, knowledge_config.doc_language, data_source_info, @@ -1587,7 +1591,7 @@ class DocumentService: document_ids.append(document.id) documents.append(document) position += 1 - elif knowledge_config.data_source.info_list.data_source_type == "notion_import": # type: ignore + elif knowledge_config.data_source.info_list.data_source_type == "notion_import": notion_info_list = knowledge_config.data_source.info_list.notion_info_list # type: ignore if not notion_info_list: raise ValueError("No notion info list found.") @@ -1616,15 +1620,15 @@ class DocumentService: "credential_id": notion_info.credential_id, "notion_workspace_id": workspace_id, "notion_page_id": page.page_id, - "notion_page_icon": page.page_icon.model_dump() if page.page_icon else None, + "notion_page_icon": page.page_icon.model_dump() if page.page_icon else None, # type: ignore "type": page.type, } # Truncate page name to 255 characters to prevent DB field length errors truncated_page_name = page.page_name[:255] if page.page_name else "nopagename" document = DocumentService.build_document( dataset, - dataset_process_rule.id, # type: ignore - knowledge_config.data_source.info_list.data_source_type, # type: ignore + dataset_process_rule.id, + knowledge_config.data_source.info_list.data_source_type, knowledge_config.doc_form, knowledge_config.doc_language, data_source_info, @@ -1644,8 +1648,8 @@ class DocumentService: # delete not selected documents if len(exist_document) > 0: clean_notion_document_task.delay(list(exist_document.values()), dataset.id) - elif knowledge_config.data_source.info_list.data_source_type == "website_crawl": # type: ignore - website_info = knowledge_config.data_source.info_list.website_info_list # type: ignore + elif knowledge_config.data_source.info_list.data_source_type == "website_crawl": + website_info = knowledge_config.data_source.info_list.website_info_list if not website_info: raise ValueError("No website info list found.") urls = website_info.urls @@ -1663,8 +1667,8 @@ class DocumentService: document_name = url document = DocumentService.build_document( dataset, - dataset_process_rule.id, # type: ignore - knowledge_config.data_source.info_list.data_source_type, # type: ignore + dataset_process_rule.id, + knowledge_config.data_source.info_list.data_source_type, knowledge_config.doc_form, knowledge_config.doc_language, data_source_info, @@ -2071,7 +2075,7 @@ class DocumentService: # update document data source if document_data.data_source: file_name = "" - data_source_info = {} + data_source_info: dict[str, str | bool] = {} if document_data.data_source.info_list.data_source_type == "upload_file": if not document_data.data_source.info_list.file_info_list: raise ValueError("No file info list found.") @@ -2128,7 +2132,7 @@ class DocumentService: "url": url, "provider": website_info.provider, "job_id": website_info.job_id, - "only_main_content": website_info.only_main_content, # type: ignore + "only_main_content": website_info.only_main_content, "mode": "crawl", } document.data_source_type = document_data.data_source.info_list.data_source_type @@ -2154,7 +2158,7 @@ class DocumentService: db.session.query(DocumentSegment).filter_by(document_id=document.id).update( {DocumentSegment.status: "re_segment"} - ) # type: ignore + ) db.session.commit() # trigger async task document_indexing_update_task.delay(document.dataset_id, document.id) @@ -2164,25 +2168,26 @@ class DocumentService: def save_document_without_dataset_id(tenant_id: str, knowledge_config: KnowledgeConfig, account: Account): assert isinstance(current_user, Account) assert current_user.current_tenant_id is not None + assert knowledge_config.data_source features = FeatureService.get_features(current_user.current_tenant_id) if features.billing.enabled: count = 0 - if knowledge_config.data_source.info_list.data_source_type == "upload_file": # type: ignore + if knowledge_config.data_source.info_list.data_source_type == "upload_file": upload_file_list = ( - knowledge_config.data_source.info_list.file_info_list.file_ids # type: ignore - if knowledge_config.data_source.info_list.file_info_list # type: ignore + knowledge_config.data_source.info_list.file_info_list.file_ids + if knowledge_config.data_source.info_list.file_info_list else [] ) count = len(upload_file_list) - elif knowledge_config.data_source.info_list.data_source_type == "notion_import": # type: ignore - notion_info_list = knowledge_config.data_source.info_list.notion_info_list # type: ignore + elif knowledge_config.data_source.info_list.data_source_type == "notion_import": + notion_info_list = knowledge_config.data_source.info_list.notion_info_list if notion_info_list: for notion_info in notion_info_list: count = count + len(notion_info.pages) - elif knowledge_config.data_source.info_list.data_source_type == "website_crawl": # type: ignore - website_info = knowledge_config.data_source.info_list.website_info_list # type: ignore + elif knowledge_config.data_source.info_list.data_source_type == "website_crawl": + website_info = knowledge_config.data_source.info_list.website_info_list if website_info: count = len(website_info.urls) if features.billing.subscription.plan == "sandbox" and count > 1: @@ -2196,9 +2201,11 @@ class DocumentService: dataset_collection_binding_id = None retrieval_model = None if knowledge_config.indexing_technique == "high_quality": + assert knowledge_config.embedding_model_provider + assert knowledge_config.embedding_model dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( - knowledge_config.embedding_model_provider, # type: ignore - knowledge_config.embedding_model, # type: ignore + knowledge_config.embedding_model_provider, + knowledge_config.embedding_model, ) dataset_collection_binding_id = dataset_collection_binding.id if knowledge_config.retrieval_model: @@ -2215,7 +2222,7 @@ class DocumentService: dataset = Dataset( tenant_id=tenant_id, name="", - data_source_type=knowledge_config.data_source.info_list.data_source_type, # type: ignore + data_source_type=knowledge_config.data_source.info_list.data_source_type, indexing_technique=knowledge_config.indexing_technique, created_by=account.id, embedding_model=knowledge_config.embedding_model, @@ -2224,7 +2231,7 @@ class DocumentService: retrieval_model=retrieval_model.model_dump() if retrieval_model else None, ) - db.session.add(dataset) # type: ignore + db.session.add(dataset) db.session.flush() documents, batch = DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account) diff --git a/api/services/hit_testing_service.py b/api/services/hit_testing_service.py index 7fa82c6d22..337181728c 100644 --- a/api/services/hit_testing_service.py +++ b/api/services/hit_testing_service.py @@ -88,7 +88,7 @@ class HitTestingService: db.session.add(dataset_query) db.session.commit() - return cls.compact_retrieve_response(query, all_documents) # type: ignore + return cls.compact_retrieve_response(query, all_documents) @classmethod def external_retrieve( diff --git a/api/services/knowledge_service.py b/api/services/knowledge_service.py index 8df1a6ba14..02fe1d19bc 100644 --- a/api/services/knowledge_service.py +++ b/api/services/knowledge_service.py @@ -1,4 +1,4 @@ -import boto3 # type: ignore +import boto3 from configs import dify_config diff --git a/api/services/metadata_service.py b/api/services/metadata_service.py index 5f280c9e57..b369994d2d 100644 --- a/api/services/metadata_service.py +++ b/api/services/metadata_service.py @@ -89,7 +89,7 @@ class MetadataService: document.doc_metadata = doc_metadata db.session.add(document) db.session.commit() - return metadata # type: ignore + return metadata except Exception: logger.exception("Update metadata name failed") finally: diff --git a/api/services/model_provider_service.py b/api/services/model_provider_service.py index 2901a0d273..50ddbbf681 100644 --- a/api/services/model_provider_service.py +++ b/api/services/model_provider_service.py @@ -137,7 +137,7 @@ class ModelProviderService: :return: """ provider_configuration = self._get_provider_configuration(tenant_id, provider) - return provider_configuration.get_provider_credential(credential_id=credential_id) # type: ignore + return provider_configuration.get_provider_credential(credential_id=credential_id) def validate_provider_credentials(self, tenant_id: str, provider: str, credentials: dict): """ @@ -225,7 +225,7 @@ class ModelProviderService: :return: """ provider_configuration = self._get_provider_configuration(tenant_id, provider) - return provider_configuration.get_custom_model_credential( # type: ignore + return provider_configuration.get_custom_model_credential( model_type=ModelType.value_of(model_type), model=model, credential_id=credential_id ) diff --git a/api/services/plugin/plugin_migration.py b/api/services/plugin/plugin_migration.py index dec92a6faa..df5fa3e233 100644 --- a/api/services/plugin/plugin_migration.py +++ b/api/services/plugin/plugin_migration.py @@ -146,7 +146,7 @@ class PluginMigration: futures.append( thread_pool.submit( process_tenant, - current_app._get_current_object(), # type: ignore[attr-defined] + current_app._get_current_object(), # type: ignore tenant_id, ) ) diff --git a/api/services/tools/builtin_tools_manage_service.py b/api/services/tools/builtin_tools_manage_service.py index b5dcec17d0..0628c8f22e 100644 --- a/api/services/tools/builtin_tools_manage_service.py +++ b/api/services/tools/builtin_tools_manage_service.py @@ -544,8 +544,8 @@ class BuiltinToolManageService: try: # handle include, exclude if is_filtered( - include_set=dify_config.POSITION_TOOL_INCLUDES_SET, # type: ignore - exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET, # type: ignore + include_set=dify_config.POSITION_TOOL_INCLUDES_SET, + exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET, data=provider_controller, name_func=lambda x: x.entity.identity.name, ): diff --git a/api/services/tools/mcp_tools_manage_service.py b/api/services/tools/mcp_tools_manage_service.py index 54133d3801..92c33c1a49 100644 --- a/api/services/tools/mcp_tools_manage_service.py +++ b/api/services/tools/mcp_tools_manage_service.py @@ -308,7 +308,7 @@ class MCPToolManageService: provider_controller = MCPToolProviderController.from_db(mcp_provider) tool_configuration = ProviderConfigEncrypter( tenant_id=mcp_provider.tenant_id, - config=list(provider_controller.get_credentials_schema()), # ty: ignore [invalid-argument-type] + config=list(provider_controller.get_credentials_schema()), provider_config_cache=NoOpProviderCredentialCache(), ) credentials = tool_configuration.encrypt(credentials) diff --git a/api/tasks/batch_create_segment_to_index_task.py b/api/tasks/batch_create_segment_to_index_task.py index b528728364..bd95af2614 100644 --- a/api/tasks/batch_create_segment_to_index_task.py +++ b/api/tasks/batch_create_segment_to_index_task.py @@ -102,7 +102,7 @@ def batch_create_segment_to_index_task( for segment, tokens in zip(content, tokens_list): content = segment["content"] doc_id = str(uuid.uuid4()) - segment_hash = helper.generate_text_hash(content) # type: ignore + segment_hash = helper.generate_text_hash(content) max_position = ( db.session.query(func.max(DocumentSegment.position)) .where(DocumentSegment.document_id == dataset_document.id) diff --git a/api/tests/integration_tests/vdb/__mock/baiduvectordb.py b/api/tests/integration_tests/vdb/__mock/baiduvectordb.py index 8a43d03a43..3984078ee9 100644 --- a/api/tests/integration_tests/vdb/__mock/baiduvectordb.py +++ b/api/tests/integration_tests/vdb/__mock/baiduvectordb.py @@ -5,11 +5,11 @@ from unittest.mock import MagicMock import pytest from _pytest.monkeypatch import MonkeyPatch -from pymochow import MochowClient # type: ignore -from pymochow.model.database import Database # type: ignore -from pymochow.model.enum import IndexState, IndexType, MetricType, ReadConsistency, TableState # type: ignore -from pymochow.model.schema import HNSWParams, VectorIndex # type: ignore -from pymochow.model.table import Table # type: ignore +from pymochow import MochowClient +from pymochow.model.database import Database +from pymochow.model.enum import IndexState, IndexType, MetricType, ReadConsistency, TableState +from pymochow.model.schema import HNSWParams, VectorIndex +from pymochow.model.table import Table class AttrDict(UserDict): diff --git a/api/tests/integration_tests/vdb/__mock/tcvectordb.py b/api/tests/integration_tests/vdb/__mock/tcvectordb.py index 5130fcfe17..8f87d6a073 100644 --- a/api/tests/integration_tests/vdb/__mock/tcvectordb.py +++ b/api/tests/integration_tests/vdb/__mock/tcvectordb.py @@ -3,15 +3,15 @@ from typing import Any, Union import pytest from _pytest.monkeypatch import MonkeyPatch -from tcvectordb import RPCVectorDBClient # type: ignore +from tcvectordb import RPCVectorDBClient from tcvectordb.model import enum from tcvectordb.model.collection import FilterIndexConfig -from tcvectordb.model.document import AnnSearch, Document, Filter, KeywordSearch, Rerank # type: ignore -from tcvectordb.model.enum import ReadConsistency # type: ignore -from tcvectordb.model.index import FilterIndex, HNSWParams, Index, IndexField, VectorIndex # type: ignore +from tcvectordb.model.document import AnnSearch, Document, Filter, KeywordSearch, Rerank +from tcvectordb.model.enum import ReadConsistency +from tcvectordb.model.index import FilterIndex, HNSWParams, Index, IndexField, VectorIndex from tcvectordb.rpc.model.collection import RPCCollection from tcvectordb.rpc.model.database import RPCDatabase -from xinference_client.types import Embedding # type: ignore +from xinference_client.types import Embedding class MockTcvectordbClass: diff --git a/api/tests/integration_tests/vdb/__mock/vikingdb.py b/api/tests/integration_tests/vdb/__mock/vikingdb.py index f351df8d5b..289c515b85 100644 --- a/api/tests/integration_tests/vdb/__mock/vikingdb.py +++ b/api/tests/integration_tests/vdb/__mock/vikingdb.py @@ -4,7 +4,7 @@ from unittest.mock import MagicMock import pytest from _pytest.monkeypatch import MonkeyPatch -from volcengine.viking_db import ( # type: ignore +from volcengine.viking_db import ( Collection, Data, DistanceType, diff --git a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter.py b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter.py index 5895f63f94..8423f1ab02 100644 --- a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter.py +++ b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter.py @@ -43,7 +43,7 @@ class TestWorkflowResponseConverterFetchFilesFromVariableValue: """Test with None input""" # The method signature expects Union[dict, list, Segment], but implementation handles None # We'll test the actual behavior by passing an empty dict instead - result = WorkflowResponseConverter._fetch_files_from_variable_value(None) # type: ignore + result = WorkflowResponseConverter._fetch_files_from_variable_value(None) assert result == [] def test_fetch_files_from_variable_value_with_empty_dict(self): diff --git a/api/tests/unit_tests/core/mcp/server/test_streamable_http.py b/api/tests/unit_tests/core/mcp/server/test_streamable_http.py index 895ebdd751..fe9f0935d5 100644 --- a/api/tests/unit_tests/core/mcp/server/test_streamable_http.py +++ b/api/tests/unit_tests/core/mcp/server/test_streamable_http.py @@ -235,7 +235,7 @@ class TestIndividualHandlers: # Type assertion needed due to union type text_content = result.content[0] assert hasattr(text_content, "text") - assert text_content.text == "test answer" # type: ignore[attr-defined] + assert text_content.text == "test answer" def test_handle_call_tool_no_end_user(self): """Test call tool handler without end user""" diff --git a/api/tests/unit_tests/core/workflow/graph/test_graph_validation.py b/api/tests/unit_tests/core/workflow/graph/test_graph_validation.py new file mode 100644 index 0000000000..b55d4998c4 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/graph/test_graph_validation.py @@ -0,0 +1,181 @@ +from __future__ import annotations + +import time +from collections.abc import Mapping +from dataclasses import dataclass +from typing import Any + +import pytest + +from core.app.entities.app_invoke_entities import InvokeFrom +from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool +from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType +from core.workflow.graph import Graph +from core.workflow.graph.validation import GraphValidationError +from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig +from core.workflow.nodes.base.node import Node +from core.workflow.system_variable import SystemVariable +from models.enums import UserFrom + + +class _TestNode(Node): + node_type = NodeType.ANSWER + execution_type = NodeExecutionType.EXECUTABLE + + @classmethod + def version(cls) -> str: + return "test" + + def __init__( + self, + *, + id: str, + config: Mapping[str, object], + graph_init_params: GraphInitParams, + graph_runtime_state: GraphRuntimeState, + ) -> None: + super().__init__( + id=id, + config=config, + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + ) + data = config.get("data", {}) + if isinstance(data, Mapping): + execution_type = data.get("execution_type") + if isinstance(execution_type, str): + self.execution_type = NodeExecutionType(execution_type) + self._base_node_data = BaseNodeData(title=str(data.get("title", self.id))) + self.data: dict[str, object] = {} + + def init_node_data(self, data: Mapping[str, object]) -> None: + title = str(data.get("title", self.id)) + desc = data.get("description") + error_strategy_value = data.get("error_strategy") + error_strategy: ErrorStrategy | None = None + if isinstance(error_strategy_value, ErrorStrategy): + error_strategy = error_strategy_value + elif isinstance(error_strategy_value, str): + error_strategy = ErrorStrategy(error_strategy_value) + self._base_node_data = BaseNodeData( + title=title, + desc=str(desc) if desc is not None else None, + error_strategy=error_strategy, + ) + self.data = dict(data) + + def _run(self): + raise NotImplementedError + + def _get_error_strategy(self) -> ErrorStrategy | None: + return self._base_node_data.error_strategy + + def _get_retry_config(self) -> RetryConfig: + return self._base_node_data.retry_config + + def _get_title(self) -> str: + return self._base_node_data.title + + def _get_description(self) -> str | None: + return self._base_node_data.desc + + def _get_default_value_dict(self) -> dict[str, Any]: + return self._base_node_data.default_value_dict + + def get_base_node_data(self) -> BaseNodeData: + return self._base_node_data + + +@dataclass(slots=True) +class _SimpleNodeFactory: + graph_init_params: GraphInitParams + graph_runtime_state: GraphRuntimeState + + def create_node(self, node_config: Mapping[str, object]) -> _TestNode: + node_id = str(node_config["id"]) + node = _TestNode( + id=node_id, + config=node_config, + graph_init_params=self.graph_init_params, + graph_runtime_state=self.graph_runtime_state, + ) + node.init_node_data(node_config.get("data", {})) + return node + + +@pytest.fixture +def graph_init_dependencies() -> tuple[_SimpleNodeFactory, dict[str, object]]: + graph_config: dict[str, object] = {"edges": [], "nodes": []} + init_params = GraphInitParams( + tenant_id="tenant", + app_id="app", + workflow_id="workflow", + graph_config=graph_config, + user_id="user", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.SERVICE_API, + call_depth=0, + ) + variable_pool = VariablePool(system_variables=SystemVariable(user_id="user", files=[]), user_inputs={}) + runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) + factory = _SimpleNodeFactory(graph_init_params=init_params, graph_runtime_state=runtime_state) + return factory, graph_config + + +def test_graph_initialization_runs_default_validators( + graph_init_dependencies: tuple[_SimpleNodeFactory, dict[str, object]], +): + node_factory, graph_config = graph_init_dependencies + graph_config["nodes"] = [ + {"id": "start", "data": {"type": NodeType.START, "title": "Start", "execution_type": NodeExecutionType.ROOT}}, + {"id": "answer", "data": {"type": NodeType.ANSWER, "title": "Answer"}}, + ] + graph_config["edges"] = [ + {"source": "start", "target": "answer", "sourceHandle": "success"}, + ] + + graph = Graph.init(graph_config=graph_config, node_factory=node_factory) + + assert graph.root_node.id == "start" + assert "answer" in graph.nodes + + +def test_graph_validation_fails_for_unknown_edge_targets( + graph_init_dependencies: tuple[_SimpleNodeFactory, dict[str, object]], +) -> None: + node_factory, graph_config = graph_init_dependencies + graph_config["nodes"] = [ + {"id": "start", "data": {"type": NodeType.START, "title": "Start", "execution_type": NodeExecutionType.ROOT}}, + ] + graph_config["edges"] = [ + {"source": "start", "target": "missing", "sourceHandle": "success"}, + ] + + with pytest.raises(GraphValidationError) as exc: + Graph.init(graph_config=graph_config, node_factory=node_factory) + + assert any(issue.code == "MISSING_NODE" for issue in exc.value.issues) + + +def test_graph_promotes_fail_branch_nodes_to_branch_execution_type( + graph_init_dependencies: tuple[_SimpleNodeFactory, dict[str, object]], +) -> None: + node_factory, graph_config = graph_init_dependencies + graph_config["nodes"] = [ + {"id": "start", "data": {"type": NodeType.START, "title": "Start", "execution_type": NodeExecutionType.ROOT}}, + { + "id": "branch", + "data": { + "type": NodeType.IF_ELSE, + "title": "Branch", + "error_strategy": ErrorStrategy.FAIL_BRANCH, + }, + }, + ] + graph_config["edges"] = [ + {"source": "start", "target": "branch", "sourceHandle": "success"}, + ] + + graph = Graph.init(graph_config=graph_config, node_factory=node_factory) + + assert graph.nodes["branch"].execution_type == NodeExecutionType.BRANCH diff --git a/api/tests/unit_tests/core/workflow/nodes/parameter_extractor/test_parameter_extractor_node.py b/api/tests/unit_tests/core/workflow/nodes/parameter_extractor/test_parameter_extractor_node.py index b9947d4693..b359284d00 100644 --- a/api/tests/unit_tests/core/workflow/nodes/parameter_extractor/test_parameter_extractor_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/parameter_extractor/test_parameter_extractor_node.py @@ -212,7 +212,7 @@ class TestValidateResult: parameters=[ ParameterConfig( name="status", - type="select", # type: ignore + type="select", description="Status", required=True, options=["active", "inactive"], @@ -400,7 +400,7 @@ class TestTransformResult: parameters=[ ParameterConfig( name="status", - type="select", # type: ignore + type="select", description="Status", required=True, options=["active", "inactive"], @@ -414,7 +414,7 @@ class TestTransformResult: parameters=[ ParameterConfig( name="status", - type="select", # type: ignore + type="select", description="Status", required=True, options=["active", "inactive"], diff --git a/api/tests/unit_tests/core/workflow/test_system_variable.py b/api/tests/unit_tests/core/workflow/test_system_variable.py index 3ae5edb383..f76e81ae55 100644 --- a/api/tests/unit_tests/core/workflow/test_system_variable.py +++ b/api/tests/unit_tests/core/workflow/test_system_variable.py @@ -248,4 +248,4 @@ def test_constructor_with_extra_key(): # Test that SystemVariable should forbid extra keys with pytest.raises(ValidationError): # This should fail because there is an unexpected key. - SystemVariable(invalid_key=1) # type: ignore + SystemVariable(invalid_key=1) diff --git a/api/tests/unit_tests/libs/test_external_api.py b/api/tests/unit_tests/libs/test_external_api.py index c4c376a070..9aa157a651 100644 --- a/api/tests/unit_tests/libs/test_external_api.py +++ b/api/tests/unit_tests/libs/test_external_api.py @@ -14,36 +14,36 @@ def _create_api_app(): api = ExternalApi(bp) @api.route("/bad-request") - class Bad(Resource): # type: ignore - def get(self): # type: ignore + class Bad(Resource): + def get(self): raise BadRequest("invalid input") @api.route("/unauth") - class Unauth(Resource): # type: ignore - def get(self): # type: ignore + class Unauth(Resource): + def get(self): raise Unauthorized("auth required") @api.route("/value-error") - class ValErr(Resource): # type: ignore - def get(self): # type: ignore + class ValErr(Resource): + def get(self): raise ValueError("boom") @api.route("/quota") - class Quota(Resource): # type: ignore - def get(self): # type: ignore + class Quota(Resource): + def get(self): raise AppInvokeQuotaExceededError("quota exceeded") @api.route("/general") - class Gen(Resource): # type: ignore - def get(self): # type: ignore + class Gen(Resource): + def get(self): raise RuntimeError("oops") # Note: We avoid altering default_mediatype to keep normal error paths # Special 400 message rewrite @api.route("/json-empty") - class JsonEmpty(Resource): # type: ignore - def get(self): # type: ignore + class JsonEmpty(Resource): + def get(self): e = BadRequest() # Force the specific message the handler rewrites e.description = "Failed to decode JSON object: Expecting value: line 1 column 1 (char 0)" @@ -51,11 +51,11 @@ def _create_api_app(): # 400 mapping payload path @api.route("/param-errors") - class ParamErrors(Resource): # type: ignore - def get(self): # type: ignore + class ParamErrors(Resource): + def get(self): e = BadRequest() # Coerce a mapping description to trigger param error shaping - e.description = {"field": "is required"} # type: ignore[assignment] + e.description = {"field": "is required"} raise e app.register_blueprint(bp, url_prefix="/api") @@ -105,7 +105,7 @@ def test_external_api_param_mapping_and_quota_and_exc_info_none(): orig_exc_info = ext.sys.exc_info try: - ext.sys.exc_info = lambda: (None, None, None) # type: ignore[assignment] + ext.sys.exc_info = lambda: (None, None, None) app = _create_api_app() client = app.test_client() diff --git a/api/tests/unit_tests/libs/test_flask_utils.py b/api/tests/unit_tests/libs/test_flask_utils.py index e30433bfce..9cab0db24c 100644 --- a/api/tests/unit_tests/libs/test_flask_utils.py +++ b/api/tests/unit_tests/libs/test_flask_utils.py @@ -67,7 +67,7 @@ def test_current_user_not_accessible_across_threads(login_app: Flask, test_user: # without preserve_flask_contexts result["user_accessible"] = current_user.is_authenticated except Exception as e: - result["error"] = str(e) # type: ignore + result["error"] = str(e) # Run the function in a separate thread thread = threading.Thread(target=check_user_in_thread) @@ -110,7 +110,7 @@ def test_current_user_accessible_with_preserve_flask_contexts(login_app: Flask, else: result["user_accessible"] = False except Exception as e: - result["error"] = str(e) # type: ignore + result["error"] = str(e) # Run the function in a separate thread thread = threading.Thread(target=check_user_in_thread_with_manager) diff --git a/api/tests/unit_tests/libs/test_oauth_base.py b/api/tests/unit_tests/libs/test_oauth_base.py index 3e0c235fff..7b7f086dac 100644 --- a/api/tests/unit_tests/libs/test_oauth_base.py +++ b/api/tests/unit_tests/libs/test_oauth_base.py @@ -16,4 +16,4 @@ def test_oauth_base_methods_raise_not_implemented(): oauth.get_raw_user_info("token") with pytest.raises(NotImplementedError): - oauth._transform_user_info({}) # type: ignore[name-defined] + oauth._transform_user_info({}) diff --git a/api/tests/unit_tests/oss/__mock/tencent_cos.py b/api/tests/unit_tests/oss/__mock/tencent_cos.py index c77c5b08f3..5189b68e87 100644 --- a/api/tests/unit_tests/oss/__mock/tencent_cos.py +++ b/api/tests/unit_tests/oss/__mock/tencent_cos.py @@ -3,8 +3,8 @@ from unittest.mock import MagicMock import pytest from _pytest.monkeypatch import MonkeyPatch -from qcloud_cos import CosS3Client # type: ignore -from qcloud_cos.streambody import StreamBody # type: ignore +from qcloud_cos import CosS3Client +from qcloud_cos.streambody import StreamBody from tests.unit_tests.oss.__mock.base import ( get_example_bucket, diff --git a/api/tests/unit_tests/oss/__mock/volcengine_tos.py b/api/tests/unit_tests/oss/__mock/volcengine_tos.py index 88df59f91c..649d93a202 100644 --- a/api/tests/unit_tests/oss/__mock/volcengine_tos.py +++ b/api/tests/unit_tests/oss/__mock/volcengine_tos.py @@ -4,8 +4,8 @@ from unittest.mock import MagicMock import pytest from _pytest.monkeypatch import MonkeyPatch -from tos import TosClientV2 # type: ignore -from tos.clientv2 import DeleteObjectOutput, GetObjectOutput, HeadObjectOutput, PutObjectOutput # type: ignore +from tos import TosClientV2 +from tos.clientv2 import DeleteObjectOutput, GetObjectOutput, HeadObjectOutput, PutObjectOutput from tests.unit_tests.oss.__mock.base import ( get_example_bucket, diff --git a/api/tests/unit_tests/oss/tencent_cos/test_tencent_cos.py b/api/tests/unit_tests/oss/tencent_cos/test_tencent_cos.py index d289751800..303f0493bd 100644 --- a/api/tests/unit_tests/oss/tencent_cos/test_tencent_cos.py +++ b/api/tests/unit_tests/oss/tencent_cos/test_tencent_cos.py @@ -1,7 +1,7 @@ from unittest.mock import patch import pytest -from qcloud_cos import CosConfig # type: ignore +from qcloud_cos import CosConfig from extensions.storage.tencent_cos_storage import TencentCosStorage from tests.unit_tests.oss.__mock.base import ( diff --git a/api/tests/unit_tests/oss/volcengine_tos/test_volcengine_tos.py b/api/tests/unit_tests/oss/volcengine_tos/test_volcengine_tos.py index 1659205ec0..a06623a69e 100644 --- a/api/tests/unit_tests/oss/volcengine_tos/test_volcengine_tos.py +++ b/api/tests/unit_tests/oss/volcengine_tos/test_volcengine_tos.py @@ -1,7 +1,7 @@ from unittest.mock import patch import pytest -from tos import TosClientV2 # type: ignore +from tos import TosClientV2 from extensions.storage.volcengine_tos_storage import VolcengineTosStorage from tests.unit_tests.oss.__mock.base import ( diff --git a/api/tests/unit_tests/services/auth/test_api_key_auth_service.py b/api/tests/unit_tests/services/auth/test_api_key_auth_service.py index d23298f096..c6c3f677fb 100644 --- a/api/tests/unit_tests/services/auth/test_api_key_auth_service.py +++ b/api/tests/unit_tests/services/auth/test_api_key_auth_service.py @@ -125,13 +125,13 @@ class TestApiKeyAuthService: mock_session.commit = Mock() args_copy = self.mock_args.copy() - original_key = args_copy["credentials"]["config"]["api_key"] # type: ignore + original_key = args_copy["credentials"]["config"]["api_key"] ApiKeyAuthService.create_provider_auth(self.tenant_id, args_copy) # Verify original key is replaced with encrypted key - assert args_copy["credentials"]["config"]["api_key"] == encrypted_key # type: ignore - assert args_copy["credentials"]["config"]["api_key"] != original_key # type: ignore + assert args_copy["credentials"]["config"]["api_key"] == encrypted_key + assert args_copy["credentials"]["config"]["api_key"] != original_key # Verify encryption function is called correctly mock_encrypter.encrypt_token.assert_called_once_with(self.tenant_id, original_key) @@ -268,7 +268,7 @@ class TestApiKeyAuthService: def test_validate_api_key_auth_args_empty_credentials(self): """Test API key auth args validation - empty credentials""" args = self.mock_args.copy() - args["credentials"] = None # type: ignore + args["credentials"] = None with pytest.raises(ValueError, match="credentials is required"): ApiKeyAuthService.validate_api_key_auth_args(args) @@ -284,7 +284,7 @@ class TestApiKeyAuthService: def test_validate_api_key_auth_args_missing_auth_type(self): """Test API key auth args validation - missing auth_type""" args = self.mock_args.copy() - del args["credentials"]["auth_type"] # type: ignore + del args["credentials"]["auth_type"] with pytest.raises(ValueError, match="auth_type is required"): ApiKeyAuthService.validate_api_key_auth_args(args) @@ -292,7 +292,7 @@ class TestApiKeyAuthService: def test_validate_api_key_auth_args_empty_auth_type(self): """Test API key auth args validation - empty auth_type""" args = self.mock_args.copy() - args["credentials"]["auth_type"] = "" # type: ignore + args["credentials"]["auth_type"] = "" with pytest.raises(ValueError, match="auth_type is required"): ApiKeyAuthService.validate_api_key_auth_args(args) @@ -380,7 +380,7 @@ class TestApiKeyAuthService: def test_validate_api_key_auth_args_dict_credentials_with_list_auth_type(self): """Test API key auth args validation - dict credentials with list auth_type""" args = self.mock_args.copy() - args["credentials"]["auth_type"] = ["api_key"] # type: ignore # list instead of string + args["credentials"]["auth_type"] = ["api_key"] # Current implementation checks if auth_type exists and is truthy, list ["api_key"] is truthy # So this should not raise exception, this test should pass diff --git a/api/tests/unit_tests/utils/oauth_encryption/test_system_oauth_encryption.py b/api/tests/unit_tests/utils/oauth_encryption/test_system_oauth_encryption.py index 30990f8d50..e2607f0fb1 100644 --- a/api/tests/unit_tests/utils/oauth_encryption/test_system_oauth_encryption.py +++ b/api/tests/unit_tests/utils/oauth_encryption/test_system_oauth_encryption.py @@ -116,10 +116,10 @@ class TestSystemOAuthEncrypter: encrypter = SystemOAuthEncrypter("test_secret") with pytest.raises(Exception): # noqa: B017 - encrypter.encrypt_oauth_params(None) # type: ignore + encrypter.encrypt_oauth_params(None) with pytest.raises(Exception): # noqa: B017 - encrypter.encrypt_oauth_params("not_a_dict") # type: ignore + encrypter.encrypt_oauth_params("not_a_dict") def test_decrypt_oauth_params_basic(self): """Test basic OAuth parameters decryption""" @@ -207,12 +207,12 @@ class TestSystemOAuthEncrypter: encrypter = SystemOAuthEncrypter("test_secret") with pytest.raises(ValueError) as exc_info: - encrypter.decrypt_oauth_params(123) # type: ignore + encrypter.decrypt_oauth_params(123) assert "encrypted_data must be a string" in str(exc_info.value) with pytest.raises(ValueError) as exc_info: - encrypter.decrypt_oauth_params(None) # type: ignore + encrypter.decrypt_oauth_params(None) assert "encrypted_data must be a string" in str(exc_info.value) @@ -461,14 +461,14 @@ class TestConvenienceFunctions: """Test convenience functions with error conditions""" # Test encryption with invalid input with pytest.raises(Exception): # noqa: B017 - encrypt_system_oauth_params(None) # type: ignore + encrypt_system_oauth_params(None) # Test decryption with invalid input with pytest.raises(ValueError): decrypt_system_oauth_params("") with pytest.raises(ValueError): - decrypt_system_oauth_params(None) # type: ignore + decrypt_system_oauth_params(None) class TestErrorHandling: @@ -501,7 +501,7 @@ class TestErrorHandling: # Test non-string error with pytest.raises(ValueError) as exc_info: - encrypter.decrypt_oauth_params(123) # type: ignore + encrypter.decrypt_oauth_params(123) assert "encrypted_data must be a string" in str(exc_info.value) # Test invalid format error diff --git a/web/.storybook/__mocks__/context-block.tsx b/web/.storybook/__mocks__/context-block.tsx new file mode 100644 index 0000000000..8a9d8625cc --- /dev/null +++ b/web/.storybook/__mocks__/context-block.tsx @@ -0,0 +1,4 @@ +// Mock for context-block plugin to avoid circular dependency in Storybook +export const ContextBlockNode = null +export const ContextBlockReplacementBlock = null +export default null diff --git a/web/.storybook/__mocks__/history-block.tsx b/web/.storybook/__mocks__/history-block.tsx new file mode 100644 index 0000000000..e3c3965d13 --- /dev/null +++ b/web/.storybook/__mocks__/history-block.tsx @@ -0,0 +1,4 @@ +// Mock for history-block plugin to avoid circular dependency in Storybook +export const HistoryBlockNode = null +export const HistoryBlockReplacementBlock = null +export default null diff --git a/web/.storybook/__mocks__/query-block.tsx b/web/.storybook/__mocks__/query-block.tsx new file mode 100644 index 0000000000..d82f51363a --- /dev/null +++ b/web/.storybook/__mocks__/query-block.tsx @@ -0,0 +1,4 @@ +// Mock for query-block plugin to avoid circular dependency in Storybook +export const QueryBlockNode = null +export const QueryBlockReplacementBlock = null +export default null diff --git a/web/.storybook/main.ts b/web/.storybook/main.ts index 0605c71346..e656115ceb 100644 --- a/web/.storybook/main.ts +++ b/web/.storybook/main.ts @@ -1,4 +1,9 @@ import type { StorybookConfig } from '@storybook/nextjs' +import path from 'node:path' +import { fileURLToPath } from 'node:url' + +const __filename = fileURLToPath(import.meta.url) +const __dirname = path.dirname(__filename) const config: StorybookConfig = { stories: ['../app/components/**/*.stories.@(js|jsx|mjs|ts|tsx)'], @@ -25,5 +30,17 @@ const config: StorybookConfig = { docs: { defaultName: 'Documentation', }, + webpackFinal: async (config) => { + // Add alias to mock problematic modules with circular dependencies + config.resolve = config.resolve || {} + config.resolve.alias = { + ...config.resolve.alias, + // Mock the plugin index files to avoid circular dependencies + [path.resolve(__dirname, '../app/components/base/prompt-editor/plugins/context-block/index.tsx')]: path.resolve(__dirname, '__mocks__/context-block.tsx'), + [path.resolve(__dirname, '../app/components/base/prompt-editor/plugins/history-block/index.tsx')]: path.resolve(__dirname, '__mocks__/history-block.tsx'), + [path.resolve(__dirname, '../app/components/base/prompt-editor/plugins/query-block/index.tsx')]: path.resolve(__dirname, '__mocks__/query-block.tsx'), + } + return config + }, } export default config diff --git a/web/__tests__/navigation-utils.test.ts b/web/__tests__/navigation-utils.test.ts index 9a388505d6..fa4986e63d 100644 --- a/web/__tests__/navigation-utils.test.ts +++ b/web/__tests__/navigation-utils.test.ts @@ -160,8 +160,7 @@ describe('Navigation Utilities', () => { page: 1, limit: '', keyword: 'test', - empty: null, - undefined, + filter: '', }) expect(path).toBe('/datasets/123/documents?page=1&keyword=test') diff --git a/web/__tests__/real-browser-flicker.test.tsx b/web/__tests__/real-browser-flicker.test.tsx index f71e8de515..0a0ea0c062 100644 --- a/web/__tests__/real-browser-flicker.test.tsx +++ b/web/__tests__/real-browser-flicker.test.tsx @@ -39,28 +39,38 @@ const setupMockEnvironment = (storedTheme: string | null, systemPrefersDark = fa const isDarkQuery = DARK_MODE_MEDIA_QUERY.test(query) const matches = isDarkQuery ? systemPrefersDark : false + const handleAddListener = (listener: (event: MediaQueryListEvent) => void) => { + listeners.add(listener) + } + + const handleRemoveListener = (listener: (event: MediaQueryListEvent) => void) => { + listeners.delete(listener) + } + + const handleAddEventListener = (_event: string, listener: EventListener) => { + if (typeof listener === 'function') + listeners.add(listener as (event: MediaQueryListEvent) => void) + } + + const handleRemoveEventListener = (_event: string, listener: EventListener) => { + if (typeof listener === 'function') + listeners.delete(listener as (event: MediaQueryListEvent) => void) + } + + const handleDispatchEvent = (event: Event) => { + listeners.forEach(listener => listener(event as MediaQueryListEvent)) + return true + } + const mediaQueryList: MediaQueryList = { matches, media: query, onchange: null, - addListener: (listener: MediaQueryListListener) => { - listeners.add(listener) - }, - removeListener: (listener: MediaQueryListListener) => { - listeners.delete(listener) - }, - addEventListener: (_event, listener: EventListener) => { - if (typeof listener === 'function') - listeners.add(listener as MediaQueryListListener) - }, - removeEventListener: (_event, listener: EventListener) => { - if (typeof listener === 'function') - listeners.delete(listener as MediaQueryListListener) - }, - dispatchEvent: (event: Event) => { - listeners.forEach(listener => listener(event as MediaQueryListEvent)) - return true - }, + addListener: handleAddListener, + removeListener: handleRemoveListener, + addEventListener: handleAddEventListener, + removeEventListener: handleRemoveEventListener, + dispatchEvent: handleDispatchEvent, } return mediaQueryList @@ -69,6 +79,121 @@ const setupMockEnvironment = (storedTheme: string | null, systemPrefersDark = fa jest.spyOn(window, 'matchMedia').mockImplementation(mockMatchMedia) } +// Helper function to create timing page component +const createTimingPageComponent = ( + timingData: Array<{ phase: string; timestamp: number; styles: { backgroundColor: string; color: string } }>, +) => { + const recordTiming = (phase: string, styles: { backgroundColor: string; color: string }) => { + timingData.push({ + phase, + timestamp: performance.now(), + styles, + }) + } + + const TimingPageComponent = () => { + const [mounted, setMounted] = useState(false) + const { theme } = useTheme() + const isDark = mounted ? theme === 'dark' : false + + const currentStyles = { + backgroundColor: isDark ? '#1f2937' : '#ffffff', + color: isDark ? '#ffffff' : '#000000', + } + + recordTiming(mounted ? 'CSR' : 'Initial', currentStyles) + + useEffect(() => { + setMounted(true) + }, []) + + return ( +
+
+ Phase: {mounted ? 'CSR' : 'Initial'} | Theme: {theme} | Visual: {isDark ? 'dark' : 'light'} +
+
+ ) + } + + return TimingPageComponent +} + +// Helper function to create CSS test component +const createCSSTestComponent = ( + cssStates: Array<{ className: string; timestamp: number }>, +) => { + const recordCSSState = (className: string) => { + cssStates.push({ + className, + timestamp: performance.now(), + }) + } + + const CSSTestComponent = () => { + const [mounted, setMounted] = useState(false) + const { theme } = useTheme() + const isDark = mounted ? theme === 'dark' : false + + const className = `min-h-screen ${isDark ? 'bg-gray-900 text-white' : 'bg-white text-black'}` + + recordCSSState(className) + + useEffect(() => { + setMounted(true) + }, []) + + return ( +
+
Classes: {className}
+
+ ) + } + + return CSSTestComponent +} + +// Helper function to create performance test component +const createPerformanceTestComponent = ( + performanceMarks: Array<{ event: string; timestamp: number }>, +) => { + const recordPerformanceMark = (event: string) => { + performanceMarks.push({ event, timestamp: performance.now() }) + } + + const PerformanceTestComponent = () => { + const [mounted, setMounted] = useState(false) + const { theme } = useTheme() + + recordPerformanceMark('component-render') + + useEffect(() => { + recordPerformanceMark('mount-start') + setMounted(true) + recordPerformanceMark('mount-complete') + }, []) + + useEffect(() => { + if (theme) + recordPerformanceMark('theme-available') + }, [theme]) + + return ( +
+ Mounted: {mounted.toString()} | Theme: {theme || 'loading'} +
+ ) + } + + return PerformanceTestComponent +} + // Simulate real page component based on Dify's actual theme usage const PageComponent = () => { const [mounted, setMounted] = useState(false) @@ -227,39 +352,7 @@ describe('Real Browser Environment Dark Mode Flicker Test', () => { setupMockEnvironment('dark') const timingData: Array<{ phase: string; timestamp: number; styles: any }> = [] - - const TimingPageComponent = () => { - const [mounted, setMounted] = useState(false) - const { theme } = useTheme() - const isDark = mounted ? theme === 'dark' : false - - // Record timing and styles for each render phase - const currentStyles = { - backgroundColor: isDark ? '#1f2937' : '#ffffff', - color: isDark ? '#ffffff' : '#000000', - } - - timingData.push({ - phase: mounted ? 'CSR' : 'Initial', - timestamp: performance.now(), - styles: currentStyles, - }) - - useEffect(() => { - setMounted(true) - }, []) - - return ( -
-
- Phase: {mounted ? 'CSR' : 'Initial'} | Theme: {theme} | Visual: {isDark ? 'dark' : 'light'} -
-
- ) - } + const TimingPageComponent = createTimingPageComponent(timingData) render( @@ -295,33 +388,7 @@ describe('Real Browser Environment Dark Mode Flicker Test', () => { setupMockEnvironment('dark') const cssStates: Array<{ className: string; timestamp: number }> = [] - - const CSSTestComponent = () => { - const [mounted, setMounted] = useState(false) - const { theme } = useTheme() - const isDark = mounted ? theme === 'dark' : false - - // Simulate Tailwind CSS class application - const className = `min-h-screen ${isDark ? 'bg-gray-900 text-white' : 'bg-white text-black'}` - - cssStates.push({ - className, - timestamp: performance.now(), - }) - - useEffect(() => { - setMounted(true) - }, []) - - return ( -
-
Classes: {className}
-
- ) - } + const CSSTestComponent = createCSSTestComponent(cssStates) render( @@ -413,34 +480,12 @@ describe('Real Browser Environment Dark Mode Flicker Test', () => { test('verifies ThemeProvider position fix reduces initialization delay', async () => { const performanceMarks: Array<{ event: string; timestamp: number }> = [] - const PerformanceTestComponent = () => { - const [mounted, setMounted] = useState(false) - const { theme } = useTheme() - - performanceMarks.push({ event: 'component-render', timestamp: performance.now() }) - - useEffect(() => { - performanceMarks.push({ event: 'mount-start', timestamp: performance.now() }) - setMounted(true) - performanceMarks.push({ event: 'mount-complete', timestamp: performance.now() }) - }, []) - - useEffect(() => { - if (theme) - performanceMarks.push({ event: 'theme-available', timestamp: performance.now() }) - }, [theme]) - - return ( -
- Mounted: {mounted.toString()} | Theme: {theme || 'loading'} -
- ) - } - setupMockEnvironment('dark') expect(window.localStorage.getItem('theme')).toBe('dark') + const PerformanceTestComponent = createPerformanceTestComponent(performanceMarks) + render( diff --git a/web/__tests__/unified-tags-logic.test.ts b/web/__tests__/unified-tags-logic.test.ts index c920e28e0a..ec73a6a268 100644 --- a/web/__tests__/unified-tags-logic.test.ts +++ b/web/__tests__/unified-tags-logic.test.ts @@ -70,14 +70,18 @@ describe('Unified Tags Editing - Pure Logic Tests', () => { }) describe('Fallback Logic (from layout-main.tsx)', () => { + type Tag = { id: string; name: string } + type AppDetail = { tags: Tag[] } + type FallbackResult = { tags?: Tag[] } | null + // no-op it('should trigger fallback when tags are missing or empty', () => { - const appDetailWithoutTags = { tags: [] } - const appDetailWithTags = { tags: [{ id: 'tag1' }] } - const appDetailWithUndefinedTags = { tags: undefined as any } + const appDetailWithoutTags: AppDetail = { tags: [] } + const appDetailWithTags: AppDetail = { tags: [{ id: 'tag1', name: 't' }] } + const appDetailWithUndefinedTags: { tags: Tag[] | undefined } = { tags: undefined } // This simulates the condition in layout-main.tsx - const shouldFallback1 = !appDetailWithoutTags.tags || appDetailWithoutTags.tags.length === 0 - const shouldFallback2 = !appDetailWithTags.tags || appDetailWithTags.tags.length === 0 + const shouldFallback1 = appDetailWithoutTags.tags.length === 0 + const shouldFallback2 = appDetailWithTags.tags.length === 0 const shouldFallback3 = !appDetailWithUndefinedTags.tags || appDetailWithUndefinedTags.tags.length === 0 expect(shouldFallback1).toBe(true) // Empty array should trigger fallback @@ -86,24 +90,26 @@ describe('Unified Tags Editing - Pure Logic Tests', () => { }) it('should preserve tags when fallback succeeds', () => { - const originalAppDetail = { tags: [] as any[] } - const fallbackResult = { tags: [{ id: 'tag1', name: 'fallback-tag' }] } + const originalAppDetail: AppDetail = { tags: [] } + const fallbackResult: { tags?: Tag[] } = { tags: [{ id: 'tag1', name: 'fallback-tag' }] } // This simulates the successful fallback in layout-main.tsx - if (fallbackResult?.tags) - originalAppDetail.tags = fallbackResult.tags + const tags = fallbackResult.tags + if (tags) + originalAppDetail.tags = tags expect(originalAppDetail.tags).toEqual(fallbackResult.tags) expect(originalAppDetail.tags.length).toBe(1) }) it('should continue with empty tags when fallback fails', () => { - const originalAppDetail: { tags: any[] } = { tags: [] } - const fallbackResult: { tags?: any[] } | null = null + const originalAppDetail: AppDetail = { tags: [] } + const fallbackResult = null as FallbackResult // This simulates fallback failure in layout-main.tsx - if (fallbackResult?.tags) - originalAppDetail.tags = fallbackResult.tags + const tags: Tag[] | undefined = fallbackResult && 'tags' in fallbackResult ? fallbackResult.tags : undefined + if (tags) + originalAppDetail.tags = tags expect(originalAppDetail.tags).toEqual([]) }) diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/config-popup.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/config-popup.tsx index e4c3f60c12..0ad02ad7f3 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/config-popup.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/config-popup.tsx @@ -73,7 +73,7 @@ const ConfigPopup: FC = ({ } }, [onChooseProvider]) - const handleConfigUpdated = useCallback((payload: ArizeConfig | PhoenixConfig | LangSmithConfig | LangFuseConfig | OpikConfig | WeaveConfig | AliyunConfig) => { + const handleConfigUpdated = useCallback((payload: ArizeConfig | PhoenixConfig | LangSmithConfig | LangFuseConfig | OpikConfig | WeaveConfig | AliyunConfig | TencentConfig) => { onConfigUpdated(currentProvider!, payload) hideConfigModal() }, [currentProvider, hideConfigModal, onConfigUpdated]) diff --git a/web/app/(shareLayout)/components/splash.tsx b/web/app/(shareLayout)/components/splash.tsx index c26ea7e045..16d291d4b4 100644 --- a/web/app/(shareLayout)/components/splash.tsx +++ b/web/app/(shareLayout)/components/splash.tsx @@ -6,7 +6,6 @@ import { useWebAppStore } from '@/context/web-app-context' import { useRouter, useSearchParams } from 'next/navigation' import AppUnavailable from '@/app/components/base/app-unavailable' import { useTranslation } from 'react-i18next' -import { AccessMode } from '@/models/access-control' import { webAppLoginStatus, webAppLogout } from '@/service/webapp-auth' import { fetchAccessToken } from '@/service/share' import Loading from '@/app/components/base/loading' @@ -35,7 +34,6 @@ const Splash: FC = ({ children }) => { router.replace(url) }, [getSigninUrl, router, webAppLogout, shareCode]) - const needCheckIsLogin = webAppAccessMode !== AccessMode.PUBLIC const [isLoading, setIsLoading] = useState(true) useEffect(() => { if (message) { @@ -58,8 +56,8 @@ const Splash: FC = ({ children }) => { } (async () => { - const { userLoggedIn, appLoggedIn } = await webAppLoginStatus(needCheckIsLogin, shareCode!) - + // if access mode is public, user login is always true, but the app login(passport) may be expired + const { userLoggedIn, appLoggedIn } = await webAppLoginStatus(shareCode!) if (userLoggedIn && appLoggedIn) { redirectOrFinish() } @@ -87,7 +85,6 @@ const Splash: FC = ({ children }) => { router, message, webAppAccessMode, - needCheckIsLogin, tokenFromUrl]) if (message) { diff --git a/web/app/components/base/action-button/index.stories.tsx b/web/app/components/base/action-button/index.stories.tsx new file mode 100644 index 0000000000..c174adbc73 --- /dev/null +++ b/web/app/components/base/action-button/index.stories.tsx @@ -0,0 +1,262 @@ +import type { Meta, StoryObj } from '@storybook/nextjs' +import { RiAddLine, RiDeleteBinLine, RiEditLine, RiMore2Fill, RiSaveLine, RiShareLine } from '@remixicon/react' +import ActionButton, { ActionButtonState } from '.' + +const meta = { + title: 'Base/ActionButton', + component: ActionButton, + parameters: { + layout: 'centered', + docs: { + description: { + component: 'Action button component with multiple sizes and states. Commonly used for toolbar actions and inline operations.', + }, + }, + }, + tags: ['autodocs'], + argTypes: { + size: { + control: 'select', + options: ['xs', 'm', 'l', 'xl'], + description: 'Button size', + }, + state: { + control: 'select', + options: [ + ActionButtonState.Default, + ActionButtonState.Active, + ActionButtonState.Disabled, + ActionButtonState.Destructive, + ActionButtonState.Hover, + ], + description: 'Button state', + }, + children: { + control: 'text', + description: 'Button content', + }, + disabled: { + control: 'boolean', + description: 'Native disabled state', + }, + }, +} satisfies Meta + +export default meta +type Story = StoryObj + +// Default state +export const Default: Story = { + args: { + size: 'm', + children: , + }, +} + +// With text +export const WithText: Story = { + args: { + size: 'm', + children: 'Edit', + }, +} + +// Icon with text +export const IconWithText: Story = { + args: { + size: 'm', + children: ( + <> + + Add Item + + ), + }, +} + +// Size variations +export const ExtraSmall: Story = { + args: { + size: 'xs', + children: , + }, +} + +export const Small: Story = { + args: { + size: 'xs', + children: , + }, +} + +export const Medium: Story = { + args: { + size: 'm', + children: , + }, +} + +export const Large: Story = { + args: { + size: 'l', + children: , + }, +} + +export const ExtraLarge: Story = { + args: { + size: 'xl', + children: , + }, +} + +// State variations +export const ActiveState: Story = { + args: { + size: 'm', + state: ActionButtonState.Active, + children: , + }, +} + +export const DisabledState: Story = { + args: { + size: 'm', + state: ActionButtonState.Disabled, + children: , + }, +} + +export const DestructiveState: Story = { + args: { + size: 'm', + state: ActionButtonState.Destructive, + children: , + }, +} + +export const HoverState: Story = { + args: { + size: 'm', + state: ActionButtonState.Hover, + children: , + }, +} + +// Real-world examples +export const ToolbarActions: Story = { + render: () => ( +
+ + + + + + + + + +
+ + + +
+ ), +} + +export const InlineActions: Story = { + render: () => ( +
+ Item name + + + + + + +
+ ), +} + +export const SizeComparison: Story = { + render: () => ( +
+
+ + + + XS +
+
+ + + + S +
+
+ + + + M +
+
+ + + + L +
+
+ + + + XL +
+
+ ), +} + +export const StateComparison: Story = { + render: () => ( +
+
+ + + + Default +
+
+ + + + Active +
+
+ + + + Hover +
+
+ + + + Disabled +
+
+ + + + Destructive +
+
+ ), +} + +// Interactive playground +export const Playground: Story = { + args: { + size: 'm', + state: ActionButtonState.Default, + children: , + }, +} diff --git a/web/app/components/base/auto-height-textarea/index.stories.tsx b/web/app/components/base/auto-height-textarea/index.stories.tsx new file mode 100644 index 0000000000..f083e4f56d --- /dev/null +++ b/web/app/components/base/auto-height-textarea/index.stories.tsx @@ -0,0 +1,204 @@ +import type { Meta, StoryObj } from '@storybook/nextjs' +import { useState } from 'react' +import AutoHeightTextarea from '.' + +const meta = { + title: 'Base/AutoHeightTextarea', + component: AutoHeightTextarea, + parameters: { + layout: 'centered', + docs: { + description: { + component: 'Auto-resizing textarea component that expands and contracts based on content, with configurable min/max height constraints.', + }, + }, + }, + tags: ['autodocs'], + argTypes: { + placeholder: { + control: 'text', + description: 'Placeholder text', + }, + value: { + control: 'text', + description: 'Textarea value', + }, + minHeight: { + control: 'number', + description: 'Minimum height in pixels', + }, + maxHeight: { + control: 'number', + description: 'Maximum height in pixels', + }, + autoFocus: { + control: 'boolean', + description: 'Auto focus on mount', + }, + className: { + control: 'text', + description: 'Additional CSS classes', + }, + wrapperClassName: { + control: 'text', + description: 'Wrapper CSS classes', + }, + }, +} satisfies Meta + +export default meta +type Story = StoryObj + +// Interactive demo wrapper +const AutoHeightTextareaDemo = (args: any) => { + const [value, setValue] = useState(args.value || '') + + return ( +
+ { + setValue(e.target.value) + console.log('Text changed:', e.target.value) + }} + /> +
+ ) +} + +// Default state +export const Default: Story = { + render: args => , + args: { + placeholder: 'Type something...', + value: '', + minHeight: 36, + maxHeight: 96, + className: 'w-full p-2 border border-gray-300 rounded-lg focus:outline-none focus:ring-2 focus:ring-blue-500', + }, +} + +// With initial value +export const WithInitialValue: Story = { + render: args => , + args: { + placeholder: 'Type something...', + value: 'This is a pre-filled textarea with some initial content.', + minHeight: 36, + maxHeight: 96, + className: 'w-full p-2 border border-gray-300 rounded-lg focus:outline-none focus:ring-2 focus:ring-blue-500', + }, +} + +// With multiline content +export const MultilineContent: Story = { + render: args => , + args: { + placeholder: 'Type something...', + value: 'Line 1\nLine 2\nLine 3\nLine 4\nThis textarea automatically expands to fit the content.', + minHeight: 36, + maxHeight: 96, + className: 'w-full p-2 border border-gray-300 rounded-lg focus:outline-none focus:ring-2 focus:ring-blue-500', + }, +} + +// Custom min height +export const CustomMinHeight: Story = { + render: args => , + args: { + placeholder: 'Taller minimum height...', + value: '', + minHeight: 100, + maxHeight: 200, + className: 'w-full p-2 border border-gray-300 rounded-lg focus:outline-none focus:ring-2 focus:ring-blue-500', + }, +} + +// Small max height (scrollable) +export const SmallMaxHeight: Story = { + render: args => , + args: { + placeholder: 'Type multiple lines...', + value: 'Line 1\nLine 2\nLine 3\nLine 4\nLine 5\nLine 6\nThis will become scrollable when it exceeds max height.', + minHeight: 36, + maxHeight: 80, + className: 'w-full p-2 border border-gray-300 rounded-lg focus:outline-none focus:ring-2 focus:ring-blue-500', + }, +} + +// Auto focus enabled +export const AutoFocus: Story = { + render: args => , + args: { + placeholder: 'This textarea auto-focuses on mount', + value: '', + minHeight: 36, + maxHeight: 96, + autoFocus: true, + className: 'w-full p-2 border border-gray-300 rounded-lg focus:outline-none focus:ring-2 focus:ring-blue-500', + }, +} + +// With custom styling +export const CustomStyling: Story = { + render: args => , + args: { + placeholder: 'Custom styled textarea...', + value: '', + minHeight: 50, + maxHeight: 150, + className: 'w-full p-3 bg-gray-50 border-2 border-blue-400 rounded-xl text-lg focus:outline-none focus:bg-white focus:border-blue-600', + wrapperClassName: 'shadow-lg', + }, +} + +// Long content example +export const LongContent: Story = { + render: args => , + args: { + placeholder: 'Type something...', + value: 'Lorem ipsum dolor sit amet, consectetur adipiscing elit. Sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.\n\nUt enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.\n\nDuis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur.\n\nExcepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.', + minHeight: 36, + maxHeight: 200, + className: 'w-full p-2 border border-gray-300 rounded-lg focus:outline-none focus:ring-2 focus:ring-blue-500', + }, +} + +// Real-world example - Chat input +export const ChatInput: Story = { + render: args => , + args: { + placeholder: 'Type your message...', + value: '', + minHeight: 40, + maxHeight: 120, + className: 'w-full px-4 py-2 bg-gray-100 border border-gray-300 rounded-2xl text-sm focus:outline-none focus:bg-white focus:ring-2 focus:ring-blue-500', + }, +} + +// Real-world example - Comment box +export const CommentBox: Story = { + render: args => , + args: { + placeholder: 'Write a comment...', + value: '', + minHeight: 60, + maxHeight: 200, + className: 'w-full p-3 border border-gray-300 rounded-lg text-sm focus:outline-none focus:ring-2 focus:ring-indigo-500', + }, +} + +// Interactive playground +export const Playground: Story = { + render: args => , + args: { + placeholder: 'Type something...', + value: '', + minHeight: 36, + maxHeight: 96, + autoFocus: false, + className: 'w-full p-2 border border-gray-300 rounded-lg focus:outline-none focus:ring-2 focus:ring-blue-500', + wrapperClassName: '', + }, +} diff --git a/web/app/components/base/auto-height-textarea/index.tsx b/web/app/components/base/auto-height-textarea/index.tsx index da412a176d..fb64bf9db4 100644 --- a/web/app/components/base/auto-height-textarea/index.tsx +++ b/web/app/components/base/auto-height-textarea/index.tsx @@ -31,7 +31,7 @@ const AutoHeightTextarea = ( onKeyDown, onKeyUp, }: IProps & { - ref: React.RefObject; + ref?: React.RefObject; }, ) => { // eslint-disable-next-line react-hooks/rules-of-hooks diff --git a/web/app/components/base/block-input/index.stories.tsx b/web/app/components/base/block-input/index.stories.tsx new file mode 100644 index 0000000000..0685f4150f --- /dev/null +++ b/web/app/components/base/block-input/index.stories.tsx @@ -0,0 +1,191 @@ +import type { Meta, StoryObj } from '@storybook/nextjs' +import { useState } from 'react' +import BlockInput from '.' + +const meta = { + title: 'Base/BlockInput', + component: BlockInput, + parameters: { + layout: 'centered', + docs: { + description: { + component: 'Block input component with variable highlighting. Supports {{variable}} syntax with validation and visual highlighting of variable names.', + }, + }, + }, + tags: ['autodocs'], + argTypes: { + value: { + control: 'text', + description: 'Input value (supports {{variable}} syntax)', + }, + className: { + control: 'text', + description: 'Wrapper CSS classes', + }, + highLightClassName: { + control: 'text', + description: 'CSS class for highlighted variables (default: text-blue-500)', + }, + readonly: { + control: 'boolean', + description: 'Read-only mode', + }, + }, +} satisfies Meta + +export default meta +type Story = StoryObj + +// Interactive demo wrapper +const BlockInputDemo = (args: any) => { + const [value, setValue] = useState(args.value || '') + const [keys, setKeys] = useState([]) + + return ( +
+ { + setValue(newValue) + setKeys(extractedKeys) + console.log('Value confirmed:', newValue) + console.log('Extracted keys:', extractedKeys) + }} + /> + {keys.length > 0 && ( +
+
Detected Variables:
+
+ {keys.map(key => ( + + {key} + + ))} +
+
+ )} +
+ ) +} + +// Default state +export const Default: Story = { + render: args => , + args: { + value: '', + readonly: false, + }, +} + +// With single variable +export const SingleVariable: Story = { + render: args => , + args: { + value: 'Hello {{name}}, welcome to the application!', + readonly: false, + }, +} + +// With multiple variables +export const MultipleVariables: Story = { + render: args => , + args: { + value: 'Dear {{user_name}},\n\nYour order {{order_id}} has been shipped to {{address}}.\n\nThank you for shopping with us!', + readonly: false, + }, +} + +// Complex template +export const ComplexTemplate: Story = { + render: args => , + args: { + value: 'Hi {{customer_name}},\n\nYour {{product_type}} subscription will renew on {{renewal_date}} for {{amount}}.\n\nYour payment method ending in {{card_last_4}} will be charged.\n\nQuestions? Contact us at {{support_email}}.', + readonly: false, + }, +} + +// Read-only mode +export const ReadOnlyMode: Story = { + render: args => , + args: { + value: 'This is a read-only template with {{variable1}} and {{variable2}}.\n\nYou cannot edit this content.', + readonly: true, + }, +} + +// Empty state +export const EmptyState: Story = { + render: args => , + args: { + value: '', + readonly: false, + }, +} + +// Long content +export const LongContent: Story = { + render: args => , + args: { + value: 'Dear {{recipient_name}},\n\nWe are writing to inform you about the upcoming changes to your {{service_name}} account.\n\nEffective {{effective_date}}, your plan will include:\n\n1. Access to {{feature_1}}\n2. {{feature_2}} with unlimited usage\n3. Priority support via {{support_channel}}\n4. Monthly reports sent to {{email_address}}\n\nYour new monthly rate will be {{new_price}}, compared to your current rate of {{old_price}}.\n\nIf you have any questions, please contact our team at {{contact_info}}.\n\nBest regards,\n{{company_name}} Team', + readonly: false, + }, +} + +// Variables with underscores +export const VariablesWithUnderscores: Story = { + render: args => , + args: { + value: 'User {{user_id}} from {{user_country}} has {{total_orders}} orders with status {{order_status}}.', + readonly: false, + }, +} + +// Adjacent variables +export const AdjacentVariables: Story = { + render: args => , + args: { + value: 'File: {{file_name}}.{{file_extension}} ({{file_size}}{{size_unit}})', + readonly: false, + }, +} + +// Real-world example - Email template +export const EmailTemplate: Story = { + render: args => , + args: { + value: 'Subject: Your {{service_name}} account has been created\n\nHi {{first_name}},\n\nWelcome to {{company_name}}! Your account is now active.\n\nUsername: {{username}}\nEmail: {{email}}\n\nGet started at {{app_url}}\n\nThanks,\nThe {{company_name}} Team', + readonly: false, + }, +} + +// Real-world example - Notification template +export const NotificationTemplate: Story = { + render: args => , + args: { + value: '🔔 {{user_name}} mentioned you in {{channel_name}}\n\n"{{message_preview}}"\n\nReply now: {{message_url}}', + readonly: false, + }, +} + +// Custom styling +export const CustomStyling: Story = { + render: args => , + args: { + value: 'This template uses {{custom_variable}} with custom styling.', + readonly: false, + className: 'bg-gray-50 border-2 border-blue-200', + }, +} + +// Interactive playground +export const Playground: Story = { + render: args => , + args: { + value: 'Try editing this text and adding variables like {{example}}', + readonly: false, + className: '', + highLightClassName: '', + }, +} diff --git a/web/app/components/base/chat/chat/answer/index.stories.tsx b/web/app/components/base/chat/chat/answer/index.stories.tsx index 1f45844ec4..a83c0fea61 100644 --- a/web/app/components/base/chat/chat/answer/index.stories.tsx +++ b/web/app/components/base/chat/chat/answer/index.stories.tsx @@ -1,7 +1,5 @@ import type { Meta, StoryObj } from '@storybook/nextjs' - import type { ChatItem } from '../../types' -import { mockedWorkflowProcess } from './__mocks__/workflowProcess' import { markdownContent } from './__mocks__/markdownContent' import { markdownContentSVG } from './__mocks__/markdownContentSVG' import Answer from '.' @@ -34,6 +32,11 @@ const mockedBaseChatItem = { content: 'Hello, how can I assist you today?', } satisfies ChatItem +const mockedWorkflowProcess = { + status: 'succeeded', + tracing: [], +} + export const Basic: Story = { args: { item: mockedBaseChatItem, diff --git a/web/app/components/base/chat/chat/type.ts b/web/app/components/base/chat/chat/type.ts index a9e2ada262..d4cf460884 100644 --- a/web/app/components/base/chat/chat/type.ts +++ b/web/app/components/base/chat/chat/type.ts @@ -17,11 +17,11 @@ export type FeedbackType = { export type FeedbackFunc = ( messageId: string, - feedback: FeedbackType + feedback: FeedbackType, ) => Promise export type SubmitAnnotationFunc = ( messageId: string, - content: string + content: string, ) => Promise export type DisplayScene = 'web' | 'console' diff --git a/web/app/components/base/checkbox/index.stories.tsx b/web/app/components/base/checkbox/index.stories.tsx new file mode 100644 index 0000000000..65fa8e1b97 --- /dev/null +++ b/web/app/components/base/checkbox/index.stories.tsx @@ -0,0 +1,394 @@ +import type { Meta, StoryObj } from '@storybook/nextjs' +import { useState } from 'react' +import Checkbox from '.' + +// Helper function for toggling items in an array +const createToggleItem = ( + items: T[], + setItems: (items: T[]) => void, +) => (id: string) => { + setItems(items.map(item => + item.id === id ? { ...item, checked: !item.checked } as T : item, + )) +} + +const meta = { + title: 'Base/Checkbox', + component: Checkbox, + parameters: { + layout: 'centered', + docs: { + description: { + component: 'Checkbox component with support for checked, unchecked, indeterminate, and disabled states.', + }, + }, + }, + tags: ['autodocs'], + argTypes: { + checked: { + control: 'boolean', + description: 'Checked state', + }, + indeterminate: { + control: 'boolean', + description: 'Indeterminate state (partially checked)', + }, + disabled: { + control: 'boolean', + description: 'Disabled state', + }, + className: { + control: 'text', + description: 'Additional CSS classes', + }, + id: { + control: 'text', + description: 'HTML id attribute', + }, + }, +} satisfies Meta + +export default meta +type Story = StoryObj + +// Interactive demo wrapper +const CheckboxDemo = (args: any) => { + const [checked, setChecked] = useState(args.checked || false) + + return ( +
+ { + if (!args.disabled) { + setChecked(!checked) + console.log('Checkbox toggled:', !checked) + } + }} + /> + + {checked ? 'Checked' : 'Unchecked'} + +
+ ) +} + +// Default unchecked +export const Default: Story = { + render: args => , + args: { + checked: false, + disabled: false, + indeterminate: false, + }, +} + +// Checked state +export const Checked: Story = { + render: args => , + args: { + checked: true, + disabled: false, + indeterminate: false, + }, +} + +// Indeterminate state +export const Indeterminate: Story = { + render: args => , + args: { + checked: false, + disabled: false, + indeterminate: true, + }, +} + +// Disabled unchecked +export const DisabledUnchecked: Story = { + render: args => , + args: { + checked: false, + disabled: true, + indeterminate: false, + }, +} + +// Disabled checked +export const DisabledChecked: Story = { + render: args => , + args: { + checked: true, + disabled: true, + indeterminate: false, + }, +} + +// Disabled indeterminate +export const DisabledIndeterminate: Story = { + render: args => , + args: { + checked: false, + disabled: true, + indeterminate: true, + }, +} + +// State comparison +export const StateComparison: Story = { + render: () => ( +
+
+
+ undefined} /> + Unchecked +
+
+ undefined} /> + Checked +
+
+ undefined} /> + Indeterminate +
+
+
+
+ undefined} /> + Disabled +
+
+ undefined} /> + Disabled Checked +
+
+ undefined} /> + Disabled Indeterminate +
+
+
+ ), +} + +// With labels +const WithLabelsDemo = () => { + const [items, setItems] = useState([ + { id: '1', label: 'Enable notifications', checked: true }, + { id: '2', label: 'Enable email updates', checked: false }, + { id: '3', label: 'Enable SMS alerts', checked: false }, + ]) + + const toggleItem = createToggleItem(items, setItems) + + return ( +
+ {items.map(item => ( +
+ toggleItem(item.id)} + /> + +
+ ))} +
+ ) +} + +export const WithLabels: Story = { + render: () => , +} + +// Select all example +const SelectAllExampleDemo = () => { + const [items, setItems] = useState([ + { id: '1', label: 'Item 1', checked: false }, + { id: '2', label: 'Item 2', checked: false }, + { id: '3', label: 'Item 3', checked: false }, + ]) + + const allChecked = items.every(item => item.checked) + const someChecked = items.some(item => item.checked) + const indeterminate = someChecked && !allChecked + + const toggleAll = () => { + const newChecked = !allChecked + setItems(items.map(item => ({ ...item, checked: newChecked }))) + } + + const toggleItem = createToggleItem(items, setItems) + + return ( +
+
+ + Select All +
+
+ {items.map(item => ( +
+ toggleItem(item.id)} + /> + +
+ ))} +
+
+ ) +} + +export const SelectAllExample: Story = { + render: () => , +} + +// Form example +const FormExampleDemo = () => { + const [formData, setFormData] = useState({ + terms: false, + newsletter: false, + privacy: false, + }) + + return ( +
+

Account Settings

+
+
+ setFormData({ ...formData, terms: !formData.terms })} + /> +
+ +

+ Required to continue +

+
+
+
+ setFormData({ ...formData, newsletter: !formData.newsletter })} + /> +
+ +

+ Get updates about new features +

+
+
+
+ setFormData({ ...formData, privacy: !formData.privacy })} + /> +
+ +

+ Required to continue +

+
+
+
+
+ ) +} + +export const FormExample: Story = { + render: () => , +} + +// Task list example +const TaskListExampleDemo = () => { + const [tasks, setTasks] = useState([ + { id: '1', title: 'Review pull request', completed: true }, + { id: '2', title: 'Update documentation', completed: true }, + { id: '3', title: 'Fix navigation bug', completed: false }, + { id: '4', title: 'Deploy to staging', completed: false }, + ]) + + const toggleTask = (id: string) => { + setTasks(tasks.map(task => + task.id === id ? { ...task, completed: !task.completed } : task, + )) + } + + const completedCount = tasks.filter(t => t.completed).length + + return ( +
+
+

Today's Tasks

+ + {completedCount} of {tasks.length} completed + +
+
+ {tasks.map(task => ( +
+ toggleTask(task.id)} + /> + toggleTask(task.id)} + > + {task.title} + +
+ ))} +
+
+ ) +} + +export const TaskListExample: Story = { + render: () => , +} + +// Interactive playground +export const Playground: Story = { + render: args => , + args: { + checked: false, + indeterminate: false, + disabled: false, + id: 'playground-checkbox', + }, +} diff --git a/web/app/components/base/input-number/index.stories.tsx b/web/app/components/base/input-number/index.stories.tsx new file mode 100644 index 0000000000..0fca2e52f9 --- /dev/null +++ b/web/app/components/base/input-number/index.stories.tsx @@ -0,0 +1,438 @@ +import type { Meta, StoryObj } from '@storybook/nextjs' +import { useState } from 'react' +import { InputNumber } from '.' + +const meta = { + title: 'Base/InputNumber', + component: InputNumber, + parameters: { + layout: 'centered', + docs: { + description: { + component: 'Number input component with increment/decrement buttons. Supports min/max constraints, custom step amounts, and units display.', + }, + }, + }, + tags: ['autodocs'], + argTypes: { + value: { + control: 'number', + description: 'Current value', + }, + size: { + control: 'select', + options: ['regular', 'large'], + description: 'Input size', + }, + min: { + control: 'number', + description: 'Minimum value', + }, + max: { + control: 'number', + description: 'Maximum value', + }, + amount: { + control: 'number', + description: 'Step amount for increment/decrement', + }, + unit: { + control: 'text', + description: 'Unit text displayed (e.g., "px", "ms")', + }, + disabled: { + control: 'boolean', + description: 'Disabled state', + }, + defaultValue: { + control: 'number', + description: 'Default value when undefined', + }, + }, +} satisfies Meta + +export default meta +type Story = StoryObj + +// Interactive demo wrapper +const InputNumberDemo = (args: any) => { + const [value, setValue] = useState(args.value ?? 0) + + return ( +
+ { + setValue(newValue) + console.log('Value changed:', newValue) + }} + /> +
+ Current value: {value} +
+
+ ) +} + +// Default state +export const Default: Story = { + render: args => , + args: { + value: 0, + size: 'regular', + }, +} + +// Large size +export const LargeSize: Story = { + render: args => , + args: { + value: 10, + size: 'large', + }, +} + +// With min/max constraints +export const WithMinMax: Story = { + render: args => , + args: { + value: 5, + min: 0, + max: 10, + size: 'regular', + }, +} + +// With custom step amount +export const CustomStepAmount: Story = { + render: args => , + args: { + value: 50, + amount: 5, + min: 0, + max: 100, + size: 'regular', + }, +} + +// With unit +export const WithUnit: Story = { + render: args => , + args: { + value: 100, + unit: 'px', + min: 0, + max: 1000, + amount: 10, + size: 'regular', + }, +} + +// Disabled state +export const Disabled: Story = { + render: args => , + args: { + value: 42, + disabled: true, + size: 'regular', + }, +} + +// Decimal values +export const DecimalValues: Story = { + render: args => , + args: { + value: 2.5, + amount: 0.5, + min: 0, + max: 10, + size: 'regular', + }, +} + +// Negative values allowed +export const NegativeValues: Story = { + render: args => , + args: { + value: 0, + min: -100, + max: 100, + amount: 10, + size: 'regular', + }, +} + +// Size comparison +const SizeComparisonDemo = () => { + const [regularValue, setRegularValue] = useState(10) + const [largeValue, setLargeValue] = useState(20) + + return ( +
+
+ + +
+
+ + +
+
+ ) +} + +export const SizeComparison: Story = { + render: () => , +} + +// Real-world example - Font size picker +const FontSizePickerDemo = () => { + const [fontSize, setFontSize] = useState(16) + + return ( +
+
+
+ + +
+
+

+ Preview Text +

+
+
+
+ ) +} + +export const FontSizePicker: Story = { + render: () => , +} + +// Real-world example - Quantity selector +const QuantitySelectorDemo = () => { + const [quantity, setQuantity] = useState(1) + const pricePerItem = 29.99 + const total = (quantity * pricePerItem).toFixed(2) + + return ( +
+
+
+
+

Product Name

+

${pricePerItem} each

+
+
+
+ + +
+
+
+ Total + ${total} +
+
+
+
+ ) +} + +export const QuantitySelector: Story = { + render: () => , +} + +// Real-world example - Timer settings +const TimerSettingsDemo = () => { + const [hours, setHours] = useState(0) + const [minutes, setMinutes] = useState(15) + const [seconds, setSeconds] = useState(30) + + const totalSeconds = hours * 3600 + minutes * 60 + seconds + + return ( +
+

Timer Configuration

+
+
+ + +
+
+ + +
+
+ + +
+
+
+ Total duration: {totalSeconds} seconds +
+
+
+
+ ) +} + +export const TimerSettings: Story = { + render: () => , +} + +// Real-world example - Animation settings +const AnimationSettingsDemo = () => { + const [duration, setDuration] = useState(300) + const [delay, setDelay] = useState(0) + const [iterations, setIterations] = useState(1) + + return ( +
+

Animation Properties

+
+
+ + +
+
+ + +
+
+ + +
+
+
+ animation: {duration}ms {delay}ms {iterations} +
+
+
+
+ ) +} + +export const AnimationSettings: Story = { + render: () => , +} + +// Real-world example - Temperature control +const TemperatureControlDemo = () => { + const [temperature, setTemperature] = useState(20) + const fahrenheit = ((temperature * 9) / 5 + 32).toFixed(1) + + return ( +
+

Temperature Control

+
+
+ + +
+
+
+
Celsius
+
{temperature}°C
+
+
+
Fahrenheit
+
{fahrenheit}°F
+
+
+
+
+ ) +} + +export const TemperatureControl: Story = { + render: () => , +} + +// Interactive playground +export const Playground: Story = { + render: args => , + args: { + value: 10, + size: 'regular', + min: 0, + max: 100, + amount: 1, + unit: '', + disabled: false, + defaultValue: 0, + }, +} diff --git a/web/app/components/base/input/index.stories.tsx b/web/app/components/base/input/index.stories.tsx new file mode 100644 index 0000000000..cd857bc180 --- /dev/null +++ b/web/app/components/base/input/index.stories.tsx @@ -0,0 +1,424 @@ +import type { Meta, StoryObj } from '@storybook/nextjs' +import { useState } from 'react' +import Input from '.' + +const meta = { + title: 'Base/Input', + component: Input, + parameters: { + layout: 'centered', + docs: { + description: { + component: 'Input component with support for icons, clear button, validation states, and units. Includes automatic leading zero removal for number inputs.', + }, + }, + }, + tags: ['autodocs'], + argTypes: { + size: { + control: 'select', + options: ['regular', 'large'], + description: 'Input size', + }, + type: { + control: 'select', + options: ['text', 'number', 'email', 'password', 'url', 'tel'], + description: 'Input type', + }, + placeholder: { + control: 'text', + description: 'Placeholder text', + }, + disabled: { + control: 'boolean', + description: 'Disabled state', + }, + destructive: { + control: 'boolean', + description: 'Error/destructive state', + }, + showLeftIcon: { + control: 'boolean', + description: 'Show search icon on left', + }, + showClearIcon: { + control: 'boolean', + description: 'Show clear button when input has value', + }, + unit: { + control: 'text', + description: 'Unit text displayed on right (e.g., "px", "ms")', + }, + }, +} satisfies Meta + +export default meta +type Story = StoryObj + +// Interactive demo wrapper +const InputDemo = (args: any) => { + const [value, setValue] = useState(args.value || '') + + return ( +
+ { + setValue(e.target.value) + console.log('Input changed:', e.target.value) + }} + onClear={() => { + setValue('') + console.log('Input cleared') + }} + /> +
+ ) +} + +// Default state +export const Default: Story = { + render: args => , + args: { + size: 'regular', + placeholder: 'Enter text...', + type: 'text', + }, +} + +// Large size +export const LargeSize: Story = { + render: args => , + args: { + size: 'large', + placeholder: 'Enter text...', + type: 'text', + }, +} + +// With search icon +export const WithSearchIcon: Story = { + render: args => , + args: { + size: 'regular', + showLeftIcon: true, + placeholder: 'Search...', + type: 'text', + }, +} + +// With clear button +export const WithClearButton: Story = { + render: args => , + args: { + size: 'regular', + showClearIcon: true, + value: 'Some text to clear', + placeholder: 'Type something...', + type: 'text', + }, +} + +// Search input (icon + clear) +export const SearchInput: Story = { + render: args => , + args: { + size: 'regular', + showLeftIcon: true, + showClearIcon: true, + value: '', + placeholder: 'Search...', + type: 'text', + }, +} + +// Disabled state +export const Disabled: Story = { + render: args => , + args: { + size: 'regular', + value: 'Disabled input', + disabled: true, + type: 'text', + }, +} + +// Destructive/error state +export const DestructiveState: Story = { + render: args => , + args: { + size: 'regular', + value: 'invalid@email', + destructive: true, + placeholder: 'Enter email...', + type: 'email', + }, +} + +// Number input +export const NumberInput: Story = { + render: args => , + args: { + size: 'regular', + type: 'number', + placeholder: 'Enter a number...', + value: '0', + }, +} + +// With unit +export const WithUnit: Story = { + render: args => , + args: { + size: 'regular', + type: 'number', + value: '100', + unit: 'px', + placeholder: 'Enter value...', + }, +} + +// Email input +export const EmailInput: Story = { + render: args => , + args: { + size: 'regular', + type: 'email', + placeholder: 'Enter your email...', + showClearIcon: true, + }, +} + +// Password input +export const PasswordInput: Story = { + render: args => , + args: { + size: 'regular', + type: 'password', + placeholder: 'Enter password...', + value: 'secret123', + }, +} + +// Size comparison +const SizeComparisonDemo = () => { + const [regularValue, setRegularValue] = useState('') + const [largeValue, setLargeValue] = useState('') + + return ( +
+
+ + setRegularValue(e.target.value)} + placeholder="Regular input..." + showClearIcon + onClear={() => setRegularValue('')} + /> +
+
+ + setLargeValue(e.target.value)} + placeholder="Large input..." + showClearIcon + onClear={() => setLargeValue('')} + /> +
+
+ ) +} + +export const SizeComparison: Story = { + render: () => , +} + +// State comparison +const StateComparisonDemo = () => { + const [normalValue, setNormalValue] = useState('Normal state') + const [errorValue, setErrorValue] = useState('Error state') + + return ( +
+
+ + setNormalValue(e.target.value)} + showClearIcon + onClear={() => setNormalValue('')} + /> +
+
+ + setErrorValue(e.target.value)} + destructive + /> +
+
+ + undefined} + disabled + /> +
+
+ ) +} + +export const StateComparison: Story = { + render: () => , +} + +// Form example +const FormExampleDemo = () => { + const [formData, setFormData] = useState({ + name: '', + email: '', + age: '', + website: '', + }) + const [errors, setErrors] = useState({ + email: false, + age: false, + }) + + const validateEmail = (email: string) => { + return /^[^\s@]+@[^\s@]+\.[^\s@]+$/.test(email) + } + + return ( +
+

User Profile

+
+
+ + setFormData({ ...formData, name: e.target.value })} + placeholder="Enter your name..." + showClearIcon + onClear={() => setFormData({ ...formData, name: '' })} + /> +
+
+ + { + setFormData({ ...formData, email: e.target.value }) + setErrors({ ...errors, email: e.target.value ? !validateEmail(e.target.value) : false }) + }} + placeholder="Enter your email..." + destructive={errors.email} + showClearIcon + onClear={() => { + setFormData({ ...formData, email: '' }) + setErrors({ ...errors, email: false }) + }} + /> + {errors.email && ( + Please enter a valid email address + )} +
+
+ + { + setFormData({ ...formData, age: e.target.value }) + setErrors({ ...errors, age: e.target.value ? Number(e.target.value) < 18 : false }) + }} + placeholder="Enter your age..." + destructive={errors.age} + unit="years" + /> + {errors.age && ( + Must be 18 or older + )} +
+
+ + setFormData({ ...formData, website: e.target.value })} + placeholder="https://example.com" + showClearIcon + onClear={() => setFormData({ ...formData, website: '' })} + /> +
+
+
+ ) +} + +export const FormExample: Story = { + render: () => , +} + +// Search example +const SearchExampleDemo = () => { + const [searchQuery, setSearchQuery] = useState('') + const items = ['Apple', 'Banana', 'Cherry', 'Date', 'Elderberry', 'Fig', 'Grape'] + const filteredItems = items.filter(item => + item.toLowerCase().includes(searchQuery.toLowerCase()), + ) + + return ( +
+ setSearchQuery(e.target.value)} + onClear={() => setSearchQuery('')} + placeholder="Search fruits..." + /> + {searchQuery && ( +
+
+ {filteredItems.length} result{filteredItems.length !== 1 ? 's' : ''} +
+
+ {filteredItems.map(item => ( +
+ {item} +
+ ))} +
+
+ )} +
+ ) +} + +export const SearchExample: Story = { + render: () => , +} + +// Interactive playground +export const Playground: Story = { + render: args => , + args: { + size: 'regular', + type: 'text', + placeholder: 'Type something...', + disabled: false, + destructive: false, + showLeftIcon: false, + showClearIcon: true, + unit: '', + }, +} diff --git a/web/app/components/base/prompt-editor/index.stories.tsx b/web/app/components/base/prompt-editor/index.stories.tsx new file mode 100644 index 0000000000..17b04e4af0 --- /dev/null +++ b/web/app/components/base/prompt-editor/index.stories.tsx @@ -0,0 +1,360 @@ +import type { Meta, StoryObj } from '@storybook/nextjs' +import { useState } from 'react' + +// Mock component to avoid complex initialization issues +const PromptEditorMock = ({ value, onChange, placeholder, editable, compact, className, wrapperClassName }: any) => { + const [content, setContent] = useState(value || '') + + const handleChange = (e: React.ChangeEvent) => { + setContent(e.target.value) + onChange?.(e.target.value) + } + + return ( +
+