Merge branch 'main' into jzh

This commit is contained in:
JzoNg 2026-04-28 11:02:44 +08:00
commit 9dd73b4d47
15 changed files with 1051 additions and 179 deletions

View File

@ -39,6 +39,58 @@ class AbstractVectorFactory(ABC):
return index_struct_dict
class _LazyEmbeddings(Embeddings):
"""Lazy proxy that defers materializing the real embedding model.
Constructing the real embeddings (via ``ModelManager.get_model_instance``)
transitively calls ``FeatureService.get_features`` ``BillingService``
HTTP GETs (see ``provider_manager.py``). Cleanup paths
(``delete_by_ids`` / ``delete`` / ``text_exists``) do not need embeddings
at all, so deferring this until an ``embed_*`` method is actually invoked
keeps cleanup tasks resilient to transient billing-API failures and avoids
leaving stranded ``document_segments`` / ``child_chunks`` whenever billing
hiccups.
Existing callers that perform create / search operations are unaffected:
the first ``embed_*`` call materializes the underlying model and the
behavior is identical from that point on.
"""
def __init__(self, dataset: Dataset):
self._dataset = dataset
self._real: Embeddings | None = None
def _ensure(self) -> Embeddings:
if self._real is None:
model_manager = ModelManager.for_tenant(tenant_id=self._dataset.tenant_id)
embedding_model = model_manager.get_model_instance(
tenant_id=self._dataset.tenant_id,
provider=self._dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=self._dataset.embedding_model,
)
self._real = CacheEmbedding(embedding_model)
return self._real
def embed_documents(self, texts: list[str]) -> list[list[float]]:
return self._ensure().embed_documents(texts)
def embed_multimodal_documents(self, multimodel_documents: list[dict[str, Any]]) -> list[list[float]]:
return self._ensure().embed_multimodal_documents(multimodel_documents)
def embed_query(self, text: str) -> list[float]:
return self._ensure().embed_query(text)
def embed_multimodal_query(self, multimodel_document: dict[str, Any]) -> list[float]:
return self._ensure().embed_multimodal_query(multimodel_document)
async def aembed_documents(self, texts: list[str]) -> list[list[float]]:
return await self._ensure().aembed_documents(texts)
async def aembed_query(self, text: str) -> list[float]:
return await self._ensure().aembed_query(text)
class Vector:
def __init__(self, dataset: Dataset, attributes: list | None = None):
if attributes is None:
@ -60,7 +112,11 @@ class Vector:
"original_chunk_id",
]
self._dataset = dataset
self._embeddings = self._get_embeddings()
# Use a lazy proxy so cleanup paths (delete_by_ids / delete / text_exists)
# never transitively trigger billing API calls during ``Vector(dataset)``
# construction. The real embedding model is materialized only when an
# ``embed_*`` method is actually invoked (i.e. create / search paths).
self._embeddings: Embeddings = _LazyEmbeddings(dataset)
self._attributes = attributes
self._vector_processor = self._init_vector()

View File

@ -2182,7 +2182,7 @@ class ApiToken(Base): # bug: this uses setattr so idk the field.
return result
class UploadFile(Base):
class UploadFile(TypeBase):
__tablename__ = "upload_files"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="upload_file_pkey"),
@ -2190,9 +2190,12 @@ class UploadFile(Base):
)
# NOTE: The `id` field is generated within the application to minimize extra roundtrips
# (especially when generating `source_url`).
# The `server_default` serves as a fallback mechanism.
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()))
# (especially when generating `source_url`) and keep model metadata portable across databases.
id: Mapped[str] = mapped_column(
StringUUID,
init=False,
default_factory=lambda: str(uuid4()),
)
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
storage_type: Mapped[StorageType] = mapped_column(EnumText(StorageType, length=255), nullable=False)
key: Mapped[str] = mapped_column(String(255), nullable=False)
@ -2200,16 +2203,6 @@ class UploadFile(Base):
size: Mapped[int] = mapped_column(sa.Integer, nullable=False)
extension: Mapped[str] = mapped_column(String(255), nullable=False)
mime_type: Mapped[str] = mapped_column(String(255), nullable=True)
# The `created_by_role` field indicates whether the file was created by an `Account` or an `EndUser`.
# Its value is derived from the `CreatorUserRole` enumeration.
created_by_role: Mapped[CreatorUserRole] = mapped_column(
EnumText(CreatorUserRole, length=255),
nullable=False,
server_default=sa.text("'account'"),
default=CreatorUserRole.ACCOUNT,
)
# The `created_by` field stores the ID of the entity that created this upload file.
#
# If `created_by_role` is `ACCOUNT`, it corresponds to `Account.id`.
@ -2228,10 +2221,18 @@ class UploadFile(Base):
# `used` may indicate whether the file has been utilized by another service.
used: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
# The `created_by_role` field indicates whether the file was created by an `Account` or an `EndUser`.
# Its value is derived from the `CreatorUserRole` enumeration.
created_by_role: Mapped[CreatorUserRole] = mapped_column(
EnumText(CreatorUserRole, length=255),
nullable=False,
server_default=sa.text("'account'"),
default=CreatorUserRole.ACCOUNT,
)
# `used_by` may indicate the ID of the user who utilized this file.
used_by: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
used_at: Mapped[datetime | None] = mapped_column(sa.DateTime, nullable=True)
hash: Mapped[str | None] = mapped_column(String(255), nullable=True)
used_by: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None)
used_at: Mapped[datetime | None] = mapped_column(sa.DateTime, nullable=True, default=None)
hash: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None)
source_url: Mapped[str] = mapped_column(LongText, default="")
def __init__(

View File

@ -50,7 +50,7 @@ from libs.uuid_utils import uuidv7
from ._workflow_exc import NodeNotFoundError, WorkflowDataError
if TYPE_CHECKING:
from .model import AppMode, UploadFile
from .model import AppMode
from constants import DEFAULT_FILE_NUMBER_LIMITS, HIDDEN_VALUE
@ -63,6 +63,10 @@ from .account import Account
from .base import Base, DefaultFieldsDCMixin, TypeBase
from .engine import db
from .enums import CreatorUserRole, DraftVariableType, ExecutionOffLoadType, WorkflowRunTriggeredFrom
# UploadFile uses TypeBase while workflow execution offload models use Base, so relationships
# must target the class object directly instead of relying on string lookup across registries.
from .model import UploadFile
from .types import EnumText, LongText, StringUUID
from .utils.file_input_compat import (
build_file_from_mapping_without_lookup,
@ -1096,8 +1100,6 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo
@staticmethod
def _load_full_content(session: orm.Session, file_id: str, storage: Storage):
from .model import UploadFile
stmt = sa.select(UploadFile).where(UploadFile.id == file_id)
file = session.scalars(stmt).first()
assert file is not None, f"UploadFile with id {file_id} should exist but not"
@ -1191,10 +1193,11 @@ class WorkflowNodeExecutionOffload(Base):
)
file: Mapped[Optional["UploadFile"]] = orm.relationship(
UploadFile,
foreign_keys=[file_id],
lazy="raise",
uselist=False,
primaryjoin="WorkflowNodeExecutionOffload.file_id == UploadFile.id",
primaryjoin=lambda: orm.foreign(WorkflowNodeExecutionOffload.file_id) == UploadFile.id,
)
@ -1968,10 +1971,11 @@ class WorkflowDraftVariableFile(Base):
# Relationship to UploadFile
upload_file: Mapped["UploadFile"] = orm.relationship(
UploadFile,
foreign_keys=[upload_file_id],
lazy="raise",
uselist=False,
primaryjoin="WorkflowDraftVariableFile.upload_file_id == UploadFile.id",
primaryjoin=lambda: orm.foreign(WorkflowDraftVariableFile.upload_file_id) == UploadFile.id,
)

View File

@ -61,13 +61,31 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i
# check segment is exist
if index_node_ids:
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
with session_factory.create_session() as session:
dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1))
if dataset:
index_processor.clean(
dataset, index_node_ids, with_keywords=True, delete_child_chunks=True, delete_summaries=True
)
# Wrap vector / keyword index cleanup in try/except so that a transient
# failure here (e.g. billing API hiccup propagated via FeatureService when
# ModelManager is initialized inside ``Vector(dataset)``) does not abort
# the entire task and leave document_segments / child_chunks / image_files
# / metadata bindings stranded in PG. Mirrors the pattern already used in
# ``clean_dataset_task`` so the document row's hard delete (already
# committed by the caller) does not produce orphan PG rows just because
# the vector backend or one of its transitive dependencies was unhappy.
try:
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
with session_factory.create_session() as session:
dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1))
if dataset:
index_processor.clean(
dataset, index_node_ids, with_keywords=True, delete_child_chunks=True, delete_summaries=True
)
except Exception:
logger.exception(
"Failed to clean vector / keyword index in clean_document_task, "
"document_id=%s, dataset_id=%s, index_node_ids_count=%d. "
"Continuing with PG / storage cleanup; vector orphans can be reaped later.",
document_id,
dataset_id,
len(index_node_ids),
)
total_image_files = []
with session_factory.create_session() as session, session.begin():

View File

@ -40,12 +40,29 @@ def clean_notion_document_task(document_ids: list[str], dataset_id: str):
segments = session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all()
total_index_node_ids.extend([segment.index_node_id for segment in segments])
with session_factory.create_session() as session:
dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1))
if dataset:
index_processor.clean(
dataset, total_index_node_ids, with_keywords=True, delete_child_chunks=True, delete_summaries=True
)
# Wrap vector / keyword index cleanup in try/except so that a transient
# failure here (e.g. billing API hiccup propagated via FeatureService when
# ``ModelManager`` is initialized inside ``Vector(dataset)``) does not abort
# the task and leave the already-deleted documents' segments stranded in PG.
# The Document rows are hard-deleted in the previous session block, so any
# exception escaping this task would produce orphans that no later request
# can reference back. Mirrors the pattern in ``clean_dataset_task``.
try:
with session_factory.create_session() as session:
dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1))
if dataset:
index_processor.clean(
dataset, total_index_node_ids, with_keywords=True, delete_child_chunks=True, delete_summaries=True
)
except Exception:
logger.exception(
"Failed to clean vector / keyword index in clean_notion_document_task, "
"dataset_id=%s, document_ids=%s, index_node_ids_count=%d. "
"Continuing with segment deletion; vector orphans can be reaped later.",
dataset_id,
document_ids,
len(total_index_node_ids),
)
with session_factory.create_session() as session, session.begin():
segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.document_id.in_(document_ids))

View File

@ -11,6 +11,7 @@ from core.plugin.entities.plugin import PluginInstallationSource
from core.plugin.impl.plugin import PluginInstaller
from extensions.ext_redis import redis_client
from models.account import TenantPluginAutoUpgradeStrategy
from services.plugin.plugin_service import PluginService
logger = logging.getLogger(__name__)
@ -171,14 +172,13 @@ def process_tenant_plugin_autoupgrade_check_task(
fg="green",
)
)
_ = manager.upgrade_plugin(
# Use the service that downloads and uploads the package to the daemon
# first; calling manager.upgrade_plugin directly skips that step and the
# daemon fails because the package never reaches its local bucket.
_ = PluginService.upgrade_plugin_with_marketplace(
tenant_id,
original_unique_identifier,
new_unique_identifier,
PluginInstallationSource.Marketplace,
{
"plugin_unique_identifier": new_unique_identifier,
},
)
except Exception as e:
click.echo(click.style(f"Error when upgrading plugin: {e}", fg="red"))

View File

@ -602,14 +602,25 @@ class TestCleanNotionDocumentTask:
# Note: This test successfully verifies database operations.
# IndexProcessor verification would require more sophisticated mocking.
def test_clean_notion_document_task_database_transaction_rollback(
def test_clean_notion_document_task_continues_when_index_processor_fails(
self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies
):
"""
Test cleanup task behavior when database operations fail.
Index processor failure (e.g. transient billing API error propagated via
``FeatureService`` when ``Vector(dataset)`` lazily resolves the embedding
model) must NOT abort the cleanup task. The Document rows have already
been hard-deleted in the first session block before vector cleanup runs,
so any uncaught exception escaping the task would strand
``DocumentSegment`` rows in PG with no parent ``Document``.
This test verifies that the task properly handles database errors
and maintains data consistency.
Contract: the task swallows the index_processor exception, logs it, and
proceeds to delete the segments leaving PG consistent. (Vector orphans,
if any, can be reaped later by an offline scanner.)
Regression guard for the production incident where ``clean_document_task``
/ ``clean_notion_document_task`` failed with
``ValueError("Unable to retrieve billing information...")`` and left
tens of thousands of orphan segments per affected tenant.
"""
fake = Faker()
@ -672,17 +683,28 @@ class TestCleanNotionDocumentTask:
db_session_with_containers.add(segment)
db_session_with_containers.commit()
# Mock index processor to raise an exception
# Simulate the production failure mode: index_processor.clean() raises a
# ValueError mirroring ``BillingService._send_request`` returning non-200.
mock_index_processor = mock_index_processor_factory.return_value.init_index_processor.return_value
mock_index_processor.clean.side_effect = Exception("Index processor error")
mock_index_processor.clean.side_effect = ValueError(
"Unable to retrieve billing information. Please try again later or contact support."
)
# Execute cleanup task - current implementation propagates the exception
with pytest.raises(Exception, match="Index processor error"):
clean_notion_document_task([document.id], dataset.id)
# Execute cleanup task — must NOT raise even though clean() raises.
# Before the safety-net wrapper this would have re-raised the ValueError,
# aborting the task and leaving DocumentSegment stranded in PG.
clean_notion_document_task([document.id], dataset.id)
# Note: This test demonstrates the task's error handling capability.
# Even with external service errors, the database operations complete successfully.
# In a production environment, proper error handling would determine transaction rollback behavior.
# Vector cleanup was attempted exactly once.
mock_index_processor.clean.assert_called_once()
# The crucial assertion: despite the index processor failure, the
# final session block (line 51-52, ``DELETE FROM document_segments``)
# still ran and committed. This is what the wrapper buys us — without
# it the production incident left tens of thousands of orphan segments
# per affected tenant. Aligns with the assertion shape used by the
# happy-path test (``test_clean_notion_document_task_success``).
assert _count_segments(db_session_with_containers, DocumentSegment.document_id == document.id) == 0
def test_clean_notion_document_task_with_large_number_of_documents(
self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies

View File

@ -146,10 +146,7 @@ def test_get_vector_factory_entry_point_overrides_builtin(vector_factory_module,
def test_vector_init_uses_default_and_custom_attributes(vector_factory_module):
dataset = SimpleNamespace(id="dataset-1")
with (
patch.object(vector_factory_module.Vector, "_get_embeddings", return_value="embeddings"),
patch.object(vector_factory_module.Vector, "_init_vector", return_value="processor"),
):
with patch.object(vector_factory_module.Vector, "_init_vector", return_value="processor"):
default_vector = vector_factory_module.Vector(dataset)
custom_vector = vector_factory_module.Vector(dataset, attributes=["doc_id"])
@ -166,10 +163,57 @@ def test_vector_init_uses_default_and_custom_attributes(vector_factory_module):
"original_chunk_id",
]
assert custom_vector._attributes == ["doc_id"]
assert default_vector._embeddings == "embeddings"
# ``_embeddings`` is now a lazy proxy that defers materializing the real
# embedding model until ``embed_*`` is invoked, so cleanup paths never
# trigger billing/feature-service calls during ``Vector(dataset)``
# construction. See ``_LazyEmbeddings``.
assert isinstance(default_vector._embeddings, vector_factory_module._LazyEmbeddings)
assert default_vector._vector_processor == "processor"
def test_lazy_embeddings_defer_real_load_until_first_embed_call(vector_factory_module, monkeypatch):
"""``Vector(dataset)`` must not transitively call ``ModelManager`` during
construction. The real embedding model should only be materialized on the
first ``embed_*`` call (i.e. create / search paths) so cleanup paths
(``delete_by_ids`` / ``delete``) remain resilient to billing-API failures.
"""
for_tenant_mock = MagicMock(side_effect=AssertionError("ModelManager.for_tenant must not be called eagerly"))
monkeypatch.setattr(vector_factory_module.ModelManager, "for_tenant", for_tenant_mock)
dataset = SimpleNamespace(
tenant_id="tenant-1",
embedding_model_provider="openai",
embedding_model="text-embedding-3-small",
)
proxy = vector_factory_module._LazyEmbeddings(dataset)
# Construction alone does not trigger ModelManager / FeatureService / BillingService.
for_tenant_mock.assert_not_called()
# Exercising an embed_* method materializes the real model exactly once.
inner_model = MagicMock()
inner_model.embed_documents.return_value = [[0.1, 0.2]]
cached_embedding_mock = MagicMock(return_value=inner_model)
real_for_tenant = MagicMock()
real_for_tenant.get_model_instance.return_value = "embedding-model-instance"
monkeypatch.setattr(vector_factory_module.ModelManager, "for_tenant", MagicMock(return_value=real_for_tenant))
monkeypatch.setattr(vector_factory_module, "CacheEmbedding", cached_embedding_mock)
result = proxy.embed_documents(["hello"])
assert result == [[0.1, 0.2]]
cached_embedding_mock.assert_called_once_with("embedding-model-instance")
inner_model.embed_documents.assert_called_once_with(["hello"])
# Subsequent calls reuse the materialized model (no re-resolution).
inner_model.embed_documents.reset_mock()
cached_embedding_mock.reset_mock()
proxy.embed_documents(["world"])
cached_embedding_mock.assert_not_called()
inner_model.embed_documents.assert_called_once_with(["world"])
def test_init_vector_prefers_dataset_index_struct(vector_factory_module, monkeypatch):
calls = {"vector_type": None, "init_args": None}

View File

@ -45,7 +45,7 @@ class TestWorkflowModelValidation:
workflow = Workflow.new(
tenant_id=tenant_id,
app_id=app_id,
type=WorkflowType.WORKFLOW.value,
type=WorkflowType.WORKFLOW,
version="draft",
graph=graph,
features=features,
@ -58,7 +58,7 @@ class TestWorkflowModelValidation:
# Assert
assert workflow.tenant_id == tenant_id
assert workflow.app_id == app_id
assert workflow.type == WorkflowType.WORKFLOW.value
assert workflow.type == WorkflowType.WORKFLOW
assert workflow.version == "draft"
assert workflow.graph == graph
assert workflow.created_by == created_by
@ -68,7 +68,7 @@ class TestWorkflowModelValidation:
def test_workflow_type_enum_values(self):
"""Test WorkflowType enum values."""
# Assert
assert WorkflowType.WORKFLOW.value == "workflow"
assert WorkflowType.WORKFLOW == "workflow"
assert WorkflowType.CHAT.value == "chat"
assert WorkflowType.RAG_PIPELINE.value == "rag-pipeline"
@ -89,7 +89,7 @@ class TestWorkflowModelValidation:
workflow = Workflow.new(
tenant_id=str(uuid4()),
app_id=str(uuid4()),
type=WorkflowType.WORKFLOW.value,
type=WorkflowType.WORKFLOW,
version="draft",
graph=json.dumps(graph_data),
features="{}",
@ -114,7 +114,7 @@ class TestWorkflowModelValidation:
workflow = Workflow.new(
tenant_id=str(uuid4()),
app_id=str(uuid4()),
type=WorkflowType.WORKFLOW.value,
type=WorkflowType.WORKFLOW,
version="draft",
graph="{}",
features=json.dumps(features_data),
@ -138,7 +138,7 @@ class TestWorkflowModelValidation:
workflow = Workflow.new(
tenant_id=str(uuid4()),
app_id=str(uuid4()),
type=WorkflowType.WORKFLOW.value,
type=WorkflowType.WORKFLOW,
version="v1.0",
graph="{}",
features="{}",
@ -176,11 +176,11 @@ class TestWorkflowRunStateTransitions:
tenant_id=tenant_id,
app_id=app_id,
workflow_id=workflow_id,
type=WorkflowType.WORKFLOW.value,
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING.value,
type=WorkflowType.WORKFLOW,
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
version="draft",
status=WorkflowExecutionStatus.RUNNING.value,
created_by_role=CreatorUserRole.ACCOUNT.value,
status=WorkflowExecutionStatus.RUNNING,
created_by_role=CreatorUserRole.ACCOUNT,
created_by=created_by,
)
@ -188,9 +188,9 @@ class TestWorkflowRunStateTransitions:
assert workflow_run.tenant_id == tenant_id
assert workflow_run.app_id == app_id
assert workflow_run.workflow_id == workflow_id
assert workflow_run.type == WorkflowType.WORKFLOW.value
assert workflow_run.triggered_from == WorkflowRunTriggeredFrom.DEBUGGING.value
assert workflow_run.status == WorkflowExecutionStatus.RUNNING.value
assert workflow_run.type == WorkflowType.WORKFLOW
assert workflow_run.triggered_from == WorkflowRunTriggeredFrom.DEBUGGING
assert workflow_run.status == WorkflowExecutionStatus.RUNNING
assert workflow_run.created_by == created_by
def test_workflow_run_state_transition_running_to_succeeded(self):
@ -200,21 +200,21 @@ class TestWorkflowRunStateTransitions:
tenant_id=str(uuid4()),
app_id=str(uuid4()),
workflow_id=str(uuid4()),
type=WorkflowType.WORKFLOW.value,
triggered_from=WorkflowRunTriggeredFrom.APP_RUN.value,
type=WorkflowType.WORKFLOW,
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
version="v1.0",
status=WorkflowExecutionStatus.RUNNING.value,
created_by_role=CreatorUserRole.END_USER.value,
status=WorkflowExecutionStatus.RUNNING,
created_by_role=CreatorUserRole.END_USER,
created_by=str(uuid4()),
)
# Act
workflow_run.status = WorkflowExecutionStatus.SUCCEEDED.value
workflow_run.status = WorkflowExecutionStatus.SUCCEEDED
workflow_run.finished_at = datetime.now(UTC)
workflow_run.elapsed_time = 2.5
# Assert
assert workflow_run.status == WorkflowExecutionStatus.SUCCEEDED.value
assert workflow_run.status == WorkflowExecutionStatus.SUCCEEDED
assert workflow_run.finished_at is not None
assert workflow_run.elapsed_time == 2.5
@ -225,21 +225,21 @@ class TestWorkflowRunStateTransitions:
tenant_id=str(uuid4()),
app_id=str(uuid4()),
workflow_id=str(uuid4()),
type=WorkflowType.WORKFLOW.value,
triggered_from=WorkflowRunTriggeredFrom.APP_RUN.value,
type=WorkflowType.WORKFLOW,
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
version="v1.0",
status=WorkflowExecutionStatus.RUNNING.value,
created_by_role=CreatorUserRole.ACCOUNT.value,
status=WorkflowExecutionStatus.RUNNING,
created_by_role=CreatorUserRole.ACCOUNT,
created_by=str(uuid4()),
)
# Act
workflow_run.status = WorkflowExecutionStatus.FAILED.value
workflow_run.status = WorkflowExecutionStatus.FAILED
workflow_run.error = "Node execution failed: Invalid input"
workflow_run.finished_at = datetime.now(UTC)
# Assert
assert workflow_run.status == WorkflowExecutionStatus.FAILED.value
assert workflow_run.status == WorkflowExecutionStatus.FAILED
assert workflow_run.error == "Node execution failed: Invalid input"
assert workflow_run.finished_at is not None
@ -250,20 +250,20 @@ class TestWorkflowRunStateTransitions:
tenant_id=str(uuid4()),
app_id=str(uuid4()),
workflow_id=str(uuid4()),
type=WorkflowType.WORKFLOW.value,
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING.value,
type=WorkflowType.WORKFLOW,
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
version="draft",
status=WorkflowExecutionStatus.RUNNING.value,
created_by_role=CreatorUserRole.ACCOUNT.value,
status=WorkflowExecutionStatus.RUNNING,
created_by_role=CreatorUserRole.ACCOUNT,
created_by=str(uuid4()),
)
# Act
workflow_run.status = WorkflowExecutionStatus.STOPPED.value
workflow_run.status = WorkflowExecutionStatus.STOPPED
workflow_run.finished_at = datetime.now(UTC)
# Assert
assert workflow_run.status == WorkflowExecutionStatus.STOPPED.value
assert workflow_run.status == WorkflowExecutionStatus.STOPPED
assert workflow_run.finished_at is not None
def test_workflow_run_state_transition_running_to_paused(self):
@ -273,19 +273,19 @@ class TestWorkflowRunStateTransitions:
tenant_id=str(uuid4()),
app_id=str(uuid4()),
workflow_id=str(uuid4()),
type=WorkflowType.WORKFLOW.value,
triggered_from=WorkflowRunTriggeredFrom.APP_RUN.value,
type=WorkflowType.WORKFLOW,
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
version="v1.0",
status=WorkflowExecutionStatus.RUNNING.value,
created_by_role=CreatorUserRole.END_USER.value,
status=WorkflowExecutionStatus.RUNNING,
created_by_role=CreatorUserRole.END_USER,
created_by=str(uuid4()),
)
# Act
workflow_run.status = WorkflowExecutionStatus.PAUSED.value
workflow_run.status = WorkflowExecutionStatus.PAUSED
# Assert
assert workflow_run.status == WorkflowExecutionStatus.PAUSED.value
assert workflow_run.status == WorkflowExecutionStatus.PAUSED
assert workflow_run.finished_at is None # Not finished when paused
def test_workflow_run_state_transition_paused_to_running(self):
@ -295,19 +295,19 @@ class TestWorkflowRunStateTransitions:
tenant_id=str(uuid4()),
app_id=str(uuid4()),
workflow_id=str(uuid4()),
type=WorkflowType.WORKFLOW.value,
triggered_from=WorkflowRunTriggeredFrom.APP_RUN.value,
type=WorkflowType.WORKFLOW,
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
version="v1.0",
status=WorkflowExecutionStatus.PAUSED.value,
created_by_role=CreatorUserRole.ACCOUNT.value,
status=WorkflowExecutionStatus.PAUSED,
created_by_role=CreatorUserRole.ACCOUNT,
created_by=str(uuid4()),
)
# Act
workflow_run.status = WorkflowExecutionStatus.RUNNING.value
workflow_run.status = WorkflowExecutionStatus.RUNNING
# Assert
assert workflow_run.status == WorkflowExecutionStatus.RUNNING.value
assert workflow_run.status == WorkflowExecutionStatus.RUNNING
def test_workflow_run_with_partial_succeeded_status(self):
"""Test workflow run with partial-succeeded status."""
@ -316,17 +316,17 @@ class TestWorkflowRunStateTransitions:
tenant_id=str(uuid4()),
app_id=str(uuid4()),
workflow_id=str(uuid4()),
type=WorkflowType.WORKFLOW.value,
triggered_from=WorkflowRunTriggeredFrom.APP_RUN.value,
type=WorkflowType.WORKFLOW,
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
version="v1.0",
status=WorkflowExecutionStatus.PARTIAL_SUCCEEDED.value,
created_by_role=CreatorUserRole.ACCOUNT.value,
status=WorkflowExecutionStatus.PARTIAL_SUCCEEDED,
created_by_role=CreatorUserRole.ACCOUNT,
created_by=str(uuid4()),
exceptions_count=2,
)
# Assert
assert workflow_run.status == WorkflowExecutionStatus.PARTIAL_SUCCEEDED.value
assert workflow_run.status == WorkflowExecutionStatus.PARTIAL_SUCCEEDED
assert workflow_run.exceptions_count == 2
def test_workflow_run_with_inputs_and_outputs(self):
@ -340,11 +340,11 @@ class TestWorkflowRunStateTransitions:
tenant_id=str(uuid4()),
app_id=str(uuid4()),
workflow_id=str(uuid4()),
type=WorkflowType.WORKFLOW.value,
triggered_from=WorkflowRunTriggeredFrom.APP_RUN.value,
type=WorkflowType.WORKFLOW,
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
version="v1.0",
status=WorkflowExecutionStatus.SUCCEEDED.value,
created_by_role=CreatorUserRole.END_USER.value,
status=WorkflowExecutionStatus.SUCCEEDED,
created_by_role=CreatorUserRole.END_USER,
created_by=str(uuid4()),
inputs=json.dumps(inputs),
outputs=json.dumps(outputs),
@ -362,11 +362,11 @@ class TestWorkflowRunStateTransitions:
tenant_id=str(uuid4()),
app_id=str(uuid4()),
workflow_id=str(uuid4()),
type=WorkflowType.WORKFLOW.value,
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING.value,
type=WorkflowType.WORKFLOW,
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
version="draft",
status=WorkflowExecutionStatus.RUNNING.value,
created_by_role=CreatorUserRole.ACCOUNT.value,
status=WorkflowExecutionStatus.RUNNING,
created_by_role=CreatorUserRole.ACCOUNT,
created_by=str(uuid4()),
graph=json.dumps(graph),
)
@ -391,11 +391,11 @@ class TestWorkflowRunStateTransitions:
tenant_id=tenant_id,
app_id=app_id,
workflow_id=workflow_id,
type=WorkflowType.WORKFLOW.value,
triggered_from=WorkflowRunTriggeredFrom.APP_RUN.value,
type=WorkflowType.WORKFLOW,
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
version="v1.0",
status=WorkflowExecutionStatus.SUCCEEDED.value,
created_by_role=CreatorUserRole.ACCOUNT.value,
status=WorkflowExecutionStatus.SUCCEEDED,
created_by_role=CreatorUserRole.ACCOUNT,
created_by=created_by,
total_tokens=1500,
total_steps=5,
@ -410,7 +410,7 @@ class TestWorkflowRunStateTransitions:
assert result["tenant_id"] == tenant_id
assert result["app_id"] == app_id
assert result["workflow_id"] == workflow_id
assert result["status"] == WorkflowExecutionStatus.SUCCEEDED.value
assert result["status"] == WorkflowExecutionStatus.SUCCEEDED
assert result["total_tokens"] == 1500
assert result["total_steps"] == 5
@ -422,18 +422,18 @@ class TestWorkflowRunStateTransitions:
"tenant_id": str(uuid4()),
"app_id": str(uuid4()),
"workflow_id": str(uuid4()),
"type": WorkflowType.WORKFLOW.value,
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN.value,
"type": WorkflowType.WORKFLOW,
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN,
"version": "v1.0",
"graph": {"nodes": [], "edges": []},
"inputs": {"query": "test"},
"status": WorkflowExecutionStatus.SUCCEEDED.value,
"status": WorkflowExecutionStatus.SUCCEEDED,
"outputs": {"result": "success"},
"error": None,
"elapsed_time": 3.5,
"total_tokens": 2000,
"total_steps": 10,
"created_by_role": CreatorUserRole.ACCOUNT.value,
"created_by_role": CreatorUserRole.ACCOUNT,
"created_by": str(uuid4()),
"created_at": datetime.now(UTC),
"finished_at": datetime.now(UTC),
@ -446,7 +446,7 @@ class TestWorkflowRunStateTransitions:
# Assert
assert workflow_run.id == data["id"]
assert workflow_run.workflow_id == data["workflow_id"]
assert workflow_run.status == WorkflowExecutionStatus.SUCCEEDED.value
assert workflow_run.status == WorkflowExecutionStatus.SUCCEEDED
assert workflow_run.total_tokens == 2000
@ -467,14 +467,14 @@ class TestNodeExecutionRelationships:
tenant_id=tenant_id,
app_id=app_id,
workflow_id=workflow_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
workflow_run_id=workflow_run_id,
index=1,
node_id="start",
node_type=BuiltinNodeTypes.START,
title="Start Node",
status=WorkflowNodeExecutionStatus.SUCCEEDED.value,
created_by_role=CreatorUserRole.ACCOUNT.value,
status=WorkflowNodeExecutionStatus.SUCCEEDED,
created_by_role=CreatorUserRole.ACCOUNT,
created_by=created_by,
)
@ -498,15 +498,15 @@ class TestNodeExecutionRelationships:
tenant_id=str(uuid4()),
app_id=str(uuid4()),
workflow_id=str(uuid4()),
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
workflow_run_id=str(uuid4()),
index=2,
predecessor_node_id=predecessor_node_id,
node_id=current_node_id,
node_type=BuiltinNodeTypes.LLM,
title="LLM Node",
status=WorkflowNodeExecutionStatus.RUNNING.value,
created_by_role=CreatorUserRole.ACCOUNT.value,
status=WorkflowNodeExecutionStatus.RUNNING,
created_by_role=CreatorUserRole.ACCOUNT,
created_by=str(uuid4()),
)
@ -528,8 +528,8 @@ class TestNodeExecutionRelationships:
node_id="llm_test",
node_type=BuiltinNodeTypes.LLM,
title="Test LLM",
status=WorkflowNodeExecutionStatus.SUCCEEDED.value,
created_by_role=CreatorUserRole.ACCOUNT.value,
status=WorkflowNodeExecutionStatus.SUCCEEDED,
created_by_role=CreatorUserRole.ACCOUNT,
created_by=str(uuid4()),
)
@ -549,14 +549,14 @@ class TestNodeExecutionRelationships:
tenant_id=str(uuid4()),
app_id=str(uuid4()),
workflow_id=str(uuid4()),
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
workflow_run_id=str(uuid4()),
index=1,
node_id="llm_1",
node_type=BuiltinNodeTypes.LLM,
title="LLM Node",
status=WorkflowNodeExecutionStatus.SUCCEEDED.value,
created_by_role=CreatorUserRole.ACCOUNT.value,
status=WorkflowNodeExecutionStatus.SUCCEEDED,
created_by_role=CreatorUserRole.ACCOUNT,
created_by=str(uuid4()),
inputs=json.dumps(inputs),
outputs=json.dumps(outputs),
@ -575,24 +575,24 @@ class TestNodeExecutionRelationships:
tenant_id=str(uuid4()),
app_id=str(uuid4()),
workflow_id=str(uuid4()),
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
workflow_run_id=str(uuid4()),
index=1,
node_id="code_1",
node_type=BuiltinNodeTypes.CODE,
title="Code Node",
status=WorkflowNodeExecutionStatus.RUNNING.value,
created_by_role=CreatorUserRole.ACCOUNT.value,
status=WorkflowNodeExecutionStatus.RUNNING,
created_by_role=CreatorUserRole.ACCOUNT,
created_by=str(uuid4()),
)
# Act - transition to succeeded
node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value
node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED
node_execution.elapsed_time = 1.2
node_execution.finished_at = datetime.now(UTC)
# Assert
assert node_execution.status == WorkflowNodeExecutionStatus.SUCCEEDED.value
assert node_execution.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert node_execution.elapsed_time == 1.2
assert node_execution.finished_at is not None
@ -606,20 +606,20 @@ class TestNodeExecutionRelationships:
tenant_id=str(uuid4()),
app_id=str(uuid4()),
workflow_id=str(uuid4()),
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
workflow_run_id=str(uuid4()),
index=3,
node_id="code_1",
node_type=BuiltinNodeTypes.CODE,
title="Code Node",
status=WorkflowNodeExecutionStatus.FAILED.value,
status=WorkflowNodeExecutionStatus.FAILED,
error=error_message,
created_by_role=CreatorUserRole.ACCOUNT.value,
created_by_role=CreatorUserRole.ACCOUNT,
created_by=str(uuid4()),
)
# Assert
assert node_execution.status == WorkflowNodeExecutionStatus.FAILED.value
assert node_execution.status == WorkflowNodeExecutionStatus.FAILED
assert node_execution.error == error_message
def test_node_execution_with_metadata(self):
@ -637,14 +637,14 @@ class TestNodeExecutionRelationships:
tenant_id=str(uuid4()),
app_id=str(uuid4()),
workflow_id=str(uuid4()),
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
workflow_run_id=str(uuid4()),
index=1,
node_id="llm_1",
node_type=BuiltinNodeTypes.LLM,
title="LLM Node",
status=WorkflowNodeExecutionStatus.SUCCEEDED.value,
created_by_role=CreatorUserRole.ACCOUNT.value,
status=WorkflowNodeExecutionStatus.SUCCEEDED,
created_by_role=CreatorUserRole.ACCOUNT,
created_by=str(uuid4()),
execution_metadata=json.dumps(metadata),
)
@ -660,14 +660,14 @@ class TestNodeExecutionRelationships:
tenant_id=str(uuid4()),
app_id=str(uuid4()),
workflow_id=str(uuid4()),
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
workflow_run_id=str(uuid4()),
index=1,
node_id="start",
node_type=BuiltinNodeTypes.START,
title="Start",
status=WorkflowNodeExecutionStatus.SUCCEEDED.value,
created_by_role=CreatorUserRole.ACCOUNT.value,
status=WorkflowNodeExecutionStatus.SUCCEEDED,
created_by_role=CreatorUserRole.ACCOUNT,
created_by=str(uuid4()),
execution_metadata=None,
)
@ -696,14 +696,14 @@ class TestNodeExecutionRelationships:
tenant_id=str(uuid4()),
app_id=str(uuid4()),
workflow_id=str(uuid4()),
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
workflow_run_id=str(uuid4()),
index=1,
node_id=f"{node_type}_1",
node_type=node_type,
title=title,
status=WorkflowNodeExecutionStatus.SUCCEEDED.value,
created_by_role=CreatorUserRole.ACCOUNT.value,
status=WorkflowNodeExecutionStatus.SUCCEEDED,
created_by_role=CreatorUserRole.ACCOUNT,
created_by=str(uuid4()),
)
@ -734,7 +734,7 @@ class TestGraphConfigurationValidation:
workflow = Workflow.new(
tenant_id=str(uuid4()),
app_id=str(uuid4()),
type=WorkflowType.WORKFLOW.value,
type=WorkflowType.WORKFLOW,
version="draft",
graph=json.dumps(graph_config),
features="{}",
@ -761,7 +761,7 @@ class TestGraphConfigurationValidation:
workflow = Workflow.new(
tenant_id=str(uuid4()),
app_id=str(uuid4()),
type=WorkflowType.WORKFLOW.value,
type=WorkflowType.WORKFLOW,
version="draft",
graph=json.dumps(graph_config),
features="{}",
@ -802,7 +802,7 @@ class TestGraphConfigurationValidation:
workflow = Workflow.new(
tenant_id=str(uuid4()),
app_id=str(uuid4()),
type=WorkflowType.WORKFLOW.value,
type=WorkflowType.WORKFLOW,
version="draft",
graph=json.dumps(graph_config),
features="{}",
@ -835,11 +835,11 @@ class TestGraphConfigurationValidation:
tenant_id=str(uuid4()),
app_id=str(uuid4()),
workflow_id=str(uuid4()),
type=WorkflowType.WORKFLOW.value,
triggered_from=WorkflowRunTriggeredFrom.APP_RUN.value,
type=WorkflowType.WORKFLOW,
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
version="v1.0",
status=WorkflowExecutionStatus.RUNNING.value,
created_by_role=CreatorUserRole.ACCOUNT.value,
status=WorkflowExecutionStatus.RUNNING,
created_by_role=CreatorUserRole.ACCOUNT,
created_by=str(uuid4()),
graph=json.dumps(original_graph),
)
@ -872,7 +872,7 @@ class TestGraphConfigurationValidation:
workflow = Workflow.new(
tenant_id=str(uuid4()),
app_id=str(uuid4()),
type=WorkflowType.WORKFLOW.value,
type=WorkflowType.WORKFLOW,
version="draft",
graph=json.dumps(graph_config),
features="{}",
@ -912,7 +912,7 @@ class TestGraphConfigurationValidation:
workflow = Workflow.new(
tenant_id=str(uuid4()),
app_id=str(uuid4()),
type=WorkflowType.WORKFLOW.value,
type=WorkflowType.WORKFLOW,
version="draft",
graph=json.dumps(graph_config),
features="{}",
@ -933,7 +933,7 @@ class TestGraphConfigurationValidation:
workflow = Workflow.new(
tenant_id=str(uuid4()),
app_id=str(uuid4()),
type=WorkflowType.WORKFLOW.value,
type=WorkflowType.WORKFLOW,
version="draft",
graph=None,
features="{}",
@ -956,11 +956,11 @@ class TestGraphConfigurationValidation:
tenant_id=str(uuid4()),
app_id=str(uuid4()),
workflow_id=str(uuid4()),
type=WorkflowType.WORKFLOW.value,
triggered_from=WorkflowRunTriggeredFrom.APP_RUN.value,
type=WorkflowType.WORKFLOW,
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
version="v1.0",
status=WorkflowExecutionStatus.RUNNING.value,
created_by_role=CreatorUserRole.ACCOUNT.value,
status=WorkflowExecutionStatus.RUNNING,
created_by_role=CreatorUserRole.ACCOUNT,
created_by=str(uuid4()),
inputs=None,
)
@ -978,11 +978,11 @@ class TestGraphConfigurationValidation:
tenant_id=str(uuid4()),
app_id=str(uuid4()),
workflow_id=str(uuid4()),
type=WorkflowType.WORKFLOW.value,
triggered_from=WorkflowRunTriggeredFrom.APP_RUN.value,
type=WorkflowType.WORKFLOW,
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
version="v1.0",
status=WorkflowExecutionStatus.RUNNING.value,
created_by_role=CreatorUserRole.ACCOUNT.value,
status=WorkflowExecutionStatus.RUNNING,
created_by_role=CreatorUserRole.ACCOUNT,
created_by=str(uuid4()),
outputs=None,
)
@ -1000,14 +1000,14 @@ class TestGraphConfigurationValidation:
tenant_id=str(uuid4()),
app_id=str(uuid4()),
workflow_id=str(uuid4()),
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
workflow_run_id=str(uuid4()),
index=1,
node_id="start",
node_type=BuiltinNodeTypes.START,
title="Start",
status=WorkflowNodeExecutionStatus.SUCCEEDED.value,
created_by_role=CreatorUserRole.ACCOUNT.value,
status=WorkflowNodeExecutionStatus.SUCCEEDED,
created_by_role=CreatorUserRole.ACCOUNT,
created_by=str(uuid4()),
inputs=None,
)
@ -1025,14 +1025,14 @@ class TestGraphConfigurationValidation:
tenant_id=str(uuid4()),
app_id=str(uuid4()),
workflow_id=str(uuid4()),
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
workflow_run_id=str(uuid4()),
index=1,
node_id="start",
node_type=BuiltinNodeTypes.START,
title="Start",
status=WorkflowNodeExecutionStatus.SUCCEEDED.value,
created_by_role=CreatorUserRole.ACCOUNT.value,
status=WorkflowNodeExecutionStatus.SUCCEEDED,
created_by_role=CreatorUserRole.ACCOUNT,
created_by=str(uuid4()),
outputs=None,
)

View File

@ -0,0 +1,70 @@
import base64
import hashlib
import os
from io import BytesIO
from types import SimpleNamespace
import pytest
from _pytest.monkeypatch import MonkeyPatch
from baidubce.services.bos.bos_client import BosClient
from tests.unit_tests.oss.__mock.base import (
get_example_bucket,
get_example_data,
get_example_filename,
get_example_filepath,
)
class MockBaiduObsClass:
def __init__(self, config=None):
self.bucket_name = get_example_bucket()
self.key = get_example_filename()
self.content = get_example_data()
self.filepath = get_example_filepath()
def put_object(self, bucket_name, key, data, content_length=None, content_md5=None, **kwargs):
assert bucket_name == self.bucket_name
assert key == self.key
assert data == self.content
assert content_length == len(self.content)
expected_md5 = base64.standard_b64encode(hashlib.md5(self.content).digest())
assert content_md5 == expected_md5
def get_object(self, bucket_name, key, **kwargs):
assert bucket_name == self.bucket_name
assert key == self.key
return SimpleNamespace(data=BytesIO(self.content))
def get_object_to_file(self, bucket_name, key, file_name, **kwargs):
assert bucket_name == self.bucket_name
assert key == self.key
assert file_name == self.filepath
def get_object_meta_data(self, bucket_name, key, **kwargs):
assert bucket_name == self.bucket_name
assert key == self.key
return SimpleNamespace(status=200)
def delete_object(self, bucket_name, key, **kwargs):
assert bucket_name == self.bucket_name
assert key == self.key
MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"
@pytest.fixture
def setup_baidu_obs_mock(monkeypatch: MonkeyPatch):
if MOCK:
monkeypatch.setattr(BosClient, "__init__", MockBaiduObsClass.__init__)
monkeypatch.setattr(BosClient, "put_object", MockBaiduObsClass.put_object)
monkeypatch.setattr(BosClient, "get_object", MockBaiduObsClass.get_object)
monkeypatch.setattr(BosClient, "get_object_to_file", MockBaiduObsClass.get_object_to_file)
monkeypatch.setattr(BosClient, "get_object_meta_data", MockBaiduObsClass.get_object_meta_data)
monkeypatch.setattr(BosClient, "delete_object", MockBaiduObsClass.delete_object)
yield
if MOCK:
monkeypatch.undo()

View File

@ -0,0 +1 @@

View File

@ -0,0 +1,59 @@
from unittest.mock import MagicMock, patch
import pytest
from baidubce.auth.bce_credentials import BceCredentials
from baidubce.bce_client_configuration import BceClientConfiguration
from extensions.storage.baidu_obs_storage import BaiduObsStorage
from tests.unit_tests.oss.__mock.baidu_obs import setup_baidu_obs_mock
from tests.unit_tests.oss.__mock.base import (
BaseStorageTest,
get_example_bucket,
)
class TestBaiduObs(BaseStorageTest):
@pytest.fixture(autouse=True)
def setup_method(self, setup_baidu_obs_mock):
"""Executed before each test method."""
with (
patch.object(BceCredentials, "__init__", return_value=None),
patch.object(BceClientConfiguration, "__init__", return_value=None),
):
self.storage = BaiduObsStorage()
self.storage.bucket_name = get_example_bucket()
class TestBaiduObsConfiguration:
def test_init_with_config(self):
mock_dify_config = MagicMock()
mock_dify_config.BAIDU_OBS_BUCKET_NAME = "test-bucket"
mock_dify_config.BAIDU_OBS_ACCESS_KEY = "test-access-key"
mock_dify_config.BAIDU_OBS_SECRET_KEY = "test-secret-key"
mock_dify_config.BAIDU_OBS_ENDPOINT = "https://bj.bcebos.com"
mock_credentials = MagicMock(name="credentials")
mock_config = MagicMock(name="config")
mock_client = MagicMock(name="client")
with (
patch("extensions.storage.baidu_obs_storage.dify_config", mock_dify_config),
patch("extensions.storage.baidu_obs_storage.BceCredentials", return_value=mock_credentials) as credentials,
patch(
"extensions.storage.baidu_obs_storage.BceClientConfiguration", return_value=mock_config
) as configuration,
patch("extensions.storage.baidu_obs_storage.BosClient", return_value=mock_client) as client_cls,
):
storage = BaiduObsStorage()
assert storage.bucket_name == "test-bucket"
assert storage.client == mock_client
credentials.assert_called_once_with(
access_key_id="test-access-key",
secret_access_key="test-secret-key",
)
configuration.assert_called_once_with(
credentials=mock_credentials,
endpoint="https://bj.bcebos.com",
)
client_cls.assert_called_once_with(config=mock_config)

View File

@ -0,0 +1,291 @@
"""
Unit tests for clean_document_task.
Focuses on the resilience contract added by the billing-failure fix:
``index_processor.clean()`` is wrapped in ``try/except`` so that a transient
failure inside the vector / keyword cleanup (e.g. ``ValueError("Unable to
retrieve billing information...")`` raised by ``BillingService._send_request``
when ``Vector(dataset)`` transitively triggers ``FeatureService.get_features``)
does not abort the entire task and leave PG with stranded ``DocumentSegment``
/ ``ChildChunk`` / ``UploadFile`` / ``DatasetMetadataBinding`` rows.
"""
import uuid
from unittest.mock import MagicMock, patch
import pytest
from tasks.clean_document_task import clean_document_task
@pytest.fixture
def document_id():
return str(uuid.uuid4())
@pytest.fixture
def dataset_id():
return str(uuid.uuid4())
@pytest.fixture
def tenant_id():
return str(uuid.uuid4())
@pytest.fixture
def mock_session_factory():
"""Patch ``session_factory.create_session`` to return per-call mock sessions.
Each call to ``create_session()`` yields a fresh ``MagicMock`` session so we
can assert ``execute()`` calls across the multiple short-lived transactions
used by ``clean_document_task``.
"""
with patch("tasks.clean_document_task.session_factory", autospec=True) as mock_sf:
sessions: list[MagicMock] = []
def _create_session():
session = MagicMock()
session.scalars.return_value.all.return_value = []
session.execute.return_value.all.return_value = []
session.scalar.return_value = None
cm = MagicMock()
cm.__enter__.return_value = session
cm.__exit__.return_value = None
sessions.append(session)
return cm
mock_sf.create_session.side_effect = _create_session
yield mock_sf, sessions
@pytest.fixture
def mock_storage():
with patch("tasks.clean_document_task.storage", autospec=True) as mock:
mock.delete.return_value = None
yield mock
@pytest.fixture
def mock_index_processor_factory():
"""Mock ``IndexProcessorFactory`` so we can inject behavior into ``clean``."""
with patch("tasks.clean_document_task.IndexProcessorFactory", autospec=True) as factory_cls:
processor = MagicMock()
processor.clean.return_value = None
factory_instance = MagicMock()
factory_instance.init_index_processor.return_value = processor
factory_cls.return_value = factory_instance
yield {
"factory_cls": factory_cls,
"factory_instance": factory_instance,
"processor": processor,
}
def _build_segment(segment_id: str, content: str = "segment content") -> MagicMock:
seg = MagicMock()
seg.id = segment_id
seg.index_node_id = f"node-{segment_id}"
seg.content = content
return seg
def _build_dataset(dataset_id: str, tenant_id: str) -> MagicMock:
ds = MagicMock()
ds.id = dataset_id
ds.tenant_id = tenant_id
return ds
class TestVectorCleanupResilience:
"""Vector / keyword cleanup must not abort the task on transient failure."""
def test_billing_failure_during_vector_cleanup_does_not_skip_pg_cleanup(
self,
document_id,
dataset_id,
tenant_id,
mock_session_factory,
mock_storage,
mock_index_processor_factory,
):
"""Reproduces the production incident:
``Vector(dataset)`` transitively calls ``FeatureService.get_features``
which calls ``BillingService._send_request("GET", ...)``. When billing
returns non-200 it raises ``ValueError("Unable to retrieve billing
information...")``. Before the fix this propagated out of
``clean_document_task`` and left ``DocumentSegment`` / ``ChildChunk`` /
``UploadFile`` / ``DatasetMetadataBinding`` rows orphaned because the
already-deleted ``Document`` row had been hard-committed by the caller
(``dataset_service.delete_document``) before ``.delay()`` was invoked.
Contract: a billing failure inside ``index_processor.clean()`` must be
caught, logged, and the rest of the task must continue so PG ends up
consistent with the deleted ``Document`` even if Qdrant retains
orphan vectors that can be reaped later.
"""
mock_sf, sessions = mock_session_factory
# First create_session(): Step 1 (load segments + attachments).
step1_session = MagicMock()
step1_session.scalars.return_value.all.return_value = [
_build_segment("seg-1"),
_build_segment("seg-2"),
]
step1_session.execute.return_value.all.return_value = []
step1_session.scalar.return_value = _build_dataset(dataset_id, tenant_id)
# Second create_session(): Step 2 (vector cleanup). Returns dataset.
step2_session = MagicMock()
step2_session.scalar.return_value = _build_dataset(dataset_id, tenant_id)
step2_session.scalars.return_value.all.return_value = []
step2_session.execute.return_value.all.return_value = []
# Subsequent sessions: Step 3+ (image / segment / file / metadata cleanup).
# Default fixture returns empty results which is fine for these short txns.
cm1, cm2 = MagicMock(), MagicMock()
cm1.__enter__.return_value = step1_session
cm1.__exit__.return_value = None
cm2.__enter__.return_value = step2_session
cm2.__exit__.return_value = None
def _default_cm():
session = MagicMock()
session.scalars.return_value.all.return_value = []
session.execute.return_value.all.return_value = []
session.scalar.return_value = None
cm = MagicMock()
cm.__enter__.return_value = session
cm.__exit__.return_value = None
sessions.append(session)
return cm
mock_sf.create_session.side_effect = [cm1, cm2] + [_default_cm() for _ in range(10)]
# Simulate the production failure: index_processor.clean() raises ValueError
# mirroring BillingService._send_request when billing returns non-200.
mock_index_processor_factory["processor"].clean.side_effect = ValueError(
"Unable to retrieve billing information. Please try again later or contact support."
)
# Act — must not raise out of the task even though clean() raises.
clean_document_task(
document_id=document_id,
dataset_id=dataset_id,
doc_form="paragraph",
file_id=None,
)
# Assert
# 1. Vector cleanup was attempted.
mock_index_processor_factory["processor"].clean.assert_called_once()
# 2. Despite the failure the task continued: at least one DocumentSegment
# delete was issued. We use the count of session.execute calls across
# later short transactions as a proxy for "Step 3+ executed".
execute_calls = sum(s.execute.call_count for s in sessions)
assert execute_calls > 0, (
"Step 3+ DB cleanup did not run after vector cleanup failure; "
"this regression would re-introduce the orphan-segment bug."
)
def test_vector_cleanup_success_path_remains_unaffected(
self,
document_id,
dataset_id,
tenant_id,
mock_session_factory,
mock_storage,
mock_index_processor_factory,
):
"""Backward-compat: the happy path must still call ``clean()`` exactly
once with the expected arguments and complete without errors.
"""
mock_sf, sessions = mock_session_factory
step1_session = MagicMock()
step1_session.scalars.return_value.all.return_value = [_build_segment("seg-1")]
step1_session.execute.return_value.all.return_value = []
step1_session.scalar.return_value = _build_dataset(dataset_id, tenant_id)
step2_session = MagicMock()
step2_session.scalar.return_value = _build_dataset(dataset_id, tenant_id)
step2_session.scalars.return_value.all.return_value = []
step2_session.execute.return_value.all.return_value = []
cm1, cm2 = MagicMock(), MagicMock()
cm1.__enter__.return_value = step1_session
cm1.__exit__.return_value = None
cm2.__enter__.return_value = step2_session
cm2.__exit__.return_value = None
def _default_cm():
session = MagicMock()
session.scalars.return_value.all.return_value = []
session.execute.return_value.all.return_value = []
session.scalar.return_value = None
cm = MagicMock()
cm.__enter__.return_value = session
cm.__exit__.return_value = None
sessions.append(session)
return cm
mock_sf.create_session.side_effect = [cm1, cm2] + [_default_cm() for _ in range(10)]
clean_document_task(
document_id=document_id,
dataset_id=dataset_id,
doc_form="paragraph",
file_id=None,
)
assert mock_index_processor_factory["processor"].clean.call_count == 1
# Index cleanup invoked with the expected delete_summaries / delete_child_chunks flags.
_, kwargs = mock_index_processor_factory["processor"].clean.call_args
assert kwargs.get("with_keywords") is True
assert kwargs.get("delete_child_chunks") is True
assert kwargs.get("delete_summaries") is True
def test_no_segments_skips_vector_cleanup(
self,
document_id,
dataset_id,
tenant_id,
mock_session_factory,
mock_storage,
mock_index_processor_factory,
):
"""When the document has no segments (e.g. indexing failed before
producing any), vector cleanup must not be attempted and therefore
the new try/except wrapper does not change behavior here.
"""
mock_sf, sessions = mock_session_factory
step1_session = MagicMock()
step1_session.scalars.return_value.all.return_value = [] # no segments
step1_session.execute.return_value.all.return_value = []
step1_session.scalar.return_value = _build_dataset(dataset_id, tenant_id)
cm1 = MagicMock()
cm1.__enter__.return_value = step1_session
cm1.__exit__.return_value = None
def _default_cm():
session = MagicMock()
session.scalars.return_value.all.return_value = []
session.execute.return_value.all.return_value = []
session.scalar.return_value = None
cm = MagicMock()
cm.__enter__.return_value = session
cm.__exit__.return_value = None
sessions.append(session)
return cm
mock_sf.create_session.side_effect = [cm1] + [_default_cm() for _ in range(10)]
clean_document_task(
document_id=document_id,
dataset_id=dataset_id,
doc_form="paragraph",
file_id=None,
)
# Vector cleanup is gated on ``index_node_ids``; when there are no
# segments the IndexProcessorFactory path is never entered.
mock_index_processor_factory["factory_cls"].assert_not_called()

View File

@ -0,0 +1,289 @@
from __future__ import annotations
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
from core.plugin.entities.marketplace import MarketplacePluginSnapshot
from core.plugin.entities.plugin import PluginInstallationSource
from models.account import TenantPluginAutoUpgradeStrategy
MODULE = "tasks.process_tenant_plugin_autoupgrade_check_task"
def _make_plugin(plugin_id: str, version: str, source=PluginInstallationSource.Marketplace):
"""Build a minimal stand-in for a PluginInstallation entry returned by manager.list_plugins."""
return SimpleNamespace(
plugin_id=plugin_id,
version=version,
plugin_unique_identifier=f"{plugin_id}:{version}@deadbeef",
source=source,
)
def _make_manifest(plugin_id: str, latest_version: str) -> MarketplacePluginSnapshot:
org, name = plugin_id.split("/", 1)
return MarketplacePluginSnapshot(
org=org,
name=name,
latest_version=latest_version,
latest_package_identifier=f"{plugin_id}:{latest_version}@cafe1234",
latest_package_url=f"https://marketplace.example/{plugin_id}/{latest_version}.difypkg",
)
def _run_task(
*,
plugins: list,
manifests: list[MarketplacePluginSnapshot],
strategy_setting=TenantPluginAutoUpgradeStrategy.StrategySetting.LATEST,
upgrade_mode=TenantPluginAutoUpgradeStrategy.UpgradeMode.ALL,
exclude_plugins=None,
include_plugins=None,
):
"""
Execute the celery task synchronously with mocks for the plugin manager,
the marketplace cache and PluginService.upgrade_plugin_with_marketplace.
Returns the upgrade-call recorder so each test can assert on it.
"""
fake_manager = MagicMock()
fake_manager.list_plugins.return_value = plugins
upgrade_calls: list[tuple[str, str, str]] = []
def _record_upgrade(tenant_id, original, new):
upgrade_calls.append((tenant_id, original, new))
with (
patch(f"{MODULE}.PluginInstaller", return_value=fake_manager),
patch(f"{MODULE}.marketplace_batch_fetch_plugin_manifests", return_value=manifests),
patch(
f"{MODULE}.PluginService.upgrade_plugin_with_marketplace",
side_effect=_record_upgrade,
) as upgrade_mock,
):
from tasks.process_tenant_plugin_autoupgrade_check_task import (
process_tenant_plugin_autoupgrade_check_task,
)
process_tenant_plugin_autoupgrade_check_task(
"tenant-1",
strategy_setting,
0,
upgrade_mode,
exclude_plugins or [],
include_plugins or [],
)
return upgrade_mock, upgrade_calls
class TestUpgradeCallsMarketplaceService:
"""
Regression test for the bug where the auto-upgrade task called
manager.upgrade_plugin directly, which skipped downloading the new package
from marketplace and uploading it to the daemon. The daemon then failed with
"package file not found" and the upgrade silently never completed.
"""
def test_upgrade_routes_through_plugin_service(self):
plugin = _make_plugin("acme/foo", "1.0.0")
manifest = _make_manifest("acme/foo", "1.0.1")
upgrade_mock, calls = _run_task(plugins=[plugin], manifests=[manifest])
upgrade_mock.assert_called_once()
assert calls == [("tenant-1", plugin.plugin_unique_identifier, manifest.latest_package_identifier)]
def test_does_not_call_manager_upgrade_plugin_directly(self):
"""Locks in that we never go back to the broken path that bypassed download/upload."""
plugin = _make_plugin("acme/foo", "1.0.0")
manifest = _make_manifest("acme/foo", "1.0.1")
fake_manager = MagicMock()
fake_manager.list_plugins.return_value = [plugin]
with (
patch(f"{MODULE}.PluginInstaller", return_value=fake_manager),
patch(f"{MODULE}.marketplace_batch_fetch_plugin_manifests", return_value=[manifest]),
patch(f"{MODULE}.PluginService.upgrade_plugin_with_marketplace"),
):
from tasks.process_tenant_plugin_autoupgrade_check_task import (
process_tenant_plugin_autoupgrade_check_task,
)
process_tenant_plugin_autoupgrade_check_task(
"tenant-1",
TenantPluginAutoUpgradeStrategy.StrategySetting.LATEST,
0,
TenantPluginAutoUpgradeStrategy.UpgradeMode.ALL,
[],
[],
)
fake_manager.upgrade_plugin.assert_not_called()
class TestStrategySetting:
def test_disabled_strategy_skips_everything(self):
upgrade_mock, _ = _run_task(
plugins=[_make_plugin("acme/foo", "1.0.0")],
manifests=[_make_manifest("acme/foo", "1.0.1")],
strategy_setting=TenantPluginAutoUpgradeStrategy.StrategySetting.DISABLED,
)
upgrade_mock.assert_not_called()
def test_fix_only_upgrades_patch_version(self):
upgrade_mock, calls = _run_task(
plugins=[_make_plugin("acme/foo", "1.0.0")],
manifests=[_make_manifest("acme/foo", "1.0.5")],
strategy_setting=TenantPluginAutoUpgradeStrategy.StrategySetting.FIX_ONLY,
)
upgrade_mock.assert_called_once()
assert calls[0][2].endswith(":1.0.5@cafe1234")
def test_fix_only_skips_minor_bump(self):
upgrade_mock, _ = _run_task(
plugins=[_make_plugin("acme/foo", "1.0.0")],
manifests=[_make_manifest("acme/foo", "1.1.0")],
strategy_setting=TenantPluginAutoUpgradeStrategy.StrategySetting.FIX_ONLY,
)
upgrade_mock.assert_not_called()
def test_fix_only_skips_major_bump(self):
upgrade_mock, _ = _run_task(
plugins=[_make_plugin("acme/foo", "1.0.0")],
manifests=[_make_manifest("acme/foo", "2.0.0")],
strategy_setting=TenantPluginAutoUpgradeStrategy.StrategySetting.FIX_ONLY,
)
upgrade_mock.assert_not_called()
def test_latest_strategy_skips_when_versions_equal(self):
upgrade_mock, _ = _run_task(
plugins=[_make_plugin("acme/foo", "1.0.0")],
manifests=[_make_manifest("acme/foo", "1.0.0")],
strategy_setting=TenantPluginAutoUpgradeStrategy.StrategySetting.LATEST,
)
upgrade_mock.assert_not_called()
class TestUpgradeMode:
def test_mode_all_upgrades_every_marketplace_plugin(self):
plugins = [
_make_plugin("acme/foo", "1.0.0"),
_make_plugin("acme/bar", "2.0.0"),
]
manifests = [
_make_manifest("acme/foo", "1.0.1"),
_make_manifest("acme/bar", "2.0.1"),
]
upgrade_mock, calls = _run_task(
plugins=plugins,
manifests=manifests,
upgrade_mode=TenantPluginAutoUpgradeStrategy.UpgradeMode.ALL,
)
assert upgrade_mock.call_count == 2
upgraded_ids = sorted(c[1] for c in calls)
assert upgraded_ids == sorted(p.plugin_unique_identifier for p in plugins)
def test_mode_all_skips_non_marketplace_sources(self):
plugins = [
_make_plugin("acme/foo", "1.0.0"),
_make_plugin("acme/bar", "2.0.0", source=PluginInstallationSource.Github),
]
manifests = [
_make_manifest("acme/foo", "1.0.1"),
_make_manifest("acme/bar", "2.0.1"),
]
upgrade_mock, calls = _run_task(
plugins=plugins,
manifests=manifests,
upgrade_mode=TenantPluginAutoUpgradeStrategy.UpgradeMode.ALL,
)
assert upgrade_mock.call_count == 1
assert calls[0][1] == plugins[0].plugin_unique_identifier
def test_mode_partial_only_upgrades_included_plugins(self):
plugins = [
_make_plugin("acme/foo", "1.0.0"),
_make_plugin("acme/bar", "2.0.0"),
]
manifests = [
_make_manifest("acme/foo", "1.0.1"),
_make_manifest("acme/bar", "2.0.1"),
]
upgrade_mock, calls = _run_task(
plugins=plugins,
manifests=manifests,
upgrade_mode=TenantPluginAutoUpgradeStrategy.UpgradeMode.PARTIAL,
include_plugins=["acme/foo"],
)
assert upgrade_mock.call_count == 1
assert calls[0][1] == plugins[0].plugin_unique_identifier
def test_mode_exclude_skips_excluded_plugins(self):
plugins = [
_make_plugin("acme/foo", "1.0.0"),
_make_plugin("acme/bar", "2.0.0"),
]
manifests = [
_make_manifest("acme/foo", "1.0.1"),
_make_manifest("acme/bar", "2.0.1"),
]
upgrade_mock, calls = _run_task(
plugins=plugins,
manifests=manifests,
upgrade_mode=TenantPluginAutoUpgradeStrategy.UpgradeMode.EXCLUDE,
exclude_plugins=["acme/bar"],
)
assert upgrade_mock.call_count == 1
assert calls[0][1] == plugins[0].plugin_unique_identifier
class TestErrorIsolation:
def test_one_plugin_failure_does_not_block_others(self):
plugins = [
_make_plugin("acme/foo", "1.0.0"),
_make_plugin("acme/bar", "2.0.0"),
]
manifests = [
_make_manifest("acme/foo", "1.0.1"),
_make_manifest("acme/bar", "2.0.1"),
]
fake_manager = MagicMock()
fake_manager.list_plugins.return_value = plugins
seen: list[str] = []
def _upgrade(tenant_id, original, new):
seen.append(original)
if "foo" in original:
raise RuntimeError("boom")
with (
patch(f"{MODULE}.PluginInstaller", return_value=fake_manager),
patch(f"{MODULE}.marketplace_batch_fetch_plugin_manifests", return_value=manifests),
patch(f"{MODULE}.PluginService.upgrade_plugin_with_marketplace", side_effect=_upgrade),
):
from tasks.process_tenant_plugin_autoupgrade_check_task import (
process_tenant_plugin_autoupgrade_check_task,
)
process_tenant_plugin_autoupgrade_check_task(
"tenant-1",
TenantPluginAutoUpgradeStrategy.StrategySetting.LATEST,
0,
TenantPluginAutoUpgradeStrategy.UpgradeMode.ALL,
[],
[],
)
assert any("foo" in s for s in seen)
assert any("bar" in s for s in seen)

View File

@ -35,8 +35,8 @@
"stepOne.uploader.cancel": "Cancel",
"stepOne.uploader.change": "Change",
"stepOne.uploader.failed": "Upload failed",
"stepOne.uploader.tip": "Supports {{supportTypes}}. Max {{batchCount}} in a batch and {{size}} MB each.",
"stepOne.uploader.tipWithTotalLimit": "Supports {{supportTypes}}. Max {{batchCount}} in a batch and {{size}} MB each. Max total {{totalCount}} files.",
"stepOne.uploader.tip": "Ondersteunt {{supportTypes}}. Maximaal {{batchCount}} per batch en {{size}} MB per bestand.",
"stepOne.uploader.tipWithTotalLimit": "Ondersteunt {{supportTypes}}. Maximaal {{batchCount}} per batch en {{size}} MB per bestand. Maximaal {{totalCount}} bestanden in totaal.",
"stepOne.uploader.title": "Upload file",
"stepOne.uploader.validation.count": "Multiple files not supported",
"stepOne.uploader.validation.filesNumber": "You have reached the batch upload limit of {{filesNumber}}.",