mirror of
https://github.com/langgenius/dify.git
synced 2026-05-06 18:27:19 +08:00
Merge branch 'main' into jzh
This commit is contained in:
commit
9dd73b4d47
@ -39,6 +39,58 @@ class AbstractVectorFactory(ABC):
|
||||
return index_struct_dict
|
||||
|
||||
|
||||
class _LazyEmbeddings(Embeddings):
|
||||
"""Lazy proxy that defers materializing the real embedding model.
|
||||
|
||||
Constructing the real embeddings (via ``ModelManager.get_model_instance``)
|
||||
transitively calls ``FeatureService.get_features`` → ``BillingService``
|
||||
HTTP GETs (see ``provider_manager.py``). Cleanup paths
|
||||
(``delete_by_ids`` / ``delete`` / ``text_exists``) do not need embeddings
|
||||
at all, so deferring this until an ``embed_*`` method is actually invoked
|
||||
keeps cleanup tasks resilient to transient billing-API failures and avoids
|
||||
leaving stranded ``document_segments`` / ``child_chunks`` whenever billing
|
||||
hiccups.
|
||||
|
||||
Existing callers that perform create / search operations are unaffected:
|
||||
the first ``embed_*`` call materializes the underlying model and the
|
||||
behavior is identical from that point on.
|
||||
"""
|
||||
|
||||
def __init__(self, dataset: Dataset):
|
||||
self._dataset = dataset
|
||||
self._real: Embeddings | None = None
|
||||
|
||||
def _ensure(self) -> Embeddings:
|
||||
if self._real is None:
|
||||
model_manager = ModelManager.for_tenant(tenant_id=self._dataset.tenant_id)
|
||||
embedding_model = model_manager.get_model_instance(
|
||||
tenant_id=self._dataset.tenant_id,
|
||||
provider=self._dataset.embedding_model_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=self._dataset.embedding_model,
|
||||
)
|
||||
self._real = CacheEmbedding(embedding_model)
|
||||
return self._real
|
||||
|
||||
def embed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||
return self._ensure().embed_documents(texts)
|
||||
|
||||
def embed_multimodal_documents(self, multimodel_documents: list[dict[str, Any]]) -> list[list[float]]:
|
||||
return self._ensure().embed_multimodal_documents(multimodel_documents)
|
||||
|
||||
def embed_query(self, text: str) -> list[float]:
|
||||
return self._ensure().embed_query(text)
|
||||
|
||||
def embed_multimodal_query(self, multimodel_document: dict[str, Any]) -> list[float]:
|
||||
return self._ensure().embed_multimodal_query(multimodel_document)
|
||||
|
||||
async def aembed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||
return await self._ensure().aembed_documents(texts)
|
||||
|
||||
async def aembed_query(self, text: str) -> list[float]:
|
||||
return await self._ensure().aembed_query(text)
|
||||
|
||||
|
||||
class Vector:
|
||||
def __init__(self, dataset: Dataset, attributes: list | None = None):
|
||||
if attributes is None:
|
||||
@ -60,7 +112,11 @@ class Vector:
|
||||
"original_chunk_id",
|
||||
]
|
||||
self._dataset = dataset
|
||||
self._embeddings = self._get_embeddings()
|
||||
# Use a lazy proxy so cleanup paths (delete_by_ids / delete / text_exists)
|
||||
# never transitively trigger billing API calls during ``Vector(dataset)``
|
||||
# construction. The real embedding model is materialized only when an
|
||||
# ``embed_*`` method is actually invoked (i.e. create / search paths).
|
||||
self._embeddings: Embeddings = _LazyEmbeddings(dataset)
|
||||
self._attributes = attributes
|
||||
self._vector_processor = self._init_vector()
|
||||
|
||||
|
||||
@ -2182,7 +2182,7 @@ class ApiToken(Base): # bug: this uses setattr so idk the field.
|
||||
return result
|
||||
|
||||
|
||||
class UploadFile(Base):
|
||||
class UploadFile(TypeBase):
|
||||
__tablename__ = "upload_files"
|
||||
__table_args__ = (
|
||||
sa.PrimaryKeyConstraint("id", name="upload_file_pkey"),
|
||||
@ -2190,9 +2190,12 @@ class UploadFile(Base):
|
||||
)
|
||||
|
||||
# NOTE: The `id` field is generated within the application to minimize extra roundtrips
|
||||
# (especially when generating `source_url`).
|
||||
# The `server_default` serves as a fallback mechanism.
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()))
|
||||
# (especially when generating `source_url`) and keep model metadata portable across databases.
|
||||
id: Mapped[str] = mapped_column(
|
||||
StringUUID,
|
||||
init=False,
|
||||
default_factory=lambda: str(uuid4()),
|
||||
)
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
storage_type: Mapped[StorageType] = mapped_column(EnumText(StorageType, length=255), nullable=False)
|
||||
key: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
@ -2200,16 +2203,6 @@ class UploadFile(Base):
|
||||
size: Mapped[int] = mapped_column(sa.Integer, nullable=False)
|
||||
extension: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
mime_type: Mapped[str] = mapped_column(String(255), nullable=True)
|
||||
|
||||
# The `created_by_role` field indicates whether the file was created by an `Account` or an `EndUser`.
|
||||
# Its value is derived from the `CreatorUserRole` enumeration.
|
||||
created_by_role: Mapped[CreatorUserRole] = mapped_column(
|
||||
EnumText(CreatorUserRole, length=255),
|
||||
nullable=False,
|
||||
server_default=sa.text("'account'"),
|
||||
default=CreatorUserRole.ACCOUNT,
|
||||
)
|
||||
|
||||
# The `created_by` field stores the ID of the entity that created this upload file.
|
||||
#
|
||||
# If `created_by_role` is `ACCOUNT`, it corresponds to `Account.id`.
|
||||
@ -2228,10 +2221,18 @@ class UploadFile(Base):
|
||||
# `used` may indicate whether the file has been utilized by another service.
|
||||
used: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
|
||||
|
||||
# The `created_by_role` field indicates whether the file was created by an `Account` or an `EndUser`.
|
||||
# Its value is derived from the `CreatorUserRole` enumeration.
|
||||
created_by_role: Mapped[CreatorUserRole] = mapped_column(
|
||||
EnumText(CreatorUserRole, length=255),
|
||||
nullable=False,
|
||||
server_default=sa.text("'account'"),
|
||||
default=CreatorUserRole.ACCOUNT,
|
||||
)
|
||||
# `used_by` may indicate the ID of the user who utilized this file.
|
||||
used_by: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
|
||||
used_at: Mapped[datetime | None] = mapped_column(sa.DateTime, nullable=True)
|
||||
hash: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
used_by: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None)
|
||||
used_at: Mapped[datetime | None] = mapped_column(sa.DateTime, nullable=True, default=None)
|
||||
hash: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None)
|
||||
source_url: Mapped[str] = mapped_column(LongText, default="")
|
||||
|
||||
def __init__(
|
||||
|
||||
@ -50,7 +50,7 @@ from libs.uuid_utils import uuidv7
|
||||
from ._workflow_exc import NodeNotFoundError, WorkflowDataError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .model import AppMode, UploadFile
|
||||
from .model import AppMode
|
||||
|
||||
|
||||
from constants import DEFAULT_FILE_NUMBER_LIMITS, HIDDEN_VALUE
|
||||
@ -63,6 +63,10 @@ from .account import Account
|
||||
from .base import Base, DefaultFieldsDCMixin, TypeBase
|
||||
from .engine import db
|
||||
from .enums import CreatorUserRole, DraftVariableType, ExecutionOffLoadType, WorkflowRunTriggeredFrom
|
||||
|
||||
# UploadFile uses TypeBase while workflow execution offload models use Base, so relationships
|
||||
# must target the class object directly instead of relying on string lookup across registries.
|
||||
from .model import UploadFile
|
||||
from .types import EnumText, LongText, StringUUID
|
||||
from .utils.file_input_compat import (
|
||||
build_file_from_mapping_without_lookup,
|
||||
@ -1096,8 +1100,6 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo
|
||||
|
||||
@staticmethod
|
||||
def _load_full_content(session: orm.Session, file_id: str, storage: Storage):
|
||||
from .model import UploadFile
|
||||
|
||||
stmt = sa.select(UploadFile).where(UploadFile.id == file_id)
|
||||
file = session.scalars(stmt).first()
|
||||
assert file is not None, f"UploadFile with id {file_id} should exist but not"
|
||||
@ -1191,10 +1193,11 @@ class WorkflowNodeExecutionOffload(Base):
|
||||
)
|
||||
|
||||
file: Mapped[Optional["UploadFile"]] = orm.relationship(
|
||||
UploadFile,
|
||||
foreign_keys=[file_id],
|
||||
lazy="raise",
|
||||
uselist=False,
|
||||
primaryjoin="WorkflowNodeExecutionOffload.file_id == UploadFile.id",
|
||||
primaryjoin=lambda: orm.foreign(WorkflowNodeExecutionOffload.file_id) == UploadFile.id,
|
||||
)
|
||||
|
||||
|
||||
@ -1968,10 +1971,11 @@ class WorkflowDraftVariableFile(Base):
|
||||
|
||||
# Relationship to UploadFile
|
||||
upload_file: Mapped["UploadFile"] = orm.relationship(
|
||||
UploadFile,
|
||||
foreign_keys=[upload_file_id],
|
||||
lazy="raise",
|
||||
uselist=False,
|
||||
primaryjoin="WorkflowDraftVariableFile.upload_file_id == UploadFile.id",
|
||||
primaryjoin=lambda: orm.foreign(WorkflowDraftVariableFile.upload_file_id) == UploadFile.id,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -61,13 +61,31 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i
|
||||
|
||||
# check segment is exist
|
||||
if index_node_ids:
|
||||
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
|
||||
with session_factory.create_session() as session:
|
||||
dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1))
|
||||
if dataset:
|
||||
index_processor.clean(
|
||||
dataset, index_node_ids, with_keywords=True, delete_child_chunks=True, delete_summaries=True
|
||||
)
|
||||
# Wrap vector / keyword index cleanup in try/except so that a transient
|
||||
# failure here (e.g. billing API hiccup propagated via FeatureService when
|
||||
# ModelManager is initialized inside ``Vector(dataset)``) does not abort
|
||||
# the entire task and leave document_segments / child_chunks / image_files
|
||||
# / metadata bindings stranded in PG. Mirrors the pattern already used in
|
||||
# ``clean_dataset_task`` so the document row's hard delete (already
|
||||
# committed by the caller) does not produce orphan PG rows just because
|
||||
# the vector backend or one of its transitive dependencies was unhappy.
|
||||
try:
|
||||
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
|
||||
with session_factory.create_session() as session:
|
||||
dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1))
|
||||
if dataset:
|
||||
index_processor.clean(
|
||||
dataset, index_node_ids, with_keywords=True, delete_child_chunks=True, delete_summaries=True
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to clean vector / keyword index in clean_document_task, "
|
||||
"document_id=%s, dataset_id=%s, index_node_ids_count=%d. "
|
||||
"Continuing with PG / storage cleanup; vector orphans can be reaped later.",
|
||||
document_id,
|
||||
dataset_id,
|
||||
len(index_node_ids),
|
||||
)
|
||||
|
||||
total_image_files = []
|
||||
with session_factory.create_session() as session, session.begin():
|
||||
|
||||
@ -40,12 +40,29 @@ def clean_notion_document_task(document_ids: list[str], dataset_id: str):
|
||||
segments = session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all()
|
||||
total_index_node_ids.extend([segment.index_node_id for segment in segments])
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1))
|
||||
if dataset:
|
||||
index_processor.clean(
|
||||
dataset, total_index_node_ids, with_keywords=True, delete_child_chunks=True, delete_summaries=True
|
||||
)
|
||||
# Wrap vector / keyword index cleanup in try/except so that a transient
|
||||
# failure here (e.g. billing API hiccup propagated via FeatureService when
|
||||
# ``ModelManager`` is initialized inside ``Vector(dataset)``) does not abort
|
||||
# the task and leave the already-deleted documents' segments stranded in PG.
|
||||
# The Document rows are hard-deleted in the previous session block, so any
|
||||
# exception escaping this task would produce orphans that no later request
|
||||
# can reference back. Mirrors the pattern in ``clean_dataset_task``.
|
||||
try:
|
||||
with session_factory.create_session() as session:
|
||||
dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1))
|
||||
if dataset:
|
||||
index_processor.clean(
|
||||
dataset, total_index_node_ids, with_keywords=True, delete_child_chunks=True, delete_summaries=True
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to clean vector / keyword index in clean_notion_document_task, "
|
||||
"dataset_id=%s, document_ids=%s, index_node_ids_count=%d. "
|
||||
"Continuing with segment deletion; vector orphans can be reaped later.",
|
||||
dataset_id,
|
||||
document_ids,
|
||||
len(total_index_node_ids),
|
||||
)
|
||||
|
||||
with session_factory.create_session() as session, session.begin():
|
||||
segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.document_id.in_(document_ids))
|
||||
|
||||
@ -11,6 +11,7 @@ from core.plugin.entities.plugin import PluginInstallationSource
|
||||
from core.plugin.impl.plugin import PluginInstaller
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.account import TenantPluginAutoUpgradeStrategy
|
||||
from services.plugin.plugin_service import PluginService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -171,14 +172,13 @@ def process_tenant_plugin_autoupgrade_check_task(
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
_ = manager.upgrade_plugin(
|
||||
# Use the service that downloads and uploads the package to the daemon
|
||||
# first; calling manager.upgrade_plugin directly skips that step and the
|
||||
# daemon fails because the package never reaches its local bucket.
|
||||
_ = PluginService.upgrade_plugin_with_marketplace(
|
||||
tenant_id,
|
||||
original_unique_identifier,
|
||||
new_unique_identifier,
|
||||
PluginInstallationSource.Marketplace,
|
||||
{
|
||||
"plugin_unique_identifier": new_unique_identifier,
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
click.echo(click.style(f"Error when upgrading plugin: {e}", fg="red"))
|
||||
|
||||
@ -602,14 +602,25 @@ class TestCleanNotionDocumentTask:
|
||||
# Note: This test successfully verifies database operations.
|
||||
# IndexProcessor verification would require more sophisticated mocking.
|
||||
|
||||
def test_clean_notion_document_task_database_transaction_rollback(
|
||||
def test_clean_notion_document_task_continues_when_index_processor_fails(
|
||||
self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test cleanup task behavior when database operations fail.
|
||||
Index processor failure (e.g. transient billing API error propagated via
|
||||
``FeatureService`` when ``Vector(dataset)`` lazily resolves the embedding
|
||||
model) must NOT abort the cleanup task. The Document rows have already
|
||||
been hard-deleted in the first session block before vector cleanup runs,
|
||||
so any uncaught exception escaping the task would strand
|
||||
``DocumentSegment`` rows in PG with no parent ``Document``.
|
||||
|
||||
This test verifies that the task properly handles database errors
|
||||
and maintains data consistency.
|
||||
Contract: the task swallows the index_processor exception, logs it, and
|
||||
proceeds to delete the segments — leaving PG consistent. (Vector orphans,
|
||||
if any, can be reaped later by an offline scanner.)
|
||||
|
||||
Regression guard for the production incident where ``clean_document_task``
|
||||
/ ``clean_notion_document_task`` failed with
|
||||
``ValueError("Unable to retrieve billing information...")`` and left
|
||||
tens of thousands of orphan segments per affected tenant.
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
@ -672,17 +683,28 @@ class TestCleanNotionDocumentTask:
|
||||
db_session_with_containers.add(segment)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Mock index processor to raise an exception
|
||||
# Simulate the production failure mode: index_processor.clean() raises a
|
||||
# ValueError mirroring ``BillingService._send_request`` returning non-200.
|
||||
mock_index_processor = mock_index_processor_factory.return_value.init_index_processor.return_value
|
||||
mock_index_processor.clean.side_effect = Exception("Index processor error")
|
||||
mock_index_processor.clean.side_effect = ValueError(
|
||||
"Unable to retrieve billing information. Please try again later or contact support."
|
||||
)
|
||||
|
||||
# Execute cleanup task - current implementation propagates the exception
|
||||
with pytest.raises(Exception, match="Index processor error"):
|
||||
clean_notion_document_task([document.id], dataset.id)
|
||||
# Execute cleanup task — must NOT raise even though clean() raises.
|
||||
# Before the safety-net wrapper this would have re-raised the ValueError,
|
||||
# aborting the task and leaving DocumentSegment stranded in PG.
|
||||
clean_notion_document_task([document.id], dataset.id)
|
||||
|
||||
# Note: This test demonstrates the task's error handling capability.
|
||||
# Even with external service errors, the database operations complete successfully.
|
||||
# In a production environment, proper error handling would determine transaction rollback behavior.
|
||||
# Vector cleanup was attempted exactly once.
|
||||
mock_index_processor.clean.assert_called_once()
|
||||
|
||||
# The crucial assertion: despite the index processor failure, the
|
||||
# final session block (line 51-52, ``DELETE FROM document_segments``)
|
||||
# still ran and committed. This is what the wrapper buys us — without
|
||||
# it the production incident left tens of thousands of orphan segments
|
||||
# per affected tenant. Aligns with the assertion shape used by the
|
||||
# happy-path test (``test_clean_notion_document_task_success``).
|
||||
assert _count_segments(db_session_with_containers, DocumentSegment.document_id == document.id) == 0
|
||||
|
||||
def test_clean_notion_document_task_with_large_number_of_documents(
|
||||
self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies
|
||||
|
||||
@ -146,10 +146,7 @@ def test_get_vector_factory_entry_point_overrides_builtin(vector_factory_module,
|
||||
def test_vector_init_uses_default_and_custom_attributes(vector_factory_module):
|
||||
dataset = SimpleNamespace(id="dataset-1")
|
||||
|
||||
with (
|
||||
patch.object(vector_factory_module.Vector, "_get_embeddings", return_value="embeddings"),
|
||||
patch.object(vector_factory_module.Vector, "_init_vector", return_value="processor"),
|
||||
):
|
||||
with patch.object(vector_factory_module.Vector, "_init_vector", return_value="processor"):
|
||||
default_vector = vector_factory_module.Vector(dataset)
|
||||
custom_vector = vector_factory_module.Vector(dataset, attributes=["doc_id"])
|
||||
|
||||
@ -166,10 +163,57 @@ def test_vector_init_uses_default_and_custom_attributes(vector_factory_module):
|
||||
"original_chunk_id",
|
||||
]
|
||||
assert custom_vector._attributes == ["doc_id"]
|
||||
assert default_vector._embeddings == "embeddings"
|
||||
# ``_embeddings`` is now a lazy proxy that defers materializing the real
|
||||
# embedding model until ``embed_*`` is invoked, so cleanup paths never
|
||||
# trigger billing/feature-service calls during ``Vector(dataset)``
|
||||
# construction. See ``_LazyEmbeddings``.
|
||||
assert isinstance(default_vector._embeddings, vector_factory_module._LazyEmbeddings)
|
||||
assert default_vector._vector_processor == "processor"
|
||||
|
||||
|
||||
def test_lazy_embeddings_defer_real_load_until_first_embed_call(vector_factory_module, monkeypatch):
|
||||
"""``Vector(dataset)`` must not transitively call ``ModelManager`` during
|
||||
construction. The real embedding model should only be materialized on the
|
||||
first ``embed_*`` call (i.e. create / search paths) so cleanup paths
|
||||
(``delete_by_ids`` / ``delete``) remain resilient to billing-API failures.
|
||||
"""
|
||||
for_tenant_mock = MagicMock(side_effect=AssertionError("ModelManager.for_tenant must not be called eagerly"))
|
||||
monkeypatch.setattr(vector_factory_module.ModelManager, "for_tenant", for_tenant_mock)
|
||||
|
||||
dataset = SimpleNamespace(
|
||||
tenant_id="tenant-1",
|
||||
embedding_model_provider="openai",
|
||||
embedding_model="text-embedding-3-small",
|
||||
)
|
||||
|
||||
proxy = vector_factory_module._LazyEmbeddings(dataset)
|
||||
|
||||
# Construction alone does not trigger ModelManager / FeatureService / BillingService.
|
||||
for_tenant_mock.assert_not_called()
|
||||
|
||||
# Exercising an embed_* method materializes the real model exactly once.
|
||||
inner_model = MagicMock()
|
||||
inner_model.embed_documents.return_value = [[0.1, 0.2]]
|
||||
cached_embedding_mock = MagicMock(return_value=inner_model)
|
||||
real_for_tenant = MagicMock()
|
||||
real_for_tenant.get_model_instance.return_value = "embedding-model-instance"
|
||||
monkeypatch.setattr(vector_factory_module.ModelManager, "for_tenant", MagicMock(return_value=real_for_tenant))
|
||||
monkeypatch.setattr(vector_factory_module, "CacheEmbedding", cached_embedding_mock)
|
||||
|
||||
result = proxy.embed_documents(["hello"])
|
||||
|
||||
assert result == [[0.1, 0.2]]
|
||||
cached_embedding_mock.assert_called_once_with("embedding-model-instance")
|
||||
inner_model.embed_documents.assert_called_once_with(["hello"])
|
||||
|
||||
# Subsequent calls reuse the materialized model (no re-resolution).
|
||||
inner_model.embed_documents.reset_mock()
|
||||
cached_embedding_mock.reset_mock()
|
||||
proxy.embed_documents(["world"])
|
||||
cached_embedding_mock.assert_not_called()
|
||||
inner_model.embed_documents.assert_called_once_with(["world"])
|
||||
|
||||
|
||||
def test_init_vector_prefers_dataset_index_struct(vector_factory_module, monkeypatch):
|
||||
calls = {"vector_type": None, "init_args": None}
|
||||
|
||||
|
||||
@ -45,7 +45,7 @@ class TestWorkflowModelValidation:
|
||||
workflow = Workflow.new(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
type=WorkflowType.WORKFLOW.value,
|
||||
type=WorkflowType.WORKFLOW,
|
||||
version="draft",
|
||||
graph=graph,
|
||||
features=features,
|
||||
@ -58,7 +58,7 @@ class TestWorkflowModelValidation:
|
||||
# Assert
|
||||
assert workflow.tenant_id == tenant_id
|
||||
assert workflow.app_id == app_id
|
||||
assert workflow.type == WorkflowType.WORKFLOW.value
|
||||
assert workflow.type == WorkflowType.WORKFLOW
|
||||
assert workflow.version == "draft"
|
||||
assert workflow.graph == graph
|
||||
assert workflow.created_by == created_by
|
||||
@ -68,7 +68,7 @@ class TestWorkflowModelValidation:
|
||||
def test_workflow_type_enum_values(self):
|
||||
"""Test WorkflowType enum values."""
|
||||
# Assert
|
||||
assert WorkflowType.WORKFLOW.value == "workflow"
|
||||
assert WorkflowType.WORKFLOW == "workflow"
|
||||
assert WorkflowType.CHAT.value == "chat"
|
||||
assert WorkflowType.RAG_PIPELINE.value == "rag-pipeline"
|
||||
|
||||
@ -89,7 +89,7 @@ class TestWorkflowModelValidation:
|
||||
workflow = Workflow.new(
|
||||
tenant_id=str(uuid4()),
|
||||
app_id=str(uuid4()),
|
||||
type=WorkflowType.WORKFLOW.value,
|
||||
type=WorkflowType.WORKFLOW,
|
||||
version="draft",
|
||||
graph=json.dumps(graph_data),
|
||||
features="{}",
|
||||
@ -114,7 +114,7 @@ class TestWorkflowModelValidation:
|
||||
workflow = Workflow.new(
|
||||
tenant_id=str(uuid4()),
|
||||
app_id=str(uuid4()),
|
||||
type=WorkflowType.WORKFLOW.value,
|
||||
type=WorkflowType.WORKFLOW,
|
||||
version="draft",
|
||||
graph="{}",
|
||||
features=json.dumps(features_data),
|
||||
@ -138,7 +138,7 @@ class TestWorkflowModelValidation:
|
||||
workflow = Workflow.new(
|
||||
tenant_id=str(uuid4()),
|
||||
app_id=str(uuid4()),
|
||||
type=WorkflowType.WORKFLOW.value,
|
||||
type=WorkflowType.WORKFLOW,
|
||||
version="v1.0",
|
||||
graph="{}",
|
||||
features="{}",
|
||||
@ -176,11 +176,11 @@ class TestWorkflowRunStateTransitions:
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
workflow_id=workflow_id,
|
||||
type=WorkflowType.WORKFLOW.value,
|
||||
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING.value,
|
||||
type=WorkflowType.WORKFLOW,
|
||||
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
|
||||
version="draft",
|
||||
status=WorkflowExecutionStatus.RUNNING.value,
|
||||
created_by_role=CreatorUserRole.ACCOUNT.value,
|
||||
status=WorkflowExecutionStatus.RUNNING,
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=created_by,
|
||||
)
|
||||
|
||||
@ -188,9 +188,9 @@ class TestWorkflowRunStateTransitions:
|
||||
assert workflow_run.tenant_id == tenant_id
|
||||
assert workflow_run.app_id == app_id
|
||||
assert workflow_run.workflow_id == workflow_id
|
||||
assert workflow_run.type == WorkflowType.WORKFLOW.value
|
||||
assert workflow_run.triggered_from == WorkflowRunTriggeredFrom.DEBUGGING.value
|
||||
assert workflow_run.status == WorkflowExecutionStatus.RUNNING.value
|
||||
assert workflow_run.type == WorkflowType.WORKFLOW
|
||||
assert workflow_run.triggered_from == WorkflowRunTriggeredFrom.DEBUGGING
|
||||
assert workflow_run.status == WorkflowExecutionStatus.RUNNING
|
||||
assert workflow_run.created_by == created_by
|
||||
|
||||
def test_workflow_run_state_transition_running_to_succeeded(self):
|
||||
@ -200,21 +200,21 @@ class TestWorkflowRunStateTransitions:
|
||||
tenant_id=str(uuid4()),
|
||||
app_id=str(uuid4()),
|
||||
workflow_id=str(uuid4()),
|
||||
type=WorkflowType.WORKFLOW.value,
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN.value,
|
||||
type=WorkflowType.WORKFLOW,
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
version="v1.0",
|
||||
status=WorkflowExecutionStatus.RUNNING.value,
|
||||
created_by_role=CreatorUserRole.END_USER.value,
|
||||
status=WorkflowExecutionStatus.RUNNING,
|
||||
created_by_role=CreatorUserRole.END_USER,
|
||||
created_by=str(uuid4()),
|
||||
)
|
||||
|
||||
# Act
|
||||
workflow_run.status = WorkflowExecutionStatus.SUCCEEDED.value
|
||||
workflow_run.status = WorkflowExecutionStatus.SUCCEEDED
|
||||
workflow_run.finished_at = datetime.now(UTC)
|
||||
workflow_run.elapsed_time = 2.5
|
||||
|
||||
# Assert
|
||||
assert workflow_run.status == WorkflowExecutionStatus.SUCCEEDED.value
|
||||
assert workflow_run.status == WorkflowExecutionStatus.SUCCEEDED
|
||||
assert workflow_run.finished_at is not None
|
||||
assert workflow_run.elapsed_time == 2.5
|
||||
|
||||
@ -225,21 +225,21 @@ class TestWorkflowRunStateTransitions:
|
||||
tenant_id=str(uuid4()),
|
||||
app_id=str(uuid4()),
|
||||
workflow_id=str(uuid4()),
|
||||
type=WorkflowType.WORKFLOW.value,
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN.value,
|
||||
type=WorkflowType.WORKFLOW,
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
version="v1.0",
|
||||
status=WorkflowExecutionStatus.RUNNING.value,
|
||||
created_by_role=CreatorUserRole.ACCOUNT.value,
|
||||
status=WorkflowExecutionStatus.RUNNING,
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=str(uuid4()),
|
||||
)
|
||||
|
||||
# Act
|
||||
workflow_run.status = WorkflowExecutionStatus.FAILED.value
|
||||
workflow_run.status = WorkflowExecutionStatus.FAILED
|
||||
workflow_run.error = "Node execution failed: Invalid input"
|
||||
workflow_run.finished_at = datetime.now(UTC)
|
||||
|
||||
# Assert
|
||||
assert workflow_run.status == WorkflowExecutionStatus.FAILED.value
|
||||
assert workflow_run.status == WorkflowExecutionStatus.FAILED
|
||||
assert workflow_run.error == "Node execution failed: Invalid input"
|
||||
assert workflow_run.finished_at is not None
|
||||
|
||||
@ -250,20 +250,20 @@ class TestWorkflowRunStateTransitions:
|
||||
tenant_id=str(uuid4()),
|
||||
app_id=str(uuid4()),
|
||||
workflow_id=str(uuid4()),
|
||||
type=WorkflowType.WORKFLOW.value,
|
||||
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING.value,
|
||||
type=WorkflowType.WORKFLOW,
|
||||
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
|
||||
version="draft",
|
||||
status=WorkflowExecutionStatus.RUNNING.value,
|
||||
created_by_role=CreatorUserRole.ACCOUNT.value,
|
||||
status=WorkflowExecutionStatus.RUNNING,
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=str(uuid4()),
|
||||
)
|
||||
|
||||
# Act
|
||||
workflow_run.status = WorkflowExecutionStatus.STOPPED.value
|
||||
workflow_run.status = WorkflowExecutionStatus.STOPPED
|
||||
workflow_run.finished_at = datetime.now(UTC)
|
||||
|
||||
# Assert
|
||||
assert workflow_run.status == WorkflowExecutionStatus.STOPPED.value
|
||||
assert workflow_run.status == WorkflowExecutionStatus.STOPPED
|
||||
assert workflow_run.finished_at is not None
|
||||
|
||||
def test_workflow_run_state_transition_running_to_paused(self):
|
||||
@ -273,19 +273,19 @@ class TestWorkflowRunStateTransitions:
|
||||
tenant_id=str(uuid4()),
|
||||
app_id=str(uuid4()),
|
||||
workflow_id=str(uuid4()),
|
||||
type=WorkflowType.WORKFLOW.value,
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN.value,
|
||||
type=WorkflowType.WORKFLOW,
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
version="v1.0",
|
||||
status=WorkflowExecutionStatus.RUNNING.value,
|
||||
created_by_role=CreatorUserRole.END_USER.value,
|
||||
status=WorkflowExecutionStatus.RUNNING,
|
||||
created_by_role=CreatorUserRole.END_USER,
|
||||
created_by=str(uuid4()),
|
||||
)
|
||||
|
||||
# Act
|
||||
workflow_run.status = WorkflowExecutionStatus.PAUSED.value
|
||||
workflow_run.status = WorkflowExecutionStatus.PAUSED
|
||||
|
||||
# Assert
|
||||
assert workflow_run.status == WorkflowExecutionStatus.PAUSED.value
|
||||
assert workflow_run.status == WorkflowExecutionStatus.PAUSED
|
||||
assert workflow_run.finished_at is None # Not finished when paused
|
||||
|
||||
def test_workflow_run_state_transition_paused_to_running(self):
|
||||
@ -295,19 +295,19 @@ class TestWorkflowRunStateTransitions:
|
||||
tenant_id=str(uuid4()),
|
||||
app_id=str(uuid4()),
|
||||
workflow_id=str(uuid4()),
|
||||
type=WorkflowType.WORKFLOW.value,
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN.value,
|
||||
type=WorkflowType.WORKFLOW,
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
version="v1.0",
|
||||
status=WorkflowExecutionStatus.PAUSED.value,
|
||||
created_by_role=CreatorUserRole.ACCOUNT.value,
|
||||
status=WorkflowExecutionStatus.PAUSED,
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=str(uuid4()),
|
||||
)
|
||||
|
||||
# Act
|
||||
workflow_run.status = WorkflowExecutionStatus.RUNNING.value
|
||||
workflow_run.status = WorkflowExecutionStatus.RUNNING
|
||||
|
||||
# Assert
|
||||
assert workflow_run.status == WorkflowExecutionStatus.RUNNING.value
|
||||
assert workflow_run.status == WorkflowExecutionStatus.RUNNING
|
||||
|
||||
def test_workflow_run_with_partial_succeeded_status(self):
|
||||
"""Test workflow run with partial-succeeded status."""
|
||||
@ -316,17 +316,17 @@ class TestWorkflowRunStateTransitions:
|
||||
tenant_id=str(uuid4()),
|
||||
app_id=str(uuid4()),
|
||||
workflow_id=str(uuid4()),
|
||||
type=WorkflowType.WORKFLOW.value,
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN.value,
|
||||
type=WorkflowType.WORKFLOW,
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
version="v1.0",
|
||||
status=WorkflowExecutionStatus.PARTIAL_SUCCEEDED.value,
|
||||
created_by_role=CreatorUserRole.ACCOUNT.value,
|
||||
status=WorkflowExecutionStatus.PARTIAL_SUCCEEDED,
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=str(uuid4()),
|
||||
exceptions_count=2,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert workflow_run.status == WorkflowExecutionStatus.PARTIAL_SUCCEEDED.value
|
||||
assert workflow_run.status == WorkflowExecutionStatus.PARTIAL_SUCCEEDED
|
||||
assert workflow_run.exceptions_count == 2
|
||||
|
||||
def test_workflow_run_with_inputs_and_outputs(self):
|
||||
@ -340,11 +340,11 @@ class TestWorkflowRunStateTransitions:
|
||||
tenant_id=str(uuid4()),
|
||||
app_id=str(uuid4()),
|
||||
workflow_id=str(uuid4()),
|
||||
type=WorkflowType.WORKFLOW.value,
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN.value,
|
||||
type=WorkflowType.WORKFLOW,
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
version="v1.0",
|
||||
status=WorkflowExecutionStatus.SUCCEEDED.value,
|
||||
created_by_role=CreatorUserRole.END_USER.value,
|
||||
status=WorkflowExecutionStatus.SUCCEEDED,
|
||||
created_by_role=CreatorUserRole.END_USER,
|
||||
created_by=str(uuid4()),
|
||||
inputs=json.dumps(inputs),
|
||||
outputs=json.dumps(outputs),
|
||||
@ -362,11 +362,11 @@ class TestWorkflowRunStateTransitions:
|
||||
tenant_id=str(uuid4()),
|
||||
app_id=str(uuid4()),
|
||||
workflow_id=str(uuid4()),
|
||||
type=WorkflowType.WORKFLOW.value,
|
||||
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING.value,
|
||||
type=WorkflowType.WORKFLOW,
|
||||
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
|
||||
version="draft",
|
||||
status=WorkflowExecutionStatus.RUNNING.value,
|
||||
created_by_role=CreatorUserRole.ACCOUNT.value,
|
||||
status=WorkflowExecutionStatus.RUNNING,
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=str(uuid4()),
|
||||
graph=json.dumps(graph),
|
||||
)
|
||||
@ -391,11 +391,11 @@ class TestWorkflowRunStateTransitions:
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
workflow_id=workflow_id,
|
||||
type=WorkflowType.WORKFLOW.value,
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN.value,
|
||||
type=WorkflowType.WORKFLOW,
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
version="v1.0",
|
||||
status=WorkflowExecutionStatus.SUCCEEDED.value,
|
||||
created_by_role=CreatorUserRole.ACCOUNT.value,
|
||||
status=WorkflowExecutionStatus.SUCCEEDED,
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=created_by,
|
||||
total_tokens=1500,
|
||||
total_steps=5,
|
||||
@ -410,7 +410,7 @@ class TestWorkflowRunStateTransitions:
|
||||
assert result["tenant_id"] == tenant_id
|
||||
assert result["app_id"] == app_id
|
||||
assert result["workflow_id"] == workflow_id
|
||||
assert result["status"] == WorkflowExecutionStatus.SUCCEEDED.value
|
||||
assert result["status"] == WorkflowExecutionStatus.SUCCEEDED
|
||||
assert result["total_tokens"] == 1500
|
||||
assert result["total_steps"] == 5
|
||||
|
||||
@ -422,18 +422,18 @@ class TestWorkflowRunStateTransitions:
|
||||
"tenant_id": str(uuid4()),
|
||||
"app_id": str(uuid4()),
|
||||
"workflow_id": str(uuid4()),
|
||||
"type": WorkflowType.WORKFLOW.value,
|
||||
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN.value,
|
||||
"type": WorkflowType.WORKFLOW,
|
||||
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN,
|
||||
"version": "v1.0",
|
||||
"graph": {"nodes": [], "edges": []},
|
||||
"inputs": {"query": "test"},
|
||||
"status": WorkflowExecutionStatus.SUCCEEDED.value,
|
||||
"status": WorkflowExecutionStatus.SUCCEEDED,
|
||||
"outputs": {"result": "success"},
|
||||
"error": None,
|
||||
"elapsed_time": 3.5,
|
||||
"total_tokens": 2000,
|
||||
"total_steps": 10,
|
||||
"created_by_role": CreatorUserRole.ACCOUNT.value,
|
||||
"created_by_role": CreatorUserRole.ACCOUNT,
|
||||
"created_by": str(uuid4()),
|
||||
"created_at": datetime.now(UTC),
|
||||
"finished_at": datetime.now(UTC),
|
||||
@ -446,7 +446,7 @@ class TestWorkflowRunStateTransitions:
|
||||
# Assert
|
||||
assert workflow_run.id == data["id"]
|
||||
assert workflow_run.workflow_id == data["workflow_id"]
|
||||
assert workflow_run.status == WorkflowExecutionStatus.SUCCEEDED.value
|
||||
assert workflow_run.status == WorkflowExecutionStatus.SUCCEEDED
|
||||
assert workflow_run.total_tokens == 2000
|
||||
|
||||
|
||||
@ -467,14 +467,14 @@ class TestNodeExecutionRelationships:
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
workflow_id=workflow_id,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
workflow_run_id=workflow_run_id,
|
||||
index=1,
|
||||
node_id="start",
|
||||
node_type=BuiltinNodeTypes.START,
|
||||
title="Start Node",
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED.value,
|
||||
created_by_role=CreatorUserRole.ACCOUNT.value,
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=created_by,
|
||||
)
|
||||
|
||||
@ -498,15 +498,15 @@ class TestNodeExecutionRelationships:
|
||||
tenant_id=str(uuid4()),
|
||||
app_id=str(uuid4()),
|
||||
workflow_id=str(uuid4()),
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
workflow_run_id=str(uuid4()),
|
||||
index=2,
|
||||
predecessor_node_id=predecessor_node_id,
|
||||
node_id=current_node_id,
|
||||
node_type=BuiltinNodeTypes.LLM,
|
||||
title="LLM Node",
|
||||
status=WorkflowNodeExecutionStatus.RUNNING.value,
|
||||
created_by_role=CreatorUserRole.ACCOUNT.value,
|
||||
status=WorkflowNodeExecutionStatus.RUNNING,
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=str(uuid4()),
|
||||
)
|
||||
|
||||
@ -528,8 +528,8 @@ class TestNodeExecutionRelationships:
|
||||
node_id="llm_test",
|
||||
node_type=BuiltinNodeTypes.LLM,
|
||||
title="Test LLM",
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED.value,
|
||||
created_by_role=CreatorUserRole.ACCOUNT.value,
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=str(uuid4()),
|
||||
)
|
||||
|
||||
@ -549,14 +549,14 @@ class TestNodeExecutionRelationships:
|
||||
tenant_id=str(uuid4()),
|
||||
app_id=str(uuid4()),
|
||||
workflow_id=str(uuid4()),
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
workflow_run_id=str(uuid4()),
|
||||
index=1,
|
||||
node_id="llm_1",
|
||||
node_type=BuiltinNodeTypes.LLM,
|
||||
title="LLM Node",
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED.value,
|
||||
created_by_role=CreatorUserRole.ACCOUNT.value,
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=str(uuid4()),
|
||||
inputs=json.dumps(inputs),
|
||||
outputs=json.dumps(outputs),
|
||||
@ -575,24 +575,24 @@ class TestNodeExecutionRelationships:
|
||||
tenant_id=str(uuid4()),
|
||||
app_id=str(uuid4()),
|
||||
workflow_id=str(uuid4()),
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
workflow_run_id=str(uuid4()),
|
||||
index=1,
|
||||
node_id="code_1",
|
||||
node_type=BuiltinNodeTypes.CODE,
|
||||
title="Code Node",
|
||||
status=WorkflowNodeExecutionStatus.RUNNING.value,
|
||||
created_by_role=CreatorUserRole.ACCOUNT.value,
|
||||
status=WorkflowNodeExecutionStatus.RUNNING,
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=str(uuid4()),
|
||||
)
|
||||
|
||||
# Act - transition to succeeded
|
||||
node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value
|
||||
node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
node_execution.elapsed_time = 1.2
|
||||
node_execution.finished_at = datetime.now(UTC)
|
||||
|
||||
# Assert
|
||||
assert node_execution.status == WorkflowNodeExecutionStatus.SUCCEEDED.value
|
||||
assert node_execution.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert node_execution.elapsed_time == 1.2
|
||||
assert node_execution.finished_at is not None
|
||||
|
||||
@ -606,20 +606,20 @@ class TestNodeExecutionRelationships:
|
||||
tenant_id=str(uuid4()),
|
||||
app_id=str(uuid4()),
|
||||
workflow_id=str(uuid4()),
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
workflow_run_id=str(uuid4()),
|
||||
index=3,
|
||||
node_id="code_1",
|
||||
node_type=BuiltinNodeTypes.CODE,
|
||||
title="Code Node",
|
||||
status=WorkflowNodeExecutionStatus.FAILED.value,
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=error_message,
|
||||
created_by_role=CreatorUserRole.ACCOUNT.value,
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=str(uuid4()),
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert node_execution.status == WorkflowNodeExecutionStatus.FAILED.value
|
||||
assert node_execution.status == WorkflowNodeExecutionStatus.FAILED
|
||||
assert node_execution.error == error_message
|
||||
|
||||
def test_node_execution_with_metadata(self):
|
||||
@ -637,14 +637,14 @@ class TestNodeExecutionRelationships:
|
||||
tenant_id=str(uuid4()),
|
||||
app_id=str(uuid4()),
|
||||
workflow_id=str(uuid4()),
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
workflow_run_id=str(uuid4()),
|
||||
index=1,
|
||||
node_id="llm_1",
|
||||
node_type=BuiltinNodeTypes.LLM,
|
||||
title="LLM Node",
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED.value,
|
||||
created_by_role=CreatorUserRole.ACCOUNT.value,
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=str(uuid4()),
|
||||
execution_metadata=json.dumps(metadata),
|
||||
)
|
||||
@ -660,14 +660,14 @@ class TestNodeExecutionRelationships:
|
||||
tenant_id=str(uuid4()),
|
||||
app_id=str(uuid4()),
|
||||
workflow_id=str(uuid4()),
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
workflow_run_id=str(uuid4()),
|
||||
index=1,
|
||||
node_id="start",
|
||||
node_type=BuiltinNodeTypes.START,
|
||||
title="Start",
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED.value,
|
||||
created_by_role=CreatorUserRole.ACCOUNT.value,
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=str(uuid4()),
|
||||
execution_metadata=None,
|
||||
)
|
||||
@ -696,14 +696,14 @@ class TestNodeExecutionRelationships:
|
||||
tenant_id=str(uuid4()),
|
||||
app_id=str(uuid4()),
|
||||
workflow_id=str(uuid4()),
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
workflow_run_id=str(uuid4()),
|
||||
index=1,
|
||||
node_id=f"{node_type}_1",
|
||||
node_type=node_type,
|
||||
title=title,
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED.value,
|
||||
created_by_role=CreatorUserRole.ACCOUNT.value,
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=str(uuid4()),
|
||||
)
|
||||
|
||||
@ -734,7 +734,7 @@ class TestGraphConfigurationValidation:
|
||||
workflow = Workflow.new(
|
||||
tenant_id=str(uuid4()),
|
||||
app_id=str(uuid4()),
|
||||
type=WorkflowType.WORKFLOW.value,
|
||||
type=WorkflowType.WORKFLOW,
|
||||
version="draft",
|
||||
graph=json.dumps(graph_config),
|
||||
features="{}",
|
||||
@ -761,7 +761,7 @@ class TestGraphConfigurationValidation:
|
||||
workflow = Workflow.new(
|
||||
tenant_id=str(uuid4()),
|
||||
app_id=str(uuid4()),
|
||||
type=WorkflowType.WORKFLOW.value,
|
||||
type=WorkflowType.WORKFLOW,
|
||||
version="draft",
|
||||
graph=json.dumps(graph_config),
|
||||
features="{}",
|
||||
@ -802,7 +802,7 @@ class TestGraphConfigurationValidation:
|
||||
workflow = Workflow.new(
|
||||
tenant_id=str(uuid4()),
|
||||
app_id=str(uuid4()),
|
||||
type=WorkflowType.WORKFLOW.value,
|
||||
type=WorkflowType.WORKFLOW,
|
||||
version="draft",
|
||||
graph=json.dumps(graph_config),
|
||||
features="{}",
|
||||
@ -835,11 +835,11 @@ class TestGraphConfigurationValidation:
|
||||
tenant_id=str(uuid4()),
|
||||
app_id=str(uuid4()),
|
||||
workflow_id=str(uuid4()),
|
||||
type=WorkflowType.WORKFLOW.value,
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN.value,
|
||||
type=WorkflowType.WORKFLOW,
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
version="v1.0",
|
||||
status=WorkflowExecutionStatus.RUNNING.value,
|
||||
created_by_role=CreatorUserRole.ACCOUNT.value,
|
||||
status=WorkflowExecutionStatus.RUNNING,
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=str(uuid4()),
|
||||
graph=json.dumps(original_graph),
|
||||
)
|
||||
@ -872,7 +872,7 @@ class TestGraphConfigurationValidation:
|
||||
workflow = Workflow.new(
|
||||
tenant_id=str(uuid4()),
|
||||
app_id=str(uuid4()),
|
||||
type=WorkflowType.WORKFLOW.value,
|
||||
type=WorkflowType.WORKFLOW,
|
||||
version="draft",
|
||||
graph=json.dumps(graph_config),
|
||||
features="{}",
|
||||
@ -912,7 +912,7 @@ class TestGraphConfigurationValidation:
|
||||
workflow = Workflow.new(
|
||||
tenant_id=str(uuid4()),
|
||||
app_id=str(uuid4()),
|
||||
type=WorkflowType.WORKFLOW.value,
|
||||
type=WorkflowType.WORKFLOW,
|
||||
version="draft",
|
||||
graph=json.dumps(graph_config),
|
||||
features="{}",
|
||||
@ -933,7 +933,7 @@ class TestGraphConfigurationValidation:
|
||||
workflow = Workflow.new(
|
||||
tenant_id=str(uuid4()),
|
||||
app_id=str(uuid4()),
|
||||
type=WorkflowType.WORKFLOW.value,
|
||||
type=WorkflowType.WORKFLOW,
|
||||
version="draft",
|
||||
graph=None,
|
||||
features="{}",
|
||||
@ -956,11 +956,11 @@ class TestGraphConfigurationValidation:
|
||||
tenant_id=str(uuid4()),
|
||||
app_id=str(uuid4()),
|
||||
workflow_id=str(uuid4()),
|
||||
type=WorkflowType.WORKFLOW.value,
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN.value,
|
||||
type=WorkflowType.WORKFLOW,
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
version="v1.0",
|
||||
status=WorkflowExecutionStatus.RUNNING.value,
|
||||
created_by_role=CreatorUserRole.ACCOUNT.value,
|
||||
status=WorkflowExecutionStatus.RUNNING,
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=str(uuid4()),
|
||||
inputs=None,
|
||||
)
|
||||
@ -978,11 +978,11 @@ class TestGraphConfigurationValidation:
|
||||
tenant_id=str(uuid4()),
|
||||
app_id=str(uuid4()),
|
||||
workflow_id=str(uuid4()),
|
||||
type=WorkflowType.WORKFLOW.value,
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN.value,
|
||||
type=WorkflowType.WORKFLOW,
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
version="v1.0",
|
||||
status=WorkflowExecutionStatus.RUNNING.value,
|
||||
created_by_role=CreatorUserRole.ACCOUNT.value,
|
||||
status=WorkflowExecutionStatus.RUNNING,
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=str(uuid4()),
|
||||
outputs=None,
|
||||
)
|
||||
@ -1000,14 +1000,14 @@ class TestGraphConfigurationValidation:
|
||||
tenant_id=str(uuid4()),
|
||||
app_id=str(uuid4()),
|
||||
workflow_id=str(uuid4()),
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
workflow_run_id=str(uuid4()),
|
||||
index=1,
|
||||
node_id="start",
|
||||
node_type=BuiltinNodeTypes.START,
|
||||
title="Start",
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED.value,
|
||||
created_by_role=CreatorUserRole.ACCOUNT.value,
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=str(uuid4()),
|
||||
inputs=None,
|
||||
)
|
||||
@ -1025,14 +1025,14 @@ class TestGraphConfigurationValidation:
|
||||
tenant_id=str(uuid4()),
|
||||
app_id=str(uuid4()),
|
||||
workflow_id=str(uuid4()),
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
workflow_run_id=str(uuid4()),
|
||||
index=1,
|
||||
node_id="start",
|
||||
node_type=BuiltinNodeTypes.START,
|
||||
title="Start",
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED.value,
|
||||
created_by_role=CreatorUserRole.ACCOUNT.value,
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=str(uuid4()),
|
||||
outputs=None,
|
||||
)
|
||||
|
||||
70
api/tests/unit_tests/oss/__mock/baidu_obs.py
Normal file
70
api/tests/unit_tests/oss/__mock/baidu_obs.py
Normal file
@ -0,0 +1,70 @@
|
||||
import base64
|
||||
import hashlib
|
||||
import os
|
||||
from io import BytesIO
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
from baidubce.services.bos.bos_client import BosClient
|
||||
|
||||
from tests.unit_tests.oss.__mock.base import (
|
||||
get_example_bucket,
|
||||
get_example_data,
|
||||
get_example_filename,
|
||||
get_example_filepath,
|
||||
)
|
||||
|
||||
|
||||
class MockBaiduObsClass:
|
||||
def __init__(self, config=None):
|
||||
self.bucket_name = get_example_bucket()
|
||||
self.key = get_example_filename()
|
||||
self.content = get_example_data()
|
||||
self.filepath = get_example_filepath()
|
||||
|
||||
def put_object(self, bucket_name, key, data, content_length=None, content_md5=None, **kwargs):
|
||||
assert bucket_name == self.bucket_name
|
||||
assert key == self.key
|
||||
assert data == self.content
|
||||
assert content_length == len(self.content)
|
||||
expected_md5 = base64.standard_b64encode(hashlib.md5(self.content).digest())
|
||||
assert content_md5 == expected_md5
|
||||
|
||||
def get_object(self, bucket_name, key, **kwargs):
|
||||
assert bucket_name == self.bucket_name
|
||||
assert key == self.key
|
||||
return SimpleNamespace(data=BytesIO(self.content))
|
||||
|
||||
def get_object_to_file(self, bucket_name, key, file_name, **kwargs):
|
||||
assert bucket_name == self.bucket_name
|
||||
assert key == self.key
|
||||
assert file_name == self.filepath
|
||||
|
||||
def get_object_meta_data(self, bucket_name, key, **kwargs):
|
||||
assert bucket_name == self.bucket_name
|
||||
assert key == self.key
|
||||
return SimpleNamespace(status=200)
|
||||
|
||||
def delete_object(self, bucket_name, key, **kwargs):
|
||||
assert bucket_name == self.bucket_name
|
||||
assert key == self.key
|
||||
|
||||
|
||||
MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def setup_baidu_obs_mock(monkeypatch: MonkeyPatch):
|
||||
if MOCK:
|
||||
monkeypatch.setattr(BosClient, "__init__", MockBaiduObsClass.__init__)
|
||||
monkeypatch.setattr(BosClient, "put_object", MockBaiduObsClass.put_object)
|
||||
monkeypatch.setattr(BosClient, "get_object", MockBaiduObsClass.get_object)
|
||||
monkeypatch.setattr(BosClient, "get_object_to_file", MockBaiduObsClass.get_object_to_file)
|
||||
monkeypatch.setattr(BosClient, "get_object_meta_data", MockBaiduObsClass.get_object_meta_data)
|
||||
monkeypatch.setattr(BosClient, "delete_object", MockBaiduObsClass.delete_object)
|
||||
|
||||
yield
|
||||
|
||||
if MOCK:
|
||||
monkeypatch.undo()
|
||||
1
api/tests/unit_tests/oss/baidu_obs/__init__.py
Normal file
1
api/tests/unit_tests/oss/baidu_obs/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
|
||||
59
api/tests/unit_tests/oss/baidu_obs/test_baidu_obs.py
Normal file
59
api/tests/unit_tests/oss/baidu_obs/test_baidu_obs.py
Normal file
@ -0,0 +1,59 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from baidubce.auth.bce_credentials import BceCredentials
|
||||
from baidubce.bce_client_configuration import BceClientConfiguration
|
||||
|
||||
from extensions.storage.baidu_obs_storage import BaiduObsStorage
|
||||
from tests.unit_tests.oss.__mock.baidu_obs import setup_baidu_obs_mock
|
||||
from tests.unit_tests.oss.__mock.base import (
|
||||
BaseStorageTest,
|
||||
get_example_bucket,
|
||||
)
|
||||
|
||||
|
||||
class TestBaiduObs(BaseStorageTest):
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_method(self, setup_baidu_obs_mock):
|
||||
"""Executed before each test method."""
|
||||
with (
|
||||
patch.object(BceCredentials, "__init__", return_value=None),
|
||||
patch.object(BceClientConfiguration, "__init__", return_value=None),
|
||||
):
|
||||
self.storage = BaiduObsStorage()
|
||||
self.storage.bucket_name = get_example_bucket()
|
||||
|
||||
|
||||
class TestBaiduObsConfiguration:
|
||||
def test_init_with_config(self):
|
||||
mock_dify_config = MagicMock()
|
||||
mock_dify_config.BAIDU_OBS_BUCKET_NAME = "test-bucket"
|
||||
mock_dify_config.BAIDU_OBS_ACCESS_KEY = "test-access-key"
|
||||
mock_dify_config.BAIDU_OBS_SECRET_KEY = "test-secret-key"
|
||||
mock_dify_config.BAIDU_OBS_ENDPOINT = "https://bj.bcebos.com"
|
||||
|
||||
mock_credentials = MagicMock(name="credentials")
|
||||
mock_config = MagicMock(name="config")
|
||||
mock_client = MagicMock(name="client")
|
||||
|
||||
with (
|
||||
patch("extensions.storage.baidu_obs_storage.dify_config", mock_dify_config),
|
||||
patch("extensions.storage.baidu_obs_storage.BceCredentials", return_value=mock_credentials) as credentials,
|
||||
patch(
|
||||
"extensions.storage.baidu_obs_storage.BceClientConfiguration", return_value=mock_config
|
||||
) as configuration,
|
||||
patch("extensions.storage.baidu_obs_storage.BosClient", return_value=mock_client) as client_cls,
|
||||
):
|
||||
storage = BaiduObsStorage()
|
||||
|
||||
assert storage.bucket_name == "test-bucket"
|
||||
assert storage.client == mock_client
|
||||
credentials.assert_called_once_with(
|
||||
access_key_id="test-access-key",
|
||||
secret_access_key="test-secret-key",
|
||||
)
|
||||
configuration.assert_called_once_with(
|
||||
credentials=mock_credentials,
|
||||
endpoint="https://bj.bcebos.com",
|
||||
)
|
||||
client_cls.assert_called_once_with(config=mock_config)
|
||||
291
api/tests/unit_tests/tasks/test_clean_document_task.py
Normal file
291
api/tests/unit_tests/tasks/test_clean_document_task.py
Normal file
@ -0,0 +1,291 @@
|
||||
"""
|
||||
Unit tests for clean_document_task.
|
||||
|
||||
Focuses on the resilience contract added by the billing-failure fix:
|
||||
``index_processor.clean()`` is wrapped in ``try/except`` so that a transient
|
||||
failure inside the vector / keyword cleanup (e.g. ``ValueError("Unable to
|
||||
retrieve billing information...")`` raised by ``BillingService._send_request``
|
||||
when ``Vector(dataset)`` transitively triggers ``FeatureService.get_features``)
|
||||
does not abort the entire task and leave PG with stranded ``DocumentSegment``
|
||||
/ ``ChildChunk`` / ``UploadFile`` / ``DatasetMetadataBinding`` rows.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from tasks.clean_document_task import clean_document_task
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def document_id():
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dataset_id():
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tenant_id():
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session_factory():
|
||||
"""Patch ``session_factory.create_session`` to return per-call mock sessions.
|
||||
|
||||
Each call to ``create_session()`` yields a fresh ``MagicMock`` session so we
|
||||
can assert ``execute()`` calls across the multiple short-lived transactions
|
||||
used by ``clean_document_task``.
|
||||
"""
|
||||
with patch("tasks.clean_document_task.session_factory", autospec=True) as mock_sf:
|
||||
sessions: list[MagicMock] = []
|
||||
|
||||
def _create_session():
|
||||
session = MagicMock()
|
||||
session.scalars.return_value.all.return_value = []
|
||||
session.execute.return_value.all.return_value = []
|
||||
session.scalar.return_value = None
|
||||
cm = MagicMock()
|
||||
cm.__enter__.return_value = session
|
||||
cm.__exit__.return_value = None
|
||||
sessions.append(session)
|
||||
return cm
|
||||
|
||||
mock_sf.create_session.side_effect = _create_session
|
||||
yield mock_sf, sessions
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_storage():
|
||||
with patch("tasks.clean_document_task.storage", autospec=True) as mock:
|
||||
mock.delete.return_value = None
|
||||
yield mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_index_processor_factory():
|
||||
"""Mock ``IndexProcessorFactory`` so we can inject behavior into ``clean``."""
|
||||
with patch("tasks.clean_document_task.IndexProcessorFactory", autospec=True) as factory_cls:
|
||||
processor = MagicMock()
|
||||
processor.clean.return_value = None
|
||||
factory_instance = MagicMock()
|
||||
factory_instance.init_index_processor.return_value = processor
|
||||
factory_cls.return_value = factory_instance
|
||||
|
||||
yield {
|
||||
"factory_cls": factory_cls,
|
||||
"factory_instance": factory_instance,
|
||||
"processor": processor,
|
||||
}
|
||||
|
||||
|
||||
def _build_segment(segment_id: str, content: str = "segment content") -> MagicMock:
|
||||
seg = MagicMock()
|
||||
seg.id = segment_id
|
||||
seg.index_node_id = f"node-{segment_id}"
|
||||
seg.content = content
|
||||
return seg
|
||||
|
||||
|
||||
def _build_dataset(dataset_id: str, tenant_id: str) -> MagicMock:
|
||||
ds = MagicMock()
|
||||
ds.id = dataset_id
|
||||
ds.tenant_id = tenant_id
|
||||
return ds
|
||||
|
||||
|
||||
class TestVectorCleanupResilience:
|
||||
"""Vector / keyword cleanup must not abort the task on transient failure."""
|
||||
|
||||
def test_billing_failure_during_vector_cleanup_does_not_skip_pg_cleanup(
|
||||
self,
|
||||
document_id,
|
||||
dataset_id,
|
||||
tenant_id,
|
||||
mock_session_factory,
|
||||
mock_storage,
|
||||
mock_index_processor_factory,
|
||||
):
|
||||
"""Reproduces the production incident:
|
||||
|
||||
``Vector(dataset)`` transitively calls ``FeatureService.get_features``
|
||||
which calls ``BillingService._send_request("GET", ...)``. When billing
|
||||
returns non-200 it raises ``ValueError("Unable to retrieve billing
|
||||
information...")``. Before the fix this propagated out of
|
||||
``clean_document_task`` and left ``DocumentSegment`` / ``ChildChunk`` /
|
||||
``UploadFile`` / ``DatasetMetadataBinding`` rows orphaned because the
|
||||
already-deleted ``Document`` row had been hard-committed by the caller
|
||||
(``dataset_service.delete_document``) before ``.delay()`` was invoked.
|
||||
|
||||
Contract: a billing failure inside ``index_processor.clean()`` must be
|
||||
caught, logged, and the rest of the task must continue so PG ends up
|
||||
consistent with the deleted ``Document`` even if Qdrant retains
|
||||
orphan vectors that can be reaped later.
|
||||
"""
|
||||
mock_sf, sessions = mock_session_factory
|
||||
|
||||
# First create_session(): Step 1 (load segments + attachments).
|
||||
step1_session = MagicMock()
|
||||
step1_session.scalars.return_value.all.return_value = [
|
||||
_build_segment("seg-1"),
|
||||
_build_segment("seg-2"),
|
||||
]
|
||||
step1_session.execute.return_value.all.return_value = []
|
||||
step1_session.scalar.return_value = _build_dataset(dataset_id, tenant_id)
|
||||
# Second create_session(): Step 2 (vector cleanup). Returns dataset.
|
||||
step2_session = MagicMock()
|
||||
step2_session.scalar.return_value = _build_dataset(dataset_id, tenant_id)
|
||||
step2_session.scalars.return_value.all.return_value = []
|
||||
step2_session.execute.return_value.all.return_value = []
|
||||
# Subsequent sessions: Step 3+ (image / segment / file / metadata cleanup).
|
||||
# Default fixture returns empty results which is fine for these short txns.
|
||||
cm1, cm2 = MagicMock(), MagicMock()
|
||||
cm1.__enter__.return_value = step1_session
|
||||
cm1.__exit__.return_value = None
|
||||
cm2.__enter__.return_value = step2_session
|
||||
cm2.__exit__.return_value = None
|
||||
|
||||
def _default_cm():
|
||||
session = MagicMock()
|
||||
session.scalars.return_value.all.return_value = []
|
||||
session.execute.return_value.all.return_value = []
|
||||
session.scalar.return_value = None
|
||||
cm = MagicMock()
|
||||
cm.__enter__.return_value = session
|
||||
cm.__exit__.return_value = None
|
||||
sessions.append(session)
|
||||
return cm
|
||||
|
||||
mock_sf.create_session.side_effect = [cm1, cm2] + [_default_cm() for _ in range(10)]
|
||||
|
||||
# Simulate the production failure: index_processor.clean() raises ValueError
|
||||
# mirroring BillingService._send_request when billing returns non-200.
|
||||
mock_index_processor_factory["processor"].clean.side_effect = ValueError(
|
||||
"Unable to retrieve billing information. Please try again later or contact support."
|
||||
)
|
||||
|
||||
# Act — must not raise out of the task even though clean() raises.
|
||||
clean_document_task(
|
||||
document_id=document_id,
|
||||
dataset_id=dataset_id,
|
||||
doc_form="paragraph",
|
||||
file_id=None,
|
||||
)
|
||||
|
||||
# Assert
|
||||
# 1. Vector cleanup was attempted.
|
||||
mock_index_processor_factory["processor"].clean.assert_called_once()
|
||||
# 2. Despite the failure the task continued: at least one DocumentSegment
|
||||
# delete was issued. We use the count of session.execute calls across
|
||||
# later short transactions as a proxy for "Step 3+ executed".
|
||||
execute_calls = sum(s.execute.call_count for s in sessions)
|
||||
assert execute_calls > 0, (
|
||||
"Step 3+ DB cleanup did not run after vector cleanup failure; "
|
||||
"this regression would re-introduce the orphan-segment bug."
|
||||
)
|
||||
|
||||
def test_vector_cleanup_success_path_remains_unaffected(
|
||||
self,
|
||||
document_id,
|
||||
dataset_id,
|
||||
tenant_id,
|
||||
mock_session_factory,
|
||||
mock_storage,
|
||||
mock_index_processor_factory,
|
||||
):
|
||||
"""Backward-compat: the happy path must still call ``clean()`` exactly
|
||||
once with the expected arguments and complete without errors.
|
||||
"""
|
||||
mock_sf, sessions = mock_session_factory
|
||||
|
||||
step1_session = MagicMock()
|
||||
step1_session.scalars.return_value.all.return_value = [_build_segment("seg-1")]
|
||||
step1_session.execute.return_value.all.return_value = []
|
||||
step1_session.scalar.return_value = _build_dataset(dataset_id, tenant_id)
|
||||
step2_session = MagicMock()
|
||||
step2_session.scalar.return_value = _build_dataset(dataset_id, tenant_id)
|
||||
step2_session.scalars.return_value.all.return_value = []
|
||||
step2_session.execute.return_value.all.return_value = []
|
||||
cm1, cm2 = MagicMock(), MagicMock()
|
||||
cm1.__enter__.return_value = step1_session
|
||||
cm1.__exit__.return_value = None
|
||||
cm2.__enter__.return_value = step2_session
|
||||
cm2.__exit__.return_value = None
|
||||
|
||||
def _default_cm():
|
||||
session = MagicMock()
|
||||
session.scalars.return_value.all.return_value = []
|
||||
session.execute.return_value.all.return_value = []
|
||||
session.scalar.return_value = None
|
||||
cm = MagicMock()
|
||||
cm.__enter__.return_value = session
|
||||
cm.__exit__.return_value = None
|
||||
sessions.append(session)
|
||||
return cm
|
||||
|
||||
mock_sf.create_session.side_effect = [cm1, cm2] + [_default_cm() for _ in range(10)]
|
||||
|
||||
clean_document_task(
|
||||
document_id=document_id,
|
||||
dataset_id=dataset_id,
|
||||
doc_form="paragraph",
|
||||
file_id=None,
|
||||
)
|
||||
|
||||
assert mock_index_processor_factory["processor"].clean.call_count == 1
|
||||
# Index cleanup invoked with the expected delete_summaries / delete_child_chunks flags.
|
||||
_, kwargs = mock_index_processor_factory["processor"].clean.call_args
|
||||
assert kwargs.get("with_keywords") is True
|
||||
assert kwargs.get("delete_child_chunks") is True
|
||||
assert kwargs.get("delete_summaries") is True
|
||||
|
||||
def test_no_segments_skips_vector_cleanup(
|
||||
self,
|
||||
document_id,
|
||||
dataset_id,
|
||||
tenant_id,
|
||||
mock_session_factory,
|
||||
mock_storage,
|
||||
mock_index_processor_factory,
|
||||
):
|
||||
"""When the document has no segments (e.g. indexing failed before
|
||||
producing any), vector cleanup must not be attempted — and therefore
|
||||
the new try/except wrapper does not change behavior here.
|
||||
"""
|
||||
mock_sf, sessions = mock_session_factory
|
||||
|
||||
step1_session = MagicMock()
|
||||
step1_session.scalars.return_value.all.return_value = [] # no segments
|
||||
step1_session.execute.return_value.all.return_value = []
|
||||
step1_session.scalar.return_value = _build_dataset(dataset_id, tenant_id)
|
||||
cm1 = MagicMock()
|
||||
cm1.__enter__.return_value = step1_session
|
||||
cm1.__exit__.return_value = None
|
||||
|
||||
def _default_cm():
|
||||
session = MagicMock()
|
||||
session.scalars.return_value.all.return_value = []
|
||||
session.execute.return_value.all.return_value = []
|
||||
session.scalar.return_value = None
|
||||
cm = MagicMock()
|
||||
cm.__enter__.return_value = session
|
||||
cm.__exit__.return_value = None
|
||||
sessions.append(session)
|
||||
return cm
|
||||
|
||||
mock_sf.create_session.side_effect = [cm1] + [_default_cm() for _ in range(10)]
|
||||
|
||||
clean_document_task(
|
||||
document_id=document_id,
|
||||
dataset_id=dataset_id,
|
||||
doc_form="paragraph",
|
||||
file_id=None,
|
||||
)
|
||||
|
||||
# Vector cleanup is gated on ``index_node_ids``; when there are no
|
||||
# segments the IndexProcessorFactory path is never entered.
|
||||
mock_index_processor_factory["factory_cls"].assert_not_called()
|
||||
@ -0,0 +1,289 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from core.plugin.entities.marketplace import MarketplacePluginSnapshot
|
||||
from core.plugin.entities.plugin import PluginInstallationSource
|
||||
from models.account import TenantPluginAutoUpgradeStrategy
|
||||
|
||||
MODULE = "tasks.process_tenant_plugin_autoupgrade_check_task"
|
||||
|
||||
|
||||
def _make_plugin(plugin_id: str, version: str, source=PluginInstallationSource.Marketplace):
|
||||
"""Build a minimal stand-in for a PluginInstallation entry returned by manager.list_plugins."""
|
||||
return SimpleNamespace(
|
||||
plugin_id=plugin_id,
|
||||
version=version,
|
||||
plugin_unique_identifier=f"{plugin_id}:{version}@deadbeef",
|
||||
source=source,
|
||||
)
|
||||
|
||||
|
||||
def _make_manifest(plugin_id: str, latest_version: str) -> MarketplacePluginSnapshot:
|
||||
org, name = plugin_id.split("/", 1)
|
||||
return MarketplacePluginSnapshot(
|
||||
org=org,
|
||||
name=name,
|
||||
latest_version=latest_version,
|
||||
latest_package_identifier=f"{plugin_id}:{latest_version}@cafe1234",
|
||||
latest_package_url=f"https://marketplace.example/{plugin_id}/{latest_version}.difypkg",
|
||||
)
|
||||
|
||||
|
||||
def _run_task(
|
||||
*,
|
||||
plugins: list,
|
||||
manifests: list[MarketplacePluginSnapshot],
|
||||
strategy_setting=TenantPluginAutoUpgradeStrategy.StrategySetting.LATEST,
|
||||
upgrade_mode=TenantPluginAutoUpgradeStrategy.UpgradeMode.ALL,
|
||||
exclude_plugins=None,
|
||||
include_plugins=None,
|
||||
):
|
||||
"""
|
||||
Execute the celery task synchronously with mocks for the plugin manager,
|
||||
the marketplace cache and PluginService.upgrade_plugin_with_marketplace.
|
||||
Returns the upgrade-call recorder so each test can assert on it.
|
||||
"""
|
||||
fake_manager = MagicMock()
|
||||
fake_manager.list_plugins.return_value = plugins
|
||||
|
||||
upgrade_calls: list[tuple[str, str, str]] = []
|
||||
|
||||
def _record_upgrade(tenant_id, original, new):
|
||||
upgrade_calls.append((tenant_id, original, new))
|
||||
|
||||
with (
|
||||
patch(f"{MODULE}.PluginInstaller", return_value=fake_manager),
|
||||
patch(f"{MODULE}.marketplace_batch_fetch_plugin_manifests", return_value=manifests),
|
||||
patch(
|
||||
f"{MODULE}.PluginService.upgrade_plugin_with_marketplace",
|
||||
side_effect=_record_upgrade,
|
||||
) as upgrade_mock,
|
||||
):
|
||||
from tasks.process_tenant_plugin_autoupgrade_check_task import (
|
||||
process_tenant_plugin_autoupgrade_check_task,
|
||||
)
|
||||
|
||||
process_tenant_plugin_autoupgrade_check_task(
|
||||
"tenant-1",
|
||||
strategy_setting,
|
||||
0,
|
||||
upgrade_mode,
|
||||
exclude_plugins or [],
|
||||
include_plugins or [],
|
||||
)
|
||||
|
||||
return upgrade_mock, upgrade_calls
|
||||
|
||||
|
||||
class TestUpgradeCallsMarketplaceService:
|
||||
"""
|
||||
Regression test for the bug where the auto-upgrade task called
|
||||
manager.upgrade_plugin directly, which skipped downloading the new package
|
||||
from marketplace and uploading it to the daemon. The daemon then failed with
|
||||
"package file not found" and the upgrade silently never completed.
|
||||
"""
|
||||
|
||||
def test_upgrade_routes_through_plugin_service(self):
|
||||
plugin = _make_plugin("acme/foo", "1.0.0")
|
||||
manifest = _make_manifest("acme/foo", "1.0.1")
|
||||
|
||||
upgrade_mock, calls = _run_task(plugins=[plugin], manifests=[manifest])
|
||||
|
||||
upgrade_mock.assert_called_once()
|
||||
assert calls == [("tenant-1", plugin.plugin_unique_identifier, manifest.latest_package_identifier)]
|
||||
|
||||
def test_does_not_call_manager_upgrade_plugin_directly(self):
|
||||
"""Locks in that we never go back to the broken path that bypassed download/upload."""
|
||||
plugin = _make_plugin("acme/foo", "1.0.0")
|
||||
manifest = _make_manifest("acme/foo", "1.0.1")
|
||||
|
||||
fake_manager = MagicMock()
|
||||
fake_manager.list_plugins.return_value = [plugin]
|
||||
|
||||
with (
|
||||
patch(f"{MODULE}.PluginInstaller", return_value=fake_manager),
|
||||
patch(f"{MODULE}.marketplace_batch_fetch_plugin_manifests", return_value=[manifest]),
|
||||
patch(f"{MODULE}.PluginService.upgrade_plugin_with_marketplace"),
|
||||
):
|
||||
from tasks.process_tenant_plugin_autoupgrade_check_task import (
|
||||
process_tenant_plugin_autoupgrade_check_task,
|
||||
)
|
||||
|
||||
process_tenant_plugin_autoupgrade_check_task(
|
||||
"tenant-1",
|
||||
TenantPluginAutoUpgradeStrategy.StrategySetting.LATEST,
|
||||
0,
|
||||
TenantPluginAutoUpgradeStrategy.UpgradeMode.ALL,
|
||||
[],
|
||||
[],
|
||||
)
|
||||
|
||||
fake_manager.upgrade_plugin.assert_not_called()
|
||||
|
||||
|
||||
class TestStrategySetting:
|
||||
def test_disabled_strategy_skips_everything(self):
|
||||
upgrade_mock, _ = _run_task(
|
||||
plugins=[_make_plugin("acme/foo", "1.0.0")],
|
||||
manifests=[_make_manifest("acme/foo", "1.0.1")],
|
||||
strategy_setting=TenantPluginAutoUpgradeStrategy.StrategySetting.DISABLED,
|
||||
)
|
||||
upgrade_mock.assert_not_called()
|
||||
|
||||
def test_fix_only_upgrades_patch_version(self):
|
||||
upgrade_mock, calls = _run_task(
|
||||
plugins=[_make_plugin("acme/foo", "1.0.0")],
|
||||
manifests=[_make_manifest("acme/foo", "1.0.5")],
|
||||
strategy_setting=TenantPluginAutoUpgradeStrategy.StrategySetting.FIX_ONLY,
|
||||
)
|
||||
upgrade_mock.assert_called_once()
|
||||
assert calls[0][2].endswith(":1.0.5@cafe1234")
|
||||
|
||||
def test_fix_only_skips_minor_bump(self):
|
||||
upgrade_mock, _ = _run_task(
|
||||
plugins=[_make_plugin("acme/foo", "1.0.0")],
|
||||
manifests=[_make_manifest("acme/foo", "1.1.0")],
|
||||
strategy_setting=TenantPluginAutoUpgradeStrategy.StrategySetting.FIX_ONLY,
|
||||
)
|
||||
upgrade_mock.assert_not_called()
|
||||
|
||||
def test_fix_only_skips_major_bump(self):
|
||||
upgrade_mock, _ = _run_task(
|
||||
plugins=[_make_plugin("acme/foo", "1.0.0")],
|
||||
manifests=[_make_manifest("acme/foo", "2.0.0")],
|
||||
strategy_setting=TenantPluginAutoUpgradeStrategy.StrategySetting.FIX_ONLY,
|
||||
)
|
||||
upgrade_mock.assert_not_called()
|
||||
|
||||
def test_latest_strategy_skips_when_versions_equal(self):
|
||||
upgrade_mock, _ = _run_task(
|
||||
plugins=[_make_plugin("acme/foo", "1.0.0")],
|
||||
manifests=[_make_manifest("acme/foo", "1.0.0")],
|
||||
strategy_setting=TenantPluginAutoUpgradeStrategy.StrategySetting.LATEST,
|
||||
)
|
||||
upgrade_mock.assert_not_called()
|
||||
|
||||
|
||||
class TestUpgradeMode:
|
||||
def test_mode_all_upgrades_every_marketplace_plugin(self):
|
||||
plugins = [
|
||||
_make_plugin("acme/foo", "1.0.0"),
|
||||
_make_plugin("acme/bar", "2.0.0"),
|
||||
]
|
||||
manifests = [
|
||||
_make_manifest("acme/foo", "1.0.1"),
|
||||
_make_manifest("acme/bar", "2.0.1"),
|
||||
]
|
||||
|
||||
upgrade_mock, calls = _run_task(
|
||||
plugins=plugins,
|
||||
manifests=manifests,
|
||||
upgrade_mode=TenantPluginAutoUpgradeStrategy.UpgradeMode.ALL,
|
||||
)
|
||||
|
||||
assert upgrade_mock.call_count == 2
|
||||
upgraded_ids = sorted(c[1] for c in calls)
|
||||
assert upgraded_ids == sorted(p.plugin_unique_identifier for p in plugins)
|
||||
|
||||
def test_mode_all_skips_non_marketplace_sources(self):
|
||||
plugins = [
|
||||
_make_plugin("acme/foo", "1.0.0"),
|
||||
_make_plugin("acme/bar", "2.0.0", source=PluginInstallationSource.Github),
|
||||
]
|
||||
manifests = [
|
||||
_make_manifest("acme/foo", "1.0.1"),
|
||||
_make_manifest("acme/bar", "2.0.1"),
|
||||
]
|
||||
|
||||
upgrade_mock, calls = _run_task(
|
||||
plugins=plugins,
|
||||
manifests=manifests,
|
||||
upgrade_mode=TenantPluginAutoUpgradeStrategy.UpgradeMode.ALL,
|
||||
)
|
||||
|
||||
assert upgrade_mock.call_count == 1
|
||||
assert calls[0][1] == plugins[0].plugin_unique_identifier
|
||||
|
||||
def test_mode_partial_only_upgrades_included_plugins(self):
|
||||
plugins = [
|
||||
_make_plugin("acme/foo", "1.0.0"),
|
||||
_make_plugin("acme/bar", "2.0.0"),
|
||||
]
|
||||
manifests = [
|
||||
_make_manifest("acme/foo", "1.0.1"),
|
||||
_make_manifest("acme/bar", "2.0.1"),
|
||||
]
|
||||
|
||||
upgrade_mock, calls = _run_task(
|
||||
plugins=plugins,
|
||||
manifests=manifests,
|
||||
upgrade_mode=TenantPluginAutoUpgradeStrategy.UpgradeMode.PARTIAL,
|
||||
include_plugins=["acme/foo"],
|
||||
)
|
||||
|
||||
assert upgrade_mock.call_count == 1
|
||||
assert calls[0][1] == plugins[0].plugin_unique_identifier
|
||||
|
||||
def test_mode_exclude_skips_excluded_plugins(self):
|
||||
plugins = [
|
||||
_make_plugin("acme/foo", "1.0.0"),
|
||||
_make_plugin("acme/bar", "2.0.0"),
|
||||
]
|
||||
manifests = [
|
||||
_make_manifest("acme/foo", "1.0.1"),
|
||||
_make_manifest("acme/bar", "2.0.1"),
|
||||
]
|
||||
|
||||
upgrade_mock, calls = _run_task(
|
||||
plugins=plugins,
|
||||
manifests=manifests,
|
||||
upgrade_mode=TenantPluginAutoUpgradeStrategy.UpgradeMode.EXCLUDE,
|
||||
exclude_plugins=["acme/bar"],
|
||||
)
|
||||
|
||||
assert upgrade_mock.call_count == 1
|
||||
assert calls[0][1] == plugins[0].plugin_unique_identifier
|
||||
|
||||
|
||||
class TestErrorIsolation:
|
||||
def test_one_plugin_failure_does_not_block_others(self):
|
||||
plugins = [
|
||||
_make_plugin("acme/foo", "1.0.0"),
|
||||
_make_plugin("acme/bar", "2.0.0"),
|
||||
]
|
||||
manifests = [
|
||||
_make_manifest("acme/foo", "1.0.1"),
|
||||
_make_manifest("acme/bar", "2.0.1"),
|
||||
]
|
||||
fake_manager = MagicMock()
|
||||
fake_manager.list_plugins.return_value = plugins
|
||||
|
||||
seen: list[str] = []
|
||||
|
||||
def _upgrade(tenant_id, original, new):
|
||||
seen.append(original)
|
||||
if "foo" in original:
|
||||
raise RuntimeError("boom")
|
||||
|
||||
with (
|
||||
patch(f"{MODULE}.PluginInstaller", return_value=fake_manager),
|
||||
patch(f"{MODULE}.marketplace_batch_fetch_plugin_manifests", return_value=manifests),
|
||||
patch(f"{MODULE}.PluginService.upgrade_plugin_with_marketplace", side_effect=_upgrade),
|
||||
):
|
||||
from tasks.process_tenant_plugin_autoupgrade_check_task import (
|
||||
process_tenant_plugin_autoupgrade_check_task,
|
||||
)
|
||||
|
||||
process_tenant_plugin_autoupgrade_check_task(
|
||||
"tenant-1",
|
||||
TenantPluginAutoUpgradeStrategy.StrategySetting.LATEST,
|
||||
0,
|
||||
TenantPluginAutoUpgradeStrategy.UpgradeMode.ALL,
|
||||
[],
|
||||
[],
|
||||
)
|
||||
|
||||
assert any("foo" in s for s in seen)
|
||||
assert any("bar" in s for s in seen)
|
||||
@ -35,8 +35,8 @@
|
||||
"stepOne.uploader.cancel": "Cancel",
|
||||
"stepOne.uploader.change": "Change",
|
||||
"stepOne.uploader.failed": "Upload failed",
|
||||
"stepOne.uploader.tip": "Supports {{supportTypes}}. Max {{batchCount}} in a batch and {{size}} MB each.",
|
||||
"stepOne.uploader.tipWithTotalLimit": "Supports {{supportTypes}}. Max {{batchCount}} in a batch and {{size}} MB each. Max total {{totalCount}} files.",
|
||||
"stepOne.uploader.tip": "Ondersteunt {{supportTypes}}. Maximaal {{batchCount}} per batch en {{size}} MB per bestand.",
|
||||
"stepOne.uploader.tipWithTotalLimit": "Ondersteunt {{supportTypes}}. Maximaal {{batchCount}} per batch en {{size}} MB per bestand. Maximaal {{totalCount}} bestanden in totaal.",
|
||||
"stepOne.uploader.title": "Upload file",
|
||||
"stepOne.uploader.validation.count": "Multiple files not supported",
|
||||
"stepOne.uploader.validation.filesNumber": "You have reached the batch upload limit of {{filesNumber}}.",
|
||||
|
||||
Loading…
Reference in New Issue
Block a user