mirror of https://github.com/langgenius/dify.git
Merge branch 'main' into feat/memory-orchestration-fed
This commit is contained in:
commit
c8188274a2
|
|
@ -145,7 +145,7 @@ class DatabaseConfig(BaseSettings):
|
|||
default="postgresql",
|
||||
)
|
||||
|
||||
@computed_field # type: ignore[misc]
|
||||
@computed_field # type: ignore[prop-decorator]
|
||||
@property
|
||||
def SQLALCHEMY_DATABASE_URI(self) -> str:
|
||||
db_extras = (
|
||||
|
|
@ -198,7 +198,7 @@ class DatabaseConfig(BaseSettings):
|
|||
default=os.cpu_count() or 1,
|
||||
)
|
||||
|
||||
@computed_field # type: ignore[misc]
|
||||
@computed_field # type: ignore[prop-decorator]
|
||||
@property
|
||||
def SQLALCHEMY_ENGINE_OPTIONS(self) -> dict[str, Any]:
|
||||
# Parse DB_EXTRAS for 'options'
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ except ImportError:
|
|||
)
|
||||
else:
|
||||
warnings.warn("To use python-magic guess MIMETYPE, you need to install `libmagic`", stacklevel=2)
|
||||
magic = None # type: ignore
|
||||
magic = None # type: ignore[assignment]
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
|
|
|||
|
|
@ -211,8 +211,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
|||
user=user,
|
||||
stream=streaming,
|
||||
)
|
||||
# FIXME: Type hinting issue here, ignore it for now, will fix it later
|
||||
return AgentChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) # type: ignore
|
||||
return AgentChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
|
||||
|
||||
def _generate_worker(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -89,7 +89,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
|||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(data)
|
||||
elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse):
|
||||
response_chunk.update(sub_stream_response.to_ignore_detail_dict()) # ty: ignore [unresolved-attribute]
|
||||
response_chunk.update(sub_stream_response.to_ignore_detail_dict())
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json"))
|
||||
yield response_chunk
|
||||
|
|
|
|||
|
|
@ -98,7 +98,7 @@ class RateLimit:
|
|||
else:
|
||||
return RateLimitGenerator(
|
||||
rate_limit=self,
|
||||
generator=generator, # ty: ignore [invalid-argument-type]
|
||||
generator=generator,
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -49,7 +49,7 @@ class BasedGenerateTaskPipeline:
|
|||
if isinstance(e, InvokeAuthorizationError):
|
||||
err = InvokeAuthorizationError("Incorrect API key provided")
|
||||
elif isinstance(e, InvokeError | ValueError):
|
||||
err = e # ty: ignore [invalid-assignment]
|
||||
err = e
|
||||
else:
|
||||
description = getattr(e, "description", None)
|
||||
err = Exception(description if description is not None else str(e))
|
||||
|
|
|
|||
|
|
@ -1868,7 +1868,7 @@ class ProviderConfigurations(BaseModel):
|
|||
if "/" not in key:
|
||||
key = str(ModelProviderID(key))
|
||||
|
||||
return self.configurations.get(key, default) # type: ignore
|
||||
return self.configurations.get(key, default)
|
||||
|
||||
|
||||
class ProviderModelBundle(BaseModel):
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ def import_module_from_source(*, module_name: str, py_file_path: AnyStr, use_laz
|
|||
else:
|
||||
# Refer to: https://docs.python.org/3/library/importlib.html#importing-a-source-file-directly
|
||||
# FIXME: mypy does not support the type of spec.loader
|
||||
spec = importlib.util.spec_from_file_location(module_name, py_file_path) # type: ignore
|
||||
spec = importlib.util.spec_from_file_location(module_name, py_file_path) # type: ignore[assignment]
|
||||
if not spec or not spec.loader:
|
||||
raise Exception(f"Failed to load module {module_name} from {py_file_path!r}")
|
||||
if use_lazy_loader:
|
||||
|
|
|
|||
|
|
@ -49,62 +49,80 @@ class IndexingRunner:
|
|||
self.storage = storage
|
||||
self.model_manager = ModelManager()
|
||||
|
||||
def _handle_indexing_error(self, document_id: str, error: Exception) -> None:
|
||||
"""Handle indexing errors by updating document status."""
|
||||
logger.exception("consume document failed")
|
||||
document = db.session.get(DatasetDocument, document_id)
|
||||
if document:
|
||||
document.indexing_status = "error"
|
||||
error_message = getattr(error, "description", str(error))
|
||||
document.error = str(error_message)
|
||||
document.stopped_at = naive_utc_now()
|
||||
db.session.commit()
|
||||
|
||||
def run(self, dataset_documents: list[DatasetDocument]):
|
||||
"""Run the indexing process."""
|
||||
for dataset_document in dataset_documents:
|
||||
document_id = dataset_document.id
|
||||
try:
|
||||
# Re-query the document to ensure it's bound to the current session
|
||||
requeried_document = db.session.get(DatasetDocument, document_id)
|
||||
if not requeried_document:
|
||||
logger.warning("Document not found, skipping document id: %s", document_id)
|
||||
continue
|
||||
|
||||
# get dataset
|
||||
dataset = db.session.query(Dataset).filter_by(id=dataset_document.dataset_id).first()
|
||||
dataset = db.session.query(Dataset).filter_by(id=requeried_document.dataset_id).first()
|
||||
|
||||
if not dataset:
|
||||
raise ValueError("no dataset found")
|
||||
# get the process rule
|
||||
stmt = select(DatasetProcessRule).where(
|
||||
DatasetProcessRule.id == dataset_document.dataset_process_rule_id
|
||||
DatasetProcessRule.id == requeried_document.dataset_process_rule_id
|
||||
)
|
||||
processing_rule = db.session.scalar(stmt)
|
||||
if not processing_rule:
|
||||
raise ValueError("no process rule found")
|
||||
index_type = dataset_document.doc_form
|
||||
index_type = requeried_document.doc_form
|
||||
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
||||
# extract
|
||||
text_docs = self._extract(index_processor, dataset_document, processing_rule.to_dict())
|
||||
text_docs = self._extract(index_processor, requeried_document, processing_rule.to_dict())
|
||||
|
||||
# transform
|
||||
documents = self._transform(
|
||||
index_processor, dataset, text_docs, dataset_document.doc_language, processing_rule.to_dict()
|
||||
index_processor, dataset, text_docs, requeried_document.doc_language, processing_rule.to_dict()
|
||||
)
|
||||
# save segment
|
||||
self._load_segments(dataset, dataset_document, documents)
|
||||
self._load_segments(dataset, requeried_document, documents)
|
||||
|
||||
# load
|
||||
self._load(
|
||||
index_processor=index_processor,
|
||||
dataset=dataset,
|
||||
dataset_document=dataset_document,
|
||||
dataset_document=requeried_document,
|
||||
documents=documents,
|
||||
)
|
||||
except DocumentIsPausedError:
|
||||
raise DocumentIsPausedError(f"Document paused, document id: {dataset_document.id}")
|
||||
raise DocumentIsPausedError(f"Document paused, document id: {document_id}")
|
||||
except ProviderTokenNotInitError as e:
|
||||
dataset_document.indexing_status = "error"
|
||||
dataset_document.error = str(e.description)
|
||||
dataset_document.stopped_at = naive_utc_now()
|
||||
db.session.commit()
|
||||
self._handle_indexing_error(document_id, e)
|
||||
except ObjectDeletedError:
|
||||
logger.warning("Document deleted, document id: %s", dataset_document.id)
|
||||
logger.warning("Document deleted, document id: %s", document_id)
|
||||
except Exception as e:
|
||||
logger.exception("consume document failed")
|
||||
dataset_document.indexing_status = "error"
|
||||
dataset_document.error = str(e)
|
||||
dataset_document.stopped_at = naive_utc_now()
|
||||
db.session.commit()
|
||||
self._handle_indexing_error(document_id, e)
|
||||
|
||||
def run_in_splitting_status(self, dataset_document: DatasetDocument):
|
||||
"""Run the indexing process when the index_status is splitting."""
|
||||
document_id = dataset_document.id
|
||||
try:
|
||||
# Re-query the document to ensure it's bound to the current session
|
||||
requeried_document = db.session.get(DatasetDocument, document_id)
|
||||
if not requeried_document:
|
||||
logger.warning("Document not found: %s", document_id)
|
||||
return
|
||||
|
||||
# get dataset
|
||||
dataset = db.session.query(Dataset).filter_by(id=dataset_document.dataset_id).first()
|
||||
dataset = db.session.query(Dataset).filter_by(id=requeried_document.dataset_id).first()
|
||||
|
||||
if not dataset:
|
||||
raise ValueError("no dataset found")
|
||||
|
|
@ -112,57 +130,60 @@ class IndexingRunner:
|
|||
# get exist document_segment list and delete
|
||||
document_segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
.filter_by(dataset_id=dataset.id, document_id=dataset_document.id)
|
||||
.filter_by(dataset_id=dataset.id, document_id=requeried_document.id)
|
||||
.all()
|
||||
)
|
||||
|
||||
for document_segment in document_segments:
|
||||
db.session.delete(document_segment)
|
||||
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
|
||||
if requeried_document.doc_form == IndexType.PARENT_CHILD_INDEX:
|
||||
# delete child chunks
|
||||
db.session.query(ChildChunk).where(ChildChunk.segment_id == document_segment.id).delete()
|
||||
db.session.commit()
|
||||
# get the process rule
|
||||
stmt = select(DatasetProcessRule).where(DatasetProcessRule.id == dataset_document.dataset_process_rule_id)
|
||||
stmt = select(DatasetProcessRule).where(DatasetProcessRule.id == requeried_document.dataset_process_rule_id)
|
||||
processing_rule = db.session.scalar(stmt)
|
||||
if not processing_rule:
|
||||
raise ValueError("no process rule found")
|
||||
|
||||
index_type = dataset_document.doc_form
|
||||
index_type = requeried_document.doc_form
|
||||
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
||||
# extract
|
||||
text_docs = self._extract(index_processor, dataset_document, processing_rule.to_dict())
|
||||
text_docs = self._extract(index_processor, requeried_document, processing_rule.to_dict())
|
||||
|
||||
# transform
|
||||
documents = self._transform(
|
||||
index_processor, dataset, text_docs, dataset_document.doc_language, processing_rule.to_dict()
|
||||
index_processor, dataset, text_docs, requeried_document.doc_language, processing_rule.to_dict()
|
||||
)
|
||||
# save segment
|
||||
self._load_segments(dataset, dataset_document, documents)
|
||||
self._load_segments(dataset, requeried_document, documents)
|
||||
|
||||
# load
|
||||
self._load(
|
||||
index_processor=index_processor, dataset=dataset, dataset_document=dataset_document, documents=documents
|
||||
index_processor=index_processor,
|
||||
dataset=dataset,
|
||||
dataset_document=requeried_document,
|
||||
documents=documents,
|
||||
)
|
||||
except DocumentIsPausedError:
|
||||
raise DocumentIsPausedError(f"Document paused, document id: {dataset_document.id}")
|
||||
raise DocumentIsPausedError(f"Document paused, document id: {document_id}")
|
||||
except ProviderTokenNotInitError as e:
|
||||
dataset_document.indexing_status = "error"
|
||||
dataset_document.error = str(e.description)
|
||||
dataset_document.stopped_at = naive_utc_now()
|
||||
db.session.commit()
|
||||
self._handle_indexing_error(document_id, e)
|
||||
except Exception as e:
|
||||
logger.exception("consume document failed")
|
||||
dataset_document.indexing_status = "error"
|
||||
dataset_document.error = str(e)
|
||||
dataset_document.stopped_at = naive_utc_now()
|
||||
db.session.commit()
|
||||
self._handle_indexing_error(document_id, e)
|
||||
|
||||
def run_in_indexing_status(self, dataset_document: DatasetDocument):
|
||||
"""Run the indexing process when the index_status is indexing."""
|
||||
document_id = dataset_document.id
|
||||
try:
|
||||
# Re-query the document to ensure it's bound to the current session
|
||||
requeried_document = db.session.get(DatasetDocument, document_id)
|
||||
if not requeried_document:
|
||||
logger.warning("Document not found: %s", document_id)
|
||||
return
|
||||
|
||||
# get dataset
|
||||
dataset = db.session.query(Dataset).filter_by(id=dataset_document.dataset_id).first()
|
||||
dataset = db.session.query(Dataset).filter_by(id=requeried_document.dataset_id).first()
|
||||
|
||||
if not dataset:
|
||||
raise ValueError("no dataset found")
|
||||
|
|
@ -170,7 +191,7 @@ class IndexingRunner:
|
|||
# get exist document_segment list and delete
|
||||
document_segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
.filter_by(dataset_id=dataset.id, document_id=dataset_document.id)
|
||||
.filter_by(dataset_id=dataset.id, document_id=requeried_document.id)
|
||||
.all()
|
||||
)
|
||||
|
||||
|
|
@ -188,7 +209,7 @@ class IndexingRunner:
|
|||
"dataset_id": document_segment.dataset_id,
|
||||
},
|
||||
)
|
||||
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
|
||||
if requeried_document.doc_form == IndexType.PARENT_CHILD_INDEX:
|
||||
child_chunks = document_segment.get_child_chunks()
|
||||
if child_chunks:
|
||||
child_documents = []
|
||||
|
|
@ -206,24 +227,20 @@ class IndexingRunner:
|
|||
document.children = child_documents
|
||||
documents.append(document)
|
||||
# build index
|
||||
index_type = dataset_document.doc_form
|
||||
index_type = requeried_document.doc_form
|
||||
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
||||
self._load(
|
||||
index_processor=index_processor, dataset=dataset, dataset_document=dataset_document, documents=documents
|
||||
index_processor=index_processor,
|
||||
dataset=dataset,
|
||||
dataset_document=requeried_document,
|
||||
documents=documents,
|
||||
)
|
||||
except DocumentIsPausedError:
|
||||
raise DocumentIsPausedError(f"Document paused, document id: {dataset_document.id}")
|
||||
raise DocumentIsPausedError(f"Document paused, document id: {document_id}")
|
||||
except ProviderTokenNotInitError as e:
|
||||
dataset_document.indexing_status = "error"
|
||||
dataset_document.error = str(e.description)
|
||||
dataset_document.stopped_at = naive_utc_now()
|
||||
db.session.commit()
|
||||
self._handle_indexing_error(document_id, e)
|
||||
except Exception as e:
|
||||
logger.exception("consume document failed")
|
||||
dataset_document.indexing_status = "error"
|
||||
dataset_document.error = str(e)
|
||||
dataset_document.stopped_at = naive_utc_now()
|
||||
db.session.commit()
|
||||
self._handle_indexing_error(document_id, e)
|
||||
|
||||
def indexing_estimate(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import logging
|
|||
import os
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from langfuse import Langfuse # type: ignore
|
||||
from langfuse import Langfuse
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.ops.base_trace_instance import BaseTraceInstance
|
||||
|
|
|
|||
|
|
@ -180,7 +180,7 @@ class BasePluginClient:
|
|||
Make a request to the plugin daemon inner API and return the response as a model.
|
||||
"""
|
||||
response = self._request(method, path, headers, data, params, files)
|
||||
return type_(**response.json()) # type: ignore
|
||||
return type_(**response.json()) # type: ignore[return-value]
|
||||
|
||||
def _request_with_plugin_daemon_response(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -74,7 +74,7 @@ class CeleryWorkflowExecutionRepository(WorkflowExecutionRepository):
|
|||
tenant_id = extract_tenant_id(user)
|
||||
if not tenant_id:
|
||||
raise ValueError("User must have a tenant_id or current_tenant_id")
|
||||
self._tenant_id = tenant_id # type: ignore[assignment] # We've already checked tenant_id is not None
|
||||
self._tenant_id = tenant_id
|
||||
|
||||
# Store app context
|
||||
self._app_id = app_id
|
||||
|
|
|
|||
|
|
@ -81,7 +81,7 @@ class CeleryWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository):
|
|||
tenant_id = extract_tenant_id(user)
|
||||
if not tenant_id:
|
||||
raise ValueError("User must have a tenant_id or current_tenant_id")
|
||||
self._tenant_id = tenant_id # type: ignore[assignment] # We've already checked tenant_id is not None
|
||||
self._tenant_id = tenant_id
|
||||
|
||||
# Store app context
|
||||
self._app_id = app_id
|
||||
|
|
|
|||
|
|
@ -60,7 +60,7 @@ class DifyCoreRepositoryFactory:
|
|||
|
||||
try:
|
||||
repository_class = import_string(class_path)
|
||||
return repository_class( # type: ignore[no-any-return]
|
||||
return repository_class(
|
||||
session_factory=session_factory,
|
||||
user=user,
|
||||
app_id=app_id,
|
||||
|
|
@ -96,7 +96,7 @@ class DifyCoreRepositoryFactory:
|
|||
|
||||
try:
|
||||
repository_class = import_string(class_path)
|
||||
return repository_class( # type: ignore[no-any-return]
|
||||
return repository_class(
|
||||
session_factory=session_factory,
|
||||
user=user,
|
||||
app_id=app_id,
|
||||
|
|
|
|||
|
|
@ -157,7 +157,7 @@ class BuiltinToolProviderController(ToolProviderController):
|
|||
"""
|
||||
returns the tool that the provider can provide
|
||||
"""
|
||||
return next(filter(lambda x: x.entity.identity.name == tool_name, self.get_tools()), None) # type: ignore
|
||||
return next(filter(lambda x: x.entity.identity.name == tool_name, self.get_tools()), None)
|
||||
|
||||
@property
|
||||
def need_credentials(self) -> bool:
|
||||
|
|
|
|||
|
|
@ -43,7 +43,7 @@ class TTSTool(BuiltinTool):
|
|||
content_text=tool_parameters.get("text"), # type: ignore
|
||||
user=user_id,
|
||||
tenant_id=self.runtime.tenant_id,
|
||||
voice=voice, # type: ignore
|
||||
voice=voice,
|
||||
)
|
||||
buffer = io.BytesIO()
|
||||
for chunk in tts:
|
||||
|
|
|
|||
|
|
@ -34,6 +34,7 @@ class LocaltimeToTimestampTool(BuiltinTool):
|
|||
|
||||
yield self.create_text_message(f"{timestamp}")
|
||||
|
||||
# TODO: this method's type is messy
|
||||
@staticmethod
|
||||
def localtime_to_timestamp(localtime: str, time_format: str, local_tz=None) -> int | None:
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -48,6 +48,6 @@ class TimezoneConversionTool(BuiltinTool):
|
|||
datetime_with_tz = input_timezone.localize(local_time)
|
||||
# timezone convert
|
||||
converted_datetime = datetime_with_tz.astimezone(output_timezone)
|
||||
return converted_datetime.strftime(format=time_format) # type: ignore
|
||||
return converted_datetime.strftime(time_format)
|
||||
except Exception as e:
|
||||
raise ToolInvokeError(str(e))
|
||||
|
|
|
|||
|
|
@ -105,7 +105,7 @@ class MCPToolProviderController(ToolProviderController):
|
|||
"""
|
||||
pass
|
||||
|
||||
def get_tool(self, tool_name: str) -> MCPTool: # type: ignore
|
||||
def get_tool(self, tool_name: str) -> MCPTool:
|
||||
"""
|
||||
return tool with given name
|
||||
"""
|
||||
|
|
@ -128,7 +128,7 @@ class MCPToolProviderController(ToolProviderController):
|
|||
sse_read_timeout=self.sse_read_timeout,
|
||||
)
|
||||
|
||||
def get_tools(self) -> list[MCPTool]: # type: ignore
|
||||
def get_tools(self) -> list[MCPTool]:
|
||||
"""
|
||||
get all tools
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ class ToolLabelManager:
|
|||
labels = cls.filter_tool_labels(labels)
|
||||
|
||||
if isinstance(controller, ApiToolProviderController | WorkflowToolProviderController):
|
||||
provider_id = controller.provider_id # ty: ignore [unresolved-attribute]
|
||||
provider_id = controller.provider_id
|
||||
else:
|
||||
raise ValueError("Unsupported tool type")
|
||||
|
||||
|
|
@ -51,7 +51,7 @@ class ToolLabelManager:
|
|||
Get tool labels
|
||||
"""
|
||||
if isinstance(controller, ApiToolProviderController | WorkflowToolProviderController):
|
||||
provider_id = controller.provider_id # ty: ignore [unresolved-attribute]
|
||||
provider_id = controller.provider_id
|
||||
elif isinstance(controller, BuiltinToolProviderController):
|
||||
return controller.tool_labels
|
||||
else:
|
||||
|
|
@ -85,7 +85,7 @@ class ToolLabelManager:
|
|||
provider_ids = []
|
||||
for controller in tool_providers:
|
||||
assert isinstance(controller, ApiToolProviderController | WorkflowToolProviderController)
|
||||
provider_ids.append(controller.provider_id) # ty: ignore [unresolved-attribute]
|
||||
provider_ids.append(controller.provider_id)
|
||||
|
||||
labels = db.session.scalars(select(ToolLabelBinding).where(ToolLabelBinding.tool_id.in_(provider_ids))).all()
|
||||
|
||||
|
|
|
|||
|
|
@ -193,18 +193,18 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
|||
DatasetDocument.enabled == True,
|
||||
DatasetDocument.archived == False,
|
||||
)
|
||||
document = db.session.scalar(dataset_document_stmt) # type: ignore
|
||||
document = db.session.scalar(dataset_document_stmt)
|
||||
if dataset and document:
|
||||
source = RetrievalSourceMetadata(
|
||||
dataset_id=dataset.id,
|
||||
dataset_name=dataset.name,
|
||||
document_id=document.id, # type: ignore
|
||||
document_name=document.name, # type: ignore
|
||||
data_source_type=document.data_source_type, # type: ignore
|
||||
document_id=document.id,
|
||||
document_name=document.name,
|
||||
data_source_type=document.data_source_type,
|
||||
segment_id=segment.id,
|
||||
retriever_from=self.retriever_from,
|
||||
score=record.score or 0.0,
|
||||
doc_metadata=document.doc_metadata, # type: ignore
|
||||
doc_metadata=document.doc_metadata,
|
||||
)
|
||||
|
||||
if self.retriever_from == "dev":
|
||||
|
|
|
|||
|
|
@ -6,8 +6,8 @@ from typing import Any, cast
|
|||
from urllib.parse import unquote
|
||||
|
||||
import chardet
|
||||
import cloudscraper # type: ignore
|
||||
from readabilipy import simple_json_from_html_string # type: ignore
|
||||
import cloudscraper
|
||||
from readabilipy import simple_json_from_html_string
|
||||
|
||||
from core.helper import ssrf_proxy
|
||||
from core.rag.extractor import extract_processor
|
||||
|
|
@ -63,8 +63,8 @@ def get_url(url: str, user_agent: str | None = None) -> str:
|
|||
response = ssrf_proxy.get(url, headers=headers, follow_redirects=True, timeout=(120, 300))
|
||||
elif response.status_code == 403:
|
||||
scraper = cloudscraper.create_scraper()
|
||||
scraper.perform_request = ssrf_proxy.make_request # type: ignore
|
||||
response = scraper.get(url, headers=headers, follow_redirects=True, timeout=(120, 300)) # type: ignore
|
||||
scraper.perform_request = ssrf_proxy.make_request
|
||||
response = scraper.get(url, headers=headers, timeout=(120, 300))
|
||||
|
||||
if response.status_code != 200:
|
||||
return f"URL returned status code {response.status_code}."
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ from functools import lru_cache
|
|||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import yaml # type: ignore
|
||||
import yaml
|
||||
from yaml import YAMLError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
|
|||
|
|
@ -99,7 +99,7 @@ class WorkflowToolProviderController(ToolProviderController):
|
|||
variables = WorkflowToolConfigurationUtils.get_workflow_graph_variables(graph)
|
||||
|
||||
def fetch_workflow_variable(variable_name: str) -> VariableEntity | None:
|
||||
return next(filter(lambda x: x.variable == variable_name, variables), None) # type: ignore
|
||||
return next(filter(lambda x: x.variable == variable_name, variables), None)
|
||||
|
||||
user = db_provider.user
|
||||
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from .types import SegmentType
|
|||
|
||||
class SegmentGroup(Segment):
|
||||
value_type: SegmentType = SegmentType.GROUP
|
||||
value: list[Segment] = None # type: ignore
|
||||
value: list[Segment]
|
||||
|
||||
@property
|
||||
def text(self):
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ class Segment(BaseModel):
|
|||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
value_type: SegmentType
|
||||
value: Any = None
|
||||
value: Any
|
||||
|
||||
@field_validator("value_type")
|
||||
@classmethod
|
||||
|
|
@ -74,12 +74,12 @@ class NoneSegment(Segment):
|
|||
|
||||
class StringSegment(Segment):
|
||||
value_type: SegmentType = SegmentType.STRING
|
||||
value: str = None # type: ignore
|
||||
value: str
|
||||
|
||||
|
||||
class FloatSegment(Segment):
|
||||
value_type: SegmentType = SegmentType.FLOAT
|
||||
value: float = None # type: ignore
|
||||
value: float
|
||||
# NOTE(QuantumGhost): seems that the equality for FloatSegment with `NaN` value has some problems.
|
||||
# The following tests cannot pass.
|
||||
#
|
||||
|
|
@ -98,12 +98,12 @@ class FloatSegment(Segment):
|
|||
|
||||
class IntegerSegment(Segment):
|
||||
value_type: SegmentType = SegmentType.INTEGER
|
||||
value: int = None # type: ignore
|
||||
value: int
|
||||
|
||||
|
||||
class ObjectSegment(Segment):
|
||||
value_type: SegmentType = SegmentType.OBJECT
|
||||
value: Mapping[str, Any] = None # type: ignore
|
||||
value: Mapping[str, Any]
|
||||
|
||||
@property
|
||||
def text(self) -> str:
|
||||
|
|
@ -136,7 +136,7 @@ class ArraySegment(Segment):
|
|||
|
||||
class FileSegment(Segment):
|
||||
value_type: SegmentType = SegmentType.FILE
|
||||
value: File = None # type: ignore
|
||||
value: File
|
||||
|
||||
@property
|
||||
def markdown(self) -> str:
|
||||
|
|
@ -153,17 +153,17 @@ class FileSegment(Segment):
|
|||
|
||||
class BooleanSegment(Segment):
|
||||
value_type: SegmentType = SegmentType.BOOLEAN
|
||||
value: bool = None # type: ignore
|
||||
value: bool
|
||||
|
||||
|
||||
class ArrayAnySegment(ArraySegment):
|
||||
value_type: SegmentType = SegmentType.ARRAY_ANY
|
||||
value: Sequence[Any] = None # type: ignore
|
||||
value: Sequence[Any]
|
||||
|
||||
|
||||
class ArrayStringSegment(ArraySegment):
|
||||
value_type: SegmentType = SegmentType.ARRAY_STRING
|
||||
value: Sequence[str] = None # type: ignore
|
||||
value: Sequence[str]
|
||||
|
||||
@property
|
||||
def text(self) -> str:
|
||||
|
|
@ -175,17 +175,17 @@ class ArrayStringSegment(ArraySegment):
|
|||
|
||||
class ArrayNumberSegment(ArraySegment):
|
||||
value_type: SegmentType = SegmentType.ARRAY_NUMBER
|
||||
value: Sequence[float | int] = None # type: ignore
|
||||
value: Sequence[float | int]
|
||||
|
||||
|
||||
class ArrayObjectSegment(ArraySegment):
|
||||
value_type: SegmentType = SegmentType.ARRAY_OBJECT
|
||||
value: Sequence[Mapping[str, Any]] = None # type: ignore
|
||||
value: Sequence[Mapping[str, Any]]
|
||||
|
||||
|
||||
class ArrayFileSegment(ArraySegment):
|
||||
value_type: SegmentType = SegmentType.ARRAY_FILE
|
||||
value: Sequence[File] = None # type: ignore
|
||||
value: Sequence[File]
|
||||
|
||||
@property
|
||||
def markdown(self) -> str:
|
||||
|
|
@ -205,7 +205,7 @@ class ArrayFileSegment(ArraySegment):
|
|||
|
||||
class ArrayBooleanSegment(ArraySegment):
|
||||
value_type: SegmentType = SegmentType.ARRAY_BOOLEAN
|
||||
value: Sequence[bool] = None # type: ignore
|
||||
value: Sequence[bool]
|
||||
|
||||
|
||||
def get_segment_discriminator(v: Any) -> SegmentType | None:
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
from ..runtime.graph_runtime_state import GraphRuntimeState
|
||||
from ..runtime.variable_pool import VariablePool
|
||||
from .agent import AgentNodeStrategyInit
|
||||
from .graph_init_params import GraphInitParams
|
||||
from .workflow_execution import WorkflowExecution
|
||||
|
|
@ -6,6 +8,8 @@ from .workflow_node_execution import WorkflowNodeExecution
|
|||
__all__ = [
|
||||
"AgentNodeStrategyInit",
|
||||
"GraphInitParams",
|
||||
"GraphRuntimeState",
|
||||
"VariablePool",
|
||||
"WorkflowExecution",
|
||||
"WorkflowNodeExecution",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -3,11 +3,12 @@ from collections import defaultdict
|
|||
from collections.abc import Mapping, Sequence
|
||||
from typing import Protocol, cast, final
|
||||
|
||||
from core.workflow.enums import NodeExecutionType, NodeState, NodeType
|
||||
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeState, NodeType
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from libs.typing import is_str, is_str_dict
|
||||
|
||||
from .edge import Edge
|
||||
from .validation import get_graph_validator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -201,6 +202,17 @@ class Graph:
|
|||
|
||||
return GraphBuilder(graph_cls=cls)
|
||||
|
||||
@classmethod
|
||||
def _promote_fail_branch_nodes(cls, nodes: dict[str, Node]) -> None:
|
||||
"""
|
||||
Promote nodes configured with FAIL_BRANCH error strategy to branch execution type.
|
||||
|
||||
:param nodes: mapping of node ID to node instance
|
||||
"""
|
||||
for node in nodes.values():
|
||||
if node.error_strategy == ErrorStrategy.FAIL_BRANCH:
|
||||
node.execution_type = NodeExecutionType.BRANCH
|
||||
|
||||
@classmethod
|
||||
def _mark_inactive_root_branches(
|
||||
cls,
|
||||
|
|
@ -307,6 +319,9 @@ class Graph:
|
|||
# Create node instances
|
||||
nodes = cls._create_node_instances(node_configs_map, node_factory)
|
||||
|
||||
# Promote fail-branch nodes to branch execution type at graph level
|
||||
cls._promote_fail_branch_nodes(nodes)
|
||||
|
||||
# Get root node instance
|
||||
root_node = nodes[root_node_id]
|
||||
|
||||
|
|
@ -314,7 +329,7 @@ class Graph:
|
|||
cls._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, root_node_id)
|
||||
|
||||
# Create and return the graph
|
||||
return cls(
|
||||
graph = cls(
|
||||
nodes=nodes,
|
||||
edges=edges,
|
||||
in_edges=in_edges,
|
||||
|
|
@ -322,6 +337,11 @@ class Graph:
|
|||
root_node=root_node,
|
||||
)
|
||||
|
||||
# Validate the graph structure using built-in validators
|
||||
get_graph_validator().validate(graph)
|
||||
|
||||
return graph
|
||||
|
||||
@property
|
||||
def node_ids(self) -> list[str]:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -0,0 +1,125 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Protocol
|
||||
|
||||
from core.workflow.enums import NodeExecutionType, NodeType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .graph import Graph
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class GraphValidationIssue:
|
||||
"""Immutable value object describing a single validation issue."""
|
||||
|
||||
code: str
|
||||
message: str
|
||||
node_id: str | None = None
|
||||
|
||||
|
||||
class GraphValidationError(ValueError):
|
||||
"""Raised when graph validation fails."""
|
||||
|
||||
def __init__(self, issues: Sequence[GraphValidationIssue]) -> None:
|
||||
if not issues:
|
||||
raise ValueError("GraphValidationError requires at least one issue.")
|
||||
self.issues: tuple[GraphValidationIssue, ...] = tuple(issues)
|
||||
message = "; ".join(f"[{issue.code}] {issue.message}" for issue in self.issues)
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class GraphValidationRule(Protocol):
|
||||
"""Protocol that individual validation rules must satisfy."""
|
||||
|
||||
def validate(self, graph: Graph) -> Sequence[GraphValidationIssue]:
|
||||
"""Validate the provided graph and return any discovered issues."""
|
||||
...
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class _EdgeEndpointValidator:
|
||||
"""Ensures all edges reference existing nodes."""
|
||||
|
||||
missing_node_code: str = "MISSING_NODE"
|
||||
|
||||
def validate(self, graph: Graph) -> Sequence[GraphValidationIssue]:
|
||||
issues: list[GraphValidationIssue] = []
|
||||
for edge in graph.edges.values():
|
||||
if edge.tail not in graph.nodes:
|
||||
issues.append(
|
||||
GraphValidationIssue(
|
||||
code=self.missing_node_code,
|
||||
message=f"Edge {edge.id} references unknown source node '{edge.tail}'.",
|
||||
node_id=edge.tail,
|
||||
)
|
||||
)
|
||||
if edge.head not in graph.nodes:
|
||||
issues.append(
|
||||
GraphValidationIssue(
|
||||
code=self.missing_node_code,
|
||||
message=f"Edge {edge.id} references unknown target node '{edge.head}'.",
|
||||
node_id=edge.head,
|
||||
)
|
||||
)
|
||||
return issues
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class _RootNodeValidator:
|
||||
"""Validates root node invariants."""
|
||||
|
||||
invalid_root_code: str = "INVALID_ROOT"
|
||||
container_entry_types: tuple[NodeType, ...] = (NodeType.ITERATION_START, NodeType.LOOP_START)
|
||||
|
||||
def validate(self, graph: Graph) -> Sequence[GraphValidationIssue]:
|
||||
root_node = graph.root_node
|
||||
issues: list[GraphValidationIssue] = []
|
||||
if root_node.id not in graph.nodes:
|
||||
issues.append(
|
||||
GraphValidationIssue(
|
||||
code=self.invalid_root_code,
|
||||
message=f"Root node '{root_node.id}' is missing from the node registry.",
|
||||
node_id=root_node.id,
|
||||
)
|
||||
)
|
||||
return issues
|
||||
|
||||
node_type = getattr(root_node, "node_type", None)
|
||||
if root_node.execution_type != NodeExecutionType.ROOT and node_type not in self.container_entry_types:
|
||||
issues.append(
|
||||
GraphValidationIssue(
|
||||
code=self.invalid_root_code,
|
||||
message=f"Root node '{root_node.id}' must declare execution type 'root'.",
|
||||
node_id=root_node.id,
|
||||
)
|
||||
)
|
||||
return issues
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class GraphValidator:
|
||||
"""Coordinates execution of graph validation rules."""
|
||||
|
||||
rules: tuple[GraphValidationRule, ...]
|
||||
|
||||
def validate(self, graph: Graph) -> None:
|
||||
"""Validate the graph against all configured rules."""
|
||||
issues: list[GraphValidationIssue] = []
|
||||
for rule in self.rules:
|
||||
issues.extend(rule.validate(graph))
|
||||
|
||||
if issues:
|
||||
raise GraphValidationError(issues)
|
||||
|
||||
|
||||
_DEFAULT_RULES: tuple[GraphValidationRule, ...] = (
|
||||
_EdgeEndpointValidator(),
|
||||
_RootNodeValidator(),
|
||||
)
|
||||
|
||||
|
||||
def get_graph_validator() -> GraphValidator:
|
||||
"""Construct the validator composed of default rules."""
|
||||
return GraphValidator(_DEFAULT_RULES)
|
||||
|
|
@ -1,5 +1,6 @@
|
|||
import json
|
||||
from abc import ABC
|
||||
from builtins import type as type_
|
||||
from collections.abc import Sequence
|
||||
from enum import StrEnum
|
||||
from typing import Any, Union
|
||||
|
|
@ -58,10 +59,9 @@ class DefaultValue(BaseModel):
|
|||
raise DefaultValueTypeError(f"Invalid JSON format for value: {value}")
|
||||
|
||||
@staticmethod
|
||||
def _validate_array(value: Any, element_type: DefaultValueType) -> bool:
|
||||
def _validate_array(value: Any, element_type: type_ | tuple[type_, ...]) -> bool:
|
||||
"""Unified array type validation"""
|
||||
# FIXME, type ignore here for do not find the reason mypy complain, if find the root cause, please fix it
|
||||
return isinstance(value, list) and all(isinstance(x, element_type) for x in value) # type: ignore
|
||||
return isinstance(value, list) and all(isinstance(x, element_type) for x in value)
|
||||
|
||||
@staticmethod
|
||||
def _convert_number(value: str) -> float:
|
||||
|
|
|
|||
|
|
@ -10,10 +10,10 @@ from typing import Any
|
|||
import chardet
|
||||
import docx
|
||||
import pandas as pd
|
||||
import pypandoc # type: ignore
|
||||
import pypdfium2 # type: ignore
|
||||
import webvtt # type: ignore
|
||||
import yaml # type: ignore
|
||||
import pypandoc
|
||||
import pypdfium2
|
||||
import webvtt
|
||||
import yaml
|
||||
from docx.document import Document
|
||||
from docx.oxml.table import CT_Tbl
|
||||
from docx.oxml.text.paragraph import CT_P
|
||||
|
|
|
|||
|
|
@ -141,7 +141,7 @@ class KnowledgeRetrievalNode(Node):
|
|||
def version(cls):
|
||||
return "1"
|
||||
|
||||
def _run(self) -> NodeRunResult: # type: ignore
|
||||
def _run(self) -> NodeRunResult:
|
||||
# extract variables
|
||||
variable = self.graph_runtime_state.variable_pool.get(self._node_data.query_variable_selector)
|
||||
if not isinstance(variable, StringSegment):
|
||||
|
|
@ -443,7 +443,7 @@ class KnowledgeRetrievalNode(Node):
|
|||
metadata_condition = MetadataCondition(
|
||||
logical_operator=node_data.metadata_filtering_conditions.logical_operator
|
||||
if node_data.metadata_filtering_conditions
|
||||
else "or", # type: ignore
|
||||
else "or",
|
||||
conditions=conditions,
|
||||
)
|
||||
elif node_data.metadata_filtering_mode == "manual":
|
||||
|
|
@ -457,10 +457,10 @@ class KnowledgeRetrievalNode(Node):
|
|||
expected_value = self.graph_runtime_state.variable_pool.convert_template(
|
||||
expected_value
|
||||
).value[0]
|
||||
if expected_value.value_type in {"number", "integer", "float"}: # type: ignore
|
||||
expected_value = expected_value.value # type: ignore
|
||||
elif expected_value.value_type == "string": # type: ignore
|
||||
expected_value = re.sub(r"[\r\n\t]+", " ", expected_value.text).strip() # type: ignore
|
||||
if expected_value.value_type in {"number", "integer", "float"}:
|
||||
expected_value = expected_value.value
|
||||
elif expected_value.value_type == "string":
|
||||
expected_value = re.sub(r"[\r\n\t]+", " ", expected_value.text).strip()
|
||||
else:
|
||||
raise ValueError("Invalid expected metadata value type")
|
||||
conditions.append(
|
||||
|
|
@ -487,7 +487,7 @@ class KnowledgeRetrievalNode(Node):
|
|||
if (
|
||||
node_data.metadata_filtering_conditions
|
||||
and node_data.metadata_filtering_conditions.logical_operator == "and"
|
||||
): # type: ignore
|
||||
):
|
||||
document_query = document_query.where(and_(*filters))
|
||||
else:
|
||||
document_query = document_query.where(or_(*filters))
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, final
|
|||
|
||||
from typing_extensions import override
|
||||
|
||||
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.graph import NodeFactory
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from libs.typing import is_str, is_str_dict
|
||||
|
|
@ -82,8 +82,4 @@ class DifyNodeFactory(NodeFactory):
|
|||
raise ValueError(f"Node {node_id} missing data information")
|
||||
node_instance.init_node_data(node_data)
|
||||
|
||||
# If node has fail branch, change execution type to branch
|
||||
if node_instance.error_strategy == ErrorStrategy.FAIL_BRANCH:
|
||||
node_instance.execution_type = NodeExecutionType.BRANCH
|
||||
|
||||
return node_instance
|
||||
|
|
|
|||
|
|
@ -747,7 +747,7 @@ class ParameterExtractorNode(Node):
|
|||
if model_mode == ModelMode.CHAT:
|
||||
system_prompt_messages = ChatModelMessage(
|
||||
role=PromptMessageRole.SYSTEM,
|
||||
text=CHAT_GENERATE_JSON_PROMPT.format(histories=memory_str).replace("{{instructions}}", instruction),
|
||||
text=CHAT_GENERATE_JSON_PROMPT.format(histories=memory_str, instructions=instruction),
|
||||
)
|
||||
user_prompt_message = ChatModelMessage(role=PromptMessageRole.USER, text=input_text)
|
||||
return [system_prompt_messages, user_prompt_message]
|
||||
|
|
|
|||
|
|
@ -135,7 +135,7 @@ Here are the chat histories between human and assistant, inside <histories></his
|
|||
### Instructions:
|
||||
Some extra information are provided below, you should always follow the instructions as possible as you can.
|
||||
<instructions>
|
||||
{{instructions}}
|
||||
{instructions}
|
||||
</instructions>
|
||||
"""
|
||||
|
||||
|
|
|
|||
|
|
@ -260,7 +260,7 @@ class VariablePool(BaseModel):
|
|||
# This ensures that we can keep the id of the system variables intact.
|
||||
if self._has(selector):
|
||||
continue
|
||||
self.add(selector, value) # type: ignore
|
||||
self.add(selector, value)
|
||||
|
||||
@classmethod
|
||||
def empty(cls) -> "VariablePool":
|
||||
|
|
|
|||
|
|
@ -1,7 +1,12 @@
|
|||
from configs import dify_config
|
||||
from constants import HEADER_NAME_APP_CODE, HEADER_NAME_CSRF_TOKEN
|
||||
from constants import HEADER_NAME_APP_CODE, HEADER_NAME_CSRF_TOKEN, HEADER_NAME_PASSPORT
|
||||
from dify_app import DifyApp
|
||||
|
||||
BASE_CORS_HEADERS: tuple[str, ...] = ("Content-Type", HEADER_NAME_APP_CODE, HEADER_NAME_PASSPORT)
|
||||
SERVICE_API_HEADERS: tuple[str, ...] = (*BASE_CORS_HEADERS, "Authorization")
|
||||
AUTHENTICATED_HEADERS: tuple[str, ...] = (*SERVICE_API_HEADERS, HEADER_NAME_CSRF_TOKEN)
|
||||
FILES_HEADERS: tuple[str, ...] = (*BASE_CORS_HEADERS, HEADER_NAME_CSRF_TOKEN)
|
||||
|
||||
|
||||
def init_app(app: DifyApp):
|
||||
# register blueprint routers
|
||||
|
|
@ -17,7 +22,7 @@ def init_app(app: DifyApp):
|
|||
|
||||
CORS(
|
||||
service_api_bp,
|
||||
allow_headers=["Content-Type", "Authorization", HEADER_NAME_APP_CODE],
|
||||
allow_headers=list(SERVICE_API_HEADERS),
|
||||
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
|
||||
)
|
||||
app.register_blueprint(service_api_bp)
|
||||
|
|
@ -26,7 +31,7 @@ def init_app(app: DifyApp):
|
|||
web_bp,
|
||||
resources={r"/*": {"origins": dify_config.WEB_API_CORS_ALLOW_ORIGINS}},
|
||||
supports_credentials=True,
|
||||
allow_headers=["Content-Type", "Authorization", HEADER_NAME_APP_CODE, HEADER_NAME_CSRF_TOKEN],
|
||||
allow_headers=list(AUTHENTICATED_HEADERS),
|
||||
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
|
||||
expose_headers=["X-Version", "X-Env"],
|
||||
)
|
||||
|
|
@ -36,7 +41,7 @@ def init_app(app: DifyApp):
|
|||
console_app_bp,
|
||||
resources={r"/*": {"origins": dify_config.CONSOLE_CORS_ALLOW_ORIGINS}},
|
||||
supports_credentials=True,
|
||||
allow_headers=["Content-Type", "Authorization", HEADER_NAME_CSRF_TOKEN],
|
||||
allow_headers=list(AUTHENTICATED_HEADERS),
|
||||
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
|
||||
expose_headers=["X-Version", "X-Env"],
|
||||
)
|
||||
|
|
@ -44,7 +49,7 @@ def init_app(app: DifyApp):
|
|||
|
||||
CORS(
|
||||
files_bp,
|
||||
allow_headers=["Content-Type", HEADER_NAME_CSRF_TOKEN],
|
||||
allow_headers=list(FILES_HEADERS),
|
||||
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
|
||||
)
|
||||
app.register_blueprint(files_bp)
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ def is_enabled() -> bool:
|
|||
|
||||
|
||||
def init_app(app: DifyApp):
|
||||
from flask_compress import Compress # type: ignore
|
||||
from flask_compress import Compress
|
||||
|
||||
compress = Compress()
|
||||
compress.init_app(app)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import json
|
||||
|
||||
import flask_login # type: ignore
|
||||
import flask_login
|
||||
from flask import Response, request
|
||||
from flask_login import user_loaded_from_request, user_logged_in
|
||||
from werkzeug.exceptions import NotFound, Unauthorized
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ from dify_app import DifyApp
|
|||
|
||||
|
||||
def init_app(app: DifyApp):
|
||||
import flask_migrate # type: ignore
|
||||
import flask_migrate
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
|
|
|
|||
|
|
@ -103,7 +103,7 @@ def init_app(app: DifyApp):
|
|||
def shutdown_tracer():
|
||||
provider = trace.get_tracer_provider()
|
||||
if hasattr(provider, "force_flush"):
|
||||
provider.force_flush() # ty: ignore [call-non-callable]
|
||||
provider.force_flush()
|
||||
|
||||
class ExceptionLoggingHandler(logging.Handler):
|
||||
"""Custom logging handler that creates spans for logging.exception() calls"""
|
||||
|
|
|
|||
|
|
@ -6,4 +6,4 @@ def init_app(app: DifyApp):
|
|||
if dify_config.RESPECT_XFORWARD_HEADERS_ENABLED:
|
||||
from werkzeug.middleware.proxy_fix import ProxyFix
|
||||
|
||||
app.wsgi_app = ProxyFix(app.wsgi_app, x_port=1) # type: ignore
|
||||
app.wsgi_app = ProxyFix(app.wsgi_app, x_port=1) # type: ignore[method-assign]
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from dify_app import DifyApp
|
|||
def init_app(app: DifyApp):
|
||||
if dify_config.SENTRY_DSN:
|
||||
import sentry_sdk
|
||||
from langfuse import parse_error # type: ignore
|
||||
from langfuse import parse_error
|
||||
from sentry_sdk.integrations.celery import CeleryIntegration
|
||||
from sentry_sdk.integrations.flask import FlaskIntegration
|
||||
from werkzeug.exceptions import HTTPException
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import posixpath
|
||||
from collections.abc import Generator
|
||||
|
||||
import oss2 as aliyun_s3 # type: ignore
|
||||
import oss2 as aliyun_s3
|
||||
|
||||
from configs import dify_config
|
||||
from extensions.storage.base_storage import BaseStorage
|
||||
|
|
|
|||
|
|
@ -2,9 +2,9 @@ import base64
|
|||
import hashlib
|
||||
from collections.abc import Generator
|
||||
|
||||
from baidubce.auth.bce_credentials import BceCredentials # type: ignore
|
||||
from baidubce.bce_client_configuration import BceClientConfiguration # type: ignore
|
||||
from baidubce.services.bos.bos_client import BosClient # type: ignore
|
||||
from baidubce.auth.bce_credentials import BceCredentials
|
||||
from baidubce.bce_client_configuration import BceClientConfiguration
|
||||
from baidubce.services.bos.bos_client import BosClient
|
||||
|
||||
from configs import dify_config
|
||||
from extensions.storage.base_storage import BaseStorage
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ from collections.abc import Generator
|
|||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
|
||||
import clickzetta # type: ignore[import]
|
||||
import clickzetta
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
from extensions.storage.base_storage import BaseStorage
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@ class VolumePermissionManager:
|
|||
# Support two initialization methods: connection object or configuration dictionary
|
||||
if isinstance(connection_or_config, dict):
|
||||
# Create connection from configuration dictionary
|
||||
import clickzetta # type: ignore[import-untyped]
|
||||
import clickzetta
|
||||
|
||||
config = connection_or_config
|
||||
self._connection = clickzetta.connect(
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import io
|
|||
import json
|
||||
from collections.abc import Generator
|
||||
|
||||
from google.cloud import storage as google_cloud_storage # type: ignore
|
||||
from google.cloud import storage as google_cloud_storage
|
||||
|
||||
from configs import dify_config
|
||||
from extensions.storage.base_storage import BaseStorage
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from collections.abc import Generator
|
||||
|
||||
from obs import ObsClient # type: ignore
|
||||
from obs import ObsClient
|
||||
|
||||
from configs import dify_config
|
||||
from extensions.storage.base_storage import BaseStorage
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from collections.abc import Generator
|
||||
|
||||
import boto3 # type: ignore
|
||||
from botocore.exceptions import ClientError # type: ignore
|
||||
import boto3
|
||||
from botocore.exceptions import ClientError
|
||||
|
||||
from configs import dify_config
|
||||
from extensions.storage.base_storage import BaseStorage
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from collections.abc import Generator
|
||||
|
||||
from qcloud_cos import CosConfig, CosS3Client # type: ignore
|
||||
from qcloud_cos import CosConfig, CosS3Client
|
||||
|
||||
from configs import dify_config
|
||||
from extensions.storage.base_storage import BaseStorage
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from collections.abc import Generator
|
||||
|
||||
import tos # type: ignore
|
||||
import tos
|
||||
|
||||
from configs import dify_config
|
||||
from extensions.storage.base_storage import BaseStorage
|
||||
|
|
|
|||
|
|
@ -146,6 +146,6 @@ class ExternalApi(Api):
|
|||
kwargs["doc"] = dify_config.SWAGGER_UI_PATH if dify_config.SWAGGER_UI_ENABLED else False
|
||||
|
||||
# manual separate call on construction and init_app to ensure configs in kwargs effective
|
||||
super().__init__(app=None, *args, **kwargs) # type: ignore
|
||||
super().__init__(app=None, *args, **kwargs)
|
||||
self.init_app(app, **kwargs)
|
||||
register_external_error_handlers(self)
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ from hashlib import sha1
|
|||
|
||||
import Crypto.Hash.SHA1
|
||||
import Crypto.Util.number
|
||||
import gmpy2 # type: ignore
|
||||
import gmpy2
|
||||
from Crypto import Random
|
||||
from Crypto.Signature.pss import MGF1
|
||||
from Crypto.Util.number import bytes_to_long, ceil_div, long_to_bytes
|
||||
|
|
@ -136,7 +136,7 @@ class PKCS1OAepCipher:
|
|||
# Step 3a (OS2IP)
|
||||
em_int = bytes_to_long(em)
|
||||
# Step 3b (RSAEP)
|
||||
m_int = gmpy2.powmod(em_int, self._key.e, self._key.n) # ty: ignore [unresolved-attribute]
|
||||
m_int = gmpy2.powmod(em_int, self._key.e, self._key.n)
|
||||
# Step 3c (I2OSP)
|
||||
c = long_to_bytes(m_int, k)
|
||||
return c
|
||||
|
|
@ -169,7 +169,7 @@ class PKCS1OAepCipher:
|
|||
ct_int = bytes_to_long(ciphertext)
|
||||
# Step 2b (RSADP)
|
||||
# m_int = self._key._decrypt(ct_int)
|
||||
m_int = gmpy2.powmod(ct_int, self._key.d, self._key.n) # ty: ignore [unresolved-attribute]
|
||||
m_int = gmpy2.powmod(ct_int, self._key.d, self._key.n)
|
||||
# Complete step 2c (I2OSP)
|
||||
em = long_to_bytes(m_int, k)
|
||||
# Step 3a
|
||||
|
|
@ -191,12 +191,12 @@ class PKCS1OAepCipher:
|
|||
# Step 3g
|
||||
one_pos = hLen + db[hLen:].find(b"\x01")
|
||||
lHash1 = db[:hLen]
|
||||
invalid = bord(y) | int(one_pos < hLen) # type: ignore
|
||||
invalid = bord(y) | int(one_pos < hLen) # type: ignore[arg-type]
|
||||
hash_compare = strxor(lHash1, lHash)
|
||||
for x in hash_compare:
|
||||
invalid |= bord(x) # type: ignore
|
||||
invalid |= bord(x) # type: ignore[arg-type]
|
||||
for x in db[hLen:one_pos]:
|
||||
invalid |= bord(x) # type: ignore
|
||||
invalid |= bord(x) # type: ignore[arg-type]
|
||||
if invalid != 0:
|
||||
raise ValueError("Incorrect decryption.")
|
||||
# Step 4
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ from functools import wraps
|
|||
from typing import Any
|
||||
|
||||
from flask import current_app, g, has_request_context, request
|
||||
from flask_login.config import EXEMPT_METHODS # type: ignore
|
||||
from flask_login.config import EXEMPT_METHODS
|
||||
from werkzeug.local import LocalProxy
|
||||
|
||||
from configs import dify_config
|
||||
|
|
@ -87,7 +87,7 @@ def _get_user() -> EndUser | Account | None:
|
|||
if "_login_user" not in g:
|
||||
current_app.login_manager._load_user() # type: ignore
|
||||
|
||||
return g._login_user # type: ignore
|
||||
return g._login_user
|
||||
|
||||
return None
|
||||
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
import logging
|
||||
|
||||
import sendgrid # type: ignore
|
||||
import sendgrid
|
||||
from python_http_client.exceptions import ForbiddenError, UnauthorizedError
|
||||
from sendgrid.helpers.mail import Content, Email, Mail, To # type: ignore
|
||||
from sendgrid.helpers.mail import Content, Email, Mail, To
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from datetime import datetime
|
|||
from typing import Any, Optional
|
||||
|
||||
import sqlalchemy as sa
|
||||
from flask_login import UserMixin # type: ignore[import-untyped]
|
||||
from flask_login import UserMixin
|
||||
from sqlalchemy import DateTime, String, func, select
|
||||
from sqlalchemy.orm import Mapped, Session, mapped_column
|
||||
from typing_extensions import deprecated
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from typing import TYPE_CHECKING, Any, Literal, Optional, cast
|
|||
|
||||
import sqlalchemy as sa
|
||||
from flask import request
|
||||
from flask_login import UserMixin # type: ignore[import-untyped]
|
||||
from flask_login import UserMixin
|
||||
from sqlalchemy import Float, Index, PrimaryKeyConstraint, String, exists, func, select, text
|
||||
from sqlalchemy.orm import Mapped, Session, mapped_column
|
||||
|
||||
|
|
|
|||
|
|
@ -16,7 +16,25 @@
|
|||
"opentelemetry.instrumentation.requests",
|
||||
"opentelemetry.instrumentation.sqlalchemy",
|
||||
"opentelemetry.instrumentation.redis",
|
||||
"opentelemetry.instrumentation.httpx"
|
||||
"langfuse",
|
||||
"cloudscraper",
|
||||
"readabilipy",
|
||||
"pypandoc",
|
||||
"pypdfium2",
|
||||
"webvtt",
|
||||
"flask_compress",
|
||||
"oss2",
|
||||
"baidubce.auth.bce_credentials",
|
||||
"baidubce.bce_client_configuration",
|
||||
"baidubce.services.bos.bos_client",
|
||||
"clickzetta",
|
||||
"google.cloud",
|
||||
"obs",
|
||||
"qcloud_cos",
|
||||
"tos",
|
||||
"gmpy2",
|
||||
"sendgrid",
|
||||
"sendgrid.helpers.mail"
|
||||
],
|
||||
"reportUnknownMemberType": "hint",
|
||||
"reportUnknownParameterType": "hint",
|
||||
|
|
@ -28,7 +46,7 @@
|
|||
"reportUnnecessaryComparison": "hint",
|
||||
"reportUnnecessaryIsInstance": "hint",
|
||||
"reportUntypedFunctionDecorator": "hint",
|
||||
|
||||
"reportUnnecessaryTypeIgnoreComment": "hint",
|
||||
"reportAttributeAccessIssue": "hint",
|
||||
"pythonVersion": "3.11",
|
||||
"pythonPlatform": "All"
|
||||
|
|
|
|||
|
|
@ -48,7 +48,7 @@ class DifyAPIRepositoryFactory(DifyCoreRepositoryFactory):
|
|||
|
||||
try:
|
||||
repository_class = import_string(class_path)
|
||||
return repository_class(session_maker=session_maker) # type: ignore[no-any-return]
|
||||
return repository_class(session_maker=session_maker)
|
||||
except (ImportError, Exception) as e:
|
||||
raise RepositoryImportError(
|
||||
f"Failed to create DifyAPIWorkflowNodeExecutionRepository from '{class_path}': {e}"
|
||||
|
|
@ -77,6 +77,6 @@ class DifyAPIRepositoryFactory(DifyCoreRepositoryFactory):
|
|||
|
||||
try:
|
||||
repository_class = import_string(class_path)
|
||||
return repository_class(session_maker=session_maker) # type: ignore[no-any-return]
|
||||
return repository_class(session_maker=session_maker)
|
||||
except (ImportError, Exception) as e:
|
||||
raise RepositoryImportError(f"Failed to create APIWorkflowRunRepository from '{class_path}': {e}") from e
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ from enum import StrEnum
|
|||
from urllib.parse import urlparse
|
||||
from uuid import uuid4
|
||||
|
||||
import yaml # type: ignore
|
||||
import yaml
|
||||
from Crypto.Cipher import AES
|
||||
from Crypto.Util.Padding import pad, unpad
|
||||
from packaging import version
|
||||
|
|
@ -563,7 +563,7 @@ class AppDslService:
|
|||
else:
|
||||
cls._append_model_config_export_data(export_data, app_model)
|
||||
|
||||
return yaml.dump(export_data, allow_unicode=True) # type: ignore
|
||||
return yaml.dump(export_data, allow_unicode=True)
|
||||
|
||||
@classmethod
|
||||
def _append_workflow_export_data(
|
||||
|
|
|
|||
|
|
@ -241,9 +241,9 @@ class DatasetService:
|
|||
dataset.created_by = account.id
|
||||
dataset.updated_by = account.id
|
||||
dataset.tenant_id = tenant_id
|
||||
dataset.embedding_model_provider = embedding_model.provider if embedding_model else None # type: ignore
|
||||
dataset.embedding_model = embedding_model.model if embedding_model else None # type: ignore
|
||||
dataset.retrieval_model = retrieval_model.model_dump() if retrieval_model else None # type: ignore
|
||||
dataset.embedding_model_provider = embedding_model.provider if embedding_model else None
|
||||
dataset.embedding_model = embedding_model.model if embedding_model else None
|
||||
dataset.retrieval_model = retrieval_model.model_dump() if retrieval_model else None
|
||||
dataset.permission = permission or DatasetPermissionEnum.ONLY_ME
|
||||
dataset.provider = provider
|
||||
db.session.add(dataset)
|
||||
|
|
@ -1416,6 +1416,8 @@ class DocumentService:
|
|||
# check document limit
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
assert knowledge_config.data_source
|
||||
assert knowledge_config.data_source.info_list.file_info_list
|
||||
|
||||
features = FeatureService.get_features(current_user.current_tenant_id)
|
||||
|
||||
|
|
@ -1424,15 +1426,16 @@ class DocumentService:
|
|||
count = 0
|
||||
if knowledge_config.data_source:
|
||||
if knowledge_config.data_source.info_list.data_source_type == "upload_file":
|
||||
upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids # type: ignore
|
||||
upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids
|
||||
count = len(upload_file_list)
|
||||
elif knowledge_config.data_source.info_list.data_source_type == "notion_import":
|
||||
notion_info_list = knowledge_config.data_source.info_list.notion_info_list
|
||||
for notion_info in notion_info_list: # type: ignore
|
||||
notion_info_list = knowledge_config.data_source.info_list.notion_info_list or []
|
||||
for notion_info in notion_info_list:
|
||||
count = count + len(notion_info.pages)
|
||||
elif knowledge_config.data_source.info_list.data_source_type == "website_crawl":
|
||||
website_info = knowledge_config.data_source.info_list.website_info_list
|
||||
count = len(website_info.urls) # type: ignore
|
||||
assert website_info
|
||||
count = len(website_info.urls)
|
||||
batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT)
|
||||
|
||||
if features.billing.subscription.plan == "sandbox" and count > 1:
|
||||
|
|
@ -1444,7 +1447,7 @@ class DocumentService:
|
|||
|
||||
# if dataset is empty, update dataset data_source_type
|
||||
if not dataset.data_source_type:
|
||||
dataset.data_source_type = knowledge_config.data_source.info_list.data_source_type # type: ignore
|
||||
dataset.data_source_type = knowledge_config.data_source.info_list.data_source_type
|
||||
|
||||
if not dataset.indexing_technique:
|
||||
if knowledge_config.indexing_technique not in Dataset.INDEXING_TECHNIQUE_LIST:
|
||||
|
|
@ -1481,7 +1484,7 @@ class DocumentService:
|
|||
knowledge_config.retrieval_model.model_dump()
|
||||
if knowledge_config.retrieval_model
|
||||
else default_retrieval_model
|
||||
) # type: ignore
|
||||
)
|
||||
|
||||
documents = []
|
||||
if knowledge_config.original_document_id:
|
||||
|
|
@ -1523,11 +1526,12 @@ class DocumentService:
|
|||
db.session.flush()
|
||||
lock_name = f"add_document_lock_dataset_id_{dataset.id}"
|
||||
with redis_client.lock(lock_name, timeout=600):
|
||||
assert dataset_process_rule
|
||||
position = DocumentService.get_documents_position(dataset.id)
|
||||
document_ids = []
|
||||
duplicate_document_ids = []
|
||||
if knowledge_config.data_source.info_list.data_source_type == "upload_file": # type: ignore
|
||||
upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids # type: ignore
|
||||
if knowledge_config.data_source.info_list.data_source_type == "upload_file":
|
||||
upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids
|
||||
for file_id in upload_file_list:
|
||||
file = (
|
||||
db.session.query(UploadFile)
|
||||
|
|
@ -1540,7 +1544,7 @@ class DocumentService:
|
|||
raise FileNotExistsError()
|
||||
|
||||
file_name = file.name
|
||||
data_source_info = {
|
||||
data_source_info: dict[str, str | bool] = {
|
||||
"upload_file_id": file_id,
|
||||
}
|
||||
# check duplicate
|
||||
|
|
@ -1557,7 +1561,7 @@ class DocumentService:
|
|||
.first()
|
||||
)
|
||||
if document:
|
||||
document.dataset_process_rule_id = dataset_process_rule.id # type: ignore
|
||||
document.dataset_process_rule_id = dataset_process_rule.id
|
||||
document.updated_at = naive_utc_now()
|
||||
document.created_from = created_from
|
||||
document.doc_form = knowledge_config.doc_form
|
||||
|
|
@ -1571,8 +1575,8 @@ class DocumentService:
|
|||
continue
|
||||
document = DocumentService.build_document(
|
||||
dataset,
|
||||
dataset_process_rule.id, # type: ignore
|
||||
knowledge_config.data_source.info_list.data_source_type, # type: ignore
|
||||
dataset_process_rule.id,
|
||||
knowledge_config.data_source.info_list.data_source_type,
|
||||
knowledge_config.doc_form,
|
||||
knowledge_config.doc_language,
|
||||
data_source_info,
|
||||
|
|
@ -1587,7 +1591,7 @@ class DocumentService:
|
|||
document_ids.append(document.id)
|
||||
documents.append(document)
|
||||
position += 1
|
||||
elif knowledge_config.data_source.info_list.data_source_type == "notion_import": # type: ignore
|
||||
elif knowledge_config.data_source.info_list.data_source_type == "notion_import":
|
||||
notion_info_list = knowledge_config.data_source.info_list.notion_info_list # type: ignore
|
||||
if not notion_info_list:
|
||||
raise ValueError("No notion info list found.")
|
||||
|
|
@ -1616,15 +1620,15 @@ class DocumentService:
|
|||
"credential_id": notion_info.credential_id,
|
||||
"notion_workspace_id": workspace_id,
|
||||
"notion_page_id": page.page_id,
|
||||
"notion_page_icon": page.page_icon.model_dump() if page.page_icon else None,
|
||||
"notion_page_icon": page.page_icon.model_dump() if page.page_icon else None, # type: ignore
|
||||
"type": page.type,
|
||||
}
|
||||
# Truncate page name to 255 characters to prevent DB field length errors
|
||||
truncated_page_name = page.page_name[:255] if page.page_name else "nopagename"
|
||||
document = DocumentService.build_document(
|
||||
dataset,
|
||||
dataset_process_rule.id, # type: ignore
|
||||
knowledge_config.data_source.info_list.data_source_type, # type: ignore
|
||||
dataset_process_rule.id,
|
||||
knowledge_config.data_source.info_list.data_source_type,
|
||||
knowledge_config.doc_form,
|
||||
knowledge_config.doc_language,
|
||||
data_source_info,
|
||||
|
|
@ -1644,8 +1648,8 @@ class DocumentService:
|
|||
# delete not selected documents
|
||||
if len(exist_document) > 0:
|
||||
clean_notion_document_task.delay(list(exist_document.values()), dataset.id)
|
||||
elif knowledge_config.data_source.info_list.data_source_type == "website_crawl": # type: ignore
|
||||
website_info = knowledge_config.data_source.info_list.website_info_list # type: ignore
|
||||
elif knowledge_config.data_source.info_list.data_source_type == "website_crawl":
|
||||
website_info = knowledge_config.data_source.info_list.website_info_list
|
||||
if not website_info:
|
||||
raise ValueError("No website info list found.")
|
||||
urls = website_info.urls
|
||||
|
|
@ -1663,8 +1667,8 @@ class DocumentService:
|
|||
document_name = url
|
||||
document = DocumentService.build_document(
|
||||
dataset,
|
||||
dataset_process_rule.id, # type: ignore
|
||||
knowledge_config.data_source.info_list.data_source_type, # type: ignore
|
||||
dataset_process_rule.id,
|
||||
knowledge_config.data_source.info_list.data_source_type,
|
||||
knowledge_config.doc_form,
|
||||
knowledge_config.doc_language,
|
||||
data_source_info,
|
||||
|
|
@ -2071,7 +2075,7 @@ class DocumentService:
|
|||
# update document data source
|
||||
if document_data.data_source:
|
||||
file_name = ""
|
||||
data_source_info = {}
|
||||
data_source_info: dict[str, str | bool] = {}
|
||||
if document_data.data_source.info_list.data_source_type == "upload_file":
|
||||
if not document_data.data_source.info_list.file_info_list:
|
||||
raise ValueError("No file info list found.")
|
||||
|
|
@ -2128,7 +2132,7 @@ class DocumentService:
|
|||
"url": url,
|
||||
"provider": website_info.provider,
|
||||
"job_id": website_info.job_id,
|
||||
"only_main_content": website_info.only_main_content, # type: ignore
|
||||
"only_main_content": website_info.only_main_content,
|
||||
"mode": "crawl",
|
||||
}
|
||||
document.data_source_type = document_data.data_source.info_list.data_source_type
|
||||
|
|
@ -2154,7 +2158,7 @@ class DocumentService:
|
|||
|
||||
db.session.query(DocumentSegment).filter_by(document_id=document.id).update(
|
||||
{DocumentSegment.status: "re_segment"}
|
||||
) # type: ignore
|
||||
)
|
||||
db.session.commit()
|
||||
# trigger async task
|
||||
document_indexing_update_task.delay(document.dataset_id, document.id)
|
||||
|
|
@ -2164,25 +2168,26 @@ class DocumentService:
|
|||
def save_document_without_dataset_id(tenant_id: str, knowledge_config: KnowledgeConfig, account: Account):
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
assert knowledge_config.data_source
|
||||
|
||||
features = FeatureService.get_features(current_user.current_tenant_id)
|
||||
|
||||
if features.billing.enabled:
|
||||
count = 0
|
||||
if knowledge_config.data_source.info_list.data_source_type == "upload_file": # type: ignore
|
||||
if knowledge_config.data_source.info_list.data_source_type == "upload_file":
|
||||
upload_file_list = (
|
||||
knowledge_config.data_source.info_list.file_info_list.file_ids # type: ignore
|
||||
if knowledge_config.data_source.info_list.file_info_list # type: ignore
|
||||
knowledge_config.data_source.info_list.file_info_list.file_ids
|
||||
if knowledge_config.data_source.info_list.file_info_list
|
||||
else []
|
||||
)
|
||||
count = len(upload_file_list)
|
||||
elif knowledge_config.data_source.info_list.data_source_type == "notion_import": # type: ignore
|
||||
notion_info_list = knowledge_config.data_source.info_list.notion_info_list # type: ignore
|
||||
elif knowledge_config.data_source.info_list.data_source_type == "notion_import":
|
||||
notion_info_list = knowledge_config.data_source.info_list.notion_info_list
|
||||
if notion_info_list:
|
||||
for notion_info in notion_info_list:
|
||||
count = count + len(notion_info.pages)
|
||||
elif knowledge_config.data_source.info_list.data_source_type == "website_crawl": # type: ignore
|
||||
website_info = knowledge_config.data_source.info_list.website_info_list # type: ignore
|
||||
elif knowledge_config.data_source.info_list.data_source_type == "website_crawl":
|
||||
website_info = knowledge_config.data_source.info_list.website_info_list
|
||||
if website_info:
|
||||
count = len(website_info.urls)
|
||||
if features.billing.subscription.plan == "sandbox" and count > 1:
|
||||
|
|
@ -2196,9 +2201,11 @@ class DocumentService:
|
|||
dataset_collection_binding_id = None
|
||||
retrieval_model = None
|
||||
if knowledge_config.indexing_technique == "high_quality":
|
||||
assert knowledge_config.embedding_model_provider
|
||||
assert knowledge_config.embedding_model
|
||||
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
|
||||
knowledge_config.embedding_model_provider, # type: ignore
|
||||
knowledge_config.embedding_model, # type: ignore
|
||||
knowledge_config.embedding_model_provider,
|
||||
knowledge_config.embedding_model,
|
||||
)
|
||||
dataset_collection_binding_id = dataset_collection_binding.id
|
||||
if knowledge_config.retrieval_model:
|
||||
|
|
@ -2215,7 +2222,7 @@ class DocumentService:
|
|||
dataset = Dataset(
|
||||
tenant_id=tenant_id,
|
||||
name="",
|
||||
data_source_type=knowledge_config.data_source.info_list.data_source_type, # type: ignore
|
||||
data_source_type=knowledge_config.data_source.info_list.data_source_type,
|
||||
indexing_technique=knowledge_config.indexing_technique,
|
||||
created_by=account.id,
|
||||
embedding_model=knowledge_config.embedding_model,
|
||||
|
|
@ -2224,7 +2231,7 @@ class DocumentService:
|
|||
retrieval_model=retrieval_model.model_dump() if retrieval_model else None,
|
||||
)
|
||||
|
||||
db.session.add(dataset) # type: ignore
|
||||
db.session.add(dataset)
|
||||
db.session.flush()
|
||||
|
||||
documents, batch = DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account)
|
||||
|
|
|
|||
|
|
@ -88,7 +88,7 @@ class HitTestingService:
|
|||
db.session.add(dataset_query)
|
||||
db.session.commit()
|
||||
|
||||
return cls.compact_retrieve_response(query, all_documents) # type: ignore
|
||||
return cls.compact_retrieve_response(query, all_documents)
|
||||
|
||||
@classmethod
|
||||
def external_retrieve(
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
import boto3 # type: ignore
|
||||
import boto3
|
||||
|
||||
from configs import dify_config
|
||||
|
||||
|
|
|
|||
|
|
@ -89,7 +89,7 @@ class MetadataService:
|
|||
document.doc_metadata = doc_metadata
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
return metadata # type: ignore
|
||||
return metadata
|
||||
except Exception:
|
||||
logger.exception("Update metadata name failed")
|
||||
finally:
|
||||
|
|
|
|||
|
|
@ -137,7 +137,7 @@ class ModelProviderService:
|
|||
:return:
|
||||
"""
|
||||
provider_configuration = self._get_provider_configuration(tenant_id, provider)
|
||||
return provider_configuration.get_provider_credential(credential_id=credential_id) # type: ignore
|
||||
return provider_configuration.get_provider_credential(credential_id=credential_id)
|
||||
|
||||
def validate_provider_credentials(self, tenant_id: str, provider: str, credentials: dict):
|
||||
"""
|
||||
|
|
@ -225,7 +225,7 @@ class ModelProviderService:
|
|||
:return:
|
||||
"""
|
||||
provider_configuration = self._get_provider_configuration(tenant_id, provider)
|
||||
return provider_configuration.get_custom_model_credential( # type: ignore
|
||||
return provider_configuration.get_custom_model_credential(
|
||||
model_type=ModelType.value_of(model_type), model=model, credential_id=credential_id
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -146,7 +146,7 @@ class PluginMigration:
|
|||
futures.append(
|
||||
thread_pool.submit(
|
||||
process_tenant,
|
||||
current_app._get_current_object(), # type: ignore[attr-defined]
|
||||
current_app._get_current_object(), # type: ignore
|
||||
tenant_id,
|
||||
)
|
||||
)
|
||||
|
|
|
|||
|
|
@ -544,8 +544,8 @@ class BuiltinToolManageService:
|
|||
try:
|
||||
# handle include, exclude
|
||||
if is_filtered(
|
||||
include_set=dify_config.POSITION_TOOL_INCLUDES_SET, # type: ignore
|
||||
exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET, # type: ignore
|
||||
include_set=dify_config.POSITION_TOOL_INCLUDES_SET,
|
||||
exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET,
|
||||
data=provider_controller,
|
||||
name_func=lambda x: x.entity.identity.name,
|
||||
):
|
||||
|
|
|
|||
|
|
@ -308,7 +308,7 @@ class MCPToolManageService:
|
|||
provider_controller = MCPToolProviderController.from_db(mcp_provider)
|
||||
tool_configuration = ProviderConfigEncrypter(
|
||||
tenant_id=mcp_provider.tenant_id,
|
||||
config=list(provider_controller.get_credentials_schema()), # ty: ignore [invalid-argument-type]
|
||||
config=list(provider_controller.get_credentials_schema()),
|
||||
provider_config_cache=NoOpProviderCredentialCache(),
|
||||
)
|
||||
credentials = tool_configuration.encrypt(credentials)
|
||||
|
|
|
|||
|
|
@ -102,7 +102,7 @@ def batch_create_segment_to_index_task(
|
|||
for segment, tokens in zip(content, tokens_list):
|
||||
content = segment["content"]
|
||||
doc_id = str(uuid.uuid4())
|
||||
segment_hash = helper.generate_text_hash(content) # type: ignore
|
||||
segment_hash = helper.generate_text_hash(content)
|
||||
max_position = (
|
||||
db.session.query(func.max(DocumentSegment.position))
|
||||
.where(DocumentSegment.document_id == dataset_document.id)
|
||||
|
|
|
|||
|
|
@ -5,11 +5,11 @@ from unittest.mock import MagicMock
|
|||
|
||||
import pytest
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
from pymochow import MochowClient # type: ignore
|
||||
from pymochow.model.database import Database # type: ignore
|
||||
from pymochow.model.enum import IndexState, IndexType, MetricType, ReadConsistency, TableState # type: ignore
|
||||
from pymochow.model.schema import HNSWParams, VectorIndex # type: ignore
|
||||
from pymochow.model.table import Table # type: ignore
|
||||
from pymochow import MochowClient
|
||||
from pymochow.model.database import Database
|
||||
from pymochow.model.enum import IndexState, IndexType, MetricType, ReadConsistency, TableState
|
||||
from pymochow.model.schema import HNSWParams, VectorIndex
|
||||
from pymochow.model.table import Table
|
||||
|
||||
|
||||
class AttrDict(UserDict):
|
||||
|
|
|
|||
|
|
@ -3,15 +3,15 @@ from typing import Any, Union
|
|||
|
||||
import pytest
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
from tcvectordb import RPCVectorDBClient # type: ignore
|
||||
from tcvectordb import RPCVectorDBClient
|
||||
from tcvectordb.model import enum
|
||||
from tcvectordb.model.collection import FilterIndexConfig
|
||||
from tcvectordb.model.document import AnnSearch, Document, Filter, KeywordSearch, Rerank # type: ignore
|
||||
from tcvectordb.model.enum import ReadConsistency # type: ignore
|
||||
from tcvectordb.model.index import FilterIndex, HNSWParams, Index, IndexField, VectorIndex # type: ignore
|
||||
from tcvectordb.model.document import AnnSearch, Document, Filter, KeywordSearch, Rerank
|
||||
from tcvectordb.model.enum import ReadConsistency
|
||||
from tcvectordb.model.index import FilterIndex, HNSWParams, Index, IndexField, VectorIndex
|
||||
from tcvectordb.rpc.model.collection import RPCCollection
|
||||
from tcvectordb.rpc.model.database import RPCDatabase
|
||||
from xinference_client.types import Embedding # type: ignore
|
||||
from xinference_client.types import Embedding
|
||||
|
||||
|
||||
class MockTcvectordbClass:
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from unittest.mock import MagicMock
|
|||
|
||||
import pytest
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
from volcengine.viking_db import ( # type: ignore
|
||||
from volcengine.viking_db import (
|
||||
Collection,
|
||||
Data,
|
||||
DistanceType,
|
||||
|
|
|
|||
|
|
@ -43,7 +43,7 @@ class TestWorkflowResponseConverterFetchFilesFromVariableValue:
|
|||
"""Test with None input"""
|
||||
# The method signature expects Union[dict, list, Segment], but implementation handles None
|
||||
# We'll test the actual behavior by passing an empty dict instead
|
||||
result = WorkflowResponseConverter._fetch_files_from_variable_value(None) # type: ignore
|
||||
result = WorkflowResponseConverter._fetch_files_from_variable_value(None)
|
||||
assert result == []
|
||||
|
||||
def test_fetch_files_from_variable_value_with_empty_dict(self):
|
||||
|
|
|
|||
|
|
@ -235,7 +235,7 @@ class TestIndividualHandlers:
|
|||
# Type assertion needed due to union type
|
||||
text_content = result.content[0]
|
||||
assert hasattr(text_content, "text")
|
||||
assert text_content.text == "test answer" # type: ignore[attr-defined]
|
||||
assert text_content.text == "test answer"
|
||||
|
||||
def test_handle_call_tool_no_end_user(self):
|
||||
"""Test call tool handler without end user"""
|
||||
|
|
|
|||
|
|
@ -0,0 +1,181 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
|
||||
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph.validation import GraphValidationError
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from models.enums import UserFrom
|
||||
|
||||
|
||||
class _TestNode(Node):
|
||||
node_type = NodeType.ANSWER
|
||||
execution_type = NodeExecutionType.EXECUTABLE
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "test"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
id: str,
|
||||
config: Mapping[str, object],
|
||||
graph_init_params: GraphInitParams,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
id=id,
|
||||
config=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
data = config.get("data", {})
|
||||
if isinstance(data, Mapping):
|
||||
execution_type = data.get("execution_type")
|
||||
if isinstance(execution_type, str):
|
||||
self.execution_type = NodeExecutionType(execution_type)
|
||||
self._base_node_data = BaseNodeData(title=str(data.get("title", self.id)))
|
||||
self.data: dict[str, object] = {}
|
||||
|
||||
def init_node_data(self, data: Mapping[str, object]) -> None:
|
||||
title = str(data.get("title", self.id))
|
||||
desc = data.get("description")
|
||||
error_strategy_value = data.get("error_strategy")
|
||||
error_strategy: ErrorStrategy | None = None
|
||||
if isinstance(error_strategy_value, ErrorStrategy):
|
||||
error_strategy = error_strategy_value
|
||||
elif isinstance(error_strategy_value, str):
|
||||
error_strategy = ErrorStrategy(error_strategy_value)
|
||||
self._base_node_data = BaseNodeData(
|
||||
title=title,
|
||||
desc=str(desc) if desc is not None else None,
|
||||
error_strategy=error_strategy,
|
||||
)
|
||||
self.data = dict(data)
|
||||
|
||||
def _run(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._base_node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
return self._base_node_data.retry_config
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return self._base_node_data.title
|
||||
|
||||
def _get_description(self) -> str | None:
|
||||
return self._base_node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
return self._base_node_data.default_value_dict
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._base_node_data
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class _SimpleNodeFactory:
|
||||
graph_init_params: GraphInitParams
|
||||
graph_runtime_state: GraphRuntimeState
|
||||
|
||||
def create_node(self, node_config: Mapping[str, object]) -> _TestNode:
|
||||
node_id = str(node_config["id"])
|
||||
node = _TestNode(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
)
|
||||
node.init_node_data(node_config.get("data", {}))
|
||||
return node
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def graph_init_dependencies() -> tuple[_SimpleNodeFactory, dict[str, object]]:
|
||||
graph_config: dict[str, object] = {"edges": [], "nodes": []}
|
||||
init_params = GraphInitParams(
|
||||
tenant_id="tenant",
|
||||
app_id="app",
|
||||
workflow_id="workflow",
|
||||
graph_config=graph_config,
|
||||
user_id="user",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
call_depth=0,
|
||||
)
|
||||
variable_pool = VariablePool(system_variables=SystemVariable(user_id="user", files=[]), user_inputs={})
|
||||
runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
factory = _SimpleNodeFactory(graph_init_params=init_params, graph_runtime_state=runtime_state)
|
||||
return factory, graph_config
|
||||
|
||||
|
||||
def test_graph_initialization_runs_default_validators(
|
||||
graph_init_dependencies: tuple[_SimpleNodeFactory, dict[str, object]],
|
||||
):
|
||||
node_factory, graph_config = graph_init_dependencies
|
||||
graph_config["nodes"] = [
|
||||
{"id": "start", "data": {"type": NodeType.START, "title": "Start", "execution_type": NodeExecutionType.ROOT}},
|
||||
{"id": "answer", "data": {"type": NodeType.ANSWER, "title": "Answer"}},
|
||||
]
|
||||
graph_config["edges"] = [
|
||||
{"source": "start", "target": "answer", "sourceHandle": "success"},
|
||||
]
|
||||
|
||||
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)
|
||||
|
||||
assert graph.root_node.id == "start"
|
||||
assert "answer" in graph.nodes
|
||||
|
||||
|
||||
def test_graph_validation_fails_for_unknown_edge_targets(
|
||||
graph_init_dependencies: tuple[_SimpleNodeFactory, dict[str, object]],
|
||||
) -> None:
|
||||
node_factory, graph_config = graph_init_dependencies
|
||||
graph_config["nodes"] = [
|
||||
{"id": "start", "data": {"type": NodeType.START, "title": "Start", "execution_type": NodeExecutionType.ROOT}},
|
||||
]
|
||||
graph_config["edges"] = [
|
||||
{"source": "start", "target": "missing", "sourceHandle": "success"},
|
||||
]
|
||||
|
||||
with pytest.raises(GraphValidationError) as exc:
|
||||
Graph.init(graph_config=graph_config, node_factory=node_factory)
|
||||
|
||||
assert any(issue.code == "MISSING_NODE" for issue in exc.value.issues)
|
||||
|
||||
|
||||
def test_graph_promotes_fail_branch_nodes_to_branch_execution_type(
|
||||
graph_init_dependencies: tuple[_SimpleNodeFactory, dict[str, object]],
|
||||
) -> None:
|
||||
node_factory, graph_config = graph_init_dependencies
|
||||
graph_config["nodes"] = [
|
||||
{"id": "start", "data": {"type": NodeType.START, "title": "Start", "execution_type": NodeExecutionType.ROOT}},
|
||||
{
|
||||
"id": "branch",
|
||||
"data": {
|
||||
"type": NodeType.IF_ELSE,
|
||||
"title": "Branch",
|
||||
"error_strategy": ErrorStrategy.FAIL_BRANCH,
|
||||
},
|
||||
},
|
||||
]
|
||||
graph_config["edges"] = [
|
||||
{"source": "start", "target": "branch", "sourceHandle": "success"},
|
||||
]
|
||||
|
||||
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)
|
||||
|
||||
assert graph.nodes["branch"].execution_type == NodeExecutionType.BRANCH
|
||||
|
|
@ -212,7 +212,7 @@ class TestValidateResult:
|
|||
parameters=[
|
||||
ParameterConfig(
|
||||
name="status",
|
||||
type="select", # type: ignore
|
||||
type="select",
|
||||
description="Status",
|
||||
required=True,
|
||||
options=["active", "inactive"],
|
||||
|
|
@ -400,7 +400,7 @@ class TestTransformResult:
|
|||
parameters=[
|
||||
ParameterConfig(
|
||||
name="status",
|
||||
type="select", # type: ignore
|
||||
type="select",
|
||||
description="Status",
|
||||
required=True,
|
||||
options=["active", "inactive"],
|
||||
|
|
@ -414,7 +414,7 @@ class TestTransformResult:
|
|||
parameters=[
|
||||
ParameterConfig(
|
||||
name="status",
|
||||
type="select", # type: ignore
|
||||
type="select",
|
||||
description="Status",
|
||||
required=True,
|
||||
options=["active", "inactive"],
|
||||
|
|
|
|||
|
|
@ -248,4 +248,4 @@ def test_constructor_with_extra_key():
|
|||
# Test that SystemVariable should forbid extra keys
|
||||
with pytest.raises(ValidationError):
|
||||
# This should fail because there is an unexpected key.
|
||||
SystemVariable(invalid_key=1) # type: ignore
|
||||
SystemVariable(invalid_key=1)
|
||||
|
|
|
|||
|
|
@ -14,36 +14,36 @@ def _create_api_app():
|
|||
api = ExternalApi(bp)
|
||||
|
||||
@api.route("/bad-request")
|
||||
class Bad(Resource): # type: ignore
|
||||
def get(self): # type: ignore
|
||||
class Bad(Resource):
|
||||
def get(self):
|
||||
raise BadRequest("invalid input")
|
||||
|
||||
@api.route("/unauth")
|
||||
class Unauth(Resource): # type: ignore
|
||||
def get(self): # type: ignore
|
||||
class Unauth(Resource):
|
||||
def get(self):
|
||||
raise Unauthorized("auth required")
|
||||
|
||||
@api.route("/value-error")
|
||||
class ValErr(Resource): # type: ignore
|
||||
def get(self): # type: ignore
|
||||
class ValErr(Resource):
|
||||
def get(self):
|
||||
raise ValueError("boom")
|
||||
|
||||
@api.route("/quota")
|
||||
class Quota(Resource): # type: ignore
|
||||
def get(self): # type: ignore
|
||||
class Quota(Resource):
|
||||
def get(self):
|
||||
raise AppInvokeQuotaExceededError("quota exceeded")
|
||||
|
||||
@api.route("/general")
|
||||
class Gen(Resource): # type: ignore
|
||||
def get(self): # type: ignore
|
||||
class Gen(Resource):
|
||||
def get(self):
|
||||
raise RuntimeError("oops")
|
||||
|
||||
# Note: We avoid altering default_mediatype to keep normal error paths
|
||||
|
||||
# Special 400 message rewrite
|
||||
@api.route("/json-empty")
|
||||
class JsonEmpty(Resource): # type: ignore
|
||||
def get(self): # type: ignore
|
||||
class JsonEmpty(Resource):
|
||||
def get(self):
|
||||
e = BadRequest()
|
||||
# Force the specific message the handler rewrites
|
||||
e.description = "Failed to decode JSON object: Expecting value: line 1 column 1 (char 0)"
|
||||
|
|
@ -51,11 +51,11 @@ def _create_api_app():
|
|||
|
||||
# 400 mapping payload path
|
||||
@api.route("/param-errors")
|
||||
class ParamErrors(Resource): # type: ignore
|
||||
def get(self): # type: ignore
|
||||
class ParamErrors(Resource):
|
||||
def get(self):
|
||||
e = BadRequest()
|
||||
# Coerce a mapping description to trigger param error shaping
|
||||
e.description = {"field": "is required"} # type: ignore[assignment]
|
||||
e.description = {"field": "is required"}
|
||||
raise e
|
||||
|
||||
app.register_blueprint(bp, url_prefix="/api")
|
||||
|
|
@ -105,7 +105,7 @@ def test_external_api_param_mapping_and_quota_and_exc_info_none():
|
|||
|
||||
orig_exc_info = ext.sys.exc_info
|
||||
try:
|
||||
ext.sys.exc_info = lambda: (None, None, None) # type: ignore[assignment]
|
||||
ext.sys.exc_info = lambda: (None, None, None)
|
||||
|
||||
app = _create_api_app()
|
||||
client = app.test_client()
|
||||
|
|
|
|||
|
|
@ -67,7 +67,7 @@ def test_current_user_not_accessible_across_threads(login_app: Flask, test_user:
|
|||
# without preserve_flask_contexts
|
||||
result["user_accessible"] = current_user.is_authenticated
|
||||
except Exception as e:
|
||||
result["error"] = str(e) # type: ignore
|
||||
result["error"] = str(e)
|
||||
|
||||
# Run the function in a separate thread
|
||||
thread = threading.Thread(target=check_user_in_thread)
|
||||
|
|
@ -110,7 +110,7 @@ def test_current_user_accessible_with_preserve_flask_contexts(login_app: Flask,
|
|||
else:
|
||||
result["user_accessible"] = False
|
||||
except Exception as e:
|
||||
result["error"] = str(e) # type: ignore
|
||||
result["error"] = str(e)
|
||||
|
||||
# Run the function in a separate thread
|
||||
thread = threading.Thread(target=check_user_in_thread_with_manager)
|
||||
|
|
|
|||
|
|
@ -16,4 +16,4 @@ def test_oauth_base_methods_raise_not_implemented():
|
|||
oauth.get_raw_user_info("token")
|
||||
|
||||
with pytest.raises(NotImplementedError):
|
||||
oauth._transform_user_info({}) # type: ignore[name-defined]
|
||||
oauth._transform_user_info({})
|
||||
|
|
|
|||
|
|
@ -3,8 +3,8 @@ from unittest.mock import MagicMock
|
|||
|
||||
import pytest
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
from qcloud_cos import CosS3Client # type: ignore
|
||||
from qcloud_cos.streambody import StreamBody # type: ignore
|
||||
from qcloud_cos import CosS3Client
|
||||
from qcloud_cos.streambody import StreamBody
|
||||
|
||||
from tests.unit_tests.oss.__mock.base import (
|
||||
get_example_bucket,
|
||||
|
|
|
|||
|
|
@ -4,8 +4,8 @@ from unittest.mock import MagicMock
|
|||
|
||||
import pytest
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
from tos import TosClientV2 # type: ignore
|
||||
from tos.clientv2 import DeleteObjectOutput, GetObjectOutput, HeadObjectOutput, PutObjectOutput # type: ignore
|
||||
from tos import TosClientV2
|
||||
from tos.clientv2 import DeleteObjectOutput, GetObjectOutput, HeadObjectOutput, PutObjectOutput
|
||||
|
||||
from tests.unit_tests.oss.__mock.base import (
|
||||
get_example_bucket,
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from qcloud_cos import CosConfig # type: ignore
|
||||
from qcloud_cos import CosConfig
|
||||
|
||||
from extensions.storage.tencent_cos_storage import TencentCosStorage
|
||||
from tests.unit_tests.oss.__mock.base import (
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from tos import TosClientV2 # type: ignore
|
||||
from tos import TosClientV2
|
||||
|
||||
from extensions.storage.volcengine_tos_storage import VolcengineTosStorage
|
||||
from tests.unit_tests.oss.__mock.base import (
|
||||
|
|
|
|||
|
|
@ -125,13 +125,13 @@ class TestApiKeyAuthService:
|
|||
mock_session.commit = Mock()
|
||||
|
||||
args_copy = self.mock_args.copy()
|
||||
original_key = args_copy["credentials"]["config"]["api_key"] # type: ignore
|
||||
original_key = args_copy["credentials"]["config"]["api_key"]
|
||||
|
||||
ApiKeyAuthService.create_provider_auth(self.tenant_id, args_copy)
|
||||
|
||||
# Verify original key is replaced with encrypted key
|
||||
assert args_copy["credentials"]["config"]["api_key"] == encrypted_key # type: ignore
|
||||
assert args_copy["credentials"]["config"]["api_key"] != original_key # type: ignore
|
||||
assert args_copy["credentials"]["config"]["api_key"] == encrypted_key
|
||||
assert args_copy["credentials"]["config"]["api_key"] != original_key
|
||||
|
||||
# Verify encryption function is called correctly
|
||||
mock_encrypter.encrypt_token.assert_called_once_with(self.tenant_id, original_key)
|
||||
|
|
@ -268,7 +268,7 @@ class TestApiKeyAuthService:
|
|||
def test_validate_api_key_auth_args_empty_credentials(self):
|
||||
"""Test API key auth args validation - empty credentials"""
|
||||
args = self.mock_args.copy()
|
||||
args["credentials"] = None # type: ignore
|
||||
args["credentials"] = None
|
||||
|
||||
with pytest.raises(ValueError, match="credentials is required"):
|
||||
ApiKeyAuthService.validate_api_key_auth_args(args)
|
||||
|
|
@ -284,7 +284,7 @@ class TestApiKeyAuthService:
|
|||
def test_validate_api_key_auth_args_missing_auth_type(self):
|
||||
"""Test API key auth args validation - missing auth_type"""
|
||||
args = self.mock_args.copy()
|
||||
del args["credentials"]["auth_type"] # type: ignore
|
||||
del args["credentials"]["auth_type"]
|
||||
|
||||
with pytest.raises(ValueError, match="auth_type is required"):
|
||||
ApiKeyAuthService.validate_api_key_auth_args(args)
|
||||
|
|
@ -292,7 +292,7 @@ class TestApiKeyAuthService:
|
|||
def test_validate_api_key_auth_args_empty_auth_type(self):
|
||||
"""Test API key auth args validation - empty auth_type"""
|
||||
args = self.mock_args.copy()
|
||||
args["credentials"]["auth_type"] = "" # type: ignore
|
||||
args["credentials"]["auth_type"] = ""
|
||||
|
||||
with pytest.raises(ValueError, match="auth_type is required"):
|
||||
ApiKeyAuthService.validate_api_key_auth_args(args)
|
||||
|
|
@ -380,7 +380,7 @@ class TestApiKeyAuthService:
|
|||
def test_validate_api_key_auth_args_dict_credentials_with_list_auth_type(self):
|
||||
"""Test API key auth args validation - dict credentials with list auth_type"""
|
||||
args = self.mock_args.copy()
|
||||
args["credentials"]["auth_type"] = ["api_key"] # type: ignore # list instead of string
|
||||
args["credentials"]["auth_type"] = ["api_key"]
|
||||
|
||||
# Current implementation checks if auth_type exists and is truthy, list ["api_key"] is truthy
|
||||
# So this should not raise exception, this test should pass
|
||||
|
|
|
|||
|
|
@ -116,10 +116,10 @@ class TestSystemOAuthEncrypter:
|
|||
encrypter = SystemOAuthEncrypter("test_secret")
|
||||
|
||||
with pytest.raises(Exception): # noqa: B017
|
||||
encrypter.encrypt_oauth_params(None) # type: ignore
|
||||
encrypter.encrypt_oauth_params(None)
|
||||
|
||||
with pytest.raises(Exception): # noqa: B017
|
||||
encrypter.encrypt_oauth_params("not_a_dict") # type: ignore
|
||||
encrypter.encrypt_oauth_params("not_a_dict")
|
||||
|
||||
def test_decrypt_oauth_params_basic(self):
|
||||
"""Test basic OAuth parameters decryption"""
|
||||
|
|
@ -207,12 +207,12 @@ class TestSystemOAuthEncrypter:
|
|||
encrypter = SystemOAuthEncrypter("test_secret")
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
encrypter.decrypt_oauth_params(123) # type: ignore
|
||||
encrypter.decrypt_oauth_params(123)
|
||||
|
||||
assert "encrypted_data must be a string" in str(exc_info.value)
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
encrypter.decrypt_oauth_params(None) # type: ignore
|
||||
encrypter.decrypt_oauth_params(None)
|
||||
|
||||
assert "encrypted_data must be a string" in str(exc_info.value)
|
||||
|
||||
|
|
@ -461,14 +461,14 @@ class TestConvenienceFunctions:
|
|||
"""Test convenience functions with error conditions"""
|
||||
# Test encryption with invalid input
|
||||
with pytest.raises(Exception): # noqa: B017
|
||||
encrypt_system_oauth_params(None) # type: ignore
|
||||
encrypt_system_oauth_params(None)
|
||||
|
||||
# Test decryption with invalid input
|
||||
with pytest.raises(ValueError):
|
||||
decrypt_system_oauth_params("")
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
decrypt_system_oauth_params(None) # type: ignore
|
||||
decrypt_system_oauth_params(None)
|
||||
|
||||
|
||||
class TestErrorHandling:
|
||||
|
|
@ -501,7 +501,7 @@ class TestErrorHandling:
|
|||
|
||||
# Test non-string error
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
encrypter.decrypt_oauth_params(123) # type: ignore
|
||||
encrypter.decrypt_oauth_params(123)
|
||||
assert "encrypted_data must be a string" in str(exc_info.value)
|
||||
|
||||
# Test invalid format error
|
||||
|
|
|
|||
|
|
@ -0,0 +1,4 @@
|
|||
// Mock for context-block plugin to avoid circular dependency in Storybook
|
||||
export const ContextBlockNode = null
|
||||
export const ContextBlockReplacementBlock = null
|
||||
export default null
|
||||
|
|
@ -0,0 +1,4 @@
|
|||
// Mock for history-block plugin to avoid circular dependency in Storybook
|
||||
export const HistoryBlockNode = null
|
||||
export const HistoryBlockReplacementBlock = null
|
||||
export default null
|
||||
|
|
@ -0,0 +1,4 @@
|
|||
// Mock for query-block plugin to avoid circular dependency in Storybook
|
||||
export const QueryBlockNode = null
|
||||
export const QueryBlockReplacementBlock = null
|
||||
export default null
|
||||
|
|
@ -1,4 +1,9 @@
|
|||
import type { StorybookConfig } from '@storybook/nextjs'
|
||||
import path from 'node:path'
|
||||
import { fileURLToPath } from 'node:url'
|
||||
|
||||
const __filename = fileURLToPath(import.meta.url)
|
||||
const __dirname = path.dirname(__filename)
|
||||
|
||||
const config: StorybookConfig = {
|
||||
stories: ['../app/components/**/*.stories.@(js|jsx|mjs|ts|tsx)'],
|
||||
|
|
@ -25,5 +30,17 @@ const config: StorybookConfig = {
|
|||
docs: {
|
||||
defaultName: 'Documentation',
|
||||
},
|
||||
webpackFinal: async (config) => {
|
||||
// Add alias to mock problematic modules with circular dependencies
|
||||
config.resolve = config.resolve || {}
|
||||
config.resolve.alias = {
|
||||
...config.resolve.alias,
|
||||
// Mock the plugin index files to avoid circular dependencies
|
||||
[path.resolve(__dirname, '../app/components/base/prompt-editor/plugins/context-block/index.tsx')]: path.resolve(__dirname, '__mocks__/context-block.tsx'),
|
||||
[path.resolve(__dirname, '../app/components/base/prompt-editor/plugins/history-block/index.tsx')]: path.resolve(__dirname, '__mocks__/history-block.tsx'),
|
||||
[path.resolve(__dirname, '../app/components/base/prompt-editor/plugins/query-block/index.tsx')]: path.resolve(__dirname, '__mocks__/query-block.tsx'),
|
||||
}
|
||||
return config
|
||||
},
|
||||
}
|
||||
export default config
|
||||
|
|
|
|||
|
|
@ -160,8 +160,7 @@ describe('Navigation Utilities', () => {
|
|||
page: 1,
|
||||
limit: '',
|
||||
keyword: 'test',
|
||||
empty: null,
|
||||
undefined,
|
||||
filter: '',
|
||||
})
|
||||
|
||||
expect(path).toBe('/datasets/123/documents?page=1&keyword=test')
|
||||
|
|
|
|||
|
|
@ -39,28 +39,38 @@ const setupMockEnvironment = (storedTheme: string | null, systemPrefersDark = fa
|
|||
const isDarkQuery = DARK_MODE_MEDIA_QUERY.test(query)
|
||||
const matches = isDarkQuery ? systemPrefersDark : false
|
||||
|
||||
const handleAddListener = (listener: (event: MediaQueryListEvent) => void) => {
|
||||
listeners.add(listener)
|
||||
}
|
||||
|
||||
const handleRemoveListener = (listener: (event: MediaQueryListEvent) => void) => {
|
||||
listeners.delete(listener)
|
||||
}
|
||||
|
||||
const handleAddEventListener = (_event: string, listener: EventListener) => {
|
||||
if (typeof listener === 'function')
|
||||
listeners.add(listener as (event: MediaQueryListEvent) => void)
|
||||
}
|
||||
|
||||
const handleRemoveEventListener = (_event: string, listener: EventListener) => {
|
||||
if (typeof listener === 'function')
|
||||
listeners.delete(listener as (event: MediaQueryListEvent) => void)
|
||||
}
|
||||
|
||||
const handleDispatchEvent = (event: Event) => {
|
||||
listeners.forEach(listener => listener(event as MediaQueryListEvent))
|
||||
return true
|
||||
}
|
||||
|
||||
const mediaQueryList: MediaQueryList = {
|
||||
matches,
|
||||
media: query,
|
||||
onchange: null,
|
||||
addListener: (listener: MediaQueryListListener) => {
|
||||
listeners.add(listener)
|
||||
},
|
||||
removeListener: (listener: MediaQueryListListener) => {
|
||||
listeners.delete(listener)
|
||||
},
|
||||
addEventListener: (_event, listener: EventListener) => {
|
||||
if (typeof listener === 'function')
|
||||
listeners.add(listener as MediaQueryListListener)
|
||||
},
|
||||
removeEventListener: (_event, listener: EventListener) => {
|
||||
if (typeof listener === 'function')
|
||||
listeners.delete(listener as MediaQueryListListener)
|
||||
},
|
||||
dispatchEvent: (event: Event) => {
|
||||
listeners.forEach(listener => listener(event as MediaQueryListEvent))
|
||||
return true
|
||||
},
|
||||
addListener: handleAddListener,
|
||||
removeListener: handleRemoveListener,
|
||||
addEventListener: handleAddEventListener,
|
||||
removeEventListener: handleRemoveEventListener,
|
||||
dispatchEvent: handleDispatchEvent,
|
||||
}
|
||||
|
||||
return mediaQueryList
|
||||
|
|
@ -69,6 +79,121 @@ const setupMockEnvironment = (storedTheme: string | null, systemPrefersDark = fa
|
|||
jest.spyOn(window, 'matchMedia').mockImplementation(mockMatchMedia)
|
||||
}
|
||||
|
||||
// Helper function to create timing page component
|
||||
const createTimingPageComponent = (
|
||||
timingData: Array<{ phase: string; timestamp: number; styles: { backgroundColor: string; color: string } }>,
|
||||
) => {
|
||||
const recordTiming = (phase: string, styles: { backgroundColor: string; color: string }) => {
|
||||
timingData.push({
|
||||
phase,
|
||||
timestamp: performance.now(),
|
||||
styles,
|
||||
})
|
||||
}
|
||||
|
||||
const TimingPageComponent = () => {
|
||||
const [mounted, setMounted] = useState(false)
|
||||
const { theme } = useTheme()
|
||||
const isDark = mounted ? theme === 'dark' : false
|
||||
|
||||
const currentStyles = {
|
||||
backgroundColor: isDark ? '#1f2937' : '#ffffff',
|
||||
color: isDark ? '#ffffff' : '#000000',
|
||||
}
|
||||
|
||||
recordTiming(mounted ? 'CSR' : 'Initial', currentStyles)
|
||||
|
||||
useEffect(() => {
|
||||
setMounted(true)
|
||||
}, [])
|
||||
|
||||
return (
|
||||
<div
|
||||
data-testid="timing-page"
|
||||
style={currentStyles}
|
||||
>
|
||||
<div data-testid="timing-status">
|
||||
Phase: {mounted ? 'CSR' : 'Initial'} | Theme: {theme} | Visual: {isDark ? 'dark' : 'light'}
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
return TimingPageComponent
|
||||
}
|
||||
|
||||
// Helper function to create CSS test component
|
||||
const createCSSTestComponent = (
|
||||
cssStates: Array<{ className: string; timestamp: number }>,
|
||||
) => {
|
||||
const recordCSSState = (className: string) => {
|
||||
cssStates.push({
|
||||
className,
|
||||
timestamp: performance.now(),
|
||||
})
|
||||
}
|
||||
|
||||
const CSSTestComponent = () => {
|
||||
const [mounted, setMounted] = useState(false)
|
||||
const { theme } = useTheme()
|
||||
const isDark = mounted ? theme === 'dark' : false
|
||||
|
||||
const className = `min-h-screen ${isDark ? 'bg-gray-900 text-white' : 'bg-white text-black'}`
|
||||
|
||||
recordCSSState(className)
|
||||
|
||||
useEffect(() => {
|
||||
setMounted(true)
|
||||
}, [])
|
||||
|
||||
return (
|
||||
<div
|
||||
data-testid="css-component"
|
||||
className={className}
|
||||
>
|
||||
<div data-testid="css-classes">Classes: {className}</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
return CSSTestComponent
|
||||
}
|
||||
|
||||
// Helper function to create performance test component
|
||||
const createPerformanceTestComponent = (
|
||||
performanceMarks: Array<{ event: string; timestamp: number }>,
|
||||
) => {
|
||||
const recordPerformanceMark = (event: string) => {
|
||||
performanceMarks.push({ event, timestamp: performance.now() })
|
||||
}
|
||||
|
||||
const PerformanceTestComponent = () => {
|
||||
const [mounted, setMounted] = useState(false)
|
||||
const { theme } = useTheme()
|
||||
|
||||
recordPerformanceMark('component-render')
|
||||
|
||||
useEffect(() => {
|
||||
recordPerformanceMark('mount-start')
|
||||
setMounted(true)
|
||||
recordPerformanceMark('mount-complete')
|
||||
}, [])
|
||||
|
||||
useEffect(() => {
|
||||
if (theme)
|
||||
recordPerformanceMark('theme-available')
|
||||
}, [theme])
|
||||
|
||||
return (
|
||||
<div data-testid="performance-test">
|
||||
Mounted: {mounted.toString()} | Theme: {theme || 'loading'}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
return PerformanceTestComponent
|
||||
}
|
||||
|
||||
// Simulate real page component based on Dify's actual theme usage
|
||||
const PageComponent = () => {
|
||||
const [mounted, setMounted] = useState(false)
|
||||
|
|
@ -227,39 +352,7 @@ describe('Real Browser Environment Dark Mode Flicker Test', () => {
|
|||
setupMockEnvironment('dark')
|
||||
|
||||
const timingData: Array<{ phase: string; timestamp: number; styles: any }> = []
|
||||
|
||||
const TimingPageComponent = () => {
|
||||
const [mounted, setMounted] = useState(false)
|
||||
const { theme } = useTheme()
|
||||
const isDark = mounted ? theme === 'dark' : false
|
||||
|
||||
// Record timing and styles for each render phase
|
||||
const currentStyles = {
|
||||
backgroundColor: isDark ? '#1f2937' : '#ffffff',
|
||||
color: isDark ? '#ffffff' : '#000000',
|
||||
}
|
||||
|
||||
timingData.push({
|
||||
phase: mounted ? 'CSR' : 'Initial',
|
||||
timestamp: performance.now(),
|
||||
styles: currentStyles,
|
||||
})
|
||||
|
||||
useEffect(() => {
|
||||
setMounted(true)
|
||||
}, [])
|
||||
|
||||
return (
|
||||
<div
|
||||
data-testid="timing-page"
|
||||
style={currentStyles}
|
||||
>
|
||||
<div data-testid="timing-status">
|
||||
Phase: {mounted ? 'CSR' : 'Initial'} | Theme: {theme} | Visual: {isDark ? 'dark' : 'light'}
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
const TimingPageComponent = createTimingPageComponent(timingData)
|
||||
|
||||
render(
|
||||
<TestThemeProvider>
|
||||
|
|
@ -295,33 +388,7 @@ describe('Real Browser Environment Dark Mode Flicker Test', () => {
|
|||
setupMockEnvironment('dark')
|
||||
|
||||
const cssStates: Array<{ className: string; timestamp: number }> = []
|
||||
|
||||
const CSSTestComponent = () => {
|
||||
const [mounted, setMounted] = useState(false)
|
||||
const { theme } = useTheme()
|
||||
const isDark = mounted ? theme === 'dark' : false
|
||||
|
||||
// Simulate Tailwind CSS class application
|
||||
const className = `min-h-screen ${isDark ? 'bg-gray-900 text-white' : 'bg-white text-black'}`
|
||||
|
||||
cssStates.push({
|
||||
className,
|
||||
timestamp: performance.now(),
|
||||
})
|
||||
|
||||
useEffect(() => {
|
||||
setMounted(true)
|
||||
}, [])
|
||||
|
||||
return (
|
||||
<div
|
||||
data-testid="css-component"
|
||||
className={className}
|
||||
>
|
||||
<div data-testid="css-classes">Classes: {className}</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
const CSSTestComponent = createCSSTestComponent(cssStates)
|
||||
|
||||
render(
|
||||
<TestThemeProvider>
|
||||
|
|
@ -413,34 +480,12 @@ describe('Real Browser Environment Dark Mode Flicker Test', () => {
|
|||
test('verifies ThemeProvider position fix reduces initialization delay', async () => {
|
||||
const performanceMarks: Array<{ event: string; timestamp: number }> = []
|
||||
|
||||
const PerformanceTestComponent = () => {
|
||||
const [mounted, setMounted] = useState(false)
|
||||
const { theme } = useTheme()
|
||||
|
||||
performanceMarks.push({ event: 'component-render', timestamp: performance.now() })
|
||||
|
||||
useEffect(() => {
|
||||
performanceMarks.push({ event: 'mount-start', timestamp: performance.now() })
|
||||
setMounted(true)
|
||||
performanceMarks.push({ event: 'mount-complete', timestamp: performance.now() })
|
||||
}, [])
|
||||
|
||||
useEffect(() => {
|
||||
if (theme)
|
||||
performanceMarks.push({ event: 'theme-available', timestamp: performance.now() })
|
||||
}, [theme])
|
||||
|
||||
return (
|
||||
<div data-testid="performance-test">
|
||||
Mounted: {mounted.toString()} | Theme: {theme || 'loading'}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
setupMockEnvironment('dark')
|
||||
|
||||
expect(window.localStorage.getItem('theme')).toBe('dark')
|
||||
|
||||
const PerformanceTestComponent = createPerformanceTestComponent(performanceMarks)
|
||||
|
||||
render(
|
||||
<TestThemeProvider>
|
||||
<PerformanceTestComponent />
|
||||
|
|
|
|||
|
|
@ -70,14 +70,18 @@ describe('Unified Tags Editing - Pure Logic Tests', () => {
|
|||
})
|
||||
|
||||
describe('Fallback Logic (from layout-main.tsx)', () => {
|
||||
type Tag = { id: string; name: string }
|
||||
type AppDetail = { tags: Tag[] }
|
||||
type FallbackResult = { tags?: Tag[] } | null
|
||||
// no-op
|
||||
it('should trigger fallback when tags are missing or empty', () => {
|
||||
const appDetailWithoutTags = { tags: [] }
|
||||
const appDetailWithTags = { tags: [{ id: 'tag1' }] }
|
||||
const appDetailWithUndefinedTags = { tags: undefined as any }
|
||||
const appDetailWithoutTags: AppDetail = { tags: [] }
|
||||
const appDetailWithTags: AppDetail = { tags: [{ id: 'tag1', name: 't' }] }
|
||||
const appDetailWithUndefinedTags: { tags: Tag[] | undefined } = { tags: undefined }
|
||||
|
||||
// This simulates the condition in layout-main.tsx
|
||||
const shouldFallback1 = !appDetailWithoutTags.tags || appDetailWithoutTags.tags.length === 0
|
||||
const shouldFallback2 = !appDetailWithTags.tags || appDetailWithTags.tags.length === 0
|
||||
const shouldFallback1 = appDetailWithoutTags.tags.length === 0
|
||||
const shouldFallback2 = appDetailWithTags.tags.length === 0
|
||||
const shouldFallback3 = !appDetailWithUndefinedTags.tags || appDetailWithUndefinedTags.tags.length === 0
|
||||
|
||||
expect(shouldFallback1).toBe(true) // Empty array should trigger fallback
|
||||
|
|
@ -86,24 +90,26 @@ describe('Unified Tags Editing - Pure Logic Tests', () => {
|
|||
})
|
||||
|
||||
it('should preserve tags when fallback succeeds', () => {
|
||||
const originalAppDetail = { tags: [] as any[] }
|
||||
const fallbackResult = { tags: [{ id: 'tag1', name: 'fallback-tag' }] }
|
||||
const originalAppDetail: AppDetail = { tags: [] }
|
||||
const fallbackResult: { tags?: Tag[] } = { tags: [{ id: 'tag1', name: 'fallback-tag' }] }
|
||||
|
||||
// This simulates the successful fallback in layout-main.tsx
|
||||
if (fallbackResult?.tags)
|
||||
originalAppDetail.tags = fallbackResult.tags
|
||||
const tags = fallbackResult.tags
|
||||
if (tags)
|
||||
originalAppDetail.tags = tags
|
||||
|
||||
expect(originalAppDetail.tags).toEqual(fallbackResult.tags)
|
||||
expect(originalAppDetail.tags.length).toBe(1)
|
||||
})
|
||||
|
||||
it('should continue with empty tags when fallback fails', () => {
|
||||
const originalAppDetail: { tags: any[] } = { tags: [] }
|
||||
const fallbackResult: { tags?: any[] } | null = null
|
||||
const originalAppDetail: AppDetail = { tags: [] }
|
||||
const fallbackResult = null as FallbackResult
|
||||
|
||||
// This simulates fallback failure in layout-main.tsx
|
||||
if (fallbackResult?.tags)
|
||||
originalAppDetail.tags = fallbackResult.tags
|
||||
const tags: Tag[] | undefined = fallbackResult && 'tags' in fallbackResult ? fallbackResult.tags : undefined
|
||||
if (tags)
|
||||
originalAppDetail.tags = tags
|
||||
|
||||
expect(originalAppDetail.tags).toEqual([])
|
||||
})
|
||||
|
|
|
|||
|
|
@ -73,7 +73,7 @@ const ConfigPopup: FC<PopupProps> = ({
|
|||
}
|
||||
}, [onChooseProvider])
|
||||
|
||||
const handleConfigUpdated = useCallback((payload: ArizeConfig | PhoenixConfig | LangSmithConfig | LangFuseConfig | OpikConfig | WeaveConfig | AliyunConfig) => {
|
||||
const handleConfigUpdated = useCallback((payload: ArizeConfig | PhoenixConfig | LangSmithConfig | LangFuseConfig | OpikConfig | WeaveConfig | AliyunConfig | TencentConfig) => {
|
||||
onConfigUpdated(currentProvider!, payload)
|
||||
hideConfigModal()
|
||||
}, [currentProvider, hideConfigModal, onConfigUpdated])
|
||||
|
|
|
|||
|
|
@ -6,7 +6,6 @@ import { useWebAppStore } from '@/context/web-app-context'
|
|||
import { useRouter, useSearchParams } from 'next/navigation'
|
||||
import AppUnavailable from '@/app/components/base/app-unavailable'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { AccessMode } from '@/models/access-control'
|
||||
import { webAppLoginStatus, webAppLogout } from '@/service/webapp-auth'
|
||||
import { fetchAccessToken } from '@/service/share'
|
||||
import Loading from '@/app/components/base/loading'
|
||||
|
|
@ -35,7 +34,6 @@ const Splash: FC<PropsWithChildren> = ({ children }) => {
|
|||
router.replace(url)
|
||||
}, [getSigninUrl, router, webAppLogout, shareCode])
|
||||
|
||||
const needCheckIsLogin = webAppAccessMode !== AccessMode.PUBLIC
|
||||
const [isLoading, setIsLoading] = useState(true)
|
||||
useEffect(() => {
|
||||
if (message) {
|
||||
|
|
@ -58,8 +56,8 @@ const Splash: FC<PropsWithChildren> = ({ children }) => {
|
|||
}
|
||||
|
||||
(async () => {
|
||||
const { userLoggedIn, appLoggedIn } = await webAppLoginStatus(needCheckIsLogin, shareCode!)
|
||||
|
||||
// if access mode is public, user login is always true, but the app login(passport) may be expired
|
||||
const { userLoggedIn, appLoggedIn } = await webAppLoginStatus(shareCode!)
|
||||
if (userLoggedIn && appLoggedIn) {
|
||||
redirectOrFinish()
|
||||
}
|
||||
|
|
@ -87,7 +85,6 @@ const Splash: FC<PropsWithChildren> = ({ children }) => {
|
|||
router,
|
||||
message,
|
||||
webAppAccessMode,
|
||||
needCheckIsLogin,
|
||||
tokenFromUrl])
|
||||
|
||||
if (message) {
|
||||
|
|
|
|||
|
|
@ -0,0 +1,262 @@
|
|||
import type { Meta, StoryObj } from '@storybook/nextjs'
|
||||
import { RiAddLine, RiDeleteBinLine, RiEditLine, RiMore2Fill, RiSaveLine, RiShareLine } from '@remixicon/react'
|
||||
import ActionButton, { ActionButtonState } from '.'
|
||||
|
||||
const meta = {
|
||||
title: 'Base/ActionButton',
|
||||
component: ActionButton,
|
||||
parameters: {
|
||||
layout: 'centered',
|
||||
docs: {
|
||||
description: {
|
||||
component: 'Action button component with multiple sizes and states. Commonly used for toolbar actions and inline operations.',
|
||||
},
|
||||
},
|
||||
},
|
||||
tags: ['autodocs'],
|
||||
argTypes: {
|
||||
size: {
|
||||
control: 'select',
|
||||
options: ['xs', 'm', 'l', 'xl'],
|
||||
description: 'Button size',
|
||||
},
|
||||
state: {
|
||||
control: 'select',
|
||||
options: [
|
||||
ActionButtonState.Default,
|
||||
ActionButtonState.Active,
|
||||
ActionButtonState.Disabled,
|
||||
ActionButtonState.Destructive,
|
||||
ActionButtonState.Hover,
|
||||
],
|
||||
description: 'Button state',
|
||||
},
|
||||
children: {
|
||||
control: 'text',
|
||||
description: 'Button content',
|
||||
},
|
||||
disabled: {
|
||||
control: 'boolean',
|
||||
description: 'Native disabled state',
|
||||
},
|
||||
},
|
||||
} satisfies Meta<typeof ActionButton>
|
||||
|
||||
export default meta
|
||||
type Story = StoryObj<typeof meta>
|
||||
|
||||
// Default state
|
||||
export const Default: Story = {
|
||||
args: {
|
||||
size: 'm',
|
||||
children: <RiEditLine className="h-4 w-4" />,
|
||||
},
|
||||
}
|
||||
|
||||
// With text
|
||||
export const WithText: Story = {
|
||||
args: {
|
||||
size: 'm',
|
||||
children: 'Edit',
|
||||
},
|
||||
}
|
||||
|
||||
// Icon with text
|
||||
export const IconWithText: Story = {
|
||||
args: {
|
||||
size: 'm',
|
||||
children: (
|
||||
<>
|
||||
<RiAddLine className="mr-1 h-4 w-4" />
|
||||
Add Item
|
||||
</>
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
// Size variations
|
||||
export const ExtraSmall: Story = {
|
||||
args: {
|
||||
size: 'xs',
|
||||
children: <RiEditLine className="h-3 w-3" />,
|
||||
},
|
||||
}
|
||||
|
||||
export const Small: Story = {
|
||||
args: {
|
||||
size: 'xs',
|
||||
children: <RiEditLine className="h-3.5 w-3.5" />,
|
||||
},
|
||||
}
|
||||
|
||||
export const Medium: Story = {
|
||||
args: {
|
||||
size: 'm',
|
||||
children: <RiEditLine className="h-4 w-4" />,
|
||||
},
|
||||
}
|
||||
|
||||
export const Large: Story = {
|
||||
args: {
|
||||
size: 'l',
|
||||
children: <RiEditLine className="h-5 w-5" />,
|
||||
},
|
||||
}
|
||||
|
||||
export const ExtraLarge: Story = {
|
||||
args: {
|
||||
size: 'xl',
|
||||
children: <RiEditLine className="h-6 w-6" />,
|
||||
},
|
||||
}
|
||||
|
||||
// State variations
|
||||
export const ActiveState: Story = {
|
||||
args: {
|
||||
size: 'm',
|
||||
state: ActionButtonState.Active,
|
||||
children: <RiEditLine className="h-4 w-4" />,
|
||||
},
|
||||
}
|
||||
|
||||
export const DisabledState: Story = {
|
||||
args: {
|
||||
size: 'm',
|
||||
state: ActionButtonState.Disabled,
|
||||
children: <RiEditLine className="h-4 w-4" />,
|
||||
},
|
||||
}
|
||||
|
||||
export const DestructiveState: Story = {
|
||||
args: {
|
||||
size: 'm',
|
||||
state: ActionButtonState.Destructive,
|
||||
children: <RiDeleteBinLine className="h-4 w-4" />,
|
||||
},
|
||||
}
|
||||
|
||||
export const HoverState: Story = {
|
||||
args: {
|
||||
size: 'm',
|
||||
state: ActionButtonState.Hover,
|
||||
children: <RiEditLine className="h-4 w-4" />,
|
||||
},
|
||||
}
|
||||
|
||||
// Real-world examples
|
||||
export const ToolbarActions: Story = {
|
||||
render: () => (
|
||||
<div className="flex items-center gap-1 rounded-lg bg-background-section-burn p-2">
|
||||
<ActionButton size="m">
|
||||
<RiEditLine className="h-4 w-4" />
|
||||
</ActionButton>
|
||||
<ActionButton size="m">
|
||||
<RiShareLine className="h-4 w-4" />
|
||||
</ActionButton>
|
||||
<ActionButton size="m">
|
||||
<RiSaveLine className="h-4 w-4" />
|
||||
</ActionButton>
|
||||
<div className="mx-1 h-4 w-px bg-divider-regular" />
|
||||
<ActionButton size="m" state={ActionButtonState.Destructive}>
|
||||
<RiDeleteBinLine className="h-4 w-4" />
|
||||
</ActionButton>
|
||||
</div>
|
||||
),
|
||||
}
|
||||
|
||||
export const InlineActions: Story = {
|
||||
render: () => (
|
||||
<div className="flex items-center gap-2">
|
||||
<span className="text-text-secondary">Item name</span>
|
||||
<ActionButton size="xs">
|
||||
<RiEditLine className="h-3.5 w-3.5" />
|
||||
</ActionButton>
|
||||
<ActionButton size="xs">
|
||||
<RiMore2Fill className="h-3.5 w-3.5" />
|
||||
</ActionButton>
|
||||
</div>
|
||||
),
|
||||
}
|
||||
|
||||
export const SizeComparison: Story = {
|
||||
render: () => (
|
||||
<div className="flex items-center gap-4">
|
||||
<div className="flex flex-col items-center gap-2">
|
||||
<ActionButton size="xs">
|
||||
<RiEditLine className="h-3 w-3" />
|
||||
</ActionButton>
|
||||
<span className="text-xs text-text-tertiary">XS</span>
|
||||
</div>
|
||||
<div className="flex flex-col items-center gap-2">
|
||||
<ActionButton size="xs">
|
||||
<RiEditLine className="h-3.5 w-3.5" />
|
||||
</ActionButton>
|
||||
<span className="text-xs text-text-tertiary">S</span>
|
||||
</div>
|
||||
<div className="flex flex-col items-center gap-2">
|
||||
<ActionButton size="m">
|
||||
<RiEditLine className="h-4 w-4" />
|
||||
</ActionButton>
|
||||
<span className="text-xs text-text-tertiary">M</span>
|
||||
</div>
|
||||
<div className="flex flex-col items-center gap-2">
|
||||
<ActionButton size="l">
|
||||
<RiEditLine className="h-5 w-5" />
|
||||
</ActionButton>
|
||||
<span className="text-xs text-text-tertiary">L</span>
|
||||
</div>
|
||||
<div className="flex flex-col items-center gap-2">
|
||||
<ActionButton size="xl">
|
||||
<RiEditLine className="h-6 w-6" />
|
||||
</ActionButton>
|
||||
<span className="text-xs text-text-tertiary">XL</span>
|
||||
</div>
|
||||
</div>
|
||||
),
|
||||
}
|
||||
|
||||
export const StateComparison: Story = {
|
||||
render: () => (
|
||||
<div className="flex items-center gap-4">
|
||||
<div className="flex flex-col items-center gap-2">
|
||||
<ActionButton size="m" state={ActionButtonState.Default}>
|
||||
<RiEditLine className="h-4 w-4" />
|
||||
</ActionButton>
|
||||
<span className="text-xs text-text-tertiary">Default</span>
|
||||
</div>
|
||||
<div className="flex flex-col items-center gap-2">
|
||||
<ActionButton size="m" state={ActionButtonState.Active}>
|
||||
<RiEditLine className="h-4 w-4" />
|
||||
</ActionButton>
|
||||
<span className="text-xs text-text-tertiary">Active</span>
|
||||
</div>
|
||||
<div className="flex flex-col items-center gap-2">
|
||||
<ActionButton size="m" state={ActionButtonState.Hover}>
|
||||
<RiEditLine className="h-4 w-4" />
|
||||
</ActionButton>
|
||||
<span className="text-xs text-text-tertiary">Hover</span>
|
||||
</div>
|
||||
<div className="flex flex-col items-center gap-2">
|
||||
<ActionButton size="m" state={ActionButtonState.Disabled}>
|
||||
<RiEditLine className="h-4 w-4" />
|
||||
</ActionButton>
|
||||
<span className="text-xs text-text-tertiary">Disabled</span>
|
||||
</div>
|
||||
<div className="flex flex-col items-center gap-2">
|
||||
<ActionButton size="m" state={ActionButtonState.Destructive}>
|
||||
<RiDeleteBinLine className="h-4 w-4" />
|
||||
</ActionButton>
|
||||
<span className="text-xs text-text-tertiary">Destructive</span>
|
||||
</div>
|
||||
</div>
|
||||
),
|
||||
}
|
||||
|
||||
// Interactive playground
|
||||
export const Playground: Story = {
|
||||
args: {
|
||||
size: 'm',
|
||||
state: ActionButtonState.Default,
|
||||
children: <RiEditLine className="h-4 w-4" />,
|
||||
},
|
||||
}
|
||||
|
|
@ -0,0 +1,204 @@
|
|||
import type { Meta, StoryObj } from '@storybook/nextjs'
|
||||
import { useState } from 'react'
|
||||
import AutoHeightTextarea from '.'
|
||||
|
||||
const meta = {
|
||||
title: 'Base/AutoHeightTextarea',
|
||||
component: AutoHeightTextarea,
|
||||
parameters: {
|
||||
layout: 'centered',
|
||||
docs: {
|
||||
description: {
|
||||
component: 'Auto-resizing textarea component that expands and contracts based on content, with configurable min/max height constraints.',
|
||||
},
|
||||
},
|
||||
},
|
||||
tags: ['autodocs'],
|
||||
argTypes: {
|
||||
placeholder: {
|
||||
control: 'text',
|
||||
description: 'Placeholder text',
|
||||
},
|
||||
value: {
|
||||
control: 'text',
|
||||
description: 'Textarea value',
|
||||
},
|
||||
minHeight: {
|
||||
control: 'number',
|
||||
description: 'Minimum height in pixels',
|
||||
},
|
||||
maxHeight: {
|
||||
control: 'number',
|
||||
description: 'Maximum height in pixels',
|
||||
},
|
||||
autoFocus: {
|
||||
control: 'boolean',
|
||||
description: 'Auto focus on mount',
|
||||
},
|
||||
className: {
|
||||
control: 'text',
|
||||
description: 'Additional CSS classes',
|
||||
},
|
||||
wrapperClassName: {
|
||||
control: 'text',
|
||||
description: 'Wrapper CSS classes',
|
||||
},
|
||||
},
|
||||
} satisfies Meta<typeof AutoHeightTextarea>
|
||||
|
||||
export default meta
|
||||
type Story = StoryObj<typeof meta>
|
||||
|
||||
// Interactive demo wrapper
|
||||
const AutoHeightTextareaDemo = (args: any) => {
|
||||
const [value, setValue] = useState(args.value || '')
|
||||
|
||||
return (
|
||||
<div style={{ width: '500px' }}>
|
||||
<AutoHeightTextarea
|
||||
{...args}
|
||||
value={value}
|
||||
onChange={(e) => {
|
||||
setValue(e.target.value)
|
||||
console.log('Text changed:', e.target.value)
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
// Default state
|
||||
export const Default: Story = {
|
||||
render: args => <AutoHeightTextareaDemo {...args} />,
|
||||
args: {
|
||||
placeholder: 'Type something...',
|
||||
value: '',
|
||||
minHeight: 36,
|
||||
maxHeight: 96,
|
||||
className: 'w-full p-2 border border-gray-300 rounded-lg focus:outline-none focus:ring-2 focus:ring-blue-500',
|
||||
},
|
||||
}
|
||||
|
||||
// With initial value
|
||||
export const WithInitialValue: Story = {
|
||||
render: args => <AutoHeightTextareaDemo {...args} />,
|
||||
args: {
|
||||
placeholder: 'Type something...',
|
||||
value: 'This is a pre-filled textarea with some initial content.',
|
||||
minHeight: 36,
|
||||
maxHeight: 96,
|
||||
className: 'w-full p-2 border border-gray-300 rounded-lg focus:outline-none focus:ring-2 focus:ring-blue-500',
|
||||
},
|
||||
}
|
||||
|
||||
// With multiline content
|
||||
export const MultilineContent: Story = {
|
||||
render: args => <AutoHeightTextareaDemo {...args} />,
|
||||
args: {
|
||||
placeholder: 'Type something...',
|
||||
value: 'Line 1\nLine 2\nLine 3\nLine 4\nThis textarea automatically expands to fit the content.',
|
||||
minHeight: 36,
|
||||
maxHeight: 96,
|
||||
className: 'w-full p-2 border border-gray-300 rounded-lg focus:outline-none focus:ring-2 focus:ring-blue-500',
|
||||
},
|
||||
}
|
||||
|
||||
// Custom min height
|
||||
export const CustomMinHeight: Story = {
|
||||
render: args => <AutoHeightTextareaDemo {...args} />,
|
||||
args: {
|
||||
placeholder: 'Taller minimum height...',
|
||||
value: '',
|
||||
minHeight: 100,
|
||||
maxHeight: 200,
|
||||
className: 'w-full p-2 border border-gray-300 rounded-lg focus:outline-none focus:ring-2 focus:ring-blue-500',
|
||||
},
|
||||
}
|
||||
|
||||
// Small max height (scrollable)
|
||||
export const SmallMaxHeight: Story = {
|
||||
render: args => <AutoHeightTextareaDemo {...args} />,
|
||||
args: {
|
||||
placeholder: 'Type multiple lines...',
|
||||
value: 'Line 1\nLine 2\nLine 3\nLine 4\nLine 5\nLine 6\nThis will become scrollable when it exceeds max height.',
|
||||
minHeight: 36,
|
||||
maxHeight: 80,
|
||||
className: 'w-full p-2 border border-gray-300 rounded-lg focus:outline-none focus:ring-2 focus:ring-blue-500',
|
||||
},
|
||||
}
|
||||
|
||||
// Auto focus enabled
|
||||
export const AutoFocus: Story = {
|
||||
render: args => <AutoHeightTextareaDemo {...args} />,
|
||||
args: {
|
||||
placeholder: 'This textarea auto-focuses on mount',
|
||||
value: '',
|
||||
minHeight: 36,
|
||||
maxHeight: 96,
|
||||
autoFocus: true,
|
||||
className: 'w-full p-2 border border-gray-300 rounded-lg focus:outline-none focus:ring-2 focus:ring-blue-500',
|
||||
},
|
||||
}
|
||||
|
||||
// With custom styling
|
||||
export const CustomStyling: Story = {
|
||||
render: args => <AutoHeightTextareaDemo {...args} />,
|
||||
args: {
|
||||
placeholder: 'Custom styled textarea...',
|
||||
value: '',
|
||||
minHeight: 50,
|
||||
maxHeight: 150,
|
||||
className: 'w-full p-3 bg-gray-50 border-2 border-blue-400 rounded-xl text-lg focus:outline-none focus:bg-white focus:border-blue-600',
|
||||
wrapperClassName: 'shadow-lg',
|
||||
},
|
||||
}
|
||||
|
||||
// Long content example
|
||||
export const LongContent: Story = {
|
||||
render: args => <AutoHeightTextareaDemo {...args} />,
|
||||
args: {
|
||||
placeholder: 'Type something...',
|
||||
value: 'Lorem ipsum dolor sit amet, consectetur adipiscing elit. Sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.\n\nUt enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.\n\nDuis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur.\n\nExcepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.',
|
||||
minHeight: 36,
|
||||
maxHeight: 200,
|
||||
className: 'w-full p-2 border border-gray-300 rounded-lg focus:outline-none focus:ring-2 focus:ring-blue-500',
|
||||
},
|
||||
}
|
||||
|
||||
// Real-world example - Chat input
|
||||
export const ChatInput: Story = {
|
||||
render: args => <AutoHeightTextareaDemo {...args} />,
|
||||
args: {
|
||||
placeholder: 'Type your message...',
|
||||
value: '',
|
||||
minHeight: 40,
|
||||
maxHeight: 120,
|
||||
className: 'w-full px-4 py-2 bg-gray-100 border border-gray-300 rounded-2xl text-sm focus:outline-none focus:bg-white focus:ring-2 focus:ring-blue-500',
|
||||
},
|
||||
}
|
||||
|
||||
// Real-world example - Comment box
|
||||
export const CommentBox: Story = {
|
||||
render: args => <AutoHeightTextareaDemo {...args} />,
|
||||
args: {
|
||||
placeholder: 'Write a comment...',
|
||||
value: '',
|
||||
minHeight: 60,
|
||||
maxHeight: 200,
|
||||
className: 'w-full p-3 border border-gray-300 rounded-lg text-sm focus:outline-none focus:ring-2 focus:ring-indigo-500',
|
||||
},
|
||||
}
|
||||
|
||||
// Interactive playground
|
||||
export const Playground: Story = {
|
||||
render: args => <AutoHeightTextareaDemo {...args} />,
|
||||
args: {
|
||||
placeholder: 'Type something...',
|
||||
value: '',
|
||||
minHeight: 36,
|
||||
maxHeight: 96,
|
||||
autoFocus: false,
|
||||
className: 'w-full p-2 border border-gray-300 rounded-lg focus:outline-none focus:ring-2 focus:ring-blue-500',
|
||||
wrapperClassName: '',
|
||||
},
|
||||
}
|
||||
|
|
@ -31,7 +31,7 @@ const AutoHeightTextarea = (
|
|||
onKeyDown,
|
||||
onKeyUp,
|
||||
}: IProps & {
|
||||
ref: React.RefObject<unknown>;
|
||||
ref?: React.RefObject<HTMLTextAreaElement>;
|
||||
},
|
||||
) => {
|
||||
// eslint-disable-next-line react-hooks/rules-of-hooks
|
||||
|
|
|
|||
|
|
@ -0,0 +1,191 @@
|
|||
import type { Meta, StoryObj } from '@storybook/nextjs'
|
||||
import { useState } from 'react'
|
||||
import BlockInput from '.'
|
||||
|
||||
const meta = {
|
||||
title: 'Base/BlockInput',
|
||||
component: BlockInput,
|
||||
parameters: {
|
||||
layout: 'centered',
|
||||
docs: {
|
||||
description: {
|
||||
component: 'Block input component with variable highlighting. Supports {{variable}} syntax with validation and visual highlighting of variable names.',
|
||||
},
|
||||
},
|
||||
},
|
||||
tags: ['autodocs'],
|
||||
argTypes: {
|
||||
value: {
|
||||
control: 'text',
|
||||
description: 'Input value (supports {{variable}} syntax)',
|
||||
},
|
||||
className: {
|
||||
control: 'text',
|
||||
description: 'Wrapper CSS classes',
|
||||
},
|
||||
highLightClassName: {
|
||||
control: 'text',
|
||||
description: 'CSS class for highlighted variables (default: text-blue-500)',
|
||||
},
|
||||
readonly: {
|
||||
control: 'boolean',
|
||||
description: 'Read-only mode',
|
||||
},
|
||||
},
|
||||
} satisfies Meta<typeof BlockInput>
|
||||
|
||||
export default meta
|
||||
type Story = StoryObj<typeof meta>
|
||||
|
||||
// Interactive demo wrapper
|
||||
const BlockInputDemo = (args: any) => {
|
||||
const [value, setValue] = useState(args.value || '')
|
||||
const [keys, setKeys] = useState<string[]>([])
|
||||
|
||||
return (
|
||||
<div style={{ width: '600px' }}>
|
||||
<BlockInput
|
||||
{...args}
|
||||
value={value}
|
||||
onConfirm={(newValue, extractedKeys) => {
|
||||
setValue(newValue)
|
||||
setKeys(extractedKeys)
|
||||
console.log('Value confirmed:', newValue)
|
||||
console.log('Extracted keys:', extractedKeys)
|
||||
}}
|
||||
/>
|
||||
{keys.length > 0 && (
|
||||
<div className="mt-4 rounded-lg bg-blue-50 p-3">
|
||||
<div className="mb-2 text-sm font-medium text-gray-700">Detected Variables:</div>
|
||||
<div className="flex flex-wrap gap-2">
|
||||
{keys.map(key => (
|
||||
<span key={key} className="rounded bg-blue-500 px-2 py-1 text-xs text-white">
|
||||
{key}
|
||||
</span>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
// Default state
|
||||
export const Default: Story = {
|
||||
render: args => <BlockInputDemo {...args} />,
|
||||
args: {
|
||||
value: '',
|
||||
readonly: false,
|
||||
},
|
||||
}
|
||||
|
||||
// With single variable
|
||||
export const SingleVariable: Story = {
|
||||
render: args => <BlockInputDemo {...args} />,
|
||||
args: {
|
||||
value: 'Hello {{name}}, welcome to the application!',
|
||||
readonly: false,
|
||||
},
|
||||
}
|
||||
|
||||
// With multiple variables
|
||||
export const MultipleVariables: Story = {
|
||||
render: args => <BlockInputDemo {...args} />,
|
||||
args: {
|
||||
value: 'Dear {{user_name}},\n\nYour order {{order_id}} has been shipped to {{address}}.\n\nThank you for shopping with us!',
|
||||
readonly: false,
|
||||
},
|
||||
}
|
||||
|
||||
// Complex template
|
||||
export const ComplexTemplate: Story = {
|
||||
render: args => <BlockInputDemo {...args} />,
|
||||
args: {
|
||||
value: 'Hi {{customer_name}},\n\nYour {{product_type}} subscription will renew on {{renewal_date}} for {{amount}}.\n\nYour payment method ending in {{card_last_4}} will be charged.\n\nQuestions? Contact us at {{support_email}}.',
|
||||
readonly: false,
|
||||
},
|
||||
}
|
||||
|
||||
// Read-only mode
|
||||
export const ReadOnlyMode: Story = {
|
||||
render: args => <BlockInputDemo {...args} />,
|
||||
args: {
|
||||
value: 'This is a read-only template with {{variable1}} and {{variable2}}.\n\nYou cannot edit this content.',
|
||||
readonly: true,
|
||||
},
|
||||
}
|
||||
|
||||
// Empty state
|
||||
export const EmptyState: Story = {
|
||||
render: args => <BlockInputDemo {...args} />,
|
||||
args: {
|
||||
value: '',
|
||||
readonly: false,
|
||||
},
|
||||
}
|
||||
|
||||
// Long content
|
||||
export const LongContent: Story = {
|
||||
render: args => <BlockInputDemo {...args} />,
|
||||
args: {
|
||||
value: 'Dear {{recipient_name}},\n\nWe are writing to inform you about the upcoming changes to your {{service_name}} account.\n\nEffective {{effective_date}}, your plan will include:\n\n1. Access to {{feature_1}}\n2. {{feature_2}} with unlimited usage\n3. Priority support via {{support_channel}}\n4. Monthly reports sent to {{email_address}}\n\nYour new monthly rate will be {{new_price}}, compared to your current rate of {{old_price}}.\n\nIf you have any questions, please contact our team at {{contact_info}}.\n\nBest regards,\n{{company_name}} Team',
|
||||
readonly: false,
|
||||
},
|
||||
}
|
||||
|
||||
// Variables with underscores
|
||||
export const VariablesWithUnderscores: Story = {
|
||||
render: args => <BlockInputDemo {...args} />,
|
||||
args: {
|
||||
value: 'User {{user_id}} from {{user_country}} has {{total_orders}} orders with status {{order_status}}.',
|
||||
readonly: false,
|
||||
},
|
||||
}
|
||||
|
||||
// Adjacent variables
|
||||
export const AdjacentVariables: Story = {
|
||||
render: args => <BlockInputDemo {...args} />,
|
||||
args: {
|
||||
value: 'File: {{file_name}}.{{file_extension}} ({{file_size}}{{size_unit}})',
|
||||
readonly: false,
|
||||
},
|
||||
}
|
||||
|
||||
// Real-world example - Email template
|
||||
export const EmailTemplate: Story = {
|
||||
render: args => <BlockInputDemo {...args} />,
|
||||
args: {
|
||||
value: 'Subject: Your {{service_name}} account has been created\n\nHi {{first_name}},\n\nWelcome to {{company_name}}! Your account is now active.\n\nUsername: {{username}}\nEmail: {{email}}\n\nGet started at {{app_url}}\n\nThanks,\nThe {{company_name}} Team',
|
||||
readonly: false,
|
||||
},
|
||||
}
|
||||
|
||||
// Real-world example - Notification template
|
||||
export const NotificationTemplate: Story = {
|
||||
render: args => <BlockInputDemo {...args} />,
|
||||
args: {
|
||||
value: '🔔 {{user_name}} mentioned you in {{channel_name}}\n\n"{{message_preview}}"\n\nReply now: {{message_url}}',
|
||||
readonly: false,
|
||||
},
|
||||
}
|
||||
|
||||
// Custom styling
|
||||
export const CustomStyling: Story = {
|
||||
render: args => <BlockInputDemo {...args} />,
|
||||
args: {
|
||||
value: 'This template uses {{custom_variable}} with custom styling.',
|
||||
readonly: false,
|
||||
className: 'bg-gray-50 border-2 border-blue-200',
|
||||
},
|
||||
}
|
||||
|
||||
// Interactive playground
|
||||
export const Playground: Story = {
|
||||
render: args => <BlockInputDemo {...args} />,
|
||||
args: {
|
||||
value: 'Try editing this text and adding variables like {{example}}',
|
||||
readonly: false,
|
||||
className: '',
|
||||
highLightClassName: '',
|
||||
},
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue