merge evaluation-fe

This commit is contained in:
JzoNg 2026-04-27 14:36:32 +08:00
commit 47050b8d15
196 changed files with 3144 additions and 6784 deletions

View File

@ -16,7 +16,7 @@ concurrency:
jobs:
api-unit:
name: API Unit Tests
runs-on: ubuntu-latest
runs-on: depot-ubuntu-24.04
env:
COVERAGE_FILE: coverage-unit
defaults:
@ -62,7 +62,7 @@ jobs:
api-integration:
name: API Integration Tests
runs-on: ubuntu-latest
runs-on: depot-ubuntu-24.04
env:
COVERAGE_FILE: coverage-integration
STORAGE_TYPE: opendal
@ -137,7 +137,7 @@ jobs:
api-coverage:
name: API Coverage
runs-on: ubuntu-latest
runs-on: depot-ubuntu-24.04
needs:
- api-unit
- api-integration

View File

@ -13,7 +13,7 @@ permissions:
jobs:
autofix:
if: github.repository == 'langgenius/dify'
runs-on: ubuntu-latest
runs-on: depot-ubuntu-24.04
steps:
- name: Complete merge group check
if: github.event_name == 'merge_group'

View File

@ -26,6 +26,9 @@ jobs:
build:
runs-on: ${{ matrix.runs_on }}
if: github.repository == 'langgenius/dify'
permissions:
contents: read
id-token: write
strategy:
matrix:
include:
@ -35,28 +38,28 @@ jobs:
build_context: "{{defaultContext}}:api"
file: "Dockerfile"
platform: linux/amd64
runs_on: ubuntu-latest
runs_on: depot-ubuntu-24.04-4
- service_name: "build-api-arm64"
image_name_env: "DIFY_API_IMAGE_NAME"
artifact_context: "api"
build_context: "{{defaultContext}}:api"
file: "Dockerfile"
platform: linux/arm64
runs_on: ubuntu-24.04-arm
runs_on: depot-ubuntu-24.04-4
- service_name: "build-web-amd64"
image_name_env: "DIFY_WEB_IMAGE_NAME"
artifact_context: "web"
build_context: "{{defaultContext}}"
file: "web/Dockerfile"
platform: linux/amd64
runs_on: ubuntu-latest
runs_on: depot-ubuntu-24.04-4
- service_name: "build-web-arm64"
image_name_env: "DIFY_WEB_IMAGE_NAME"
artifact_context: "web"
build_context: "{{defaultContext}}"
file: "web/Dockerfile"
platform: linux/arm64
runs_on: ubuntu-24.04-arm
runs_on: depot-ubuntu-24.04-4
steps:
- name: Prepare
@ -70,8 +73,8 @@ jobs:
username: ${{ env.DOCKERHUB_USER }}
password: ${{ env.DOCKERHUB_TOKEN }}
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4.0.0
- name: Set up Depot CLI
uses: depot/setup-action@v1
- name: Extract metadata for Docker
id: meta
@ -81,16 +84,15 @@ jobs:
- name: Build Docker image
id: build
uses: docker/build-push-action@bcafcacb16a39f128d818304e6c9c0c18556b85f # v7.1.0
uses: depot/build-push-action@v1
with:
project: ${{ vars.DEPOT_PROJECT_ID }}
context: ${{ matrix.build_context }}
file: ${{ matrix.file }}
platforms: ${{ matrix.platform }}
build-args: COMMIT_SHA=${{ fromJSON(steps.meta.outputs.json).labels['org.opencontainers.image.revision'] }}
labels: ${{ steps.meta.outputs.labels }}
outputs: type=image,name=${{ env[matrix.image_name_env] }},push-by-digest=true,name-canonical=true,push=true
cache-from: type=gha,scope=${{ matrix.service_name }}
cache-to: type=gha,mode=max,scope=${{ matrix.service_name }}
- name: Export digest
env:
@ -108,9 +110,33 @@ jobs:
if-no-files-found: error
retention-days: 1
fork-build-validate:
if: github.repository != 'langgenius/dify'
runs-on: ubuntu-24.04
strategy:
matrix:
include:
- service_name: "validate-api-amd64"
build_context: "{{defaultContext}}:api"
file: "Dockerfile"
- service_name: "validate-web-amd64"
build_context: "{{defaultContext}}"
file: "web/Dockerfile"
steps:
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@98e3b2c9eab4f4f98a95c0c0a3ea5e5e672fd2a8 # v3.10.0
- name: Validate Docker image
uses: docker/build-push-action@5cd29d66b4a8d8e6f4d5dfe2e9329f0b1d446289 # v6.18.0
with:
push: false
context: ${{ matrix.build_context }}
file: ${{ matrix.file }}
platforms: linux/amd64
create-manifest:
needs: build
runs-on: ubuntu-latest
runs-on: depot-ubuntu-24.04
if: github.repository == 'langgenius/dify'
strategy:
matrix:

View File

@ -9,7 +9,7 @@ concurrency:
jobs:
db-migration-test-postgres:
runs-on: ubuntu-latest
runs-on: depot-ubuntu-24.04
steps:
- name: Checkout code
@ -59,7 +59,7 @@ jobs:
run: uv run --directory api flask upgrade-db
db-migration-test-mysql:
runs-on: ubuntu-latest
runs-on: depot-ubuntu-24.04
steps:
- name: Checkout code

View File

@ -13,7 +13,7 @@ on:
jobs:
deploy:
runs-on: ubuntu-latest
runs-on: depot-ubuntu-24.04
if: |
github.event.workflow_run.conclusion == 'success' &&
github.event.workflow_run.head_branch == 'deploy/agent-dev'

View File

@ -10,7 +10,7 @@ on:
jobs:
deploy:
runs-on: ubuntu-latest
runs-on: depot-ubuntu-24.04
if: |
github.event.workflow_run.conclusion == 'success' &&
github.event.workflow_run.head_branch == 'deploy/dev'

View File

@ -13,7 +13,7 @@ on:
jobs:
deploy:
runs-on: ubuntu-latest
runs-on: depot-ubuntu-24.04
if: |
github.event.workflow_run.conclusion == 'success' &&
github.event.workflow_run.head_branch == 'deploy/enterprise'

View File

@ -10,7 +10,7 @@ on:
jobs:
deploy:
runs-on: ubuntu-latest
runs-on: depot-ubuntu-24.04
if: |
github.event.workflow_run.conclusion == 'success' &&
github.event.workflow_run.head_branch == 'build/feat/hitl'

View File

@ -14,40 +14,69 @@ concurrency:
jobs:
build-docker:
if: github.event.pull_request.head.repo.full_name == github.repository
runs-on: ${{ matrix.runs_on }}
permissions:
contents: read
id-token: write
strategy:
matrix:
include:
- service_name: "api-amd64"
platform: linux/amd64
runs_on: ubuntu-latest
runs_on: depot-ubuntu-24.04-4
context: "{{defaultContext}}:api"
file: "Dockerfile"
- service_name: "api-arm64"
platform: linux/arm64
runs_on: ubuntu-24.04-arm
runs_on: depot-ubuntu-24.04-4
context: "{{defaultContext}}:api"
file: "Dockerfile"
- service_name: "web-amd64"
platform: linux/amd64
runs_on: ubuntu-latest
runs_on: depot-ubuntu-24.04-4
context: "{{defaultContext}}"
file: "web/Dockerfile"
- service_name: "web-arm64"
platform: linux/arm64
runs_on: ubuntu-24.04-arm
runs_on: depot-ubuntu-24.04-4
context: "{{defaultContext}}"
file: "web/Dockerfile"
steps:
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4.0.0
- name: Set up Depot CLI
uses: depot/setup-action@v1
- name: Build Docker Image
uses: docker/build-push-action@bcafcacb16a39f128d818304e6c9c0c18556b85f # v7.1.0
uses: depot/build-push-action@v1
with:
project: ${{ vars.DEPOT_PROJECT_ID }}
push: false
context: ${{ matrix.context }}
file: ${{ matrix.file }}
platforms: ${{ matrix.platform }}
cache-from: type=gha
cache-to: type=gha,mode=max
build-docker-fork:
if: github.event.pull_request.head.repo.full_name != github.repository
runs-on: ubuntu-24.04
permissions:
contents: read
strategy:
matrix:
include:
- service_name: "api-amd64"
context: "{{defaultContext}}:api"
file: "Dockerfile"
- service_name: "web-amd64"
context: "{{defaultContext}}"
file: "web/Dockerfile"
steps:
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@98e3b2c9eab4f4f98a95c0c0a3ea5e5e672fd2a8 # v3.10.0
- name: Build Docker Image
uses: docker/build-push-action@5cd29d66b4a8d8e6f4d5dfe2e9329f0b1d446289 # v6.18.0
with:
push: false
context: ${{ matrix.context }}
file: ${{ matrix.file }}
platforms: linux/amd64

View File

@ -7,7 +7,7 @@ jobs:
permissions:
contents: read
pull-requests: write
runs-on: ubuntu-latest
runs-on: depot-ubuntu-24.04
steps:
- uses: actions/labeler@634933edcd8ababfe52f92936142cc22ac488b1b # v6.0.1
with:

View File

@ -23,7 +23,7 @@ concurrency:
jobs:
pre_job:
name: Skip Duplicate Checks
runs-on: ubuntu-latest
runs-on: depot-ubuntu-24.04
outputs:
should_skip: ${{ steps.skip_check.outputs.should_skip || 'false' }}
steps:
@ -39,7 +39,7 @@ jobs:
name: Check Changed Files
needs: pre_job
if: needs.pre_job.outputs.should_skip != 'true'
runs-on: ubuntu-latest
runs-on: depot-ubuntu-24.04
outputs:
api-changed: ${{ steps.changes.outputs.api }}
e2e-changed: ${{ steps.changes.outputs.e2e }}
@ -141,7 +141,7 @@ jobs:
- pre_job
- check-changes
if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.api-changed != 'true'
runs-on: ubuntu-latest
runs-on: depot-ubuntu-24.04
steps:
- name: Report skipped API tests
run: echo "No API-related changes detected; skipping API tests."
@ -154,7 +154,7 @@ jobs:
- check-changes
- api-tests-run
- api-tests-skip
runs-on: ubuntu-latest
runs-on: depot-ubuntu-24.04
steps:
- name: Finalize API Tests status
env:
@ -201,7 +201,7 @@ jobs:
- pre_job
- check-changes
if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.web-changed != 'true'
runs-on: ubuntu-latest
runs-on: depot-ubuntu-24.04
steps:
- name: Report skipped web tests
run: echo "No web-related changes detected; skipping web tests."
@ -214,7 +214,7 @@ jobs:
- check-changes
- web-tests-run
- web-tests-skip
runs-on: ubuntu-latest
runs-on: depot-ubuntu-24.04
steps:
- name: Finalize Web Tests status
env:
@ -260,7 +260,7 @@ jobs:
- pre_job
- check-changes
if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.e2e-changed != 'true'
runs-on: ubuntu-latest
runs-on: depot-ubuntu-24.04
steps:
- name: Report skipped web full-stack e2e
run: echo "No E2E-related changes detected; skipping web full-stack E2E."
@ -273,7 +273,7 @@ jobs:
- check-changes
- web-e2e-run
- web-e2e-skip
runs-on: ubuntu-latest
runs-on: depot-ubuntu-24.04
steps:
- name: Finalize Web Full-Stack E2E status
env:
@ -325,7 +325,7 @@ jobs:
- pre_job
- check-changes
if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.vdb-changed != 'true'
runs-on: ubuntu-latest
runs-on: depot-ubuntu-24.04
steps:
- name: Report skipped VDB tests
run: echo "No VDB-related changes detected; skipping VDB tests."
@ -338,7 +338,7 @@ jobs:
- check-changes
- vdb-tests-run
- vdb-tests-skip
runs-on: ubuntu-latest
runs-on: depot-ubuntu-24.04
steps:
- name: Finalize VDB Tests status
env:
@ -384,7 +384,7 @@ jobs:
- pre_job
- check-changes
if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.migration-changed != 'true'
runs-on: ubuntu-latest
runs-on: depot-ubuntu-24.04
steps:
- name: Report skipped DB migration tests
run: echo "No migration-related changes detected; skipping DB migration tests."
@ -397,7 +397,7 @@ jobs:
- check-changes
- db-migration-test-run
- db-migration-test-skip
runs-on: ubuntu-latest
runs-on: depot-ubuntu-24.04
steps:
- name: Finalize DB Migration Test status
env:

View File

@ -12,7 +12,7 @@ permissions: {}
jobs:
comment:
name: Comment PR with pyrefly diff
runs-on: ubuntu-latest
runs-on: depot-ubuntu-24.04
permissions:
actions: read
contents: read

View File

@ -10,7 +10,7 @@ permissions:
jobs:
pyrefly-diff:
runs-on: ubuntu-latest
runs-on: depot-ubuntu-24.04
permissions:
contents: read
issues: write

View File

@ -12,7 +12,7 @@ permissions: {}
jobs:
comment:
name: Comment PR with type coverage
runs-on: ubuntu-latest
runs-on: depot-ubuntu-24.04
permissions:
actions: read
contents: read

View File

@ -10,7 +10,7 @@ permissions:
jobs:
pyrefly-type-coverage:
runs-on: ubuntu-latest
runs-on: depot-ubuntu-24.04
permissions:
contents: read
issues: write

View File

@ -16,7 +16,7 @@ jobs:
name: Validate PR title
permissions:
pull-requests: read
runs-on: ubuntu-latest
runs-on: depot-ubuntu-24.04
steps:
- name: Complete merge group check
if: github.event_name == 'merge_group'

View File

@ -12,7 +12,7 @@ on:
jobs:
stale:
runs-on: ubuntu-latest
runs-on: depot-ubuntu-24.04
permissions:
issues: write
pull-requests: write

View File

@ -15,7 +15,7 @@ permissions:
jobs:
python-style:
name: Python Style
runs-on: ubuntu-latest
runs-on: depot-ubuntu-24.04
steps:
- name: Checkout code
@ -57,7 +57,7 @@ jobs:
web-style:
name: Web Style
runs-on: ubuntu-latest
runs-on: depot-ubuntu-24.04
defaults:
run:
working-directory: ./web
@ -108,6 +108,8 @@ jobs:
- name: Web tsslint
if: steps.changed-files.outputs.any_changed == 'true'
working-directory: ./web
env:
NODE_OPTIONS: --max-old-space-size=4096
run: vp run lint:tss
- name: Web type check
@ -129,7 +131,7 @@ jobs:
superlinter:
name: SuperLinter
runs-on: ubuntu-latest
runs-on: depot-ubuntu-24.04
steps:
- name: Checkout code

View File

@ -18,7 +18,7 @@ concurrency:
jobs:
build:
name: unit test for Node.js SDK
runs-on: ubuntu-latest
runs-on: depot-ubuntu-24.04
defaults:
run:

View File

@ -35,7 +35,7 @@ concurrency:
jobs:
translate:
if: github.repository == 'langgenius/dify'
runs-on: ubuntu-latest
runs-on: depot-ubuntu-24.04
timeout-minutes: 120
steps:
@ -158,7 +158,7 @@ jobs:
- name: Run Claude Code for Translation Sync
if: steps.context.outputs.CHANGED_FILES != ''
uses: anthropics/claude-code-action@38ec876110f9fbf8b950c79f534430740c3ac009 # v1.0.101
uses: anthropics/claude-code-action@567fe954a4527e81f132d87d1bdbcc94f7737434 # v1.0.107
with:
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
github_token: ${{ secrets.GITHUB_TOKEN }}

View File

@ -16,7 +16,7 @@ concurrency:
jobs:
trigger:
if: github.repository == 'langgenius/dify'
runs-on: ubuntu-latest
runs-on: depot-ubuntu-24.04
timeout-minutes: 5
steps:

View File

@ -16,7 +16,7 @@ jobs:
test:
name: Full VDB Tests
if: github.repository == 'langgenius/dify'
runs-on: ubuntu-latest
runs-on: depot-ubuntu-24.04
strategy:
matrix:
python-version:

View File

@ -13,7 +13,7 @@ concurrency:
jobs:
test:
name: VDB Smoke Tests
runs-on: ubuntu-latest
runs-on: depot-ubuntu-24.04
strategy:
matrix:
python-version:

View File

@ -13,7 +13,7 @@ concurrency:
jobs:
test:
name: Web Full-Stack E2E
runs-on: ubuntu-latest
runs-on: depot-ubuntu-24.04
defaults:
run:
shell: bash

View File

@ -16,7 +16,7 @@ concurrency:
jobs:
test:
name: Web Tests (${{ matrix.shardIndex }}/${{ matrix.shardTotal }})
runs-on: ubuntu-latest
runs-on: depot-ubuntu-24.04
env:
VITEST_COVERAGE_SCOPE: app-components
strategy:
@ -54,7 +54,7 @@ jobs:
name: Merge Test Reports
if: ${{ !cancelled() }}
needs: [test]
runs-on: ubuntu-latest
runs-on: depot-ubuntu-24.04
env:
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
defaults:
@ -92,7 +92,7 @@ jobs:
dify-ui-test:
name: dify-ui Tests
runs-on: ubuntu-latest
runs-on: depot-ubuntu-24.04
env:
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
defaults:

View File

@ -147,7 +147,7 @@ Import the dashboard to Grafana, using Dify's PostgreSQL database as data source
### Deployment with Kubernetes
If you'd like to configure a highly-available setup, there are community-contributed [Helm Charts](https://helm.sh/) and YAML files which allow Dify to be deployed on Kubernetes.
If you'd like to configure a highly available setup, there are community-contributed [Helm Charts](https://helm.sh/) and YAML files which allow Dify to be deployed on Kubernetes.
- [Helm Chart by @LeoQuote](https://github.com/douban/charts/tree/master/charts/dify)
- [Helm Chart by @BorisPolonsky](https://github.com/BorisPolonsky/dify-helm)

View File

@ -37,6 +37,11 @@ class TagBindingRemovePayload(BaseModel):
type: TagType = Field(description="Tag type")
class TagBindingItemDeletePayload(BaseModel):
target_id: str = Field(description="Target ID to unbind tag from")
type: TagType = Field(description="Tag type")
class TagListQueryParam(BaseModel):
type: Literal["knowledge", "app", ""] = Field("", description="Tag type filter")
keyword: str | None = Field(None, description="Search keyword")
@ -70,6 +75,7 @@ register_schema_models(
TagBasePayload,
TagBindingPayload,
TagBindingRemovePayload,
TagBindingItemDeletePayload,
TagListQueryParam,
TagResponse,
)
@ -152,41 +158,107 @@ class TagUpdateDeleteApi(Resource):
return "", 204
@console_ns.route("/tag-bindings/create")
class TagBindingCreateApi(Resource):
def _require_tag_binding_edit_permission() -> None:
"""
Ensure the current account can edit tag bindings.
Tag binding operations are allowed for users who can edit resources (app/dataset) within the current tenant.
"""
current_user, _ = current_account_with_tenant()
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden()
def _create_tag_bindings() -> tuple[dict[str, str], int]:
_require_tag_binding_edit_permission()
payload = TagBindingPayload.model_validate(console_ns.payload or {})
TagService.save_tag_binding(
TagBindingCreatePayload(
tag_ids=payload.tag_ids,
target_id=payload.target_id,
type=payload.type,
)
)
return {"result": "success"}, 200
def _remove_tag_binding() -> tuple[dict[str, str], int]:
_require_tag_binding_edit_permission()
payload = TagBindingRemovePayload.model_validate(console_ns.payload or {})
TagService.delete_tag_binding(
TagBindingDeletePayload(
tag_id=payload.tag_id,
target_id=payload.target_id,
type=payload.type,
)
)
return {"result": "success"}, 200
@console_ns.route("/tag-bindings")
class TagBindingCollectionApi(Resource):
"""Canonical collection resource for tag binding creation."""
@console_ns.doc("create_tag_binding")
@console_ns.expect(console_ns.models[TagBindingPayload.__name__])
@setup_required
@login_required
@account_initialization_required
def post(self):
current_user, _ = current_account_with_tenant()
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden()
return _create_tag_bindings()
payload = TagBindingPayload.model_validate(console_ns.payload or {})
TagService.save_tag_binding(
TagBindingCreatePayload(tag_ids=payload.tag_ids, target_id=payload.target_id, type=payload.type)
@console_ns.route("/tag-bindings/<uuid:id>")
class TagBindingItemApi(Resource):
"""Canonical item resource for tag binding deletion."""
@console_ns.doc("delete_tag_binding")
@console_ns.doc(params={"id": "Tag ID"})
@console_ns.expect(console_ns.models[TagBindingItemDeletePayload.__name__])
@setup_required
@login_required
@account_initialization_required
def delete(self, id):
_require_tag_binding_edit_permission()
payload = TagBindingItemDeletePayload.model_validate(console_ns.payload or {})
TagService.delete_tag_binding(
TagBindingDeletePayload(
tag_id=str(id),
target_id=payload.target_id,
type=payload.type,
)
)
return {"result": "success"}, 200
@console_ns.route("/tag-bindings/create")
class DeprecatedTagBindingCreateApi(Resource):
"""Deprecated verb-based alias for tag binding creation."""
@console_ns.doc("create_tag_binding_deprecated")
@console_ns.doc(deprecated=True)
@console_ns.doc(description="Deprecated legacy alias. Use POST /tag-bindings instead.")
@console_ns.expect(console_ns.models[TagBindingPayload.__name__])
@setup_required
@login_required
@account_initialization_required
def post(self):
return _create_tag_bindings()
@console_ns.route("/tag-bindings/remove")
class TagBindingDeleteApi(Resource):
class DeprecatedTagBindingRemoveApi(Resource):
"""Deprecated verb-based alias for tag binding deletion."""
@console_ns.doc("delete_tag_binding_deprecated")
@console_ns.doc(deprecated=True)
@console_ns.doc(description="Deprecated legacy alias. Use DELETE /tag-bindings/{id} instead.")
@console_ns.expect(console_ns.models[TagBindingRemovePayload.__name__])
@setup_required
@login_required
@account_initialization_required
def post(self):
current_user, _ = current_account_with_tenant()
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden()
payload = TagBindingRemovePayload.model_validate(console_ns.payload or {})
TagService.delete_tag_binding(
TagBindingDeletePayload(tag_id=payload.tag_id, target_id=payload.target_id, type=payload.type)
)
return {"result": "success"}, 200
return _remove_tag_binding()

View File

@ -527,6 +527,7 @@ class RetrievalService:
child_index_nodes = session.execute(child_chunk_stmt).scalars().all()
for i in child_index_nodes:
assert i.index_node_id
segment_ids.append(i.segment_id)
if i.segment_id in child_chunk_map:
child_chunk_map[i.segment_id].append(i)

View File

@ -11,6 +11,7 @@ from core.rag.index_processor.constant.index_type import IndexTechniqueType
from core.rag.models.document import AttachmentDocument, Document
from extensions.ext_database import db
from models.dataset import ChildChunk, Dataset, DocumentSegment, SegmentAttachmentBinding
from models.enums import SegmentType
class DatasetDocumentStore:
@ -127,6 +128,7 @@ class DatasetDocumentStore:
if save_child:
if doc.children:
for position, child in enumerate(doc.children, start=1):
assert self._document_id
child_segment = ChildChunk(
tenant_id=self._dataset.tenant_id,
dataset_id=self._dataset.id,
@ -137,7 +139,7 @@ class DatasetDocumentStore:
index_node_hash=child.metadata.get("doc_hash"),
content=child.page_content,
word_count=len(child.page_content),
type="automatic",
type=SegmentType.AUTOMATIC,
created_by=self._user_id,
)
db.session.add(child_segment)
@ -163,6 +165,7 @@ class DatasetDocumentStore:
)
# add new child chunks
for position, child in enumerate(doc.children, start=1):
assert self._document_id
child_segment = ChildChunk(
tenant_id=self._dataset.tenant_id,
dataset_id=self._dataset.id,
@ -173,7 +176,7 @@ class DatasetDocumentStore:
index_node_hash=child.metadata.get("doc_hash"),
content=child.page_content,
word_count=len(child.page_content),
type="automatic",
type=SegmentType.AUTOMATIC,
created_by=self._user_id,
)
db.session.add(child_segment)

View File

@ -1036,7 +1036,7 @@ class DocumentSegment(Base):
return attachment_list
class ChildChunk(Base):
class ChildChunk(TypeBase):
__tablename__ = "child_chunks"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="child_chunk_pkey"),
@ -1046,29 +1046,42 @@ class ChildChunk(Base):
)
# initial fields
id = mapped_column(StringUUID, nullable=False, default=lambda: str(uuid4()))
tenant_id = mapped_column(StringUUID, nullable=False)
dataset_id = mapped_column(StringUUID, nullable=False)
document_id = mapped_column(StringUUID, nullable=False)
segment_id = mapped_column(StringUUID, nullable=False)
id: Mapped[str] = mapped_column(StringUUID, nullable=False, default_factory=lambda: str(uuid4()), init=False)
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
dataset_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
document_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
segment_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
position: Mapped[int] = mapped_column(sa.Integer, nullable=False)
content = mapped_column(LongText, nullable=False)
content: Mapped[str] = mapped_column(LongText, nullable=False)
word_count: Mapped[int] = mapped_column(sa.Integer, nullable=False)
# indexing fields
index_node_id = mapped_column(String(255), nullable=True)
index_node_hash = mapped_column(String(255), nullable=True)
type: Mapped[SegmentType] = mapped_column(
EnumText(SegmentType, length=255), nullable=False, server_default=sa.text("'automatic'")
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=sa.func.current_timestamp(), init=False
)
created_by = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=sa.func.current_timestamp())
updated_by = mapped_column(StringUUID, nullable=True)
updated_by: Mapped[str | None] = mapped_column(StringUUID, nullable=True, init=False)
updated_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=sa.func.current_timestamp(), onupdate=func.current_timestamp()
DateTime,
nullable=False,
server_default=sa.func.current_timestamp(),
onupdate=func.current_timestamp(),
init=False,
)
indexing_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
completed_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
error = mapped_column(LongText, nullable=True)
indexing_at: Mapped[datetime | None] = mapped_column(
DateTime, nullable=True, insert_default=None, server_default=None, init=False
)
completed_at: Mapped[datetime | None] = mapped_column(
DateTime, nullable=True, insert_default=None, server_default=None, init=False
)
index_node_id: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None)
index_node_hash: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None)
type: Mapped[SegmentType] = mapped_column(
EnumText(SegmentType, length=255),
nullable=False,
server_default=sa.text("'automatic'"),
default=SegmentType.AUTOMATIC,
)
error: Mapped[str | None] = mapped_column(LongText, nullable=True, init=False)
@property
def dataset(self):

View File

@ -1867,15 +1867,18 @@ class MessageAnnotation(TypeBase):
)
id: Mapped[str] = mapped_column(
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
StringUUID,
insert_default=lambda: str(uuid4()),
default_factory=lambda: str(uuid4()),
init=False,
)
app_id: Mapped[str] = mapped_column(StringUUID)
question: Mapped[str] = mapped_column(LongText, nullable=False)
content: Mapped[str] = mapped_column(LongText, nullable=False)
hit_count: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0"), init=False)
account_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
conversation_id: Mapped[str | None] = mapped_column(StringUUID, sa.ForeignKey("conversations.id"), default=None)
message_id: Mapped[str | None] = mapped_column(StringUUID, default=None)
hit_count: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0"), default=0)
created_at: Mapped[datetime] = mapped_column(
sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False
)

View File

@ -225,8 +225,10 @@ class TestSpanBuilder:
span = builder.build_span(span_data)
assert isinstance(span, ReadableSpan)
assert span.name == "test-span"
assert span.context is not None
assert span.context.trace_id == 123
assert span.context.span_id == 456
assert span.parent is not None
assert span.parent.span_id == 789
assert span.resource == resource
assert span.attributes == {"attr1": "val1"}

View File

@ -64,12 +64,13 @@ class TestSpanData:
def test_span_data_missing_required_fields(self):
with pytest.raises(ValidationError):
SpanData(
trace_id=123,
# span_id missing
name="test_span",
start_time=1000,
end_time=2000,
SpanData.model_validate(
{
"trace_id": 123,
"name": "test_span",
"start_time": 1000,
"end_time": 2000,
}
)
def test_span_data_arbitrary_types_allowed(self):

View File

@ -2,12 +2,14 @@ from __future__ import annotations
from datetime import UTC, datetime
from types import SimpleNamespace
from typing import cast
from unittest.mock import MagicMock
import dify_trace_aliyun.aliyun_trace as aliyun_trace_module
import pytest
from dify_trace_aliyun.aliyun_trace import AliyunDataTrace
from dify_trace_aliyun.config import AliyunConfig
from dify_trace_aliyun.entities.aliyun_trace_entity import SpanData, TraceMetadata
from dify_trace_aliyun.entities.semconv import (
GEN_AI_COMPLETION,
GEN_AI_INPUT_MESSAGE,
@ -44,7 +46,7 @@ class RecordingTraceClient:
self.endpoint = endpoint
self.added_spans: list[object] = []
def add_span(self, span) -> None:
def add_span(self, span: object) -> None:
self.added_spans.append(span)
def api_check(self) -> bool:
@ -63,11 +65,35 @@ def _make_link(trace_id: int = 1, span_id: int = 2) -> Link:
trace_id=trace_id,
span_id=span_id,
is_remote=False,
trace_flags=TraceFlags.SAMPLED,
trace_flags=TraceFlags(TraceFlags.SAMPLED),
)
return Link(context)
def _make_trace_metadata(
trace_id: int = 1,
workflow_span_id: int = 2,
session_id: str = "s",
user_id: str = "u",
links: list[Link] | None = None,
) -> TraceMetadata:
return TraceMetadata(
trace_id=trace_id,
workflow_span_id=workflow_span_id,
session_id=session_id,
user_id=user_id,
links=[] if links is None else links,
)
def _recording_trace_client(trace_instance: AliyunDataTrace) -> RecordingTraceClient:
return cast(RecordingTraceClient, trace_instance.trace_client)
def _recorded_span_data(trace_instance: AliyunDataTrace) -> list[SpanData]:
return cast(list[SpanData], _recording_trace_client(trace_instance).added_spans)
def _make_workflow_trace_info(**overrides) -> WorkflowTraceInfo:
defaults = {
"workflow_id": "workflow-id",
@ -263,20 +289,20 @@ def test_workflow_trace_adds_workflow_and_node_spans(trace_instance: AliyunDataT
trace_instance.workflow_trace(trace_info)
add_workflow_span.assert_called_once()
passed_trace_metadata = add_workflow_span.call_args.args[1]
passed_trace_metadata = cast(TraceMetadata, add_workflow_span.call_args.args[1])
assert passed_trace_metadata.trace_id == 111
assert passed_trace_metadata.workflow_span_id == 222
assert passed_trace_metadata.session_id == "c"
assert passed_trace_metadata.user_id == "u"
assert passed_trace_metadata.links == []
assert trace_instance.trace_client.added_spans == ["span-1", "span-2"]
assert _recording_trace_client(trace_instance).added_spans == ["span-1", "span-2"]
def test_message_trace_returns_early_if_no_message_data(trace_instance: AliyunDataTrace):
trace_info = _make_message_trace_info(message_data=None)
trace_instance.message_trace(trace_info)
assert trace_instance.trace_client.added_spans == []
assert _recording_trace_client(trace_instance).added_spans == []
def test_message_trace_creates_message_and_llm_spans(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch):
@ -302,8 +328,9 @@ def test_message_trace_creates_message_and_llm_spans(trace_instance: AliyunDataT
)
trace_instance.message_trace(trace_info)
assert len(trace_instance.trace_client.added_spans) == 2
message_span, llm_span = trace_instance.trace_client.added_spans
spans = _recorded_span_data(trace_instance)
assert len(spans) == 2
message_span, llm_span = spans
assert message_span.name == "message"
assert message_span.trace_id == 10
@ -324,7 +351,7 @@ def test_message_trace_creates_message_and_llm_spans(trace_instance: AliyunDataT
def test_dataset_retrieval_trace_returns_early_if_no_message_data(trace_instance: AliyunDataTrace):
trace_info = _make_dataset_retrieval_trace_info(message_data=None)
trace_instance.dataset_retrieval_trace(trace_info)
assert trace_instance.trace_client.added_spans == []
assert _recording_trace_client(trace_instance).added_spans == []
def test_dataset_retrieval_trace_creates_span(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch):
@ -338,8 +365,9 @@ def test_dataset_retrieval_trace_creates_span(trace_instance: AliyunDataTrace, m
monkeypatch.setattr(aliyun_trace_module, "extract_retrieval_documents", lambda _: [{"doc": "d"}])
trace_instance.dataset_retrieval_trace(_make_dataset_retrieval_trace_info(inputs="query"))
assert len(trace_instance.trace_client.added_spans) == 1
span = trace_instance.trace_client.added_spans[0]
spans = _recorded_span_data(trace_instance)
assert len(spans) == 1
span = spans[0]
assert span.name == "dataset_retrieval"
assert span.attributes[RETRIEVAL_QUERY] == "query"
assert span.attributes[RETRIEVAL_DOCUMENT] == '[{"doc": "d"}]'
@ -348,7 +376,7 @@ def test_dataset_retrieval_trace_creates_span(trace_instance: AliyunDataTrace, m
def test_tool_trace_returns_early_if_no_message_data(trace_instance: AliyunDataTrace):
trace_info = _make_tool_trace_info(message_data=None)
trace_instance.tool_trace(trace_info)
assert trace_instance.trace_client.added_spans == []
assert _recording_trace_client(trace_instance).added_spans == []
def test_tool_trace_creates_span(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch):
@ -371,8 +399,9 @@ def test_tool_trace_creates_span(trace_instance: AliyunDataTrace, monkeypatch: p
)
)
assert len(trace_instance.trace_client.added_spans) == 1
span = trace_instance.trace_client.added_spans[0]
spans = _recorded_span_data(trace_instance)
assert len(spans) == 1
span = spans[0]
assert span.name == "my-tool"
assert span.status == status
assert span.attributes[TOOL_NAME] == "my-tool"
@ -409,7 +438,7 @@ def test_get_workflow_node_executions_builds_repo_and_fetches(
def test_build_workflow_node_span_routes_llm_type(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch):
node_execution = MagicMock(spec=WorkflowNodeExecution)
trace_info = _make_workflow_trace_info()
trace_metadata = MagicMock()
trace_metadata = _make_trace_metadata()
monkeypatch.setattr(trace_instance, "build_workflow_llm_span", MagicMock(return_value="llm"))
@ -422,7 +451,7 @@ def test_build_workflow_node_span_routes_knowledge_retrieval_type(
):
node_execution = MagicMock(spec=WorkflowNodeExecution)
trace_info = _make_workflow_trace_info()
trace_metadata = MagicMock()
trace_metadata = _make_trace_metadata()
monkeypatch.setattr(trace_instance, "build_workflow_retrieval_span", MagicMock(return_value="retrieval"))
@ -433,7 +462,7 @@ def test_build_workflow_node_span_routes_knowledge_retrieval_type(
def test_build_workflow_node_span_routes_tool_type(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch):
node_execution = MagicMock(spec=WorkflowNodeExecution)
trace_info = _make_workflow_trace_info()
trace_metadata = MagicMock()
trace_metadata = _make_trace_metadata()
monkeypatch.setattr(trace_instance, "build_workflow_tool_span", MagicMock(return_value="tool"))
@ -444,7 +473,7 @@ def test_build_workflow_node_span_routes_tool_type(trace_instance: AliyunDataTra
def test_build_workflow_node_span_routes_code_type(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch):
node_execution = MagicMock(spec=WorkflowNodeExecution)
trace_info = _make_workflow_trace_info()
trace_metadata = MagicMock()
trace_metadata = _make_trace_metadata()
monkeypatch.setattr(trace_instance, "build_workflow_task_span", MagicMock(return_value="task"))
@ -457,7 +486,7 @@ def test_build_workflow_node_span_handles_errors(
):
node_execution = MagicMock(spec=WorkflowNodeExecution)
trace_info = _make_workflow_trace_info()
trace_metadata = MagicMock()
trace_metadata = _make_trace_metadata()
monkeypatch.setattr(trace_instance, "build_workflow_task_span", MagicMock(side_effect=RuntimeError("boom")))
node_execution.node_type = BuiltinNodeTypes.CODE
@ -472,7 +501,7 @@ def test_build_workflow_task_span(trace_instance: AliyunDataTrace, monkeypatch:
status = Status(StatusCode.OK)
monkeypatch.setattr(aliyun_trace_module, "get_workflow_node_status", lambda _: status)
trace_metadata = SimpleNamespace(trace_id=1, workflow_span_id=2, session_id="s", user_id="u", links=[])
trace_metadata = _make_trace_metadata()
node_execution = MagicMock(spec=WorkflowNodeExecution)
node_execution.id = "node-id"
node_execution.title = "title"
@ -494,7 +523,7 @@ def test_build_workflow_tool_span(trace_instance: AliyunDataTrace, monkeypatch:
status = Status(StatusCode.OK)
monkeypatch.setattr(aliyun_trace_module, "get_workflow_node_status", lambda _: status)
trace_metadata = SimpleNamespace(trace_id=1, workflow_span_id=2, session_id="s", user_id="u", links=[_make_link()])
trace_metadata = _make_trace_metadata(links=[_make_link()])
node_execution = MagicMock(spec=WorkflowNodeExecution)
node_execution.id = "node-id"
node_execution.title = "my-tool"
@ -527,7 +556,7 @@ def test_build_workflow_retrieval_span(trace_instance: AliyunDataTrace, monkeypa
aliyun_trace_module, "format_retrieval_documents", lambda docs: [{"formatted": True}] if docs else []
)
trace_metadata = SimpleNamespace(trace_id=1, workflow_span_id=2, session_id="s", user_id="u", links=[])
trace_metadata = _make_trace_metadata()
node_execution = MagicMock(spec=WorkflowNodeExecution)
node_execution.id = "node-id"
node_execution.title = "retrieval"
@ -556,7 +585,7 @@ def test_build_workflow_llm_span(trace_instance: AliyunDataTrace, monkeypatch: p
monkeypatch.setattr(aliyun_trace_module, "format_input_messages", lambda _: "in")
monkeypatch.setattr(aliyun_trace_module, "format_output_messages", lambda _: "out")
trace_metadata = SimpleNamespace(trace_id=1, workflow_span_id=2, session_id="s", user_id="u", links=[])
trace_metadata = _make_trace_metadata()
node_execution = MagicMock(spec=WorkflowNodeExecution)
node_execution.id = "node-id"
node_execution.title = "llm"
@ -594,7 +623,7 @@ def test_add_workflow_span(trace_instance: AliyunDataTrace, monkeypatch: pytest.
status = Status(StatusCode.OK)
monkeypatch.setattr(aliyun_trace_module, "create_status_from_error", lambda _: status)
trace_metadata = SimpleNamespace(trace_id=1, workflow_span_id=2, session_id="s", user_id="u", links=[])
trace_metadata = _make_trace_metadata()
# CASE 1: With message_id
trace_info = _make_workflow_trace_info(
@ -602,9 +631,11 @@ def test_add_workflow_span(trace_instance: AliyunDataTrace, monkeypatch: pytest.
)
trace_instance.add_workflow_span(trace_info, trace_metadata)
assert len(trace_instance.trace_client.added_spans) == 2
message_span = trace_instance.trace_client.added_spans[0]
workflow_span = trace_instance.trace_client.added_spans[1]
client = _recording_trace_client(trace_instance)
spans = _recorded_span_data(trace_instance)
assert len(spans) == 2
message_span = spans[0]
workflow_span = spans[1]
assert message_span.name == "message"
assert message_span.span_kind == SpanKind.SERVER
@ -614,13 +645,14 @@ def test_add_workflow_span(trace_instance: AliyunDataTrace, monkeypatch: pytest.
assert workflow_span.span_kind == SpanKind.INTERNAL
assert workflow_span.parent_span_id == 20
trace_instance.trace_client.added_spans.clear()
client.added_spans.clear()
# CASE 2: Without message_id
trace_info_no_msg = _make_workflow_trace_info(message_id=None)
trace_instance.add_workflow_span(trace_info_no_msg, trace_metadata)
assert len(trace_instance.trace_client.added_spans) == 1
span = trace_instance.trace_client.added_spans[0]
spans = _recorded_span_data(trace_instance)
assert len(spans) == 1
span = spans[0]
assert span.name == "workflow"
assert span.span_kind == SpanKind.SERVER
assert span.parent_span_id is None
@ -641,7 +673,8 @@ def test_suggested_question_trace(trace_instance: AliyunDataTrace, monkeypatch:
trace_info = _make_suggested_question_trace_info(suggested_question=["how?"])
trace_instance.suggested_question_trace(trace_info)
assert len(trace_instance.trace_client.added_spans) == 1
span = trace_instance.trace_client.added_spans[0]
spans = _recorded_span_data(trace_instance)
assert len(spans) == 1
span = spans[0]
assert span.name == "suggested_question"
assert span.attributes[GEN_AI_COMPLETION] == '["how?"]'

View File

@ -1,4 +1,6 @@
import json
from collections.abc import Mapping
from typing import Any, cast
from unittest.mock import MagicMock
from dify_trace_aliyun.entities.semconv import (
@ -170,7 +172,7 @@ def test_create_common_span_attributes():
def test_format_retrieval_documents():
# Not a list
assert format_retrieval_documents("not a list") == []
assert format_retrieval_documents(cast(list[object], "not a list")) == []
# Valid list
docs = [
@ -211,7 +213,7 @@ def test_format_retrieval_documents():
def test_format_input_messages():
# Not a dict
assert format_input_messages(None) == serialize_json_data([])
assert format_input_messages(cast(Mapping[str, Any], None)) == serialize_json_data([])
# No prompts
assert format_input_messages({}) == serialize_json_data([])
@ -244,7 +246,7 @@ def test_format_input_messages():
def test_format_output_messages():
# Not a dict
assert format_output_messages(None) == serialize_json_data([])
assert format_output_messages(cast(Mapping[str, Any], None)) == serialize_json_data([])
# No text
assert format_output_messages({"finish_reason": "stop"}) == serialize_json_data([])

View File

@ -25,13 +25,13 @@ class TestAliyunConfig:
def test_missing_required_fields(self):
"""Test that required fields are enforced"""
with pytest.raises(ValidationError):
AliyunConfig()
AliyunConfig.model_validate({})
with pytest.raises(ValidationError):
AliyunConfig(license_key="test_license")
AliyunConfig.model_validate({"license_key": "test_license"})
with pytest.raises(ValidationError):
AliyunConfig(endpoint="https://tracing-analysis-dc-hz.aliyuncs.com")
AliyunConfig.model_validate({"endpoint": "https://tracing-analysis-dc-hz.aliyuncs.com"})
def test_app_name_validation_empty(self):
"""Test app_name validation with empty value"""

View File

@ -1,4 +1,5 @@
from datetime import UTC, datetime, timedelta
from typing import cast
from unittest.mock import MagicMock, patch
import pytest
@ -129,7 +130,7 @@ def test_set_span_status():
return "SilentErrorRepr"
span.reset_mock()
set_span_status(span, SilentError())
set_span_status(span, cast(Exception | str | None, SilentError()))
assert span.add_event.call_args[1]["attributes"][OTELSpanAttributes.EXCEPTION_MESSAGE] == "SilentErrorRepr"

View File

@ -28,13 +28,13 @@ class TestLangfuseConfig:
def test_missing_required_fields(self):
"""Test that required fields are enforced"""
with pytest.raises(ValidationError):
LangfuseConfig()
LangfuseConfig.model_validate({})
with pytest.raises(ValidationError):
LangfuseConfig(public_key="public")
LangfuseConfig.model_validate({"public_key": "public"})
with pytest.raises(ValidationError):
LangfuseConfig(secret_key="secret")
LangfuseConfig.model_validate({"secret_key": "secret"})
def test_host_validation_empty(self):
"""Test host validation with empty value"""

View File

@ -2,6 +2,7 @@
from datetime import datetime, timedelta
from types import SimpleNamespace
from typing import cast
from unittest.mock import MagicMock, patch
from dify_trace_langfuse.config import LangfuseConfig
@ -134,4 +135,4 @@ class TestLangFuseDataTraceCompletionStartTime:
assert trace._get_completion_start_time(start_time, None) is None
assert trace._get_completion_start_time(start_time, -1) is None
assert trace._get_completion_start_time(start_time, "invalid") is None
assert trace._get_completion_start_time(start_time, cast(float | int | None, "invalid")) is None

View File

@ -21,13 +21,13 @@ class TestLangSmithConfig:
def test_missing_required_fields(self):
"""Test that required fields are enforced"""
with pytest.raises(ValidationError):
LangSmithConfig()
LangSmithConfig.model_validate({})
with pytest.raises(ValidationError):
LangSmithConfig(api_key="key")
LangSmithConfig.model_validate({"api_key": "key"})
with pytest.raises(ValidationError):
LangSmithConfig(project="project")
LangSmithConfig.model_validate({"project": "project"})
def test_endpoint_validation_https_only(self):
"""Test endpoint validation only allows HTTPS"""

View File

@ -599,7 +599,6 @@ class TestMessageTrace:
span = MagicMock()
mock_tracing["start"].return_value = span
mock_tracing["set"].return_value = "token"
mock_db.session.query.return_value.where.return_value.first.return_value = None
trace_instance.message_trace(_make_message_trace_info())
mock_tracing["start"].assert_called_once()
@ -609,7 +608,6 @@ class TestMessageTrace:
span = MagicMock()
mock_tracing["start"].return_value = span
mock_tracing["set"].return_value = "token"
mock_db.session.query.return_value.where.return_value.first.return_value = None
trace_info = _make_message_trace_info(error="something broke")
trace_instance.message_trace(trace_info)
@ -620,7 +618,6 @@ class TestMessageTrace:
span = MagicMock()
mock_tracing["start"].return_value = span
mock_tracing["set"].return_value = "token"
mock_db.session.query.return_value.where.return_value.first.return_value = None
monkeypatch.setenv("FILES_URL", "http://files.test")
file_data = SimpleNamespace(url="path/to/file.png")
@ -638,7 +635,6 @@ class TestMessageTrace:
span = MagicMock()
mock_tracing["start"].return_value = span
mock_tracing["set"].return_value = "token"
mock_db.session.query.return_value.where.return_value.first.return_value = None
trace_info = _make_message_trace_info(file_list=None, message_file_data=None)
trace_instance.message_trace(trace_info)
@ -651,7 +647,6 @@ class TestMessageTrace:
end_user = MagicMock()
end_user.session_id = "session-xyz"
mock_db.session.query.return_value.where.return_value.first.return_value = end_user
trace_info = _make_message_trace_info(
metadata={"from_end_user_id": "eu-1", "conversation_id": "c1"},
@ -664,7 +659,6 @@ class TestMessageTrace:
span = MagicMock()
mock_tracing["start"].return_value = span
mock_tracing["set"].return_value = "token"
mock_db.session.query.return_value.where.return_value.first.return_value = None
trace_info = _make_message_trace_info(
metadata={"from_account_id": "acc-1"},

View File

@ -12,6 +12,7 @@ from __future__ import annotations
import uuid
from datetime import datetime
from typing import cast
from unittest.mock import MagicMock, patch
from dify_trace_opik.opik_trace import OpikDataTrace, _seed_to_uuid4, prepare_opik_uuid
@ -69,6 +70,14 @@ def _make_opik_trace_instance() -> OpikDataTrace:
return instance
def _add_trace_mock(instance: OpikDataTrace) -> MagicMock:
return cast(MagicMock, instance.add_trace)
def _add_span_mock(instance: OpikDataTrace) -> MagicMock:
return cast(MagicMock, instance.add_span)
# ---------------------------------------------------------------------------
# _seed_to_uuid4
# ---------------------------------------------------------------------------
@ -155,21 +164,21 @@ class TestWorkflowTraceWithoutMessageId:
def test_root_span_is_created(self):
trace_info = _make_workflow_trace_info(message_id=None)
instance = self._run(trace_info)
assert instance.add_span.called
assert _add_span_mock(instance).called
def test_root_span_id_matches_expected(self):
trace_info = _make_workflow_trace_info(message_id=None)
instance = self._run(trace_info)
expected = self._expected_root_span_id(trace_info)
root_span_kwargs = instance.add_span.call_args_list[0][0][0]
root_span_kwargs = _add_span_mock(instance).call_args_list[0][0][0]
assert root_span_kwargs["id"] == expected
def test_root_span_has_no_parent(self):
trace_info = _make_workflow_trace_info(message_id=None)
instance = self._run(trace_info)
root_span_kwargs = instance.add_span.call_args_list[0][0][0]
root_span_kwargs = _add_span_mock(instance).call_args_list[0][0][0]
assert root_span_kwargs["parent_span_id"] is None
def test_trace_name_is_workflow_trace(self):
@ -177,21 +186,21 @@ class TestWorkflowTraceWithoutMessageId:
trace_info = _make_workflow_trace_info(message_id=None)
instance = self._run(trace_info)
trace_kwargs = instance.add_trace.call_args_list[0][0][0]
trace_kwargs = _add_trace_mock(instance).call_args_list[0][0][0]
assert trace_kwargs["name"] == TraceTaskName.WORKFLOW_TRACE
def test_root_span_name_is_workflow_trace(self):
trace_info = _make_workflow_trace_info(message_id=None)
instance = self._run(trace_info)
root_span_kwargs = instance.add_span.call_args_list[0][0][0]
root_span_kwargs = _add_span_mock(instance).call_args_list[0][0][0]
assert root_span_kwargs["name"] == TraceTaskName.WORKFLOW_TRACE
def test_root_span_has_workflow_tag(self):
trace_info = _make_workflow_trace_info(message_id=None)
instance = self._run(trace_info)
root_span_kwargs = instance.add_span.call_args_list[0][0][0]
root_span_kwargs = _add_span_mock(instance).call_args_list[0][0][0]
assert "workflow" in root_span_kwargs["tags"]
def test_node_execution_spans_are_parented_to_root(self):
@ -214,8 +223,9 @@ class TestWorkflowTraceWithoutMessageId:
instance = self._run(trace_info, node_executions=[node_exec])
# call_args_list[0] = root span, [1] = node execution span
assert instance.add_span.call_count == 2
node_span_kwargs = instance.add_span.call_args_list[1][0][0]
add_span = _add_span_mock(instance)
assert add_span.call_count == 2
node_span_kwargs = add_span.call_args_list[1][0][0]
assert node_span_kwargs["parent_span_id"] == expected_root_span_id
def test_node_span_not_parented_to_workflow_app_log_id(self):
@ -240,7 +250,7 @@ class TestWorkflowTraceWithoutMessageId:
instance = self._run(trace_info, node_executions=[node_exec])
old_parent_id = prepare_opik_uuid(trace_info.start_time, trace_info.workflow_app_log_id)
node_span_kwargs = instance.add_span.call_args_list[1][0][0]
node_span_kwargs = _add_span_mock(instance).call_args_list[1][0][0]
assert node_span_kwargs["parent_span_id"] != old_parent_id
def test_root_span_id_differs_from_trace_id(self):
@ -283,7 +293,7 @@ class TestWorkflowTraceWithMessageId:
trace_info = _make_workflow_trace_info(message_id=self._MESSAGE_ID)
instance = self._run(trace_info)
trace_kwargs = instance.add_trace.call_args_list[0][0][0]
trace_kwargs = _add_trace_mock(instance).call_args_list[0][0][0]
assert trace_kwargs["name"] == TraceTaskName.MESSAGE_TRACE
def test_root_span_uses_workflow_run_id_directly(self):
@ -292,7 +302,7 @@ class TestWorkflowTraceWithMessageId:
instance = self._run(trace_info)
expected_root_span_id = prepare_opik_uuid(trace_info.start_time, trace_info.workflow_run_id)
root_span_kwargs = instance.add_span.call_args_list[0][0][0]
root_span_kwargs = _add_span_mock(instance).call_args_list[0][0][0]
assert root_span_kwargs["id"] == expected_root_span_id
def test_root_span_id_differs_from_no_message_id_case(self):
@ -326,5 +336,5 @@ class TestWorkflowTraceWithMessageId:
instance = self._run(trace_info, node_executions=[node_exec])
node_span_kwargs = instance.add_span.call_args_list[1][0][0]
node_span_kwargs = _add_span_mock(instance).call_args_list[1][0][0]
assert node_span_kwargs["parent_span_id"] == expected_root_span_id

View File

@ -5,6 +5,7 @@ from __future__ import annotations
import sys
import types
from types import SimpleNamespace
from typing import Any, TypedDict, cast
from unittest.mock import MagicMock
import pytest
@ -12,7 +13,7 @@ from dify_trace_tencent import client as client_module
from dify_trace_tencent.client import TencentTraceClient, _get_opentelemetry_sdk_version
from dify_trace_tencent.entities.tencent_trace_entity import SpanData
from opentelemetry.sdk.trace import Event
from opentelemetry.trace import Status, StatusCode
from opentelemetry.trace import SpanContext, Status, StatusCode, TraceFlags
metric_reader_instances: list[DummyMetricReader] = []
meter_provider_instances: list[DummyMeterProvider] = []
@ -80,6 +81,16 @@ class DummyJsonMetricExporterNoTemporality:
self.kwargs = kwargs
class PatchedCoreComponents(TypedDict):
span_exporter: MagicMock
span_processor: MagicMock
tracer: MagicMock
span: MagicMock
tracer_provider: MagicMock
logger: MagicMock
trace_api: Any
def _add_stub_modules(monkeypatch: pytest.MonkeyPatch) -> None:
"""Drop fake metric modules into sys.modules so the client imports resolve."""
@ -118,7 +129,7 @@ def stub_metric_modules(monkeypatch: pytest.MonkeyPatch) -> None:
@pytest.fixture(autouse=True)
def patch_core_components(monkeypatch: pytest.MonkeyPatch) -> dict[str, object]:
def patch_core_components(monkeypatch: pytest.MonkeyPatch) -> PatchedCoreComponents:
span_exporter = MagicMock(name="span_exporter")
monkeypatch.setattr(client_module, "OTLPSpanExporter", MagicMock(return_value=span_exporter))
@ -168,6 +179,15 @@ def patch_core_components(monkeypatch: pytest.MonkeyPatch) -> dict[str, object]:
}
def _make_span_context(trace_id: int = 1, span_id: int = 2) -> SpanContext:
return SpanContext(
trace_id=trace_id,
span_id=span_id,
is_remote=False,
trace_flags=TraceFlags(TraceFlags.SAMPLED),
)
def _build_client() -> TencentTraceClient:
return TencentTraceClient(
service_name="service",
@ -208,7 +228,7 @@ def test_resolve_grpc_target_parsable_variants(endpoint: str, expected: tuple[st
def test_resolve_grpc_target_handles_errors() -> None:
assert TencentTraceClient._resolve_grpc_target(123) == ("localhost:4317", True, "localhost", 4317)
assert TencentTraceClient._resolve_grpc_target(cast(str, 123)) == ("localhost:4317", True, "localhost", 4317)
@pytest.mark.parametrize(
@ -248,7 +268,7 @@ def test_record_methods_skip_when_histogram_missing() -> None:
client.record_trace_duration(0.5)
def test_record_llm_duration_handles_exceptions(patch_core_components: dict[str, object]) -> None:
def test_record_llm_duration_handles_exceptions(patch_core_components: PatchedCoreComponents) -> None:
client = _build_client()
client.hist_llm_duration = MagicMock(name="hist_llm_duration")
client.hist_llm_duration.record.side_effect = RuntimeError("boom")
@ -258,10 +278,11 @@ def test_record_llm_duration_handles_exceptions(patch_core_components: dict[str,
logger.debug.assert_called()
def test_create_and_export_span_sets_attributes(patch_core_components: dict[str, object]) -> None:
def test_create_and_export_span_sets_attributes(patch_core_components: PatchedCoreComponents) -> None:
client = _build_client()
span = patch_core_components["span"]
span.get_span_context.return_value = "ctx"
ctx = _make_span_context(span_id=2)
span.get_span_context.return_value = ctx
data = SpanData(
trace_id=1,
@ -280,14 +301,15 @@ def test_create_and_export_span_sets_attributes(patch_core_components: dict[str,
span.add_event.assert_called_once()
span.set_status.assert_called_once()
span.end.assert_called_once_with(end_time=20)
assert client.span_contexts[2] == "ctx"
assert client.span_contexts[2] == ctx
def test_create_and_export_span_uses_parent_context(patch_core_components: dict[str, object]) -> None:
def test_create_and_export_span_uses_parent_context(patch_core_components: PatchedCoreComponents) -> None:
client = _build_client()
client.span_contexts[10] = "existing"
existing_context = _make_span_context(span_id=10)
client.span_contexts[10] = existing_context
span = patch_core_components["span"]
span.get_span_context.return_value = "child"
span.get_span_context.return_value = _make_span_context(span_id=11)
data = SpanData(
trace_id=1,
@ -302,14 +324,14 @@ def test_create_and_export_span_uses_parent_context(patch_core_components: dict[
client._create_and_export_span(data)
trace_api = patch_core_components["trace_api"]
trace_api.NonRecordingSpan.assert_called_once_with("existing")
trace_api.NonRecordingSpan.assert_called_once_with(existing_context)
trace_api.set_span_in_context.assert_called_once()
def test_create_and_export_span_exception_logs_error(patch_core_components: dict[str, object]) -> None:
def test_create_and_export_span_exception_logs_error(patch_core_components: PatchedCoreComponents) -> None:
client = _build_client()
span = patch_core_components["span"]
span.get_span_context.return_value = "ctx"
span.get_span_context.return_value = _make_span_context(span_id=2)
client.tracer.start_span.side_effect = RuntimeError("boom")
client._create_and_export_span(
@ -385,7 +407,7 @@ def test_get_project_url() -> None:
assert client.get_project_url() == "https://console.cloud.tencent.com/apm"
def test_shutdown_flushes_all_components(patch_core_components: dict[str, object]) -> None:
def test_shutdown_flushes_all_components(patch_core_components: PatchedCoreComponents) -> None:
client = _build_client()
span_processor = patch_core_components["span_processor"]
tracer_provider = patch_core_components["tracer_provider"]
@ -401,10 +423,11 @@ def test_shutdown_flushes_all_components(patch_core_components: dict[str, object
metric_reader.shutdown.assert_called_once()
def test_shutdown_logs_when_meter_provider_fails(patch_core_components: dict[str, object]) -> None:
def test_shutdown_logs_when_meter_provider_fails(patch_core_components: PatchedCoreComponents) -> None:
client = _build_client()
meter_provider = meter_provider_instances[-1]
meter_provider.shutdown.side_effect = RuntimeError("boom")
assert client.metric_reader is not None
client.metric_reader.shutdown.side_effect = RuntimeError("boom")
client.shutdown()
@ -433,7 +456,7 @@ def test_metrics_initialization_failure_sets_histogram_attributes(monkeypatch: p
assert client.metric_reader is None
def test_add_span_logs_exception(monkeypatch: pytest.MonkeyPatch, patch_core_components: dict[str, object]) -> None:
def test_add_span_logs_exception(monkeypatch: pytest.MonkeyPatch, patch_core_components: PatchedCoreComponents) -> None:
client = _build_client()
monkeypatch.setattr(client, "_create_and_export_span", MagicMock(side_effect=RuntimeError("boom")))
@ -454,10 +477,10 @@ def test_add_span_logs_exception(monkeypatch: pytest.MonkeyPatch, patch_core_com
logger.exception.assert_called_once()
def test_create_and_export_span_converts_attribute_types(patch_core_components: dict[str, object]) -> None:
def test_create_and_export_span_converts_attribute_types(patch_core_components: PatchedCoreComponents) -> None:
client = _build_client()
span = patch_core_components["span"]
span.get_span_context.return_value = "ctx"
span.get_span_context.return_value = _make_span_context(span_id=2)
data = SpanData.model_construct(
trace_id=1,
@ -485,7 +508,7 @@ def test_record_llm_duration_converts_attributes() -> None:
hist_mock = MagicMock(name="hist_llm_duration")
client.hist_llm_duration = hist_mock
client.record_llm_duration(0.3, {"foo": object(), "bar": 2})
client.record_llm_duration(0.3, cast(dict[str, str], {"foo": object(), "bar": 2}))
_, attrs = hist_mock.record.call_args.args
assert isinstance(attrs["foo"], str)
assert attrs["bar"] == 2
@ -496,7 +519,7 @@ def test_record_trace_duration_converts_attributes() -> None:
hist_mock = MagicMock(name="hist_trace_duration")
client.hist_trace_duration = hist_mock
client.record_trace_duration(1.0, {"meta": object(), "ok": True})
client.record_trace_duration(1.0, cast(dict[str, str], {"meta": object(), "ok": True}))
_, attrs = hist_mock.record.call_args.args
assert isinstance(attrs["meta"], str)
assert attrs["ok"] is True
@ -512,7 +535,7 @@ def test_record_trace_duration_converts_attributes() -> None:
],
)
def test_record_methods_handle_exceptions(
method: str, attr_name: str, args: tuple[object, ...], patch_core_components: dict[str, object]
method: str, attr_name: str, args: tuple[object, ...], patch_core_components: PatchedCoreComponents
) -> None:
client = _build_client()
hist_mock = MagicMock(name=attr_name)
@ -527,35 +550,38 @@ def test_record_methods_handle_exceptions(
def test_metrics_initializes_grpc_metric_exporter() -> None:
client = _build_client()
metric_reader = metric_reader_instances[-1]
exporter = cast(DummyGrpcMetricExporter, metric_reader.exporter)
assert isinstance(metric_reader.exporter, DummyGrpcMetricExporter)
assert isinstance(exporter, DummyGrpcMetricExporter)
assert metric_reader.export_interval_millis == client.metrics_export_interval_sec * 1000
assert metric_reader.exporter.kwargs["endpoint"] == "trace.example.com:4317"
assert metric_reader.exporter.kwargs["insecure"] is False
assert metric_reader.exporter.kwargs["headers"]["authorization"] == "Bearer token"
assert exporter.kwargs["endpoint"] == "trace.example.com:4317"
assert exporter.kwargs["insecure"] is False
assert cast(dict[str, dict[str, str]], exporter.kwargs)["headers"]["authorization"] == "Bearer token"
def test_metrics_initializes_http_protobuf_metric_exporter(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setenv("OTEL_EXPORTER_OTLP_PROTOCOL", "http/protobuf")
client = _build_client()
metric_reader = metric_reader_instances[-1]
exporter = cast(DummyHttpMetricExporter, metric_reader.exporter)
assert isinstance(metric_reader.exporter, DummyHttpMetricExporter)
assert isinstance(exporter, DummyHttpMetricExporter)
assert metric_reader.export_interval_millis == client.metrics_export_interval_sec * 1000
assert metric_reader.exporter.kwargs["endpoint"] == client.endpoint
assert metric_reader.exporter.kwargs["headers"]["authorization"] == "Bearer token"
assert exporter.kwargs["endpoint"] == client.endpoint
assert cast(dict[str, dict[str, str]], exporter.kwargs)["headers"]["authorization"] == "Bearer token"
def test_metrics_initializes_http_json_metric_exporter(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setenv("OTEL_EXPORTER_OTLP_PROTOCOL", "http/json")
client = _build_client()
metric_reader = metric_reader_instances[-1]
exporter = cast(DummyJsonMetricExporter, metric_reader.exporter)
assert isinstance(metric_reader.exporter, DummyJsonMetricExporter)
assert isinstance(exporter, DummyJsonMetricExporter)
assert metric_reader.export_interval_millis == client.metrics_export_interval_sec * 1000
assert metric_reader.exporter.kwargs["endpoint"] == client.endpoint
assert metric_reader.exporter.kwargs["headers"]["authorization"] == "Bearer token"
assert "preferred_temporality" in metric_reader.exporter.kwargs
assert exporter.kwargs["endpoint"] == client.endpoint
assert cast(dict[str, dict[str, str]], exporter.kwargs)["headers"]["authorization"] == "Bearer token"
assert "preferred_temporality" in exporter.kwargs
def test_metrics_http_json_metric_exporter_falls_back_without_temporality(monkeypatch: pytest.MonkeyPatch) -> None:
@ -564,9 +590,10 @@ def test_metrics_http_json_metric_exporter_falls_back_without_temporality(monkey
monkeypatch.setattr(exporter_module, "OTLPMetricExporter", DummyJsonMetricExporterNoTemporality)
_ = _build_client()
metric_reader = metric_reader_instances[-1]
exporter = cast(DummyJsonMetricExporterNoTemporality, metric_reader.exporter)
assert isinstance(metric_reader.exporter, DummyJsonMetricExporterNoTemporality)
assert "preferred_temporality" not in metric_reader.exporter.kwargs
assert isinstance(exporter, DummyJsonMetricExporterNoTemporality)
assert "preferred_temporality" not in exporter.kwargs
def test_metrics_http_json_uses_http_fallback_when_no_json_exporter(monkeypatch: pytest.MonkeyPatch) -> None:

View File

@ -31,13 +31,13 @@ class TestWeaveConfig:
def test_missing_required_fields(self):
"""Test that required fields are enforced"""
with pytest.raises(ValidationError):
WeaveConfig()
WeaveConfig.model_validate({})
with pytest.raises(ValidationError):
WeaveConfig(api_key="key")
WeaveConfig.model_validate({"api_key": "key"})
with pytest.raises(ValidationError):
WeaveConfig(project="project")
WeaveConfig.model_validate({"project": "project"})
def test_endpoint_validation_https_only(self):
"""Test endpoint validation only allows HTTPS"""

View File

@ -6,9 +6,10 @@ requires-python = "~=3.12.0"
dependencies = [
# Legacy: mature and widely deployed
"bleach>=6.3.0",
"boto3>=1.42.91",
"boto3>=1.42.96",
"celery>=5.6.3",
"croniter>=6.2.2",
"flask>=3.1.3,<4.0.0",
"flask-cors>=6.0.2",
"gevent>=26.4.0",
"gevent-websocket>=0.10.1",
@ -16,7 +17,7 @@ dependencies = [
"google-api-python-client>=2.194.0",
"gunicorn>=25.3.0",
"psycogreen>=1.0.2",
"psycopg2-binary>=2.9.11",
"psycopg2-binary>=2.9.12",
"python-socketio>=5.13.0",
"redis[hiredis]>=7.4.0",
"sendgrid>=6.12.5",
@ -32,13 +33,13 @@ dependencies = [
"flask-restx>=1.3.2,<2.0.0",
"google-cloud-aiplatform>=1.148.1,<2.0.0",
"httpx[socks]>=0.28.1,<1.0.0",
"opentelemetry-distro>=0.62b0,<1.0.0",
"opentelemetry-distro>=0.62b1,<1.0.0",
"opentelemetry-instrumentation-celery>=0.62b0,<1.0.0",
"opentelemetry-instrumentation-flask>=0.62b0,<1.0.0",
"opentelemetry-instrumentation-httpx>=0.62b0,<1.0.0",
"opentelemetry-instrumentation-redis>=0.62b0,<1.0.0",
"opentelemetry-instrumentation-sqlalchemy>=0.62b0,<1.0.0",
"opentelemetry-propagator-b3>=1.41.0,<2.0.0",
"opentelemetry-propagator-b3>=1.41.1,<2.0.0",
"readabilipy>=0.3.0,<1.0.0",
"resend>=2.27.0,<3.0.0",
@ -117,7 +118,7 @@ dev = [
"faker>=40.15.0",
"lxml-stubs>=0.5.1",
"basedpyright>=1.39.3",
"ruff>=0.15.11",
"ruff>=0.15.12",
"pytest>=9.0.3",
"pytest-benchmark>=5.2.3",
"pytest-cov>=7.1.0",
@ -144,7 +145,7 @@ dev = [
"types-pexpect>=4.9.0",
"types-protobuf>=7.34.1",
"types-psutil>=7.2.2",
"types-psycopg2>=2.9.21",
"types-psycopg2>=2.9.21.20260422",
"types-pygments>=2.20.0",
"types-pymysql>=1.1.0",
"types-python-dateutil>=2.9.0",
@ -157,9 +158,9 @@ dev = [
"types-tensorflow>=2.18.0.20260408",
"types-tqdm>=4.67.3.20260408",
"types-ujson>=5.10.0",
"boto3-stubs>=1.42.92",
"boto3-stubs>=1.42.96",
"types-jmespath>=1.1.0.20260408",
"hypothesis>=6.152.1",
"hypothesis>=6.152.3",
"types_pyOpenSSL>=24.1.0",
"types_cffi>=2.0.0.20260408",
"types_setuptools>=82.0.0.20260408",
@ -169,12 +170,12 @@ dev = [
"import-linter>=2.3",
"types-redis>=4.6.0.20241004",
"celery-types>=0.23.0",
"mypy>=1.20.1",
"mypy>=1.20.2",
# "locust>=2.40.4", # Temporarily removed due to compatibility issues. Uncomment when resolved.
"pytest-timeout>=2.4.0",
"pytest-xdist>=3.8.0",
"pyrefly>=0.62.0",
"xinference-client>=2.5.0",
"xinference-client>=2.7.0",
]
############################################################
@ -184,12 +185,12 @@ dev = [
storage = [
"azure-storage-blob>=12.28.0",
"bce-python-sdk>=0.9.70",
"cos-python-sdk-v5>=1.9.41",
"cos-python-sdk-v5>=1.9.42",
"esdk-obs-python>=3.22.2",
"google-cloud-storage>=3.10.1",
"opendal>=0.46.0",
"oss2>=2.19.1",
"supabase>=2.28.3",
"supabase>=2.29.0",
"tos>=2.9.0",
]
@ -272,7 +273,7 @@ vdb-vastbase = ["dify-vdb-vastbase"]
vdb-vikingdb = ["dify-vdb-vikingdb"]
vdb-weaviate = ["dify-vdb-weaviate"]
# Optional client used by some tests / integrations (not a vector backend plugin)
vdb-xinference = ["xinference-client>=2.5.0"]
vdb-xinference = ["xinference-client>=2.7.0"]
trace-all = [
"dify-trace-aliyun",

View File

@ -133,7 +133,14 @@ class AppAnnotationService:
raise ValueError("'question' is required when 'message_id' is not provided")
question = maybe_question
annotation = MessageAnnotation(app_id=app.id, content=answer, question=question, account_id=current_user.id)
annotation = MessageAnnotation(
app_id=app.id,
conversation_id=None,
message_id=None,
content=answer,
question=question,
account_id=current_user.id,
)
db.session.add(annotation)
db.session.commit()

View File

@ -89,7 +89,10 @@ class AsyncWorkflowService:
raise WorkflowNotFoundError(f"App not found: {trigger_data.app_id}")
# 2. Get workflow
workflow = cls._get_workflow(workflow_service, app_model, trigger_data.workflow_id)
workflow = cls._get_workflow(workflow_service, app_model, trigger_data.workflow_id, session=session)
# commit read only session before starting the billig rpc call
session.commit()
# 3. Get dispatcher based on tenant subscription
dispatcher = dispatcher_manager.get_dispatcher(trigger_data.tenant_id)
@ -302,13 +305,21 @@ class AsyncWorkflowService:
return [log.to_dict() for log in logs]
@staticmethod
def _get_workflow(workflow_service: WorkflowService, app_model: App, workflow_id: str | None = None) -> Workflow:
def _get_workflow(
workflow_service: WorkflowService,
app_model: App,
workflow_id: str | None = None,
session: Session | None = None,
) -> Workflow:
"""
Get workflow for the app
Args:
app_model: App model instance
workflow_id: Optional specific workflow ID
session: Reuse this SQLAlchemy session for the lookup when provided,
so the caller's explicit session bears the connection cost
instead of Flask's request-scoped ``db.session``.
Returns:
Workflow instance
@ -318,12 +329,12 @@ class AsyncWorkflowService:
"""
if workflow_id:
# Get specific published workflow
workflow = workflow_service.get_published_workflow_by_id(app_model, workflow_id)
workflow = workflow_service.get_published_workflow_by_id(app_model, workflow_id, session=session)
if not workflow:
raise WorkflowNotFoundError(f"Published workflow not found: {workflow_id}")
else:
# Get default published workflow
workflow = workflow_service.get_published_workflow(app_model)
workflow = workflow_service.get_published_workflow(app_model, session=session)
if not workflow:
raise WorkflowNotFoundError(f"No published workflow found for app: {app_model.id}")

View File

@ -3748,6 +3748,7 @@ class SegmentService:
ChildChunk.segment_id == segment.id,
)
)
assert current_user.current_tenant_id
child_chunk = ChildChunk(
tenant_id=current_user.current_tenant_id,
dataset_id=dataset.id,
@ -3758,7 +3759,7 @@ class SegmentService:
index_node_hash=index_node_hash,
content=content,
word_count=len(content),
type="customized",
type=SegmentType.CUSTOMIZED,
created_by=current_user.id,
)
db.session.add(child_chunk)
@ -3818,6 +3819,7 @@ class SegmentService:
if new_child_chunks_args:
child_chunk_count = len(child_chunks)
for position, args in enumerate(new_child_chunks_args, start=child_chunk_count + 1):
assert current_user.current_tenant_id
index_node_id = str(uuid.uuid4())
index_node_hash = helper.generate_text_hash(args.content)
child_chunk = ChildChunk(
@ -3830,7 +3832,7 @@ class SegmentService:
index_node_hash=index_node_hash,
content=args.content,
word_count=len(args.content),
type="customized",
type=SegmentType.CUSTOMIZED,
created_by=current_user.id,
)

View File

@ -799,50 +799,47 @@ class WebhookService:
Exception: If workflow execution fails
"""
try:
with Session(db.engine) as session:
# Prepare inputs for the webhook node
# The webhook node expects webhook_data in the inputs
workflow_inputs = cls.build_workflow_inputs(webhook_data)
workflow_inputs = cls.build_workflow_inputs(webhook_data)
# Create trigger data
trigger_data = WebhookTriggerData(
app_id=webhook_trigger.app_id,
workflow_id=workflow.id,
root_node_id=webhook_trigger.node_id, # Start from the webhook node
inputs=workflow_inputs,
tenant_id=webhook_trigger.tenant_id,
trigger_data = WebhookTriggerData(
app_id=webhook_trigger.app_id,
workflow_id=workflow.id,
root_node_id=webhook_trigger.node_id,
inputs=workflow_inputs,
tenant_id=webhook_trigger.tenant_id,
)
end_user = EndUserService.get_or_create_end_user_by_type(
type=InvokeFrom.TRIGGER,
tenant_id=webhook_trigger.tenant_id,
app_id=webhook_trigger.app_id,
user_id=None,
)
try:
quota_charge = QuotaService.reserve(QuotaType.TRIGGER, webhook_trigger.tenant_id)
except QuotaExceededError:
AppTriggerService.mark_tenant_triggers_rate_limited(webhook_trigger.tenant_id)
logger.info(
"Tenant %s rate limited, skipping webhook trigger %s",
webhook_trigger.tenant_id,
webhook_trigger.webhook_id,
)
raise
end_user = EndUserService.get_or_create_end_user_by_type(
type=InvokeFrom.TRIGGER,
tenant_id=webhook_trigger.tenant_id,
app_id=webhook_trigger.app_id,
user_id=None,
)
# reserve quota before triggering workflow execution
try:
quota_charge = QuotaService.reserve(QuotaType.TRIGGER, webhook_trigger.tenant_id)
except QuotaExceededError:
AppTriggerService.mark_tenant_triggers_rate_limited(webhook_trigger.tenant_id)
logger.info(
"Tenant %s rate limited, skipping webhook trigger %s",
webhook_trigger.tenant_id,
webhook_trigger.webhook_id,
)
raise
# Trigger workflow execution asynchronously
try:
try:
# NOTE: don not use `with sessionmaker(bind=db.engine, expire_on_commit=False).begin()`
# trigger_workflow_async need to handle multipe session commits internally
with Session(db.engine, expire_on_commit=False) as session:
AsyncWorkflowService.trigger_workflow_async(
session,
end_user,
trigger_data,
)
quota_charge.commit()
except Exception:
quota_charge.refund()
raise
quota_charge.commit()
except Exception:
quota_charge.refund()
raise
except Exception:
logger.exception("Failed to trigger workflow for webhook %s", webhook_trigger.webhook_id)

View File

@ -16,6 +16,7 @@ from extensions.ext_database import db
from models import UploadFile
from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment, SegmentAttachmentBinding
from models.dataset import Document as DatasetDocument
from models.enums import SegmentType
logger = logging.getLogger(__name__)
@ -178,7 +179,7 @@ class VectorService:
index_node_hash=child_chunk.metadata["doc_hash"],
content=child_chunk.page_content,
word_count=len(child_chunk.page_content),
type="automatic",
type=SegmentType.AUTOMATIC,
created_by=dataset_document.created_by,
)
db.session.add(child_segment)
@ -222,6 +223,7 @@ class VectorService:
)
documents.append(new_child_document)
for update_child_chunk in update_child_chunks:
assert update_child_chunk.index_node_id
child_document = Document(
page_content=update_child_chunk.content,
metadata={
@ -234,6 +236,7 @@ class VectorService:
documents.append(child_document)
delete_node_ids.append(update_child_chunk.index_node_id)
for delete_child_chunk in delete_child_chunks:
assert delete_child_chunk.index_node_id
delete_node_ids.append(delete_child_chunk.index_node_id)
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
# update vector index
@ -246,6 +249,7 @@ class VectorService:
@classmethod
def delete_child_chunk_vector(cls, child_chunk: ChildChunk, dataset: Dataset):
vector = Vector(dataset=dataset)
assert child_chunk.index_node_id
vector.delete_by_ids([child_chunk.index_node_id])
@classmethod

View File

@ -173,11 +173,18 @@ class WorkflowService:
# return draft workflow
return workflow
def get_published_workflow_by_id(self, app_model: App, workflow_id: str) -> Workflow | None:
def get_published_workflow_by_id(
self, app_model: App, workflow_id: str, session: Session | None = None
) -> Workflow | None:
"""
fetch published workflow by workflow_id
When ``session`` is provided, reuse it so callers that already hold a
Session avoid checking out an extra request-scoped ``db.session``
connection. Falls back to ``db.session`` for backward compatibility.
"""
workflow = db.session.scalar(
bind = session if session is not None else db.session
workflow = bind.scalar(
select(Workflow)
.where(
Workflow.tenant_id == app_model.tenant_id,
@ -195,16 +202,20 @@ class WorkflowService:
)
return workflow
def get_published_workflow(self, app_model: App) -> Workflow | None:
def get_published_workflow(self, app_model: App, session: Session | None = None) -> Workflow | None:
"""
Get published workflow
When ``session`` is provided, reuse it so callers that already hold a
Session avoid checking out an extra request-scoped ``db.session``
connection. Falls back to ``db.session`` for backward compatibility.
"""
if not app_model.workflow_id:
return None
# fetch published workflow by workflow_id
workflow = db.session.scalar(
bind = session if session is not None else db.session
workflow = bind.scalar(
select(Workflow)
.where(
Workflow.tenant_id == app_model.tenant_id,

View File

@ -259,59 +259,58 @@ def dispatch_triggered_workflow(
tenant_id=subscription.tenant_id, provider_id=TriggerProviderID(subscription.provider_id)
)
trigger_entity: TriggerProviderEntity = provider_controller.entity
# Ensure expire_on_commit is set to False to remain workflows available
with session_factory.create_session() as session:
workflows: Mapping[str, Workflow] = _get_latest_workflows_by_app_ids(session, subscribers)
end_users: Mapping[str, EndUser] = EndUserService.create_end_user_batch(
type=InvokeFrom.TRIGGER,
tenant_id=subscription.tenant_id,
app_ids=[plugin_trigger.app_id for plugin_trigger in subscribers],
user_id=user_id,
)
for plugin_trigger in subscribers:
# Get workflow from mapping
workflow: Workflow | None = workflows.get(plugin_trigger.app_id)
if not workflow:
logger.error(
"Workflow not found for app %s",
plugin_trigger.app_id,
)
continue
end_users: Mapping[str, EndUser] = EndUserService.create_end_user_batch(
type=InvokeFrom.TRIGGER,
tenant_id=subscription.tenant_id,
app_ids=[plugin_trigger.app_id for plugin_trigger in subscribers],
user_id=user_id,
)
# Find the trigger node in the workflow
event_node = None
for node_id, node_config in workflow.walk_nodes(TRIGGER_PLUGIN_NODE_TYPE):
if node_id == plugin_trigger.node_id:
event_node = node_config
break
if not event_node:
logger.error("Trigger event node not found for app %s", plugin_trigger.app_id)
continue
# invoke trigger
trigger_metadata = PluginTriggerMetadata(
plugin_unique_identifier=provider_controller.plugin_unique_identifier or "",
endpoint_id=subscription.endpoint_id,
provider_id=subscription.provider_id,
event_name=event_name,
icon_filename=trigger_entity.identity.icon or "",
icon_dark_filename=trigger_entity.identity.icon_dark or "",
for plugin_trigger in subscribers:
workflow: Workflow | None = workflows.get(plugin_trigger.app_id)
if not workflow:
logger.error(
"Workflow not found for app %s",
plugin_trigger.app_id,
)
continue
# reserve quota before invoking trigger
quota_charge = unlimited()
try:
quota_charge = QuotaService.reserve(QuotaType.TRIGGER, subscription.tenant_id)
except QuotaExceededError:
AppTriggerService.mark_tenant_triggers_rate_limited(subscription.tenant_id)
logger.info(
"Tenant %s rate limited, skipping plugin trigger %s", subscription.tenant_id, plugin_trigger.id
)
return 0
event_node = None
for node_id, node_config in workflow.walk_nodes(TRIGGER_PLUGIN_NODE_TYPE):
if node_id == plugin_trigger.node_id:
event_node = node_config
break
node_data: TriggerEventNodeData = TriggerEventNodeData.model_validate(event_node)
invoke_response: TriggerInvokeEventResponse | None = None
if not event_node:
logger.error("Trigger event node not found for app %s", plugin_trigger.app_id)
continue
trigger_metadata = PluginTriggerMetadata(
plugin_unique_identifier=provider_controller.plugin_unique_identifier or "",
endpoint_id=subscription.endpoint_id,
provider_id=subscription.provider_id,
event_name=event_name,
icon_filename=trigger_entity.identity.icon or "",
icon_dark_filename=trigger_entity.identity.icon_dark or "",
)
quota_charge = unlimited()
try:
quota_charge = QuotaService.reserve(QuotaType.TRIGGER, subscription.tenant_id)
except QuotaExceededError:
AppTriggerService.mark_tenant_triggers_rate_limited(subscription.tenant_id)
logger.info("Tenant %s rate limited, skipping plugin trigger %s", subscription.tenant_id, plugin_trigger.id)
return dispatched_count
node_data: TriggerEventNodeData = TriggerEventNodeData.model_validate(event_node)
invoke_response: TriggerInvokeEventResponse | None = None
with session_factory.create_session() as session:
try:
invoke_response = TriggerManager.invoke_trigger_event(
tenant_id=subscription.tenant_id,
@ -403,7 +402,7 @@ def dispatch_triggered_workflow(
plugin_trigger.app_id,
)
return dispatched_count
return dispatched_count
def dispatch_triggered_workflows(

View File

@ -33,6 +33,7 @@ def run_schedule_trigger(schedule_id: str) -> None:
TenantOwnerNotFoundError: If no owner/admin for tenant
ScheduleExecutionError: If workflow trigger fails
"""
# Ensure expire_on_commit is set to False to remain schedule/tenant_owner available
with session_factory.create_session() as session:
schedule = session.get(WorkflowSchedulePlan, schedule_id)
if not schedule:
@ -42,16 +43,16 @@ def run_schedule_trigger(schedule_id: str) -> None:
if not tenant_owner:
raise TenantOwnerNotFoundError(f"No owner or admin found for tenant {schedule.tenant_id}")
quota_charge = unlimited()
try:
quota_charge = QuotaService.reserve(QuotaType.TRIGGER, schedule.tenant_id)
except QuotaExceededError:
AppTriggerService.mark_tenant_triggers_rate_limited(schedule.tenant_id)
logger.info("Tenant %s rate limited, skipping schedule trigger %s", schedule.tenant_id, schedule_id)
return
quota_charge = unlimited()
try:
quota_charge = QuotaService.reserve(QuotaType.TRIGGER, schedule.tenant_id)
except QuotaExceededError:
AppTriggerService.mark_tenant_triggers_rate_limited(schedule.tenant_id)
logger.info("Tenant %s rate limited, skipping schedule trigger %s", schedule.tenant_id, schedule_id)
return
try:
# Production dispatch: Trigger the workflow normally
try:
with session_factory.create_session() as session:
response = AsyncWorkflowService.trigger_workflow_async(
session=session,
user=tenant_owner,
@ -62,10 +63,10 @@ def run_schedule_trigger(schedule_id: str) -> None:
tenant_id=schedule.tenant_id,
),
)
quota_charge.commit()
logger.info("Schedule %s triggered workflow: %s", schedule_id, response.workflow_trigger_log_id)
except Exception as e:
quota_charge.refund()
raise ScheduleExecutionError(
f"Failed to trigger workflow for schedule {schedule_id}, app {schedule.app_id}"
) from e
quota_charge.commit()
logger.info("Schedule %s triggered workflow: %s", schedule_id, response.workflow_trigger_log_id)
except Exception as e:
quota_charge.refund()
raise ScheduleExecutionError(
f"Failed to trigger workflow for schedule {schedule_id}, app {schedule.app_id}"
) from e

View File

@ -171,35 +171,13 @@ class TestChatMessageApiPermissions:
parent_message_id=None,
)
class MockQuery:
def __init__(self, model):
self.model = model
def where(self, *args, **kwargs):
return self
def first(self):
if getattr(self.model, "__name__", "") == "Conversation":
return mock_conversation
return None
def order_by(self, *args, **kwargs):
return self
def limit(self, *_):
return self
def all(self):
if getattr(self.model, "__name__", "") == "Message":
return [mock_message]
return []
mock_session = mock.Mock()
mock_session.query.side_effect = MockQuery
mock_session.scalar.return_value = False
mock_session.scalar.return_value = mock_conversation
mock_session.scalars.return_value.all.return_value = [mock_message]
monkeypatch.setattr(message_api, "db", SimpleNamespace(session=mock_session))
monkeypatch.setattr(message_api, "current_user", mock_account)
monkeypatch.setattr(message_api, "attach_message_extra_contents", mock.Mock())
class DummyPagination:
def __init__(self, data, limit, has_more):

View File

@ -24,7 +24,6 @@ def _patch_wraps():
patch("controllers.console.wraps.dify_config", dify_settings),
patch("controllers.console.wraps.FeatureService.get_system_features", return_value=wraps_features),
):
mock_db.session.query.return_value.first.return_value = MagicMock()
yield

View File

@ -13,6 +13,12 @@ from models.model import App, Conversation, Message
from services.feedback_service import FeedbackService
def _execute_result(rows):
result = mock.Mock()
result.all.return_value = rows
return result
class TestFeedbackService:
"""Test FeedbackService methods."""
@ -81,25 +87,17 @@ class TestFeedbackService:
def test_export_feedbacks_csv_format(self, mock_db_session, sample_data):
"""Test exporting feedback data in CSV format."""
# Setup mock query result
mock_query = mock.Mock()
mock_query.join.return_value = mock_query
mock_query.outerjoin.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.filter.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.all.return_value = [
(
sample_data["user_feedback"],
sample_data["message"],
sample_data["conversation"],
sample_data["app"],
sample_data["user_feedback"].from_account,
)
]
mock_db_session.execute.return_value = mock_query
mock_db_session.execute.return_value = _execute_result(
[
(
sample_data["user_feedback"],
sample_data["message"],
sample_data["conversation"],
sample_data["app"],
sample_data["user_feedback"].from_account,
)
]
)
# Test CSV export
result = FeedbackService.export_feedbacks(app_id=sample_data["app"].id, format_type="csv")
@ -120,25 +118,17 @@ class TestFeedbackService:
def test_export_feedbacks_json_format(self, mock_db_session, sample_data):
"""Test exporting feedback data in JSON format."""
# Setup mock query result
mock_query = mock.Mock()
mock_query.join.return_value = mock_query
mock_query.outerjoin.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.filter.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.all.return_value = [
(
sample_data["admin_feedback"],
sample_data["message"],
sample_data["conversation"],
sample_data["app"],
sample_data["admin_feedback"].from_account,
)
]
mock_db_session.execute.return_value = mock_query
mock_db_session.execute.return_value = _execute_result(
[
(
sample_data["admin_feedback"],
sample_data["message"],
sample_data["conversation"],
sample_data["app"],
sample_data["admin_feedback"].from_account,
)
]
)
# Test JSON export
result = FeedbackService.export_feedbacks(app_id=sample_data["app"].id, format_type="json")
@ -157,25 +147,17 @@ class TestFeedbackService:
def test_export_feedbacks_with_filters(self, mock_db_session, sample_data):
"""Test exporting feedback with various filters."""
# Setup mock query result
mock_query = mock.Mock()
mock_query.join.return_value = mock_query
mock_query.outerjoin.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.filter.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.all.return_value = [
(
sample_data["admin_feedback"],
sample_data["message"],
sample_data["conversation"],
sample_data["app"],
sample_data["admin_feedback"].from_account,
)
]
mock_db_session.execute.return_value = mock_query
mock_db_session.execute.return_value = _execute_result(
[
(
sample_data["admin_feedback"],
sample_data["message"],
sample_data["conversation"],
sample_data["app"],
sample_data["admin_feedback"].from_account,
)
]
)
# Test with filters
result = FeedbackService.export_feedbacks(
@ -193,17 +175,7 @@ class TestFeedbackService:
def test_export_feedbacks_no_data(self, mock_db_session, sample_data):
"""Test exporting feedback when no data exists."""
# Setup mock query result with no data
mock_query = mock.Mock()
mock_query.join.return_value = mock_query
mock_query.outerjoin.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.filter.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.all.return_value = []
mock_db_session.execute.return_value = mock_query
mock_db_session.execute.return_value = _execute_result([])
result = FeedbackService.export_feedbacks(app_id=sample_data["app"].id, format_type="csv")
@ -251,24 +223,17 @@ class TestFeedbackService:
created_at=datetime(2024, 1, 1, 10, 0, 0),
)
# Setup mock query result
mock_query = mock.Mock()
mock_query.join.return_value = mock_query
mock_query.outerjoin.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.filter.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.all.return_value = [
(
sample_data["user_feedback"],
long_message,
sample_data["conversation"],
sample_data["app"],
sample_data["user_feedback"].from_account,
)
]
mock_db_session.execute.return_value = mock_query
mock_db_session.execute.return_value = _execute_result(
[
(
sample_data["user_feedback"],
long_message,
sample_data["conversation"],
sample_data["app"],
sample_data["user_feedback"].from_account,
)
]
)
# Test export
result = FeedbackService.export_feedbacks(app_id=sample_data["app"].id, format_type="json")
@ -309,24 +274,17 @@ class TestFeedbackService:
created_at=datetime(2024, 1, 1, 10, 0, 0),
)
# Setup mock query result
mock_query = mock.Mock()
mock_query.join.return_value = mock_query
mock_query.outerjoin.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.filter.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.all.return_value = [
(
chinese_feedback,
chinese_message,
sample_data["conversation"],
sample_data["app"],
None, # No account for user feedback
)
]
mock_db_session.execute.return_value = mock_query
mock_db_session.execute.return_value = _execute_result(
[
(
chinese_feedback,
chinese_message,
sample_data["conversation"],
sample_data["app"],
None,
)
]
)
# Test export
result = FeedbackService.export_feedbacks(app_id=sample_data["app"].id, format_type="csv")
@ -339,32 +297,24 @@ class TestFeedbackService:
def test_export_feedbacks_emoji_ratings(self, mock_db_session, sample_data):
"""Test that rating emojis are properly formatted in export."""
# Setup mock query result with both like and dislike feedback
mock_query = mock.Mock()
mock_query.join.return_value = mock_query
mock_query.outerjoin.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.filter.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.all.return_value = [
(
sample_data["user_feedback"],
sample_data["message"],
sample_data["conversation"],
sample_data["app"],
sample_data["user_feedback"].from_account,
),
(
sample_data["admin_feedback"],
sample_data["message"],
sample_data["conversation"],
sample_data["app"],
sample_data["admin_feedback"].from_account,
),
]
mock_db_session.execute.return_value = mock_query
mock_db_session.execute.return_value = _execute_result(
[
(
sample_data["user_feedback"],
sample_data["message"],
sample_data["conversation"],
sample_data["app"],
sample_data["user_feedback"].from_account,
),
(
sample_data["admin_feedback"],
sample_data["message"],
sample_data["conversation"],
sample_data["app"],
sample_data["admin_feedback"].from_account,
),
]
)
# Test export
result = FeedbackService.export_feedbacks(app_id=sample_data["app"].id, format_type="json")

View File

@ -10,6 +10,7 @@ from sqlalchemy import select
from sqlalchemy.orm import Session
from core.trigger.constants import TRIGGER_WEBHOOK_NODE_TYPE
from enums.quota_type import QuotaType
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.enums import AppTriggerStatus, AppTriggerType
from models.model import App
@ -290,17 +291,26 @@ class TestWebhookServiceTriggerExecutionWithContainers:
end_user = SimpleNamespace(id=str(uuid4()))
webhook_data = {"body": {"value": 1}, "headers": {}, "query_params": {}, "files": {}, "method": "POST"}
quota_charge = MagicMock()
with (
patch(
"services.trigger.webhook_service.EndUserService.get_or_create_end_user_by_type",
return_value=end_user,
),
patch("services.trigger.webhook_service.QuotaType.TRIGGER.consume") as mock_consume,
patch(
"services.trigger.webhook_service.QuotaService.reserve",
return_value=quota_charge,
) as mock_reserve,
patch("services.trigger.webhook_service.AsyncWorkflowService.trigger_workflow_async") as mock_trigger,
):
WebhookService.trigger_workflow_execution(webhook_trigger, webhook_data, workflow)
mock_consume.assert_called_once_with(webhook_trigger.tenant_id)
mock_reserve.assert_called_once()
reserve_args = mock_reserve.call_args.args
assert reserve_args[0] == QuotaType.TRIGGER
assert reserve_args[1] == webhook_trigger.tenant_id
quota_charge.commit.assert_called_once()
mock_trigger.assert_called_once()
trigger_args = mock_trigger.call_args.args
assert trigger_args[1] is end_user
@ -327,7 +337,7 @@ class TestWebhookServiceTriggerExecutionWithContainers:
return_value=SimpleNamespace(id=str(uuid4())),
),
patch(
"services.trigger.webhook_service.QuotaType.TRIGGER.consume",
"services.trigger.webhook_service.QuotaService.reserve",
side_effect=QuotaExceededError(feature="trigger", tenant_id=tenant.id, required=1),
),
patch(

View File

@ -121,33 +121,32 @@ def _configure_session_factory(_unit_test_engine):
configure_session_factory(_unit_test_engine, expire_on_commit=False)
def setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account):
def setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_owner):
"""
Helper to set up the mock DB execute chain for tenant/account authentication.
Helper to stub the tenant-owner execute result for service API app authentication.
This configures the mock to return (tenant, account) for the
db.session.execute(select(...).join().join().where()).one_or_none()
query used by validate_app_token decorator.
The validate_app_token decorator currently resolves the active tenant owner
via db.session.execute(select(Tenant, Account)...).one_or_none().
Args:
mock_db: The mocked db object
mock_tenant: Mock tenant object to return
mock_account: Mock account object to return
mock_owner: Mock owner object to return from the execute result
"""
mock_db.session.execute.return_value.one_or_none.return_value = (mock_tenant, mock_account)
mock_db.session.execute.return_value.one_or_none.return_value = (mock_tenant, mock_owner)
def setup_mock_dataset_tenant_query(mock_db, mock_tenant, mock_ta):
def setup_mock_dataset_owner_execute_result(mock_db, mock_tenant, mock_tenant_account_join):
"""
Helper to set up the mock DB execute chain for dataset tenant authentication.
Helper to stub the tenant-owner execute result for dataset token authentication.
This configures the mock to return (tenant, tenant_account) for the
db.session.execute(select(...).where().where().where().where()).one_or_none()
query used by validate_dataset_token decorator.
The validate_dataset_token decorator currently resolves the owner mapping via
db.session.execute(select(Tenant, TenantAccountJoin)...).one_or_none(), and
then loads the Account separately via db.session.get(...).
Args:
mock_db: The mocked db object
mock_tenant: Mock tenant object to return
mock_ta: Mock tenant account object to return
mock_tenant_account_join: Mock tenant-account join object to return
"""
mock_db.session.execute.return_value.one_or_none.return_value = (mock_tenant, mock_ta)
mock_db.session.execute.return_value.one_or_none.return_value = (mock_tenant, mock_tenant_account_join)

View File

@ -208,8 +208,6 @@ class TestAnnotationImportServiceValidation:
file = FileStorage(stream=io.BytesIO(csv_content.encode()), filename="test.csv", content_type="text/csv")
mock_db_session.query.return_value.where.return_value.first.return_value = mock_app
with patch("services.annotation_service.current_account_with_tenant") as mock_auth:
mock_auth.return_value = (MagicMock(id="user_id"), "tenant_id")
@ -230,8 +228,6 @@ class TestAnnotationImportServiceValidation:
file = FileStorage(stream=io.BytesIO(csv_content.encode()), filename="test.csv", content_type="text/csv")
mock_db_session.query.return_value.where.return_value.first.return_value = mock_app
with patch("services.annotation_service.current_account_with_tenant") as mock_auth:
mock_auth.return_value = (MagicMock(id="user_id"), "tenant_id")
@ -248,8 +244,6 @@ class TestAnnotationImportServiceValidation:
csv_content = 'invalid,csv,format\nwith,unbalanced,quotes,and"stuff'
file = FileStorage(stream=io.BytesIO(csv_content.encode()), filename="test.csv", content_type="text/csv")
mock_db_session.query.return_value.where.return_value.first.return_value = mock_app
with (
patch("services.annotation_service.current_account_with_tenant") as mock_auth,
patch("services.annotation_service.pd.read_csv", side_effect=ParserError("malformed CSV")),
@ -269,8 +263,6 @@ class TestAnnotationImportServiceValidation:
file = FileStorage(stream=io.BytesIO(csv_content.encode()), filename="test.csv", content_type="text/csv")
mock_db_session.query.return_value.where.return_value.first.return_value = mock_app
with patch("services.annotation_service.current_account_with_tenant") as mock_auth:
mock_auth.return_value = (MagicMock(id="user_id"), "tenant_id")

View File

@ -43,7 +43,6 @@ class TestAuthenticationSecurity:
mock_is_rate_limit.return_value = False
mock_get_invitation.return_value = None
mock_authenticate.side_effect = services.errors.account.AccountPasswordError("Invalid email or password.")
mock_db.session.query.return_value.first.return_value = MagicMock() # Mock setup exists
mock_features.return_value.is_allow_register = True
# Act
@ -76,7 +75,6 @@ class TestAuthenticationSecurity:
mock_is_rate_limit.return_value = False
mock_get_invitation.return_value = None
mock_authenticate.side_effect = services.errors.account.AccountPasswordError("Wrong password")
mock_db.session.query.return_value.first.return_value = MagicMock() # Mock setup exists
# Act
with self.app.test_request_context(
@ -109,7 +107,6 @@ class TestAuthenticationSecurity:
mock_is_rate_limit.return_value = False
mock_get_invitation.return_value = None
mock_authenticate.side_effect = services.errors.account.AccountPasswordError("Invalid email or password.")
mock_db.session.query.return_value.first.return_value = MagicMock() # Mock setup exists
mock_features.return_value.is_allow_register = False
# Act
@ -135,7 +132,6 @@ class TestAuthenticationSecurity:
def test_reset_password_with_existing_account(self, mock_send_email, mock_get_user, mock_features, mock_db):
"""Test that reset password returns success with token for existing accounts."""
# Mock the setup check
mock_db.session.query.return_value.first.return_value = MagicMock() # Mock setup exists
# Test with existing account
mock_get_user.return_value = MagicMock(email="existing@example.com")

View File

@ -65,7 +65,6 @@ class TestEmailCodeLoginSendEmailApi:
- IP rate limiting is checked
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_is_ip_limit.return_value = False
mock_get_user.return_value = mock_account
mock_send_email.return_value = "email_token_123"
@ -98,7 +97,6 @@ class TestEmailCodeLoginSendEmailApi:
- Registration is allowed by system features
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_is_ip_limit.return_value = False
mock_get_user.return_value = None
mock_get_features.return_value.is_allow_register = True
@ -130,7 +128,6 @@ class TestEmailCodeLoginSendEmailApi:
- Registration is blocked by system features
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_is_ip_limit.return_value = False
mock_get_user.return_value = None
mock_get_features.return_value.is_allow_register = False
@ -152,7 +149,6 @@ class TestEmailCodeLoginSendEmailApi:
- Prevents spam and abuse
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_is_ip_limit.return_value = True
# Act & Assert
@ -172,7 +168,6 @@ class TestEmailCodeLoginSendEmailApi:
- AccountInFreezeError is raised for frozen accounts
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_is_ip_limit.return_value = False
mock_get_user.side_effect = AccountRegisterError("Account frozen")
@ -213,7 +208,6 @@ class TestEmailCodeLoginSendEmailApi:
- Defaults to en-US when not specified
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_is_ip_limit.return_value = False
mock_get_user.return_value = mock_account
mock_send_email.return_value = "token"
@ -286,7 +280,6 @@ class TestEmailCodeLoginApi:
- User is logged in with token pair
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_get_data.return_value = {"email": "test@example.com", "code": "123456"}
mock_get_user.return_value = mock_account
mock_get_tenants.return_value = [MagicMock()]
@ -335,7 +328,6 @@ class TestEmailCodeLoginApi:
- User is logged in after account creation
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_get_data.return_value = {"email": "newuser@example.com", "code": "123456"}
mock_get_user.return_value = None
mock_create_account.return_value = mock_account
@ -369,7 +361,6 @@ class TestEmailCodeLoginApi:
- InvalidTokenError is raised for invalid/expired tokens
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_get_data.return_value = None
# Act & Assert
@ -392,7 +383,6 @@ class TestEmailCodeLoginApi:
- InvalidEmailError is raised when email doesn't match token
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_get_data.return_value = {"email": "original@example.com", "code": "123456"}
# Act & Assert
@ -415,7 +405,6 @@ class TestEmailCodeLoginApi:
- EmailCodeError is raised for wrong verification code
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_get_data.return_value = {"email": "test@example.com", "code": "123456"}
# Act & Assert
@ -453,7 +442,6 @@ class TestEmailCodeLoginApi:
- User is added as owner of new workspace
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_get_data.return_value = {"email": "test@example.com", "code": "123456"}
mock_get_user.return_value = mock_account
mock_get_tenants.return_value = []
@ -496,7 +484,6 @@ class TestEmailCodeLoginApi:
- WorkspacesLimitExceeded is raised when limit reached
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_get_data.return_value = {"email": "test@example.com", "code": "123456"}
mock_get_user.return_value = mock_account
mock_get_tenants.return_value = []
@ -538,7 +525,6 @@ class TestEmailCodeLoginApi:
- NotAllowedCreateWorkspace is raised when creation disabled
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_get_data.return_value = {"email": "test@example.com", "code": "123456"}
mock_get_user.return_value = mock_account
mock_get_tenants.return_value = []

View File

@ -110,7 +110,6 @@ class TestLoginApi:
- Rate limit is reset after successful login
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_is_rate_limit.return_value = False
mock_get_invitation.return_value = None
mock_authenticate.return_value = mock_account
@ -162,7 +161,6 @@ class TestLoginApi:
- Authentication proceeds with invitation token
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_is_rate_limit.return_value = False
mock_get_invitation.return_value = {"data": {"email": "test@example.com"}}
mock_authenticate.return_value = mock_account
@ -199,7 +197,6 @@ class TestLoginApi:
- EmailPasswordLoginLimitError is raised when limit exceeded
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_is_rate_limit.return_value = True
mock_get_invitation.return_value = None
@ -228,7 +225,6 @@ class TestLoginApi:
- AccountInFreezeError is raised for frozen accounts
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_is_frozen.return_value = True
# Act & Assert
@ -268,7 +264,6 @@ class TestLoginApi:
- Generic error message prevents user enumeration
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_is_rate_limit.return_value = False
mock_get_invitation.return_value = None
mock_authenticate.side_effect = AccountPasswordError("Invalid password")
@ -305,7 +300,6 @@ class TestLoginApi:
- Login is prevented even with valid credentials
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_is_rate_limit.return_value = False
mock_get_invitation.return_value = None
mock_authenticate.side_effect = AccountLoginError("Account is banned")
@ -351,7 +345,6 @@ class TestLoginApi:
- User cannot login without an assigned workspace
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_is_rate_limit.return_value = False
mock_get_invitation.return_value = None
mock_authenticate.return_value = mock_account
@ -383,7 +376,6 @@ class TestLoginApi:
- Security check prevents invitation token abuse
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_is_rate_limit.return_value = False
mock_get_invitation.return_value = {"data": {"email": "invited@example.com"}}
@ -425,7 +417,6 @@ class TestLoginApi:
mock_token_pair,
):
"""Test that login retries with lowercase email when uppercase lookup fails."""
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_is_rate_limit.return_value = False
mock_get_invitation.return_value = None
mock_authenticate.side_effect = [AccountPasswordError("Invalid"), mock_account]
@ -459,7 +450,6 @@ class TestLoginApi:
mock_db,
app,
):
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_get_token_data.return_value = {"email": "User@Example.com", "code": "123456"}
mock_get_account.side_effect = Unauthorized("Account is banned.")
@ -513,7 +503,6 @@ class TestLogoutApi:
- Success response is returned
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_current_account.return_value = (mock_account, MagicMock())
# Act
@ -539,7 +528,6 @@ class TestLogoutApi:
- Success response is returned
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
# Create a mock anonymous user that will pass isinstance check
anonymous_user = MagicMock()
mock_flask_login.AnonymousUserMixin = type("AnonymousUserMixin", (), {})

View File

@ -46,7 +46,6 @@ class TestPartnerTenants:
patch("libs.login.dify_config.LOGIN_DISABLED", False),
patch("libs.login.check_csrf_token") as mock_csrf,
):
mock_db.session.query.return_value.first.return_value = MagicMock() # Mock setup exists
mock_csrf.return_value = None
yield {"db": mock_db, "csrf": mock_csrf}

View File

@ -8,8 +8,10 @@ from werkzeug.exceptions import Forbidden
import controllers.console.tag.tags as module
from controllers.console import console_ns
from controllers.console.tag.tags import (
TagBindingCreateApi,
TagBindingDeleteApi,
DeprecatedTagBindingCreateApi,
DeprecatedTagBindingRemoveApi,
TagBindingCollectionApi,
TagBindingItemApi,
TagListApi,
TagUpdateDeleteApi,
)
@ -205,9 +207,9 @@ class TestTagUpdateDeleteApi:
assert status == 204
class TestTagBindingCreateApi:
class TestTagBindingCollectionApi:
def test_create_success(self, app, admin_user, payload_patch):
api = TagBindingCreateApi()
api = TagBindingCollectionApi()
method = unwrap(api.post)
payload = {
@ -232,7 +234,7 @@ class TestTagBindingCreateApi:
assert result["result"] == "success"
def test_create_forbidden(self, app, readonly_user, payload_patch):
api = TagBindingCreateApi()
api = TagBindingCollectionApi()
method = unwrap(api.post)
with app.test_request_context("/", json={}):
@ -247,9 +249,78 @@ class TestTagBindingCreateApi:
method(api)
class TestTagBindingDeleteApi:
class TestDeprecatedTagBindingCreateApi:
def test_create_success(self, app, admin_user, payload_patch):
api = DeprecatedTagBindingCreateApi()
method = unwrap(api.post)
payload = {
"tag_ids": ["tag-1"],
"target_id": "target-1",
"type": "knowledge",
}
with app.test_request_context("/", json=payload):
with (
patch(
"controllers.console.tag.tags.current_account_with_tenant",
return_value=(admin_user, None),
),
payload_patch(payload),
patch("controllers.console.tag.tags.TagService.save_tag_binding") as save_mock,
):
result, status = method(api)
save_mock.assert_called_once()
assert status == 200
assert result["result"] == "success"
class TestTagBindingItemApi:
def test_delete_success(self, app, admin_user, payload_patch):
api = TagBindingItemApi()
method = unwrap(api.delete)
payload = {
"target_id": "target-1",
"type": "knowledge",
}
with app.test_request_context("/", json=payload):
with (
patch(
"controllers.console.tag.tags.current_account_with_tenant",
return_value=(admin_user, None),
),
payload_patch(payload),
patch("controllers.console.tag.tags.TagService.delete_tag_binding") as delete_mock,
):
result, status = method(api, "tag-1")
delete_mock.assert_called_once()
delete_payload = delete_mock.call_args.args[0]
assert delete_payload.tag_id == "tag-1"
assert delete_payload.target_id == "target-1"
assert delete_payload.type == TagType.KNOWLEDGE
assert status == 200
assert result["result"] == "success"
def test_delete_forbidden(self, app, readonly_user):
api = TagBindingItemApi()
method = unwrap(api.delete)
with app.test_request_context("/"):
with patch(
"controllers.console.tag.tags.current_account_with_tenant",
return_value=(readonly_user, None),
):
with pytest.raises(Forbidden):
method(api, "tag-1")
class TestDeprecatedTagBindingRemoveApi:
def test_remove_success(self, app, admin_user, payload_patch):
api = TagBindingDeleteApi()
api = DeprecatedTagBindingRemoveApi()
method = unwrap(api.post)
payload = {
@ -274,7 +345,7 @@ class TestTagBindingDeleteApi:
assert result["result"] == "success"
def test_remove_forbidden(self, app, readonly_user, payload_patch):
api = TagBindingDeleteApi()
api = DeprecatedTagBindingRemoveApi()
method = unwrap(api.post)
with app.test_request_context("/", json={}):
@ -297,3 +368,35 @@ class TestTagResponseModel:
assert payload["type"] == "knowledge"
assert payload["binding_count"] == "1"
class TestTagBindingRouteMetadata:
def test_legacy_write_routes_are_marked_deprecated(self):
assert DeprecatedTagBindingCreateApi.post.__apidoc__["deprecated"] is True
assert DeprecatedTagBindingRemoveApi.post.__apidoc__["deprecated"] is True
assert TagBindingCollectionApi.post.__apidoc__.get("deprecated") is not True
assert TagBindingItemApi.delete.__apidoc__.get("deprecated") is not True
def test_write_routes_have_stable_operation_ids(self):
assert TagBindingCollectionApi.post.__apidoc__["id"] == "create_tag_binding"
assert TagBindingItemApi.delete.__apidoc__["id"] == "delete_tag_binding"
assert DeprecatedTagBindingCreateApi.post.__apidoc__["id"] == "create_tag_binding_deprecated"
assert DeprecatedTagBindingRemoveApi.post.__apidoc__["id"] == "delete_tag_binding_deprecated"
def test_canonical_and_legacy_write_routes_are_registered(self):
route_map = {
resource.__name__: urls
for resource, urls, _route_doc, _kwargs in console_ns.resources
if resource.__name__
in {
"TagBindingCollectionApi",
"TagBindingItemApi",
"DeprecatedTagBindingCreateApi",
"DeprecatedTagBindingRemoveApi",
}
}
assert route_map["TagBindingCollectionApi"] == ("/tag-bindings",)
assert route_map["TagBindingItemApi"] == ("/tag-bindings/<uuid:id>",)
assert route_map["DeprecatedTagBindingCreateApi"] == ("/tag-bindings/create",)
assert route_map["DeprecatedTagBindingRemoveApi"] == ("/tag-bindings/remove",)

View File

@ -24,10 +24,6 @@ def app():
return app
def _mock_wraps_db(mock_db):
mock_db.session.query.return_value.first.return_value = MagicMock()
def _build_account(email: str, account_id: str = "acc", tenant: object | None = None) -> Account:
tenant_obj = tenant if tenant is not None else SimpleNamespace(id="tenant-id")
account = Account(name=account_id, email=email)
@ -64,7 +60,6 @@ class TestChangeEmailSend:
mock_db,
app,
):
_mock_wraps_db(mock_db)
mock_features.return_value = SimpleNamespace(enable_change_email=True)
mock_account = _build_account("current@example.com", "acc1")
mock_current_account.return_value = (mock_account, None)
@ -117,7 +112,6 @@ class TestChangeEmailSend:
"""GHSA-4q3w-q5mc-45rq: a phase-1 token must not unlock the new-email send step."""
from controllers.console.auth.error import InvalidTokenError
_mock_wraps_db(mock_db)
mock_features.return_value = SimpleNamespace(enable_change_email=True)
mock_account = _build_account("current@example.com", "acc1")
mock_current_account.return_value = (mock_account, None)
@ -163,7 +157,6 @@ class TestChangeEmailValidity:
mock_db,
app,
):
_mock_wraps_db(mock_db)
mock_features.return_value = SimpleNamespace(enable_change_email=True)
mock_account = _build_account("user@example.com", "acc2")
mock_current_account.return_value = (mock_account, None)
@ -223,7 +216,6 @@ class TestChangeEmailValidity:
mock_db,
app,
):
_mock_wraps_db(mock_db)
mock_features.return_value = SimpleNamespace(enable_change_email=True)
mock_current_account.return_value = (_build_account("old@example.com", "acc"), None)
mock_is_rate_limit.return_value = False
@ -280,7 +272,6 @@ class TestChangeEmailValidity:
"""A token whose phase marker is a string but not a known transition must be rejected."""
from controllers.console.auth.error import InvalidTokenError
_mock_wraps_db(mock_db)
mock_features.return_value = SimpleNamespace(enable_change_email=True)
mock_current_account.return_value = (_build_account("old@example.com", "acc"), None)
mock_is_rate_limit.return_value = False
@ -330,7 +321,6 @@ class TestChangeEmailValidity:
"""A token minted without a phase marker (e.g. a hand-crafted token) must not validate."""
from controllers.console.auth.error import InvalidTokenError
_mock_wraps_db(mock_db)
mock_features.return_value = SimpleNamespace(enable_change_email=True)
mock_current_account.return_value = (_build_account("old@example.com", "acc"), None)
mock_is_rate_limit.return_value = False
@ -378,7 +368,6 @@ class TestChangeEmailReset:
mock_db,
app,
):
_mock_wraps_db(mock_db)
mock_features.return_value = SimpleNamespace(enable_change_email=True)
current_user = _build_account("old@example.com", "acc3")
mock_current_account.return_value = (current_user, None)
@ -434,7 +423,6 @@ class TestChangeEmailReset:
"""GHSA-4q3w-q5mc-45rq PoC: phase-1 token must not be usable against /reset."""
from controllers.console.auth.error import InvalidTokenError
_mock_wraps_db(mock_db)
mock_features.return_value = SimpleNamespace(enable_change_email=True)
current_user = _build_account("old@example.com", "acc3")
mock_current_account.return_value = (current_user, None)
@ -488,7 +476,6 @@ class TestChangeEmailReset:
"""A verified token for address A must not be replayed to change to address B."""
from controllers.console.auth.error import InvalidTokenError
_mock_wraps_db(mock_db)
mock_features.return_value = SimpleNamespace(enable_change_email=True)
current_user = _build_account("old@example.com", "acc3")
mock_current_account.return_value = (current_user, None)
@ -561,7 +548,6 @@ class TestAccountDeletionFeedback:
@patch("controllers.console.wraps.db")
@patch("controllers.console.workspace.account.BillingService.update_account_deletion_feedback")
def test_should_normalize_feedback_email(self, mock_update, mock_db, app):
_mock_wraps_db(mock_db)
with app.test_request_context(
"/account/delete/feedback",
method="POST",
@ -578,7 +564,6 @@ class TestCheckEmailUnique:
@patch("controllers.console.workspace.account.AccountService.check_email_unique")
@patch("controllers.console.workspace.account.AccountService.is_account_in_freeze")
def test_should_normalize_email(self, mock_is_freeze, mock_check_unique, mock_db, app):
_mock_wraps_db(mock_db)
mock_is_freeze.return_value = False
mock_check_unique.return_value = True

View File

@ -1,5 +1,5 @@
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
from unittest.mock import patch
import pytest
from flask import Flask, g
@ -16,10 +16,6 @@ def app():
return flask_app
def _mock_wraps_db(mock_db):
mock_db.session.query.return_value.first.return_value = MagicMock()
def _build_feature_flags():
placeholder_quota = SimpleNamespace(limit=0, size=0)
workspace_members = SimpleNamespace(is_available=lambda count: True)
@ -49,7 +45,6 @@ class TestMemberInviteEmailApi:
mock_get_features,
app,
):
_mock_wraps_db(mock_db)
mock_get_features.return_value = _build_feature_flags()
mock_invite_member.return_value = "token-abc"

View File

@ -310,7 +310,6 @@ class TestSystemSetup:
def test_should_allow_when_setup_complete(self, mock_db):
"""Test that requests are allowed when setup is complete"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock() # Setup exists
@setup_required
def admin_view():

View File

@ -22,7 +22,7 @@ _WRAPS_MODULE: ModuleType | None = None
@contextmanager
def _mock_db():
mock_session = SimpleNamespace(query=lambda *args, **kwargs: SimpleNamespace(first=lambda: True))
mock_session = SimpleNamespace(scalar=lambda *args, **kwargs: True)
with patch("extensions.ext_database.db.session", mock_session):
yield

View File

@ -12,7 +12,7 @@ from controllers.service_api.app.app import AppInfoApi, AppMetaApi, AppParameter
from controllers.service_api.app.error import AppUnavailableError
from models.account import TenantStatus
from models.model import App, AppMode
from tests.unit_tests.conftest import setup_mock_tenant_account_query
from tests.unit_tests.conftest import setup_mock_tenant_owner_execute_result
class TestAppParameterApi:
@ -74,7 +74,7 @@ class TestAppParameterApi:
# Mock tenant owner info for login
mock_account = Mock()
mock_account.current_tenant = mock_tenant
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account)
setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account)
# Act
with app.test_request_context("/parameters", method="GET", headers={"Authorization": "Bearer test_token"}):
@ -120,7 +120,7 @@ class TestAppParameterApi:
mock_account = Mock()
mock_account.current_tenant = mock_tenant
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account)
setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account)
# Act
with app.test_request_context("/parameters", method="GET", headers={"Authorization": "Bearer test_token"}):
@ -161,7 +161,7 @@ class TestAppParameterApi:
mock_account = Mock()
mock_account.current_tenant = mock_tenant
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account)
setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account)
# Act & Assert
with app.test_request_context("/parameters", method="GET", headers={"Authorization": "Bearer test_token"}):
@ -200,7 +200,7 @@ class TestAppParameterApi:
mock_account = Mock()
mock_account.current_tenant = mock_tenant
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account)
setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account)
# Act & Assert
with app.test_request_context("/parameters", method="GET", headers={"Authorization": "Bearer test_token"}):
@ -263,7 +263,7 @@ class TestAppMetaApi:
mock_account = Mock()
mock_account.current_tenant = mock_tenant
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account)
setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account)
# Act
with app.test_request_context("/meta", method="GET", headers={"Authorization": "Bearer test_token"}):
@ -331,7 +331,7 @@ class TestAppInfoApi:
mock_account = Mock()
mock_account.current_tenant = mock_tenant
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account)
setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account)
# Act
with app.test_request_context("/info", method="GET", headers={"Authorization": "Bearer test_token"}):
@ -388,7 +388,7 @@ class TestAppInfoApi:
mock_account = Mock()
mock_account.current_tenant = mock_tenant
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account)
setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account)
# Act
with app.test_request_context("/info", method="GET", headers={"Authorization": "Bearer test_token"}):
@ -434,7 +434,7 @@ class TestAppInfoApi:
mock_account = Mock()
mock_account.current_tenant = mock_tenant
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account)
setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account)
# Act
with app.test_request_context("/info", method="GET", headers={"Authorization": "Bearer test_token"}):
@ -486,7 +486,7 @@ class TestAppInfoApi:
mock_account = Mock()
mock_account.current_tenant = mock_tenant
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account)
setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account)
# Act
with app.test_request_context("/info", method="GET", headers={"Authorization": "Bearer test_token"}):

View File

@ -15,7 +15,10 @@ from flask import Flask
from core.rag.index_processor.constant.index_type import IndexStructureType
from models.account import TenantStatus
from models.model import App, AppMode, EndUser
from tests.unit_tests.conftest import setup_mock_tenant_account_query
from tests.unit_tests.conftest import (
setup_mock_dataset_owner_execute_result,
setup_mock_tenant_owner_execute_result,
)
@pytest.fixture
@ -123,9 +126,7 @@ class AuthenticationMocker:
mock_db.session.get.side_effect = [mock_app, mock_tenant]
if mock_account:
mock_ta = Mock()
mock_ta.account_id = mock_account.id
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_ta)
setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account)
@staticmethod
def setup_dataset_auth(mock_db, mock_tenant, mock_account):
@ -133,8 +134,7 @@ class AuthenticationMocker:
mock_ta = Mock()
mock_ta.account_id = mock_account.id
mock_db.session.execute.return_value.one_or_none.return_value = (mock_tenant, mock_ta)
setup_mock_dataset_owner_execute_result(mock_db, mock_tenant, mock_ta)
mock_db.session.get.return_value = mock_account

View File

@ -701,8 +701,8 @@ class TestDocumentApiDelete:
``delete`` is wrapped by ``@cloud_edition_billing_rate_limit_check`` which
internally calls ``validate_and_get_api_token``. To bypass the decorator
we call the original function via ``__wrapped__`` (preserved by
``functools.wraps``). ``delete`` queries the dataset via
``db.session.query(Dataset)`` directly, so we patch ``db`` at the
``functools.wraps``). ``delete`` loads the dataset via
``db.session.scalar(select(Dataset)...)``, so we patch ``db`` at the
controller module.
"""

View File

@ -24,8 +24,8 @@ from enums.cloud_plan import CloudPlan
from models.account import TenantStatus
from models.model import ApiToken
from tests.unit_tests.conftest import (
setup_mock_dataset_tenant_query,
setup_mock_tenant_account_query,
setup_mock_dataset_owner_execute_result,
setup_mock_tenant_owner_execute_result,
)
@ -141,14 +141,11 @@ class TestValidateAppToken:
mock_account = Mock()
mock_account.id = str(uuid.uuid4())
mock_ta = Mock()
mock_ta.account_id = mock_account.id
# Use side_effect to return app first, then tenant via session.get()
mock_db.session.get.side_effect = [mock_app, mock_tenant]
# Mock the tenant owner query (execute(select(...)).one_or_none())
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_ta)
# Mock the tenant owner execute result (execute(select(...)).one_or_none())
setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account)
@validate_app_token
def protected_view(app_model):
@ -471,7 +468,7 @@ class TestValidateDatasetToken:
mock_account.current_tenant = mock_tenant
# Mock the tenant account join query (execute(select(...)).one_or_none())
setup_mock_dataset_tenant_query(mock_db, mock_tenant, mock_ta)
setup_mock_dataset_owner_execute_result(mock_db, mock_tenant, mock_ta)
# Mock the account lookup via session.get()
mock_db.session.get.return_value = mock_account

View File

@ -22,18 +22,16 @@ class FakeSession:
def __init__(self, mapping: dict[str, Any] | None = None):
self._mapping: dict[str, Any] = mapping or {}
self._model_name: str | None = None
def query(self, model: type) -> FakeSession:
self._model_name = model.__name__
return self
def get(self, model: type, _ident: object) -> Any:
return self._mapping.get(model.__name__)
def where(self, *_args: object, **_kwargs: object) -> FakeSession:
return self
def first(self) -> Any:
assert self._model_name is not None
return self._mapping.get(self._model_name)
def scalar(self, stmt: Any) -> Any:
try:
model = stmt.column_descriptions[0]["entity"]
except (AttributeError, IndexError, KeyError, TypeError):
return None
return self._mapping.get(model.__name__)
class FakeDB:

View File

@ -36,18 +36,6 @@ class _FakeSession:
def __init__(self, mapping: dict[str, Any]):
self._mapping = mapping
self._model_name: str | None = None
def query(self, model):
self._model_name = model.__name__
return self
def where(self, *args, **kwargs):
return self
def first(self):
assert self._model_name is not None
return self._mapping.get(self._model_name)
def get(self, model, ident):
return self._mapping.get(model.__name__)

View File

@ -34,7 +34,6 @@ def _patch_wraps():
patch("controllers.console.wraps.FeatureService.get_system_features", return_value=wraps_features),
patch("controllers.web.login.dify_config", web_dify),
):
mock_db.session.query.return_value.first.return_value = MagicMock()
yield

View File

@ -154,7 +154,6 @@ class TestAdvancedChatAppRunnerConversationVariables:
mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session
mock_sessionmaker.return_value.begin.return_value.__exit__ = MagicMock(return_value=False)
mock_session_class.return_value.__enter__.return_value = MagicMock()
mock_db.session.query.return_value.where.return_value.first.return_value = MagicMock() # App exists
mock_db.engine = MagicMock()
# Mock GraphRuntimeState to accept the variable pool
@ -301,7 +300,6 @@ class TestAdvancedChatAppRunnerConversationVariables:
mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session
mock_sessionmaker.return_value.begin.return_value.__exit__ = MagicMock(return_value=False)
mock_session_class.return_value.__enter__.return_value = MagicMock()
mock_db.session.query.return_value.where.return_value.first.return_value = MagicMock() # App exists
mock_db.engine = MagicMock()
# Mock ConversationVariable.from_variable to return mock objects
@ -453,7 +451,6 @@ class TestAdvancedChatAppRunnerConversationVariables:
mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session
mock_sessionmaker.return_value.begin.return_value.__exit__ = MagicMock(return_value=False)
mock_session_class.return_value.__enter__.return_value = MagicMock()
mock_db.session.query.return_value.where.return_value.first.return_value = MagicMock() # App exists
mock_db.engine = MagicMock()
# Mock GraphRuntimeState to accept the variable pool

View File

@ -375,7 +375,7 @@ def test_generate_success_returns_converted(generator, mocker):
workflow = MagicMock(id="wf", tenant_id="tenant", app_id="pipe", graph_dict={})
session = MagicMock()
session.query.return_value.where.return_value.first.return_value = workflow
session.get.return_value = workflow
mocker.patch.object(module.db, "session", session)
queue_manager = MagicMock()

View File

@ -132,11 +132,8 @@ def test_run_pipeline_not_found(mocker):
app_generate_entity.single_iteration_run = None
app_generate_entity.single_loop_run = None
query = MagicMock()
query.where.return_value.first.return_value = None
session = MagicMock()
session.query.return_value = query
session.get.side_effect = [None, None]
mocker.patch.object(module.db, "session", session)
runner = PipelineRunner(
@ -157,11 +154,9 @@ def test_run_workflow_not_initialized(mocker):
app_generate_entity = _build_app_generate_entity()
pipeline = MagicMock(id="pipe")
query_pipeline = MagicMock()
query_pipeline.where.return_value.first.return_value = pipeline
session = MagicMock()
session.query.return_value = query_pipeline
session.get.side_effect = [None, pipeline]
mocker.patch.object(module.db, "session", session)
runner = PipelineRunner(

View File

@ -775,9 +775,6 @@ class TestNotionExtractorLastEditedTime:
"last_edited_time": "2024-11-27T18:00:00.000Z",
}
mock_request.return_value = mock_response
mock_query = Mock()
mock_db.session.query.return_value = mock_query
mock_query.filter_by.return_value = mock_query
# Act
extractor_page.update_last_edited_time(mock_document_model)
@ -863,9 +860,6 @@ class TestNotionExtractorIntegration:
}
mock_request.side_effect = [last_edited_response, block_response]
mock_query = Mock()
mock_db.session.query.return_value = mock_query
mock_query.filter_by.return_value = mock_query
# Act
documents = extractor.extract()
@ -919,10 +913,6 @@ class TestNotionExtractorIntegration:
}
mock_post.return_value = database_response
mock_query = Mock()
mock_db.session.query.return_value = mock_query
mock_query.filter_by.return_value = mock_query
# Act
documents = extractor.extract()

View File

@ -40,11 +40,11 @@ class TestObfuscatedToken:
class TestEncryptToken:
@patch("extensions.ext_database.db.session.get")
@patch("libs.rsa.encrypt")
def test_successful_encryption(self, mock_encrypt, mock_query):
def test_successful_encryption(self, mock_encrypt, mock_get):
"""Test successful token encryption"""
mock_tenant = MagicMock()
mock_tenant.encrypt_public_key = "mock_public_key"
mock_query.return_value = mock_tenant
mock_get.return_value = mock_tenant
mock_encrypt.return_value = b"encrypted_data"
result = encrypt_token("tenant-123", "test_token")
@ -53,9 +53,9 @@ class TestEncryptToken:
mock_encrypt.assert_called_with("test_token", "mock_public_key")
@patch("extensions.ext_database.db.session.get")
def test_tenant_not_found(self, mock_query):
def test_tenant_not_found(self, mock_get):
"""Test error when tenant doesn't exist"""
mock_query.return_value = None
mock_get.return_value = None
with pytest.raises(ValueError) as exc_info:
encrypt_token("invalid-tenant", "test_token")
@ -122,12 +122,12 @@ class TestEncryptDecryptIntegration:
@patch("extensions.ext_database.db.session.get")
@patch("libs.rsa.encrypt")
@patch("libs.rsa.decrypt")
def test_should_encrypt_and_decrypt_consistently(self, mock_decrypt, mock_encrypt, mock_query):
def test_should_encrypt_and_decrypt_consistently(self, mock_decrypt, mock_encrypt, mock_get):
"""Test that encryption and decryption are consistent"""
# Setup mock tenant
mock_tenant = MagicMock()
mock_tenant.encrypt_public_key = "mock_public_key"
mock_query.return_value = mock_tenant
mock_get.return_value = mock_tenant
# Setup mock encryption/decryption
original_token = "test_token_123"
@ -148,12 +148,12 @@ class TestSecurity:
@patch("extensions.ext_database.db.session.get")
@patch("libs.rsa.encrypt")
def test_cross_tenant_isolation(self, mock_encrypt, mock_query):
def test_cross_tenant_isolation(self, mock_encrypt, mock_get):
"""Ensure tokens encrypted for one tenant cannot be used by another"""
# Setup mock tenant
mock_tenant = MagicMock()
mock_tenant.encrypt_public_key = "tenant1_public_key"
mock_query.return_value = mock_tenant
mock_get.return_value = mock_tenant
mock_encrypt.return_value = b"encrypted_for_tenant1"
# Encrypt token for tenant1
@ -183,10 +183,10 @@ class TestSecurity:
@patch("extensions.ext_database.db.session.get")
@patch("libs.rsa.encrypt")
def test_encryption_randomness(self, mock_encrypt, mock_query):
def test_encryption_randomness(self, mock_encrypt, mock_get):
"""Ensure same plaintext produces different ciphertext"""
mock_tenant = MagicMock(encrypt_public_key="key")
mock_query.return_value = mock_tenant
mock_get.return_value = mock_tenant
# Different outputs for same input
mock_encrypt.side_effect = [b"enc1", b"enc2", b"enc3"]
@ -207,11 +207,11 @@ class TestEdgeCases:
@patch("extensions.ext_database.db.session.get")
@patch("libs.rsa.encrypt")
def test_should_handle_empty_token_encryption(self, mock_encrypt, mock_query):
def test_should_handle_empty_token_encryption(self, mock_encrypt, mock_get):
"""Test encryption of empty token"""
mock_tenant = MagicMock()
mock_tenant.encrypt_public_key = "mock_public_key"
mock_query.return_value = mock_tenant
mock_get.return_value = mock_tenant
mock_encrypt.return_value = b"encrypted_empty"
result = encrypt_token("tenant-123", "")
@ -221,11 +221,11 @@ class TestEdgeCases:
@patch("extensions.ext_database.db.session.get")
@patch("libs.rsa.encrypt")
def test_should_handle_special_characters_in_token(self, mock_encrypt, mock_query):
def test_should_handle_special_characters_in_token(self, mock_encrypt, mock_get):
"""Test tokens containing special/unicode characters"""
mock_tenant = MagicMock()
mock_tenant.encrypt_public_key = "mock_public_key"
mock_query.return_value = mock_tenant
mock_get.return_value = mock_tenant
mock_encrypt.return_value = b"encrypted_special"
# Test various special characters
@ -244,11 +244,11 @@ class TestEdgeCases:
@patch("extensions.ext_database.db.session.get")
@patch("libs.rsa.encrypt")
def test_should_handle_rsa_size_limits(self, mock_encrypt, mock_query):
def test_should_handle_rsa_size_limits(self, mock_encrypt, mock_get):
"""Test behavior when token exceeds RSA encryption limits"""
mock_tenant = MagicMock()
mock_tenant.encrypt_public_key = "mock_public_key"
mock_query.return_value = mock_tenant
mock_get.return_value = mock_tenant
# RSA 2048-bit can only encrypt ~245 bytes
# The actual limit depends on padding scheme

View File

@ -495,7 +495,7 @@ class TestLLMGenerator:
def test_instruction_modify_workflow_no_last_run_fallback(self, mock_model_instance, model_config_entity):
with patch("extensions.ext_database.db.session") as mock_session:
mock_session.return_value.query.return_value.where.return_value.first.return_value = MagicMock()
mock_session.return_value.scalar.return_value = MagicMock()
workflow = MagicMock()
workflow.graph_dict = {"graph": {"nodes": [{"id": "node_id", "data": {"type": "code"}}]}}
@ -521,7 +521,7 @@ class TestLLMGenerator:
def test_instruction_modify_workflow_node_type_fallback(self, mock_model_instance, model_config_entity):
with patch("extensions.ext_database.db.session") as mock_session:
mock_session.return_value.query.return_value.where.return_value.first.return_value = MagicMock()
mock_session.return_value.scalar.return_value = MagicMock()
workflow = MagicMock()
# Cause exception in node_type logic
workflow.graph_dict = {"graph": {"nodes": []}}
@ -548,7 +548,7 @@ class TestLLMGenerator:
def test_instruction_modify_workflow_empty_agent_log(self, mock_model_instance, model_config_entity):
with patch("extensions.ext_database.db.session") as mock_session:
mock_session.return_value.query.return_value.where.return_value.first.return_value = MagicMock()
mock_session.return_value.scalar.return_value = MagicMock()
workflow = MagicMock()
workflow.graph_dict = {"graph": {"nodes": [{"id": "node_id", "data": {"type": "llm"}}]}}
@ -636,7 +636,7 @@ class TestLLMGenerator:
instance.invoke_llm.return_value = mock_response
with patch("extensions.ext_database.db.session") as mock_session:
mock_session.return_value.query.return_value.where.return_value.first.return_value = MagicMock()
mock_session.return_value.scalar.return_value = MagicMock()
workflow = MagicMock()
workflow.graph_dict = {"graph": {"nodes": [{"id": "node_id", "data": {"type": "other"}}]}}

View File

@ -29,15 +29,6 @@ class _Field:
return ("in", self._name, tuple(values))
class _FakeQuery:
def __init__(self):
self.where_calls: list[tuple] = []
def where(self, *conditions):
self.where_calls.append(conditions)
return self
class _FakeExecuteResult:
def __init__(self, segments: list[SimpleNamespace]):
self._segments = segments

View File

@ -109,17 +109,6 @@ class _FakeExecuteResult:
return _FakeExecuteScalarResult(self._data)
class _FakeSummaryQuery:
def __init__(self, summaries: list) -> None:
self._summaries = summaries
def filter(self, *args, **kwargs):
return self
def all(self) -> list:
return self._summaries
class _FakeScalarsResult:
def __init__(self, data: list) -> None:
self._data = data

View File

@ -372,19 +372,11 @@ def test_vector_delegation_methods(vector_factory_module):
def test_search_by_file_handles_missing_and_existing_upload(vector_factory_module, monkeypatch):
class _Field:
def __eq__(self, value):
return value
upload_query = MagicMock()
upload_query.where.return_value = upload_query
vector = vector_factory_module.Vector.__new__(vector_factory_module.Vector)
vector._embeddings = MagicMock()
vector._vector_processor = MagicMock()
mock_session = SimpleNamespace(get=lambda _model, _id: None)
monkeypatch.setattr(vector_factory_module, "UploadFile", SimpleNamespace(id=_Field()))
monkeypatch.setattr(vector_factory_module, "db", SimpleNamespace(session=mock_session))
assert vector.search_by_file("file-1") == []

View File

@ -1484,11 +1484,8 @@ class TestIndexingRunnerProcessChunk:
mock_dependencies["redis"].get.return_value = None
# Mock database query for segment updates
mock_query = MagicMock()
mock_dependencies["db"].session.query.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.update.return_value = None
# Mock database update for segment status
mock_dependencies["db"].session.execute.return_value = None
# Create a proper context manager mock
mock_context = MagicMock()

View File

@ -2417,12 +2417,11 @@ class TestDatasetRetrievalKnowledgeRetrieval:
mock_document.data_source_type = "upload_file"
mock_document.doc_metadata = {}
mock_session.query.return_value.filter.return_value.all.return_value = [
mock_dataset_from_db
]
mock_session.query.return_value.filter.return_value.all.__iter__ = lambda self: iter(
[mock_dataset_from_db, mock_document]
)
mock_datasets = MagicMock()
mock_datasets.all.return_value = [mock_dataset_from_db]
mock_documents = MagicMock()
mock_documents.all.return_value = [mock_document]
mock_session.scalars.side_effect = [mock_datasets, mock_documents]
# Act
result = dataset_retrieval.knowledge_retrieval(request)

View File

@ -451,12 +451,11 @@ class TestDatasetRetrievalKnowledgeRetrieval:
mock_document.data_source_type = "upload_file"
mock_document.doc_metadata = {}
mock_session.query.return_value.filter.return_value.all.return_value = [
mock_dataset_from_db
]
mock_session.query.return_value.filter.return_value.all.__iter__ = lambda self: iter(
[mock_dataset_from_db, mock_document]
)
mock_datasets = MagicMock()
mock_datasets.all.return_value = [mock_dataset_from_db]
mock_documents = MagicMock()
mock_documents.all.return_value = [mock_document]
mock_session.scalars.side_effect = [mock_datasets, mock_documents]
# Act
result = dataset_retrieval.knowledge_retrieval(request)

View File

@ -711,6 +711,8 @@ class TestMessageAnnotation:
annotation = MessageAnnotation(
app_id=app_id,
question="What is AI?",
conversation_id=None,
message_id=None,
content="AI stands for Artificial Intelligence.",
account_id=account_id,
)
@ -728,6 +730,8 @@ class TestMessageAnnotation:
annotation = MessageAnnotation(
app_id=str(uuid4()),
question="Test question",
conversation_id=None,
message_id=None,
content="Test content",
account_id=str(uuid4()),
)
@ -1068,6 +1072,8 @@ class TestModelIntegration:
app_id=app_id,
question="What is AI?",
content="AI stands for Artificial Intelligence.",
conversation_id=None,
message_id=message_id,
account_id=account_id,
)
annotation.id = annotation_id

View File

@ -365,7 +365,6 @@ def _make_segment(
def _make_child_chunk() -> ChildChunk:
return ChildChunk(
id="child-a",
tenant_id="tenant-1",
dataset_id="dataset-1",
document_id="doc-1",

File diff suppressed because it is too large Load Diff

View File

@ -1,925 +0,0 @@
"""
Extensive unit tests for ``ExternalDatasetService``.
This module focuses on the *external dataset service* surface area, which is responsible
for integrating with **external knowledge APIs** and wiring them into Dify datasets.
The goal of this test suite is twofold:
- Provide **highconfidence regression coverage** for all public helpers on
``ExternalDatasetService``.
- Serve as **executable documentation** for how external API integration is expected
to behave in different scenarios (happy paths, validation failures, and error codes).
The file intentionally contains **rich comments and generous spacing** in order to make
each scenario easy to scan during reviews.
"""
from __future__ import annotations
from types import SimpleNamespace
from typing import Any, cast
from unittest.mock import MagicMock, Mock, patch
import httpx
import pytest
from constants import HIDDEN_VALUE
from models.dataset import Dataset, ExternalKnowledgeApis, ExternalKnowledgeBindings
from services.entities.external_knowledge_entities.external_knowledge_entities import (
Authorization,
AuthorizationConfig,
ExternalKnowledgeApiSetting,
)
from services.errors.dataset import DatasetNameDuplicateError
from services.external_knowledge_service import ExternalDatasetService
class ExternalDatasetTestDataFactory:
"""
Factory helpers for building *lightweight* mocks for external knowledge tests.
These helpers are intentionally small and explicit:
- They avoid pulling in unnecessary fixtures.
- They reflect the minimal contract that the service under test cares about.
"""
@staticmethod
def create_external_api(
api_id: str = "api-123",
tenant_id: str = "tenant-1",
name: str = "Test API",
description: str = "Description",
settings: dict[str, Any] | None = None,
) -> ExternalKnowledgeApis:
"""
Create a concrete ``ExternalKnowledgeApis`` instance with minimal fields.
Using the real SQLAlchemy model (instead of a pure Mock) makes it easier to
exercise ``settings_dict`` and other convenience properties if needed.
"""
instance = ExternalKnowledgeApis(
tenant_id=tenant_id,
name=name,
description=description,
settings=None if settings is None else cast(str, pytest.approx), # type: ignore[assignment]
)
# Overwrite generated id for determinism in assertions.
instance.id = api_id
return instance
@staticmethod
def create_dataset(
dataset_id: str = "ds-1",
tenant_id: str = "tenant-1",
name: str = "External Dataset",
provider: str = "external",
) -> Dataset:
"""
Build a small ``Dataset`` instance representing an external dataset.
"""
dataset = Dataset(
tenant_id=tenant_id,
name=name,
description="",
provider=provider,
created_by="user-1",
)
dataset.id = dataset_id
return dataset
@staticmethod
def create_external_binding(
tenant_id: str = "tenant-1",
dataset_id: str = "ds-1",
api_id: str = "api-1",
external_knowledge_id: str = "knowledge-1",
) -> ExternalKnowledgeBindings:
"""
Small helper for a binding between dataset and external knowledge API.
"""
binding = ExternalKnowledgeBindings(
tenant_id=tenant_id,
dataset_id=dataset_id,
external_knowledge_api_id=api_id,
external_knowledge_id=external_knowledge_id,
created_by="user-1",
)
return binding
# ---------------------------------------------------------------------------
# get_external_knowledge_apis
# ---------------------------------------------------------------------------
class TestExternalDatasetServiceGetExternalKnowledgeApis:
"""
Tests for ``ExternalDatasetService.get_external_knowledge_apis``.
These tests focus on:
- Basic pagination wiring via ``db.paginate``.
- Optional search keyword behaviour.
"""
@pytest.fixture
def mock_db_paginate(self):
"""
Patch ``db.paginate`` so we do not touch the real database layer.
"""
with (
patch("services.external_knowledge_service.db.paginate", autospec=True) as mock_paginate,
patch("services.external_knowledge_service.select", autospec=True),
):
yield mock_paginate
def test_get_external_knowledge_apis_basic_pagination(self, mock_db_paginate: MagicMock):
"""
It should return ``items`` and ``total`` coming from the paginate object.
"""
# Arrange
tenant_id = "tenant-1"
page = 1
per_page = 20
mock_items = [Mock(spec=ExternalKnowledgeApis), Mock(spec=ExternalKnowledgeApis)]
mock_pagination = SimpleNamespace(items=mock_items, total=42)
mock_db_paginate.return_value = mock_pagination
# Act
items, total = ExternalDatasetService.get_external_knowledge_apis(page, per_page, tenant_id)
# Assert
assert items is mock_items
assert total == 42
mock_db_paginate.assert_called_once()
call_kwargs = mock_db_paginate.call_args.kwargs
assert call_kwargs["page"] == page
assert call_kwargs["per_page"] == per_page
assert call_kwargs["max_per_page"] == 100
assert call_kwargs["error_out"] is False
def test_get_external_knowledge_apis_with_search_keyword(self, mock_db_paginate: MagicMock):
"""
When a search keyword is provided, the query should be adjusted
(we simply assert that paginate is still called and does not explode).
"""
# Arrange
tenant_id = "tenant-1"
page = 2
per_page = 10
search = "foo"
mock_pagination = SimpleNamespace(items=[], total=0)
mock_db_paginate.return_value = mock_pagination
# Act
items, total = ExternalDatasetService.get_external_knowledge_apis(page, per_page, tenant_id, search=search)
# Assert
assert items == []
assert total == 0
mock_db_paginate.assert_called_once()
# ---------------------------------------------------------------------------
# validate_api_list
# ---------------------------------------------------------------------------
class TestExternalDatasetServiceValidateApiList:
"""
Lightweight validation tests for ``validate_api_list``.
"""
def test_validate_api_list_success(self):
"""
A minimal valid configuration (endpoint + api_key) should pass.
"""
config = {"endpoint": "https://example.com", "api_key": "secret"}
# Act & Assert no exception expected
ExternalDatasetService.validate_api_list(config)
@pytest.mark.parametrize(
("config", "expected_message"),
[
({}, "api list is empty"),
({"api_key": "k"}, "endpoint is required"),
({"endpoint": "https://example.com"}, "api_key is required"),
],
)
def test_validate_api_list_failures(self, config: dict[str, Any], expected_message: str):
"""
Invalid configs should raise ``ValueError`` with a clear message.
"""
with pytest.raises(ValueError, match=expected_message):
ExternalDatasetService.validate_api_list(config)
# ---------------------------------------------------------------------------
# create_external_knowledge_api & get/update/delete
# ---------------------------------------------------------------------------
class TestExternalDatasetServiceCrudExternalKnowledgeApi:
"""
CRUD tests for external knowledge API templates.
"""
@pytest.fixture
def mock_db_session(self):
"""
Patch ``db.session`` for all CRUD tests in this class.
"""
with patch("services.external_knowledge_service.db.session", autospec=True) as mock_session:
yield mock_session
def test_create_external_knowledge_api_success(self, mock_db_session: MagicMock):
"""
``create_external_knowledge_api`` should persist a new record
when settings are present and valid.
"""
tenant_id = "tenant-1"
user_id = "user-1"
args = {
"name": "API",
"description": "desc",
"settings": {"endpoint": "https://api.example.com", "api_key": "secret"},
}
# We do not want to actually call the remote endpoint here, so we patch the validator.
with patch.object(ExternalDatasetService, "check_endpoint_and_api_key", autospec=True) as mock_check:
result = ExternalDatasetService.create_external_knowledge_api(tenant_id, user_id, args)
assert isinstance(result, ExternalKnowledgeApis)
mock_check.assert_called_once_with(args["settings"])
mock_db_session.add.assert_called_once()
mock_db_session.commit.assert_called_once()
def test_create_external_knowledge_api_missing_settings_raises(self, mock_db_session: MagicMock):
"""
Missing ``settings`` should result in a ``ValueError``.
"""
tenant_id = "tenant-1"
user_id = "user-1"
args = {"name": "API", "description": "desc"}
with pytest.raises(ValueError, match="settings is required"):
ExternalDatasetService.create_external_knowledge_api(tenant_id, user_id, args)
mock_db_session.add.assert_not_called()
mock_db_session.commit.assert_not_called()
def test_get_external_knowledge_api_found(self, mock_db_session: MagicMock):
"""
``get_external_knowledge_api`` should return the first matching record.
"""
api = Mock(spec=ExternalKnowledgeApis)
mock_db_session.scalar.return_value = api
result = ExternalDatasetService.get_external_knowledge_api("api-id", "tenant-id")
assert result is api
def test_get_external_knowledge_api_not_found_raises(self, mock_db_session: MagicMock):
"""
When the record is absent, a ``ValueError`` is raised.
"""
mock_db_session.scalar.return_value = None
with pytest.raises(ValueError, match="api template not found"):
ExternalDatasetService.get_external_knowledge_api("missing-id", "tenant-id")
def test_update_external_knowledge_api_success_with_hidden_api_key(self, mock_db_session: MagicMock):
"""
Updating an API should keep the existing API key when the special hidden
value placeholder is sent from the UI.
"""
tenant_id = "tenant-1"
user_id = "user-1"
api_id = "api-1"
existing_api = Mock(spec=ExternalKnowledgeApis)
existing_api.settings_dict = {"api_key": "stored-key"}
existing_api.settings = '{"api_key":"stored-key"}'
mock_db_session.scalar.return_value = existing_api
args = {
"name": "New Name",
"description": "New Desc",
"settings": {"endpoint": "https://api.example.com", "api_key": HIDDEN_VALUE},
}
result = ExternalDatasetService.update_external_knowledge_api(tenant_id, user_id, api_id, args)
assert result is existing_api
# The placeholder should be replaced with stored key.
assert args["settings"]["api_key"] == "stored-key"
mock_db_session.commit.assert_called_once()
def test_update_external_knowledge_api_not_found_raises(self, mock_db_session: MagicMock):
"""
Updating a nonexistent API template should raise ``ValueError``.
"""
mock_db_session.scalar.return_value = None
with pytest.raises(ValueError, match="api template not found"):
ExternalDatasetService.update_external_knowledge_api(
tenant_id="tenant-1",
user_id="user-1",
external_knowledge_api_id="missing-id",
args={"name": "n", "description": "d", "settings": {}},
)
def test_delete_external_knowledge_api_success(self, mock_db_session: MagicMock):
"""
``delete_external_knowledge_api`` should delete and commit when found.
"""
api = Mock(spec=ExternalKnowledgeApis)
mock_db_session.scalar.return_value = api
ExternalDatasetService.delete_external_knowledge_api("tenant-1", "api-1")
mock_db_session.delete.assert_called_once_with(api)
mock_db_session.commit.assert_called_once()
def test_delete_external_knowledge_api_not_found_raises(self, mock_db_session: MagicMock):
"""
Deletion of a missing template should raise ``ValueError``.
"""
mock_db_session.scalar.return_value = None
with pytest.raises(ValueError, match="api template not found"):
ExternalDatasetService.delete_external_knowledge_api("tenant-1", "missing")
# ---------------------------------------------------------------------------
# external_knowledge_api_use_check & binding lookups
# ---------------------------------------------------------------------------
class TestExternalDatasetServiceUsageAndBindings:
"""
Tests for usage checks and dataset binding retrieval.
"""
@pytest.fixture
def mock_db_session(self):
with patch("services.external_knowledge_service.db.session", autospec=True) as mock_session:
yield mock_session
def test_external_knowledge_api_use_check_in_use(self, mock_db_session: MagicMock):
"""
When there are bindings, ``external_knowledge_api_use_check`` returns True and count.
"""
mock_db_session.scalar.return_value = 3
in_use, count = ExternalDatasetService.external_knowledge_api_use_check("api-1", "tenant-1")
assert in_use is True
assert count == 3
assert "tenant_id" in str(mock_db_session.scalar.call_args.args[0])
def test_external_knowledge_api_use_check_not_in_use(self, mock_db_session: MagicMock):
"""
Zero bindings should return ``(False, 0)``.
"""
mock_db_session.scalar.return_value = 0
in_use, count = ExternalDatasetService.external_knowledge_api_use_check("api-1", "tenant-1")
assert in_use is False
assert count == 0
def test_get_external_knowledge_binding_with_dataset_id_found(self, mock_db_session: MagicMock):
"""
Binding lookup should return the first record when present.
"""
binding = Mock(spec=ExternalKnowledgeBindings)
mock_db_session.scalar.return_value = binding
result = ExternalDatasetService.get_external_knowledge_binding_with_dataset_id("tenant-1", "ds-1")
assert result is binding
def test_get_external_knowledge_binding_with_dataset_id_not_found_raises(self, mock_db_session: MagicMock):
"""
Missing binding should result in a ``ValueError``.
"""
mock_db_session.scalar.return_value = None
with pytest.raises(ValueError, match="external knowledge binding not found"):
ExternalDatasetService.get_external_knowledge_binding_with_dataset_id("tenant-1", "ds-1")
# ---------------------------------------------------------------------------
# document_create_args_validate
# ---------------------------------------------------------------------------
class TestExternalDatasetServiceDocumentCreateArgsValidate:
"""
Tests for ``document_create_args_validate``.
"""
@pytest.fixture
def mock_db_session(self):
with patch("services.external_knowledge_service.db.session", autospec=True) as mock_session:
yield mock_session
def test_document_create_args_validate_success(self, mock_db_session: MagicMock):
"""
All required custom parameters present validation should pass.
"""
external_api = Mock(spec=ExternalKnowledgeApis)
external_api.settings = json_settings = (
'[{"document_process_setting":[{"name":"foo","required":true},{"name":"bar","required":false}]}]'
)
# Raw string; the service itself calls json.loads on it
mock_db_session.scalar.return_value = external_api
process_parameter = {"foo": "value", "bar": "optional"}
# Act & Assert no exception
ExternalDatasetService.document_create_args_validate("tenant-1", "api-1", process_parameter)
assert json_settings in external_api.settings # simple sanity check on our test data
def test_document_create_args_validate_missing_template_raises(self, mock_db_session: MagicMock):
"""
When the referenced API template is missing, a ``ValueError`` is raised.
"""
mock_db_session.scalar.return_value = None
with pytest.raises(ValueError, match="api template not found"):
ExternalDatasetService.document_create_args_validate("tenant-1", "missing", {})
def test_document_create_args_validate_missing_required_parameter_raises(self, mock_db_session: MagicMock):
"""
Required document process parameters must be supplied.
"""
external_api = Mock(spec=ExternalKnowledgeApis)
external_api.settings = (
'[{"document_process_setting":[{"name":"foo","required":true},{"name":"bar","required":false}]}]'
)
mock_db_session.scalar.return_value = external_api
process_parameter = {"bar": "present"} # missing "foo"
with pytest.raises(ValueError, match="foo is required"):
ExternalDatasetService.document_create_args_validate("tenant-1", "api-1", process_parameter)
# ---------------------------------------------------------------------------
# process_external_api
# ---------------------------------------------------------------------------
class TestExternalDatasetServiceProcessExternalApi:
"""
Tests focused on the HTTP request assembly and method mapping behaviour.
"""
def test_process_external_api_valid_method_post(self):
"""
For a supported HTTP verb we should delegate to the correct ``ssrf_proxy`` function.
"""
settings = ExternalKnowledgeApiSetting(
url="https://example.com/path",
request_method="POST",
headers={"X-Test": "1"},
params={"foo": "bar"},
)
fake_response = httpx.Response(200)
with patch("services.external_knowledge_service.ssrf_proxy.post", autospec=True) as mock_post:
mock_post.return_value = fake_response
result = ExternalDatasetService.process_external_api(settings, files=None)
assert result is fake_response
mock_post.assert_called_once()
kwargs = mock_post.call_args.kwargs
assert kwargs["url"] == settings.url
assert kwargs["headers"] == settings.headers
assert kwargs["follow_redirects"] is True
assert "data" in kwargs
def test_process_external_api_invalid_method_raises(self):
"""
An unsupported HTTP verb should raise ``InvalidHttpMethodError``.
"""
settings = ExternalKnowledgeApiSetting(
url="https://example.com",
request_method="INVALID",
headers=None,
params={},
)
from graphon.nodes.http_request.exc import InvalidHttpMethodError
with pytest.raises(InvalidHttpMethodError):
ExternalDatasetService.process_external_api(settings, files=None)
# ---------------------------------------------------------------------------
# assembling_headers
# ---------------------------------------------------------------------------
class TestExternalDatasetServiceAssemblingHeaders:
"""
Tests for header assembly based on different authentication flavours.
"""
def test_assembling_headers_bearer_token(self):
"""
For bearer auth we expect ``Authorization: Bearer <key>`` by default.
"""
auth = Authorization(
type="api-key",
config=AuthorizationConfig(type="bearer", api_key="secret", header=None),
)
headers = ExternalDatasetService.assembling_headers(auth)
assert headers["Authorization"] == "Bearer secret"
def test_assembling_headers_basic_token_with_custom_header(self):
"""
For basic auth we honour the configured header name.
"""
auth = Authorization(
type="api-key",
config=AuthorizationConfig(type="basic", api_key="abc123", header="X-Auth"),
)
headers = ExternalDatasetService.assembling_headers(auth, headers={"Existing": "1"})
assert headers["Existing"] == "1"
assert headers["X-Auth"] == "Basic abc123"
def test_assembling_headers_custom_type(self):
"""
Custom auth type should inject the raw API key.
"""
auth = Authorization(
type="api-key",
config=AuthorizationConfig(type="custom", api_key="raw-key", header="X-API-KEY"),
)
headers = ExternalDatasetService.assembling_headers(auth, headers=None)
assert headers["X-API-KEY"] == "raw-key"
def test_assembling_headers_missing_config_raises(self):
"""
Missing config object should be rejected.
"""
auth = Authorization(type="api-key", config=None)
with pytest.raises(ValueError, match="authorization config is required"):
ExternalDatasetService.assembling_headers(auth)
def test_assembling_headers_missing_api_key_raises(self):
"""
``api_key`` is required when type is ``api-key``.
"""
auth = Authorization(
type="api-key",
config=AuthorizationConfig(type="bearer", api_key=None, header="Authorization"),
)
with pytest.raises(ValueError, match="api_key is required"):
ExternalDatasetService.assembling_headers(auth)
def test_assembling_headers_no_auth_type_leaves_headers_unchanged(self):
"""
For ``no-auth`` we should not modify the headers mapping.
"""
auth = Authorization(type="no-auth", config=None)
base_headers = {"X": "1"}
result = ExternalDatasetService.assembling_headers(auth, headers=base_headers)
# A copy is returned, original is not mutated.
assert result == base_headers
assert result is not base_headers
# ---------------------------------------------------------------------------
# get_external_knowledge_api_settings
# ---------------------------------------------------------------------------
class TestExternalDatasetServiceGetExternalKnowledgeApiSettings:
"""
Simple shape test for ``get_external_knowledge_api_settings``.
"""
def test_get_external_knowledge_api_settings(self):
settings_dict: dict[str, Any] = {
"url": "https://example.com/retrieval",
"request_method": "post",
"headers": {"Content-Type": "application/json"},
"params": {"foo": "bar"},
}
result = ExternalDatasetService.get_external_knowledge_api_settings(settings_dict)
assert isinstance(result, ExternalKnowledgeApiSetting)
assert result.url == settings_dict["url"]
assert result.request_method == settings_dict["request_method"]
assert result.headers == settings_dict["headers"]
assert result.params == settings_dict["params"]
# ---------------------------------------------------------------------------
# create_external_dataset
# ---------------------------------------------------------------------------
class TestExternalDatasetServiceCreateExternalDataset:
"""
Tests around creating the external dataset and its binding row.
"""
@pytest.fixture
def mock_db_session(self):
with patch("services.external_knowledge_service.db.session", autospec=True) as mock_session:
yield mock_session
def test_create_external_dataset_success(self, mock_db_session: MagicMock):
"""
A brand new dataset name with valid external knowledge references
should create both the dataset and its binding.
"""
tenant_id = "tenant-1"
user_id = "user-1"
args = {
"name": "My Dataset",
"description": "desc",
"external_knowledge_api_id": "api-1",
"external_knowledge_id": "knowledge-1",
"external_retrieval_model": {"top_k": 3},
}
# No existing dataset with same name.
mock_db_session.scalar.side_effect = [
None, # duplicatename check
Mock(spec=ExternalKnowledgeApis), # external knowledge api
]
dataset = ExternalDatasetService.create_external_dataset(tenant_id, user_id, args)
assert isinstance(dataset, Dataset)
assert dataset.provider == "external"
assert dataset.retrieval_model == args["external_retrieval_model"]
assert mock_db_session.add.call_count >= 2 # dataset + binding
mock_db_session.flush.assert_called_once()
mock_db_session.commit.assert_called_once()
def test_create_external_dataset_duplicate_name_raises(self, mock_db_session: MagicMock):
"""
When a dataset with the same name already exists,
``DatasetNameDuplicateError`` is raised.
"""
existing_dataset = Mock(spec=Dataset)
mock_db_session.scalar.return_value = existing_dataset
args = {
"name": "Existing",
"external_knowledge_api_id": "api-1",
"external_knowledge_id": "knowledge-1",
}
with pytest.raises(DatasetNameDuplicateError):
ExternalDatasetService.create_external_dataset("tenant-1", "user-1", args)
mock_db_session.add.assert_not_called()
mock_db_session.commit.assert_not_called()
def test_create_external_dataset_missing_api_template_raises(self, mock_db_session: MagicMock):
"""
If the referenced external knowledge API does not exist, a ``ValueError`` is raised.
"""
# First call: duplicate name check not found.
mock_db_session.scalar.side_effect = [
None,
None, # external knowledge api lookup
]
args = {
"name": "Dataset",
"external_knowledge_api_id": "missing",
"external_knowledge_id": "knowledge-1",
}
with pytest.raises(ValueError, match="api template not found"):
ExternalDatasetService.create_external_dataset("tenant-1", "user-1", args)
def test_create_external_dataset_missing_required_ids_raise(self, mock_db_session: MagicMock):
"""
``external_knowledge_id`` and ``external_knowledge_api_id`` are mandatory.
"""
# duplicate name check — two calls to create_external_dataset, each does 2 scalar calls
mock_db_session.scalar.side_effect = [
None,
Mock(spec=ExternalKnowledgeApis),
None,
Mock(spec=ExternalKnowledgeApis),
]
args_missing_knowledge_id = {
"name": "Dataset",
"external_knowledge_api_id": "api-1",
"external_knowledge_id": None,
}
with pytest.raises(ValueError, match="external_knowledge_id is required"):
ExternalDatasetService.create_external_dataset("tenant-1", "user-1", args_missing_knowledge_id)
args_missing_api_id = {
"name": "Dataset",
"external_knowledge_api_id": None,
"external_knowledge_id": "k-1",
}
with pytest.raises(ValueError, match="external_knowledge_api_id is required"):
ExternalDatasetService.create_external_dataset("tenant-1", "user-1", args_missing_api_id)
# ---------------------------------------------------------------------------
# fetch_external_knowledge_retrieval
# ---------------------------------------------------------------------------
class TestExternalDatasetServiceFetchExternalKnowledgeRetrieval:
"""
Tests for ``fetch_external_knowledge_retrieval`` which orchestrates
external retrieval requests and normalises the response payload.
"""
@pytest.fixture
def mock_db_session(self):
with patch("services.external_knowledge_service.db.session", autospec=True) as mock_session:
yield mock_session
def test_fetch_external_knowledge_retrieval_success(self, mock_db_session: MagicMock):
"""
With a valid binding and API template, records from the external
service should be returned when the HTTP response is 200.
"""
tenant_id = "tenant-1"
dataset_id = "ds-1"
query = "test query"
external_retrieval_parameters = {"top_k": 3, "score_threshold_enabled": True, "score_threshold": 0.5}
binding = ExternalDatasetTestDataFactory.create_external_binding(
tenant_id=tenant_id,
dataset_id=dataset_id,
api_id="api-1",
external_knowledge_id="knowledge-1",
)
api = Mock(spec=ExternalKnowledgeApis)
api.settings = '{"endpoint":"https://example.com","api_key":"secret"}'
# First query: binding; second query: api.
mock_db_session.scalar.side_effect = [
binding,
api,
]
fake_records = [{"content": "doc", "score": 0.9}]
fake_response = Mock(spec=httpx.Response)
fake_response.status_code = 200
fake_response.json.return_value = {"records": fake_records}
metadata_condition = SimpleNamespace(model_dump=lambda: {"field": "value"})
with patch.object(
ExternalDatasetService, "process_external_api", return_value=fake_response, autospec=True
) as mock_process:
result = ExternalDatasetService.fetch_external_knowledge_retrieval(
tenant_id=tenant_id,
dataset_id=dataset_id,
query=query,
external_retrieval_parameters=external_retrieval_parameters,
metadata_condition=metadata_condition,
)
assert result == fake_records
mock_process.assert_called_once()
setting_arg = mock_process.call_args.args[0]
assert isinstance(setting_arg, ExternalKnowledgeApiSetting)
assert setting_arg.url.endswith("/retrieval")
def test_fetch_external_knowledge_retrieval_binding_not_found_raises(self, mock_db_session: MagicMock):
"""
Missing binding should raise ``ValueError``.
"""
mock_db_session.scalar.return_value = None
with pytest.raises(ValueError, match="external knowledge binding not found"):
ExternalDatasetService.fetch_external_knowledge_retrieval(
tenant_id="tenant-1",
dataset_id="missing",
query="q",
external_retrieval_parameters={},
metadata_condition=None,
)
def test_fetch_external_knowledge_retrieval_missing_api_template_raises(self, mock_db_session: MagicMock):
"""
When the API template is missing or has no settings, a ``ValueError`` is raised.
"""
binding = ExternalDatasetTestDataFactory.create_external_binding()
mock_db_session.scalar.side_effect = [
binding,
None,
]
with pytest.raises(ValueError, match="external api template not found"):
ExternalDatasetService.fetch_external_knowledge_retrieval(
tenant_id="tenant-1",
dataset_id="ds-1",
query="q",
external_retrieval_parameters={},
metadata_condition=None,
)
def test_fetch_external_knowledge_retrieval_non_200_status_returns_empty_list(self, mock_db_session: MagicMock):
"""
Non200 responses should be treated as an empty result set.
"""
binding = ExternalDatasetTestDataFactory.create_external_binding()
api = Mock(spec=ExternalKnowledgeApis)
api.settings = '{"endpoint":"https://example.com","api_key":"secret"}'
mock_db_session.scalar.side_effect = [
binding,
api,
]
fake_response = Mock(spec=httpx.Response)
fake_response.status_code = 500
fake_response.json.return_value = {}
with patch.object(ExternalDatasetService, "process_external_api", return_value=fake_response, autospec=True):
result = ExternalDatasetService.fetch_external_knowledge_retrieval(
tenant_id="tenant-1",
dataset_id="ds-1",
query="q",
external_retrieval_parameters={},
metadata_condition=None,
)
assert result == []

View File

@ -374,24 +374,14 @@ def test_publish_workflow_success(mocker, rag_pipeline_service) -> None:
mock_db = mocker.Mock()
mocker.patch("services.rag_pipeline.rag_pipeline.db", mock_db)
mock_dataset_service_class = mocker.patch("services.dataset_service.DatasetService")
mock_dataset_service = mock_dataset_service_class.return_value
# 6. Mock session and its scalar/query methods
# 6. Mock session and dataset lookup
mock_session = mocker.Mock()
mock_session.scalar.return_value = draft_wf
# Mock dataset update query (needed even if service is mocked, as rag_pipeline fetches it first)
dataset = mocker.Mock()
dataset.retrieval_model_dict = {}
dataset_query = mocker.Mock()
dataset_query.where.return_value.first.return_value = dataset
# Mock node execution copy
node_exec_query = mocker.Mock()
node_exec_query.where.return_value.all.return_value = []
# Mocked session query side effects
mock_session.query.side_effect = [node_exec_query, dataset_query]
pipeline.retrieve_dataset.return_value = dataset
# 7. Run test
result = rag_pipeline_service.publish_workflow(session=mock_session, pipeline=pipeline, account=account)
@ -1524,7 +1514,6 @@ def test_handle_node_run_result_marks_document_error_for_published_invoke(mocker
)
document = SimpleNamespace(indexing_status="waiting", error=None)
query = mocker.Mock()
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.get", return_value=document)
add_mock = mocker.patch("services.rag_pipeline.rag_pipeline.db.session.add")
commit_mock = mocker.patch("services.rag_pipeline.rag_pipeline.db.session.commit")
@ -1595,7 +1584,6 @@ def test_publish_customized_pipeline_template_raises_for_missing_workflow_id(moc
def test_get_pipeline_raises_when_dataset_missing(mocker, rag_pipeline_service) -> None:
query = mocker.Mock()
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=None)
with pytest.raises(ValueError, match="Dataset not found"):
@ -1604,7 +1592,6 @@ def test_get_pipeline_raises_when_dataset_missing(mocker, rag_pipeline_service)
def test_get_pipeline_raises_when_pipeline_missing(mocker, rag_pipeline_service) -> None:
dataset = SimpleNamespace(pipeline_id="p1")
query = mocker.Mock()
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", side_effect=[dataset, None])
with pytest.raises(ValueError, match="Pipeline not found"):
@ -1644,7 +1631,6 @@ def test_get_pipeline_templates_builtin_en_us_no_fallback(mocker) -> None:
def test_update_customized_pipeline_template_commits_when_name_empty(mocker) -> None:
template = SimpleNamespace(name="old", description="old", icon={}, updated_by=None)
query = mocker.Mock()
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=template)
commit = mocker.patch("services.rag_pipeline.rag_pipeline.db.session.commit")
mocker.patch("services.rag_pipeline.rag_pipeline.current_user", SimpleNamespace(id="u1", current_tenant_id="t1"))
@ -1871,7 +1857,6 @@ def test_run_free_workflow_node_delegates_to_handle_result(mocker, rag_pipeline_
def test_publish_customized_pipeline_template_raises_when_workflow_missing(mocker, rag_pipeline_service) -> None:
pipeline = SimpleNamespace(id="p1", tenant_id="t1", workflow_id="wf-1")
query = mocker.Mock()
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.get", side_effect=[pipeline, None])
with pytest.raises(ValueError, match="Workflow not found"):
@ -1910,7 +1895,6 @@ def test_get_recommended_plugins_skips_manifest_when_missing(mocker, rag_pipelin
def test_retry_error_document_raises_when_pipeline_missing(mocker, rag_pipeline_service) -> None:
exec_log = SimpleNamespace(pipeline_id="p1")
query = mocker.Mock()
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=exec_log)
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.get", return_value=None)
@ -1923,7 +1907,6 @@ def test_retry_error_document_raises_when_pipeline_missing(mocker, rag_pipeline_
def test_retry_error_document_raises_when_workflow_missing(mocker, rag_pipeline_service) -> None:
exec_log = SimpleNamespace(pipeline_id="p1")
pipeline = SimpleNamespace(id="p1")
query = mocker.Mock()
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=exec_log)
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.get", return_value=pipeline)
mocker.patch.object(rag_pipeline_service, "get_published_workflow", return_value=None)
@ -1940,7 +1923,6 @@ def test_get_datasource_plugins_returns_empty_for_non_datasource_nodes(mocker, r
workflow = SimpleNamespace(
graph_dict={"nodes": [{"id": "n1", "data": {"type": "start"}}]}, rag_pipeline_variables=[]
)
query = mocker.Mock()
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", side_effect=[dataset, pipeline])
mocker.patch.object(rag_pipeline_service, "get_published_workflow", return_value=workflow)
@ -2103,7 +2085,6 @@ def test_get_datasource_plugins_handles_empty_datasource_data_and_non_published(
graph_dict={"nodes": [{"id": "n1", "data": {"type": "datasource", "datasource_parameters": {}}}]},
rag_pipeline_variables=[{"variable": "v1", "belong_to_node_id": "shared"}],
)
query = mocker.Mock()
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", side_effect=[dataset, pipeline])
mocker.patch.object(rag_pipeline_service, "get_draft_workflow", return_value=workflow)
mocker.patch(
@ -2143,7 +2124,6 @@ def test_get_datasource_plugins_extracts_user_inputs_and_credentials(mocker, rag
{"variable": "v3", "belong_to_node_id": "shared"},
],
)
query = mocker.Mock()
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", side_effect=[dataset, pipeline])
mocker.patch.object(rag_pipeline_service, "get_published_workflow", return_value=workflow)
mocker.patch(
@ -2161,7 +2141,6 @@ def test_get_datasource_plugins_extracts_user_inputs_and_credentials(mocker, rag
def test_get_pipeline_returns_pipeline_when_found(mocker, rag_pipeline_service) -> None:
dataset = SimpleNamespace(pipeline_id="p1")
pipeline = SimpleNamespace(id="p1")
query = mocker.Mock()
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", side_effect=[dataset, pipeline])
result = rag_pipeline_service.get_pipeline("t1", "d1")

File diff suppressed because it is too large Load Diff

View File

@ -1,59 +0,0 @@
from unittest.mock import MagicMock
class ServiceDbTestHelper:
"""
Helper class for service database query tests.
"""
@staticmethod
def setup_db_query_filter_by_mock(mock_db, query_results):
"""
Smart database query mock that responds based on model type and query parameters.
Args:
mock_db: Mock database session
query_results: Dict mapping (model_name, filter_key, filter_value) to return value
Example: {('Account', 'email', 'test@example.com'): mock_account}
"""
def query_side_effect(model):
mock_query = MagicMock()
def filter_by_side_effect(**kwargs):
mock_filter_result = MagicMock()
def first_side_effect():
# Find matching result based on model and filter parameters
for (model_name, filter_key, filter_value), result in query_results.items():
if model.__name__ == model_name and filter_key in kwargs and kwargs[filter_key] == filter_value:
return result
return None
mock_filter_result.first.side_effect = first_side_effect
# Handle order_by calls for complex queries
def order_by_side_effect(*args, **kwargs):
mock_order_result = MagicMock()
def order_first_side_effect():
# Look for order_by results in the same query_results dict
for (model_name, filter_key, filter_value), result in query_results.items():
if (
model.__name__ == model_name
and filter_key == "order_by"
and filter_value == "first_available"
):
return result
return None
mock_order_result.first.side_effect = order_first_side_effect
return mock_order_result
mock_filter_result.order_by.side_effect = order_by_side_effect
return mock_filter_result
mock_query.filter_by.side_effect = filter_by_side_effect
return mock_query
mock_db.session.query.side_effect = query_side_effect

View File

@ -14,7 +14,6 @@ from services.errors.account import (
AccountRegisterError,
CurrentPasswordIncorrectError,
)
from tests.unit_tests.services.services_test_help import ServiceDbTestHelper
class TestAccountAssociatedDataFactory:
@ -149,7 +148,6 @@ class TestAccountService:
# Setup basic session methods
mock_session.add = MagicMock()
mock_session.commit = MagicMock()
mock_session.query = MagicMock()
yield mock_db
@ -1572,15 +1570,9 @@ class TestRegisterService:
account_id="existing-user-456", email="existing@example.com", status="active"
)
# Mock database queries
query_results = {
(
"TenantAccountJoin",
"tenant_id",
"tenant-456",
): TestAccountAssociatedDataFactory.create_tenant_join_mock(),
}
ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results)
mock_db_dependencies[
"db"
].session.scalar.return_value = TestAccountAssociatedDataFactory.create_tenant_join_mock()
# Mock TenantService methods
with (

View File

@ -238,6 +238,8 @@ class TestAppAnnotationServiceUpInsert:
assert result == annotation_instance
mock_cls.assert_called_once_with(
app_id=app.id,
conversation_id=None,
message_id=None,
content="hello",
question="q1",
account_id=current_user.id,

View File

@ -163,7 +163,7 @@ class TestAsyncWorkflowService:
mocks["quota_service"].reserve.assert_called_once()
quota_charge_mock.commit.assert_called_once()
assert session.commit.call_count == 2
assert session.commit.call_count == 3
created_log = mocks["repo"].create.call_args[0][0]
assert created_log.status == WorkflowTriggerStatus.QUEUED
@ -266,7 +266,7 @@ class TestAsyncWorkflowService:
trigger_data=trigger_data,
)
assert session.commit.call_count == 2
assert session.commit.call_count == 3
updated_log = mocks["repo"].update.call_args[0][0]
assert updated_log.status == WorkflowTriggerStatus.RATE_LIMITED
assert "Quota limit reached" in updated_log.error
@ -469,7 +469,7 @@ class TestAsyncWorkflowServiceGetWorkflow:
# Assert
assert result == workflow
workflow_service.get_published_workflow_by_id.assert_called_once_with(app_model, "workflow-123")
workflow_service.get_published_workflow_by_id.assert_called_once_with(app_model, "workflow-123", session=None)
workflow_service.get_published_workflow.assert_not_called()
def test_should_raise_when_specific_workflow_id_not_found(self):
@ -497,7 +497,7 @@ class TestAsyncWorkflowServiceGetWorkflow:
# Assert
assert result == workflow
workflow_service.get_published_workflow.assert_called_once_with(app_model)
workflow_service.get_published_workflow.assert_called_once_with(app_model, session=None)
workflow_service.get_published_workflow_by_id.assert_not_called()
def test_should_raise_when_default_published_workflow_not_found(self):

View File

@ -89,7 +89,6 @@ class TestSegmentServiceChildChunks:
document = _make_document()
segment = _make_segment()
existing_a = ChildChunk(
id="child-a",
tenant_id="tenant-1",
dataset_id="dataset-1",
document_id="doc-1",
@ -100,7 +99,6 @@ class TestSegmentServiceChildChunks:
created_by="user-1",
)
existing_b = ChildChunk(
id="child-b",
tenant_id="tenant-1",
dataset_id="dataset-1",
document_id="doc-1",
@ -110,7 +108,8 @@ class TestSegmentServiceChildChunks:
word_count=9,
created_by="user-1",
)
existing_a.id = "child-a"
existing_b.id = "child-b"
with (
patch("services.dataset_service.db") as mock_db,
patch("services.dataset_service.uuid.uuid4", return_value="node-new"),
@ -714,7 +713,6 @@ class TestSegmentServiceMutations:
patch("services.dataset_service.db") as mock_db,
patch("services.dataset_service.delete_segment_from_index_task") as delete_task,
):
segments_query = MagicMock()
# execute().all() for segments_info (multi-column)
execute_result = MagicMock()
execute_result.all.return_value = [

View File

@ -36,9 +36,7 @@ class TestDatasourceProviderService:
@pytest.fixture
def mock_db_session(self):
"""
Robust, chainable query mock.
q returns itself for .filter_by(), .order_by(), .where() so any
SQLAlchemy chaining pattern works without multiple brittle sub-mocks.
Mock session with scalar/scalars defaults for current SQLAlchemy access paths.
"""
with (
patch("services.datasource_provider_service.Session") as mock_cls,
@ -46,20 +44,6 @@ class TestDatasourceProviderService:
):
sess = MagicMock(spec=Session)
q = MagicMock()
sess.query.return_value = q
# Self-returning chain — any method called on q returns q
q.filter_by.return_value = q
q.order_by.return_value = q
q.where.return_value = q
# Default terminal values (tests override per-case)
q.first.return_value = None
q.all.return_value = []
q.count.return_value = 0
q.delete.return_value = 1
# Default values for select()-style calls (tests override per-case)
sess.scalar.return_value = None
sess.scalars.return_value.all.return_value = []

View File

@ -17,23 +17,6 @@ from services.trigger import webhook_service as service_module
from services.trigger.webhook_service import WebhookService
class _FakeQuery:
def __init__(self, result: Any) -> None:
self._result = result
def where(self, *args: Any, **kwargs: Any) -> "_FakeQuery":
return self
def filter(self, *args: Any, **kwargs: Any) -> "_FakeQuery":
return self
def order_by(self, *args: Any, **kwargs: Any) -> "_FakeQuery":
return self
def first(self) -> Any:
return self._result
@pytest.fixture
def flask_app() -> Flask:
return Flask(__name__)

Some files were not shown because too many files have changed in this diff Show More