Merge branch 'main' into feat/model-total-credits

This commit is contained in:
Coding On Star 2025-12-19 14:56:10 +08:00 committed by GitHub
commit 8d884aad29
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1020 changed files with 39160 additions and 4092 deletions

View File

@ -6,7 +6,7 @@ cd web && pnpm install
pipx install uv
echo "alias start-api=\"cd $WORKSPACE_ROOT/api && uv run python -m flask run --host 0.0.0.0 --port=5001 --debug\"" >> ~/.bashrc
echo "alias start-worker=\"cd $WORKSPACE_ROOT/api && uv run python -m celery -A app.celery worker -P threads -c 1 --loglevel INFO -Q dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor\"" >> ~/.bashrc
echo "alias start-worker=\"cd $WORKSPACE_ROOT/api && uv run python -m celery -A app.celery worker -P threads -c 1 --loglevel INFO -Q dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention\"" >> ~/.bashrc
echo "alias start-web=\"cd $WORKSPACE_ROOT/web && pnpm dev\"" >> ~/.bashrc
echo "alias start-web-prod=\"cd $WORKSPACE_ROOT/web && pnpm build && pnpm start\"" >> ~/.bashrc
echo "alias start-containers=\"cd $WORKSPACE_ROOT/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env up -d\"" >> ~/.bashrc

8
.github/CODEOWNERS vendored
View File

@ -6,6 +6,12 @@
* @crazywoola @laipz8200 @Yeuoly
# CODEOWNERS file
.github/CODEOWNERS @laipz8200 @crazywoola
# Docs
docs/ @crazywoola
# Backend (default owner, more specific rules below will override)
api/ @QuantumGhost
@ -116,7 +122,7 @@ api/controllers/console/feature.py @GarfieldDai @GareArc
api/controllers/web/feature.py @GarfieldDai @GareArc
# Backend - Database Migrations
api/migrations/ @snakevash @laipz8200
api/migrations/ @snakevash @laipz8200 @MRZHUH
# Frontend
web/ @iamjoel

View File

@ -93,4 +93,12 @@ jobs:
# Create a detailed coverage summary
echo "### Test Coverage Summary :test_tube:" >> $GITHUB_STEP_SUMMARY
echo "Total Coverage: ${TOTAL_COVERAGE}%" >> $GITHUB_STEP_SUMMARY
uv run --project api coverage report --format=markdown >> $GITHUB_STEP_SUMMARY
{
echo ""
echo "<details><summary>File-level coverage (click to expand)</summary>"
echo ""
echo '```'
uv run --project api coverage report -m
echo '```'
echo "</details>"
} >> $GITHUB_STEP_SUMMARY

View File

@ -79,7 +79,7 @@ jobs:
with:
node-version: 22
cache: pnpm
cache-dependency-path: ./web/package.json
cache-dependency-path: ./web/pnpm-lock.yaml
- name: Web dependencies
working-directory: ./web

View File

@ -90,7 +90,7 @@ jobs:
with:
node-version: 22
cache: pnpm
cache-dependency-path: ./web/package.json
cache-dependency-path: ./web/pnpm-lock.yaml
- name: Web dependencies
if: steps.changed-files.outputs.any_changed == 'true'

View File

@ -55,7 +55,7 @@ jobs:
with:
node-version: 'lts/*'
cache: pnpm
cache-dependency-path: ./web/package.json
cache-dependency-path: ./web/pnpm-lock.yaml
- name: Install dependencies
if: env.FILES_CHANGED == 'true'

View File

@ -13,6 +13,7 @@ jobs:
runs-on: ubuntu-latest
defaults:
run:
shell: bash
working-directory: ./web
steps:
@ -21,14 +22,7 @@ jobs:
with:
persist-credentials: false
- name: Check changed files
id: changed-files
uses: tj-actions/changed-files@v46
with:
files: web/**
- name: Install pnpm
if: steps.changed-files.outputs.any_changed == 'true'
uses: pnpm/action-setup@v4
with:
package_json_file: web/package.json
@ -36,23 +30,355 @@ jobs:
- name: Setup Node.js
uses: actions/setup-node@v4
if: steps.changed-files.outputs.any_changed == 'true'
with:
node-version: 22
cache: pnpm
cache-dependency-path: ./web/package.json
cache-dependency-path: ./web/pnpm-lock.yaml
- name: Restore Jest cache
uses: actions/cache@v4
with:
path: web/.cache/jest
key: ${{ runner.os }}-jest-${{ hashFiles('web/pnpm-lock.yaml') }}
restore-keys: |
${{ runner.os }}-jest-
- name: Install dependencies
if: steps.changed-files.outputs.any_changed == 'true'
working-directory: ./web
run: pnpm install --frozen-lockfile
- name: Check i18n types synchronization
if: steps.changed-files.outputs.any_changed == 'true'
working-directory: ./web
run: pnpm run check:i18n-types
- name: Run tests
if: steps.changed-files.outputs.any_changed == 'true'
working-directory: ./web
run: pnpm test
run: |
pnpm exec jest \
--ci \
--maxWorkers=100% \
--coverage \
--passWithNoTests
- name: Coverage Summary
if: always()
id: coverage-summary
run: |
set -eo pipefail
COVERAGE_FILE="coverage/coverage-final.json"
COVERAGE_SUMMARY_FILE="coverage/coverage-summary.json"
if [ ! -f "$COVERAGE_FILE" ] && [ ! -f "$COVERAGE_SUMMARY_FILE" ]; then
echo "has_coverage=false" >> "$GITHUB_OUTPUT"
echo "### 🚨 Test Coverage Report :test_tube:" >> "$GITHUB_STEP_SUMMARY"
echo "Coverage data not found. Ensure Jest runs with coverage enabled." >> "$GITHUB_STEP_SUMMARY"
exit 0
fi
echo "has_coverage=true" >> "$GITHUB_OUTPUT"
node <<'NODE' >> "$GITHUB_STEP_SUMMARY"
const fs = require('fs');
const path = require('path');
let libCoverage = null;
try {
libCoverage = require('istanbul-lib-coverage');
} catch (error) {
libCoverage = null;
}
const summaryPath = path.join('coverage', 'coverage-summary.json');
const finalPath = path.join('coverage', 'coverage-final.json');
const hasSummary = fs.existsSync(summaryPath);
const hasFinal = fs.existsSync(finalPath);
if (!hasSummary && !hasFinal) {
console.log('### Test Coverage Summary :test_tube:');
console.log('');
console.log('No coverage data found.');
process.exit(0);
}
const summary = hasSummary
? JSON.parse(fs.readFileSync(summaryPath, 'utf8'))
: null;
const coverage = hasFinal
? JSON.parse(fs.readFileSync(finalPath, 'utf8'))
: null;
const getLineCoverageFromStatements = (statementMap, statementHits) => {
const lineHits = {};
if (!statementMap || !statementHits) {
return lineHits;
}
Object.entries(statementMap).forEach(([key, statement]) => {
const line = statement?.start?.line;
if (!line) {
return;
}
const hits = statementHits[key] ?? 0;
const previous = lineHits[line];
lineHits[line] = previous === undefined ? hits : Math.max(previous, hits);
});
return lineHits;
};
const getFileCoverage = (entry) => (
libCoverage ? libCoverage.createFileCoverage(entry) : null
);
const getLineHits = (entry, fileCoverage) => {
const lineHits = entry.l ?? {};
if (Object.keys(lineHits).length > 0) {
return lineHits;
}
if (fileCoverage) {
return fileCoverage.getLineCoverage();
}
return getLineCoverageFromStatements(entry.statementMap ?? {}, entry.s ?? {});
};
const getUncoveredLines = (entry, fileCoverage, lineHits) => {
if (lineHits && Object.keys(lineHits).length > 0) {
return Object.entries(lineHits)
.filter(([, count]) => count === 0)
.map(([line]) => Number(line))
.sort((a, b) => a - b);
}
if (fileCoverage) {
return fileCoverage.getUncoveredLines();
}
return [];
};
const totals = {
lines: { covered: 0, total: 0 },
statements: { covered: 0, total: 0 },
branches: { covered: 0, total: 0 },
functions: { covered: 0, total: 0 },
};
const fileSummaries = [];
if (summary) {
const totalEntry = summary.total ?? {};
['lines', 'statements', 'branches', 'functions'].forEach((key) => {
if (totalEntry[key]) {
totals[key].covered = totalEntry[key].covered ?? 0;
totals[key].total = totalEntry[key].total ?? 0;
}
});
Object.entries(summary)
.filter(([file]) => file !== 'total')
.forEach(([file, data]) => {
fileSummaries.push({
file,
pct: data.lines?.pct ?? data.statements?.pct ?? 0,
lines: {
covered: data.lines?.covered ?? 0,
total: data.lines?.total ?? 0,
},
});
});
} else if (coverage) {
Object.entries(coverage).forEach(([file, entry]) => {
const fileCoverage = getFileCoverage(entry);
const lineHits = getLineHits(entry, fileCoverage);
const statementHits = entry.s ?? {};
const branchHits = entry.b ?? {};
const functionHits = entry.f ?? {};
const lineTotal = Object.keys(lineHits).length;
const lineCovered = Object.values(lineHits).filter((n) => n > 0).length;
const statementTotal = Object.keys(statementHits).length;
const statementCovered = Object.values(statementHits).filter((n) => n > 0).length;
const branchTotal = Object.values(branchHits).reduce((acc, branches) => acc + branches.length, 0);
const branchCovered = Object.values(branchHits).reduce(
(acc, branches) => acc + branches.filter((n) => n > 0).length,
0,
);
const functionTotal = Object.keys(functionHits).length;
const functionCovered = Object.values(functionHits).filter((n) => n > 0).length;
totals.lines.total += lineTotal;
totals.lines.covered += lineCovered;
totals.statements.total += statementTotal;
totals.statements.covered += statementCovered;
totals.branches.total += branchTotal;
totals.branches.covered += branchCovered;
totals.functions.total += functionTotal;
totals.functions.covered += functionCovered;
const pct = (covered, tot) => (tot > 0 ? (covered / tot) * 100 : 0);
fileSummaries.push({
file,
pct: pct(lineCovered || statementCovered, lineTotal || statementTotal),
lines: {
covered: lineCovered || statementCovered,
total: lineTotal || statementTotal,
},
});
});
}
const pct = (covered, tot) => (tot > 0 ? ((covered / tot) * 100).toFixed(2) : '0.00');
console.log('### Test Coverage Summary :test_tube:');
console.log('');
console.log('| Metric | Coverage | Covered / Total |');
console.log('|--------|----------|-----------------|');
console.log(`| Lines | ${pct(totals.lines.covered, totals.lines.total)}% | ${totals.lines.covered} / ${totals.lines.total} |`);
console.log(`| Statements | ${pct(totals.statements.covered, totals.statements.total)}% | ${totals.statements.covered} / ${totals.statements.total} |`);
console.log(`| Branches | ${pct(totals.branches.covered, totals.branches.total)}% | ${totals.branches.covered} / ${totals.branches.total} |`);
console.log(`| Functions | ${pct(totals.functions.covered, totals.functions.total)}% | ${totals.functions.covered} / ${totals.functions.total} |`);
console.log('');
console.log('<details><summary>File coverage (lowest lines first)</summary>');
console.log('');
console.log('```');
fileSummaries
.sort((a, b) => (a.pct - b.pct) || (b.lines.total - a.lines.total))
.slice(0, 25)
.forEach(({ file, pct, lines }) => {
console.log(`${pct.toFixed(2)}%\t${lines.covered}/${lines.total}\t${file}`);
});
console.log('```');
console.log('</details>');
if (coverage) {
const pctValue = (covered, tot) => {
if (tot === 0) {
return '0';
}
return ((covered / tot) * 100)
.toFixed(2)
.replace(/\.?0+$/, '');
};
const formatLineRanges = (lines) => {
if (lines.length === 0) {
return '';
}
const ranges = [];
let start = lines[0];
let end = lines[0];
for (let i = 1; i < lines.length; i += 1) {
const current = lines[i];
if (current === end + 1) {
end = current;
continue;
}
ranges.push(start === end ? `${start}` : `${start}-${end}`);
start = current;
end = current;
}
ranges.push(start === end ? `${start}` : `${start}-${end}`);
return ranges.join(',');
};
const tableTotals = {
statements: { covered: 0, total: 0 },
branches: { covered: 0, total: 0 },
functions: { covered: 0, total: 0 },
lines: { covered: 0, total: 0 },
};
const tableRows = Object.entries(coverage)
.map(([file, entry]) => {
const fileCoverage = getFileCoverage(entry);
const lineHits = getLineHits(entry, fileCoverage);
const statementHits = entry.s ?? {};
const branchHits = entry.b ?? {};
const functionHits = entry.f ?? {};
const lineTotal = Object.keys(lineHits).length;
const lineCovered = Object.values(lineHits).filter((n) => n > 0).length;
const statementTotal = Object.keys(statementHits).length;
const statementCovered = Object.values(statementHits).filter((n) => n > 0).length;
const branchTotal = Object.values(branchHits).reduce((acc, branches) => acc + branches.length, 0);
const branchCovered = Object.values(branchHits).reduce(
(acc, branches) => acc + branches.filter((n) => n > 0).length,
0,
);
const functionTotal = Object.keys(functionHits).length;
const functionCovered = Object.values(functionHits).filter((n) => n > 0).length;
tableTotals.lines.total += lineTotal;
tableTotals.lines.covered += lineCovered;
tableTotals.statements.total += statementTotal;
tableTotals.statements.covered += statementCovered;
tableTotals.branches.total += branchTotal;
tableTotals.branches.covered += branchCovered;
tableTotals.functions.total += functionTotal;
tableTotals.functions.covered += functionCovered;
const uncoveredLines = getUncoveredLines(entry, fileCoverage, lineHits);
const filePath = entry.path ?? file;
const relativePath = path.isAbsolute(filePath)
? path.relative(process.cwd(), filePath)
: filePath;
return {
file: relativePath || file,
statements: pctValue(statementCovered, statementTotal),
branches: pctValue(branchCovered, branchTotal),
functions: pctValue(functionCovered, functionTotal),
lines: pctValue(lineCovered, lineTotal),
uncovered: formatLineRanges(uncoveredLines),
};
})
.sort((a, b) => a.file.localeCompare(b.file));
const columns = [
{ key: 'file', header: 'File', align: 'left' },
{ key: 'statements', header: '% Stmts', align: 'right' },
{ key: 'branches', header: '% Branch', align: 'right' },
{ key: 'functions', header: '% Funcs', align: 'right' },
{ key: 'lines', header: '% Lines', align: 'right' },
{ key: 'uncovered', header: 'Uncovered Line #s', align: 'left' },
];
const allFilesRow = {
file: 'All files',
statements: pctValue(tableTotals.statements.covered, tableTotals.statements.total),
branches: pctValue(tableTotals.branches.covered, tableTotals.branches.total),
functions: pctValue(tableTotals.functions.covered, tableTotals.functions.total),
lines: pctValue(tableTotals.lines.covered, tableTotals.lines.total),
uncovered: '',
};
const rowsForOutput = [allFilesRow, ...tableRows];
const formatRow = (row) => `| ${columns
.map(({ key }) => String(row[key] ?? ''))
.join(' | ')} |`;
const headerRow = `| ${columns.map(({ header }) => header).join(' | ')} |`;
const dividerRow = `| ${columns
.map(({ align }) => (align === 'right' ? '---:' : ':---'))
.join(' | ')} |`;
console.log('');
console.log('<details><summary>Jest coverage table</summary>');
console.log('');
console.log(headerRow);
console.log(dividerRow);
rowsForOutput.forEach((row) => console.log(formatRow(row)));
console.log('</details>');
}
NODE
- name: Upload Coverage Artifact
if: steps.coverage-summary.outputs.has_coverage == 'true'
uses: actions/upload-artifact@v4
with:
name: web-coverage-report
path: web/coverage
retention-days: 30
if-no-files-found: error

View File

@ -37,7 +37,7 @@
"-c",
"1",
"-Q",
"dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor",
"dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention",
"--loglevel",
"INFO"
],

View File

@ -543,6 +543,25 @@ APP_MAX_EXECUTION_TIME=1200
APP_DEFAULT_ACTIVE_REQUESTS=0
APP_MAX_ACTIVE_REQUESTS=0
# Aliyun SLS Logstore Configuration
# Aliyun Access Key ID
ALIYUN_SLS_ACCESS_KEY_ID=
# Aliyun Access Key Secret
ALIYUN_SLS_ACCESS_KEY_SECRET=
# Aliyun SLS Endpoint (e.g., cn-hangzhou.log.aliyuncs.com)
ALIYUN_SLS_ENDPOINT=
# Aliyun SLS Region (e.g., cn-hangzhou)
ALIYUN_SLS_REGION=
# Aliyun SLS Project Name
ALIYUN_SLS_PROJECT_NAME=
# Number of days to retain workflow run logs (default: 365 days 3650 for permanent storage)
ALIYUN_SLS_LOGSTORE_TTL=365
# Enable dual-write to both SLS LogStore and SQL database (default: false)
LOGSTORE_DUAL_WRITE_ENABLED=false
# Enable dual-read fallback to SQL database when LogStore returns no results (default: true)
# Useful for migration scenarios where historical data exists only in SQL database
LOGSTORE_DUAL_READ_ENABLED=true
# Celery beat configuration
CELERY_BEAT_SCHEDULER_TIME=1
@ -670,4 +689,9 @@ ANNOTATION_IMPORT_MIN_RECORDS=1
ANNOTATION_IMPORT_RATE_LIMIT_PER_MINUTE=5
ANNOTATION_IMPORT_RATE_LIMIT_PER_HOUR=20
# Maximum number of concurrent annotation import tasks per tenant
ANNOTATION_IMPORT_MAX_CONCURRENT=5
ANNOTATION_IMPORT_MAX_CONCURRENT=5
# Sandbox expired records clean configuration
SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD=21
SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE=1000
SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS=30

View File

@ -84,7 +84,7 @@
1. If you need to handle and debug the async tasks (e.g. dataset importing and documents indexing), please start the worker service.
```bash
uv run celery -A app.celery worker -P threads -c 2 --loglevel INFO -Q dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor
uv run celery -A app.celery worker -P threads -c 2 --loglevel INFO -Q dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention
```
Additionally, if you want to debug the celery scheduled tasks, you can run the following command in another terminal to start the beat service:

View File

@ -75,6 +75,7 @@ def initialize_extensions(app: DifyApp):
ext_import_modules,
ext_logging,
ext_login,
ext_logstore,
ext_mail,
ext_migrate,
ext_orjson,
@ -105,6 +106,7 @@ def initialize_extensions(app: DifyApp):
ext_migrate,
ext_redis,
ext_storage,
ext_logstore, # Initialize logstore after storage, before celery
ext_celery,
ext_login,
ext_mail,

View File

@ -218,7 +218,7 @@ class PluginConfig(BaseSettings):
PLUGIN_DAEMON_TIMEOUT: PositiveFloat | None = Field(
description="Timeout in seconds for requests to the plugin daemon (set to None to disable)",
default=300.0,
default=600.0,
)
INNER_API_KEY_FOR_PLUGIN: str = Field(description="Inner api key for plugin", default="inner-api-key")
@ -1270,6 +1270,21 @@ class TenantIsolatedTaskQueueConfig(BaseSettings):
)
class SandboxExpiredRecordsCleanConfig(BaseSettings):
SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD: NonNegativeInt = Field(
description="Graceful period in days for sandbox records clean after subscription expiration",
default=21,
)
SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE: PositiveInt = Field(
description="Maximum number of records to process in each batch",
default=1000,
)
SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS: PositiveInt = Field(
description="Retention days for sandbox expired workflow_run records and message records",
default=30,
)
class FeatureConfig(
# place the configs in alphabet order
AppExecutionConfig,
@ -1295,6 +1310,7 @@ class FeatureConfig(
PositionConfig,
RagEtlConfig,
RepositoryConfig,
SandboxExpiredRecordsCleanConfig,
SecurityConfig,
TenantIsolatedTaskQueueConfig,
ToolConfig,

View File

@ -22,7 +22,12 @@ from controllers.console.error import (
NotAllowedCreateWorkspace,
WorkspacesLimitExceeded,
)
from controllers.console.wraps import email_password_login_enabled, setup_required
from controllers.console.wraps import (
decrypt_code_field,
decrypt_password_field,
email_password_login_enabled,
setup_required,
)
from events.tenant_event import tenant_was_created
from libs.helper import EmailStr, extract_remote_ip
from libs.login import current_account_with_tenant
@ -79,6 +84,7 @@ class LoginApi(Resource):
@setup_required
@email_password_login_enabled
@console_ns.expect(console_ns.models[LoginPayload.__name__])
@decrypt_password_field
def post(self):
"""Authenticate user and login."""
args = LoginPayload.model_validate(console_ns.payload)
@ -218,6 +224,7 @@ class EmailCodeLoginSendEmailApi(Resource):
class EmailCodeLoginApi(Resource):
@setup_required
@console_ns.expect(console_ns.models[EmailCodeLoginPayload.__name__])
@decrypt_code_field
def post(self):
args = EmailCodeLoginPayload.model_validate(console_ns.payload)

View File

@ -140,6 +140,18 @@ class DataSourceNotionListApi(Resource):
credential_id = request.args.get("credential_id", default=None, type=str)
if not credential_id:
raise ValueError("Credential id is required.")
# Get datasource_parameters from query string (optional, for GitHub and other datasources)
datasource_parameters_str = request.args.get("datasource_parameters", default=None, type=str)
datasource_parameters = {}
if datasource_parameters_str:
try:
datasource_parameters = json.loads(datasource_parameters_str)
if not isinstance(datasource_parameters, dict):
raise ValueError("datasource_parameters must be a JSON object.")
except json.JSONDecodeError:
raise ValueError("Invalid datasource_parameters JSON format.")
datasource_provider_service = DatasourceProviderService()
credential = datasource_provider_service.get_datasource_credentials(
tenant_id=current_tenant_id,
@ -187,7 +199,7 @@ class DataSourceNotionListApi(Resource):
online_document_result: Generator[OnlineDocumentPagesMessage, None, None] = (
datasource_runtime.get_online_document_pages(
user_id=current_user.id,
datasource_parameters={},
datasource_parameters=datasource_parameters,
provider_type=datasource_runtime.datasource_provider_type(),
)
)
@ -218,14 +230,14 @@ class DataSourceNotionListApi(Resource):
@console_ns.route(
"/notion/workspaces/<uuid:workspace_id>/pages/<uuid:page_id>/<string:page_type>/preview",
"/notion/pages/<uuid:page_id>/<string:page_type>/preview",
"/datasets/notion-indexing-estimate",
)
class DataSourceNotionApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, workspace_id, page_id, page_type):
def get(self, page_id, page_type):
_, current_tenant_id = current_account_with_tenant()
credential_id = request.args.get("credential_id", default=None, type=str)
@ -239,11 +251,10 @@ class DataSourceNotionApi(Resource):
plugin_id="langgenius/notion_datasource",
)
workspace_id = str(workspace_id)
page_id = str(page_id)
extractor = NotionExtractor(
notion_workspace_id=workspace_id,
notion_workspace_id="",
notion_obj_id=page_id,
notion_page_type=page_type,
notion_access_token=credential.get("integration_secret"),

View File

@ -146,7 +146,7 @@ class DatasetUpdatePayload(BaseModel):
embedding_model: str | None = None
embedding_model_provider: str | None = None
retrieval_model: dict[str, Any] | None = None
partial_member_list: list[str] | None = None
partial_member_list: list[dict[str, str]] | None = None
external_retrieval_model: dict[str, Any] | None = None
external_knowledge_id: str | None = None
external_knowledge_api_id: str | None = None

View File

@ -4,7 +4,7 @@ from typing import Any, Literal, cast
from uuid import UUID
from flask import abort, request
from flask_restx import Resource, marshal_with # type: ignore
from flask_restx import Resource, marshal_with, reqparse # type: ignore
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
@ -975,6 +975,11 @@ class RagPipelineRecommendedPluginApi(Resource):
@login_required
@account_initialization_required
def get(self):
parser = reqparse.RequestParser()
parser.add_argument("type", type=str, location="args", required=False, default="all")
args = parser.parse_args()
type = args["type"]
rag_pipeline_service = RagPipelineService()
recommended_plugins = rag_pipeline_service.get_recommended_plugins()
recommended_plugins = rag_pipeline_service.get_recommended_plugins(type)
return recommended_plugins

View File

@ -1,31 +1,40 @@
from typing import Literal
from flask import request
from flask_restx import Resource, marshal_with, reqparse
from flask_restx import Resource, marshal_with
from pydantic import BaseModel, Field
from werkzeug.exceptions import Forbidden
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from fields.tag_fields import dataset_tag_fields
from libs.login import current_account_with_tenant, login_required
from models.model import Tag
from services.tag_service import TagService
def _validate_name(name):
if not name or len(name) < 1 or len(name) > 50:
raise ValueError("Name must be between 1 to 50 characters.")
return name
class TagBasePayload(BaseModel):
name: str = Field(description="Tag name", min_length=1, max_length=50)
type: Literal["knowledge", "app"] | None = Field(default=None, description="Tag type")
parser_tags = (
reqparse.RequestParser()
.add_argument(
"name",
nullable=False,
required=True,
help="Name must be between 1 to 50 characters.",
type=_validate_name,
)
.add_argument("type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type.")
class TagBindingPayload(BaseModel):
tag_ids: list[str] = Field(description="Tag IDs to bind")
target_id: str = Field(description="Target ID to bind tags to")
type: Literal["knowledge", "app"] | None = Field(default=None, description="Tag type")
class TagBindingRemovePayload(BaseModel):
tag_id: str = Field(description="Tag ID to remove")
target_id: str = Field(description="Target ID to unbind tag from")
type: Literal["knowledge", "app"] | None = Field(default=None, description="Tag type")
register_schema_models(
console_ns,
TagBasePayload,
TagBindingPayload,
TagBindingRemovePayload,
)
@ -43,7 +52,7 @@ class TagListApi(Resource):
return tags, 200
@console_ns.expect(parser_tags)
@console_ns.expect(console_ns.models[TagBasePayload.__name__])
@setup_required
@login_required
@account_initialization_required
@ -53,22 +62,17 @@ class TagListApi(Resource):
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden()
args = parser_tags.parse_args()
tag = TagService.save_tags(args)
payload = TagBasePayload.model_validate(console_ns.payload or {})
tag = TagService.save_tags(payload.model_dump())
response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0}
return response, 200
parser_tag_id = reqparse.RequestParser().add_argument(
"name", nullable=False, required=True, help="Name must be between 1 to 50 characters.", type=_validate_name
)
@console_ns.route("/tags/<uuid:tag_id>")
class TagUpdateDeleteApi(Resource):
@console_ns.expect(parser_tag_id)
@console_ns.expect(console_ns.models[TagBasePayload.__name__])
@setup_required
@login_required
@account_initialization_required
@ -79,8 +83,8 @@ class TagUpdateDeleteApi(Resource):
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden()
args = parser_tag_id.parse_args()
tag = TagService.update_tags(args, tag_id)
payload = TagBasePayload.model_validate(console_ns.payload or {})
tag = TagService.update_tags(payload.model_dump(), tag_id)
binding_count = TagService.get_tag_binding_count(tag_id)
@ -100,17 +104,9 @@ class TagUpdateDeleteApi(Resource):
return 204
parser_create = (
reqparse.RequestParser()
.add_argument("tag_ids", type=list, nullable=False, required=True, location="json", help="Tag IDs is required.")
.add_argument("target_id", type=str, nullable=False, required=True, location="json", help="Target ID is required.")
.add_argument("type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type.")
)
@console_ns.route("/tag-bindings/create")
class TagBindingCreateApi(Resource):
@console_ns.expect(parser_create)
@console_ns.expect(console_ns.models[TagBindingPayload.__name__])
@setup_required
@login_required
@account_initialization_required
@ -120,23 +116,15 @@ class TagBindingCreateApi(Resource):
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden()
args = parser_create.parse_args()
TagService.save_tag_binding(args)
payload = TagBindingPayload.model_validate(console_ns.payload or {})
TagService.save_tag_binding(payload.model_dump())
return {"result": "success"}, 200
parser_remove = (
reqparse.RequestParser()
.add_argument("tag_id", type=str, nullable=False, required=True, help="Tag ID is required.")
.add_argument("target_id", type=str, nullable=False, required=True, help="Target ID is required.")
.add_argument("type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type.")
)
@console_ns.route("/tag-bindings/remove")
class TagBindingDeleteApi(Resource):
@console_ns.expect(parser_remove)
@console_ns.expect(console_ns.models[TagBindingRemovePayload.__name__])
@setup_required
@login_required
@account_initialization_required
@ -146,7 +134,7 @@ class TagBindingDeleteApi(Resource):
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden()
args = parser_remove.parse_args()
TagService.delete_tag_binding(args)
payload = TagBindingRemovePayload.model_validate(console_ns.payload or {})
TagService.delete_tag_binding(payload.model_dump())
return {"result": "success"}, 200

View File

@ -9,10 +9,12 @@ from typing import ParamSpec, TypeVar
from flask import abort, request
from configs import dify_config
from controllers.console.auth.error import AuthenticationFailedError, EmailCodeError
from controllers.console.workspace.error import AccountNotInitializedError
from enums.cloud_plan import CloudPlan
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.encryption import FieldEncryption
from libs.login import current_account_with_tenant
from models.account import AccountStatus
from models.dataset import RateLimitLog
@ -25,6 +27,14 @@ from .error import NotInitValidateError, NotSetupError, UnauthorizedAndForceLogo
P = ParamSpec("P")
R = TypeVar("R")
# Field names for decryption
FIELD_NAME_PASSWORD = "password"
FIELD_NAME_CODE = "code"
# Error messages for decryption failures
ERROR_MSG_INVALID_ENCRYPTED_DATA = "Invalid encrypted data"
ERROR_MSG_INVALID_ENCRYPTED_CODE = "Invalid encrypted code"
def account_initialization_required(view: Callable[P, R]):
@wraps(view)
@ -419,3 +429,75 @@ def annotation_import_concurrency_limit(view: Callable[P, R]):
return view(*args, **kwargs)
return decorated
def _decrypt_field(field_name: str, error_class: type[Exception], error_message: str) -> None:
"""
Helper to decode a Base64 encoded field in the request payload.
Args:
field_name: Name of the field to decode
error_class: Exception class to raise on decoding failure
error_message: Error message to include in the exception
"""
if not request or not request.is_json:
return
# Get the payload dict - it's cached and mutable
payload = request.get_json()
if not payload or field_name not in payload:
return
encoded_value = payload[field_name]
decoded_value = FieldEncryption.decrypt_field(encoded_value)
# If decoding failed, raise error immediately
if decoded_value is None:
raise error_class(error_message)
# Update payload dict in-place with decoded value
# Since payload is a mutable dict and get_json() returns the cached reference,
# modifying it will affect all subsequent accesses including console_ns.payload
payload[field_name] = decoded_value
def decrypt_password_field(view: Callable[P, R]):
"""
Decorator to decrypt password field in request payload.
Automatically decrypts the 'password' field if encryption is enabled.
If decryption fails, raises AuthenticationFailedError.
Usage:
@decrypt_password_field
def post(self):
args = LoginPayload.model_validate(console_ns.payload)
# args.password is now decrypted
"""
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
_decrypt_field(FIELD_NAME_PASSWORD, AuthenticationFailedError, ERROR_MSG_INVALID_ENCRYPTED_DATA)
return view(*args, **kwargs)
return decorated
def decrypt_code_field(view: Callable[P, R]):
"""
Decorator to decrypt verification code field in request payload.
Automatically decrypts the 'code' field if encryption is enabled.
If decryption fails, raises EmailCodeError.
Usage:
@decrypt_code_field
def post(self):
args = EmailCodeLoginPayload.model_validate(console_ns.payload)
# args.code is now decrypted
"""
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
_decrypt_field(FIELD_NAME_CODE, EmailCodeError, ERROR_MSG_INVALID_ENCRYPTED_CODE)
return view(*args, **kwargs)
return decorated

View File

@ -49,7 +49,7 @@ class DatasetUpdatePayload(BaseModel):
embedding_model: str | None = None
embedding_model_provider: str | None = None
retrieval_model: RetrievalModel | None = None
partial_member_list: list[str] | None = None
partial_member_list: list[dict[str, str]] | None = None
external_retrieval_model: dict[str, Any] | None = None
external_knowledge_id: str | None = None
external_knowledge_api_id: str | None = None

View File

@ -1,7 +1,8 @@
import logging
from flask import request
from flask_restx import fields, marshal_with, reqparse
from flask_restx import fields, marshal_with
from pydantic import BaseModel, field_validator
from werkzeug.exceptions import InternalServerError
import services
@ -20,6 +21,7 @@ from controllers.web.error import (
from controllers.web.wraps import WebApiResource
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.model_runtime.errors.invoke import InvokeError
from libs.helper import uuid_value
from models.model import App
from services.audio_service import AudioService
from services.errors.audio import (
@ -29,6 +31,25 @@ from services.errors.audio import (
UnsupportedAudioTypeServiceError,
)
from ..common.schema import register_schema_models
class TextToAudioPayload(BaseModel):
message_id: str | None = None
voice: str | None = None
text: str | None = None
streaming: bool | None = None
@field_validator("message_id")
@classmethod
def validate_message_id(cls, value: str | None) -> str | None:
if value is None:
return value
return uuid_value(value)
register_schema_models(web_ns, TextToAudioPayload)
logger = logging.getLogger(__name__)
@ -88,6 +109,7 @@ class AudioApi(WebApiResource):
@web_ns.route("/text-to-audio")
class TextApi(WebApiResource):
@web_ns.expect(web_ns.models[TextToAudioPayload.__name__])
@web_ns.doc("Text to Audio")
@web_ns.doc(description="Convert text to audio using text-to-speech service.")
@web_ns.doc(
@ -102,18 +124,11 @@ class TextApi(WebApiResource):
def post(self, app_model: App, end_user):
"""Convert text to audio"""
try:
parser = (
reqparse.RequestParser()
.add_argument("message_id", type=str, required=False, location="json")
.add_argument("voice", type=str, location="json")
.add_argument("text", type=str, location="json")
.add_argument("streaming", type=bool, location="json")
)
args = parser.parse_args()
payload = TextToAudioPayload.model_validate(web_ns.payload or {})
message_id = args.get("message_id", None)
text = args.get("text", None)
voice = args.get("voice", None)
message_id = payload.message_id
text = payload.text
voice = payload.voice
response = AudioService.transcript_tts(
app_model=app_model, text=text, voice=voice, end_user=end_user.external_user_id, message_id=message_id
)

View File

@ -1,9 +1,11 @@
import logging
from typing import Any, Literal
from flask_restx import reqparse
from pydantic import BaseModel, Field, field_validator
from werkzeug.exceptions import InternalServerError, NotFound
import services
from controllers.common.schema import register_schema_models
from controllers.web import web_ns
from controllers.web.error import (
AppUnavailableError,
@ -34,25 +36,44 @@ from services.errors.llm import InvokeRateLimitError
logger = logging.getLogger(__name__)
class CompletionMessagePayload(BaseModel):
inputs: dict[str, Any] = Field(description="Input variables for the completion")
query: str = Field(default="", description="Query text for completion")
files: list[dict[str, Any]] | None = Field(default=None, description="Files to be processed")
response_mode: Literal["blocking", "streaming"] | None = Field(
default=None, description="Response mode: blocking or streaming"
)
retriever_from: str = Field(default="web_app", description="Source of retriever")
class ChatMessagePayload(BaseModel):
inputs: dict[str, Any] = Field(description="Input variables for the chat")
query: str = Field(description="User query/message")
files: list[dict[str, Any]] | None = Field(default=None, description="Files to be processed")
response_mode: Literal["blocking", "streaming"] | None = Field(
default=None, description="Response mode: blocking or streaming"
)
conversation_id: str | None = Field(default=None, description="Conversation ID")
parent_message_id: str | None = Field(default=None, description="Parent message ID")
retriever_from: str = Field(default="web_app", description="Source of retriever")
@field_validator("conversation_id", "parent_message_id")
@classmethod
def validate_uuid(cls, value: str | None) -> str | None:
if value is None:
return value
return uuid_value(value)
register_schema_models(web_ns, CompletionMessagePayload, ChatMessagePayload)
# define completion api for user
@web_ns.route("/completion-messages")
class CompletionApi(WebApiResource):
@web_ns.doc("Create Completion Message")
@web_ns.doc(description="Create a completion message for text generation applications.")
@web_ns.doc(
params={
"inputs": {"description": "Input variables for the completion", "type": "object", "required": True},
"query": {"description": "Query text for completion", "type": "string", "required": False},
"files": {"description": "Files to be processed", "type": "array", "required": False},
"response_mode": {
"description": "Response mode: blocking or streaming",
"type": "string",
"enum": ["blocking", "streaming"],
"required": False,
},
"retriever_from": {"description": "Source of retriever", "type": "string", "required": False},
}
)
@web_ns.expect(web_ns.models[CompletionMessagePayload.__name__])
@web_ns.doc(
responses={
200: "Success",
@ -67,18 +88,10 @@ class CompletionApi(WebApiResource):
if app_model.mode != AppMode.COMPLETION:
raise NotCompletionAppError()
parser = (
reqparse.RequestParser()
.add_argument("inputs", type=dict, required=True, location="json")
.add_argument("query", type=str, location="json", default="")
.add_argument("files", type=list, required=False, location="json")
.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
.add_argument("retriever_from", type=str, required=False, default="web_app", location="json")
)
payload = CompletionMessagePayload.model_validate(web_ns.payload or {})
args = payload.model_dump(exclude_none=True)
args = parser.parse_args()
streaming = args["response_mode"] == "streaming"
streaming = payload.response_mode == "streaming"
args["auto_generate_name"] = False
try:
@ -142,22 +155,7 @@ class CompletionStopApi(WebApiResource):
class ChatApi(WebApiResource):
@web_ns.doc("Create Chat Message")
@web_ns.doc(description="Create a chat message for conversational applications.")
@web_ns.doc(
params={
"inputs": {"description": "Input variables for the chat", "type": "object", "required": True},
"query": {"description": "User query/message", "type": "string", "required": True},
"files": {"description": "Files to be processed", "type": "array", "required": False},
"response_mode": {
"description": "Response mode: blocking or streaming",
"type": "string",
"enum": ["blocking", "streaming"],
"required": False,
},
"conversation_id": {"description": "Conversation UUID", "type": "string", "required": False},
"parent_message_id": {"description": "Parent message UUID", "type": "string", "required": False},
"retriever_from": {"description": "Source of retriever", "type": "string", "required": False},
}
)
@web_ns.expect(web_ns.models[ChatMessagePayload.__name__])
@web_ns.doc(
responses={
200: "Success",
@ -173,20 +171,10 @@ class ChatApi(WebApiResource):
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
parser = (
reqparse.RequestParser()
.add_argument("inputs", type=dict, required=True, location="json")
.add_argument("query", type=str, required=True, location="json")
.add_argument("files", type=list, required=False, location="json")
.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
.add_argument("conversation_id", type=uuid_value, location="json")
.add_argument("parent_message_id", type=uuid_value, required=False, location="json")
.add_argument("retriever_from", type=str, required=False, default="web_app", location="json")
)
payload = ChatMessagePayload.model_validate(web_ns.payload or {})
args = payload.model_dump(exclude_none=True)
args = parser.parse_args()
streaming = args["response_mode"] == "streaming"
streaming = payload.response_mode == "streaming"
args["auto_generate_name"] = False
try:

View File

@ -1,3 +1,4 @@
import json
from collections.abc import Sequence
from enum import StrEnum, auto
from typing import Any, Literal
@ -120,7 +121,7 @@ class VariableEntity(BaseModel):
allowed_file_types: Sequence[FileType] | None = Field(default_factory=list)
allowed_file_extensions: Sequence[str] | None = Field(default_factory=list)
allowed_file_upload_methods: Sequence[FileTransferMethod] | None = Field(default_factory=list)
json_schema: dict[str, Any] | None = Field(default=None)
json_schema: str | None = Field(default=None)
@field_validator("description", mode="before")
@classmethod
@ -134,11 +135,17 @@ class VariableEntity(BaseModel):
@field_validator("json_schema")
@classmethod
def validate_json_schema(cls, schema: dict[str, Any] | None) -> dict[str, Any] | None:
def validate_json_schema(cls, schema: str | None) -> str | None:
if schema is None:
return None
try:
Draft7Validator.check_schema(schema)
json_schema = json.loads(schema)
except json.JSONDecodeError:
raise ValueError(f"invalid json_schema value {schema}")
try:
Draft7Validator.check_schema(json_schema)
except SchemaError as e:
raise ValueError(f"Invalid JSON schema: {e.message}")
return schema

View File

@ -1,3 +1,4 @@
import json
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Union, final
@ -175,6 +176,13 @@ class BaseAppGenerator:
value = True
elif value == 0:
value = False
case VariableEntityType.JSON_OBJECT:
if not isinstance(value, str):
raise ValueError(f"{variable_entity.variable} in input form must be a string")
try:
json.loads(value)
except json.JSONDecodeError:
raise ValueError(f"{variable_entity.variable} in input form must be a valid JSON object")
case _:
raise AssertionError("this statement should be unreachable.")

View File

@ -342,9 +342,11 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
self._task_state.llm_result.message.content = current_content
if isinstance(event, QueueLLMChunkEvent):
event_type = self._message_cycle_manager.get_message_event_type(message_id=self._message_id)
yield self._message_cycle_manager.message_to_stream_response(
answer=cast(str, delta_text),
message_id=self._message_id,
event_type=event_type,
)
else:
yield self._agent_message_to_stream_response(

View File

@ -5,7 +5,7 @@ from threading import Thread
from typing import Union
from flask import Flask, current_app
from sqlalchemy import select
from sqlalchemy import exists, select
from sqlalchemy.orm import Session
from configs import dify_config
@ -54,6 +54,20 @@ class MessageCycleManager:
):
self._application_generate_entity = application_generate_entity
self._task_state = task_state
self._message_has_file: set[str] = set()
def get_message_event_type(self, message_id: str) -> StreamEvent:
if message_id in self._message_has_file:
return StreamEvent.MESSAGE_FILE
with Session(db.engine, expire_on_commit=False) as session:
has_file = session.query(exists().where(MessageFile.message_id == message_id)).scalar()
if has_file:
self._message_has_file.add(message_id)
return StreamEvent.MESSAGE_FILE
return StreamEvent.MESSAGE
def generate_conversation_name(self, *, conversation_id: str, query: str) -> Thread | None:
"""
@ -214,7 +228,11 @@ class MessageCycleManager:
return None
def message_to_stream_response(
self, answer: str, message_id: str, from_variable_selector: list[str] | None = None
self,
answer: str,
message_id: str,
from_variable_selector: list[str] | None = None,
event_type: StreamEvent | None = None,
) -> MessageStreamResponse:
"""
Message to stream response.
@ -222,16 +240,12 @@ class MessageCycleManager:
:param message_id: message id
:return:
"""
with Session(db.engine, expire_on_commit=False) as session:
message_file = session.scalar(select(MessageFile).where(MessageFile.id == message_id))
event_type = StreamEvent.MESSAGE_FILE if message_file else StreamEvent.MESSAGE
return MessageStreamResponse(
task_id=self._application_generate_entity.task_id,
id=message_id,
answer=answer,
from_variable_selector=from_variable_selector,
event=event_type,
event=event_type or StreamEvent.MESSAGE,
)
def message_replace_to_stream_response(self, answer: str, reason: str = "") -> MessageReplaceStreamResponse:

View File

@ -18,34 +18,20 @@ This module provides the interface for invoking and authenticating various model
- Model provider display
![image-20231210143654461](./docs/en_US/images/index/image-20231210143654461.png)
Displays a list of all supported providers, including provider names, icons, supported model types list, predefined model list, configuration method, and credentials form rules, etc. For detailed rule design, see: [Schema](./docs/en_US/schema.md).
Displays a list of all supported providers, including provider names, icons, supported model types list, predefined model list, configuration method, and credentials form rules, etc.
- Selectable model list display
![image-20231210144229650](./docs/en_US/images/index/image-20231210144229650.png)
After configuring provider/model credentials, the dropdown (application orchestration interface/default model) allows viewing of the available LLM list. Greyed out items represent predefined model lists from providers without configured credentials, facilitating user review of supported models.
In addition, this list also returns configurable parameter information and rules for LLM, as shown below:
![image-20231210144814617](./docs/en_US/images/index/image-20231210144814617.png)
These parameters are all defined in the backend, allowing different settings for various parameters supported by different models, as detailed in: [Schema](./docs/en_US/schema.md#ParameterRule).
In addition, this list also returns configurable parameter information and rules for LLM. These parameters are all defined in the backend, allowing different settings for various parameters supported by different models.
- Provider/model credential authentication
![image-20231210151548521](./docs/en_US/images/index/image-20231210151548521.png)
![image-20231210151628992](./docs/en_US/images/index/image-20231210151628992.png)
The provider list returns configuration information for the credentials form, which can be authenticated through Runtime's interface. The first image above is a provider credential DEMO, and the second is a model credential DEMO.
The provider list returns configuration information for the credentials form, which can be authenticated through Runtime's interface.
## Structure
![](./docs/en_US/images/index/image-20231210165243632.png)
Model Runtime is divided into three layers:
- The outermost layer is the factory method
@ -60,9 +46,6 @@ Model Runtime is divided into three layers:
It offers direct invocation of various model types, predefined model configuration information, getting predefined/remote model lists, model credential authentication methods. Different models provide additional special methods, like LLM's pre-computed tokens method, cost information obtaining method, etc., **allowing horizontal expansion** for different models under the same provider (within supported model types).
## Next Steps
## Documentation
- Add new provider configuration: [Link](./docs/en_US/provider_scale_out.md)
- Add new models for existing providers: [Link](./docs/en_US/provider_scale_out.md#AddModel)
- View YAML configuration rules: [Link](./docs/en_US/schema.md)
- Implement interface methods: [Link](./docs/en_US/interfaces.md)
For detailed documentation on how to add new providers or models, please refer to the [Dify documentation](https://docs.dify.ai/).

View File

@ -18,34 +18,20 @@
- 模型供应商展示
![image-20231210143654461](./docs/zh_Hans/images/index/image-20231210143654461.png)
展示所有已支持的供应商列表,除了返回供应商名称、图标之外,还提供了支持的模型类型列表,预定义模型列表、配置方式以及配置凭据的表单规则等等,规则设计详见:[Schema](./docs/zh_Hans/schema.md)。
展示所有已支持的供应商列表,除了返回供应商名称、图标之外,还提供了支持的模型类型列表,预定义模型列表、配置方式以及配置凭据的表单规则等等。
- 可选择的模型列表展示
![image-20231210144229650](./docs/zh_Hans/images/index/image-20231210144229650.png)
配置供应商/模型凭据后,可在此下拉(应用编排界面/默认模型)查看可用的 LLM 列表,其中灰色的为未配置凭据供应商的预定义模型列表,方便用户查看已支持的模型。
配置供应商/模型凭据后,可在此下拉(应用编排界面/默认模型)查看可用的 LLM 列表,其中灰色的为未配置凭据供应商的预定义模型列表,方便用户查看已支持的模型。
除此之外,该列表还返回了 LLM 可配置的参数信息和规则,如下图:
![image-20231210144814617](./docs/zh_Hans/images/index/image-20231210144814617.png)
这里的参数均为后端定义,相比之前只有 5 种固定参数,这里可为不同模型设置所支持的各种参数,详见:[Schema](./docs/zh_Hans/schema.md#ParameterRule)。
除此之外,该列表还返回了 LLM 可配置的参数信息和规则。这里的参数均为后端定义,相比之前只有 5 种固定参数,这里可为不同模型设置所支持的各种参数。
- 供应商/模型凭据鉴权
![image-20231210151548521](./docs/zh_Hans/images/index/image-20231210151548521.png)
![image-20231210151628992](./docs/zh_Hans/images/index/image-20231210151628992.png)
供应商列表返回了凭据表单的配置信息,可通过 Runtime 提供的接口对凭据进行鉴权,上图 1 为供应商凭据 DEMO上图 2 为模型凭据 DEMO。
供应商列表返回了凭据表单的配置信息,可通过 Runtime 提供的接口对凭据进行鉴权。
## 结构
![](./docs/zh_Hans/images/index/image-20231210165243632.png)
Model Runtime 分三层:
- 最外层为工厂方法
@ -59,8 +45,7 @@ Model Runtime 分三层:
对于供应商/模型凭据,有两种情况
- 如 OpenAI 这类中心化供应商,需要定义如**api_key**这类的鉴权凭据
- 如[**Xinference**](https://github.com/xorbitsai/inference)这类本地部署的供应商,需要定义如**server_url**这类的地址凭据,有时候还需要定义**model_uid**之类的模型类型凭据,就像下面这样,当在供应商层定义了这些凭据后,就可以在前端页面上直接展示,无需修改前端逻辑。
![Alt text](docs/zh_Hans/images/index/image.png)
- 如[**Xinference**](https://github.com/xorbitsai/inference)这类本地部署的供应商,需要定义如**server_url**这类的地址凭据,有时候还需要定义**model_uid**之类的模型类型凭据。当在供应商层定义了这些凭据后,就可以在前端页面上直接展示,无需修改前端逻辑。
当配置好凭据后,就可以通过 DifyRuntime 的外部接口直接获取到对应供应商所需要的**Schema**(凭据表单规则),从而在可以在不修改前端逻辑的情况下,提供新的供应商/模型的支持。
@ -74,20 +59,6 @@ Model Runtime 分三层:
- 模型凭据 (**在供应商层定义**):这是一类不经常变动,一般在配置好后就不会再变动的参数,如 **api_key**、**server_url** 等。在 DifyRuntime 中,他们的参数名一般为**credentials: dict[str, any]**Provider 层的 credentials 会直接被传递到这一层,不需要再单独定义。
## 下一步
## 文档
### [增加新的供应商配置 👈🏻](./docs/zh_Hans/provider_scale_out.md)
当添加后,这里将会出现一个新的供应商
![Alt text](docs/zh_Hans/images/index/image-1.png)
### [为已存在的供应商新增模型 👈🏻](./docs/zh_Hans/provider_scale_out.md#%E5%A2%9E%E5%8A%A0%E6%A8%A1%E5%9E%8B)
当添加后,对应供应商的模型列表中将会出现一个新的预定义模型供用户选择,如 GPT-3.5 GPT-4 ChatGLM3-6b 等,而对于支持自定义模型的供应商,则不需要新增模型。
![Alt text](docs/zh_Hans/images/index/image-2.png)
### [接口的具体实现 👈🏻](./docs/zh_Hans/interfaces.md)
你可以在这里找到你想要查看的接口的具体实现,以及接口的参数和返回值的具体含义。
有关如何添加新供应商或模型的详细文档,请参阅 [Dify 文档](https://docs.dify.ai/)。

View File

@ -39,7 +39,7 @@ from core.trigger.errors import (
plugin_daemon_inner_api_baseurl = URL(str(dify_config.PLUGIN_DAEMON_URL))
_plugin_daemon_timeout_config = cast(
float | httpx.Timeout | None,
getattr(dify_config, "PLUGIN_DAEMON_TIMEOUT", 300.0),
getattr(dify_config, "PLUGIN_DAEMON_TIMEOUT", 600.0),
)
plugin_daemon_request_timeout: httpx.Timeout | None
if _plugin_daemon_timeout_config is None:

View File

@ -10,7 +10,7 @@ class NotionInfo(BaseModel):
"""
credential_id: str | None = None
notion_workspace_id: str
notion_workspace_id: str | None = ""
notion_obj_id: str
notion_page_type: str
document: Document | None = None

View File

@ -166,7 +166,7 @@ class ExtractProcessor:
elif extract_setting.datasource_type == DatasourceType.NOTION:
assert extract_setting.notion_info is not None, "notion_info is required"
extractor = NotionExtractor(
notion_workspace_id=extract_setting.notion_info.notion_workspace_id,
notion_workspace_id=extract_setting.notion_info.notion_workspace_id or "",
notion_obj_id=extract_setting.notion_info.notion_obj_id,
notion_page_type=extract_setting.notion_info.notion_page_type,
document_model=extract_setting.notion_info.document,

View File

@ -45,6 +45,6 @@ def detect_file_encodings(file_path: str, timeout: int = 5, sample_size: int = 1
except concurrent.futures.TimeoutError:
raise TimeoutError(f"Timeout reached while detecting encoding for {file_path}")
if all(encoding["encoding"] is None for encoding in encodings):
if all(encoding.encoding is None for encoding in encodings):
raise RuntimeError(f"Could not detect encoding for {file_path}")
return [FileEncoding(**enc) for enc in encodings if enc["encoding"] is not None]
return [enc for enc in encodings if enc.encoding is not None]

View File

@ -2,6 +2,7 @@
from __future__ import annotations
import codecs
import re
from typing import Any
@ -52,7 +53,7 @@ class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter)
def __init__(self, fixed_separator: str = "\n\n", separators: list[str] | None = None, **kwargs: Any):
"""Create a new TextSplitter."""
super().__init__(**kwargs)
self._fixed_separator = fixed_separator
self._fixed_separator = codecs.decode(fixed_separator, "unicode_escape")
self._separators = separators or ["\n\n", "\n", "", ". ", " ", ""]
def split_text(self, text: str) -> list[str]:
@ -94,7 +95,8 @@ class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter)
splits = re.split(r" +", text)
else:
splits = text.split(separator)
splits = [item + separator if i < len(splits) else item for i, item in enumerate(splits)]
if self._keep_separator:
splits = [s + separator for s in splits[:-1]] + splits[-1:]
else:
splits = list(text)
if separator == "\n":
@ -103,7 +105,7 @@ class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter)
splits = [s for s in splits if (s not in {"", "\n"})]
_good_splits = []
_good_splits_lengths = [] # cache the lengths of the splits
_separator = separator if self._keep_separator else ""
_separator = "" if self._keep_separator else separator
s_lens = self._length_function(splits)
if separator != "":
for s, s_len in zip(splits, s_lens):

View File

@ -1,3 +1,4 @@
import json
from typing import Any
from jsonschema import Draft7Validator, ValidationError
@ -42,15 +43,25 @@ class StartNode(Node[StartNodeData]):
if value is None and variable.required:
raise ValueError(f"{key} is required in input form")
if not isinstance(value, dict):
raise ValueError(f"{key} must be a JSON object")
schema = variable.json_schema
if not schema:
continue
if not value:
continue
try:
Draft7Validator(schema).validate(value)
json_schema = json.loads(schema)
except json.JSONDecodeError as e:
raise ValueError(f"{schema} must be a valid JSON object")
try:
json_value = json.loads(value)
except json.JSONDecodeError as e:
raise ValueError(f"{value} must be a valid JSON object")
try:
Draft7Validator(json_schema).validate(json_value)
except ValidationError as e:
raise ValueError(f"JSON object for '{key}' does not match schema: {e.message}")
node_inputs[key] = value
node_inputs[key] = json_value

View File

@ -34,10 +34,10 @@ if [[ "${MODE}" == "worker" ]]; then
if [[ -z "${CELERY_QUEUES}" ]]; then
if [[ "${EDITION}" == "CLOUD" ]]; then
# Cloud edition: separate queues for dataset and trigger tasks
DEFAULT_QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow_professional,workflow_team,workflow_sandbox,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor"
DEFAULT_QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow_professional,workflow_team,workflow_sandbox,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention"
else
# Community edition (SELF_HOSTED): dataset, pipeline and workflow have separate queues
DEFAULT_QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor"
DEFAULT_QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention"
fi
else
DEFAULT_QUEUES="${CELERY_QUEUES}"
@ -69,6 +69,53 @@ if [[ "${MODE}" == "worker" ]]; then
elif [[ "${MODE}" == "beat" ]]; then
exec celery -A app.celery beat --loglevel ${LOG_LEVEL:-INFO}
elif [[ "${MODE}" == "job" ]]; then
# Job mode: Run a one-time Flask command and exit
# Pass Flask command and arguments via container args
# Example K8s usage:
# args:
# - create-tenant
# - --email
# - admin@example.com
#
# Example Docker usage:
# docker run -e MODE=job dify-api:latest create-tenant --email admin@example.com
if [[ $# -eq 0 ]]; then
echo "Error: No command specified for job mode."
echo ""
echo "Usage examples:"
echo " Kubernetes:"
echo " args: [create-tenant, --email, admin@example.com]"
echo ""
echo " Docker:"
echo " docker run -e MODE=job dify-api create-tenant --email admin@example.com"
echo ""
echo "Available commands:"
echo " create-tenant, reset-password, reset-email, upgrade-db,"
echo " vdb-migrate, install-plugins, and more..."
echo ""
echo "Run 'flask --help' to see all available commands."
exit 1
fi
echo "Running Flask job command: flask $*"
# Temporarily disable exit on error to capture exit code
set +e
flask "$@"
JOB_EXIT_CODE=$?
set -e
if [[ ${JOB_EXIT_CODE} -eq 0 ]]; then
echo "Job completed successfully."
else
echo "Job failed with exit code ${JOB_EXIT_CODE}."
fi
exit ${JOB_EXIT_CODE}
else
if [[ "${DEBUG}" == "true" ]]; then
exec flask run --host=${DIFY_BIND_ADDRESS:-0.0.0.0} --port=${DIFY_PORT:-5001} --debug

View File

@ -0,0 +1,74 @@
"""
Logstore extension for Dify application.
This extension initializes the logstore (Aliyun SLS) on application startup,
creating necessary projects, logstores, and indexes if they don't exist.
"""
import logging
import os
from dotenv import load_dotenv
from dify_app import DifyApp
logger = logging.getLogger(__name__)
def is_enabled() -> bool:
"""
Check if logstore extension is enabled.
Returns:
True if all required Aliyun SLS environment variables are set, False otherwise
"""
# Load environment variables from .env file
load_dotenv()
required_vars = [
"ALIYUN_SLS_ACCESS_KEY_ID",
"ALIYUN_SLS_ACCESS_KEY_SECRET",
"ALIYUN_SLS_ENDPOINT",
"ALIYUN_SLS_REGION",
"ALIYUN_SLS_PROJECT_NAME",
]
all_set = all(os.environ.get(var) for var in required_vars)
if not all_set:
logger.info("Logstore extension disabled: required Aliyun SLS environment variables not set")
return all_set
def init_app(app: DifyApp):
"""
Initialize logstore on application startup.
This function:
1. Creates Aliyun SLS project if it doesn't exist
2. Creates logstores (workflow_execution, workflow_node_execution) if they don't exist
3. Creates indexes with field configurations based on PostgreSQL table structures
This operation is idempotent and only executes once during application startup.
Args:
app: The Dify application instance
"""
try:
from extensions.logstore.aliyun_logstore import AliyunLogStore
logger.info("Initializing logstore...")
# Create logstore client and initialize project/logstores/indexes
logstore_client = AliyunLogStore()
logstore_client.init_project_logstore()
# Attach to app for potential later use
app.extensions["logstore"] = logstore_client
logger.info("Logstore initialized successfully")
except Exception:
logger.exception("Failed to initialize logstore")
# Don't raise - allow application to continue even if logstore init fails
# This ensures that the application can still run if logstore is misconfigured

View File

View File

@ -0,0 +1,890 @@
import logging
import os
import threading
import time
from collections.abc import Sequence
from typing import Any
import sqlalchemy as sa
from aliyun.log import ( # type: ignore[import-untyped]
GetLogsRequest,
IndexConfig,
IndexKeyConfig,
IndexLineConfig,
LogClient,
LogItem,
PutLogsRequest,
)
from aliyun.log.auth import AUTH_VERSION_4 # type: ignore[import-untyped]
from aliyun.log.logexception import LogException # type: ignore[import-untyped]
from dotenv import load_dotenv
from sqlalchemy.orm import DeclarativeBase
from configs import dify_config
from extensions.logstore.aliyun_logstore_pg import AliyunLogStorePG
logger = logging.getLogger(__name__)
class AliyunLogStore:
"""
Singleton class for Aliyun SLS LogStore operations.
Ensures only one instance exists to prevent multiple PG connection pools.
"""
_instance: "AliyunLogStore | None" = None
_initialized: bool = False
# Track delayed PG connection for newly created projects
_pg_connection_timer: threading.Timer | None = None
_pg_connection_delay: int = 90 # delay seconds
# Default tokenizer for text/json fields and full-text index
# Common delimiters: comma, space, quotes, punctuation, operators, brackets, special chars
DEFAULT_TOKEN_LIST = [
",",
" ",
'"',
'"',
";",
"=",
"(",
")",
"[",
"]",
"{",
"}",
"?",
"@",
"&",
"<",
">",
"/",
":",
"\n",
"\t",
]
def __new__(cls) -> "AliyunLogStore":
"""Implement singleton pattern."""
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
project_des = "dify"
workflow_execution_logstore = "workflow_execution"
workflow_node_execution_logstore = "workflow_node_execution"
@staticmethod
def _sqlalchemy_type_to_logstore_type(column: Any) -> str:
"""
Map SQLAlchemy column type to Aliyun LogStore index type.
Args:
column: SQLAlchemy column object
Returns:
LogStore index type: 'text', 'long', 'double', or 'json'
"""
column_type = column.type
# Integer types -> long
if isinstance(column_type, (sa.Integer, sa.BigInteger, sa.SmallInteger)):
return "long"
# Float types -> double
if isinstance(column_type, (sa.Float, sa.Numeric)):
return "double"
# String and Text types -> text
if isinstance(column_type, (sa.String, sa.Text)):
return "text"
# DateTime -> text (stored as ISO format string in logstore)
if isinstance(column_type, sa.DateTime):
return "text"
# Boolean -> long (stored as 0/1)
if isinstance(column_type, sa.Boolean):
return "long"
# JSON -> json
if isinstance(column_type, sa.JSON):
return "json"
# Default to text for unknown types
return "text"
@staticmethod
def _generate_index_keys_from_model(model_class: type[DeclarativeBase]) -> dict[str, IndexKeyConfig]:
"""
Automatically generate LogStore field index configuration from SQLAlchemy model.
This method introspects the SQLAlchemy model's column definitions and creates
corresponding LogStore index configurations. When the PG schema is updated via
Flask-Migrate, this method will automatically pick up the new fields on next startup.
Args:
model_class: SQLAlchemy model class (e.g., WorkflowRun, WorkflowNodeExecutionModel)
Returns:
Dictionary mapping field names to IndexKeyConfig objects
"""
index_keys = {}
# Iterate over all mapped columns in the model
if hasattr(model_class, "__mapper__"):
for column_name, column_property in model_class.__mapper__.columns.items():
# Skip relationship properties and other non-column attributes
if not hasattr(column_property, "type"):
continue
# Map SQLAlchemy type to LogStore type
logstore_type = AliyunLogStore._sqlalchemy_type_to_logstore_type(column_property)
# Create index configuration
# - text fields: case_insensitive for better search, with tokenizer and Chinese support
# - all fields: doc_value=True for analytics
if logstore_type == "text":
index_keys[column_name] = IndexKeyConfig(
index_type="text",
case_sensitive=False,
doc_value=True,
token_list=AliyunLogStore.DEFAULT_TOKEN_LIST,
chinese=True,
)
else:
index_keys[column_name] = IndexKeyConfig(index_type=logstore_type, doc_value=True)
# Add log_version field (not in PG model, but used in logstore for versioning)
index_keys["log_version"] = IndexKeyConfig(index_type="long", doc_value=True)
return index_keys
def __init__(self) -> None:
# Skip initialization if already initialized (singleton pattern)
if self.__class__._initialized:
return
load_dotenv()
self.access_key_id: str = os.environ.get("ALIYUN_SLS_ACCESS_KEY_ID", "")
self.access_key_secret: str = os.environ.get("ALIYUN_SLS_ACCESS_KEY_SECRET", "")
self.endpoint: str = os.environ.get("ALIYUN_SLS_ENDPOINT", "")
self.region: str = os.environ.get("ALIYUN_SLS_REGION", "")
self.project_name: str = os.environ.get("ALIYUN_SLS_PROJECT_NAME", "")
self.logstore_ttl: int = int(os.environ.get("ALIYUN_SLS_LOGSTORE_TTL", 365))
self.log_enabled: bool = os.environ.get("SQLALCHEMY_ECHO", "false").lower() == "true"
self.pg_mode_enabled: bool = os.environ.get("LOGSTORE_PG_MODE_ENABLED", "true").lower() == "true"
# Initialize SDK client
self.client = LogClient(
self.endpoint, self.access_key_id, self.access_key_secret, auth_version=AUTH_VERSION_4, region=self.region
)
# Append Dify identification to the existing user agent
original_user_agent = self.client._user_agent # pyright: ignore[reportPrivateUsage]
dify_version = dify_config.project.version
enhanced_user_agent = f"Dify,Dify-{dify_version},{original_user_agent}"
self.client.set_user_agent(enhanced_user_agent)
# PG client will be initialized in init_project_logstore
self._pg_client: AliyunLogStorePG | None = None
self._use_pg_protocol: bool = False
self.__class__._initialized = True
@property
def supports_pg_protocol(self) -> bool:
"""Check if PG protocol is supported and enabled."""
return self._use_pg_protocol
def _attempt_pg_connection_init(self) -> bool:
"""
Attempt to initialize PG connection.
This method tries to establish PG connection and performs necessary checks.
It's used both for immediate connection (existing projects) and delayed connection (new projects).
Returns:
True if PG connection was successfully established, False otherwise.
"""
if not self.pg_mode_enabled or not self._pg_client:
return False
try:
self._use_pg_protocol = self._pg_client.init_connection()
if self._use_pg_protocol:
logger.info("Successfully connected to project %s using PG protocol", self.project_name)
# Check if scan_index is enabled for all logstores
self._check_and_disable_pg_if_scan_index_disabled()
return True
else:
logger.info("PG connection failed for project %s. Will use SDK mode.", self.project_name)
return False
except Exception as e:
logger.warning(
"Failed to establish PG connection for project %s: %s. Will use SDK mode.",
self.project_name,
str(e),
)
self._use_pg_protocol = False
return False
def _delayed_pg_connection_init(self) -> None:
"""
Delayed initialization of PG connection for newly created projects.
This method is called by a background timer 3 minutes after project creation.
"""
# Double check conditions in case state changed
if self._use_pg_protocol:
return
logger.info(
"Attempting delayed PG connection for newly created project %s ...",
self.project_name,
)
self._attempt_pg_connection_init()
self.__class__._pg_connection_timer = None
def init_project_logstore(self):
"""
Initialize project, logstore, index, and PG connection.
This method should be called once during application startup to ensure
all required resources exist and connections are established.
"""
# Step 1: Ensure project and logstore exist
project_is_new = False
if not self.is_project_exist():
self.create_project()
project_is_new = True
self.create_logstore_if_not_exist()
# Step 2: Initialize PG client and connection (if enabled)
if not self.pg_mode_enabled:
logger.info("PG mode is disabled. Will use SDK mode.")
return
# Create PG client if not already created
if self._pg_client is None:
logger.info("Initializing PG client for project %s...", self.project_name)
self._pg_client = AliyunLogStorePG(
self.access_key_id, self.access_key_secret, self.endpoint, self.project_name
)
# Step 3: Establish PG connection based on project status
if project_is_new:
# For newly created projects, schedule delayed PG connection
self._use_pg_protocol = False
logger.info(
"Project %s is newly created. Will use SDK mode and schedule PG connection attempt in %d seconds.",
self.project_name,
self.__class__._pg_connection_delay,
)
if self.__class__._pg_connection_timer is not None:
self.__class__._pg_connection_timer.cancel()
self.__class__._pg_connection_timer = threading.Timer(
self.__class__._pg_connection_delay,
self._delayed_pg_connection_init,
)
self.__class__._pg_connection_timer.daemon = True # Don't block app shutdown
self.__class__._pg_connection_timer.start()
else:
# For existing projects, attempt PG connection immediately
logger.info("Project %s already exists. Attempting PG connection...", self.project_name)
self._attempt_pg_connection_init()
def _check_and_disable_pg_if_scan_index_disabled(self) -> None:
"""
Check if scan_index is enabled for all logstores.
If any logstore has scan_index=false, disable PG protocol.
This is necessary because PG protocol requires scan_index to be enabled.
"""
logstore_name_list = [
AliyunLogStore.workflow_execution_logstore,
AliyunLogStore.workflow_node_execution_logstore,
]
for logstore_name in logstore_name_list:
existing_config = self.get_existing_index_config(logstore_name)
if existing_config and not existing_config.scan_index:
logger.info(
"Logstore %s has scan_index=false, USE SDK mode for read/write operations. "
"PG protocol requires scan_index to be enabled.",
logstore_name,
)
self._use_pg_protocol = False
# Close PG connection if it was initialized
if self._pg_client:
self._pg_client.close()
self._pg_client = None
return
def is_project_exist(self) -> bool:
try:
self.client.get_project(self.project_name)
return True
except Exception as e:
if e.args[0] == "ProjectNotExist":
return False
else:
raise e
def create_project(self):
try:
self.client.create_project(self.project_name, AliyunLogStore.project_des)
logger.info("Project %s created successfully", self.project_name)
except LogException as e:
logger.exception(
"Failed to create project %s: errorCode=%s, errorMessage=%s, requestId=%s",
self.project_name,
e.get_error_code(),
e.get_error_message(),
e.get_request_id(),
)
raise
def is_logstore_exist(self, logstore_name: str) -> bool:
try:
_ = self.client.get_logstore(self.project_name, logstore_name)
return True
except Exception as e:
if e.args[0] == "LogStoreNotExist":
return False
else:
raise e
def create_logstore_if_not_exist(self) -> None:
logstore_name_list = [
AliyunLogStore.workflow_execution_logstore,
AliyunLogStore.workflow_node_execution_logstore,
]
for logstore_name in logstore_name_list:
if not self.is_logstore_exist(logstore_name):
try:
self.client.create_logstore(
project_name=self.project_name, logstore_name=logstore_name, ttl=self.logstore_ttl
)
logger.info("logstore %s created successfully", logstore_name)
except LogException as e:
logger.exception(
"Failed to create logstore %s: errorCode=%s, errorMessage=%s, requestId=%s",
logstore_name,
e.get_error_code(),
e.get_error_message(),
e.get_request_id(),
)
raise
# Ensure index contains all Dify-required fields
# This intelligently merges with existing config, preserving custom indexes
self.ensure_index_config(logstore_name)
def is_index_exist(self, logstore_name: str) -> bool:
try:
_ = self.client.get_index_config(self.project_name, logstore_name)
return True
except Exception as e:
if e.args[0] == "IndexConfigNotExist":
return False
else:
raise e
def get_existing_index_config(self, logstore_name: str) -> IndexConfig | None:
"""
Get existing index configuration from logstore.
Args:
logstore_name: Name of the logstore
Returns:
IndexConfig object if index exists, None otherwise
"""
try:
response = self.client.get_index_config(self.project_name, logstore_name)
return response.get_index_config()
except Exception as e:
if e.args[0] == "IndexConfigNotExist":
return None
else:
logger.exception("Failed to get index config for logstore %s", logstore_name)
raise e
def _get_workflow_execution_index_keys(self) -> dict[str, IndexKeyConfig]:
"""
Get field index configuration for workflow_execution logstore.
This method automatically generates index configuration from the WorkflowRun SQLAlchemy model.
When the PG schema is updated via Flask-Migrate, the index configuration will be automatically
updated on next application startup.
"""
from models.workflow import WorkflowRun
index_keys = self._generate_index_keys_from_model(WorkflowRun)
# Add custom fields that are in logstore but not in PG model
# These fields are added by the repository layer
index_keys["error_message"] = IndexKeyConfig(
index_type="text",
case_sensitive=False,
doc_value=True,
token_list=self.DEFAULT_TOKEN_LIST,
chinese=True,
) # Maps to 'error' in PG
index_keys["started_at"] = IndexKeyConfig(
index_type="text",
case_sensitive=False,
doc_value=True,
token_list=self.DEFAULT_TOKEN_LIST,
chinese=True,
) # Maps to 'created_at' in PG
logger.info("Generated %d index keys for workflow_execution from WorkflowRun model", len(index_keys))
return index_keys
def _get_workflow_node_execution_index_keys(self) -> dict[str, IndexKeyConfig]:
"""
Get field index configuration for workflow_node_execution logstore.
This method automatically generates index configuration from the WorkflowNodeExecutionModel.
When the PG schema is updated via Flask-Migrate, the index configuration will be automatically
updated on next application startup.
"""
from models.workflow import WorkflowNodeExecutionModel
index_keys = self._generate_index_keys_from_model(WorkflowNodeExecutionModel)
logger.debug(
"Generated %d index keys for workflow_node_execution from WorkflowNodeExecutionModel", len(index_keys)
)
return index_keys
def _get_index_config(self, logstore_name: str) -> IndexConfig:
"""
Get index configuration for the specified logstore.
Args:
logstore_name: Name of the logstore
Returns:
IndexConfig object with line and field indexes
"""
# Create full-text index (line config) with tokenizer
line_config = IndexLineConfig(token_list=self.DEFAULT_TOKEN_LIST, case_sensitive=False, chinese=True)
# Get field index configuration based on logstore name
field_keys = {}
if logstore_name == AliyunLogStore.workflow_execution_logstore:
field_keys = self._get_workflow_execution_index_keys()
elif logstore_name == AliyunLogStore.workflow_node_execution_logstore:
field_keys = self._get_workflow_node_execution_index_keys()
# key_config_list should be a dict, not a list
# Create index config with both line and field indexes
return IndexConfig(line_config=line_config, key_config_list=field_keys, scan_index=True)
def create_index(self, logstore_name: str) -> None:
"""
Create index for the specified logstore with both full-text and field indexes.
Field indexes are automatically generated from the corresponding SQLAlchemy model.
"""
index_config = self._get_index_config(logstore_name)
try:
self.client.create_index(self.project_name, logstore_name, index_config)
logger.info(
"index for %s created successfully with %d field indexes",
logstore_name,
len(index_config.key_config_list or {}),
)
except LogException as e:
logger.exception(
"Failed to create index for logstore %s: errorCode=%s, errorMessage=%s, requestId=%s",
logstore_name,
e.get_error_code(),
e.get_error_message(),
e.get_request_id(),
)
raise
def _merge_index_configs(
self, existing_config: IndexConfig, required_keys: dict[str, IndexKeyConfig], logstore_name: str
) -> tuple[IndexConfig, bool]:
"""
Intelligently merge existing index config with Dify's required field indexes.
This method:
1. Preserves all existing field indexes in logstore (including custom fields)
2. Adds missing Dify-required fields
3. Updates fields where type doesn't match (with json/text compatibility)
4. Corrects case mismatches (e.g., if Dify needs 'status' but logstore has 'Status')
Type compatibility rules:
- json and text types are considered compatible (users can manually choose either)
- All other type mismatches will be corrected to match Dify requirements
Note: Logstore is case-sensitive and doesn't allow duplicate fields with different cases.
Case mismatch means: existing field name differs from required name only in case.
Args:
existing_config: Current index configuration from logstore
required_keys: Dify's required field index configurations
logstore_name: Name of the logstore (for logging)
Returns:
Tuple of (merged_config, needs_update)
"""
# key_config_list is already a dict in the SDK
# Make a copy to avoid modifying the original
existing_keys = dict(existing_config.key_config_list) if existing_config.key_config_list else {}
# Track changes
needs_update = False
case_corrections = [] # Fields that need case correction (e.g., 'Status' -> 'status')
missing_fields = []
type_mismatches = []
# First pass: Check for and resolve case mismatches with required fields
# Note: Logstore itself doesn't allow duplicate fields with different cases,
# so we only need to check if the existing case matches the required case
for required_name in required_keys:
lower_name = required_name.lower()
# Find key that matches case-insensitively but not exactly
wrong_case_key = None
for existing_key in existing_keys:
if existing_key.lower() == lower_name and existing_key != required_name:
wrong_case_key = existing_key
break
if wrong_case_key:
# Field exists but with wrong case (e.g., 'Status' when we need 'status')
# Remove the wrong-case key, will be added back with correct case later
case_corrections.append((wrong_case_key, required_name))
del existing_keys[wrong_case_key]
needs_update = True
# Second pass: Check each required field
for required_name, required_config in required_keys.items():
# Check for exact match (case-sensitive)
if required_name in existing_keys:
existing_type = existing_keys[required_name].index_type
required_type = required_config.index_type
# Check if type matches
# Special case: json and text are interchangeable for JSON content fields
# Allow users to manually configure text instead of json (or vice versa) without forcing updates
is_compatible = existing_type == required_type or ({existing_type, required_type} == {"json", "text"})
if not is_compatible:
type_mismatches.append((required_name, existing_type, required_type))
# Update with correct type
existing_keys[required_name] = required_config
needs_update = True
# else: field exists with compatible type, no action needed
else:
# Field doesn't exist (may have been removed in first pass due to case conflict)
missing_fields.append(required_name)
existing_keys[required_name] = required_config
needs_update = True
# Log changes
if missing_fields:
logger.info(
"Logstore %s: Adding %d missing Dify-required fields: %s",
logstore_name,
len(missing_fields),
", ".join(missing_fields[:10]) + ("..." if len(missing_fields) > 10 else ""),
)
if type_mismatches:
logger.info(
"Logstore %s: Fixing %d type mismatches: %s",
logstore_name,
len(type_mismatches),
", ".join([f"{name}({old}->{new})" for name, old, new in type_mismatches[:5]])
+ ("..." if len(type_mismatches) > 5 else ""),
)
if case_corrections:
logger.info(
"Logstore %s: Correcting %d field name cases: %s",
logstore_name,
len(case_corrections),
", ".join([f"'{old}' -> '{new}'" for old, new in case_corrections[:5]])
+ ("..." if len(case_corrections) > 5 else ""),
)
# Create merged config
# key_config_list should be a dict, not a list
# Preserve the original scan_index value - don't force it to True
merged_config = IndexConfig(
line_config=existing_config.line_config
or IndexLineConfig(token_list=self.DEFAULT_TOKEN_LIST, case_sensitive=False, chinese=True),
key_config_list=existing_keys,
scan_index=existing_config.scan_index,
)
return merged_config, needs_update
def ensure_index_config(self, logstore_name: str) -> None:
"""
Ensure index configuration includes all Dify-required fields.
This method intelligently manages index configuration:
1. If index doesn't exist, create it with Dify's required fields
2. If index exists:
- Check if all Dify-required fields are present
- Check if field types match requirements
- Only update if fields are missing or types are incorrect
- Preserve any additional custom index configurations
This approach allows users to add their own custom indexes without being overwritten.
"""
# Get Dify's required field indexes
required_keys = {}
if logstore_name == AliyunLogStore.workflow_execution_logstore:
required_keys = self._get_workflow_execution_index_keys()
elif logstore_name == AliyunLogStore.workflow_node_execution_logstore:
required_keys = self._get_workflow_node_execution_index_keys()
# Check if index exists
existing_config = self.get_existing_index_config(logstore_name)
if existing_config is None:
# Index doesn't exist, create it
logger.info(
"Logstore %s: Index doesn't exist, creating with %d required fields",
logstore_name,
len(required_keys),
)
self.create_index(logstore_name)
else:
merged_config, needs_update = self._merge_index_configs(existing_config, required_keys, logstore_name)
if needs_update:
logger.info("Logstore %s: Updating index to include Dify-required fields", logstore_name)
try:
self.client.update_index(self.project_name, logstore_name, merged_config)
logger.info(
"Logstore %s: Index updated successfully, now has %d total field indexes",
logstore_name,
len(merged_config.key_config_list or {}),
)
except LogException as e:
logger.exception(
"Failed to update index for logstore %s: errorCode=%s, errorMessage=%s, requestId=%s",
logstore_name,
e.get_error_code(),
e.get_error_message(),
e.get_request_id(),
)
raise
else:
logger.info(
"Logstore %s: Index already contains all %d Dify-required fields with correct types, "
"no update needed",
logstore_name,
len(required_keys),
)
def put_log(self, logstore: str, contents: Sequence[tuple[str, str]]) -> None:
# Route to PG or SDK based on protocol availability
if self._use_pg_protocol and self._pg_client:
self._pg_client.put_log(logstore, contents, self.log_enabled)
else:
log_item = LogItem(contents=contents)
request = PutLogsRequest(project=self.project_name, logstore=logstore, logitems=[log_item])
if self.log_enabled:
logger.info(
"[LogStore-SDK] PUT_LOG | logstore=%s | project=%s | items_count=%d",
logstore,
self.project_name,
len(contents),
)
try:
self.client.put_logs(request)
except LogException as e:
logger.exception(
"Failed to put logs to logstore %s: errorCode=%s, errorMessage=%s, requestId=%s",
logstore,
e.get_error_code(),
e.get_error_message(),
e.get_request_id(),
)
raise
def get_logs(
self,
logstore: str,
from_time: int,
to_time: int,
topic: str = "",
query: str = "",
line: int = 100,
offset: int = 0,
reverse: bool = True,
) -> list[dict]:
request = GetLogsRequest(
project=self.project_name,
logstore=logstore,
fromTime=from_time,
toTime=to_time,
topic=topic,
query=query,
line=line,
offset=offset,
reverse=reverse,
)
# Log query info if SQLALCHEMY_ECHO is enabled
if self.log_enabled:
logger.info(
"[LogStore] GET_LOGS | logstore=%s | project=%s | query=%s | "
"from_time=%d | to_time=%d | line=%d | offset=%d | reverse=%s",
logstore,
self.project_name,
query,
from_time,
to_time,
line,
offset,
reverse,
)
try:
response = self.client.get_logs(request)
result = []
logs = response.get_logs() if response else []
for log in logs:
result.append(log.get_contents())
# Log result count if SQLALCHEMY_ECHO is enabled
if self.log_enabled:
logger.info(
"[LogStore] GET_LOGS RESULT | logstore=%s | returned_count=%d",
logstore,
len(result),
)
return result
except LogException as e:
logger.exception(
"Failed to get logs from logstore %s with query '%s': errorCode=%s, errorMessage=%s, requestId=%s",
logstore,
query,
e.get_error_code(),
e.get_error_message(),
e.get_request_id(),
)
raise
def execute_sql(
self,
sql: str,
logstore: str | None = None,
query: str = "*",
from_time: int | None = None,
to_time: int | None = None,
power_sql: bool = False,
) -> list[dict]:
"""
Execute SQL query for aggregation and analysis.
Args:
sql: SQL query string (SELECT statement)
logstore: Name of the logstore (required)
query: Search/filter query for SDK mode (default: "*" for all logs).
Only used in SDK mode. PG mode ignores this parameter.
from_time: Start time (Unix timestamp) - only used in SDK mode
to_time: End time (Unix timestamp) - only used in SDK mode
power_sql: Whether to use enhanced SQL mode (default: False)
Returns:
List of result rows as dictionaries
Note:
- PG mode: Only executes the SQL directly
- SDK mode: Combines query and sql as "query | sql"
"""
# Logstore is required
if not logstore:
raise ValueError("logstore parameter is required for execute_sql")
# Route to PG or SDK based on protocol availability
if self._use_pg_protocol and self._pg_client:
# PG mode: execute SQL directly (ignore query parameter)
return self._pg_client.execute_sql(sql, logstore, self.log_enabled)
else:
# SDK mode: combine query and sql as "query | sql"
full_query = f"{query} | {sql}"
# Provide default time range if not specified
if from_time is None:
from_time = 0
if to_time is None:
to_time = int(time.time()) # now
request = GetLogsRequest(
project=self.project_name,
logstore=logstore,
fromTime=from_time,
toTime=to_time,
query=full_query,
)
# Log query info if SQLALCHEMY_ECHO is enabled
if self.log_enabled:
logger.info(
"[LogStore-SDK] EXECUTE_SQL | logstore=%s | project=%s | from_time=%d | to_time=%d | full_query=%s",
logstore,
self.project_name,
from_time,
to_time,
query,
sql,
)
try:
response = self.client.get_logs(request)
result = []
logs = response.get_logs() if response else []
for log in logs:
result.append(log.get_contents())
# Log result count if SQLALCHEMY_ECHO is enabled
if self.log_enabled:
logger.info(
"[LogStore-SDK] EXECUTE_SQL RESULT | logstore=%s | returned_count=%d",
logstore,
len(result),
)
return result
except LogException as e:
logger.exception(
"Failed to execute SQL, logstore %s: errorCode=%s, errorMessage=%s, requestId=%s, full_query=%s",
logstore,
e.get_error_code(),
e.get_error_message(),
e.get_request_id(),
full_query,
)
raise
if __name__ == "__main__":
aliyun_logstore = AliyunLogStore()
# aliyun_logstore.init_project_logstore()
aliyun_logstore.put_log(AliyunLogStore.workflow_execution_logstore, [("key1", "value1")])

View File

@ -0,0 +1,407 @@
import logging
import os
import socket
import time
from collections.abc import Sequence
from contextlib import contextmanager
from typing import Any
import psycopg2
import psycopg2.pool
from psycopg2 import InterfaceError, OperationalError
from configs import dify_config
logger = logging.getLogger(__name__)
class AliyunLogStorePG:
"""
PostgreSQL protocol support for Aliyun SLS LogStore.
Handles PG connection pooling and operations for regions that support PG protocol.
"""
def __init__(self, access_key_id: str, access_key_secret: str, endpoint: str, project_name: str):
"""
Initialize PG connection for SLS.
Args:
access_key_id: Aliyun access key ID
access_key_secret: Aliyun access key secret
endpoint: SLS endpoint
project_name: SLS project name
"""
self._access_key_id = access_key_id
self._access_key_secret = access_key_secret
self._endpoint = endpoint
self.project_name = project_name
self._pg_pool: psycopg2.pool.SimpleConnectionPool | None = None
self._use_pg_protocol = False
def _check_port_connectivity(self, host: str, port: int, timeout: float = 2.0) -> bool:
"""
Check if a TCP port is reachable using socket connection.
This provides a fast check before attempting full database connection,
preventing long waits when connecting to unsupported regions.
Args:
host: Hostname or IP address
port: Port number
timeout: Connection timeout in seconds (default: 2.0)
Returns:
True if port is reachable, False otherwise
"""
try:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.settimeout(timeout)
result = sock.connect_ex((host, port))
sock.close()
return result == 0
except Exception as e:
logger.debug("Port connectivity check failed for %s:%d: %s", host, port, str(e))
return False
def init_connection(self) -> bool:
"""
Initialize PostgreSQL connection pool for SLS PG protocol support.
Attempts to connect to SLS using PostgreSQL protocol. If successful, sets
_use_pg_protocol to True and creates a connection pool. If connection fails
(region doesn't support PG protocol or other errors), returns False.
Returns:
True if PG protocol is supported and initialized, False otherwise
"""
try:
# Extract hostname from endpoint (remove protocol if present)
pg_host = self._endpoint.replace("http://", "").replace("https://", "")
# Get pool configuration
pg_max_connections = int(os.environ.get("ALIYUN_SLS_PG_MAX_CONNECTIONS", 10))
logger.debug(
"Check PG protocol connection to SLS: host=%s, project=%s",
pg_host,
self.project_name,
)
# Fast port connectivity check before attempting full connection
# This prevents long waits when connecting to unsupported regions
if not self._check_port_connectivity(pg_host, 5432, timeout=1.0):
logger.info(
"USE SDK mode for read/write operations, host=%s",
pg_host,
)
return False
# Create connection pool
self._pg_pool = psycopg2.pool.SimpleConnectionPool(
minconn=1,
maxconn=pg_max_connections,
host=pg_host,
port=5432,
database=self.project_name,
user=self._access_key_id,
password=self._access_key_secret,
sslmode="require",
connect_timeout=5,
application_name=f"Dify-{dify_config.project.version}",
)
# Note: Skip test query because SLS PG protocol only supports SELECT/INSERT on actual tables
# Connection pool creation success already indicates connectivity
self._use_pg_protocol = True
logger.info(
"PG protocol initialized successfully for SLS project=%s. Will use PG for read/write operations.",
self.project_name,
)
return True
except Exception as e:
# PG connection failed - fallback to SDK mode
self._use_pg_protocol = False
if self._pg_pool:
try:
self._pg_pool.closeall()
except Exception:
logger.debug("Failed to close PG connection pool during cleanup, ignoring")
self._pg_pool = None
logger.info(
"PG protocol connection failed (region may not support PG protocol): %s. "
"Falling back to SDK mode for read/write operations.",
str(e),
)
return False
def _is_connection_valid(self, conn: Any) -> bool:
"""
Check if a connection is still valid.
Args:
conn: psycopg2 connection object
Returns:
True if connection is valid, False otherwise
"""
try:
# Check if connection is closed
if conn.closed:
return False
# Quick ping test - execute a lightweight query
# For SLS PG protocol, we can't use SELECT 1 without FROM,
# so we just check the connection status
with conn.cursor() as cursor:
cursor.execute("SELECT 1")
cursor.fetchone()
return True
except Exception:
return False
@contextmanager
def _get_connection(self):
"""
Context manager to get a PostgreSQL connection from the pool.
Automatically validates and refreshes stale connections.
Note: Aliyun SLS PG protocol does not support transactions, so we always
use autocommit mode.
Yields:
psycopg2 connection object
Raises:
RuntimeError: If PG pool is not initialized
"""
if not self._pg_pool:
raise RuntimeError("PG connection pool is not initialized")
conn = self._pg_pool.getconn()
try:
# Validate connection and get a fresh one if needed
if not self._is_connection_valid(conn):
logger.debug("Connection is stale, marking as bad and getting a new one")
# Mark connection as bad and get a new one
self._pg_pool.putconn(conn, close=True)
conn = self._pg_pool.getconn()
# Aliyun SLS PG protocol does not support transactions, always use autocommit
conn.autocommit = True
yield conn
finally:
# Return connection to pool (or close if it's bad)
if self._is_connection_valid(conn):
self._pg_pool.putconn(conn)
else:
self._pg_pool.putconn(conn, close=True)
def close(self) -> None:
"""Close the PostgreSQL connection pool."""
if self._pg_pool:
try:
self._pg_pool.closeall()
logger.info("PG connection pool closed")
except Exception:
logger.exception("Failed to close PG connection pool")
def _is_retriable_error(self, error: Exception) -> bool:
"""
Check if an error is retriable (connection-related issues).
Args:
error: Exception to check
Returns:
True if the error is retriable, False otherwise
"""
# Retry on connection-related errors
if isinstance(error, (OperationalError, InterfaceError)):
return True
# Check error message for specific connection issues
error_msg = str(error).lower()
retriable_patterns = [
"connection",
"timeout",
"closed",
"broken pipe",
"reset by peer",
"no route to host",
"network",
]
return any(pattern in error_msg for pattern in retriable_patterns)
def put_log(self, logstore: str, contents: Sequence[tuple[str, str]], log_enabled: bool = False) -> None:
"""
Write log to SLS using PostgreSQL protocol with automatic retry.
Note: SLS PG protocol only supports INSERT (not UPDATE). This uses append-only
writes with log_version field for versioning, same as SDK implementation.
Args:
logstore: Name of the logstore table
contents: List of (field_name, value) tuples
log_enabled: Whether to enable logging
Raises:
psycopg2.Error: If database operation fails after all retries
"""
if not contents:
return
# Extract field names and values from contents
fields = [field_name for field_name, _ in contents]
values = [value for _, value in contents]
# Build INSERT statement with literal values
# Note: Aliyun SLS PG protocol doesn't support parameterized queries,
# so we need to use mogrify to safely create literal values
field_list = ", ".join([f'"{field}"' for field in fields])
if log_enabled:
logger.info(
"[LogStore-PG] PUT_LOG | logstore=%s | project=%s | items_count=%d",
logstore,
self.project_name,
len(contents),
)
# Retry configuration
max_retries = 3
retry_delay = 0.1 # Start with 100ms
for attempt in range(max_retries):
try:
with self._get_connection() as conn:
with conn.cursor() as cursor:
# Use mogrify to safely convert values to SQL literals
placeholders = ", ".join(["%s"] * len(fields))
values_literal = cursor.mogrify(f"({placeholders})", values).decode("utf-8")
insert_sql = f'INSERT INTO "{logstore}" ({field_list}) VALUES {values_literal}'
cursor.execute(insert_sql)
# Success - exit retry loop
return
except psycopg2.Error as e:
# Check if error is retriable
if not self._is_retriable_error(e):
# Not a retriable error (e.g., data validation error), fail immediately
logger.exception(
"Failed to put logs to logstore %s via PG protocol (non-retriable error)",
logstore,
)
raise
# Retriable error - log and retry if we have attempts left
if attempt < max_retries - 1:
logger.warning(
"Failed to put logs to logstore %s via PG protocol (attempt %d/%d): %s. Retrying...",
logstore,
attempt + 1,
max_retries,
str(e),
)
time.sleep(retry_delay)
retry_delay *= 2 # Exponential backoff
else:
# Last attempt failed
logger.exception(
"Failed to put logs to logstore %s via PG protocol after %d attempts",
logstore,
max_retries,
)
raise
def execute_sql(self, sql: str, logstore: str, log_enabled: bool = False) -> list[dict[str, Any]]:
"""
Execute SQL query using PostgreSQL protocol with automatic retry.
Args:
sql: SQL query string
logstore: Name of the logstore (for logging purposes)
log_enabled: Whether to enable logging
Returns:
List of result rows as dictionaries
Raises:
psycopg2.Error: If database operation fails after all retries
"""
if log_enabled:
logger.info(
"[LogStore-PG] EXECUTE_SQL | logstore=%s | project=%s | sql=%s",
logstore,
self.project_name,
sql,
)
# Retry configuration
max_retries = 3
retry_delay = 0.1 # Start with 100ms
for attempt in range(max_retries):
try:
with self._get_connection() as conn:
with conn.cursor() as cursor:
cursor.execute(sql)
# Get column names from cursor description
columns = [desc[0] for desc in cursor.description]
# Fetch all results and convert to list of dicts
result = []
for row in cursor.fetchall():
row_dict = {}
for col, val in zip(columns, row):
row_dict[col] = "" if val is None else str(val)
result.append(row_dict)
if log_enabled:
logger.info(
"[LogStore-PG] EXECUTE_SQL RESULT | logstore=%s | returned_count=%d",
logstore,
len(result),
)
return result
except psycopg2.Error as e:
# Check if error is retriable
if not self._is_retriable_error(e):
# Not a retriable error (e.g., SQL syntax error), fail immediately
logger.exception(
"Failed to execute SQL query on logstore %s via PG protocol (non-retriable error): sql=%s",
logstore,
sql,
)
raise
# Retriable error - log and retry if we have attempts left
if attempt < max_retries - 1:
logger.warning(
"Failed to execute SQL query on logstore %s via PG protocol (attempt %d/%d): %s. Retrying...",
logstore,
attempt + 1,
max_retries,
str(e),
)
time.sleep(retry_delay)
retry_delay *= 2 # Exponential backoff
else:
# Last attempt failed
logger.exception(
"Failed to execute SQL query on logstore %s via PG protocol after %d attempts: sql=%s",
logstore,
max_retries,
sql,
)
raise
# This line should never be reached due to raise above, but makes type checker happy
return []

View File

@ -0,0 +1,365 @@
"""
LogStore implementation of DifyAPIWorkflowNodeExecutionRepository.
This module provides the LogStore-based implementation for service-layer
WorkflowNodeExecutionModel operations using Aliyun SLS LogStore.
"""
import logging
import time
from collections.abc import Sequence
from datetime import datetime
from typing import Any
from sqlalchemy.orm import sessionmaker
from extensions.logstore.aliyun_logstore import AliyunLogStore
from models.workflow import WorkflowNodeExecutionModel
from repositories.api_workflow_node_execution_repository import DifyAPIWorkflowNodeExecutionRepository
logger = logging.getLogger(__name__)
def _dict_to_workflow_node_execution_model(data: dict[str, Any]) -> WorkflowNodeExecutionModel:
"""
Convert LogStore result dictionary to WorkflowNodeExecutionModel instance.
Args:
data: Dictionary from LogStore query result
Returns:
WorkflowNodeExecutionModel instance (detached from session)
Note:
The returned model is not attached to any SQLAlchemy session.
Relationship fields (like offload_data) are not loaded from LogStore.
"""
logger.debug("_dict_to_workflow_node_execution_model: data keys=%s", list(data.keys())[:5])
# Create model instance without session
model = WorkflowNodeExecutionModel()
# Map all required fields with validation
# Critical fields - must not be None
model.id = data.get("id") or ""
model.tenant_id = data.get("tenant_id") or ""
model.app_id = data.get("app_id") or ""
model.workflow_id = data.get("workflow_id") or ""
model.triggered_from = data.get("triggered_from") or ""
model.node_id = data.get("node_id") or ""
model.node_type = data.get("node_type") or ""
model.status = data.get("status") or "running" # Default status if missing
model.title = data.get("title") or ""
model.created_by_role = data.get("created_by_role") or ""
model.created_by = data.get("created_by") or ""
# Numeric fields with defaults
model.index = int(data.get("index", 0))
model.elapsed_time = float(data.get("elapsed_time", 0))
# Optional fields
model.workflow_run_id = data.get("workflow_run_id")
model.predecessor_node_id = data.get("predecessor_node_id")
model.node_execution_id = data.get("node_execution_id")
model.inputs = data.get("inputs")
model.process_data = data.get("process_data")
model.outputs = data.get("outputs")
model.error = data.get("error")
model.execution_metadata = data.get("execution_metadata")
# Handle datetime fields
created_at = data.get("created_at")
if created_at:
if isinstance(created_at, str):
model.created_at = datetime.fromisoformat(created_at)
elif isinstance(created_at, (int, float)):
model.created_at = datetime.fromtimestamp(created_at)
else:
model.created_at = created_at
else:
# Provide default created_at if missing
model.created_at = datetime.now()
finished_at = data.get("finished_at")
if finished_at:
if isinstance(finished_at, str):
model.finished_at = datetime.fromisoformat(finished_at)
elif isinstance(finished_at, (int, float)):
model.finished_at = datetime.fromtimestamp(finished_at)
else:
model.finished_at = finished_at
return model
class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRepository):
"""
LogStore implementation of DifyAPIWorkflowNodeExecutionRepository.
Provides service-layer database operations for WorkflowNodeExecutionModel
using LogStore SQL queries with optimized deduplication strategies.
"""
def __init__(self, session_maker: sessionmaker | None = None):
"""
Initialize the repository with LogStore client.
Args:
session_maker: SQLAlchemy sessionmaker (unused, for compatibility with factory pattern)
"""
logger.debug("LogstoreAPIWorkflowNodeExecutionRepository.__init__: initializing")
self.logstore_client = AliyunLogStore()
def get_node_last_execution(
self,
tenant_id: str,
app_id: str,
workflow_id: str,
node_id: str,
) -> WorkflowNodeExecutionModel | None:
"""
Get the most recent execution for a specific node.
Uses query syntax to get raw logs and selects the one with max log_version.
Returns the most recent execution ordered by created_at.
"""
logger.debug(
"get_node_last_execution: tenant_id=%s, app_id=%s, workflow_id=%s, node_id=%s",
tenant_id,
app_id,
workflow_id,
node_id,
)
try:
# Check if PG protocol is supported
if self.logstore_client.supports_pg_protocol:
# Use PG protocol with SQL query (get latest version of each record)
sql_query = f"""
SELECT * FROM (
SELECT *,
ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) as rn
FROM "{AliyunLogStore.workflow_node_execution_logstore}"
WHERE tenant_id = '{tenant_id}'
AND app_id = '{app_id}'
AND workflow_id = '{workflow_id}'
AND node_id = '{node_id}'
AND __time__ > 0
) AS subquery WHERE rn = 1
LIMIT 100
"""
results = self.logstore_client.execute_sql(
sql=sql_query,
logstore=AliyunLogStore.workflow_node_execution_logstore,
)
else:
# Use SDK with LogStore query syntax
query = (
f"tenant_id: {tenant_id} and app_id: {app_id} and workflow_id: {workflow_id} and node_id: {node_id}"
)
from_time = 0
to_time = int(time.time()) # now
results = self.logstore_client.get_logs(
logstore=AliyunLogStore.workflow_node_execution_logstore,
from_time=from_time,
to_time=to_time,
query=query,
line=100,
reverse=False,
)
if not results:
return None
# For SDK mode, group by id and select the one with max log_version for each group
# For PG mode, this is already done by the SQL query
if not self.logstore_client.supports_pg_protocol:
id_to_results: dict[str, list[dict[str, Any]]] = {}
for row in results:
row_id = row.get("id")
if row_id:
if row_id not in id_to_results:
id_to_results[row_id] = []
id_to_results[row_id].append(row)
# For each id, select the row with max log_version
deduplicated_results = []
for rows in id_to_results.values():
if len(rows) > 1:
max_row = max(rows, key=lambda x: int(x.get("log_version", 0)))
else:
max_row = rows[0]
deduplicated_results.append(max_row)
else:
# For PG mode, results are already deduplicated by the SQL query
deduplicated_results = results
# Sort by created_at DESC and return the most recent one
deduplicated_results.sort(
key=lambda x: x.get("created_at", 0) if isinstance(x.get("created_at"), (int, float)) else 0,
reverse=True,
)
if deduplicated_results:
return _dict_to_workflow_node_execution_model(deduplicated_results[0])
return None
except Exception:
logger.exception("Failed to get node last execution from LogStore")
raise
def get_executions_by_workflow_run(
self,
tenant_id: str,
app_id: str,
workflow_run_id: str,
) -> Sequence[WorkflowNodeExecutionModel]:
"""
Get all node executions for a specific workflow run.
Uses query syntax to get raw logs and selects the one with max log_version for each node execution.
Ordered by index DESC for trace visualization.
"""
logger.debug(
"[LogStore] get_executions_by_workflow_run: tenant_id=%s, app_id=%s, workflow_run_id=%s",
tenant_id,
app_id,
workflow_run_id,
)
try:
# Check if PG protocol is supported
if self.logstore_client.supports_pg_protocol:
# Use PG protocol with SQL query (get latest version of each record)
sql_query = f"""
SELECT * FROM (
SELECT *,
ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) as rn
FROM "{AliyunLogStore.workflow_node_execution_logstore}"
WHERE tenant_id = '{tenant_id}'
AND app_id = '{app_id}'
AND workflow_run_id = '{workflow_run_id}'
AND __time__ > 0
) AS subquery WHERE rn = 1
LIMIT 1000
"""
results = self.logstore_client.execute_sql(
sql=sql_query,
logstore=AliyunLogStore.workflow_node_execution_logstore,
)
else:
# Use SDK with LogStore query syntax
query = f"tenant_id: {tenant_id} and app_id: {app_id} and workflow_run_id: {workflow_run_id}"
from_time = 0
to_time = int(time.time()) # now
results = self.logstore_client.get_logs(
logstore=AliyunLogStore.workflow_node_execution_logstore,
from_time=from_time,
to_time=to_time,
query=query,
line=1000, # Get more results for node executions
reverse=False,
)
if not results:
return []
# For SDK mode, group by id and select the one with max log_version for each group
# For PG mode, this is already done by the SQL query
models = []
if not self.logstore_client.supports_pg_protocol:
id_to_results: dict[str, list[dict[str, Any]]] = {}
for row in results:
row_id = row.get("id")
if row_id:
if row_id not in id_to_results:
id_to_results[row_id] = []
id_to_results[row_id].append(row)
# For each id, select the row with max log_version
for rows in id_to_results.values():
if len(rows) > 1:
max_row = max(rows, key=lambda x: int(x.get("log_version", 0)))
else:
max_row = rows[0]
model = _dict_to_workflow_node_execution_model(max_row)
if model and model.id: # Ensure model is valid
models.append(model)
else:
# For PG mode, results are already deduplicated by the SQL query
for row in results:
model = _dict_to_workflow_node_execution_model(row)
if model and model.id: # Ensure model is valid
models.append(model)
# Sort by index DESC for trace visualization
models.sort(key=lambda x: x.index, reverse=True)
return models
except Exception:
logger.exception("Failed to get executions by workflow run from LogStore")
raise
def get_execution_by_id(
self,
execution_id: str,
tenant_id: str | None = None,
) -> WorkflowNodeExecutionModel | None:
"""
Get a workflow node execution by its ID.
Uses query syntax to get raw logs and selects the one with max log_version.
"""
logger.debug("get_execution_by_id: execution_id=%s, tenant_id=%s", execution_id, tenant_id)
try:
# Check if PG protocol is supported
if self.logstore_client.supports_pg_protocol:
# Use PG protocol with SQL query (get latest version of record)
tenant_filter = f"AND tenant_id = '{tenant_id}'" if tenant_id else ""
sql_query = f"""
SELECT * FROM (
SELECT *,
ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) as rn
FROM "{AliyunLogStore.workflow_node_execution_logstore}"
WHERE id = '{execution_id}' {tenant_filter} AND __time__ > 0
) AS subquery WHERE rn = 1
LIMIT 1
"""
results = self.logstore_client.execute_sql(
sql=sql_query,
logstore=AliyunLogStore.workflow_node_execution_logstore,
)
else:
# Use SDK with LogStore query syntax
if tenant_id:
query = f"id: {execution_id} and tenant_id: {tenant_id}"
else:
query = f"id: {execution_id}"
from_time = 0
to_time = int(time.time()) # now
results = self.logstore_client.get_logs(
logstore=AliyunLogStore.workflow_node_execution_logstore,
from_time=from_time,
to_time=to_time,
query=query,
line=100,
reverse=False,
)
if not results:
return None
# For PG mode, result is already the latest version
# For SDK mode, if multiple results, select the one with max log_version
if self.logstore_client.supports_pg_protocol or len(results) == 1:
return _dict_to_workflow_node_execution_model(results[0])
else:
max_result = max(results, key=lambda x: int(x.get("log_version", 0)))
return _dict_to_workflow_node_execution_model(max_result)
except Exception:
logger.exception("Failed to get execution by ID from LogStore: execution_id=%s", execution_id)
raise

View File

@ -0,0 +1,757 @@
"""
LogStore API WorkflowRun Repository Implementation
This module provides the LogStore-based implementation of the APIWorkflowRunRepository
protocol. It handles service-layer WorkflowRun database operations using Aliyun SLS LogStore
with optimized queries for statistics and pagination.
Key Features:
- LogStore SQL queries for aggregation and statistics
- Optimized deduplication using finished_at IS NOT NULL filter
- Window functions only when necessary (running status queries)
- Multi-tenant data isolation and security
"""
import logging
import os
import time
from collections.abc import Sequence
from datetime import datetime
from typing import Any, cast
from sqlalchemy.orm import sessionmaker
from extensions.logstore.aliyun_logstore import AliyunLogStore
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from models.enums import WorkflowRunTriggeredFrom
from models.workflow import WorkflowRun
from repositories.api_workflow_run_repository import APIWorkflowRunRepository
from repositories.types import (
AverageInteractionStats,
DailyRunsStats,
DailyTerminalsStats,
DailyTokenCostStats,
)
logger = logging.getLogger(__name__)
def _dict_to_workflow_run(data: dict[str, Any]) -> WorkflowRun:
"""
Convert LogStore result dictionary to WorkflowRun instance.
Args:
data: Dictionary from LogStore query result
Returns:
WorkflowRun instance
"""
logger.debug("_dict_to_workflow_run: data keys=%s", list(data.keys())[:5])
# Create model instance without session
model = WorkflowRun()
# Map all required fields with validation
# Critical fields - must not be None
model.id = data.get("id") or ""
model.tenant_id = data.get("tenant_id") or ""
model.app_id = data.get("app_id") or ""
model.workflow_id = data.get("workflow_id") or ""
model.type = data.get("type") or ""
model.triggered_from = data.get("triggered_from") or ""
model.version = data.get("version") or ""
model.status = data.get("status") or "running" # Default status if missing
model.created_by_role = data.get("created_by_role") or ""
model.created_by = data.get("created_by") or ""
# Numeric fields with defaults
model.total_tokens = int(data.get("total_tokens", 0))
model.total_steps = int(data.get("total_steps", 0))
model.exceptions_count = int(data.get("exceptions_count", 0))
# Optional fields
model.graph = data.get("graph")
model.inputs = data.get("inputs")
model.outputs = data.get("outputs")
model.error = data.get("error_message") or data.get("error")
# Handle datetime fields
started_at = data.get("started_at") or data.get("created_at")
if started_at:
if isinstance(started_at, str):
model.created_at = datetime.fromisoformat(started_at)
elif isinstance(started_at, (int, float)):
model.created_at = datetime.fromtimestamp(started_at)
else:
model.created_at = started_at
else:
# Provide default created_at if missing
model.created_at = datetime.now()
finished_at = data.get("finished_at")
if finished_at:
if isinstance(finished_at, str):
model.finished_at = datetime.fromisoformat(finished_at)
elif isinstance(finished_at, (int, float)):
model.finished_at = datetime.fromtimestamp(finished_at)
else:
model.finished_at = finished_at
# Compute elapsed_time from started_at and finished_at
# LogStore doesn't store elapsed_time, it's computed in WorkflowExecution domain entity
if model.finished_at and model.created_at:
model.elapsed_time = (model.finished_at - model.created_at).total_seconds()
else:
model.elapsed_time = float(data.get("elapsed_time", 0))
return model
class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
"""
LogStore implementation of APIWorkflowRunRepository.
Provides service-layer WorkflowRun database operations using LogStore SQL
with optimized query strategies:
- Use finished_at IS NOT NULL for deduplication (10-100x faster)
- Use window functions only when running status is required
- Proper time range filtering for LogStore queries
"""
def __init__(self, session_maker: sessionmaker | None = None):
"""
Initialize the repository with LogStore client.
Args:
session_maker: SQLAlchemy sessionmaker (unused, for compatibility with factory pattern)
"""
logger.debug("LogstoreAPIWorkflowRunRepository.__init__: initializing")
self.logstore_client = AliyunLogStore()
# Control flag for dual-read (fallback to PostgreSQL when LogStore returns no results)
# Set to True to enable fallback for safe migration from PostgreSQL to LogStore
# Set to False for new deployments without legacy data in PostgreSQL
self._enable_dual_read = os.environ.get("LOGSTORE_DUAL_READ_ENABLED", "true").lower() == "true"
def get_paginated_workflow_runs(
self,
tenant_id: str,
app_id: str,
triggered_from: WorkflowRunTriggeredFrom | Sequence[WorkflowRunTriggeredFrom],
limit: int = 20,
last_id: str | None = None,
status: str | None = None,
) -> InfiniteScrollPagination:
"""
Get paginated workflow runs with filtering.
Uses window function for deduplication to support both running and finished states.
Args:
tenant_id: Tenant identifier for multi-tenant isolation
app_id: Application identifier
triggered_from: Filter by trigger source(s)
limit: Maximum number of records to return (default: 20)
last_id: Cursor for pagination - ID of the last record from previous page
status: Optional filter by status
Returns:
InfiniteScrollPagination object
"""
logger.debug(
"get_paginated_workflow_runs: tenant_id=%s, app_id=%s, limit=%d, status=%s",
tenant_id,
app_id,
limit,
status,
)
# Convert triggered_from to list if needed
if isinstance(triggered_from, WorkflowRunTriggeredFrom):
triggered_from_list = [triggered_from]
else:
triggered_from_list = list(triggered_from)
# Build triggered_from filter
triggered_from_filter = " OR ".join([f"triggered_from='{tf.value}'" for tf in triggered_from_list])
# Build status filter
status_filter = f"AND status='{status}'" if status else ""
# Build last_id filter for pagination
# Note: This is simplified. In production, you'd need to track created_at from last record
last_id_filter = ""
if last_id:
# TODO: Implement proper cursor-based pagination with created_at
logger.warning("last_id pagination not fully implemented for LogStore")
# Use window function to get latest log_version of each workflow run
sql = f"""
SELECT * FROM (
SELECT *, ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) AS rn
FROM {AliyunLogStore.workflow_execution_logstore}
WHERE tenant_id='{tenant_id}'
AND app_id='{app_id}'
AND ({triggered_from_filter})
{status_filter}
{last_id_filter}
) t
WHERE rn = 1
ORDER BY created_at DESC
LIMIT {limit + 1}
"""
try:
results = self.logstore_client.execute_sql(
sql=sql, query="*", logstore=AliyunLogStore.workflow_execution_logstore, from_time=None, to_time=None
)
# Check if there are more records
has_more = len(results) > limit
if has_more:
results = results[:limit]
# Convert results to WorkflowRun models
workflow_runs = [_dict_to_workflow_run(row) for row in results]
return InfiniteScrollPagination(data=workflow_runs, limit=limit, has_more=has_more)
except Exception:
logger.exception("Failed to get paginated workflow runs from LogStore")
raise
def get_workflow_run_by_id(
self,
tenant_id: str,
app_id: str,
run_id: str,
) -> WorkflowRun | None:
"""
Get a specific workflow run by ID with tenant and app isolation.
Uses query syntax to get raw logs and selects the one with max log_version in code.
Falls back to PostgreSQL if not found in LogStore (for data consistency during migration).
"""
logger.debug("get_workflow_run_by_id: tenant_id=%s, app_id=%s, run_id=%s", tenant_id, app_id, run_id)
try:
# Check if PG protocol is supported
if self.logstore_client.supports_pg_protocol:
# Use PG protocol with SQL query (get latest version of record)
sql_query = f"""
SELECT * FROM (
SELECT *,
ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) as rn
FROM "{AliyunLogStore.workflow_execution_logstore}"
WHERE id = '{run_id}' AND tenant_id = '{tenant_id}' AND app_id = '{app_id}' AND __time__ > 0
) AS subquery WHERE rn = 1
LIMIT 100
"""
results = self.logstore_client.execute_sql(
sql=sql_query,
logstore=AliyunLogStore.workflow_execution_logstore,
)
else:
# Use SDK with LogStore query syntax
query = f"id: {run_id} and tenant_id: {tenant_id} and app_id: {app_id}"
from_time = 0
to_time = int(time.time()) # now
results = self.logstore_client.get_logs(
logstore=AliyunLogStore.workflow_execution_logstore,
from_time=from_time,
to_time=to_time,
query=query,
line=100,
reverse=False,
)
if not results:
# Fallback to PostgreSQL for records created before LogStore migration
if self._enable_dual_read:
logger.debug(
"WorkflowRun not found in LogStore, falling back to PostgreSQL: "
"run_id=%s, tenant_id=%s, app_id=%s",
run_id,
tenant_id,
app_id,
)
return self._fallback_get_workflow_run_by_id_with_tenant(run_id, tenant_id, app_id)
return None
# For PG mode, results are already deduplicated by the SQL query
# For SDK mode, if multiple results, select the one with max log_version
if self.logstore_client.supports_pg_protocol or len(results) == 1:
return _dict_to_workflow_run(results[0])
else:
max_result = max(results, key=lambda x: int(x.get("log_version", 0)))
return _dict_to_workflow_run(max_result)
except Exception:
logger.exception("Failed to get workflow run by ID from LogStore: run_id=%s", run_id)
# Try PostgreSQL fallback on any error (only if dual-read is enabled)
if self._enable_dual_read:
try:
return self._fallback_get_workflow_run_by_id_with_tenant(run_id, tenant_id, app_id)
except Exception:
logger.exception(
"PostgreSQL fallback also failed: run_id=%s, tenant_id=%s, app_id=%s", run_id, tenant_id, app_id
)
raise
def _fallback_get_workflow_run_by_id_with_tenant(
self, run_id: str, tenant_id: str, app_id: str
) -> WorkflowRun | None:
"""Fallback to PostgreSQL query for records not in LogStore (with tenant isolation)."""
from sqlalchemy import select
from sqlalchemy.orm import Session
from extensions.ext_database import db
with Session(db.engine) as session:
stmt = select(WorkflowRun).where(
WorkflowRun.id == run_id, WorkflowRun.tenant_id == tenant_id, WorkflowRun.app_id == app_id
)
return session.scalar(stmt)
def get_workflow_run_by_id_without_tenant(
self,
run_id: str,
) -> WorkflowRun | None:
"""
Get a specific workflow run by ID without tenant/app context.
Uses query syntax to get raw logs and selects the one with max log_version.
Falls back to PostgreSQL if not found in LogStore (controlled by LOGSTORE_DUAL_READ_ENABLED).
"""
logger.debug("get_workflow_run_by_id_without_tenant: run_id=%s", run_id)
try:
# Check if PG protocol is supported
if self.logstore_client.supports_pg_protocol:
# Use PG protocol with SQL query (get latest version of record)
sql_query = f"""
SELECT * FROM (
SELECT *,
ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) as rn
FROM "{AliyunLogStore.workflow_execution_logstore}"
WHERE id = '{run_id}' AND __time__ > 0
) AS subquery WHERE rn = 1
LIMIT 100
"""
results = self.logstore_client.execute_sql(
sql=sql_query,
logstore=AliyunLogStore.workflow_execution_logstore,
)
else:
# Use SDK with LogStore query syntax
query = f"id: {run_id}"
from_time = 0
to_time = int(time.time()) # now
results = self.logstore_client.get_logs(
logstore=AliyunLogStore.workflow_execution_logstore,
from_time=from_time,
to_time=to_time,
query=query,
line=100,
reverse=False,
)
if not results:
# Fallback to PostgreSQL for records created before LogStore migration
if self._enable_dual_read:
logger.debug("WorkflowRun not found in LogStore, falling back to PostgreSQL: run_id=%s", run_id)
return self._fallback_get_workflow_run_by_id(run_id)
return None
# For PG mode, results are already deduplicated by the SQL query
# For SDK mode, if multiple results, select the one with max log_version
if self.logstore_client.supports_pg_protocol or len(results) == 1:
return _dict_to_workflow_run(results[0])
else:
max_result = max(results, key=lambda x: int(x.get("log_version", 0)))
return _dict_to_workflow_run(max_result)
except Exception:
logger.exception("Failed to get workflow run without tenant: run_id=%s", run_id)
# Try PostgreSQL fallback on any error (only if dual-read is enabled)
if self._enable_dual_read:
try:
return self._fallback_get_workflow_run_by_id(run_id)
except Exception:
logger.exception("PostgreSQL fallback also failed: run_id=%s", run_id)
raise
def _fallback_get_workflow_run_by_id(self, run_id: str) -> WorkflowRun | None:
"""Fallback to PostgreSQL query for records not in LogStore."""
from sqlalchemy import select
from sqlalchemy.orm import Session
from extensions.ext_database import db
with Session(db.engine) as session:
stmt = select(WorkflowRun).where(WorkflowRun.id == run_id)
return session.scalar(stmt)
def get_workflow_runs_count(
self,
tenant_id: str,
app_id: str,
triggered_from: str,
status: str | None = None,
time_range: str | None = None,
) -> dict[str, int]:
"""
Get workflow runs count statistics grouped by status.
Optimization: Use finished_at IS NOT NULL for completed runs (10-50x faster)
"""
logger.debug(
"get_workflow_runs_count: tenant_id=%s, app_id=%s, triggered_from=%s, status=%s",
tenant_id,
app_id,
triggered_from,
status,
)
# Build time range filter
time_filter = ""
if time_range:
# TODO: Parse time_range and convert to from_time/to_time
logger.warning("time_range filter not implemented")
# If status is provided, simple count
if status:
if status == "running":
# Running status requires window function
sql = f"""
SELECT COUNT(*) as count
FROM (
SELECT *, ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) AS rn
FROM {AliyunLogStore.workflow_execution_logstore}
WHERE tenant_id='{tenant_id}'
AND app_id='{app_id}'
AND triggered_from='{triggered_from}'
AND status='running'
{time_filter}
) t
WHERE rn = 1
"""
else:
# Finished status uses optimized filter
sql = f"""
SELECT COUNT(DISTINCT id) as count
FROM {AliyunLogStore.workflow_execution_logstore}
WHERE tenant_id='{tenant_id}'
AND app_id='{app_id}'
AND triggered_from='{triggered_from}'
AND status='{status}'
AND finished_at IS NOT NULL
{time_filter}
"""
try:
results = self.logstore_client.execute_sql(
sql=sql, query="*", logstore=AliyunLogStore.workflow_execution_logstore
)
count = results[0]["count"] if results and len(results) > 0 else 0
return {
"total": count,
"running": count if status == "running" else 0,
"succeeded": count if status == "succeeded" else 0,
"failed": count if status == "failed" else 0,
"stopped": count if status == "stopped" else 0,
"partial-succeeded": count if status == "partial-succeeded" else 0,
}
except Exception:
logger.exception("Failed to get workflow runs count")
raise
# No status filter - get counts grouped by status
# Use optimized query for finished runs, separate query for running
try:
# Count finished runs grouped by status
finished_sql = f"""
SELECT status, COUNT(DISTINCT id) as count
FROM {AliyunLogStore.workflow_execution_logstore}
WHERE tenant_id='{tenant_id}'
AND app_id='{app_id}'
AND triggered_from='{triggered_from}'
AND finished_at IS NOT NULL
{time_filter}
GROUP BY status
"""
# Count running runs
running_sql = f"""
SELECT COUNT(*) as count
FROM (
SELECT *, ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) AS rn
FROM {AliyunLogStore.workflow_execution_logstore}
WHERE tenant_id='{tenant_id}'
AND app_id='{app_id}'
AND triggered_from='{triggered_from}'
AND status='running'
{time_filter}
) t
WHERE rn = 1
"""
finished_results = self.logstore_client.execute_sql(
sql=finished_sql, query="*", logstore=AliyunLogStore.workflow_execution_logstore
)
running_results = self.logstore_client.execute_sql(
sql=running_sql, query="*", logstore=AliyunLogStore.workflow_execution_logstore
)
# Build response
status_counts = {
"running": 0,
"succeeded": 0,
"failed": 0,
"stopped": 0,
"partial-succeeded": 0,
}
total = 0
for result in finished_results:
status_val = result.get("status")
count = result.get("count", 0)
if status_val in status_counts:
status_counts[status_val] = count
total += count
# Add running count
running_count = running_results[0]["count"] if running_results and len(running_results) > 0 else 0
status_counts["running"] = running_count
total += running_count
return {"total": total} | status_counts
except Exception:
logger.exception("Failed to get workflow runs count")
raise
def get_daily_runs_statistics(
self,
tenant_id: str,
app_id: str,
triggered_from: str,
start_date: datetime | None = None,
end_date: datetime | None = None,
timezone: str = "UTC",
) -> list[DailyRunsStats]:
"""
Get daily runs statistics using optimized query.
Optimization: Use finished_at IS NOT NULL + COUNT(DISTINCT id) (20-100x faster)
"""
logger.debug(
"get_daily_runs_statistics: tenant_id=%s, app_id=%s, triggered_from=%s", tenant_id, app_id, triggered_from
)
# Build time range filter
time_filter = ""
if start_date:
time_filter += f" AND __time__ >= to_unixtime(from_iso8601_timestamp('{start_date.isoformat()}'))"
if end_date:
time_filter += f" AND __time__ < to_unixtime(from_iso8601_timestamp('{end_date.isoformat()}'))"
# Optimized query: Use finished_at filter to avoid window function
sql = f"""
SELECT DATE(from_unixtime(__time__)) as date, COUNT(DISTINCT id) as runs
FROM {AliyunLogStore.workflow_execution_logstore}
WHERE tenant_id='{tenant_id}'
AND app_id='{app_id}'
AND triggered_from='{triggered_from}'
AND finished_at IS NOT NULL
{time_filter}
GROUP BY date
ORDER BY date
"""
try:
results = self.logstore_client.execute_sql(
sql=sql, query="*", logstore=AliyunLogStore.workflow_execution_logstore
)
response_data = []
for row in results:
response_data.append({"date": str(row.get("date", "")), "runs": row.get("runs", 0)})
return cast(list[DailyRunsStats], response_data)
except Exception:
logger.exception("Failed to get daily runs statistics")
raise
def get_daily_terminals_statistics(
self,
tenant_id: str,
app_id: str,
triggered_from: str,
start_date: datetime | None = None,
end_date: datetime | None = None,
timezone: str = "UTC",
) -> list[DailyTerminalsStats]:
"""
Get daily terminals statistics using optimized query.
Optimization: Use finished_at IS NOT NULL + COUNT(DISTINCT created_by) (20-100x faster)
"""
logger.debug(
"get_daily_terminals_statistics: tenant_id=%s, app_id=%s, triggered_from=%s",
tenant_id,
app_id,
triggered_from,
)
# Build time range filter
time_filter = ""
if start_date:
time_filter += f" AND __time__ >= to_unixtime(from_iso8601_timestamp('{start_date.isoformat()}'))"
if end_date:
time_filter += f" AND __time__ < to_unixtime(from_iso8601_timestamp('{end_date.isoformat()}'))"
sql = f"""
SELECT DATE(from_unixtime(__time__)) as date, COUNT(DISTINCT created_by) as terminal_count
FROM {AliyunLogStore.workflow_execution_logstore}
WHERE tenant_id='{tenant_id}'
AND app_id='{app_id}'
AND triggered_from='{triggered_from}'
AND finished_at IS NOT NULL
{time_filter}
GROUP BY date
ORDER BY date
"""
try:
results = self.logstore_client.execute_sql(
sql=sql, query="*", logstore=AliyunLogStore.workflow_execution_logstore
)
response_data = []
for row in results:
response_data.append({"date": str(row.get("date", "")), "terminal_count": row.get("terminal_count", 0)})
return cast(list[DailyTerminalsStats], response_data)
except Exception:
logger.exception("Failed to get daily terminals statistics")
raise
def get_daily_token_cost_statistics(
self,
tenant_id: str,
app_id: str,
triggered_from: str,
start_date: datetime | None = None,
end_date: datetime | None = None,
timezone: str = "UTC",
) -> list[DailyTokenCostStats]:
"""
Get daily token cost statistics using optimized query.
Optimization: Use finished_at IS NOT NULL + SUM(total_tokens) (20-100x faster)
"""
logger.debug(
"get_daily_token_cost_statistics: tenant_id=%s, app_id=%s, triggered_from=%s",
tenant_id,
app_id,
triggered_from,
)
# Build time range filter
time_filter = ""
if start_date:
time_filter += f" AND __time__ >= to_unixtime(from_iso8601_timestamp('{start_date.isoformat()}'))"
if end_date:
time_filter += f" AND __time__ < to_unixtime(from_iso8601_timestamp('{end_date.isoformat()}'))"
sql = f"""
SELECT DATE(from_unixtime(__time__)) as date, SUM(total_tokens) as token_count
FROM {AliyunLogStore.workflow_execution_logstore}
WHERE tenant_id='{tenant_id}'
AND app_id='{app_id}'
AND triggered_from='{triggered_from}'
AND finished_at IS NOT NULL
{time_filter}
GROUP BY date
ORDER BY date
"""
try:
results = self.logstore_client.execute_sql(
sql=sql, query="*", logstore=AliyunLogStore.workflow_execution_logstore
)
response_data = []
for row in results:
response_data.append({"date": str(row.get("date", "")), "token_count": row.get("token_count", 0)})
return cast(list[DailyTokenCostStats], response_data)
except Exception:
logger.exception("Failed to get daily token cost statistics")
raise
def get_average_app_interaction_statistics(
self,
tenant_id: str,
app_id: str,
triggered_from: str,
start_date: datetime | None = None,
end_date: datetime | None = None,
timezone: str = "UTC",
) -> list[AverageInteractionStats]:
"""
Get average app interaction statistics using optimized query.
Optimization: Use finished_at IS NOT NULL + AVG (20-100x faster)
"""
logger.debug(
"get_average_app_interaction_statistics: tenant_id=%s, app_id=%s, triggered_from=%s",
tenant_id,
app_id,
triggered_from,
)
# Build time range filter
time_filter = ""
if start_date:
time_filter += f" AND __time__ >= to_unixtime(from_iso8601_timestamp('{start_date.isoformat()}'))"
if end_date:
time_filter += f" AND __time__ < to_unixtime(from_iso8601_timestamp('{end_date.isoformat()}'))"
sql = f"""
SELECT
AVG(sub.interactions) AS interactions,
sub.date
FROM (
SELECT
DATE(from_unixtime(__time__)) AS date,
created_by,
COUNT(DISTINCT id) AS interactions
FROM {AliyunLogStore.workflow_execution_logstore}
WHERE tenant_id='{tenant_id}'
AND app_id='{app_id}'
AND triggered_from='{triggered_from}'
AND finished_at IS NOT NULL
{time_filter}
GROUP BY date, created_by
) sub
GROUP BY sub.date
"""
try:
results = self.logstore_client.execute_sql(
sql=sql, query="*", logstore=AliyunLogStore.workflow_execution_logstore
)
response_data = []
for row in results:
response_data.append(
{
"date": str(row.get("date", "")),
"interactions": float(row.get("interactions", 0)),
}
)
return cast(list[AverageInteractionStats], response_data)
except Exception:
logger.exception("Failed to get average app interaction statistics")
raise

View File

@ -0,0 +1,164 @@
import json
import logging
import os
import time
from typing import Union
from sqlalchemy.engine import Engine
from sqlalchemy.orm import sessionmaker
from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository
from core.workflow.entities import WorkflowExecution
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from extensions.logstore.aliyun_logstore import AliyunLogStore
from libs.helper import extract_tenant_id
from models import (
Account,
CreatorUserRole,
EndUser,
)
from models.enums import WorkflowRunTriggeredFrom
logger = logging.getLogger(__name__)
class LogstoreWorkflowExecutionRepository(WorkflowExecutionRepository):
def __init__(
self,
session_factory: sessionmaker | Engine,
user: Union[Account, EndUser],
app_id: str | None,
triggered_from: WorkflowRunTriggeredFrom | None,
):
"""
Initialize the repository with a SQLAlchemy sessionmaker or engine and context information.
Args:
session_factory: SQLAlchemy sessionmaker or engine for creating sessions
user: Account or EndUser object containing tenant_id, user ID, and role information
app_id: App ID for filtering by application (can be None)
triggered_from: Source of the execution trigger (DEBUGGING or APP_RUN)
"""
logger.debug(
"LogstoreWorkflowExecutionRepository.__init__: app_id=%s, triggered_from=%s", app_id, triggered_from
)
# Initialize LogStore client
# Note: Project/logstore/index initialization is done at app startup via ext_logstore
self.logstore_client = AliyunLogStore()
# Extract tenant_id from user
tenant_id = extract_tenant_id(user)
if not tenant_id:
raise ValueError("User must have a tenant_id or current_tenant_id")
self._tenant_id = tenant_id
# Store app context
self._app_id = app_id
# Extract user context
self._triggered_from = triggered_from
self._creator_user_id = user.id
# Determine user role based on user type
self._creator_user_role = CreatorUserRole.ACCOUNT if isinstance(user, Account) else CreatorUserRole.END_USER
# Initialize SQL repository for dual-write support
self.sql_repository = SQLAlchemyWorkflowExecutionRepository(session_factory, user, app_id, triggered_from)
# Control flag for dual-write (write to both LogStore and SQL database)
# Set to True to enable dual-write for safe migration, False to use LogStore only
self._enable_dual_write = os.environ.get("LOGSTORE_DUAL_WRITE_ENABLED", "true").lower() == "true"
def _to_logstore_model(self, domain_model: WorkflowExecution) -> list[tuple[str, str]]:
"""
Convert a domain model to a logstore model (List[Tuple[str, str]]).
Args:
domain_model: The domain model to convert
Returns:
The logstore model as a list of key-value tuples
"""
logger.debug(
"_to_logstore_model: id=%s, workflow_id=%s, status=%s",
domain_model.id_,
domain_model.workflow_id,
domain_model.status.value,
)
# Use values from constructor if provided
if not self._triggered_from:
raise ValueError("triggered_from is required in repository constructor")
if not self._creator_user_id:
raise ValueError("created_by is required in repository constructor")
if not self._creator_user_role:
raise ValueError("created_by_role is required in repository constructor")
# Generate log_version as nanosecond timestamp for record versioning
log_version = str(time.time_ns())
logstore_model = [
("id", domain_model.id_),
("log_version", log_version), # Add log_version field for append-only writes
("tenant_id", self._tenant_id),
("app_id", self._app_id or ""),
("workflow_id", domain_model.workflow_id),
(
"triggered_from",
self._triggered_from.value if hasattr(self._triggered_from, "value") else str(self._triggered_from),
),
("type", domain_model.workflow_type.value),
("version", domain_model.workflow_version),
("graph", json.dumps(domain_model.graph, ensure_ascii=False) if domain_model.graph else "{}"),
("inputs", json.dumps(domain_model.inputs, ensure_ascii=False) if domain_model.inputs else "{}"),
("outputs", json.dumps(domain_model.outputs, ensure_ascii=False) if domain_model.outputs else "{}"),
("status", domain_model.status.value),
("error_message", domain_model.error_message or ""),
("total_tokens", str(domain_model.total_tokens)),
("total_steps", str(domain_model.total_steps)),
("exceptions_count", str(domain_model.exceptions_count)),
(
"created_by_role",
self._creator_user_role.value
if hasattr(self._creator_user_role, "value")
else str(self._creator_user_role),
),
("created_by", self._creator_user_id),
("started_at", domain_model.started_at.isoformat() if domain_model.started_at else ""),
("finished_at", domain_model.finished_at.isoformat() if domain_model.finished_at else ""),
]
return logstore_model
def save(self, execution: WorkflowExecution) -> None:
"""
Save or update a WorkflowExecution domain entity to the logstore.
This method serves as a domain-to-logstore adapter that:
1. Converts the domain entity to its logstore representation
2. Persists the logstore model using Aliyun SLS
3. Maintains proper multi-tenancy by including tenant context during conversion
4. Optionally writes to SQL database for dual-write support (controlled by LOGSTORE_DUAL_WRITE_ENABLED)
Args:
execution: The WorkflowExecution domain entity to persist
"""
logger.debug(
"save: id=%s, workflow_id=%s, status=%s", execution.id_, execution.workflow_id, execution.status.value
)
try:
logstore_model = self._to_logstore_model(execution)
self.logstore_client.put_log(AliyunLogStore.workflow_execution_logstore, logstore_model)
logger.debug("Saved workflow execution to logstore: id=%s", execution.id_)
except Exception:
logger.exception("Failed to save workflow execution to logstore: id=%s", execution.id_)
raise
# Dual-write to SQL database if enabled (for safe migration)
if self._enable_dual_write:
try:
self.sql_repository.save(execution)
logger.debug("Dual-write: saved workflow execution to SQL database: id=%s", execution.id_)
except Exception:
logger.exception("Failed to dual-write workflow execution to SQL database: id=%s", execution.id_)
# Don't raise - LogStore write succeeded, SQL is just a backup

View File

@ -0,0 +1,366 @@
"""
LogStore implementation of the WorkflowNodeExecutionRepository.
This module provides a LogStore-based repository for WorkflowNodeExecution entities,
using Aliyun SLS LogStore with append-only writes and version control.
"""
import json
import logging
import os
import time
from collections.abc import Sequence
from datetime import datetime
from typing import Any, Union
from sqlalchemy.engine import Engine
from sqlalchemy.orm import sessionmaker
from core.model_runtime.utils.encoders import jsonable_encoder
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
from core.workflow.entities import WorkflowNodeExecution
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from core.workflow.enums import NodeType
from core.workflow.repositories.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
from extensions.logstore.aliyun_logstore import AliyunLogStore
from libs.helper import extract_tenant_id
from models import (
Account,
CreatorUserRole,
EndUser,
WorkflowNodeExecutionTriggeredFrom,
)
logger = logging.getLogger(__name__)
def _dict_to_workflow_node_execution(data: dict[str, Any]) -> WorkflowNodeExecution:
"""
Convert LogStore result dictionary to WorkflowNodeExecution domain model.
Args:
data: Dictionary from LogStore query result
Returns:
WorkflowNodeExecution domain model instance
"""
logger.debug("_dict_to_workflow_node_execution: data keys=%s", list(data.keys())[:5])
# Parse JSON fields
inputs = json.loads(data.get("inputs", "{}"))
process_data = json.loads(data.get("process_data", "{}"))
outputs = json.loads(data.get("outputs", "{}"))
metadata = json.loads(data.get("execution_metadata", "{}"))
# Convert metadata to domain enum keys
domain_metadata = {}
for k, v in metadata.items():
try:
domain_metadata[WorkflowNodeExecutionMetadataKey(k)] = v
except ValueError:
# Skip invalid metadata keys
continue
# Convert status to domain enum
status = WorkflowNodeExecutionStatus(data.get("status", "running"))
# Parse datetime fields
created_at = datetime.fromisoformat(data.get("created_at", "")) if data.get("created_at") else datetime.now()
finished_at = datetime.fromisoformat(data.get("finished_at", "")) if data.get("finished_at") else None
return WorkflowNodeExecution(
id=data.get("id", ""),
node_execution_id=data.get("node_execution_id"),
workflow_id=data.get("workflow_id", ""),
workflow_execution_id=data.get("workflow_run_id"),
index=int(data.get("index", 0)),
predecessor_node_id=data.get("predecessor_node_id"),
node_id=data.get("node_id", ""),
node_type=NodeType(data.get("node_type", "start")),
title=data.get("title", ""),
inputs=inputs,
process_data=process_data,
outputs=outputs,
status=status,
error=data.get("error"),
elapsed_time=float(data.get("elapsed_time", 0.0)),
metadata=domain_metadata,
created_at=created_at,
finished_at=finished_at,
)
class LogstoreWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository):
"""
LogStore implementation of the WorkflowNodeExecutionRepository interface.
This implementation uses Aliyun SLS LogStore with an append-only write strategy:
- Each save() operation appends a new record with a version timestamp
- Updates are simulated by writing new records with higher version numbers
- Queries retrieve the latest version using finished_at IS NOT NULL filter
- Multi-tenancy is maintained through tenant_id filtering
Version Strategy:
version = time.time_ns() # Nanosecond timestamp for unique ordering
"""
def __init__(
self,
session_factory: sessionmaker | Engine,
user: Union[Account, EndUser],
app_id: str | None,
triggered_from: WorkflowNodeExecutionTriggeredFrom | None,
):
"""
Initialize the repository with a SQLAlchemy sessionmaker or engine and context information.
Args:
session_factory: SQLAlchemy sessionmaker or engine for creating sessions
user: Account or EndUser object containing tenant_id, user ID, and role information
app_id: App ID for filtering by application (can be None)
triggered_from: Source of the execution trigger (SINGLE_STEP or WORKFLOW_RUN)
"""
logger.debug(
"LogstoreWorkflowNodeExecutionRepository.__init__: app_id=%s, triggered_from=%s", app_id, triggered_from
)
# Initialize LogStore client
self.logstore_client = AliyunLogStore()
# Extract tenant_id from user
tenant_id = extract_tenant_id(user)
if not tenant_id:
raise ValueError("User must have a tenant_id or current_tenant_id")
self._tenant_id = tenant_id
# Store app context
self._app_id = app_id
# Extract user context
self._triggered_from = triggered_from
self._creator_user_id = user.id
# Determine user role based on user type
self._creator_user_role = CreatorUserRole.ACCOUNT if isinstance(user, Account) else CreatorUserRole.END_USER
# Initialize SQL repository for dual-write support
self.sql_repository = SQLAlchemyWorkflowNodeExecutionRepository(session_factory, user, app_id, triggered_from)
# Control flag for dual-write (write to both LogStore and SQL database)
# Set to True to enable dual-write for safe migration, False to use LogStore only
self._enable_dual_write = os.environ.get("LOGSTORE_DUAL_WRITE_ENABLED", "true").lower() == "true"
def _to_logstore_model(self, domain_model: WorkflowNodeExecution) -> Sequence[tuple[str, str]]:
logger.debug(
"_to_logstore_model: id=%s, node_id=%s, status=%s",
domain_model.id,
domain_model.node_id,
domain_model.status.value,
)
if not self._triggered_from:
raise ValueError("triggered_from is required in repository constructor")
if not self._creator_user_id:
raise ValueError("created_by is required in repository constructor")
if not self._creator_user_role:
raise ValueError("created_by_role is required in repository constructor")
# Generate log_version as nanosecond timestamp for record versioning
log_version = str(time.time_ns())
json_converter = WorkflowRuntimeTypeConverter()
logstore_model = [
("id", domain_model.id),
("log_version", log_version), # Add log_version field for append-only writes
("tenant_id", self._tenant_id),
("app_id", self._app_id or ""),
("workflow_id", domain_model.workflow_id),
(
"triggered_from",
self._triggered_from.value if hasattr(self._triggered_from, "value") else str(self._triggered_from),
),
("workflow_run_id", domain_model.workflow_execution_id or ""),
("index", str(domain_model.index)),
("predecessor_node_id", domain_model.predecessor_node_id or ""),
("node_execution_id", domain_model.node_execution_id or ""),
("node_id", domain_model.node_id),
("node_type", domain_model.node_type.value),
("title", domain_model.title),
(
"inputs",
json.dumps(json_converter.to_json_encodable(domain_model.inputs), ensure_ascii=False)
if domain_model.inputs
else "{}",
),
(
"process_data",
json.dumps(json_converter.to_json_encodable(domain_model.process_data), ensure_ascii=False)
if domain_model.process_data
else "{}",
),
(
"outputs",
json.dumps(json_converter.to_json_encodable(domain_model.outputs), ensure_ascii=False)
if domain_model.outputs
else "{}",
),
("status", domain_model.status.value),
("error", domain_model.error or ""),
("elapsed_time", str(domain_model.elapsed_time)),
(
"execution_metadata",
json.dumps(jsonable_encoder(domain_model.metadata), ensure_ascii=False)
if domain_model.metadata
else "{}",
),
("created_at", domain_model.created_at.isoformat() if domain_model.created_at else ""),
("created_by_role", self._creator_user_role.value),
("created_by", self._creator_user_id),
("finished_at", domain_model.finished_at.isoformat() if domain_model.finished_at else ""),
]
return logstore_model
def save(self, execution: WorkflowNodeExecution) -> None:
"""
Save or update a NodeExecution domain entity to LogStore.
This method serves as a domain-to-logstore adapter that:
1. Converts the domain entity to its logstore representation
2. Appends a new record with a log_version timestamp
3. Maintains proper multi-tenancy by including tenant context during conversion
4. Optionally writes to SQL database for dual-write support (controlled by LOGSTORE_DUAL_WRITE_ENABLED)
Each save operation creates a new record. Updates are simulated by writing
new records with higher log_version numbers.
Args:
execution: The NodeExecution domain entity to persist
"""
logger.debug(
"save: id=%s, node_execution_id=%s, status=%s",
execution.id,
execution.node_execution_id,
execution.status.value,
)
try:
logstore_model = self._to_logstore_model(execution)
self.logstore_client.put_log(AliyunLogStore.workflow_node_execution_logstore, logstore_model)
logger.debug(
"Saved node execution to LogStore: id=%s, node_execution_id=%s, status=%s",
execution.id,
execution.node_execution_id,
execution.status.value,
)
except Exception:
logger.exception(
"Failed to save node execution to LogStore: id=%s, node_execution_id=%s",
execution.id,
execution.node_execution_id,
)
raise
# Dual-write to SQL database if enabled (for safe migration)
if self._enable_dual_write:
try:
self.sql_repository.save(execution)
logger.debug("Dual-write: saved node execution to SQL database: id=%s", execution.id)
except Exception:
logger.exception("Failed to dual-write node execution to SQL database: id=%s", execution.id)
# Don't raise - LogStore write succeeded, SQL is just a backup
def save_execution_data(self, execution: WorkflowNodeExecution) -> None:
"""
Save or update the inputs, process_data, or outputs associated with a specific
node_execution record.
For LogStore implementation, this is similar to save() since we always write
complete records. We append a new record with updated data fields.
Args:
execution: The NodeExecution instance with data to save
"""
logger.debug("save_execution_data: id=%s, node_execution_id=%s", execution.id, execution.node_execution_id)
# In LogStore, we simply write a new complete record with the data
# The log_version timestamp will ensure this is treated as the latest version
self.save(execution)
def get_by_workflow_run(
self,
workflow_run_id: str,
order_config: OrderConfig | None = None,
) -> Sequence[WorkflowNodeExecution]:
"""
Retrieve all NodeExecution instances for a specific workflow run.
Uses LogStore SQL query with finished_at IS NOT NULL filter for deduplication.
This ensures we only get the final version of each node execution.
Args:
workflow_run_id: The workflow run ID
order_config: Optional configuration for ordering results
order_config.order_by: List of fields to order by (e.g., ["index", "created_at"])
order_config.order_direction: Direction to order ("asc" or "desc")
Returns:
A list of NodeExecution instances
Note:
This method filters by finished_at IS NOT NULL to avoid duplicates from
version updates. For complete history including intermediate states,
a different query strategy would be needed.
"""
logger.debug("get_by_workflow_run: workflow_run_id=%s, order_config=%s", workflow_run_id, order_config)
# Build SQL query with deduplication using finished_at IS NOT NULL
# This optimization avoids window functions for common case where we only
# want the final state of each node execution
# Build ORDER BY clause
order_clause = ""
if order_config and order_config.order_by:
order_fields = []
for field in order_config.order_by:
# Map domain field names to logstore field names if needed
field_name = field
if order_config.order_direction == "desc":
order_fields.append(f"{field_name} DESC")
else:
order_fields.append(f"{field_name} ASC")
if order_fields:
order_clause = "ORDER BY " + ", ".join(order_fields)
sql = f"""
SELECT *
FROM {AliyunLogStore.workflow_node_execution_logstore}
WHERE workflow_run_id='{workflow_run_id}'
AND tenant_id='{self._tenant_id}'
AND finished_at IS NOT NULL
"""
if self._app_id:
sql += f" AND app_id='{self._app_id}'"
if order_clause:
sql += f" {order_clause}"
try:
# Execute SQL query
results = self.logstore_client.execute_sql(
sql=sql,
query="*",
logstore=AliyunLogStore.workflow_node_execution_logstore,
)
# Convert LogStore results to WorkflowNodeExecution domain models
executions = []
for row in results:
try:
execution = _dict_to_workflow_node_execution(row)
executions.append(execution)
except Exception as e:
logger.warning("Failed to convert row to WorkflowNodeExecution: %s, row=%s", e, row)
continue
return executions
except Exception:
logger.exception("Failed to retrieve node executions from LogStore: workflow_run_id=%s", workflow_run_id)
raise

View File

@ -87,15 +87,16 @@ class OpenDALStorage(BaseStorage):
if not self.exists(path):
raise FileNotFoundError("Path not found")
all_files = self.op.scan(path=path)
# Use the new OpenDAL 0.46.0+ API with recursive listing
lister = self.op.list(path, recursive=True)
if files and directories:
logger.debug("files and directories on %s scanned", path)
return [f.path for f in all_files]
return [entry.path for entry in lister]
if files:
logger.debug("files on %s scanned", path)
return [f.path for f in all_files if not f.path.endswith("/")]
return [entry.path for entry in lister if not entry.metadata.is_dir]
elif directories:
logger.debug("directories on %s scanned", path)
return [f.path for f in all_files if f.path.endswith("/")]
return [entry.path for entry in lister if entry.metadata.is_dir]
else:
raise ValueError("At least one of files or directories must be True")

66
api/libs/encryption.py Normal file
View File

@ -0,0 +1,66 @@
"""
Field Encoding/Decoding Utilities
Provides Base64 decoding for sensitive fields (password, verification code)
received from the frontend.
Note: This uses Base64 encoding for obfuscation, not cryptographic encryption.
Real security relies on HTTPS for transport layer encryption.
"""
import base64
import logging
logger = logging.getLogger(__name__)
class FieldEncryption:
"""Handle decoding of sensitive fields during transmission"""
@classmethod
def decrypt_field(cls, encoded_text: str) -> str | None:
"""
Decode Base64 encoded field from frontend.
Args:
encoded_text: Base64 encoded text from frontend
Returns:
Decoded plaintext, or None if decoding fails
"""
try:
# Decode base64
decoded_bytes = base64.b64decode(encoded_text)
decoded_text = decoded_bytes.decode("utf-8")
logger.debug("Field decoding successful")
return decoded_text
except Exception:
# Decoding failed - return None to trigger error in caller
return None
@classmethod
def decrypt_password(cls, encrypted_password: str) -> str | None:
"""
Decrypt password field
Args:
encrypted_password: Encrypted password from frontend
Returns:
Decrypted password or None if decryption fails
"""
return cls.decrypt_field(encrypted_password)
@classmethod
def decrypt_verification_code(cls, encrypted_code: str) -> str | None:
"""
Decrypt verification code field
Args:
encrypted_code: Encrypted code from frontend
Returns:
Decrypted code or None if decryption fails
"""
return cls.decrypt_field(encrypted_code)

View File

@ -11,6 +11,7 @@ from collections.abc import Generator, Mapping
from datetime import datetime
from hashlib import sha256
from typing import TYPE_CHECKING, Annotated, Any, Optional, Union, cast
from uuid import UUID
from zoneinfo import available_timezones
from flask import Response, stream_with_context
@ -119,6 +120,19 @@ def uuid_value(value: Any) -> str:
raise ValueError(error)
def normalize_uuid(value: str | UUID) -> str:
if not value:
return ""
try:
return uuid_value(value)
except ValueError as exc:
raise ValueError("must be a valid UUID") from exc
UUIDStrOrEmpty = Annotated[str, AfterValidator(normalize_uuid)]
def alphanumeric(value: str):
# check if the value is alphanumeric and underlined
if re.match(r"^[a-zA-Z0-9_]+$", value):

View File

@ -0,0 +1,31 @@
"""add type column not null default tool
Revision ID: 03ea244985ce
Revises: d57accd375ae
Create Date: 2025-12-16 18:17:12.193877
"""
from alembic import op
import models as models
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = '03ea244985ce'
down_revision = 'd57accd375ae'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('pipeline_recommended_plugins', schema=None) as batch_op:
batch_op.add_column(sa.Column('type', sa.String(length=50), server_default=sa.text("'tool'"), nullable=False))
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('pipeline_recommended_plugins', schema=None) as batch_op:
batch_op.drop_column('type')
# ### end Alembic commands ###

View File

@ -1532,6 +1532,7 @@ class PipelineRecommendedPlugin(TypeBase):
)
plugin_id: Mapped[str] = mapped_column(LongText, nullable=False)
provider_name: Mapped[str] = mapped_column(LongText, nullable=False)
type: Mapped[str] = mapped_column(sa.String(50), nullable=False, server_default=sa.text("'tool'"))
position: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0)
active: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, default=True)
created_at: Mapped[datetime] = mapped_column(

View File

@ -4,6 +4,7 @@ version = "1.11.1"
requires-python = ">=3.11,<3.13"
dependencies = [
"aliyun-log-python-sdk~=0.9.37",
"arize-phoenix-otel~=0.9.2",
"azure-identity==1.16.1",
"beautifulsoup4==4.12.2",
@ -31,6 +32,7 @@ dependencies = [
"httpx[socks]~=0.27.0",
"jieba==0.42.1",
"json-repair>=0.41.1",
"jsonschema>=4.25.1",
"langfuse~=2.51.3",
"langsmith~=0.1.77",
"markdown~=3.5.1",
@ -91,7 +93,6 @@ dependencies = [
"weaviate-client==4.17.0",
"apscheduler>=3.11.0",
"weave>=0.52.16",
"jsonschema>=4.25.1",
]
# Before adding new dependency, consider place it in
# alphabet order (a-z) and suitable group.

View File

@ -1,8 +1,12 @@
import logging
import os
from collections.abc import Sequence
from typing import Literal
import httpx
from pydantic import TypeAdapter
from tenacity import retry, retry_if_exception_type, stop_before_delay, wait_fixed
from typing_extensions import TypedDict
from werkzeug.exceptions import InternalServerError
from enums.cloud_plan import CloudPlan
@ -11,6 +15,15 @@ from extensions.ext_redis import redis_client
from libs.helper import RateLimiter
from models import Account, TenantAccountJoin, TenantAccountRole
logger = logging.getLogger(__name__)
class SubscriptionPlan(TypedDict):
"""Tenant subscriptionplan information."""
plan: str
expiration_date: int
class BillingService:
base_url = os.environ.get("BILLING_API_URL", "BILLING_API_URL")
@ -239,3 +252,39 @@ class BillingService:
def sync_partner_tenants_bindings(cls, account_id: str, partner_key: str, click_id: str):
payload = {"account_id": account_id, "click_id": click_id}
return cls._send_request("PUT", f"/partners/{partner_key}/tenants", json=payload)
@classmethod
def get_plan_bulk(cls, tenant_ids: Sequence[str]) -> dict[str, SubscriptionPlan]:
"""
Bulk fetch billing subscription plan via billing API.
Payload: {"tenant_ids": ["t1", "t2", ...]} (max 200 per request)
Returns:
Mapping of tenant_id -> {plan: str, expiration_date: int}
"""
results: dict[str, SubscriptionPlan] = {}
subscription_adapter = TypeAdapter(SubscriptionPlan)
chunk_size = 200
for i in range(0, len(tenant_ids), chunk_size):
chunk = tenant_ids[i : i + chunk_size]
try:
resp = cls._send_request("POST", "/subscription/plan/batch", json={"tenant_ids": chunk})
data = resp.get("data", {})
for tenant_id, plan in data.items():
subscription_plan = subscription_adapter.validate_python(plan)
results[tenant_id] = subscription_plan
except Exception:
logger.exception("Failed to fetch billing info batch for tenants: %s", chunk)
continue
return results
@classmethod
def get_expired_subscription_cleanup_whitelist(cls) -> Sequence[str]:
resp = cls._send_request("GET", "/subscription/cleanup/whitelist")
data = resp.get("data", [])
tenant_whitelist = []
for item in data:
tenant_whitelist.append(item["tenant_id"])
return tenant_whitelist

View File

@ -23,7 +23,7 @@ class RagPipelineDatasetCreateEntity(BaseModel):
description: str
icon_info: IconInfo
permission: str
partial_member_list: list[str] | None = None
partial_member_list: list[dict[str, str]] | None = None
yaml_content: str | None = None

View File

@ -1248,14 +1248,13 @@ class RagPipelineService:
session.commit()
return workflow_node_execution_db_model
def get_recommended_plugins(self) -> dict:
def get_recommended_plugins(self, type: str) -> dict:
# Query active recommended plugins
pipeline_recommended_plugins = (
db.session.query(PipelineRecommendedPlugin)
.where(PipelineRecommendedPlugin.active == True)
.order_by(PipelineRecommendedPlugin.position.asc())
.all()
)
query = db.session.query(PipelineRecommendedPlugin).where(PipelineRecommendedPlugin.active == True)
if type and type != "all":
query = query.where(PipelineRecommendedPlugin.type == type)
pipeline_recommended_plugins = query.order_by(PipelineRecommendedPlugin.position.asc()).all()
if not pipeline_recommended_plugins:
return {

View File

@ -410,9 +410,12 @@ class VariableTruncator(BaseTruncator):
@overload
def _truncate_json_primitives(self, val: None, target_size: int) -> _PartResult[None]: ...
@overload
def _truncate_json_primitives(self, val: File, target_size: int) -> _PartResult[File]: ...
def _truncate_json_primitives(
self,
val: UpdatedVariable | str | list[object] | dict[str, object] | bool | int | float | None,
val: UpdatedVariable | File | str | list[object] | dict[str, object] | bool | int | float | None,
target_size: int,
) -> _PartResult[Any]:
"""Truncate a value within an object to fit within budget."""
@ -425,6 +428,9 @@ class VariableTruncator(BaseTruncator):
return self._truncate_array(val, target_size)
elif isinstance(val, dict):
return self._truncate_object(val, target_size)
elif isinstance(val, File):
# File objects should not be truncated, return as-is
return _PartResult(val, self.calculate_json_size(val), False)
elif val is None or isinstance(val, (bool, int, float)):
return _PartResult(val, self.calculate_json_size(val), False)
else:

View File

@ -113,16 +113,31 @@ class TestShardedRedisBroadcastChannelIntegration:
topic = broadcast_channel.topic(topic_name)
producer = topic.as_producer()
subscriptions = [topic.subscribe() for _ in range(subscriber_count)]
ready_events = [threading.Event() for _ in range(subscriber_count)]
def producer_thread():
time.sleep(0.2) # Allow all subscribers to connect
deadline = time.time() + 5.0
for ev in ready_events:
remaining = deadline - time.time()
if remaining <= 0:
break
if not ev.wait(timeout=max(0.0, remaining)):
pytest.fail("subscriber did not become ready before publish deadline")
producer.publish(message)
time.sleep(0.2)
for sub in subscriptions:
sub.close()
def consumer_thread(subscription: Subscription) -> list[bytes]:
def consumer_thread(subscription: Subscription, ready_event: threading.Event) -> list[bytes]:
received_msgs = []
# Prime subscription so the underlying Pub/Sub listener thread starts before publishing
try:
_ = subscription.receive(0.01)
except SubscriptionClosedError:
return received_msgs
finally:
ready_event.set()
while True:
try:
msg = subscription.receive(0.1)
@ -137,7 +152,10 @@ class TestShardedRedisBroadcastChannelIntegration:
with ThreadPoolExecutor(max_workers=subscriber_count + 1) as executor:
producer_future = executor.submit(producer_thread)
consumer_futures = [executor.submit(consumer_thread, subscription) for subscription in subscriptions]
consumer_futures = [
executor.submit(consumer_thread, subscription, ready_events[idx])
for idx, subscription in enumerate(subscriptions)
]
producer_future.result(timeout=10.0)
msgs_by_consumers = []

View File

@ -1,5 +1,6 @@
"""Test authentication security to prevent user enumeration."""
import base64
from unittest.mock import MagicMock, patch
import pytest
@ -11,6 +12,11 @@ from controllers.console.auth.error import AuthenticationFailedError
from controllers.console.auth.login import LoginApi
def encode_password(password: str) -> str:
"""Helper to encode password as Base64 for testing."""
return base64.b64encode(password.encode("utf-8")).decode()
class TestAuthenticationSecurity:
"""Test authentication endpoints for security against user enumeration."""
@ -42,7 +48,9 @@ class TestAuthenticationSecurity:
# Act
with self.app.test_request_context(
"/login", method="POST", json={"email": "nonexistent@example.com", "password": "WrongPass123!"}
"/login",
method="POST",
json={"email": "nonexistent@example.com", "password": encode_password("WrongPass123!")},
):
login_api = LoginApi()
@ -72,7 +80,9 @@ class TestAuthenticationSecurity:
# Act
with self.app.test_request_context(
"/login", method="POST", json={"email": "existing@example.com", "password": "WrongPass123!"}
"/login",
method="POST",
json={"email": "existing@example.com", "password": encode_password("WrongPass123!")},
):
login_api = LoginApi()
@ -104,7 +114,9 @@ class TestAuthenticationSecurity:
# Act
with self.app.test_request_context(
"/login", method="POST", json={"email": "nonexistent@example.com", "password": "WrongPass123!"}
"/login",
method="POST",
json={"email": "nonexistent@example.com", "password": encode_password("WrongPass123!")},
):
login_api = LoginApi()

View File

@ -8,6 +8,7 @@ This module tests the email code login mechanism including:
- Workspace creation for new users
"""
import base64
from unittest.mock import MagicMock, patch
import pytest
@ -25,6 +26,11 @@ from controllers.console.error import (
from services.errors.account import AccountRegisterError
def encode_code(code: str) -> str:
"""Helper to encode verification code as Base64 for testing."""
return base64.b64encode(code.encode("utf-8")).decode()
class TestEmailCodeLoginSendEmailApi:
"""Test cases for sending email verification codes."""
@ -290,7 +296,7 @@ class TestEmailCodeLoginApi:
with app.test_request_context(
"/email-code-login/validity",
method="POST",
json={"email": "test@example.com", "code": "123456", "token": "valid_token"},
json={"email": "test@example.com", "code": encode_code("123456"), "token": "valid_token"},
):
api = EmailCodeLoginApi()
response = api.post()
@ -339,7 +345,12 @@ class TestEmailCodeLoginApi:
with app.test_request_context(
"/email-code-login/validity",
method="POST",
json={"email": "newuser@example.com", "code": "123456", "token": "valid_token", "language": "en-US"},
json={
"email": "newuser@example.com",
"code": encode_code("123456"),
"token": "valid_token",
"language": "en-US",
},
):
api = EmailCodeLoginApi()
response = api.post()
@ -365,7 +376,7 @@ class TestEmailCodeLoginApi:
with app.test_request_context(
"/email-code-login/validity",
method="POST",
json={"email": "test@example.com", "code": "123456", "token": "invalid_token"},
json={"email": "test@example.com", "code": encode_code("123456"), "token": "invalid_token"},
):
api = EmailCodeLoginApi()
with pytest.raises(InvalidTokenError):
@ -388,7 +399,7 @@ class TestEmailCodeLoginApi:
with app.test_request_context(
"/email-code-login/validity",
method="POST",
json={"email": "different@example.com", "code": "123456", "token": "token"},
json={"email": "different@example.com", "code": encode_code("123456"), "token": "token"},
):
api = EmailCodeLoginApi()
with pytest.raises(InvalidEmailError):
@ -411,7 +422,7 @@ class TestEmailCodeLoginApi:
with app.test_request_context(
"/email-code-login/validity",
method="POST",
json={"email": "test@example.com", "code": "wrong_code", "token": "token"},
json={"email": "test@example.com", "code": encode_code("wrong_code"), "token": "token"},
):
api = EmailCodeLoginApi()
with pytest.raises(EmailCodeError):
@ -497,7 +508,7 @@ class TestEmailCodeLoginApi:
with app.test_request_context(
"/email-code-login/validity",
method="POST",
json={"email": "test@example.com", "code": "123456", "token": "token"},
json={"email": "test@example.com", "code": encode_code("123456"), "token": "token"},
):
api = EmailCodeLoginApi()
with pytest.raises(WorkspacesLimitExceeded):
@ -539,7 +550,7 @@ class TestEmailCodeLoginApi:
with app.test_request_context(
"/email-code-login/validity",
method="POST",
json={"email": "test@example.com", "code": "123456", "token": "token"},
json={"email": "test@example.com", "code": encode_code("123456"), "token": "token"},
):
api = EmailCodeLoginApi()
with pytest.raises(NotAllowedCreateWorkspace):

View File

@ -8,6 +8,7 @@ This module tests the core authentication endpoints including:
- Account status validation
"""
import base64
from unittest.mock import MagicMock, patch
import pytest
@ -28,6 +29,11 @@ from controllers.console.error import (
from services.errors.account import AccountLoginError, AccountPasswordError
def encode_password(password: str) -> str:
"""Helper to encode password as Base64 for testing."""
return base64.b64encode(password.encode("utf-8")).decode()
class TestLoginApi:
"""Test cases for the LoginApi endpoint."""
@ -106,7 +112,9 @@ class TestLoginApi:
# Act
with app.test_request_context(
"/login", method="POST", json={"email": "test@example.com", "password": "ValidPass123!"}
"/login",
method="POST",
json={"email": "test@example.com", "password": encode_password("ValidPass123!")},
):
login_api = LoginApi()
response = login_api.post()
@ -158,7 +166,11 @@ class TestLoginApi:
with app.test_request_context(
"/login",
method="POST",
json={"email": "test@example.com", "password": "ValidPass123!", "invite_token": "valid_token"},
json={
"email": "test@example.com",
"password": encode_password("ValidPass123!"),
"invite_token": "valid_token",
},
):
login_api = LoginApi()
response = login_api.post()
@ -186,7 +198,7 @@ class TestLoginApi:
# Act & Assert
with app.test_request_context(
"/login", method="POST", json={"email": "test@example.com", "password": "password"}
"/login", method="POST", json={"email": "test@example.com", "password": encode_password("password")}
):
login_api = LoginApi()
with pytest.raises(EmailPasswordLoginLimitError):
@ -209,7 +221,7 @@ class TestLoginApi:
# Act & Assert
with app.test_request_context(
"/login", method="POST", json={"email": "frozen@example.com", "password": "password"}
"/login", method="POST", json={"email": "frozen@example.com", "password": encode_password("password")}
):
login_api = LoginApi()
with pytest.raises(AccountInFreezeError):
@ -246,7 +258,7 @@ class TestLoginApi:
# Act & Assert
with app.test_request_context(
"/login", method="POST", json={"email": "test@example.com", "password": "WrongPass123!"}
"/login", method="POST", json={"email": "test@example.com", "password": encode_password("WrongPass123!")}
):
login_api = LoginApi()
with pytest.raises(AuthenticationFailedError):
@ -277,7 +289,7 @@ class TestLoginApi:
# Act & Assert
with app.test_request_context(
"/login", method="POST", json={"email": "banned@example.com", "password": "ValidPass123!"}
"/login", method="POST", json={"email": "banned@example.com", "password": encode_password("ValidPass123!")}
):
login_api = LoginApi()
with pytest.raises(AccountBannedError):
@ -322,7 +334,7 @@ class TestLoginApi:
# Act & Assert
with app.test_request_context(
"/login", method="POST", json={"email": "test@example.com", "password": "ValidPass123!"}
"/login", method="POST", json={"email": "test@example.com", "password": encode_password("ValidPass123!")}
):
login_api = LoginApi()
with pytest.raises(WorkspacesLimitExceeded):
@ -349,7 +361,11 @@ class TestLoginApi:
with app.test_request_context(
"/login",
method="POST",
json={"email": "different@example.com", "password": "ValidPass123!", "invite_token": "token"},
json={
"email": "different@example.com",
"password": encode_password("ValidPass123!"),
"invite_token": "token",
},
):
login_api = LoginApi()
with pytest.raises(InvalidEmailError):

View File

@ -0,0 +1,420 @@
from types import SimpleNamespace
from unittest.mock import ANY, Mock, patch
import pytest
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import ChatAppGenerateEntity
from core.app.entities.queue_entities import (
QueueAgentMessageEvent,
QueueErrorEvent,
QueueLLMChunkEvent,
QueueMessageEndEvent,
QueueMessageFileEvent,
QueuePingEvent,
)
from core.app.entities.task_entities import (
EasyUITaskState,
ErrorStreamResponse,
MessageEndStreamResponse,
MessageFileStreamResponse,
MessageReplaceStreamResponse,
MessageStreamResponse,
PingStreamResponse,
StreamEvent,
)
from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline
from core.base.tts import AppGeneratorTTSPublisher
from core.model_runtime.entities.llm_entities import LLMResult as RuntimeLLMResult
from core.model_runtime.entities.message_entities import TextPromptMessageContent
from core.ops.ops_trace_manager import TraceQueueManager
from models.model import AppMode
class TestEasyUIBasedGenerateTaskPipelineProcessStreamResponse:
"""Test cases for EasyUIBasedGenerateTaskPipeline._process_stream_response method."""
@pytest.fixture
def mock_application_generate_entity(self):
"""Create a mock application generate entity."""
entity = Mock(spec=ChatAppGenerateEntity)
entity.task_id = "test-task-id"
entity.app_id = "test-app-id"
# minimal app_config used by pipeline internals
entity.app_config = SimpleNamespace(
tenant_id="test-tenant-id",
app_id="test-app-id",
app_mode=AppMode.CHAT,
app_model_config_dict={},
additional_features=None,
sensitive_word_avoidance=None,
)
# minimal model_conf for LLMResult init
entity.model_conf = SimpleNamespace(
model="test-model",
provider_model_bundle=SimpleNamespace(model_type_instance=Mock()),
credentials={},
)
return entity
@pytest.fixture
def mock_queue_manager(self):
"""Create a mock queue manager."""
manager = Mock(spec=AppQueueManager)
return manager
@pytest.fixture
def mock_message_cycle_manager(self):
"""Create a mock message cycle manager."""
manager = Mock()
manager.get_message_event_type.return_value = StreamEvent.MESSAGE
manager.message_to_stream_response.return_value = Mock(spec=MessageStreamResponse)
manager.message_file_to_stream_response.return_value = Mock(spec=MessageFileStreamResponse)
manager.message_replace_to_stream_response.return_value = Mock(spec=MessageReplaceStreamResponse)
manager.handle_retriever_resources = Mock()
manager.handle_annotation_reply.return_value = None
return manager
@pytest.fixture
def mock_conversation(self):
"""Create a mock conversation."""
conversation = Mock()
conversation.id = "test-conversation-id"
conversation.mode = "chat"
return conversation
@pytest.fixture
def mock_message(self):
"""Create a mock message."""
message = Mock()
message.id = "test-message-id"
message.created_at = Mock()
message.created_at.timestamp.return_value = 1234567890
return message
@pytest.fixture
def mock_task_state(self):
"""Create a mock task state."""
task_state = Mock(spec=EasyUITaskState)
# Create LLM result mock
llm_result = Mock(spec=RuntimeLLMResult)
llm_result.prompt_messages = []
llm_result.message = Mock()
llm_result.message.content = ""
task_state.llm_result = llm_result
task_state.answer = ""
return task_state
@pytest.fixture
def pipeline(
self,
mock_application_generate_entity,
mock_queue_manager,
mock_conversation,
mock_message,
mock_message_cycle_manager,
mock_task_state,
):
"""Create an EasyUIBasedGenerateTaskPipeline instance with mocked dependencies."""
with patch(
"core.app.task_pipeline.easy_ui_based_generate_task_pipeline.EasyUITaskState", return_value=mock_task_state
):
pipeline = EasyUIBasedGenerateTaskPipeline(
application_generate_entity=mock_application_generate_entity,
queue_manager=mock_queue_manager,
conversation=mock_conversation,
message=mock_message,
stream=True,
)
pipeline._message_cycle_manager = mock_message_cycle_manager
pipeline._task_state = mock_task_state
return pipeline
def test_get_message_event_type_called_once_when_first_llm_chunk_arrives(
self, pipeline, mock_message_cycle_manager
):
"""Expect get_message_event_type to be called when processing the first LLM chunk event."""
# Setup a minimal LLM chunk event
chunk = Mock()
chunk.delta.message.content = "hi"
chunk.prompt_messages = []
llm_chunk_event = Mock(spec=QueueLLMChunkEvent)
llm_chunk_event.chunk = chunk
mock_queue_message = Mock()
mock_queue_message.event = llm_chunk_event
pipeline.queue_manager.listen.return_value = [mock_queue_message]
# Execute
list(pipeline._process_stream_response(publisher=None, trace_manager=None))
# Assert
mock_message_cycle_manager.get_message_event_type.assert_called_once_with(message_id="test-message-id")
def test_llm_chunk_event_with_text_content(self, pipeline, mock_message_cycle_manager, mock_task_state):
"""Test handling of LLM chunk events with text content."""
# Setup
chunk = Mock()
chunk.delta.message.content = "Hello, world!"
chunk.prompt_messages = []
llm_chunk_event = Mock(spec=QueueLLMChunkEvent)
llm_chunk_event.chunk = chunk
mock_queue_message = Mock()
mock_queue_message.event = llm_chunk_event
pipeline.queue_manager.listen.return_value = [mock_queue_message]
mock_message_cycle_manager.get_message_event_type.return_value = StreamEvent.MESSAGE
# Execute
responses = list(pipeline._process_stream_response(publisher=None, trace_manager=None))
# Assert
assert len(responses) == 1
mock_message_cycle_manager.message_to_stream_response.assert_called_once_with(
answer="Hello, world!", message_id="test-message-id", event_type=StreamEvent.MESSAGE
)
assert mock_task_state.llm_result.message.content == "Hello, world!"
def test_llm_chunk_event_with_list_content(self, pipeline, mock_message_cycle_manager, mock_task_state):
"""Test handling of LLM chunk events with list content."""
# Setup
text_content = Mock(spec=TextPromptMessageContent)
text_content.data = "Hello"
chunk = Mock()
chunk.delta.message.content = [text_content, " world!"]
chunk.prompt_messages = []
llm_chunk_event = Mock(spec=QueueLLMChunkEvent)
llm_chunk_event.chunk = chunk
mock_queue_message = Mock()
mock_queue_message.event = llm_chunk_event
pipeline.queue_manager.listen.return_value = [mock_queue_message]
mock_message_cycle_manager.get_message_event_type.return_value = StreamEvent.MESSAGE
# Execute
responses = list(pipeline._process_stream_response(publisher=None, trace_manager=None))
# Assert
assert len(responses) == 1
mock_message_cycle_manager.message_to_stream_response.assert_called_once_with(
answer="Hello world!", message_id="test-message-id", event_type=StreamEvent.MESSAGE
)
assert mock_task_state.llm_result.message.content == "Hello world!"
def test_agent_message_event(self, pipeline, mock_message_cycle_manager, mock_task_state):
"""Test handling of agent message events."""
# Setup
chunk = Mock()
chunk.delta.message.content = "Agent response"
agent_message_event = Mock(spec=QueueAgentMessageEvent)
agent_message_event.chunk = chunk
mock_queue_message = Mock()
mock_queue_message.event = agent_message_event
pipeline.queue_manager.listen.return_value = [mock_queue_message]
# Ensure method under assertion is a mock to track calls
pipeline._agent_message_to_stream_response = Mock(return_value=Mock())
# Execute
responses = list(pipeline._process_stream_response(publisher=None, trace_manager=None))
# Assert
assert len(responses) == 1
# Agent messages should use _agent_message_to_stream_response
pipeline._agent_message_to_stream_response.assert_called_once_with(
answer="Agent response", message_id="test-message-id"
)
def test_message_end_event(self, pipeline, mock_message_cycle_manager, mock_task_state):
"""Test handling of message end events."""
# Setup
llm_result = Mock(spec=RuntimeLLMResult)
llm_result.message = Mock()
llm_result.message.content = "Final response"
message_end_event = Mock(spec=QueueMessageEndEvent)
message_end_event.llm_result = llm_result
mock_queue_message = Mock()
mock_queue_message.event = message_end_event
pipeline.queue_manager.listen.return_value = [mock_queue_message]
pipeline._save_message = Mock()
pipeline._message_end_to_stream_response = Mock(return_value=Mock(spec=MessageEndStreamResponse))
# Patch db.engine used inside pipeline for session creation
with patch(
"core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db", new=SimpleNamespace(engine=Mock())
):
# Execute
responses = list(pipeline._process_stream_response(publisher=None, trace_manager=None))
# Assert
assert len(responses) == 1
assert mock_task_state.llm_result == llm_result
pipeline._save_message.assert_called_once()
pipeline._message_end_to_stream_response.assert_called_once()
def test_error_event(self, pipeline):
"""Test handling of error events."""
# Setup
error_event = Mock(spec=QueueErrorEvent)
error_event.error = Exception("Test error")
mock_queue_message = Mock()
mock_queue_message.event = error_event
pipeline.queue_manager.listen.return_value = [mock_queue_message]
pipeline.handle_error = Mock(return_value=Exception("Test error"))
pipeline.error_to_stream_response = Mock(return_value=Mock(spec=ErrorStreamResponse))
# Patch db.engine used inside pipeline for session creation
with patch(
"core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db", new=SimpleNamespace(engine=Mock())
):
# Execute
responses = list(pipeline._process_stream_response(publisher=None, trace_manager=None))
# Assert
assert len(responses) == 1
pipeline.handle_error.assert_called_once()
pipeline.error_to_stream_response.assert_called_once()
def test_ping_event(self, pipeline):
"""Test handling of ping events."""
# Setup
ping_event = Mock(spec=QueuePingEvent)
mock_queue_message = Mock()
mock_queue_message.event = ping_event
pipeline.queue_manager.listen.return_value = [mock_queue_message]
pipeline.ping_stream_response = Mock(return_value=Mock(spec=PingStreamResponse))
# Execute
responses = list(pipeline._process_stream_response(publisher=None, trace_manager=None))
# Assert
assert len(responses) == 1
pipeline.ping_stream_response.assert_called_once()
def test_file_event(self, pipeline, mock_message_cycle_manager):
"""Test handling of file events."""
# Setup
file_event = Mock(spec=QueueMessageFileEvent)
file_event.message_file_id = "file-id"
mock_queue_message = Mock()
mock_queue_message.event = file_event
pipeline.queue_manager.listen.return_value = [mock_queue_message]
file_response = Mock(spec=MessageFileStreamResponse)
mock_message_cycle_manager.message_file_to_stream_response.return_value = file_response
# Execute
responses = list(pipeline._process_stream_response(publisher=None, trace_manager=None))
# Assert
assert len(responses) == 1
assert responses[0] == file_response
mock_message_cycle_manager.message_file_to_stream_response.assert_called_once_with(file_event)
def test_publisher_is_called_with_messages(self, pipeline):
"""Test that publisher publishes messages when provided."""
# Setup
publisher = Mock(spec=AppGeneratorTTSPublisher)
ping_event = Mock(spec=QueuePingEvent)
mock_queue_message = Mock()
mock_queue_message.event = ping_event
pipeline.queue_manager.listen.return_value = [mock_queue_message]
pipeline.ping_stream_response = Mock(return_value=Mock(spec=PingStreamResponse))
# Execute
list(pipeline._process_stream_response(publisher=publisher, trace_manager=None))
# Assert
# Called once with message and once with None at the end
assert publisher.publish.call_count == 2
publisher.publish.assert_any_call(mock_queue_message)
publisher.publish.assert_any_call(None)
def test_trace_manager_passed_to_save_message(self, pipeline):
"""Test that trace manager is passed to _save_message."""
# Setup
trace_manager = Mock(spec=TraceQueueManager)
message_end_event = Mock(spec=QueueMessageEndEvent)
message_end_event.llm_result = None
mock_queue_message = Mock()
mock_queue_message.event = message_end_event
pipeline.queue_manager.listen.return_value = [mock_queue_message]
pipeline._save_message = Mock()
pipeline._message_end_to_stream_response = Mock(return_value=Mock(spec=MessageEndStreamResponse))
# Patch db.engine used inside pipeline for session creation
with patch(
"core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db", new=SimpleNamespace(engine=Mock())
):
# Execute
list(pipeline._process_stream_response(publisher=None, trace_manager=trace_manager))
# Assert
pipeline._save_message.assert_called_once_with(session=ANY, trace_manager=trace_manager)
def test_multiple_events_sequence(self, pipeline, mock_message_cycle_manager, mock_task_state):
"""Test handling multiple events in sequence."""
# Setup
chunk1 = Mock()
chunk1.delta.message.content = "Hello"
chunk1.prompt_messages = []
chunk2 = Mock()
chunk2.delta.message.content = " world!"
chunk2.prompt_messages = []
llm_chunk_event1 = Mock(spec=QueueLLMChunkEvent)
llm_chunk_event1.chunk = chunk1
ping_event = Mock(spec=QueuePingEvent)
llm_chunk_event2 = Mock(spec=QueueLLMChunkEvent)
llm_chunk_event2.chunk = chunk2
mock_queue_messages = [
Mock(event=llm_chunk_event1),
Mock(event=ping_event),
Mock(event=llm_chunk_event2),
]
pipeline.queue_manager.listen.return_value = mock_queue_messages
mock_message_cycle_manager.get_message_event_type.return_value = StreamEvent.MESSAGE
pipeline.ping_stream_response = Mock(return_value=Mock(spec=PingStreamResponse))
# Execute
responses = list(pipeline._process_stream_response(publisher=None, trace_manager=None))
# Assert
assert len(responses) == 3
assert mock_task_state.llm_result.message.content == "Hello world!"
# Verify calls to message_to_stream_response
assert mock_message_cycle_manager.message_to_stream_response.call_count == 2
mock_message_cycle_manager.message_to_stream_response.assert_any_call(
answer="Hello", message_id="test-message-id", event_type=StreamEvent.MESSAGE
)
mock_message_cycle_manager.message_to_stream_response.assert_any_call(
answer=" world!", message_id="test-message-id", event_type=StreamEvent.MESSAGE
)

View File

@ -0,0 +1,166 @@
"""Unit tests for the message cycle manager optimization."""
from types import SimpleNamespace
from unittest.mock import ANY, Mock, patch
import pytest
from flask import current_app
from core.app.entities.task_entities import MessageStreamResponse, StreamEvent
from core.app.task_pipeline.message_cycle_manager import MessageCycleManager
class TestMessageCycleManagerOptimization:
"""Test cases for the message cycle manager optimization that prevents N+1 queries."""
@pytest.fixture
def mock_application_generate_entity(self):
"""Create a mock application generate entity."""
entity = Mock()
entity.task_id = "test-task-id"
return entity
@pytest.fixture
def message_cycle_manager(self, mock_application_generate_entity):
"""Create a message cycle manager instance."""
task_state = Mock()
return MessageCycleManager(application_generate_entity=mock_application_generate_entity, task_state=task_state)
def test_get_message_event_type_with_message_file(self, message_cycle_manager):
"""Test get_message_event_type returns MESSAGE_FILE when message has files."""
with (
patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class,
patch("core.app.task_pipeline.message_cycle_manager.db", new=SimpleNamespace(engine=Mock())),
):
# Setup mock session and message file
mock_session = Mock()
mock_session_class.return_value.__enter__.return_value = mock_session
mock_message_file = Mock()
# Current implementation uses session.query(...).scalar()
mock_session.query.return_value.scalar.return_value = mock_message_file
# Execute
with current_app.app_context():
result = message_cycle_manager.get_message_event_type("test-message-id")
# Assert
assert result == StreamEvent.MESSAGE_FILE
mock_session.query.return_value.scalar.assert_called_once()
def test_get_message_event_type_without_message_file(self, message_cycle_manager):
"""Test get_message_event_type returns MESSAGE when message has no files."""
with (
patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class,
patch("core.app.task_pipeline.message_cycle_manager.db", new=SimpleNamespace(engine=Mock())),
):
# Setup mock session and no message file
mock_session = Mock()
mock_session_class.return_value.__enter__.return_value = mock_session
# Current implementation uses session.query(...).scalar()
mock_session.query.return_value.scalar.return_value = None
# Execute
with current_app.app_context():
result = message_cycle_manager.get_message_event_type("test-message-id")
# Assert
assert result == StreamEvent.MESSAGE
mock_session.query.return_value.scalar.assert_called_once()
def test_message_to_stream_response_with_precomputed_event_type(self, message_cycle_manager):
"""MessageCycleManager.message_to_stream_response expects a valid event_type; callers should precompute it."""
with (
patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class,
patch("core.app.task_pipeline.message_cycle_manager.db", new=SimpleNamespace(engine=Mock())),
):
# Setup mock session and message file
mock_session = Mock()
mock_session_class.return_value.__enter__.return_value = mock_session
mock_message_file = Mock()
# Current implementation uses session.query(...).scalar()
mock_session.query.return_value.scalar.return_value = mock_message_file
# Execute: compute event type once, then pass to message_to_stream_response
with current_app.app_context():
event_type = message_cycle_manager.get_message_event_type("test-message-id")
result = message_cycle_manager.message_to_stream_response(
answer="Hello world", message_id="test-message-id", event_type=event_type
)
# Assert
assert isinstance(result, MessageStreamResponse)
assert result.answer == "Hello world"
assert result.id == "test-message-id"
assert result.event == StreamEvent.MESSAGE_FILE
mock_session.query.return_value.scalar.assert_called_once()
def test_message_to_stream_response_with_event_type_skips_query(self, message_cycle_manager):
"""Test that message_to_stream_response skips database query when event_type is provided."""
with patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class:
# Execute with event_type provided
result = message_cycle_manager.message_to_stream_response(
answer="Hello world", message_id="test-message-id", event_type=StreamEvent.MESSAGE
)
# Assert
assert isinstance(result, MessageStreamResponse)
assert result.answer == "Hello world"
assert result.id == "test-message-id"
assert result.event == StreamEvent.MESSAGE
# Should not query database when event_type is provided
mock_session_class.assert_not_called()
def test_message_to_stream_response_with_from_variable_selector(self, message_cycle_manager):
"""Test message_to_stream_response with from_variable_selector parameter."""
result = message_cycle_manager.message_to_stream_response(
answer="Hello world",
message_id="test-message-id",
from_variable_selector=["var1", "var2"],
event_type=StreamEvent.MESSAGE,
)
assert isinstance(result, MessageStreamResponse)
assert result.answer == "Hello world"
assert result.id == "test-message-id"
assert result.from_variable_selector == ["var1", "var2"]
assert result.event == StreamEvent.MESSAGE
def test_optimization_usage_example(self, message_cycle_manager):
"""Test the optimization pattern that should be used by callers."""
# Step 1: Get event type once (this queries database)
with (
patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class,
patch("core.app.task_pipeline.message_cycle_manager.db", new=SimpleNamespace(engine=Mock())),
):
mock_session = Mock()
mock_session_class.return_value.__enter__.return_value = mock_session
# Current implementation uses session.query(...).scalar()
mock_session.query.return_value.scalar.return_value = None # No files
with current_app.app_context():
event_type = message_cycle_manager.get_message_event_type("test-message-id")
# Should query database once
mock_session_class.assert_called_once_with(ANY, expire_on_commit=False)
assert event_type == StreamEvent.MESSAGE
# Step 2: Use event_type for multiple calls (no additional queries)
with patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class:
mock_session_class.return_value.__enter__.return_value = Mock()
chunk1_response = message_cycle_manager.message_to_stream_response(
answer="Chunk 1", message_id="test-message-id", event_type=event_type
)
chunk2_response = message_cycle_manager.message_to_stream_response(
answer="Chunk 2", message_id="test-message-id", event_type=event_type
)
# Should not query database again
mock_session_class.assert_not_called()
assert chunk1_response.event == StreamEvent.MESSAGE
assert chunk2_response.event == StreamEvent.MESSAGE
assert chunk1_response.answer == "Chunk 1"
assert chunk2_response.answer == "Chunk 2"

View File

@ -0,0 +1,10 @@
import tempfile
from core.rag.extractor.helpers import FileEncoding, detect_file_encodings
def test_detect_file_encodings() -> None:
with tempfile.NamedTemporaryFile(mode="w+t", suffix=".txt") as temp:
temp.write("Shared data")
temp_path = temp.name
assert detect_file_encodings(temp_path) == [FileEncoding(encoding="utf_8", confidence=0.0, language="Unknown")]

View File

@ -901,6 +901,13 @@ class TestFixedRecursiveCharacterTextSplitter:
# Verify no empty chunks
assert all(len(chunk) > 0 for chunk in result)
def test_double_slash_n(self):
data = "chunk 1\n\nsubchunk 1.\nsubchunk 2.\n\n---\n\nchunk 2\n\nsubchunk 1\nsubchunk 2."
separator = "\\n\\n---\\n\\n"
splitter = FixedRecursiveCharacterTextSplitter(fixed_separator=separator)
chunks = splitter.split_text(data)
assert chunks == ["chunk 1\n\nsubchunk 1.\nsubchunk 2.", "chunk 2\n\nsubchunk 1\nsubchunk 2."]
# ============================================================================
# Test Metadata Preservation

View File

@ -1,3 +1,4 @@
import json
import time
import pytest
@ -46,14 +47,16 @@ def make_start_node(user_inputs, variables):
def test_json_object_valid_schema():
schema = {
"type": "object",
"properties": {
"age": {"type": "number"},
"name": {"type": "string"},
},
"required": ["age"],
}
schema = json.dumps(
{
"type": "object",
"properties": {
"age": {"type": "number"},
"name": {"type": "string"},
},
"required": ["age"],
}
)
variables = [
VariableEntity(
@ -65,7 +68,7 @@ def test_json_object_valid_schema():
)
]
user_inputs = {"profile": {"age": 20, "name": "Tom"}}
user_inputs = {"profile": json.dumps({"age": 20, "name": "Tom"})}
node = make_start_node(user_inputs, variables)
result = node._run()
@ -74,12 +77,23 @@ def test_json_object_valid_schema():
def test_json_object_invalid_json_string():
schema = json.dumps(
{
"type": "object",
"properties": {
"age": {"type": "number"},
"name": {"type": "string"},
},
"required": ["age", "name"],
}
)
variables = [
VariableEntity(
variable="profile",
label="profile",
type=VariableEntityType.JSON_OBJECT,
required=True,
json_schema=schema,
)
]
@ -88,38 +102,21 @@ def test_json_object_invalid_json_string():
node = make_start_node(user_inputs, variables)
with pytest.raises(ValueError, match="profile must be a JSON object"):
node._run()
@pytest.mark.parametrize("value", ["[1, 2, 3]", "123"])
def test_json_object_valid_json_but_not_object(value):
variables = [
VariableEntity(
variable="profile",
label="profile",
type=VariableEntityType.JSON_OBJECT,
required=True,
)
]
user_inputs = {"profile": value}
node = make_start_node(user_inputs, variables)
with pytest.raises(ValueError, match="profile must be a JSON object"):
with pytest.raises(ValueError, match='{"age": 20, "name": "Tom" must be a valid JSON object'):
node._run()
def test_json_object_does_not_match_schema():
schema = {
"type": "object",
"properties": {
"age": {"type": "number"},
"name": {"type": "string"},
},
"required": ["age", "name"],
}
schema = json.dumps(
{
"type": "object",
"properties": {
"age": {"type": "number"},
"name": {"type": "string"},
},
"required": ["age", "name"],
}
)
variables = [
VariableEntity(
@ -132,7 +129,7 @@ def test_json_object_does_not_match_schema():
]
# age is a string, which violates the schema (expects number)
user_inputs = {"profile": {"age": "twenty", "name": "Tom"}}
user_inputs = {"profile": json.dumps({"age": "twenty", "name": "Tom"})}
node = make_start_node(user_inputs, variables)
@ -141,14 +138,16 @@ def test_json_object_does_not_match_schema():
def test_json_object_missing_required_schema_field():
schema = {
"type": "object",
"properties": {
"age": {"type": "number"},
"name": {"type": "string"},
},
"required": ["age", "name"],
}
schema = json.dumps(
{
"type": "object",
"properties": {
"age": {"type": "number"},
"name": {"type": "string"},
},
"required": ["age", "name"],
}
)
variables = [
VariableEntity(
@ -161,7 +160,7 @@ def test_json_object_missing_required_schema_field():
]
# Missing required field "name"
user_inputs = {"profile": {"age": 20}}
user_inputs = {"profile": json.dumps({"age": 20})}
node = make_start_node(user_inputs, variables)
@ -214,7 +213,7 @@ def test_json_object_optional_variable_not_provided():
variable="profile",
label="profile",
type=VariableEntityType.JSON_OBJECT,
required=False,
required=True,
)
]
@ -223,5 +222,5 @@ def test_json_object_optional_variable_not_provided():
node = make_start_node(user_inputs, variables)
# Current implementation raises a validation error even when the variable is optional
with pytest.raises(ValueError, match="profile must be a JSON object"):
with pytest.raises(ValueError, match="profile is required in input form"):
node._run()

View File

@ -0,0 +1,150 @@
"""
Unit tests for field encoding/decoding utilities.
These tests verify Base64 encoding/decoding functionality and
proper error handling and fallback behavior.
"""
import base64
from libs.encryption import FieldEncryption
class TestDecodeField:
"""Test cases for field decoding functionality."""
def test_decode_valid_base64(self):
"""Test decoding a valid Base64 encoded string."""
plaintext = "password123"
encoded = base64.b64encode(plaintext.encode("utf-8")).decode()
result = FieldEncryption.decrypt_field(encoded)
assert result == plaintext
def test_decode_non_base64_returns_none(self):
"""Test that non-base64 input returns None."""
non_base64 = "plain-password-!@#"
result = FieldEncryption.decrypt_field(non_base64)
# Should return None (decoding failed)
assert result is None
def test_decode_unicode_text(self):
"""Test decoding Base64 encoded Unicode text."""
plaintext = "密码Test123"
encoded = base64.b64encode(plaintext.encode("utf-8")).decode()
result = FieldEncryption.decrypt_field(encoded)
assert result == plaintext
def test_decode_empty_string(self):
"""Test decoding an empty string returns empty string."""
result = FieldEncryption.decrypt_field("")
# Empty string base64 decodes to empty string
assert result == ""
def test_decode_special_characters(self):
"""Test decoding with special characters."""
plaintext = "P@ssw0rd!#$%^&*()"
encoded = base64.b64encode(plaintext.encode("utf-8")).decode()
result = FieldEncryption.decrypt_field(encoded)
assert result == plaintext
class TestDecodePassword:
"""Test cases for password decoding."""
def test_decode_password_base64(self):
"""Test decoding a Base64 encoded password."""
password = "SecureP@ssw0rd!"
encoded = base64.b64encode(password.encode("utf-8")).decode()
result = FieldEncryption.decrypt_password(encoded)
assert result == password
def test_decode_password_invalid_returns_none(self):
"""Test that invalid base64 passwords return None."""
invalid = "PlainPassword!@#"
result = FieldEncryption.decrypt_password(invalid)
# Should return None (decoding failed)
assert result is None
class TestDecodeVerificationCode:
"""Test cases for verification code decoding."""
def test_decode_code_base64(self):
"""Test decoding a Base64 encoded verification code."""
code = "789012"
encoded = base64.b64encode(code.encode("utf-8")).decode()
result = FieldEncryption.decrypt_verification_code(encoded)
assert result == code
def test_decode_code_invalid_returns_none(self):
"""Test that invalid base64 codes return None."""
invalid = "123456" # Plain 6-digit code, not base64
result = FieldEncryption.decrypt_verification_code(invalid)
# Should return None (decoding failed)
assert result is None
class TestRoundTripEncodingDecoding:
"""
Integration tests for complete encoding-decoding cycle.
These tests simulate the full frontend-to-backend flow using Base64.
"""
def test_roundtrip_password(self):
"""Test encoding and decoding a password."""
original_password = "SecureP@ssw0rd!"
# Simulate frontend encoding (Base64)
encoded = base64.b64encode(original_password.encode("utf-8")).decode()
# Backend decoding
decoded = FieldEncryption.decrypt_password(encoded)
assert decoded == original_password
def test_roundtrip_verification_code(self):
"""Test encoding and decoding a verification code."""
original_code = "123456"
# Simulate frontend encoding
encoded = base64.b64encode(original_code.encode("utf-8")).decode()
# Backend decoding
decoded = FieldEncryption.decrypt_verification_code(encoded)
assert decoded == original_code
def test_roundtrip_unicode_password(self):
"""Test encoding and decoding password with Unicode characters."""
original_password = "密码Test123!@#"
# Frontend encoding
encoded = base64.b64encode(original_password.encode("utf-8")).decode()
# Backend decoding
decoded = FieldEncryption.decrypt_password(encoded)
assert decoded == original_password
def test_roundtrip_long_password(self):
"""Test encoding and decoding a long password."""
original_password = "ThisIsAVeryLongPasswordWithLotsOfCharacters123!@#$%^&*()"
encoded = base64.b64encode(original_password.encode("utf-8")).decode()
decoded = FieldEncryption.decrypt_password(encoded)
assert decoded == original_password
def test_roundtrip_with_whitespace(self):
"""Test encoding and decoding with whitespace."""
original_password = "pass word with spaces"
encoded = base64.b64encode(original_password.encode("utf-8")).decode()
decoded = FieldEncryption.decrypt_field(encoded)
assert decoded == original_password

View File

@ -1156,6 +1156,199 @@ class TestBillingServiceEdgeCases:
assert "Only team owner or team admin can perform this action" in str(exc_info.value)
class TestBillingServiceSubscriptionOperations:
"""Unit tests for subscription operations in BillingService.
Tests cover:
- Bulk plan retrieval with chunking
- Expired subscription cleanup whitelist retrieval
"""
@pytest.fixture
def mock_send_request(self):
"""Mock _send_request method."""
with patch.object(BillingService, "_send_request") as mock:
yield mock
def test_get_plan_bulk_with_empty_list(self, mock_send_request):
"""Test bulk plan retrieval with empty tenant list."""
# Arrange
tenant_ids = []
# Act
result = BillingService.get_plan_bulk(tenant_ids)
# Assert
assert result == {}
mock_send_request.assert_not_called()
def test_get_plan_bulk_with_chunking(self, mock_send_request):
"""Test bulk plan retrieval with more than 200 tenants (chunking logic)."""
# Arrange - 250 tenants to test chunking (chunk_size = 200)
tenant_ids = [f"tenant-{i}" for i in range(250)]
# First chunk: tenants 0-199
first_chunk_response = {
"data": {f"tenant-{i}": {"plan": "sandbox", "expiration_date": 1735689600} for i in range(200)}
}
# Second chunk: tenants 200-249
second_chunk_response = {
"data": {f"tenant-{i}": {"plan": "professional", "expiration_date": 1767225600} for i in range(200, 250)}
}
mock_send_request.side_effect = [first_chunk_response, second_chunk_response]
# Act
result = BillingService.get_plan_bulk(tenant_ids)
# Assert
assert len(result) == 250
assert result["tenant-0"]["plan"] == "sandbox"
assert result["tenant-199"]["plan"] == "sandbox"
assert result["tenant-200"]["plan"] == "professional"
assert result["tenant-249"]["plan"] == "professional"
assert mock_send_request.call_count == 2
# Verify first chunk call
first_call = mock_send_request.call_args_list[0]
assert first_call[0][0] == "POST"
assert first_call[0][1] == "/subscription/plan/batch"
assert len(first_call[1]["json"]["tenant_ids"]) == 200
# Verify second chunk call
second_call = mock_send_request.call_args_list[1]
assert len(second_call[1]["json"]["tenant_ids"]) == 50
def test_get_plan_bulk_with_partial_batch_failure(self, mock_send_request):
"""Test bulk plan retrieval when one batch fails but others succeed."""
# Arrange - 250 tenants, second batch will fail
tenant_ids = [f"tenant-{i}" for i in range(250)]
# First chunk succeeds
first_chunk_response = {
"data": {f"tenant-{i}": {"plan": "sandbox", "expiration_date": 1735689600} for i in range(200)}
}
# Second chunk fails - need to create a mock that raises when called
def side_effect_func(*args, **kwargs):
if mock_send_request.call_count == 1:
return first_chunk_response
else:
raise ValueError("API error")
mock_send_request.side_effect = side_effect_func
# Act
result = BillingService.get_plan_bulk(tenant_ids)
# Assert - should only have data from first batch
assert len(result) == 200
assert result["tenant-0"]["plan"] == "sandbox"
assert result["tenant-199"]["plan"] == "sandbox"
assert "tenant-200" not in result
assert mock_send_request.call_count == 2
def test_get_plan_bulk_with_all_batches_failing(self, mock_send_request):
"""Test bulk plan retrieval when all batches fail."""
# Arrange
tenant_ids = [f"tenant-{i}" for i in range(250)]
# All chunks fail
def side_effect_func(*args, **kwargs):
raise ValueError("API error")
mock_send_request.side_effect = side_effect_func
# Act
result = BillingService.get_plan_bulk(tenant_ids)
# Assert - should return empty dict
assert result == {}
assert mock_send_request.call_count == 2
def test_get_plan_bulk_with_exactly_200_tenants(self, mock_send_request):
"""Test bulk plan retrieval with exactly 200 tenants (boundary condition)."""
# Arrange
tenant_ids = [f"tenant-{i}" for i in range(200)]
mock_send_request.return_value = {
"data": {f"tenant-{i}": {"plan": "sandbox", "expiration_date": 1735689600} for i in range(200)}
}
# Act
result = BillingService.get_plan_bulk(tenant_ids)
# Assert
assert len(result) == 200
assert mock_send_request.call_count == 1
def test_get_plan_bulk_with_empty_data_response(self, mock_send_request):
"""Test bulk plan retrieval with empty data in response."""
# Arrange
tenant_ids = ["tenant-1", "tenant-2"]
mock_send_request.return_value = {"data": {}}
# Act
result = BillingService.get_plan_bulk(tenant_ids)
# Assert
assert result == {}
def test_get_expired_subscription_cleanup_whitelist_success(self, mock_send_request):
"""Test successful retrieval of expired subscription cleanup whitelist."""
# Arrange
api_response = [
{
"created_at": "2025-10-16T01:56:17",
"tenant_id": "36bd55ec-2ea9-4d75-a9ea-1f26aeb4ffe6",
"contact": "example@dify.ai",
"id": "36bd55ec-2ea9-4d75-a9ea-1f26aeb4ffe5",
"expired_at": "2026-01-01T01:56:17",
"updated_at": "2025-10-16T01:56:17",
},
{
"created_at": "2025-10-16T02:00:00",
"tenant_id": "tenant-2",
"contact": "test@example.com",
"id": "whitelist-id-2",
"expired_at": "2026-02-01T00:00:00",
"updated_at": "2025-10-16T02:00:00",
},
{
"created_at": "2025-10-16T03:00:00",
"tenant_id": "tenant-3",
"contact": "another@example.com",
"id": "whitelist-id-3",
"expired_at": "2026-03-01T00:00:00",
"updated_at": "2025-10-16T03:00:00",
},
]
mock_send_request.return_value = {"data": api_response}
# Act
result = BillingService.get_expired_subscription_cleanup_whitelist()
# Assert - should return only tenant_ids
assert result == ["36bd55ec-2ea9-4d75-a9ea-1f26aeb4ffe6", "tenant-2", "tenant-3"]
assert len(result) == 3
assert result[0] == "36bd55ec-2ea9-4d75-a9ea-1f26aeb4ffe6"
assert result[1] == "tenant-2"
assert result[2] == "tenant-3"
mock_send_request.assert_called_once_with("GET", "/subscription/cleanup/whitelist")
def test_get_expired_subscription_cleanup_whitelist_empty_list(self, mock_send_request):
"""Test retrieval of empty cleanup whitelist."""
# Arrange
mock_send_request.return_value = {"data": []}
# Act
result = BillingService.get_expired_subscription_cleanup_whitelist()
# Assert
assert result == []
assert len(result) == 0
class TestBillingServiceIntegrationScenarios:
"""Integration-style tests simulating real-world usage scenarios.

View File

@ -518,6 +518,55 @@ class TestEdgeCases:
assert isinstance(result.result, StringSegment)
class TestTruncateJsonPrimitives:
"""Test _truncate_json_primitives method with different data types."""
@pytest.fixture
def truncator(self):
return VariableTruncator()
def test_truncate_json_primitives_file_type(self, truncator, file):
"""Test that File objects are handled correctly in _truncate_json_primitives."""
# Test File object is returned as-is without truncation
result = truncator._truncate_json_primitives(file, 1000)
assert result.value == file
assert result.truncated is False
# Size should be calculated correctly
expected_size = VariableTruncator.calculate_json_size(file)
assert result.value_size == expected_size
def test_truncate_json_primitives_file_type_small_budget(self, truncator, file):
"""Test that File objects are returned as-is even with small budget."""
# Even with a small size budget, File objects should not be truncated
result = truncator._truncate_json_primitives(file, 10)
assert result.value == file
assert result.truncated is False
def test_truncate_json_primitives_file_type_in_array(self, truncator, file):
"""Test File objects in arrays are handled correctly."""
array_with_files = [file, file]
result = truncator._truncate_json_primitives(array_with_files, 1000)
assert isinstance(result.value, list)
assert len(result.value) == 2
assert result.value[0] == file
assert result.value[1] == file
assert result.truncated is False
def test_truncate_json_primitives_file_type_in_object(self, truncator, file):
"""Test File objects in objects are handled correctly."""
obj_with_files = {"file1": file, "file2": file}
result = truncator._truncate_json_primitives(obj_with_files, 1000)
assert isinstance(result.value, dict)
assert len(result.value) == 2
assert result.value["file1"] == file
assert result.value["file2"] == file
assert result.truncated is False
class TestIntegrationScenarios:
"""Test realistic integration scenarios."""

File diff suppressed because it is too large Load Diff

View File

@ -37,6 +37,7 @@ show_help() {
echo " pipeline - Standard pipeline tasks"
echo " triggered_workflow_dispatcher - Trigger dispatcher tasks"
echo " trigger_refresh_executor - Trigger refresh tasks"
echo " retention - Retention tasks"
}
# Parse command line arguments
@ -105,10 +106,10 @@ if [[ -z "${QUEUES}" ]]; then
# Configure queues based on edition
if [[ "${EDITION}" == "CLOUD" ]]; then
# Cloud edition: separate queues for dataset and trigger tasks
QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow_professional,workflow_team,workflow_sandbox,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor"
QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow_professional,workflow_team,workflow_sandbox,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention"
else
# Community edition (SELF_HOSTED): dataset and workflow have separate queues
QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor"
QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention"
fi
echo "No queues specified, using edition-based defaults: ${QUEUES}"

View File

@ -1044,6 +1044,25 @@ WORKFLOW_LOG_RETENTION_DAYS=30
# Batch size for workflow log cleanup operations (default: 100)
WORKFLOW_LOG_CLEANUP_BATCH_SIZE=100
# Aliyun SLS Logstore Configuration
# Aliyun Access Key ID
ALIYUN_SLS_ACCESS_KEY_ID=
# Aliyun Access Key Secret
ALIYUN_SLS_ACCESS_KEY_SECRET=
# Aliyun SLS Endpoint (e.g., cn-hangzhou.log.aliyuncs.com)
ALIYUN_SLS_ENDPOINT=
# Aliyun SLS Region (e.g., cn-hangzhou)
ALIYUN_SLS_REGION=
# Aliyun SLS Project Name
ALIYUN_SLS_PROJECT_NAME=
# Number of days to retain workflow run logs (default: 365 days 3650 for permanent storage)
ALIYUN_SLS_LOGSTORE_TTL=365
# Enable dual-write to both SLS LogStore and SQL database (default: false)
LOGSTORE_DUAL_WRITE_ENABLED=false
# Enable dual-read fallback to SQL database when LogStore returns no results (default: true)
# Useful for migration scenarios where historical data exists only in SQL database
LOGSTORE_DUAL_READ_ENABLED=true
# HTTP request node in workflow configuration
HTTP_REQUEST_NODE_MAX_BINARY_SIZE=10485760
HTTP_REQUEST_NODE_MAX_TEXT_SIZE=1048576
@ -1350,7 +1369,10 @@ PLUGIN_STDIO_BUFFER_SIZE=1024
PLUGIN_STDIO_MAX_BUFFER_SIZE=5242880
PLUGIN_PYTHON_ENV_INIT_TIMEOUT=120
# Plugin Daemon side timeout (configure to match the API side below)
PLUGIN_MAX_EXECUTION_TIMEOUT=600
# API side timeout (configure to match the Plugin Daemon side above)
PLUGIN_DAEMON_TIMEOUT=600.0
# PIP_MIRROR_URL=https://pypi.tuna.tsinghua.edu.cn/simple
PIP_MIRROR_URL=
@ -1461,3 +1483,8 @@ ANNOTATION_IMPORT_MAX_CONCURRENT=5
# The API key of amplitude
AMPLITUDE_API_KEY=
# Sandbox expired records clean configuration
SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD=21
SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE=1000
SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS=30

View File

@ -34,6 +34,7 @@ services:
PLUGIN_REMOTE_INSTALL_HOST: ${EXPOSE_PLUGIN_DEBUGGING_HOST:-localhost}
PLUGIN_REMOTE_INSTALL_PORT: ${EXPOSE_PLUGIN_DEBUGGING_PORT:-5003}
PLUGIN_MAX_PACKAGE_SIZE: ${PLUGIN_MAX_PACKAGE_SIZE:-52428800}
PLUGIN_DAEMON_TIMEOUT: ${PLUGIN_DAEMON_TIMEOUT:-600.0}
INNER_API_KEY_FOR_PLUGIN: ${PLUGIN_DIFY_INNER_API_KEY:-QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y1}
depends_on:
init_permissions:

View File

@ -455,6 +455,14 @@ x-shared-env: &shared-api-worker-env
WORKFLOW_LOG_CLEANUP_ENABLED: ${WORKFLOW_LOG_CLEANUP_ENABLED:-false}
WORKFLOW_LOG_RETENTION_DAYS: ${WORKFLOW_LOG_RETENTION_DAYS:-30}
WORKFLOW_LOG_CLEANUP_BATCH_SIZE: ${WORKFLOW_LOG_CLEANUP_BATCH_SIZE:-100}
ALIYUN_SLS_ACCESS_KEY_ID: ${ALIYUN_SLS_ACCESS_KEY_ID:-}
ALIYUN_SLS_ACCESS_KEY_SECRET: ${ALIYUN_SLS_ACCESS_KEY_SECRET:-}
ALIYUN_SLS_ENDPOINT: ${ALIYUN_SLS_ENDPOINT:-}
ALIYUN_SLS_REGION: ${ALIYUN_SLS_REGION:-}
ALIYUN_SLS_PROJECT_NAME: ${ALIYUN_SLS_PROJECT_NAME:-}
ALIYUN_SLS_LOGSTORE_TTL: ${ALIYUN_SLS_LOGSTORE_TTL:-365}
LOGSTORE_DUAL_WRITE_ENABLED: ${LOGSTORE_DUAL_WRITE_ENABLED:-false}
LOGSTORE_DUAL_READ_ENABLED: ${LOGSTORE_DUAL_READ_ENABLED:-true}
HTTP_REQUEST_NODE_MAX_BINARY_SIZE: ${HTTP_REQUEST_NODE_MAX_BINARY_SIZE:-10485760}
HTTP_REQUEST_NODE_MAX_TEXT_SIZE: ${HTTP_REQUEST_NODE_MAX_TEXT_SIZE:-1048576}
HTTP_REQUEST_NODE_SSL_VERIFY: ${HTTP_REQUEST_NODE_SSL_VERIFY:-True}
@ -583,6 +591,7 @@ x-shared-env: &shared-api-worker-env
PLUGIN_STDIO_MAX_BUFFER_SIZE: ${PLUGIN_STDIO_MAX_BUFFER_SIZE:-5242880}
PLUGIN_PYTHON_ENV_INIT_TIMEOUT: ${PLUGIN_PYTHON_ENV_INIT_TIMEOUT:-120}
PLUGIN_MAX_EXECUTION_TIMEOUT: ${PLUGIN_MAX_EXECUTION_TIMEOUT:-600}
PLUGIN_DAEMON_TIMEOUT: ${PLUGIN_DAEMON_TIMEOUT:-600.0}
PIP_MIRROR_URL: ${PIP_MIRROR_URL:-}
PLUGIN_STORAGE_TYPE: ${PLUGIN_STORAGE_TYPE:-local}
PLUGIN_STORAGE_LOCAL_ROOT: ${PLUGIN_STORAGE_LOCAL_ROOT:-/app/storage}
@ -655,6 +664,9 @@ x-shared-env: &shared-api-worker-env
ANNOTATION_IMPORT_RATE_LIMIT_PER_HOUR: ${ANNOTATION_IMPORT_RATE_LIMIT_PER_HOUR:-20}
ANNOTATION_IMPORT_MAX_CONCURRENT: ${ANNOTATION_IMPORT_MAX_CONCURRENT:-5}
AMPLITUDE_API_KEY: ${AMPLITUDE_API_KEY:-}
SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD: ${SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD:-21}
SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE: ${SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE:-1000}
SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS: ${SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS:-30}
services:
# Init container to fix permissions
@ -691,6 +703,7 @@ services:
PLUGIN_REMOTE_INSTALL_HOST: ${EXPOSE_PLUGIN_DEBUGGING_HOST:-localhost}
PLUGIN_REMOTE_INSTALL_PORT: ${EXPOSE_PLUGIN_DEBUGGING_PORT:-5003}
PLUGIN_MAX_PACKAGE_SIZE: ${PLUGIN_MAX_PACKAGE_SIZE:-52428800}
PLUGIN_DAEMON_TIMEOUT: ${PLUGIN_DAEMON_TIMEOUT:-600.0}
INNER_API_KEY_FOR_PLUGIN: ${PLUGIN_DIFY_INNER_API_KEY:-QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y1}
depends_on:
init_permissions:

View File

@ -213,3 +213,24 @@ PLUGIN_VOLCENGINE_TOS_ENDPOINT=
PLUGIN_VOLCENGINE_TOS_ACCESS_KEY=
PLUGIN_VOLCENGINE_TOS_SECRET_KEY=
PLUGIN_VOLCENGINE_TOS_REGION=
# ------------------------------
# Environment Variables for Aliyun SLS (Simple Log Service)
# ------------------------------
# Aliyun SLS Access Key ID
ALIYUN_SLS_ACCESS_KEY_ID=
# Aliyun SLS Access Key Secret
ALIYUN_SLS_ACCESS_KEY_SECRET=
# Aliyun SLS Endpoint (e.g., cn-hangzhou.log.aliyuncs.com)
ALIYUN_SLS_ENDPOINT=
# Aliyun SLS Region (e.g., cn-hangzhou)
ALIYUN_SLS_REGION=
# Aliyun SLS Project Name
ALIYUN_SLS_PROJECT_NAME=
# Aliyun SLS Logstore TTL (default: 365 days 3650 for permanent storage)
ALIYUN_SLS_LOGSTORE_TTL=365
# Enable dual-write to both LogStore and SQL database (default: true)
LOGSTORE_DUAL_WRITE_ENABLED=true
# Enable dual-read fallback to SQL database when LogStore returns no results (default: true)
# Useful for migration scenarios where historical data exists only in SQL database
LOGSTORE_DUAL_READ_ENABLED=true

View File

@ -61,14 +61,14 @@
<p align="center">
<a href="https://trendshift.io/repositories/2152" target="_blank"><img src="https://trendshift.io/api/badge/repositories/2152" alt="langgenius%2Fdify | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
</p>
Dify est une plateforme de développement d'applications LLM open source. Son interface intuitive combine un flux de travail d'IA, un pipeline RAG, des capacités d'agent, une gestion de modèles, des fonctionnalités d'observabilité, et plus encore, vous permettant de passer rapidement du prototype à la production. Voici une liste des fonctionnalités principales:
Dify est une plateforme de développement d'applications LLM open source. Sa interface intuitive combine un flux de travail d'IA, un pipeline RAG, des capacités d'agent, une gestion de modèles, des fonctionnalités d'observabilité, et plus encore, vous permettant de passer rapidement du prototype à la production. Voici une liste des fonctionnalités principales:
</br> </br>
**1. Flux de travail** :
Construisez et testez des flux de travail d'IA puissants sur un canevas visuel, en utilisant toutes les fonctionnalités suivantes et plus encore.
**2. Prise en charge complète des modèles** :
Intégration transparente avec des centaines de LLM propriétaires / open source provenant de dizaines de fournisseurs d'inférence et de solutions auto-hébergées, couvrant GPT, Mistral, Llama3, et tous les modèles compatibles avec l'API OpenAI. Une liste complète des fournisseurs de modèles pris en charge se trouve [ici](https://docs.dify.ai/getting-started/readme/model-providers).
Intégration transparente avec des centaines de LLM propriétaires / open source offerts par dizaines de fournisseurs d'inférence et de solutions auto-hébergées, couvrant GPT, Mistral, Llama3, et tous les modèles compatibles avec l'API OpenAI. Une liste complète des fournisseurs de modèles pris en charge se trouve [ici](https://docs.dify.ai/getting-started/readme/model-providers).
![providers-v5](https://github.com/langgenius/dify/assets/13230914/5a17bdbe-097a-4100-8363-40255b70f6e3)
@ -79,7 +79,7 @@ Interface intuitive pour créer des prompts, comparer les performances des modè
Des capacités RAG étendues qui couvrent tout, de l'ingestion de documents à la récupération, avec un support prêt à l'emploi pour l'extraction de texte à partir de PDF, PPT et autres formats de document courants.
**5. Capacités d'agent** :
Vous pouvez définir des agents basés sur l'appel de fonction LLM ou ReAct, et ajouter des outils pré-construits ou personnalisés pour l'agent. Dify fournit plus de 50 outils intégrés pour les agents d'IA, tels que la recherche Google, DALL·E, Stable Diffusion et WolframAlpha.
Vous pouvez définir des agents basés sur l'appel de fonctions LLM ou ReAct, et ajouter des outils pré-construits ou personnalisés pour l'agent. Dify fournit plus de 50 outils intégrés pour les agents d'IA, tels que la recherche Google, DALL·E, Stable Diffusion et WolframAlpha.
**6. LLMOps** :
Surveillez et analysez les journaux d'application et les performances au fil du temps. Vous pouvez continuellement améliorer les prompts, les ensembles de données et les modèles en fonction des données de production et des annotations.

View File

@ -16,7 +16,7 @@ import {
import { useTranslation } from 'react-i18next'
import { useShallow } from 'zustand/react/shallow'
import s from './style.module.css'
import cn from '@/utils/classnames'
import { cn } from '@/utils/classnames'
import { useStore } from '@/app/components/app/store'
import AppSideBar from '@/app/components/app-sidebar'
import type { NavIcon } from '@/app/components/app-sidebar/navLink'

View File

@ -3,7 +3,7 @@ import { RiCalendarLine } from '@remixicon/react'
import type { Dayjs } from 'dayjs'
import type { FC } from 'react'
import React, { useCallback } from 'react'
import cn from '@/utils/classnames'
import { cn } from '@/utils/classnames'
import { formatToLocalTime } from '@/utils/format'
import { useI18N } from '@/context/i18n'
import Picker from '@/app/components/base/date-and-time-picker/date-picker'

View File

@ -6,7 +6,7 @@ import { SimpleSelect } from '@/app/components/base/select'
import type { Item } from '@/app/components/base/select'
import dayjs from 'dayjs'
import { RiArrowDownSLine, RiCheckLine } from '@remixicon/react'
import cn from '@/utils/classnames'
import { cn } from '@/utils/classnames'
import { useTranslation } from 'react-i18next'
const today = dayjs()

View File

@ -4,7 +4,7 @@ import React, { useCallback, useRef, useState } from 'react'
import type { PopupProps } from './config-popup'
import ConfigPopup from './config-popup'
import cn from '@/utils/classnames'
import { cn } from '@/utils/classnames'
import {
PortalToFollowElem,
PortalToFollowElemContent,

View File

@ -12,7 +12,7 @@ import Indicator from '@/app/components/header/indicator'
import Switch from '@/app/components/base/switch'
import Tooltip from '@/app/components/base/tooltip'
import Divider from '@/app/components/base/divider'
import cn from '@/utils/classnames'
import { cn } from '@/utils/classnames'
const I18N_PREFIX = 'app.tracing'

View File

@ -1,7 +1,7 @@
'use client'
import type { FC } from 'react'
import React from 'react'
import cn from '@/utils/classnames'
import { cn } from '@/utils/classnames'
import Input from '@/app/components/base/input'
type Props = {

View File

@ -12,7 +12,7 @@ import type { AliyunConfig, ArizeConfig, DatabricksConfig, LangFuseConfig, LangS
import { TracingProvider } from './type'
import TracingIcon from './tracing-icon'
import ConfigButton from './config-button'
import cn from '@/utils/classnames'
import { cn } from '@/utils/classnames'
import { AliyunIcon, ArizeIcon, DatabricksIcon, LangfuseIcon, LangsmithIcon, MlflowIcon, OpikIcon, PhoenixIcon, TencentIcon, WeaveIcon } from '@/app/components/base/icons/src/public/tracing'
import Indicator from '@/app/components/header/indicator'
import { fetchTracingConfig as doFetchTracingConfig, fetchTracingStatus, updateTracingStatus } from '@/service/apps'

View File

@ -6,7 +6,7 @@ import {
} from '@remixicon/react'
import { useTranslation } from 'react-i18next'
import { TracingProvider } from './type'
import cn from '@/utils/classnames'
import { cn } from '@/utils/classnames'
import { AliyunIconBig, ArizeIconBig, DatabricksIconBig, LangfuseIconBig, LangsmithIconBig, MlflowIconBig, OpikIconBig, PhoenixIconBig, TencentIconBig, WeaveIconBig } from '@/app/components/base/icons/src/public/tracing'
import { Eye as View } from '@/app/components/base/icons/src/vender/solid/general'

View File

@ -1,7 +1,7 @@
'use client'
import type { FC } from 'react'
import React from 'react'
import cn from '@/utils/classnames'
import { cn } from '@/utils/classnames'
import { TracingIcon as Icon } from '@/app/components/base/icons/src/public/tracing'
type Props = {

View File

@ -23,7 +23,7 @@ import { useDatasetDetail, useDatasetRelatedApps } from '@/service/knowledge/use
import useDocumentTitle from '@/hooks/use-document-title'
import ExtraInfo from '@/app/components/datasets/extra-info'
import { useEventEmitterContextContext } from '@/context/event-emitter'
import cn from '@/utils/classnames'
import { cn } from '@/utils/classnames'
export type IAppDetailLayoutProps = {
children: React.ReactNode

View File

@ -1,7 +1,7 @@
'use client'
import Header from '@/app/signin/_header'
import cn from '@/utils/classnames'
import { cn } from '@/utils/classnames'
import { useGlobalPublicStore } from '@/context/global-public-context'
export default function SignInLayout({ children }: any) {

View File

@ -2,7 +2,7 @@
import { useCallback, useState } from 'react'
import { useTranslation } from 'react-i18next'
import { useRouter, useSearchParams } from 'next/navigation'
import cn from 'classnames'
import { cn } from '@/utils/classnames'
import { RiCheckboxCircleFill } from '@remixicon/react'
import { useCountDown } from 'ahooks'
import Button from '@/app/components/base/button'

View File

@ -1,6 +1,6 @@
'use client'
import cn from '@/utils/classnames'
import { cn } from '@/utils/classnames'
import { useGlobalPublicStore } from '@/context/global-public-context'
import useDocumentTitle from '@/hooks/use-document-title'
import type { PropsWithChildren } from 'react'

View File

@ -7,7 +7,7 @@ import Loading from '@/app/components/base/loading'
import MailAndCodeAuth from './components/mail-and-code-auth'
import MailAndPasswordAuth from './components/mail-and-password-auth'
import SSOAuth from './components/sso-auth'
import cn from '@/utils/classnames'
import { cn } from '@/utils/classnames'
import { LicenseStatus } from '@/types/feature'
import { IS_CE_EDITION } from '@/config'
import { useGlobalPublicStore } from '@/context/global-public-context'

View File

@ -1,7 +1,7 @@
'use client'
import Header from '@/app/signin/_header'
import cn from '@/utils/classnames'
import { cn } from '@/utils/classnames'
import { useGlobalPublicStore } from '@/context/global-public-context'
import useDocumentTitle from '@/hooks/use-document-title'
import { AppContextProvider } from '@/context/app-context'

View File

@ -2,7 +2,7 @@
import { useTranslation } from 'react-i18next'
import useSWR from 'swr'
import { useRouter, useSearchParams } from 'next/navigation'
import cn from '@/utils/classnames'
import { cn } from '@/utils/classnames'
import Button from '@/app/components/base/button'
import { invitationCheck } from '@/service/common'

View File

@ -2,7 +2,7 @@
import React from 'react'
import Header from '../signin/_header'
import ActivateForm from './activateForm'
import cn from '@/utils/classnames'
import { cn } from '@/utils/classnames'
import { useGlobalPublicStore } from '@/context/global-public-context'
const Activate = () => {

View File

@ -29,7 +29,7 @@ import CardView from '@/app/(commonLayout)/app/(appDetailLayout)/[appId]/overvie
import type { Operation } from './app-operations'
import AppOperations from './app-operations'
import dynamic from 'next/dynamic'
import cn from '@/utils/classnames'
import { cn } from '@/utils/classnames'
import { AppModeEnum } from '@/types/app'
const SwitchAppModal = dynamic(() => import('@/app/components/app/switch-app-modal'), {

View File

@ -16,7 +16,7 @@ import AppInfo from './app-info'
import NavLink from './navLink'
import { useStore as useAppStore } from '@/app/components/app/store'
import type { NavIcon } from './navLink'
import cn from '@/utils/classnames'
import { cn } from '@/utils/classnames'
import { AppModeEnum } from '@/types/app'
type Props = {

View File

@ -2,7 +2,7 @@ import React, { useCallback, useState } from 'react'
import { PortalToFollowElem, PortalToFollowElemContent, PortalToFollowElemTrigger } from '../../base/portal-to-follow-elem'
import ActionButton from '../../base/action-button'
import { RiMoreFill } from '@remixicon/react'
import cn from '@/utils/classnames'
import { cn } from '@/utils/classnames'
import Menu from './menu'
import { useSelector as useAppContextWithSelector } from '@/context/app-context'
import { useDatasetDetailContextWithSelector } from '@/context/dataset-detail'

View File

@ -8,7 +8,7 @@ import { useDatasetDetailContextWithSelector } from '@/context/dataset-detail'
import type { DataSet } from '@/models/datasets'
import { DOC_FORM_TEXT } from '@/models/datasets'
import { useKnowledge } from '@/hooks/use-knowledge'
import cn from '@/utils/classnames'
import { cn } from '@/utils/classnames'
import Dropdown from './dropdown'
type DatasetInfoProps = {

View File

@ -11,7 +11,7 @@ import AppIcon from '../base/app-icon'
import Divider from '../base/divider'
import NavLink from './navLink'
import type { NavIcon } from './navLink'
import cn from '@/utils/classnames'
import { cn } from '@/utils/classnames'
import { useDatasetDetailContextWithSelector } from '@/context/dataset-detail'
import Effect from '../base/effect'
import Dropdown from './dataset-info/dropdown'

View File

@ -9,7 +9,7 @@ import AppSidebarDropdown from './app-sidebar-dropdown'
import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints'
import { useStore as useAppStore } from '@/app/components/app/store'
import { useEventEmitterContextContext } from '@/context/event-emitter'
import cn from '@/utils/classnames'
import { cn } from '@/utils/classnames'
import Divider from '../base/divider'
import { useHover, useKeyPress } from 'ahooks'
import ToggleButton from './toggle-button'

View File

@ -2,7 +2,7 @@
import React from 'react'
import { useSelectedLayoutSegment } from 'next/navigation'
import Link from 'next/link'
import classNames from '@/utils/classnames'
import { cn } from '@/utils/classnames'
import type { RemixiconComponentType } from '@remixicon/react'
export type NavIcon = React.ComponentType<
@ -42,7 +42,7 @@ const NavLink = ({
const NavIcon = isActive ? iconMap.selected : iconMap.normal
const renderIcon = () => (
<div className={classNames(mode !== 'expand' && '-ml-1')}>
<div className={cn(mode !== 'expand' && '-ml-1')}>
<NavIcon className="h-4 w-4 shrink-0" aria-hidden="true" />
</div>
)
@ -53,21 +53,17 @@ const NavLink = ({
key={name}
type='button'
disabled
className={classNames(
'system-sm-medium flex h-8 cursor-not-allowed items-center rounded-lg text-components-menu-item-text opacity-30 hover:bg-components-menu-item-bg-hover',
'pl-3 pr-1',
)}
className={cn('system-sm-medium flex h-8 cursor-not-allowed items-center rounded-lg text-components-menu-item-text opacity-30 hover:bg-components-menu-item-bg-hover',
'pl-3 pr-1')}
title={mode === 'collapse' ? name : ''}
aria-disabled
>
{renderIcon()}
<span
className={classNames(
'overflow-hidden whitespace-nowrap transition-all duration-200 ease-in-out',
className={cn('overflow-hidden whitespace-nowrap transition-all duration-200 ease-in-out',
mode === 'expand'
? 'ml-2 max-w-none opacity-100'
: 'ml-0 max-w-0 opacity-0',
)}
: 'ml-0 max-w-0 opacity-0')}
>
{name}
</span>
@ -79,22 +75,18 @@ const NavLink = ({
<Link
key={name}
href={href}
className={classNames(
isActive
? 'system-sm-semibold border-b-[0.25px] border-l-[0.75px] border-r-[0.25px] border-t-[0.75px] border-effects-highlight-lightmode-off bg-components-menu-item-bg-active text-text-accent-light-mode-only'
: 'system-sm-medium text-components-menu-item-text hover:bg-components-menu-item-bg-hover hover:text-components-menu-item-text-hover',
'flex h-8 items-center rounded-lg pl-3 pr-1',
)}
className={cn(isActive
? 'system-sm-semibold border-b-[0.25px] border-l-[0.75px] border-r-[0.25px] border-t-[0.75px] border-effects-highlight-lightmode-off bg-components-menu-item-bg-active text-text-accent-light-mode-only'
: 'system-sm-medium text-components-menu-item-text hover:bg-components-menu-item-bg-hover hover:text-components-menu-item-text-hover',
'flex h-8 items-center rounded-lg pl-3 pr-1')}
title={mode === 'collapse' ? name : ''}
>
{renderIcon()}
<span
className={classNames(
'overflow-hidden whitespace-nowrap transition-all duration-200 ease-in-out',
className={cn('overflow-hidden whitespace-nowrap transition-all duration-200 ease-in-out',
mode === 'expand'
? 'ml-2 max-w-none opacity-100'
: 'ml-0 max-w-0 opacity-0',
)}
: 'ml-0 max-w-0 opacity-0')}
>
{name}
</span>

View File

@ -1,7 +1,7 @@
import React from 'react'
import Button from '../base/button'
import { RiArrowLeftSLine, RiArrowRightSLine } from '@remixicon/react'
import cn from '@/utils/classnames'
import { cn } from '@/utils/classnames'
import Tooltip from '../base/tooltip'
import { useTranslation } from 'react-i18next'
import { getKeyboardKeyNameBySystem } from '../workflow/utils'

View File

@ -0,0 +1,47 @@
import React from 'react'
import { fireEvent, render, screen } from '@testing-library/react'
import EditItem, { EditItemType } from './index'
describe('AddAnnotationModal/EditItem', () => {
test('should render query inputs with user avatar and placeholder strings', () => {
render(
<EditItem
type={EditItemType.Query}
content="Why?"
onChange={jest.fn()}
/>,
)
expect(screen.getByText('appAnnotation.addModal.queryName')).toBeInTheDocument()
expect(screen.getByPlaceholderText('appAnnotation.addModal.queryPlaceholder')).toBeInTheDocument()
expect(screen.getByText('Why?')).toBeInTheDocument()
})
test('should render answer name and placeholder text', () => {
render(
<EditItem
type={EditItemType.Answer}
content="Existing answer"
onChange={jest.fn()}
/>,
)
expect(screen.getByText('appAnnotation.addModal.answerName')).toBeInTheDocument()
expect(screen.getByPlaceholderText('appAnnotation.addModal.answerPlaceholder')).toBeInTheDocument()
expect(screen.getByDisplayValue('Existing answer')).toBeInTheDocument()
})
test('should propagate changes when answer content updates', () => {
const handleChange = jest.fn()
render(
<EditItem
type={EditItemType.Answer}
content=""
onChange={handleChange}
/>,
)
fireEvent.change(screen.getByPlaceholderText('appAnnotation.addModal.answerPlaceholder'), { target: { value: 'Because' } })
expect(handleChange).toHaveBeenCalledWith('Because')
})
})

View File

@ -0,0 +1,155 @@
import React from 'react'
import { act, fireEvent, render, screen, waitFor } from '@testing-library/react'
import AddAnnotationModal from './index'
import { useProviderContext } from '@/context/provider-context'
jest.mock('@/context/provider-context', () => ({
useProviderContext: jest.fn(),
}))
const mockToastNotify = jest.fn()
jest.mock('@/app/components/base/toast', () => ({
__esModule: true,
default: {
notify: jest.fn(args => mockToastNotify(args)),
},
}))
jest.mock('@/app/components/billing/annotation-full', () => () => <div data-testid="annotation-full" />)
const mockUseProviderContext = useProviderContext as jest.Mock
const getProviderContext = ({ usage = 0, total = 10, enableBilling = false } = {}) => ({
plan: {
usage: { annotatedResponse: usage },
total: { annotatedResponse: total },
},
enableBilling,
})
describe('AddAnnotationModal', () => {
const baseProps = {
isShow: true,
onHide: jest.fn(),
onAdd: jest.fn(),
}
beforeEach(() => {
jest.clearAllMocks()
mockUseProviderContext.mockReturnValue(getProviderContext())
})
const typeQuestion = (value: string) => {
fireEvent.change(screen.getByPlaceholderText('appAnnotation.addModal.queryPlaceholder'), {
target: { value },
})
}
const typeAnswer = (value: string) => {
fireEvent.change(screen.getByPlaceholderText('appAnnotation.addModal.answerPlaceholder'), {
target: { value },
})
}
test('should render modal title when drawer is visible', () => {
render(<AddAnnotationModal {...baseProps} />)
expect(screen.getByText('appAnnotation.addModal.title')).toBeInTheDocument()
})
test('should capture query input text when typing', () => {
render(<AddAnnotationModal {...baseProps} />)
typeQuestion('Sample question')
expect(screen.getByPlaceholderText('appAnnotation.addModal.queryPlaceholder')).toHaveValue('Sample question')
})
test('should capture answer input text when typing', () => {
render(<AddAnnotationModal {...baseProps} />)
typeAnswer('Sample answer')
expect(screen.getByPlaceholderText('appAnnotation.addModal.answerPlaceholder')).toHaveValue('Sample answer')
})
test('should show annotation full notice and disable submit when quota exceeded', () => {
mockUseProviderContext.mockReturnValue(getProviderContext({ usage: 10, total: 10, enableBilling: true }))
render(<AddAnnotationModal {...baseProps} />)
expect(screen.getByTestId('annotation-full')).toBeInTheDocument()
expect(screen.getByRole('button', { name: 'common.operation.add' })).toBeDisabled()
})
test('should call onAdd with form values when create next enabled', async () => {
const onAdd = jest.fn().mockResolvedValue(undefined)
render(<AddAnnotationModal {...baseProps} onAdd={onAdd} />)
typeQuestion('Question value')
typeAnswer('Answer value')
fireEvent.click(screen.getByTestId('checkbox-create-next-checkbox'))
await act(async () => {
fireEvent.click(screen.getByRole('button', { name: 'common.operation.add' }))
})
expect(onAdd).toHaveBeenCalledWith({ question: 'Question value', answer: 'Answer value' })
})
test('should reset fields after saving when create next enabled', async () => {
const onAdd = jest.fn().mockResolvedValue(undefined)
render(<AddAnnotationModal {...baseProps} onAdd={onAdd} />)
typeQuestion('Question value')
typeAnswer('Answer value')
const createNextToggle = screen.getByText('appAnnotation.addModal.createNext').previousElementSibling as HTMLElement
fireEvent.click(createNextToggle)
await act(async () => {
fireEvent.click(screen.getByRole('button', { name: 'common.operation.add' }))
})
await waitFor(() => {
expect(screen.getByPlaceholderText('appAnnotation.addModal.queryPlaceholder')).toHaveValue('')
expect(screen.getByPlaceholderText('appAnnotation.addModal.answerPlaceholder')).toHaveValue('')
})
})
test('should show toast when validation fails for missing question', () => {
render(<AddAnnotationModal {...baseProps} />)
fireEvent.click(screen.getByRole('button', { name: 'common.operation.add' }))
expect(mockToastNotify).toHaveBeenCalledWith(expect.objectContaining({
type: 'error',
message: 'appAnnotation.errorMessage.queryRequired',
}))
})
test('should show toast when validation fails for missing answer', () => {
render(<AddAnnotationModal {...baseProps} />)
typeQuestion('Filled question')
fireEvent.click(screen.getByRole('button', { name: 'common.operation.add' }))
expect(mockToastNotify).toHaveBeenCalledWith(expect.objectContaining({
type: 'error',
message: 'appAnnotation.errorMessage.answerRequired',
}))
})
test('should close modal when save completes and create next unchecked', async () => {
const onAdd = jest.fn().mockResolvedValue(undefined)
render(<AddAnnotationModal {...baseProps} onAdd={onAdd} />)
typeQuestion('Q')
typeAnswer('A')
await act(async () => {
fireEvent.click(screen.getByRole('button', { name: 'common.operation.add' }))
})
expect(baseProps.onHide).toHaveBeenCalled()
})
test('should allow cancel button to close the drawer', () => {
render(<AddAnnotationModal {...baseProps} />)
fireEvent.click(screen.getByRole('button', { name: 'common.operation.cancel' }))
expect(baseProps.onHide).toHaveBeenCalled()
})
})

View File

@ -101,7 +101,7 @@ const AddAnnotationModal: FC<Props> = ({
<div
className='flex items-center space-x-2'
>
<Checkbox checked={isCreateNext} onCheck={() => setIsCreateNext(!isCreateNext)} />
<Checkbox id='create-next-checkbox' checked={isCreateNext} onCheck={() => setIsCreateNext(!isCreateNext)} />
<div>{t('appAnnotation.addModal.createNext')}</div>
</div>
<div className='mt-2 flex space-x-2'>

Some files were not shown because too many files have changed in this diff Show More