Merge branch 'main' into feat/memory-orchestration-fed

This commit is contained in:
zxhlyh 2025-10-21 13:01:37 +08:00
commit c8188274a2
375 changed files with 65700 additions and 43160 deletions

View File

@ -145,7 +145,7 @@ class DatabaseConfig(BaseSettings):
default="postgresql",
)
@computed_field # type: ignore[misc]
@computed_field # type: ignore[prop-decorator]
@property
def SQLALCHEMY_DATABASE_URI(self) -> str:
db_extras = (
@ -198,7 +198,7 @@ class DatabaseConfig(BaseSettings):
default=os.cpu_count() or 1,
)
@computed_field # type: ignore[misc]
@computed_field # type: ignore[prop-decorator]
@property
def SQLALCHEMY_ENGINE_OPTIONS(self) -> dict[str, Any]:
# Parse DB_EXTRAS for 'options'

View File

@ -24,7 +24,7 @@ except ImportError:
)
else:
warnings.warn("To use python-magic guess MIMETYPE, you need to install `libmagic`", stacklevel=2)
magic = None # type: ignore
magic = None # type: ignore[assignment]
from pydantic import BaseModel

View File

@ -211,8 +211,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
user=user,
stream=streaming,
)
# FIXME: Type hinting issue here, ignore it for now, will fix it later
return AgentChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) # type: ignore
return AgentChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
def _generate_worker(
self,

View File

@ -89,7 +89,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data)
elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse):
response_chunk.update(sub_stream_response.to_ignore_detail_dict()) # ty: ignore [unresolved-attribute]
response_chunk.update(sub_stream_response.to_ignore_detail_dict())
else:
response_chunk.update(sub_stream_response.model_dump(mode="json"))
yield response_chunk

View File

@ -98,7 +98,7 @@ class RateLimit:
else:
return RateLimitGenerator(
rate_limit=self,
generator=generator, # ty: ignore [invalid-argument-type]
generator=generator,
request_id=request_id,
)

View File

@ -49,7 +49,7 @@ class BasedGenerateTaskPipeline:
if isinstance(e, InvokeAuthorizationError):
err = InvokeAuthorizationError("Incorrect API key provided")
elif isinstance(e, InvokeError | ValueError):
err = e # ty: ignore [invalid-assignment]
err = e
else:
description = getattr(e, "description", None)
err = Exception(description if description is not None else str(e))

View File

@ -1868,7 +1868,7 @@ class ProviderConfigurations(BaseModel):
if "/" not in key:
key = str(ModelProviderID(key))
return self.configurations.get(key, default) # type: ignore
return self.configurations.get(key, default)
class ProviderModelBundle(BaseModel):

View File

@ -20,7 +20,7 @@ def import_module_from_source(*, module_name: str, py_file_path: AnyStr, use_laz
else:
# Refer to: https://docs.python.org/3/library/importlib.html#importing-a-source-file-directly
# FIXME: mypy does not support the type of spec.loader
spec = importlib.util.spec_from_file_location(module_name, py_file_path) # type: ignore
spec = importlib.util.spec_from_file_location(module_name, py_file_path) # type: ignore[assignment]
if not spec or not spec.loader:
raise Exception(f"Failed to load module {module_name} from {py_file_path!r}")
if use_lazy_loader:

View File

@ -49,62 +49,80 @@ class IndexingRunner:
self.storage = storage
self.model_manager = ModelManager()
def _handle_indexing_error(self, document_id: str, error: Exception) -> None:
"""Handle indexing errors by updating document status."""
logger.exception("consume document failed")
document = db.session.get(DatasetDocument, document_id)
if document:
document.indexing_status = "error"
error_message = getattr(error, "description", str(error))
document.error = str(error_message)
document.stopped_at = naive_utc_now()
db.session.commit()
def run(self, dataset_documents: list[DatasetDocument]):
"""Run the indexing process."""
for dataset_document in dataset_documents:
document_id = dataset_document.id
try:
# Re-query the document to ensure it's bound to the current session
requeried_document = db.session.get(DatasetDocument, document_id)
if not requeried_document:
logger.warning("Document not found, skipping document id: %s", document_id)
continue
# get dataset
dataset = db.session.query(Dataset).filter_by(id=dataset_document.dataset_id).first()
dataset = db.session.query(Dataset).filter_by(id=requeried_document.dataset_id).first()
if not dataset:
raise ValueError("no dataset found")
# get the process rule
stmt = select(DatasetProcessRule).where(
DatasetProcessRule.id == dataset_document.dataset_process_rule_id
DatasetProcessRule.id == requeried_document.dataset_process_rule_id
)
processing_rule = db.session.scalar(stmt)
if not processing_rule:
raise ValueError("no process rule found")
index_type = dataset_document.doc_form
index_type = requeried_document.doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
# extract
text_docs = self._extract(index_processor, dataset_document, processing_rule.to_dict())
text_docs = self._extract(index_processor, requeried_document, processing_rule.to_dict())
# transform
documents = self._transform(
index_processor, dataset, text_docs, dataset_document.doc_language, processing_rule.to_dict()
index_processor, dataset, text_docs, requeried_document.doc_language, processing_rule.to_dict()
)
# save segment
self._load_segments(dataset, dataset_document, documents)
self._load_segments(dataset, requeried_document, documents)
# load
self._load(
index_processor=index_processor,
dataset=dataset,
dataset_document=dataset_document,
dataset_document=requeried_document,
documents=documents,
)
except DocumentIsPausedError:
raise DocumentIsPausedError(f"Document paused, document id: {dataset_document.id}")
raise DocumentIsPausedError(f"Document paused, document id: {document_id}")
except ProviderTokenNotInitError as e:
dataset_document.indexing_status = "error"
dataset_document.error = str(e.description)
dataset_document.stopped_at = naive_utc_now()
db.session.commit()
self._handle_indexing_error(document_id, e)
except ObjectDeletedError:
logger.warning("Document deleted, document id: %s", dataset_document.id)
logger.warning("Document deleted, document id: %s", document_id)
except Exception as e:
logger.exception("consume document failed")
dataset_document.indexing_status = "error"
dataset_document.error = str(e)
dataset_document.stopped_at = naive_utc_now()
db.session.commit()
self._handle_indexing_error(document_id, e)
def run_in_splitting_status(self, dataset_document: DatasetDocument):
"""Run the indexing process when the index_status is splitting."""
document_id = dataset_document.id
try:
# Re-query the document to ensure it's bound to the current session
requeried_document = db.session.get(DatasetDocument, document_id)
if not requeried_document:
logger.warning("Document not found: %s", document_id)
return
# get dataset
dataset = db.session.query(Dataset).filter_by(id=dataset_document.dataset_id).first()
dataset = db.session.query(Dataset).filter_by(id=requeried_document.dataset_id).first()
if not dataset:
raise ValueError("no dataset found")
@ -112,57 +130,60 @@ class IndexingRunner:
# get exist document_segment list and delete
document_segments = (
db.session.query(DocumentSegment)
.filter_by(dataset_id=dataset.id, document_id=dataset_document.id)
.filter_by(dataset_id=dataset.id, document_id=requeried_document.id)
.all()
)
for document_segment in document_segments:
db.session.delete(document_segment)
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
if requeried_document.doc_form == IndexType.PARENT_CHILD_INDEX:
# delete child chunks
db.session.query(ChildChunk).where(ChildChunk.segment_id == document_segment.id).delete()
db.session.commit()
# get the process rule
stmt = select(DatasetProcessRule).where(DatasetProcessRule.id == dataset_document.dataset_process_rule_id)
stmt = select(DatasetProcessRule).where(DatasetProcessRule.id == requeried_document.dataset_process_rule_id)
processing_rule = db.session.scalar(stmt)
if not processing_rule:
raise ValueError("no process rule found")
index_type = dataset_document.doc_form
index_type = requeried_document.doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
# extract
text_docs = self._extract(index_processor, dataset_document, processing_rule.to_dict())
text_docs = self._extract(index_processor, requeried_document, processing_rule.to_dict())
# transform
documents = self._transform(
index_processor, dataset, text_docs, dataset_document.doc_language, processing_rule.to_dict()
index_processor, dataset, text_docs, requeried_document.doc_language, processing_rule.to_dict()
)
# save segment
self._load_segments(dataset, dataset_document, documents)
self._load_segments(dataset, requeried_document, documents)
# load
self._load(
index_processor=index_processor, dataset=dataset, dataset_document=dataset_document, documents=documents
index_processor=index_processor,
dataset=dataset,
dataset_document=requeried_document,
documents=documents,
)
except DocumentIsPausedError:
raise DocumentIsPausedError(f"Document paused, document id: {dataset_document.id}")
raise DocumentIsPausedError(f"Document paused, document id: {document_id}")
except ProviderTokenNotInitError as e:
dataset_document.indexing_status = "error"
dataset_document.error = str(e.description)
dataset_document.stopped_at = naive_utc_now()
db.session.commit()
self._handle_indexing_error(document_id, e)
except Exception as e:
logger.exception("consume document failed")
dataset_document.indexing_status = "error"
dataset_document.error = str(e)
dataset_document.stopped_at = naive_utc_now()
db.session.commit()
self._handle_indexing_error(document_id, e)
def run_in_indexing_status(self, dataset_document: DatasetDocument):
"""Run the indexing process when the index_status is indexing."""
document_id = dataset_document.id
try:
# Re-query the document to ensure it's bound to the current session
requeried_document = db.session.get(DatasetDocument, document_id)
if not requeried_document:
logger.warning("Document not found: %s", document_id)
return
# get dataset
dataset = db.session.query(Dataset).filter_by(id=dataset_document.dataset_id).first()
dataset = db.session.query(Dataset).filter_by(id=requeried_document.dataset_id).first()
if not dataset:
raise ValueError("no dataset found")
@ -170,7 +191,7 @@ class IndexingRunner:
# get exist document_segment list and delete
document_segments = (
db.session.query(DocumentSegment)
.filter_by(dataset_id=dataset.id, document_id=dataset_document.id)
.filter_by(dataset_id=dataset.id, document_id=requeried_document.id)
.all()
)
@ -188,7 +209,7 @@ class IndexingRunner:
"dataset_id": document_segment.dataset_id,
},
)
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
if requeried_document.doc_form == IndexType.PARENT_CHILD_INDEX:
child_chunks = document_segment.get_child_chunks()
if child_chunks:
child_documents = []
@ -206,24 +227,20 @@ class IndexingRunner:
document.children = child_documents
documents.append(document)
# build index
index_type = dataset_document.doc_form
index_type = requeried_document.doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
self._load(
index_processor=index_processor, dataset=dataset, dataset_document=dataset_document, documents=documents
index_processor=index_processor,
dataset=dataset,
dataset_document=requeried_document,
documents=documents,
)
except DocumentIsPausedError:
raise DocumentIsPausedError(f"Document paused, document id: {dataset_document.id}")
raise DocumentIsPausedError(f"Document paused, document id: {document_id}")
except ProviderTokenNotInitError as e:
dataset_document.indexing_status = "error"
dataset_document.error = str(e.description)
dataset_document.stopped_at = naive_utc_now()
db.session.commit()
self._handle_indexing_error(document_id, e)
except Exception as e:
logger.exception("consume document failed")
dataset_document.indexing_status = "error"
dataset_document.error = str(e)
dataset_document.stopped_at = naive_utc_now()
db.session.commit()
self._handle_indexing_error(document_id, e)
def indexing_estimate(
self,

View File

@ -2,7 +2,7 @@ import logging
import os
from datetime import datetime, timedelta
from langfuse import Langfuse # type: ignore
from langfuse import Langfuse
from sqlalchemy.orm import sessionmaker
from core.ops.base_trace_instance import BaseTraceInstance

View File

@ -180,7 +180,7 @@ class BasePluginClient:
Make a request to the plugin daemon inner API and return the response as a model.
"""
response = self._request(method, path, headers, data, params, files)
return type_(**response.json()) # type: ignore
return type_(**response.json()) # type: ignore[return-value]
def _request_with_plugin_daemon_response(
self,

View File

@ -74,7 +74,7 @@ class CeleryWorkflowExecutionRepository(WorkflowExecutionRepository):
tenant_id = extract_tenant_id(user)
if not tenant_id:
raise ValueError("User must have a tenant_id or current_tenant_id")
self._tenant_id = tenant_id # type: ignore[assignment] # We've already checked tenant_id is not None
self._tenant_id = tenant_id
# Store app context
self._app_id = app_id

View File

@ -81,7 +81,7 @@ class CeleryWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository):
tenant_id = extract_tenant_id(user)
if not tenant_id:
raise ValueError("User must have a tenant_id or current_tenant_id")
self._tenant_id = tenant_id # type: ignore[assignment] # We've already checked tenant_id is not None
self._tenant_id = tenant_id
# Store app context
self._app_id = app_id

View File

@ -60,7 +60,7 @@ class DifyCoreRepositoryFactory:
try:
repository_class = import_string(class_path)
return repository_class( # type: ignore[no-any-return]
return repository_class(
session_factory=session_factory,
user=user,
app_id=app_id,
@ -96,7 +96,7 @@ class DifyCoreRepositoryFactory:
try:
repository_class = import_string(class_path)
return repository_class( # type: ignore[no-any-return]
return repository_class(
session_factory=session_factory,
user=user,
app_id=app_id,

View File

@ -157,7 +157,7 @@ class BuiltinToolProviderController(ToolProviderController):
"""
returns the tool that the provider can provide
"""
return next(filter(lambda x: x.entity.identity.name == tool_name, self.get_tools()), None) # type: ignore
return next(filter(lambda x: x.entity.identity.name == tool_name, self.get_tools()), None)
@property
def need_credentials(self) -> bool:

View File

@ -43,7 +43,7 @@ class TTSTool(BuiltinTool):
content_text=tool_parameters.get("text"), # type: ignore
user=user_id,
tenant_id=self.runtime.tenant_id,
voice=voice, # type: ignore
voice=voice,
)
buffer = io.BytesIO()
for chunk in tts:

View File

@ -34,6 +34,7 @@ class LocaltimeToTimestampTool(BuiltinTool):
yield self.create_text_message(f"{timestamp}")
# TODO: this method's type is messy
@staticmethod
def localtime_to_timestamp(localtime: str, time_format: str, local_tz=None) -> int | None:
try:

View File

@ -48,6 +48,6 @@ class TimezoneConversionTool(BuiltinTool):
datetime_with_tz = input_timezone.localize(local_time)
# timezone convert
converted_datetime = datetime_with_tz.astimezone(output_timezone)
return converted_datetime.strftime(format=time_format) # type: ignore
return converted_datetime.strftime(time_format)
except Exception as e:
raise ToolInvokeError(str(e))

View File

@ -105,7 +105,7 @@ class MCPToolProviderController(ToolProviderController):
"""
pass
def get_tool(self, tool_name: str) -> MCPTool: # type: ignore
def get_tool(self, tool_name: str) -> MCPTool:
"""
return tool with given name
"""
@ -128,7 +128,7 @@ class MCPToolProviderController(ToolProviderController):
sse_read_timeout=self.sse_read_timeout,
)
def get_tools(self) -> list[MCPTool]: # type: ignore
def get_tools(self) -> list[MCPTool]:
"""
get all tools
"""

View File

@ -26,7 +26,7 @@ class ToolLabelManager:
labels = cls.filter_tool_labels(labels)
if isinstance(controller, ApiToolProviderController | WorkflowToolProviderController):
provider_id = controller.provider_id # ty: ignore [unresolved-attribute]
provider_id = controller.provider_id
else:
raise ValueError("Unsupported tool type")
@ -51,7 +51,7 @@ class ToolLabelManager:
Get tool labels
"""
if isinstance(controller, ApiToolProviderController | WorkflowToolProviderController):
provider_id = controller.provider_id # ty: ignore [unresolved-attribute]
provider_id = controller.provider_id
elif isinstance(controller, BuiltinToolProviderController):
return controller.tool_labels
else:
@ -85,7 +85,7 @@ class ToolLabelManager:
provider_ids = []
for controller in tool_providers:
assert isinstance(controller, ApiToolProviderController | WorkflowToolProviderController)
provider_ids.append(controller.provider_id) # ty: ignore [unresolved-attribute]
provider_ids.append(controller.provider_id)
labels = db.session.scalars(select(ToolLabelBinding).where(ToolLabelBinding.tool_id.in_(provider_ids))).all()

View File

@ -193,18 +193,18 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
)
document = db.session.scalar(dataset_document_stmt) # type: ignore
document = db.session.scalar(dataset_document_stmt)
if dataset and document:
source = RetrievalSourceMetadata(
dataset_id=dataset.id,
dataset_name=dataset.name,
document_id=document.id, # type: ignore
document_name=document.name, # type: ignore
data_source_type=document.data_source_type, # type: ignore
document_id=document.id,
document_name=document.name,
data_source_type=document.data_source_type,
segment_id=segment.id,
retriever_from=self.retriever_from,
score=record.score or 0.0,
doc_metadata=document.doc_metadata, # type: ignore
doc_metadata=document.doc_metadata,
)
if self.retriever_from == "dev":

View File

@ -6,8 +6,8 @@ from typing import Any, cast
from urllib.parse import unquote
import chardet
import cloudscraper # type: ignore
from readabilipy import simple_json_from_html_string # type: ignore
import cloudscraper
from readabilipy import simple_json_from_html_string
from core.helper import ssrf_proxy
from core.rag.extractor import extract_processor
@ -63,8 +63,8 @@ def get_url(url: str, user_agent: str | None = None) -> str:
response = ssrf_proxy.get(url, headers=headers, follow_redirects=True, timeout=(120, 300))
elif response.status_code == 403:
scraper = cloudscraper.create_scraper()
scraper.perform_request = ssrf_proxy.make_request # type: ignore
response = scraper.get(url, headers=headers, follow_redirects=True, timeout=(120, 300)) # type: ignore
scraper.perform_request = ssrf_proxy.make_request
response = scraper.get(url, headers=headers, timeout=(120, 300))
if response.status_code != 200:
return f"URL returned status code {response.status_code}."

View File

@ -3,7 +3,7 @@ from functools import lru_cache
from pathlib import Path
from typing import Any
import yaml # type: ignore
import yaml
from yaml import YAMLError
logger = logging.getLogger(__name__)

View File

@ -99,7 +99,7 @@ class WorkflowToolProviderController(ToolProviderController):
variables = WorkflowToolConfigurationUtils.get_workflow_graph_variables(graph)
def fetch_workflow_variable(variable_name: str) -> VariableEntity | None:
return next(filter(lambda x: x.variable == variable_name, variables), None) # type: ignore
return next(filter(lambda x: x.variable == variable_name, variables), None)
user = db_provider.user

View File

@ -4,7 +4,7 @@ from .types import SegmentType
class SegmentGroup(Segment):
value_type: SegmentType = SegmentType.GROUP
value: list[Segment] = None # type: ignore
value: list[Segment]
@property
def text(self):

View File

@ -19,7 +19,7 @@ class Segment(BaseModel):
model_config = ConfigDict(frozen=True)
value_type: SegmentType
value: Any = None
value: Any
@field_validator("value_type")
@classmethod
@ -74,12 +74,12 @@ class NoneSegment(Segment):
class StringSegment(Segment):
value_type: SegmentType = SegmentType.STRING
value: str = None # type: ignore
value: str
class FloatSegment(Segment):
value_type: SegmentType = SegmentType.FLOAT
value: float = None # type: ignore
value: float
# NOTE(QuantumGhost): seems that the equality for FloatSegment with `NaN` value has some problems.
# The following tests cannot pass.
#
@ -98,12 +98,12 @@ class FloatSegment(Segment):
class IntegerSegment(Segment):
value_type: SegmentType = SegmentType.INTEGER
value: int = None # type: ignore
value: int
class ObjectSegment(Segment):
value_type: SegmentType = SegmentType.OBJECT
value: Mapping[str, Any] = None # type: ignore
value: Mapping[str, Any]
@property
def text(self) -> str:
@ -136,7 +136,7 @@ class ArraySegment(Segment):
class FileSegment(Segment):
value_type: SegmentType = SegmentType.FILE
value: File = None # type: ignore
value: File
@property
def markdown(self) -> str:
@ -153,17 +153,17 @@ class FileSegment(Segment):
class BooleanSegment(Segment):
value_type: SegmentType = SegmentType.BOOLEAN
value: bool = None # type: ignore
value: bool
class ArrayAnySegment(ArraySegment):
value_type: SegmentType = SegmentType.ARRAY_ANY
value: Sequence[Any] = None # type: ignore
value: Sequence[Any]
class ArrayStringSegment(ArraySegment):
value_type: SegmentType = SegmentType.ARRAY_STRING
value: Sequence[str] = None # type: ignore
value: Sequence[str]
@property
def text(self) -> str:
@ -175,17 +175,17 @@ class ArrayStringSegment(ArraySegment):
class ArrayNumberSegment(ArraySegment):
value_type: SegmentType = SegmentType.ARRAY_NUMBER
value: Sequence[float | int] = None # type: ignore
value: Sequence[float | int]
class ArrayObjectSegment(ArraySegment):
value_type: SegmentType = SegmentType.ARRAY_OBJECT
value: Sequence[Mapping[str, Any]] = None # type: ignore
value: Sequence[Mapping[str, Any]]
class ArrayFileSegment(ArraySegment):
value_type: SegmentType = SegmentType.ARRAY_FILE
value: Sequence[File] = None # type: ignore
value: Sequence[File]
@property
def markdown(self) -> str:
@ -205,7 +205,7 @@ class ArrayFileSegment(ArraySegment):
class ArrayBooleanSegment(ArraySegment):
value_type: SegmentType = SegmentType.ARRAY_BOOLEAN
value: Sequence[bool] = None # type: ignore
value: Sequence[bool]
def get_segment_discriminator(v: Any) -> SegmentType | None:

View File

@ -1,3 +1,5 @@
from ..runtime.graph_runtime_state import GraphRuntimeState
from ..runtime.variable_pool import VariablePool
from .agent import AgentNodeStrategyInit
from .graph_init_params import GraphInitParams
from .workflow_execution import WorkflowExecution
@ -6,6 +8,8 @@ from .workflow_node_execution import WorkflowNodeExecution
__all__ = [
"AgentNodeStrategyInit",
"GraphInitParams",
"GraphRuntimeState",
"VariablePool",
"WorkflowExecution",
"WorkflowNodeExecution",
]

View File

@ -3,11 +3,12 @@ from collections import defaultdict
from collections.abc import Mapping, Sequence
from typing import Protocol, cast, final
from core.workflow.enums import NodeExecutionType, NodeState, NodeType
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeState, NodeType
from core.workflow.nodes.base.node import Node
from libs.typing import is_str, is_str_dict
from .edge import Edge
from .validation import get_graph_validator
logger = logging.getLogger(__name__)
@ -201,6 +202,17 @@ class Graph:
return GraphBuilder(graph_cls=cls)
@classmethod
def _promote_fail_branch_nodes(cls, nodes: dict[str, Node]) -> None:
"""
Promote nodes configured with FAIL_BRANCH error strategy to branch execution type.
:param nodes: mapping of node ID to node instance
"""
for node in nodes.values():
if node.error_strategy == ErrorStrategy.FAIL_BRANCH:
node.execution_type = NodeExecutionType.BRANCH
@classmethod
def _mark_inactive_root_branches(
cls,
@ -307,6 +319,9 @@ class Graph:
# Create node instances
nodes = cls._create_node_instances(node_configs_map, node_factory)
# Promote fail-branch nodes to branch execution type at graph level
cls._promote_fail_branch_nodes(nodes)
# Get root node instance
root_node = nodes[root_node_id]
@ -314,7 +329,7 @@ class Graph:
cls._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, root_node_id)
# Create and return the graph
return cls(
graph = cls(
nodes=nodes,
edges=edges,
in_edges=in_edges,
@ -322,6 +337,11 @@ class Graph:
root_node=root_node,
)
# Validate the graph structure using built-in validators
get_graph_validator().validate(graph)
return graph
@property
def node_ids(self) -> list[str]:
"""

View File

@ -0,0 +1,125 @@
from __future__ import annotations
from collections.abc import Sequence
from dataclasses import dataclass
from typing import TYPE_CHECKING, Protocol
from core.workflow.enums import NodeExecutionType, NodeType
if TYPE_CHECKING:
from .graph import Graph
@dataclass(frozen=True, slots=True)
class GraphValidationIssue:
"""Immutable value object describing a single validation issue."""
code: str
message: str
node_id: str | None = None
class GraphValidationError(ValueError):
"""Raised when graph validation fails."""
def __init__(self, issues: Sequence[GraphValidationIssue]) -> None:
if not issues:
raise ValueError("GraphValidationError requires at least one issue.")
self.issues: tuple[GraphValidationIssue, ...] = tuple(issues)
message = "; ".join(f"[{issue.code}] {issue.message}" for issue in self.issues)
super().__init__(message)
class GraphValidationRule(Protocol):
"""Protocol that individual validation rules must satisfy."""
def validate(self, graph: Graph) -> Sequence[GraphValidationIssue]:
"""Validate the provided graph and return any discovered issues."""
...
@dataclass(frozen=True, slots=True)
class _EdgeEndpointValidator:
"""Ensures all edges reference existing nodes."""
missing_node_code: str = "MISSING_NODE"
def validate(self, graph: Graph) -> Sequence[GraphValidationIssue]:
issues: list[GraphValidationIssue] = []
for edge in graph.edges.values():
if edge.tail not in graph.nodes:
issues.append(
GraphValidationIssue(
code=self.missing_node_code,
message=f"Edge {edge.id} references unknown source node '{edge.tail}'.",
node_id=edge.tail,
)
)
if edge.head not in graph.nodes:
issues.append(
GraphValidationIssue(
code=self.missing_node_code,
message=f"Edge {edge.id} references unknown target node '{edge.head}'.",
node_id=edge.head,
)
)
return issues
@dataclass(frozen=True, slots=True)
class _RootNodeValidator:
"""Validates root node invariants."""
invalid_root_code: str = "INVALID_ROOT"
container_entry_types: tuple[NodeType, ...] = (NodeType.ITERATION_START, NodeType.LOOP_START)
def validate(self, graph: Graph) -> Sequence[GraphValidationIssue]:
root_node = graph.root_node
issues: list[GraphValidationIssue] = []
if root_node.id not in graph.nodes:
issues.append(
GraphValidationIssue(
code=self.invalid_root_code,
message=f"Root node '{root_node.id}' is missing from the node registry.",
node_id=root_node.id,
)
)
return issues
node_type = getattr(root_node, "node_type", None)
if root_node.execution_type != NodeExecutionType.ROOT and node_type not in self.container_entry_types:
issues.append(
GraphValidationIssue(
code=self.invalid_root_code,
message=f"Root node '{root_node.id}' must declare execution type 'root'.",
node_id=root_node.id,
)
)
return issues
@dataclass(frozen=True, slots=True)
class GraphValidator:
"""Coordinates execution of graph validation rules."""
rules: tuple[GraphValidationRule, ...]
def validate(self, graph: Graph) -> None:
"""Validate the graph against all configured rules."""
issues: list[GraphValidationIssue] = []
for rule in self.rules:
issues.extend(rule.validate(graph))
if issues:
raise GraphValidationError(issues)
_DEFAULT_RULES: tuple[GraphValidationRule, ...] = (
_EdgeEndpointValidator(),
_RootNodeValidator(),
)
def get_graph_validator() -> GraphValidator:
"""Construct the validator composed of default rules."""
return GraphValidator(_DEFAULT_RULES)

View File

@ -1,5 +1,6 @@
import json
from abc import ABC
from builtins import type as type_
from collections.abc import Sequence
from enum import StrEnum
from typing import Any, Union
@ -58,10 +59,9 @@ class DefaultValue(BaseModel):
raise DefaultValueTypeError(f"Invalid JSON format for value: {value}")
@staticmethod
def _validate_array(value: Any, element_type: DefaultValueType) -> bool:
def _validate_array(value: Any, element_type: type_ | tuple[type_, ...]) -> bool:
"""Unified array type validation"""
# FIXME, type ignore here for do not find the reason mypy complain, if find the root cause, please fix it
return isinstance(value, list) and all(isinstance(x, element_type) for x in value) # type: ignore
return isinstance(value, list) and all(isinstance(x, element_type) for x in value)
@staticmethod
def _convert_number(value: str) -> float:

View File

@ -10,10 +10,10 @@ from typing import Any
import chardet
import docx
import pandas as pd
import pypandoc # type: ignore
import pypdfium2 # type: ignore
import webvtt # type: ignore
import yaml # type: ignore
import pypandoc
import pypdfium2
import webvtt
import yaml
from docx.document import Document
from docx.oxml.table import CT_Tbl
from docx.oxml.text.paragraph import CT_P

View File

@ -141,7 +141,7 @@ class KnowledgeRetrievalNode(Node):
def version(cls):
return "1"
def _run(self) -> NodeRunResult: # type: ignore
def _run(self) -> NodeRunResult:
# extract variables
variable = self.graph_runtime_state.variable_pool.get(self._node_data.query_variable_selector)
if not isinstance(variable, StringSegment):
@ -443,7 +443,7 @@ class KnowledgeRetrievalNode(Node):
metadata_condition = MetadataCondition(
logical_operator=node_data.metadata_filtering_conditions.logical_operator
if node_data.metadata_filtering_conditions
else "or", # type: ignore
else "or",
conditions=conditions,
)
elif node_data.metadata_filtering_mode == "manual":
@ -457,10 +457,10 @@ class KnowledgeRetrievalNode(Node):
expected_value = self.graph_runtime_state.variable_pool.convert_template(
expected_value
).value[0]
if expected_value.value_type in {"number", "integer", "float"}: # type: ignore
expected_value = expected_value.value # type: ignore
elif expected_value.value_type == "string": # type: ignore
expected_value = re.sub(r"[\r\n\t]+", " ", expected_value.text).strip() # type: ignore
if expected_value.value_type in {"number", "integer", "float"}:
expected_value = expected_value.value
elif expected_value.value_type == "string":
expected_value = re.sub(r"[\r\n\t]+", " ", expected_value.text).strip()
else:
raise ValueError("Invalid expected metadata value type")
conditions.append(
@ -487,7 +487,7 @@ class KnowledgeRetrievalNode(Node):
if (
node_data.metadata_filtering_conditions
and node_data.metadata_filtering_conditions.logical_operator == "and"
): # type: ignore
):
document_query = document_query.where(and_(*filters))
else:
document_query = document_query.where(or_(*filters))

View File

@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, final
from typing_extensions import override
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType
from core.workflow.enums import NodeType
from core.workflow.graph import NodeFactory
from core.workflow.nodes.base.node import Node
from libs.typing import is_str, is_str_dict
@ -82,8 +82,4 @@ class DifyNodeFactory(NodeFactory):
raise ValueError(f"Node {node_id} missing data information")
node_instance.init_node_data(node_data)
# If node has fail branch, change execution type to branch
if node_instance.error_strategy == ErrorStrategy.FAIL_BRANCH:
node_instance.execution_type = NodeExecutionType.BRANCH
return node_instance

View File

@ -747,7 +747,7 @@ class ParameterExtractorNode(Node):
if model_mode == ModelMode.CHAT:
system_prompt_messages = ChatModelMessage(
role=PromptMessageRole.SYSTEM,
text=CHAT_GENERATE_JSON_PROMPT.format(histories=memory_str).replace("{{instructions}}", instruction),
text=CHAT_GENERATE_JSON_PROMPT.format(histories=memory_str, instructions=instruction),
)
user_prompt_message = ChatModelMessage(role=PromptMessageRole.USER, text=input_text)
return [system_prompt_messages, user_prompt_message]

View File

@ -135,7 +135,7 @@ Here are the chat histories between human and assistant, inside <histories></his
### Instructions:
Some extra information are provided below, you should always follow the instructions as possible as you can.
<instructions>
{{instructions}}
{instructions}
</instructions>
"""

View File

@ -260,7 +260,7 @@ class VariablePool(BaseModel):
# This ensures that we can keep the id of the system variables intact.
if self._has(selector):
continue
self.add(selector, value) # type: ignore
self.add(selector, value)
@classmethod
def empty(cls) -> "VariablePool":

View File

@ -1,7 +1,12 @@
from configs import dify_config
from constants import HEADER_NAME_APP_CODE, HEADER_NAME_CSRF_TOKEN
from constants import HEADER_NAME_APP_CODE, HEADER_NAME_CSRF_TOKEN, HEADER_NAME_PASSPORT
from dify_app import DifyApp
BASE_CORS_HEADERS: tuple[str, ...] = ("Content-Type", HEADER_NAME_APP_CODE, HEADER_NAME_PASSPORT)
SERVICE_API_HEADERS: tuple[str, ...] = (*BASE_CORS_HEADERS, "Authorization")
AUTHENTICATED_HEADERS: tuple[str, ...] = (*SERVICE_API_HEADERS, HEADER_NAME_CSRF_TOKEN)
FILES_HEADERS: tuple[str, ...] = (*BASE_CORS_HEADERS, HEADER_NAME_CSRF_TOKEN)
def init_app(app: DifyApp):
# register blueprint routers
@ -17,7 +22,7 @@ def init_app(app: DifyApp):
CORS(
service_api_bp,
allow_headers=["Content-Type", "Authorization", HEADER_NAME_APP_CODE],
allow_headers=list(SERVICE_API_HEADERS),
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
)
app.register_blueprint(service_api_bp)
@ -26,7 +31,7 @@ def init_app(app: DifyApp):
web_bp,
resources={r"/*": {"origins": dify_config.WEB_API_CORS_ALLOW_ORIGINS}},
supports_credentials=True,
allow_headers=["Content-Type", "Authorization", HEADER_NAME_APP_CODE, HEADER_NAME_CSRF_TOKEN],
allow_headers=list(AUTHENTICATED_HEADERS),
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
expose_headers=["X-Version", "X-Env"],
)
@ -36,7 +41,7 @@ def init_app(app: DifyApp):
console_app_bp,
resources={r"/*": {"origins": dify_config.CONSOLE_CORS_ALLOW_ORIGINS}},
supports_credentials=True,
allow_headers=["Content-Type", "Authorization", HEADER_NAME_CSRF_TOKEN],
allow_headers=list(AUTHENTICATED_HEADERS),
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
expose_headers=["X-Version", "X-Env"],
)
@ -44,7 +49,7 @@ def init_app(app: DifyApp):
CORS(
files_bp,
allow_headers=["Content-Type", HEADER_NAME_CSRF_TOKEN],
allow_headers=list(FILES_HEADERS),
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
)
app.register_blueprint(files_bp)

View File

@ -7,7 +7,7 @@ def is_enabled() -> bool:
def init_app(app: DifyApp):
from flask_compress import Compress # type: ignore
from flask_compress import Compress
compress = Compress()
compress.init_app(app)

View File

@ -1,6 +1,6 @@
import json
import flask_login # type: ignore
import flask_login
from flask import Response, request
from flask_login import user_loaded_from_request, user_logged_in
from werkzeug.exceptions import NotFound, Unauthorized

View File

@ -2,7 +2,7 @@ from dify_app import DifyApp
def init_app(app: DifyApp):
import flask_migrate # type: ignore
import flask_migrate
from extensions.ext_database import db

View File

@ -103,7 +103,7 @@ def init_app(app: DifyApp):
def shutdown_tracer():
provider = trace.get_tracer_provider()
if hasattr(provider, "force_flush"):
provider.force_flush() # ty: ignore [call-non-callable]
provider.force_flush()
class ExceptionLoggingHandler(logging.Handler):
"""Custom logging handler that creates spans for logging.exception() calls"""

View File

@ -6,4 +6,4 @@ def init_app(app: DifyApp):
if dify_config.RESPECT_XFORWARD_HEADERS_ENABLED:
from werkzeug.middleware.proxy_fix import ProxyFix
app.wsgi_app = ProxyFix(app.wsgi_app, x_port=1) # type: ignore
app.wsgi_app = ProxyFix(app.wsgi_app, x_port=1) # type: ignore[method-assign]

View File

@ -5,7 +5,7 @@ from dify_app import DifyApp
def init_app(app: DifyApp):
if dify_config.SENTRY_DSN:
import sentry_sdk
from langfuse import parse_error # type: ignore
from langfuse import parse_error
from sentry_sdk.integrations.celery import CeleryIntegration
from sentry_sdk.integrations.flask import FlaskIntegration
from werkzeug.exceptions import HTTPException

View File

@ -1,7 +1,7 @@
import posixpath
from collections.abc import Generator
import oss2 as aliyun_s3 # type: ignore
import oss2 as aliyun_s3
from configs import dify_config
from extensions.storage.base_storage import BaseStorage

View File

@ -2,9 +2,9 @@ import base64
import hashlib
from collections.abc import Generator
from baidubce.auth.bce_credentials import BceCredentials # type: ignore
from baidubce.bce_client_configuration import BceClientConfiguration # type: ignore
from baidubce.services.bos.bos_client import BosClient # type: ignore
from baidubce.auth.bce_credentials import BceCredentials
from baidubce.bce_client_configuration import BceClientConfiguration
from baidubce.services.bos.bos_client import BosClient
from configs import dify_config
from extensions.storage.base_storage import BaseStorage

View File

@ -11,7 +11,7 @@ from collections.abc import Generator
from io import BytesIO
from pathlib import Path
import clickzetta # type: ignore[import]
import clickzetta
from pydantic import BaseModel, model_validator
from extensions.storage.base_storage import BaseStorage

View File

@ -34,7 +34,7 @@ class VolumePermissionManager:
# Support two initialization methods: connection object or configuration dictionary
if isinstance(connection_or_config, dict):
# Create connection from configuration dictionary
import clickzetta # type: ignore[import-untyped]
import clickzetta
config = connection_or_config
self._connection = clickzetta.connect(

View File

@ -3,7 +3,7 @@ import io
import json
from collections.abc import Generator
from google.cloud import storage as google_cloud_storage # type: ignore
from google.cloud import storage as google_cloud_storage
from configs import dify_config
from extensions.storage.base_storage import BaseStorage

View File

@ -1,6 +1,6 @@
from collections.abc import Generator
from obs import ObsClient # type: ignore
from obs import ObsClient
from configs import dify_config
from extensions.storage.base_storage import BaseStorage

View File

@ -1,7 +1,7 @@
from collections.abc import Generator
import boto3 # type: ignore
from botocore.exceptions import ClientError # type: ignore
import boto3
from botocore.exceptions import ClientError
from configs import dify_config
from extensions.storage.base_storage import BaseStorage

View File

@ -1,6 +1,6 @@
from collections.abc import Generator
from qcloud_cos import CosConfig, CosS3Client # type: ignore
from qcloud_cos import CosConfig, CosS3Client
from configs import dify_config
from extensions.storage.base_storage import BaseStorage

View File

@ -1,6 +1,6 @@
from collections.abc import Generator
import tos # type: ignore
import tos
from configs import dify_config
from extensions.storage.base_storage import BaseStorage

View File

@ -146,6 +146,6 @@ class ExternalApi(Api):
kwargs["doc"] = dify_config.SWAGGER_UI_PATH if dify_config.SWAGGER_UI_ENABLED else False
# manual separate call on construction and init_app to ensure configs in kwargs effective
super().__init__(app=None, *args, **kwargs) # type: ignore
super().__init__(app=None, *args, **kwargs)
self.init_app(app, **kwargs)
register_external_error_handlers(self)

View File

@ -23,7 +23,7 @@ from hashlib import sha1
import Crypto.Hash.SHA1
import Crypto.Util.number
import gmpy2 # type: ignore
import gmpy2
from Crypto import Random
from Crypto.Signature.pss import MGF1
from Crypto.Util.number import bytes_to_long, ceil_div, long_to_bytes
@ -136,7 +136,7 @@ class PKCS1OAepCipher:
# Step 3a (OS2IP)
em_int = bytes_to_long(em)
# Step 3b (RSAEP)
m_int = gmpy2.powmod(em_int, self._key.e, self._key.n) # ty: ignore [unresolved-attribute]
m_int = gmpy2.powmod(em_int, self._key.e, self._key.n)
# Step 3c (I2OSP)
c = long_to_bytes(m_int, k)
return c
@ -169,7 +169,7 @@ class PKCS1OAepCipher:
ct_int = bytes_to_long(ciphertext)
# Step 2b (RSADP)
# m_int = self._key._decrypt(ct_int)
m_int = gmpy2.powmod(ct_int, self._key.d, self._key.n) # ty: ignore [unresolved-attribute]
m_int = gmpy2.powmod(ct_int, self._key.d, self._key.n)
# Complete step 2c (I2OSP)
em = long_to_bytes(m_int, k)
# Step 3a
@ -191,12 +191,12 @@ class PKCS1OAepCipher:
# Step 3g
one_pos = hLen + db[hLen:].find(b"\x01")
lHash1 = db[:hLen]
invalid = bord(y) | int(one_pos < hLen) # type: ignore
invalid = bord(y) | int(one_pos < hLen) # type: ignore[arg-type]
hash_compare = strxor(lHash1, lHash)
for x in hash_compare:
invalid |= bord(x) # type: ignore
invalid |= bord(x) # type: ignore[arg-type]
for x in db[hLen:one_pos]:
invalid |= bord(x) # type: ignore
invalid |= bord(x) # type: ignore[arg-type]
if invalid != 0:
raise ValueError("Incorrect decryption.")
# Step 4

View File

@ -3,7 +3,7 @@ from functools import wraps
from typing import Any
from flask import current_app, g, has_request_context, request
from flask_login.config import EXEMPT_METHODS # type: ignore
from flask_login.config import EXEMPT_METHODS
from werkzeug.local import LocalProxy
from configs import dify_config
@ -87,7 +87,7 @@ def _get_user() -> EndUser | Account | None:
if "_login_user" not in g:
current_app.login_manager._load_user() # type: ignore
return g._login_user # type: ignore
return g._login_user
return None

View File

@ -1,8 +1,8 @@
import logging
import sendgrid # type: ignore
import sendgrid
from python_http_client.exceptions import ForbiddenError, UnauthorizedError
from sendgrid.helpers.mail import Content, Email, Mail, To # type: ignore
from sendgrid.helpers.mail import Content, Email, Mail, To
logger = logging.getLogger(__name__)

View File

@ -5,7 +5,7 @@ from datetime import datetime
from typing import Any, Optional
import sqlalchemy as sa
from flask_login import UserMixin # type: ignore[import-untyped]
from flask_login import UserMixin
from sqlalchemy import DateTime, String, func, select
from sqlalchemy.orm import Mapped, Session, mapped_column
from typing_extensions import deprecated

View File

@ -8,7 +8,7 @@ from typing import TYPE_CHECKING, Any, Literal, Optional, cast
import sqlalchemy as sa
from flask import request
from flask_login import UserMixin # type: ignore[import-untyped]
from flask_login import UserMixin
from sqlalchemy import Float, Index, PrimaryKeyConstraint, String, exists, func, select, text
from sqlalchemy.orm import Mapped, Session, mapped_column

View File

@ -16,7 +16,25 @@
"opentelemetry.instrumentation.requests",
"opentelemetry.instrumentation.sqlalchemy",
"opentelemetry.instrumentation.redis",
"opentelemetry.instrumentation.httpx"
"langfuse",
"cloudscraper",
"readabilipy",
"pypandoc",
"pypdfium2",
"webvtt",
"flask_compress",
"oss2",
"baidubce.auth.bce_credentials",
"baidubce.bce_client_configuration",
"baidubce.services.bos.bos_client",
"clickzetta",
"google.cloud",
"obs",
"qcloud_cos",
"tos",
"gmpy2",
"sendgrid",
"sendgrid.helpers.mail"
],
"reportUnknownMemberType": "hint",
"reportUnknownParameterType": "hint",
@ -28,7 +46,7 @@
"reportUnnecessaryComparison": "hint",
"reportUnnecessaryIsInstance": "hint",
"reportUntypedFunctionDecorator": "hint",
"reportUnnecessaryTypeIgnoreComment": "hint",
"reportAttributeAccessIssue": "hint",
"pythonVersion": "3.11",
"pythonPlatform": "All"

View File

@ -48,7 +48,7 @@ class DifyAPIRepositoryFactory(DifyCoreRepositoryFactory):
try:
repository_class = import_string(class_path)
return repository_class(session_maker=session_maker) # type: ignore[no-any-return]
return repository_class(session_maker=session_maker)
except (ImportError, Exception) as e:
raise RepositoryImportError(
f"Failed to create DifyAPIWorkflowNodeExecutionRepository from '{class_path}': {e}"
@ -77,6 +77,6 @@ class DifyAPIRepositoryFactory(DifyCoreRepositoryFactory):
try:
repository_class = import_string(class_path)
return repository_class(session_maker=session_maker) # type: ignore[no-any-return]
return repository_class(session_maker=session_maker)
except (ImportError, Exception) as e:
raise RepositoryImportError(f"Failed to create APIWorkflowRunRepository from '{class_path}': {e}") from e

View File

@ -7,7 +7,7 @@ from enum import StrEnum
from urllib.parse import urlparse
from uuid import uuid4
import yaml # type: ignore
import yaml
from Crypto.Cipher import AES
from Crypto.Util.Padding import pad, unpad
from packaging import version
@ -563,7 +563,7 @@ class AppDslService:
else:
cls._append_model_config_export_data(export_data, app_model)
return yaml.dump(export_data, allow_unicode=True) # type: ignore
return yaml.dump(export_data, allow_unicode=True)
@classmethod
def _append_workflow_export_data(

View File

@ -241,9 +241,9 @@ class DatasetService:
dataset.created_by = account.id
dataset.updated_by = account.id
dataset.tenant_id = tenant_id
dataset.embedding_model_provider = embedding_model.provider if embedding_model else None # type: ignore
dataset.embedding_model = embedding_model.model if embedding_model else None # type: ignore
dataset.retrieval_model = retrieval_model.model_dump() if retrieval_model else None # type: ignore
dataset.embedding_model_provider = embedding_model.provider if embedding_model else None
dataset.embedding_model = embedding_model.model if embedding_model else None
dataset.retrieval_model = retrieval_model.model_dump() if retrieval_model else None
dataset.permission = permission or DatasetPermissionEnum.ONLY_ME
dataset.provider = provider
db.session.add(dataset)
@ -1416,6 +1416,8 @@ class DocumentService:
# check document limit
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
assert knowledge_config.data_source
assert knowledge_config.data_source.info_list.file_info_list
features = FeatureService.get_features(current_user.current_tenant_id)
@ -1424,15 +1426,16 @@ class DocumentService:
count = 0
if knowledge_config.data_source:
if knowledge_config.data_source.info_list.data_source_type == "upload_file":
upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids # type: ignore
upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids
count = len(upload_file_list)
elif knowledge_config.data_source.info_list.data_source_type == "notion_import":
notion_info_list = knowledge_config.data_source.info_list.notion_info_list
for notion_info in notion_info_list: # type: ignore
notion_info_list = knowledge_config.data_source.info_list.notion_info_list or []
for notion_info in notion_info_list:
count = count + len(notion_info.pages)
elif knowledge_config.data_source.info_list.data_source_type == "website_crawl":
website_info = knowledge_config.data_source.info_list.website_info_list
count = len(website_info.urls) # type: ignore
assert website_info
count = len(website_info.urls)
batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT)
if features.billing.subscription.plan == "sandbox" and count > 1:
@ -1444,7 +1447,7 @@ class DocumentService:
# if dataset is empty, update dataset data_source_type
if not dataset.data_source_type:
dataset.data_source_type = knowledge_config.data_source.info_list.data_source_type # type: ignore
dataset.data_source_type = knowledge_config.data_source.info_list.data_source_type
if not dataset.indexing_technique:
if knowledge_config.indexing_technique not in Dataset.INDEXING_TECHNIQUE_LIST:
@ -1481,7 +1484,7 @@ class DocumentService:
knowledge_config.retrieval_model.model_dump()
if knowledge_config.retrieval_model
else default_retrieval_model
) # type: ignore
)
documents = []
if knowledge_config.original_document_id:
@ -1523,11 +1526,12 @@ class DocumentService:
db.session.flush()
lock_name = f"add_document_lock_dataset_id_{dataset.id}"
with redis_client.lock(lock_name, timeout=600):
assert dataset_process_rule
position = DocumentService.get_documents_position(dataset.id)
document_ids = []
duplicate_document_ids = []
if knowledge_config.data_source.info_list.data_source_type == "upload_file": # type: ignore
upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids # type: ignore
if knowledge_config.data_source.info_list.data_source_type == "upload_file":
upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids
for file_id in upload_file_list:
file = (
db.session.query(UploadFile)
@ -1540,7 +1544,7 @@ class DocumentService:
raise FileNotExistsError()
file_name = file.name
data_source_info = {
data_source_info: dict[str, str | bool] = {
"upload_file_id": file_id,
}
# check duplicate
@ -1557,7 +1561,7 @@ class DocumentService:
.first()
)
if document:
document.dataset_process_rule_id = dataset_process_rule.id # type: ignore
document.dataset_process_rule_id = dataset_process_rule.id
document.updated_at = naive_utc_now()
document.created_from = created_from
document.doc_form = knowledge_config.doc_form
@ -1571,8 +1575,8 @@ class DocumentService:
continue
document = DocumentService.build_document(
dataset,
dataset_process_rule.id, # type: ignore
knowledge_config.data_source.info_list.data_source_type, # type: ignore
dataset_process_rule.id,
knowledge_config.data_source.info_list.data_source_type,
knowledge_config.doc_form,
knowledge_config.doc_language,
data_source_info,
@ -1587,7 +1591,7 @@ class DocumentService:
document_ids.append(document.id)
documents.append(document)
position += 1
elif knowledge_config.data_source.info_list.data_source_type == "notion_import": # type: ignore
elif knowledge_config.data_source.info_list.data_source_type == "notion_import":
notion_info_list = knowledge_config.data_source.info_list.notion_info_list # type: ignore
if not notion_info_list:
raise ValueError("No notion info list found.")
@ -1616,15 +1620,15 @@ class DocumentService:
"credential_id": notion_info.credential_id,
"notion_workspace_id": workspace_id,
"notion_page_id": page.page_id,
"notion_page_icon": page.page_icon.model_dump() if page.page_icon else None,
"notion_page_icon": page.page_icon.model_dump() if page.page_icon else None, # type: ignore
"type": page.type,
}
# Truncate page name to 255 characters to prevent DB field length errors
truncated_page_name = page.page_name[:255] if page.page_name else "nopagename"
document = DocumentService.build_document(
dataset,
dataset_process_rule.id, # type: ignore
knowledge_config.data_source.info_list.data_source_type, # type: ignore
dataset_process_rule.id,
knowledge_config.data_source.info_list.data_source_type,
knowledge_config.doc_form,
knowledge_config.doc_language,
data_source_info,
@ -1644,8 +1648,8 @@ class DocumentService:
# delete not selected documents
if len(exist_document) > 0:
clean_notion_document_task.delay(list(exist_document.values()), dataset.id)
elif knowledge_config.data_source.info_list.data_source_type == "website_crawl": # type: ignore
website_info = knowledge_config.data_source.info_list.website_info_list # type: ignore
elif knowledge_config.data_source.info_list.data_source_type == "website_crawl":
website_info = knowledge_config.data_source.info_list.website_info_list
if not website_info:
raise ValueError("No website info list found.")
urls = website_info.urls
@ -1663,8 +1667,8 @@ class DocumentService:
document_name = url
document = DocumentService.build_document(
dataset,
dataset_process_rule.id, # type: ignore
knowledge_config.data_source.info_list.data_source_type, # type: ignore
dataset_process_rule.id,
knowledge_config.data_source.info_list.data_source_type,
knowledge_config.doc_form,
knowledge_config.doc_language,
data_source_info,
@ -2071,7 +2075,7 @@ class DocumentService:
# update document data source
if document_data.data_source:
file_name = ""
data_source_info = {}
data_source_info: dict[str, str | bool] = {}
if document_data.data_source.info_list.data_source_type == "upload_file":
if not document_data.data_source.info_list.file_info_list:
raise ValueError("No file info list found.")
@ -2128,7 +2132,7 @@ class DocumentService:
"url": url,
"provider": website_info.provider,
"job_id": website_info.job_id,
"only_main_content": website_info.only_main_content, # type: ignore
"only_main_content": website_info.only_main_content,
"mode": "crawl",
}
document.data_source_type = document_data.data_source.info_list.data_source_type
@ -2154,7 +2158,7 @@ class DocumentService:
db.session.query(DocumentSegment).filter_by(document_id=document.id).update(
{DocumentSegment.status: "re_segment"}
) # type: ignore
)
db.session.commit()
# trigger async task
document_indexing_update_task.delay(document.dataset_id, document.id)
@ -2164,25 +2168,26 @@ class DocumentService:
def save_document_without_dataset_id(tenant_id: str, knowledge_config: KnowledgeConfig, account: Account):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
assert knowledge_config.data_source
features = FeatureService.get_features(current_user.current_tenant_id)
if features.billing.enabled:
count = 0
if knowledge_config.data_source.info_list.data_source_type == "upload_file": # type: ignore
if knowledge_config.data_source.info_list.data_source_type == "upload_file":
upload_file_list = (
knowledge_config.data_source.info_list.file_info_list.file_ids # type: ignore
if knowledge_config.data_source.info_list.file_info_list # type: ignore
knowledge_config.data_source.info_list.file_info_list.file_ids
if knowledge_config.data_source.info_list.file_info_list
else []
)
count = len(upload_file_list)
elif knowledge_config.data_source.info_list.data_source_type == "notion_import": # type: ignore
notion_info_list = knowledge_config.data_source.info_list.notion_info_list # type: ignore
elif knowledge_config.data_source.info_list.data_source_type == "notion_import":
notion_info_list = knowledge_config.data_source.info_list.notion_info_list
if notion_info_list:
for notion_info in notion_info_list:
count = count + len(notion_info.pages)
elif knowledge_config.data_source.info_list.data_source_type == "website_crawl": # type: ignore
website_info = knowledge_config.data_source.info_list.website_info_list # type: ignore
elif knowledge_config.data_source.info_list.data_source_type == "website_crawl":
website_info = knowledge_config.data_source.info_list.website_info_list
if website_info:
count = len(website_info.urls)
if features.billing.subscription.plan == "sandbox" and count > 1:
@ -2196,9 +2201,11 @@ class DocumentService:
dataset_collection_binding_id = None
retrieval_model = None
if knowledge_config.indexing_technique == "high_quality":
assert knowledge_config.embedding_model_provider
assert knowledge_config.embedding_model
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
knowledge_config.embedding_model_provider, # type: ignore
knowledge_config.embedding_model, # type: ignore
knowledge_config.embedding_model_provider,
knowledge_config.embedding_model,
)
dataset_collection_binding_id = dataset_collection_binding.id
if knowledge_config.retrieval_model:
@ -2215,7 +2222,7 @@ class DocumentService:
dataset = Dataset(
tenant_id=tenant_id,
name="",
data_source_type=knowledge_config.data_source.info_list.data_source_type, # type: ignore
data_source_type=knowledge_config.data_source.info_list.data_source_type,
indexing_technique=knowledge_config.indexing_technique,
created_by=account.id,
embedding_model=knowledge_config.embedding_model,
@ -2224,7 +2231,7 @@ class DocumentService:
retrieval_model=retrieval_model.model_dump() if retrieval_model else None,
)
db.session.add(dataset) # type: ignore
db.session.add(dataset)
db.session.flush()
documents, batch = DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account)

View File

@ -88,7 +88,7 @@ class HitTestingService:
db.session.add(dataset_query)
db.session.commit()
return cls.compact_retrieve_response(query, all_documents) # type: ignore
return cls.compact_retrieve_response(query, all_documents)
@classmethod
def external_retrieve(

View File

@ -1,4 +1,4 @@
import boto3 # type: ignore
import boto3
from configs import dify_config

View File

@ -89,7 +89,7 @@ class MetadataService:
document.doc_metadata = doc_metadata
db.session.add(document)
db.session.commit()
return metadata # type: ignore
return metadata
except Exception:
logger.exception("Update metadata name failed")
finally:

View File

@ -137,7 +137,7 @@ class ModelProviderService:
:return:
"""
provider_configuration = self._get_provider_configuration(tenant_id, provider)
return provider_configuration.get_provider_credential(credential_id=credential_id) # type: ignore
return provider_configuration.get_provider_credential(credential_id=credential_id)
def validate_provider_credentials(self, tenant_id: str, provider: str, credentials: dict):
"""
@ -225,7 +225,7 @@ class ModelProviderService:
:return:
"""
provider_configuration = self._get_provider_configuration(tenant_id, provider)
return provider_configuration.get_custom_model_credential( # type: ignore
return provider_configuration.get_custom_model_credential(
model_type=ModelType.value_of(model_type), model=model, credential_id=credential_id
)

View File

@ -146,7 +146,7 @@ class PluginMigration:
futures.append(
thread_pool.submit(
process_tenant,
current_app._get_current_object(), # type: ignore[attr-defined]
current_app._get_current_object(), # type: ignore
tenant_id,
)
)

View File

@ -544,8 +544,8 @@ class BuiltinToolManageService:
try:
# handle include, exclude
if is_filtered(
include_set=dify_config.POSITION_TOOL_INCLUDES_SET, # type: ignore
exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET, # type: ignore
include_set=dify_config.POSITION_TOOL_INCLUDES_SET,
exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET,
data=provider_controller,
name_func=lambda x: x.entity.identity.name,
):

View File

@ -308,7 +308,7 @@ class MCPToolManageService:
provider_controller = MCPToolProviderController.from_db(mcp_provider)
tool_configuration = ProviderConfigEncrypter(
tenant_id=mcp_provider.tenant_id,
config=list(provider_controller.get_credentials_schema()), # ty: ignore [invalid-argument-type]
config=list(provider_controller.get_credentials_schema()),
provider_config_cache=NoOpProviderCredentialCache(),
)
credentials = tool_configuration.encrypt(credentials)

View File

@ -102,7 +102,7 @@ def batch_create_segment_to_index_task(
for segment, tokens in zip(content, tokens_list):
content = segment["content"]
doc_id = str(uuid.uuid4())
segment_hash = helper.generate_text_hash(content) # type: ignore
segment_hash = helper.generate_text_hash(content)
max_position = (
db.session.query(func.max(DocumentSegment.position))
.where(DocumentSegment.document_id == dataset_document.id)

View File

@ -5,11 +5,11 @@ from unittest.mock import MagicMock
import pytest
from _pytest.monkeypatch import MonkeyPatch
from pymochow import MochowClient # type: ignore
from pymochow.model.database import Database # type: ignore
from pymochow.model.enum import IndexState, IndexType, MetricType, ReadConsistency, TableState # type: ignore
from pymochow.model.schema import HNSWParams, VectorIndex # type: ignore
from pymochow.model.table import Table # type: ignore
from pymochow import MochowClient
from pymochow.model.database import Database
from pymochow.model.enum import IndexState, IndexType, MetricType, ReadConsistency, TableState
from pymochow.model.schema import HNSWParams, VectorIndex
from pymochow.model.table import Table
class AttrDict(UserDict):

View File

@ -3,15 +3,15 @@ from typing import Any, Union
import pytest
from _pytest.monkeypatch import MonkeyPatch
from tcvectordb import RPCVectorDBClient # type: ignore
from tcvectordb import RPCVectorDBClient
from tcvectordb.model import enum
from tcvectordb.model.collection import FilterIndexConfig
from tcvectordb.model.document import AnnSearch, Document, Filter, KeywordSearch, Rerank # type: ignore
from tcvectordb.model.enum import ReadConsistency # type: ignore
from tcvectordb.model.index import FilterIndex, HNSWParams, Index, IndexField, VectorIndex # type: ignore
from tcvectordb.model.document import AnnSearch, Document, Filter, KeywordSearch, Rerank
from tcvectordb.model.enum import ReadConsistency
from tcvectordb.model.index import FilterIndex, HNSWParams, Index, IndexField, VectorIndex
from tcvectordb.rpc.model.collection import RPCCollection
from tcvectordb.rpc.model.database import RPCDatabase
from xinference_client.types import Embedding # type: ignore
from xinference_client.types import Embedding
class MockTcvectordbClass:

View File

@ -4,7 +4,7 @@ from unittest.mock import MagicMock
import pytest
from _pytest.monkeypatch import MonkeyPatch
from volcengine.viking_db import ( # type: ignore
from volcengine.viking_db import (
Collection,
Data,
DistanceType,

View File

@ -43,7 +43,7 @@ class TestWorkflowResponseConverterFetchFilesFromVariableValue:
"""Test with None input"""
# The method signature expects Union[dict, list, Segment], but implementation handles None
# We'll test the actual behavior by passing an empty dict instead
result = WorkflowResponseConverter._fetch_files_from_variable_value(None) # type: ignore
result = WorkflowResponseConverter._fetch_files_from_variable_value(None)
assert result == []
def test_fetch_files_from_variable_value_with_empty_dict(self):

View File

@ -235,7 +235,7 @@ class TestIndividualHandlers:
# Type assertion needed due to union type
text_content = result.content[0]
assert hasattr(text_content, "text")
assert text_content.text == "test answer" # type: ignore[attr-defined]
assert text_content.text == "test answer"
def test_handle_call_tool_no_end_user(self):
"""Test call tool handler without end user"""

View File

@ -0,0 +1,181 @@
from __future__ import annotations
import time
from collections.abc import Mapping
from dataclasses import dataclass
from typing import Any
import pytest
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType
from core.workflow.graph import Graph
from core.workflow.graph.validation import GraphValidationError
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from core.workflow.system_variable import SystemVariable
from models.enums import UserFrom
class _TestNode(Node):
node_type = NodeType.ANSWER
execution_type = NodeExecutionType.EXECUTABLE
@classmethod
def version(cls) -> str:
return "test"
def __init__(
self,
*,
id: str,
config: Mapping[str, object],
graph_init_params: GraphInitParams,
graph_runtime_state: GraphRuntimeState,
) -> None:
super().__init__(
id=id,
config=config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
data = config.get("data", {})
if isinstance(data, Mapping):
execution_type = data.get("execution_type")
if isinstance(execution_type, str):
self.execution_type = NodeExecutionType(execution_type)
self._base_node_data = BaseNodeData(title=str(data.get("title", self.id)))
self.data: dict[str, object] = {}
def init_node_data(self, data: Mapping[str, object]) -> None:
title = str(data.get("title", self.id))
desc = data.get("description")
error_strategy_value = data.get("error_strategy")
error_strategy: ErrorStrategy | None = None
if isinstance(error_strategy_value, ErrorStrategy):
error_strategy = error_strategy_value
elif isinstance(error_strategy_value, str):
error_strategy = ErrorStrategy(error_strategy_value)
self._base_node_data = BaseNodeData(
title=title,
desc=str(desc) if desc is not None else None,
error_strategy=error_strategy,
)
self.data = dict(data)
def _run(self):
raise NotImplementedError
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._base_node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._base_node_data.retry_config
def _get_title(self) -> str:
return self._base_node_data.title
def _get_description(self) -> str | None:
return self._base_node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._base_node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._base_node_data
@dataclass(slots=True)
class _SimpleNodeFactory:
graph_init_params: GraphInitParams
graph_runtime_state: GraphRuntimeState
def create_node(self, node_config: Mapping[str, object]) -> _TestNode:
node_id = str(node_config["id"])
node = _TestNode(
id=node_id,
config=node_config,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
)
node.init_node_data(node_config.get("data", {}))
return node
@pytest.fixture
def graph_init_dependencies() -> tuple[_SimpleNodeFactory, dict[str, object]]:
graph_config: dict[str, object] = {"edges": [], "nodes": []}
init_params = GraphInitParams(
tenant_id="tenant",
app_id="app",
workflow_id="workflow",
graph_config=graph_config,
user_id="user",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.SERVICE_API,
call_depth=0,
)
variable_pool = VariablePool(system_variables=SystemVariable(user_id="user", files=[]), user_inputs={})
runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
factory = _SimpleNodeFactory(graph_init_params=init_params, graph_runtime_state=runtime_state)
return factory, graph_config
def test_graph_initialization_runs_default_validators(
graph_init_dependencies: tuple[_SimpleNodeFactory, dict[str, object]],
):
node_factory, graph_config = graph_init_dependencies
graph_config["nodes"] = [
{"id": "start", "data": {"type": NodeType.START, "title": "Start", "execution_type": NodeExecutionType.ROOT}},
{"id": "answer", "data": {"type": NodeType.ANSWER, "title": "Answer"}},
]
graph_config["edges"] = [
{"source": "start", "target": "answer", "sourceHandle": "success"},
]
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)
assert graph.root_node.id == "start"
assert "answer" in graph.nodes
def test_graph_validation_fails_for_unknown_edge_targets(
graph_init_dependencies: tuple[_SimpleNodeFactory, dict[str, object]],
) -> None:
node_factory, graph_config = graph_init_dependencies
graph_config["nodes"] = [
{"id": "start", "data": {"type": NodeType.START, "title": "Start", "execution_type": NodeExecutionType.ROOT}},
]
graph_config["edges"] = [
{"source": "start", "target": "missing", "sourceHandle": "success"},
]
with pytest.raises(GraphValidationError) as exc:
Graph.init(graph_config=graph_config, node_factory=node_factory)
assert any(issue.code == "MISSING_NODE" for issue in exc.value.issues)
def test_graph_promotes_fail_branch_nodes_to_branch_execution_type(
graph_init_dependencies: tuple[_SimpleNodeFactory, dict[str, object]],
) -> None:
node_factory, graph_config = graph_init_dependencies
graph_config["nodes"] = [
{"id": "start", "data": {"type": NodeType.START, "title": "Start", "execution_type": NodeExecutionType.ROOT}},
{
"id": "branch",
"data": {
"type": NodeType.IF_ELSE,
"title": "Branch",
"error_strategy": ErrorStrategy.FAIL_BRANCH,
},
},
]
graph_config["edges"] = [
{"source": "start", "target": "branch", "sourceHandle": "success"},
]
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)
assert graph.nodes["branch"].execution_type == NodeExecutionType.BRANCH

View File

@ -212,7 +212,7 @@ class TestValidateResult:
parameters=[
ParameterConfig(
name="status",
type="select", # type: ignore
type="select",
description="Status",
required=True,
options=["active", "inactive"],
@ -400,7 +400,7 @@ class TestTransformResult:
parameters=[
ParameterConfig(
name="status",
type="select", # type: ignore
type="select",
description="Status",
required=True,
options=["active", "inactive"],
@ -414,7 +414,7 @@ class TestTransformResult:
parameters=[
ParameterConfig(
name="status",
type="select", # type: ignore
type="select",
description="Status",
required=True,
options=["active", "inactive"],

View File

@ -248,4 +248,4 @@ def test_constructor_with_extra_key():
# Test that SystemVariable should forbid extra keys
with pytest.raises(ValidationError):
# This should fail because there is an unexpected key.
SystemVariable(invalid_key=1) # type: ignore
SystemVariable(invalid_key=1)

View File

@ -14,36 +14,36 @@ def _create_api_app():
api = ExternalApi(bp)
@api.route("/bad-request")
class Bad(Resource): # type: ignore
def get(self): # type: ignore
class Bad(Resource):
def get(self):
raise BadRequest("invalid input")
@api.route("/unauth")
class Unauth(Resource): # type: ignore
def get(self): # type: ignore
class Unauth(Resource):
def get(self):
raise Unauthorized("auth required")
@api.route("/value-error")
class ValErr(Resource): # type: ignore
def get(self): # type: ignore
class ValErr(Resource):
def get(self):
raise ValueError("boom")
@api.route("/quota")
class Quota(Resource): # type: ignore
def get(self): # type: ignore
class Quota(Resource):
def get(self):
raise AppInvokeQuotaExceededError("quota exceeded")
@api.route("/general")
class Gen(Resource): # type: ignore
def get(self): # type: ignore
class Gen(Resource):
def get(self):
raise RuntimeError("oops")
# Note: We avoid altering default_mediatype to keep normal error paths
# Special 400 message rewrite
@api.route("/json-empty")
class JsonEmpty(Resource): # type: ignore
def get(self): # type: ignore
class JsonEmpty(Resource):
def get(self):
e = BadRequest()
# Force the specific message the handler rewrites
e.description = "Failed to decode JSON object: Expecting value: line 1 column 1 (char 0)"
@ -51,11 +51,11 @@ def _create_api_app():
# 400 mapping payload path
@api.route("/param-errors")
class ParamErrors(Resource): # type: ignore
def get(self): # type: ignore
class ParamErrors(Resource):
def get(self):
e = BadRequest()
# Coerce a mapping description to trigger param error shaping
e.description = {"field": "is required"} # type: ignore[assignment]
e.description = {"field": "is required"}
raise e
app.register_blueprint(bp, url_prefix="/api")
@ -105,7 +105,7 @@ def test_external_api_param_mapping_and_quota_and_exc_info_none():
orig_exc_info = ext.sys.exc_info
try:
ext.sys.exc_info = lambda: (None, None, None) # type: ignore[assignment]
ext.sys.exc_info = lambda: (None, None, None)
app = _create_api_app()
client = app.test_client()

View File

@ -67,7 +67,7 @@ def test_current_user_not_accessible_across_threads(login_app: Flask, test_user:
# without preserve_flask_contexts
result["user_accessible"] = current_user.is_authenticated
except Exception as e:
result["error"] = str(e) # type: ignore
result["error"] = str(e)
# Run the function in a separate thread
thread = threading.Thread(target=check_user_in_thread)
@ -110,7 +110,7 @@ def test_current_user_accessible_with_preserve_flask_contexts(login_app: Flask,
else:
result["user_accessible"] = False
except Exception as e:
result["error"] = str(e) # type: ignore
result["error"] = str(e)
# Run the function in a separate thread
thread = threading.Thread(target=check_user_in_thread_with_manager)

View File

@ -16,4 +16,4 @@ def test_oauth_base_methods_raise_not_implemented():
oauth.get_raw_user_info("token")
with pytest.raises(NotImplementedError):
oauth._transform_user_info({}) # type: ignore[name-defined]
oauth._transform_user_info({})

View File

@ -3,8 +3,8 @@ from unittest.mock import MagicMock
import pytest
from _pytest.monkeypatch import MonkeyPatch
from qcloud_cos import CosS3Client # type: ignore
from qcloud_cos.streambody import StreamBody # type: ignore
from qcloud_cos import CosS3Client
from qcloud_cos.streambody import StreamBody
from tests.unit_tests.oss.__mock.base import (
get_example_bucket,

View File

@ -4,8 +4,8 @@ from unittest.mock import MagicMock
import pytest
from _pytest.monkeypatch import MonkeyPatch
from tos import TosClientV2 # type: ignore
from tos.clientv2 import DeleteObjectOutput, GetObjectOutput, HeadObjectOutput, PutObjectOutput # type: ignore
from tos import TosClientV2
from tos.clientv2 import DeleteObjectOutput, GetObjectOutput, HeadObjectOutput, PutObjectOutput
from tests.unit_tests.oss.__mock.base import (
get_example_bucket,

View File

@ -1,7 +1,7 @@
from unittest.mock import patch
import pytest
from qcloud_cos import CosConfig # type: ignore
from qcloud_cos import CosConfig
from extensions.storage.tencent_cos_storage import TencentCosStorage
from tests.unit_tests.oss.__mock.base import (

View File

@ -1,7 +1,7 @@
from unittest.mock import patch
import pytest
from tos import TosClientV2 # type: ignore
from tos import TosClientV2
from extensions.storage.volcengine_tos_storage import VolcengineTosStorage
from tests.unit_tests.oss.__mock.base import (

View File

@ -125,13 +125,13 @@ class TestApiKeyAuthService:
mock_session.commit = Mock()
args_copy = self.mock_args.copy()
original_key = args_copy["credentials"]["config"]["api_key"] # type: ignore
original_key = args_copy["credentials"]["config"]["api_key"]
ApiKeyAuthService.create_provider_auth(self.tenant_id, args_copy)
# Verify original key is replaced with encrypted key
assert args_copy["credentials"]["config"]["api_key"] == encrypted_key # type: ignore
assert args_copy["credentials"]["config"]["api_key"] != original_key # type: ignore
assert args_copy["credentials"]["config"]["api_key"] == encrypted_key
assert args_copy["credentials"]["config"]["api_key"] != original_key
# Verify encryption function is called correctly
mock_encrypter.encrypt_token.assert_called_once_with(self.tenant_id, original_key)
@ -268,7 +268,7 @@ class TestApiKeyAuthService:
def test_validate_api_key_auth_args_empty_credentials(self):
"""Test API key auth args validation - empty credentials"""
args = self.mock_args.copy()
args["credentials"] = None # type: ignore
args["credentials"] = None
with pytest.raises(ValueError, match="credentials is required"):
ApiKeyAuthService.validate_api_key_auth_args(args)
@ -284,7 +284,7 @@ class TestApiKeyAuthService:
def test_validate_api_key_auth_args_missing_auth_type(self):
"""Test API key auth args validation - missing auth_type"""
args = self.mock_args.copy()
del args["credentials"]["auth_type"] # type: ignore
del args["credentials"]["auth_type"]
with pytest.raises(ValueError, match="auth_type is required"):
ApiKeyAuthService.validate_api_key_auth_args(args)
@ -292,7 +292,7 @@ class TestApiKeyAuthService:
def test_validate_api_key_auth_args_empty_auth_type(self):
"""Test API key auth args validation - empty auth_type"""
args = self.mock_args.copy()
args["credentials"]["auth_type"] = "" # type: ignore
args["credentials"]["auth_type"] = ""
with pytest.raises(ValueError, match="auth_type is required"):
ApiKeyAuthService.validate_api_key_auth_args(args)
@ -380,7 +380,7 @@ class TestApiKeyAuthService:
def test_validate_api_key_auth_args_dict_credentials_with_list_auth_type(self):
"""Test API key auth args validation - dict credentials with list auth_type"""
args = self.mock_args.copy()
args["credentials"]["auth_type"] = ["api_key"] # type: ignore # list instead of string
args["credentials"]["auth_type"] = ["api_key"]
# Current implementation checks if auth_type exists and is truthy, list ["api_key"] is truthy
# So this should not raise exception, this test should pass

View File

@ -116,10 +116,10 @@ class TestSystemOAuthEncrypter:
encrypter = SystemOAuthEncrypter("test_secret")
with pytest.raises(Exception): # noqa: B017
encrypter.encrypt_oauth_params(None) # type: ignore
encrypter.encrypt_oauth_params(None)
with pytest.raises(Exception): # noqa: B017
encrypter.encrypt_oauth_params("not_a_dict") # type: ignore
encrypter.encrypt_oauth_params("not_a_dict")
def test_decrypt_oauth_params_basic(self):
"""Test basic OAuth parameters decryption"""
@ -207,12 +207,12 @@ class TestSystemOAuthEncrypter:
encrypter = SystemOAuthEncrypter("test_secret")
with pytest.raises(ValueError) as exc_info:
encrypter.decrypt_oauth_params(123) # type: ignore
encrypter.decrypt_oauth_params(123)
assert "encrypted_data must be a string" in str(exc_info.value)
with pytest.raises(ValueError) as exc_info:
encrypter.decrypt_oauth_params(None) # type: ignore
encrypter.decrypt_oauth_params(None)
assert "encrypted_data must be a string" in str(exc_info.value)
@ -461,14 +461,14 @@ class TestConvenienceFunctions:
"""Test convenience functions with error conditions"""
# Test encryption with invalid input
with pytest.raises(Exception): # noqa: B017
encrypt_system_oauth_params(None) # type: ignore
encrypt_system_oauth_params(None)
# Test decryption with invalid input
with pytest.raises(ValueError):
decrypt_system_oauth_params("")
with pytest.raises(ValueError):
decrypt_system_oauth_params(None) # type: ignore
decrypt_system_oauth_params(None)
class TestErrorHandling:
@ -501,7 +501,7 @@ class TestErrorHandling:
# Test non-string error
with pytest.raises(ValueError) as exc_info:
encrypter.decrypt_oauth_params(123) # type: ignore
encrypter.decrypt_oauth_params(123)
assert "encrypted_data must be a string" in str(exc_info.value)
# Test invalid format error

View File

@ -0,0 +1,4 @@
// Mock for context-block plugin to avoid circular dependency in Storybook
export const ContextBlockNode = null
export const ContextBlockReplacementBlock = null
export default null

View File

@ -0,0 +1,4 @@
// Mock for history-block plugin to avoid circular dependency in Storybook
export const HistoryBlockNode = null
export const HistoryBlockReplacementBlock = null
export default null

View File

@ -0,0 +1,4 @@
// Mock for query-block plugin to avoid circular dependency in Storybook
export const QueryBlockNode = null
export const QueryBlockReplacementBlock = null
export default null

View File

@ -1,4 +1,9 @@
import type { StorybookConfig } from '@storybook/nextjs'
import path from 'node:path'
import { fileURLToPath } from 'node:url'
const __filename = fileURLToPath(import.meta.url)
const __dirname = path.dirname(__filename)
const config: StorybookConfig = {
stories: ['../app/components/**/*.stories.@(js|jsx|mjs|ts|tsx)'],
@ -25,5 +30,17 @@ const config: StorybookConfig = {
docs: {
defaultName: 'Documentation',
},
webpackFinal: async (config) => {
// Add alias to mock problematic modules with circular dependencies
config.resolve = config.resolve || {}
config.resolve.alias = {
...config.resolve.alias,
// Mock the plugin index files to avoid circular dependencies
[path.resolve(__dirname, '../app/components/base/prompt-editor/plugins/context-block/index.tsx')]: path.resolve(__dirname, '__mocks__/context-block.tsx'),
[path.resolve(__dirname, '../app/components/base/prompt-editor/plugins/history-block/index.tsx')]: path.resolve(__dirname, '__mocks__/history-block.tsx'),
[path.resolve(__dirname, '../app/components/base/prompt-editor/plugins/query-block/index.tsx')]: path.resolve(__dirname, '__mocks__/query-block.tsx'),
}
return config
},
}
export default config

View File

@ -160,8 +160,7 @@ describe('Navigation Utilities', () => {
page: 1,
limit: '',
keyword: 'test',
empty: null,
undefined,
filter: '',
})
expect(path).toBe('/datasets/123/documents?page=1&keyword=test')

View File

@ -39,28 +39,38 @@ const setupMockEnvironment = (storedTheme: string | null, systemPrefersDark = fa
const isDarkQuery = DARK_MODE_MEDIA_QUERY.test(query)
const matches = isDarkQuery ? systemPrefersDark : false
const handleAddListener = (listener: (event: MediaQueryListEvent) => void) => {
listeners.add(listener)
}
const handleRemoveListener = (listener: (event: MediaQueryListEvent) => void) => {
listeners.delete(listener)
}
const handleAddEventListener = (_event: string, listener: EventListener) => {
if (typeof listener === 'function')
listeners.add(listener as (event: MediaQueryListEvent) => void)
}
const handleRemoveEventListener = (_event: string, listener: EventListener) => {
if (typeof listener === 'function')
listeners.delete(listener as (event: MediaQueryListEvent) => void)
}
const handleDispatchEvent = (event: Event) => {
listeners.forEach(listener => listener(event as MediaQueryListEvent))
return true
}
const mediaQueryList: MediaQueryList = {
matches,
media: query,
onchange: null,
addListener: (listener: MediaQueryListListener) => {
listeners.add(listener)
},
removeListener: (listener: MediaQueryListListener) => {
listeners.delete(listener)
},
addEventListener: (_event, listener: EventListener) => {
if (typeof listener === 'function')
listeners.add(listener as MediaQueryListListener)
},
removeEventListener: (_event, listener: EventListener) => {
if (typeof listener === 'function')
listeners.delete(listener as MediaQueryListListener)
},
dispatchEvent: (event: Event) => {
listeners.forEach(listener => listener(event as MediaQueryListEvent))
return true
},
addListener: handleAddListener,
removeListener: handleRemoveListener,
addEventListener: handleAddEventListener,
removeEventListener: handleRemoveEventListener,
dispatchEvent: handleDispatchEvent,
}
return mediaQueryList
@ -69,6 +79,121 @@ const setupMockEnvironment = (storedTheme: string | null, systemPrefersDark = fa
jest.spyOn(window, 'matchMedia').mockImplementation(mockMatchMedia)
}
// Helper function to create timing page component
const createTimingPageComponent = (
timingData: Array<{ phase: string; timestamp: number; styles: { backgroundColor: string; color: string } }>,
) => {
const recordTiming = (phase: string, styles: { backgroundColor: string; color: string }) => {
timingData.push({
phase,
timestamp: performance.now(),
styles,
})
}
const TimingPageComponent = () => {
const [mounted, setMounted] = useState(false)
const { theme } = useTheme()
const isDark = mounted ? theme === 'dark' : false
const currentStyles = {
backgroundColor: isDark ? '#1f2937' : '#ffffff',
color: isDark ? '#ffffff' : '#000000',
}
recordTiming(mounted ? 'CSR' : 'Initial', currentStyles)
useEffect(() => {
setMounted(true)
}, [])
return (
<div
data-testid="timing-page"
style={currentStyles}
>
<div data-testid="timing-status">
Phase: {mounted ? 'CSR' : 'Initial'} | Theme: {theme} | Visual: {isDark ? 'dark' : 'light'}
</div>
</div>
)
}
return TimingPageComponent
}
// Helper function to create CSS test component
const createCSSTestComponent = (
cssStates: Array<{ className: string; timestamp: number }>,
) => {
const recordCSSState = (className: string) => {
cssStates.push({
className,
timestamp: performance.now(),
})
}
const CSSTestComponent = () => {
const [mounted, setMounted] = useState(false)
const { theme } = useTheme()
const isDark = mounted ? theme === 'dark' : false
const className = `min-h-screen ${isDark ? 'bg-gray-900 text-white' : 'bg-white text-black'}`
recordCSSState(className)
useEffect(() => {
setMounted(true)
}, [])
return (
<div
data-testid="css-component"
className={className}
>
<div data-testid="css-classes">Classes: {className}</div>
</div>
)
}
return CSSTestComponent
}
// Helper function to create performance test component
const createPerformanceTestComponent = (
performanceMarks: Array<{ event: string; timestamp: number }>,
) => {
const recordPerformanceMark = (event: string) => {
performanceMarks.push({ event, timestamp: performance.now() })
}
const PerformanceTestComponent = () => {
const [mounted, setMounted] = useState(false)
const { theme } = useTheme()
recordPerformanceMark('component-render')
useEffect(() => {
recordPerformanceMark('mount-start')
setMounted(true)
recordPerformanceMark('mount-complete')
}, [])
useEffect(() => {
if (theme)
recordPerformanceMark('theme-available')
}, [theme])
return (
<div data-testid="performance-test">
Mounted: {mounted.toString()} | Theme: {theme || 'loading'}
</div>
)
}
return PerformanceTestComponent
}
// Simulate real page component based on Dify's actual theme usage
const PageComponent = () => {
const [mounted, setMounted] = useState(false)
@ -227,39 +352,7 @@ describe('Real Browser Environment Dark Mode Flicker Test', () => {
setupMockEnvironment('dark')
const timingData: Array<{ phase: string; timestamp: number; styles: any }> = []
const TimingPageComponent = () => {
const [mounted, setMounted] = useState(false)
const { theme } = useTheme()
const isDark = mounted ? theme === 'dark' : false
// Record timing and styles for each render phase
const currentStyles = {
backgroundColor: isDark ? '#1f2937' : '#ffffff',
color: isDark ? '#ffffff' : '#000000',
}
timingData.push({
phase: mounted ? 'CSR' : 'Initial',
timestamp: performance.now(),
styles: currentStyles,
})
useEffect(() => {
setMounted(true)
}, [])
return (
<div
data-testid="timing-page"
style={currentStyles}
>
<div data-testid="timing-status">
Phase: {mounted ? 'CSR' : 'Initial'} | Theme: {theme} | Visual: {isDark ? 'dark' : 'light'}
</div>
</div>
)
}
const TimingPageComponent = createTimingPageComponent(timingData)
render(
<TestThemeProvider>
@ -295,33 +388,7 @@ describe('Real Browser Environment Dark Mode Flicker Test', () => {
setupMockEnvironment('dark')
const cssStates: Array<{ className: string; timestamp: number }> = []
const CSSTestComponent = () => {
const [mounted, setMounted] = useState(false)
const { theme } = useTheme()
const isDark = mounted ? theme === 'dark' : false
// Simulate Tailwind CSS class application
const className = `min-h-screen ${isDark ? 'bg-gray-900 text-white' : 'bg-white text-black'}`
cssStates.push({
className,
timestamp: performance.now(),
})
useEffect(() => {
setMounted(true)
}, [])
return (
<div
data-testid="css-component"
className={className}
>
<div data-testid="css-classes">Classes: {className}</div>
</div>
)
}
const CSSTestComponent = createCSSTestComponent(cssStates)
render(
<TestThemeProvider>
@ -413,34 +480,12 @@ describe('Real Browser Environment Dark Mode Flicker Test', () => {
test('verifies ThemeProvider position fix reduces initialization delay', async () => {
const performanceMarks: Array<{ event: string; timestamp: number }> = []
const PerformanceTestComponent = () => {
const [mounted, setMounted] = useState(false)
const { theme } = useTheme()
performanceMarks.push({ event: 'component-render', timestamp: performance.now() })
useEffect(() => {
performanceMarks.push({ event: 'mount-start', timestamp: performance.now() })
setMounted(true)
performanceMarks.push({ event: 'mount-complete', timestamp: performance.now() })
}, [])
useEffect(() => {
if (theme)
performanceMarks.push({ event: 'theme-available', timestamp: performance.now() })
}, [theme])
return (
<div data-testid="performance-test">
Mounted: {mounted.toString()} | Theme: {theme || 'loading'}
</div>
)
}
setupMockEnvironment('dark')
expect(window.localStorage.getItem('theme')).toBe('dark')
const PerformanceTestComponent = createPerformanceTestComponent(performanceMarks)
render(
<TestThemeProvider>
<PerformanceTestComponent />

View File

@ -70,14 +70,18 @@ describe('Unified Tags Editing - Pure Logic Tests', () => {
})
describe('Fallback Logic (from layout-main.tsx)', () => {
type Tag = { id: string; name: string }
type AppDetail = { tags: Tag[] }
type FallbackResult = { tags?: Tag[] } | null
// no-op
it('should trigger fallback when tags are missing or empty', () => {
const appDetailWithoutTags = { tags: [] }
const appDetailWithTags = { tags: [{ id: 'tag1' }] }
const appDetailWithUndefinedTags = { tags: undefined as any }
const appDetailWithoutTags: AppDetail = { tags: [] }
const appDetailWithTags: AppDetail = { tags: [{ id: 'tag1', name: 't' }] }
const appDetailWithUndefinedTags: { tags: Tag[] | undefined } = { tags: undefined }
// This simulates the condition in layout-main.tsx
const shouldFallback1 = !appDetailWithoutTags.tags || appDetailWithoutTags.tags.length === 0
const shouldFallback2 = !appDetailWithTags.tags || appDetailWithTags.tags.length === 0
const shouldFallback1 = appDetailWithoutTags.tags.length === 0
const shouldFallback2 = appDetailWithTags.tags.length === 0
const shouldFallback3 = !appDetailWithUndefinedTags.tags || appDetailWithUndefinedTags.tags.length === 0
expect(shouldFallback1).toBe(true) // Empty array should trigger fallback
@ -86,24 +90,26 @@ describe('Unified Tags Editing - Pure Logic Tests', () => {
})
it('should preserve tags when fallback succeeds', () => {
const originalAppDetail = { tags: [] as any[] }
const fallbackResult = { tags: [{ id: 'tag1', name: 'fallback-tag' }] }
const originalAppDetail: AppDetail = { tags: [] }
const fallbackResult: { tags?: Tag[] } = { tags: [{ id: 'tag1', name: 'fallback-tag' }] }
// This simulates the successful fallback in layout-main.tsx
if (fallbackResult?.tags)
originalAppDetail.tags = fallbackResult.tags
const tags = fallbackResult.tags
if (tags)
originalAppDetail.tags = tags
expect(originalAppDetail.tags).toEqual(fallbackResult.tags)
expect(originalAppDetail.tags.length).toBe(1)
})
it('should continue with empty tags when fallback fails', () => {
const originalAppDetail: { tags: any[] } = { tags: [] }
const fallbackResult: { tags?: any[] } | null = null
const originalAppDetail: AppDetail = { tags: [] }
const fallbackResult = null as FallbackResult
// This simulates fallback failure in layout-main.tsx
if (fallbackResult?.tags)
originalAppDetail.tags = fallbackResult.tags
const tags: Tag[] | undefined = fallbackResult && 'tags' in fallbackResult ? fallbackResult.tags : undefined
if (tags)
originalAppDetail.tags = tags
expect(originalAppDetail.tags).toEqual([])
})

View File

@ -73,7 +73,7 @@ const ConfigPopup: FC<PopupProps> = ({
}
}, [onChooseProvider])
const handleConfigUpdated = useCallback((payload: ArizeConfig | PhoenixConfig | LangSmithConfig | LangFuseConfig | OpikConfig | WeaveConfig | AliyunConfig) => {
const handleConfigUpdated = useCallback((payload: ArizeConfig | PhoenixConfig | LangSmithConfig | LangFuseConfig | OpikConfig | WeaveConfig | AliyunConfig | TencentConfig) => {
onConfigUpdated(currentProvider!, payload)
hideConfigModal()
}, [currentProvider, hideConfigModal, onConfigUpdated])

View File

@ -6,7 +6,6 @@ import { useWebAppStore } from '@/context/web-app-context'
import { useRouter, useSearchParams } from 'next/navigation'
import AppUnavailable from '@/app/components/base/app-unavailable'
import { useTranslation } from 'react-i18next'
import { AccessMode } from '@/models/access-control'
import { webAppLoginStatus, webAppLogout } from '@/service/webapp-auth'
import { fetchAccessToken } from '@/service/share'
import Loading from '@/app/components/base/loading'
@ -35,7 +34,6 @@ const Splash: FC<PropsWithChildren> = ({ children }) => {
router.replace(url)
}, [getSigninUrl, router, webAppLogout, shareCode])
const needCheckIsLogin = webAppAccessMode !== AccessMode.PUBLIC
const [isLoading, setIsLoading] = useState(true)
useEffect(() => {
if (message) {
@ -58,8 +56,8 @@ const Splash: FC<PropsWithChildren> = ({ children }) => {
}
(async () => {
const { userLoggedIn, appLoggedIn } = await webAppLoginStatus(needCheckIsLogin, shareCode!)
// if access mode is public, user login is always true, but the app login(passport) may be expired
const { userLoggedIn, appLoggedIn } = await webAppLoginStatus(shareCode!)
if (userLoggedIn && appLoggedIn) {
redirectOrFinish()
}
@ -87,7 +85,6 @@ const Splash: FC<PropsWithChildren> = ({ children }) => {
router,
message,
webAppAccessMode,
needCheckIsLogin,
tokenFromUrl])
if (message) {

View File

@ -0,0 +1,262 @@
import type { Meta, StoryObj } from '@storybook/nextjs'
import { RiAddLine, RiDeleteBinLine, RiEditLine, RiMore2Fill, RiSaveLine, RiShareLine } from '@remixicon/react'
import ActionButton, { ActionButtonState } from '.'
const meta = {
title: 'Base/ActionButton',
component: ActionButton,
parameters: {
layout: 'centered',
docs: {
description: {
component: 'Action button component with multiple sizes and states. Commonly used for toolbar actions and inline operations.',
},
},
},
tags: ['autodocs'],
argTypes: {
size: {
control: 'select',
options: ['xs', 'm', 'l', 'xl'],
description: 'Button size',
},
state: {
control: 'select',
options: [
ActionButtonState.Default,
ActionButtonState.Active,
ActionButtonState.Disabled,
ActionButtonState.Destructive,
ActionButtonState.Hover,
],
description: 'Button state',
},
children: {
control: 'text',
description: 'Button content',
},
disabled: {
control: 'boolean',
description: 'Native disabled state',
},
},
} satisfies Meta<typeof ActionButton>
export default meta
type Story = StoryObj<typeof meta>
// Default state
export const Default: Story = {
args: {
size: 'm',
children: <RiEditLine className="h-4 w-4" />,
},
}
// With text
export const WithText: Story = {
args: {
size: 'm',
children: 'Edit',
},
}
// Icon with text
export const IconWithText: Story = {
args: {
size: 'm',
children: (
<>
<RiAddLine className="mr-1 h-4 w-4" />
Add Item
</>
),
},
}
// Size variations
export const ExtraSmall: Story = {
args: {
size: 'xs',
children: <RiEditLine className="h-3 w-3" />,
},
}
export const Small: Story = {
args: {
size: 'xs',
children: <RiEditLine className="h-3.5 w-3.5" />,
},
}
export const Medium: Story = {
args: {
size: 'm',
children: <RiEditLine className="h-4 w-4" />,
},
}
export const Large: Story = {
args: {
size: 'l',
children: <RiEditLine className="h-5 w-5" />,
},
}
export const ExtraLarge: Story = {
args: {
size: 'xl',
children: <RiEditLine className="h-6 w-6" />,
},
}
// State variations
export const ActiveState: Story = {
args: {
size: 'm',
state: ActionButtonState.Active,
children: <RiEditLine className="h-4 w-4" />,
},
}
export const DisabledState: Story = {
args: {
size: 'm',
state: ActionButtonState.Disabled,
children: <RiEditLine className="h-4 w-4" />,
},
}
export const DestructiveState: Story = {
args: {
size: 'm',
state: ActionButtonState.Destructive,
children: <RiDeleteBinLine className="h-4 w-4" />,
},
}
export const HoverState: Story = {
args: {
size: 'm',
state: ActionButtonState.Hover,
children: <RiEditLine className="h-4 w-4" />,
},
}
// Real-world examples
export const ToolbarActions: Story = {
render: () => (
<div className="flex items-center gap-1 rounded-lg bg-background-section-burn p-2">
<ActionButton size="m">
<RiEditLine className="h-4 w-4" />
</ActionButton>
<ActionButton size="m">
<RiShareLine className="h-4 w-4" />
</ActionButton>
<ActionButton size="m">
<RiSaveLine className="h-4 w-4" />
</ActionButton>
<div className="mx-1 h-4 w-px bg-divider-regular" />
<ActionButton size="m" state={ActionButtonState.Destructive}>
<RiDeleteBinLine className="h-4 w-4" />
</ActionButton>
</div>
),
}
export const InlineActions: Story = {
render: () => (
<div className="flex items-center gap-2">
<span className="text-text-secondary">Item name</span>
<ActionButton size="xs">
<RiEditLine className="h-3.5 w-3.5" />
</ActionButton>
<ActionButton size="xs">
<RiMore2Fill className="h-3.5 w-3.5" />
</ActionButton>
</div>
),
}
export const SizeComparison: Story = {
render: () => (
<div className="flex items-center gap-4">
<div className="flex flex-col items-center gap-2">
<ActionButton size="xs">
<RiEditLine className="h-3 w-3" />
</ActionButton>
<span className="text-xs text-text-tertiary">XS</span>
</div>
<div className="flex flex-col items-center gap-2">
<ActionButton size="xs">
<RiEditLine className="h-3.5 w-3.5" />
</ActionButton>
<span className="text-xs text-text-tertiary">S</span>
</div>
<div className="flex flex-col items-center gap-2">
<ActionButton size="m">
<RiEditLine className="h-4 w-4" />
</ActionButton>
<span className="text-xs text-text-tertiary">M</span>
</div>
<div className="flex flex-col items-center gap-2">
<ActionButton size="l">
<RiEditLine className="h-5 w-5" />
</ActionButton>
<span className="text-xs text-text-tertiary">L</span>
</div>
<div className="flex flex-col items-center gap-2">
<ActionButton size="xl">
<RiEditLine className="h-6 w-6" />
</ActionButton>
<span className="text-xs text-text-tertiary">XL</span>
</div>
</div>
),
}
export const StateComparison: Story = {
render: () => (
<div className="flex items-center gap-4">
<div className="flex flex-col items-center gap-2">
<ActionButton size="m" state={ActionButtonState.Default}>
<RiEditLine className="h-4 w-4" />
</ActionButton>
<span className="text-xs text-text-tertiary">Default</span>
</div>
<div className="flex flex-col items-center gap-2">
<ActionButton size="m" state={ActionButtonState.Active}>
<RiEditLine className="h-4 w-4" />
</ActionButton>
<span className="text-xs text-text-tertiary">Active</span>
</div>
<div className="flex flex-col items-center gap-2">
<ActionButton size="m" state={ActionButtonState.Hover}>
<RiEditLine className="h-4 w-4" />
</ActionButton>
<span className="text-xs text-text-tertiary">Hover</span>
</div>
<div className="flex flex-col items-center gap-2">
<ActionButton size="m" state={ActionButtonState.Disabled}>
<RiEditLine className="h-4 w-4" />
</ActionButton>
<span className="text-xs text-text-tertiary">Disabled</span>
</div>
<div className="flex flex-col items-center gap-2">
<ActionButton size="m" state={ActionButtonState.Destructive}>
<RiDeleteBinLine className="h-4 w-4" />
</ActionButton>
<span className="text-xs text-text-tertiary">Destructive</span>
</div>
</div>
),
}
// Interactive playground
export const Playground: Story = {
args: {
size: 'm',
state: ActionButtonState.Default,
children: <RiEditLine className="h-4 w-4" />,
},
}

View File

@ -0,0 +1,204 @@
import type { Meta, StoryObj } from '@storybook/nextjs'
import { useState } from 'react'
import AutoHeightTextarea from '.'
const meta = {
title: 'Base/AutoHeightTextarea',
component: AutoHeightTextarea,
parameters: {
layout: 'centered',
docs: {
description: {
component: 'Auto-resizing textarea component that expands and contracts based on content, with configurable min/max height constraints.',
},
},
},
tags: ['autodocs'],
argTypes: {
placeholder: {
control: 'text',
description: 'Placeholder text',
},
value: {
control: 'text',
description: 'Textarea value',
},
minHeight: {
control: 'number',
description: 'Minimum height in pixels',
},
maxHeight: {
control: 'number',
description: 'Maximum height in pixels',
},
autoFocus: {
control: 'boolean',
description: 'Auto focus on mount',
},
className: {
control: 'text',
description: 'Additional CSS classes',
},
wrapperClassName: {
control: 'text',
description: 'Wrapper CSS classes',
},
},
} satisfies Meta<typeof AutoHeightTextarea>
export default meta
type Story = StoryObj<typeof meta>
// Interactive demo wrapper
const AutoHeightTextareaDemo = (args: any) => {
const [value, setValue] = useState(args.value || '')
return (
<div style={{ width: '500px' }}>
<AutoHeightTextarea
{...args}
value={value}
onChange={(e) => {
setValue(e.target.value)
console.log('Text changed:', e.target.value)
}}
/>
</div>
)
}
// Default state
export const Default: Story = {
render: args => <AutoHeightTextareaDemo {...args} />,
args: {
placeholder: 'Type something...',
value: '',
minHeight: 36,
maxHeight: 96,
className: 'w-full p-2 border border-gray-300 rounded-lg focus:outline-none focus:ring-2 focus:ring-blue-500',
},
}
// With initial value
export const WithInitialValue: Story = {
render: args => <AutoHeightTextareaDemo {...args} />,
args: {
placeholder: 'Type something...',
value: 'This is a pre-filled textarea with some initial content.',
minHeight: 36,
maxHeight: 96,
className: 'w-full p-2 border border-gray-300 rounded-lg focus:outline-none focus:ring-2 focus:ring-blue-500',
},
}
// With multiline content
export const MultilineContent: Story = {
render: args => <AutoHeightTextareaDemo {...args} />,
args: {
placeholder: 'Type something...',
value: 'Line 1\nLine 2\nLine 3\nLine 4\nThis textarea automatically expands to fit the content.',
minHeight: 36,
maxHeight: 96,
className: 'w-full p-2 border border-gray-300 rounded-lg focus:outline-none focus:ring-2 focus:ring-blue-500',
},
}
// Custom min height
export const CustomMinHeight: Story = {
render: args => <AutoHeightTextareaDemo {...args} />,
args: {
placeholder: 'Taller minimum height...',
value: '',
minHeight: 100,
maxHeight: 200,
className: 'w-full p-2 border border-gray-300 rounded-lg focus:outline-none focus:ring-2 focus:ring-blue-500',
},
}
// Small max height (scrollable)
export const SmallMaxHeight: Story = {
render: args => <AutoHeightTextareaDemo {...args} />,
args: {
placeholder: 'Type multiple lines...',
value: 'Line 1\nLine 2\nLine 3\nLine 4\nLine 5\nLine 6\nThis will become scrollable when it exceeds max height.',
minHeight: 36,
maxHeight: 80,
className: 'w-full p-2 border border-gray-300 rounded-lg focus:outline-none focus:ring-2 focus:ring-blue-500',
},
}
// Auto focus enabled
export const AutoFocus: Story = {
render: args => <AutoHeightTextareaDemo {...args} />,
args: {
placeholder: 'This textarea auto-focuses on mount',
value: '',
minHeight: 36,
maxHeight: 96,
autoFocus: true,
className: 'w-full p-2 border border-gray-300 rounded-lg focus:outline-none focus:ring-2 focus:ring-blue-500',
},
}
// With custom styling
export const CustomStyling: Story = {
render: args => <AutoHeightTextareaDemo {...args} />,
args: {
placeholder: 'Custom styled textarea...',
value: '',
minHeight: 50,
maxHeight: 150,
className: 'w-full p-3 bg-gray-50 border-2 border-blue-400 rounded-xl text-lg focus:outline-none focus:bg-white focus:border-blue-600',
wrapperClassName: 'shadow-lg',
},
}
// Long content example
export const LongContent: Story = {
render: args => <AutoHeightTextareaDemo {...args} />,
args: {
placeholder: 'Type something...',
value: 'Lorem ipsum dolor sit amet, consectetur adipiscing elit. Sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.\n\nUt enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.\n\nDuis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur.\n\nExcepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.',
minHeight: 36,
maxHeight: 200,
className: 'w-full p-2 border border-gray-300 rounded-lg focus:outline-none focus:ring-2 focus:ring-blue-500',
},
}
// Real-world example - Chat input
export const ChatInput: Story = {
render: args => <AutoHeightTextareaDemo {...args} />,
args: {
placeholder: 'Type your message...',
value: '',
minHeight: 40,
maxHeight: 120,
className: 'w-full px-4 py-2 bg-gray-100 border border-gray-300 rounded-2xl text-sm focus:outline-none focus:bg-white focus:ring-2 focus:ring-blue-500',
},
}
// Real-world example - Comment box
export const CommentBox: Story = {
render: args => <AutoHeightTextareaDemo {...args} />,
args: {
placeholder: 'Write a comment...',
value: '',
minHeight: 60,
maxHeight: 200,
className: 'w-full p-3 border border-gray-300 rounded-lg text-sm focus:outline-none focus:ring-2 focus:ring-indigo-500',
},
}
// Interactive playground
export const Playground: Story = {
render: args => <AutoHeightTextareaDemo {...args} />,
args: {
placeholder: 'Type something...',
value: '',
minHeight: 36,
maxHeight: 96,
autoFocus: false,
className: 'w-full p-2 border border-gray-300 rounded-lg focus:outline-none focus:ring-2 focus:ring-blue-500',
wrapperClassName: '',
},
}

View File

@ -31,7 +31,7 @@ const AutoHeightTextarea = (
onKeyDown,
onKeyUp,
}: IProps & {
ref: React.RefObject<unknown>;
ref?: React.RefObject<HTMLTextAreaElement>;
},
) => {
// eslint-disable-next-line react-hooks/rules-of-hooks

View File

@ -0,0 +1,191 @@
import type { Meta, StoryObj } from '@storybook/nextjs'
import { useState } from 'react'
import BlockInput from '.'
const meta = {
title: 'Base/BlockInput',
component: BlockInput,
parameters: {
layout: 'centered',
docs: {
description: {
component: 'Block input component with variable highlighting. Supports {{variable}} syntax with validation and visual highlighting of variable names.',
},
},
},
tags: ['autodocs'],
argTypes: {
value: {
control: 'text',
description: 'Input value (supports {{variable}} syntax)',
},
className: {
control: 'text',
description: 'Wrapper CSS classes',
},
highLightClassName: {
control: 'text',
description: 'CSS class for highlighted variables (default: text-blue-500)',
},
readonly: {
control: 'boolean',
description: 'Read-only mode',
},
},
} satisfies Meta<typeof BlockInput>
export default meta
type Story = StoryObj<typeof meta>
// Interactive demo wrapper
const BlockInputDemo = (args: any) => {
const [value, setValue] = useState(args.value || '')
const [keys, setKeys] = useState<string[]>([])
return (
<div style={{ width: '600px' }}>
<BlockInput
{...args}
value={value}
onConfirm={(newValue, extractedKeys) => {
setValue(newValue)
setKeys(extractedKeys)
console.log('Value confirmed:', newValue)
console.log('Extracted keys:', extractedKeys)
}}
/>
{keys.length > 0 && (
<div className="mt-4 rounded-lg bg-blue-50 p-3">
<div className="mb-2 text-sm font-medium text-gray-700">Detected Variables:</div>
<div className="flex flex-wrap gap-2">
{keys.map(key => (
<span key={key} className="rounded bg-blue-500 px-2 py-1 text-xs text-white">
{key}
</span>
))}
</div>
</div>
)}
</div>
)
}
// Default state
export const Default: Story = {
render: args => <BlockInputDemo {...args} />,
args: {
value: '',
readonly: false,
},
}
// With single variable
export const SingleVariable: Story = {
render: args => <BlockInputDemo {...args} />,
args: {
value: 'Hello {{name}}, welcome to the application!',
readonly: false,
},
}
// With multiple variables
export const MultipleVariables: Story = {
render: args => <BlockInputDemo {...args} />,
args: {
value: 'Dear {{user_name}},\n\nYour order {{order_id}} has been shipped to {{address}}.\n\nThank you for shopping with us!',
readonly: false,
},
}
// Complex template
export const ComplexTemplate: Story = {
render: args => <BlockInputDemo {...args} />,
args: {
value: 'Hi {{customer_name}},\n\nYour {{product_type}} subscription will renew on {{renewal_date}} for {{amount}}.\n\nYour payment method ending in {{card_last_4}} will be charged.\n\nQuestions? Contact us at {{support_email}}.',
readonly: false,
},
}
// Read-only mode
export const ReadOnlyMode: Story = {
render: args => <BlockInputDemo {...args} />,
args: {
value: 'This is a read-only template with {{variable1}} and {{variable2}}.\n\nYou cannot edit this content.',
readonly: true,
},
}
// Empty state
export const EmptyState: Story = {
render: args => <BlockInputDemo {...args} />,
args: {
value: '',
readonly: false,
},
}
// Long content
export const LongContent: Story = {
render: args => <BlockInputDemo {...args} />,
args: {
value: 'Dear {{recipient_name}},\n\nWe are writing to inform you about the upcoming changes to your {{service_name}} account.\n\nEffective {{effective_date}}, your plan will include:\n\n1. Access to {{feature_1}}\n2. {{feature_2}} with unlimited usage\n3. Priority support via {{support_channel}}\n4. Monthly reports sent to {{email_address}}\n\nYour new monthly rate will be {{new_price}}, compared to your current rate of {{old_price}}.\n\nIf you have any questions, please contact our team at {{contact_info}}.\n\nBest regards,\n{{company_name}} Team',
readonly: false,
},
}
// Variables with underscores
export const VariablesWithUnderscores: Story = {
render: args => <BlockInputDemo {...args} />,
args: {
value: 'User {{user_id}} from {{user_country}} has {{total_orders}} orders with status {{order_status}}.',
readonly: false,
},
}
// Adjacent variables
export const AdjacentVariables: Story = {
render: args => <BlockInputDemo {...args} />,
args: {
value: 'File: {{file_name}}.{{file_extension}} ({{file_size}}{{size_unit}})',
readonly: false,
},
}
// Real-world example - Email template
export const EmailTemplate: Story = {
render: args => <BlockInputDemo {...args} />,
args: {
value: 'Subject: Your {{service_name}} account has been created\n\nHi {{first_name}},\n\nWelcome to {{company_name}}! Your account is now active.\n\nUsername: {{username}}\nEmail: {{email}}\n\nGet started at {{app_url}}\n\nThanks,\nThe {{company_name}} Team',
readonly: false,
},
}
// Real-world example - Notification template
export const NotificationTemplate: Story = {
render: args => <BlockInputDemo {...args} />,
args: {
value: '🔔 {{user_name}} mentioned you in {{channel_name}}\n\n"{{message_preview}}"\n\nReply now: {{message_url}}',
readonly: false,
},
}
// Custom styling
export const CustomStyling: Story = {
render: args => <BlockInputDemo {...args} />,
args: {
value: 'This template uses {{custom_variable}} with custom styling.',
readonly: false,
className: 'bg-gray-50 border-2 border-blue-200',
},
}
// Interactive playground
export const Playground: Story = {
render: args => <BlockInputDemo {...args} />,
args: {
value: 'Try editing this text and adding variables like {{example}}',
readonly: false,
className: '',
highLightClassName: '',
},
}

Some files were not shown because too many files have changed in this diff Show More