Merge branch 'main' into jzh

This commit is contained in:
JzoNg 2026-03-20 10:46:45 +08:00
commit 7b85adf1cc
351 changed files with 11867 additions and 5256 deletions

View File

@ -4,10 +4,10 @@ runs:
using: composite
steps:
- name: Setup Vite+
uses: voidzero-dev/setup-vp@b5d848f5a62488f3d3d920f8aa6ac318a60c5f07 # v1
uses: voidzero-dev/setup-vp@4a524139920f87f9f7080d3b8545acac019e1852 # v1.0.0
with:
node-version-file: "./web/.nvmrc"
node-version-file: web/.nvmrc
cache: true
cache-dependency-path: web/pnpm-lock.yaml
run-install: |
- cwd: ./web
args: ['--frozen-lockfile']
cwd: ./web

View File

@ -12,7 +12,7 @@ jobs:
anti-slop:
runs-on: ubuntu-latest
steps:
- uses: peakoss/anti-slop@v0
- uses: peakoss/anti-slop@85daca1880e9e1af197fc06ea03349daf08f4202 # v0.2.1
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
close-pr: false

View File

@ -2,6 +2,12 @@ name: Run Pytest
on:
workflow_call:
secrets:
CODECOV_TOKEN:
required: false
permissions:
contents: read
concurrency:
group: api-tests-${{ github.head_ref || github.run_id }}
@ -11,6 +17,8 @@ jobs:
test:
name: API Tests
runs-on: ubuntu-latest
env:
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
defaults:
run:
shell: bash
@ -24,10 +32,11 @@ jobs:
- name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
fetch-depth: 0
persist-credentials: false
- name: Setup UV and Python
uses: astral-sh/setup-uv@e06108dd0aef18192324c70427afc47652e63a82 # v7.5.0
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0
with:
enable-cache: true
python-version: ${{ matrix.python-version }}
@ -79,21 +88,12 @@ jobs:
api/tests/test_containers_integration_tests \
api/tests/unit_tests
- name: Coverage Summary
run: |
set -x
# Extract coverage percentage and create a summary
TOTAL_COVERAGE=$(python -c 'import json; print(json.load(open("coverage.json"))["totals"]["percent_covered_display"])')
# Create a detailed coverage summary
echo "### Test Coverage Summary :test_tube:" >> $GITHUB_STEP_SUMMARY
echo "Total Coverage: ${TOTAL_COVERAGE}%" >> $GITHUB_STEP_SUMMARY
{
echo ""
echo "<details><summary>File-level coverage (click to expand)</summary>"
echo ""
echo '```'
uv run --project api coverage report -m
echo '```'
echo "</details>"
} >> $GITHUB_STEP_SUMMARY
- name: Report coverage
if: ${{ env.CODECOV_TOKEN != '' && matrix.python-version == '3.12' }}
uses: codecov/codecov-action@1af58845a975a7985b0beb0cbe6fbbb71a41dbad # v5.5.3
with:
files: ./coverage.xml
disable_search: true
flags: api
env:
CODECOV_TOKEN: ${{ env.CODECOV_TOKEN }}

View File

@ -39,7 +39,7 @@ jobs:
with:
python-version: "3.11"
- uses: astral-sh/setup-uv@e06108dd0aef18192324c70427afc47652e63a82 # v7.5.0
- uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0
- name: Generate Docker Compose
if: steps.docker-compose-changes.outputs.any_changed == 'true'

View File

@ -19,7 +19,7 @@ jobs:
persist-credentials: false
- name: Setup UV and Python
uses: astral-sh/setup-uv@e06108dd0aef18192324c70427afc47652e63a82 # v7.5.0
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0
with:
enable-cache: true
python-version: "3.12"
@ -69,7 +69,7 @@ jobs:
persist-credentials: false
- name: Setup UV and Python
uses: astral-sh/setup-uv@e06108dd0aef18192324c70427afc47652e63a82 # v7.5.0
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0
with:
enable-cache: true
python-version: "3.12"

View File

@ -56,16 +56,14 @@ jobs:
needs: check-changes
if: needs.check-changes.outputs.api-changed == 'true'
uses: ./.github/workflows/api-tests.yml
secrets: inherit
web-tests:
name: Web Tests
needs: check-changes
if: needs.check-changes.outputs.web-changed == 'true'
uses: ./.github/workflows/web-tests.yml
with:
base_sha: ${{ github.event.before || github.event.pull_request.base.sha }}
diff_range_mode: ${{ github.event.before && 'exact' || 'merge-base' }}
head_sha: ${{ github.event.after || github.event.pull_request.head.sha || github.sha }}
secrets: inherit
style-check:
name: Style Check

View File

@ -22,7 +22,7 @@ jobs:
fetch-depth: 0
- name: Setup Python & UV
uses: astral-sh/setup-uv@e06108dd0aef18192324c70427afc47652e63a82 # v7.5.0
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0
with:
enable-cache: true

View File

@ -33,7 +33,7 @@ jobs:
- name: Setup UV and Python
if: steps.changed-files.outputs.any_changed == 'true'
uses: astral-sh/setup-uv@e06108dd0aef18192324c70427afc47652e63a82 # v7.5.0
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0
with:
enable-cache: false
python-version: "3.12"

View File

@ -120,7 +120,7 @@ jobs:
- name: Run Claude Code for Translation Sync
if: steps.detect_changes.outputs.CHANGED_FILES != ''
uses: anthropics/claude-code-action@cd77b50d2b0808657f8e6774085c8bf54484351c # v1.0.72
uses: anthropics/claude-code-action@df37d2f0760a4b5683a6e617c9325bc1a36443f6 # v1.0.75
with:
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
github_token: ${{ secrets.GITHUB_TOKEN }}

View File

@ -31,7 +31,7 @@ jobs:
remove_tool_cache: true
- name: Setup UV and Python
uses: astral-sh/setup-uv@e06108dd0aef18192324c70427afc47652e63a82 # v7.5.0
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0
with:
enable-cache: true
python-version: ${{ matrix.python-version }}

View File

@ -2,16 +2,9 @@ name: Web Tests
on:
workflow_call:
inputs:
base_sha:
secrets:
CODECOV_TOKEN:
required: false
type: string
diff_range_mode:
required: false
type: string
head_sha:
required: false
type: string
permissions:
contents: read
@ -63,7 +56,7 @@ jobs:
needs: [test]
runs-on: ubuntu-latest
env:
VITEST_COVERAGE_SCOPE: app-components
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
defaults:
run:
shell: bash
@ -87,52 +80,16 @@ jobs:
merge-multiple: true
- name: Merge reports
run: vp test --merge-reports --reporter=json --reporter=agent --coverage
run: vp test --merge-reports --coverage --silent=passed-only
- name: Report app/components baseline coverage
run: node ./scripts/report-components-coverage-baseline.mjs
- name: Report app/components test touch
env:
BASE_SHA: ${{ inputs.base_sha }}
DIFF_RANGE_MODE: ${{ inputs.diff_range_mode }}
HEAD_SHA: ${{ inputs.head_sha }}
run: node ./scripts/report-components-test-touch.mjs
- name: Check app/components pure diff coverage
env:
BASE_SHA: ${{ inputs.base_sha }}
DIFF_RANGE_MODE: ${{ inputs.diff_range_mode }}
HEAD_SHA: ${{ inputs.head_sha }}
run: node ./scripts/check-components-diff-coverage.mjs
- name: Check Coverage Summary
if: always()
id: coverage-summary
run: |
set -eo pipefail
COVERAGE_FILE="coverage/coverage-final.json"
COVERAGE_SUMMARY_FILE="coverage/coverage-summary.json"
if [ -f "$COVERAGE_FILE" ] || [ -f "$COVERAGE_SUMMARY_FILE" ]; then
echo "has_coverage=true" >> "$GITHUB_OUTPUT"
exit 0
fi
echo "has_coverage=false" >> "$GITHUB_OUTPUT"
echo "### 🚨 app/components Diff Coverage" >> "$GITHUB_STEP_SUMMARY"
echo "" >> "$GITHUB_STEP_SUMMARY"
echo "Coverage artifacts not found. Ensure Vitest merge reports ran with coverage enabled." >> "$GITHUB_STEP_SUMMARY"
- name: Upload Coverage Artifact
if: steps.coverage-summary.outputs.has_coverage == 'true'
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
- name: Report coverage
if: ${{ env.CODECOV_TOKEN != '' }}
uses: codecov/codecov-action@1af58845a975a7985b0beb0cbe6fbbb71a41dbad # v5.5.3
with:
name: web-coverage-report
path: web/coverage
retention-days: 30
if-no-files-found: error
directory: web/coverage
flags: web
env:
CODECOV_TOKEN: ${{ env.CODECOV_TOKEN }}
web-build:
name: Web Build

View File

@ -1,9 +1,11 @@
import json
import logging
from typing import Any
from typing import Any, cast
import click
from pydantic import TypeAdapter
from sqlalchemy import delete, select
from sqlalchemy.engine import CursorResult
from configs import dify_config
from core.helper import encrypter
@ -48,14 +50,15 @@ def setup_system_tool_oauth_client(provider, client_params):
click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red"))
return
deleted_count = (
db.session.query(ToolOAuthSystemClient)
.filter_by(
provider=provider_name,
plugin_id=plugin_id,
)
.delete()
)
deleted_count = cast(
CursorResult,
db.session.execute(
delete(ToolOAuthSystemClient).where(
ToolOAuthSystemClient.provider == provider_name,
ToolOAuthSystemClient.plugin_id == plugin_id,
)
),
).rowcount
if deleted_count > 0:
click.echo(click.style(f"Deleted {deleted_count} existing oauth client params.", fg="yellow"))
@ -97,14 +100,15 @@ def setup_system_trigger_oauth_client(provider, client_params):
click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red"))
return
deleted_count = (
db.session.query(TriggerOAuthSystemClient)
.filter_by(
provider=provider_name,
plugin_id=plugin_id,
)
.delete()
)
deleted_count = cast(
CursorResult,
db.session.execute(
delete(TriggerOAuthSystemClient).where(
TriggerOAuthSystemClient.provider == provider_name,
TriggerOAuthSystemClient.plugin_id == plugin_id,
)
),
).rowcount
if deleted_count > 0:
click.echo(click.style(f"Deleted {deleted_count} existing oauth client params.", fg="yellow"))
@ -139,14 +143,15 @@ def setup_datasource_oauth_client(provider, client_params):
return
click.echo(click.style(f"Ready to delete existing oauth client params: {provider_name}", fg="yellow"))
deleted_count = (
db.session.query(DatasourceOauthParamConfig)
.filter_by(
provider=provider_name,
plugin_id=plugin_id,
)
.delete()
)
deleted_count = cast(
CursorResult,
db.session.execute(
delete(DatasourceOauthParamConfig).where(
DatasourceOauthParamConfig.provider == provider_name,
DatasourceOauthParamConfig.plugin_id == plugin_id,
)
),
).rowcount
if deleted_count > 0:
click.echo(click.style(f"Deleted {deleted_count} existing oauth client params.", fg="yellow"))
@ -192,7 +197,9 @@ def transform_datasource_credentials(environment: str):
# deal notion credentials
deal_notion_count = 0
notion_credentials = db.session.query(DataSourceOauthBinding).filter_by(provider="notion").all()
notion_credentials = db.session.scalars(
select(DataSourceOauthBinding).where(DataSourceOauthBinding.provider == "notion")
).all()
if notion_credentials:
notion_credentials_tenant_mapping: dict[str, list[DataSourceOauthBinding]] = {}
for notion_credential in notion_credentials:
@ -201,7 +208,7 @@ def transform_datasource_credentials(environment: str):
notion_credentials_tenant_mapping[tenant_id] = []
notion_credentials_tenant_mapping[tenant_id].append(notion_credential)
for tenant_id, notion_tenant_credentials in notion_credentials_tenant_mapping.items():
tenant = db.session.query(Tenant).filter_by(id=tenant_id).first()
tenant = db.session.scalar(select(Tenant).where(Tenant.id == tenant_id))
if not tenant:
continue
try:
@ -250,7 +257,9 @@ def transform_datasource_credentials(environment: str):
db.session.commit()
# deal firecrawl credentials
deal_firecrawl_count = 0
firecrawl_credentials = db.session.query(DataSourceApiKeyAuthBinding).filter_by(provider="firecrawl").all()
firecrawl_credentials = db.session.scalars(
select(DataSourceApiKeyAuthBinding).where(DataSourceApiKeyAuthBinding.provider == "firecrawl")
).all()
if firecrawl_credentials:
firecrawl_credentials_tenant_mapping: dict[str, list[DataSourceApiKeyAuthBinding]] = {}
for firecrawl_credential in firecrawl_credentials:
@ -259,7 +268,7 @@ def transform_datasource_credentials(environment: str):
firecrawl_credentials_tenant_mapping[tenant_id] = []
firecrawl_credentials_tenant_mapping[tenant_id].append(firecrawl_credential)
for tenant_id, firecrawl_tenant_credentials in firecrawl_credentials_tenant_mapping.items():
tenant = db.session.query(Tenant).filter_by(id=tenant_id).first()
tenant = db.session.scalar(select(Tenant).where(Tenant.id == tenant_id))
if not tenant:
continue
try:
@ -312,7 +321,9 @@ def transform_datasource_credentials(environment: str):
db.session.commit()
# deal jina credentials
deal_jina_count = 0
jina_credentials = db.session.query(DataSourceApiKeyAuthBinding).filter_by(provider="jinareader").all()
jina_credentials = db.session.scalars(
select(DataSourceApiKeyAuthBinding).where(DataSourceApiKeyAuthBinding.provider == "jinareader")
).all()
if jina_credentials:
jina_credentials_tenant_mapping: dict[str, list[DataSourceApiKeyAuthBinding]] = {}
for jina_credential in jina_credentials:
@ -321,7 +332,7 @@ def transform_datasource_credentials(environment: str):
jina_credentials_tenant_mapping[tenant_id] = []
jina_credentials_tenant_mapping[tenant_id].append(jina_credential)
for tenant_id, jina_tenant_credentials in jina_credentials_tenant_mapping.items():
tenant = db.session.query(Tenant).filter_by(id=tenant_id).first()
tenant = db.session.scalar(select(Tenant).where(Tenant.id == tenant_id))
if not tenant:
continue
try:

View File

@ -1,7 +1,10 @@
import json
from typing import cast
import click
import sqlalchemy as sa
from sqlalchemy import update
from sqlalchemy.engine import CursorResult
from configs import dify_config
from extensions.ext_database import db
@ -740,14 +743,17 @@ def migrate_oss(
else:
try:
source_storage_type = StorageType.LOCAL if is_source_local else StorageType.OPENDAL
updated = (
db.session.query(UploadFile)
.where(
UploadFile.storage_type == source_storage_type,
UploadFile.key.in_(copied_upload_file_keys),
)
.update({UploadFile.storage_type: dify_config.STORAGE_TYPE}, synchronize_session=False)
)
updated = cast(
CursorResult,
db.session.execute(
update(UploadFile)
.where(
UploadFile.storage_type == source_storage_type,
UploadFile.key.in_(copied_upload_file_keys),
)
.values(storage_type=dify_config.STORAGE_TYPE)
),
).rowcount
db.session.commit()
click.echo(click.style(f"Updated storage_type for {updated} upload_files records.", fg="green"))
except Exception as e:

View File

@ -2,6 +2,7 @@ import logging
import click
import sqlalchemy as sa
from sqlalchemy import delete, select, update
from sqlalchemy.orm import sessionmaker
from configs import dify_config
@ -41,7 +42,7 @@ def reset_encrypt_key_pair():
click.echo(click.style("This command is only for SELF_HOSTED installations.", fg="red"))
return
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
tenants = session.query(Tenant).all()
tenants = session.scalars(select(Tenant)).all()
for tenant in tenants:
if not tenant:
click.echo(click.style("No workspaces found. Run /install first.", fg="red"))
@ -49,8 +50,8 @@ def reset_encrypt_key_pair():
tenant.encrypt_public_key = generate_key_pair(tenant.id)
session.query(Provider).where(Provider.provider_type == "custom", Provider.tenant_id == tenant.id).delete()
session.query(ProviderModel).where(ProviderModel.tenant_id == tenant.id).delete()
session.execute(delete(Provider).where(Provider.provider_type == "custom", Provider.tenant_id == tenant.id))
session.execute(delete(ProviderModel).where(ProviderModel.tenant_id == tenant.id))
click.echo(
click.style(
@ -93,7 +94,7 @@ def convert_to_agent_apps():
app_id = str(i.id)
if app_id not in proceeded_app_ids:
proceeded_app_ids.append(app_id)
app = db.session.query(App).where(App.id == app_id).first()
app = db.session.scalar(select(App).where(App.id == app_id))
if app is not None:
apps.append(app)
@ -108,8 +109,8 @@ def convert_to_agent_apps():
db.session.commit()
# update conversation mode to agent
db.session.query(Conversation).where(Conversation.app_id == app.id).update(
{Conversation.mode: AppMode.AGENT_CHAT}
db.session.execute(
update(Conversation).where(Conversation.app_id == app.id).values(mode=AppMode.AGENT_CHAT)
)
db.session.commit()
@ -177,7 +178,7 @@ where sites.id is null limit 1000"""
continue
try:
app = db.session.query(App).where(App.id == app_id).first()
app = db.session.scalar(select(App).where(App.id == app_id))
if not app:
logger.info("App %s not found", app_id)
continue

View File

@ -41,14 +41,13 @@ def migrate_annotation_vector_database():
# get apps info
per_page = 50
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
apps = (
session.query(App)
apps = session.scalars(
select(App)
.where(App.status == "normal")
.order_by(App.created_at.desc())
.limit(per_page)
.offset((page - 1) * per_page)
.all()
)
).all()
if not apps:
break
except SQLAlchemyError:
@ -63,8 +62,8 @@ def migrate_annotation_vector_database():
try:
click.echo(f"Creating app annotation index: {app.id}")
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
app_annotation_setting = (
session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app.id).first()
app_annotation_setting = session.scalar(
select(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app.id).limit(1)
)
if not app_annotation_setting:
@ -72,10 +71,10 @@ def migrate_annotation_vector_database():
click.echo(f"App annotation setting disabled: {app.id}")
continue
# get dataset_collection_binding info
dataset_collection_binding = (
session.query(DatasetCollectionBinding)
.where(DatasetCollectionBinding.id == app_annotation_setting.collection_binding_id)
.first()
dataset_collection_binding = session.scalar(
select(DatasetCollectionBinding).where(
DatasetCollectionBinding.id == app_annotation_setting.collection_binding_id
)
)
if not dataset_collection_binding:
click.echo(f"App annotation collection binding not found: {app.id}")
@ -205,11 +204,11 @@ def migrate_knowledge_vector_database():
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
elif vector_type == VectorType.QDRANT:
if dataset.collection_binding_id:
dataset_collection_binding = (
db.session.query(DatasetCollectionBinding)
.where(DatasetCollectionBinding.id == dataset.collection_binding_id)
.one_or_none()
)
dataset_collection_binding = db.session.execute(
select(DatasetCollectionBinding).where(
DatasetCollectionBinding.id == dataset.collection_binding_id
)
).scalar_one_or_none()
if dataset_collection_binding:
collection_name = dataset_collection_binding.collection_name
else:
@ -334,7 +333,7 @@ def add_qdrant_index(field: str):
create_count = 0
try:
bindings = db.session.query(DatasetCollectionBinding).all()
bindings = db.session.scalars(select(DatasetCollectionBinding)).all()
if not bindings:
click.echo(click.style("No dataset collection bindings found.", fg="red"))
return
@ -421,10 +420,10 @@ def old_metadata_migration():
if field.value == key:
break
else:
dataset_metadata = (
db.session.query(DatasetMetadata)
dataset_metadata = db.session.scalar(
select(DatasetMetadata)
.where(DatasetMetadata.dataset_id == document.dataset_id, DatasetMetadata.name == key)
.first()
.limit(1)
)
if not dataset_metadata:
dataset_metadata = DatasetMetadata(
@ -436,7 +435,7 @@ def old_metadata_migration():
)
db.session.add(dataset_metadata)
db.session.flush()
dataset_metadata_binding = DatasetMetadataBinding(
dataset_metadata_binding: DatasetMetadataBinding | None = DatasetMetadataBinding(
tenant_id=document.tenant_id,
dataset_id=document.dataset_id,
metadata_id=dataset_metadata.id,
@ -445,14 +444,14 @@ def old_metadata_migration():
)
db.session.add(dataset_metadata_binding)
else:
dataset_metadata_binding = (
db.session.query(DatasetMetadataBinding) # type: ignore
dataset_metadata_binding = db.session.scalar(
select(DatasetMetadataBinding)
.where(
DatasetMetadataBinding.dataset_id == document.dataset_id,
DatasetMetadataBinding.document_id == document.id,
DatasetMetadataBinding.metadata_id == dataset_metadata.id,
)
.first()
.limit(1)
)
if not dataset_metadata_binding:
dataset_metadata_binding = DatasetMetadataBinding(

View File

@ -103,13 +103,13 @@ class AppMCPServerController(Resource):
raise NotFound()
description = payload.description
if description is None:
pass
elif not description:
if description is None or not description:
server.description = app_model.description or ""
else:
server.description = description
server.name = app_model.name
server.parameters = json.dumps(payload.parameters, ensure_ascii=False)
if payload.status:
try:

View File

@ -30,6 +30,7 @@ from fields.raws import FilesContainedField
from libs.helper import TimestampField, uuid_value
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from libs.login import current_account_with_tenant, login_required
from models.enums import FeedbackFromSource, FeedbackRating
from models.model import AppMode, Conversation, Message, MessageAnnotation, MessageFeedback
from services.errors.conversation import ConversationNotExistsError
from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError
@ -335,7 +336,7 @@ class MessageFeedbackApi(Resource):
if not args.rating and feedback:
db.session.delete(feedback)
elif args.rating and feedback:
feedback.rating = args.rating
feedback.rating = FeedbackRating(args.rating)
feedback.content = args.content
elif not args.rating and not feedback:
raise ValueError("rating cannot be None when feedback not exists")
@ -347,9 +348,9 @@ class MessageFeedbackApi(Resource):
app_id=app_model.id,
conversation_id=message.conversation_id,
message_id=message.id,
rating=rating_value,
rating=FeedbackRating(rating_value),
content=args.content,
from_source="admin",
from_source=FeedbackFromSource.ADMIN,
from_account_id=current_user.id,
)
db.session.add(feedback)

View File

@ -298,6 +298,7 @@ class DatasetDocumentListApi(Resource):
if sort == "hit_count":
sub_query = (
sa.select(DocumentSegment.document_id, sa.func.sum(DocumentSegment.hit_count).label("total_hit_count"))
.where(DocumentSegment.dataset_id == str(dataset_id))
.group_by(DocumentSegment.document_id)
.subquery()
)

View File

@ -24,6 +24,7 @@ from fields.hit_testing_fields import hit_testing_record_fields
from libs.login import current_user
from models.account import Account
from services.dataset_service import DatasetService
from services.entities.knowledge_entities.knowledge_entities import RetrievalModel
from services.hit_testing_service import HitTestingService
logger = logging.getLogger(__name__)
@ -31,7 +32,7 @@ logger = logging.getLogger(__name__)
class HitTestingPayload(BaseModel):
query: str = Field(max_length=250)
retrieval_model: dict[str, Any] | None = None
retrieval_model: RetrievalModel | None = None
external_retrieval_model: dict[str, Any] | None = None
attachment_ids: list[str] | None = None

View File

@ -4,6 +4,7 @@ from flask_restx import Resource
from controllers.console import api
from controllers.console.explore.wraps import explore_banner_enabled
from extensions.ext_database import db
from models.enums import BannerStatus
from models.model import ExporleBanner
@ -16,7 +17,7 @@ class BannerApi(Resource):
language = request.args.get("language", "en-US")
# Build base query for enabled banners
base_query = db.session.query(ExporleBanner).where(ExporleBanner.status == "enabled")
base_query = db.session.query(ExporleBanner).where(ExporleBanner.status == BannerStatus.ENABLED)
# Try to get banners in the requested language
banners = base_query.where(ExporleBanner.language == language).order_by(ExporleBanner.sort).all()

View File

@ -27,6 +27,7 @@ from fields.message_fields import MessageInfiniteScrollPagination, MessageListIt
from libs import helper
from libs.helper import UUIDStrOrEmpty
from libs.login import current_account_with_tenant
from models.enums import FeedbackRating
from models.model import AppMode
from services.app_generate_service import AppGenerateService
from services.errors.app import MoreLikeThisDisabledError
@ -116,7 +117,7 @@ class MessageFeedbackApi(InstalledAppResource):
app_model=app_model,
message_id=message_id,
user=current_user,
rating=payload.rating,
rating=FeedbackRating(payload.rating) if payload.rating else None,
content=payload.content,
)
except MessageNotExistsError:

View File

@ -5,6 +5,7 @@ from typing import ParamSpec, TypeVar
from flask import current_app, request
from flask_login import user_logged_in
from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy.orm import Session
from extensions.ext_database import db
@ -36,23 +37,16 @@ def get_user(tenant_id: str, user_id: str | None) -> EndUser:
user_model = None
if is_anonymous:
user_model = (
session.query(EndUser)
user_model = session.scalar(
select(EndUser)
.where(
EndUser.session_id == user_id,
EndUser.tenant_id == tenant_id,
)
.first()
.limit(1)
)
else:
user_model = (
session.query(EndUser)
.where(
EndUser.id == user_id,
EndUser.tenant_id == tenant_id,
)
.first()
)
user_model = session.get(EndUser, user_id)
if not user_model:
user_model = EndUser(
@ -85,16 +79,7 @@ def get_user_tenant(view_func: Callable[P, R]):
if not user_id:
user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID
try:
tenant_model = (
db.session.query(Tenant)
.where(
Tenant.id == tenant_id,
)
.first()
)
except Exception:
raise ValueError("tenant not found")
tenant_model = db.session.get(Tenant, tenant_id)
if not tenant_model:
raise ValueError("tenant not found")

View File

@ -2,6 +2,7 @@ import json
from flask_restx import Resource
from pydantic import BaseModel
from sqlalchemy import select
from controllers.common.schema import register_schema_models
from controllers.console.wraps import setup_required
@ -42,7 +43,7 @@ class EnterpriseWorkspace(Resource):
def post(self):
args = WorkspaceCreatePayload.model_validate(inner_api_ns.payload or {})
account = db.session.query(Account).filter_by(email=args.owner_email).first()
account = db.session.scalar(select(Account).where(Account.email == args.owner_email).limit(1))
if account is None:
return {"message": "owner account not found."}, 404

View File

@ -75,7 +75,7 @@ def enterprise_inner_api_user_auth(view: Callable[P, R]):
if signature_base64 != token:
return view(*args, **kwargs)
kwargs["user"] = db.session.query(EndUser).where(EndUser.id == user_id).first()
kwargs["user"] = db.session.get(EndUser, user_id)
return view(*args, **kwargs)

View File

@ -15,6 +15,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom
from fields.conversation_fields import ResultResponse
from fields.message_fields import MessageInfiniteScrollPagination, MessageListItem
from libs.helper import UUIDStrOrEmpty
from models.enums import FeedbackRating
from models.model import App, AppMode, EndUser
from services.errors.message import (
FirstMessageNotExistsError,
@ -116,7 +117,7 @@ class MessageFeedbackApi(Resource):
app_model=app_model,
message_id=message_id,
user=end_user,
rating=payload.rating,
rating=FeedbackRating(payload.rating) if payload.rating else None,
content=payload.content,
)
except MessageNotExistsError:

View File

@ -8,6 +8,7 @@ from datetime import datetime
from flask import Response, request
from flask_restx import Resource, reqparse
from sqlalchemy import select
from werkzeug.exceptions import Forbidden
from configs import dify_config
@ -147,11 +148,11 @@ class HumanInputFormApi(Resource):
def _get_app_site_from_form(form: Form) -> tuple[App, Site]:
"""Resolve App/Site for the form's app and validate tenant status."""
app_model = db.session.query(App).where(App.id == form.app_id).first()
app_model = db.session.get(App, form.app_id)
if app_model is None or app_model.tenant_id != form.tenant_id:
raise NotFoundError("Form not found")
site = db.session.query(Site).where(Site.app_id == app_model.id).first()
site = db.session.scalar(select(Site).where(Site.app_id == app_model.id).limit(1))
if site is None:
raise Forbidden()

View File

@ -25,6 +25,7 @@ from fields.conversation_fields import ResultResponse
from fields.message_fields import SuggestedQuestionsResponse, WebMessageInfiniteScrollPagination, WebMessageListItem
from libs import helper
from libs.helper import uuid_value
from models.enums import FeedbackRating
from models.model import AppMode
from services.app_generate_service import AppGenerateService
from services.errors.app import MoreLikeThisDisabledError
@ -157,7 +158,7 @@ class MessageFeedbackApi(WebApiResource):
app_model=app_model,
message_id=message_id,
user=end_user,
rating=payload.rating,
rating=FeedbackRating(payload.rating) if payload.rating else None,
content=payload.content,
)
except MessageNotExistsError:

View File

@ -1,6 +1,7 @@
from typing import cast
from flask_restx import fields, marshal, marshal_with
from sqlalchemy import select
from werkzeug.exceptions import Forbidden
from configs import dify_config
@ -72,7 +73,7 @@ class AppSiteApi(WebApiResource):
def get(self, app_model, end_user):
"""Retrieve app site info."""
# get site
site = db.session.query(Site).where(Site.app_id == app_model.id).first()
site = db.session.scalar(select(Site).where(Site.app_id == app_model.id).limit(1))
if not site:
raise Forbidden()

View File

@ -76,7 +76,7 @@ from dify_graph.system_variable import SystemVariable
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models import Account, Conversation, EndUser, Message, MessageFile
from models.enums import CreatorUserRole, MessageStatus
from models.enums import CreatorUserRole, MessageFileBelongsTo, MessageStatus
from models.execution_extra_content import HumanInputContent
from models.workflow import Workflow
@ -939,7 +939,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
type=file["type"],
transfer_method=file["transfer_method"],
url=file["remote_url"],
belongs_to="assistant",
belongs_to=MessageFileBelongsTo.ASSISTANT,
upload_file_id=file["related_id"],
created_by_role=CreatorUserRole.ACCOUNT
if message.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER}

View File

@ -40,7 +40,7 @@ from dify_graph.model_runtime.entities.message_entities import (
from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey
from dify_graph.model_runtime.errors.invoke import InvokeBadRequestError
from extensions.ext_database import db
from models.enums import CreatorUserRole
from models.enums import CreatorUserRole, MessageFileBelongsTo
from models.model import App, AppMode, Message, MessageAnnotation, MessageFile
if TYPE_CHECKING:
@ -419,7 +419,7 @@ class AppRunner:
message_id=message_id,
type=FileType.IMAGE,
transfer_method=FileTransferMethod.TOOL_FILE,
belongs_to="assistant",
belongs_to=MessageFileBelongsTo.ASSISTANT,
url=f"/files/tools/{tool_file.id}",
upload_file_id=tool_file.id,
created_by_role=(

View File

@ -517,7 +517,7 @@ class WorkflowResponseConverter:
snapshot = self._pop_snapshot(event.node_execution_id)
start_at = snapshot.start_at if snapshot else event.start_at
finished_at = naive_utc_now()
finished_at = event.finished_at or naive_utc_now()
elapsed_time = (finished_at - start_at).total_seconds()
inputs, inputs_truncated = self._truncate_mapping(event.inputs)

View File

@ -33,7 +33,7 @@ from extensions.ext_redis import get_pubsub_broadcast_channel
from libs.broadcast_channel.channel import Topic
from libs.datetime_utils import naive_utc_now
from models import Account
from models.enums import CreatorUserRole
from models.enums import CreatorUserRole, MessageFileBelongsTo
from models.model import App, AppMode, AppModelConfig, Conversation, EndUser, Message, MessageFile
from services.errors.app_model_config import AppModelConfigBrokenError
from services.errors.conversation import ConversationNotExistsError
@ -225,7 +225,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
message_id=message.id,
type=file.type,
transfer_method=file.transfer_method,
belongs_to="user",
belongs_to=MessageFileBelongsTo.USER,
url=file.remote_url,
upload_file_id=file.related_id,
created_by_role=(CreatorUserRole.ACCOUNT if account_id else CreatorUserRole.END_USER),

View File

@ -456,6 +456,7 @@ class WorkflowBasedAppRunner:
node_id=event.node_id,
node_type=event.node_type,
start_at=event.start_at,
finished_at=event.finished_at,
inputs=inputs,
process_data=process_data,
outputs=outputs,
@ -471,6 +472,7 @@ class WorkflowBasedAppRunner:
node_id=event.node_id,
node_type=event.node_type,
start_at=event.start_at,
finished_at=event.finished_at,
inputs=event.node_run_result.inputs,
process_data=event.node_run_result.process_data,
outputs=event.node_run_result.outputs,
@ -487,6 +489,7 @@ class WorkflowBasedAppRunner:
node_id=event.node_id,
node_type=event.node_type,
start_at=event.start_at,
finished_at=event.finished_at,
inputs=event.node_run_result.inputs,
process_data=event.node_run_result.process_data,
outputs=event.node_run_result.outputs,

View File

@ -335,6 +335,7 @@ class QueueNodeSucceededEvent(AppQueueEvent):
in_loop_id: str | None = None
"""loop id if node is in loop"""
start_at: datetime
finished_at: datetime | None = None
inputs: Mapping[str, object] = Field(default_factory=dict)
process_data: Mapping[str, object] = Field(default_factory=dict)
@ -390,6 +391,7 @@ class QueueNodeExceptionEvent(AppQueueEvent):
in_loop_id: str | None = None
"""loop id if node is in loop"""
start_at: datetime
finished_at: datetime | None = None
inputs: Mapping[str, object] = Field(default_factory=dict)
process_data: Mapping[str, object] = Field(default_factory=dict)
@ -414,6 +416,7 @@ class QueueNodeFailedEvent(AppQueueEvent):
in_loop_id: str | None = None
"""loop id if node is in loop"""
start_at: datetime
finished_at: datetime | None = None
inputs: Mapping[str, object] = Field(default_factory=dict)
process_data: Mapping[str, object] = Field(default_factory=dict)

View File

@ -34,6 +34,7 @@ from core.llm_generator.llm_generator import LLMGenerator
from core.tools.signature import sign_tool_file
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.enums import MessageFileBelongsTo
from models.model import AppMode, Conversation, MessageAnnotation, MessageFile
from services.annotation_service import AppAnnotationService
@ -233,7 +234,7 @@ class MessageCycleManager:
task_id=self._application_generate_entity.task_id,
id=message_file.id,
type=message_file.type,
belongs_to=message_file.belongs_to or "user",
belongs_to=message_file.belongs_to or MessageFileBelongsTo.USER,
url=url,
)

View File

@ -268,7 +268,12 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
def _handle_node_succeeded(self, event: NodeRunSucceededEvent) -> None:
domain_execution = self._get_node_execution(event.id)
self._update_node_execution(domain_execution, event.node_run_result, WorkflowNodeExecutionStatus.SUCCEEDED)
self._update_node_execution(
domain_execution,
event.node_run_result,
WorkflowNodeExecutionStatus.SUCCEEDED,
finished_at=event.finished_at,
)
def _handle_node_failed(self, event: NodeRunFailedEvent) -> None:
domain_execution = self._get_node_execution(event.id)
@ -277,6 +282,7 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
event.node_run_result,
WorkflowNodeExecutionStatus.FAILED,
error=event.error,
finished_at=event.finished_at,
)
def _handle_node_exception(self, event: NodeRunExceptionEvent) -> None:
@ -286,6 +292,7 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
event.node_run_result,
WorkflowNodeExecutionStatus.EXCEPTION,
error=event.error,
finished_at=event.finished_at,
)
def _handle_node_pause_requested(self, event: NodeRunPauseRequestedEvent) -> None:
@ -352,13 +359,14 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
*,
error: str | None = None,
update_outputs: bool = True,
finished_at: datetime | None = None,
) -> None:
finished_at = naive_utc_now()
actual_finished_at = finished_at or naive_utc_now()
snapshot = self._node_snapshots.get(domain_execution.id)
start_at = snapshot.created_at if snapshot else domain_execution.created_at
domain_execution.status = status
domain_execution.finished_at = finished_at
domain_execution.elapsed_time = max((finished_at - start_at).total_seconds(), 0.0)
domain_execution.finished_at = actual_finished_at
domain_execution.elapsed_time = max((actual_finished_at - start_at).total_seconds(), 0.0)
if error:
domain_execution.error = error

View File

@ -15,6 +15,7 @@ from configs import dify_config
from core.helper import ssrf_proxy
from extensions.ext_database import db
from extensions.ext_storage import storage
from extensions.storage.storage_type import StorageType
from models.enums import CreatorUserRole
from models.model import MessageFile, UploadFile
from models.tools import ToolFile
@ -81,7 +82,7 @@ class DatasourceFileManager:
upload_file = UploadFile(
tenant_id=tenant_id,
storage_type=dify_config.STORAGE_TYPE,
storage_type=StorageType(dify_config.STORAGE_TYPE),
key=filepath,
name=present_filename,
size=len(file_binary),

View File

@ -1422,12 +1422,12 @@ class ProviderConfiguration(BaseModel):
preferred_model_provider = s.execute(stmt).scalars().first()
if preferred_model_provider:
preferred_model_provider.preferred_provider_type = provider_type.value
preferred_model_provider.preferred_provider_type = provider_type
else:
preferred_model_provider = TenantPreferredModelProvider(
tenant_id=self.tenant_id,
provider_name=self.provider.provider,
preferred_provider_type=provider_type.value,
preferred_provider_type=provider_type,
)
s.add(preferred_model_provider)
s.commit()

View File

@ -195,7 +195,7 @@ class ProviderManager:
preferred_provider_type_record = provider_name_to_preferred_model_provider_records_dict.get(provider_name)
if preferred_provider_type_record:
preferred_provider_type = ProviderType.value_of(preferred_provider_type_record.preferred_provider_type)
preferred_provider_type = preferred_provider_type_record.preferred_provider_type
elif dify_config.EDITION == "CLOUD" and system_configuration.enabled:
preferred_provider_type = ProviderType.SYSTEM
elif custom_configuration.provider or custom_configuration.models:

View File

@ -1,9 +1,10 @@
import re
from typing import Any
class CleanProcessor:
@classmethod
def clean(cls, text: str, process_rule: dict) -> str:
def clean(cls, text: str, process_rule: dict[str, Any] | None) -> str:
# default clean
# remove invalid symbol
text = re.sub(r"<\|", "<", text)

View File

@ -4,6 +4,7 @@ from typing import Any
import orjson
from pydantic import BaseModel
from sqlalchemy import select
from typing_extensions import TypedDict
from configs import dify_config
from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler
@ -15,6 +16,11 @@ from extensions.ext_storage import storage
from models.dataset import Dataset, DatasetKeywordTable, DocumentSegment
class PreSegmentData(TypedDict):
segment: DocumentSegment
keywords: list[str]
class KeywordTableConfig(BaseModel):
max_keywords_per_chunk: int = 10
@ -128,7 +134,7 @@ class Jieba(BaseKeyword):
file_key = "keyword_files/" + self.dataset.tenant_id + "/" + self.dataset.id + ".txt"
storage.delete(file_key)
def _save_dataset_keyword_table(self, keyword_table):
def _save_dataset_keyword_table(self, keyword_table: dict[str, set[str]] | None):
keyword_table_dict = {
"__type__": "keyword_table",
"__data__": {"index_id": self.dataset.id, "summary": None, "table": keyword_table},
@ -144,7 +150,7 @@ class Jieba(BaseKeyword):
storage.delete(file_key)
storage.save(file_key, dumps_with_sets(keyword_table_dict).encode("utf-8"))
def _get_dataset_keyword_table(self) -> dict | None:
def _get_dataset_keyword_table(self) -> dict[str, set[str]] | None:
dataset_keyword_table = self.dataset.dataset_keyword_table
if dataset_keyword_table:
keyword_table_dict = dataset_keyword_table.keyword_table_dict
@ -169,14 +175,16 @@ class Jieba(BaseKeyword):
return {}
def _add_text_to_keyword_table(self, keyword_table: dict, id: str, keywords: list[str]):
def _add_text_to_keyword_table(
self, keyword_table: dict[str, set[str]], id: str, keywords: list[str]
) -> dict[str, set[str]]:
for keyword in keywords:
if keyword not in keyword_table:
keyword_table[keyword] = set()
keyword_table[keyword].add(id)
return keyword_table
def _delete_ids_from_keyword_table(self, keyword_table: dict, ids: list[str]):
def _delete_ids_from_keyword_table(self, keyword_table: dict[str, set[str]], ids: list[str]) -> dict[str, set[str]]:
# get set of ids that correspond to node
node_idxs_to_delete = set(ids)
@ -193,7 +201,7 @@ class Jieba(BaseKeyword):
return keyword_table
def _retrieve_ids_by_query(self, keyword_table: dict, query: str, k: int = 4):
def _retrieve_ids_by_query(self, keyword_table: dict[str, set[str]], query: str, k: int = 4) -> list[str]:
keyword_table_handler = JiebaKeywordTableHandler()
keywords = keyword_table_handler.extract_keywords(query)
@ -228,7 +236,7 @@ class Jieba(BaseKeyword):
keyword_table = self._add_text_to_keyword_table(keyword_table or {}, node_id, keywords)
self._save_dataset_keyword_table(keyword_table)
def multi_create_segment_keywords(self, pre_segment_data_list: list):
def multi_create_segment_keywords(self, pre_segment_data_list: list[PreSegmentData]):
keyword_table_handler = JiebaKeywordTableHandler()
keyword_table = self._get_dataset_keyword_table()
for pre_segment_data in pre_segment_data_list:

View File

@ -68,9 +68,12 @@ class SegmentRecord(TypedDict):
class DefaultRetrievalModelDict(TypedDict):
search_method: RetrievalMethod | str
search_method: RetrievalMethod
reranking_enable: bool
reranking_model: RerankingModelDict
reranking_mode: NotRequired[str]
weights: NotRequired[WeightsDict | None]
score_threshold: NotRequired[float]
top_k: int
score_threshold_enabled: bool
@ -100,7 +103,7 @@ class RetrievalService:
reranking_mode: str = "reranking_model",
weights: WeightsDict | None = None,
document_ids_filter: list[str] | None = None,
attachment_ids: list | None = None,
attachment_ids: list[str] | None = None,
):
if not query and not attachment_ids:
return []
@ -247,8 +250,8 @@ class RetrievalService:
dataset_id: str,
query: str,
top_k: int,
all_documents: list,
exceptions: list,
all_documents: list[Document],
exceptions: list[str],
document_ids_filter: list[str] | None = None,
):
with flask_app.app_context():
@ -276,9 +279,9 @@ class RetrievalService:
top_k: int,
score_threshold: float | None,
reranking_model: RerankingModelDict | None,
all_documents: list,
all_documents: list[Document],
retrieval_method: RetrievalMethod,
exceptions: list,
exceptions: list[str],
document_ids_filter: list[str] | None = None,
query_type: QueryType = QueryType.TEXT_QUERY,
):
@ -370,9 +373,9 @@ class RetrievalService:
top_k: int,
score_threshold: float | None,
reranking_model: RerankingModelDict | None,
all_documents: list,
all_documents: list[Document],
retrieval_method: str,
exceptions: list,
exceptions: list[str],
document_ids_filter: list[str] | None = None,
):
with flask_app.app_context():

View File

@ -1,12 +1,38 @@
import json
import time
from typing import Any, cast
from typing import Any, NotRequired, cast
import httpx
from typing_extensions import TypedDict
from extensions.ext_storage import storage
class FirecrawlDocumentData(TypedDict):
title: str | None
description: str | None
source_url: str | None
markdown: str | None
class CrawlStatusResponse(TypedDict):
status: str
total: int | None
current: int | None
data: list[FirecrawlDocumentData]
class MapResponse(TypedDict):
success: bool
links: list[str]
class SearchResponse(TypedDict):
success: bool
data: list[dict[str, Any]]
warning: NotRequired[str]
class FirecrawlApp:
def __init__(self, api_key=None, base_url=None):
self.api_key = api_key
@ -14,7 +40,7 @@ class FirecrawlApp:
if self.api_key is None and self.base_url == "https://api.firecrawl.dev":
raise ValueError("No API key provided")
def scrape_url(self, url, params=None) -> dict[str, Any]:
def scrape_url(self, url, params=None) -> FirecrawlDocumentData:
# Documentation: https://docs.firecrawl.dev/api-reference/endpoint/scrape
headers = self._prepare_headers()
json_data = {
@ -32,9 +58,7 @@ class FirecrawlApp:
return self._extract_common_fields(data)
elif response.status_code in {402, 409, 500, 429, 408}:
self._handle_error(response, "scrape URL")
return {} # Avoid additional exception after handling error
else:
raise Exception(f"Failed to scrape URL. Status code: {response.status_code}")
raise Exception(f"Failed to scrape URL. Status code: {response.status_code}")
def crawl_url(self, url, params=None) -> str:
# Documentation: https://docs.firecrawl.dev/api-reference/endpoint/crawl-post
@ -51,7 +75,7 @@ class FirecrawlApp:
self._handle_error(response, "start crawl job")
return "" # unreachable
def map(self, url: str, params: dict[str, Any] | None = None) -> dict[str, Any]:
def map(self, url: str, params: dict[str, Any] | None = None) -> MapResponse:
# Documentation: https://docs.firecrawl.dev/api-reference/endpoint/map
headers = self._prepare_headers()
json_data: dict[str, Any] = {"url": url, "integration": "dify"}
@ -60,14 +84,12 @@ class FirecrawlApp:
json_data.update(params)
response = self._post_request(self._build_url("v2/map"), json_data, headers)
if response.status_code == 200:
return cast(dict[str, Any], response.json())
return cast(MapResponse, response.json())
elif response.status_code in {402, 409, 500, 429, 408}:
self._handle_error(response, "start map job")
return {}
else:
raise Exception(f"Failed to start map job. Status code: {response.status_code}")
raise Exception(f"Failed to start map job. Status code: {response.status_code}")
def check_crawl_status(self, job_id) -> dict[str, Any]:
def check_crawl_status(self, job_id) -> CrawlStatusResponse:
headers = self._prepare_headers()
response = self._get_request(self._build_url(f"v2/crawl/{job_id}"), headers)
if response.status_code == 200:
@ -77,7 +99,7 @@ class FirecrawlApp:
if total == 0:
raise Exception("Failed to check crawl status. Error: No page found")
data = crawl_status_response.get("data", [])
url_data_list = []
url_data_list: list[FirecrawlDocumentData] = []
for item in data:
if isinstance(item, dict) and "metadata" in item and "markdown" in item:
url_data = self._extract_common_fields(item)
@ -95,13 +117,15 @@ class FirecrawlApp:
return self._format_crawl_status_response(
crawl_status_response.get("status"), crawl_status_response, []
)
else:
self._handle_error(response, "check crawl status")
return {} # unreachable
self._handle_error(response, "check crawl status")
raise RuntimeError("unreachable: _handle_error always raises")
def _format_crawl_status_response(
self, status: str, crawl_status_response: dict[str, Any], url_data_list: list[dict[str, Any]]
) -> dict[str, Any]:
self,
status: str,
crawl_status_response: dict[str, Any],
url_data_list: list[FirecrawlDocumentData],
) -> CrawlStatusResponse:
return {
"status": status,
"total": crawl_status_response.get("total"),
@ -109,7 +133,7 @@ class FirecrawlApp:
"data": url_data_list,
}
def _extract_common_fields(self, item: dict[str, Any]) -> dict[str, Any]:
def _extract_common_fields(self, item: dict[str, Any]) -> FirecrawlDocumentData:
return {
"title": item.get("metadata", {}).get("title"),
"description": item.get("metadata", {}).get("description"),
@ -117,7 +141,7 @@ class FirecrawlApp:
"markdown": item.get("markdown"),
}
def _prepare_headers(self) -> dict[str, Any]:
def _prepare_headers(self) -> dict[str, str]:
return {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
def _build_url(self, path: str) -> str:
@ -150,10 +174,10 @@ class FirecrawlApp:
error_message = response.text or "Unknown error occurred"
raise Exception(f"Failed to {action}. Status code: {response.status_code}. Error: {error_message}") # type: ignore[return]
def search(self, query: str, params: dict[str, Any] | None = None) -> dict[str, Any]:
def search(self, query: str, params: dict[str, Any] | None = None) -> SearchResponse:
# Documentation: https://docs.firecrawl.dev/api-reference/endpoint/search
headers = self._prepare_headers()
json_data = {
json_data: dict[str, Any] = {
"query": query,
"limit": 5,
"lang": "en",
@ -170,12 +194,10 @@ class FirecrawlApp:
json_data.update(params)
response = self._post_request(self._build_url("v2/search"), json_data, headers)
if response.status_code == 200:
response_data = response.json()
response_data: SearchResponse = response.json()
if not response_data.get("success"):
raise Exception(f"Search failed. Error: {response_data.get('warning', 'Unknown error')}")
return cast(dict[str, Any], response_data)
return response_data
elif response.status_code in {402, 409, 500, 429, 408}:
self._handle_error(response, "perform search")
return {} # Avoid additional exception after handling error
else:
raise Exception(f"Failed to perform search. Status code: {response.status_code}")
raise Exception(f"Failed to perform search. Status code: {response.status_code}")

View File

@ -15,6 +15,7 @@ from core.rag.extractor.extractor_base import BaseExtractor
from core.rag.models.document import Document
from extensions.ext_database import db
from extensions.ext_storage import storage
from extensions.storage.storage_type import StorageType
from libs.datetime_utils import naive_utc_now
from models.enums import CreatorUserRole
from models.model import UploadFile
@ -150,7 +151,7 @@ class PdfExtractor(BaseExtractor):
# save file to db
upload_file = UploadFile(
tenant_id=self._tenant_id,
storage_type=dify_config.STORAGE_TYPE,
storage_type=StorageType(dify_config.STORAGE_TYPE),
key=file_key,
name=file_key,
size=len(img_bytes),

View File

@ -1,10 +1,11 @@
import json
from collections.abc import Generator
from typing import Union
from typing import Any, Union
from urllib.parse import urljoin
import httpx
from httpx import Response
from typing_extensions import TypedDict
from core.rag.extractor.watercrawl.exceptions import (
WaterCrawlAuthenticationError,
@ -13,6 +14,27 @@ from core.rag.extractor.watercrawl.exceptions import (
)
class SpiderOptions(TypedDict):
max_depth: int
page_limit: int
allowed_domains: list[str]
exclude_paths: list[str]
include_paths: list[str]
class PageOptions(TypedDict):
exclude_tags: list[str]
include_tags: list[str]
wait_time: int
include_html: bool
only_main_content: bool
include_links: bool
timeout: int
accept_cookies_selector: str
locale: str
actions: list[Any]
class BaseAPIClient:
def __init__(self, api_key, base_url):
self.api_key = api_key
@ -121,9 +143,9 @@ class WaterCrawlAPIClient(BaseAPIClient):
def create_crawl_request(
self,
url: Union[list, str] | None = None,
spider_options: dict | None = None,
page_options: dict | None = None,
plugin_options: dict | None = None,
spider_options: SpiderOptions | None = None,
page_options: PageOptions | None = None,
plugin_options: dict[str, Any] | None = None,
):
data = {
# 'urls': url if isinstance(url, list) else [url],
@ -176,8 +198,8 @@ class WaterCrawlAPIClient(BaseAPIClient):
def scrape_url(
self,
url: str,
page_options: dict | None = None,
plugin_options: dict | None = None,
page_options: PageOptions | None = None,
plugin_options: dict[str, Any] | None = None,
sync: bool = True,
prefetched: bool = True,
):

View File

@ -2,16 +2,39 @@ from collections.abc import Generator
from datetime import datetime
from typing import Any
from core.rag.extractor.watercrawl.client import WaterCrawlAPIClient
from typing_extensions import TypedDict
from core.rag.extractor.watercrawl.client import PageOptions, SpiderOptions, WaterCrawlAPIClient
class WatercrawlDocumentData(TypedDict):
title: str | None
description: str | None
source_url: str | None
markdown: str | None
class CrawlJobResponse(TypedDict):
status: str
job_id: str | None
class WatercrawlCrawlStatusResponse(TypedDict):
status: str
job_id: str | None
total: int
current: int
data: list[WatercrawlDocumentData]
time_consuming: float
class WaterCrawlProvider:
def __init__(self, api_key, base_url: str | None = None):
self.client = WaterCrawlAPIClient(api_key, base_url)
def crawl_url(self, url, options: dict | Any | None = None):
def crawl_url(self, url: str, options: dict[str, Any] | None = None) -> CrawlJobResponse:
options = options or {}
spider_options = {
spider_options: SpiderOptions = {
"max_depth": 1,
"page_limit": 1,
"allowed_domains": [],
@ -25,7 +48,7 @@ class WaterCrawlProvider:
spider_options["exclude_paths"] = options.get("excludes", "").split(",") if options.get("excludes") else []
wait_time = options.get("wait_time", 1000)
page_options = {
page_options: PageOptions = {
"exclude_tags": options.get("exclude_tags", "").split(",") if options.get("exclude_tags") else [],
"include_tags": options.get("include_tags", "").split(",") if options.get("include_tags") else [],
"wait_time": max(1000, wait_time), # minimum wait time is 1 second
@ -41,9 +64,9 @@ class WaterCrawlProvider:
return {"status": "active", "job_id": result.get("uuid")}
def get_crawl_status(self, crawl_request_id):
def get_crawl_status(self, crawl_request_id: str) -> WatercrawlCrawlStatusResponse:
response = self.client.get_crawl_request(crawl_request_id)
data = []
data: list[WatercrawlDocumentData] = []
if response["status"] in ["new", "running"]:
status = "active"
else:
@ -67,7 +90,7 @@ class WaterCrawlProvider:
"time_consuming": time_consuming,
}
def get_crawl_url_data(self, job_id, url) -> dict | None:
def get_crawl_url_data(self, job_id: str, url: str) -> WatercrawlDocumentData | None:
if not job_id:
return self.scrape_url(url)
@ -82,11 +105,11 @@ class WaterCrawlProvider:
return None
def scrape_url(self, url: str):
def scrape_url(self, url: str) -> WatercrawlDocumentData:
response = self.client.scrape_url(url=url, sync=True, prefetched=True)
return self._structure_data(response)
def _structure_data(self, result_object: dict):
def _structure_data(self, result_object: dict[str, Any]) -> WatercrawlDocumentData:
if isinstance(result_object.get("result", {}), str):
raise ValueError("Invalid result object. Expected a dictionary.")
@ -98,7 +121,9 @@ class WaterCrawlProvider:
"markdown": result_object.get("result", {}).get("markdown"),
}
def _get_results(self, crawl_request_id: str, query_params: dict | None = None) -> Generator[dict, None, None]:
def _get_results(
self, crawl_request_id: str, query_params: dict | None = None
) -> Generator[WatercrawlDocumentData, None, None]:
page = 0
page_size = 100

View File

@ -21,6 +21,7 @@ from core.rag.extractor.extractor_base import BaseExtractor
from core.rag.models.document import Document
from extensions.ext_database import db
from extensions.ext_storage import storage
from extensions.storage.storage_type import StorageType
from libs.datetime_utils import naive_utc_now
from models.enums import CreatorUserRole
from models.model import UploadFile
@ -112,7 +113,7 @@ class WordExtractor(BaseExtractor):
# save file to db
upload_file = UploadFile(
tenant_id=self.tenant_id,
storage_type=dify_config.STORAGE_TYPE,
storage_type=StorageType(dify_config.STORAGE_TYPE),
key=file_key,
name=file_key,
size=0,
@ -140,7 +141,7 @@ class WordExtractor(BaseExtractor):
# save file to db
upload_file = UploadFile(
tenant_id=self.tenant_id,
storage_type=dify_config.STORAGE_TYPE,
storage_type=StorageType(dify_config.STORAGE_TYPE),
key=file_key,
name=file_key,
size=0,
@ -365,7 +366,7 @@ class WordExtractor(BaseExtractor):
paragraph_content = []
# State for legacy HYPERLINK fields
hyperlink_field_url = None
hyperlink_field_text_parts: list = []
hyperlink_field_text_parts: list[str] = []
is_collecting_field_text = False
# Iterate through paragraph elements in document order
for child in paragraph._element:

View File

@ -33,7 +33,7 @@ from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, Comp
from core.prompt.simple_prompt_transform import ModelMode
from core.rag.data_post_processor.data_post_processor import DataPostProcessor, RerankingModelDict, WeightsDict
from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.datasource.retrieval_service import DefaultRetrievalModelDict, RetrievalService
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.rag.entities.context_entities import DocumentContext
from core.rag.entities.metadata_entities import Condition, MetadataCondition
@ -87,7 +87,7 @@ from models.enums import CreatorUserRole, DatasetQuerySource
from services.external_knowledge_service import ExternalDatasetService
from services.feature_service import FeatureService
default_retrieval_model: dict[str, Any] = {
default_retrieval_model: DefaultRetrievalModelDict = {
"search_method": RetrievalMethod.SEMANTIC_SEARCH,
"reranking_enable": False,
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
@ -591,7 +591,7 @@ class DatasetRetrieval:
user_id: str,
user_from: str,
query: str,
available_datasets: list,
available_datasets: list[Dataset],
model_instance: ModelInstance,
model_config: ModelConfigWithCredentialsEntity,
planning_strategy: PlanningStrategy,
@ -633,15 +633,15 @@ class DatasetRetrieval:
if dataset_id:
# get retrieval model config
dataset_stmt = select(Dataset).where(Dataset.id == dataset_id)
dataset = db.session.scalar(dataset_stmt)
if dataset:
selected_dataset = db.session.scalar(dataset_stmt)
if selected_dataset:
results = []
if dataset.provider == "external":
if selected_dataset.provider == "external":
external_documents = ExternalDatasetService.fetch_external_knowledge_retrieval(
tenant_id=dataset.tenant_id,
tenant_id=selected_dataset.tenant_id,
dataset_id=dataset_id,
query=query,
external_retrieval_parameters=dataset.retrieval_model,
external_retrieval_parameters=selected_dataset.retrieval_model,
metadata_condition=metadata_condition,
)
for external_document in external_documents:
@ -654,24 +654,28 @@ class DatasetRetrieval:
document.metadata["score"] = external_document.get("score")
document.metadata["title"] = external_document.get("title")
document.metadata["dataset_id"] = dataset_id
document.metadata["dataset_name"] = dataset.name
document.metadata["dataset_name"] = selected_dataset.name
results.append(document)
else:
if metadata_condition and not metadata_filter_document_ids:
return []
document_ids_filter = None
if metadata_filter_document_ids:
document_ids = metadata_filter_document_ids.get(dataset.id, [])
document_ids = metadata_filter_document_ids.get(selected_dataset.id, [])
if document_ids:
document_ids_filter = document_ids
else:
return []
retrieval_model_config = dataset.retrieval_model or default_retrieval_model
retrieval_model_config: DefaultRetrievalModelDict = (
cast(DefaultRetrievalModelDict, selected_dataset.retrieval_model)
if selected_dataset.retrieval_model
else default_retrieval_model
)
# get top k
top_k = retrieval_model_config["top_k"]
# get retrieval method
if dataset.indexing_technique == "economy":
if selected_dataset.indexing_technique == "economy":
retrieval_method = RetrievalMethod.KEYWORD_SEARCH
else:
retrieval_method = retrieval_model_config["search_method"]
@ -690,7 +694,7 @@ class DatasetRetrieval:
with measure_time() as timer:
results = RetrievalService.retrieve(
retrieval_method=retrieval_method,
dataset_id=dataset.id,
dataset_id=selected_dataset.id,
query=query,
top_k=top_k,
score_threshold=score_threshold,
@ -722,7 +726,7 @@ class DatasetRetrieval:
tenant_id: str,
user_id: str,
user_from: str,
available_datasets: list,
available_datasets: list[Dataset],
query: str | None,
top_k: int,
score_threshold: float,
@ -1024,7 +1028,7 @@ class DatasetRetrieval:
dataset_id: str,
query: str,
top_k: int,
all_documents: list,
all_documents: list[Document],
document_ids_filter: list[str] | None = None,
metadata_condition: MetadataCondition | None = None,
attachment_ids: list[str] | None = None,
@ -1058,7 +1062,11 @@ class DatasetRetrieval:
all_documents.append(document)
else:
# get retrieval model , if the model is not setting , using default
retrieval_model = dataset.retrieval_model or default_retrieval_model
retrieval_model: DefaultRetrievalModelDict = (
cast(DefaultRetrievalModelDict, dataset.retrieval_model)
if dataset.retrieval_model
else default_retrieval_model
)
if dataset.indexing_technique == "economy":
# use keyword table query
@ -1132,7 +1140,7 @@ class DatasetRetrieval:
if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE:
# get retrieval model config
default_retrieval_model = {
default_retrieval_model: DefaultRetrievalModelDict = {
"search_method": RetrievalMethod.SEMANTIC_SEARCH,
"reranking_enable": False,
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
@ -1141,7 +1149,11 @@ class DatasetRetrieval:
}
for dataset in available_datasets:
retrieval_model_config = dataset.retrieval_model or default_retrieval_model
retrieval_model_config: DefaultRetrievalModelDict = (
cast(DefaultRetrievalModelDict, dataset.retrieval_model)
if dataset.retrieval_model
else default_retrieval_model
)
# get top k
top_k = retrieval_model_config["top_k"]
@ -1286,7 +1298,7 @@ class DatasetRetrieval:
def get_metadata_filter_condition(
self,
dataset_ids: list,
dataset_ids: list[str],
query: str,
tenant_id: str,
user_id: str,
@ -1388,7 +1400,7 @@ class DatasetRetrieval:
return output
def _automatic_metadata_filter_func(
self, dataset_ids: list, query: str, tenant_id: str, user_id: str, metadata_model_config: ModelConfig
self, dataset_ids: list[str], query: str, tenant_id: str, user_id: str, metadata_model_config: ModelConfig
) -> list[dict[str, Any]] | None:
# get all metadata field
metadata_stmt = select(DatasetMetadata).where(DatasetMetadata.dataset_id.in_(dataset_ids))
@ -1586,7 +1598,7 @@ class DatasetRetrieval:
)
def _get_prompt_template(
self, model_config: ModelConfigWithCredentialsEntity, mode: str, metadata_fields: list, query: str
self, model_config: ModelConfigWithCredentialsEntity, mode: str, metadata_fields: list[str], query: str
):
model_mode = ModelMode(mode)
input_text = query
@ -1678,7 +1690,7 @@ class DatasetRetrieval:
def _multiple_retrieve_thread(
self,
flask_app: Flask,
available_datasets: list,
available_datasets: list[Dataset],
metadata_condition: MetadataCondition | None,
metadata_filter_document_ids: dict[str, list[str]] | None,
all_documents: list[Document],

View File

@ -34,7 +34,7 @@ from core.tools.workflow_as_tool.tool import WorkflowTool
from dify_graph.file import FileType
from dify_graph.file.models import FileTransferMethod
from extensions.ext_database import db
from models.enums import CreatorUserRole
from models.enums import CreatorUserRole, MessageFileBelongsTo
from models.model import Message, MessageFile
logger = logging.getLogger(__name__)
@ -352,7 +352,7 @@ class ToolEngine:
message_id=agent_message.id,
type=file_type,
transfer_method=FileTransferMethod.TOOL_FILE,
belongs_to="assistant",
belongs_to=MessageFileBelongsTo.ASSISTANT,
url=message.url,
upload_file_id=tool_file_id,
created_by_role=(

View File

@ -3,7 +3,6 @@ from typing import Final
TRIGGER_WEBHOOK_NODE_TYPE: Final[str] = "trigger-webhook"
TRIGGER_SCHEDULE_NODE_TYPE: Final[str] = "trigger-schedule"
TRIGGER_PLUGIN_NODE_TYPE: Final[str] = "trigger-plugin"
TRIGGER_INFO_METADATA_KEY: Final[str] = "trigger_info"
TRIGGER_NODE_TYPES: Final[frozenset[str]] = frozenset(
{

View File

@ -1,7 +1,7 @@
from collections.abc import Mapping
from typing import Any, cast
from typing import Any
from core.trigger.constants import TRIGGER_INFO_METADATA_KEY, TRIGGER_PLUGIN_NODE_TYPE
from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE
from dify_graph.constants import SYSTEM_VARIABLE_NODE_ID
from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from dify_graph.enums import NodeExecutionType, WorkflowNodeExecutionMetadataKey
@ -47,7 +47,7 @@ class TriggerEventNode(Node[TriggerEventNodeData]):
# Get trigger data passed when workflow was triggered
metadata: dict[WorkflowNodeExecutionMetadataKey, Any] = {
cast(WorkflowNodeExecutionMetadataKey, TRIGGER_INFO_METADATA_KEY): {
WorkflowNodeExecutionMetadataKey.TRIGGER_INFO: {
"provider_id": self.node_data.provider_id,
"event_name": self.node_data.event_name,
"plugin_unique_identifier": self.node_data.plugin_unique_identifier,

View File

@ -245,6 +245,9 @@ _END_STATE = frozenset(
class WorkflowNodeExecutionMetadataKey(StrEnum):
"""
Node Run Metadata Key.
Values in this enum are persisted as execution metadata and must stay in sync
with every node that writes `NodeRunResult.metadata`.
"""
TOTAL_TOKENS = "total_tokens"
@ -266,6 +269,7 @@ class WorkflowNodeExecutionMetadataKey(StrEnum):
ERROR_STRATEGY = "error_strategy" # node in continue on error mode return the field
LOOP_VARIABLE_MAP = "loop_variable_map" # single loop variable output
DATASOURCE_INFO = "datasource_info"
TRIGGER_INFO = "trigger_info"
COMPLETED_REASON = "completed_reason" # completed reason for loop node

View File

@ -159,6 +159,7 @@ class ErrorHandler:
node_id=event.node_id,
node_type=event.node_type,
start_at=event.start_at,
finished_at=event.finished_at,
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.EXCEPTION,
inputs=event.node_run_result.inputs,
@ -198,6 +199,7 @@ class ErrorHandler:
node_id=event.node_id,
node_type=event.node_type,
start_at=event.start_at,
finished_at=event.finished_at,
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.EXCEPTION,
inputs=event.node_run_result.inputs,

View File

@ -15,10 +15,13 @@ from typing import TYPE_CHECKING, final
from typing_extensions import override
from dify_graph.context import IExecutionContext
from dify_graph.enums import WorkflowNodeExecutionStatus
from dify_graph.graph import Graph
from dify_graph.graph_engine.layers.base import GraphEngineLayer
from dify_graph.graph_events import GraphNodeEventBase, NodeRunFailedEvent, is_node_result_event
from dify_graph.graph_events import GraphNodeEventBase, NodeRunFailedEvent, NodeRunStartedEvent, is_node_result_event
from dify_graph.node_events import NodeRunResult
from dify_graph.nodes.base.node import Node
from libs.datetime_utils import naive_utc_now
from .ready_queue import ReadyQueue
@ -65,6 +68,7 @@ class Worker(threading.Thread):
self._stop_event = threading.Event()
self._layers = layers if layers is not None else []
self._last_task_time = time.time()
self._current_node_started_at: datetime | None = None
def stop(self) -> None:
"""Signal the worker to stop processing."""
@ -104,18 +108,15 @@ class Worker(threading.Thread):
self._last_task_time = time.time()
node = self._graph.nodes[node_id]
try:
self._current_node_started_at = None
self._execute_node(node)
self._ready_queue.task_done()
except Exception as e:
error_event = NodeRunFailedEvent(
id=node.execution_id,
node_id=node.id,
node_type=node.node_type,
in_iteration_id=None,
error=str(e),
start_at=datetime.now(),
self._event_queue.put(
self._build_fallback_failure_event(node, e, started_at=self._current_node_started_at)
)
self._event_queue.put(error_event)
finally:
self._current_node_started_at = None
def _execute_node(self, node: Node) -> None:
"""
@ -136,6 +137,8 @@ class Worker(threading.Thread):
try:
node_events = node.run()
for event in node_events:
if isinstance(event, NodeRunStartedEvent) and event.id == node.execution_id:
self._current_node_started_at = event.start_at
self._event_queue.put(event)
if is_node_result_event(event):
result_event = event
@ -149,6 +152,8 @@ class Worker(threading.Thread):
try:
node_events = node.run()
for event in node_events:
if isinstance(event, NodeRunStartedEvent) and event.id == node.execution_id:
self._current_node_started_at = event.start_at
self._event_queue.put(event)
if is_node_result_event(event):
result_event = event
@ -177,3 +182,24 @@ class Worker(threading.Thread):
except Exception:
# Silently ignore layer errors to prevent disrupting node execution
continue
def _build_fallback_failure_event(
self, node: Node, error: Exception, *, started_at: datetime | None = None
) -> NodeRunFailedEvent:
"""Build a failed event when worker-level execution aborts before a node emits its own result event."""
failure_time = naive_utc_now()
error_message = str(error)
return NodeRunFailedEvent(
id=node.execution_id,
node_id=node.id,
node_type=node.node_type,
in_iteration_id=None,
error=error_message,
start_at=started_at or failure_time,
finished_at=failure_time,
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=error_message,
error_type=type(error).__name__,
),
)

View File

@ -36,16 +36,19 @@ class NodeRunRetrieverResourceEvent(GraphNodeEventBase):
class NodeRunSucceededEvent(GraphNodeEventBase):
start_at: datetime = Field(..., description="node start time")
finished_at: datetime | None = Field(default=None, description="node finish time")
class NodeRunFailedEvent(GraphNodeEventBase):
error: str = Field(..., description="error")
start_at: datetime = Field(..., description="node start time")
finished_at: datetime | None = Field(default=None, description="node finish time")
class NodeRunExceptionEvent(GraphNodeEventBase):
error: str = Field(..., description="error")
start_at: datetime = Field(..., description="node start time")
finished_at: datetime | None = Field(default=None, description="node finish time")
class NodeRunRetryEvent(NodeRunStartedEvent):

View File

@ -406,11 +406,13 @@ class Node(Generic[NodeDataT]):
error=str(e),
error_type="WorkflowNodeError",
)
finished_at = naive_utc_now()
yield NodeRunFailedEvent(
id=self.execution_id,
node_id=self._node_id,
node_type=self.node_type,
start_at=self._start_at,
finished_at=finished_at,
node_run_result=result,
error=str(e),
)
@ -568,6 +570,7 @@ class Node(Generic[NodeDataT]):
return self._node_data
def _convert_node_run_result_to_graph_node_event(self, result: NodeRunResult) -> GraphNodeEventBase:
finished_at = naive_utc_now()
match result.status:
case WorkflowNodeExecutionStatus.FAILED:
return NodeRunFailedEvent(
@ -575,6 +578,7 @@ class Node(Generic[NodeDataT]):
node_id=self.id,
node_type=self.node_type,
start_at=self._start_at,
finished_at=finished_at,
node_run_result=result,
error=result.error,
)
@ -584,6 +588,7 @@ class Node(Generic[NodeDataT]):
node_id=self.id,
node_type=self.node_type,
start_at=self._start_at,
finished_at=finished_at,
node_run_result=result,
)
case _:
@ -606,6 +611,7 @@ class Node(Generic[NodeDataT]):
@_dispatch.register
def _(self, event: StreamCompletedEvent) -> NodeRunSucceededEvent | NodeRunFailedEvent:
finished_at = naive_utc_now()
match event.node_run_result.status:
case WorkflowNodeExecutionStatus.SUCCEEDED:
return NodeRunSucceededEvent(
@ -613,6 +619,7 @@ class Node(Generic[NodeDataT]):
node_id=self._node_id,
node_type=self.node_type,
start_at=self._start_at,
finished_at=finished_at,
node_run_result=event.node_run_result,
)
case WorkflowNodeExecutionStatus.FAILED:
@ -621,6 +628,7 @@ class Node(Generic[NodeDataT]):
node_id=self._node_id,
node_type=self.node_type,
start_at=self._start_at,
finished_at=finished_at,
node_run_result=event.node_run_result,
error=event.node_run_result.error,
)

View File

@ -236,7 +236,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
future_to_index: dict[
Future[
tuple[
datetime,
float,
list[GraphNodeEventBase],
object | None,
dict[str, Variable],
@ -261,7 +261,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
try:
result = future.result()
(
iter_start_at,
iteration_duration,
events,
output_value,
conversation_snapshot,
@ -274,8 +274,9 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
# Yield all events from this iteration
yield from events
# Update tokens and timing
iter_run_map[str(index)] = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds()
# The worker computes duration before we replay buffered events here,
# so slow downstream consumers don't inflate per-iteration timing.
iter_run_map[str(index)] = iteration_duration
usage_accumulator[0] = self._merge_usage(usage_accumulator[0], iteration_usage)
@ -305,7 +306,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
index: int,
item: object,
execution_context: "IExecutionContext",
) -> tuple[datetime, list[GraphNodeEventBase], object | None, dict[str, Variable], LLMUsage]:
) -> tuple[float, list[GraphNodeEventBase], object | None, dict[str, Variable], LLMUsage]:
"""Execute a single iteration in parallel mode and return results."""
with execution_context:
iter_start_at = datetime.now(UTC).replace(tzinfo=None)
@ -327,9 +328,10 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
conversation_snapshot = self._extract_conversation_variable_snapshot(
variable_pool=graph_engine.graph_runtime_state.variable_pool
)
iteration_duration = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds()
return (
iter_start_at,
iteration_duration,
events,
output_value,
conversation_snapshot,

View File

@ -256,9 +256,13 @@ def fetch_prompt_messages(
):
continue
prompt_message_content.append(content_item)
if prompt_message_content:
if not prompt_message_content:
continue
if len(prompt_message_content) == 1 and prompt_message_content[0].type == PromptMessageContentType.TEXT:
prompt_message.content = prompt_message_content[0].data
else:
prompt_message.content = prompt_message_content
filtered_prompt_messages.append(prompt_message)
filtered_prompt_messages.append(prompt_message)
elif not prompt_message.is_empty():
filtered_prompt_messages.append(prompt_message)

View File

@ -3,6 +3,7 @@ import logging
import time
import click
from sqlalchemy import select
from werkzeug.exceptions import NotFound
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
@ -24,13 +25,11 @@ def handle(sender, **kwargs):
for document_id in document_ids:
logger.info(click.style(f"Start process document: {document_id}", fg="green"))
document = (
db.session.query(Document)
.where(
document = db.session.scalar(
select(Document).where(
Document.id == document_id,
Document.dataset_id == dataset_id,
)
.first()
)
if not document:

View File

@ -1,6 +1,6 @@
from typing import Any, cast
from sqlalchemy import select
from sqlalchemy import delete, select
from events.app_event import app_model_config_was_updated
from extensions.ext_database import db
@ -31,9 +31,9 @@ def handle(sender, **kwargs):
if removed_dataset_ids:
for dataset_id in removed_dataset_ids:
db.session.query(AppDatasetJoin).where(
AppDatasetJoin.app_id == app.id, AppDatasetJoin.dataset_id == dataset_id
).delete()
db.session.execute(
delete(AppDatasetJoin).where(AppDatasetJoin.app_id == app.id, AppDatasetJoin.dataset_id == dataset_id)
)
if added_dataset_ids:
for dataset_id in added_dataset_ids:

View File

@ -1,6 +1,6 @@
from typing import cast
from sqlalchemy import select
from sqlalchemy import delete, select
from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData
from dify_graph.nodes import BuiltinNodeTypes
@ -31,9 +31,9 @@ def handle(sender, **kwargs):
if removed_dataset_ids:
for dataset_id in removed_dataset_ids:
db.session.query(AppDatasetJoin).where(
AppDatasetJoin.app_id == app.id, AppDatasetJoin.dataset_id == dataset_id
).delete()
db.session.execute(
delete(AppDatasetJoin).where(AppDatasetJoin.app_id == app.id, AppDatasetJoin.dataset_id == dataset_id)
)
if added_dataset_ids:
for dataset_id in added_dataset_ids:

View File

@ -3,6 +3,7 @@ import json
import flask_login
from flask import Response, request
from flask_login import user_loaded_from_request, user_logged_in
from sqlalchemy import select
from werkzeug.exceptions import NotFound, Unauthorized
from configs import dify_config
@ -34,16 +35,15 @@ def load_user_from_request(request_from_flask_login):
if admin_api_key and admin_api_key == auth_token:
workspace_id = request.headers.get("X-WORKSPACE-ID")
if workspace_id:
tenant_account_join = (
db.session.query(Tenant, TenantAccountJoin)
tenant_account_join = db.session.execute(
select(Tenant, TenantAccountJoin)
.where(Tenant.id == workspace_id)
.where(TenantAccountJoin.tenant_id == Tenant.id)
.where(TenantAccountJoin.role == "owner")
.one_or_none()
)
).one_or_none()
if tenant_account_join:
tenant, ta = tenant_account_join
account = db.session.query(Account).filter_by(id=ta.account_id).first()
account = db.session.scalar(select(Account).where(Account.id == ta.account_id))
if account:
account.current_tenant = tenant
return account
@ -70,7 +70,7 @@ def load_user_from_request(request_from_flask_login):
end_user_id = decoded.get("end_user_id")
if not end_user_id:
raise Unauthorized("Invalid Authorization token.")
end_user = db.session.query(EndUser).where(EndUser.id == end_user_id).first()
end_user = db.session.scalar(select(EndUser).where(EndUser.id == end_user_id))
if not end_user:
raise NotFound("End user not found.")
return end_user
@ -80,7 +80,7 @@ def load_user_from_request(request_from_flask_login):
decoded = PassportService().verify(auth_token)
end_user_id = decoded.get("end_user_id")
if end_user_id:
end_user = db.session.query(EndUser).where(EndUser.id == end_user_id).first()
end_user = db.session.scalar(select(EndUser).where(EndUser.id == end_user_id))
if not end_user:
raise NotFound("End user not found.")
return end_user
@ -90,11 +90,11 @@ def load_user_from_request(request_from_flask_login):
server_code = request.view_args.get("server_code") if request.view_args else None
if not server_code:
raise Unauthorized("Invalid Authorization token.")
app_mcp_server = db.session.query(AppMCPServer).where(AppMCPServer.server_code == server_code).first()
app_mcp_server = db.session.scalar(select(AppMCPServer).where(AppMCPServer.server_code == server_code).limit(1))
if not app_mcp_server:
raise NotFound("App MCP server not found.")
end_user = (
db.session.query(EndUser).where(EndUser.session_id == app_mcp_server.id, EndUser.type == "mcp").first()
end_user = db.session.scalar(
select(EndUser).where(EndUser.session_id == app_mcp_server.id, EndUser.type == "mcp").limit(1)
)
if not end_user:
raise NotFound("End user not found.")

View File

@ -32,7 +32,7 @@ class OpenDALStorage(BaseStorage):
kwargs = kwargs or _get_opendal_kwargs(scheme=scheme)
if scheme == "fs":
root = kwargs.get("root", "storage")
root = kwargs.setdefault("root", "storage")
Path(root).mkdir(parents=True, exist_ok=True)
retry_layer = opendal.layers.RetryLayer(max_times=3, factor=2.0, jitter=True)

View File

@ -424,13 +424,11 @@ def _build_from_datasource_file(
datasource_file_id = mapping.get("datasource_file_id")
if not datasource_file_id:
raise ValueError(f"DatasourceFile {datasource_file_id} not found")
datasource_file = (
db.session.query(UploadFile)
.where(
datasource_file = db.session.scalar(
select(UploadFile).where(
UploadFile.id == datasource_file_id,
UploadFile.tenant_id == tenant_id,
)
.first()
)
if datasource_file is None:

View File

@ -158,6 +158,13 @@ class FeedbackFromSource(StrEnum):
ADMIN = "admin"
class FeedbackRating(StrEnum):
"""MessageFeedback rating"""
LIKE = "like"
DISLIKE = "dislike"
class InvokeFrom(StrEnum):
"""How a conversation/message was invoked"""

View File

@ -23,13 +23,25 @@ from core.tools.signature import sign_tool_file
from dify_graph.enums import WorkflowExecutionStatus
from dify_graph.file import FILE_MODEL_IDENTITY, File, FileTransferMethod
from dify_graph.file import helpers as file_helpers
from extensions.storage.storage_type import StorageType
from libs.helper import generate_string # type: ignore[import-not-found]
from libs.uuid_utils import uuidv7
from .account import Account, Tenant
from .base import Base, TypeBase, gen_uuidv4_string
from .engine import db
from .enums import AppMCPServerStatus, AppStatus, ConversationStatus, CreatorUserRole, MessageStatus
from .enums import (
AppMCPServerStatus,
AppStatus,
BannerStatus,
ConversationStatus,
CreatorUserRole,
FeedbackFromSource,
FeedbackRating,
MessageChainType,
MessageFileBelongsTo,
MessageStatus,
)
from .provider_ids import GenericProviderID
from .types import EnumText, LongText, StringUUID
@ -925,8 +937,11 @@ class ExporleBanner(TypeBase):
content: Mapped[dict[str, Any]] = mapped_column(sa.JSON, nullable=False)
link: Mapped[str] = mapped_column(String(255), nullable=False)
sort: Mapped[int] = mapped_column(sa.Integer, nullable=False)
status: Mapped[str] = mapped_column(
sa.String(255), nullable=False, server_default=sa.text("'enabled'::character varying"), default="enabled"
status: Mapped[BannerStatus] = mapped_column(
EnumText(BannerStatus, length=255),
nullable=False,
server_default=sa.text("'enabled'::character varying"),
default=BannerStatus.ENABLED,
)
created_at: Mapped[datetime] = mapped_column(
sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False
@ -1153,7 +1168,7 @@ class Conversation(Base):
select(func.count(MessageFeedback.id)).where(
MessageFeedback.conversation_id == self.id,
MessageFeedback.from_source == "user",
MessageFeedback.rating == "like",
MessageFeedback.rating == FeedbackRating.LIKE,
)
)
or 0
@ -1164,7 +1179,7 @@ class Conversation(Base):
select(func.count(MessageFeedback.id)).where(
MessageFeedback.conversation_id == self.id,
MessageFeedback.from_source == "user",
MessageFeedback.rating == "dislike",
MessageFeedback.rating == FeedbackRating.DISLIKE,
)
)
or 0
@ -1179,7 +1194,7 @@ class Conversation(Base):
select(func.count(MessageFeedback.id)).where(
MessageFeedback.conversation_id == self.id,
MessageFeedback.from_source == "admin",
MessageFeedback.rating == "like",
MessageFeedback.rating == FeedbackRating.LIKE,
)
)
or 0
@ -1190,7 +1205,7 @@ class Conversation(Base):
select(func.count(MessageFeedback.id)).where(
MessageFeedback.conversation_id == self.id,
MessageFeedback.from_source == "admin",
MessageFeedback.rating == "dislike",
MessageFeedback.rating == FeedbackRating.DISLIKE,
)
)
or 0
@ -1713,8 +1728,8 @@ class MessageFeedback(TypeBase):
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
conversation_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
message_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
rating: Mapped[str] = mapped_column(String(255), nullable=False)
from_source: Mapped[str] = mapped_column(String(255), nullable=False)
rating: Mapped[FeedbackRating] = mapped_column(EnumText(FeedbackRating, length=255), nullable=False)
from_source: Mapped[FeedbackFromSource] = mapped_column(EnumText(FeedbackFromSource, length=255), nullable=False)
content: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
from_end_user_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None)
from_account_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None)
@ -1767,7 +1782,9 @@ class MessageFile(TypeBase):
)
created_by_role: Mapped[CreatorUserRole] = mapped_column(EnumText(CreatorUserRole, length=255), nullable=False)
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
belongs_to: Mapped[Literal["user", "assistant"] | None] = mapped_column(String(255), nullable=True, default=None)
belongs_to: Mapped[MessageFileBelongsTo | None] = mapped_column(
EnumText(MessageFileBelongsTo, length=255), nullable=True, default=None
)
url: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
upload_file_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None)
created_at: Mapped[datetime] = mapped_column(
@ -2097,7 +2114,7 @@ class UploadFile(Base):
# The `server_default` serves as a fallback mechanism.
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()))
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
storage_type: Mapped[str] = mapped_column(String(255), nullable=False)
storage_type: Mapped[StorageType] = mapped_column(EnumText(StorageType, length=255), nullable=False)
key: Mapped[str] = mapped_column(String(255), nullable=False)
name: Mapped[str] = mapped_column(String(255), nullable=False)
size: Mapped[int] = mapped_column(sa.Integer, nullable=False)
@ -2141,7 +2158,7 @@ class UploadFile(Base):
self,
*,
tenant_id: str,
storage_type: str,
storage_type: StorageType,
key: str,
name: str,
size: int,
@ -2206,7 +2223,7 @@ class MessageChain(TypeBase):
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
)
message_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
type: Mapped[str] = mapped_column(String(255), nullable=False)
type: Mapped[MessageChainType] = mapped_column(EnumText(MessageChainType, length=255), nullable=False)
input: Mapped[str | None] = mapped_column(LongText, nullable=True)
output: Mapped[str | None] = mapped_column(LongText, nullable=True)
created_at: Mapped[datetime] = mapped_column(

View File

@ -210,7 +210,7 @@ class TenantPreferredModelProvider(TypeBase):
)
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
preferred_provider_type: Mapped[str] = mapped_column(String(40), nullable=False)
preferred_provider_type: Mapped[ProviderType] = mapped_column(EnumText(ProviderType, length=40), nullable=False)
created_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=func.current_timestamp(), init=False
)

View File

@ -22,14 +22,14 @@ from sqlalchemy import (
from sqlalchemy.orm import Mapped, mapped_column
from typing_extensions import deprecated
from core.trigger.constants import TRIGGER_INFO_METADATA_KEY, TRIGGER_PLUGIN_NODE_TYPE
from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE
from dify_graph.constants import (
CONVERSATION_VARIABLE_NODE_ID,
SYSTEM_VARIABLE_NODE_ID,
)
from dify_graph.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter
from dify_graph.entities.pause_reason import HumanInputRequired, PauseReason, PauseReasonType, SchedulingPause
from dify_graph.enums import BuiltinNodeTypes, NodeType, WorkflowExecutionStatus
from dify_graph.enums import BuiltinNodeTypes, NodeType, WorkflowExecutionStatus, WorkflowNodeExecutionMetadataKey
from dify_graph.file.constants import maybe_file_object
from dify_graph.file.models import File
from dify_graph.variables import utils as variable_utils
@ -936,8 +936,11 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo
elif self.node_type == BuiltinNodeTypes.DATASOURCE and "datasource_info" in execution_metadata:
datasource_info = execution_metadata["datasource_info"]
extras["icon"] = datasource_info.get("icon")
elif self.node_type == TRIGGER_PLUGIN_NODE_TYPE and TRIGGER_INFO_METADATA_KEY in execution_metadata:
trigger_info = execution_metadata[TRIGGER_INFO_METADATA_KEY] or {}
elif (
self.node_type == TRIGGER_PLUGIN_NODE_TYPE
and WorkflowNodeExecutionMetadataKey.TRIGGER_INFO in execution_metadata
):
trigger_info = execution_metadata[WorkflowNodeExecutionMetadataKey.TRIGGER_INFO] or {}
provider_id = trigger_info.get("provider_id")
if provider_id:
extras["icon"] = TriggerManager.get_trigger_plugin_icon(

View File

@ -1,6 +1,6 @@
[project]
name = "dify-api"
version = "1.13.1"
version = "1.13.2"
requires-python = ">=3.11,<3.13"
dependencies = [

View File

@ -1,6 +1,6 @@
[pytest]
pythonpath = .
addopts = --cov=./api --cov-report=json --import-mode=importlib
addopts = --cov=./api --cov-report=json --import-mode=importlib --cov-branch --cov-report=xml
env =
ANTHROPIC_API_KEY = sk-ant-api11-IamNotARealKeyJustForMockTestKawaiiiiiiiiii-NotBaka-ASkksz
AZURE_OPENAI_API_BASE = https://difyai-openai.openai.azure.com

View File

@ -3,6 +3,7 @@ import math
import time
import click
from sqlalchemy import select
import app
from core.helper.marketplace import fetch_global_plugin_manifest
@ -28,17 +29,15 @@ def check_upgradable_plugin_task():
now_seconds_of_day = time.time() % 86400 - 30 # we assume the tz is UTC
click.echo(click.style(f"Now seconds of day: {now_seconds_of_day}", fg="green"))
strategies = (
db.session.query(TenantPluginAutoUpgradeStrategy)
.where(
strategies = db.session.scalars(
select(TenantPluginAutoUpgradeStrategy).where(
TenantPluginAutoUpgradeStrategy.upgrade_time_of_day >= now_seconds_of_day,
TenantPluginAutoUpgradeStrategy.upgrade_time_of_day
< now_seconds_of_day + AUTO_UPGRADE_MINIMAL_CHECKING_INTERVAL,
TenantPluginAutoUpgradeStrategy.strategy_setting
!= TenantPluginAutoUpgradeStrategy.StrategySetting.DISABLED,
)
.all()
)
).all()
total_strategies = len(strategies)
click.echo(click.style(f"Total strategies: {total_strategies}", fg="green"))

View File

@ -2,7 +2,7 @@ import datetime
import time
import click
from sqlalchemy import text
from sqlalchemy import select, text
from sqlalchemy.exc import SQLAlchemyError
import app
@ -19,14 +19,12 @@ def clean_embedding_cache_task():
thirty_days_ago = datetime.datetime.now() - datetime.timedelta(days=clean_days)
while True:
try:
embedding_ids = (
db.session.query(Embedding.id)
embedding_ids = db.session.scalars(
select(Embedding.id)
.where(Embedding.created_at < thirty_days_ago)
.order_by(Embedding.created_at.desc())
.limit(100)
.all()
)
embedding_ids = [embedding_id[0] for embedding_id in embedding_ids]
).all()
except SQLAlchemyError:
raise
if embedding_ids:

View File

@ -3,7 +3,7 @@ import time
from typing import TypedDict
import click
from sqlalchemy import func, select
from sqlalchemy import func, select, update
from sqlalchemy.exc import SQLAlchemyError
import app
@ -51,7 +51,7 @@ def clean_unused_datasets_task():
try:
# Subquery for counting new documents
document_subquery_new = (
db.session.query(Document.dataset_id, func.count(Document.id).label("document_count"))
select(Document.dataset_id, func.count(Document.id).label("document_count"))
.where(
Document.indexing_status == "completed",
Document.enabled == True,
@ -64,7 +64,7 @@ def clean_unused_datasets_task():
# Subquery for counting old documents
document_subquery_old = (
db.session.query(Document.dataset_id, func.count(Document.id).label("document_count"))
select(Document.dataset_id, func.count(Document.id).label("document_count"))
.where(
Document.indexing_status == "completed",
Document.enabled == True,
@ -142,8 +142,8 @@ def clean_unused_datasets_task():
index_processor.clean(dataset, None)
# Update document
db.session.query(Document).filter_by(dataset_id=dataset.id).update(
{Document.enabled: False}
db.session.execute(
update(Document).where(Document.dataset_id == dataset.id).values(enabled=False)
)
db.session.commit()
click.echo(click.style(f"Cleaned unused dataset {dataset.id} from db success!", fg="green"))

View File

@ -1,6 +1,7 @@
import time
import click
from sqlalchemy import func, select
import app
from configs import dify_config
@ -20,7 +21,7 @@ def create_tidb_serverless_task():
try:
# check the number of idle tidb serverless
idle_tidb_serverless_number = (
db.session.query(TidbAuthBinding).where(TidbAuthBinding.active == False).count()
db.session.scalar(select(func.count(TidbAuthBinding.id)).where(TidbAuthBinding.active == False)) or 0
)
if idle_tidb_serverless_number >= tidb_serverless_number:
break

View File

@ -49,16 +49,18 @@ def mail_clean_document_notify_task():
if plan != CloudPlan.SANDBOX:
knowledge_details = []
# check tenant
tenant = db.session.query(Tenant).where(Tenant.id == tenant_id).first()
tenant = db.session.scalar(select(Tenant).where(Tenant.id == tenant_id))
if not tenant:
continue
# check current owner
current_owner_join = (
db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, role="owner").first()
current_owner_join = db.session.scalar(
select(TenantAccountJoin)
.where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.role == "owner")
.limit(1)
)
if not current_owner_join:
continue
account = db.session.query(Account).where(Account.id == current_owner_join.account_id).first()
account = db.session.scalar(select(Account).where(Account.id == current_owner_join.account_id))
if not account:
continue
@ -71,7 +73,7 @@ def mail_clean_document_notify_task():
)
for dataset_id, document_ids in dataset_auto_dataset_map.items():
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
dataset = db.session.scalar(select(Dataset).where(Dataset.id == dataset_id))
if dataset:
document_count = len(document_ids)
knowledge_details.append(rf"Knowledge base {dataset.name}: {document_count} documents")

View File

@ -7,6 +7,7 @@ from flask import Response
from sqlalchemy import or_
from extensions.ext_database import db
from models.enums import FeedbackRating
from models.model import Account, App, Conversation, Message, MessageFeedback
@ -100,7 +101,7 @@ class FeedbackService:
"ai_response": message.answer[:500] + "..."
if len(message.answer) > 500
else message.answer, # Truncate long responses
"feedback_rating": "👍" if feedback.rating == "like" else "👎",
"feedback_rating": "👍" if feedback.rating == FeedbackRating.LIKE else "👎",
"feedback_rating_raw": feedback.rating,
"feedback_comment": feedback.content or "",
"feedback_source": feedback.from_source,

View File

@ -23,6 +23,7 @@ from core.rag.extractor.extract_processor import ExtractProcessor
from dify_graph.file import helpers as file_helpers
from extensions.ext_database import db
from extensions.ext_storage import storage
from extensions.storage.storage_type import StorageType
from libs.datetime_utils import naive_utc_now
from libs.helper import extract_tenant_id
from models import Account
@ -93,7 +94,7 @@ class FileService:
# save file to db
upload_file = UploadFile(
tenant_id=current_tenant_id or "",
storage_type=dify_config.STORAGE_TYPE,
storage_type=StorageType(dify_config.STORAGE_TYPE),
key=file_key,
name=filename,
size=file_size,
@ -152,7 +153,7 @@ class FileService:
# save file to db
upload_file = UploadFile(
tenant_id=tenant_id,
storage_type=dify_config.STORAGE_TYPE,
storage_type=StorageType(dify_config.STORAGE_TYPE),
key=file_key,
name=text_name,
size=len(text),

View File

@ -16,6 +16,7 @@ from dify_graph.model_runtime.entities.model_entities import ModelType
from extensions.ext_database import db
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from models import Account
from models.enums import FeedbackFromSource, FeedbackRating
from models.model import App, AppMode, AppModelConfig, EndUser, Message, MessageFeedback
from repositories.execution_extra_content_repository import ExecutionExtraContentRepository
from repositories.sqlalchemy_execution_extra_content_repository import (
@ -172,7 +173,7 @@ class MessageService:
app_model: App,
message_id: str,
user: Union[Account, EndUser] | None,
rating: str | None,
rating: FeedbackRating | None,
content: str | None,
):
if not user:
@ -197,7 +198,7 @@ class MessageService:
message_id=message.id,
rating=rating,
content=content,
from_source=("user" if isinstance(user, EndUser) else "admin"),
from_source=(FeedbackFromSource.USER if isinstance(user, EndUser) else FeedbackFromSource.ADMIN),
from_end_user_id=(user.id if isinstance(user, EndUser) else None),
from_account_id=(user.id if isinstance(user, Account) else None),
)

View File

@ -9,7 +9,7 @@ import httpx
from flask_login import current_user
from core.helper import encrypter
from core.rag.extractor.firecrawl.firecrawl_app import FirecrawlApp
from core.rag.extractor.firecrawl.firecrawl_app import CrawlStatusResponse, FirecrawlApp, FirecrawlDocumentData
from core.rag.extractor.watercrawl.provider import WaterCrawlProvider
from extensions.ext_redis import redis_client
from extensions.ext_storage import storage
@ -216,8 +216,10 @@ class WebsiteService:
"max_depth": request.options.max_depth,
"use_sitemap": request.options.use_sitemap,
}
return WaterCrawlProvider(api_key=api_key, base_url=config.get("base_url")).crawl_url(
url=request.url, options=options
return dict(
WaterCrawlProvider(api_key=api_key, base_url=config.get("base_url")).crawl_url(
url=request.url, options=options
)
)
@classmethod
@ -270,13 +272,13 @@ class WebsiteService:
@classmethod
def _get_firecrawl_status(cls, job_id: str, api_key: str, config: dict) -> dict[str, Any]:
firecrawl_app = FirecrawlApp(api_key=api_key, base_url=config.get("base_url"))
result = firecrawl_app.check_crawl_status(job_id)
crawl_status_data = {
"status": result.get("status", "active"),
result: CrawlStatusResponse = firecrawl_app.check_crawl_status(job_id)
crawl_status_data: dict[str, Any] = {
"status": result["status"],
"job_id": job_id,
"total": result.get("total", 0),
"current": result.get("current", 0),
"data": result.get("data", []),
"total": result["total"] or 0,
"current": result["current"] or 0,
"data": result["data"],
}
if crawl_status_data["status"] == "completed":
website_crawl_time_cache_key = f"website_crawl_{job_id}"
@ -289,8 +291,8 @@ class WebsiteService:
return crawl_status_data
@classmethod
def _get_watercrawl_status(cls, job_id: str, api_key: str, config: dict) -> dict[str, Any]:
return WaterCrawlProvider(api_key, config.get("base_url")).get_crawl_status(job_id)
def _get_watercrawl_status(cls, job_id: str, api_key: str, config: dict[str, Any]) -> dict[str, Any]:
return dict(WaterCrawlProvider(api_key, config.get("base_url")).get_crawl_status(job_id))
@classmethod
def _get_jinareader_status(cls, job_id: str, api_key: str) -> dict[str, Any]:
@ -343,7 +345,7 @@ class WebsiteService:
@classmethod
def _get_firecrawl_url_data(cls, job_id: str, url: str, api_key: str, config: dict) -> dict[str, Any] | None:
crawl_data: list[dict[str, Any]] | None = None
crawl_data: list[FirecrawlDocumentData] | None = None
file_key = "website_files/" + job_id + ".txt"
if storage.exists(file_key):
stored_data = storage.load_once(file_key)
@ -352,19 +354,22 @@ class WebsiteService:
else:
firecrawl_app = FirecrawlApp(api_key=api_key, base_url=config.get("base_url"))
result = firecrawl_app.check_crawl_status(job_id)
if result.get("status") != "completed":
if result["status"] != "completed":
raise ValueError("Crawl job is not completed")
crawl_data = result.get("data")
crawl_data = result["data"]
if crawl_data:
for item in crawl_data:
if item.get("source_url") == url:
if item["source_url"] == url:
return dict(item)
return None
@classmethod
def _get_watercrawl_url_data(cls, job_id: str, url: str, api_key: str, config: dict) -> dict[str, Any] | None:
return WaterCrawlProvider(api_key, config.get("base_url")).get_crawl_url_data(job_id, url)
def _get_watercrawl_url_data(
cls, job_id: str, url: str, api_key: str, config: dict[str, Any]
) -> dict[str, Any] | None:
result = WaterCrawlProvider(api_key, config.get("base_url")).get_crawl_url_data(job_id, url)
return dict(result) if result is not None else None
@classmethod
def _get_jinareader_url_data(cls, job_id: str, url: str, api_key: str) -> dict[str, Any] | None:
@ -416,8 +421,8 @@ class WebsiteService:
def _scrape_with_firecrawl(cls, request: ScrapeRequest, api_key: str, config: dict) -> dict[str, Any]:
firecrawl_app = FirecrawlApp(api_key=api_key, base_url=config.get("base_url"))
params = {"onlyMainContent": request.only_main_content}
return firecrawl_app.scrape_url(url=request.url, params=params)
return dict(firecrawl_app.scrape_url(url=request.url, params=params))
@classmethod
def _scrape_with_watercrawl(cls, request: ScrapeRequest, api_key: str, config: dict) -> dict[str, Any]:
return WaterCrawlProvider(api_key=api_key, base_url=config.get("base_url")).scrape_url(request.url)
def _scrape_with_watercrawl(cls, request: ScrapeRequest, api_key: str, config: dict[str, Any]) -> dict[str, Any]:
return dict(WaterCrawlProvider(api_key=api_key, base_url=config.get("base_url")).scrape_url(request.url))

View File

@ -14,6 +14,7 @@ from controllers.console.app import wraps
from libs.datetime_utils import naive_utc_now
from models import App, Tenant
from models.account import Account, TenantAccountJoin, TenantAccountRole
from models.enums import FeedbackFromSource, FeedbackRating
from models.model import AppMode, MessageFeedback
from services.feedback_service import FeedbackService
@ -77,8 +78,8 @@ class TestFeedbackExportApi:
app_id=app_id,
conversation_id=conversation_id,
message_id=message_id,
rating="like",
from_source="user",
rating=FeedbackRating.LIKE,
from_source=FeedbackFromSource.USER,
content=None,
from_end_user_id=str(uuid.uuid4()),
from_account_id=None,
@ -90,8 +91,8 @@ class TestFeedbackExportApi:
app_id=app_id,
conversation_id=conversation_id,
message_id=message_id,
rating="dislike",
from_source="admin",
rating=FeedbackRating.DISLIKE,
from_source=FeedbackFromSource.ADMIN,
content="The response was not helpful",
from_end_user_id=None,
from_account_id=str(uuid.uuid4()),
@ -277,8 +278,8 @@ class TestFeedbackExportApi:
# Verify service was called with correct parameters
mock_export_feedbacks.assert_called_once_with(
app_id=mock_app_model.id,
from_source="user",
rating="dislike",
from_source=FeedbackFromSource.USER,
rating=FeedbackRating.DISLIKE,
has_comment=True,
start_date="2024-01-01",
end_date="2024-12-31",

View File

@ -8,6 +8,7 @@ from sqlalchemy.orm import Session
from dify_graph.file import File, FileTransferMethod, FileType
from extensions.ext_database import db
from extensions.storage.storage_type import StorageType
from factories.file_factory import StorageKeyLoader
from models import ToolFile, UploadFile
from models.enums import CreatorUserRole
@ -53,7 +54,7 @@ class TestStorageKeyLoader(unittest.TestCase):
upload_file = UploadFile(
tenant_id=tenant_id,
storage_type="local",
storage_type=StorageType.LOCAL,
key=storage_key,
name="test_file.txt",
size=1024,
@ -288,7 +289,7 @@ class TestStorageKeyLoader(unittest.TestCase):
# Create upload file for other tenant (but don't add to cleanup list)
upload_file_other = UploadFile(
tenant_id=other_tenant_id,
storage_type="local",
storage_type=StorageType.LOCAL,
key="other_tenant_key",
name="other_file.txt",
size=1024,

View File

@ -13,6 +13,7 @@ from dify_graph.variables.types import SegmentType
from dify_graph.variables.variables import StringVariable
from extensions.ext_database import db
from extensions.ext_storage import storage
from extensions.storage.storage_type import StorageType
from factories.variable_factory import build_segment
from libs import datetime_utils
from models.enums import CreatorUserRole
@ -347,7 +348,7 @@ class TestDraftVariableLoader(unittest.TestCase):
# Create an upload file record
upload_file = UploadFile(
tenant_id=self._test_tenant_id,
storage_type="local",
storage_type=StorageType.LOCAL,
key=f"test_offload_{uuid.uuid4()}.json",
name="test_offload.json",
size=len(content_bytes),
@ -450,7 +451,7 @@ class TestDraftVariableLoader(unittest.TestCase):
# Create upload file record
upload_file = UploadFile(
tenant_id=self._test_tenant_id,
storage_type="local",
storage_type=StorageType.LOCAL,
key=f"test_integration_{uuid.uuid4()}.txt",
name="test_integration.txt",
size=len(content_bytes),

View File

@ -6,6 +6,7 @@ from sqlalchemy import delete
from core.db.session_factory import session_factory
from dify_graph.variables.segments import StringSegment
from extensions.storage.storage_type import StorageType
from models import Tenant
from models.enums import CreatorUserRole
from models.model import App, UploadFile
@ -197,7 +198,7 @@ class TestDeleteDraftVariablesWithOffloadIntegration:
with session_factory.create_session() as session:
upload_file1 = UploadFile(
tenant_id=tenant.id,
storage_type="local",
storage_type=StorageType.LOCAL,
key="test/file1.json",
name="file1.json",
size=1024,
@ -210,7 +211,7 @@ class TestDeleteDraftVariablesWithOffloadIntegration:
)
upload_file2 = UploadFile(
tenant_id=tenant.id,
storage_type="local",
storage_type=StorageType.LOCAL,
key="test/file2.json",
name="file2.json",
size=2048,
@ -430,7 +431,7 @@ class TestDeleteDraftVariablesSessionCommit:
with session_factory.create_session() as session:
upload_file1 = UploadFile(
tenant_id=tenant.id,
storage_type="local",
storage_type=StorageType.LOCAL,
key="test/file1.json",
name="file1.json",
size=1024,
@ -443,7 +444,7 @@ class TestDeleteDraftVariablesSessionCommit:
)
upload_file2 = UploadFile(
tenant_id=tenant.id,
storage_type="local",
storage_type=StorageType.LOCAL,
key="test/file2.json",
name="file2.json",
size=2048,

View File

@ -8,6 +8,7 @@ from sqlalchemy.orm import Session
from dify_graph.file import File, FileTransferMethod, FileType
from extensions.ext_database import db
from extensions.storage.storage_type import StorageType
from factories.file_factory import StorageKeyLoader
from models import ToolFile, UploadFile
from models.enums import CreatorUserRole
@ -53,7 +54,7 @@ class TestStorageKeyLoader(unittest.TestCase):
upload_file = UploadFile(
tenant_id=tenant_id,
storage_type="local",
storage_type=StorageType.LOCAL,
key=storage_key,
name="test_file.txt",
size=1024,
@ -289,7 +290,7 @@ class TestStorageKeyLoader(unittest.TestCase):
# Create upload file for other tenant (but don't add to cleanup list)
upload_file_other = UploadFile(
tenant_id=other_tenant_id,
storage_type="local",
storage_type=StorageType.LOCAL,
key="other_tenant_key",
name="other_file.txt",
size=1024,

View File

@ -13,6 +13,7 @@ from uuid import uuid4
import pytest
from extensions.storage.storage_type import StorageType
from models import Account
from models.dataset import Dataset, Document
from models.enums import CreatorUserRole, DataSourceType, DocumentCreatedFrom, IndexingStatus
@ -198,7 +199,7 @@ class DocumentStatusTestDataFactory:
"""
upload_file = UploadFile(
tenant_id=tenant_id,
storage_type="local",
storage_type=StorageType.LOCAL,
key=f"uploads/{uuid4()}",
name=name,
size=128,

View File

@ -7,6 +7,7 @@ from sqlalchemy.orm import Session
from core.plugin.impl.exc import PluginDaemonClientSideError
from models import Account
from models.enums import MessageFileBelongsTo
from models.model import AppModelConfig, Conversation, EndUser, Message, MessageAgentThought
from services.account_service import AccountService, TenantService
from services.agent_service import AgentService
@ -852,7 +853,7 @@ class TestAgentService:
type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL,
url="http://example.com/file1.jpg",
belongs_to="user",
belongs_to=MessageFileBelongsTo.USER,
created_by_role=CreatorUserRole.ACCOUNT,
created_by=message.from_account_id,
)
@ -861,7 +862,7 @@ class TestAgentService:
type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL,
url="http://example.com/file2.png",
belongs_to="user",
belongs_to=MessageFileBelongsTo.USER,
created_by_role=CreatorUserRole.ACCOUNT,
created_by=message.from_account_id,
)

View File

@ -7,6 +7,7 @@ from uuid import uuid4
import pytest
from extensions.storage.storage_type import StorageType
from models import Account
from models.dataset import Dataset, Document
from models.enums import CreatorUserRole, DataSourceType, DocumentCreatedFrom
@ -83,7 +84,7 @@ def make_upload_file(db_session_with_containers, tenant_id: str, file_id: str, n
"""Persist an upload file row referenced by document.data_source_info."""
upload_file = UploadFile(
tenant_id=tenant_id,
storage_type="local",
storage_type=StorageType.LOCAL,
key=f"uploads/{uuid4()}",
name=name,
size=128,

View File

@ -8,6 +8,7 @@ from unittest import mock
import pytest
from extensions.ext_database import db
from models.enums import FeedbackFromSource, FeedbackRating
from models.model import App, Conversation, Message
from services.feedback_service import FeedbackService
@ -47,8 +48,8 @@ class TestFeedbackService:
app_id=app_id,
conversation_id="test-conversation-id",
message_id="test-message-id",
rating="like",
from_source="user",
rating=FeedbackRating.LIKE,
from_source=FeedbackFromSource.USER,
content="Great answer!",
from_end_user_id="user-123",
from_account_id=None,
@ -61,8 +62,8 @@ class TestFeedbackService:
app_id=app_id,
conversation_id="test-conversation-id",
message_id="test-message-id",
rating="dislike",
from_source="admin",
rating=FeedbackRating.DISLIKE,
from_source=FeedbackFromSource.ADMIN,
content="Could be more detailed",
from_end_user_id=None,
from_account_id="admin-456",
@ -179,8 +180,8 @@ class TestFeedbackService:
# Test with filters
result = FeedbackService.export_feedbacks(
app_id=sample_data["app"].id,
from_source="admin",
rating="dislike",
from_source=FeedbackFromSource.ADMIN,
rating=FeedbackRating.DISLIKE,
has_comment=True,
start_date="2024-01-01",
end_date="2024-12-31",
@ -293,8 +294,8 @@ class TestFeedbackService:
app_id=sample_data["app"].id,
conversation_id="test-conversation-id",
message_id="test-message-id",
rating="dislike",
from_source="user",
rating=FeedbackRating.DISLIKE,
from_source=FeedbackFromSource.USER,
content="回答不够详细,需要更多信息",
from_end_user_id="user-123",
from_account_id=None,

View File

@ -9,6 +9,7 @@ from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound
from configs import dify_config
from extensions.storage.storage_type import StorageType
from models import Account, Tenant
from models.enums import CreatorUserRole
from models.model import EndUser, UploadFile
@ -140,7 +141,7 @@ class TestFileService:
upload_file = UploadFile(
tenant_id=account.current_tenant_id if hasattr(account, "current_tenant_id") else str(fake.uuid4()),
storage_type="local",
storage_type=StorageType.LOCAL,
key=f"upload_files/test/{fake.uuid4()}.txt",
name="test_file.txt",
size=1024,

View File

@ -7,6 +7,7 @@ import pytest
from sqlalchemy.orm import Session
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.enums import FeedbackFromSource, FeedbackRating
from models.model import (
App,
AppAnnotationHitHistory,
@ -172,8 +173,8 @@ class TestAppMessageExportServiceIntegration:
app_id=app.id,
conversation_id=conversation.id,
message_id=first_message.id,
rating="like",
from_source="user",
rating=FeedbackRating.LIKE,
from_source=FeedbackFromSource.USER,
content="first",
from_end_user_id=conversation.from_end_user_id,
)
@ -181,8 +182,8 @@ class TestAppMessageExportServiceIntegration:
app_id=app.id,
conversation_id=conversation.id,
message_id=first_message.id,
rating="dislike",
from_source="user",
rating=FeedbackRating.DISLIKE,
from_source=FeedbackFromSource.USER,
content="second",
from_end_user_id=conversation.from_end_user_id,
)
@ -190,8 +191,8 @@ class TestAppMessageExportServiceIntegration:
app_id=app.id,
conversation_id=conversation.id,
message_id=first_message.id,
rating="like",
from_source="admin",
rating=FeedbackRating.LIKE,
from_source=FeedbackFromSource.ADMIN,
content="should-be-filtered",
from_account_id=str(uuid.uuid4()),
)

View File

@ -4,6 +4,7 @@ import pytest
from faker import Faker
from sqlalchemy.orm import Session
from models.enums import FeedbackRating
from models.model import MessageFeedback
from services.app_service import AppService
from services.errors.message import (
@ -405,7 +406,7 @@ class TestMessageService:
message = self._create_test_message(db_session_with_containers, app, conversation, account, fake)
# Create feedback
rating = "like"
rating = FeedbackRating.LIKE
content = fake.text(max_nb_chars=100)
feedback = MessageService.create_feedback(
app_model=app, message_id=message.id, user=account, rating=rating, content=content
@ -435,7 +436,11 @@ class TestMessageService:
# Test creating feedback with no user
with pytest.raises(ValueError, match="user cannot be None"):
MessageService.create_feedback(
app_model=app, message_id=message.id, user=None, rating="like", content=fake.text(max_nb_chars=100)
app_model=app,
message_id=message.id,
user=None,
rating=FeedbackRating.LIKE,
content=fake.text(max_nb_chars=100),
)
def test_create_feedback_update_existing(
@ -452,14 +457,14 @@ class TestMessageService:
message = self._create_test_message(db_session_with_containers, app, conversation, account, fake)
# Create initial feedback
initial_rating = "like"
initial_rating = FeedbackRating.LIKE
initial_content = fake.text(max_nb_chars=100)
feedback = MessageService.create_feedback(
app_model=app, message_id=message.id, user=account, rating=initial_rating, content=initial_content
)
# Update feedback
updated_rating = "dislike"
updated_rating = FeedbackRating.DISLIKE
updated_content = fake.text(max_nb_chars=100)
updated_feedback = MessageService.create_feedback(
app_model=app, message_id=message.id, user=account, rating=updated_rating, content=updated_content
@ -487,7 +492,11 @@ class TestMessageService:
# Create initial feedback
feedback = MessageService.create_feedback(
app_model=app, message_id=message.id, user=account, rating="like", content=fake.text(max_nb_chars=100)
app_model=app,
message_id=message.id,
user=account,
rating=FeedbackRating.LIKE,
content=fake.text(max_nb_chars=100),
)
# Delete feedback by setting rating to None
@ -538,7 +547,7 @@ class TestMessageService:
app_model=app,
message_id=message.id,
user=account,
rating="like" if i % 2 == 0 else "dislike",
rating=FeedbackRating.LIKE if i % 2 == 0 else FeedbackRating.DISLIKE,
content=f"Feedback {i}: {fake.text(max_nb_chars=50)}",
)
feedbacks.append(feedback)
@ -568,7 +577,11 @@ class TestMessageService:
message = self._create_test_message(db_session_with_containers, app, conversation, account, fake)
MessageService.create_feedback(
app_model=app, message_id=message.id, user=account, rating="like", content=f"Feedback {i}"
app_model=app,
message_id=message.id,
user=account,
rating=FeedbackRating.LIKE,
content=f"Feedback {i}",
)
# Get feedbacks with pagination

View File

@ -11,7 +11,7 @@ from sqlalchemy.orm import Session
from enums.cloud_plan import CloudPlan
from extensions.ext_redis import redis_client
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.enums import DataSourceType
from models.enums import DataSourceType, FeedbackFromSource, FeedbackRating, MessageChainType, MessageFileBelongsTo
from models.model import (
App,
AppAnnotationHitHistory,
@ -166,7 +166,7 @@ class TestMessagesCleanServiceIntegration:
name="Test conversation",
inputs={},
status="normal",
from_source="api",
from_source=FeedbackFromSource.USER,
from_end_user_id=str(uuid.uuid4()),
)
db_session_with_containers.add(conversation)
@ -196,7 +196,7 @@ class TestMessagesCleanServiceIntegration:
answer_unit_price=Decimal("0.002"),
total_price=Decimal("0.003"),
currency="USD",
from_source="api",
from_source=FeedbackFromSource.USER,
from_account_id=conversation.from_end_user_id,
created_at=created_at,
)
@ -216,8 +216,8 @@ class TestMessagesCleanServiceIntegration:
app_id=message.app_id,
conversation_id=message.conversation_id,
message_id=message.id,
rating="like",
from_source="api",
rating=FeedbackRating.LIKE,
from_source=FeedbackFromSource.USER,
from_end_user_id=str(uuid.uuid4()),
)
db_session_with_containers.add(feedback)
@ -236,7 +236,7 @@ class TestMessagesCleanServiceIntegration:
# MessageChain
chain = MessageChain(
message_id=message.id,
type="system",
type=MessageChainType.SYSTEM,
input=json.dumps({"test": "input"}),
output=json.dumps({"test": "output"}),
)
@ -249,7 +249,7 @@ class TestMessagesCleanServiceIntegration:
type="image",
transfer_method="local_file",
url="http://example.com/test.jpg",
belongs_to="user",
belongs_to=MessageFileBelongsTo.USER,
created_by_role="end_user",
created_by=str(uuid.uuid4()),
)

View File

@ -48,41 +48,42 @@ class TestToolTransformService:
name=fake.company(),
description=fake.text(max_nb_chars=100),
icon='{"background": "#FF6B6B", "content": "🔧"}',
icon_dark='{"background": "#252525", "content": "🔧"}',
tenant_id="test_tenant_id",
user_id="test_user_id",
credentials={"auth_type": "api_key_header", "api_key": "test_key"},
provider_type="api",
credentials_str='{"auth_type": "api_key_header", "api_key": "test_key"}',
schema="{}",
schema_type_str="openapi",
tools_str="[]",
)
elif provider_type == "builtin":
provider = BuiltinToolProvider(
name=fake.company(),
description=fake.text(max_nb_chars=100),
icon="🔧",
icon_dark="🔧",
tenant_id="test_tenant_id",
user_id="test_user_id",
provider="test_provider",
credential_type="api_key",
credentials={"api_key": "test_key"},
encrypted_credentials='{"api_key": "test_key"}',
)
elif provider_type == "workflow":
provider = WorkflowToolProvider(
name=fake.company(),
description=fake.text(max_nb_chars=100),
icon='{"background": "#FF6B6B", "content": "🔧"}',
icon_dark='{"background": "#252525", "content": "🔧"}',
tenant_id="test_tenant_id",
user_id="test_user_id",
workflow_id="test_workflow_id",
app_id="test_workflow_id",
label="Test Workflow",
version="1.0.0",
parameter_configuration="[]",
)
elif provider_type == "mcp":
provider = MCPToolProvider(
name=fake.company(),
description=fake.text(max_nb_chars=100),
provider_icon='{"background": "#FF6B6B", "content": "🔧"}',
icon='{"background": "#FF6B6B", "content": "🔧"}',
tenant_id="test_tenant_id",
user_id="test_user_id",
server_url="https://mcp.example.com",
server_url_hash="test_server_url_hash",
server_identifier="test_server",
tools='[{"name": "test_tool", "description": "Test tool"}]',
authed=True,

View File

@ -13,6 +13,7 @@ import pytest
from faker import Faker
from sqlalchemy.orm import Session
from extensions.storage.storage_type import StorageType
from libs.datetime_utils import naive_utc_now
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Dataset, Document, DocumentSegment
@ -209,7 +210,7 @@ class TestBatchCleanDocumentTask:
upload_file = UploadFile(
tenant_id=account.current_tenant.id,
storage_type="local",
storage_type=StorageType.LOCAL,
key=f"test_files/{fake.file_name()}",
name=fake.file_name(),
size=1024,

View File

@ -19,6 +19,7 @@ import pytest
from faker import Faker
from sqlalchemy.orm import Session
from extensions.storage.storage_type import StorageType
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Dataset, Document, DocumentSegment
from models.enums import CreatorUserRole, DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus
@ -203,7 +204,7 @@ class TestBatchCreateSegmentToIndexTask:
upload_file = UploadFile(
tenant_id=tenant.id,
storage_type="local",
storage_type=StorageType.LOCAL,
key=f"test_files/{fake.file_name()}",
name=fake.file_name(),
size=1024,

View File

@ -18,6 +18,7 @@ import pytest
from faker import Faker
from sqlalchemy.orm import Session
from extensions.storage.storage_type import StorageType
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import (
AppDatasetJoin,
@ -254,7 +255,7 @@ class TestCleanDatasetTask:
upload_file = UploadFile(
tenant_id=tenant.id,
storage_type="local",
storage_type=StorageType.LOCAL,
key=f"test_files/{fake.file_name()}",
name=fake.file_name(),
size=1024,
@ -925,7 +926,7 @@ class TestCleanDatasetTask:
special_filename = f"test_file_{special_content}.txt"
upload_file = UploadFile(
tenant_id=tenant.id,
storage_type="local",
storage_type=StorageType.LOCAL,
key=f"test_files/{special_filename}",
name=special_filename,
size=1024,

View File

@ -6,6 +6,7 @@ import pytest
from core.db.session_factory import session_factory
from dify_graph.variables.segments import StringSegment
from dify_graph.variables.types import SegmentType
from extensions.storage.storage_type import StorageType
from libs.datetime_utils import naive_utc_now
from models import Tenant
from models.enums import CreatorUserRole
@ -78,7 +79,7 @@ def _create_offload_data(db_session_with_containers, *, tenant_id: str, app_id:
for i in range(count):
upload_file = UploadFile(
tenant_id=tenant_id,
storage_type="local",
storage_type=StorageType.LOCAL,
key=f"test/file-{uuid.uuid4()}-{i}.json",
name=f"file-{i}.json",
size=1024 + i,

View File

@ -0,0 +1,56 @@
from pathlib import Path
from extensions.storage.opendal_storage import OpenDALStorage
class TestOpenDALFsDefaultRoot:
"""Test that OpenDALStorage with scheme='fs' works correctly when no root is provided."""
def test_fs_without_root_uses_default(self, tmp_path, monkeypatch):
"""When no root is specified, the default 'storage' should be used and passed to the Operator."""
# Change to tmp_path so the default "storage" dir is created there
monkeypatch.chdir(tmp_path)
# Ensure no OPENDAL_FS_ROOT env var is set
monkeypatch.delenv("OPENDAL_FS_ROOT", raising=False)
storage = OpenDALStorage(scheme="fs")
# The default directory should have been created
assert (tmp_path / "storage").is_dir()
# The storage should be functional
storage.save("test_default_root.txt", b"hello")
assert storage.exists("test_default_root.txt")
assert storage.load_once("test_default_root.txt") == b"hello"
# Cleanup
storage.delete("test_default_root.txt")
def test_fs_with_explicit_root(self, tmp_path):
"""When root is explicitly provided, it should be used."""
custom_root = str(tmp_path / "custom_storage")
storage = OpenDALStorage(scheme="fs", root=custom_root)
assert Path(custom_root).is_dir()
storage.save("test_explicit_root.txt", b"world")
assert storage.exists("test_explicit_root.txt")
assert storage.load_once("test_explicit_root.txt") == b"world"
# Cleanup
storage.delete("test_explicit_root.txt")
def test_fs_with_env_var_root(self, tmp_path, monkeypatch):
"""When OPENDAL_FS_ROOT env var is set, it should be picked up via _get_opendal_kwargs."""
env_root = str(tmp_path / "env_storage")
monkeypatch.setenv("OPENDAL_FS_ROOT", env_root)
# Ensure .env file doesn't interfere
monkeypatch.chdir(tmp_path)
storage = OpenDALStorage(scheme="fs")
assert Path(env_root).is_dir()
storage.save("test_env_root.txt", b"env_data")
assert storage.exists("test_env_root.txt")
assert storage.load_once("test_env_root.txt") == b"env_data"
# Cleanup
storage.delete("test_env_root.txt")

View File

@ -28,6 +28,7 @@ from controllers.console.datasets.datasets import (
from controllers.console.datasets.error import DatasetInUseError, DatasetNameDuplicateError, IndexingEstimateError
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
from core.provider_manager import ProviderManager
from extensions.storage.storage_type import StorageType
from models.enums import CreatorUserRole
from models.model import ApiToken, UploadFile
from services.dataset_service import DatasetPermissionService, DatasetService
@ -1121,7 +1122,7 @@ class TestDatasetIndexingEstimateApi:
def _upload_file(self, *, tenant_id: str = "tenant-1", file_id: str = "file-1") -> UploadFile:
upload_file = UploadFile(
tenant_id=tenant_id,
storage_type="local",
storage_type=StorageType.LOCAL,
key="key",
name="name.txt",
size=1,

View File

@ -2,6 +2,7 @@ from datetime import datetime
from unittest.mock import MagicMock, patch
import controllers.console.explore.banner as banner_module
from models.enums import BannerStatus
def unwrap(func):
@ -20,7 +21,7 @@ class TestBannerApi:
banner.content = {"text": "hello"}
banner.link = "https://example.com"
banner.sort = 1
banner.status = "enabled"
banner.status = BannerStatus.ENABLED
banner.created_at = datetime(2024, 1, 1)
query = MagicMock()
@ -54,7 +55,7 @@ class TestBannerApi:
banner.content = {"text": "fallback"}
banner.link = None
banner.sort = 1
banner.status = "enabled"
banner.status = BannerStatus.ENABLED
banner.created_at = None
query = MagicMock()

Some files were not shown because too many files have changed in this diff Show More