diff --git a/.devcontainer/Dockerfile b/.devcontainer/Dockerfile index 3dd00ee4db..c03f281858 100644 --- a/.devcontainer/Dockerfile +++ b/.devcontainer/Dockerfile @@ -1,4 +1,4 @@ -FROM mcr.microsoft.com/devcontainers/python:3.12-bullseye +FROM mcr.microsoft.com/devcontainers/python:3.12-bookworm RUN apt-get update && export DEBIAN_FRONTEND=noninteractive \ && apt-get -y install libgmp-dev libmpfr-dev libmpc-dev diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml index c1666d24cf..859f499b8e 100644 --- a/.github/ISSUE_TEMPLATE/config.yml +++ b/.github/ISSUE_TEMPLATE/config.yml @@ -1,5 +1,8 @@ blank_issues_enabled: false contact_links: + - name: "\U0001F510 Security Vulnerabilities" + url: "https://github.com/langgenius/dify/security/advisories/new" + about: Report security vulnerabilities through GitHub Security Advisories to ensure responsible disclosure. 💡 Please do not report security vulnerabilities in public issues. - name: "\U0001F4A1 Model Providers & Plugins" url: "https://github.com/langgenius/dify-official-plugins/issues/new/choose" about: Report issues with official plugins or model providers, you will need to provide the plugin version and other relevant details. diff --git a/.github/workflows/autofix.yml b/.github/workflows/autofix.yml index 068ba686fa..0cae2ef552 100644 --- a/.github/workflows/autofix.yml +++ b/.github/workflows/autofix.yml @@ -15,10 +15,12 @@ jobs: # Use uv to ensure we have the same ruff version in CI and locally. - uses: astral-sh/setup-uv@v6 with: - python-version: "3.12" + python-version: "3.11" - run: | cd api uv sync --dev + # fmt first to avoid line too long + uv run ruff format .. # Fix lint errors uv run ruff check --fix . # Format code @@ -28,6 +30,8 @@ jobs: run: | uvx --from ast-grep-cli sg --pattern 'db.session.query($WHATEVER).filter($HERE)' --rewrite 'db.session.query($WHATEVER).where($HERE)' -l py --update-all uvx --from ast-grep-cli sg --pattern 'session.query($WHATEVER).filter($HERE)' --rewrite 'session.query($WHATEVER).where($HERE)' -l py --update-all + uvx --from ast-grep-cli sg -p '$A = db.Column($$$B)' -r '$A = mapped_column($$$B)' -l py --update-all + uvx --from ast-grep-cli sg -p '$A : $T = db.Column($$$B)' -r '$A : $T = mapped_column($$$B)' -l py --update-all # Convert Optional[T] to T | None (ignoring quoted types) cat > /tmp/optional-rule.yml << 'EOF' id: convert-optional-to-union diff --git a/.github/workflows/build-push.yml b/.github/workflows/build-push.yml index 17af047267..f7f464a601 100644 --- a/.github/workflows/build-push.yml +++ b/.github/workflows/build-push.yml @@ -4,10 +4,10 @@ on: push: branches: - "main" - - "deploy/dev" - - "deploy/enterprise" + - "deploy/**" - "build/**" - "release/e-*" + - "hotfix/**" tags: - "*" diff --git a/.github/workflows/deploy-dev.yml b/.github/workflows/deploy-dev.yml index de732c3134..cd1c86e668 100644 --- a/.github/workflows/deploy-dev.yml +++ b/.github/workflows/deploy-dev.yml @@ -18,7 +18,7 @@ jobs: - name: Deploy to server uses: appleboy/ssh-action@v0.1.8 with: - host: ${{ secrets.RAG_SSH_HOST }} + host: ${{ secrets.SSH_HOST }} username: ${{ secrets.SSH_USER }} key: ${{ secrets.SSH_PRIVATE_KEY }} script: | diff --git a/.github/workflows/deploy-rag-dev.yml b/.github/workflows/deploy-trigger-dev.yml similarity index 75% rename from .github/workflows/deploy-rag-dev.yml rename to .github/workflows/deploy-trigger-dev.yml index 86265aad6d..2d9a904fc5 100644 --- a/.github/workflows/deploy-rag-dev.yml +++ b/.github/workflows/deploy-trigger-dev.yml @@ -1,4 +1,4 @@ -name: Deploy RAG Dev +name: Deploy Trigger Dev permissions: contents: read @@ -7,7 +7,7 @@ on: workflow_run: workflows: ["Build and Push API & Web"] branches: - - "deploy/rag-dev" + - "deploy/trigger-dev" types: - completed @@ -16,12 +16,12 @@ jobs: runs-on: ubuntu-latest if: | github.event.workflow_run.conclusion == 'success' && - github.event.workflow_run.head_branch == 'deploy/rag-dev' + github.event.workflow_run.head_branch == 'deploy/trigger-dev' steps: - name: Deploy to server uses: appleboy/ssh-action@v0.1.8 with: - host: ${{ secrets.RAG_SSH_HOST }} + host: ${{ secrets.TRIGGER_SSH_HOST }} username: ${{ secrets.SSH_USER }} key: ${{ secrets.SSH_PRIVATE_KEY }} script: | diff --git a/AGENTS.md b/AGENTS.md index 44f7b30360..5859cd1bd9 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -4,84 +4,51 @@ Dify is an open-source platform for developing LLM applications with an intuitive interface combining agentic AI workflows, RAG pipelines, agent capabilities, and model management. -The codebase consists of: +The codebase is split into: -- **Backend API** (`/api`): Python Flask application with Domain-Driven Design architecture -- **Frontend Web** (`/web`): Next.js 15 application with TypeScript and React 19 +- **Backend API** (`/api`): Python Flask application organized with Domain-Driven Design +- **Frontend Web** (`/web`): Next.js 15 application using TypeScript and React 19 - **Docker deployment** (`/docker`): Containerized deployment configurations -## Development Commands +## Backend Workflow -### Backend (API) +- Run backend CLI commands through `uv run --project api `. -All Python commands must be prefixed with `uv run --project api`: +- Backend QA gate requires passing `make lint`, `make type-check`, and `uv run --project api --dev dev/pytest/pytest_unit_tests.sh` before review. -```bash -# Start development servers -./dev/start-api # Start API server -./dev/start-worker # Start Celery worker +- Use Makefile targets for linting and formatting; `make lint` and `make type-check` cover the required checks. -# Run tests -uv run --project api pytest # Run all tests -uv run --project api pytest tests/unit_tests/ # Unit tests only -uv run --project api pytest tests/integration_tests/ # Integration tests +- Integration tests are CI-only and are not expected to run in the local environment. -# Code quality -./dev/reformat # Run all formatters and linters -uv run --project api ruff check --fix ./ # Fix linting issues -uv run --project api ruff format ./ # Format code -uv run --directory api basedpyright # Type checking -``` - -### Frontend (Web) +## Frontend Workflow ```bash cd web -pnpm lint # Run ESLint -pnpm eslint-fix # Fix ESLint issues -pnpm test # Run Jest tests +pnpm lint +pnpm lint:fix +pnpm test ``` -## Testing Guidelines +## Testing & Quality Practices -### Backend Testing +- Follow TDD: red → green → refactor. +- Use `pytest` for backend tests with Arrange-Act-Assert structure. +- Enforce strong typing; avoid `Any` and prefer explicit type annotations. +- Write self-documenting code; only add comments that explain intent. -- Use `pytest` for all backend tests -- Write tests first (TDD approach) -- Test structure: Arrange-Act-Assert +## Language Style -## Code Style Requirements +- **Python**: Keep type hints on functions and attributes, and implement relevant special methods (e.g., `__repr__`, `__str__`). +- **TypeScript**: Use the strict config, lean on ESLint + Prettier workflows, and avoid `any` types. -### Python +## General Practices -- Use type hints for all functions and class attributes -- No `Any` types unless absolutely necessary -- Implement special methods (`__repr__`, `__str__`) appropriately +- Prefer editing existing files; add new documentation only when requested. +- Inject dependencies through constructors and preserve clean architecture boundaries. +- Handle errors with domain-specific exceptions at the correct layer. -### TypeScript/JavaScript +## Project Conventions -- Strict TypeScript configuration -- ESLint with Prettier integration -- Avoid `any` type - -## Important Notes - -- **Environment Variables**: Always use UV for Python commands: `uv run --project api ` -- **Comments**: Only write meaningful comments that explain "why", not "what" -- **File Creation**: Always prefer editing existing files over creating new ones -- **Documentation**: Don't create documentation files unless explicitly requested -- **Code Quality**: Always run `./dev/reformat` before committing backend changes - -## Common Development Tasks - -### Adding a New API Endpoint - -1. Create controller in `/api/controllers/` -1. Add service logic in `/api/services/` -1. Update routes in controller's `__init__.py` -1. Write tests in `/api/tests/` - -## Project-Specific Conventions - -- All async tasks use Celery with Redis as broker -- **Internationalization**: Frontend supports multiple languages with English (`web/i18n/en-US/`) as the source. All user-facing text must use i18n keys, no hardcoded strings. Edit corresponding module files in `en-US/` directory for translations. +- Backend architecture adheres to DDD and Clean Architecture principles. +- Async work runs through Celery with Redis as the broker. +- Frontend user-facing strings must use `web/i18n/en-US/`; avoid hardcoded text. diff --git a/Makefile b/Makefile index ea560c7157..19c398ec82 100644 --- a/Makefile +++ b/Makefile @@ -26,7 +26,6 @@ prepare-web: @echo "🌐 Setting up web environment..." @cp -n web/.env.example web/.env 2>/dev/null || echo "Web .env already exists" @cd web && pnpm install - @cd web && pnpm build @echo "✅ Web environment prepared (not started)" # Step 3: Prepare API environment diff --git a/README.md b/README.md index 90da1d3def..aadced582d 100644 --- a/README.md +++ b/README.md @@ -40,18 +40,18 @@

README in English - 繁體中文文件 - 简体中文版自述文件 - 日本語のREADME - README en Español - README en Français - README tlhIngan Hol - README in Korean - README بالعربية - Türkçe README - README Tiếng Việt - README in Deutsch - README in বাংলা + 繁體中文文件 + 简体中文文件 + 日本語のREADME + README en Español + README en Français + README tlhIngan Hol + README in Korean + README بالعربية + Türkçe README + README Tiếng Việt + README in Deutsch + README in বাংলা

Dify is an open-source platform for developing LLM applications. Its intuitive interface combines agentic AI workflows, RAG pipelines, agent capabilities, model management, observability features, and more—allowing you to quickly move from prototype to production. diff --git a/api/.env.example b/api/.env.example index 78a363e506..1d8190ce5f 100644 --- a/api/.env.example +++ b/api/.env.example @@ -343,6 +343,15 @@ OCEANBASE_VECTOR_DATABASE=test OCEANBASE_MEMORY_LIMIT=6G OCEANBASE_ENABLE_HYBRID_SEARCH=false +# AlibabaCloud MySQL Vector configuration +ALIBABACLOUD_MYSQL_HOST=127.0.0.1 +ALIBABACLOUD_MYSQL_PORT=3306 +ALIBABACLOUD_MYSQL_USER=root +ALIBABACLOUD_MYSQL_PASSWORD=root +ALIBABACLOUD_MYSQL_DATABASE=dify +ALIBABACLOUD_MYSQL_MAX_CONNECTION=5 +ALIBABACLOUD_MYSQL_HNSW_M=6 + # openGauss configuration OPENGAUSS_HOST=127.0.0.1 OPENGAUSS_PORT=6600 @@ -408,6 +417,9 @@ SSRF_DEFAULT_TIME_OUT=5 SSRF_DEFAULT_CONNECT_TIME_OUT=5 SSRF_DEFAULT_READ_TIME_OUT=5 SSRF_DEFAULT_WRITE_TIME_OUT=5 +SSRF_POOL_MAX_CONNECTIONS=100 +SSRF_POOL_MAX_KEEPALIVE_CONNECTIONS=20 +SSRF_POOL_KEEPALIVE_EXPIRY=5.0 BATCH_UPLOAD_LIMIT=10 KEYWORD_DATA_SOURCE_TYPE=database @@ -418,10 +430,14 @@ WORKFLOW_FILE_UPLOAD_LIMIT=10 # CODE EXECUTION CONFIGURATION CODE_EXECUTION_ENDPOINT=http://127.0.0.1:8194 CODE_EXECUTION_API_KEY=dify-sandbox +CODE_EXECUTION_SSL_VERIFY=True +CODE_EXECUTION_POOL_MAX_CONNECTIONS=100 +CODE_EXECUTION_POOL_MAX_KEEPALIVE_CONNECTIONS=20 +CODE_EXECUTION_POOL_KEEPALIVE_EXPIRY=5.0 CODE_MAX_NUMBER=9223372036854775807 CODE_MIN_NUMBER=-9223372036854775808 -CODE_MAX_STRING_LENGTH=80000 -TEMPLATE_TRANSFORM_MAX_LENGTH=80000 +CODE_MAX_STRING_LENGTH=400000 +TEMPLATE_TRANSFORM_MAX_LENGTH=400000 CODE_MAX_STRING_ARRAY_LENGTH=30 CODE_MAX_OBJECT_ARRAY_LENGTH=30 CODE_MAX_NUMBER_ARRAY_LENGTH=1000 @@ -461,7 +477,6 @@ INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH=4000 WORKFLOW_MAX_EXECUTION_STEPS=500 WORKFLOW_MAX_EXECUTION_TIME=1200 WORKFLOW_CALL_MAX_DEPTH=5 -WORKFLOW_PARALLEL_DEPTH_LIMIT=3 MAX_VARIABLE_SIZE=204800 # GraphEngine Worker Pool Configuration diff --git a/api/.ruff.toml b/api/.ruff.toml index 643bc063a1..5a29e1d8fa 100644 --- a/api/.ruff.toml +++ b/api/.ruff.toml @@ -81,7 +81,6 @@ ignore = [ "SIM113", # enumerate-for-loop "SIM117", # multiple-with-statements "SIM210", # if-expr-with-true-false - "UP038", # deprecated and not recommended by Ruff, https://docs.astral.sh/ruff/rules/non-pep604-isinstance/ ] [lint.per-file-ignores] diff --git a/api/README.md b/api/README.md index 5ecf92a4f0..e75ea3d354 100644 --- a/api/README.md +++ b/api/README.md @@ -80,10 +80,10 @@ 1. If you need to handle and debug the async tasks (e.g. dataset importing and documents indexing), please start the worker service. ```bash -uv run celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation +uv run celery -A app.celery worker -P gevent -c 2 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation ``` -Addition, if you want to debug the celery scheduled tasks, you can use the following command in another terminal: +Additionally, if you want to debug the celery scheduled tasks, you can run the following command in another terminal to start the beat service: ```bash uv run celery -A app.celery beat diff --git a/api/celery_entrypoint.py b/api/celery_entrypoint.py index 4d1f17430d..28fa0972e8 100644 --- a/api/celery_entrypoint.py +++ b/api/celery_entrypoint.py @@ -1,20 +1,11 @@ -import logging - import psycogreen.gevent as pscycogreen_gevent # type: ignore from grpc.experimental import gevent as grpc_gevent # type: ignore -_logger = logging.getLogger(__name__) - - -def _log(message: str): - _logger.debug(message) - - # grpc gevent grpc_gevent.init_gevent() -_log("gRPC patched with gevent.") +print("gRPC patched with gevent.", flush=True) # noqa: T201 pscycogreen_gevent.patch_psycopg() -_log("psycopg2 patched with gevent.") +print("psycopg2 patched with gevent.", flush=True) # noqa: T201 from app import app, celery diff --git a/api/commands.py b/api/commands.py index 259d823dea..82efe34611 100644 --- a/api/commands.py +++ b/api/commands.py @@ -10,6 +10,7 @@ from flask import current_app from pydantic import TypeAdapter from sqlalchemy import select from sqlalchemy.exc import SQLAlchemyError +from sqlalchemy.orm import sessionmaker from configs import dify_config from constants.languages import languages @@ -61,31 +62,30 @@ def reset_password(email, new_password, password_confirm): if str(new_password).strip() != str(password_confirm).strip(): click.echo(click.style("Passwords do not match.", fg="red")) return + with sessionmaker(db.engine, expire_on_commit=False).begin() as session: + account = session.query(Account).where(Account.email == email).one_or_none() - account = db.session.query(Account).where(Account.email == email).one_or_none() + if not account: + click.echo(click.style(f"Account not found for email: {email}", fg="red")) + return - if not account: - click.echo(click.style(f"Account not found for email: {email}", fg="red")) - return + try: + valid_password(new_password) + except: + click.echo(click.style(f"Invalid password. Must match {password_pattern}", fg="red")) + return - try: - valid_password(new_password) - except: - click.echo(click.style(f"Invalid password. Must match {password_pattern}", fg="red")) - return + # generate password salt + salt = secrets.token_bytes(16) + base64_salt = base64.b64encode(salt).decode() - # generate password salt - salt = secrets.token_bytes(16) - base64_salt = base64.b64encode(salt).decode() - - # encrypt password with salt - password_hashed = hash_password(new_password, salt) - base64_password_hashed = base64.b64encode(password_hashed).decode() - account.password = base64_password_hashed - account.password_salt = base64_salt - db.session.commit() - AccountService.reset_login_error_rate_limit(email) - click.echo(click.style("Password reset successfully.", fg="green")) + # encrypt password with salt + password_hashed = hash_password(new_password, salt) + base64_password_hashed = base64.b64encode(password_hashed).decode() + account.password = base64_password_hashed + account.password_salt = base64_salt + AccountService.reset_login_error_rate_limit(email) + click.echo(click.style("Password reset successfully.", fg="green")) @click.command("reset-email", help="Reset the account email.") @@ -100,22 +100,21 @@ def reset_email(email, new_email, email_confirm): if str(new_email).strip() != str(email_confirm).strip(): click.echo(click.style("New emails do not match.", fg="red")) return + with sessionmaker(db.engine, expire_on_commit=False).begin() as session: + account = session.query(Account).where(Account.email == email).one_or_none() - account = db.session.query(Account).where(Account.email == email).one_or_none() + if not account: + click.echo(click.style(f"Account not found for email: {email}", fg="red")) + return - if not account: - click.echo(click.style(f"Account not found for email: {email}", fg="red")) - return + try: + email_validate(new_email) + except: + click.echo(click.style(f"Invalid email: {new_email}", fg="red")) + return - try: - email_validate(new_email) - except: - click.echo(click.style(f"Invalid email: {new_email}", fg="red")) - return - - account.email = new_email - db.session.commit() - click.echo(click.style("Email updated successfully.", fg="green")) + account.email = new_email + click.echo(click.style("Email updated successfully.", fg="green")) @click.command( @@ -139,25 +138,24 @@ def reset_encrypt_key_pair(): if dify_config.EDITION != "SELF_HOSTED": click.echo(click.style("This command is only for SELF_HOSTED installations.", fg="red")) return + with sessionmaker(db.engine, expire_on_commit=False).begin() as session: + tenants = session.query(Tenant).all() + for tenant in tenants: + if not tenant: + click.echo(click.style("No workspaces found. Run /install first.", fg="red")) + return - tenants = db.session.query(Tenant).all() - for tenant in tenants: - if not tenant: - click.echo(click.style("No workspaces found. Run /install first.", fg="red")) - return + tenant.encrypt_public_key = generate_key_pair(tenant.id) - tenant.encrypt_public_key = generate_key_pair(tenant.id) + session.query(Provider).where(Provider.provider_type == "custom", Provider.tenant_id == tenant.id).delete() + session.query(ProviderModel).where(ProviderModel.tenant_id == tenant.id).delete() - db.session.query(Provider).where(Provider.provider_type == "custom", Provider.tenant_id == tenant.id).delete() - db.session.query(ProviderModel).where(ProviderModel.tenant_id == tenant.id).delete() - db.session.commit() - - click.echo( - click.style( - f"Congratulations! The asymmetric key pair of workspace {tenant.id} has been reset.", - fg="green", + click.echo( + click.style( + f"Congratulations! The asymmetric key pair of workspace {tenant.id} has been reset.", + fg="green", + ) ) - ) @click.command("vdb-migrate", help="Migrate vector db.") @@ -182,14 +180,15 @@ def migrate_annotation_vector_database(): try: # get apps info per_page = 50 - apps = ( - db.session.query(App) - .where(App.status == "normal") - .order_by(App.created_at.desc()) - .limit(per_page) - .offset((page - 1) * per_page) - .all() - ) + with sessionmaker(db.engine, expire_on_commit=False).begin() as session: + apps = ( + session.query(App) + .where(App.status == "normal") + .order_by(App.created_at.desc()) + .limit(per_page) + .offset((page - 1) * per_page) + .all() + ) if not apps: break except SQLAlchemyError: @@ -203,26 +202,27 @@ def migrate_annotation_vector_database(): ) try: click.echo(f"Creating app annotation index: {app.id}") - app_annotation_setting = ( - db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app.id).first() - ) + with sessionmaker(db.engine, expire_on_commit=False).begin() as session: + app_annotation_setting = ( + session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app.id).first() + ) - if not app_annotation_setting: - skipped_count = skipped_count + 1 - click.echo(f"App annotation setting disabled: {app.id}") - continue - # get dataset_collection_binding info - dataset_collection_binding = ( - db.session.query(DatasetCollectionBinding) - .where(DatasetCollectionBinding.id == app_annotation_setting.collection_binding_id) - .first() - ) - if not dataset_collection_binding: - click.echo(f"App annotation collection binding not found: {app.id}") - continue - annotations = db.session.scalars( - select(MessageAnnotation).where(MessageAnnotation.app_id == app.id) - ).all() + if not app_annotation_setting: + skipped_count = skipped_count + 1 + click.echo(f"App annotation setting disabled: {app.id}") + continue + # get dataset_collection_binding info + dataset_collection_binding = ( + session.query(DatasetCollectionBinding) + .where(DatasetCollectionBinding.id == app_annotation_setting.collection_binding_id) + .first() + ) + if not dataset_collection_binding: + click.echo(f"App annotation collection binding not found: {app.id}") + continue + annotations = session.scalars( + select(MessageAnnotation).where(MessageAnnotation.app_id == app.id) + ).all() dataset = Dataset( id=app.id, tenant_id=app.tenant_id, @@ -1448,41 +1448,52 @@ def transform_datasource_credentials(): notion_credentials_tenant_mapping[tenant_id] = [] notion_credentials_tenant_mapping[tenant_id].append(notion_credential) for tenant_id, notion_tenant_credentials in notion_credentials_tenant_mapping.items(): - # check notion plugin is installed - installed_plugins = installer_manager.list_plugins(tenant_id) - installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins] - if notion_plugin_id not in installed_plugins_ids: - if notion_plugin_unique_identifier: - # install notion plugin - PluginService.install_from_marketplace_pkg(tenant_id, [notion_plugin_unique_identifier]) - auth_count = 0 - for notion_tenant_credential in notion_tenant_credentials: - auth_count += 1 - # get credential oauth params - access_token = notion_tenant_credential.access_token - # notion info - notion_info = notion_tenant_credential.source_info - workspace_id = notion_info.get("workspace_id") - workspace_name = notion_info.get("workspace_name") - workspace_icon = notion_info.get("workspace_icon") - new_credentials = { - "integration_secret": encrypter.encrypt_token(tenant_id, access_token), - "workspace_id": workspace_id, - "workspace_name": workspace_name, - "workspace_icon": workspace_icon, - } - datasource_provider = DatasourceProvider( - provider="notion_datasource", - tenant_id=tenant_id, - plugin_id=notion_plugin_id, - auth_type=oauth_credential_type.value, - encrypted_credentials=new_credentials, - name=f"Auth {auth_count}", - avatar_url=workspace_icon or "default", - is_default=False, + tenant = db.session.query(Tenant).filter_by(id=tenant_id).first() + if not tenant: + continue + try: + # check notion plugin is installed + installed_plugins = installer_manager.list_plugins(tenant_id) + installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins] + if notion_plugin_id not in installed_plugins_ids: + if notion_plugin_unique_identifier: + # install notion plugin + PluginService.install_from_marketplace_pkg(tenant_id, [notion_plugin_unique_identifier]) + auth_count = 0 + for notion_tenant_credential in notion_tenant_credentials: + auth_count += 1 + # get credential oauth params + access_token = notion_tenant_credential.access_token + # notion info + notion_info = notion_tenant_credential.source_info + workspace_id = notion_info.get("workspace_id") + workspace_name = notion_info.get("workspace_name") + workspace_icon = notion_info.get("workspace_icon") + new_credentials = { + "integration_secret": encrypter.encrypt_token(tenant_id, access_token), + "workspace_id": workspace_id, + "workspace_name": workspace_name, + "workspace_icon": workspace_icon, + } + datasource_provider = DatasourceProvider( + provider="notion_datasource", + tenant_id=tenant_id, + plugin_id=notion_plugin_id, + auth_type=oauth_credential_type.value, + encrypted_credentials=new_credentials, + name=f"Auth {auth_count}", + avatar_url=workspace_icon or "default", + is_default=False, + ) + db.session.add(datasource_provider) + deal_notion_count += 1 + except Exception as e: + click.echo( + click.style( + f"Error transforming notion credentials: {str(e)}, tenant_id: {tenant_id}", fg="red" + ) ) - db.session.add(datasource_provider) - deal_notion_count += 1 + continue db.session.commit() # deal firecrawl credentials deal_firecrawl_count = 0 @@ -1495,37 +1506,48 @@ def transform_datasource_credentials(): firecrawl_credentials_tenant_mapping[tenant_id] = [] firecrawl_credentials_tenant_mapping[tenant_id].append(firecrawl_credential) for tenant_id, firecrawl_tenant_credentials in firecrawl_credentials_tenant_mapping.items(): - # check firecrawl plugin is installed - installed_plugins = installer_manager.list_plugins(tenant_id) - installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins] - if firecrawl_plugin_id not in installed_plugins_ids: - if firecrawl_plugin_unique_identifier: - # install firecrawl plugin - PluginService.install_from_marketplace_pkg(tenant_id, [firecrawl_plugin_unique_identifier]) + tenant = db.session.query(Tenant).filter_by(id=tenant_id).first() + if not tenant: + continue + try: + # check firecrawl plugin is installed + installed_plugins = installer_manager.list_plugins(tenant_id) + installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins] + if firecrawl_plugin_id not in installed_plugins_ids: + if firecrawl_plugin_unique_identifier: + # install firecrawl plugin + PluginService.install_from_marketplace_pkg(tenant_id, [firecrawl_plugin_unique_identifier]) - auth_count = 0 - for firecrawl_tenant_credential in firecrawl_tenant_credentials: - auth_count += 1 - # get credential api key - credentials_json = json.loads(firecrawl_tenant_credential.credentials) - api_key = credentials_json.get("config", {}).get("api_key") - base_url = credentials_json.get("config", {}).get("base_url") - new_credentials = { - "firecrawl_api_key": api_key, - "base_url": base_url, - } - datasource_provider = DatasourceProvider( - provider="firecrawl", - tenant_id=tenant_id, - plugin_id=firecrawl_plugin_id, - auth_type=api_key_credential_type.value, - encrypted_credentials=new_credentials, - name=f"Auth {auth_count}", - avatar_url="default", - is_default=False, + auth_count = 0 + for firecrawl_tenant_credential in firecrawl_tenant_credentials: + auth_count += 1 + # get credential api key + credentials_json = json.loads(firecrawl_tenant_credential.credentials) + api_key = credentials_json.get("config", {}).get("api_key") + base_url = credentials_json.get("config", {}).get("base_url") + new_credentials = { + "firecrawl_api_key": api_key, + "base_url": base_url, + } + datasource_provider = DatasourceProvider( + provider="firecrawl", + tenant_id=tenant_id, + plugin_id=firecrawl_plugin_id, + auth_type=api_key_credential_type.value, + encrypted_credentials=new_credentials, + name=f"Auth {auth_count}", + avatar_url="default", + is_default=False, + ) + db.session.add(datasource_provider) + deal_firecrawl_count += 1 + except Exception as e: + click.echo( + click.style( + f"Error transforming firecrawl credentials: {str(e)}, tenant_id: {tenant_id}", fg="red" + ) ) - db.session.add(datasource_provider) - deal_firecrawl_count += 1 + continue db.session.commit() # deal jina credentials deal_jina_count = 0 @@ -1538,36 +1560,45 @@ def transform_datasource_credentials(): jina_credentials_tenant_mapping[tenant_id] = [] jina_credentials_tenant_mapping[tenant_id].append(jina_credential) for tenant_id, jina_tenant_credentials in jina_credentials_tenant_mapping.items(): - # check jina plugin is installed - installed_plugins = installer_manager.list_plugins(tenant_id) - installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins] - if jina_plugin_id not in installed_plugins_ids: - if jina_plugin_unique_identifier: - # install jina plugin - logger.debug("Installing Jina plugin %s", jina_plugin_unique_identifier) - PluginService.install_from_marketplace_pkg(tenant_id, [jina_plugin_unique_identifier]) + tenant = db.session.query(Tenant).filter_by(id=tenant_id).first() + if not tenant: + continue + try: + # check jina plugin is installed + installed_plugins = installer_manager.list_plugins(tenant_id) + installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins] + if jina_plugin_id not in installed_plugins_ids: + if jina_plugin_unique_identifier: + # install jina plugin + logger.debug("Installing Jina plugin %s", jina_plugin_unique_identifier) + PluginService.install_from_marketplace_pkg(tenant_id, [jina_plugin_unique_identifier]) - auth_count = 0 - for jina_tenant_credential in jina_tenant_credentials: - auth_count += 1 - # get credential api key - credentials_json = json.loads(jina_tenant_credential.credentials) - api_key = credentials_json.get("config", {}).get("api_key") - new_credentials = { - "integration_secret": api_key, - } - datasource_provider = DatasourceProvider( - provider="jina", - tenant_id=tenant_id, - plugin_id=jina_plugin_id, - auth_type=api_key_credential_type.value, - encrypted_credentials=new_credentials, - name=f"Auth {auth_count}", - avatar_url="default", - is_default=False, + auth_count = 0 + for jina_tenant_credential in jina_tenant_credentials: + auth_count += 1 + # get credential api key + credentials_json = json.loads(jina_tenant_credential.credentials) + api_key = credentials_json.get("config", {}).get("api_key") + new_credentials = { + "integration_secret": api_key, + } + datasource_provider = DatasourceProvider( + provider="jina", + tenant_id=tenant_id, + plugin_id=jina_plugin_id, + auth_type=api_key_credential_type.value, + encrypted_credentials=new_credentials, + name=f"Auth {auth_count}", + avatar_url="default", + is_default=False, + ) + db.session.add(datasource_provider) + deal_jina_count += 1 + except Exception as e: + click.echo( + click.style(f"Error transforming jina credentials: {str(e)}, tenant_id: {tenant_id}", fg="red") ) - db.session.add(datasource_provider) - deal_jina_count += 1 + continue db.session.commit() except Exception as e: click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red")) diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py index b17f30210c..5b871f69f9 100644 --- a/api/configs/feature/__init__.py +++ b/api/configs/feature/__init__.py @@ -113,6 +113,21 @@ class CodeExecutionSandboxConfig(BaseSettings): default=10.0, ) + CODE_EXECUTION_POOL_MAX_CONNECTIONS: PositiveInt = Field( + description="Maximum number of concurrent connections for the code execution HTTP client", + default=100, + ) + + CODE_EXECUTION_POOL_MAX_KEEPALIVE_CONNECTIONS: PositiveInt = Field( + description="Maximum number of persistent keep-alive connections for the code execution HTTP client", + default=20, + ) + + CODE_EXECUTION_POOL_KEEPALIVE_EXPIRY: PositiveFloat | None = Field( + description="Keep-alive expiry in seconds for idle connections (set to None to disable)", + default=5.0, + ) + CODE_MAX_NUMBER: PositiveInt = Field( description="Maximum allowed numeric value in code execution", default=9223372036854775807, @@ -135,7 +150,7 @@ class CodeExecutionSandboxConfig(BaseSettings): CODE_MAX_STRING_LENGTH: PositiveInt = Field( description="Maximum allowed length for strings in code execution", - default=80000, + default=400_000, ) CODE_MAX_STRING_ARRAY_LENGTH: PositiveInt = Field( @@ -153,6 +168,11 @@ class CodeExecutionSandboxConfig(BaseSettings): default=1000, ) + CODE_EXECUTION_SSL_VERIFY: bool = Field( + description="Enable or disable SSL verification for code execution requests", + default=True, + ) + class PluginConfig(BaseSettings): """ @@ -342,11 +362,11 @@ class HttpConfig(BaseSettings): ) HTTP_REQUEST_MAX_READ_TIMEOUT: int = Field( - ge=1, description="Maximum read timeout in seconds for HTTP requests", default=60 + ge=1, description="Maximum read timeout in seconds for HTTP requests", default=600 ) HTTP_REQUEST_MAX_WRITE_TIMEOUT: int = Field( - ge=1, description="Maximum write timeout in seconds for HTTP requests", default=20 + ge=1, description="Maximum write timeout in seconds for HTTP requests", default=600 ) HTTP_REQUEST_NODE_MAX_BINARY_SIZE: PositiveInt = Field( @@ -404,6 +424,21 @@ class HttpConfig(BaseSettings): default=5, ) + SSRF_POOL_MAX_CONNECTIONS: PositiveInt = Field( + description="Maximum number of concurrent connections for the SSRF HTTP client", + default=100, + ) + + SSRF_POOL_MAX_KEEPALIVE_CONNECTIONS: PositiveInt = Field( + description="Maximum number of persistent keep-alive connections for the SSRF HTTP client", + default=20, + ) + + SSRF_POOL_KEEPALIVE_EXPIRY: PositiveFloat | None = Field( + description="Keep-alive expiry in seconds for idle SSRF connections (set to None to disable)", + default=5.0, + ) + RESPECT_XFORWARD_HEADERS_ENABLED: bool = Field( description="Enable handling of X-Forwarded-For, X-Forwarded-Proto, and X-Forwarded-Port headers" " when the app is behind a single trusted reverse proxy.", @@ -542,16 +577,16 @@ class WorkflowConfig(BaseSettings): default=5, ) - WORKFLOW_PARALLEL_DEPTH_LIMIT: PositiveInt = Field( - description="Maximum allowed depth for nested parallel executions", - default=3, - ) - MAX_VARIABLE_SIZE: PositiveInt = Field( description="Maximum size in bytes for a single variable in workflows. Default to 200 KB.", default=200 * 1024, ) + TEMPLATE_TRANSFORM_MAX_LENGTH: PositiveInt = Field( + description="Maximum number of characters allowed in Template Transform node output", + default=400_000, + ) + # GraphEngine Worker Pool Configuration GRAPH_ENGINE_MIN_WORKERS: PositiveInt = Field( description="Minimum number of workers per GraphEngine instance", @@ -736,7 +771,7 @@ class MailConfig(BaseSettings): MAIL_TEMPLATING_TIMEOUT: int = Field( description=""" - Timeout for email templating in seconds. Used to prevent infinite loops in malicious templates. + Timeout for email templating in seconds. Used to prevent infinite loops in malicious templates. Only available in sandbox mode.""", default=3, ) diff --git a/api/configs/middleware/__init__.py b/api/configs/middleware/__init__.py index 62b3cc9842..d872e8201b 100644 --- a/api/configs/middleware/__init__.py +++ b/api/configs/middleware/__init__.py @@ -18,6 +18,7 @@ from .storage.opendal_storage_config import OpenDALStorageConfig from .storage.supabase_storage_config import SupabaseStorageConfig from .storage.tencent_cos_storage_config import TencentCloudCOSStorageConfig from .storage.volcengine_tos_storage_config import VolcengineTOSStorageConfig +from .vdb.alibabacloud_mysql_config import AlibabaCloudMySQLConfig from .vdb.analyticdb_config import AnalyticdbConfig from .vdb.baidu_vector_config import BaiduVectorDBConfig from .vdb.chroma_config import ChromaConfig @@ -330,6 +331,7 @@ class MiddlewareConfig( ClickzettaConfig, HuaweiCloudConfig, MilvusConfig, + AlibabaCloudMySQLConfig, MyScaleConfig, OpenSearchConfig, OracleConfig, diff --git a/api/configs/middleware/vdb/alibabacloud_mysql_config.py b/api/configs/middleware/vdb/alibabacloud_mysql_config.py new file mode 100644 index 0000000000..a76400ed1c --- /dev/null +++ b/api/configs/middleware/vdb/alibabacloud_mysql_config.py @@ -0,0 +1,54 @@ +from pydantic import Field, PositiveInt +from pydantic_settings import BaseSettings + + +class AlibabaCloudMySQLConfig(BaseSettings): + """ + Configuration settings for AlibabaCloud MySQL vector database + """ + + ALIBABACLOUD_MYSQL_HOST: str = Field( + description="Hostname or IP address of the AlibabaCloud MySQL server (e.g., 'localhost' or 'mysql.aliyun.com')", + default="localhost", + ) + + ALIBABACLOUD_MYSQL_PORT: PositiveInt = Field( + description="Port number on which the AlibabaCloud MySQL server is listening (default is 3306)", + default=3306, + ) + + ALIBABACLOUD_MYSQL_USER: str = Field( + description="Username for authenticating with AlibabaCloud MySQL (default is 'root')", + default="root", + ) + + ALIBABACLOUD_MYSQL_PASSWORD: str = Field( + description="Password for authenticating with AlibabaCloud MySQL (default is an empty string)", + default="", + ) + + ALIBABACLOUD_MYSQL_DATABASE: str = Field( + description="Name of the AlibabaCloud MySQL database to connect to (default is 'dify')", + default="dify", + ) + + ALIBABACLOUD_MYSQL_MAX_CONNECTION: PositiveInt = Field( + description="Maximum number of connections in the connection pool", + default=5, + ) + + ALIBABACLOUD_MYSQL_CHARSET: str = Field( + description="Character set for AlibabaCloud MySQL connection (default is 'utf8mb4')", + default="utf8mb4", + ) + + ALIBABACLOUD_MYSQL_DISTANCE_FUNCTION: str = Field( + description="Distance function used for vector similarity search in AlibabaCloud MySQL " + "(e.g., 'cosine', 'euclidean')", + default="cosine", + ) + + ALIBABACLOUD_MYSQL_HNSW_M: PositiveInt = Field( + description="Maximum number of connections per layer for HNSW vector index (default is 6, range: 3-200)", + default=6, + ) diff --git a/api/configs/middleware/vdb/oceanbase_config.py b/api/configs/middleware/vdb/oceanbase_config.py index 99f4c49407..7c9376f86b 100644 --- a/api/configs/middleware/vdb/oceanbase_config.py +++ b/api/configs/middleware/vdb/oceanbase_config.py @@ -40,8 +40,12 @@ class OceanBaseVectorConfig(BaseSettings): OCEANBASE_FULLTEXT_PARSER: str | None = Field( description=( - "Fulltext parser to use for text indexing. Options: 'japanese_ftparser' (Japanese), " - "'thai_ftparser' (Thai), 'ik' (Chinese). Default is 'ik'" + "Fulltext parser to use for text indexing. " + "Built-in options: 'ngram' (N-gram tokenizer for English/numbers), " + "'beng' (Basic English tokenizer), 'space' (Space-based tokenizer), " + "'ngram2' (Improved N-gram tokenizer), 'ik' (Chinese tokenizer). " + "External plugins (require installation): 'japanese_ftparser' (Japanese tokenizer), " + "'thai_ftparser' (Thai tokenizer). Default is 'ik'" ), default="ik", ) diff --git a/api/configs/middleware/vdb/opensearch_config.py b/api/configs/middleware/vdb/opensearch_config.py index ba015a6eb9..a7d712545e 100644 --- a/api/configs/middleware/vdb/opensearch_config.py +++ b/api/configs/middleware/vdb/opensearch_config.py @@ -1,23 +1,24 @@ -from enum import Enum +from enum import StrEnum from typing import Literal from pydantic import Field, PositiveInt from pydantic_settings import BaseSettings +class AuthMethod(StrEnum): + """ + Authentication method for OpenSearch + """ + + BASIC = "basic" + AWS_MANAGED_IAM = "aws_managed_iam" + + class OpenSearchConfig(BaseSettings): """ Configuration settings for OpenSearch """ - class AuthMethod(Enum): - """ - Authentication method for OpenSearch - """ - - BASIC = "basic" - AWS_MANAGED_IAM = "aws_managed_iam" - OPENSEARCH_HOST: str | None = Field( description="Hostname or IP address of the OpenSearch server (e.g., 'localhost' or 'opensearch.example.com')", default=None, diff --git a/api/configs/remote_settings_sources/nacos/http_request.py b/api/configs/remote_settings_sources/nacos/http_request.py index 6401c5830d..1a0744a21b 100644 --- a/api/configs/remote_settings_sources/nacos/http_request.py +++ b/api/configs/remote_settings_sources/nacos/http_request.py @@ -5,7 +5,7 @@ import logging import os import time -import requests +import httpx logger = logging.getLogger(__name__) @@ -30,10 +30,10 @@ class NacosHttpClient: params = {} try: self._inject_auth_info(headers, params) - response = requests.request(method, url="http://" + self.server + url, headers=headers, params=params) + response = httpx.request(method, url="http://" + self.server + url, headers=headers, params=params) response.raise_for_status() return response.text - except requests.RequestException as e: + except httpx.RequestError as e: return f"Request to Nacos failed: {e}" def _inject_auth_info(self, headers: dict[str, str], params: dict[str, str], module: str = "config") -> None: @@ -78,7 +78,7 @@ class NacosHttpClient: params = {"username": self.username, "password": self.password} url = "http://" + self.server + "/nacos/v1/auth/login" try: - resp = requests.request("POST", url, headers=None, params=params) + resp = httpx.request("POST", url, headers=None, params=params) resp.raise_for_status() response_data = resp.json() self.token = response_data.get("accessToken") diff --git a/api/constants/__init__.py b/api/constants/__init__.py index fe8f4f8785..9141fbea95 100644 --- a/api/constants/__init__.py +++ b/api/constants/__init__.py @@ -1,4 +1,5 @@ from configs import dify_config +from libs.collection_utils import convert_to_lower_and_upper_set HIDDEN_VALUE = "[__HIDDEN__]" UNKNOWN_VALUE = "[__UNKNOWN__]" @@ -6,24 +7,39 @@ UUID_NIL = "00000000-0000-0000-0000-000000000000" DEFAULT_FILE_NUMBER_LIMITS = 3 -IMAGE_EXTENSIONS = ["jpg", "jpeg", "png", "webp", "gif", "svg"] -IMAGE_EXTENSIONS.extend([ext.upper() for ext in IMAGE_EXTENSIONS]) +IMAGE_EXTENSIONS = convert_to_lower_and_upper_set({"jpg", "jpeg", "png", "webp", "gif", "svg"}) -VIDEO_EXTENSIONS = ["mp4", "mov", "mpeg", "webm"] -VIDEO_EXTENSIONS.extend([ext.upper() for ext in VIDEO_EXTENSIONS]) +VIDEO_EXTENSIONS = convert_to_lower_and_upper_set({"mp4", "mov", "mpeg", "webm"}) -AUDIO_EXTENSIONS = ["mp3", "m4a", "wav", "amr", "mpga"] -AUDIO_EXTENSIONS.extend([ext.upper() for ext in AUDIO_EXTENSIONS]) +AUDIO_EXTENSIONS = convert_to_lower_and_upper_set({"mp3", "m4a", "wav", "amr", "mpga"}) - -_doc_extensions: list[str] +_doc_extensions: set[str] if dify_config.ETL_TYPE == "Unstructured": - _doc_extensions = ["txt", "markdown", "md", "mdx", "pdf", "html", "htm", "xlsx", "xls", "vtt", "properties"] - _doc_extensions.extend(("doc", "docx", "csv", "eml", "msg", "pptx", "xml", "epub")) + _doc_extensions = { + "txt", + "markdown", + "md", + "mdx", + "pdf", + "html", + "htm", + "xlsx", + "xls", + "vtt", + "properties", + "doc", + "docx", + "csv", + "eml", + "msg", + "pptx", + "xml", + "epub", + } if dify_config.UNSTRUCTURED_API_URL: - _doc_extensions.append("ppt") + _doc_extensions.add("ppt") else: - _doc_extensions = [ + _doc_extensions = { "txt", "markdown", "md", @@ -37,5 +53,5 @@ else: "csv", "vtt", "properties", - ] -DOCUMENT_EXTENSIONS = _doc_extensions + [ext.upper() for ext in _doc_extensions] + } +DOCUMENT_EXTENSIONS: set[str] = convert_to_lower_and_upper_set(_doc_extensions) diff --git a/api/controllers/console/__init__.py b/api/controllers/console/__init__.py index ee02ff3937..621f5066e4 100644 --- a/api/controllers/console/__init__.py +++ b/api/controllers/console/__init__.py @@ -1,31 +1,10 @@ +from importlib import import_module + from flask import Blueprint from flask_restx import Namespace from libs.external_api import ExternalApi -from .app.app_import import AppImportApi, AppImportCheckDependenciesApi, AppImportConfirmApi -from .explore.audio import ChatAudioApi, ChatTextApi -from .explore.completion import ChatApi, ChatStopApi, CompletionApi, CompletionStopApi -from .explore.conversation import ( - ConversationApi, - ConversationListApi, - ConversationPinApi, - ConversationRenameApi, - ConversationUnPinApi, -) -from .explore.message import ( - MessageFeedbackApi, - MessageListApi, - MessageMoreLikeThisApi, - MessageSuggestedQuestionApi, -) -from .explore.workflow import ( - InstalledAppWorkflowRunApi, - InstalledAppWorkflowTaskStopApi, -) -from .files import FileApi, FilePreviewApi, FileSupportTypeApi -from .remote_files import RemoteFileInfoApi, RemoteFileUploadApi - bp = Blueprint("console", __name__, url_prefix="/console/api") api = ExternalApi( @@ -35,23 +14,23 @@ api = ExternalApi( description="Console management APIs for app configuration, monitoring, and administration", ) -# Create namespace console_ns = Namespace("console", description="Console management API operations", path="/") -# File -api.add_resource(FileApi, "/files/upload") -api.add_resource(FilePreviewApi, "/files//preview") -api.add_resource(FileSupportTypeApi, "/files/support-type") +RESOURCE_MODULES = ( + "controllers.console.app.app_import", + "controllers.console.explore.audio", + "controllers.console.explore.completion", + "controllers.console.explore.conversation", + "controllers.console.explore.message", + "controllers.console.explore.workflow", + "controllers.console.files", + "controllers.console.remote_files", +) -# Remote files -api.add_resource(RemoteFileInfoApi, "/remote-files/") -api.add_resource(RemoteFileUploadApi, "/remote-files/upload") - -# Import App -api.add_resource(AppImportApi, "/apps/imports") -api.add_resource(AppImportConfirmApi, "/apps/imports//confirm") -api.add_resource(AppImportCheckDependenciesApi, "/apps/imports//check-dependencies") +for module_name in RESOURCE_MODULES: + import_module(module_name) +# Ensure resource modules are imported so route decorators are evaluated. # Import other controllers from . import ( admin, @@ -150,77 +129,6 @@ from .workspace import ( workspace, ) -# Explore Audio -api.add_resource(ChatAudioApi, "/installed-apps//audio-to-text", endpoint="installed_app_audio") -api.add_resource(ChatTextApi, "/installed-apps//text-to-audio", endpoint="installed_app_text") - -# Explore Completion -api.add_resource( - CompletionApi, "/installed-apps//completion-messages", endpoint="installed_app_completion" -) -api.add_resource( - CompletionStopApi, - "/installed-apps//completion-messages//stop", - endpoint="installed_app_stop_completion", -) -api.add_resource( - ChatApi, "/installed-apps//chat-messages", endpoint="installed_app_chat_completion" -) -api.add_resource( - ChatStopApi, - "/installed-apps//chat-messages//stop", - endpoint="installed_app_stop_chat_completion", -) - -# Explore Conversation -api.add_resource( - ConversationRenameApi, - "/installed-apps//conversations//name", - endpoint="installed_app_conversation_rename", -) -api.add_resource( - ConversationListApi, "/installed-apps//conversations", endpoint="installed_app_conversations" -) -api.add_resource( - ConversationApi, - "/installed-apps//conversations/", - endpoint="installed_app_conversation", -) -api.add_resource( - ConversationPinApi, - "/installed-apps//conversations//pin", - endpoint="installed_app_conversation_pin", -) -api.add_resource( - ConversationUnPinApi, - "/installed-apps//conversations//unpin", - endpoint="installed_app_conversation_unpin", -) - - -# Explore Message -api.add_resource(MessageListApi, "/installed-apps//messages", endpoint="installed_app_messages") -api.add_resource( - MessageFeedbackApi, - "/installed-apps//messages//feedbacks", - endpoint="installed_app_message_feedback", -) -api.add_resource( - MessageMoreLikeThisApi, - "/installed-apps//messages//more-like-this", - endpoint="installed_app_more_like_this", -) -api.add_resource( - MessageSuggestedQuestionApi, - "/installed-apps//messages//suggested-questions", - endpoint="installed_app_suggested_question", -) -# Explore Workflow -api.add_resource(InstalledAppWorkflowRunApi, "/installed-apps//workflows/run") -api.add_resource( - InstalledAppWorkflowTaskStopApi, "/installed-apps//workflows/tasks//stop" -) - api.add_namespace(console_ns) __all__ = [ diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index 2d2e4b448a..3927685af3 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -19,6 +19,7 @@ from core.ops.ops_trace_manager import OpsTraceManager from extensions.ext_database import db from fields.app_fields import app_detail_fields, app_detail_fields_with_site, app_pagination_fields from libs.login import login_required +from libs.validators import validate_description_length from models import Account, App from services.app_dsl_service import AppDslService, ImportMode from services.app_service import AppService @@ -28,12 +29,6 @@ from services.feature_service import FeatureService ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "completion"] -def _validate_description_length(description): - if description and len(description) > 400: - raise ValueError("Description cannot exceed 400 characters.") - return description - - @console_ns.route("/apps") class AppListApi(Resource): @api.doc("list_apps") @@ -138,7 +133,7 @@ class AppListApi(Resource): """Create app""" parser = reqparse.RequestParser() parser.add_argument("name", type=str, required=True, location="json") - parser.add_argument("description", type=_validate_description_length, location="json") + parser.add_argument("description", type=validate_description_length, location="json") parser.add_argument("mode", type=str, choices=ALLOW_CREATE_APP_MODES, location="json") parser.add_argument("icon_type", type=str, location="json") parser.add_argument("icon", type=str, location="json") @@ -219,7 +214,7 @@ class AppApi(Resource): parser = reqparse.RequestParser() parser.add_argument("name", type=str, required=True, nullable=False, location="json") - parser.add_argument("description", type=_validate_description_length, location="json") + parser.add_argument("description", type=validate_description_length, location="json") parser.add_argument("icon_type", type=str, location="json") parser.add_argument("icon", type=str, location="json") parser.add_argument("icon_background", type=str, location="json") @@ -297,7 +292,7 @@ class AppCopyApi(Resource): parser = reqparse.RequestParser() parser.add_argument("name", type=str, location="json") - parser.add_argument("description", type=_validate_description_length, location="json") + parser.add_argument("description", type=validate_description_length, location="json") parser.add_argument("icon_type", type=str, location="json") parser.add_argument("icon", type=str, location="json") parser.add_argument("icon_background", type=str, location="json") @@ -309,7 +304,7 @@ class AppCopyApi(Resource): account = cast(Account, current_user) result = import_service.import_app( account=account, - import_mode=ImportMode.YAML_CONTENT.value, + import_mode=ImportMode.YAML_CONTENT, yaml_content=yaml_content, name=args.get("name"), description=args.get("description"), diff --git a/api/controllers/console/app/app_import.py b/api/controllers/console/app/app_import.py index aee93a8814..037561cfed 100644 --- a/api/controllers/console/app/app_import.py +++ b/api/controllers/console/app/app_import.py @@ -20,7 +20,10 @@ from services.app_dsl_service import AppDslService, ImportStatus from services.enterprise.enterprise_service import EnterpriseService from services.feature_service import FeatureService +from .. import console_ns + +@console_ns.route("/apps/imports") class AppImportApi(Resource): @setup_required @login_required @@ -67,13 +70,14 @@ class AppImportApi(Resource): EnterpriseService.WebAppAuth.update_app_access_mode(result.app_id, "private") # Return appropriate status code based on result status = result.status - if status == ImportStatus.FAILED.value: + if status == ImportStatus.FAILED: return result.model_dump(mode="json"), 400 - elif status == ImportStatus.PENDING.value: + elif status == ImportStatus.PENDING: return result.model_dump(mode="json"), 202 return result.model_dump(mode="json"), 200 +@console_ns.route("/apps/imports//confirm") class AppImportConfirmApi(Resource): @setup_required @login_required @@ -93,11 +97,12 @@ class AppImportConfirmApi(Resource): session.commit() # Return appropriate status code based on result - if result.status == ImportStatus.FAILED.value: + if result.status == ImportStatus.FAILED: return result.model_dump(mode="json"), 400 return result.model_dump(mode="json"), 200 +@console_ns.route("/apps/imports//check-dependencies") class AppImportCheckDependenciesApi(Resource): @setup_required @login_required diff --git a/api/controllers/console/app/conversation.py b/api/controllers/console/app/conversation.py index c0cbf6613e..3b8dff613b 100644 --- a/api/controllers/console/app/conversation.py +++ b/api/controllers/console/app/conversation.py @@ -1,6 +1,7 @@ from datetime import datetime import pytz # pip install pytz +import sqlalchemy as sa from flask_login import current_user from flask_restx import Resource, marshal_with, reqparse from flask_restx.inputs import int_range @@ -70,7 +71,7 @@ class CompletionConversationApi(Resource): parser.add_argument("limit", type=int_range(1, 100), default=20, location="args") args = parser.parse_args() - query = db.select(Conversation).where( + query = sa.select(Conversation).where( Conversation.app_id == app_model.id, Conversation.mode == "completion", Conversation.is_deleted.is_(False) ) @@ -236,7 +237,7 @@ class ChatConversationApi(Resource): .subquery() ) - query = db.select(Conversation).where(Conversation.app_id == app_model.id, Conversation.is_deleted.is_(False)) + query = sa.select(Conversation).where(Conversation.app_id == app_model.id, Conversation.is_deleted.is_(False)) if args["keyword"]: keyword_filter = f"%{args['keyword']}%" @@ -308,7 +309,7 @@ class ChatConversationApi(Resource): ) if app_model.mode == AppMode.ADVANCED_CHAT: - query = query.where(Conversation.invoke_from != InvokeFrom.DEBUGGER.value) + query = query.where(Conversation.invoke_from != InvokeFrom.DEBUGGER) match args["sort_by"]: case "created_at": diff --git a/api/controllers/console/app/model_config.py b/api/controllers/console/app/model_config.py index 11df511840..fa6e3f8738 100644 --- a/api/controllers/console/app/model_config.py +++ b/api/controllers/console/app/model_config.py @@ -14,6 +14,7 @@ from core.tools.tool_manager import ToolManager from core.tools.utils.configuration import ToolParameterConfigurationManager from events.app_event import app_model_config_was_updated from extensions.ext_database import db +from libs.datetime_utils import naive_utc_now from libs.login import login_required from models.account import Account from models.model import AppMode, AppModelConfig @@ -90,7 +91,7 @@ class ModelConfigResource(Resource): if not isinstance(tool, dict) or len(tool.keys()) <= 3: continue - agent_tool_entity = AgentToolEntity(**tool) + agent_tool_entity = AgentToolEntity.model_validate(tool) # get tool try: tool_runtime = ToolManager.get_agent_tool_runtime( @@ -124,7 +125,7 @@ class ModelConfigResource(Resource): # encrypt agent tool parameters if it's secret-input agent_mode = new_app_model_config.agent_mode_dict for tool in agent_mode.get("tools") or []: - agent_tool_entity = AgentToolEntity(**tool) + agent_tool_entity = AgentToolEntity.model_validate(tool) # get tool key = f"{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}" @@ -172,6 +173,8 @@ class ModelConfigResource(Resource): db.session.flush() app_model.app_model_config_id = new_app_model_config.id + app_model.updated_by = current_user.id + app_model.updated_at = naive_utc_now() db.session.commit() app_model_config_was_updated.send(app_model, app_model_config=new_app_model_config) diff --git a/api/controllers/console/app/statistic.py b/api/controllers/console/app/statistic.py index 6471b843c6..5974395c6a 100644 --- a/api/controllers/console/app/statistic.py +++ b/api/controllers/console/app/statistic.py @@ -52,7 +52,7 @@ FROM WHERE app_id = :app_id AND invoke_from != :invoke_from""" - arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value} + arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER} timezone = pytz.timezone(account.timezone) utc_timezone = pytz.utc @@ -127,7 +127,7 @@ class DailyConversationStatistic(Resource): sa.func.count(sa.distinct(Message.conversation_id)).label("conversation_count"), ) .select_from(Message) - .where(Message.app_id == app_model.id, Message.invoke_from != InvokeFrom.DEBUGGER.value) + .where(Message.app_id == app_model.id, Message.invoke_from != InvokeFrom.DEBUGGER) ) if args["start"]: @@ -190,7 +190,7 @@ FROM WHERE app_id = :app_id AND invoke_from != :invoke_from""" - arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value} + arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER} timezone = pytz.timezone(account.timezone) utc_timezone = pytz.utc @@ -263,7 +263,7 @@ FROM WHERE app_id = :app_id AND invoke_from != :invoke_from""" - arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value} + arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER} timezone = pytz.timezone(account.timezone) utc_timezone = pytz.utc @@ -345,7 +345,7 @@ FROM WHERE c.app_id = :app_id AND m.invoke_from != :invoke_from""" - arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value} + arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER} timezone = pytz.timezone(account.timezone) utc_timezone = pytz.utc @@ -432,7 +432,7 @@ LEFT JOIN WHERE m.app_id = :app_id AND m.invoke_from != :invoke_from""" - arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value} + arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER} timezone = pytz.timezone(account.timezone) utc_timezone = pytz.utc @@ -509,7 +509,7 @@ FROM WHERE app_id = :app_id AND invoke_from != :invoke_from""" - arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value} + arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER} timezone = pytz.timezone(account.timezone) utc_timezone = pytz.utc @@ -584,7 +584,7 @@ FROM WHERE app_id = :app_id AND invoke_from != :invoke_from""" - arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value} + arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER} timezone = pytz.timezone(account.timezone) utc_timezone = pytz.utc diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index e70765546c..578d864b80 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -9,7 +9,6 @@ from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden, InternalServerError, NotFound import services -from configs import dify_config from controllers.console import api, console_ns from controllers.console.app.error import ConversationCompletedError, DraftWorkflowNotExist, DraftWorkflowNotSync from controllers.console.app.wraps import get_app_model @@ -26,6 +25,7 @@ from factories import file_factory, variable_factory from fields.workflow_fields import workflow_fields, workflow_pagination_fields from fields.workflow_run_fields import workflow_run_node_execution_fields from libs import helper +from libs.datetime_utils import naive_utc_now from libs.helper import TimestampField, uuid_value from libs.login import current_user, login_required from models import App @@ -675,8 +675,12 @@ class PublishedWorkflowApi(Resource): marked_comment=args.marked_comment or "", ) - app_model.workflow_id = workflow.id - db.session.commit() # NOTE: this is necessary for update app_model.workflow_id + # Update app_model within the same session to ensure atomicity + app_model_in_session = session.get(App, app_model.id) + if app_model_in_session: + app_model_in_session.workflow_id = workflow.id + app_model_in_session.updated_by = current_user.id + app_model_in_session.updated_at = naive_utc_now() workflow_created_at = TimestampField().format(workflow.created_at) @@ -797,24 +801,6 @@ class ConvertToWorkflowApi(Resource): } -@console_ns.route("/apps//workflows/draft/config") -class WorkflowConfigApi(Resource): - """Resource for workflow configuration.""" - - @api.doc("get_workflow_config") - @api.doc(description="Get workflow configuration") - @api.doc(params={"app_id": "Application ID"}) - @api.response(200, "Workflow configuration retrieved successfully") - @setup_required - @login_required - @account_initialization_required - @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) - def get(self, app_model: App): - return { - "parallel_depth_limit": dify_config.WORKFLOW_PARALLEL_DEPTH_LIMIT, - } - - @console_ns.route("/apps//workflows") class PublishedAllWorkflowApi(Resource): @api.doc("get_all_published_workflows") diff --git a/api/controllers/console/app/workflow_statistic.py b/api/controllers/console/app/workflow_statistic.py index 535e7cadd6..b8904bf3d9 100644 --- a/api/controllers/console/app/workflow_statistic.py +++ b/api/controllers/console/app/workflow_statistic.py @@ -47,7 +47,7 @@ WHERE arg_dict = { "tz": account.timezone, "app_id": app_model.id, - "triggered_from": WorkflowRunTriggeredFrom.APP_RUN.value, + "triggered_from": WorkflowRunTriggeredFrom.APP_RUN, } timezone = pytz.timezone(account.timezone) @@ -115,7 +115,7 @@ WHERE arg_dict = { "tz": account.timezone, "app_id": app_model.id, - "triggered_from": WorkflowRunTriggeredFrom.APP_RUN.value, + "triggered_from": WorkflowRunTriggeredFrom.APP_RUN, } timezone = pytz.timezone(account.timezone) @@ -183,7 +183,7 @@ WHERE arg_dict = { "tz": account.timezone, "app_id": app_model.id, - "triggered_from": WorkflowRunTriggeredFrom.APP_RUN.value, + "triggered_from": WorkflowRunTriggeredFrom.APP_RUN, } timezone = pytz.timezone(account.timezone) @@ -269,7 +269,7 @@ GROUP BY arg_dict = { "tz": account.timezone, "app_id": app_model.id, - "triggered_from": WorkflowRunTriggeredFrom.APP_RUN.value, + "triggered_from": WorkflowRunTriggeredFrom.APP_RUN, } timezone = pytz.timezone(account.timezone) diff --git a/api/controllers/console/auth/activate.py b/api/controllers/console/auth/activate.py index 8cdadfb03c..76171e3f8a 100644 --- a/api/controllers/console/auth/activate.py +++ b/api/controllers/console/auth/activate.py @@ -103,7 +103,7 @@ class ActivateApi(Resource): account.interface_language = args["interface_language"] account.timezone = args["timezone"] account.interface_theme = "light" - account.status = AccountStatus.ACTIVE.value + account.status = AccountStatus.ACTIVE account.initialized_at = naive_utc_now() db.session.commit() diff --git a/api/controllers/console/auth/data_source_bearer_auth.py b/api/controllers/console/auth/data_source_bearer_auth.py index 796e6916cc..207303b212 100644 --- a/api/controllers/console/auth/data_source_bearer_auth.py +++ b/api/controllers/console/auth/data_source_bearer_auth.py @@ -2,7 +2,7 @@ from flask_login import current_user from flask_restx import Resource, reqparse from werkzeug.exceptions import Forbidden -from controllers.console import api +from controllers.console import console_ns from controllers.console.auth.error import ApiKeyAuthFailedError from libs.login import login_required from services.auth.api_key_auth_service import ApiKeyAuthService @@ -10,6 +10,7 @@ from services.auth.api_key_auth_service import ApiKeyAuthService from ..wraps import account_initialization_required, setup_required +@console_ns.route("/api-key-auth/data-source") class ApiKeyAuthDataSource(Resource): @setup_required @login_required @@ -33,6 +34,7 @@ class ApiKeyAuthDataSource(Resource): return {"sources": []} +@console_ns.route("/api-key-auth/data-source/binding") class ApiKeyAuthDataSourceBinding(Resource): @setup_required @login_required @@ -54,6 +56,7 @@ class ApiKeyAuthDataSourceBinding(Resource): return {"result": "success"}, 200 +@console_ns.route("/api-key-auth/data-source/") class ApiKeyAuthDataSourceBindingDelete(Resource): @setup_required @login_required @@ -66,8 +69,3 @@ class ApiKeyAuthDataSourceBindingDelete(Resource): ApiKeyAuthService.delete_provider_auth(current_user.current_tenant_id, binding_id) return {"result": "success"}, 204 - - -api.add_resource(ApiKeyAuthDataSource, "/api-key-auth/data-source") -api.add_resource(ApiKeyAuthDataSourceBinding, "/api-key-auth/data-source/binding") -api.add_resource(ApiKeyAuthDataSourceBindingDelete, "/api-key-auth/data-source/") diff --git a/api/controllers/console/auth/data_source_oauth.py b/api/controllers/console/auth/data_source_oauth.py index fc4ba3a2c7..6f1fd2f11a 100644 --- a/api/controllers/console/auth/data_source_oauth.py +++ b/api/controllers/console/auth/data_source_oauth.py @@ -1,6 +1,6 @@ import logging -import requests +import httpx from flask import current_app, redirect, request from flask_login import current_user from flask_restx import Resource, fields @@ -119,7 +119,7 @@ class OAuthDataSourceBinding(Resource): return {"error": "Invalid code"}, 400 try: oauth_provider.get_access_token(code) - except requests.HTTPError as e: + except httpx.HTTPStatusError as e: logger.exception( "An error occurred during the OAuthCallback process with %s: %s", provider, e.response.text ) @@ -152,7 +152,7 @@ class OAuthDataSourceSync(Resource): return {"error": "Invalid provider"}, 400 try: oauth_provider.sync_data_source(binding_id) - except requests.HTTPError as e: + except httpx.HTTPStatusError as e: logger.exception( "An error occurred during the OAuthCallback process with %s: %s", provider, e.response.text ) diff --git a/api/controllers/console/auth/email_register.py b/api/controllers/console/auth/email_register.py index 91de19a78a..d3613d9183 100644 --- a/api/controllers/console/auth/email_register.py +++ b/api/controllers/console/auth/email_register.py @@ -5,7 +5,7 @@ from sqlalchemy.orm import Session from configs import dify_config from constants.languages import languages -from controllers.console import api +from controllers.console import console_ns from controllers.console.auth.error import ( EmailAlreadyInUseError, EmailCodeError, @@ -25,6 +25,7 @@ from services.billing_service import BillingService from services.errors.account import AccountNotFoundError, AccountRegisterError +@console_ns.route("/email-register/send-email") class EmailRegisterSendEmailApi(Resource): @setup_required @email_password_login_enabled @@ -52,6 +53,7 @@ class EmailRegisterSendEmailApi(Resource): return {"result": "success", "data": token} +@console_ns.route("/email-register/validity") class EmailRegisterCheckApi(Resource): @setup_required @email_password_login_enabled @@ -92,6 +94,7 @@ class EmailRegisterCheckApi(Resource): return {"is_valid": True, "email": token_data.get("email"), "token": new_token} +@console_ns.route("/email-register") class EmailRegisterResetApi(Resource): @setup_required @email_password_login_enabled @@ -148,8 +151,3 @@ class EmailRegisterResetApi(Resource): raise AccountInFreezeError() return account - - -api.add_resource(EmailRegisterSendEmailApi, "/email-register/send-email") -api.add_resource(EmailRegisterCheckApi, "/email-register/validity") -api.add_resource(EmailRegisterResetApi, "/email-register") diff --git a/api/controllers/console/auth/forgot_password.py b/api/controllers/console/auth/forgot_password.py index 36ccb1d562..704bcf8fb8 100644 --- a/api/controllers/console/auth/forgot_password.py +++ b/api/controllers/console/auth/forgot_password.py @@ -221,8 +221,3 @@ class ForgotPasswordResetApi(Resource): TenantService.create_tenant_member(tenant, account, role="owner") account.current_tenant = tenant tenant_was_created.send(tenant) - - -api.add_resource(ForgotPasswordSendEmailApi, "/forgot-password") -api.add_resource(ForgotPasswordCheckApi, "/forgot-password/validity") -api.add_resource(ForgotPasswordResetApi, "/forgot-password/resets") diff --git a/api/controllers/console/auth/login.py b/api/controllers/console/auth/login.py index 3b35ab3c23..ba614aa828 100644 --- a/api/controllers/console/auth/login.py +++ b/api/controllers/console/auth/login.py @@ -7,7 +7,7 @@ from flask_restx import Resource, reqparse import services from configs import dify_config from constants.languages import languages -from controllers.console import api +from controllers.console import console_ns from controllers.console.auth.error import ( AuthenticationFailedError, EmailCodeError, @@ -34,6 +34,7 @@ from services.errors.workspace import WorkSpaceNotAllowedCreateError, Workspaces from services.feature_service import FeatureService +@console_ns.route("/login") class LoginApi(Resource): """Resource for user login.""" @@ -91,6 +92,7 @@ class LoginApi(Resource): return {"result": "success", "data": token_pair.model_dump()} +@console_ns.route("/logout") class LogoutApi(Resource): @setup_required def get(self): @@ -102,6 +104,7 @@ class LogoutApi(Resource): return {"result": "success"} +@console_ns.route("/reset-password") class ResetPasswordSendEmailApi(Resource): @setup_required @email_password_login_enabled @@ -130,6 +133,7 @@ class ResetPasswordSendEmailApi(Resource): return {"result": "success", "data": token} +@console_ns.route("/email-code-login") class EmailCodeLoginSendEmailApi(Resource): @setup_required def post(self): @@ -162,6 +166,7 @@ class EmailCodeLoginSendEmailApi(Resource): return {"result": "success", "data": token} +@console_ns.route("/email-code-login/validity") class EmailCodeLoginApi(Resource): @setup_required def post(self): @@ -218,6 +223,7 @@ class EmailCodeLoginApi(Resource): return {"result": "success", "data": token_pair.model_dump()} +@console_ns.route("/refresh-token") class RefreshTokenApi(Resource): def post(self): parser = reqparse.RequestParser() @@ -229,11 +235,3 @@ class RefreshTokenApi(Resource): return {"result": "success", "data": new_token_pair.model_dump()} except Exception as e: return {"result": "fail", "data": str(e)}, 401 - - -api.add_resource(LoginApi, "/login") -api.add_resource(LogoutApi, "/logout") -api.add_resource(EmailCodeLoginSendEmailApi, "/email-code-login") -api.add_resource(EmailCodeLoginApi, "/email-code-login/validity") -api.add_resource(ResetPasswordSendEmailApi, "/reset-password") -api.add_resource(RefreshTokenApi, "/refresh-token") diff --git a/api/controllers/console/auth/oauth.py b/api/controllers/console/auth/oauth.py index 1602ee6eea..4efeceb676 100644 --- a/api/controllers/console/auth/oauth.py +++ b/api/controllers/console/auth/oauth.py @@ -1,6 +1,6 @@ import logging -import requests +import httpx from flask import current_app, redirect, request from flask_restx import Resource from sqlalchemy import select @@ -101,8 +101,10 @@ class OAuthCallback(Resource): try: token = oauth_provider.get_access_token(code) user_info = oauth_provider.get_user_info(token) - except requests.RequestException as e: - error_text = e.response.text if e.response else str(e) + except httpx.RequestError as e: + error_text = str(e) + if isinstance(e, httpx.HTTPStatusError): + error_text = e.response.text logger.exception("An error occurred during the OAuth process with %s: %s", provider, error_text) return {"error": "OAuth process failed"}, 400 @@ -128,11 +130,11 @@ class OAuthCallback(Resource): return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message={e.description}") # Check account status - if account.status == AccountStatus.BANNED.value: + if account.status == AccountStatus.BANNED: return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Account is banned.") - if account.status == AccountStatus.PENDING.value: - account.status = AccountStatus.ACTIVE.value + if account.status == AccountStatus.PENDING: + account.status = AccountStatus.ACTIVE account.initialized_at = naive_utc_now() db.session.commit() diff --git a/api/controllers/console/auth/oauth_server.py b/api/controllers/console/auth/oauth_server.py index a54c1443f8..46281860ae 100644 --- a/api/controllers/console/auth/oauth_server.py +++ b/api/controllers/console/auth/oauth_server.py @@ -14,7 +14,7 @@ from models.account import Account from models.model import OAuthProviderApp from services.oauth_server import OAUTH_ACCESS_TOKEN_EXPIRES_IN, OAuthGrantType, OAuthServerService -from .. import api +from .. import console_ns P = ParamSpec("P") R = TypeVar("R") @@ -86,6 +86,7 @@ def oauth_server_access_token_required(view: Callable[Concatenate[T, OAuthProvid return decorated +@console_ns.route("/oauth/provider") class OAuthServerAppApi(Resource): @setup_required @oauth_server_client_id_required @@ -108,6 +109,7 @@ class OAuthServerAppApi(Resource): ) +@console_ns.route("/oauth/provider/authorize") class OAuthServerUserAuthorizeApi(Resource): @setup_required @login_required @@ -125,6 +127,7 @@ class OAuthServerUserAuthorizeApi(Resource): ) +@console_ns.route("/oauth/provider/token") class OAuthServerUserTokenApi(Resource): @setup_required @oauth_server_client_id_required @@ -180,6 +183,7 @@ class OAuthServerUserTokenApi(Resource): ) +@console_ns.route("/oauth/provider/account") class OAuthServerUserAccountApi(Resource): @setup_required @oauth_server_client_id_required @@ -194,9 +198,3 @@ class OAuthServerUserAccountApi(Resource): "timezone": account.timezone, } ) - - -api.add_resource(OAuthServerAppApi, "/oauth/provider") -api.add_resource(OAuthServerUserAuthorizeApi, "/oauth/provider/authorize") -api.add_resource(OAuthServerUserTokenApi, "/oauth/provider/token") -api.add_resource(OAuthServerUserAccountApi, "/oauth/provider/account") diff --git a/api/controllers/console/billing/billing.py b/api/controllers/console/billing/billing.py index 39fc7dec6b..fa89f45122 100644 --- a/api/controllers/console/billing/billing.py +++ b/api/controllers/console/billing/billing.py @@ -1,12 +1,13 @@ from flask_restx import Resource, reqparse -from controllers.console import api +from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required from libs.login import current_user, login_required from models.model import Account from services.billing_service import BillingService +@console_ns.route("/billing/subscription") class Subscription(Resource): @setup_required @login_required @@ -26,6 +27,7 @@ class Subscription(Resource): ) +@console_ns.route("/billing/invoices") class Invoices(Resource): @setup_required @login_required @@ -36,7 +38,3 @@ class Invoices(Resource): BillingService.is_tenant_owner_or_admin(current_user) assert current_user.current_tenant_id is not None return BillingService.get_invoices(current_user.email, current_user.current_tenant_id) - - -api.add_resource(Subscription, "/billing/subscription") -api.add_resource(Invoices, "/billing/invoices") diff --git a/api/controllers/console/billing/compliance.py b/api/controllers/console/billing/compliance.py index 4bc073f679..e489b48c82 100644 --- a/api/controllers/console/billing/compliance.py +++ b/api/controllers/console/billing/compliance.py @@ -6,10 +6,11 @@ from libs.helper import extract_remote_ip from libs.login import login_required from services.billing_service import BillingService -from .. import api +from .. import console_ns from ..wraps import account_initialization_required, only_edition_cloud, setup_required +@console_ns.route("/compliance/download") class ComplianceApi(Resource): @setup_required @login_required @@ -30,6 +31,3 @@ class ComplianceApi(Resource): ip=ip_address, device_info=device_info, ) - - -api.add_resource(ComplianceApi, "/compliance/download") diff --git a/api/controllers/console/datasets/data_source.py b/api/controllers/console/datasets/data_source.py index 3a9530af84..6d9d675e87 100644 --- a/api/controllers/console/datasets/data_source.py +++ b/api/controllers/console/datasets/data_source.py @@ -9,13 +9,13 @@ from sqlalchemy import select from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound -from controllers.console import api +from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, setup_required from core.datasource.entities.datasource_entities import DatasourceProviderType, OnlineDocumentPagesMessage from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin from core.indexing_runner import IndexingRunner from core.rag.extractor.entity.datasource_type import DatasourceType -from core.rag.extractor.entity.extract_setting import ExtractSetting +from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo from core.rag.extractor.notion_extractor import NotionExtractor from extensions.ext_database import db from fields.data_source_fields import integrate_list_fields, integrate_notion_info_list_fields @@ -27,6 +27,10 @@ from services.datasource_provider_service import DatasourceProviderService from tasks.document_indexing_sync_task import document_indexing_sync_task +@console_ns.route( + "/data-source/integrates", + "/data-source/integrates//", +) class DataSourceApi(Resource): @setup_required @login_required @@ -109,6 +113,7 @@ class DataSourceApi(Resource): return {"result": "success"}, 200 +@console_ns.route("/notion/pre-import/pages") class DataSourceNotionListApi(Resource): @setup_required @login_required @@ -196,6 +201,10 @@ class DataSourceNotionListApi(Resource): return {"notion_info": {**workspace_info, "pages": pages}}, 200 +@console_ns.route( + "/notion/workspaces//pages///preview", + "/datasets/notion-indexing-estimate", +) class DataSourceNotionApi(Resource): @setup_required @login_required @@ -247,14 +256,16 @@ class DataSourceNotionApi(Resource): credential_id = notion_info.get("credential_id") for page in notion_info["pages"]: extract_setting = ExtractSetting( - datasource_type=DatasourceType.NOTION.value, - notion_info={ - "credential_id": credential_id, - "notion_workspace_id": workspace_id, - "notion_obj_id": page["page_id"], - "notion_page_type": page["type"], - "tenant_id": current_user.current_tenant_id, - }, + datasource_type=DatasourceType.NOTION, + notion_info=NotionInfo.model_validate( + { + "credential_id": credential_id, + "notion_workspace_id": workspace_id, + "notion_obj_id": page["page_id"], + "notion_page_type": page["type"], + "tenant_id": current_user.current_tenant_id, + } + ), document_model=args["doc_form"], ) extract_settings.append(extract_setting) @@ -269,6 +280,7 @@ class DataSourceNotionApi(Resource): return response.model_dump(), 200 +@console_ns.route("/datasets//notion/sync") class DataSourceNotionDatasetSyncApi(Resource): @setup_required @login_required @@ -285,6 +297,7 @@ class DataSourceNotionDatasetSyncApi(Resource): return {"result": "success"}, 200 +@console_ns.route("/datasets//documents//notion/sync") class DataSourceNotionDocumentSyncApi(Resource): @setup_required @login_required @@ -301,16 +314,3 @@ class DataSourceNotionDocumentSyncApi(Resource): raise NotFound("Document not found.") document_indexing_sync_task.delay(dataset_id_str, document_id_str) return {"result": "success"}, 200 - - -api.add_resource(DataSourceApi, "/data-source/integrates", "/data-source/integrates//") -api.add_resource(DataSourceNotionListApi, "/notion/pre-import/pages") -api.add_resource( - DataSourceNotionApi, - "/notion/workspaces//pages///preview", - "/datasets/notion-indexing-estimate", -) -api.add_resource(DataSourceNotionDatasetSyncApi, "/datasets//notion/sync") -api.add_resource( - DataSourceNotionDocumentSyncApi, "/datasets//documents//notion/sync" -) diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index 2affbd6a42..72cd33eab6 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -1,4 +1,5 @@ -import flask_restx +from typing import Any, cast + from flask import request from flask_login import current_user from flask_restx import Resource, fields, marshal, marshal_with, reqparse @@ -23,31 +24,27 @@ from core.model_runtime.entities.model_entities import ModelType from core.provider_manager import ProviderManager from core.rag.datasource.vdb.vector_type import VectorType from core.rag.extractor.entity.datasource_type import DatasourceType -from core.rag.extractor.entity.extract_setting import ExtractSetting +from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo, WebsiteInfo from core.rag.retrieval.retrieval_methods import RetrievalMethod from extensions.ext_database import db from fields.app_fields import related_app_list from fields.dataset_fields import dataset_detail_fields, dataset_query_detail_fields from fields.document_fields import document_status_fields from libs.login import login_required +from libs.validators import validate_description_length from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile +from models.account import Account from models.dataset import DatasetPermissionEnum from models.provider_ids import ModelProviderID from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService -def _validate_name(name): +def _validate_name(name: str) -> str: if not name or len(name) < 1 or len(name) > 40: raise ValueError("Name must be between 1 to 40 characters.") return name -def _validate_description_length(description): - if description and len(description) > 400: - raise ValueError("Description cannot exceed 400 characters.") - return description - - @console_ns.route("/datasets") class DatasetListApi(Resource): @api.doc("get_datasets") @@ -92,7 +89,7 @@ class DatasetListApi(Resource): for embedding_model in embedding_models: model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}") - data = marshal(datasets, dataset_detail_fields) + data = cast(list[dict[str, Any]], marshal(datasets, dataset_detail_fields)) for item in data: # convert embedding_model_provider to plugin standard format if item["indexing_technique"] == "high_quality" and item["embedding_model_provider"]: @@ -147,7 +144,7 @@ class DatasetListApi(Resource): ) parser.add_argument( "description", - type=_validate_description_length, + type=validate_description_length, nullable=True, required=False, default="", @@ -192,7 +189,7 @@ class DatasetListApi(Resource): name=args["name"], description=args["description"], indexing_technique=args["indexing_technique"], - account=current_user, + account=cast(Account, current_user), permission=DatasetPermissionEnum.ONLY_ME, provider=args["provider"], external_knowledge_api_id=args["external_knowledge_api_id"], @@ -224,7 +221,7 @@ class DatasetApi(Resource): DatasetService.check_dataset_permission(dataset, current_user) except services.errors.account.NoPermissionError as e: raise Forbidden(str(e)) - data = marshal(dataset, dataset_detail_fields) + data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields)) if dataset.indexing_technique == "high_quality": if dataset.embedding_model_provider: provider_id = ModelProviderID(dataset.embedding_model_provider) @@ -288,7 +285,7 @@ class DatasetApi(Resource): help="type is required. Name must be between 1 to 40 characters.", type=_validate_name, ) - parser.add_argument("description", location="json", store_missing=False, type=_validate_description_length) + parser.add_argument("description", location="json", store_missing=False, type=validate_description_length) parser.add_argument( "indexing_technique", type=str, @@ -369,7 +366,7 @@ class DatasetApi(Resource): if dataset is None: raise NotFound("Dataset not found.") - result_data = marshal(dataset, dataset_detail_fields) + result_data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields)) tenant_id = current_user.current_tenant_id if data.get("partial_member_list") and data.get("permission") == "partial_members": @@ -503,7 +500,7 @@ class DatasetIndexingEstimateApi(Resource): if file_details: for file_detail in file_details: extract_setting = ExtractSetting( - datasource_type=DatasourceType.FILE.value, + datasource_type=DatasourceType.FILE, upload_file=file_detail, document_model=args["doc_form"], ) @@ -515,14 +512,16 @@ class DatasetIndexingEstimateApi(Resource): credential_id = notion_info.get("credential_id") for page in notion_info["pages"]: extract_setting = ExtractSetting( - datasource_type=DatasourceType.NOTION.value, - notion_info={ - "credential_id": credential_id, - "notion_workspace_id": workspace_id, - "notion_obj_id": page["page_id"], - "notion_page_type": page["type"], - "tenant_id": current_user.current_tenant_id, - }, + datasource_type=DatasourceType.NOTION, + notion_info=NotionInfo.model_validate( + { + "credential_id": credential_id, + "notion_workspace_id": workspace_id, + "notion_obj_id": page["page_id"], + "notion_page_type": page["type"], + "tenant_id": current_user.current_tenant_id, + } + ), document_model=args["doc_form"], ) extract_settings.append(extract_setting) @@ -530,15 +529,17 @@ class DatasetIndexingEstimateApi(Resource): website_info_list = args["info_list"]["website_info_list"] for url in website_info_list["urls"]: extract_setting = ExtractSetting( - datasource_type=DatasourceType.WEBSITE.value, - website_info={ - "provider": website_info_list["provider"], - "job_id": website_info_list["job_id"], - "url": url, - "tenant_id": current_user.current_tenant_id, - "mode": "crawl", - "only_main_content": website_info_list["only_main_content"], - }, + datasource_type=DatasourceType.WEBSITE, + website_info=WebsiteInfo.model_validate( + { + "provider": website_info_list["provider"], + "job_id": website_info_list["job_id"], + "url": url, + "tenant_id": current_user.current_tenant_id, + "mode": "crawl", + "only_main_content": website_info_list["only_main_content"], + } + ), document_model=args["doc_form"], ) extract_settings.append(extract_setting) @@ -688,7 +689,7 @@ class DatasetApiKeyApi(Resource): ) if current_key_count >= self.max_keys: - flask_restx.abort( + api.abort( 400, message=f"Cannot create more than {self.max_keys} API keys for this resource type.", code="max_keys_exceeded", @@ -733,7 +734,7 @@ class DatasetApiDeleteApi(Resource): ) if key is None: - flask_restx.abort(404, message="API key not found") + api.abort(404, message="API key not found") db.session.query(ApiToken).where(ApiToken.id == api_key_id).delete() db.session.commit() @@ -785,7 +786,7 @@ class DatasetRetrievalSettingApi(Resource): | VectorType.VIKINGDB | VectorType.UPSTASH ): - return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]} + return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH]} case ( VectorType.QDRANT | VectorType.WEAVIATE @@ -809,12 +810,13 @@ class DatasetRetrievalSettingApi(Resource): | VectorType.MATRIXONE | VectorType.CLICKZETTA | VectorType.BAIDU + | VectorType.ALIBABACLOUD_MYSQL ): return { "retrieval_method": [ - RetrievalMethod.SEMANTIC_SEARCH.value, - RetrievalMethod.FULL_TEXT_SEARCH.value, - RetrievalMethod.HYBRID_SEARCH.value, + RetrievalMethod.SEMANTIC_SEARCH, + RetrievalMethod.FULL_TEXT_SEARCH, + RetrievalMethod.HYBRID_SEARCH, ] } case _: @@ -841,7 +843,7 @@ class DatasetRetrievalSettingMockApi(Resource): | VectorType.VIKINGDB | VectorType.UPSTASH ): - return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]} + return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH]} case ( VectorType.QDRANT | VectorType.WEAVIATE @@ -863,12 +865,13 @@ class DatasetRetrievalSettingMockApi(Resource): | VectorType.MATRIXONE | VectorType.CLICKZETTA | VectorType.BAIDU + | VectorType.ALIBABACLOUD_MYSQL ): return { "retrieval_method": [ - RetrievalMethod.SEMANTIC_SEARCH.value, - RetrievalMethod.FULL_TEXT_SEARCH.value, - RetrievalMethod.HYBRID_SEARCH.value, + RetrievalMethod.SEMANTIC_SEARCH, + RetrievalMethod.FULL_TEXT_SEARCH, + RetrievalMethod.HYBRID_SEARCH, ] } case _: diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index 5de1f6c6ee..011dacde76 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -4,6 +4,7 @@ from argparse import ArgumentTypeError from collections.abc import Sequence from typing import Literal, cast +import sqlalchemy as sa from flask import request from flask_login import current_user from flask_restx import Resource, fields, marshal, marshal_with, reqparse @@ -43,7 +44,7 @@ from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.errors.invoke import InvokeAuthorizationError from core.plugin.impl.exc import PluginDaemonClientSideError from core.rag.extractor.entity.datasource_type import DatasourceType -from core.rag.extractor.entity.extract_setting import ExtractSetting +from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo, WebsiteInfo from extensions.ext_database import db from fields.document_fields import ( dataset_and_document_fields, @@ -54,6 +55,7 @@ from fields.document_fields import ( from libs.datetime_utils import naive_utc_now from libs.login import login_required from models import Dataset, DatasetProcessRule, Document, DocumentSegment, UploadFile +from models.account import Account from models.dataset import DocumentPipelineExecutionLog from services.dataset_service import DatasetService, DocumentService from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig @@ -211,13 +213,13 @@ class DatasetDocumentListApi(Resource): if sort == "hit_count": sub_query = ( - db.select(DocumentSegment.document_id, db.func.sum(DocumentSegment.hit_count).label("total_hit_count")) + sa.select(DocumentSegment.document_id, sa.func.sum(DocumentSegment.hit_count).label("total_hit_count")) .group_by(DocumentSegment.document_id) .subquery() ) query = query.outerjoin(sub_query, sub_query.c.document_id == Document.id).order_by( - sort_logic(db.func.coalesce(sub_query.c.total_hit_count, 0)), + sort_logic(sa.func.coalesce(sub_query.c.total_hit_count, 0)), sort_logic(Document.position), ) elif sort == "created_at": @@ -303,7 +305,7 @@ class DatasetDocumentListApi(Resource): "doc_language", type=str, default="English", required=False, nullable=False, location="json" ) args = parser.parse_args() - knowledge_config = KnowledgeConfig(**args) + knowledge_config = KnowledgeConfig.model_validate(args) if not dataset.indexing_technique and not knowledge_config.indexing_technique: raise ValueError("indexing_technique is required.") @@ -393,7 +395,7 @@ class DatasetInitApi(Resource): parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json") args = parser.parse_args() - knowledge_config = KnowledgeConfig(**args) + knowledge_config = KnowledgeConfig.model_validate(args) if knowledge_config.indexing_technique == "high_quality": if knowledge_config.embedding_model is None or knowledge_config.embedding_model_provider is None: raise ValueError("embedding model and embedding model provider are required for high quality indexing.") @@ -417,7 +419,9 @@ class DatasetInitApi(Resource): try: dataset, documents, batch = DocumentService.save_document_without_dataset_id( - tenant_id=current_user.current_tenant_id, knowledge_config=knowledge_config, account=current_user + tenant_id=current_user.current_tenant_id, + knowledge_config=knowledge_config, + account=cast(Account, current_user), ) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) @@ -451,7 +455,7 @@ class DocumentIndexingEstimateApi(DocumentResource): raise DocumentAlreadyFinishedError() data_process_rule = document.dataset_process_rule - data_process_rule_dict = data_process_rule.to_dict() + data_process_rule_dict = data_process_rule.to_dict() if data_process_rule else {} response = {"tokens": 0, "total_price": 0, "currency": "USD", "total_segments": 0, "preview": []} @@ -471,7 +475,7 @@ class DocumentIndexingEstimateApi(DocumentResource): raise NotFound("File not found.") extract_setting = ExtractSetting( - datasource_type=DatasourceType.FILE.value, upload_file=file, document_model=document.doc_form + datasource_type=DatasourceType.FILE, upload_file=file, document_model=document.doc_form ) indexing_runner = IndexingRunner() @@ -513,7 +517,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource): if not documents: return {"tokens": 0, "total_price": 0, "currency": "USD", "total_segments": 0, "preview": []}, 200 data_process_rule = documents[0].dataset_process_rule - data_process_rule_dict = data_process_rule.to_dict() + data_process_rule_dict = data_process_rule.to_dict() if data_process_rule else {} extract_settings = [] for document in documents: if document.indexing_status in {"completed", "error"}: @@ -534,7 +538,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource): raise NotFound("File not found.") extract_setting = ExtractSetting( - datasource_type=DatasourceType.FILE.value, upload_file=file_detail, document_model=document.doc_form + datasource_type=DatasourceType.FILE, upload_file=file_detail, document_model=document.doc_form ) extract_settings.append(extract_setting) @@ -542,14 +546,16 @@ class DocumentBatchIndexingEstimateApi(DocumentResource): if not data_source_info: continue extract_setting = ExtractSetting( - datasource_type=DatasourceType.NOTION.value, - notion_info={ - "credential_id": data_source_info["credential_id"], - "notion_workspace_id": data_source_info["notion_workspace_id"], - "notion_obj_id": data_source_info["notion_page_id"], - "notion_page_type": data_source_info["type"], - "tenant_id": current_user.current_tenant_id, - }, + datasource_type=DatasourceType.NOTION, + notion_info=NotionInfo.model_validate( + { + "credential_id": data_source_info["credential_id"], + "notion_workspace_id": data_source_info["notion_workspace_id"], + "notion_obj_id": data_source_info["notion_page_id"], + "notion_page_type": data_source_info["type"], + "tenant_id": current_user.current_tenant_id, + } + ), document_model=document.doc_form, ) extract_settings.append(extract_setting) @@ -557,15 +563,17 @@ class DocumentBatchIndexingEstimateApi(DocumentResource): if not data_source_info: continue extract_setting = ExtractSetting( - datasource_type=DatasourceType.WEBSITE.value, - website_info={ - "provider": data_source_info["provider"], - "job_id": data_source_info["job_id"], - "url": data_source_info["url"], - "tenant_id": current_user.current_tenant_id, - "mode": data_source_info["mode"], - "only_main_content": data_source_info["only_main_content"], - }, + datasource_type=DatasourceType.WEBSITE, + website_info=WebsiteInfo.model_validate( + { + "provider": data_source_info["provider"], + "job_id": data_source_info["job_id"], + "url": data_source_info["url"], + "tenant_id": current_user.current_tenant_id, + "mode": data_source_info["mode"], + "only_main_content": data_source_info["only_main_content"], + } + ), document_model=document.doc_form, ) extract_settings.append(extract_setting) @@ -752,7 +760,7 @@ class DocumentApi(DocumentResource): } else: dataset_process_rules = DatasetService.get_process_rules(dataset_id) - document_process_rules = document.dataset_process_rule.to_dict() + document_process_rules = document.dataset_process_rule.to_dict() if document.dataset_process_rule else {} data_source_info = document.data_source_detail_dict response = { "id": document.id, @@ -1072,7 +1080,9 @@ class DocumentRenameApi(DocumentResource): if not current_user.is_dataset_editor: raise Forbidden() dataset = DatasetService.get_dataset(dataset_id) - DatasetService.check_dataset_operator_permission(current_user, dataset) + if not dataset: + raise NotFound("Dataset not found.") + DatasetService.check_dataset_operator_permission(cast(Account, current_user), dataset) parser = reqparse.RequestParser() parser.add_argument("name", type=str, required=True, nullable=False, location="json") args = parser.parse_args() @@ -1113,6 +1123,7 @@ class WebsiteDocumentSyncApi(DocumentResource): return {"result": "success"}, 200 +@console_ns.route("/datasets//documents//pipeline-execution-log") class DocumentPipelineExecutionLogApi(DocumentResource): @setup_required @login_required @@ -1146,29 +1157,3 @@ class DocumentPipelineExecutionLogApi(DocumentResource): "input_data": log.input_data, "datasource_node_id": log.datasource_node_id, }, 200 - - -api.add_resource(GetProcessRuleApi, "/datasets/process-rule") -api.add_resource(DatasetDocumentListApi, "/datasets//documents") -api.add_resource(DatasetInitApi, "/datasets/init") -api.add_resource( - DocumentIndexingEstimateApi, "/datasets//documents//indexing-estimate" -) -api.add_resource(DocumentBatchIndexingEstimateApi, "/datasets//batch//indexing-estimate") -api.add_resource(DocumentBatchIndexingStatusApi, "/datasets//batch//indexing-status") -api.add_resource(DocumentIndexingStatusApi, "/datasets//documents//indexing-status") -api.add_resource(DocumentApi, "/datasets//documents/") -api.add_resource( - DocumentProcessingApi, "/datasets//documents//processing/" -) -api.add_resource(DocumentMetadataApi, "/datasets//documents//metadata") -api.add_resource(DocumentStatusApi, "/datasets//documents/status//batch") -api.add_resource(DocumentPauseApi, "/datasets//documents//processing/pause") -api.add_resource(DocumentRecoverApi, "/datasets//documents//processing/resume") -api.add_resource(DocumentRetryApi, "/datasets//retry") -api.add_resource(DocumentRenameApi, "/datasets//documents//rename") - -api.add_resource(WebsiteDocumentSyncApi, "/datasets//documents//website-sync") -api.add_resource( - DocumentPipelineExecutionLogApi, "/datasets//documents//pipeline-execution-log" -) diff --git a/api/controllers/console/datasets/datasets_segments.py b/api/controllers/console/datasets/datasets_segments.py index 463fd2d7ec..d6bd02483d 100644 --- a/api/controllers/console/datasets/datasets_segments.py +++ b/api/controllers/console/datasets/datasets_segments.py @@ -7,7 +7,7 @@ from sqlalchemy import select from werkzeug.exceptions import Forbidden, NotFound import services -from controllers.console import api +from controllers.console import console_ns from controllers.console.app.error import ProviderNotInitializeError from controllers.console.datasets.error import ( ChildChunkDeleteIndexError, @@ -37,6 +37,7 @@ from services.errors.chunk import ChildChunkIndexingError as ChildChunkIndexingS from tasks.batch_create_segment_to_index_task import batch_create_segment_to_index_task +@console_ns.route("/datasets//documents//segments") class DatasetDocumentSegmentListApi(Resource): @setup_required @login_required @@ -139,6 +140,7 @@ class DatasetDocumentSegmentListApi(Resource): return {"result": "success"}, 204 +@console_ns.route("/datasets//documents//segment/") class DatasetDocumentSegmentApi(Resource): @setup_required @login_required @@ -193,6 +195,7 @@ class DatasetDocumentSegmentApi(Resource): return {"result": "success"}, 200 +@console_ns.route("/datasets//documents//segment") class DatasetDocumentSegmentAddApi(Resource): @setup_required @login_required @@ -244,6 +247,7 @@ class DatasetDocumentSegmentAddApi(Resource): return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200 +@console_ns.route("/datasets//documents//segments/") class DatasetDocumentSegmentUpdateApi(Resource): @setup_required @login_required @@ -305,7 +309,7 @@ class DatasetDocumentSegmentUpdateApi(Resource): ) args = parser.parse_args() SegmentService.segment_create_args_validate(args, document) - segment = SegmentService.update_segment(SegmentUpdateArgs(**args), segment, document, dataset) + segment = SegmentService.update_segment(SegmentUpdateArgs.model_validate(args), segment, document, dataset) return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200 @setup_required @@ -345,6 +349,10 @@ class DatasetDocumentSegmentUpdateApi(Resource): return {"result": "success"}, 204 +@console_ns.route( + "/datasets//documents//segments/batch_import", + "/datasets/batch_import_status/", +) class DatasetDocumentSegmentBatchImportApi(Resource): @setup_required @login_required @@ -384,7 +392,12 @@ class DatasetDocumentSegmentBatchImportApi(Resource): # send batch add segments task redis_client.setnx(indexing_cache_key, "waiting") batch_create_segment_to_index_task.delay( - str(job_id), upload_file_id, dataset_id, document_id, current_user.current_tenant_id, current_user.id + str(job_id), + upload_file_id, + dataset_id, + document_id, + current_user.current_tenant_id, + current_user.id, ) except Exception as e: return {"error": str(e)}, 500 @@ -393,7 +406,9 @@ class DatasetDocumentSegmentBatchImportApi(Resource): @setup_required @login_required @account_initialization_required - def get(self, job_id): + def get(self, job_id=None, dataset_id=None, document_id=None): + if job_id is None: + raise NotFound("The job does not exist.") job_id = str(job_id) indexing_cache_key = f"segment_batch_import_{job_id}" cache_result = redis_client.get(indexing_cache_key) @@ -403,6 +418,7 @@ class DatasetDocumentSegmentBatchImportApi(Resource): return {"job_id": job_id, "job_status": cache_result.decode()}, 200 +@console_ns.route("/datasets//documents//segments//child_chunks") class ChildChunkAddApi(Resource): @setup_required @login_required @@ -457,7 +473,8 @@ class ChildChunkAddApi(Resource): parser.add_argument("content", type=str, required=True, nullable=False, location="json") args = parser.parse_args() try: - child_chunk = SegmentService.create_child_chunk(args.get("content"), segment, document, dataset) + content = args["content"] + child_chunk = SegmentService.create_child_chunk(content, segment, document, dataset) except ChildChunkIndexingServiceError as e: raise ChildChunkIndexingError(str(e)) return {"data": marshal(child_chunk, child_chunk_fields)}, 200 @@ -546,13 +563,17 @@ class ChildChunkAddApi(Resource): parser.add_argument("chunks", type=list, required=True, nullable=False, location="json") args = parser.parse_args() try: - chunks = [ChildChunkUpdateArgs(**chunk) for chunk in args.get("chunks")] + chunks_data = args["chunks"] + chunks = [ChildChunkUpdateArgs.model_validate(chunk) for chunk in chunks_data] child_chunks = SegmentService.update_child_chunks(chunks, segment, document, dataset) except ChildChunkIndexingServiceError as e: raise ChildChunkIndexingError(str(e)) return {"data": marshal(child_chunks, child_chunk_fields)}, 200 +@console_ns.route( + "/datasets//documents//segments//child_chunks/" +) class ChildChunkUpdateApi(Resource): @setup_required @login_required @@ -660,33 +681,8 @@ class ChildChunkUpdateApi(Resource): parser.add_argument("content", type=str, required=True, nullable=False, location="json") args = parser.parse_args() try: - child_chunk = SegmentService.update_child_chunk( - args.get("content"), child_chunk, segment, document, dataset - ) + content = args["content"] + child_chunk = SegmentService.update_child_chunk(content, child_chunk, segment, document, dataset) except ChildChunkIndexingServiceError as e: raise ChildChunkIndexingError(str(e)) return {"data": marshal(child_chunk, child_chunk_fields)}, 200 - - -api.add_resource(DatasetDocumentSegmentListApi, "/datasets//documents//segments") -api.add_resource( - DatasetDocumentSegmentApi, "/datasets//documents//segment/" -) -api.add_resource(DatasetDocumentSegmentAddApi, "/datasets//documents//segment") -api.add_resource( - DatasetDocumentSegmentUpdateApi, - "/datasets//documents//segments/", -) -api.add_resource( - DatasetDocumentSegmentBatchImportApi, - "/datasets//documents//segments/batch_import", - "/datasets/batch_import_status/", -) -api.add_resource( - ChildChunkAddApi, - "/datasets//documents//segments//child_chunks", -) -api.add_resource( - ChildChunkUpdateApi, - "/datasets//documents//segments//child_chunks/", -) diff --git a/api/controllers/console/datasets/external.py b/api/controllers/console/datasets/external.py index e8f5a11b41..adf9f53523 100644 --- a/api/controllers/console/datasets/external.py +++ b/api/controllers/console/datasets/external.py @@ -1,3 +1,5 @@ +from typing import cast + from flask import request from flask_login import current_user from flask_restx import Resource, fields, marshal, reqparse @@ -9,13 +11,14 @@ from controllers.console.datasets.error import DatasetNameDuplicateError from controllers.console.wraps import account_initialization_required, setup_required from fields.dataset_fields import dataset_detail_fields from libs.login import login_required +from models.account import Account from services.dataset_service import DatasetService from services.external_knowledge_service import ExternalDatasetService from services.hit_testing_service import HitTestingService from services.knowledge_service import ExternalDatasetTestService -def _validate_name(name): +def _validate_name(name: str) -> str: if not name or len(name) < 1 or len(name) > 100: raise ValueError("Name must be between 1 to 100 characters.") return name @@ -274,7 +277,7 @@ class ExternalKnowledgeHitTestingApi(Resource): response = HitTestingService.external_retrieve( dataset=dataset, query=args["query"], - account=current_user, + account=cast(Account, current_user), external_retrieval_model=args["external_retrieval_model"], metadata_filtering_conditions=args["metadata_filtering_conditions"], ) diff --git a/api/controllers/console/datasets/hit_testing_base.py b/api/controllers/console/datasets/hit_testing_base.py index cfbfc50873..a68e337135 100644 --- a/api/controllers/console/datasets/hit_testing_base.py +++ b/api/controllers/console/datasets/hit_testing_base.py @@ -1,10 +1,11 @@ import logging +from typing import cast from flask_login import current_user from flask_restx import marshal, reqparse from werkzeug.exceptions import Forbidden, InternalServerError, NotFound -import services.dataset_service +import services from controllers.console.app.error import ( CompletionRequestError, ProviderModelCurrentlyNotSupportError, @@ -20,6 +21,7 @@ from core.errors.error import ( ) from core.model_runtime.errors.invoke import InvokeError from fields.hit_testing_fields import hit_testing_record_fields +from models.account import Account from services.dataset_service import DatasetService from services.hit_testing_service import HitTestingService @@ -59,7 +61,7 @@ class DatasetsHitTestingBase: response = HitTestingService.retrieve( dataset=dataset, query=args["query"], - account=current_user, + account=cast(Account, current_user), retrieval_model=args["retrieval_model"], external_retrieval_model=args["external_retrieval_model"], limit=10, diff --git a/api/controllers/console/datasets/metadata.py b/api/controllers/console/datasets/metadata.py index 21ab5e4fe1..8438458617 100644 --- a/api/controllers/console/datasets/metadata.py +++ b/api/controllers/console/datasets/metadata.py @@ -4,7 +4,7 @@ from flask_login import current_user from flask_restx import Resource, marshal_with, reqparse from werkzeug.exceptions import NotFound -from controllers.console import api +from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required from fields.dataset_fields import dataset_metadata_fields from libs.login import login_required @@ -16,6 +16,7 @@ from services.entities.knowledge_entities.knowledge_entities import ( from services.metadata_service import MetadataService +@console_ns.route("/datasets//metadata") class DatasetMetadataCreateApi(Resource): @setup_required @login_required @@ -27,7 +28,7 @@ class DatasetMetadataCreateApi(Resource): parser.add_argument("type", type=str, required=True, nullable=False, location="json") parser.add_argument("name", type=str, required=True, nullable=False, location="json") args = parser.parse_args() - metadata_args = MetadataArgs(**args) + metadata_args = MetadataArgs.model_validate(args) dataset_id_str = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id_str) @@ -50,6 +51,7 @@ class DatasetMetadataCreateApi(Resource): return MetadataService.get_dataset_metadatas(dataset), 200 +@console_ns.route("/datasets//metadata/") class DatasetMetadataApi(Resource): @setup_required @login_required @@ -60,6 +62,7 @@ class DatasetMetadataApi(Resource): parser = reqparse.RequestParser() parser.add_argument("name", type=str, required=True, nullable=False, location="json") args = parser.parse_args() + name = args["name"] dataset_id_str = str(dataset_id) metadata_id_str = str(metadata_id) @@ -68,7 +71,7 @@ class DatasetMetadataApi(Resource): raise NotFound("Dataset not found.") DatasetService.check_dataset_permission(dataset, current_user) - metadata = MetadataService.update_metadata_name(dataset_id_str, metadata_id_str, args.get("name")) + metadata = MetadataService.update_metadata_name(dataset_id_str, metadata_id_str, name) return metadata, 200 @setup_required @@ -87,6 +90,7 @@ class DatasetMetadataApi(Resource): return {"result": "success"}, 204 +@console_ns.route("/datasets/metadata/built-in") class DatasetMetadataBuiltInFieldApi(Resource): @setup_required @login_required @@ -97,6 +101,7 @@ class DatasetMetadataBuiltInFieldApi(Resource): return {"fields": built_in_fields}, 200 +@console_ns.route("/datasets//metadata/built-in/") class DatasetMetadataBuiltInFieldActionApi(Resource): @setup_required @login_required @@ -116,6 +121,7 @@ class DatasetMetadataBuiltInFieldActionApi(Resource): return {"result": "success"}, 200 +@console_ns.route("/datasets//documents/metadata") class DocumentMetadataEditApi(Resource): @setup_required @login_required @@ -131,15 +137,8 @@ class DocumentMetadataEditApi(Resource): parser = reqparse.RequestParser() parser.add_argument("operation_data", type=list, required=True, nullable=False, location="json") args = parser.parse_args() - metadata_args = MetadataOperationData(**args) + metadata_args = MetadataOperationData.model_validate(args) MetadataService.update_documents_metadata(dataset, metadata_args) return {"result": "success"}, 200 - - -api.add_resource(DatasetMetadataCreateApi, "/datasets//metadata") -api.add_resource(DatasetMetadataApi, "/datasets//metadata/") -api.add_resource(DatasetMetadataBuiltInFieldApi, "/datasets/metadata/built-in") -api.add_resource(DatasetMetadataBuiltInFieldActionApi, "/datasets//metadata/built-in/") -api.add_resource(DocumentMetadataEditApi, "/datasets//documents/metadata") diff --git a/api/controllers/console/datasets/rag_pipeline/datasource_auth.py b/api/controllers/console/datasets/rag_pipeline/datasource_auth.py index 1a845cf326..53b5a0d965 100644 --- a/api/controllers/console/datasets/rag_pipeline/datasource_auth.py +++ b/api/controllers/console/datasets/rag_pipeline/datasource_auth.py @@ -1,16 +1,16 @@ -from fastapi.encoders import jsonable_encoder from flask import make_response, redirect, request from flask_login import current_user from flask_restx import Resource, reqparse from werkzeug.exceptions import Forbidden, NotFound from configs import dify_config -from controllers.console import api +from controllers.console import console_ns from controllers.console.wraps import ( account_initialization_required, setup_required, ) from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.utils.encoders import jsonable_encoder from core.plugin.impl.oauth import OAuthHandler from libs.helper import StrLen from libs.login import login_required @@ -19,6 +19,7 @@ from services.datasource_provider_service import DatasourceProviderService from services.plugin.oauth_service import OAuthProxyService +@console_ns.route("/oauth/plugin//datasource/get-authorization-url") class DatasourcePluginOAuthAuthorizationUrl(Resource): @setup_required @login_required @@ -68,6 +69,7 @@ class DatasourcePluginOAuthAuthorizationUrl(Resource): return response +@console_ns.route("/oauth/plugin//datasource/callback") class DatasourceOAuthCallback(Resource): @setup_required def get(self, provider_id: str): @@ -123,6 +125,7 @@ class DatasourceOAuthCallback(Resource): return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback") +@console_ns.route("/auth/plugin/datasource/") class DatasourceAuth(Resource): @setup_required @login_required @@ -165,6 +168,7 @@ class DatasourceAuth(Resource): return {"result": datasources}, 200 +@console_ns.route("/auth/plugin/datasource//delete") class DatasourceAuthDeleteApi(Resource): @setup_required @login_required @@ -188,6 +192,7 @@ class DatasourceAuthDeleteApi(Resource): return {"result": "success"}, 200 +@console_ns.route("/auth/plugin/datasource//update") class DatasourceAuthUpdateApi(Resource): @setup_required @login_required @@ -213,6 +218,7 @@ class DatasourceAuthUpdateApi(Resource): return {"result": "success"}, 201 +@console_ns.route("/auth/plugin/datasource/list") class DatasourceAuthListApi(Resource): @setup_required @login_required @@ -225,6 +231,7 @@ class DatasourceAuthListApi(Resource): return {"result": jsonable_encoder(datasources)}, 200 +@console_ns.route("/auth/plugin/datasource/default-list") class DatasourceHardCodeAuthListApi(Resource): @setup_required @login_required @@ -237,6 +244,7 @@ class DatasourceHardCodeAuthListApi(Resource): return {"result": jsonable_encoder(datasources)}, 200 +@console_ns.route("/auth/plugin/datasource//custom-client") class DatasourceAuthOauthCustomClient(Resource): @setup_required @login_required @@ -271,6 +279,7 @@ class DatasourceAuthOauthCustomClient(Resource): return {"result": "success"}, 200 +@console_ns.route("/auth/plugin/datasource//default") class DatasourceAuthDefaultApi(Resource): @setup_required @login_required @@ -291,6 +300,7 @@ class DatasourceAuthDefaultApi(Resource): return {"result": "success"}, 200 +@console_ns.route("/auth/plugin/datasource//update-name") class DatasourceUpdateProviderNameApi(Resource): @setup_required @login_required @@ -311,52 +321,3 @@ class DatasourceUpdateProviderNameApi(Resource): credential_id=args["credential_id"], ) return {"result": "success"}, 200 - - -api.add_resource( - DatasourcePluginOAuthAuthorizationUrl, - "/oauth/plugin//datasource/get-authorization-url", -) -api.add_resource( - DatasourceOAuthCallback, - "/oauth/plugin//datasource/callback", -) -api.add_resource( - DatasourceAuth, - "/auth/plugin/datasource/", -) - -api.add_resource( - DatasourceAuthUpdateApi, - "/auth/plugin/datasource//update", -) - -api.add_resource( - DatasourceAuthDeleteApi, - "/auth/plugin/datasource//delete", -) - -api.add_resource( - DatasourceAuthListApi, - "/auth/plugin/datasource/list", -) - -api.add_resource( - DatasourceHardCodeAuthListApi, - "/auth/plugin/datasource/default-list", -) - -api.add_resource( - DatasourceAuthOauthCustomClient, - "/auth/plugin/datasource//custom-client", -) - -api.add_resource( - DatasourceAuthDefaultApi, - "/auth/plugin/datasource//default", -) - -api.add_resource( - DatasourceUpdateProviderNameApi, - "/auth/plugin/datasource//update-name", -) diff --git a/api/controllers/console/datasets/rag_pipeline/datasource_content_preview.py b/api/controllers/console/datasets/rag_pipeline/datasource_content_preview.py index 05fa681a33..6c04cc877a 100644 --- a/api/controllers/console/datasets/rag_pipeline/datasource_content_preview.py +++ b/api/controllers/console/datasets/rag_pipeline/datasource_content_preview.py @@ -4,7 +4,7 @@ from flask_restx import ( # type: ignore ) from werkzeug.exceptions import Forbidden -from controllers.console import api +from controllers.console import console_ns from controllers.console.datasets.wraps import get_rag_pipeline from controllers.console.wraps import account_initialization_required, setup_required from libs.login import current_user, login_required @@ -13,6 +13,7 @@ from models.dataset import Pipeline from services.rag_pipeline.rag_pipeline import RagPipelineService +@console_ns.route("/rag/pipelines//workflows/published/datasource/nodes//preview") class DataSourceContentPreviewApi(Resource): @setup_required @login_required @@ -49,9 +50,3 @@ class DataSourceContentPreviewApi(Resource): credential_id=args.get("credential_id"), ) return preview_content, 200 - - -api.add_resource( - DataSourceContentPreviewApi, - "/rag/pipelines//workflows/published/datasource/nodes//preview", -) diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py index f04b0e04c3..e021f95283 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py @@ -4,7 +4,7 @@ from flask import request from flask_restx import Resource, reqparse from sqlalchemy.orm import Session -from controllers.console import api +from controllers.console import console_ns from controllers.console.wraps import ( account_initialization_required, enterprise_license_required, @@ -20,18 +20,19 @@ from services.rag_pipeline.rag_pipeline import RagPipelineService logger = logging.getLogger(__name__) -def _validate_name(name): +def _validate_name(name: str) -> str: if not name or len(name) < 1 or len(name) > 40: raise ValueError("Name must be between 1 to 40 characters.") return name -def _validate_description_length(description): +def _validate_description_length(description: str) -> str: if len(description) > 400: raise ValueError("Description cannot exceed 400 characters.") return description +@console_ns.route("/rag/pipeline/templates") class PipelineTemplateListApi(Resource): @setup_required @login_required @@ -45,6 +46,7 @@ class PipelineTemplateListApi(Resource): return pipeline_templates, 200 +@console_ns.route("/rag/pipeline/templates/") class PipelineTemplateDetailApi(Resource): @setup_required @login_required @@ -57,6 +59,7 @@ class PipelineTemplateDetailApi(Resource): return pipeline_template, 200 +@console_ns.route("/rag/pipeline/customized/templates/") class CustomizedPipelineTemplateApi(Resource): @setup_required @login_required @@ -73,7 +76,7 @@ class CustomizedPipelineTemplateApi(Resource): ) parser.add_argument( "description", - type=str, + type=_validate_description_length, nullable=True, required=False, default="", @@ -85,7 +88,7 @@ class CustomizedPipelineTemplateApi(Resource): nullable=True, ) args = parser.parse_args() - pipeline_template_info = PipelineTemplateInfoEntity(**args) + pipeline_template_info = PipelineTemplateInfoEntity.model_validate(args) RagPipelineService.update_customized_pipeline_template(template_id, pipeline_template_info) return 200 @@ -112,6 +115,7 @@ class CustomizedPipelineTemplateApi(Resource): return {"data": template.yaml_content}, 200 +@console_ns.route("/rag/pipelines//customized/publish") class PublishCustomizedPipelineTemplateApi(Resource): @setup_required @login_required @@ -129,7 +133,7 @@ class PublishCustomizedPipelineTemplateApi(Resource): ) parser.add_argument( "description", - type=str, + type=_validate_description_length, nullable=True, required=False, default="", @@ -144,21 +148,3 @@ class PublishCustomizedPipelineTemplateApi(Resource): rag_pipeline_service = RagPipelineService() rag_pipeline_service.publish_customized_pipeline_template(pipeline_id, args) return {"result": "success"} - - -api.add_resource( - PipelineTemplateListApi, - "/rag/pipeline/templates", -) -api.add_resource( - PipelineTemplateDetailApi, - "/rag/pipeline/templates/", -) -api.add_resource( - CustomizedPipelineTemplateApi, - "/rag/pipeline/customized/templates/", -) -api.add_resource( - PublishCustomizedPipelineTemplateApi, - "/rag/pipelines//customized/publish", -) diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_datasets.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_datasets.py index 34faa4ec85..404aa42073 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_datasets.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_datasets.py @@ -1,10 +1,10 @@ -from flask_login import current_user # type: ignore # type: ignore -from flask_restx import Resource, marshal, reqparse # type: ignore +from flask_login import current_user +from flask_restx import Resource, marshal, reqparse from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden import services -from controllers.console import api +from controllers.console import console_ns from controllers.console.datasets.error import DatasetNameDuplicateError from controllers.console.wraps import ( account_initialization_required, @@ -20,18 +20,7 @@ from services.entities.knowledge_entities.rag_pipeline_entities import IconInfo, from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService -def _validate_name(name): - if not name or len(name) < 1 or len(name) > 40: - raise ValueError("Name must be between 1 to 40 characters.") - return name - - -def _validate_description_length(description): - if len(description) > 400: - raise ValueError("Description cannot exceed 400 characters.") - return description - - +@console_ns.route("/rag/pipeline/dataset") class CreateRagPipelineDatasetApi(Resource): @setup_required @login_required @@ -84,6 +73,7 @@ class CreateRagPipelineDatasetApi(Resource): return import_info, 201 +@console_ns.route("/rag/pipeline/empty-dataset") class CreateEmptyRagPipelineDatasetApi(Resource): @setup_required @login_required @@ -108,7 +98,3 @@ class CreateEmptyRagPipelineDatasetApi(Resource): ), ) return marshal(dataset, dataset_detail_fields), 201 - - -api.add_resource(CreateRagPipelineDatasetApi, "/rag/pipeline/dataset") -api.add_resource(CreateEmptyRagPipelineDatasetApi, "/rag/pipeline/empty-dataset") diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py index db07e7729a..bef6bfd13e 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py @@ -1,24 +1,22 @@ import logging -from typing import Any, NoReturn +from typing import NoReturn from flask import Response from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden -from controllers.console import api +from controllers.console import console_ns from controllers.console.app.error import ( DraftWorkflowNotExist, ) from controllers.console.app.workflow_draft_variable import ( - _WORKFLOW_DRAFT_VARIABLE_FIELDS, - _WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS, + _WORKFLOW_DRAFT_VARIABLE_FIELDS, # type: ignore[private-usage] + _WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS, # type: ignore[private-usage] ) from controllers.console.datasets.wraps import get_rag_pipeline from controllers.console.wraps import account_initialization_required, setup_required from controllers.web.error import InvalidArgumentError, NotFoundError -from core.variables.segment_group import SegmentGroup -from core.variables.segments import ArrayFileSegment, FileSegment, Segment from core.variables.types import SegmentType from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID from extensions.ext_database import db @@ -34,32 +32,6 @@ from services.workflow_draft_variable_service import WorkflowDraftVariableList, logger = logging.getLogger(__name__) -def _convert_values_to_json_serializable_object(value: Segment) -> Any: - if isinstance(value, FileSegment): - return value.value.model_dump() - elif isinstance(value, ArrayFileSegment): - return [i.model_dump() for i in value.value] - elif isinstance(value, SegmentGroup): - return [_convert_values_to_json_serializable_object(i) for i in value.value] - else: - return value.value - - -def _serialize_var_value(variable: WorkflowDraftVariable) -> Any: - value = variable.get_value() - # create a copy of the value to avoid affecting the model cache. - value = value.model_copy(deep=True) - # Refresh the url signature before returning it to client. - if isinstance(value, FileSegment): - file = value.value - file.remote_url = file.generate_url() - elif isinstance(value, ArrayFileSegment): - files = value.value - for file in files: - file.remote_url = file.generate_url() - return _convert_values_to_json_serializable_object(value) - - def _create_pagination_parser(): parser = reqparse.RequestParser() parser.add_argument( @@ -104,13 +76,14 @@ def _api_prerequisite(f): @account_initialization_required @get_rag_pipeline def wrapper(*args, **kwargs): - if not isinstance(current_user, Account) or not current_user.is_editor: + if not isinstance(current_user, Account) or not current_user.has_edit_permission: raise Forbidden() return f(*args, **kwargs) return wrapper +@console_ns.route("/rag/pipelines//workflows/draft/variables") class RagPipelineVariableCollectionApi(Resource): @_api_prerequisite @marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS) @@ -168,6 +141,7 @@ def validate_node_id(node_id: str) -> NoReturn | None: return None +@console_ns.route("/rag/pipelines//workflows/draft/nodes//variables") class RagPipelineNodeVariableCollectionApi(Resource): @_api_prerequisite @marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS) @@ -190,6 +164,7 @@ class RagPipelineNodeVariableCollectionApi(Resource): return Response("", 204) +@console_ns.route("/rag/pipelines//workflows/draft/variables/") class RagPipelineVariableApi(Resource): _PATCH_NAME_FIELD = "name" _PATCH_VALUE_FIELD = "value" @@ -284,6 +259,7 @@ class RagPipelineVariableApi(Resource): return Response("", 204) +@console_ns.route("/rag/pipelines//workflows/draft/variables//reset") class RagPipelineVariableResetApi(Resource): @_api_prerequisite def put(self, pipeline: Pipeline, variable_id: str): @@ -325,6 +301,7 @@ def _get_variable_list(pipeline: Pipeline, node_id) -> WorkflowDraftVariableList return draft_vars +@console_ns.route("/rag/pipelines//workflows/draft/system-variables") class RagPipelineSystemVariableCollectionApi(Resource): @_api_prerequisite @marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS) @@ -332,6 +309,7 @@ class RagPipelineSystemVariableCollectionApi(Resource): return _get_variable_list(pipeline, SYSTEM_VARIABLE_NODE_ID) +@console_ns.route("/rag/pipelines//workflows/draft/environment-variables") class RagPipelineEnvironmentVariableCollectionApi(Resource): @_api_prerequisite def get(self, pipeline: Pipeline): @@ -364,26 +342,3 @@ class RagPipelineEnvironmentVariableCollectionApi(Resource): ) return {"items": env_vars_list} - - -api.add_resource( - RagPipelineVariableCollectionApi, - "/rag/pipelines//workflows/draft/variables", -) -api.add_resource( - RagPipelineNodeVariableCollectionApi, - "/rag/pipelines//workflows/draft/nodes//variables", -) -api.add_resource( - RagPipelineVariableApi, "/rag/pipelines//workflows/draft/variables/" -) -api.add_resource( - RagPipelineVariableResetApi, "/rag/pipelines//workflows/draft/variables//reset" -) -api.add_resource( - RagPipelineSystemVariableCollectionApi, "/rag/pipelines//workflows/draft/system-variables" -) -api.add_resource( - RagPipelineEnvironmentVariableCollectionApi, - "/rag/pipelines//workflows/draft/environment-variables", -) diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_import.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_import.py index a447f2848a..a82872ba2b 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_import.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_import.py @@ -5,7 +5,7 @@ from flask_restx import Resource, marshal_with, reqparse # type: ignore from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden -from controllers.console import api +from controllers.console import console_ns from controllers.console.datasets.wraps import get_rag_pipeline from controllers.console.wraps import ( account_initialization_required, @@ -20,6 +20,7 @@ from services.app_dsl_service import ImportStatus from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService +@console_ns.route("/rag/pipelines/imports") class RagPipelineImportApi(Resource): @setup_required @login_required @@ -59,13 +60,14 @@ class RagPipelineImportApi(Resource): # Return appropriate status code based on result status = result.status - if status == ImportStatus.FAILED.value: + if status == ImportStatus.FAILED: return result.model_dump(mode="json"), 400 - elif status == ImportStatus.PENDING.value: + elif status == ImportStatus.PENDING: return result.model_dump(mode="json"), 202 return result.model_dump(mode="json"), 200 +@console_ns.route("/rag/pipelines/imports//confirm") class RagPipelineImportConfirmApi(Resource): @setup_required @login_required @@ -85,11 +87,12 @@ class RagPipelineImportConfirmApi(Resource): session.commit() # Return appropriate status code based on result - if result.status == ImportStatus.FAILED.value: + if result.status == ImportStatus.FAILED: return result.model_dump(mode="json"), 400 return result.model_dump(mode="json"), 200 +@console_ns.route("/rag/pipelines/imports//check-dependencies") class RagPipelineImportCheckDependenciesApi(Resource): @setup_required @login_required @@ -107,6 +110,7 @@ class RagPipelineImportCheckDependenciesApi(Resource): return result.model_dump(mode="json"), 200 +@console_ns.route("/rag/pipelines//exports") class RagPipelineExportApi(Resource): @setup_required @login_required @@ -128,22 +132,3 @@ class RagPipelineExportApi(Resource): ) return {"data": result}, 200 - - -# Import Rag Pipeline -api.add_resource( - RagPipelineImportApi, - "/rag/pipelines/imports", -) -api.add_resource( - RagPipelineImportConfirmApi, - "/rag/pipelines/imports//confirm", -) -api.add_resource( - RagPipelineImportCheckDependenciesApi, - "/rag/pipelines/imports//check-dependencies", -) -api.add_resource( - RagPipelineExportApi, - "/rag/pipelines//exports", -) diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py index d00be3a573..a75c121fbe 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py @@ -9,8 +9,7 @@ from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden, InternalServerError, NotFound import services -from configs import dify_config -from controllers.console import api +from controllers.console import console_ns from controllers.console.app.error import ( ConversationCompletedError, DraftWorkflowNotExist, @@ -51,6 +50,7 @@ from services.rag_pipeline.rag_pipeline_transform_service import RagPipelineTran logger = logging.getLogger(__name__) +@console_ns.route("/rag/pipelines//workflows/draft") class DraftRagPipelineApi(Resource): @setup_required @login_required @@ -148,6 +148,7 @@ class DraftRagPipelineApi(Resource): } +@console_ns.route("/rag/pipelines//workflows/draft/iteration/nodes//run") class RagPipelineDraftRunIterationNodeApi(Resource): @setup_required @login_required @@ -182,6 +183,7 @@ class RagPipelineDraftRunIterationNodeApi(Resource): raise InternalServerError() +@console_ns.route("/rag/pipelines//workflows/draft/loop/nodes//run") class RagPipelineDraftRunLoopNodeApi(Resource): @setup_required @login_required @@ -216,6 +218,7 @@ class RagPipelineDraftRunLoopNodeApi(Resource): raise InternalServerError() +@console_ns.route("/rag/pipelines//workflows/draft/run") class DraftRagPipelineRunApi(Resource): @setup_required @login_required @@ -250,6 +253,7 @@ class DraftRagPipelineRunApi(Resource): raise InvokeRateLimitHttpError(ex.description) +@console_ns.route("/rag/pipelines//workflows/published/run") class PublishedRagPipelineRunApi(Resource): @setup_required @login_required @@ -370,6 +374,7 @@ class PublishedRagPipelineRunApi(Resource): # # return result # +@console_ns.route("/rag/pipelines//workflows/published/datasource/nodes//run") class RagPipelinePublishedDatasourceNodeRunApi(Resource): @setup_required @login_required @@ -412,6 +417,7 @@ class RagPipelinePublishedDatasourceNodeRunApi(Resource): ) +@console_ns.route("/rag/pipelines//workflows/draft/datasource/nodes//run") class RagPipelineDraftDatasourceNodeRunApi(Resource): @setup_required @login_required @@ -454,6 +460,7 @@ class RagPipelineDraftDatasourceNodeRunApi(Resource): ) +@console_ns.route("/rag/pipelines//workflows/draft/nodes//run") class RagPipelineDraftNodeRunApi(Resource): @setup_required @login_required @@ -487,6 +494,7 @@ class RagPipelineDraftNodeRunApi(Resource): return workflow_node_execution +@console_ns.route("/rag/pipelines//workflow-runs/tasks//stop") class RagPipelineTaskStopApi(Resource): @setup_required @login_required @@ -505,6 +513,7 @@ class RagPipelineTaskStopApi(Resource): return {"result": "success"} +@console_ns.route("/rag/pipelines//workflows/publish") class PublishedRagPipelineApi(Resource): @setup_required @login_required @@ -560,6 +569,7 @@ class PublishedRagPipelineApi(Resource): } +@console_ns.route("/rag/pipelines//workflows/default-workflow-block-configs") class DefaultRagPipelineBlockConfigsApi(Resource): @setup_required @login_required @@ -578,6 +588,7 @@ class DefaultRagPipelineBlockConfigsApi(Resource): return rag_pipeline_service.get_default_block_configs() +@console_ns.route("/rag/pipelines//workflows/default-workflow-block-configs/") class DefaultRagPipelineBlockConfigApi(Resource): @setup_required @login_required @@ -609,18 +620,7 @@ class DefaultRagPipelineBlockConfigApi(Resource): return rag_pipeline_service.get_default_block_config(node_type=block_type, filters=filters) -class RagPipelineConfigApi(Resource): - """Resource for rag pipeline configuration.""" - - @setup_required - @login_required - @account_initialization_required - def get(self, pipeline_id): - return { - "parallel_depth_limit": dify_config.WORKFLOW_PARALLEL_DEPTH_LIMIT, - } - - +@console_ns.route("/rag/pipelines//workflows") class PublishedAllRagPipelineApi(Resource): @setup_required @login_required @@ -669,6 +669,7 @@ class PublishedAllRagPipelineApi(Resource): } +@console_ns.route("/rag/pipelines//workflows/") class RagPipelineByIdApi(Resource): @setup_required @login_required @@ -726,6 +727,7 @@ class RagPipelineByIdApi(Resource): return workflow +@console_ns.route("/rag/pipelines//workflows/published/processing/parameters") class PublishedRagPipelineSecondStepApi(Resource): @setup_required @login_required @@ -751,6 +753,7 @@ class PublishedRagPipelineSecondStepApi(Resource): } +@console_ns.route("/rag/pipelines//workflows/published/pre-processing/parameters") class PublishedRagPipelineFirstStepApi(Resource): @setup_required @login_required @@ -776,6 +779,7 @@ class PublishedRagPipelineFirstStepApi(Resource): } +@console_ns.route("/rag/pipelines//workflows/draft/pre-processing/parameters") class DraftRagPipelineFirstStepApi(Resource): @setup_required @login_required @@ -801,6 +805,7 @@ class DraftRagPipelineFirstStepApi(Resource): } +@console_ns.route("/rag/pipelines//workflows/draft/processing/parameters") class DraftRagPipelineSecondStepApi(Resource): @setup_required @login_required @@ -827,6 +832,7 @@ class DraftRagPipelineSecondStepApi(Resource): } +@console_ns.route("/rag/pipelines//workflow-runs") class RagPipelineWorkflowRunListApi(Resource): @setup_required @login_required @@ -848,6 +854,7 @@ class RagPipelineWorkflowRunListApi(Resource): return result +@console_ns.route("/rag/pipelines//workflow-runs/") class RagPipelineWorkflowRunDetailApi(Resource): @setup_required @login_required @@ -866,6 +873,7 @@ class RagPipelineWorkflowRunDetailApi(Resource): return workflow_run +@console_ns.route("/rag/pipelines//workflow-runs//node-executions") class RagPipelineWorkflowRunNodeExecutionListApi(Resource): @setup_required @login_required @@ -889,6 +897,7 @@ class RagPipelineWorkflowRunNodeExecutionListApi(Resource): return {"data": node_executions} +@console_ns.route("/rag/pipelines/datasource-plugins") class DatasourceListApi(Resource): @setup_required @login_required @@ -904,6 +913,7 @@ class DatasourceListApi(Resource): return jsonable_encoder(RagPipelineManageService.list_rag_pipeline_datasources(tenant_id)) +@console_ns.route("/rag/pipelines//workflows/draft/nodes//last-run") class RagPipelineWorkflowLastRunApi(Resource): @setup_required @login_required @@ -925,6 +935,7 @@ class RagPipelineWorkflowLastRunApi(Resource): return node_exec +@console_ns.route("/rag/pipelines/transform/datasets/") class RagPipelineTransformApi(Resource): @setup_required @login_required @@ -942,6 +953,7 @@ class RagPipelineTransformApi(Resource): return result +@console_ns.route("/rag/pipelines//workflows/draft/datasource/variables-inspect") class RagPipelineDatasourceVariableApi(Resource): @setup_required @login_required @@ -971,6 +983,7 @@ class RagPipelineDatasourceVariableApi(Resource): return workflow_node_execution +@console_ns.route("/rag/pipelines/recommended-plugins") class RagPipelineRecommendedPluginApi(Resource): @setup_required @login_required @@ -979,118 +992,3 @@ class RagPipelineRecommendedPluginApi(Resource): rag_pipeline_service = RagPipelineService() recommended_plugins = rag_pipeline_service.get_recommended_plugins() return recommended_plugins - - -api.add_resource( - DraftRagPipelineApi, - "/rag/pipelines//workflows/draft", -) -api.add_resource( - RagPipelineConfigApi, - "/rag/pipelines//workflows/draft/config", -) -api.add_resource( - DraftRagPipelineRunApi, - "/rag/pipelines//workflows/draft/run", -) -api.add_resource( - PublishedRagPipelineRunApi, - "/rag/pipelines//workflows/published/run", -) -api.add_resource( - RagPipelineTaskStopApi, - "/rag/pipelines//workflow-runs/tasks//stop", -) -api.add_resource( - RagPipelineDraftNodeRunApi, - "/rag/pipelines//workflows/draft/nodes//run", -) -api.add_resource( - RagPipelinePublishedDatasourceNodeRunApi, - "/rag/pipelines//workflows/published/datasource/nodes//run", -) - -api.add_resource( - RagPipelineDraftDatasourceNodeRunApi, - "/rag/pipelines//workflows/draft/datasource/nodes//run", -) - -api.add_resource( - RagPipelineDraftRunIterationNodeApi, - "/rag/pipelines//workflows/draft/iteration/nodes//run", -) - -api.add_resource( - RagPipelineDraftRunLoopNodeApi, - "/rag/pipelines//workflows/draft/loop/nodes//run", -) - -api.add_resource( - PublishedRagPipelineApi, - "/rag/pipelines//workflows/publish", -) -api.add_resource( - PublishedAllRagPipelineApi, - "/rag/pipelines//workflows", -) -api.add_resource( - DefaultRagPipelineBlockConfigsApi, - "/rag/pipelines//workflows/default-workflow-block-configs", -) -api.add_resource( - DefaultRagPipelineBlockConfigApi, - "/rag/pipelines//workflows/default-workflow-block-configs/", -) -api.add_resource( - RagPipelineByIdApi, - "/rag/pipelines//workflows/", -) -api.add_resource( - RagPipelineWorkflowRunListApi, - "/rag/pipelines//workflow-runs", -) -api.add_resource( - RagPipelineWorkflowRunDetailApi, - "/rag/pipelines//workflow-runs/", -) -api.add_resource( - RagPipelineWorkflowRunNodeExecutionListApi, - "/rag/pipelines//workflow-runs//node-executions", -) -api.add_resource( - DatasourceListApi, - "/rag/pipelines/datasource-plugins", -) -api.add_resource( - PublishedRagPipelineSecondStepApi, - "/rag/pipelines//workflows/published/processing/parameters", -) -api.add_resource( - PublishedRagPipelineFirstStepApi, - "/rag/pipelines//workflows/published/pre-processing/parameters", -) -api.add_resource( - DraftRagPipelineSecondStepApi, - "/rag/pipelines//workflows/draft/processing/parameters", -) -api.add_resource( - DraftRagPipelineFirstStepApi, - "/rag/pipelines//workflows/draft/pre-processing/parameters", -) -api.add_resource( - RagPipelineWorkflowLastRunApi, - "/rag/pipelines//workflows/draft/nodes//last-run", -) -api.add_resource( - RagPipelineTransformApi, - "/rag/pipelines/transform/datasets/", -) -api.add_resource( - RagPipelineDatasourceVariableApi, - "/rag/pipelines//workflows/draft/datasource/variables-inspect", -) - -api.add_resource( - RagPipelineRecommendedPluginApi, - "/rag/pipelines/recommended-plugins", -) diff --git a/api/controllers/console/explore/audio.py b/api/controllers/console/explore/audio.py index dc275fe18a..7c20fb49d8 100644 --- a/api/controllers/console/explore/audio.py +++ b/api/controllers/console/explore/audio.py @@ -26,9 +26,15 @@ from services.errors.audio import ( UnsupportedAudioTypeServiceError, ) +from .. import console_ns + logger = logging.getLogger(__name__) +@console_ns.route( + "/installed-apps//audio-to-text", + endpoint="installed_app_audio", +) class ChatAudioApi(InstalledAppResource): def post(self, installed_app): app_model = installed_app.app @@ -65,6 +71,10 @@ class ChatAudioApi(InstalledAppResource): raise InternalServerError() +@console_ns.route( + "/installed-apps//text-to-audio", + endpoint="installed_app_text", +) class ChatTextApi(InstalledAppResource): def post(self, installed_app): from flask_restx import reqparse diff --git a/api/controllers/console/explore/completion.py b/api/controllers/console/explore/completion.py index a99708b7cd..1102b815eb 100644 --- a/api/controllers/console/explore/completion.py +++ b/api/controllers/console/explore/completion.py @@ -33,10 +33,16 @@ from models.model import AppMode from services.app_generate_service import AppGenerateService from services.errors.llm import InvokeRateLimitError +from .. import console_ns + logger = logging.getLogger(__name__) # define completion api for user +@console_ns.route( + "/installed-apps//completion-messages", + endpoint="installed_app_completion", +) class CompletionApi(InstalledAppResource): def post(self, installed_app): app_model = installed_app.app @@ -87,6 +93,10 @@ class CompletionApi(InstalledAppResource): raise InternalServerError() +@console_ns.route( + "/installed-apps//completion-messages//stop", + endpoint="installed_app_stop_completion", +) class CompletionStopApi(InstalledAppResource): def post(self, installed_app, task_id): app_model = installed_app.app @@ -100,6 +110,10 @@ class CompletionStopApi(InstalledAppResource): return {"result": "success"}, 200 +@console_ns.route( + "/installed-apps//chat-messages", + endpoint="installed_app_chat_completion", +) class ChatApi(InstalledAppResource): def post(self, installed_app): app_model = installed_app.app @@ -153,6 +167,10 @@ class ChatApi(InstalledAppResource): raise InternalServerError() +@console_ns.route( + "/installed-apps//chat-messages//stop", + endpoint="installed_app_stop_chat_completion", +) class ChatStopApi(InstalledAppResource): def post(self, installed_app, task_id): app_model = installed_app.app diff --git a/api/controllers/console/explore/conversation.py b/api/controllers/console/explore/conversation.py index 1aef9c544d..feabea2524 100644 --- a/api/controllers/console/explore/conversation.py +++ b/api/controllers/console/explore/conversation.py @@ -16,7 +16,13 @@ from services.conversation_service import ConversationService from services.errors.conversation import ConversationNotExistsError, LastConversationNotExistsError from services.web_conversation_service import WebConversationService +from .. import console_ns + +@console_ns.route( + "/installed-apps//conversations", + endpoint="installed_app_conversations", +) class ConversationListApi(InstalledAppResource): @marshal_with(conversation_infinite_scroll_pagination_fields) def get(self, installed_app): @@ -52,6 +58,10 @@ class ConversationListApi(InstalledAppResource): raise NotFound("Last Conversation Not Exists.") +@console_ns.route( + "/installed-apps//conversations/", + endpoint="installed_app_conversation", +) class ConversationApi(InstalledAppResource): def delete(self, installed_app, c_id): app_model = installed_app.app @@ -70,6 +80,10 @@ class ConversationApi(InstalledAppResource): return {"result": "success"}, 204 +@console_ns.route( + "/installed-apps//conversations//name", + endpoint="installed_app_conversation_rename", +) class ConversationRenameApi(InstalledAppResource): @marshal_with(simple_conversation_fields) def post(self, installed_app, c_id): @@ -95,6 +109,10 @@ class ConversationRenameApi(InstalledAppResource): raise NotFound("Conversation Not Exists.") +@console_ns.route( + "/installed-apps//conversations//pin", + endpoint="installed_app_conversation_pin", +) class ConversationPinApi(InstalledAppResource): def patch(self, installed_app, c_id): app_model = installed_app.app @@ -114,6 +132,10 @@ class ConversationPinApi(InstalledAppResource): return {"result": "success"} +@console_ns.route( + "/installed-apps//conversations//unpin", + endpoint="installed_app_conversation_unpin", +) class ConversationUnPinApi(InstalledAppResource): def patch(self, installed_app, c_id): app_model = installed_app.app diff --git a/api/controllers/console/explore/installed_app.py b/api/controllers/console/explore/installed_app.py index bdc3fb0dbd..c86c243c9b 100644 --- a/api/controllers/console/explore/installed_app.py +++ b/api/controllers/console/explore/installed_app.py @@ -6,7 +6,7 @@ from flask_restx import Resource, inputs, marshal_with, reqparse from sqlalchemy import and_, select from werkzeug.exceptions import BadRequest, Forbidden, NotFound -from controllers.console import api +from controllers.console import console_ns from controllers.console.explore.wraps import InstalledAppResource from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check from extensions.ext_database import db @@ -22,6 +22,7 @@ from services.feature_service import FeatureService logger = logging.getLogger(__name__) +@console_ns.route("/installed-apps") class InstalledAppsListApi(Resource): @login_required @account_initialization_required @@ -154,6 +155,7 @@ class InstalledAppsListApi(Resource): return {"message": "App installed successfully"} +@console_ns.route("/installed-apps/") class InstalledAppApi(InstalledAppResource): """ update and delete an installed app @@ -185,7 +187,3 @@ class InstalledAppApi(InstalledAppResource): db.session.commit() return {"result": "success", "message": "App info updated successfully"} - - -api.add_resource(InstalledAppsListApi, "/installed-apps") -api.add_resource(InstalledAppApi, "/installed-apps/") diff --git a/api/controllers/console/explore/message.py b/api/controllers/console/explore/message.py index c46c1c1f4f..b045e47846 100644 --- a/api/controllers/console/explore/message.py +++ b/api/controllers/console/explore/message.py @@ -36,9 +36,15 @@ from services.errors.message import ( ) from services.message_service import MessageService +from .. import console_ns + logger = logging.getLogger(__name__) +@console_ns.route( + "/installed-apps//messages", + endpoint="installed_app_messages", +) class MessageListApi(InstalledAppResource): @marshal_with(message_infinite_scroll_pagination_fields) def get(self, installed_app): @@ -66,6 +72,10 @@ class MessageListApi(InstalledAppResource): raise NotFound("First Message Not Exists.") +@console_ns.route( + "/installed-apps//messages//feedbacks", + endpoint="installed_app_message_feedback", +) class MessageFeedbackApi(InstalledAppResource): def post(self, installed_app, message_id): app_model = installed_app.app @@ -93,6 +103,10 @@ class MessageFeedbackApi(InstalledAppResource): return {"result": "success"} +@console_ns.route( + "/installed-apps//messages//more-like-this", + endpoint="installed_app_more_like_this", +) class MessageMoreLikeThisApi(InstalledAppResource): def get(self, installed_app, message_id): app_model = installed_app.app @@ -139,6 +153,10 @@ class MessageMoreLikeThisApi(InstalledAppResource): raise InternalServerError() +@console_ns.route( + "/installed-apps//messages//suggested-questions", + endpoint="installed_app_suggested_question", +) class MessageSuggestedQuestionApi(InstalledAppResource): def get(self, installed_app, message_id): app_model = installed_app.app diff --git a/api/controllers/console/explore/parameter.py b/api/controllers/console/explore/parameter.py index 7742ea24a9..9c6b2aedfb 100644 --- a/api/controllers/console/explore/parameter.py +++ b/api/controllers/console/explore/parameter.py @@ -1,7 +1,7 @@ from flask_restx import marshal_with from controllers.common import fields -from controllers.console import api +from controllers.console import console_ns from controllers.console.app.error import AppUnavailableError from controllers.console.explore.wraps import InstalledAppResource from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict @@ -9,6 +9,7 @@ from models.model import AppMode, InstalledApp from services.app_service import AppService +@console_ns.route("/installed-apps//parameters", endpoint="installed_app_parameters") class AppParameterApi(InstalledAppResource): """Resource for app variables.""" @@ -39,6 +40,7 @@ class AppParameterApi(InstalledAppResource): return get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form) +@console_ns.route("/installed-apps//meta", endpoint="installed_app_meta") class ExploreAppMetaApi(InstalledAppResource): def get(self, installed_app: InstalledApp): """Get app meta""" @@ -46,9 +48,3 @@ class ExploreAppMetaApi(InstalledAppResource): if not app_model: raise ValueError("App not found") return AppService().get_app_meta(app_model) - - -api.add_resource( - AppParameterApi, "/installed-apps//parameters", endpoint="installed_app_parameters" -) -api.add_resource(ExploreAppMetaApi, "/installed-apps//meta", endpoint="installed_app_meta") diff --git a/api/controllers/console/explore/recommended_app.py b/api/controllers/console/explore/recommended_app.py index 974222ddf7..6d627a929a 100644 --- a/api/controllers/console/explore/recommended_app.py +++ b/api/controllers/console/explore/recommended_app.py @@ -1,7 +1,7 @@ from flask_restx import Resource, fields, marshal_with, reqparse from constants.languages import languages -from controllers.console import api +from controllers.console import console_ns from controllers.console.wraps import account_initialization_required from libs.helper import AppIconUrlField from libs.login import current_user, login_required @@ -35,6 +35,7 @@ recommended_app_list_fields = { } +@console_ns.route("/explore/apps") class RecommendedAppListApi(Resource): @login_required @account_initialization_required @@ -56,13 +57,10 @@ class RecommendedAppListApi(Resource): return RecommendedAppService.get_recommended_apps_and_categories(language_prefix) +@console_ns.route("/explore/apps/") class RecommendedAppApi(Resource): @login_required @account_initialization_required def get(self, app_id): app_id = str(app_id) return RecommendedAppService.get_recommend_app_detail(app_id) - - -api.add_resource(RecommendedAppListApi, "/explore/apps") -api.add_resource(RecommendedAppApi, "/explore/apps/") diff --git a/api/controllers/console/explore/saved_message.py b/api/controllers/console/explore/saved_message.py index 6f05f898f9..79e4a4339e 100644 --- a/api/controllers/console/explore/saved_message.py +++ b/api/controllers/console/explore/saved_message.py @@ -2,7 +2,7 @@ from flask_restx import fields, marshal_with, reqparse from flask_restx.inputs import int_range from werkzeug.exceptions import NotFound -from controllers.console import api +from controllers.console import console_ns from controllers.console.explore.error import NotCompletionAppError from controllers.console.explore.wraps import InstalledAppResource from fields.conversation_fields import message_file_fields @@ -25,6 +25,7 @@ message_fields = { } +@console_ns.route("/installed-apps//saved-messages", endpoint="installed_app_saved_messages") class SavedMessageListApi(InstalledAppResource): saved_message_infinite_scroll_pagination_fields = { "limit": fields.Integer, @@ -66,6 +67,9 @@ class SavedMessageListApi(InstalledAppResource): return {"result": "success"} +@console_ns.route( + "/installed-apps//saved-messages/", endpoint="installed_app_saved_message" +) class SavedMessageApi(InstalledAppResource): def delete(self, installed_app, message_id): app_model = installed_app.app @@ -80,15 +84,3 @@ class SavedMessageApi(InstalledAppResource): SavedMessageService.delete(app_model, current_user, message_id) return {"result": "success"}, 204 - - -api.add_resource( - SavedMessageListApi, - "/installed-apps//saved-messages", - endpoint="installed_app_saved_messages", -) -api.add_resource( - SavedMessageApi, - "/installed-apps//saved-messages/", - endpoint="installed_app_saved_message", -) diff --git a/api/controllers/console/explore/workflow.py b/api/controllers/console/explore/workflow.py index 61e0f1b36a..e32f2814eb 100644 --- a/api/controllers/console/explore/workflow.py +++ b/api/controllers/console/explore/workflow.py @@ -27,9 +27,12 @@ from models.model import AppMode, InstalledApp from services.app_generate_service import AppGenerateService from services.errors.llm import InvokeRateLimitError +from .. import console_ns + logger = logging.getLogger(__name__) +@console_ns.route("/installed-apps//workflows/run") class InstalledAppWorkflowRunApi(InstalledAppResource): def post(self, installed_app: InstalledApp): """ @@ -70,6 +73,7 @@ class InstalledAppWorkflowRunApi(InstalledAppResource): raise InternalServerError() +@console_ns.route("/installed-apps//workflows/tasks//stop") class InstalledAppWorkflowTaskStopApi(InstalledAppResource): def post(self, installed_app: InstalledApp, task_id: str): """ diff --git a/api/controllers/console/files.py b/api/controllers/console/files.py index 105f802878..34f186e2f0 100644 --- a/api/controllers/console/files.py +++ b/api/controllers/console/files.py @@ -26,9 +26,12 @@ from libs.login import login_required from models import Account from services.file_service import FileService +from . import console_ns + PREVIEW_WORDS_LIMIT = 3000 +@console_ns.route("/files/upload") class FileApi(Resource): @setup_required @login_required @@ -88,6 +91,7 @@ class FileApi(Resource): return upload_file, 201 +@console_ns.route("/files//preview") class FilePreviewApi(Resource): @setup_required @login_required @@ -98,6 +102,7 @@ class FilePreviewApi(Resource): return {"content": text} +@console_ns.route("/files/support-type") class FileSupportTypeApi(Resource): @setup_required @login_required diff --git a/api/controllers/console/remote_files.py b/api/controllers/console/remote_files.py index dd4f34b9bd..7aaf807fb0 100644 --- a/api/controllers/console/remote_files.py +++ b/api/controllers/console/remote_files.py @@ -19,7 +19,10 @@ from fields.file_fields import file_fields_with_signed_url, remote_file_info_fie from models.account import Account from services.file_service import FileService +from . import console_ns + +@console_ns.route("/remote-files/") class RemoteFileInfoApi(Resource): @marshal_with(remote_file_info_fields) def get(self, url): @@ -35,6 +38,7 @@ class RemoteFileInfoApi(Resource): } +@console_ns.route("/remote-files/upload") class RemoteFileUploadApi(Resource): @marshal_with(file_fields_with_signed_url) def post(self): diff --git a/api/controllers/console/spec.py b/api/controllers/console/spec.py index ca54715fe0..1795e2d172 100644 --- a/api/controllers/console/spec.py +++ b/api/controllers/console/spec.py @@ -2,7 +2,6 @@ import logging from flask_restx import Resource -from controllers.console import api from controllers.console.wraps import ( account_initialization_required, setup_required, @@ -10,9 +9,12 @@ from controllers.console.wraps import ( from core.schemas.schema_manager import SchemaManager from libs.login import login_required +from . import console_ns + logger = logging.getLogger(__name__) +@console_ns.route("/spec/schema-definitions") class SpecSchemaDefinitionsApi(Resource): @setup_required @login_required @@ -30,6 +32,3 @@ class SpecSchemaDefinitionsApi(Resource): logger.exception("Failed to get schema definitions from local registry") # Return empty array as fallback return [], 200 - - -api.add_resource(SpecSchemaDefinitionsApi, "/spec/schema-definitions") diff --git a/api/controllers/console/tag/tags.py b/api/controllers/console/tag/tags.py index da236ee5af..3d29b3ee61 100644 --- a/api/controllers/console/tag/tags.py +++ b/api/controllers/console/tag/tags.py @@ -3,7 +3,7 @@ from flask_login import current_user from flask_restx import Resource, marshal_with, reqparse from werkzeug.exceptions import Forbidden -from controllers.console import api +from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, setup_required from fields.tag_fields import dataset_tag_fields from libs.login import login_required @@ -17,6 +17,7 @@ def _validate_name(name): return name +@console_ns.route("/tags") class TagListApi(Resource): @setup_required @login_required @@ -52,6 +53,7 @@ class TagListApi(Resource): return response, 200 +@console_ns.route("/tags/") class TagUpdateDeleteApi(Resource): @setup_required @login_required @@ -89,6 +91,7 @@ class TagUpdateDeleteApi(Resource): return 204 +@console_ns.route("/tag-bindings/create") class TagBindingCreateApi(Resource): @setup_required @login_required @@ -114,6 +117,7 @@ class TagBindingCreateApi(Resource): return {"result": "success"}, 200 +@console_ns.route("/tag-bindings/remove") class TagBindingDeleteApi(Resource): @setup_required @login_required @@ -133,9 +137,3 @@ class TagBindingDeleteApi(Resource): TagService.delete_tag_binding(args) return {"result": "success"}, 200 - - -api.add_resource(TagListApi, "/tags") -api.add_resource(TagUpdateDeleteApi, "/tags/") -api.add_resource(TagBindingCreateApi, "/tag-bindings/create") -api.add_resource(TagBindingDeleteApi, "/tag-bindings/remove") diff --git a/api/controllers/console/version.py b/api/controllers/console/version.py index 8d081ad995..965a520f70 100644 --- a/api/controllers/console/version.py +++ b/api/controllers/console/version.py @@ -1,7 +1,7 @@ import json import logging -import requests +import httpx from flask_restx import Resource, fields, reqparse from packaging import version @@ -57,7 +57,11 @@ class VersionApi(Resource): return result try: - response = requests.get(check_update_url, {"current_version": args["current_version"]}, timeout=(3, 10)) + response = httpx.get( + check_update_url, + params={"current_version": args["current_version"]}, + timeout=httpx.Timeout(connect=3, read=10), + ) except Exception as error: logger.warning("Check update version error: %s.", str(error)) result["version"] = args["current_version"] diff --git a/api/controllers/console/workspace/account.py b/api/controllers/console/workspace/account.py index 7a41a8a5cc..e2b0e3f84d 100644 --- a/api/controllers/console/workspace/account.py +++ b/api/controllers/console/workspace/account.py @@ -9,7 +9,7 @@ from sqlalchemy.orm import Session from configs import dify_config from constants.languages import supported_language -from controllers.console import api +from controllers.console import console_ns from controllers.console.auth.error import ( EmailAlreadyInUseError, EmailChangeLimitError, @@ -45,6 +45,7 @@ from services.billing_service import BillingService from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError +@console_ns.route("/account/init") class AccountInitApi(Resource): @setup_required @login_required @@ -97,6 +98,7 @@ class AccountInitApi(Resource): return {"result": "success"} +@console_ns.route("/account/profile") class AccountProfileApi(Resource): @setup_required @login_required @@ -109,6 +111,7 @@ class AccountProfileApi(Resource): return current_user +@console_ns.route("/account/name") class AccountNameApi(Resource): @setup_required @login_required @@ -130,6 +133,7 @@ class AccountNameApi(Resource): return updated_account +@console_ns.route("/account/avatar") class AccountAvatarApi(Resource): @setup_required @login_required @@ -147,6 +151,7 @@ class AccountAvatarApi(Resource): return updated_account +@console_ns.route("/account/interface-language") class AccountInterfaceLanguageApi(Resource): @setup_required @login_required @@ -164,6 +169,7 @@ class AccountInterfaceLanguageApi(Resource): return updated_account +@console_ns.route("/account/interface-theme") class AccountInterfaceThemeApi(Resource): @setup_required @login_required @@ -181,6 +187,7 @@ class AccountInterfaceThemeApi(Resource): return updated_account +@console_ns.route("/account/timezone") class AccountTimezoneApi(Resource): @setup_required @login_required @@ -202,6 +209,7 @@ class AccountTimezoneApi(Resource): return updated_account +@console_ns.route("/account/password") class AccountPasswordApi(Resource): @setup_required @login_required @@ -227,6 +235,7 @@ class AccountPasswordApi(Resource): return {"result": "success"} +@console_ns.route("/account/integrates") class AccountIntegrateApi(Resource): integrate_fields = { "provider": fields.String, @@ -283,6 +292,7 @@ class AccountIntegrateApi(Resource): return {"data": integrate_data} +@console_ns.route("/account/delete/verify") class AccountDeleteVerifyApi(Resource): @setup_required @login_required @@ -298,6 +308,7 @@ class AccountDeleteVerifyApi(Resource): return {"result": "success", "data": token} +@console_ns.route("/account/delete") class AccountDeleteApi(Resource): @setup_required @login_required @@ -320,6 +331,7 @@ class AccountDeleteApi(Resource): return {"result": "success"} +@console_ns.route("/account/delete/feedback") class AccountDeleteUpdateFeedbackApi(Resource): @setup_required def post(self): @@ -333,6 +345,7 @@ class AccountDeleteUpdateFeedbackApi(Resource): return {"result": "success"} +@console_ns.route("/account/education/verify") class EducationVerifyApi(Resource): verify_fields = { "token": fields.String, @@ -352,6 +365,7 @@ class EducationVerifyApi(Resource): return BillingService.EducationIdentity.verify(account.id, account.email) +@console_ns.route("/account/education") class EducationApi(Resource): status_fields = { "result": fields.Boolean, @@ -396,6 +410,7 @@ class EducationApi(Resource): return res +@console_ns.route("/account/education/autocomplete") class EducationAutoCompleteApi(Resource): data_fields = { "data": fields.List(fields.String), @@ -419,6 +434,7 @@ class EducationAutoCompleteApi(Resource): return BillingService.EducationIdentity.autocomplete(args["keywords"], args["page"], args["limit"]) +@console_ns.route("/account/change-email") class ChangeEmailSendEmailApi(Resource): @enable_change_email @setup_required @@ -467,6 +483,7 @@ class ChangeEmailSendEmailApi(Resource): return {"result": "success", "data": token} +@console_ns.route("/account/change-email/validity") class ChangeEmailCheckApi(Resource): @enable_change_email @setup_required @@ -508,6 +525,7 @@ class ChangeEmailCheckApi(Resource): return {"is_valid": True, "email": token_data.get("email"), "token": new_token} +@console_ns.route("/account/change-email/reset") class ChangeEmailResetApi(Resource): @enable_change_email @setup_required @@ -547,6 +565,7 @@ class ChangeEmailResetApi(Resource): return updated_account +@console_ns.route("/account/change-email/check-email-unique") class CheckEmailUnique(Resource): @setup_required def post(self): @@ -558,28 +577,3 @@ class CheckEmailUnique(Resource): if not AccountService.check_email_unique(args["email"]): raise EmailAlreadyInUseError() return {"result": "success"} - - -# Register API resources -api.add_resource(AccountInitApi, "/account/init") -api.add_resource(AccountProfileApi, "/account/profile") -api.add_resource(AccountNameApi, "/account/name") -api.add_resource(AccountAvatarApi, "/account/avatar") -api.add_resource(AccountInterfaceLanguageApi, "/account/interface-language") -api.add_resource(AccountInterfaceThemeApi, "/account/interface-theme") -api.add_resource(AccountTimezoneApi, "/account/timezone") -api.add_resource(AccountPasswordApi, "/account/password") -api.add_resource(AccountIntegrateApi, "/account/integrates") -api.add_resource(AccountDeleteVerifyApi, "/account/delete/verify") -api.add_resource(AccountDeleteApi, "/account/delete") -api.add_resource(AccountDeleteUpdateFeedbackApi, "/account/delete/feedback") -api.add_resource(EducationVerifyApi, "/account/education/verify") -api.add_resource(EducationApi, "/account/education") -api.add_resource(EducationAutoCompleteApi, "/account/education/autocomplete") -# Change email -api.add_resource(ChangeEmailSendEmailApi, "/account/change-email") -api.add_resource(ChangeEmailCheckApi, "/account/change-email/validity") -api.add_resource(ChangeEmailResetApi, "/account/change-email/reset") -api.add_resource(CheckEmailUnique, "/account/change-email/check-email-unique") -# api.add_resource(AccountEmailApi, '/account/email') -# api.add_resource(AccountEmailVerifyApi, '/account/email-verify') diff --git a/api/controllers/console/workspace/load_balancing_config.py b/api/controllers/console/workspace/load_balancing_config.py index 7c1bc7c075..99a1c1f032 100644 --- a/api/controllers/console/workspace/load_balancing_config.py +++ b/api/controllers/console/workspace/load_balancing_config.py @@ -1,7 +1,7 @@ from flask_restx import Resource, reqparse from werkzeug.exceptions import Forbidden -from controllers.console import api +from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, setup_required from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.errors.validate import CredentialsValidateFailedError @@ -10,6 +10,9 @@ from models.account import Account, TenantAccountRole from services.model_load_balancing_service import ModelLoadBalancingService +@console_ns.route( + "/workspaces/current/model-providers//models/load-balancing-configs/credentials-validate" +) class LoadBalancingCredentialsValidateApi(Resource): @setup_required @login_required @@ -61,6 +64,9 @@ class LoadBalancingCredentialsValidateApi(Resource): return response +@console_ns.route( + "/workspaces/current/model-providers//models/load-balancing-configs//credentials-validate" +) class LoadBalancingConfigCredentialsValidateApi(Resource): @setup_required @login_required @@ -111,15 +117,3 @@ class LoadBalancingConfigCredentialsValidateApi(Resource): response["error"] = error return response - - -# Load Balancing Config -api.add_resource( - LoadBalancingCredentialsValidateApi, - "/workspaces/current/model-providers//models/load-balancing-configs/credentials-validate", -) - -api.add_resource( - LoadBalancingConfigCredentialsValidateApi, - "/workspaces/current/model-providers//models/load-balancing-configs//credentials-validate", -) diff --git a/api/controllers/console/workspace/members.py b/api/controllers/console/workspace/members.py index 77f0c9a735..8b89853bd9 100644 --- a/api/controllers/console/workspace/members.py +++ b/api/controllers/console/workspace/members.py @@ -6,7 +6,7 @@ from flask_restx import Resource, marshal_with, reqparse import services from configs import dify_config -from controllers.console import api +from controllers.console import console_ns from controllers.console.auth.error import ( CannotTransferOwnerToSelfError, EmailCodeError, @@ -33,6 +33,7 @@ from services.errors.account import AccountAlreadyInTenantError from services.feature_service import FeatureService +@console_ns.route("/workspaces/current/members") class MemberListApi(Resource): """List all members of current tenant.""" @@ -49,6 +50,7 @@ class MemberListApi(Resource): return {"result": "success", "accounts": members}, 200 +@console_ns.route("/workspaces/current/members/invite-email") class MemberInviteEmailApi(Resource): """Invite a new member by email.""" @@ -111,6 +113,7 @@ class MemberInviteEmailApi(Resource): }, 201 +@console_ns.route("/workspaces/current/members/") class MemberCancelInviteApi(Resource): """Cancel an invitation by member id.""" @@ -143,6 +146,7 @@ class MemberCancelInviteApi(Resource): }, 200 +@console_ns.route("/workspaces/current/members//update-role") class MemberUpdateRoleApi(Resource): """Update member role.""" @@ -177,6 +181,7 @@ class MemberUpdateRoleApi(Resource): return {"result": "success"} +@console_ns.route("/workspaces/current/dataset-operators") class DatasetOperatorMemberListApi(Resource): """List all members of current tenant.""" @@ -193,6 +198,7 @@ class DatasetOperatorMemberListApi(Resource): return {"result": "success", "accounts": members}, 200 +@console_ns.route("/workspaces/current/members/send-owner-transfer-confirm-email") class SendOwnerTransferEmailApi(Resource): """Send owner transfer email.""" @@ -233,6 +239,7 @@ class SendOwnerTransferEmailApi(Resource): return {"result": "success", "data": token} +@console_ns.route("/workspaces/current/members/owner-transfer-check") class OwnerTransferCheckApi(Resource): @setup_required @login_required @@ -278,6 +285,7 @@ class OwnerTransferCheckApi(Resource): return {"is_valid": True, "email": token_data.get("email"), "token": new_token} +@console_ns.route("/workspaces/current/members//owner-transfer") class OwnerTransfer(Resource): @setup_required @login_required @@ -339,14 +347,3 @@ class OwnerTransfer(Resource): raise ValueError(str(e)) return {"result": "success"} - - -api.add_resource(MemberListApi, "/workspaces/current/members") -api.add_resource(MemberInviteEmailApi, "/workspaces/current/members/invite-email") -api.add_resource(MemberCancelInviteApi, "/workspaces/current/members/") -api.add_resource(MemberUpdateRoleApi, "/workspaces/current/members//update-role") -api.add_resource(DatasetOperatorMemberListApi, "/workspaces/current/dataset-operators") -# owner transfer -api.add_resource(SendOwnerTransferEmailApi, "/workspaces/current/members/send-owner-transfer-confirm-email") -api.add_resource(OwnerTransferCheckApi, "/workspaces/current/members/owner-transfer-check") -api.add_resource(OwnerTransfer, "/workspaces/current/members//owner-transfer") diff --git a/api/controllers/console/workspace/model_providers.py b/api/controllers/console/workspace/model_providers.py index 0c9db660aa..7012580362 100644 --- a/api/controllers/console/workspace/model_providers.py +++ b/api/controllers/console/workspace/model_providers.py @@ -5,7 +5,7 @@ from flask_login import current_user from flask_restx import Resource, reqparse from werkzeug.exceptions import Forbidden -from controllers.console import api +from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, setup_required from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.errors.validate import CredentialsValidateFailedError @@ -17,6 +17,7 @@ from services.billing_service import BillingService from services.model_provider_service import ModelProviderService +@console_ns.route("/workspaces/current/model-providers") class ModelProviderListApi(Resource): @setup_required @login_required @@ -45,6 +46,7 @@ class ModelProviderListApi(Resource): return jsonable_encoder({"data": provider_list}) +@console_ns.route("/workspaces/current/model-providers//credentials") class ModelProviderCredentialApi(Resource): @setup_required @login_required @@ -151,6 +153,7 @@ class ModelProviderCredentialApi(Resource): return {"result": "success"}, 204 +@console_ns.route("/workspaces/current/model-providers//credentials/switch") class ModelProviderCredentialSwitchApi(Resource): @setup_required @login_required @@ -175,6 +178,7 @@ class ModelProviderCredentialSwitchApi(Resource): return {"result": "success"} +@console_ns.route("/workspaces/current/model-providers//credentials/validate") class ModelProviderValidateApi(Resource): @setup_required @login_required @@ -211,6 +215,7 @@ class ModelProviderValidateApi(Resource): return response +@console_ns.route("/workspaces//model-providers///") class ModelProviderIconApi(Resource): """ Get model provider icon @@ -229,6 +234,7 @@ class ModelProviderIconApi(Resource): return send_file(io.BytesIO(icon), mimetype=mimetype) +@console_ns.route("/workspaces/current/model-providers//preferred-provider-type") class PreferredProviderTypeUpdateApi(Resource): @setup_required @login_required @@ -262,6 +268,7 @@ class PreferredProviderTypeUpdateApi(Resource): return {"result": "success"} +@console_ns.route("/workspaces/current/model-providers//checkout-url") class ModelProviderPaymentCheckoutUrlApi(Resource): @setup_required @login_required @@ -281,21 +288,3 @@ class ModelProviderPaymentCheckoutUrlApi(Resource): prefilled_email=current_user.email, ) return data - - -api.add_resource(ModelProviderListApi, "/workspaces/current/model-providers") - -api.add_resource(ModelProviderCredentialApi, "/workspaces/current/model-providers//credentials") -api.add_resource( - ModelProviderCredentialSwitchApi, "/workspaces/current/model-providers//credentials/switch" -) -api.add_resource(ModelProviderValidateApi, "/workspaces/current/model-providers//credentials/validate") - -api.add_resource( - PreferredProviderTypeUpdateApi, "/workspaces/current/model-providers//preferred-provider-type" -) -api.add_resource(ModelProviderPaymentCheckoutUrlApi, "/workspaces/current/model-providers//checkout-url") -api.add_resource( - ModelProviderIconApi, - "/workspaces//model-providers///", -) diff --git a/api/controllers/console/workspace/models.py b/api/controllers/console/workspace/models.py index f174fcc5d3..d38bb16ea7 100644 --- a/api/controllers/console/workspace/models.py +++ b/api/controllers/console/workspace/models.py @@ -4,7 +4,7 @@ from flask_login import current_user from flask_restx import Resource, reqparse from werkzeug.exceptions import Forbidden -from controllers.console import api +from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, setup_required from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.errors.validate import CredentialsValidateFailedError @@ -17,6 +17,7 @@ from services.model_provider_service import ModelProviderService logger = logging.getLogger(__name__) +@console_ns.route("/workspaces/current/default-model") class DefaultModelApi(Resource): @setup_required @login_required @@ -85,6 +86,7 @@ class DefaultModelApi(Resource): return {"result": "success"} +@console_ns.route("/workspaces/current/model-providers//models") class ModelProviderModelApi(Resource): @setup_required @login_required @@ -187,6 +189,7 @@ class ModelProviderModelApi(Resource): return {"result": "success"}, 204 +@console_ns.route("/workspaces/current/model-providers//models/credentials") class ModelProviderModelCredentialApi(Resource): @setup_required @login_required @@ -364,6 +367,7 @@ class ModelProviderModelCredentialApi(Resource): return {"result": "success"}, 204 +@console_ns.route("/workspaces/current/model-providers//models/credentials/switch") class ModelProviderModelCredentialSwitchApi(Resource): @setup_required @login_required @@ -395,6 +399,9 @@ class ModelProviderModelCredentialSwitchApi(Resource): return {"result": "success"} +@console_ns.route( + "/workspaces/current/model-providers//models/enable", endpoint="model-provider-model-enable" +) class ModelProviderModelEnableApi(Resource): @setup_required @login_required @@ -422,6 +429,9 @@ class ModelProviderModelEnableApi(Resource): return {"result": "success"} +@console_ns.route( + "/workspaces/current/model-providers//models/disable", endpoint="model-provider-model-disable" +) class ModelProviderModelDisableApi(Resource): @setup_required @login_required @@ -449,6 +459,7 @@ class ModelProviderModelDisableApi(Resource): return {"result": "success"} +@console_ns.route("/workspaces/current/model-providers//models/credentials/validate") class ModelProviderModelValidateApi(Resource): @setup_required @login_required @@ -494,6 +505,7 @@ class ModelProviderModelValidateApi(Resource): return response +@console_ns.route("/workspaces/current/model-providers//models/parameter-rules") class ModelProviderModelParameterRuleApi(Resource): @setup_required @login_required @@ -513,6 +525,7 @@ class ModelProviderModelParameterRuleApi(Resource): return jsonable_encoder({"data": parameter_rules}) +@console_ns.route("/workspaces/current/models/model-types/") class ModelProviderAvailableModelApi(Resource): @setup_required @login_required @@ -524,32 +537,3 @@ class ModelProviderAvailableModelApi(Resource): models = model_provider_service.get_models_by_model_type(tenant_id=tenant_id, model_type=model_type) return jsonable_encoder({"data": models}) - - -api.add_resource(ModelProviderModelApi, "/workspaces/current/model-providers//models") -api.add_resource( - ModelProviderModelEnableApi, - "/workspaces/current/model-providers//models/enable", - endpoint="model-provider-model-enable", -) -api.add_resource( - ModelProviderModelDisableApi, - "/workspaces/current/model-providers//models/disable", - endpoint="model-provider-model-disable", -) -api.add_resource( - ModelProviderModelCredentialApi, "/workspaces/current/model-providers//models/credentials" -) -api.add_resource( - ModelProviderModelCredentialSwitchApi, - "/workspaces/current/model-providers//models/credentials/switch", -) -api.add_resource( - ModelProviderModelValidateApi, "/workspaces/current/model-providers//models/credentials/validate" -) - -api.add_resource( - ModelProviderModelParameterRuleApi, "/workspaces/current/model-providers//models/parameter-rules" -) -api.add_resource(ModelProviderAvailableModelApi, "/workspaces/current/models/model-types/") -api.add_resource(DefaultModelApi, "/workspaces/current/default-model") diff --git a/api/controllers/console/workspace/plugin.py b/api/controllers/console/workspace/plugin.py index fd5421fa64..7c70fb8aa0 100644 --- a/api/controllers/console/workspace/plugin.py +++ b/api/controllers/console/workspace/plugin.py @@ -6,7 +6,7 @@ from flask_restx import Resource, reqparse from werkzeug.exceptions import Forbidden from configs import dify_config -from controllers.console import api +from controllers.console import console_ns from controllers.console.workspace import plugin_permission_required from controllers.console.wraps import account_initialization_required, setup_required from core.model_runtime.utils.encoders import jsonable_encoder @@ -19,6 +19,7 @@ from services.plugin.plugin_permission_service import PluginPermissionService from services.plugin.plugin_service import PluginService +@console_ns.route("/workspaces/current/plugin/debugging-key") class PluginDebuggingKeyApi(Resource): @setup_required @login_required @@ -37,6 +38,7 @@ class PluginDebuggingKeyApi(Resource): raise ValueError(e) +@console_ns.route("/workspaces/current/plugin/list") class PluginListApi(Resource): @setup_required @login_required @@ -55,6 +57,7 @@ class PluginListApi(Resource): return jsonable_encoder({"plugins": plugins_with_total.list, "total": plugins_with_total.total}) +@console_ns.route("/workspaces/current/plugin/list/latest-versions") class PluginListLatestVersionsApi(Resource): @setup_required @login_required @@ -72,6 +75,7 @@ class PluginListLatestVersionsApi(Resource): return jsonable_encoder({"versions": versions}) +@console_ns.route("/workspaces/current/plugin/list/installations/ids") class PluginListInstallationsFromIdsApi(Resource): @setup_required @login_required @@ -91,6 +95,7 @@ class PluginListInstallationsFromIdsApi(Resource): return jsonable_encoder({"plugins": plugins}) +@console_ns.route("/workspaces/current/plugin/icon") class PluginIconApi(Resource): @setup_required def get(self): @@ -108,6 +113,7 @@ class PluginIconApi(Resource): return send_file(io.BytesIO(icon_bytes), mimetype=mimetype, max_age=icon_cache_max_age) +@console_ns.route("/workspaces/current/plugin/upload/pkg") class PluginUploadFromPkgApi(Resource): @setup_required @login_required @@ -131,6 +137,7 @@ class PluginUploadFromPkgApi(Resource): return jsonable_encoder(response) +@console_ns.route("/workspaces/current/plugin/upload/github") class PluginUploadFromGithubApi(Resource): @setup_required @login_required @@ -153,6 +160,7 @@ class PluginUploadFromGithubApi(Resource): return jsonable_encoder(response) +@console_ns.route("/workspaces/current/plugin/upload/bundle") class PluginUploadFromBundleApi(Resource): @setup_required @login_required @@ -176,6 +184,7 @@ class PluginUploadFromBundleApi(Resource): return jsonable_encoder(response) +@console_ns.route("/workspaces/current/plugin/install/pkg") class PluginInstallFromPkgApi(Resource): @setup_required @login_required @@ -201,6 +210,7 @@ class PluginInstallFromPkgApi(Resource): return jsonable_encoder(response) +@console_ns.route("/workspaces/current/plugin/install/github") class PluginInstallFromGithubApi(Resource): @setup_required @login_required @@ -230,6 +240,7 @@ class PluginInstallFromGithubApi(Resource): return jsonable_encoder(response) +@console_ns.route("/workspaces/current/plugin/install/marketplace") class PluginInstallFromMarketplaceApi(Resource): @setup_required @login_required @@ -255,6 +266,7 @@ class PluginInstallFromMarketplaceApi(Resource): return jsonable_encoder(response) +@console_ns.route("/workspaces/current/plugin/marketplace/pkg") class PluginFetchMarketplacePkgApi(Resource): @setup_required @login_required @@ -280,6 +292,7 @@ class PluginFetchMarketplacePkgApi(Resource): raise ValueError(e) +@console_ns.route("/workspaces/current/plugin/fetch-manifest") class PluginFetchManifestApi(Resource): @setup_required @login_required @@ -304,6 +317,7 @@ class PluginFetchManifestApi(Resource): raise ValueError(e) +@console_ns.route("/workspaces/current/plugin/tasks") class PluginFetchInstallTasksApi(Resource): @setup_required @login_required @@ -325,6 +339,7 @@ class PluginFetchInstallTasksApi(Resource): raise ValueError(e) +@console_ns.route("/workspaces/current/plugin/tasks/") class PluginFetchInstallTaskApi(Resource): @setup_required @login_required @@ -339,6 +354,7 @@ class PluginFetchInstallTaskApi(Resource): raise ValueError(e) +@console_ns.route("/workspaces/current/plugin/tasks//delete") class PluginDeleteInstallTaskApi(Resource): @setup_required @login_required @@ -353,6 +369,7 @@ class PluginDeleteInstallTaskApi(Resource): raise ValueError(e) +@console_ns.route("/workspaces/current/plugin/tasks/delete_all") class PluginDeleteAllInstallTaskItemsApi(Resource): @setup_required @login_required @@ -367,6 +384,7 @@ class PluginDeleteAllInstallTaskItemsApi(Resource): raise ValueError(e) +@console_ns.route("/workspaces/current/plugin/tasks//delete/") class PluginDeleteInstallTaskItemApi(Resource): @setup_required @login_required @@ -381,6 +399,7 @@ class PluginDeleteInstallTaskItemApi(Resource): raise ValueError(e) +@console_ns.route("/workspaces/current/plugin/upgrade/marketplace") class PluginUpgradeFromMarketplaceApi(Resource): @setup_required @login_required @@ -404,6 +423,7 @@ class PluginUpgradeFromMarketplaceApi(Resource): raise ValueError(e) +@console_ns.route("/workspaces/current/plugin/upgrade/github") class PluginUpgradeFromGithubApi(Resource): @setup_required @login_required @@ -435,6 +455,7 @@ class PluginUpgradeFromGithubApi(Resource): raise ValueError(e) +@console_ns.route("/workspaces/current/plugin/uninstall") class PluginUninstallApi(Resource): @setup_required @login_required @@ -453,6 +474,7 @@ class PluginUninstallApi(Resource): raise ValueError(e) +@console_ns.route("/workspaces/current/plugin/permission/change") class PluginChangePermissionApi(Resource): @setup_required @login_required @@ -475,6 +497,7 @@ class PluginChangePermissionApi(Resource): return {"success": PluginPermissionService.change_permission(tenant_id, install_permission, debug_permission)} +@console_ns.route("/workspaces/current/plugin/permission/fetch") class PluginFetchPermissionApi(Resource): @setup_required @login_required @@ -499,6 +522,7 @@ class PluginFetchPermissionApi(Resource): ) +@console_ns.route("/workspaces/current/plugin/parameters/dynamic-options") class PluginFetchDynamicSelectOptionsApi(Resource): @setup_required @login_required @@ -535,6 +559,7 @@ class PluginFetchDynamicSelectOptionsApi(Resource): return jsonable_encoder({"options": options}) +@console_ns.route("/workspaces/current/plugin/preferences/change") class PluginChangePreferencesApi(Resource): @setup_required @login_required @@ -590,6 +615,7 @@ class PluginChangePreferencesApi(Resource): return jsonable_encoder({"success": True}) +@console_ns.route("/workspaces/current/plugin/preferences/fetch") class PluginFetchPreferencesApi(Resource): @setup_required @login_required @@ -628,6 +654,7 @@ class PluginFetchPreferencesApi(Resource): return jsonable_encoder({"permission": permission_dict, "auto_upgrade": auto_upgrade_dict}) +@console_ns.route("/workspaces/current/plugin/preferences/autoupgrade/exclude") class PluginAutoUpgradeExcludePluginApi(Resource): @setup_required @login_required @@ -641,35 +668,3 @@ class PluginAutoUpgradeExcludePluginApi(Resource): args = req.parse_args() return jsonable_encoder({"success": PluginAutoUpgradeService.exclude_plugin(tenant_id, args["plugin_id"])}) - - -api.add_resource(PluginDebuggingKeyApi, "/workspaces/current/plugin/debugging-key") -api.add_resource(PluginListApi, "/workspaces/current/plugin/list") -api.add_resource(PluginListLatestVersionsApi, "/workspaces/current/plugin/list/latest-versions") -api.add_resource(PluginListInstallationsFromIdsApi, "/workspaces/current/plugin/list/installations/ids") -api.add_resource(PluginIconApi, "/workspaces/current/plugin/icon") -api.add_resource(PluginUploadFromPkgApi, "/workspaces/current/plugin/upload/pkg") -api.add_resource(PluginUploadFromGithubApi, "/workspaces/current/plugin/upload/github") -api.add_resource(PluginUploadFromBundleApi, "/workspaces/current/plugin/upload/bundle") -api.add_resource(PluginInstallFromPkgApi, "/workspaces/current/plugin/install/pkg") -api.add_resource(PluginInstallFromGithubApi, "/workspaces/current/plugin/install/github") -api.add_resource(PluginUpgradeFromMarketplaceApi, "/workspaces/current/plugin/upgrade/marketplace") -api.add_resource(PluginUpgradeFromGithubApi, "/workspaces/current/plugin/upgrade/github") -api.add_resource(PluginInstallFromMarketplaceApi, "/workspaces/current/plugin/install/marketplace") -api.add_resource(PluginFetchManifestApi, "/workspaces/current/plugin/fetch-manifest") -api.add_resource(PluginFetchInstallTasksApi, "/workspaces/current/plugin/tasks") -api.add_resource(PluginFetchInstallTaskApi, "/workspaces/current/plugin/tasks/") -api.add_resource(PluginDeleteInstallTaskApi, "/workspaces/current/plugin/tasks//delete") -api.add_resource(PluginDeleteAllInstallTaskItemsApi, "/workspaces/current/plugin/tasks/delete_all") -api.add_resource(PluginDeleteInstallTaskItemApi, "/workspaces/current/plugin/tasks//delete/") -api.add_resource(PluginUninstallApi, "/workspaces/current/plugin/uninstall") -api.add_resource(PluginFetchMarketplacePkgApi, "/workspaces/current/plugin/marketplace/pkg") - -api.add_resource(PluginChangePermissionApi, "/workspaces/current/plugin/permission/change") -api.add_resource(PluginFetchPermissionApi, "/workspaces/current/plugin/permission/fetch") - -api.add_resource(PluginFetchDynamicSelectOptionsApi, "/workspaces/current/plugin/parameters/dynamic-options") - -api.add_resource(PluginFetchPreferencesApi, "/workspaces/current/plugin/preferences/fetch") -api.add_resource(PluginChangePreferencesApi, "/workspaces/current/plugin/preferences/change") -api.add_resource(PluginAutoUpgradeExcludePluginApi, "/workspaces/current/plugin/preferences/autoupgrade/exclude") diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index 8693d99e23..9285577f72 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -10,7 +10,7 @@ from flask_restx import ( from werkzeug.exceptions import Forbidden from configs import dify_config -from controllers.console import api +from controllers.console import console_ns from controllers.console.wraps import ( account_initialization_required, enterprise_license_required, @@ -47,6 +47,7 @@ def is_valid_url(url: str) -> bool: return False +@console_ns.route("/workspaces/current/tool-providers") class ToolProviderListApi(Resource): @setup_required @login_required @@ -71,6 +72,7 @@ class ToolProviderListApi(Resource): return ToolCommonService.list_tool_providers(user_id, tenant_id, args.get("type", None)) +@console_ns.route("/workspaces/current/tool-provider/builtin//tools") class ToolBuiltinProviderListToolsApi(Resource): @setup_required @login_required @@ -88,6 +90,7 @@ class ToolBuiltinProviderListToolsApi(Resource): ) +@console_ns.route("/workspaces/current/tool-provider/builtin//info") class ToolBuiltinProviderInfoApi(Resource): @setup_required @login_required @@ -100,6 +103,7 @@ class ToolBuiltinProviderInfoApi(Resource): return jsonable_encoder(BuiltinToolManageService.get_builtin_tool_provider_info(tenant_id, provider)) +@console_ns.route("/workspaces/current/tool-provider/builtin//delete") class ToolBuiltinProviderDeleteApi(Resource): @setup_required @login_required @@ -121,6 +125,7 @@ class ToolBuiltinProviderDeleteApi(Resource): ) +@console_ns.route("/workspaces/current/tool-provider/builtin//add") class ToolBuiltinProviderAddApi(Resource): @setup_required @login_required @@ -150,6 +155,7 @@ class ToolBuiltinProviderAddApi(Resource): ) +@console_ns.route("/workspaces/current/tool-provider/builtin//update") class ToolBuiltinProviderUpdateApi(Resource): @setup_required @login_required @@ -181,6 +187,7 @@ class ToolBuiltinProviderUpdateApi(Resource): return result +@console_ns.route("/workspaces/current/tool-provider/builtin//credentials") class ToolBuiltinProviderGetCredentialsApi(Resource): @setup_required @login_required @@ -196,6 +203,7 @@ class ToolBuiltinProviderGetCredentialsApi(Resource): ) +@console_ns.route("/workspaces/current/tool-provider/builtin//icon") class ToolBuiltinProviderIconApi(Resource): @setup_required def get(self, provider): @@ -204,6 +212,7 @@ class ToolBuiltinProviderIconApi(Resource): return send_file(io.BytesIO(icon_bytes), mimetype=mimetype, max_age=icon_cache_max_age) +@console_ns.route("/workspaces/current/tool-provider/api/add") class ToolApiProviderAddApi(Resource): @setup_required @login_required @@ -243,6 +252,7 @@ class ToolApiProviderAddApi(Resource): ) +@console_ns.route("/workspaces/current/tool-provider/api/remote") class ToolApiProviderGetRemoteSchemaApi(Resource): @setup_required @login_required @@ -266,6 +276,7 @@ class ToolApiProviderGetRemoteSchemaApi(Resource): ) +@console_ns.route("/workspaces/current/tool-provider/api/tools") class ToolApiProviderListToolsApi(Resource): @setup_required @login_required @@ -291,6 +302,7 @@ class ToolApiProviderListToolsApi(Resource): ) +@console_ns.route("/workspaces/current/tool-provider/api/update") class ToolApiProviderUpdateApi(Resource): @setup_required @login_required @@ -332,6 +344,7 @@ class ToolApiProviderUpdateApi(Resource): ) +@console_ns.route("/workspaces/current/tool-provider/api/delete") class ToolApiProviderDeleteApi(Resource): @setup_required @login_required @@ -358,6 +371,7 @@ class ToolApiProviderDeleteApi(Resource): ) +@console_ns.route("/workspaces/current/tool-provider/api/get") class ToolApiProviderGetApi(Resource): @setup_required @login_required @@ -381,6 +395,7 @@ class ToolApiProviderGetApi(Resource): ) +@console_ns.route("/workspaces/current/tool-provider/builtin//credential/schema/") class ToolBuiltinProviderCredentialsSchemaApi(Resource): @setup_required @login_required @@ -396,6 +411,7 @@ class ToolBuiltinProviderCredentialsSchemaApi(Resource): ) +@console_ns.route("/workspaces/current/tool-provider/api/schema") class ToolApiProviderSchemaApi(Resource): @setup_required @login_required @@ -412,6 +428,7 @@ class ToolApiProviderSchemaApi(Resource): ) +@console_ns.route("/workspaces/current/tool-provider/api/test/pre") class ToolApiProviderPreviousTestApi(Resource): @setup_required @login_required @@ -439,6 +456,7 @@ class ToolApiProviderPreviousTestApi(Resource): ) +@console_ns.route("/workspaces/current/tool-provider/workflow/create") class ToolWorkflowProviderCreateApi(Resource): @setup_required @login_required @@ -478,6 +496,7 @@ class ToolWorkflowProviderCreateApi(Resource): ) +@console_ns.route("/workspaces/current/tool-provider/workflow/update") class ToolWorkflowProviderUpdateApi(Resource): @setup_required @login_required @@ -520,6 +539,7 @@ class ToolWorkflowProviderUpdateApi(Resource): ) +@console_ns.route("/workspaces/current/tool-provider/workflow/delete") class ToolWorkflowProviderDeleteApi(Resource): @setup_required @login_required @@ -545,6 +565,7 @@ class ToolWorkflowProviderDeleteApi(Resource): ) +@console_ns.route("/workspaces/current/tool-provider/workflow/get") class ToolWorkflowProviderGetApi(Resource): @setup_required @login_required @@ -579,6 +600,7 @@ class ToolWorkflowProviderGetApi(Resource): return jsonable_encoder(tool) +@console_ns.route("/workspaces/current/tool-provider/workflow/tools") class ToolWorkflowProviderListToolApi(Resource): @setup_required @login_required @@ -603,6 +625,7 @@ class ToolWorkflowProviderListToolApi(Resource): ) +@console_ns.route("/workspaces/current/tools/builtin") class ToolBuiltinListApi(Resource): @setup_required @login_required @@ -624,6 +647,7 @@ class ToolBuiltinListApi(Resource): ) +@console_ns.route("/workspaces/current/tools/api") class ToolApiListApi(Resource): @setup_required @login_required @@ -642,6 +666,7 @@ class ToolApiListApi(Resource): ) +@console_ns.route("/workspaces/current/tools/workflow") class ToolWorkflowListApi(Resource): @setup_required @login_required @@ -663,6 +688,7 @@ class ToolWorkflowListApi(Resource): ) +@console_ns.route("/workspaces/current/tool-labels") class ToolLabelsApi(Resource): @setup_required @login_required @@ -672,6 +698,7 @@ class ToolLabelsApi(Resource): return jsonable_encoder(ToolLabelsService.list_tool_labels()) +@console_ns.route("/oauth/plugin//tool/authorization-url") class ToolPluginOAuthApi(Resource): @setup_required @login_required @@ -716,6 +743,7 @@ class ToolPluginOAuthApi(Resource): return response +@console_ns.route("/oauth/plugin//tool/callback") class ToolOAuthCallback(Resource): @setup_required def get(self, provider): @@ -766,6 +794,7 @@ class ToolOAuthCallback(Resource): return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback") +@console_ns.route("/workspaces/current/tool-provider/builtin//default-credential") class ToolBuiltinProviderSetDefaultApi(Resource): @setup_required @login_required @@ -779,6 +808,7 @@ class ToolBuiltinProviderSetDefaultApi(Resource): ) +@console_ns.route("/workspaces/current/tool-provider/builtin//oauth/custom-client") class ToolOAuthCustomClient(Resource): @setup_required @login_required @@ -822,6 +852,7 @@ class ToolOAuthCustomClient(Resource): ) +@console_ns.route("/workspaces/current/tool-provider/builtin//oauth/client-schema") class ToolBuiltinProviderGetOauthClientSchemaApi(Resource): @setup_required @login_required @@ -834,6 +865,7 @@ class ToolBuiltinProviderGetOauthClientSchemaApi(Resource): ) +@console_ns.route("/workspaces/current/tool-provider/builtin//credential/info") class ToolBuiltinProviderGetCredentialInfoApi(Resource): @setup_required @login_required @@ -849,6 +881,7 @@ class ToolBuiltinProviderGetCredentialInfoApi(Resource): ) +@console_ns.route("/workspaces/current/tool-provider/mcp") class ToolProviderMCPApi(Resource): @setup_required @login_required @@ -933,6 +966,7 @@ class ToolProviderMCPApi(Resource): return {"result": "success"} +@console_ns.route("/workspaces/current/tool-provider/mcp/auth") class ToolMCPAuthApi(Resource): @setup_required @login_required @@ -978,6 +1012,7 @@ class ToolMCPAuthApi(Resource): raise ValueError(f"Failed to connect to MCP server: {e}") from e +@console_ns.route("/workspaces/current/tool-provider/mcp/tools/") class ToolMCPDetailApi(Resource): @setup_required @login_required @@ -988,6 +1023,7 @@ class ToolMCPDetailApi(Resource): return jsonable_encoder(ToolTransformService.mcp_provider_to_user_provider(provider, for_list=True)) +@console_ns.route("/workspaces/current/tools/mcp") class ToolMCPListAllApi(Resource): @setup_required @login_required @@ -1001,6 +1037,7 @@ class ToolMCPListAllApi(Resource): return [tool.to_dict() for tool in tools] +@console_ns.route("/workspaces/current/tool-provider/mcp/update/") class ToolMCPUpdateApi(Resource): @setup_required @login_required @@ -1014,6 +1051,7 @@ class ToolMCPUpdateApi(Resource): return jsonable_encoder(tools) +@console_ns.route("/mcp/oauth/callback") class ToolMCPCallbackApi(Resource): def get(self): parser = reqparse.RequestParser() @@ -1024,67 +1062,3 @@ class ToolMCPCallbackApi(Resource): authorization_code = args["code"] handle_callback(state_key, authorization_code) return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback") - - -# tool provider -api.add_resource(ToolProviderListApi, "/workspaces/current/tool-providers") - -# tool oauth -api.add_resource(ToolPluginOAuthApi, "/oauth/plugin//tool/authorization-url") -api.add_resource(ToolOAuthCallback, "/oauth/plugin//tool/callback") -api.add_resource(ToolOAuthCustomClient, "/workspaces/current/tool-provider/builtin//oauth/custom-client") - -# builtin tool provider -api.add_resource(ToolBuiltinProviderListToolsApi, "/workspaces/current/tool-provider/builtin//tools") -api.add_resource(ToolBuiltinProviderInfoApi, "/workspaces/current/tool-provider/builtin//info") -api.add_resource(ToolBuiltinProviderAddApi, "/workspaces/current/tool-provider/builtin//add") -api.add_resource(ToolBuiltinProviderDeleteApi, "/workspaces/current/tool-provider/builtin//delete") -api.add_resource(ToolBuiltinProviderUpdateApi, "/workspaces/current/tool-provider/builtin//update") -api.add_resource( - ToolBuiltinProviderSetDefaultApi, "/workspaces/current/tool-provider/builtin//default-credential" -) -api.add_resource( - ToolBuiltinProviderGetCredentialInfoApi, "/workspaces/current/tool-provider/builtin//credential/info" -) -api.add_resource( - ToolBuiltinProviderGetCredentialsApi, "/workspaces/current/tool-provider/builtin//credentials" -) -api.add_resource( - ToolBuiltinProviderCredentialsSchemaApi, - "/workspaces/current/tool-provider/builtin//credential/schema/", -) -api.add_resource( - ToolBuiltinProviderGetOauthClientSchemaApi, - "/workspaces/current/tool-provider/builtin//oauth/client-schema", -) -api.add_resource(ToolBuiltinProviderIconApi, "/workspaces/current/tool-provider/builtin//icon") - -# api tool provider -api.add_resource(ToolApiProviderAddApi, "/workspaces/current/tool-provider/api/add") -api.add_resource(ToolApiProviderGetRemoteSchemaApi, "/workspaces/current/tool-provider/api/remote") -api.add_resource(ToolApiProviderListToolsApi, "/workspaces/current/tool-provider/api/tools") -api.add_resource(ToolApiProviderUpdateApi, "/workspaces/current/tool-provider/api/update") -api.add_resource(ToolApiProviderDeleteApi, "/workspaces/current/tool-provider/api/delete") -api.add_resource(ToolApiProviderGetApi, "/workspaces/current/tool-provider/api/get") -api.add_resource(ToolApiProviderSchemaApi, "/workspaces/current/tool-provider/api/schema") -api.add_resource(ToolApiProviderPreviousTestApi, "/workspaces/current/tool-provider/api/test/pre") - -# workflow tool provider -api.add_resource(ToolWorkflowProviderCreateApi, "/workspaces/current/tool-provider/workflow/create") -api.add_resource(ToolWorkflowProviderUpdateApi, "/workspaces/current/tool-provider/workflow/update") -api.add_resource(ToolWorkflowProviderDeleteApi, "/workspaces/current/tool-provider/workflow/delete") -api.add_resource(ToolWorkflowProviderGetApi, "/workspaces/current/tool-provider/workflow/get") -api.add_resource(ToolWorkflowProviderListToolApi, "/workspaces/current/tool-provider/workflow/tools") - -# mcp tool provider -api.add_resource(ToolMCPDetailApi, "/workspaces/current/tool-provider/mcp/tools/") -api.add_resource(ToolProviderMCPApi, "/workspaces/current/tool-provider/mcp") -api.add_resource(ToolMCPUpdateApi, "/workspaces/current/tool-provider/mcp/update/") -api.add_resource(ToolMCPAuthApi, "/workspaces/current/tool-provider/mcp/auth") -api.add_resource(ToolMCPCallbackApi, "/mcp/oauth/callback") - -api.add_resource(ToolBuiltinListApi, "/workspaces/current/tools/builtin") -api.add_resource(ToolApiListApi, "/workspaces/current/tools/api") -api.add_resource(ToolMCPListAllApi, "/workspaces/current/tools/mcp") -api.add_resource(ToolWorkflowListApi, "/workspaces/current/tools/workflow") -api.add_resource(ToolLabelsApi, "/workspaces/current/tool-labels") diff --git a/api/controllers/console/workspace/workspace.py b/api/controllers/console/workspace/workspace.py index 6bec70b5da..bc748ac3d2 100644 --- a/api/controllers/console/workspace/workspace.py +++ b/api/controllers/console/workspace/workspace.py @@ -14,7 +14,7 @@ from controllers.common.errors import ( TooManyFilesError, UnsupportedFileTypeError, ) -from controllers.console import api +from controllers.console import console_ns from controllers.console.admin import admin_required from controllers.console.error import AccountNotLinkTenantError from controllers.console.wraps import ( @@ -65,6 +65,7 @@ tenants_fields = { workspace_fields = {"id": fields.String, "name": fields.String, "status": fields.String, "created_at": TimestampField} +@console_ns.route("/workspaces") class TenantListApi(Resource): @setup_required @login_required @@ -93,6 +94,7 @@ class TenantListApi(Resource): return {"workspaces": marshal(tenant_dicts, tenants_fields)}, 200 +@console_ns.route("/all-workspaces") class WorkspaceListApi(Resource): @setup_required @admin_required @@ -118,6 +120,8 @@ class WorkspaceListApi(Resource): }, 200 +@console_ns.route("/workspaces/current", endpoint="workspaces_current") +@console_ns.route("/info", endpoint="info") # Deprecated class TenantApi(Resource): @setup_required @login_required @@ -143,11 +147,10 @@ class TenantApi(Resource): else: raise Unauthorized("workspace is archived") - if not tenant: - raise ValueError("No tenant available") return WorkspaceService.get_tenant_info(tenant), 200 +@console_ns.route("/workspaces/switch") class SwitchWorkspaceApi(Resource): @setup_required @login_required @@ -172,6 +175,7 @@ class SwitchWorkspaceApi(Resource): return {"result": "success", "new_tenant": marshal(WorkspaceService.get_tenant_info(new_tenant), tenant_fields)} +@console_ns.route("/workspaces/custom-config") class CustomConfigWorkspaceApi(Resource): @setup_required @login_required @@ -202,6 +206,7 @@ class CustomConfigWorkspaceApi(Resource): return {"result": "success", "tenant": marshal(WorkspaceService.get_tenant_info(tenant), tenant_fields)} +@console_ns.route("/workspaces/custom-config/webapp-logo/upload") class WebappLogoWorkspaceApi(Resource): @setup_required @login_required @@ -242,6 +247,7 @@ class WebappLogoWorkspaceApi(Resource): return {"id": upload_file.id}, 201 +@console_ns.route("/workspaces/info") class WorkspaceInfoApi(Resource): @setup_required @login_required @@ -261,13 +267,3 @@ class WorkspaceInfoApi(Resource): db.session.commit() return {"result": "success", "tenant": marshal(WorkspaceService.get_tenant_info(tenant), tenant_fields)} - - -api.add_resource(TenantListApi, "/workspaces") # GET for getting all tenants -api.add_resource(WorkspaceListApi, "/all-workspaces") # GET for getting all tenants -api.add_resource(TenantApi, "/workspaces/current", endpoint="workspaces_current") # GET for getting current tenant info -api.add_resource(TenantApi, "/info", endpoint="info") # Deprecated -api.add_resource(SwitchWorkspaceApi, "/workspaces/switch") # POST for switching tenant -api.add_resource(CustomConfigWorkspaceApi, "/workspaces/custom-config") -api.add_resource(WebappLogoWorkspaceApi, "/workspaces/custom-config/webapp-logo/upload") -api.add_resource(WorkspaceInfoApi, "/workspaces/info") # POST for changing workspace info diff --git a/api/controllers/inner_api/plugin/wraps.py b/api/controllers/inner_api/plugin/wraps.py index 04102c49f3..1f588bedce 100644 --- a/api/controllers/inner_api/plugin/wraps.py +++ b/api/controllers/inner_api/plugin/wraps.py @@ -24,20 +24,14 @@ def get_user(tenant_id: str, user_id: str | None) -> EndUser: NOTE: user_id is not trusted, it could be maliciously set to any value. As a result, it could only be considered as an end user id. """ + if not user_id: + user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID + is_anonymous = user_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID try: with Session(db.engine) as session: - if not user_id: - user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID.value + user_model = None - user_model = ( - session.query(EndUser) - .where( - EndUser.id == user_id, - EndUser.tenant_id == tenant_id, - ) - .first() - ) - if not user_model: + if is_anonymous: user_model = ( session.query(EndUser) .where( @@ -46,11 +40,21 @@ def get_user(tenant_id: str, user_id: str | None) -> EndUser: ) .first() ) + else: + user_model = ( + session.query(EndUser) + .where( + EndUser.id == user_id, + EndUser.tenant_id == tenant_id, + ) + .first() + ) + if not user_model: user_model = EndUser( tenant_id=tenant_id, type="service_api", - is_anonymous=user_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID.value, + is_anonymous=is_anonymous, session_id=user_id, ) session.add(user_model) @@ -81,7 +85,7 @@ def get_user_tenant(view: Callable[P, R] | None = None): raise ValueError("tenant_id is required") if not user_id: - user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID.value + user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID try: tenant_model = ( @@ -124,7 +128,7 @@ def plugin_data(view: Callable[P, R] | None = None, *, payload_type: type[BaseMo raise ValueError("invalid json") try: - payload = payload_type(**data) + payload = payload_type.model_validate(data) except Exception as e: raise ValueError(f"invalid payload: {str(e)}") diff --git a/api/controllers/service_api/dataset/dataset.py b/api/controllers/service_api/dataset/dataset.py index 6a70345f7c..92bbb76f0f 100644 --- a/api/controllers/service_api/dataset/dataset.py +++ b/api/controllers/service_api/dataset/dataset.py @@ -1,10 +1,10 @@ -from typing import Literal +from typing import Any, Literal, cast from flask import request from flask_restx import marshal, reqparse from werkzeug.exceptions import Forbidden, NotFound -import services.dataset_service +import services from controllers.service_api import service_api_ns from controllers.service_api.dataset.error import DatasetInUseError, DatasetNameDuplicateError, InvalidActionError from controllers.service_api.wraps import ( @@ -17,6 +17,7 @@ from core.provider_manager import ProviderManager from fields.dataset_fields import dataset_detail_fields from fields.tag_fields import build_dataset_tag_fields from libs.login import current_user +from libs.validators import validate_description_length from models.account import Account from models.dataset import Dataset, DatasetPermissionEnum from models.provider_ids import ModelProviderID @@ -31,12 +32,6 @@ def _validate_name(name): return name -def _validate_description_length(description): - if description and len(description) > 400: - raise ValueError("Description cannot exceed 400 characters.") - return description - - # Define parsers for dataset operations dataset_create_parser = reqparse.RequestParser() dataset_create_parser.add_argument( @@ -48,7 +43,7 @@ dataset_create_parser.add_argument( ) dataset_create_parser.add_argument( "description", - type=_validate_description_length, + type=validate_description_length, nullable=True, required=False, default="", @@ -101,7 +96,7 @@ dataset_update_parser.add_argument( type=_validate_name, ) dataset_update_parser.add_argument( - "description", location="json", store_missing=False, type=_validate_description_length + "description", location="json", store_missing=False, type=validate_description_length ) dataset_update_parser.add_argument( "indexing_technique", @@ -254,19 +249,21 @@ class DatasetListApi(DatasetApiResource): """Resource for creating datasets.""" args = dataset_create_parser.parse_args() - if args.get("embedding_model_provider"): - DatasetService.check_embedding_model_setting( - tenant_id, args.get("embedding_model_provider"), args.get("embedding_model") - ) + embedding_model_provider = args.get("embedding_model_provider") + embedding_model = args.get("embedding_model") + if embedding_model_provider and embedding_model: + DatasetService.check_embedding_model_setting(tenant_id, embedding_model_provider, embedding_model) + + retrieval_model = args.get("retrieval_model") if ( - args.get("retrieval_model") - and args.get("retrieval_model").get("reranking_model") - and args.get("retrieval_model").get("reranking_model").get("reranking_provider_name") + retrieval_model + and retrieval_model.get("reranking_model") + and retrieval_model.get("reranking_model").get("reranking_provider_name") ): DatasetService.check_reranking_model_setting( tenant_id, - args.get("retrieval_model").get("reranking_model").get("reranking_provider_name"), - args.get("retrieval_model").get("reranking_model").get("reranking_model_name"), + retrieval_model.get("reranking_model").get("reranking_provider_name"), + retrieval_model.get("reranking_model").get("reranking_model_name"), ) try: @@ -283,7 +280,7 @@ class DatasetListApi(DatasetApiResource): external_knowledge_id=args["external_knowledge_id"], embedding_model_provider=args["embedding_model_provider"], embedding_model_name=args["embedding_model"], - retrieval_model=RetrievalModel(**args["retrieval_model"]) + retrieval_model=RetrievalModel.model_validate(args["retrieval_model"]) if args["retrieval_model"] is not None else None, ) @@ -317,7 +314,7 @@ class DatasetApi(DatasetApiResource): DatasetService.check_dataset_permission(dataset, current_user) except services.errors.account.NoPermissionError as e: raise Forbidden(str(e)) - data = marshal(dataset, dataset_detail_fields) + data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields)) # check embedding setting provider_manager = ProviderManager() assert isinstance(current_user, Account) @@ -331,8 +328,8 @@ class DatasetApi(DatasetApiResource): for embedding_model in embedding_models: model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}") - if data["indexing_technique"] == "high_quality": - item_model = f"{data['embedding_model']}:{data['embedding_model_provider']}" + if data.get("indexing_technique") == "high_quality": + item_model = f"{data.get('embedding_model')}:{data.get('embedding_model_provider')}" if item_model in model_names: data["embedding_available"] = True else: @@ -341,7 +338,9 @@ class DatasetApi(DatasetApiResource): data["embedding_available"] = True # force update search method to keyword_search if indexing_technique is economic - data["retrieval_model_dict"]["search_method"] = "keyword_search" + retrieval_model_dict = data.get("retrieval_model_dict") + if retrieval_model_dict: + retrieval_model_dict["search_method"] = "keyword_search" if data.get("permission") == "partial_members": part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str) @@ -372,19 +371,24 @@ class DatasetApi(DatasetApiResource): data = request.get_json() # check embedding model setting - if data.get("indexing_technique") == "high_quality" or data.get("embedding_model_provider"): - DatasetService.check_embedding_model_setting( - dataset.tenant_id, data.get("embedding_model_provider"), data.get("embedding_model") - ) + embedding_model_provider = data.get("embedding_model_provider") + embedding_model = data.get("embedding_model") + if data.get("indexing_technique") == "high_quality" or embedding_model_provider: + if embedding_model_provider and embedding_model: + DatasetService.check_embedding_model_setting( + dataset.tenant_id, embedding_model_provider, embedding_model + ) + + retrieval_model = data.get("retrieval_model") if ( - data.get("retrieval_model") - and data.get("retrieval_model").get("reranking_model") - and data.get("retrieval_model").get("reranking_model").get("reranking_provider_name") + retrieval_model + and retrieval_model.get("reranking_model") + and retrieval_model.get("reranking_model").get("reranking_provider_name") ): DatasetService.check_reranking_model_setting( dataset.tenant_id, - data.get("retrieval_model").get("reranking_model").get("reranking_provider_name"), - data.get("retrieval_model").get("reranking_model").get("reranking_model_name"), + retrieval_model.get("reranking_model").get("reranking_provider_name"), + retrieval_model.get("reranking_model").get("reranking_model_name"), ) # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator @@ -397,7 +401,7 @@ class DatasetApi(DatasetApiResource): if dataset is None: raise NotFound("Dataset not found.") - result_data = marshal(dataset, dataset_detail_fields) + result_data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields)) assert isinstance(current_user, Account) tenant_id = current_user.current_tenant_id @@ -591,9 +595,10 @@ class DatasetTagsApi(DatasetApiResource): args = tag_update_parser.parse_args() args["type"] = "knowledge" - tag = TagService.update_tags(args, args.get("tag_id")) + tag_id = args["tag_id"] + tag = TagService.update_tags(args, tag_id) - binding_count = TagService.get_tag_binding_count(args.get("tag_id")) + binding_count = TagService.get_tag_binding_count(tag_id) response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count} @@ -616,7 +621,7 @@ class DatasetTagsApi(DatasetApiResource): if not current_user.has_edit_permission: raise Forbidden() args = tag_delete_parser.parse_args() - TagService.delete_tag(args.get("tag_id")) + TagService.delete_tag(args["tag_id"]) return 204 diff --git a/api/controllers/service_api/dataset/document.py b/api/controllers/service_api/dataset/document.py index d26c64fe36..961a338bc5 100644 --- a/api/controllers/service_api/dataset/document.py +++ b/api/controllers/service_api/dataset/document.py @@ -30,7 +30,6 @@ from extensions.ext_database import db from fields.document_fields import document_fields, document_status_fields from libs.login import current_user from models.dataset import Dataset, Document, DocumentSegment -from models.model import EndUser from services.dataset_service import DatasetService, DocumentService from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig from services.file_service import FileService @@ -109,19 +108,21 @@ class DocumentAddByTextApi(DatasetApiResource): if text is None or name is None: raise ValueError("Both 'text' and 'name' must be non-null values.") - if args.get("embedding_model_provider"): - DatasetService.check_embedding_model_setting( - tenant_id, args.get("embedding_model_provider"), args.get("embedding_model") - ) + embedding_model_provider = args.get("embedding_model_provider") + embedding_model = args.get("embedding_model") + if embedding_model_provider and embedding_model: + DatasetService.check_embedding_model_setting(tenant_id, embedding_model_provider, embedding_model) + + retrieval_model = args.get("retrieval_model") if ( - args.get("retrieval_model") - and args.get("retrieval_model").get("reranking_model") - and args.get("retrieval_model").get("reranking_model").get("reranking_provider_name") + retrieval_model + and retrieval_model.get("reranking_model") + and retrieval_model.get("reranking_model").get("reranking_provider_name") ): DatasetService.check_reranking_model_setting( tenant_id, - args.get("retrieval_model").get("reranking_model").get("reranking_provider_name"), - args.get("retrieval_model").get("reranking_model").get("reranking_model_name"), + retrieval_model.get("reranking_model").get("reranking_provider_name"), + retrieval_model.get("reranking_model").get("reranking_model_name"), ) if not current_user: @@ -135,7 +136,7 @@ class DocumentAddByTextApi(DatasetApiResource): "info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}}, } args["data_source"] = data_source - knowledge_config = KnowledgeConfig(**args) + knowledge_config = KnowledgeConfig.model_validate(args) # validate args DocumentService.document_create_args_validate(knowledge_config) @@ -188,15 +189,16 @@ class DocumentUpdateByTextApi(DatasetApiResource): if not dataset: raise ValueError("Dataset does not exist.") + retrieval_model = args.get("retrieval_model") if ( - args.get("retrieval_model") - and args.get("retrieval_model").get("reranking_model") - and args.get("retrieval_model").get("reranking_model").get("reranking_provider_name") + retrieval_model + and retrieval_model.get("reranking_model") + and retrieval_model.get("reranking_model").get("reranking_provider_name") ): DatasetService.check_reranking_model_setting( tenant_id, - args.get("retrieval_model").get("reranking_model").get("reranking_provider_name"), - args.get("retrieval_model").get("reranking_model").get("reranking_model_name"), + retrieval_model.get("reranking_model").get("reranking_provider_name"), + retrieval_model.get("reranking_model").get("reranking_model_name"), ) # indexing_technique is already set in dataset since this is an update @@ -219,7 +221,7 @@ class DocumentUpdateByTextApi(DatasetApiResource): args["data_source"] = data_source # validate args args["original_document_id"] = str(document_id) - knowledge_config = KnowledgeConfig(**args) + knowledge_config = KnowledgeConfig.model_validate(args) DocumentService.document_create_args_validate(knowledge_config) try: @@ -311,8 +313,6 @@ class DocumentAddByFileApi(DatasetApiResource): if not file.filename: raise FilenameNotExistsError - if not isinstance(current_user, EndUser): - raise ValueError("Invalid user account") if not current_user: raise ValueError("current_user is required") upload_file = FileService(db.engine).upload_file( @@ -328,7 +328,7 @@ class DocumentAddByFileApi(DatasetApiResource): } args["data_source"] = data_source # validate args - knowledge_config = KnowledgeConfig(**args) + knowledge_config = KnowledgeConfig.model_validate(args) DocumentService.document_create_args_validate(knowledge_config) dataset_process_rule = dataset.latest_process_rule if "process_rule" not in args else None @@ -406,9 +406,6 @@ class DocumentUpdateByFileApi(DatasetApiResource): if not current_user: raise ValueError("current_user is required") - if not isinstance(current_user, EndUser): - raise ValueError("Invalid user account") - try: upload_file = FileService(db.engine).upload_file( filename=file.filename, @@ -429,7 +426,7 @@ class DocumentUpdateByFileApi(DatasetApiResource): # validate args args["original_document_id"] = str(document_id) - knowledge_config = KnowledgeConfig(**args) + knowledge_config = KnowledgeConfig.model_validate(args) DocumentService.document_create_args_validate(knowledge_config) try: diff --git a/api/controllers/service_api/dataset/metadata.py b/api/controllers/service_api/dataset/metadata.py index c6032048e6..51420fdd5f 100644 --- a/api/controllers/service_api/dataset/metadata.py +++ b/api/controllers/service_api/dataset/metadata.py @@ -51,7 +51,7 @@ class DatasetMetadataCreateServiceApi(DatasetApiResource): def post(self, tenant_id, dataset_id): """Create metadata for a dataset.""" args = metadata_create_parser.parse_args() - metadata_args = MetadataArgs(**args) + metadata_args = MetadataArgs.model_validate(args) dataset_id_str = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id_str) @@ -106,7 +106,7 @@ class DatasetMetadataServiceApi(DatasetApiResource): raise NotFound("Dataset not found.") DatasetService.check_dataset_permission(dataset, current_user) - metadata = MetadataService.update_metadata_name(dataset_id_str, metadata_id_str, args.get("name")) + metadata = MetadataService.update_metadata_name(dataset_id_str, metadata_id_str, args["name"]) return marshal(metadata, dataset_metadata_fields), 200 @service_api_ns.doc("delete_dataset_metadata") @@ -200,7 +200,7 @@ class DocumentMetadataEditServiceApi(DatasetApiResource): DatasetService.check_dataset_permission(dataset, current_user) args = document_metadata_parser.parse_args() - metadata_args = MetadataOperationData(**args) + metadata_args = MetadataOperationData.model_validate(args) MetadataService.update_documents_metadata(dataset, metadata_args) diff --git a/api/controllers/service_api/dataset/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/service_api/dataset/rag_pipeline/rag_pipeline_workflow.py index f05325d711..13ef8abc2d 100644 --- a/api/controllers/service_api/dataset/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/service_api/dataset/rag_pipeline/rag_pipeline_workflow.py @@ -98,7 +98,7 @@ class DatasourceNodeRunApi(DatasetApiResource): parser.add_argument("is_published", type=bool, required=True, location="json") args: ParseResult = parser.parse_args() - datasource_node_run_api_entity: DatasourceNodeRunApiEntity = DatasourceNodeRunApiEntity(**args) + datasource_node_run_api_entity = DatasourceNodeRunApiEntity.model_validate(args) assert isinstance(current_user, Account) rag_pipeline_service: RagPipelineService = RagPipelineService() pipeline: Pipeline = rag_pipeline_service.get_pipeline(tenant_id=tenant_id, dataset_id=dataset_id) diff --git a/api/controllers/service_api/dataset/segment.py b/api/controllers/service_api/dataset/segment.py index a22155b07a..d674c7467d 100644 --- a/api/controllers/service_api/dataset/segment.py +++ b/api/controllers/service_api/dataset/segment.py @@ -252,7 +252,7 @@ class DatasetSegmentApi(DatasetApiResource): args = segment_update_parser.parse_args() updated_segment = SegmentService.update_segment( - SegmentUpdateArgs(**args["segment"]), segment, document, dataset + SegmentUpdateArgs.model_validate(args["segment"]), segment, document, dataset ) return {"data": marshal(updated_segment, segment_fields), "doc_form": document.doc_form}, 200 diff --git a/api/controllers/service_api/wraps.py b/api/controllers/service_api/wraps.py index ee8e1d105b..2c9be4e887 100644 --- a/api/controllers/service_api/wraps.py +++ b/api/controllers/service_api/wraps.py @@ -313,7 +313,7 @@ def create_or_update_end_user_for_user_id(app_model: App, user_id: str | None = Create or update session terminal based on user ID. """ if not user_id: - user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID.value + user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID with Session(db.engine, expire_on_commit=False) as session: end_user = ( @@ -332,7 +332,7 @@ def create_or_update_end_user_for_user_id(app_model: App, user_id: str | None = tenant_id=app_model.tenant_id, app_id=app_model.id, type="service_api", - is_anonymous=user_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID.value, + is_anonymous=user_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID, session_id=user_id, ) session.add(end_user) diff --git a/api/controllers/web/passport.py b/api/controllers/web/passport.py index 6f7105a724..7190f06426 100644 --- a/api/controllers/web/passport.py +++ b/api/controllers/web/passport.py @@ -126,6 +126,8 @@ def exchange_token_for_existing_web_user(app_code: str, enterprise_user_decoded: end_user_id = enterprise_user_decoded.get("end_user_id") session_id = enterprise_user_decoded.get("session_id") user_auth_type = enterprise_user_decoded.get("auth_type") + exchanged_token_expires_unix = enterprise_user_decoded.get("exp") + if not user_auth_type: raise Unauthorized("Missing auth_type in the token.") @@ -169,8 +171,11 @@ def exchange_token_for_existing_web_user(app_code: str, enterprise_user_decoded: ) db.session.add(end_user) db.session.commit() - exp_dt = datetime.now(UTC) + timedelta(minutes=dify_config.ACCESS_TOKEN_EXPIRE_MINUTES) - exp = int(exp_dt.timestamp()) + + exp = int((datetime.now(UTC) + timedelta(minutes=dify_config.ACCESS_TOKEN_EXPIRE_MINUTES)).timestamp()) + if exchanged_token_expires_unix: + exp = int(exchanged_token_expires_unix) + payload = { "iss": site.id, "sub": "Web API Passport", diff --git a/api/core/app/app_config/easy_ui_based_app/agent/manager.py b/api/core/app/app_config/easy_ui_based_app/agent/manager.py index eab26e5af9..c1f336fdde 100644 --- a/api/core/app/app_config/easy_ui_based_app/agent/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/agent/manager.py @@ -40,7 +40,7 @@ class AgentConfigManager: "credential_id": tool.get("credential_id", None), } - agent_tools.append(AgentToolEntity(**agent_tool_properties)) + agent_tools.append(AgentToolEntity.model_validate(agent_tool_properties)) if "strategy" in config["agent_mode"] and config["agent_mode"]["strategy"] not in { "react_router", diff --git a/api/core/app/app_config/easy_ui_based_app/dataset/manager.py b/api/core/app/app_config/easy_ui_based_app/dataset/manager.py index 4b824bde76..aacafb2dad 100644 --- a/api/core/app/app_config/easy_ui_based_app/dataset/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/dataset/manager.py @@ -1,4 +1,5 @@ import uuid +from typing import Literal, cast from core.app.app_config.entities import ( DatasetEntity, @@ -74,6 +75,9 @@ class DatasetConfigManager: return None query_variable = config.get("dataset_query_variable") + metadata_model_config_dict = dataset_configs.get("metadata_model_config") + metadata_filtering_conditions_dict = dataset_configs.get("metadata_filtering_conditions") + if dataset_configs["retrieval_model"] == "single": return DatasetEntity( dataset_ids=dataset_ids, @@ -82,18 +86,23 @@ class DatasetConfigManager: retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of( dataset_configs["retrieval_model"] ), - metadata_filtering_mode=dataset_configs.get("metadata_filtering_mode", "disabled"), - metadata_model_config=ModelConfig(**dataset_configs.get("metadata_model_config")) - if dataset_configs.get("metadata_model_config") + metadata_filtering_mode=cast( + Literal["disabled", "automatic", "manual"], + dataset_configs.get("metadata_filtering_mode", "disabled"), + ), + metadata_model_config=ModelConfig(**metadata_model_config_dict) + if isinstance(metadata_model_config_dict, dict) else None, - metadata_filtering_conditions=MetadataFilteringCondition( - **dataset_configs.get("metadata_filtering_conditions", {}) - ) - if dataset_configs.get("metadata_filtering_conditions") + metadata_filtering_conditions=MetadataFilteringCondition(**metadata_filtering_conditions_dict) + if isinstance(metadata_filtering_conditions_dict, dict) else None, ), ) else: + score_threshold_val = dataset_configs.get("score_threshold") + reranking_model_val = dataset_configs.get("reranking_model") + weights_val = dataset_configs.get("weights") + return DatasetEntity( dataset_ids=dataset_ids, retrieve_config=DatasetRetrieveConfigEntity( @@ -101,22 +110,23 @@ class DatasetConfigManager: retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of( dataset_configs["retrieval_model"] ), - top_k=dataset_configs.get("top_k", 4), - score_threshold=dataset_configs.get("score_threshold") - if dataset_configs.get("score_threshold_enabled", False) + top_k=int(dataset_configs.get("top_k", 4)), + score_threshold=float(score_threshold_val) + if dataset_configs.get("score_threshold_enabled", False) and score_threshold_val is not None else None, - reranking_model=dataset_configs.get("reranking_model"), - weights=dataset_configs.get("weights"), - reranking_enabled=dataset_configs.get("reranking_enabled", True), + reranking_model=reranking_model_val if isinstance(reranking_model_val, dict) else None, + weights=weights_val if isinstance(weights_val, dict) else None, + reranking_enabled=bool(dataset_configs.get("reranking_enabled", True)), rerank_mode=dataset_configs.get("reranking_mode", "reranking_model"), - metadata_filtering_mode=dataset_configs.get("metadata_filtering_mode", "disabled"), - metadata_model_config=ModelConfig(**dataset_configs.get("metadata_model_config")) - if dataset_configs.get("metadata_model_config") + metadata_filtering_mode=cast( + Literal["disabled", "automatic", "manual"], + dataset_configs.get("metadata_filtering_mode", "disabled"), + ), + metadata_model_config=ModelConfig(**metadata_model_config_dict) + if isinstance(metadata_model_config_dict, dict) else None, - metadata_filtering_conditions=MetadataFilteringCondition( - **dataset_configs.get("metadata_filtering_conditions", {}) - ) - if dataset_configs.get("metadata_filtering_conditions") + metadata_filtering_conditions=MetadataFilteringCondition(**metadata_filtering_conditions_dict) + if isinstance(metadata_filtering_conditions_dict, dict) else None, ), ) @@ -134,18 +144,17 @@ class DatasetConfigManager: config = cls.extract_dataset_config_for_legacy_compatibility(tenant_id, app_mode, config) # dataset_configs - if not config.get("dataset_configs"): - config["dataset_configs"] = {"retrieval_model": "single"} + if "dataset_configs" not in config or not config.get("dataset_configs"): + config["dataset_configs"] = {} + config["dataset_configs"]["retrieval_model"] = config["dataset_configs"].get("retrieval_model", "single") if not isinstance(config["dataset_configs"], dict): raise ValueError("dataset_configs must be of object type") - if not config["dataset_configs"].get("datasets"): + if "datasets" not in config["dataset_configs"] or not config["dataset_configs"].get("datasets"): config["dataset_configs"]["datasets"] = {"strategy": "router", "datasets": []} - need_manual_query_datasets = config.get("dataset_configs") and config["dataset_configs"].get( - "datasets", {} - ).get("datasets") + need_manual_query_datasets = config.get("dataset_configs", {}).get("datasets", {}).get("datasets") if need_manual_query_datasets and app_mode == AppMode.COMPLETION: # Only check when mode is completion @@ -166,8 +175,8 @@ class DatasetConfigManager: :param config: app model config args """ # Extract dataset config for legacy compatibility - if not config.get("agent_mode"): - config["agent_mode"] = {"enabled": False, "tools": []} + if "agent_mode" not in config or not config.get("agent_mode"): + config["agent_mode"] = {} if not isinstance(config["agent_mode"], dict): raise ValueError("agent_mode must be of object type") @@ -180,19 +189,22 @@ class DatasetConfigManager: raise ValueError("enabled in agent_mode must be of boolean type") # tools - if not config["agent_mode"].get("tools"): + if "tools" not in config["agent_mode"] or not config["agent_mode"].get("tools"): config["agent_mode"]["tools"] = [] if not isinstance(config["agent_mode"]["tools"], list): raise ValueError("tools in agent_mode must be a list of objects") # strategy - if not config["agent_mode"].get("strategy"): - config["agent_mode"]["strategy"] = PlanningStrategy.ROUTER.value + if "strategy" not in config["agent_mode"] or not config["agent_mode"].get("strategy"): + config["agent_mode"]["strategy"] = PlanningStrategy.ROUTER has_datasets = False - if config["agent_mode"]["strategy"] in {PlanningStrategy.ROUTER.value, PlanningStrategy.REACT_ROUTER.value}: - for tool in config["agent_mode"]["tools"]: + if config.get("agent_mode", {}).get("strategy") in { + PlanningStrategy.ROUTER, + PlanningStrategy.REACT_ROUTER, + }: + for tool in config.get("agent_mode", {}).get("tools", []): key = list(tool.keys())[0] if key == "dataset": # old style, use tool name as key @@ -217,7 +229,7 @@ class DatasetConfigManager: has_datasets = True - need_manual_query_datasets = has_datasets and config["agent_mode"]["enabled"] + need_manual_query_datasets = has_datasets and config.get("agent_mode", {}).get("enabled") if need_manual_query_datasets and app_mode == AppMode.COMPLETION: # Only check when mode is completion diff --git a/api/core/app/app_config/easy_ui_based_app/model_config/converter.py b/api/core/app/app_config/easy_ui_based_app/model_config/converter.py index 5b5eefe315..7cd5fe75d5 100644 --- a/api/core/app/app_config/easy_ui_based_app/model_config/converter.py +++ b/api/core/app/app_config/easy_ui_based_app/model_config/converter.py @@ -68,7 +68,7 @@ class ModelConfigConverter: # get model mode model_mode = model_config.mode if not model_mode: - model_mode = LLMMode.CHAT.value + model_mode = LLMMode.CHAT if model_schema and model_schema.model_properties.get(ModelPropertyKey.MODE): model_mode = LLMMode(model_schema.model_properties[ModelPropertyKey.MODE]).value diff --git a/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py b/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py index ec4f6074ab..21614c010c 100644 --- a/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py @@ -100,7 +100,7 @@ class PromptTemplateConfigManager: if config["model"]["mode"] not in model_mode_vals: raise ValueError(f"model.mode must be in {model_mode_vals} when prompt_type is advanced") - if app_mode == AppMode.CHAT and config["model"]["mode"] == ModelMode.COMPLETION.value: + if app_mode == AppMode.CHAT and config["model"]["mode"] == ModelMode.COMPLETION: user_prefix = config["completion_prompt_config"]["conversation_histories_role"]["user_prefix"] assistant_prefix = config["completion_prompt_config"]["conversation_histories_role"]["assistant_prefix"] @@ -110,7 +110,7 @@ class PromptTemplateConfigManager: if not assistant_prefix: config["completion_prompt_config"]["conversation_histories_role"]["assistant_prefix"] = "Assistant" - if config["model"]["mode"] == ModelMode.CHAT.value: + if config["model"]["mode"] == ModelMode.CHAT: prompt_list = config["chat_prompt_config"]["prompt"] if len(prompt_list) > 10: diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index af8b7e4e17..919b135ec9 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -79,29 +79,12 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): if not app_record: raise ValueError("App not found") - if self.application_generate_entity.single_iteration_run: - # if only single iteration run is requested - graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool.empty(), - start_at=time.time(), - ) - graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration( + if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run: + # Handle single iteration or single loop run + graph, variable_pool, graph_runtime_state = self._prepare_single_node_execution( workflow=self._workflow, - node_id=self.application_generate_entity.single_iteration_run.node_id, - user_inputs=dict(self.application_generate_entity.single_iteration_run.inputs), - graph_runtime_state=graph_runtime_state, - ) - elif self.application_generate_entity.single_loop_run: - # if only single loop run is requested - graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool.empty(), - start_at=time.time(), - ) - graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop( - workflow=self._workflow, - node_id=self.application_generate_entity.single_loop_run.node_id, - user_inputs=dict(self.application_generate_entity.single_loop_run.inputs), - graph_runtime_state=graph_runtime_state, + single_iteration_run=self.application_generate_entity.single_iteration_run, + single_loop_run=self.application_generate_entity.single_loop_run, ) else: inputs = self.application_generate_entity.inputs diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index 71588870fa..e021b0aca7 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -551,7 +551,7 @@ class AdvancedChatAppGenerateTaskPipeline: total_steps=validated_state.node_run_steps, outputs=event.outputs, exceptions_count=event.exceptions_count, - conversation_id=None, + conversation_id=self._conversation_id, trace_manager=trace_manager, external_trace_id=self._application_generate_entity.extras.get("external_trace_id"), ) diff --git a/api/core/app/apps/agent_chat/app_config_manager.py b/api/core/app/apps/agent_chat/app_config_manager.py index 9ce841f432..801619ddbc 100644 --- a/api/core/app/apps/agent_chat/app_config_manager.py +++ b/api/core/app/apps/agent_chat/app_config_manager.py @@ -186,7 +186,7 @@ class AgentChatAppConfigManager(BaseAppConfigManager): raise ValueError("enabled in agent_mode must be of boolean type") if not agent_mode.get("strategy"): - agent_mode["strategy"] = PlanningStrategy.ROUTER.value + agent_mode["strategy"] = PlanningStrategy.ROUTER if agent_mode["strategy"] not in [member.value for member in list(PlanningStrategy.__members__.values())]: raise ValueError("strategy in agent_mode must be in the specified strategy list") diff --git a/api/core/app/apps/agent_chat/app_runner.py b/api/core/app/apps/agent_chat/app_runner.py index 388bed5255..759398b556 100644 --- a/api/core/app/apps/agent_chat/app_runner.py +++ b/api/core/app/apps/agent_chat/app_runner.py @@ -198,9 +198,9 @@ class AgentChatAppRunner(AppRunner): # start agent runner if agent_entity.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT: # check LLM mode - if model_schema.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT.value: + if model_schema.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT: runner_cls = CotChatAgentRunner - elif model_schema.model_properties.get(ModelPropertyKey.MODE) == LLMMode.COMPLETION.value: + elif model_schema.model_properties.get(ModelPropertyKey.MODE) == LLMMode.COMPLETION: runner_cls = CotCompletionAgentRunner else: raise ValueError(f"Invalid LLM mode: {model_schema.model_properties.get(ModelPropertyKey.MODE)}") diff --git a/api/core/app/apps/base_app_queue_manager.py b/api/core/app/apps/base_app_queue_manager.py index fdba952eeb..4b246a53d3 100644 --- a/api/core/app/apps/base_app_queue_manager.py +++ b/api/core/app/apps/base_app_queue_manager.py @@ -1,9 +1,11 @@ +import logging import queue import time from abc import abstractmethod from enum import IntEnum, auto from typing import Any +from redis.exceptions import RedisError from sqlalchemy.orm import DeclarativeMeta from configs import dify_config @@ -18,6 +20,8 @@ from core.app.entities.queue_entities import ( ) from extensions.ext_redis import redis_client +logger = logging.getLogger(__name__) + class PublishFrom(IntEnum): APPLICATION_MANAGER = auto() @@ -35,9 +39,8 @@ class AppQueueManager: self.invoke_from = invoke_from # Public accessor for invoke_from user_prefix = "account" if self._invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end-user" - redis_client.setex( - AppQueueManager._generate_task_belong_cache_key(self._task_id), 1800, f"{user_prefix}-{self._user_id}" - ) + self._task_belong_cache_key = AppQueueManager._generate_task_belong_cache_key(self._task_id) + redis_client.setex(self._task_belong_cache_key, 1800, f"{user_prefix}-{self._user_id}") q: queue.Queue[WorkflowQueueMessage | MessageQueueMessage | None] = queue.Queue() @@ -79,9 +82,21 @@ class AppQueueManager: Stop listen to queue :return: """ + self._clear_task_belong_cache() self._q.put(None) - def publish_error(self, e, pub_from: PublishFrom): + def _clear_task_belong_cache(self) -> None: + """ + Remove the task belong cache key once listening is finished. + """ + try: + redis_client.delete(self._task_belong_cache_key) + except RedisError: + logger.exception( + "Failed to clear task belong cache for task %s (key: %s)", self._task_id, self._task_belong_cache_key + ) + + def publish_error(self, e, pub_from: PublishFrom) -> None: """ Publish error :param e: error diff --git a/api/core/app/apps/base_app_runner.py b/api/core/app/apps/base_app_runner.py index e7db3bc41b..61ac040c05 100644 --- a/api/core/app/apps/base_app_runner.py +++ b/api/core/app/apps/base_app_runner.py @@ -61,9 +61,6 @@ class AppRunner: if model_context_tokens is None: return -1 - if max_tokens is None: - max_tokens = 0 - prompt_tokens = model_instance.get_llm_num_tokens(prompt_messages) if prompt_tokens + max_tokens > model_context_tokens: diff --git a/api/core/app/apps/pipeline/pipeline_generator.py b/api/core/app/apps/pipeline/pipeline_generator.py index 76627b876b..bd077c4cb8 100644 --- a/api/core/app/apps/pipeline/pipeline_generator.py +++ b/api/core/app/apps/pipeline/pipeline_generator.py @@ -427,6 +427,9 @@ class PipelineGenerator(BaseAppGenerator): invoke_from=InvokeFrom.DEBUGGER, call_depth=0, workflow_execution_id=str(uuid.uuid4()), + single_iteration_run=RagPipelineGenerateEntity.SingleIterationRunEntity( + node_id=node_id, inputs=args["inputs"] + ), ) contexts.plugin_tool_providers.set({}) contexts.plugin_tool_providers_lock.set(threading.Lock()) @@ -465,6 +468,7 @@ class PipelineGenerator(BaseAppGenerator): workflow_node_execution_repository=workflow_node_execution_repository, streaming=streaming, variable_loader=var_loader, + context=contextvars.copy_context(), ) def single_loop_generate( @@ -559,6 +563,7 @@ class PipelineGenerator(BaseAppGenerator): workflow_node_execution_repository=workflow_node_execution_repository, streaming=streaming, variable_loader=var_loader, + context=contextvars.copy_context(), ) def _generate_worker( diff --git a/api/core/app/apps/pipeline/pipeline_runner.py b/api/core/app/apps/pipeline/pipeline_runner.py index ebb8b15163..a8a7dde2b4 100644 --- a/api/core/app/apps/pipeline/pipeline_runner.py +++ b/api/core/app/apps/pipeline/pipeline_runner.py @@ -86,29 +86,12 @@ class PipelineRunner(WorkflowBasedAppRunner): db.session.close() # if only single iteration run is requested - if self.application_generate_entity.single_iteration_run: - graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool.empty(), - start_at=time.time(), - ) - # if only single iteration run is requested - graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration( + if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run: + # Handle single iteration or single loop run + graph, variable_pool, graph_runtime_state = self._prepare_single_node_execution( workflow=workflow, - node_id=self.application_generate_entity.single_iteration_run.node_id, - user_inputs=self.application_generate_entity.single_iteration_run.inputs, - graph_runtime_state=graph_runtime_state, - ) - elif self.application_generate_entity.single_loop_run: - graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool.empty(), - start_at=time.time(), - ) - # if only single loop run is requested - graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop( - workflow=workflow, - node_id=self.application_generate_entity.single_loop_run.node_id, - user_inputs=self.application_generate_entity.single_loop_run.inputs, - graph_runtime_state=graph_runtime_state, + single_iteration_run=self.application_generate_entity.single_iteration_run, + single_loop_run=self.application_generate_entity.single_loop_run, ) else: inputs = self.application_generate_entity.inputs @@ -133,7 +116,7 @@ class PipelineRunner(WorkflowBasedAppRunner): rag_pipeline_variables = [] if workflow.rag_pipeline_variables: for v in workflow.rag_pipeline_variables: - rag_pipeline_variable = RAGPipelineVariable(**v) + rag_pipeline_variable = RAGPipelineVariable.model_validate(v) if ( rag_pipeline_variable.belong_to_node_id in (self.application_generate_entity.start_node_id, "shared") @@ -246,8 +229,8 @@ class PipelineRunner(WorkflowBasedAppRunner): workflow_id=workflow.id, graph_config=graph_config, user_id=self.application_generate_entity.user_id, - user_from=UserFrom.ACCOUNT.value, - invoke_from=InvokeFrom.SERVICE_API.value, + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.SERVICE_API, call_depth=0, ) diff --git a/api/core/app/apps/workflow/app_runner.py b/api/core/app/apps/workflow/app_runner.py index b009dc7715..943ae8ab4e 100644 --- a/api/core/app/apps/workflow/app_runner.py +++ b/api/core/app/apps/workflow/app_runner.py @@ -51,30 +51,12 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): app_config = self.application_generate_entity.app_config app_config = cast(WorkflowAppConfig, app_config) - # if only single iteration run is requested - if self.application_generate_entity.single_iteration_run: - # if only single iteration run is requested - graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool.empty(), - start_at=time.time(), - ) - graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration( + # if only single iteration or single loop run is requested + if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run: + graph, variable_pool, graph_runtime_state = self._prepare_single_node_execution( workflow=self._workflow, - node_id=self.application_generate_entity.single_iteration_run.node_id, - user_inputs=self.application_generate_entity.single_iteration_run.inputs, - graph_runtime_state=graph_runtime_state, - ) - elif self.application_generate_entity.single_loop_run: - # if only single loop run is requested - graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool.empty(), - start_at=time.time(), - ) - graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop( - workflow=self._workflow, - node_id=self.application_generate_entity.single_loop_run.node_id, - user_inputs=self.application_generate_entity.single_loop_run.inputs, - graph_runtime_state=graph_runtime_state, + single_iteration_run=self.application_generate_entity.single_iteration_run, + single_loop_run=self.application_generate_entity.single_loop_run, ) else: inputs = self.application_generate_entity.inputs diff --git a/api/core/app/apps/workflow_app_runner.py b/api/core/app/apps/workflow_app_runner.py index 056e03fa14..68eb455d26 100644 --- a/api/core/app/apps/workflow_app_runner.py +++ b/api/core/app/apps/workflow_app_runner.py @@ -1,3 +1,4 @@ +import time from collections.abc import Mapping from typing import Any, cast @@ -99,8 +100,8 @@ class WorkflowBasedAppRunner: workflow_id=workflow_id, graph_config=graph_config, user_id=user_id, - user_from=UserFrom.ACCOUNT.value, - invoke_from=InvokeFrom.SERVICE_API.value, + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.SERVICE_API, call_depth=0, ) @@ -119,15 +120,81 @@ class WorkflowBasedAppRunner: return graph - def _get_graph_and_variable_pool_of_single_iteration( + def _prepare_single_node_execution( + self, + workflow: Workflow, + single_iteration_run: Any | None = None, + single_loop_run: Any | None = None, + ) -> tuple[Graph, VariablePool, GraphRuntimeState]: + """ + Prepare graph, variable pool, and runtime state for single node execution + (either single iteration or single loop). + + Args: + workflow: The workflow instance + single_iteration_run: SingleIterationRunEntity if running single iteration, None otherwise + single_loop_run: SingleLoopRunEntity if running single loop, None otherwise + + Returns: + A tuple containing (graph, variable_pool, graph_runtime_state) + + Raises: + ValueError: If neither single_iteration_run nor single_loop_run is specified + """ + # Create initial runtime state with variable pool containing environment variables + graph_runtime_state = GraphRuntimeState( + variable_pool=VariablePool( + system_variables=SystemVariable.empty(), + user_inputs={}, + environment_variables=workflow.environment_variables, + ), + start_at=time.time(), + ) + + # Determine which type of single node execution and get graph/variable_pool + if single_iteration_run: + graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration( + workflow=workflow, + node_id=single_iteration_run.node_id, + user_inputs=dict(single_iteration_run.inputs), + graph_runtime_state=graph_runtime_state, + ) + elif single_loop_run: + graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop( + workflow=workflow, + node_id=single_loop_run.node_id, + user_inputs=dict(single_loop_run.inputs), + graph_runtime_state=graph_runtime_state, + ) + else: + raise ValueError("Neither single_iteration_run nor single_loop_run is specified") + + # Return the graph, variable_pool, and the same graph_runtime_state used during graph creation + # This ensures all nodes in the graph reference the same GraphRuntimeState instance + return graph, variable_pool, graph_runtime_state + + def _get_graph_and_variable_pool_for_single_node_run( self, workflow: Workflow, node_id: str, - user_inputs: dict, + user_inputs: dict[str, Any], graph_runtime_state: GraphRuntimeState, + node_type_filter_key: str, # 'iteration_id' or 'loop_id' + node_type_label: str = "node", # 'iteration' or 'loop' for error messages ) -> tuple[Graph, VariablePool]: """ - Get variable pool of single iteration + Get graph and variable pool for single node execution (iteration or loop). + + Args: + workflow: The workflow instance + node_id: The node ID to execute + user_inputs: User inputs for the node + graph_runtime_state: The graph runtime state + node_type_filter_key: The key to filter nodes ('iteration_id' or 'loop_id') + node_type_label: Label for error messages ('iteration' or 'loop') + + Returns: + A tuple containing (graph, variable_pool) """ # fetch workflow graph graph_config = workflow.graph_dict @@ -145,18 +212,22 @@ class WorkflowBasedAppRunner: if not isinstance(graph_config.get("edges"), list): raise ValueError("edges in workflow graph must be a list") - # filter nodes only in iteration + # filter nodes only in the specified node type (iteration or loop) + main_node_config = next((n for n in graph_config.get("nodes", []) if n.get("id") == node_id), None) + start_node_id = main_node_config.get("data", {}).get("start_node_id") if main_node_config else None node_configs = [ node for node in graph_config.get("nodes", []) - if node.get("id") == node_id or node.get("data", {}).get("iteration_id", "") == node_id + if node.get("id") == node_id + or node.get("data", {}).get(node_type_filter_key, "") == node_id + or (start_node_id and node.get("id") == start_node_id) ] graph_config["nodes"] = node_configs node_ids = [node.get("id") for node in node_configs] - # filter edges only in iteration + # filter edges only in the specified node type edge_configs = [ edge for edge in graph_config.get("edges", []) @@ -173,8 +244,8 @@ class WorkflowBasedAppRunner: workflow_id=workflow.id, graph_config=graph_config, user_id="", - user_from=UserFrom.ACCOUNT.value, - invoke_from=InvokeFrom.SERVICE_API.value, + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.SERVICE_API, call_depth=0, ) @@ -190,30 +261,26 @@ class WorkflowBasedAppRunner: raise ValueError("graph not found in workflow") # fetch node config from node id - iteration_node_config = None + target_node_config = None for node in node_configs: if node.get("id") == node_id: - iteration_node_config = node + target_node_config = node break - if not iteration_node_config: - raise ValueError("iteration node id not found in workflow graph") + if not target_node_config: + raise ValueError(f"{node_type_label} node id not found in workflow graph") # Get node class - node_type = NodeType(iteration_node_config.get("data", {}).get("type")) - node_version = iteration_node_config.get("data", {}).get("version", "1") + node_type = NodeType(target_node_config.get("data", {}).get("type")) + node_version = target_node_config.get("data", {}).get("version", "1") node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version] - # init variable pool - variable_pool = VariablePool( - system_variables=SystemVariable.empty(), - user_inputs={}, - environment_variables=workflow.environment_variables, - ) + # Use the variable pool from graph_runtime_state instead of creating a new one + variable_pool = graph_runtime_state.variable_pool try: variable_mapping = node_cls.extract_variable_selector_to_variable_mapping( - graph_config=workflow.graph_dict, config=iteration_node_config + graph_config=workflow.graph_dict, config=target_node_config ) except NotImplementedError: variable_mapping = {} @@ -234,120 +301,44 @@ class WorkflowBasedAppRunner: return graph, variable_pool + def _get_graph_and_variable_pool_of_single_iteration( + self, + workflow: Workflow, + node_id: str, + user_inputs: dict[str, Any], + graph_runtime_state: GraphRuntimeState, + ) -> tuple[Graph, VariablePool]: + """ + Get variable pool of single iteration + """ + return self._get_graph_and_variable_pool_for_single_node_run( + workflow=workflow, + node_id=node_id, + user_inputs=user_inputs, + graph_runtime_state=graph_runtime_state, + node_type_filter_key="iteration_id", + node_type_label="iteration", + ) + def _get_graph_and_variable_pool_of_single_loop( self, workflow: Workflow, node_id: str, - user_inputs: dict, + user_inputs: dict[str, Any], graph_runtime_state: GraphRuntimeState, ) -> tuple[Graph, VariablePool]: """ Get variable pool of single loop """ - # fetch workflow graph - graph_config = workflow.graph_dict - if not graph_config: - raise ValueError("workflow graph not found") - - graph_config = cast(dict[str, Any], graph_config) - - if "nodes" not in graph_config or "edges" not in graph_config: - raise ValueError("nodes or edges not found in workflow graph") - - if not isinstance(graph_config.get("nodes"), list): - raise ValueError("nodes in workflow graph must be a list") - - if not isinstance(graph_config.get("edges"), list): - raise ValueError("edges in workflow graph must be a list") - - # filter nodes only in loop - node_configs = [ - node - for node in graph_config.get("nodes", []) - if node.get("id") == node_id or node.get("data", {}).get("loop_id", "") == node_id - ] - - graph_config["nodes"] = node_configs - - node_ids = [node.get("id") for node in node_configs] - - # filter edges only in loop - edge_configs = [ - edge - for edge in graph_config.get("edges", []) - if (edge.get("source") is None or edge.get("source") in node_ids) - and (edge.get("target") is None or edge.get("target") in node_ids) - ] - - graph_config["edges"] = edge_configs - - # Create required parameters for Graph.init - graph_init_params = GraphInitParams( - tenant_id=workflow.tenant_id, - app_id=self._app_id, - workflow_id=workflow.id, - graph_config=graph_config, - user_id="", - user_from=UserFrom.ACCOUNT.value, - invoke_from=InvokeFrom.SERVICE_API.value, - call_depth=0, - ) - - node_factory = DifyNodeFactory( - graph_init_params=graph_init_params, + return self._get_graph_and_variable_pool_for_single_node_run( + workflow=workflow, + node_id=node_id, + user_inputs=user_inputs, graph_runtime_state=graph_runtime_state, + node_type_filter_key="loop_id", + node_type_label="loop", ) - # init graph - graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id=node_id) - - if not graph: - raise ValueError("graph not found in workflow") - - # fetch node config from node id - loop_node_config = None - for node in node_configs: - if node.get("id") == node_id: - loop_node_config = node - break - - if not loop_node_config: - raise ValueError("loop node id not found in workflow graph") - - # Get node class - node_type = NodeType(loop_node_config.get("data", {}).get("type")) - node_version = loop_node_config.get("data", {}).get("version", "1") - node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version] - - # init variable pool - variable_pool = VariablePool( - system_variables=SystemVariable.empty(), - user_inputs={}, - environment_variables=workflow.environment_variables, - ) - - try: - variable_mapping = node_cls.extract_variable_selector_to_variable_mapping( - graph_config=workflow.graph_dict, config=loop_node_config - ) - except NotImplementedError: - variable_mapping = {} - load_into_variable_pool( - self._variable_loader, - variable_pool=variable_pool, - variable_mapping=variable_mapping, - user_inputs=user_inputs, - ) - - WorkflowEntry.mapping_user_inputs_to_variable_pool( - variable_mapping=variable_mapping, - user_inputs=user_inputs, - variable_pool=variable_pool, - tenant_id=workflow.tenant_id, - ) - - return graph, variable_pool - def _handle_event(self, workflow_entry: WorkflowEntry, event: GraphEngineEvent): """ Handle event diff --git a/api/core/app/task_pipeline/message_cycle_manager.py b/api/core/app/task_pipeline/message_cycle_manager.py index 0004fb592e..7a384e5c92 100644 --- a/api/core/app/task_pipeline/message_cycle_manager.py +++ b/api/core/app/task_pipeline/message_cycle_manager.py @@ -107,7 +107,6 @@ class MessageCycleManager: if dify_config.DEBUG: logger.exception("generate conversation name failed, conversation_id: %s", conversation_id) - db.session.merge(conversation) db.session.commit() db.session.close() diff --git a/api/core/datasource/__base/datasource_runtime.py b/api/core/datasource/__base/datasource_runtime.py index b7f280208a..c5d6c1d771 100644 --- a/api/core/datasource/__base/datasource_runtime.py +++ b/api/core/datasource/__base/datasource_runtime.py @@ -1,7 +1,6 @@ from typing import TYPE_CHECKING, Any, Optional -from openai import BaseModel -from pydantic import Field +from pydantic import BaseModel, Field # Import InvokeFrom locally to avoid circular import from core.app.entities.app_invoke_entities import InvokeFrom diff --git a/api/core/datasource/entities/api_entities.py b/api/core/datasource/entities/api_entities.py index cdefcc4506..1179537570 100644 --- a/api/core/datasource/entities/api_entities.py +++ b/api/core/datasource/entities/api_entities.py @@ -49,7 +49,7 @@ class DatasourceProviderApiEntity(BaseModel): for datasource in datasources: if datasource.get("parameters"): for parameter in datasource.get("parameters"): - if parameter.get("type") == DatasourceParameter.DatasourceParameterType.SYSTEM_FILES.value: + if parameter.get("type") == DatasourceParameter.DatasourceParameterType.SYSTEM_FILES: parameter["type"] = "files" # ------------- diff --git a/api/core/datasource/entities/common_entities.py b/api/core/datasource/entities/common_entities.py index ac36d83ae3..3c64632dbb 100644 --- a/api/core/datasource/entities/common_entities.py +++ b/api/core/datasource/entities/common_entities.py @@ -1,4 +1,4 @@ -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, model_validator class I18nObject(BaseModel): @@ -11,11 +11,12 @@ class I18nObject(BaseModel): pt_BR: str | None = Field(default=None) ja_JP: str | None = Field(default=None) - def __init__(self, **data): - super().__init__(**data) + @model_validator(mode="after") + def _(self): self.zh_Hans = self.zh_Hans or self.en_US self.pt_BR = self.pt_BR or self.en_US self.ja_JP = self.ja_JP or self.en_US + return self def to_dict(self) -> dict: return {"zh_Hans": self.zh_Hans, "en_US": self.en_US, "pt_BR": self.pt_BR, "ja_JP": self.ja_JP} diff --git a/api/core/datasource/entities/datasource_entities.py b/api/core/datasource/entities/datasource_entities.py index ac4f51ac75..7f503b963f 100644 --- a/api/core/datasource/entities/datasource_entities.py +++ b/api/core/datasource/entities/datasource_entities.py @@ -54,16 +54,16 @@ class DatasourceParameter(PluginParameter): removes TOOLS_SELECTOR from PluginParameterType """ - STRING = PluginParameterType.STRING.value - NUMBER = PluginParameterType.NUMBER.value - BOOLEAN = PluginParameterType.BOOLEAN.value - SELECT = PluginParameterType.SELECT.value - SECRET_INPUT = PluginParameterType.SECRET_INPUT.value - FILE = PluginParameterType.FILE.value - FILES = PluginParameterType.FILES.value + STRING = PluginParameterType.STRING + NUMBER = PluginParameterType.NUMBER + BOOLEAN = PluginParameterType.BOOLEAN + SELECT = PluginParameterType.SELECT + SECRET_INPUT = PluginParameterType.SECRET_INPUT + FILE = PluginParameterType.FILE + FILES = PluginParameterType.FILES # deprecated, should not use. - SYSTEM_FILES = PluginParameterType.SYSTEM_FILES.value + SYSTEM_FILES = PluginParameterType.SYSTEM_FILES def as_normal_type(self): return as_normal_type(self) diff --git a/api/core/entities/provider_configuration.py b/api/core/entities/provider_configuration.py index de3b0964ff..bc19afb52a 100644 --- a/api/core/entities/provider_configuration.py +++ b/api/core/entities/provider_configuration.py @@ -5,7 +5,7 @@ from collections import defaultdict from collections.abc import Iterator, Sequence from json import JSONDecodeError -from pydantic import BaseModel, ConfigDict, Field +from pydantic import BaseModel, ConfigDict, Field, model_validator from sqlalchemy import func, select from sqlalchemy.orm import Session @@ -73,9 +73,8 @@ class ProviderConfiguration(BaseModel): # pydantic configs model_config = ConfigDict(protected_namespaces=()) - def __init__(self, **data): - super().__init__(**data) - + @model_validator(mode="after") + def _(self): if self.provider.provider not in original_provider_configurate_methods: original_provider_configurate_methods[self.provider.provider] = [] for configurate_method in self.provider.configurate_methods: @@ -90,6 +89,7 @@ class ProviderConfiguration(BaseModel): and ConfigurateMethod.PREDEFINED_MODEL not in self.provider.configurate_methods ): self.provider.configurate_methods.append(ConfigurateMethod.PREDEFINED_MODEL) + return self def get_current_credentials(self, model_type: ModelType, model: str) -> dict | None: """ @@ -205,16 +205,10 @@ class ProviderConfiguration(BaseModel): """ Get custom provider record. """ - # get provider - model_provider_id = ModelProviderID(self.provider.provider) - provider_names = [self.provider.provider] - if model_provider_id.is_langgenius(): - provider_names.append(model_provider_id.provider_name) - stmt = select(Provider).where( Provider.tenant_id == self.tenant_id, - Provider.provider_type == ProviderType.CUSTOM.value, - Provider.provider_name.in_(provider_names), + Provider.provider_type == ProviderType.CUSTOM, + Provider.provider_name.in_(self._get_provider_names()), ) return session.execute(stmt).scalar_one_or_none() @@ -276,7 +270,7 @@ class ProviderConfiguration(BaseModel): """ stmt = select(ProviderCredential.id).where( ProviderCredential.tenant_id == self.tenant_id, - ProviderCredential.provider_name == self.provider.provider, + ProviderCredential.provider_name.in_(self._get_provider_names()), ProviderCredential.credential_name == credential_name, ) if exclude_id: @@ -324,7 +318,7 @@ class ProviderConfiguration(BaseModel): try: stmt = select(ProviderCredential).where( ProviderCredential.tenant_id == self.tenant_id, - ProviderCredential.provider_name == self.provider.provider, + ProviderCredential.provider_name.in_(self._get_provider_names()), ProviderCredential.id == credential_id, ) credential_record = s.execute(stmt).scalar_one_or_none() @@ -374,7 +368,7 @@ class ProviderConfiguration(BaseModel): session=session, query_factory=lambda: select(ProviderCredential).where( ProviderCredential.tenant_id == self.tenant_id, - ProviderCredential.provider_name == self.provider.provider, + ProviderCredential.provider_name.in_(self._get_provider_names()), ), ) @@ -387,7 +381,7 @@ class ProviderConfiguration(BaseModel): session=session, query_factory=lambda: select(ProviderModelCredential).where( ProviderModelCredential.tenant_id == self.tenant_id, - ProviderModelCredential.provider_name == self.provider.provider, + ProviderModelCredential.provider_name.in_(self._get_provider_names()), ProviderModelCredential.model_name == model, ProviderModelCredential.model_type == model_type.to_origin_model_type(), ), @@ -423,6 +417,16 @@ class ProviderConfiguration(BaseModel): logger.warning("Error generating next credential name: %s", str(e)) return "API KEY 1" + def _get_provider_names(self): + """ + The provider name might be stored in the database as either `openai` or `langgenius/openai/openai`. + """ + model_provider_id = ModelProviderID(self.provider.provider) + provider_names = [self.provider.provider] + if model_provider_id.is_langgenius(): + provider_names.append(model_provider_id.provider_name) + return provider_names + def create_provider_credential(self, credentials: dict, credential_name: str | None): """ Add custom provider credentials. @@ -454,7 +458,7 @@ class ProviderConfiguration(BaseModel): provider_record = Provider( tenant_id=self.tenant_id, provider_name=self.provider.provider, - provider_type=ProviderType.CUSTOM.value, + provider_type=ProviderType.CUSTOM, is_valid=True, credential_id=new_record.id, ) @@ -501,7 +505,7 @@ class ProviderConfiguration(BaseModel): stmt = select(ProviderCredential).where( ProviderCredential.id == credential_id, ProviderCredential.tenant_id == self.tenant_id, - ProviderCredential.provider_name == self.provider.provider, + ProviderCredential.provider_name.in_(self._get_provider_names()), ) # Get the credential record to update @@ -554,7 +558,7 @@ class ProviderConfiguration(BaseModel): # Find all load balancing configs that use this credential_id stmt = select(LoadBalancingModelConfig).where( LoadBalancingModelConfig.tenant_id == self.tenant_id, - LoadBalancingModelConfig.provider_name == self.provider.provider, + LoadBalancingModelConfig.provider_name.in_(self._get_provider_names()), LoadBalancingModelConfig.credential_id == credential_id, LoadBalancingModelConfig.credential_source_type == credential_source, ) @@ -591,7 +595,7 @@ class ProviderConfiguration(BaseModel): stmt = select(ProviderCredential).where( ProviderCredential.id == credential_id, ProviderCredential.tenant_id == self.tenant_id, - ProviderCredential.provider_name == self.provider.provider, + ProviderCredential.provider_name.in_(self._get_provider_names()), ) # Get the credential record to update @@ -602,7 +606,7 @@ class ProviderConfiguration(BaseModel): # Check if this credential is used in load balancing configs lb_stmt = select(LoadBalancingModelConfig).where( LoadBalancingModelConfig.tenant_id == self.tenant_id, - LoadBalancingModelConfig.provider_name == self.provider.provider, + LoadBalancingModelConfig.provider_name.in_(self._get_provider_names()), LoadBalancingModelConfig.credential_id == credential_id, LoadBalancingModelConfig.credential_source_type == "provider", ) @@ -624,7 +628,7 @@ class ProviderConfiguration(BaseModel): # if this is the last credential, we need to delete the provider record count_stmt = select(func.count(ProviderCredential.id)).where( ProviderCredential.tenant_id == self.tenant_id, - ProviderCredential.provider_name == self.provider.provider, + ProviderCredential.provider_name.in_(self._get_provider_names()), ) available_credentials_count = session.execute(count_stmt).scalar() or 0 session.delete(credential_record) @@ -668,7 +672,7 @@ class ProviderConfiguration(BaseModel): stmt = select(ProviderCredential).where( ProviderCredential.id == credential_id, ProviderCredential.tenant_id == self.tenant_id, - ProviderCredential.provider_name == self.provider.provider, + ProviderCredential.provider_name.in_(self._get_provider_names()), ) credential_record = session.execute(stmt).scalar_one_or_none() if not credential_record: @@ -737,7 +741,7 @@ class ProviderConfiguration(BaseModel): stmt = select(ProviderModelCredential).where( ProviderModelCredential.id == credential_id, ProviderModelCredential.tenant_id == self.tenant_id, - ProviderModelCredential.provider_name == self.provider.provider, + ProviderModelCredential.provider_name.in_(self._get_provider_names()), ProviderModelCredential.model_name == model, ProviderModelCredential.model_type == model_type.to_origin_model_type(), ) @@ -784,7 +788,7 @@ class ProviderConfiguration(BaseModel): """ stmt = select(ProviderModelCredential).where( ProviderModelCredential.tenant_id == self.tenant_id, - ProviderModelCredential.provider_name == self.provider.provider, + ProviderModelCredential.provider_name.in_(self._get_provider_names()), ProviderModelCredential.model_name == model, ProviderModelCredential.model_type == model_type.to_origin_model_type(), ProviderModelCredential.credential_name == credential_name, @@ -860,7 +864,7 @@ class ProviderConfiguration(BaseModel): stmt = select(ProviderModelCredential).where( ProviderModelCredential.id == credential_id, ProviderModelCredential.tenant_id == self.tenant_id, - ProviderModelCredential.provider_name == self.provider.provider, + ProviderModelCredential.provider_name.in_(self._get_provider_names()), ProviderModelCredential.model_name == model, ProviderModelCredential.model_type == model_type.to_origin_model_type(), ) @@ -997,7 +1001,7 @@ class ProviderConfiguration(BaseModel): stmt = select(ProviderModelCredential).where( ProviderModelCredential.id == credential_id, ProviderModelCredential.tenant_id == self.tenant_id, - ProviderModelCredential.provider_name == self.provider.provider, + ProviderModelCredential.provider_name.in_(self._get_provider_names()), ProviderModelCredential.model_name == model, ProviderModelCredential.model_type == model_type.to_origin_model_type(), ) @@ -1042,7 +1046,7 @@ class ProviderConfiguration(BaseModel): stmt = select(ProviderModelCredential).where( ProviderModelCredential.id == credential_id, ProviderModelCredential.tenant_id == self.tenant_id, - ProviderModelCredential.provider_name == self.provider.provider, + ProviderModelCredential.provider_name.in_(self._get_provider_names()), ProviderModelCredential.model_name == model, ProviderModelCredential.model_type == model_type.to_origin_model_type(), ) @@ -1052,7 +1056,7 @@ class ProviderConfiguration(BaseModel): lb_stmt = select(LoadBalancingModelConfig).where( LoadBalancingModelConfig.tenant_id == self.tenant_id, - LoadBalancingModelConfig.provider_name == self.provider.provider, + LoadBalancingModelConfig.provider_name.in_(self._get_provider_names()), LoadBalancingModelConfig.credential_id == credential_id, LoadBalancingModelConfig.credential_source_type == "custom_model", ) @@ -1075,7 +1079,7 @@ class ProviderConfiguration(BaseModel): # if this is the last credential, we need to delete the custom model record count_stmt = select(func.count(ProviderModelCredential.id)).where( ProviderModelCredential.tenant_id == self.tenant_id, - ProviderModelCredential.provider_name == self.provider.provider, + ProviderModelCredential.provider_name.in_(self._get_provider_names()), ProviderModelCredential.model_name == model, ProviderModelCredential.model_type == model_type.to_origin_model_type(), ) @@ -1115,7 +1119,7 @@ class ProviderConfiguration(BaseModel): stmt = select(ProviderModelCredential).where( ProviderModelCredential.id == credential_id, ProviderModelCredential.tenant_id == self.tenant_id, - ProviderModelCredential.provider_name == self.provider.provider, + ProviderModelCredential.provider_name.in_(self._get_provider_names()), ProviderModelCredential.model_name == model, ProviderModelCredential.model_type == model_type.to_origin_model_type(), ) @@ -1157,7 +1161,7 @@ class ProviderConfiguration(BaseModel): stmt = select(ProviderModelCredential).where( ProviderModelCredential.id == credential_id, ProviderModelCredential.tenant_id == self.tenant_id, - ProviderModelCredential.provider_name == self.provider.provider, + ProviderModelCredential.provider_name.in_(self._get_provider_names()), ProviderModelCredential.model_name == model, ProviderModelCredential.model_type == model_type.to_origin_model_type(), ) @@ -1204,15 +1208,9 @@ class ProviderConfiguration(BaseModel): """ Get provider model setting. """ - - model_provider_id = ModelProviderID(self.provider.provider) - provider_names = [self.provider.provider] - if model_provider_id.is_langgenius(): - provider_names.append(model_provider_id.provider_name) - stmt = select(ProviderModelSetting).where( ProviderModelSetting.tenant_id == self.tenant_id, - ProviderModelSetting.provider_name.in_(provider_names), + ProviderModelSetting.provider_name.in_(self._get_provider_names()), ProviderModelSetting.model_type == model_type.to_origin_model_type(), ProviderModelSetting.model_name == model, ) @@ -1384,15 +1382,9 @@ class ProviderConfiguration(BaseModel): return def _switch(s: Session): - # get preferred provider - model_provider_id = ModelProviderID(self.provider.provider) - provider_names = [self.provider.provider] - if model_provider_id.is_langgenius(): - provider_names.append(model_provider_id.provider_name) - stmt = select(TenantPreferredModelProvider).where( TenantPreferredModelProvider.tenant_id == self.tenant_id, - TenantPreferredModelProvider.provider_name.in_(provider_names), + TenantPreferredModelProvider.provider_name.in_(self._get_provider_names()), ) preferred_model_provider = s.execute(stmt).scalars().first() @@ -1422,7 +1414,7 @@ class ProviderConfiguration(BaseModel): """ secret_input_form_variables = [] for credential_form_schema in credential_form_schemas: - if credential_form_schema.type.value == FormType.SECRET_INPUT.value: + if credential_form_schema.type.value == FormType.SECRET_INPUT: secret_input_form_variables.append(credential_form_schema.variable) return secret_input_form_variables diff --git a/api/core/extension/api_based_extension_requestor.py b/api/core/extension/api_based_extension_requestor.py index fab9ae44e9..f9e6099049 100644 --- a/api/core/extension/api_based_extension_requestor.py +++ b/api/core/extension/api_based_extension_requestor.py @@ -1,13 +1,13 @@ from typing import cast -import requests +import httpx from configs import dify_config from models.api_based_extension import APIBasedExtensionPoint class APIBasedExtensionRequestor: - timeout: tuple[int, int] = (5, 60) + timeout: httpx.Timeout = httpx.Timeout(60.0, connect=5.0) """timeout for request connect and read""" def __init__(self, api_endpoint: str, api_key: str): @@ -27,25 +27,23 @@ class APIBasedExtensionRequestor: url = self.api_endpoint try: - # proxy support for security - proxies = None + mounts: dict[str, httpx.BaseTransport] | None = None if dify_config.SSRF_PROXY_HTTP_URL and dify_config.SSRF_PROXY_HTTPS_URL: - proxies = { - "http": dify_config.SSRF_PROXY_HTTP_URL, - "https": dify_config.SSRF_PROXY_HTTPS_URL, + mounts = { + "http://": httpx.HTTPTransport(proxy=dify_config.SSRF_PROXY_HTTP_URL), + "https://": httpx.HTTPTransport(proxy=dify_config.SSRF_PROXY_HTTPS_URL), } - response = requests.request( - method="POST", - url=url, - json={"point": point.value, "params": params}, - headers=headers, - timeout=self.timeout, - proxies=proxies, - ) - except requests.Timeout: + with httpx.Client(mounts=mounts, timeout=self.timeout) as client: + response = client.request( + method="POST", + url=url, + json={"point": point.value, "params": params}, + headers=headers, + ) + except httpx.TimeoutException: raise ValueError("request timeout") - except requests.ConnectionError: + except httpx.RequestError: raise ValueError("request connection error") if response.status_code != 200: diff --git a/api/core/helper/code_executor/code_executor.py b/api/core/helper/code_executor/code_executor.py index c44a8e1840..f92278f9e2 100644 --- a/api/core/helper/code_executor/code_executor.py +++ b/api/core/helper/code_executor/code_executor.py @@ -4,7 +4,7 @@ from enum import StrEnum from threading import Lock from typing import Any -from httpx import Timeout, post +import httpx from pydantic import BaseModel from yarl import URL @@ -13,9 +13,17 @@ from core.helper.code_executor.javascript.javascript_transformer import NodeJsTe from core.helper.code_executor.jinja2.jinja2_transformer import Jinja2TemplateTransformer from core.helper.code_executor.python3.python3_transformer import Python3TemplateTransformer from core.helper.code_executor.template_transformer import TemplateTransformer +from core.helper.http_client_pooling import get_pooled_http_client logger = logging.getLogger(__name__) code_execution_endpoint_url = URL(str(dify_config.CODE_EXECUTION_ENDPOINT)) +CODE_EXECUTION_SSL_VERIFY = dify_config.CODE_EXECUTION_SSL_VERIFY +_CODE_EXECUTOR_CLIENT_LIMITS = httpx.Limits( + max_connections=dify_config.CODE_EXECUTION_POOL_MAX_CONNECTIONS, + max_keepalive_connections=dify_config.CODE_EXECUTION_POOL_MAX_KEEPALIVE_CONNECTIONS, + keepalive_expiry=dify_config.CODE_EXECUTION_POOL_KEEPALIVE_EXPIRY, +) +_CODE_EXECUTOR_CLIENT_KEY = "code_executor:http_client" class CodeExecutionError(Exception): @@ -38,6 +46,13 @@ class CodeLanguage(StrEnum): JAVASCRIPT = "javascript" +def _build_code_executor_client() -> httpx.Client: + return httpx.Client( + verify=CODE_EXECUTION_SSL_VERIFY, + limits=_CODE_EXECUTOR_CLIENT_LIMITS, + ) + + class CodeExecutor: dependencies_cache: dict[str, str] = {} dependencies_cache_lock = Lock() @@ -76,17 +91,21 @@ class CodeExecutor: "enable_network": True, } + timeout = httpx.Timeout( + connect=dify_config.CODE_EXECUTION_CONNECT_TIMEOUT, + read=dify_config.CODE_EXECUTION_READ_TIMEOUT, + write=dify_config.CODE_EXECUTION_WRITE_TIMEOUT, + pool=None, + ) + + client = get_pooled_http_client(_CODE_EXECUTOR_CLIENT_KEY, _build_code_executor_client) + try: - response = post( + response = client.post( str(url), json=data, headers=headers, - timeout=Timeout( - connect=dify_config.CODE_EXECUTION_CONNECT_TIMEOUT, - read=dify_config.CODE_EXECUTION_READ_TIMEOUT, - write=dify_config.CODE_EXECUTION_WRITE_TIMEOUT, - pool=None, - ), + timeout=timeout, ) if response.status_code == 503: raise CodeExecutionError("Code execution service is unavailable") @@ -106,13 +125,13 @@ class CodeExecutor: try: response_data = response.json() - except: - raise CodeExecutionError("Failed to parse response") + except Exception as e: + raise CodeExecutionError("Failed to parse response") from e if (code := response_data.get("code")) != 0: raise CodeExecutionError(f"Got error code: {code}. Got error msg: {response_data.get('message')}") - response_code = CodeExecutionResponse(**response_data) + response_code = CodeExecutionResponse.model_validate(response_data) if response_code.data.error: raise CodeExecutionError(response_code.data.error) diff --git a/api/core/helper/http_client_pooling.py b/api/core/helper/http_client_pooling.py new file mode 100644 index 0000000000..f4c3ff0e8b --- /dev/null +++ b/api/core/helper/http_client_pooling.py @@ -0,0 +1,59 @@ +"""HTTP client pooling utilities.""" + +from __future__ import annotations + +import atexit +import threading +from collections.abc import Callable + +import httpx + +ClientBuilder = Callable[[], httpx.Client] + + +class HttpClientPoolFactory: + """Thread-safe factory that maintains reusable HTTP client instances.""" + + def __init__(self) -> None: + self._clients: dict[str, httpx.Client] = {} + self._lock = threading.Lock() + + def get_or_create(self, key: str, builder: ClientBuilder) -> httpx.Client: + """Return a pooled client associated with ``key`` creating it on demand.""" + client = self._clients.get(key) + if client is not None: + return client + + with self._lock: + client = self._clients.get(key) + if client is None: + client = builder() + self._clients[key] = client + return client + + def close_all(self) -> None: + """Close all pooled clients and clear the pool.""" + with self._lock: + for client in self._clients.values(): + client.close() + self._clients.clear() + + +_factory = HttpClientPoolFactory() + + +def get_pooled_http_client(key: str, builder: ClientBuilder) -> httpx.Client: + """Return a pooled client for the given ``key`` using ``builder`` when missing.""" + return _factory.get_or_create(key, builder) + + +def close_all_pooled_clients() -> None: + """Close every client created through the pooling factory.""" + _factory.close_all() + + +def _register_shutdown_hook() -> None: + atexit.register(close_all_pooled_clients) + + +_register_shutdown_hook() diff --git a/api/core/helper/marketplace.py b/api/core/helper/marketplace.py index 89dae4808f..bddb864a95 100644 --- a/api/core/helper/marketplace.py +++ b/api/core/helper/marketplace.py @@ -23,10 +23,10 @@ def batch_fetch_plugin_manifests(plugin_ids: list[str]) -> Sequence[MarketplaceP return [] url = str(marketplace_api_url / "api/v1/plugins/batch") - response = httpx.post(url, json={"plugin_ids": plugin_ids}) + response = httpx.post(url, json={"plugin_ids": plugin_ids}, headers={"X-Dify-Version": dify_config.project.version}) response.raise_for_status() - return [MarketplacePluginDeclaration(**plugin) for plugin in response.json()["data"]["plugins"]] + return [MarketplacePluginDeclaration.model_validate(plugin) for plugin in response.json()["data"]["plugins"]] def batch_fetch_plugin_manifests_ignore_deserialization_error( @@ -36,12 +36,12 @@ def batch_fetch_plugin_manifests_ignore_deserialization_error( return [] url = str(marketplace_api_url / "api/v1/plugins/batch") - response = httpx.post(url, json={"plugin_ids": plugin_ids}) + response = httpx.post(url, json={"plugin_ids": plugin_ids}, headers={"X-Dify-Version": dify_config.project.version}) response.raise_for_status() result: list[MarketplacePluginDeclaration] = [] for plugin in response.json()["data"]["plugins"]: try: - result.append(MarketplacePluginDeclaration(**plugin)) + result.append(MarketplacePluginDeclaration.model_validate(plugin)) except Exception: pass diff --git a/api/core/helper/ssrf_proxy.py b/api/core/helper/ssrf_proxy.py index cbb78939d2..0de026f3c7 100644 --- a/api/core/helper/ssrf_proxy.py +++ b/api/core/helper/ssrf_proxy.py @@ -8,27 +8,23 @@ import time import httpx from configs import dify_config +from core.helper.http_client_pooling import get_pooled_http_client logger = logging.getLogger(__name__) SSRF_DEFAULT_MAX_RETRIES = dify_config.SSRF_DEFAULT_MAX_RETRIES -http_request_node_ssl_verify = True # Default value for http_request_node_ssl_verify is True -try: - config_value = dify_config.HTTP_REQUEST_NODE_SSL_VERIFY - http_request_node_ssl_verify_lower = str(config_value).lower() - if http_request_node_ssl_verify_lower == "true": - http_request_node_ssl_verify = True - elif http_request_node_ssl_verify_lower == "false": - http_request_node_ssl_verify = False - else: - raise ValueError("Invalid value. HTTP_REQUEST_NODE_SSL_VERIFY should be 'True' or 'False'") -except NameError: - http_request_node_ssl_verify = True - BACKOFF_FACTOR = 0.5 STATUS_FORCELIST = [429, 500, 502, 503, 504] +_SSL_VERIFIED_POOL_KEY = "ssrf:verified" +_SSL_UNVERIFIED_POOL_KEY = "ssrf:unverified" +_SSRF_CLIENT_LIMITS = httpx.Limits( + max_connections=dify_config.SSRF_POOL_MAX_CONNECTIONS, + max_keepalive_connections=dify_config.SSRF_POOL_MAX_KEEPALIVE_CONNECTIONS, + keepalive_expiry=dify_config.SSRF_POOL_KEEPALIVE_EXPIRY, +) + class MaxRetriesExceededError(ValueError): """Raised when the maximum number of retries is exceeded.""" @@ -36,6 +32,45 @@ class MaxRetriesExceededError(ValueError): pass +def _create_proxy_mounts() -> dict[str, httpx.HTTPTransport]: + return { + "http://": httpx.HTTPTransport( + proxy=dify_config.SSRF_PROXY_HTTP_URL, + ), + "https://": httpx.HTTPTransport( + proxy=dify_config.SSRF_PROXY_HTTPS_URL, + ), + } + + +def _build_ssrf_client(verify: bool) -> httpx.Client: + if dify_config.SSRF_PROXY_ALL_URL: + return httpx.Client( + proxy=dify_config.SSRF_PROXY_ALL_URL, + verify=verify, + limits=_SSRF_CLIENT_LIMITS, + ) + + if dify_config.SSRF_PROXY_HTTP_URL and dify_config.SSRF_PROXY_HTTPS_URL: + return httpx.Client( + mounts=_create_proxy_mounts(), + verify=verify, + limits=_SSRF_CLIENT_LIMITS, + ) + + return httpx.Client(verify=verify, limits=_SSRF_CLIENT_LIMITS) + + +def _get_ssrf_client(ssl_verify_enabled: bool) -> httpx.Client: + if not isinstance(ssl_verify_enabled, bool): + raise ValueError("SSRF client verify flag must be a boolean") + + return get_pooled_http_client( + _SSL_VERIFIED_POOL_KEY if ssl_verify_enabled else _SSL_UNVERIFIED_POOL_KEY, + lambda: _build_ssrf_client(verify=ssl_verify_enabled), + ) + + def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): if "allow_redirects" in kwargs: allow_redirects = kwargs.pop("allow_redirects") @@ -50,33 +85,22 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): write=dify_config.SSRF_DEFAULT_WRITE_TIME_OUT, ) - if "ssl_verify" not in kwargs: - kwargs["ssl_verify"] = http_request_node_ssl_verify - - ssl_verify = kwargs.pop("ssl_verify") + # prioritize per-call option, which can be switched on and off inside the HTTP node on the web UI + verify_option = kwargs.pop("ssl_verify", dify_config.HTTP_REQUEST_NODE_SSL_VERIFY) + client = _get_ssrf_client(verify_option) retries = 0 while retries <= max_retries: try: - if dify_config.SSRF_PROXY_ALL_URL: - with httpx.Client(proxy=dify_config.SSRF_PROXY_ALL_URL, verify=ssl_verify) as client: - response = client.request(method=method, url=url, **kwargs) - elif dify_config.SSRF_PROXY_HTTP_URL and dify_config.SSRF_PROXY_HTTPS_URL: - proxy_mounts = { - "http://": httpx.HTTPTransport(proxy=dify_config.SSRF_PROXY_HTTP_URL, verify=ssl_verify), - "https://": httpx.HTTPTransport(proxy=dify_config.SSRF_PROXY_HTTPS_URL, verify=ssl_verify), - } - with httpx.Client(mounts=proxy_mounts, verify=ssl_verify) as client: - response = client.request(method=method, url=url, **kwargs) - else: - with httpx.Client(verify=ssl_verify) as client: - response = client.request(method=method, url=url, **kwargs) + response = client.request(method=method, url=url, **kwargs) if response.status_code not in STATUS_FORCELIST: return response else: logger.warning( - "Received status code %s for URL %s which is in the force list", response.status_code, url + "Received status code %s for URL %s which is in the force list", + response.status_code, + url, ) except httpx.RequestError as e: diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index ee37024260..7822ed4268 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -20,7 +20,7 @@ from core.rag.cleaner.clean_processor import CleanProcessor from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.docstore.dataset_docstore import DatasetDocumentStore from core.rag.extractor.entity.datasource_type import DatasourceType -from core.rag.extractor.entity.extract_setting import ExtractSetting +from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo, WebsiteInfo from core.rag.index_processor.constant.index_type import IndexType from core.rag.index_processor.index_processor_base import BaseIndexProcessor from core.rag.index_processor.index_processor_factory import IndexProcessorFactory @@ -343,7 +343,7 @@ class IndexingRunner: if file_detail: extract_setting = ExtractSetting( - datasource_type=DatasourceType.FILE.value, + datasource_type=DatasourceType.FILE, upload_file=file_detail, document_model=dataset_document.doc_form, ) @@ -356,15 +356,17 @@ class IndexingRunner: ): raise ValueError("no notion import info found") extract_setting = ExtractSetting( - datasource_type=DatasourceType.NOTION.value, - notion_info={ - "credential_id": data_source_info["credential_id"], - "notion_workspace_id": data_source_info["notion_workspace_id"], - "notion_obj_id": data_source_info["notion_page_id"], - "notion_page_type": data_source_info["type"], - "document": dataset_document, - "tenant_id": dataset_document.tenant_id, - }, + datasource_type=DatasourceType.NOTION, + notion_info=NotionInfo.model_validate( + { + "credential_id": data_source_info["credential_id"], + "notion_workspace_id": data_source_info["notion_workspace_id"], + "notion_obj_id": data_source_info["notion_page_id"], + "notion_page_type": data_source_info["type"], + "document": dataset_document, + "tenant_id": dataset_document.tenant_id, + } + ), document_model=dataset_document.doc_form, ) text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"]) @@ -377,15 +379,17 @@ class IndexingRunner: ): raise ValueError("no website import info found") extract_setting = ExtractSetting( - datasource_type=DatasourceType.WEBSITE.value, - website_info={ - "provider": data_source_info["provider"], - "job_id": data_source_info["job_id"], - "tenant_id": dataset_document.tenant_id, - "url": data_source_info["url"], - "mode": data_source_info["mode"], - "only_main_content": data_source_info["only_main_content"], - }, + datasource_type=DatasourceType.WEBSITE, + website_info=WebsiteInfo.model_validate( + { + "provider": data_source_info["provider"], + "job_id": data_source_info["job_id"], + "tenant_id": dataset_document.tenant_id, + "url": data_source_info["url"], + "mode": data_source_info["mode"], + "only_main_content": data_source_info["only_main_content"], + } + ), document_model=dataset_document.doc_form, ) text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"]) diff --git a/api/core/llm_generator/llm_generator.py b/api/core/llm_generator/llm_generator.py index e07d0ec14e..e64ac25ab1 100644 --- a/api/core/llm_generator/llm_generator.py +++ b/api/core/llm_generator/llm_generator.py @@ -28,7 +28,6 @@ from core.ops.ops_trace_manager import TraceQueueManager, TraceTask from core.ops.utils import measure_time from core.prompt.utils.prompt_template_parser import PromptTemplateParser from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey -from core.workflow.node_events import AgentLogEvent from extensions.ext_database import db from extensions.ext_storage import storage from models import App, Message, WorkflowNodeExecutionModel @@ -462,19 +461,18 @@ class LLMGenerator: ) def agent_log_of(node_execution: WorkflowNodeExecutionModel) -> Sequence: - raw_agent_log = node_execution.execution_metadata_dict.get(WorkflowNodeExecutionMetadataKey.AGENT_LOG) + raw_agent_log = node_execution.execution_metadata_dict.get(WorkflowNodeExecutionMetadataKey.AGENT_LOG, []) if not raw_agent_log: return [] - parsed: Sequence[AgentLogEvent] = json.loads(raw_agent_log) - def dict_of_event(event: AgentLogEvent): - return { - "status": event.status, - "error": event.error, - "data": event.data, + return [ + { + "status": event["status"], + "error": event["error"], + "data": event["data"], } - - return [dict_of_event(event) for event in parsed] + for event in raw_agent_log + ] inputs = last_run.load_full_inputs(session, storage) last_run_dict = { diff --git a/api/core/llm_generator/output_parser/structured_output.py b/api/core/llm_generator/output_parser/structured_output.py index 1e302b7668..686529c3ca 100644 --- a/api/core/llm_generator/output_parser/structured_output.py +++ b/api/core/llm_generator/output_parser/structured_output.py @@ -224,8 +224,8 @@ def _handle_native_json_schema( # Set appropriate response format if required by the model for rule in rules: - if rule.name == "response_format" and ResponseFormat.JSON_SCHEMA.value in rule.options: - model_parameters["response_format"] = ResponseFormat.JSON_SCHEMA.value + if rule.name == "response_format" and ResponseFormat.JSON_SCHEMA in rule.options: + model_parameters["response_format"] = ResponseFormat.JSON_SCHEMA return model_parameters @@ -239,10 +239,10 @@ def _set_response_format(model_parameters: dict, rules: list): """ for rule in rules: if rule.name == "response_format": - if ResponseFormat.JSON.value in rule.options: - model_parameters["response_format"] = ResponseFormat.JSON.value - elif ResponseFormat.JSON_OBJECT.value in rule.options: - model_parameters["response_format"] = ResponseFormat.JSON_OBJECT.value + if ResponseFormat.JSON in rule.options: + model_parameters["response_format"] = ResponseFormat.JSON + elif ResponseFormat.JSON_OBJECT in rule.options: + model_parameters["response_format"] = ResponseFormat.JSON_OBJECT def _handle_prompt_based_schema( diff --git a/api/core/mcp/session/client_session.py b/api/core/mcp/session/client_session.py index 5817416ba4..fa1d309134 100644 --- a/api/core/mcp/session/client_session.py +++ b/api/core/mcp/session/client_session.py @@ -294,7 +294,7 @@ class ClientSession( method="completion/complete", params=types.CompleteRequestParams( ref=ref, - argument=types.CompletionArgument(**argument), + argument=types.CompletionArgument.model_validate(argument), ), ) ), diff --git a/api/core/model_runtime/entities/common_entities.py b/api/core/model_runtime/entities/common_entities.py index c7353de5af..b673efae22 100644 --- a/api/core/model_runtime/entities/common_entities.py +++ b/api/core/model_runtime/entities/common_entities.py @@ -1,4 +1,4 @@ -from pydantic import BaseModel +from pydantic import BaseModel, model_validator class I18nObject(BaseModel): @@ -9,7 +9,8 @@ class I18nObject(BaseModel): zh_Hans: str | None = None en_US: str - def __init__(self, **data): - super().__init__(**data) + @model_validator(mode="after") + def _(self): if not self.zh_Hans: self.zh_Hans = self.en_US + return self diff --git a/api/core/model_runtime/entities/message_entities.py b/api/core/model_runtime/entities/message_entities.py index 9235c881e0..89dae2dbff 100644 --- a/api/core/model_runtime/entities/message_entities.py +++ b/api/core/model_runtime/entities/message_entities.py @@ -74,7 +74,7 @@ class TextPromptMessageContent(PromptMessageContent): Model class for text prompt message content. """ - type: Literal[PromptMessageContentType.TEXT] = PromptMessageContentType.TEXT + type: Literal[PromptMessageContentType.TEXT] = PromptMessageContentType.TEXT # type: ignore data: str @@ -95,11 +95,11 @@ class MultiModalPromptMessageContent(PromptMessageContent): class VideoPromptMessageContent(MultiModalPromptMessageContent): - type: Literal[PromptMessageContentType.VIDEO] = PromptMessageContentType.VIDEO + type: Literal[PromptMessageContentType.VIDEO] = PromptMessageContentType.VIDEO # type: ignore class AudioPromptMessageContent(MultiModalPromptMessageContent): - type: Literal[PromptMessageContentType.AUDIO] = PromptMessageContentType.AUDIO + type: Literal[PromptMessageContentType.AUDIO] = PromptMessageContentType.AUDIO # type: ignore class ImagePromptMessageContent(MultiModalPromptMessageContent): @@ -111,12 +111,12 @@ class ImagePromptMessageContent(MultiModalPromptMessageContent): LOW = auto() HIGH = auto() - type: Literal[PromptMessageContentType.IMAGE] = PromptMessageContentType.IMAGE + type: Literal[PromptMessageContentType.IMAGE] = PromptMessageContentType.IMAGE # type: ignore detail: DETAIL = DETAIL.LOW class DocumentPromptMessageContent(MultiModalPromptMessageContent): - type: Literal[PromptMessageContentType.DOCUMENT] = PromptMessageContentType.DOCUMENT + type: Literal[PromptMessageContentType.DOCUMENT] = PromptMessageContentType.DOCUMENT # type: ignore PromptMessageContentUnionTypes = Annotated[ diff --git a/api/core/model_runtime/entities/provider_entities.py b/api/core/model_runtime/entities/provider_entities.py index 2ccc9e0eae..831fb9d4db 100644 --- a/api/core/model_runtime/entities/provider_entities.py +++ b/api/core/model_runtime/entities/provider_entities.py @@ -1,7 +1,7 @@ from collections.abc import Sequence from enum import Enum, StrEnum, auto -from pydantic import BaseModel, ConfigDict, Field, field_validator +from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.model_entities import AIModelEntity, ModelType @@ -46,10 +46,11 @@ class FormOption(BaseModel): value: str show_on: list[FormShowOnObject] = [] - def __init__(self, **data): - super().__init__(**data) + @model_validator(mode="after") + def _(self): if not self.label: self.label = I18nObject(en_US=self.value) + return self class CredentialFormSchema(BaseModel): diff --git a/api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenizer.py b/api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenizer.py index 23d36c03af..3967acf07b 100644 --- a/api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenizer.py +++ b/api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenizer.py @@ -15,7 +15,7 @@ class GPT2Tokenizer: use gpt2 tokenizer to get num tokens """ _tokenizer = GPT2Tokenizer.get_encoder() - tokens = _tokenizer.encode(text) + tokens = _tokenizer.encode(text) # type: ignore return len(tokens) @staticmethod diff --git a/api/core/model_runtime/model_providers/model_provider_factory.py b/api/core/model_runtime/model_providers/model_provider_factory.py index e070c17abd..e1afc41bee 100644 --- a/api/core/model_runtime/model_providers/model_provider_factory.py +++ b/api/core/model_runtime/model_providers/model_provider_factory.py @@ -269,17 +269,17 @@ class ModelProviderFactory: } if model_type == ModelType.LLM: - return LargeLanguageModel(**init_params) # type: ignore + return LargeLanguageModel.model_validate(init_params) elif model_type == ModelType.TEXT_EMBEDDING: - return TextEmbeddingModel(**init_params) # type: ignore + return TextEmbeddingModel.model_validate(init_params) elif model_type == ModelType.RERANK: - return RerankModel(**init_params) # type: ignore + return RerankModel.model_validate(init_params) elif model_type == ModelType.SPEECH2TEXT: - return Speech2TextModel(**init_params) # type: ignore + return Speech2TextModel.model_validate(init_params) elif model_type == ModelType.MODERATION: - return ModerationModel(**init_params) # type: ignore + return ModerationModel.model_validate(init_params) elif model_type == ModelType.TTS: - return TTSModel(**init_params) # type: ignore + return TTSModel.model_validate(init_params) def get_provider_icon(self, provider: str, icon_type: str, lang: str) -> tuple[bytes, str]: """ diff --git a/api/core/model_runtime/utils/encoders.py b/api/core/model_runtime/utils/encoders.py index c758eaf49f..c85152463e 100644 --- a/api/core/model_runtime/utils/encoders.py +++ b/api/core/model_runtime/utils/encoders.py @@ -196,15 +196,15 @@ def jsonable_encoder( return encoder(obj) try: - data = dict(obj) + data = dict(obj) # type: ignore except Exception as e: errors: list[Exception] = [] errors.append(e) try: - data = vars(obj) + data = vars(obj) # type: ignore except Exception as e: errors.append(e) - raise ValueError(errors) from e + raise ValueError(str(errors)) from e return jsonable_encoder( data, by_alias=by_alias, diff --git a/api/core/moderation/api/api.py b/api/core/moderation/api/api.py index 573f4ec2a7..2d72b17a04 100644 --- a/api/core/moderation/api/api.py +++ b/api/core/moderation/api/api.py @@ -51,7 +51,7 @@ class ApiModeration(Moderation): params = ModerationInputParams(app_id=self.app_id, inputs=inputs, query=query) result = self._get_config_by_requestor(APIBasedExtensionPoint.APP_MODERATION_INPUT, params.model_dump()) - return ModerationInputsResult(**result) + return ModerationInputsResult.model_validate(result) return ModerationInputsResult( flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response @@ -67,7 +67,7 @@ class ApiModeration(Moderation): params = ModerationOutputParams(app_id=self.app_id, text=text) result = self._get_config_by_requestor(APIBasedExtensionPoint.APP_MODERATION_OUTPUT, params.model_dump()) - return ModerationOutputsResult(**result) + return ModerationOutputsResult.model_validate(result) return ModerationOutputsResult( flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response diff --git a/api/core/ops/aliyun_trace/aliyun_trace.py b/api/core/ops/aliyun_trace/aliyun_trace.py index 7e817a6bff..a7d8576d8d 100644 --- a/api/core/ops/aliyun_trace/aliyun_trace.py +++ b/api/core/ops/aliyun_trace/aliyun_trace.py @@ -1,38 +1,28 @@ -import json import logging from collections.abc import Sequence -from urllib.parse import urljoin -from opentelemetry.trace import Link, Status, StatusCode -from sqlalchemy import select -from sqlalchemy.orm import Session, sessionmaker +from sqlalchemy.orm import sessionmaker from core.ops.aliyun_trace.data_exporter.traceclient import ( TraceClient, + build_endpoint, convert_datetime_to_nanoseconds, convert_to_span_id, convert_to_trace_id, - create_link, generate_span_id, ) -from core.ops.aliyun_trace.entities.aliyun_trace_entity import SpanData +from core.ops.aliyun_trace.entities.aliyun_trace_entity import SpanData, TraceMetadata from core.ops.aliyun_trace.entities.semconv import ( GEN_AI_COMPLETION, - GEN_AI_FRAMEWORK, - GEN_AI_MODEL_NAME, + GEN_AI_INPUT_MESSAGE, + GEN_AI_OUTPUT_MESSAGE, GEN_AI_PROMPT, - GEN_AI_PROMPT_TEMPLATE_TEMPLATE, - GEN_AI_PROMPT_TEMPLATE_VARIABLE, + GEN_AI_PROVIDER_NAME, + GEN_AI_REQUEST_MODEL, GEN_AI_RESPONSE_FINISH_REASON, - GEN_AI_SESSION_ID, - GEN_AI_SPAN_KIND, - GEN_AI_SYSTEM, GEN_AI_USAGE_INPUT_TOKENS, GEN_AI_USAGE_OUTPUT_TOKENS, GEN_AI_USAGE_TOTAL_TOKENS, - GEN_AI_USER_ID, - INPUT_VALUE, - OUTPUT_VALUE, RETRIEVAL_DOCUMENT, RETRIEVAL_QUERY, TOOL_DESCRIPTION, @@ -40,6 +30,18 @@ from core.ops.aliyun_trace.entities.semconv import ( TOOL_PARAMETERS, GenAISpanKind, ) +from core.ops.aliyun_trace.utils import ( + create_common_span_attributes, + create_links_from_trace_id, + create_status_from_error, + extract_retrieval_documents, + format_input_messages, + format_output_messages, + format_retrieval_documents, + get_user_id_from_message_data, + get_workflow_node_status, + serialize_json_data, +) from core.ops.base_trace_instance import BaseTraceInstance from core.ops.entities.config_entity import AliyunConfig from core.ops.entities.trace_entity import ( @@ -52,12 +54,11 @@ from core.ops.entities.trace_entity import ( ToolTraceInfo, WorkflowTraceInfo, ) -from core.rag.models.document import Document from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository from core.workflow.entities import WorkflowNodeExecution -from core.workflow.enums import NodeType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus +from core.workflow.enums import NodeType, WorkflowNodeExecutionMetadataKey from extensions.ext_database import db -from models import Account, App, EndUser, TenantAccountJoin, WorkflowNodeExecutionTriggeredFrom +from models import WorkflowNodeExecutionTriggeredFrom logger = logging.getLogger(__name__) @@ -68,8 +69,7 @@ class AliyunDataTrace(BaseTraceInstance): aliyun_config: AliyunConfig, ): super().__init__(aliyun_config) - base_url = aliyun_config.endpoint.rstrip("/") - endpoint = urljoin(base_url, f"adapt_{aliyun_config.license_key}/api/otlp/traces") + endpoint = build_endpoint(aliyun_config.endpoint, aliyun_config.license_key) self.trace_client = TraceClient(service_name=aliyun_config.app_name, endpoint=endpoint) def trace(self, trace_info: BaseTraceInfo): @@ -95,423 +95,425 @@ class AliyunDataTrace(BaseTraceInstance): try: return self.trace_client.get_project_url() except Exception as e: - logger.info("Aliyun get run url failed: %s", str(e), exc_info=True) - raise ValueError(f"Aliyun get run url failed: {str(e)}") + logger.info("Aliyun get project url failed: %s", str(e), exc_info=True) + raise ValueError(f"Aliyun get project url failed: {str(e)}") def workflow_trace(self, trace_info: WorkflowTraceInfo): - trace_id = convert_to_trace_id(trace_info.workflow_run_id) - links = [] - if trace_info.trace_id: - links.append(create_link(trace_id_str=trace_info.trace_id)) - workflow_span_id = convert_to_span_id(trace_info.workflow_run_id, "workflow") - self.add_workflow_span(trace_id, workflow_span_id, trace_info, links) + trace_metadata = TraceMetadata( + trace_id=convert_to_trace_id(trace_info.workflow_run_id), + workflow_span_id=convert_to_span_id(trace_info.workflow_run_id, "workflow"), + session_id=trace_info.metadata.get("conversation_id") or "", + user_id=str(trace_info.metadata.get("user_id") or ""), + links=create_links_from_trace_id(trace_info.trace_id), + ) + + self.add_workflow_span(trace_info, trace_metadata) workflow_node_executions = self.get_workflow_node_executions(trace_info) for node_execution in workflow_node_executions: - node_span = self.build_workflow_node_span(node_execution, trace_id, trace_info, workflow_span_id) + node_span = self.build_workflow_node_span(node_execution, trace_info, trace_metadata) self.trace_client.add_span(node_span) def message_trace(self, trace_info: MessageTraceInfo): message_data = trace_info.message_data if message_data is None: return + message_id = trace_info.message_id + user_id = get_user_id_from_message_data(message_data) + status = create_status_from_error(trace_info.error) - user_id = message_data.from_account_id - if message_data.from_end_user_id: - end_user_data: EndUser | None = ( - db.session.query(EndUser).where(EndUser.id == message_data.from_end_user_id).first() - ) - if end_user_data is not None: - user_id = end_user_data.session_id + trace_metadata = TraceMetadata( + trace_id=convert_to_trace_id(message_id), + workflow_span_id=0, + session_id=trace_info.metadata.get("conversation_id") or "", + user_id=user_id, + links=create_links_from_trace_id(trace_info.trace_id), + ) - status: Status = Status(StatusCode.OK) - if trace_info.error: - status = Status(StatusCode.ERROR, trace_info.error) - - trace_id = convert_to_trace_id(message_id) - links = [] - if trace_info.trace_id: - links.append(create_link(trace_id_str=trace_info.trace_id)) + inputs_json = serialize_json_data(trace_info.inputs) + outputs_str = str(trace_info.outputs) message_span_id = convert_to_span_id(message_id, "message") message_span = SpanData( - trace_id=trace_id, + trace_id=trace_metadata.trace_id, parent_span_id=None, span_id=message_span_id, name="message", start_time=convert_datetime_to_nanoseconds(trace_info.start_time), end_time=convert_datetime_to_nanoseconds(trace_info.end_time), - attributes={ - GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id") or "", - GEN_AI_USER_ID: str(user_id), - GEN_AI_SPAN_KIND: GenAISpanKind.CHAIN.value, - GEN_AI_FRAMEWORK: "dify", - INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False), - OUTPUT_VALUE: str(trace_info.outputs), - }, + attributes=create_common_span_attributes( + session_id=trace_metadata.session_id, + user_id=trace_metadata.user_id, + span_kind=GenAISpanKind.CHAIN, + inputs=inputs_json, + outputs=outputs_str, + ), status=status, - links=links, + links=trace_metadata.links, ) self.trace_client.add_span(message_span) - app_model_config = getattr(trace_info.message_data, "app_model_config", {}) - pre_prompt = getattr(app_model_config, "pre_prompt", "") - inputs_data = getattr(trace_info.message_data, "inputs", {}) llm_span = SpanData( - trace_id=trace_id, + trace_id=trace_metadata.trace_id, parent_span_id=message_span_id, span_id=convert_to_span_id(message_id, "llm"), name="llm", start_time=convert_datetime_to_nanoseconds(trace_info.start_time), end_time=convert_datetime_to_nanoseconds(trace_info.end_time), attributes={ - GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id") or "", - GEN_AI_USER_ID: str(user_id), - GEN_AI_SPAN_KIND: GenAISpanKind.LLM.value, - GEN_AI_FRAMEWORK: "dify", - GEN_AI_MODEL_NAME: trace_info.metadata.get("ls_model_name") or "", - GEN_AI_SYSTEM: trace_info.metadata.get("ls_provider") or "", + **create_common_span_attributes( + session_id=trace_metadata.session_id, + user_id=trace_metadata.user_id, + span_kind=GenAISpanKind.LLM, + inputs=inputs_json, + outputs=outputs_str, + ), + GEN_AI_REQUEST_MODEL: trace_info.metadata.get("ls_model_name") or "", + GEN_AI_PROVIDER_NAME: trace_info.metadata.get("ls_provider") or "", GEN_AI_USAGE_INPUT_TOKENS: str(trace_info.message_tokens), GEN_AI_USAGE_OUTPUT_TOKENS: str(trace_info.answer_tokens), GEN_AI_USAGE_TOTAL_TOKENS: str(trace_info.total_tokens), - GEN_AI_PROMPT_TEMPLATE_VARIABLE: json.dumps(inputs_data, ensure_ascii=False), - GEN_AI_PROMPT_TEMPLATE_TEMPLATE: pre_prompt, - GEN_AI_PROMPT: json.dumps(trace_info.inputs, ensure_ascii=False), - GEN_AI_COMPLETION: str(trace_info.outputs), - INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False), - OUTPUT_VALUE: str(trace_info.outputs), + GEN_AI_PROMPT: inputs_json, + GEN_AI_COMPLETION: outputs_str, }, status=status, + links=trace_metadata.links, ) self.trace_client.add_span(llm_span) def dataset_retrieval_trace(self, trace_info: DatasetRetrievalTraceInfo): if trace_info.message_data is None: return + message_id = trace_info.message_id - trace_id = convert_to_trace_id(message_id) - links = [] - if trace_info.trace_id: - links.append(create_link(trace_id_str=trace_info.trace_id)) + trace_metadata = TraceMetadata( + trace_id=convert_to_trace_id(message_id), + workflow_span_id=0, + session_id=trace_info.metadata.get("conversation_id") or "", + user_id=str(trace_info.metadata.get("user_id") or ""), + links=create_links_from_trace_id(trace_info.trace_id), + ) documents_data = extract_retrieval_documents(trace_info.documents) + documents_json = serialize_json_data(documents_data) + inputs_str = str(trace_info.inputs) + dataset_retrieval_span = SpanData( - trace_id=trace_id, + trace_id=trace_metadata.trace_id, parent_span_id=convert_to_span_id(message_id, "message"), span_id=generate_span_id(), name="dataset_retrieval", start_time=convert_datetime_to_nanoseconds(trace_info.start_time), end_time=convert_datetime_to_nanoseconds(trace_info.end_time), attributes={ - GEN_AI_SPAN_KIND: GenAISpanKind.RETRIEVER.value, - GEN_AI_FRAMEWORK: "dify", - RETRIEVAL_QUERY: str(trace_info.inputs), - RETRIEVAL_DOCUMENT: json.dumps(documents_data, ensure_ascii=False), - INPUT_VALUE: str(trace_info.inputs), - OUTPUT_VALUE: json.dumps(documents_data, ensure_ascii=False), + **create_common_span_attributes( + session_id=trace_metadata.session_id, + user_id=trace_metadata.user_id, + span_kind=GenAISpanKind.RETRIEVER, + inputs=inputs_str, + outputs=documents_json, + ), + RETRIEVAL_QUERY: inputs_str, + RETRIEVAL_DOCUMENT: documents_json, }, - links=links, + links=trace_metadata.links, ) self.trace_client.add_span(dataset_retrieval_span) def tool_trace(self, trace_info: ToolTraceInfo): if trace_info.message_data is None: return + message_id = trace_info.message_id + status = create_status_from_error(trace_info.error) - status: Status = Status(StatusCode.OK) - if trace_info.error: - status = Status(StatusCode.ERROR, trace_info.error) + trace_metadata = TraceMetadata( + trace_id=convert_to_trace_id(message_id), + workflow_span_id=0, + session_id=trace_info.metadata.get("conversation_id") or "", + user_id=str(trace_info.metadata.get("user_id") or ""), + links=create_links_from_trace_id(trace_info.trace_id), + ) - trace_id = convert_to_trace_id(message_id) - links = [] - if trace_info.trace_id: - links.append(create_link(trace_id_str=trace_info.trace_id)) + tool_config_json = serialize_json_data(trace_info.tool_config) + tool_inputs_json = serialize_json_data(trace_info.tool_inputs) + inputs_json = serialize_json_data(trace_info.inputs) tool_span = SpanData( - trace_id=trace_id, + trace_id=trace_metadata.trace_id, parent_span_id=convert_to_span_id(message_id, "message"), span_id=generate_span_id(), name=trace_info.tool_name, start_time=convert_datetime_to_nanoseconds(trace_info.start_time), end_time=convert_datetime_to_nanoseconds(trace_info.end_time), attributes={ - GEN_AI_SPAN_KIND: GenAISpanKind.TOOL.value, - GEN_AI_FRAMEWORK: "dify", + **create_common_span_attributes( + session_id=trace_metadata.session_id, + user_id=trace_metadata.user_id, + span_kind=GenAISpanKind.TOOL, + inputs=inputs_json, + outputs=str(trace_info.tool_outputs), + ), TOOL_NAME: trace_info.tool_name, - TOOL_DESCRIPTION: json.dumps(trace_info.tool_config, ensure_ascii=False), - TOOL_PARAMETERS: json.dumps(trace_info.tool_inputs, ensure_ascii=False), - INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False), - OUTPUT_VALUE: str(trace_info.tool_outputs), + TOOL_DESCRIPTION: tool_config_json, + TOOL_PARAMETERS: tool_inputs_json, }, status=status, - links=links, + links=trace_metadata.links, ) self.trace_client.add_span(tool_span) def get_workflow_node_executions(self, trace_info: WorkflowTraceInfo) -> Sequence[WorkflowNodeExecution]: - # through workflow_run_id get all_nodes_execution using repository - session_factory = sessionmaker(bind=db.engine) - # Find the app's creator account - with Session(db.engine, expire_on_commit=False) as session: - # Get the app to find its creator - app_id = trace_info.metadata.get("app_id") - if not app_id: - raise ValueError("No app_id found in trace_info metadata") - app_stmt = select(App).where(App.id == app_id) - app = session.scalar(app_stmt) - if not app: - raise ValueError(f"App with id {app_id} not found") + app_id = trace_info.metadata.get("app_id") + if not app_id: + raise ValueError("No app_id found in trace_info metadata") - if not app.created_by: - raise ValueError(f"App with id {app_id} has no creator (created_by is None)") - account_stmt = select(Account).where(Account.id == app.created_by) - service_account = session.scalar(account_stmt) - if not service_account: - raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}") - current_tenant = ( - session.query(TenantAccountJoin).filter_by(account_id=service_account.id, current=True).first() - ) - if not current_tenant: - raise ValueError(f"Current tenant not found for account {service_account.id}") - service_account.set_tenant_id(current_tenant.tenant_id) + service_account = self.get_service_account_with_tenant(app_id) + + session_factory = sessionmaker(bind=db.engine) workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( session_factory=session_factory, user=service_account, app_id=app_id, triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, ) - # Get all executions for this workflow run - workflow_node_executions = workflow_node_execution_repository.get_by_workflow_run( - workflow_run_id=trace_info.workflow_run_id - ) - return workflow_node_executions + + return workflow_node_execution_repository.get_by_workflow_run(workflow_run_id=trace_info.workflow_run_id) def build_workflow_node_span( - self, node_execution: WorkflowNodeExecution, trace_id: int, trace_info: WorkflowTraceInfo, workflow_span_id: int + self, node_execution: WorkflowNodeExecution, trace_info: WorkflowTraceInfo, trace_metadata: TraceMetadata ): try: if node_execution.node_type == NodeType.LLM: - node_span = self.build_workflow_llm_span(trace_id, workflow_span_id, trace_info, node_execution) + node_span = self.build_workflow_llm_span(trace_info, node_execution, trace_metadata) elif node_execution.node_type == NodeType.KNOWLEDGE_RETRIEVAL: - node_span = self.build_workflow_retrieval_span(trace_id, workflow_span_id, trace_info, node_execution) + node_span = self.build_workflow_retrieval_span(trace_info, node_execution, trace_metadata) elif node_execution.node_type == NodeType.TOOL: - node_span = self.build_workflow_tool_span(trace_id, workflow_span_id, trace_info, node_execution) + node_span = self.build_workflow_tool_span(trace_info, node_execution, trace_metadata) else: - node_span = self.build_workflow_task_span(trace_id, workflow_span_id, trace_info, node_execution) + node_span = self.build_workflow_task_span(trace_info, node_execution, trace_metadata) return node_span except Exception as e: logger.debug("Error occurred in build_workflow_node_span: %s", e, exc_info=True) return None - def get_workflow_node_status(self, node_execution: WorkflowNodeExecution) -> Status: - span_status: Status = Status(StatusCode.UNSET) - if node_execution.status == WorkflowNodeExecutionStatus.SUCCEEDED: - span_status = Status(StatusCode.OK) - elif node_execution.status in [WorkflowNodeExecutionStatus.FAILED, WorkflowNodeExecutionStatus.EXCEPTION]: - span_status = Status(StatusCode.ERROR, str(node_execution.error)) - return span_status - def build_workflow_task_span( - self, trace_id: int, workflow_span_id: int, trace_info: WorkflowTraceInfo, node_execution: WorkflowNodeExecution + self, trace_info: WorkflowTraceInfo, node_execution: WorkflowNodeExecution, trace_metadata: TraceMetadata ) -> SpanData: + inputs_json = serialize_json_data(node_execution.inputs) + outputs_json = serialize_json_data(node_execution.outputs) return SpanData( - trace_id=trace_id, - parent_span_id=workflow_span_id, + trace_id=trace_metadata.trace_id, + parent_span_id=trace_metadata.workflow_span_id, span_id=convert_to_span_id(node_execution.id, "node"), name=node_execution.title, start_time=convert_datetime_to_nanoseconds(node_execution.created_at), end_time=convert_datetime_to_nanoseconds(node_execution.finished_at), - attributes={ - GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id") or "", - GEN_AI_SPAN_KIND: GenAISpanKind.TASK.value, - GEN_AI_FRAMEWORK: "dify", - INPUT_VALUE: json.dumps(node_execution.inputs, ensure_ascii=False), - OUTPUT_VALUE: json.dumps(node_execution.outputs, ensure_ascii=False), - }, - status=self.get_workflow_node_status(node_execution), + attributes=create_common_span_attributes( + session_id=trace_metadata.session_id, + user_id=trace_metadata.user_id, + span_kind=GenAISpanKind.TASK, + inputs=inputs_json, + outputs=outputs_json, + ), + status=get_workflow_node_status(node_execution), + links=trace_metadata.links, ) def build_workflow_tool_span( - self, trace_id: int, workflow_span_id: int, trace_info: WorkflowTraceInfo, node_execution: WorkflowNodeExecution + self, trace_info: WorkflowTraceInfo, node_execution: WorkflowNodeExecution, trace_metadata: TraceMetadata ) -> SpanData: tool_des = {} if node_execution.metadata: tool_des = node_execution.metadata.get(WorkflowNodeExecutionMetadataKey.TOOL_INFO, {}) + + inputs_json = serialize_json_data(node_execution.inputs or {}) + outputs_json = serialize_json_data(node_execution.outputs) + return SpanData( - trace_id=trace_id, - parent_span_id=workflow_span_id, + trace_id=trace_metadata.trace_id, + parent_span_id=trace_metadata.workflow_span_id, span_id=convert_to_span_id(node_execution.id, "node"), name=node_execution.title, start_time=convert_datetime_to_nanoseconds(node_execution.created_at), end_time=convert_datetime_to_nanoseconds(node_execution.finished_at), attributes={ - GEN_AI_SPAN_KIND: GenAISpanKind.TOOL.value, - GEN_AI_FRAMEWORK: "dify", + **create_common_span_attributes( + session_id=trace_metadata.session_id, + user_id=trace_metadata.user_id, + span_kind=GenAISpanKind.TOOL, + inputs=inputs_json, + outputs=outputs_json, + ), TOOL_NAME: node_execution.title, - TOOL_DESCRIPTION: json.dumps(tool_des, ensure_ascii=False), - TOOL_PARAMETERS: json.dumps(node_execution.inputs or {}, ensure_ascii=False), - INPUT_VALUE: json.dumps(node_execution.inputs or {}, ensure_ascii=False), - OUTPUT_VALUE: json.dumps(node_execution.outputs, ensure_ascii=False), + TOOL_DESCRIPTION: serialize_json_data(tool_des), + TOOL_PARAMETERS: inputs_json, }, - status=self.get_workflow_node_status(node_execution), + status=get_workflow_node_status(node_execution), + links=trace_metadata.links, ) def build_workflow_retrieval_span( - self, trace_id: int, workflow_span_id: int, trace_info: WorkflowTraceInfo, node_execution: WorkflowNodeExecution + self, trace_info: WorkflowTraceInfo, node_execution: WorkflowNodeExecution, trace_metadata: TraceMetadata ) -> SpanData: - input_value = "" - if node_execution.inputs: - input_value = str(node_execution.inputs.get("query", "")) - output_value = "" - if node_execution.outputs: - output_value = json.dumps(node_execution.outputs.get("result", []), ensure_ascii=False) + input_value = str(node_execution.inputs.get("query", "")) if node_execution.inputs else "" + output_value = serialize_json_data(node_execution.outputs.get("result", [])) if node_execution.outputs else "" + + retrieval_documents = node_execution.outputs.get("result", []) if node_execution.outputs else [] + semantic_retrieval_documents = format_retrieval_documents(retrieval_documents) + semantic_retrieval_documents_json = serialize_json_data(semantic_retrieval_documents) + return SpanData( - trace_id=trace_id, - parent_span_id=workflow_span_id, + trace_id=trace_metadata.trace_id, + parent_span_id=trace_metadata.workflow_span_id, span_id=convert_to_span_id(node_execution.id, "node"), name=node_execution.title, start_time=convert_datetime_to_nanoseconds(node_execution.created_at), end_time=convert_datetime_to_nanoseconds(node_execution.finished_at), attributes={ - GEN_AI_SPAN_KIND: GenAISpanKind.RETRIEVER.value, - GEN_AI_FRAMEWORK: "dify", + **create_common_span_attributes( + session_id=trace_metadata.session_id, + user_id=trace_metadata.user_id, + span_kind=GenAISpanKind.RETRIEVER, + inputs=input_value, + outputs=output_value, + ), RETRIEVAL_QUERY: input_value, - RETRIEVAL_DOCUMENT: output_value, - INPUT_VALUE: input_value, - OUTPUT_VALUE: output_value, + RETRIEVAL_DOCUMENT: semantic_retrieval_documents_json, }, - status=self.get_workflow_node_status(node_execution), + status=get_workflow_node_status(node_execution), + links=trace_metadata.links, ) def build_workflow_llm_span( - self, trace_id: int, workflow_span_id: int, trace_info: WorkflowTraceInfo, node_execution: WorkflowNodeExecution + self, trace_info: WorkflowTraceInfo, node_execution: WorkflowNodeExecution, trace_metadata: TraceMetadata ) -> SpanData: process_data = node_execution.process_data or {} outputs = node_execution.outputs or {} usage_data = process_data.get("usage", {}) if "usage" in process_data else outputs.get("usage", {}) + + prompts_json = serialize_json_data(process_data.get("prompts", [])) + text_output = str(outputs.get("text", "")) + + gen_ai_input_message = format_input_messages(process_data) + gen_ai_output_message = format_output_messages(outputs) + return SpanData( - trace_id=trace_id, - parent_span_id=workflow_span_id, + trace_id=trace_metadata.trace_id, + parent_span_id=trace_metadata.workflow_span_id, span_id=convert_to_span_id(node_execution.id, "node"), name=node_execution.title, start_time=convert_datetime_to_nanoseconds(node_execution.created_at), end_time=convert_datetime_to_nanoseconds(node_execution.finished_at), attributes={ - GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id") or "", - GEN_AI_SPAN_KIND: GenAISpanKind.LLM.value, - GEN_AI_FRAMEWORK: "dify", - GEN_AI_MODEL_NAME: process_data.get("model_name") or "", - GEN_AI_SYSTEM: process_data.get("model_provider") or "", + **create_common_span_attributes( + session_id=trace_metadata.session_id, + user_id=trace_metadata.user_id, + span_kind=GenAISpanKind.LLM, + inputs=prompts_json, + outputs=text_output, + ), + GEN_AI_REQUEST_MODEL: process_data.get("model_name") or "", + GEN_AI_PROVIDER_NAME: process_data.get("model_provider") or "", GEN_AI_USAGE_INPUT_TOKENS: str(usage_data.get("prompt_tokens", 0)), GEN_AI_USAGE_OUTPUT_TOKENS: str(usage_data.get("completion_tokens", 0)), GEN_AI_USAGE_TOTAL_TOKENS: str(usage_data.get("total_tokens", 0)), - GEN_AI_PROMPT: json.dumps(process_data.get("prompts", []), ensure_ascii=False), - GEN_AI_COMPLETION: str(outputs.get("text", "")), + GEN_AI_PROMPT: prompts_json, + GEN_AI_COMPLETION: text_output, GEN_AI_RESPONSE_FINISH_REASON: outputs.get("finish_reason") or "", - INPUT_VALUE: json.dumps(process_data.get("prompts", []), ensure_ascii=False), - OUTPUT_VALUE: str(outputs.get("text", "")), + GEN_AI_INPUT_MESSAGE: gen_ai_input_message, + GEN_AI_OUTPUT_MESSAGE: gen_ai_output_message, }, - status=self.get_workflow_node_status(node_execution), + status=get_workflow_node_status(node_execution), + links=trace_metadata.links, ) - def add_workflow_span( - self, trace_id: int, workflow_span_id: int, trace_info: WorkflowTraceInfo, links: Sequence[Link] - ): + def add_workflow_span(self, trace_info: WorkflowTraceInfo, trace_metadata: TraceMetadata): message_span_id = None if trace_info.message_id: message_span_id = convert_to_span_id(trace_info.message_id, "message") - user_id = trace_info.metadata.get("user_id") - status: Status = Status(StatusCode.OK) - if trace_info.error: - status = Status(StatusCode.ERROR, trace_info.error) - if message_span_id: # chatflow + status = create_status_from_error(trace_info.error) + + inputs_json = serialize_json_data(trace_info.workflow_run_inputs) + outputs_json = serialize_json_data(trace_info.workflow_run_outputs) + + if message_span_id: message_span = SpanData( - trace_id=trace_id, + trace_id=trace_metadata.trace_id, parent_span_id=None, span_id=message_span_id, name="message", start_time=convert_datetime_to_nanoseconds(trace_info.start_time), end_time=convert_datetime_to_nanoseconds(trace_info.end_time), - attributes={ - GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id") or "", - GEN_AI_USER_ID: str(user_id), - GEN_AI_SPAN_KIND: GenAISpanKind.CHAIN.value, - GEN_AI_FRAMEWORK: "dify", - INPUT_VALUE: trace_info.workflow_run_inputs.get("sys.query") or "", - OUTPUT_VALUE: json.dumps(trace_info.workflow_run_outputs, ensure_ascii=False), - }, + attributes=create_common_span_attributes( + session_id=trace_metadata.session_id, + user_id=trace_metadata.user_id, + span_kind=GenAISpanKind.CHAIN, + inputs=trace_info.workflow_run_inputs.get("sys.query") or "", + outputs=outputs_json, + ), status=status, - links=links, + links=trace_metadata.links, ) self.trace_client.add_span(message_span) workflow_span = SpanData( - trace_id=trace_id, + trace_id=trace_metadata.trace_id, parent_span_id=message_span_id, - span_id=workflow_span_id, + span_id=trace_metadata.workflow_span_id, name="workflow", start_time=convert_datetime_to_nanoseconds(trace_info.start_time), end_time=convert_datetime_to_nanoseconds(trace_info.end_time), - attributes={ - GEN_AI_USER_ID: str(user_id), - GEN_AI_SPAN_KIND: GenAISpanKind.CHAIN.value, - GEN_AI_FRAMEWORK: "dify", - INPUT_VALUE: json.dumps(trace_info.workflow_run_inputs, ensure_ascii=False), - OUTPUT_VALUE: json.dumps(trace_info.workflow_run_outputs, ensure_ascii=False), - }, + attributes=create_common_span_attributes( + session_id=trace_metadata.session_id, + user_id=trace_metadata.user_id, + span_kind=GenAISpanKind.CHAIN, + inputs=inputs_json, + outputs=outputs_json, + ), status=status, - links=links, + links=trace_metadata.links, ) self.trace_client.add_span(workflow_span) def suggested_question_trace(self, trace_info: SuggestedQuestionTraceInfo): message_id = trace_info.message_id - status: Status = Status(StatusCode.OK) - if trace_info.error: - status = Status(StatusCode.ERROR, trace_info.error) + status = create_status_from_error(trace_info.error) - trace_id = convert_to_trace_id(message_id) - links = [] - if trace_info.trace_id: - links.append(create_link(trace_id_str=trace_info.trace_id)) + trace_metadata = TraceMetadata( + trace_id=convert_to_trace_id(message_id), + workflow_span_id=0, + session_id=trace_info.metadata.get("conversation_id") or "", + user_id=str(trace_info.metadata.get("user_id") or ""), + links=create_links_from_trace_id(trace_info.trace_id), + ) + + inputs_json = serialize_json_data(trace_info.inputs) + suggested_question_json = serialize_json_data(trace_info.suggested_question) suggested_question_span = SpanData( - trace_id=trace_id, + trace_id=trace_metadata.trace_id, parent_span_id=convert_to_span_id(message_id, "message"), span_id=convert_to_span_id(message_id, "suggested_question"), name="suggested_question", start_time=convert_datetime_to_nanoseconds(trace_info.start_time), end_time=convert_datetime_to_nanoseconds(trace_info.end_time), attributes={ - GEN_AI_SPAN_KIND: GenAISpanKind.LLM.value, - GEN_AI_FRAMEWORK: "dify", - GEN_AI_MODEL_NAME: trace_info.metadata.get("ls_model_name") or "", - GEN_AI_SYSTEM: trace_info.metadata.get("ls_provider") or "", - GEN_AI_PROMPT: json.dumps(trace_info.inputs, ensure_ascii=False), - GEN_AI_COMPLETION: json.dumps(trace_info.suggested_question, ensure_ascii=False), - INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False), - OUTPUT_VALUE: json.dumps(trace_info.suggested_question, ensure_ascii=False), + **create_common_span_attributes( + session_id=trace_metadata.session_id, + user_id=trace_metadata.user_id, + span_kind=GenAISpanKind.LLM, + inputs=inputs_json, + outputs=suggested_question_json, + ), + GEN_AI_REQUEST_MODEL: trace_info.metadata.get("ls_model_name") or "", + GEN_AI_PROVIDER_NAME: trace_info.metadata.get("ls_provider") or "", + GEN_AI_PROMPT: inputs_json, + GEN_AI_COMPLETION: suggested_question_json, }, status=status, - links=links, + links=trace_metadata.links, ) self.trace_client.add_span(suggested_question_span) - - -def extract_retrieval_documents(documents: list[Document]): - documents_data = [] - for document in documents: - document_data = { - "content": document.page_content, - "metadata": { - "dataset_id": document.metadata.get("dataset_id"), - "doc_id": document.metadata.get("doc_id"), - "document_id": document.metadata.get("document_id"), - }, - "score": document.metadata.get("score"), - } - documents_data.append(document_data) - return documents_data diff --git a/api/core/ops/aliyun_trace/data_exporter/traceclient.py b/api/core/ops/aliyun_trace/data_exporter/traceclient.py index 09cb6e3fc1..f54405b5de 100644 --- a/api/core/ops/aliyun_trace/data_exporter/traceclient.py +++ b/api/core/ops/aliyun_trace/data_exporter/traceclient.py @@ -7,8 +7,10 @@ import uuid from collections import deque from collections.abc import Sequence from datetime import datetime +from typing import Final +from urllib.parse import urljoin -import requests +import httpx from opentelemetry import trace as trace_api from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter from opentelemetry.sdk.resources import Resource @@ -20,8 +22,12 @@ from opentelemetry.trace import Link, SpanContext, TraceFlags from configs import dify_config from core.ops.aliyun_trace.entities.aliyun_trace_entity import SpanData -INVALID_SPAN_ID = 0x0000000000000000 -INVALID_TRACE_ID = 0x00000000000000000000000000000000 +INVALID_SPAN_ID: Final[int] = 0x0000000000000000 +INVALID_TRACE_ID: Final[int] = 0x00000000000000000000000000000000 +DEFAULT_TIMEOUT: Final[int] = 5 +DEFAULT_MAX_QUEUE_SIZE: Final[int] = 1000 +DEFAULT_SCHEDULE_DELAY_SEC: Final[int] = 5 +DEFAULT_MAX_EXPORT_BATCH_SIZE: Final[int] = 50 logger = logging.getLogger(__name__) @@ -31,9 +37,9 @@ class TraceClient: self, service_name: str, endpoint: str, - max_queue_size: int = 1000, - schedule_delay_sec: int = 5, - max_export_batch_size: int = 50, + max_queue_size: int = DEFAULT_MAX_QUEUE_SIZE, + schedule_delay_sec: int = DEFAULT_SCHEDULE_DELAY_SEC, + max_export_batch_size: int = DEFAULT_MAX_EXPORT_BATCH_SIZE, ): self.endpoint = endpoint self.resource = Resource( @@ -63,24 +69,25 @@ class TraceClient: def export(self, spans: Sequence[ReadableSpan]): self.exporter.export(spans) - def api_check(self): + def api_check(self) -> bool: try: - response = requests.head(self.endpoint, timeout=5) + response = httpx.head(self.endpoint, timeout=DEFAULT_TIMEOUT) if response.status_code == 405: return True else: logger.debug("AliyunTrace API check failed: Unexpected status code: %s", response.status_code) return False - except requests.RequestException as e: + except httpx.RequestError as e: logger.debug("AliyunTrace API check failed: %s", str(e)) raise ValueError(f"AliyunTrace API check failed: {str(e)}") - def get_project_url(self): + def get_project_url(self) -> str: return "https://arms.console.aliyun.com/#/llm" - def add_span(self, span_data: SpanData): + def add_span(self, span_data: SpanData | None) -> None: if span_data is None: return + span: ReadableSpan = self.span_builder.build_span(span_data) with self.condition: if len(self.queue) == self.max_queue_size: @@ -92,14 +99,14 @@ class TraceClient: if len(self.queue) >= self.max_export_batch_size: self.condition.notify() - def _worker(self): + def _worker(self) -> None: while not self.done: with self.condition: if len(self.queue) < self.max_export_batch_size and not self.done: self.condition.wait(timeout=self.schedule_delay_sec) self._export_batch() - def _export_batch(self): + def _export_batch(self) -> None: spans_to_export: list[ReadableSpan] = [] with self.condition: while len(spans_to_export) < self.max_export_batch_size and self.queue: @@ -111,7 +118,7 @@ class TraceClient: except Exception as e: logger.debug("Error exporting spans: %s", e) - def shutdown(self): + def shutdown(self) -> None: with self.condition: self.done = True self.condition.notify_all() @@ -121,7 +128,7 @@ class TraceClient: class SpanBuilder: - def __init__(self, resource): + def __init__(self, resource: Resource) -> None: self.resource = resource self.instrumentation_scope = InstrumentationScope( __name__, @@ -167,8 +174,12 @@ class SpanBuilder: def create_link(trace_id_str: str) -> Link: - placeholder_span_id = 0x0000000000000000 - trace_id = int(trace_id_str, 16) + placeholder_span_id = INVALID_SPAN_ID + try: + trace_id = int(trace_id_str, 16) + except ValueError as e: + raise ValueError(f"Invalid trace ID format: {trace_id_str}") from e + span_context = SpanContext( trace_id=trace_id, span_id=placeholder_span_id, is_remote=False, trace_flags=TraceFlags(TraceFlags.SAMPLED) ) @@ -184,26 +195,29 @@ def generate_span_id() -> int: def convert_to_trace_id(uuid_v4: str | None) -> int: + if uuid_v4 is None: + raise ValueError("UUID cannot be None") try: uuid_obj = uuid.UUID(uuid_v4) return uuid_obj.int - except Exception as e: - raise ValueError(f"Invalid UUID input: {e}") + except ValueError as e: + raise ValueError(f"Invalid UUID input: {uuid_v4}") from e def convert_string_to_id(string: str | None) -> int: if not string: return generate_span_id() hash_bytes = hashlib.sha256(string.encode("utf-8")).digest() - id = int.from_bytes(hash_bytes[:8], byteorder="big", signed=False) - return id + return int.from_bytes(hash_bytes[:8], byteorder="big", signed=False) def convert_to_span_id(uuid_v4: str | None, span_type: str) -> int: + if uuid_v4 is None: + raise ValueError("UUID cannot be None") try: uuid_obj = uuid.UUID(uuid_v4) - except Exception as e: - raise ValueError(f"Invalid UUID input: {e}") + except ValueError as e: + raise ValueError(f"Invalid UUID input: {uuid_v4}") from e combined_key = f"{uuid_obj.hex}-{span_type}" return convert_string_to_id(combined_key) @@ -212,5 +226,11 @@ def convert_datetime_to_nanoseconds(start_time_a: datetime | None) -> int | None if start_time_a is None: return None timestamp_in_seconds = start_time_a.timestamp() - timestamp_in_nanoseconds = int(timestamp_in_seconds * 1e9) - return timestamp_in_nanoseconds + return int(timestamp_in_seconds * 1e9) + + +def build_endpoint(base_url: str, license_key: str) -> str: + if "log.aliyuncs.com" in base_url: # cms2.0 endpoint + return urljoin(base_url, f"adapt_{license_key}/api/v1/traces") + else: # xtrace endpoint + return urljoin(base_url, f"adapt_{license_key}/api/otlp/traces") diff --git a/api/core/ops/aliyun_trace/entities/aliyun_trace_entity.py b/api/core/ops/aliyun_trace/entities/aliyun_trace_entity.py index f3dcbc5b8f..20ff2d0875 100644 --- a/api/core/ops/aliyun_trace/entities/aliyun_trace_entity.py +++ b/api/core/ops/aliyun_trace/entities/aliyun_trace_entity.py @@ -1,18 +1,34 @@ from collections.abc import Sequence +from dataclasses import dataclass +from typing import Any from opentelemetry import trace as trace_api -from opentelemetry.sdk.trace import Event, Status, StatusCode +from opentelemetry.sdk.trace import Event +from opentelemetry.trace import Status, StatusCode from pydantic import BaseModel, Field +@dataclass +class TraceMetadata: + """Metadata for trace operations, containing common attributes for all spans in a trace.""" + + trace_id: int + workflow_span_id: int + session_id: str + user_id: str + links: list[trace_api.Link] + + class SpanData(BaseModel): + """Data model for span information in Aliyun trace system.""" + model_config = {"arbitrary_types_allowed": True} trace_id: int = Field(..., description="The unique identifier for the trace.") parent_span_id: int | None = Field(None, description="The ID of the parent span, if any.") span_id: int = Field(..., description="The unique identifier for this span.") name: str = Field(..., description="The name of the span.") - attributes: dict[str, str] = Field(default_factory=dict, description="Attributes associated with the span.") + attributes: dict[str, Any] = Field(default_factory=dict, description="Attributes associated with the span.") events: Sequence[Event] = Field(default_factory=list, description="Events recorded in the span.") links: Sequence[trace_api.Link] = Field(default_factory=list, description="Links to other spans.") status: Status = Field(default=Status(StatusCode.UNSET), description="The status of the span.") diff --git a/api/core/ops/aliyun_trace/entities/semconv.py b/api/core/ops/aliyun_trace/entities/semconv.py index c9427c776a..c823fcab8a 100644 --- a/api/core/ops/aliyun_trace/entities/semconv.py +++ b/api/core/ops/aliyun_trace/entities/semconv.py @@ -1,56 +1,38 @@ from enum import StrEnum +from typing import Final -# public -GEN_AI_SESSION_ID = "gen_ai.session.id" +# Public attributes +GEN_AI_SESSION_ID: Final[str] = "gen_ai.session.id" +GEN_AI_USER_ID: Final[str] = "gen_ai.user.id" +GEN_AI_USER_NAME: Final[str] = "gen_ai.user.name" +GEN_AI_SPAN_KIND: Final[str] = "gen_ai.span.kind" +GEN_AI_FRAMEWORK: Final[str] = "gen_ai.framework" -GEN_AI_USER_ID = "gen_ai.user.id" +# Chain attributes +INPUT_VALUE: Final[str] = "input.value" +OUTPUT_VALUE: Final[str] = "output.value" -GEN_AI_USER_NAME = "gen_ai.user.name" +# Retriever attributes +RETRIEVAL_QUERY: Final[str] = "retrieval.query" +RETRIEVAL_DOCUMENT: Final[str] = "retrieval.document" -GEN_AI_SPAN_KIND = "gen_ai.span.kind" +# LLM attributes +GEN_AI_REQUEST_MODEL: Final[str] = "gen_ai.request.model" +GEN_AI_PROVIDER_NAME: Final[str] = "gen_ai.provider.name" +GEN_AI_USAGE_INPUT_TOKENS: Final[str] = "gen_ai.usage.input_tokens" +GEN_AI_USAGE_OUTPUT_TOKENS: Final[str] = "gen_ai.usage.output_tokens" +GEN_AI_USAGE_TOTAL_TOKENS: Final[str] = "gen_ai.usage.total_tokens" +GEN_AI_PROMPT: Final[str] = "gen_ai.prompt" +GEN_AI_COMPLETION: Final[str] = "gen_ai.completion" +GEN_AI_RESPONSE_FINISH_REASON: Final[str] = "gen_ai.response.finish_reason" -GEN_AI_FRAMEWORK = "gen_ai.framework" +GEN_AI_INPUT_MESSAGE: Final[str] = "gen_ai.input.messages" +GEN_AI_OUTPUT_MESSAGE: Final[str] = "gen_ai.output.messages" - -# Chain -INPUT_VALUE = "input.value" - -OUTPUT_VALUE = "output.value" - - -# Retriever -RETRIEVAL_QUERY = "retrieval.query" - -RETRIEVAL_DOCUMENT = "retrieval.document" - - -# LLM -GEN_AI_MODEL_NAME = "gen_ai.model_name" - -GEN_AI_SYSTEM = "gen_ai.system" - -GEN_AI_USAGE_INPUT_TOKENS = "gen_ai.usage.input_tokens" - -GEN_AI_USAGE_OUTPUT_TOKENS = "gen_ai.usage.output_tokens" - -GEN_AI_USAGE_TOTAL_TOKENS = "gen_ai.usage.total_tokens" - -GEN_AI_PROMPT_TEMPLATE_TEMPLATE = "gen_ai.prompt_template.template" - -GEN_AI_PROMPT_TEMPLATE_VARIABLE = "gen_ai.prompt_template.variable" - -GEN_AI_PROMPT = "gen_ai.prompt" - -GEN_AI_COMPLETION = "gen_ai.completion" - -GEN_AI_RESPONSE_FINISH_REASON = "gen_ai.response.finish_reason" - -# Tool -TOOL_NAME = "tool.name" - -TOOL_DESCRIPTION = "tool.description" - -TOOL_PARAMETERS = "tool.parameters" +# Tool attributes +TOOL_NAME: Final[str] = "tool.name" +TOOL_DESCRIPTION: Final[str] = "tool.description" +TOOL_PARAMETERS: Final[str] = "tool.parameters" class GenAISpanKind(StrEnum): diff --git a/api/core/ops/aliyun_trace/utils.py b/api/core/ops/aliyun_trace/utils.py new file mode 100644 index 0000000000..7f68889e92 --- /dev/null +++ b/api/core/ops/aliyun_trace/utils.py @@ -0,0 +1,190 @@ +import json +from collections.abc import Mapping +from typing import Any + +from opentelemetry.trace import Link, Status, StatusCode + +from core.ops.aliyun_trace.entities.semconv import ( + GEN_AI_FRAMEWORK, + GEN_AI_SESSION_ID, + GEN_AI_SPAN_KIND, + GEN_AI_USER_ID, + INPUT_VALUE, + OUTPUT_VALUE, + GenAISpanKind, +) +from core.rag.models.document import Document +from core.workflow.entities import WorkflowNodeExecution +from core.workflow.enums import WorkflowNodeExecutionStatus +from extensions.ext_database import db +from models import EndUser + +# Constants +DEFAULT_JSON_ENSURE_ASCII = False +DEFAULT_FRAMEWORK_NAME = "dify" + + +def get_user_id_from_message_data(message_data) -> str: + user_id = message_data.from_account_id + if message_data.from_end_user_id: + end_user_data: EndUser | None = ( + db.session.query(EndUser).where(EndUser.id == message_data.from_end_user_id).first() + ) + if end_user_data is not None: + user_id = end_user_data.session_id + return user_id + + +def create_status_from_error(error: str | None) -> Status: + if error: + return Status(StatusCode.ERROR, error) + return Status(StatusCode.OK) + + +def get_workflow_node_status(node_execution: WorkflowNodeExecution) -> Status: + if node_execution.status == WorkflowNodeExecutionStatus.SUCCEEDED: + return Status(StatusCode.OK) + if node_execution.status in [WorkflowNodeExecutionStatus.FAILED, WorkflowNodeExecutionStatus.EXCEPTION]: + return Status(StatusCode.ERROR, str(node_execution.error)) + return Status(StatusCode.UNSET) + + +def create_links_from_trace_id(trace_id: str | None) -> list[Link]: + from core.ops.aliyun_trace.data_exporter.traceclient import create_link + + links = [] + if trace_id: + links.append(create_link(trace_id_str=trace_id)) + return links + + +def extract_retrieval_documents(documents: list[Document]) -> list[dict[str, Any]]: + documents_data = [] + for document in documents: + document_data = { + "content": document.page_content, + "metadata": { + "dataset_id": document.metadata.get("dataset_id"), + "doc_id": document.metadata.get("doc_id"), + "document_id": document.metadata.get("document_id"), + }, + "score": document.metadata.get("score"), + } + documents_data.append(document_data) + return documents_data + + +def serialize_json_data(data: Any, ensure_ascii: bool = DEFAULT_JSON_ENSURE_ASCII) -> str: + return json.dumps(data, ensure_ascii=ensure_ascii) + + +def create_common_span_attributes( + session_id: str = "", + user_id: str = "", + span_kind: str = GenAISpanKind.CHAIN, + framework: str = DEFAULT_FRAMEWORK_NAME, + inputs: str = "", + outputs: str = "", +) -> dict[str, Any]: + return { + GEN_AI_SESSION_ID: session_id, + GEN_AI_USER_ID: user_id, + GEN_AI_SPAN_KIND: span_kind, + GEN_AI_FRAMEWORK: framework, + INPUT_VALUE: inputs, + OUTPUT_VALUE: outputs, + } + + +def format_retrieval_documents(retrieval_documents: list) -> list: + try: + if not isinstance(retrieval_documents, list): + return [] + + semantic_documents = [] + for doc in retrieval_documents: + if not isinstance(doc, dict): + continue + + metadata = doc.get("metadata", {}) + content = doc.get("content", "") + title = doc.get("title", "") + score = metadata.get("score", 0.0) + document_id = metadata.get("document_id", "") + + semantic_metadata = {} + if title: + semantic_metadata["title"] = title + if metadata.get("source"): + semantic_metadata["source"] = metadata["source"] + elif metadata.get("_source"): + semantic_metadata["source"] = metadata["_source"] + if metadata.get("doc_metadata"): + doc_metadata = metadata["doc_metadata"] + if isinstance(doc_metadata, dict): + semantic_metadata.update(doc_metadata) + + semantic_doc = { + "document": {"content": content, "metadata": semantic_metadata, "score": score, "id": document_id} + } + semantic_documents.append(semantic_doc) + + return semantic_documents + except Exception: + return [] + + +def format_input_messages(process_data: Mapping[str, Any]) -> str: + try: + if not isinstance(process_data, dict): + return serialize_json_data([]) + + prompts = process_data.get("prompts", []) + if not prompts: + return serialize_json_data([]) + + valid_roles = {"system", "user", "assistant", "tool"} + input_messages = [] + for prompt in prompts: + if not isinstance(prompt, dict): + continue + + role = prompt.get("role", "") + text = prompt.get("text", "") + + if not role or role not in valid_roles: + continue + + if text: + message = {"role": role, "parts": [{"type": "text", "content": text}]} + input_messages.append(message) + + return serialize_json_data(input_messages) + except Exception: + return serialize_json_data([]) + + +def format_output_messages(outputs: Mapping[str, Any]) -> str: + try: + if not isinstance(outputs, dict): + return serialize_json_data([]) + + text = outputs.get("text", "") + finish_reason = outputs.get("finish_reason", "") + + if not text: + return serialize_json_data([]) + + valid_finish_reasons = {"stop", "length", "content_filter", "tool_call", "error"} + if finish_reason not in valid_finish_reasons: + finish_reason = "stop" + + output_message = { + "role": "assistant", + "parts": [{"type": "text", "content": text}], + "finish_reason": finish_reason, + } + + return serialize_json_data([output_message]) + except Exception: + return serialize_json_data([]) diff --git a/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py b/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py index 1497bc1863..03d2d75372 100644 --- a/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py +++ b/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py @@ -213,9 +213,9 @@ class ArizePhoenixDataTrace(BaseTraceInstance): node_metadata.update(json.loads(node_execution.execution_metadata)) # Determine the correct span kind based on node type - span_kind = OpenInferenceSpanKindValues.CHAIN.value + span_kind = OpenInferenceSpanKindValues.CHAIN if node_execution.node_type == "llm": - span_kind = OpenInferenceSpanKindValues.LLM.value + span_kind = OpenInferenceSpanKindValues.LLM provider = process_data.get("model_provider") model = process_data.get("model_name") if provider: @@ -230,18 +230,18 @@ class ArizePhoenixDataTrace(BaseTraceInstance): node_metadata["prompt_tokens"] = usage_data.get("prompt_tokens", 0) node_metadata["completion_tokens"] = usage_data.get("completion_tokens", 0) elif node_execution.node_type == "dataset_retrieval": - span_kind = OpenInferenceSpanKindValues.RETRIEVER.value + span_kind = OpenInferenceSpanKindValues.RETRIEVER elif node_execution.node_type == "tool": - span_kind = OpenInferenceSpanKindValues.TOOL.value + span_kind = OpenInferenceSpanKindValues.TOOL else: - span_kind = OpenInferenceSpanKindValues.CHAIN.value + span_kind = OpenInferenceSpanKindValues.CHAIN node_span = self.tracer.start_span( name=node_execution.node_type, attributes={ SpanAttributes.INPUT_VALUE: node_execution.inputs or "{}", SpanAttributes.OUTPUT_VALUE: node_execution.outputs or "{}", - SpanAttributes.OPENINFERENCE_SPAN_KIND: span_kind, + SpanAttributes.OPENINFERENCE_SPAN_KIND: span_kind.value, SpanAttributes.METADATA: json.dumps(node_metadata, ensure_ascii=False), SpanAttributes.SESSION_ID: trace_info.conversation_id or "", }, diff --git a/api/core/ops/entities/config_entity.py b/api/core/ops/entities/config_entity.py index 851a77fbc1..4ba6eb0780 100644 --- a/api/core/ops/entities/config_entity.py +++ b/api/core/ops/entities/config_entity.py @@ -191,7 +191,8 @@ class AliyunConfig(BaseTracingConfig): @field_validator("endpoint") @classmethod def endpoint_validator(cls, v, info: ValidationInfo): - return cls.validate_endpoint_url(v, "https://tracing-analysis-dc-hz.aliyuncs.com") + # aliyun uses two URL formats, which may include a URL path + return validate_url_with_path(v, "https://tracing-analysis-dc-hz.aliyuncs.com") OPS_FILE_PATH = "ops_trace/" diff --git a/api/core/ops/langfuse_trace/langfuse_trace.py b/api/core/ops/langfuse_trace/langfuse_trace.py index 931bed78d4..92e6b8ea60 100644 --- a/api/core/ops/langfuse_trace/langfuse_trace.py +++ b/api/core/ops/langfuse_trace/langfuse_trace.py @@ -73,7 +73,7 @@ class LangFuseDataTrace(BaseTraceInstance): if trace_info.message_id: trace_id = trace_info.trace_id or trace_info.message_id - name = TraceTaskName.MESSAGE_TRACE.value + name = TraceTaskName.MESSAGE_TRACE trace_data = LangfuseTrace( id=trace_id, user_id=user_id, @@ -88,7 +88,7 @@ class LangFuseDataTrace(BaseTraceInstance): self.add_trace(langfuse_trace_data=trace_data) workflow_span_data = LangfuseSpan( id=trace_info.workflow_run_id, - name=TraceTaskName.WORKFLOW_TRACE.value, + name=TraceTaskName.WORKFLOW_TRACE, input=dict(trace_info.workflow_run_inputs), output=dict(trace_info.workflow_run_outputs), trace_id=trace_id, @@ -103,7 +103,7 @@ class LangFuseDataTrace(BaseTraceInstance): trace_data = LangfuseTrace( id=trace_id, user_id=user_id, - name=TraceTaskName.WORKFLOW_TRACE.value, + name=TraceTaskName.WORKFLOW_TRACE, input=dict(trace_info.workflow_run_inputs), output=dict(trace_info.workflow_run_outputs), metadata=metadata, @@ -253,7 +253,7 @@ class LangFuseDataTrace(BaseTraceInstance): trace_data = LangfuseTrace( id=trace_id, user_id=user_id, - name=TraceTaskName.MESSAGE_TRACE.value, + name=TraceTaskName.MESSAGE_TRACE, input={ "message": trace_info.inputs, "files": file_list, @@ -303,7 +303,7 @@ class LangFuseDataTrace(BaseTraceInstance): if trace_info.message_data is None: return span_data = LangfuseSpan( - name=TraceTaskName.MODERATION_TRACE.value, + name=TraceTaskName.MODERATION_TRACE, input=trace_info.inputs, output={ "action": trace_info.action, @@ -331,7 +331,7 @@ class LangFuseDataTrace(BaseTraceInstance): ) generation_data = LangfuseGeneration( - name=TraceTaskName.SUGGESTED_QUESTION_TRACE.value, + name=TraceTaskName.SUGGESTED_QUESTION_TRACE, input=trace_info.inputs, output=str(trace_info.suggested_question), trace_id=trace_info.trace_id or trace_info.message_id, @@ -349,7 +349,7 @@ class LangFuseDataTrace(BaseTraceInstance): if trace_info.message_data is None: return dataset_retrieval_span_data = LangfuseSpan( - name=TraceTaskName.DATASET_RETRIEVAL_TRACE.value, + name=TraceTaskName.DATASET_RETRIEVAL_TRACE, input=trace_info.inputs, output={"documents": trace_info.documents}, trace_id=trace_info.trace_id or trace_info.message_id, @@ -377,7 +377,7 @@ class LangFuseDataTrace(BaseTraceInstance): def generate_name_trace(self, trace_info: GenerateNameTraceInfo): name_generation_trace_data = LangfuseTrace( - name=TraceTaskName.GENERATE_NAME_TRACE.value, + name=TraceTaskName.GENERATE_NAME_TRACE, input=trace_info.inputs, output=trace_info.outputs, user_id=trace_info.tenant_id, @@ -388,7 +388,7 @@ class LangFuseDataTrace(BaseTraceInstance): self.add_trace(langfuse_trace_data=name_generation_trace_data) name_generation_span_data = LangfuseSpan( - name=TraceTaskName.GENERATE_NAME_TRACE.value, + name=TraceTaskName.GENERATE_NAME_TRACE, input=trace_info.inputs, output=trace_info.outputs, trace_id=trace_info.conversation_id, diff --git a/api/core/ops/langsmith_trace/langsmith_trace.py b/api/core/ops/langsmith_trace/langsmith_trace.py index 24a43e1cd8..8b8117b24c 100644 --- a/api/core/ops/langsmith_trace/langsmith_trace.py +++ b/api/core/ops/langsmith_trace/langsmith_trace.py @@ -81,7 +81,7 @@ class LangSmithDataTrace(BaseTraceInstance): if trace_info.message_id: message_run = LangSmithRunModel( id=trace_info.message_id, - name=TraceTaskName.MESSAGE_TRACE.value, + name=TraceTaskName.MESSAGE_TRACE, inputs=dict(trace_info.workflow_run_inputs), outputs=dict(trace_info.workflow_run_outputs), run_type=LangSmithRunType.chain, @@ -110,7 +110,7 @@ class LangSmithDataTrace(BaseTraceInstance): file_list=trace_info.file_list, total_tokens=trace_info.total_tokens, id=trace_info.workflow_run_id, - name=TraceTaskName.WORKFLOW_TRACE.value, + name=TraceTaskName.WORKFLOW_TRACE, inputs=dict(trace_info.workflow_run_inputs), run_type=LangSmithRunType.tool, start_time=trace_info.workflow_data.created_at, @@ -271,7 +271,7 @@ class LangSmithDataTrace(BaseTraceInstance): output_tokens=trace_info.answer_tokens, total_tokens=trace_info.total_tokens, id=message_id, - name=TraceTaskName.MESSAGE_TRACE.value, + name=TraceTaskName.MESSAGE_TRACE, inputs=trace_info.inputs, run_type=LangSmithRunType.chain, start_time=trace_info.start_time, @@ -327,7 +327,7 @@ class LangSmithDataTrace(BaseTraceInstance): if trace_info.message_data is None: return langsmith_run = LangSmithRunModel( - name=TraceTaskName.MODERATION_TRACE.value, + name=TraceTaskName.MODERATION_TRACE, inputs=trace_info.inputs, outputs={ "action": trace_info.action, @@ -362,7 +362,7 @@ class LangSmithDataTrace(BaseTraceInstance): if message_data is None: return suggested_question_run = LangSmithRunModel( - name=TraceTaskName.SUGGESTED_QUESTION_TRACE.value, + name=TraceTaskName.SUGGESTED_QUESTION_TRACE, inputs=trace_info.inputs, outputs=trace_info.suggested_question, run_type=LangSmithRunType.tool, @@ -391,7 +391,7 @@ class LangSmithDataTrace(BaseTraceInstance): if trace_info.message_data is None: return dataset_retrieval_run = LangSmithRunModel( - name=TraceTaskName.DATASET_RETRIEVAL_TRACE.value, + name=TraceTaskName.DATASET_RETRIEVAL_TRACE, inputs=trace_info.inputs, outputs={"documents": trace_info.documents}, run_type=LangSmithRunType.retriever, @@ -447,7 +447,7 @@ class LangSmithDataTrace(BaseTraceInstance): def generate_name_trace(self, trace_info: GenerateNameTraceInfo): name_run = LangSmithRunModel( - name=TraceTaskName.GENERATE_NAME_TRACE.value, + name=TraceTaskName.GENERATE_NAME_TRACE, inputs=trace_info.inputs, outputs=trace_info.outputs, run_type=LangSmithRunType.tool, diff --git a/api/core/ops/opik_trace/opik_trace.py b/api/core/ops/opik_trace/opik_trace.py index 8fa92f9fcd..8050c59db9 100644 --- a/api/core/ops/opik_trace/opik_trace.py +++ b/api/core/ops/opik_trace/opik_trace.py @@ -108,7 +108,7 @@ class OpikDataTrace(BaseTraceInstance): trace_data = { "id": opik_trace_id, - "name": TraceTaskName.MESSAGE_TRACE.value, + "name": TraceTaskName.MESSAGE_TRACE, "start_time": trace_info.start_time, "end_time": trace_info.end_time, "metadata": workflow_metadata, @@ -125,7 +125,7 @@ class OpikDataTrace(BaseTraceInstance): "id": root_span_id, "parent_span_id": None, "trace_id": opik_trace_id, - "name": TraceTaskName.WORKFLOW_TRACE.value, + "name": TraceTaskName.WORKFLOW_TRACE, "input": wrap_dict("input", trace_info.workflow_run_inputs), "output": wrap_dict("output", trace_info.workflow_run_outputs), "start_time": trace_info.start_time, @@ -138,7 +138,7 @@ class OpikDataTrace(BaseTraceInstance): else: trace_data = { "id": opik_trace_id, - "name": TraceTaskName.MESSAGE_TRACE.value, + "name": TraceTaskName.MESSAGE_TRACE, "start_time": trace_info.start_time, "end_time": trace_info.end_time, "metadata": workflow_metadata, @@ -290,7 +290,7 @@ class OpikDataTrace(BaseTraceInstance): trace_data = { "id": prepare_opik_uuid(trace_info.start_time, dify_trace_id), - "name": TraceTaskName.MESSAGE_TRACE.value, + "name": TraceTaskName.MESSAGE_TRACE, "start_time": trace_info.start_time, "end_time": trace_info.end_time, "metadata": wrap_metadata(metadata), @@ -329,7 +329,7 @@ class OpikDataTrace(BaseTraceInstance): span_data = { "trace_id": prepare_opik_uuid(start_time, trace_info.trace_id or trace_info.message_id), - "name": TraceTaskName.MODERATION_TRACE.value, + "name": TraceTaskName.MODERATION_TRACE, "type": "tool", "start_time": start_time, "end_time": trace_info.end_time or trace_info.message_data.updated_at, @@ -355,7 +355,7 @@ class OpikDataTrace(BaseTraceInstance): span_data = { "trace_id": prepare_opik_uuid(start_time, trace_info.trace_id or trace_info.message_id), - "name": TraceTaskName.SUGGESTED_QUESTION_TRACE.value, + "name": TraceTaskName.SUGGESTED_QUESTION_TRACE, "type": "tool", "start_time": start_time, "end_time": trace_info.end_time or message_data.updated_at, @@ -375,7 +375,7 @@ class OpikDataTrace(BaseTraceInstance): span_data = { "trace_id": prepare_opik_uuid(start_time, trace_info.trace_id or trace_info.message_id), - "name": TraceTaskName.DATASET_RETRIEVAL_TRACE.value, + "name": TraceTaskName.DATASET_RETRIEVAL_TRACE, "type": "tool", "start_time": start_time, "end_time": trace_info.end_time or trace_info.message_data.updated_at, @@ -405,7 +405,7 @@ class OpikDataTrace(BaseTraceInstance): def generate_name_trace(self, trace_info: GenerateNameTraceInfo): trace_data = { "id": prepare_opik_uuid(trace_info.start_time, trace_info.trace_id or trace_info.message_id), - "name": TraceTaskName.GENERATE_NAME_TRACE.value, + "name": TraceTaskName.GENERATE_NAME_TRACE, "start_time": trace_info.start_time, "end_time": trace_info.end_time, "metadata": wrap_metadata(trace_info.metadata), @@ -420,7 +420,7 @@ class OpikDataTrace(BaseTraceInstance): span_data = { "trace_id": trace.id, - "name": TraceTaskName.GENERATE_NAME_TRACE.value, + "name": TraceTaskName.GENERATE_NAME_TRACE, "start_time": trace_info.start_time, "end_time": trace_info.end_time, "metadata": wrap_metadata(trace_info.metadata), diff --git a/api/core/ops/ops_trace_manager.py b/api/core/ops/ops_trace_manager.py index 0679b27271..e181373bd0 100644 --- a/api/core/ops/ops_trace_manager.py +++ b/api/core/ops/ops_trace_manager.py @@ -155,7 +155,10 @@ class OpsTraceManager: if key in tracing_config: if "*" in tracing_config[key]: # If the key contains '*', retain the original value from the current config - new_config[key] = current_trace_config.get(key, tracing_config[key]) + if current_trace_config: + new_config[key] = current_trace_config.get(key, tracing_config[key]) + else: + new_config[key] = tracing_config[key] else: # Otherwise, encrypt the key new_config[key] = encrypt_token(tenant_id, tracing_config[key]) diff --git a/api/core/ops/weave_trace/weave_trace.py b/api/core/ops/weave_trace/weave_trace.py index 339694cf07..9b3d7a8192 100644 --- a/api/core/ops/weave_trace/weave_trace.py +++ b/api/core/ops/weave_trace/weave_trace.py @@ -62,7 +62,8 @@ class WeaveDataTrace(BaseTraceInstance): self, ): try: - project_url = f"https://wandb.ai/{self.weave_client._project_id()}" + project_identifier = f"{self.entity}/{self.project_name}" if self.entity else self.project_name + project_url = f"https://wandb.ai/{project_identifier}" return project_url except Exception as e: logger.debug("Weave get run url failed: %s", str(e)) @@ -103,7 +104,7 @@ class WeaveDataTrace(BaseTraceInstance): message_run = WeaveTraceModel( id=trace_info.message_id, - op=str(TraceTaskName.MESSAGE_TRACE.value), + op=str(TraceTaskName.MESSAGE_TRACE), inputs=dict(trace_info.workflow_run_inputs), outputs=dict(trace_info.workflow_run_outputs), total_tokens=trace_info.total_tokens, @@ -125,7 +126,7 @@ class WeaveDataTrace(BaseTraceInstance): file_list=trace_info.file_list, total_tokens=trace_info.total_tokens, id=trace_info.workflow_run_id, - op=str(TraceTaskName.WORKFLOW_TRACE.value), + op=str(TraceTaskName.WORKFLOW_TRACE), inputs=dict(trace_info.workflow_run_inputs), outputs=dict(trace_info.workflow_run_outputs), attributes=workflow_attributes, @@ -252,7 +253,7 @@ class WeaveDataTrace(BaseTraceInstance): message_run = WeaveTraceModel( id=trace_id, - op=str(TraceTaskName.MESSAGE_TRACE.value), + op=str(TraceTaskName.MESSAGE_TRACE), input_tokens=trace_info.message_tokens, output_tokens=trace_info.answer_tokens, total_tokens=trace_info.total_tokens, @@ -299,7 +300,7 @@ class WeaveDataTrace(BaseTraceInstance): moderation_run = WeaveTraceModel( id=str(uuid.uuid4()), - op=str(TraceTaskName.MODERATION_TRACE.value), + op=str(TraceTaskName.MODERATION_TRACE), inputs=trace_info.inputs, outputs={ "action": trace_info.action, @@ -329,7 +330,7 @@ class WeaveDataTrace(BaseTraceInstance): suggested_question_run = WeaveTraceModel( id=str(uuid.uuid4()), - op=str(TraceTaskName.SUGGESTED_QUESTION_TRACE.value), + op=str(TraceTaskName.SUGGESTED_QUESTION_TRACE), inputs=trace_info.inputs, outputs=trace_info.suggested_question, attributes=attributes, @@ -354,7 +355,7 @@ class WeaveDataTrace(BaseTraceInstance): dataset_retrieval_run = WeaveTraceModel( id=str(uuid.uuid4()), - op=str(TraceTaskName.DATASET_RETRIEVAL_TRACE.value), + op=str(TraceTaskName.DATASET_RETRIEVAL_TRACE), inputs=trace_info.inputs, outputs={"documents": trace_info.documents}, attributes=attributes, @@ -396,7 +397,7 @@ class WeaveDataTrace(BaseTraceInstance): name_run = WeaveTraceModel( id=str(uuid.uuid4()), - op=str(TraceTaskName.GENERATE_NAME_TRACE.value), + op=str(TraceTaskName.GENERATE_NAME_TRACE), inputs=trace_info.inputs, outputs=trace_info.outputs, attributes=attributes, @@ -424,7 +425,23 @@ class WeaveDataTrace(BaseTraceInstance): raise ValueError(f"Weave API check failed: {str(e)}") def start_call(self, run_data: WeaveTraceModel, parent_run_id: str | None = None): - call = self.weave_client.create_call(op=run_data.op, inputs=run_data.inputs, attributes=run_data.attributes) + inputs = run_data.inputs + if inputs is None: + inputs = {} + elif not isinstance(inputs, dict): + inputs = {"inputs": str(inputs)} + + attributes = run_data.attributes + if attributes is None: + attributes = {} + elif not isinstance(attributes, dict): + attributes = {"attributes": str(attributes)} + + call = self.weave_client.create_call( + op=run_data.op, + inputs=inputs, + attributes=attributes, + ) self.calls[run_data.id] = call if parent_run_id: self.calls[run_data.id].parent_id = parent_run_id @@ -432,6 +449,7 @@ class WeaveDataTrace(BaseTraceInstance): def finish_call(self, run_data: WeaveTraceModel): call = self.calls.get(run_data.id) if call: - self.weave_client.finish_call(call=call, output=run_data.outputs, exception=run_data.exception) + exception = Exception(run_data.exception) if run_data.exception else None + self.weave_client.finish_call(call=call, output=run_data.outputs, exception=exception) else: raise ValueError(f"Call with id {run_data.id} not found") diff --git a/api/core/plugin/backwards_invocation/node.py b/api/core/plugin/backwards_invocation/node.py index 1d6d21cff7..9fbcbf55b4 100644 --- a/api/core/plugin/backwards_invocation/node.py +++ b/api/core/plugin/backwards_invocation/node.py @@ -52,7 +52,7 @@ class PluginNodeBackwardsInvocation(BaseBackwardsInvocation): instruction=instruction, # instruct with variables are not supported ) node_data_dict = node_data.model_dump() - node_data_dict["type"] = NodeType.PARAMETER_EXTRACTOR.value + node_data_dict["type"] = NodeType.PARAMETER_EXTRACTOR execution = workflow_service.run_free_workflow_node( node_data_dict, tenant_id=tenant_id, diff --git a/api/core/plugin/entities/request.py b/api/core/plugin/entities/request.py index 10f37f75f8..d5df85730b 100644 --- a/api/core/plugin/entities/request.py +++ b/api/core/plugin/entities/request.py @@ -83,16 +83,16 @@ class RequestInvokeLLM(BaseRequestInvokeModel): raise ValueError("prompt_messages must be a list") for i in range(len(v)): - if v[i]["role"] == PromptMessageRole.USER.value: - v[i] = UserPromptMessage(**v[i]) - elif v[i]["role"] == PromptMessageRole.ASSISTANT.value: - v[i] = AssistantPromptMessage(**v[i]) - elif v[i]["role"] == PromptMessageRole.SYSTEM.value: - v[i] = SystemPromptMessage(**v[i]) - elif v[i]["role"] == PromptMessageRole.TOOL.value: - v[i] = ToolPromptMessage(**v[i]) + if v[i]["role"] == PromptMessageRole.USER: + v[i] = UserPromptMessage.model_validate(v[i]) + elif v[i]["role"] == PromptMessageRole.ASSISTANT: + v[i] = AssistantPromptMessage.model_validate(v[i]) + elif v[i]["role"] == PromptMessageRole.SYSTEM: + v[i] = SystemPromptMessage.model_validate(v[i]) + elif v[i]["role"] == PromptMessageRole.TOOL: + v[i] = ToolPromptMessage.model_validate(v[i]) else: - v[i] = PromptMessage(**v[i]) + v[i] = PromptMessage.model_validate(v[i]) return v diff --git a/api/core/plugin/impl/base.py b/api/core/plugin/impl/base.py index 8e3df4da2c..c791b35161 100644 --- a/api/core/plugin/impl/base.py +++ b/api/core/plugin/impl/base.py @@ -2,11 +2,10 @@ import inspect import json import logging from collections.abc import Callable, Generator -from typing import TypeVar +from typing import Any, TypeVar -import requests +import httpx from pydantic import BaseModel -from requests.exceptions import HTTPError from yarl import URL from configs import dify_config @@ -47,29 +46,56 @@ class BasePluginClient: data: bytes | dict | str | None = None, params: dict | None = None, files: dict | None = None, - stream: bool = False, - ) -> requests.Response: + ) -> httpx.Response: """ Make a request to the plugin daemon inner API. """ - url = plugin_daemon_inner_api_baseurl / path - headers = headers or {} - headers["X-Api-Key"] = dify_config.PLUGIN_DAEMON_KEY - headers["Accept-Encoding"] = "gzip, deflate, br" + url, headers, prepared_data, params, files = self._prepare_request(path, headers, data, params, files) - if headers.get("Content-Type") == "application/json" and isinstance(data, dict): - data = json.dumps(data) + request_kwargs: dict[str, Any] = { + "method": method, + "url": url, + "headers": headers, + "params": params, + "files": files, + } + if isinstance(prepared_data, dict): + request_kwargs["data"] = prepared_data + elif prepared_data is not None: + request_kwargs["content"] = prepared_data try: - response = requests.request( - method=method, url=str(url), headers=headers, data=data, params=params, stream=stream, files=files - ) - except requests.ConnectionError: + response = httpx.request(**request_kwargs) + except httpx.RequestError: logger.exception("Request to Plugin Daemon Service failed") raise PluginDaemonInnerError(code=-500, message="Request to Plugin Daemon Service failed") return response + def _prepare_request( + self, + path: str, + headers: dict | None, + data: bytes | dict | str | None, + params: dict | None, + files: dict | None, + ) -> tuple[str, dict, bytes | dict | str | None, dict | None, dict | None]: + url = plugin_daemon_inner_api_baseurl / path + prepared_headers = dict(headers or {}) + prepared_headers["X-Api-Key"] = dify_config.PLUGIN_DAEMON_KEY + prepared_headers.setdefault("Accept-Encoding", "gzip, deflate, br") + + prepared_data: bytes | dict | str | None = ( + data if isinstance(data, (bytes, str, dict)) or data is None else None + ) + if isinstance(data, dict): + if prepared_headers.get("Content-Type") == "application/json": + prepared_data = json.dumps(data) + else: + prepared_data = data + + return str(url), prepared_headers, prepared_data, params, files + def _stream_request( self, method: str, @@ -78,23 +104,44 @@ class BasePluginClient: headers: dict | None = None, data: bytes | dict | None = None, files: dict | None = None, - ) -> Generator[bytes, None, None]: + ) -> Generator[str, None, None]: """ Make a stream request to the plugin daemon inner API """ - response = self._request(method, path, headers, data, params, files, stream=True) - for line in response.iter_lines(chunk_size=1024 * 8): - line = line.decode("utf-8").strip() - if line.startswith("data:"): - line = line[5:].strip() - if line: - yield line + url, headers, prepared_data, params, files = self._prepare_request(path, headers, data, params, files) + + stream_kwargs: dict[str, Any] = { + "method": method, + "url": url, + "headers": headers, + "params": params, + "files": files, + } + if isinstance(prepared_data, dict): + stream_kwargs["data"] = prepared_data + elif prepared_data is not None: + stream_kwargs["content"] = prepared_data + + try: + with httpx.stream(**stream_kwargs) as response: + for raw_line in response.iter_lines(): + if raw_line is None: + continue + line = raw_line.decode("utf-8") if isinstance(raw_line, bytes) else raw_line + line = line.strip() + if line.startswith("data:"): + line = line[5:].strip() + if line: + yield line + except httpx.RequestError: + logger.exception("Stream request to Plugin Daemon Service failed") + raise PluginDaemonInnerError(code=-500, message="Request to Plugin Daemon Service failed") def _stream_request_with_model( self, method: str, path: str, - type: type[T], + type_: type[T], headers: dict | None = None, data: bytes | dict | None = None, params: dict | None = None, @@ -104,13 +151,13 @@ class BasePluginClient: Make a stream request to the plugin daemon inner API and yield the response as a model. """ for line in self._stream_request(method, path, params, headers, data, files): - yield type(**json.loads(line)) # type: ignore + yield type_(**json.loads(line)) # type: ignore def _request_with_model( self, method: str, path: str, - type: type[T], + type_: type[T], headers: dict | None = None, data: bytes | None = None, params: dict | None = None, @@ -120,13 +167,13 @@ class BasePluginClient: Make a request to the plugin daemon inner API and return the response as a model. """ response = self._request(method, path, headers, data, params, files) - return type(**response.json()) # type: ignore + return type_(**response.json()) # type: ignore def _request_with_plugin_daemon_response( self, method: str, path: str, - type: type[T], + type_: type[T], headers: dict | None = None, data: bytes | dict | None = None, params: dict | None = None, @@ -139,23 +186,23 @@ class BasePluginClient: try: response = self._request(method, path, headers, data, params, files) response.raise_for_status() - except HTTPError as e: - msg = f"Failed to request plugin daemon, status: {e.response.status_code}, url: {path}" - logger.exception(msg) + except httpx.HTTPStatusError as e: + logger.exception("Failed to request plugin daemon, status: %s, url: %s", e.response.status_code, path) raise e except Exception as e: msg = f"Failed to request plugin daemon, url: {path}" - logger.exception(msg) + logger.exception("Failed to request plugin daemon, url: %s", path) raise ValueError(msg) from e try: json_response = response.json() if transformer: json_response = transformer(json_response) - rep = PluginDaemonBasicResponse[type](**json_response) # type: ignore + # https://stackoverflow.com/questions/59634937/variable-foo-class-is-not-valid-as-type-but-why + rep = PluginDaemonBasicResponse[type_].model_validate(json_response) # type: ignore except Exception: msg = ( - f"Failed to parse response from plugin daemon to PluginDaemonBasicResponse [{str(type.__name__)}]," + f"Failed to parse response from plugin daemon to PluginDaemonBasicResponse [{str(type_.__name__)}]," f" url: {path}" ) logger.exception(msg) @@ -163,7 +210,7 @@ class BasePluginClient: if rep.code != 0: try: - error = PluginDaemonError(**json.loads(rep.message)) + error = PluginDaemonError.model_validate(json.loads(rep.message)) except Exception: raise ValueError(f"{rep.message}, code: {rep.code}") @@ -178,7 +225,7 @@ class BasePluginClient: self, method: str, path: str, - type: type[T], + type_: type[T], headers: dict | None = None, data: bytes | dict | None = None, params: dict | None = None, @@ -189,7 +236,7 @@ class BasePluginClient: """ for line in self._stream_request(method, path, params, headers, data, files): try: - rep = PluginDaemonBasicResponse[type].model_validate_json(line) # type: ignore + rep = PluginDaemonBasicResponse[type_].model_validate_json(line) # type: ignore except (ValueError, TypeError): # TODO modify this when line_data has code and message try: @@ -204,7 +251,7 @@ class BasePluginClient: if rep.code != 0: if rep.code == -500: try: - error = PluginDaemonError(**json.loads(rep.message)) + error = PluginDaemonError.model_validate(json.loads(rep.message)) except Exception: raise PluginDaemonInnerError(code=rep.code, message=rep.message) diff --git a/api/core/plugin/impl/datasource.py b/api/core/plugin/impl/datasource.py index 84087f8104..ce1ef71494 100644 --- a/api/core/plugin/impl/datasource.py +++ b/api/core/plugin/impl/datasource.py @@ -46,7 +46,9 @@ class PluginDatasourceManager(BasePluginClient): params={"page": 1, "page_size": 256}, transformer=transformer, ) - local_file_datasource_provider = PluginDatasourceProviderEntity(**self._get_local_file_datasource_provider()) + local_file_datasource_provider = PluginDatasourceProviderEntity.model_validate( + self._get_local_file_datasource_provider() + ) for provider in response: ToolTransformService.repack_provider(tenant_id=tenant_id, provider=provider) @@ -104,7 +106,7 @@ class PluginDatasourceManager(BasePluginClient): Fetch datasource provider for the given tenant and plugin. """ if provider_id == "langgenius/file/file": - return PluginDatasourceProviderEntity(**self._get_local_file_datasource_provider()) + return PluginDatasourceProviderEntity.model_validate(self._get_local_file_datasource_provider()) tool_provider_id = DatasourceProviderID(provider_id) diff --git a/api/core/plugin/impl/model.py b/api/core/plugin/impl/model.py index 153da142f4..5dfc3c212e 100644 --- a/api/core/plugin/impl/model.py +++ b/api/core/plugin/impl/model.py @@ -162,7 +162,7 @@ class PluginModelClient(BasePluginClient): response = self._request_with_plugin_daemon_response_stream( method="POST", path=f"plugin/{tenant_id}/dispatch/llm/invoke", - type=LLMResultChunk, + type_=LLMResultChunk, data=jsonable_encoder( { "user_id": user_id, @@ -208,7 +208,7 @@ class PluginModelClient(BasePluginClient): response = self._request_with_plugin_daemon_response_stream( method="POST", path=f"plugin/{tenant_id}/dispatch/llm/num_tokens", - type=PluginLLMNumTokensResponse, + type_=PluginLLMNumTokensResponse, data=jsonable_encoder( { "user_id": user_id, @@ -250,7 +250,7 @@ class PluginModelClient(BasePluginClient): response = self._request_with_plugin_daemon_response_stream( method="POST", path=f"plugin/{tenant_id}/dispatch/text_embedding/invoke", - type=TextEmbeddingResult, + type_=TextEmbeddingResult, data=jsonable_encoder( { "user_id": user_id, @@ -291,7 +291,7 @@ class PluginModelClient(BasePluginClient): response = self._request_with_plugin_daemon_response_stream( method="POST", path=f"plugin/{tenant_id}/dispatch/text_embedding/num_tokens", - type=PluginTextEmbeddingNumTokensResponse, + type_=PluginTextEmbeddingNumTokensResponse, data=jsonable_encoder( { "user_id": user_id, @@ -334,7 +334,7 @@ class PluginModelClient(BasePluginClient): response = self._request_with_plugin_daemon_response_stream( method="POST", path=f"plugin/{tenant_id}/dispatch/rerank/invoke", - type=RerankResult, + type_=RerankResult, data=jsonable_encoder( { "user_id": user_id, @@ -378,7 +378,7 @@ class PluginModelClient(BasePluginClient): response = self._request_with_plugin_daemon_response_stream( method="POST", path=f"plugin/{tenant_id}/dispatch/tts/invoke", - type=PluginStringResultResponse, + type_=PluginStringResultResponse, data=jsonable_encoder( { "user_id": user_id, @@ -422,7 +422,7 @@ class PluginModelClient(BasePluginClient): response = self._request_with_plugin_daemon_response_stream( method="POST", path=f"plugin/{tenant_id}/dispatch/tts/model/voices", - type=PluginVoicesResponse, + type_=PluginVoicesResponse, data=jsonable_encoder( { "user_id": user_id, @@ -466,7 +466,7 @@ class PluginModelClient(BasePluginClient): response = self._request_with_plugin_daemon_response_stream( method="POST", path=f"plugin/{tenant_id}/dispatch/speech2text/invoke", - type=PluginStringResultResponse, + type_=PluginStringResultResponse, data=jsonable_encoder( { "user_id": user_id, @@ -506,7 +506,7 @@ class PluginModelClient(BasePluginClient): response = self._request_with_plugin_daemon_response_stream( method="POST", path=f"plugin/{tenant_id}/dispatch/moderation/invoke", - type=PluginBasicBooleanResponse, + type_=PluginBasicBooleanResponse, data=jsonable_encoder( { "user_id": user_id, diff --git a/api/core/plugin/utils/chunk_merger.py b/api/core/plugin/utils/chunk_merger.py index e30076f9d3..28cb70f96a 100644 --- a/api/core/plugin/utils/chunk_merger.py +++ b/api/core/plugin/utils/chunk_merger.py @@ -1,6 +1,6 @@ from collections.abc import Generator from dataclasses import dataclass, field -from typing import TypeVar, Union, cast +from typing import TypeVar, Union from core.agent.entities import AgentInvokeMessage from core.tools.entities.tool_entities import ToolInvokeMessage @@ -87,7 +87,8 @@ def merge_blob_chunks( ), meta=resp.meta, ) - yield cast(MessageType, merged_message) + assert isinstance(merged_message, (ToolInvokeMessage, AgentInvokeMessage)) + yield merged_message # type: ignore # Clean up the buffer del files[chunk_id] else: diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py index 6f642ab5db..7bc9830ac3 100644 --- a/api/core/provider_manager.py +++ b/api/core/provider_manager.py @@ -513,6 +513,21 @@ class ProviderManager: return provider_name_to_provider_load_balancing_model_configs_dict + @staticmethod + def _get_provider_names(provider_name: str) -> list[str]: + """ + provider_name: `openai` or `langgenius/openai/openai` + return: [`openai`, `langgenius/openai/openai`] + """ + provider_names = [provider_name] + model_provider_id = ModelProviderID(provider_name) + if model_provider_id.is_langgenius(): + if "/" in provider_name: + provider_names.append(model_provider_id.provider_name) + else: + provider_names.append(str(model_provider_id)) + return provider_names + @staticmethod def get_provider_available_credentials(tenant_id: str, provider_name: str) -> list[CredentialConfiguration]: """ @@ -525,7 +540,10 @@ class ProviderManager: with Session(db.engine, expire_on_commit=False) as session: stmt = ( select(ProviderCredential) - .where(ProviderCredential.tenant_id == tenant_id, ProviderCredential.provider_name == provider_name) + .where( + ProviderCredential.tenant_id == tenant_id, + ProviderCredential.provider_name.in_(ProviderManager._get_provider_names(provider_name)), + ) .order_by(ProviderCredential.created_at.desc()) ) @@ -554,7 +572,7 @@ class ProviderManager: select(ProviderModelCredential) .where( ProviderModelCredential.tenant_id == tenant_id, - ProviderModelCredential.provider_name == provider_name, + ProviderModelCredential.provider_name.in_(ProviderManager._get_provider_names(provider_name)), ProviderModelCredential.model_name == model_name, ProviderModelCredential.model_type == model_type, ) @@ -592,7 +610,7 @@ class ProviderManager: provider_quota_to_provider_record_dict = {} for provider_record in provider_records: - if provider_record.provider_type != ProviderType.SYSTEM.value: + if provider_record.provider_type != ProviderType.SYSTEM: continue provider_quota_to_provider_record_dict[ProviderQuotaType.value_of(provider_record.quota_type)] = ( @@ -609,8 +627,8 @@ class ProviderManager: tenant_id=tenant_id, # TODO: Use provider name with prefix after the data migration. provider_name=ModelProviderID(provider_name).provider_name, - provider_type=ProviderType.SYSTEM.value, - quota_type=ProviderQuotaType.TRIAL.value, + provider_type=ProviderType.SYSTEM, + quota_type=ProviderQuotaType.TRIAL, quota_limit=quota.quota_limit, # type: ignore quota_used=0, is_valid=True, @@ -623,8 +641,8 @@ class ProviderManager: stmt = select(Provider).where( Provider.tenant_id == tenant_id, Provider.provider_name == ModelProviderID(provider_name).provider_name, - Provider.provider_type == ProviderType.SYSTEM.value, - Provider.quota_type == ProviderQuotaType.TRIAL.value, + Provider.provider_type == ProviderType.SYSTEM, + Provider.quota_type == ProviderQuotaType.TRIAL, ) existed_provider_record = db.session.scalar(stmt) if not existed_provider_record: @@ -684,7 +702,7 @@ class ProviderManager: """Get custom provider configuration.""" # Find custom provider record (non-system) custom_provider_record = next( - (record for record in provider_records if record.provider_type != ProviderType.SYSTEM.value), None + (record for record in provider_records if record.provider_type != ProviderType.SYSTEM), None ) if not custom_provider_record: @@ -887,7 +905,7 @@ class ProviderManager: # Convert provider_records to dict quota_type_to_provider_records_dict: dict[ProviderQuotaType, Provider] = {} for provider_record in provider_records: - if provider_record.provider_type != ProviderType.SYSTEM.value: + if provider_record.provider_type != ProviderType.SYSTEM: continue quota_type_to_provider_records_dict[ProviderQuotaType.value_of(provider_record.quota_type)] = ( @@ -1028,7 +1046,7 @@ class ProviderManager: """ secret_input_form_variables = [] for credential_form_schema in credential_form_schemas: - if credential_form_schema.type.value == FormType.SECRET_INPUT.value: + if credential_form_schema.type.value == FormType.SECRET_INPUT: secret_input_form_variables.append(credential_form_schema.variable) return secret_input_form_variables diff --git a/api/core/rag/data_post_processor/data_post_processor.py b/api/core/rag/data_post_processor/data_post_processor.py index 696e3e967f..cc946a72c3 100644 --- a/api/core/rag/data_post_processor/data_post_processor.py +++ b/api/core/rag/data_post_processor/data_post_processor.py @@ -46,7 +46,7 @@ class DataPostProcessor: reranking_model: dict | None = None, weights: dict | None = None, ) -> BaseRerankRunner | None: - if reranking_mode == RerankMode.WEIGHTED_SCORE.value and weights: + if reranking_mode == RerankMode.WEIGHTED_SCORE and weights: runner = RerankRunnerFactory.create_rerank_runner( runner_type=reranking_mode, tenant_id=tenant_id, @@ -62,7 +62,7 @@ class DataPostProcessor: ), ) return runner - elif reranking_mode == RerankMode.RERANKING_MODEL.value: + elif reranking_mode == RerankMode.RERANKING_MODEL: rerank_model_instance = self._get_rerank_model_instance(tenant_id, reranking_model) if rerank_model_instance is None: return None diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py index 429744c0de..6e9e2b4527 100644 --- a/api/core/rag/datasource/retrieval_service.py +++ b/api/core/rag/datasource/retrieval_service.py @@ -21,7 +21,7 @@ from models.dataset import Document as DatasetDocument from services.external_knowledge_service import ExternalDatasetService default_retrieval_model = { - "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, + "search_method": RetrievalMethod.SEMANTIC_SEARCH, "reranking_enable": False, "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, "top_k": 4, @@ -106,7 +106,9 @@ class RetrievalService: if exceptions: raise ValueError(";\n".join(exceptions)) - if retrieval_method == RetrievalMethod.HYBRID_SEARCH.value: + # Deduplicate documents for hybrid search to avoid duplicate chunks + if retrieval_method == RetrievalMethod.HYBRID_SEARCH: + all_documents = cls._deduplicate_documents(all_documents) data_post_processor = DataPostProcessor( str(dataset.tenant_id), reranking_mode, reranking_model, weights, False ) @@ -132,7 +134,7 @@ class RetrievalService: if not dataset: return [] metadata_condition = ( - MetadataCondition(**metadata_filtering_conditions) if metadata_filtering_conditions else None + MetadataCondition.model_validate(metadata_filtering_conditions) if metadata_filtering_conditions else None ) all_documents = ExternalDatasetService.fetch_external_knowledge_retrieval( dataset.tenant_id, @@ -143,6 +145,40 @@ class RetrievalService: ) return all_documents + @classmethod + def _deduplicate_documents(cls, documents: list[Document]) -> list[Document]: + """Deduplicate documents based on doc_id to avoid duplicate chunks in hybrid search.""" + if not documents: + return documents + + unique_documents = [] + seen_doc_ids = set() + + for document in documents: + # For dify provider documents, use doc_id for deduplication + if document.provider == "dify" and document.metadata is not None and "doc_id" in document.metadata: + doc_id = document.metadata["doc_id"] + if doc_id not in seen_doc_ids: + seen_doc_ids.add(doc_id) + unique_documents.append(document) + # If duplicate, keep the one with higher score + elif "score" in document.metadata: + # Find existing document with same doc_id and compare scores + for i, existing_doc in enumerate(unique_documents): + if ( + existing_doc.metadata + and existing_doc.metadata.get("doc_id") == doc_id + and existing_doc.metadata.get("score", 0) < document.metadata.get("score", 0) + ): + unique_documents[i] = document + break + else: + # For non-dify documents, use content-based deduplication + if document not in unique_documents: + unique_documents.append(document) + + return unique_documents + @classmethod def _get_dataset(cls, dataset_id: str) -> Dataset | None: with Session(db.engine) as session: @@ -209,10 +245,10 @@ class RetrievalService: reranking_model and reranking_model.get("reranking_model_name") and reranking_model.get("reranking_provider_name") - and retrieval_method == RetrievalMethod.SEMANTIC_SEARCH.value + and retrieval_method == RetrievalMethod.SEMANTIC_SEARCH ): data_post_processor = DataPostProcessor( - str(dataset.tenant_id), str(RerankMode.RERANKING_MODEL.value), reranking_model, None, False + str(dataset.tenant_id), str(RerankMode.RERANKING_MODEL), reranking_model, None, False ) all_documents.extend( data_post_processor.invoke( @@ -257,10 +293,10 @@ class RetrievalService: reranking_model and reranking_model.get("reranking_model_name") and reranking_model.get("reranking_provider_name") - and retrieval_method == RetrievalMethod.FULL_TEXT_SEARCH.value + and retrieval_method == RetrievalMethod.FULL_TEXT_SEARCH ): data_post_processor = DataPostProcessor( - str(dataset.tenant_id), str(RerankMode.RERANKING_MODEL.value), reranking_model, None, False + str(dataset.tenant_id), str(RerankMode.RERANKING_MODEL), reranking_model, None, False ) all_documents.extend( data_post_processor.invoke( diff --git a/docker/volumes/sandbox/dependencies/python-requirements.txt b/api/core/rag/datasource/vdb/alibabacloud_mysql/__init__.py similarity index 100% rename from docker/volumes/sandbox/dependencies/python-requirements.txt rename to api/core/rag/datasource/vdb/alibabacloud_mysql/__init__.py diff --git a/api/core/rag/datasource/vdb/alibabacloud_mysql/alibabacloud_mysql_vector.py b/api/core/rag/datasource/vdb/alibabacloud_mysql/alibabacloud_mysql_vector.py new file mode 100644 index 0000000000..fdb5ffebfc --- /dev/null +++ b/api/core/rag/datasource/vdb/alibabacloud_mysql/alibabacloud_mysql_vector.py @@ -0,0 +1,388 @@ +import hashlib +import json +import logging +import uuid +from contextlib import contextmanager +from typing import Any, Literal, cast + +import mysql.connector +from mysql.connector import Error as MySQLError +from pydantic import BaseModel, model_validator + +from configs import dify_config +from core.rag.datasource.vdb.vector_base import BaseVector +from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory +from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.embedding.embedding_base import Embeddings +from core.rag.models.document import Document +from extensions.ext_redis import redis_client +from models.dataset import Dataset + +logger = logging.getLogger(__name__) + + +class AlibabaCloudMySQLVectorConfig(BaseModel): + host: str + port: int + user: str + password: str + database: str + max_connection: int + charset: str = "utf8mb4" + distance_function: Literal["cosine", "euclidean"] = "cosine" + hnsw_m: int = 6 + + @model_validator(mode="before") + @classmethod + def validate_config(cls, values: dict): + if not values.get("host"): + raise ValueError("config ALIBABACLOUD_MYSQL_HOST is required") + if not values.get("port"): + raise ValueError("config ALIBABACLOUD_MYSQL_PORT is required") + if not values.get("user"): + raise ValueError("config ALIBABACLOUD_MYSQL_USER is required") + if values.get("password") is None: + raise ValueError("config ALIBABACLOUD_MYSQL_PASSWORD is required") + if not values.get("database"): + raise ValueError("config ALIBABACLOUD_MYSQL_DATABASE is required") + if not values.get("max_connection"): + raise ValueError("config ALIBABACLOUD_MYSQL_MAX_CONNECTION is required") + return values + + +SQL_CREATE_TABLE = """ +CREATE TABLE IF NOT EXISTS {table_name} ( + id VARCHAR(36) PRIMARY KEY, + text LONGTEXT NOT NULL, + meta JSON NOT NULL, + embedding VECTOR({dimension}) NOT NULL, + VECTOR INDEX (embedding) M={hnsw_m} DISTANCE={distance_function} +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; +""" + +SQL_CREATE_META_INDEX = """ +CREATE INDEX idx_{index_hash}_meta ON {table_name} + ((CAST(JSON_UNQUOTE(JSON_EXTRACT(meta, '$.document_id')) AS CHAR(36)))); +""" + +SQL_CREATE_FULLTEXT_INDEX = """ +CREATE FULLTEXT INDEX idx_{index_hash}_text ON {table_name} (text) WITH PARSER ngram; +""" + + +class AlibabaCloudMySQLVector(BaseVector): + def __init__(self, collection_name: str, config: AlibabaCloudMySQLVectorConfig): + super().__init__(collection_name) + self.pool = self._create_connection_pool(config) + self.table_name = collection_name.lower() + self.index_hash = hashlib.md5(self.table_name.encode()).hexdigest()[:8] + self.distance_function = config.distance_function.lower() + self.hnsw_m = config.hnsw_m + self._check_vector_support() + + def get_type(self) -> str: + return VectorType.ALIBABACLOUD_MYSQL + + def _create_connection_pool(self, config: AlibabaCloudMySQLVectorConfig): + # Create connection pool using mysql-connector-python pooling + pool_config: dict[str, Any] = { + "host": config.host, + "port": config.port, + "user": config.user, + "password": config.password, + "database": config.database, + "charset": config.charset, + "autocommit": True, + "pool_name": f"pool_{self.collection_name}", + "pool_size": config.max_connection, + "pool_reset_session": True, + } + return mysql.connector.pooling.MySQLConnectionPool(**pool_config) + + def _check_vector_support(self): + """Check if the MySQL server supports vector operations.""" + try: + with self._get_cursor() as cur: + # Check MySQL version and vector support + cur.execute("SELECT VERSION()") + version = cur.fetchone()["VERSION()"] + logger.debug("Connected to MySQL version: %s", version) + # Try to execute a simple vector function to verify support + cur.execute("SELECT VEC_FromText('[1,2,3]') IS NOT NULL as vector_support") + result = cur.fetchone() + if not result or not result.get("vector_support"): + raise ValueError( + "RDS MySQL Vector functions are not available." + " Please ensure you're using RDS MySQL 8.0.36+ with Vector support." + ) + + except MySQLError as e: + if "FUNCTION" in str(e) and "VEC_FromText" in str(e): + raise ValueError( + "RDS MySQL Vector functions are not available." + " Please ensure you're using RDS MySQL 8.0.36+ with Vector support." + ) from e + raise e + + @contextmanager + def _get_cursor(self): + conn = self.pool.get_connection() + cur = conn.cursor(dictionary=True) + try: + yield cur + finally: + cur.close() + conn.close() + + def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): + dimension = len(embeddings[0]) + self._create_collection(dimension) + return self.add_texts(texts, embeddings) + + def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): + values = [] + pks = [] + for i, doc in enumerate(documents): + if doc.metadata is not None: + doc_id = doc.metadata.get("doc_id", str(uuid.uuid4())) + pks.append(doc_id) + # Convert embedding list to Aliyun MySQL vector format + vector_str = "[" + ",".join(map(str, embeddings[i])) + "]" + values.append( + ( + doc_id, + doc.page_content, + json.dumps(doc.metadata), + vector_str, + ) + ) + + with self._get_cursor() as cur: + insert_sql = ( + f"INSERT INTO {self.table_name} (id, text, meta, embedding) VALUES (%s, %s, %s, VEC_FromText(%s))" + ) + cur.executemany(insert_sql, values) + return pks + + def text_exists(self, id: str) -> bool: + with self._get_cursor() as cur: + cur.execute(f"SELECT id FROM {self.table_name} WHERE id = %s", (id,)) + return cur.fetchone() is not None + + def get_by_ids(self, ids: list[str]) -> list[Document]: + if not ids: + return [] + + with self._get_cursor() as cur: + placeholders = ",".join(["%s"] * len(ids)) + cur.execute(f"SELECT meta, text FROM {self.table_name} WHERE id IN ({placeholders})", ids) + docs = [] + for record in cur: + metadata = record["meta"] + if isinstance(metadata, str): + metadata = json.loads(metadata) + docs.append(Document(page_content=record["text"], metadata=metadata)) + return docs + + def delete_by_ids(self, ids: list[str]): + # Avoiding crashes caused by performing delete operations on empty lists + if not ids: + return + + with self._get_cursor() as cur: + try: + placeholders = ",".join(["%s"] * len(ids)) + cur.execute(f"DELETE FROM {self.table_name} WHERE id IN ({placeholders})", ids) + except MySQLError as e: + if e.errno == 1146: # Table doesn't exist + logger.warning("Table %s not found, skipping delete operation.", self.table_name) + return + else: + raise e + + def delete_by_metadata_field(self, key: str, value: str): + with self._get_cursor() as cur: + cur.execute( + f"DELETE FROM {self.table_name} WHERE JSON_UNQUOTE(JSON_EXTRACT(meta, %s)) = %s", (f"$.{key}", value) + ) + + def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: + """ + Search the nearest neighbors to a vector using RDS MySQL vector distance functions. + + :param query_vector: The input vector to search for similar items. + :return: List of Documents that are nearest to the query vector. + """ + top_k = kwargs.get("top_k", 4) + if not isinstance(top_k, int) or top_k <= 0: + raise ValueError("top_k must be a positive integer") + + document_ids_filter = kwargs.get("document_ids_filter") + where_clause = "" + params = [] + + if document_ids_filter: + placeholders = ",".join(["%s"] * len(document_ids_filter)) + where_clause = f" WHERE JSON_UNQUOTE(JSON_EXTRACT(meta, '$.document_id')) IN ({placeholders}) " + params.extend(document_ids_filter) + + # Convert query vector to RDS MySQL vector format + query_vector_str = "[" + ",".join(map(str, query_vector)) + "]" + + # Use RSD MySQL's native vector distance functions + with self._get_cursor() as cur: + # Choose distance function based on configuration + distance_func = "VEC_DISTANCE_COSINE" if self.distance_function == "cosine" else "VEC_DISTANCE_EUCLIDEAN" + + # Note: RDS MySQL optimizer will use vector index when ORDER BY + LIMIT are present + # Use column alias in ORDER BY to avoid calculating distance twice + sql = f""" + SELECT meta, text, + {distance_func}(embedding, VEC_FromText(%s)) AS distance + FROM {self.table_name} + {where_clause} + ORDER BY distance + LIMIT %s + """ + query_params = [query_vector_str] + params + [top_k] + + cur.execute(sql, query_params) + + docs = [] + score_threshold = float(kwargs.get("score_threshold") or 0.0) + + for record in cur: + try: + distance = float(record["distance"]) + # Convert distance to similarity score + if self.distance_function == "cosine": + # For cosine distance: similarity = 1 - distance + similarity = 1.0 - distance + else: + # For euclidean distance: use inverse relationship + # similarity = 1 / (1 + distance) + similarity = 1.0 / (1.0 + distance) + + metadata = record["meta"] + if isinstance(metadata, str): + metadata = json.loads(metadata) + metadata["score"] = similarity + metadata["distance"] = distance + + if similarity >= score_threshold: + docs.append(Document(page_content=record["text"], metadata=metadata)) + except (ValueError, json.JSONDecodeError) as e: + logger.warning("Error processing search result: %s", e) + continue + + return docs + + def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: + top_k = kwargs.get("top_k", 5) + if not isinstance(top_k, int) or top_k <= 0: + raise ValueError("top_k must be a positive integer") + + document_ids_filter = kwargs.get("document_ids_filter") + where_clause = "" + params = [] + + if document_ids_filter: + placeholders = ",".join(["%s"] * len(document_ids_filter)) + where_clause = f" AND JSON_UNQUOTE(JSON_EXTRACT(meta, '$.document_id')) IN ({placeholders}) " + params.extend(document_ids_filter) + + with self._get_cursor() as cur: + # Build query parameters: query (twice for MATCH clauses), document_ids_filter (if any), top_k + query_params = [query, query] + params + [top_k] + cur.execute( + f"""SELECT meta, text, + MATCH(text) AGAINST(%s IN NATURAL LANGUAGE MODE) AS score + FROM {self.table_name} + WHERE MATCH(text) AGAINST(%s IN NATURAL LANGUAGE MODE) + {where_clause} + ORDER BY score DESC + LIMIT %s""", + query_params, + ) + docs = [] + for record in cur: + metadata = record["meta"] + if isinstance(metadata, str): + metadata = json.loads(metadata) + metadata["score"] = float(record["score"]) + docs.append(Document(page_content=record["text"], metadata=metadata)) + return docs + + def delete(self): + with self._get_cursor() as cur: + cur.execute(f"DROP TABLE IF EXISTS {self.table_name}") + + def _create_collection(self, dimension: int): + collection_exist_cache_key = f"vector_indexing_{self._collection_name}" + lock_name = f"{collection_exist_cache_key}_lock" + with redis_client.lock(lock_name, timeout=20): + if redis_client.get(collection_exist_cache_key): + return + + with self._get_cursor() as cur: + # Create table with vector column and vector index + cur.execute( + SQL_CREATE_TABLE.format( + table_name=self.table_name, + dimension=dimension, + distance_function=self.distance_function, + hnsw_m=self.hnsw_m, + ) + ) + # Create metadata index (check if exists first) + try: + cur.execute(SQL_CREATE_META_INDEX.format(table_name=self.table_name, index_hash=self.index_hash)) + except MySQLError as e: + if e.errno != 1061: # Duplicate key name + logger.warning("Could not create meta index: %s", e) + + # Create full-text index for text search + try: + cur.execute( + SQL_CREATE_FULLTEXT_INDEX.format(table_name=self.table_name, index_hash=self.index_hash) + ) + except MySQLError as e: + if e.errno != 1061: # Duplicate key name + logger.warning("Could not create fulltext index: %s", e) + + redis_client.set(collection_exist_cache_key, 1, ex=3600) + + +class AlibabaCloudMySQLVectorFactory(AbstractVectorFactory): + def _validate_distance_function(self, distance_function: str) -> Literal["cosine", "euclidean"]: + """Validate and return the distance function as a proper Literal type.""" + if distance_function not in ["cosine", "euclidean"]: + raise ValueError(f"Invalid distance function: {distance_function}. Must be 'cosine' or 'euclidean'") + return cast(Literal["cosine", "euclidean"], distance_function) + + def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> AlibabaCloudMySQLVector: + if dataset.index_struct_dict: + class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] + collection_name = class_prefix + else: + dataset_id = dataset.id + collection_name = Dataset.gen_collection_name_by_id(dataset_id) + dataset.index_struct = json.dumps( + self.gen_index_struct_dict(VectorType.ALIBABACLOUD_MYSQL, collection_name) + ) + return AlibabaCloudMySQLVector( + collection_name=collection_name, + config=AlibabaCloudMySQLVectorConfig( + host=dify_config.ALIBABACLOUD_MYSQL_HOST or "localhost", + port=dify_config.ALIBABACLOUD_MYSQL_PORT, + user=dify_config.ALIBABACLOUD_MYSQL_USER or "root", + password=dify_config.ALIBABACLOUD_MYSQL_PASSWORD or "", + database=dify_config.ALIBABACLOUD_MYSQL_DATABASE or "dify", + max_connection=dify_config.ALIBABACLOUD_MYSQL_MAX_CONNECTION, + charset=dify_config.ALIBABACLOUD_MYSQL_CHARSET or "utf8mb4", + distance_function=self._validate_distance_function( + dify_config.ALIBABACLOUD_MYSQL_DISTANCE_FUNCTION or "cosine" + ), + hnsw_m=dify_config.ALIBABACLOUD_MYSQL_HNSW_M or 6, + ), + ) diff --git a/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py b/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py index e55e5f3101..a306f9ba0c 100644 --- a/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py +++ b/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py @@ -488,9 +488,9 @@ class ClickzettaVector(BaseVector): create_table_sql = f""" CREATE TABLE IF NOT EXISTS {self._config.schema_name}.{self._table_name} ( id STRING NOT NULL COMMENT 'Unique document identifier', - {Field.CONTENT_KEY.value} STRING NOT NULL COMMENT 'Document text content for search and retrieval', - {Field.METADATA_KEY.value} JSON COMMENT 'Document metadata including source, type, and other attributes', - {Field.VECTOR.value} VECTOR(FLOAT, {dimension}) NOT NULL COMMENT + {Field.CONTENT_KEY} STRING NOT NULL COMMENT 'Document text content for search and retrieval', + {Field.METADATA_KEY} JSON COMMENT 'Document metadata including source, type, and other attributes', + {Field.VECTOR} VECTOR(FLOAT, {dimension}) NOT NULL COMMENT 'High-dimensional embedding vector for semantic similarity search', PRIMARY KEY (id) ) COMMENT 'Dify RAG knowledge base vector storage table for document embeddings and content' @@ -519,15 +519,15 @@ class ClickzettaVector(BaseVector): existing_indexes = cursor.fetchall() for idx in existing_indexes: # Check if vector index already exists on the embedding column - if Field.VECTOR.value in str(idx).lower(): - logger.info("Vector index already exists on column %s", Field.VECTOR.value) + if Field.VECTOR in str(idx).lower(): + logger.info("Vector index already exists on column %s", Field.VECTOR) return except (RuntimeError, ValueError) as e: logger.warning("Failed to check existing indexes: %s", e) index_sql = f""" CREATE VECTOR INDEX IF NOT EXISTS {index_name} - ON TABLE {self._config.schema_name}.{self._table_name}({Field.VECTOR.value}) + ON TABLE {self._config.schema_name}.{self._table_name}({Field.VECTOR}) PROPERTIES ( "distance.function" = "{self._config.vector_distance_function}", "scalar.type" = "f32", @@ -560,17 +560,17 @@ class ClickzettaVector(BaseVector): # More precise check: look for inverted index specifically on the content column if ( "inverted" in idx_str - and Field.CONTENT_KEY.value.lower() in idx_str + and Field.CONTENT_KEY.lower() in idx_str and (index_name.lower() in idx_str or f"idx_{self._table_name}_text" in idx_str) ): - logger.info("Inverted index already exists on column %s: %s", Field.CONTENT_KEY.value, idx) + logger.info("Inverted index already exists on column %s: %s", Field.CONTENT_KEY, idx) return except (RuntimeError, ValueError) as e: logger.warning("Failed to check existing indexes: %s", e) index_sql = f""" CREATE INVERTED INDEX IF NOT EXISTS {index_name} - ON TABLE {self._config.schema_name}.{self._table_name} ({Field.CONTENT_KEY.value}) + ON TABLE {self._config.schema_name}.{self._table_name} ({Field.CONTENT_KEY}) PROPERTIES ( "analyzer" = "{self._config.analyzer_type}", "mode" = "{self._config.analyzer_mode}" @@ -588,13 +588,13 @@ class ClickzettaVector(BaseVector): or "with the same type" in error_msg or "cannot create inverted index" in error_msg ) and "already has index" in error_msg: - logger.info("Inverted index already exists on column %s", Field.CONTENT_KEY.value) + logger.info("Inverted index already exists on column %s", Field.CONTENT_KEY) # Try to get the existing index name for logging try: cursor.execute(f"SHOW INDEX FROM {self._config.schema_name}.{self._table_name}") existing_indexes = cursor.fetchall() for idx in existing_indexes: - if "inverted" in str(idx).lower() and Field.CONTENT_KEY.value.lower() in str(idx).lower(): + if "inverted" in str(idx).lower() and Field.CONTENT_KEY.lower() in str(idx).lower(): logger.info("Found existing inverted index: %s", idx) break except (RuntimeError, ValueError): @@ -669,7 +669,7 @@ class ClickzettaVector(BaseVector): # Use parameterized INSERT with executemany for better performance and security # Cast JSON and VECTOR in SQL, pass raw data as parameters - columns = f"id, {Field.CONTENT_KEY.value}, {Field.METADATA_KEY.value}, {Field.VECTOR.value}" + columns = f"id, {Field.CONTENT_KEY}, {Field.METADATA_KEY}, {Field.VECTOR}" insert_sql = ( f"INSERT INTO {self._config.schema_name}.{self._table_name} ({columns}) " f"VALUES (?, ?, CAST(? AS JSON), CAST(? AS VECTOR({vector_dimension})))" @@ -767,7 +767,7 @@ class ClickzettaVector(BaseVector): # Use json_extract_string function for ClickZetta compatibility sql = ( f"DELETE FROM {self._config.schema_name}.{self._table_name} " - f"WHERE json_extract_string({Field.METADATA_KEY.value}, '$.{key}') = ?" + f"WHERE json_extract_string({Field.METADATA_KEY}, '$.{key}') = ?" ) cursor.execute(sql, binding_params=[value]) @@ -795,9 +795,7 @@ class ClickzettaVector(BaseVector): safe_doc_ids = [str(id).replace("'", "''") for id in document_ids_filter] doc_ids_str = ",".join(f"'{id}'" for id in safe_doc_ids) # Use json_extract_string function for ClickZetta compatibility - filter_clauses.append( - f"json_extract_string({Field.METADATA_KEY.value}, '$.document_id') IN ({doc_ids_str})" - ) + filter_clauses.append(f"json_extract_string({Field.METADATA_KEY}, '$.document_id') IN ({doc_ids_str})") # No need for dataset_id filter since each dataset has its own table @@ -808,23 +806,21 @@ class ClickzettaVector(BaseVector): distance_func = "COSINE_DISTANCE" if score_threshold > 0: query_vector_str = f"CAST('[{self._format_vector_simple(query_vector)}]' AS VECTOR({vector_dimension}))" - filter_clauses.append( - f"{distance_func}({Field.VECTOR.value}, {query_vector_str}) < {2 - score_threshold}" - ) + filter_clauses.append(f"{distance_func}({Field.VECTOR}, {query_vector_str}) < {2 - score_threshold}") else: # For L2 distance, smaller is better distance_func = "L2_DISTANCE" if score_threshold > 0: query_vector_str = f"CAST('[{self._format_vector_simple(query_vector)}]' AS VECTOR({vector_dimension}))" - filter_clauses.append(f"{distance_func}({Field.VECTOR.value}, {query_vector_str}) < {score_threshold}") + filter_clauses.append(f"{distance_func}({Field.VECTOR}, {query_vector_str}) < {score_threshold}") where_clause = " AND ".join(filter_clauses) if filter_clauses else "1=1" # Execute vector search query query_vector_str = f"CAST('[{self._format_vector_simple(query_vector)}]' AS VECTOR({vector_dimension}))" search_sql = f""" - SELECT id, {Field.CONTENT_KEY.value}, {Field.METADATA_KEY.value}, - {distance_func}({Field.VECTOR.value}, {query_vector_str}) AS distance + SELECT id, {Field.CONTENT_KEY}, {Field.METADATA_KEY}, + {distance_func}({Field.VECTOR}, {query_vector_str}) AS distance FROM {self._config.schema_name}.{self._table_name} WHERE {where_clause} ORDER BY distance @@ -887,9 +883,7 @@ class ClickzettaVector(BaseVector): safe_doc_ids = [str(id).replace("'", "''") for id in document_ids_filter] doc_ids_str = ",".join(f"'{id}'" for id in safe_doc_ids) # Use json_extract_string function for ClickZetta compatibility - filter_clauses.append( - f"json_extract_string({Field.METADATA_KEY.value}, '$.document_id') IN ({doc_ids_str})" - ) + filter_clauses.append(f"json_extract_string({Field.METADATA_KEY}, '$.document_id') IN ({doc_ids_str})") # No need for dataset_id filter since each dataset has its own table @@ -897,13 +891,13 @@ class ClickzettaVector(BaseVector): # match_all requires all terms to be present # Use simple quote escaping for MATCH_ALL since it needs to be in the WHERE clause escaped_query = query.replace("'", "''") - filter_clauses.append(f"MATCH_ALL({Field.CONTENT_KEY.value}, '{escaped_query}')") + filter_clauses.append(f"MATCH_ALL({Field.CONTENT_KEY}, '{escaped_query}')") where_clause = " AND ".join(filter_clauses) # Execute full-text search query search_sql = f""" - SELECT id, {Field.CONTENT_KEY.value}, {Field.METADATA_KEY.value} + SELECT id, {Field.CONTENT_KEY}, {Field.METADATA_KEY} FROM {self._config.schema_name}.{self._table_name} WHERE {where_clause} LIMIT {top_k} @@ -986,19 +980,17 @@ class ClickzettaVector(BaseVector): safe_doc_ids = [str(id).replace("'", "''") for id in document_ids_filter] doc_ids_str = ",".join(f"'{id}'" for id in safe_doc_ids) # Use json_extract_string function for ClickZetta compatibility - filter_clauses.append( - f"json_extract_string({Field.METADATA_KEY.value}, '$.document_id') IN ({doc_ids_str})" - ) + filter_clauses.append(f"json_extract_string({Field.METADATA_KEY}, '$.document_id') IN ({doc_ids_str})") # No need for dataset_id filter since each dataset has its own table # Use simple quote escaping for LIKE clause escaped_query = query.replace("'", "''") - filter_clauses.append(f"{Field.CONTENT_KEY.value} LIKE '%{escaped_query}%'") + filter_clauses.append(f"{Field.CONTENT_KEY} LIKE '%{escaped_query}%'") where_clause = " AND ".join(filter_clauses) search_sql = f""" - SELECT id, {Field.CONTENT_KEY.value}, {Field.METADATA_KEY.value} + SELECT id, {Field.CONTENT_KEY}, {Field.METADATA_KEY} FROM {self._config.schema_name}.{self._table_name} WHERE {where_clause} LIMIT {top_k} diff --git a/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_ja_vector.py b/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_ja_vector.py index 7b00928b7b..1e7fe52666 100644 --- a/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_ja_vector.py +++ b/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_ja_vector.py @@ -57,18 +57,18 @@ class ElasticSearchJaVector(ElasticSearchVector): } mappings = { "properties": { - Field.CONTENT_KEY.value: { + Field.CONTENT_KEY: { "type": "text", "analyzer": "ja_analyzer", "search_analyzer": "ja_analyzer", }, - Field.VECTOR.value: { # Make sure the dimension is correct here + Field.VECTOR: { # Make sure the dimension is correct here "type": "dense_vector", "dims": dim, "index": True, "similarity": "cosine", }, - Field.METADATA_KEY.value: { + Field.METADATA_KEY: { "type": "object", "properties": { "doc_id": {"type": "keyword"} # Map doc_id to keyword type diff --git a/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py b/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py index 2c147fa7ca..0ff8c915e6 100644 --- a/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py +++ b/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py @@ -4,7 +4,7 @@ import math from typing import Any, cast from urllib.parse import urlparse -import requests +from elasticsearch import ConnectionError as ElasticsearchConnectionError from elasticsearch import Elasticsearch from flask import current_app from packaging.version import parse as parse_version @@ -138,7 +138,7 @@ class ElasticSearchVector(BaseVector): if not client.ping(): raise ConnectionError("Failed to connect to Elasticsearch") - except requests.ConnectionError as e: + except ElasticsearchConnectionError as e: raise ConnectionError(f"Vector database connection error: {str(e)}") except Exception as e: raise ConnectionError(f"Elasticsearch client initialization failed: {str(e)}") @@ -163,9 +163,9 @@ class ElasticSearchVector(BaseVector): index=self._collection_name, id=uuids[i], document={ - Field.CONTENT_KEY.value: documents[i].page_content, - Field.VECTOR.value: embeddings[i] or None, - Field.METADATA_KEY.value: documents[i].metadata or {}, + Field.CONTENT_KEY: documents[i].page_content, + Field.VECTOR: embeddings[i] or None, + Field.METADATA_KEY: documents[i].metadata or {}, }, ) self._client.indices.refresh(index=self._collection_name) @@ -193,7 +193,7 @@ class ElasticSearchVector(BaseVector): def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: top_k = kwargs.get("top_k", 4) num_candidates = math.ceil(top_k * 1.5) - knn = {"field": Field.VECTOR.value, "query_vector": query_vector, "k": top_k, "num_candidates": num_candidates} + knn = {"field": Field.VECTOR, "query_vector": query_vector, "k": top_k, "num_candidates": num_candidates} document_ids_filter = kwargs.get("document_ids_filter") if document_ids_filter: knn["filter"] = {"terms": {"metadata.document_id": document_ids_filter}} @@ -205,9 +205,9 @@ class ElasticSearchVector(BaseVector): docs_and_scores.append( ( Document( - page_content=hit["_source"][Field.CONTENT_KEY.value], - vector=hit["_source"][Field.VECTOR.value], - metadata=hit["_source"][Field.METADATA_KEY.value], + page_content=hit["_source"][Field.CONTENT_KEY], + vector=hit["_source"][Field.VECTOR], + metadata=hit["_source"][Field.METADATA_KEY], ), hit["_score"], ) @@ -224,13 +224,13 @@ class ElasticSearchVector(BaseVector): return docs def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: - query_str: dict[str, Any] = {"match": {Field.CONTENT_KEY.value: query}} + query_str: dict[str, Any] = {"match": {Field.CONTENT_KEY: query}} document_ids_filter = kwargs.get("document_ids_filter") if document_ids_filter: query_str = { "bool": { - "must": {"match": {Field.CONTENT_KEY.value: query}}, + "must": {"match": {Field.CONTENT_KEY: query}}, "filter": {"terms": {"metadata.document_id": document_ids_filter}}, } } @@ -240,9 +240,9 @@ class ElasticSearchVector(BaseVector): for hit in results["hits"]["hits"]: docs.append( Document( - page_content=hit["_source"][Field.CONTENT_KEY.value], - vector=hit["_source"][Field.VECTOR.value], - metadata=hit["_source"][Field.METADATA_KEY.value], + page_content=hit["_source"][Field.CONTENT_KEY], + vector=hit["_source"][Field.VECTOR], + metadata=hit["_source"][Field.METADATA_KEY], ) ) @@ -270,14 +270,14 @@ class ElasticSearchVector(BaseVector): dim = len(embeddings[0]) mappings = { "properties": { - Field.CONTENT_KEY.value: {"type": "text"}, - Field.VECTOR.value: { # Make sure the dimension is correct here + Field.CONTENT_KEY: {"type": "text"}, + Field.VECTOR: { # Make sure the dimension is correct here "type": "dense_vector", "dims": dim, "index": True, "similarity": "cosine", }, - Field.METADATA_KEY.value: { + Field.METADATA_KEY: { "type": "object", "properties": { "doc_id": {"type": "keyword"}, # Map doc_id to keyword type diff --git a/api/core/rag/datasource/vdb/huawei/huawei_cloud_vector.py b/api/core/rag/datasource/vdb/huawei/huawei_cloud_vector.py index cfee090768..c7b6593a8f 100644 --- a/api/core/rag/datasource/vdb/huawei/huawei_cloud_vector.py +++ b/api/core/rag/datasource/vdb/huawei/huawei_cloud_vector.py @@ -67,9 +67,9 @@ class HuaweiCloudVector(BaseVector): index=self._collection_name, id=uuids[i], document={ - Field.CONTENT_KEY.value: documents[i].page_content, - Field.VECTOR.value: embeddings[i] or None, - Field.METADATA_KEY.value: documents[i].metadata or {}, + Field.CONTENT_KEY: documents[i].page_content, + Field.VECTOR: embeddings[i] or None, + Field.METADATA_KEY: documents[i].metadata or {}, }, ) self._client.indices.refresh(index=self._collection_name) @@ -101,7 +101,7 @@ class HuaweiCloudVector(BaseVector): "size": top_k, "query": { "vector": { - Field.VECTOR.value: { + Field.VECTOR: { "vector": query_vector, "topk": top_k, } @@ -116,9 +116,9 @@ class HuaweiCloudVector(BaseVector): docs_and_scores.append( ( Document( - page_content=hit["_source"][Field.CONTENT_KEY.value], - vector=hit["_source"][Field.VECTOR.value], - metadata=hit["_source"][Field.METADATA_KEY.value], + page_content=hit["_source"][Field.CONTENT_KEY], + vector=hit["_source"][Field.VECTOR], + metadata=hit["_source"][Field.METADATA_KEY], ), hit["_score"], ) @@ -135,15 +135,15 @@ class HuaweiCloudVector(BaseVector): return docs def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: - query_str = {"match": {Field.CONTENT_KEY.value: query}} + query_str = {"match": {Field.CONTENT_KEY: query}} results = self._client.search(index=self._collection_name, query=query_str, size=kwargs.get("top_k", 4)) docs = [] for hit in results["hits"]["hits"]: docs.append( Document( - page_content=hit["_source"][Field.CONTENT_KEY.value], - vector=hit["_source"][Field.VECTOR.value], - metadata=hit["_source"][Field.METADATA_KEY.value], + page_content=hit["_source"][Field.CONTENT_KEY], + vector=hit["_source"][Field.VECTOR], + metadata=hit["_source"][Field.METADATA_KEY], ) ) @@ -171,8 +171,8 @@ class HuaweiCloudVector(BaseVector): dim = len(embeddings[0]) mappings = { "properties": { - Field.CONTENT_KEY.value: {"type": "text"}, - Field.VECTOR.value: { # Make sure the dimension is correct here + Field.CONTENT_KEY: {"type": "text"}, + Field.VECTOR: { # Make sure the dimension is correct here "type": "vector", "dimension": dim, "indexing": True, @@ -181,7 +181,7 @@ class HuaweiCloudVector(BaseVector): "neighbors": 32, "efc": 128, }, - Field.METADATA_KEY.value: { + Field.METADATA_KEY: { "type": "object", "properties": { "doc_id": {"type": "keyword"} # Map doc_id to keyword type diff --git a/api/core/rag/datasource/vdb/lindorm/lindorm_vector.py b/api/core/rag/datasource/vdb/lindorm/lindorm_vector.py index 8824e1c67b..bfcb620618 100644 --- a/api/core/rag/datasource/vdb/lindorm/lindorm_vector.py +++ b/api/core/rag/datasource/vdb/lindorm/lindorm_vector.py @@ -125,9 +125,9 @@ class LindormVectorStore(BaseVector): } } action_values: dict[str, Any] = { - Field.CONTENT_KEY.value: documents[i].page_content, - Field.VECTOR.value: embeddings[i], - Field.METADATA_KEY.value: documents[i].metadata, + Field.CONTENT_KEY: documents[i].page_content, + Field.VECTOR: embeddings[i], + Field.METADATA_KEY: documents[i].metadata, } if self._using_ugc: action_header["index"]["routing"] = self._routing @@ -149,7 +149,7 @@ class LindormVectorStore(BaseVector): def get_ids_by_metadata_field(self, key: str, value: str): query: dict[str, Any] = { - "query": {"bool": {"must": [{"term": {f"{Field.METADATA_KEY.value}.{key}.keyword": value}}]}} + "query": {"bool": {"must": [{"term": {f"{Field.METADATA_KEY}.{key}.keyword": value}}]}} } if self._using_ugc: query["query"]["bool"]["must"].append({"term": {f"{ROUTING_FIELD}.keyword": self._routing}}) @@ -252,14 +252,14 @@ class LindormVectorStore(BaseVector): search_query: dict[str, Any] = { "size": top_k, "_source": True, - "query": {"knn": {Field.VECTOR.value: {"vector": query_vector, "k": top_k}}}, + "query": {"knn": {Field.VECTOR: {"vector": query_vector, "k": top_k}}}, } final_ext: dict[str, Any] = {"lvector": {}} if filters is not None and len(filters) > 0: # when using filter, transform filter from List[Dict] to Dict as valid format filter_dict = {"bool": {"must": filters}} if len(filters) > 1 else filters[0] - search_query["query"]["knn"][Field.VECTOR.value]["filter"] = filter_dict # filter should be Dict + search_query["query"]["knn"][Field.VECTOR]["filter"] = filter_dict # filter should be Dict final_ext["lvector"]["filter_type"] = "pre_filter" if final_ext != {"lvector": {}}: @@ -279,9 +279,9 @@ class LindormVectorStore(BaseVector): docs_and_scores.append( ( Document( - page_content=hit["_source"][Field.CONTENT_KEY.value], - vector=hit["_source"][Field.VECTOR.value], - metadata=hit["_source"][Field.METADATA_KEY.value], + page_content=hit["_source"][Field.CONTENT_KEY], + vector=hit["_source"][Field.VECTOR], + metadata=hit["_source"][Field.METADATA_KEY], ), hit["_score"], ) @@ -318,9 +318,9 @@ class LindormVectorStore(BaseVector): docs = [] for hit in response["hits"]["hits"]: - metadata = hit["_source"].get(Field.METADATA_KEY.value) - vector = hit["_source"].get(Field.VECTOR.value) - page_content = hit["_source"].get(Field.CONTENT_KEY.value) + metadata = hit["_source"].get(Field.METADATA_KEY) + vector = hit["_source"].get(Field.VECTOR) + page_content = hit["_source"].get(Field.CONTENT_KEY) doc = Document(page_content=page_content, vector=vector, metadata=metadata) docs.append(doc) @@ -342,8 +342,8 @@ class LindormVectorStore(BaseVector): "settings": {"index": {"knn": True, "knn_routing": self._using_ugc}}, "mappings": { "properties": { - Field.CONTENT_KEY.value: {"type": "text"}, - Field.VECTOR.value: { + Field.CONTENT_KEY: {"type": "text"}, + Field.VECTOR: { "type": "knn_vector", "dimension": len(embeddings[0]), # Make sure the dimension is correct here "method": { diff --git a/api/core/rag/datasource/vdb/milvus/milvus_vector.py b/api/core/rag/datasource/vdb/milvus/milvus_vector.py index 5f32feb709..96eb465401 100644 --- a/api/core/rag/datasource/vdb/milvus/milvus_vector.py +++ b/api/core/rag/datasource/vdb/milvus/milvus_vector.py @@ -85,7 +85,7 @@ class MilvusVector(BaseVector): collection_info = self._client.describe_collection(self._collection_name) fields = [field["name"] for field in collection_info["fields"]] # Since primary field is auto-id, no need to track it - self._fields = [f for f in fields if f != Field.PRIMARY_KEY.value] + self._fields = [f for f in fields if f != Field.PRIMARY_KEY] def _check_hybrid_search_support(self) -> bool: """ @@ -130,9 +130,9 @@ class MilvusVector(BaseVector): insert_dict = { # Do not need to insert the sparse_vector field separately, as the text_bm25_emb # function will automatically convert the native text into a sparse vector for us. - Field.CONTENT_KEY.value: documents[i].page_content, - Field.VECTOR.value: embeddings[i], - Field.METADATA_KEY.value: documents[i].metadata, + Field.CONTENT_KEY: documents[i].page_content, + Field.VECTOR: embeddings[i], + Field.METADATA_KEY: documents[i].metadata, } insert_dict_list.append(insert_dict) # Total insert count @@ -243,15 +243,15 @@ class MilvusVector(BaseVector): results = self._client.search( collection_name=self._collection_name, data=[query_vector], - anns_field=Field.VECTOR.value, + anns_field=Field.VECTOR, limit=kwargs.get("top_k", 4), - output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value], + output_fields=[Field.CONTENT_KEY, Field.METADATA_KEY], filter=filter, ) return self._process_search_results( results, - output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value], + output_fields=[Field.CONTENT_KEY, Field.METADATA_KEY], score_threshold=float(kwargs.get("score_threshold") or 0.0), ) @@ -264,7 +264,7 @@ class MilvusVector(BaseVector): "Full-text search is disabled: set MILVUS_ENABLE_HYBRID_SEARCH=true (requires Milvus >= 2.5.0)." ) return [] - if not self.field_exists(Field.SPARSE_VECTOR.value): + if not self.field_exists(Field.SPARSE_VECTOR): logger.warning( "Full-text search unavailable: collection missing 'sparse_vector' field; " "recreate the collection after enabling MILVUS_ENABLE_HYBRID_SEARCH to add BM25 sparse index." @@ -279,15 +279,15 @@ class MilvusVector(BaseVector): results = self._client.search( collection_name=self._collection_name, data=[query], - anns_field=Field.SPARSE_VECTOR.value, + anns_field=Field.SPARSE_VECTOR, limit=kwargs.get("top_k", 4), - output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value], + output_fields=[Field.CONTENT_KEY, Field.METADATA_KEY], filter=filter, ) return self._process_search_results( results, - output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value], + output_fields=[Field.CONTENT_KEY, Field.METADATA_KEY], score_threshold=float(kwargs.get("score_threshold") or 0.0), ) @@ -311,7 +311,7 @@ class MilvusVector(BaseVector): dim = len(embeddings[0]) fields = [] if metadatas: - fields.append(FieldSchema(Field.METADATA_KEY.value, DataType.JSON, max_length=65_535)) + fields.append(FieldSchema(Field.METADATA_KEY, DataType.JSON, max_length=65_535)) # Create the text field, enable_analyzer will be set True to support milvus automatically # transfer text to sparse_vector, reference: https://milvus.io/docs/full-text-search.md @@ -326,15 +326,15 @@ class MilvusVector(BaseVector): ): content_field_kwargs["analyzer_params"] = self._client_config.analyzer_params - fields.append(FieldSchema(Field.CONTENT_KEY.value, DataType.VARCHAR, **content_field_kwargs)) + fields.append(FieldSchema(Field.CONTENT_KEY, DataType.VARCHAR, **content_field_kwargs)) # Create the primary key field - fields.append(FieldSchema(Field.PRIMARY_KEY.value, DataType.INT64, is_primary=True, auto_id=True)) + fields.append(FieldSchema(Field.PRIMARY_KEY, DataType.INT64, is_primary=True, auto_id=True)) # Create the vector field, supports binary or float vectors - fields.append(FieldSchema(Field.VECTOR.value, infer_dtype_bydata(embeddings[0]), dim=dim)) + fields.append(FieldSchema(Field.VECTOR, infer_dtype_bydata(embeddings[0]), dim=dim)) # Create Sparse Vector Index for the collection if self._hybrid_search_enabled: - fields.append(FieldSchema(Field.SPARSE_VECTOR.value, DataType.SPARSE_FLOAT_VECTOR)) + fields.append(FieldSchema(Field.SPARSE_VECTOR, DataType.SPARSE_FLOAT_VECTOR)) schema = CollectionSchema(fields) @@ -342,8 +342,8 @@ class MilvusVector(BaseVector): if self._hybrid_search_enabled: bm25_function = Function( name="text_bm25_emb", - input_field_names=[Field.CONTENT_KEY.value], - output_field_names=[Field.SPARSE_VECTOR.value], + input_field_names=[Field.CONTENT_KEY], + output_field_names=[Field.SPARSE_VECTOR], function_type=FunctionType.BM25, ) schema.add_function(bm25_function) @@ -352,12 +352,12 @@ class MilvusVector(BaseVector): # Create Index params for the collection index_params_obj = IndexParams() - index_params_obj.add_index(field_name=Field.VECTOR.value, **index_params) + index_params_obj.add_index(field_name=Field.VECTOR, **index_params) # Create Sparse Vector Index for the collection if self._hybrid_search_enabled: index_params_obj.add_index( - field_name=Field.SPARSE_VECTOR.value, index_type="AUTOINDEX", metric_type="BM25" + field_name=Field.SPARSE_VECTOR, index_type="AUTOINDEX", metric_type="BM25" ) # Create the collection diff --git a/api/core/rag/datasource/vdb/oceanbase/oceanbase_vector.py b/api/core/rag/datasource/vdb/oceanbase/oceanbase_vector.py index 49cf900126..b3db7332e8 100644 --- a/api/core/rag/datasource/vdb/oceanbase/oceanbase_vector.py +++ b/api/core/rag/datasource/vdb/oceanbase/oceanbase_vector.py @@ -123,7 +123,7 @@ class OceanBaseVector(BaseVector): # Get parser from config or use default ik parser parser_name = dify_config.OCEANBASE_FULLTEXT_PARSER or "ik" - allowed_parsers = ["ik", "japanese_ftparser", "thai_ftparser"] + allowed_parsers = ["ngram", "beng", "space", "ngram2", "ik", "japanese_ftparser", "thai_ftparser"] if parser_name not in allowed_parsers: raise ValueError( f"Invalid OceanBase full-text parser: {parser_name}. " diff --git a/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py b/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py index 3eb1df027e..80ffdadd96 100644 --- a/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py +++ b/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py @@ -1,6 +1,6 @@ import json import logging -from typing import Any, Literal +from typing import Any from uuid import uuid4 from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnection, helpers @@ -8,6 +8,7 @@ from opensearchpy.helpers import BulkIndexError from pydantic import BaseModel, model_validator from configs import dify_config +from configs.middleware.vdb.opensearch_config import AuthMethod from core.rag.datasource.vdb.field import Field from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory @@ -25,7 +26,7 @@ class OpenSearchConfig(BaseModel): port: int secure: bool = False # use_ssl verify_certs: bool = True - auth_method: Literal["basic", "aws_managed_iam"] = "basic" + auth_method: AuthMethod = AuthMethod.BASIC user: str | None = None password: str | None = None aws_region: str | None = None @@ -98,9 +99,9 @@ class OpenSearchVector(BaseVector): "_op_type": "index", "_index": self._collection_name.lower(), "_source": { - Field.CONTENT_KEY.value: documents[i].page_content, - Field.VECTOR.value: embeddings[i], # Make sure you pass an array here - Field.METADATA_KEY.value: documents[i].metadata, + Field.CONTENT_KEY: documents[i].page_content, + Field.VECTOR: embeddings[i], # Make sure you pass an array here + Field.METADATA_KEY: documents[i].metadata, }, } # See https://github.com/langchain-ai/langchainjs/issues/4346#issuecomment-1935123377 @@ -116,7 +117,7 @@ class OpenSearchVector(BaseVector): ) def get_ids_by_metadata_field(self, key: str, value: str): - query = {"query": {"term": {f"{Field.METADATA_KEY.value}.{key}": value}}} + query = {"query": {"term": {f"{Field.METADATA_KEY}.{key}": value}}} response = self._client.search(index=self._collection_name.lower(), body=query) if response["hits"]["hits"]: return [hit["_id"] for hit in response["hits"]["hits"]] @@ -180,17 +181,17 @@ class OpenSearchVector(BaseVector): query = { "size": kwargs.get("top_k", 4), - "query": {"knn": {Field.VECTOR.value: {Field.VECTOR.value: query_vector, "k": kwargs.get("top_k", 4)}}}, + "query": {"knn": {Field.VECTOR: {Field.VECTOR: query_vector, "k": kwargs.get("top_k", 4)}}}, } document_ids_filter = kwargs.get("document_ids_filter") if document_ids_filter: query["query"] = { "script_score": { - "query": {"bool": {"filter": [{"terms": {Field.DOCUMENT_ID.value: document_ids_filter}}]}}, + "query": {"bool": {"filter": [{"terms": {Field.DOCUMENT_ID: document_ids_filter}}]}}, "script": { "source": "knn_score", "lang": "knn", - "params": {"field": Field.VECTOR.value, "query_value": query_vector, "space_type": "l2"}, + "params": {"field": Field.VECTOR, "query_value": query_vector, "space_type": "l2"}, }, } } @@ -203,7 +204,7 @@ class OpenSearchVector(BaseVector): docs = [] for hit in response["hits"]["hits"]: - metadata = hit["_source"].get(Field.METADATA_KEY.value, {}) + metadata = hit["_source"].get(Field.METADATA_KEY, {}) # Make sure metadata is a dictionary if metadata is None: @@ -212,7 +213,7 @@ class OpenSearchVector(BaseVector): metadata["score"] = hit["_score"] score_threshold = float(kwargs.get("score_threshold") or 0.0) if hit["_score"] >= score_threshold: - doc = Document(page_content=hit["_source"].get(Field.CONTENT_KEY.value), metadata=metadata) + doc = Document(page_content=hit["_source"].get(Field.CONTENT_KEY), metadata=metadata) docs.append(doc) return docs @@ -227,9 +228,9 @@ class OpenSearchVector(BaseVector): docs = [] for hit in response["hits"]["hits"]: - metadata = hit["_source"].get(Field.METADATA_KEY.value) - vector = hit["_source"].get(Field.VECTOR.value) - page_content = hit["_source"].get(Field.CONTENT_KEY.value) + metadata = hit["_source"].get(Field.METADATA_KEY) + vector = hit["_source"].get(Field.VECTOR) + page_content = hit["_source"].get(Field.CONTENT_KEY) doc = Document(page_content=page_content, vector=vector, metadata=metadata) docs.append(doc) @@ -250,8 +251,8 @@ class OpenSearchVector(BaseVector): "settings": {"index": {"knn": True}}, "mappings": { "properties": { - Field.CONTENT_KEY.value: {"type": "text"}, - Field.VECTOR.value: { + Field.CONTENT_KEY: {"type": "text"}, + Field.VECTOR: { "type": "knn_vector", "dimension": len(embeddings[0]), # Make sure the dimension is correct here "method": { @@ -261,7 +262,7 @@ class OpenSearchVector(BaseVector): "parameters": {"ef_construction": 64, "m": 8}, }, }, - Field.METADATA_KEY.value: { + Field.METADATA_KEY: { "type": "object", "properties": { "doc_id": {"type": "keyword"}, # Map doc_id to keyword type @@ -293,7 +294,7 @@ class OpenSearchVectorFactory(AbstractVectorFactory): port=dify_config.OPENSEARCH_PORT, secure=dify_config.OPENSEARCH_SECURE, verify_certs=dify_config.OPENSEARCH_VERIFY_CERTS, - auth_method=dify_config.OPENSEARCH_AUTH_METHOD.value, + auth_method=dify_config.OPENSEARCH_AUTH_METHOD, user=dify_config.OPENSEARCH_USER, password=dify_config.OPENSEARCH_PASSWORD, aws_region=dify_config.OPENSEARCH_AWS_REGION, diff --git a/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py b/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py index d46f29bd64..f8c62b908a 100644 --- a/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py +++ b/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py @@ -147,15 +147,13 @@ class QdrantVector(BaseVector): # create group_id payload index self._client.create_payload_index( - collection_name, Field.GROUP_KEY.value, field_schema=PayloadSchemaType.KEYWORD + collection_name, Field.GROUP_KEY, field_schema=PayloadSchemaType.KEYWORD ) # create doc_id payload index - self._client.create_payload_index( - collection_name, Field.DOC_ID.value, field_schema=PayloadSchemaType.KEYWORD - ) + self._client.create_payload_index(collection_name, Field.DOC_ID, field_schema=PayloadSchemaType.KEYWORD) # create document_id payload index self._client.create_payload_index( - collection_name, Field.DOCUMENT_ID.value, field_schema=PayloadSchemaType.KEYWORD + collection_name, Field.DOCUMENT_ID, field_schema=PayloadSchemaType.KEYWORD ) # create full text index text_index_params = TextIndexParams( @@ -165,9 +163,7 @@ class QdrantVector(BaseVector): max_token_len=20, lowercase=True, ) - self._client.create_payload_index( - collection_name, Field.CONTENT_KEY.value, field_schema=text_index_params - ) + self._client.create_payload_index(collection_name, Field.CONTENT_KEY, field_schema=text_index_params) redis_client.set(collection_exist_cache_key, 1, ex=3600) def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): @@ -220,10 +216,10 @@ class QdrantVector(BaseVector): self._build_payloads( batch_texts, batch_metadatas, - Field.CONTENT_KEY.value, - Field.METADATA_KEY.value, + Field.CONTENT_KEY, + Field.METADATA_KEY, group_id or "", # Ensure group_id is never None - Field.GROUP_KEY.value, + Field.GROUP_KEY, ), ) ] @@ -381,12 +377,12 @@ class QdrantVector(BaseVector): for result in results: if result.payload is None: continue - metadata = result.payload.get(Field.METADATA_KEY.value) or {} + metadata = result.payload.get(Field.METADATA_KEY) or {} # duplicate check score threshold if result.score >= score_threshold: metadata["score"] = result.score doc = Document( - page_content=result.payload.get(Field.CONTENT_KEY.value, ""), + page_content=result.payload.get(Field.CONTENT_KEY, ""), metadata=metadata, ) docs.append(doc) @@ -433,7 +429,7 @@ class QdrantVector(BaseVector): documents = [] for result in results: if result: - document = self._document_from_scored_point(result, Field.CONTENT_KEY.value, Field.METADATA_KEY.value) + document = self._document_from_scored_point(result, Field.CONTENT_KEY, Field.METADATA_KEY) documents.append(document) return documents diff --git a/api/core/rag/datasource/vdb/tablestore/tablestore_vector.py b/api/core/rag/datasource/vdb/tablestore/tablestore_vector.py index e91d9bb0d6..f2156afa59 100644 --- a/api/core/rag/datasource/vdb/tablestore/tablestore_vector.py +++ b/api/core/rag/datasource/vdb/tablestore/tablestore_vector.py @@ -55,7 +55,7 @@ class TableStoreVector(BaseVector): self._normalize_full_text_bm25_score = config.normalize_full_text_bm25_score self._table_name = f"{collection_name}" self._index_name = f"{collection_name}_idx" - self._tags_field = f"{Field.METADATA_KEY.value}_tags" + self._tags_field = f"{Field.METADATA_KEY}_tags" def create_collection(self, embeddings: list[list[float]], **kwargs): dimension = len(embeddings[0]) @@ -64,7 +64,7 @@ class TableStoreVector(BaseVector): def get_by_ids(self, ids: list[str]) -> list[Document]: docs = [] request = BatchGetRowRequest() - columns_to_get = [Field.METADATA_KEY.value, Field.CONTENT_KEY.value] + columns_to_get = [Field.METADATA_KEY, Field.CONTENT_KEY] rows_to_get = [[("id", _id)] for _id in ids] request.add(TableInBatchGetRowItem(self._table_name, rows_to_get, columns_to_get, None, 1)) @@ -73,11 +73,7 @@ class TableStoreVector(BaseVector): for item in table_result: if item.is_ok and item.row: kv = {k: v for k, v, _ in item.row.attribute_columns} - docs.append( - Document( - page_content=kv[Field.CONTENT_KEY.value], metadata=json.loads(kv[Field.METADATA_KEY.value]) - ) - ) + docs.append(Document(page_content=kv[Field.CONTENT_KEY], metadata=json.loads(kv[Field.METADATA_KEY]))) return docs def get_type(self) -> str: @@ -95,9 +91,9 @@ class TableStoreVector(BaseVector): self._write_row( primary_key=uuids[i], attributes={ - Field.CONTENT_KEY.value: documents[i].page_content, - Field.VECTOR.value: embeddings[i], - Field.METADATA_KEY.value: documents[i].metadata, + Field.CONTENT_KEY: documents[i].page_content, + Field.VECTOR: embeddings[i], + Field.METADATA_KEY: documents[i].metadata, }, ) return uuids @@ -180,7 +176,7 @@ class TableStoreVector(BaseVector): field_schemas = [ tablestore.FieldSchema( - Field.CONTENT_KEY.value, + Field.CONTENT_KEY, tablestore.FieldType.TEXT, analyzer=tablestore.AnalyzerType.MAXWORD, index=True, @@ -188,7 +184,7 @@ class TableStoreVector(BaseVector): store=False, ), tablestore.FieldSchema( - Field.VECTOR.value, + Field.VECTOR, tablestore.FieldType.VECTOR, vector_options=tablestore.VectorOptions( data_type=tablestore.VectorDataType.VD_FLOAT_32, @@ -197,7 +193,7 @@ class TableStoreVector(BaseVector): ), ), tablestore.FieldSchema( - Field.METADATA_KEY.value, + Field.METADATA_KEY, tablestore.FieldType.KEYWORD, index=True, store=False, @@ -233,15 +229,15 @@ class TableStoreVector(BaseVector): pk = [("id", primary_key)] tags = [] - for key, value in attributes[Field.METADATA_KEY.value].items(): + for key, value in attributes[Field.METADATA_KEY].items(): tags.append(str(key) + "=" + str(value)) attribute_columns = [ - (Field.CONTENT_KEY.value, attributes[Field.CONTENT_KEY.value]), - (Field.VECTOR.value, json.dumps(attributes[Field.VECTOR.value])), + (Field.CONTENT_KEY, attributes[Field.CONTENT_KEY]), + (Field.VECTOR, json.dumps(attributes[Field.VECTOR])), ( - Field.METADATA_KEY.value, - json.dumps(attributes[Field.METADATA_KEY.value]), + Field.METADATA_KEY, + json.dumps(attributes[Field.METADATA_KEY]), ), (self._tags_field, json.dumps(tags)), ] @@ -270,7 +266,7 @@ class TableStoreVector(BaseVector): index_name=self._index_name, search_query=query, columns_to_get=tablestore.ColumnsToGet( - column_names=[Field.PRIMARY_KEY.value], return_type=tablestore.ColumnReturnType.SPECIFIED + column_names=[Field.PRIMARY_KEY], return_type=tablestore.ColumnReturnType.SPECIFIED ), ) @@ -288,7 +284,7 @@ class TableStoreVector(BaseVector): self, query_vector: list[float], document_ids_filter: list[str] | None, top_k: int, score_threshold: float ) -> list[Document]: knn_vector_query = tablestore.KnnVectorQuery( - field_name=Field.VECTOR.value, + field_name=Field.VECTOR, top_k=top_k, float32_query_vector=query_vector, ) @@ -311,8 +307,8 @@ class TableStoreVector(BaseVector): for col in search_hit.row[1]: ots_column_map[col[0]] = col[1] - vector_str = ots_column_map.get(Field.VECTOR.value) - metadata_str = ots_column_map.get(Field.METADATA_KEY.value) + vector_str = ots_column_map.get(Field.VECTOR) + metadata_str = ots_column_map.get(Field.METADATA_KEY) vector = json.loads(vector_str) if vector_str else None metadata = json.loads(metadata_str) if metadata_str else {} @@ -321,7 +317,7 @@ class TableStoreVector(BaseVector): documents.append( Document( - page_content=ots_column_map.get(Field.CONTENT_KEY.value) or "", + page_content=ots_column_map.get(Field.CONTENT_KEY) or "", vector=vector, metadata=metadata, ) @@ -343,7 +339,7 @@ class TableStoreVector(BaseVector): self, query: str, document_ids_filter: list[str] | None, top_k: int, score_threshold: float ) -> list[Document]: bool_query = tablestore.BoolQuery(must_queries=[], filter_queries=[], should_queries=[], must_not_queries=[]) - bool_query.must_queries.append(tablestore.MatchQuery(text=query, field_name=Field.CONTENT_KEY.value)) + bool_query.must_queries.append(tablestore.MatchQuery(text=query, field_name=Field.CONTENT_KEY)) if document_ids_filter: bool_query.filter_queries.append(tablestore.TermsQuery(self._tags_field, document_ids_filter)) @@ -374,10 +370,10 @@ class TableStoreVector(BaseVector): for col in search_hit.row[1]: ots_column_map[col[0]] = col[1] - metadata_str = ots_column_map.get(Field.METADATA_KEY.value) + metadata_str = ots_column_map.get(Field.METADATA_KEY) metadata = json.loads(metadata_str) if metadata_str else {} - vector_str = ots_column_map.get(Field.VECTOR.value) + vector_str = ots_column_map.get(Field.VECTOR) vector = json.loads(vector_str) if vector_str else None if score: @@ -385,7 +381,7 @@ class TableStoreVector(BaseVector): documents.append( Document( - page_content=ots_column_map.get(Field.CONTENT_KEY.value) or "", + page_content=ots_column_map.get(Field.CONTENT_KEY) or "", vector=vector, metadata=metadata, ) diff --git a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py index f90a311df4..56ffb36a2b 100644 --- a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py +++ b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py @@ -5,9 +5,10 @@ from collections.abc import Generator, Iterable, Sequence from itertools import islice from typing import TYPE_CHECKING, Any, Union +import httpx import qdrant_client -import requests from flask import current_app +from httpx import DigestAuth from pydantic import BaseModel from qdrant_client.http import models as rest from qdrant_client.http.models import ( @@ -19,7 +20,6 @@ from qdrant_client.http.models import ( TokenizerType, ) from qdrant_client.local.qdrant_local import QdrantLocal -from requests.auth import HTTPDigestAuth from sqlalchemy import select from configs import dify_config @@ -141,15 +141,13 @@ class TidbOnQdrantVector(BaseVector): # create group_id payload index self._client.create_payload_index( - collection_name, Field.GROUP_KEY.value, field_schema=PayloadSchemaType.KEYWORD + collection_name, Field.GROUP_KEY, field_schema=PayloadSchemaType.KEYWORD ) # create doc_id payload index - self._client.create_payload_index( - collection_name, Field.DOC_ID.value, field_schema=PayloadSchemaType.KEYWORD - ) + self._client.create_payload_index(collection_name, Field.DOC_ID, field_schema=PayloadSchemaType.KEYWORD) # create document_id payload index self._client.create_payload_index( - collection_name, Field.DOCUMENT_ID.value, field_schema=PayloadSchemaType.KEYWORD + collection_name, Field.DOCUMENT_ID, field_schema=PayloadSchemaType.KEYWORD ) # create full text index text_index_params = TextIndexParams( @@ -159,9 +157,7 @@ class TidbOnQdrantVector(BaseVector): max_token_len=20, lowercase=True, ) - self._client.create_payload_index( - collection_name, Field.CONTENT_KEY.value, field_schema=text_index_params - ) + self._client.create_payload_index(collection_name, Field.CONTENT_KEY, field_schema=text_index_params) redis_client.set(collection_exist_cache_key, 1, ex=3600) def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): @@ -211,10 +207,10 @@ class TidbOnQdrantVector(BaseVector): self._build_payloads( batch_texts, batch_metadatas, - Field.CONTENT_KEY.value, - Field.METADATA_KEY.value, + Field.CONTENT_KEY, + Field.METADATA_KEY, group_id or "", - Field.GROUP_KEY.value, + Field.GROUP_KEY, ), ) ] @@ -349,13 +345,13 @@ class TidbOnQdrantVector(BaseVector): for result in results: if result.payload is None: continue - metadata = result.payload.get(Field.METADATA_KEY.value) or {} + metadata = result.payload.get(Field.METADATA_KEY) or {} # duplicate check score threshold score_threshold = kwargs.get("score_threshold") or 0.0 if result.score >= score_threshold: metadata["score"] = result.score doc = Document( - page_content=result.payload.get(Field.CONTENT_KEY.value, ""), + page_content=result.payload.get(Field.CONTENT_KEY, ""), metadata=metadata, ) docs.append(doc) @@ -392,7 +388,7 @@ class TidbOnQdrantVector(BaseVector): documents = [] for result in results: if result: - document = self._document_from_scored_point(result, Field.CONTENT_KEY.value, Field.METADATA_KEY.value) + document = self._document_from_scored_point(result, Field.CONTENT_KEY, Field.METADATA_KEY) documents.append(document) return documents @@ -504,10 +500,10 @@ class TidbOnQdrantVectorFactory(AbstractVectorFactory): } cluster_data = {"displayName": display_name, "region": region_object, "labels": labels} - response = requests.post( + response = httpx.post( f"{tidb_config.api_url}/clusters", json=cluster_data, - auth=HTTPDigestAuth(tidb_config.public_key, tidb_config.private_key), + auth=DigestAuth(tidb_config.public_key, tidb_config.private_key), ) if response.status_code == 200: @@ -527,10 +523,10 @@ class TidbOnQdrantVectorFactory(AbstractVectorFactory): body = {"password": new_password} - response = requests.put( + response = httpx.put( f"{tidb_config.api_url}/clusters/{cluster_id}/password", json=body, - auth=HTTPDigestAuth(tidb_config.public_key, tidb_config.private_key), + auth=DigestAuth(tidb_config.public_key, tidb_config.private_key), ) if response.status_code == 200: diff --git a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py index e1d4422144..754c149241 100644 --- a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py +++ b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py @@ -2,8 +2,8 @@ import time import uuid from collections.abc import Sequence -import requests -from requests.auth import HTTPDigestAuth +import httpx +from httpx import DigestAuth from configs import dify_config from extensions.ext_database import db @@ -49,7 +49,7 @@ class TidbService: "rootPassword": password, } - response = requests.post(f"{api_url}/clusters", json=cluster_data, auth=HTTPDigestAuth(public_key, private_key)) + response = httpx.post(f"{api_url}/clusters", json=cluster_data, auth=DigestAuth(public_key, private_key)) if response.status_code == 200: response_data = response.json() @@ -83,7 +83,7 @@ class TidbService: :return: The response from the API. """ - response = requests.delete(f"{api_url}/clusters/{cluster_id}", auth=HTTPDigestAuth(public_key, private_key)) + response = httpx.delete(f"{api_url}/clusters/{cluster_id}", auth=DigestAuth(public_key, private_key)) if response.status_code == 200: return response.json() @@ -102,7 +102,7 @@ class TidbService: :return: The response from the API. """ - response = requests.get(f"{api_url}/clusters/{cluster_id}", auth=HTTPDigestAuth(public_key, private_key)) + response = httpx.get(f"{api_url}/clusters/{cluster_id}", auth=DigestAuth(public_key, private_key)) if response.status_code == 200: return response.json() @@ -127,10 +127,10 @@ class TidbService: body = {"password": new_password, "builtinRole": "role_admin", "customRoles": []} - response = requests.patch( + response = httpx.patch( f"{api_url}/clusters/{cluster_id}/sqlUsers/{account}", json=body, - auth=HTTPDigestAuth(public_key, private_key), + auth=DigestAuth(public_key, private_key), ) if response.status_code == 200: @@ -161,9 +161,7 @@ class TidbService: tidb_serverless_list_map = {item.cluster_id: item for item in tidb_serverless_list} cluster_ids = [item.cluster_id for item in tidb_serverless_list] params = {"clusterIds": cluster_ids, "view": "BASIC"} - response = requests.get( - f"{api_url}/clusters:batchGet", params=params, auth=HTTPDigestAuth(public_key, private_key) - ) + response = httpx.get(f"{api_url}/clusters:batchGet", params=params, auth=DigestAuth(public_key, private_key)) if response.status_code == 200: response_data = response.json() @@ -224,8 +222,8 @@ class TidbService: clusters.append(cluster_data) request_body = {"requests": clusters} - response = requests.post( - f"{api_url}/clusters:batchCreate", json=request_body, auth=HTTPDigestAuth(public_key, private_key) + response = httpx.post( + f"{api_url}/clusters:batchCreate", json=request_body, auth=DigestAuth(public_key, private_key) ) if response.status_code == 200: diff --git a/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py b/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py index b8897c4165..27ae038a06 100644 --- a/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py +++ b/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py @@ -55,13 +55,13 @@ class TiDBVector(BaseVector): return Table( self._collection_name, self._orm_base.metadata, - Column(Field.PRIMARY_KEY.value, String(36), primary_key=True, nullable=False), + Column(Field.PRIMARY_KEY, String(36), primary_key=True, nullable=False), Column( - Field.VECTOR.value, + Field.VECTOR, VectorType(dim), nullable=False, ), - Column(Field.TEXT_KEY.value, TEXT, nullable=False), + Column(Field.TEXT_KEY, TEXT, nullable=False), Column("meta", JSON, nullable=False), Column("create_time", DateTime, server_default=sqlalchemy.text("CURRENT_TIMESTAMP")), Column( diff --git a/api/core/rag/datasource/vdb/vector_factory.py b/api/core/rag/datasource/vdb/vector_factory.py index dc4f026ff3..0beb388693 100644 --- a/api/core/rag/datasource/vdb/vector_factory.py +++ b/api/core/rag/datasource/vdb/vector_factory.py @@ -71,6 +71,12 @@ class Vector: from core.rag.datasource.vdb.milvus.milvus_vector import MilvusVectorFactory return MilvusVectorFactory + case VectorType.ALIBABACLOUD_MYSQL: + from core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector import ( + AlibabaCloudMySQLVectorFactory, + ) + + return AlibabaCloudMySQLVectorFactory case VectorType.MYSCALE: from core.rag.datasource.vdb.myscale.myscale_vector import MyScaleVectorFactory diff --git a/api/core/rag/datasource/vdb/vector_type.py b/api/core/rag/datasource/vdb/vector_type.py index a415142196..bc7d93a2e0 100644 --- a/api/core/rag/datasource/vdb/vector_type.py +++ b/api/core/rag/datasource/vdb/vector_type.py @@ -2,6 +2,7 @@ from enum import StrEnum class VectorType(StrEnum): + ALIBABACLOUD_MYSQL = "alibabacloud_mysql" ANALYTICDB = "analyticdb" CHROMA = "chroma" MILVUS = "milvus" diff --git a/api/core/rag/datasource/vdb/vikingdb/vikingdb_vector.py b/api/core/rag/datasource/vdb/vikingdb/vikingdb_vector.py index d1bdd3baef..e5feecf2bc 100644 --- a/api/core/rag/datasource/vdb/vikingdb/vikingdb_vector.py +++ b/api/core/rag/datasource/vdb/vikingdb/vikingdb_vector.py @@ -76,11 +76,11 @@ class VikingDBVector(BaseVector): if not self._has_collection(): fields = [ - Field(field_name=vdb_Field.PRIMARY_KEY.value, field_type=FieldType.String, is_primary_key=True), - Field(field_name=vdb_Field.METADATA_KEY.value, field_type=FieldType.String), - Field(field_name=vdb_Field.GROUP_KEY.value, field_type=FieldType.String), - Field(field_name=vdb_Field.CONTENT_KEY.value, field_type=FieldType.Text), - Field(field_name=vdb_Field.VECTOR.value, field_type=FieldType.Vector, dim=dimension), + Field(field_name=vdb_Field.PRIMARY_KEY, field_type=FieldType.String, is_primary_key=True), + Field(field_name=vdb_Field.METADATA_KEY, field_type=FieldType.String), + Field(field_name=vdb_Field.GROUP_KEY, field_type=FieldType.String), + Field(field_name=vdb_Field.CONTENT_KEY, field_type=FieldType.Text), + Field(field_name=vdb_Field.VECTOR, field_type=FieldType.Vector, dim=dimension), ] self._client.create_collection( @@ -100,7 +100,7 @@ class VikingDBVector(BaseVector): collection_name=self._collection_name, index_name=self._index_name, vector_index=vector_index, - partition_by=vdb_Field.GROUP_KEY.value, + partition_by=vdb_Field.GROUP_KEY, description="Index For Dify", ) redis_client.set(collection_exist_cache_key, 1, ex=3600) @@ -126,11 +126,11 @@ class VikingDBVector(BaseVector): # FIXME: fix the type of metadata later doc = Data( { - vdb_Field.PRIMARY_KEY.value: metadatas[i]["doc_id"], # type: ignore - vdb_Field.VECTOR.value: embeddings[i] if embeddings else None, - vdb_Field.CONTENT_KEY.value: page_content, - vdb_Field.METADATA_KEY.value: json.dumps(metadata), - vdb_Field.GROUP_KEY.value: self._group_id, + vdb_Field.PRIMARY_KEY: metadatas[i]["doc_id"], # type: ignore + vdb_Field.VECTOR: embeddings[i] if embeddings else None, + vdb_Field.CONTENT_KEY: page_content, + vdb_Field.METADATA_KEY: json.dumps(metadata), + vdb_Field.GROUP_KEY: self._group_id, } ) docs.append(doc) @@ -151,7 +151,7 @@ class VikingDBVector(BaseVector): # Note: Metadata field value is an dict, but vikingdb field # not support json type results = self._client.get_index(self._collection_name, self._index_name).search( - filter={"op": "must", "field": vdb_Field.GROUP_KEY.value, "conds": [self._group_id]}, + filter={"op": "must", "field": vdb_Field.GROUP_KEY, "conds": [self._group_id]}, # max value is 5000 limit=5000, ) @@ -161,7 +161,7 @@ class VikingDBVector(BaseVector): ids = [] for result in results: - metadata = result.fields.get(vdb_Field.METADATA_KEY.value) + metadata = result.fields.get(vdb_Field.METADATA_KEY) if metadata is not None: metadata = json.loads(metadata) if metadata.get(key) == value: @@ -189,12 +189,12 @@ class VikingDBVector(BaseVector): docs = [] for result in results: - metadata = result.fields.get(vdb_Field.METADATA_KEY.value) + metadata = result.fields.get(vdb_Field.METADATA_KEY) if metadata is not None: metadata = json.loads(metadata) if result.score >= score_threshold: metadata["score"] = result.score - doc = Document(page_content=result.fields.get(vdb_Field.CONTENT_KEY.value), metadata=metadata) + doc = Document(page_content=result.fields.get(vdb_Field.CONTENT_KEY), metadata=metadata) docs.append(doc) docs = sorted(docs, key=lambda x: x.metadata.get("score", 0) if x.metadata else 0, reverse=True) return docs diff --git a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py b/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py index 3ec08b93ed..8820c0a846 100644 --- a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py +++ b/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py @@ -2,7 +2,6 @@ import datetime import json from typing import Any -import requests import weaviate # type: ignore from pydantic import BaseModel, model_validator @@ -45,8 +44,8 @@ class WeaviateVector(BaseVector): client = weaviate.Client( url=config.endpoint, auth_client_secret=auth_config, timeout_config=(5, 60), startup_period=None ) - except requests.ConnectionError: - raise ConnectionError("Vector database connection error") + except Exception as exc: + raise ConnectionError("Vector database connection error") from exc client.batch.configure( # `batch_size` takes an `int` value to enable auto-batching @@ -105,7 +104,7 @@ class WeaviateVector(BaseVector): with self._client.batch as batch: for i, text in enumerate(texts): - data_properties = {Field.TEXT_KEY.value: text} + data_properties = {Field.TEXT_KEY: text} if metadatas is not None: # metadata maybe None for key, val in (metadatas[i] or {}).items(): @@ -183,7 +182,7 @@ class WeaviateVector(BaseVector): """Look up similar documents by embedding vector in Weaviate.""" collection_name = self._collection_name properties = self._attributes - properties.append(Field.TEXT_KEY.value) + properties.append(Field.TEXT_KEY) query_obj = self._client.query.get(collection_name, properties) vector = {"vector": query_vector} @@ -205,7 +204,7 @@ class WeaviateVector(BaseVector): docs_and_scores = [] for res in result["data"]["Get"][collection_name]: - text = res.pop(Field.TEXT_KEY.value) + text = res.pop(Field.TEXT_KEY) score = 1 - res["_additional"]["distance"] docs_and_scores.append((Document(page_content=text, metadata=res), score)) @@ -233,7 +232,7 @@ class WeaviateVector(BaseVector): collection_name = self._collection_name content: dict[str, Any] = {"concepts": [query]} properties = self._attributes - properties.append(Field.TEXT_KEY.value) + properties.append(Field.TEXT_KEY) if kwargs.get("search_distance"): content["certainty"] = kwargs.get("search_distance") query_obj = self._client.query.get(collection_name, properties) @@ -251,7 +250,7 @@ class WeaviateVector(BaseVector): raise ValueError(f"Error during query: {result['errors']}") docs = [] for res in result["data"]["Get"][collection_name]: - text = res.pop(Field.TEXT_KEY.value) + text = res.pop(Field.TEXT_KEY) additional = res.pop("_additional") docs.append(Document(page_content=text, vector=additional["vector"], metadata=res)) return docs diff --git a/api/core/rag/embedding/cached_embedding.py b/api/core/rag/embedding/cached_embedding.py index 5f94129a0c..c2f17cd148 100644 --- a/api/core/rag/embedding/cached_embedding.py +++ b/api/core/rag/embedding/cached_embedding.py @@ -42,6 +42,10 @@ class CacheEmbedding(Embeddings): text_embeddings[i] = embedding.get_embedding() else: embedding_queue_indices.append(i) + + # release database connection, because embedding may take a long time + db.session.close() + if embedding_queue_indices: embedding_queue_texts = [texts[i] for i in embedding_queue_indices] embedding_queue_embeddings = [] diff --git a/api/core/rag/entities/event.py b/api/core/rag/entities/event.py index 24db5d77be..a61b17ddb8 100644 --- a/api/core/rag/entities/event.py +++ b/api/core/rag/entities/event.py @@ -20,12 +20,12 @@ class BaseDatasourceEvent(BaseModel): class DatasourceErrorEvent(BaseDatasourceEvent): - event: str = DatasourceStreamEvent.ERROR.value + event: DatasourceStreamEvent = DatasourceStreamEvent.ERROR error: str = Field(..., description="error message") class DatasourceCompletedEvent(BaseDatasourceEvent): - event: str = DatasourceStreamEvent.COMPLETED.value + event: DatasourceStreamEvent = DatasourceStreamEvent.COMPLETED data: Mapping[str, Any] | list = Field(..., description="result") total: int | None = Field(default=0, description="total") completed: int | None = Field(default=0, description="completed") @@ -33,6 +33,6 @@ class DatasourceCompletedEvent(BaseDatasourceEvent): class DatasourceProcessingEvent(BaseDatasourceEvent): - event: str = DatasourceStreamEvent.PROCESSING.value + event: DatasourceStreamEvent = DatasourceStreamEvent.PROCESSING total: int | None = Field(..., description="total") completed: int | None = Field(..., description="completed") diff --git a/api/core/rag/extractor/entity/extract_setting.py b/api/core/rag/extractor/entity/extract_setting.py index b9bf9d0d8c..c3bfbce98f 100644 --- a/api/core/rag/extractor/entity/extract_setting.py +++ b/api/core/rag/extractor/entity/extract_setting.py @@ -17,9 +17,6 @@ class NotionInfo(BaseModel): tenant_id: str model_config = ConfigDict(arbitrary_types_allowed=True) - def __init__(self, **data): - super().__init__(**data) - class WebsiteInfo(BaseModel): """ @@ -47,6 +44,3 @@ class ExtractSetting(BaseModel): website_info: WebsiteInfo | None = None document_model: str | None = None model_config = ConfigDict(arbitrary_types_allowed=True) - - def __init__(self, **data): - super().__init__(**data) diff --git a/api/core/rag/extractor/extract_processor.py b/api/core/rag/extractor/extract_processor.py index 3dc08e1832..0f62f9c4b6 100644 --- a/api/core/rag/extractor/extract_processor.py +++ b/api/core/rag/extractor/extract_processor.py @@ -45,7 +45,7 @@ class ExtractProcessor: cls, upload_file: UploadFile, return_text: bool = False, is_automatic: bool = False ) -> Union[list[Document], str]: extract_setting = ExtractSetting( - datasource_type=DatasourceType.FILE.value, upload_file=upload_file, document_model="text_model" + datasource_type=DatasourceType.FILE, upload_file=upload_file, document_model="text_model" ) if return_text: delimiter = "\n" @@ -76,7 +76,7 @@ class ExtractProcessor: # https://stackoverflow.com/questions/26541416/generate-temporary-file-names-without-creating-actual-file-in-python#comment90414256_26541521 file_path = f"{temp_dir}/{tempfile.gettempdir()}{suffix}" Path(file_path).write_bytes(response.content) - extract_setting = ExtractSetting(datasource_type=DatasourceType.FILE.value, document_model="text_model") + extract_setting = ExtractSetting(datasource_type=DatasourceType.FILE, document_model="text_model") if return_text: delimiter = "\n" return delimiter.join( @@ -92,7 +92,7 @@ class ExtractProcessor: def extract( cls, extract_setting: ExtractSetting, is_automatic: bool = False, file_path: str | None = None ) -> list[Document]: - if extract_setting.datasource_type == DatasourceType.FILE.value: + if extract_setting.datasource_type == DatasourceType.FILE: with tempfile.TemporaryDirectory() as temp_dir: if not file_path: assert extract_setting.upload_file is not None, "upload_file is required" @@ -163,7 +163,7 @@ class ExtractProcessor: # txt extractor = TextExtractor(file_path, autodetect_encoding=True) return extractor.extract() - elif extract_setting.datasource_type == DatasourceType.NOTION.value: + elif extract_setting.datasource_type == DatasourceType.NOTION: assert extract_setting.notion_info is not None, "notion_info is required" extractor = NotionExtractor( notion_workspace_id=extract_setting.notion_info.notion_workspace_id, @@ -174,7 +174,7 @@ class ExtractProcessor: credential_id=extract_setting.notion_info.credential_id, ) return extractor.extract() - elif extract_setting.datasource_type == DatasourceType.WEBSITE.value: + elif extract_setting.datasource_type == DatasourceType.WEBSITE: assert extract_setting.website_info is not None, "website_info is required" if extract_setting.website_info.provider == "firecrawl": extractor = FirecrawlWebExtractor( diff --git a/api/core/rag/extractor/firecrawl/firecrawl_app.py b/api/core/rag/extractor/firecrawl/firecrawl_app.py index e1ba6ef243..c20ecd2b89 100644 --- a/api/core/rag/extractor/firecrawl/firecrawl_app.py +++ b/api/core/rag/extractor/firecrawl/firecrawl_app.py @@ -2,7 +2,7 @@ import json import time from typing import Any, cast -import requests +import httpx from extensions.ext_storage import storage @@ -104,18 +104,18 @@ class FirecrawlApp: def _prepare_headers(self) -> dict[str, Any]: return {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"} - def _post_request(self, url, data, headers, retries=3, backoff_factor=0.5) -> requests.Response: + def _post_request(self, url, data, headers, retries=3, backoff_factor=0.5) -> httpx.Response: for attempt in range(retries): - response = requests.post(url, headers=headers, json=data) + response = httpx.post(url, headers=headers, json=data) if response.status_code == 502: time.sleep(backoff_factor * (2**attempt)) else: return response return response - def _get_request(self, url, headers, retries=3, backoff_factor=0.5) -> requests.Response: + def _get_request(self, url, headers, retries=3, backoff_factor=0.5) -> httpx.Response: for attempt in range(retries): - response = requests.get(url, headers=headers) + response = httpx.get(url, headers=headers) if response.status_code == 502: time.sleep(backoff_factor * (2**attempt)) else: diff --git a/api/core/rag/extractor/notion_extractor.py b/api/core/rag/extractor/notion_extractor.py index bddf41af43..e87ab38349 100644 --- a/api/core/rag/extractor/notion_extractor.py +++ b/api/core/rag/extractor/notion_extractor.py @@ -3,7 +3,7 @@ import logging import operator from typing import Any, cast -import requests +import httpx from configs import dify_config from core.rag.extractor.extractor_base import BaseExtractor @@ -92,7 +92,7 @@ class NotionExtractor(BaseExtractor): if next_cursor: current_query["start_cursor"] = next_cursor - res = requests.post( + res = httpx.post( DATABASE_URL_TMPL.format(database_id=database_id), headers={ "Authorization": "Bearer " + self._notion_access_token, @@ -160,7 +160,7 @@ class NotionExtractor(BaseExtractor): while True: query_dict: dict[str, Any] = {} if not start_cursor else {"start_cursor": start_cursor} try: - res = requests.request( + res = httpx.request( "GET", block_url, headers={ @@ -173,7 +173,7 @@ class NotionExtractor(BaseExtractor): if res.status_code != 200: raise ValueError(f"Error fetching Notion block data: {res.text}") data = res.json() - except requests.RequestException as e: + except httpx.HTTPError as e: raise ValueError("Error fetching Notion block data") from e if "results" not in data or not isinstance(data["results"], list): raise ValueError("Error fetching Notion block data") @@ -222,7 +222,7 @@ class NotionExtractor(BaseExtractor): while True: query_dict: dict[str, Any] = {} if not start_cursor else {"start_cursor": start_cursor} - res = requests.request( + res = httpx.request( "GET", block_url, headers={ @@ -282,7 +282,7 @@ class NotionExtractor(BaseExtractor): while not done: query_dict: dict[str, Any] = {} if not start_cursor else {"start_cursor": start_cursor} - res = requests.request( + res = httpx.request( "GET", block_url, headers={ @@ -354,7 +354,7 @@ class NotionExtractor(BaseExtractor): query_dict: dict[str, Any] = {} - res = requests.request( + res = httpx.request( "GET", retrieve_page_url, headers={ diff --git a/api/core/rag/extractor/watercrawl/client.py b/api/core/rag/extractor/watercrawl/client.py index 6d596e07d8..7cf6c4d289 100644 --- a/api/core/rag/extractor/watercrawl/client.py +++ b/api/core/rag/extractor/watercrawl/client.py @@ -3,8 +3,8 @@ from collections.abc import Generator from typing import Union from urllib.parse import urljoin -import requests -from requests import Response +import httpx +from httpx import Response from core.rag.extractor.watercrawl.exceptions import ( WaterCrawlAuthenticationError, @@ -20,28 +20,45 @@ class BaseAPIClient: self.session = self.init_session() def init_session(self): - session = requests.Session() - session.headers.update({"X-API-Key": self.api_key}) - session.headers.update({"Content-Type": "application/json"}) - session.headers.update({"Accept": "application/json"}) - session.headers.update({"User-Agent": "WaterCrawl-Plugin"}) - session.headers.update({"Accept-Language": "en-US"}) - return session + headers = { + "X-API-Key": self.api_key, + "Content-Type": "application/json", + "Accept": "application/json", + "User-Agent": "WaterCrawl-Plugin", + "Accept-Language": "en-US", + } + return httpx.Client(headers=headers, timeout=None) + + def _request( + self, + method: str, + endpoint: str, + query_params: dict | None = None, + data: dict | None = None, + **kwargs, + ) -> Response: + stream = kwargs.pop("stream", False) + url = urljoin(self.base_url, endpoint) + if stream: + request = self.session.build_request(method, url, params=query_params, json=data) + return self.session.send(request, stream=True, **kwargs) + + return self.session.request(method, url, params=query_params, json=data, **kwargs) def _get(self, endpoint: str, query_params: dict | None = None, **kwargs): - return self.session.get(urljoin(self.base_url, endpoint), params=query_params, **kwargs) + return self._request("GET", endpoint, query_params=query_params, **kwargs) def _post(self, endpoint: str, query_params: dict | None = None, data: dict | None = None, **kwargs): - return self.session.post(urljoin(self.base_url, endpoint), params=query_params, json=data, **kwargs) + return self._request("POST", endpoint, query_params=query_params, data=data, **kwargs) def _put(self, endpoint: str, query_params: dict | None = None, data: dict | None = None, **kwargs): - return self.session.put(urljoin(self.base_url, endpoint), params=query_params, json=data, **kwargs) + return self._request("PUT", endpoint, query_params=query_params, data=data, **kwargs) def _delete(self, endpoint: str, query_params: dict | None = None, **kwargs): - return self.session.delete(urljoin(self.base_url, endpoint), params=query_params, **kwargs) + return self._request("DELETE", endpoint, query_params=query_params, **kwargs) def _patch(self, endpoint: str, query_params: dict | None = None, data: dict | None = None, **kwargs): - return self.session.patch(urljoin(self.base_url, endpoint), params=query_params, json=data, **kwargs) + return self._request("PATCH", endpoint, query_params=query_params, data=data, **kwargs) class WaterCrawlAPIClient(BaseAPIClient): @@ -49,14 +66,17 @@ class WaterCrawlAPIClient(BaseAPIClient): super().__init__(api_key, base_url) def process_eventstream(self, response: Response, download: bool = False) -> Generator: - for line in response.iter_lines(): - line = line.decode("utf-8") - if line.startswith("data:"): - line = line[5:].strip() - data = json.loads(line) - if data["type"] == "result" and download: - data["data"] = self.download_result(data["data"]) - yield data + try: + for raw_line in response.iter_lines(): + line = raw_line.decode("utf-8") if isinstance(raw_line, bytes) else raw_line + if line.startswith("data:"): + line = line[5:].strip() + data = json.loads(line) + if data["type"] == "result" and download: + data["data"] = self.download_result(data["data"]) + yield data + finally: + response.close() def process_response(self, response: Response) -> dict | bytes | list | None | Generator: if response.status_code == 401: @@ -170,7 +190,10 @@ class WaterCrawlAPIClient(BaseAPIClient): return event_data["data"] def download_result(self, result_object: dict): - response = requests.get(result_object["result"]) - response.raise_for_status() - result_object["result"] = response.json() + response = httpx.get(result_object["result"], timeout=None) + try: + response.raise_for_status() + result_object["result"] = response.json() + finally: + response.close() return result_object diff --git a/api/core/rag/extractor/word_extractor.py b/api/core/rag/extractor/word_extractor.py index f25f92cf81..1a9704688a 100644 --- a/api/core/rag/extractor/word_extractor.py +++ b/api/core/rag/extractor/word_extractor.py @@ -9,7 +9,7 @@ import uuid from urllib.parse import urlparse from xml.etree import ElementTree -import requests +import httpx from docx import Document as DocxDocument from configs import dify_config @@ -43,15 +43,19 @@ class WordExtractor(BaseExtractor): # If the file is a web path, download it to a temporary file, and use that if not os.path.isfile(self.file_path) and self._is_valid_url(self.file_path): - r = requests.get(self.file_path) + response = httpx.get(self.file_path, timeout=None) - if r.status_code != 200: - raise ValueError(f"Check the url of your file; returned status code {r.status_code}") + if response.status_code != 200: + response.close() + raise ValueError(f"Check the url of your file; returned status code {response.status_code}") self.web_path = self.file_path # TODO: use a better way to handle the file self.temp_file = tempfile.NamedTemporaryFile() # noqa SIM115 - self.temp_file.write(r.content) + try: + self.temp_file.write(response.content) + finally: + response.close() self.file_path = self.temp_file.name elif not os.path.isfile(self.file_path): raise ValueError(f"File path {self.file_path} is not a valid file or url") diff --git a/api/core/rag/index_processor/processor/paragraph_index_processor.py b/api/core/rag/index_processor/processor/paragraph_index_processor.py index 755aa88d08..4fcffbcc77 100644 --- a/api/core/rag/index_processor/processor/paragraph_index_processor.py +++ b/api/core/rag/index_processor/processor/paragraph_index_processor.py @@ -38,11 +38,11 @@ class ParagraphIndexProcessor(BaseIndexProcessor): raise ValueError("No process rule found.") if process_rule.get("mode") == "automatic": automatic_rule = DatasetProcessRule.AUTOMATIC_RULES - rules = Rule(**automatic_rule) + rules = Rule.model_validate(automatic_rule) else: if not process_rule.get("rules"): raise ValueError("No rules found in process rule.") - rules = Rule(**process_rule.get("rules")) + rules = Rule.model_validate(process_rule.get("rules")) # Split the text documents into nodes. if not rules.segmentation: raise ValueError("No segmentation found in rules.") diff --git a/api/core/rag/index_processor/processor/parent_child_index_processor.py b/api/core/rag/index_processor/processor/parent_child_index_processor.py index e0ccd8b567..7bdde286f5 100644 --- a/api/core/rag/index_processor/processor/parent_child_index_processor.py +++ b/api/core/rag/index_processor/processor/parent_child_index_processor.py @@ -40,7 +40,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor): raise ValueError("No process rule found.") if not process_rule.get("rules"): raise ValueError("No rules found in process rule.") - rules = Rule(**process_rule.get("rules")) + rules = Rule.model_validate(process_rule.get("rules")) all_documents: list[Document] = [] if rules.parent_mode == ParentMode.PARAGRAPH: # Split the text documents into nodes. @@ -110,7 +110,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor): child_documents = document.children if child_documents: formatted_child_documents = [ - Document(**child_document.model_dump()) for child_document in child_documents + Document.model_validate(child_document.model_dump()) for child_document in child_documents ] vector.create(formatted_child_documents) @@ -224,7 +224,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor): return child_nodes def index(self, dataset: Dataset, document: DatasetDocument, chunks: Any): - parent_childs = ParentChildStructureChunk(**chunks) + parent_childs = ParentChildStructureChunk.model_validate(chunks) documents = [] for parent_child in parent_childs.parent_child_chunks: metadata = { @@ -274,7 +274,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor): vector.create(all_child_documents) def format_preview(self, chunks: Any) -> Mapping[str, Any]: - parent_childs = ParentChildStructureChunk(**chunks) + parent_childs = ParentChildStructureChunk.model_validate(chunks) preview = [] for parent_child in parent_childs.parent_child_chunks: preview.append({"content": parent_child.parent_content, "child_chunks": parent_child.child_contents}) diff --git a/api/core/rag/index_processor/processor/qa_index_processor.py b/api/core/rag/index_processor/processor/qa_index_processor.py index 2054031643..9c8f70dba8 100644 --- a/api/core/rag/index_processor/processor/qa_index_processor.py +++ b/api/core/rag/index_processor/processor/qa_index_processor.py @@ -47,7 +47,7 @@ class QAIndexProcessor(BaseIndexProcessor): raise ValueError("No process rule found.") if not process_rule.get("rules"): raise ValueError("No rules found in process rule.") - rules = Rule(**process_rule.get("rules")) + rules = Rule.model_validate(process_rule.get("rules")) splitter = self._get_splitter( processing_rule_mode=process_rule.get("mode"), max_tokens=rules.segmentation.max_tokens if rules.segmentation else 0, @@ -168,7 +168,7 @@ class QAIndexProcessor(BaseIndexProcessor): return docs def index(self, dataset: Dataset, document: DatasetDocument, chunks: Any): - qa_chunks = QAStructureChunk(**chunks) + qa_chunks = QAStructureChunk.model_validate(chunks) documents = [] for qa_chunk in qa_chunks.qa_chunks: metadata = { @@ -191,7 +191,7 @@ class QAIndexProcessor(BaseIndexProcessor): raise ValueError("Indexing technique must be high quality.") def format_preview(self, chunks: Any) -> Mapping[str, Any]: - qa_chunks = QAStructureChunk(**chunks) + qa_chunks = QAStructureChunk.model_validate(chunks) preview = [] for qa_chunk in qa_chunks.qa_chunks: preview.append({"question": qa_chunk.question, "answer": qa_chunk.answer}) diff --git a/api/core/rag/rerank/rerank_factory.py b/api/core/rag/rerank/rerank_factory.py index 1a3cf85736..524e83824c 100644 --- a/api/core/rag/rerank/rerank_factory.py +++ b/api/core/rag/rerank/rerank_factory.py @@ -8,9 +8,9 @@ class RerankRunnerFactory: @staticmethod def create_rerank_runner(runner_type: str, *args, **kwargs) -> BaseRerankRunner: match runner_type: - case RerankMode.RERANKING_MODEL.value: + case RerankMode.RERANKING_MODEL: return RerankModelRunner(*args, **kwargs) - case RerankMode.WEIGHTED_SCORE.value: + case RerankMode.WEIGHTED_SCORE: return WeightRerankRunner(*args, **kwargs) case _: raise ValueError(f"Unknown runner type: {runner_type}") diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index b08f80da49..0a702d2902 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -61,7 +61,7 @@ from models.dataset import Document as DatasetDocument from services.external_knowledge_service import ExternalDatasetService default_retrieval_model: dict[str, Any] = { - "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, + "search_method": RetrievalMethod.SEMANTIC_SEARCH, "reranking_enable": False, "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, "top_k": 4, @@ -692,7 +692,7 @@ class DatasetRetrieval: if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE: # get retrieval model config default_retrieval_model = { - "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, + "search_method": RetrievalMethod.SEMANTIC_SEARCH, "reranking_enable": False, "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, "top_k": 2, diff --git a/api/core/rag/retrieval/retrieval_methods.py b/api/core/rag/retrieval/retrieval_methods.py index c7c6e60c8d..5f0f2a9d33 100644 --- a/api/core/rag/retrieval/retrieval_methods.py +++ b/api/core/rag/retrieval/retrieval_methods.py @@ -9,8 +9,8 @@ class RetrievalMethod(Enum): @staticmethod def is_support_semantic_search(retrieval_method: str) -> bool: - return retrieval_method in {RetrievalMethod.SEMANTIC_SEARCH.value, RetrievalMethod.HYBRID_SEARCH.value} + return retrieval_method in {RetrievalMethod.SEMANTIC_SEARCH, RetrievalMethod.HYBRID_SEARCH} @staticmethod def is_support_fulltext_search(retrieval_method: str) -> bool: - return retrieval_method in {RetrievalMethod.FULL_TEXT_SEARCH.value, RetrievalMethod.HYBRID_SEARCH.value} + return retrieval_method in {RetrievalMethod.FULL_TEXT_SEARCH, RetrievalMethod.HYBRID_SEARCH} diff --git a/api/core/rag/splitter/fixed_text_splitter.py b/api/core/rag/splitter/fixed_text_splitter.py index 8356861242..801d2a2a52 100644 --- a/api/core/rag/splitter/fixed_text_splitter.py +++ b/api/core/rag/splitter/fixed_text_splitter.py @@ -2,6 +2,7 @@ from __future__ import annotations +import re from typing import Any from core.model_manager import ModelInstance @@ -52,7 +53,7 @@ class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter) """Create a new TextSplitter.""" super().__init__(**kwargs) self._fixed_separator = fixed_separator - self._separators = separators or ["\n\n", "\n", " ", ""] + self._separators = separators or ["\n\n", "\n", "。", ". ", " ", ""] def split_text(self, text: str) -> list[str]: """Split incoming text and return chunks.""" @@ -90,16 +91,19 @@ class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter) # Now that we have the separator, split the text if separator: if separator == " ": - splits = text.split() + splits = re.split(r" +", text) else: splits = text.split(separator) splits = [item + separator if i < len(splits) else item for i, item in enumerate(splits)] else: splits = list(text) - splits = [s for s in splits if (s not in {"", "\n"})] + if separator == "\n": + splits = [s for s in splits if s != ""] + else: + splits = [s for s in splits if (s not in {"", "\n"})] _good_splits = [] _good_splits_lengths = [] # cache the lengths of the splits - _separator = "" if self._keep_separator else separator + _separator = separator if self._keep_separator else "" s_lens = self._length_function(splits) if separator != "": for s, s_len in zip(splits, s_lens): diff --git a/api/core/tools/__base/tool_runtime.py b/api/core/tools/__base/tool_runtime.py index 3de0014c61..09bc817c01 100644 --- a/api/core/tools/__base/tool_runtime.py +++ b/api/core/tools/__base/tool_runtime.py @@ -1,7 +1,6 @@ from typing import Any -from openai import BaseModel -from pydantic import Field +from pydantic import BaseModel, Field from core.app.entities.app_invoke_entities import InvokeFrom from core.tools.entities.tool_entities import CredentialType, ToolInvokeFrom diff --git a/api/core/tools/builtin_tool/provider.py b/api/core/tools/builtin_tool/provider.py index 45fd16d684..2e94907f30 100644 --- a/api/core/tools/builtin_tool/provider.py +++ b/api/core/tools/builtin_tool/provider.py @@ -90,7 +90,7 @@ class BuiltinToolProviderController(ToolProviderController): tools.append( assistant_tool_class( provider=provider, - entity=ToolEntity(**tool), + entity=ToolEntity.model_validate(tool), runtime=ToolRuntime(tenant_id=""), ) ) @@ -111,7 +111,7 @@ class BuiltinToolProviderController(ToolProviderController): :return: the credentials schema """ - return self.get_credentials_schema_by_type(CredentialType.API_KEY.value) + return self.get_credentials_schema_by_type(CredentialType.API_KEY) def get_credentials_schema_by_type(self, credential_type: str) -> list[ProviderConfig]: """ @@ -122,7 +122,7 @@ class BuiltinToolProviderController(ToolProviderController): """ if credential_type == CredentialType.OAUTH2.value: return self.entity.oauth_schema.credentials_schema.copy() if self.entity.oauth_schema else [] - if credential_type == CredentialType.API_KEY.value: + if credential_type == CredentialType.API_KEY: return self.entity.credentials_schema.copy() if self.entity.credentials_schema else [] raise ValueError(f"Invalid credential type: {credential_type}") @@ -134,15 +134,15 @@ class BuiltinToolProviderController(ToolProviderController): """ return self.entity.oauth_schema.client_schema.copy() if self.entity.oauth_schema else [] - def get_supported_credential_types(self) -> list[str]: + def get_supported_credential_types(self) -> list[CredentialType]: """ returns the credential support type of the provider """ types = [] if self.entity.credentials_schema is not None and len(self.entity.credentials_schema) > 0: - types.append(CredentialType.API_KEY.value) + types.append(CredentialType.API_KEY) if self.entity.oauth_schema is not None and len(self.entity.oauth_schema.credentials_schema) > 0: - types.append(CredentialType.OAUTH2.value) + types.append(CredentialType.OAUTH2) return types def get_tools(self) -> list[BuiltinTool]: diff --git a/api/core/tools/builtin_tool/providers/code/_assets/icon.svg b/api/core/tools/builtin_tool/providers/code/_assets/icon.svg index b986ed9426..154726a081 100644 --- a/api/core/tools/builtin_tool/providers/code/_assets/icon.svg +++ b/api/core/tools/builtin_tool/providers/code/_assets/icon.svg @@ -1 +1 @@ - \ No newline at end of file + \ No newline at end of file diff --git a/api/core/tools/custom_tool/tool.py b/api/core/tools/custom_tool/tool.py index 34d0f5c622..f18f638f2d 100644 --- a/api/core/tools/custom_tool/tool.py +++ b/api/core/tools/custom_tool/tool.py @@ -290,6 +290,7 @@ class ApiTool(Tool): method_lc ]( # https://discuss.python.org/t/type-inference-for-function-return-types/42926 url, + max_retries=0, params=params, headers=headers, cookies=cookies, diff --git a/api/core/tools/entities/api_entities.py b/api/core/tools/entities/api_entities.py index 00c4ab9dd7..de6bf01ae9 100644 --- a/api/core/tools/entities/api_entities.py +++ b/api/core/tools/entities/api_entities.py @@ -61,7 +61,7 @@ class ToolProviderApiEntity(BaseModel): for tool in tools: if tool.get("parameters"): for parameter in tool.get("parameters"): - if parameter.get("type") == ToolParameter.ToolParameterType.SYSTEM_FILES.value: + if parameter.get("type") == ToolParameter.ToolParameterType.SYSTEM_FILES: parameter["type"] = "files" if parameter.get("input_schema") is None: parameter.pop("input_schema", None) @@ -110,7 +110,9 @@ class ToolProviderCredentialApiEntity(BaseModel): class ToolProviderCredentialInfoApiEntity(BaseModel): - supported_credential_types: list[str] = Field(description="The supported credential types of the provider") + supported_credential_types: list[CredentialType] = Field( + description="The supported credential types of the provider" + ) is_oauth_custom_client_enabled: bool = Field( default=False, description="Whether the OAuth custom client is enabled for the provider" ) diff --git a/api/core/tools/entities/common_entities.py b/api/core/tools/entities/common_entities.py index 2c6d9c1964..21d310bbb9 100644 --- a/api/core/tools/entities/common_entities.py +++ b/api/core/tools/entities/common_entities.py @@ -1,4 +1,4 @@ -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, model_validator class I18nObject(BaseModel): @@ -11,11 +11,12 @@ class I18nObject(BaseModel): pt_BR: str | None = Field(default=None) ja_JP: str | None = Field(default=None) - def __init__(self, **data): - super().__init__(**data) + @model_validator(mode="after") + def _populate_missing_locales(self): self.zh_Hans = self.zh_Hans or self.en_US self.pt_BR = self.pt_BR or self.en_US self.ja_JP = self.ja_JP or self.en_US + return self def to_dict(self): return {"zh_Hans": self.zh_Hans, "en_US": self.en_US, "pt_BR": self.pt_BR, "ja_JP": self.ja_JP} diff --git a/api/core/tools/entities/tool_entities.py b/api/core/tools/entities/tool_entities.py index a59b54216f..62e3aa8b5d 100644 --- a/api/core/tools/entities/tool_entities.py +++ b/api/core/tools/entities/tool_entities.py @@ -113,7 +113,7 @@ class ApiProviderAuthType(StrEnum): # normalize & tiny alias for backward compatibility v = (value or "").strip().lower() if v == "api_key": - v = cls.API_KEY_HEADER.value + v = cls.API_KEY_HEADER for mode in cls: if mode.value == v: diff --git a/api/core/tools/mcp_tool/provider.py b/api/core/tools/mcp_tool/provider.py index 5b04f0edbe..f269b8db9b 100644 --- a/api/core/tools/mcp_tool/provider.py +++ b/api/core/tools/mcp_tool/provider.py @@ -54,7 +54,7 @@ class MCPToolProviderController(ToolProviderController): """ tools = [] tools_data = json.loads(db_provider.tools) - remote_mcp_tools = [RemoteMCPTool(**tool) for tool in tools_data] + remote_mcp_tools = [RemoteMCPTool.model_validate(tool) for tool in tools_data] user = db_provider.load_user() tools = [ ToolEntity( diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index 9e5f5a7c23..af68971ca7 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -1008,7 +1008,7 @@ class ToolManager: config = tool_configurations.get(parameter.name, {}) if not (config and isinstance(config, dict) and config.get("value") is not None): continue - tool_input = ToolNodeData.ToolInput(**tool_configurations.get(parameter.name, {})) + tool_input = ToolNodeData.ToolInput.model_validate(tool_configurations.get(parameter.name, {})) if tool_input.type == "variable": variable = variable_pool.get(tool_input.value) if variable is None: diff --git a/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py b/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py index 75c0c6738e..b5bc4d3c00 100644 --- a/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py +++ b/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py @@ -18,7 +18,7 @@ from extensions.ext_database import db from models.dataset import Dataset, Document, DocumentSegment default_retrieval_model: dict[str, Any] = { - "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, + "search_method": RetrievalMethod.SEMANTIC_SEARCH, "reranking_enable": False, "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, "top_k": 2, @@ -126,7 +126,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool): data_source_type=document.data_source_type, segment_id=segment.id, retriever_from=self.retriever_from, - score=document_score_list.get(segment.index_node_id, None), + score=document_score_list.get(segment.index_node_id), doc_metadata=document.doc_metadata, ) diff --git a/api/core/tools/utils/dataset_retriever/dataset_retriever_base_tool.py b/api/core/tools/utils/dataset_retriever/dataset_retriever_base_tool.py index ac2967d0c1..dd0b4bedcf 100644 --- a/api/core/tools/utils/dataset_retriever/dataset_retriever_base_tool.py +++ b/api/core/tools/utils/dataset_retriever/dataset_retriever_base_tool.py @@ -18,6 +18,10 @@ class DatasetRetrieverBaseTool(BaseModel, ABC): retriever_from: str model_config = ConfigDict(arbitrary_types_allowed=True) + def run(self, query: str) -> str: + """Use the tool.""" + return self._run(query) + @abstractmethod def _run(self, query: str) -> str: """Use the tool. diff --git a/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py b/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py index 0e2237befd..1eae582f67 100644 --- a/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py +++ b/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py @@ -17,7 +17,7 @@ from models.dataset import Document as DatasetDocument from services.external_knowledge_service import ExternalDatasetService default_retrieval_model: dict[str, Any] = { - "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, + "search_method": RetrievalMethod.SEMANTIC_SEARCH, "reranking_enable": False, "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, "reranking_mode": "reranking_model", diff --git a/api/core/tools/utils/dataset_retriever_tool.py b/api/core/tools/utils/dataset_retriever_tool.py index a62d419243..fca6e6f1c7 100644 --- a/api/core/tools/utils/dataset_retriever_tool.py +++ b/api/core/tools/utils/dataset_retriever_tool.py @@ -124,7 +124,7 @@ class DatasetRetrieverTool(Tool): yield self.create_text_message(text="please input query") else: # invoke dataset retriever tool - result = self.retrieval_tool._run(query=query) + result = self.retrieval_tool.run(query=query) yield self.create_text_message(text=result) def validate_credentials( diff --git a/api/core/tools/utils/parser.py b/api/core/tools/utils/parser.py index 2e306db6c7..c7ac3387e5 100644 --- a/api/core/tools/utils/parser.py +++ b/api/core/tools/utils/parser.py @@ -2,9 +2,10 @@ import re from json import dumps as json_dumps from json import loads as json_loads from json.decoder import JSONDecodeError +from typing import Any +import httpx from flask import request -from requests import get from yaml import YAMLError, safe_load from core.tools.entities.common_entities import I18nObject @@ -127,34 +128,34 @@ class ApiBasedToolSchemaParser: if "allOf" in prop_dict: del prop_dict["allOf"] - # parse body parameters - if "schema" in interface["operation"]["requestBody"]["content"][content_type]: - body_schema = interface["operation"]["requestBody"]["content"][content_type]["schema"] - required = body_schema.get("required", []) - properties = body_schema.get("properties", {}) - for name, property in properties.items(): - tool = ToolParameter( - name=name, - label=I18nObject(en_US=name, zh_Hans=name), - human_description=I18nObject( - en_US=property.get("description", ""), zh_Hans=property.get("description", "") - ), - type=ToolParameter.ToolParameterType.STRING, - required=name in required, - form=ToolParameter.ToolParameterForm.LLM, - llm_description=property.get("description", ""), - default=property.get("default", None), - placeholder=I18nObject( - en_US=property.get("description", ""), zh_Hans=property.get("description", "") - ), - ) + # parse body parameters + if "schema" in interface["operation"]["requestBody"]["content"][content_type]: + body_schema = interface["operation"]["requestBody"]["content"][content_type]["schema"] + required = body_schema.get("required", []) + properties = body_schema.get("properties", {}) + for name, property in properties.items(): + tool = ToolParameter( + name=name, + label=I18nObject(en_US=name, zh_Hans=name), + human_description=I18nObject( + en_US=property.get("description", ""), zh_Hans=property.get("description", "") + ), + type=ToolParameter.ToolParameterType.STRING, + required=name in required, + form=ToolParameter.ToolParameterForm.LLM, + llm_description=property.get("description", ""), + default=property.get("default", None), + placeholder=I18nObject( + en_US=property.get("description", ""), zh_Hans=property.get("description", "") + ), + ) - # check if there is a type - typ = ApiBasedToolSchemaParser._get_tool_parameter_type(property) - if typ: - tool.type = typ + # check if there is a type + typ = ApiBasedToolSchemaParser._get_tool_parameter_type(property) + if typ: + tool.type = typ - parameters.append(tool) + parameters.append(tool) # check if parameters is duplicated parameters_count = {} @@ -241,7 +242,9 @@ class ApiBasedToolSchemaParser: return ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(openapi, extra_info=extra_info, warning=warning) @staticmethod - def parse_swagger_to_openapi(swagger: dict, extra_info: dict | None = None, warning: dict | None = None): + def parse_swagger_to_openapi( + swagger: dict, extra_info: dict | None = None, warning: dict | None = None + ) -> dict[str, Any]: warning = warning or {} """ parse swagger to openapi @@ -257,7 +260,7 @@ class ApiBasedToolSchemaParser: if len(servers) == 0: raise ToolApiSchemaError("No server found in the swagger yaml.") - openapi = { + converted_openapi: dict[str, Any] = { "openapi": "3.0.0", "info": { "title": info.get("title", "Swagger"), @@ -275,7 +278,7 @@ class ApiBasedToolSchemaParser: # convert paths for path, path_item in swagger["paths"].items(): - openapi["paths"][path] = {} + converted_openapi["paths"][path] = {} for method, operation in path_item.items(): if "operationId" not in operation: raise ToolApiSchemaError(f"No operationId found in operation {method} {path}.") @@ -286,7 +289,7 @@ class ApiBasedToolSchemaParser: if warning is not None: warning["missing_summary"] = f"No summary or description found in operation {method} {path}." - openapi["paths"][path][method] = { + converted_openapi["paths"][path][method] = { "operationId": operation["operationId"], "summary": operation.get("summary", ""), "description": operation.get("description", ""), @@ -295,13 +298,14 @@ class ApiBasedToolSchemaParser: } if "requestBody" in operation: - openapi["paths"][path][method]["requestBody"] = operation["requestBody"] + converted_openapi["paths"][path][method]["requestBody"] = operation["requestBody"] # convert definitions - for name, definition in swagger["definitions"].items(): - openapi["components"]["schemas"][name] = definition + if "definitions" in swagger: + for name, definition in swagger["definitions"].items(): + converted_openapi["components"]["schemas"][name] = definition - return openapi + return converted_openapi @staticmethod def parse_openai_plugin_json_to_tool_bundle( @@ -330,15 +334,20 @@ class ApiBasedToolSchemaParser: raise ToolNotSupportedError("Only openapi is supported now.") # get openapi yaml - response = get(api_url, headers={"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) "}, timeout=5) - - if response.status_code != 200: - raise ToolProviderNotFoundError("cannot get openapi yaml from url.") - - return ApiBasedToolSchemaParser.parse_openapi_yaml_to_tool_bundle( - response.text, extra_info=extra_info, warning=warning + response = httpx.get( + api_url, headers={"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) "}, timeout=5 ) + try: + if response.status_code != 200: + raise ToolProviderNotFoundError("cannot get openapi yaml from url.") + + return ApiBasedToolSchemaParser.parse_openapi_yaml_to_tool_bundle( + response.text, extra_info=extra_info, warning=warning + ) + finally: + response.close() + @staticmethod def auto_parse_to_tool_bundle( content: str, extra_info: dict | None = None, warning: dict | None = None @@ -384,7 +393,7 @@ class ApiBasedToolSchemaParser: openapi = ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle( loaded_content, extra_info=extra_info, warning=warning ) - schema_type = ApiProviderSchemaType.OPENAPI.value + schema_type = ApiProviderSchemaType.OPENAPI return openapi, schema_type except ToolApiSchemaError as e: openapi_error = e @@ -394,7 +403,7 @@ class ApiBasedToolSchemaParser: converted_swagger = ApiBasedToolSchemaParser.parse_swagger_to_openapi( loaded_content, extra_info=extra_info, warning=warning ) - schema_type = ApiProviderSchemaType.SWAGGER.value + schema_type = ApiProviderSchemaType.SWAGGER return ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle( converted_swagger, extra_info=extra_info, warning=warning ), schema_type @@ -406,7 +415,7 @@ class ApiBasedToolSchemaParser: openapi_plugin = ApiBasedToolSchemaParser.parse_openai_plugin_json_to_tool_bundle( json_dumps(loaded_content), extra_info=extra_info, warning=warning ) - return openapi_plugin, ApiProviderSchemaType.OPENAI_PLUGIN.value + return openapi_plugin, ApiProviderSchemaType.OPENAI_PLUGIN except ToolNotSupportedError as e: # maybe it's not plugin at all openapi_plugin_error = e diff --git a/api/core/workflow/README.md b/api/core/workflow/README.md index bef19ba90b..72f5dbe1e2 100644 --- a/api/core/workflow/README.md +++ b/api/core/workflow/README.md @@ -60,8 +60,8 @@ Extensible middleware for cross-cutting concerns: ```python engine = GraphEngine(graph) -engine.add_layer(DebugLoggingLayer(level="INFO")) -engine.add_layer(ExecutionLimitsLayer(max_nodes=100)) +engine.layer(DebugLoggingLayer(level="INFO")) +engine.layer(ExecutionLimitsLayer(max_nodes=100)) ``` ### Event-Driven Architecture @@ -117,7 +117,7 @@ The codebase enforces strict layering via import-linter: 1. Create class inheriting from `Layer` base 1. Override lifecycle methods: `on_graph_start()`, `on_event()`, `on_graph_end()` -1. Add to engine via `engine.add_layer()` +1. Add to engine via `engine.layer()` ### Debugging Workflow Execution diff --git a/api/core/workflow/entities/variable_pool.py b/api/core/workflow/entities/variable_pool.py index 8ceabde7e6..2dc00fd70b 100644 --- a/api/core/workflow/entities/variable_pool.py +++ b/api/core/workflow/entities/variable_pool.py @@ -184,11 +184,22 @@ class VariablePool(BaseModel): """Extract the actual value from an ObjectSegment.""" return obj.value if isinstance(obj, ObjectSegment) else obj - def _get_nested_attribute(self, obj: Mapping[str, Any], attr: str): - """Get a nested attribute from a dictionary-like object.""" - if not isinstance(obj, dict): + def _get_nested_attribute(self, obj: Mapping[str, Any], attr: str) -> Segment | None: + """ + Get a nested attribute from a dictionary-like object. + + Args: + obj: The dictionary-like object to search. + attr: The key to look up. + + Returns: + Segment | None: + The corresponding Segment built from the attribute value if the key exists, + otherwise None. + """ + if not isinstance(obj, dict) or attr not in obj: return None - return obj.get(attr) + return variable_factory.build_segment(obj.get(attr)) def remove(self, selector: Sequence[str], /): """ diff --git a/api/core/workflow/graph_engine/command_channels/redis_channel.py b/api/core/workflow/graph_engine/command_channels/redis_channel.py index 056e17bf5d..c841459170 100644 --- a/api/core/workflow/graph_engine/command_channels/redis_channel.py +++ b/api/core/workflow/graph_engine/command_channels/redis_channel.py @@ -105,10 +105,10 @@ class RedisChannel: command_type = CommandType(command_type_value) if command_type == CommandType.ABORT: - return AbortCommand(**data) + return AbortCommand.model_validate(data) else: # For other command types, use base class - return GraphEngineCommand(**data) + return GraphEngineCommand.model_validate(data) except (ValueError, TypeError): return None diff --git a/api/core/workflow/graph_engine/domain/graph_execution.py b/api/core/workflow/graph_engine/domain/graph_execution.py index 5951af1087..b273ee9969 100644 --- a/api/core/workflow/graph_engine/domain/graph_execution.py +++ b/api/core/workflow/graph_engine/domain/graph_execution.py @@ -41,7 +41,8 @@ class GraphExecutionState(BaseModel): completed: bool = Field(default=False) aborted: bool = Field(default=False) error: GraphExecutionErrorState | None = Field(default=None) - node_executions: list[NodeExecutionState] = Field(default_factory=list) + exceptions_count: int = Field(default=0) + node_executions: list[NodeExecutionState] = Field(default_factory=list[NodeExecutionState]) def _serialize_error(error: Exception | None) -> GraphExecutionErrorState | None: @@ -103,7 +104,8 @@ class GraphExecution: completed: bool = False aborted: bool = False error: Exception | None = None - node_executions: dict[str, NodeExecution] = field(default_factory=dict) + node_executions: dict[str, NodeExecution] = field(default_factory=dict[str, NodeExecution]) + exceptions_count: int = 0 def start(self) -> None: """Mark the graph execution as started.""" @@ -172,6 +174,7 @@ class GraphExecution: completed=self.completed, aborted=self.aborted, error=_serialize_error(self.error), + exceptions_count=self.exceptions_count, node_executions=node_states, ) @@ -195,6 +198,7 @@ class GraphExecution: self.completed = state.completed self.aborted = state.aborted self.error = _deserialize_error(state.error) + self.exceptions_count = state.exceptions_count self.node_executions = { item.node_id: NodeExecution( node_id=item.node_id, @@ -205,3 +209,7 @@ class GraphExecution: ) for item in state.node_executions } + + def record_node_failure(self) -> None: + """Increment the count of node failures encountered during execution.""" + self.exceptions_count += 1 diff --git a/api/core/workflow/graph_engine/event_management/event_handlers.py b/api/core/workflow/graph_engine/event_management/event_handlers.py index 244f4a4d86..7247b17967 100644 --- a/api/core/workflow/graph_engine/event_management/event_handlers.py +++ b/api/core/workflow/graph_engine/event_management/event_handlers.py @@ -3,11 +3,12 @@ Event handler implementations for different event types. """ import logging +from collections.abc import Mapping from functools import singledispatchmethod from typing import TYPE_CHECKING, final from core.workflow.entities import GraphRuntimeState -from core.workflow.enums import NodeExecutionType +from core.workflow.enums import ErrorStrategy, NodeExecutionType from core.workflow.graph import Graph from core.workflow.graph_events import ( GraphNodeEventBase, @@ -122,13 +123,15 @@ class EventHandler: """ # Track execution in domain model node_execution = self._graph_execution.get_or_create_node_execution(event.node_id) + is_initial_attempt = node_execution.retry_count == 0 node_execution.mark_started(event.id) # Track in response coordinator for stream ordering self._response_coordinator.track_node_execution(event.node_id, event.id) - # Collect the event - self._event_collector.collect(event) + # Collect the event only for the first attempt; retries remain silent + if is_initial_attempt: + self._event_collector.collect(event) @_dispatch.register def _(self, event: NodeRunStreamChunkEvent) -> None: @@ -161,7 +164,7 @@ class EventHandler: node_execution.mark_taken() # Store outputs in variable pool - self._store_node_outputs(event) + self._store_node_outputs(event.node_id, event.node_run_result.outputs) # Forward to response coordinator and emit streaming events streaming_events = self._response_coordinator.intercept_event(event) @@ -191,7 +194,7 @@ class EventHandler: # Handle response node outputs if node.execution_type == NodeExecutionType.RESPONSE: - self._update_response_outputs(event) + self._update_response_outputs(event.node_run_result.outputs) # Collect the event self._event_collector.collect(event) @@ -207,6 +210,7 @@ class EventHandler: # Update domain model node_execution = self._graph_execution.get_or_create_node_execution(event.node_id) node_execution.mark_failed(event.error) + self._graph_execution.record_node_failure() result = self._error_handler.handle_node_failure(event) @@ -227,10 +231,40 @@ class EventHandler: Args: event: The node exception event """ - # Node continues via fail-branch, so it's technically "succeeded" + # Node continues via fail-branch/default-value, treat as completion node_execution = self._graph_execution.get_or_create_node_execution(event.node_id) node_execution.mark_taken() + # Persist outputs produced by the exception strategy (e.g. default values) + self._store_node_outputs(event.node_id, event.node_run_result.outputs) + + node = self._graph.nodes[event.node_id] + + if node.error_strategy == ErrorStrategy.DEFAULT_VALUE: + ready_nodes, edge_streaming_events = self._edge_processor.process_node_success(event.node_id) + elif node.error_strategy == ErrorStrategy.FAIL_BRANCH: + ready_nodes, edge_streaming_events = self._edge_processor.handle_branch_completion( + event.node_id, event.node_run_result.edge_source_handle + ) + else: + raise NotImplementedError(f"Unsupported error strategy: {node.error_strategy}") + + for edge_event in edge_streaming_events: + self._event_collector.collect(edge_event) + + for node_id in ready_nodes: + self._state_manager.enqueue_node(node_id) + self._state_manager.start_execution(node_id) + + # Update response outputs if applicable + if node.execution_type == NodeExecutionType.RESPONSE: + self._update_response_outputs(event.node_run_result.outputs) + + self._state_manager.finish_execution(event.node_id) + + # Collect the exception event for observers + self._event_collector.collect(event) + @_dispatch.register def _(self, event: NodeRunRetryEvent) -> None: """ @@ -242,21 +276,31 @@ class EventHandler: node_execution = self._graph_execution.get_or_create_node_execution(event.node_id) node_execution.increment_retry() - def _store_node_outputs(self, event: NodeRunSucceededEvent) -> None: + # Finish the previous attempt before re-queuing the node + self._state_manager.finish_execution(event.node_id) + + # Emit retry event for observers + self._event_collector.collect(event) + + # Re-queue node for execution + self._state_manager.enqueue_node(event.node_id) + self._state_manager.start_execution(event.node_id) + + def _store_node_outputs(self, node_id: str, outputs: Mapping[str, object]) -> None: """ Store node outputs in the variable pool. Args: event: The node succeeded event containing outputs """ - for variable_name, variable_value in event.node_run_result.outputs.items(): - self._graph_runtime_state.variable_pool.add((event.node_id, variable_name), variable_value) + for variable_name, variable_value in outputs.items(): + self._graph_runtime_state.variable_pool.add((node_id, variable_name), variable_value) - def _update_response_outputs(self, event: NodeRunSucceededEvent) -> None: + def _update_response_outputs(self, outputs: Mapping[str, object]) -> None: """Update response outputs for response nodes.""" # TODO: Design a mechanism for nodes to notify the engine about how to update outputs # in runtime state, rather than allowing nodes to directly access runtime state. - for key, value in event.node_run_result.outputs.items(): + for key, value in outputs.items(): if key == "answer": existing = self._graph_runtime_state.get_output("answer", "") if existing: diff --git a/api/core/workflow/graph_engine/event_management/event_manager.py b/api/core/workflow/graph_engine/event_management/event_manager.py index 6f37193070..751a2a4352 100644 --- a/api/core/workflow/graph_engine/event_management/event_manager.py +++ b/api/core/workflow/graph_engine/event_management/event_manager.py @@ -5,6 +5,7 @@ Unified event manager for collecting and emitting events. import threading import time from collections.abc import Generator +from contextlib import contextmanager from typing import final from core.workflow.graph_events import GraphEngineEvent @@ -51,43 +52,23 @@ class ReadWriteLock: """Release a write lock.""" self._read_ready.release() - def read_lock(self) -> "ReadLockContext": + @contextmanager + def read_lock(self): """Return a context manager for read locking.""" - return ReadLockContext(self) + self.acquire_read() + try: + yield + finally: + self.release_read() - def write_lock(self) -> "WriteLockContext": + @contextmanager + def write_lock(self): """Return a context manager for write locking.""" - return WriteLockContext(self) - - -@final -class ReadLockContext: - """Context manager for read locks.""" - - def __init__(self, lock: ReadWriteLock) -> None: - self._lock = lock - - def __enter__(self) -> "ReadLockContext": - self._lock.acquire_read() - return self - - def __exit__(self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: object) -> None: - self._lock.release_read() - - -@final -class WriteLockContext: - """Context manager for write locks.""" - - def __init__(self, lock: ReadWriteLock) -> None: - self._lock = lock - - def __enter__(self) -> "WriteLockContext": - self._lock.acquire_write() - return self - - def __exit__(self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: object) -> None: - self._lock.release_write() + self.acquire_write() + try: + yield + finally: + self.release_write() @final diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index 164ae41cca..a21fb7c022 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -23,6 +23,7 @@ from core.workflow.graph_events import ( GraphNodeEventBase, GraphRunAbortedEvent, GraphRunFailedEvent, + GraphRunPartialSucceededEvent, GraphRunStartedEvent, GraphRunSucceededEvent, ) @@ -260,12 +261,23 @@ class GraphEngine: if self._graph_execution.error: raise self._graph_execution.error else: - yield GraphRunSucceededEvent( - outputs=self._graph_runtime_state.outputs, - ) + outputs = self._graph_runtime_state.outputs + exceptions_count = self._graph_execution.exceptions_count + if exceptions_count > 0: + yield GraphRunPartialSucceededEvent( + exceptions_count=exceptions_count, + outputs=outputs, + ) + else: + yield GraphRunSucceededEvent( + outputs=outputs, + ) except Exception as e: - yield GraphRunFailedEvent(error=str(e)) + yield GraphRunFailedEvent( + error=str(e), + exceptions_count=self._graph_execution.exceptions_count, + ) raise finally: diff --git a/api/core/workflow/graph_engine/layers/README.md b/api/core/workflow/graph_engine/layers/README.md index 8ee35baec0..17845ee1f0 100644 --- a/api/core/workflow/graph_engine/layers/README.md +++ b/api/core/workflow/graph_engine/layers/README.md @@ -30,7 +30,7 @@ debug_layer = DebugLoggingLayer( ) engine = GraphEngine(graph) -engine.add_layer(debug_layer) +engine.layer(debug_layer) engine.run() ``` diff --git a/api/core/workflow/graph_engine/layers/debug_logging.py b/api/core/workflow/graph_engine/layers/debug_logging.py index f24c3fe33c..034ebcf54f 100644 --- a/api/core/workflow/graph_engine/layers/debug_logging.py +++ b/api/core/workflow/graph_engine/layers/debug_logging.py @@ -15,6 +15,7 @@ from core.workflow.graph_events import ( GraphEngineEvent, GraphRunAbortedEvent, GraphRunFailedEvent, + GraphRunPartialSucceededEvent, GraphRunStartedEvent, GraphRunSucceededEvent, NodeRunExceptionEvent, @@ -127,6 +128,13 @@ class DebugLoggingLayer(GraphEngineLayer): if self.include_outputs and event.outputs: self.logger.info(" Final outputs: %s", self._format_dict(event.outputs)) + elif isinstance(event, GraphRunPartialSucceededEvent): + self.logger.warning("⚠️ Graph run partially succeeded") + if event.exceptions_count > 0: + self.logger.warning(" Total exceptions: %s", event.exceptions_count) + if self.include_outputs and event.outputs: + self.logger.info(" Final outputs: %s", self._format_dict(event.outputs)) + elif isinstance(event, GraphRunFailedEvent): self.logger.error("❌ Graph run failed: %s", event.error) if event.exceptions_count > 0: @@ -138,6 +146,12 @@ class DebugLoggingLayer(GraphEngineLayer): self.logger.info(" Partial outputs: %s", self._format_dict(event.outputs)) # Node-level events + # Retry before Started because Retry subclasses Started; + elif isinstance(event, NodeRunRetryEvent): + self.retry_count += 1 + self.logger.warning("🔄 Node retry: %s (attempt %s)", event.node_id, event.retry_index) + self.logger.warning(" Previous error: %s", event.error) + elif isinstance(event, NodeRunStartedEvent): self.node_count += 1 self.logger.info('▶️ Node started: %s - "%s" (type: %s)', event.node_id, event.node_title, event.node_type) @@ -167,11 +181,6 @@ class DebugLoggingLayer(GraphEngineLayer): self.logger.warning("⚠️ Node exception handled: %s", event.node_id) self.logger.warning(" Error: %s", event.error) - elif isinstance(event, NodeRunRetryEvent): - self.retry_count += 1 - self.logger.warning("🔄 Node retry: %s (attempt %s)", event.node_id, event.retry_index) - self.logger.warning(" Previous error: %s", event.error) - elif isinstance(event, NodeRunStreamChunkEvent): # Log stream chunks at debug level to avoid spam final_indicator = " (FINAL)" if event.is_final else "" diff --git a/api/core/workflow/graph_engine/response_coordinator/coordinator.py b/api/core/workflow/graph_engine/response_coordinator/coordinator.py index 985992f3f1..3db40c545e 100644 --- a/api/core/workflow/graph_engine/response_coordinator/coordinator.py +++ b/api/core/workflow/graph_engine/response_coordinator/coordinator.py @@ -212,10 +212,11 @@ class ResponseStreamCoordinator: edge = self._graph.edges[edge_id] source_node = self._graph.nodes[edge.tail] - # Check if node is a branch/container (original behavior) + # Check if node is a branch, container, or response node if source_node.execution_type in { NodeExecutionType.BRANCH, NodeExecutionType.CONTAINER, + NodeExecutionType.RESPONSE, } or source_node.blocks_variable_output(variable_selectors): blocking_edges.append(edge_id) diff --git a/api/core/workflow/node_events/node.py b/api/core/workflow/node_events/node.py index c1aeb9fe27..93dfefb679 100644 --- a/api/core/workflow/node_events/node.py +++ b/api/core/workflow/node_events/node.py @@ -20,6 +20,7 @@ class ModelInvokeCompletedEvent(NodeEventBase): usage: LLMUsage finish_reason: str | None = None reasoning_content: str | None = None + structured_output: dict | None = None class RunRetryEvent(NodeEventBase): diff --git a/api/core/workflow/nodes/agent/agent_node.py b/api/core/workflow/nodes/agent/agent_node.py index ec05805879..4a24b18465 100644 --- a/api/core/workflow/nodes/agent/agent_node.py +++ b/api/core/workflow/nodes/agent/agent_node.py @@ -252,7 +252,10 @@ class AgentNode(Node): if all(isinstance(v, dict) for _, v in parameters.items()): params = {} for key, param in parameters.items(): - if param.get("auto", ParamsAutoGenerated.OPEN.value) == ParamsAutoGenerated.CLOSE.value: + if param.get("auto", ParamsAutoGenerated.OPEN) in ( + ParamsAutoGenerated.CLOSE, + 0, + ): value_param = param.get("value", {}) params[key] = value_param.get("value", "") if value_param is not None else None else: @@ -266,7 +269,7 @@ class AgentNode(Node): value = cast(list[dict[str, Any]], value) tool_value = [] for tool in value: - provider_type = ToolProviderType(tool.get("type", ToolProviderType.BUILT_IN.value)) + provider_type = ToolProviderType(tool.get("type", ToolProviderType.BUILT_IN)) setting_params = tool.get("settings", {}) parameters = tool.get("parameters", {}) manual_input_params = [key for key, value in parameters.items() if value is not None] @@ -288,7 +291,7 @@ class AgentNode(Node): # But for backward compatibility with historical data # this version field judgment is still preserved here. runtime_variable_pool: VariablePool | None = None - if node_data.version != "1" or node_data.tool_node_version != "1": + if node_data.version != "1" or node_data.tool_node_version is not None: runtime_variable_pool = variable_pool tool_runtime = ToolManager.get_agent_tool_runtime( self.tenant_id, self.app_id, entity, self.invoke_from, runtime_variable_pool @@ -417,7 +420,7 @@ class AgentNode(Node): def _fetch_memory(self, model_instance: ModelInstance) -> TokenBufferMemory | None: # get conversation id conversation_id_variable = self.graph_runtime_state.variable_pool.get( - ["sys", SystemVariableKey.CONVERSATION_ID.value] + ["sys", SystemVariableKey.CONVERSATION_ID] ) if not isinstance(conversation_id_variable, StringSegment): return None @@ -476,7 +479,7 @@ class AgentNode(Node): if meta_version and Version(meta_version) > Version("0.0.1"): return tools else: - return [tool for tool in tools if tool.get("type") != ToolProviderType.MCP.value] + return [tool for tool in tools if tool.get("type") != ToolProviderType.MCP] def _transform_message( self, diff --git a/api/core/workflow/nodes/datasource/datasource_node.py b/api/core/workflow/nodes/datasource/datasource_node.py index 937f4c944f..e392cb5f5c 100644 --- a/api/core/workflow/nodes/datasource/datasource_node.py +++ b/api/core/workflow/nodes/datasource/datasource_node.py @@ -75,11 +75,11 @@ class DatasourceNode(Node): node_data = self._node_data variable_pool = self.graph_runtime_state.variable_pool - datasource_type_segement = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_TYPE.value]) + datasource_type_segement = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_TYPE]) if not datasource_type_segement: raise DatasourceNodeError("Datasource type is not set") datasource_type = str(datasource_type_segement.value) if datasource_type_segement.value else None - datasource_info_segement = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_INFO.value]) + datasource_info_segement = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_INFO]) if not datasource_info_segement: raise DatasourceNodeError("Datasource info is not set") datasource_info_value = datasource_info_segement.value @@ -267,7 +267,7 @@ class DatasourceNode(Node): return result def _fetch_files(self, variable_pool: VariablePool) -> list[File]: - variable = variable_pool.get(["sys", SystemVariableKey.FILES.value]) + variable = variable_pool.get(["sys", SystemVariableKey.FILES]) assert isinstance(variable, ArrayAnyVariable | ArrayAnySegment) return list(variable.value) if variable else [] diff --git a/api/core/workflow/nodes/end/end_node.py b/api/core/workflow/nodes/end/end_node.py index 2bdfe4efce..7ec74084d0 100644 --- a/api/core/workflow/nodes/end/end_node.py +++ b/api/core/workflow/nodes/end/end_node.py @@ -16,7 +16,7 @@ class EndNode(Node): _node_data: EndNodeData def init_node_data(self, data: Mapping[str, Any]): - self._node_data = EndNodeData(**data) + self._node_data = EndNodeData.model_validate(data) def _get_error_strategy(self) -> ErrorStrategy | None: return self._node_data.error_strategy diff --git a/api/core/workflow/nodes/http_request/executor.py b/api/core/workflow/nodes/http_request/executor.py index c47ffb5ab0..d3d3571b44 100644 --- a/api/core/workflow/nodes/http_request/executor.py +++ b/api/core/workflow/nodes/http_request/executor.py @@ -87,7 +87,7 @@ class Executor: node_data.authorization.config.api_key ).text - self.url: str = node_data.url + self.url = node_data.url self.method = node_data.method self.auth = node_data.authorization self.timeout = timeout @@ -349,11 +349,10 @@ class Executor: "timeout": (self.timeout.connect, self.timeout.read, self.timeout.write), "ssl_verify": self.ssl_verify, "follow_redirects": True, - "max_retries": self.max_retries, } # request_args = {k: v for k, v in request_args.items() if v is not None} try: - response: httpx.Response = _METHOD_MAP[method_lc](**request_args) + response: httpx.Response = _METHOD_MAP[method_lc](**request_args, max_retries=self.max_retries) except (ssrf_proxy.MaxRetriesExceededError, httpx.RequestError) as e: raise HttpRequestNodeError(str(e)) from e # FIXME: fix type ignore, this maybe httpx type issue diff --git a/api/core/workflow/nodes/http_request/node.py b/api/core/workflow/nodes/http_request/node.py index 826820a8e3..55dec3fb08 100644 --- a/api/core/workflow/nodes/http_request/node.py +++ b/api/core/workflow/nodes/http_request/node.py @@ -165,6 +165,8 @@ class HttpRequestNode(Node): body_type = typed_node_data.body.type data = typed_node_data.body.data match body_type: + case "none": + pass case "binary": if len(data) != 1: raise RequestBodyError("invalid body data, should have only one item") @@ -232,7 +234,7 @@ class HttpRequestNode(Node): mapping = { "tool_file_id": tool_file.id, - "transfer_method": FileTransferMethod.TOOL_FILE.value, + "transfer_method": FileTransferMethod.TOOL_FILE, } file = file_factory.build_from_mapping( mapping=mapping, diff --git a/api/core/workflow/nodes/if_else/if_else_node.py b/api/core/workflow/nodes/if_else/if_else_node.py index 075f6f8444..7e3b6ecc1a 100644 --- a/api/core/workflow/nodes/if_else/if_else_node.py +++ b/api/core/workflow/nodes/if_else/if_else_node.py @@ -83,7 +83,7 @@ class IfElseNode(Node): else: # TODO: Update database then remove this # Fallback to old structure if cases are not defined - input_conditions, group_result, final_result = _should_not_use_old_function( # ty: ignore [deprecated] + input_conditions, group_result, final_result = _should_not_use_old_function( # pyright: ignore [reportDeprecated] condition_processor=condition_processor, variable_pool=self.graph_runtime_state.variable_pool, conditions=self._node_data.conditions or [], diff --git a/api/core/workflow/nodes/iteration/iteration_node.py b/api/core/workflow/nodes/iteration/iteration_node.py index 5340a5b6ce..c089a68bd4 100644 --- a/api/core/workflow/nodes/iteration/iteration_node.py +++ b/api/core/workflow/nodes/iteration/iteration_node.py @@ -1,13 +1,17 @@ +import contextvars import logging from collections.abc import Generator, Mapping, Sequence from concurrent.futures import Future, ThreadPoolExecutor, as_completed from datetime import UTC, datetime from typing import TYPE_CHECKING, Any, NewType, cast +from flask import Flask, current_app from typing_extensions import TypeIs from core.variables import IntegerVariable, NoneSegment from core.variables.segments import ArrayAnySegment, ArraySegment +from core.variables.variables import VariableUnion +from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID from core.workflow.entities import VariablePool from core.workflow.enums import ( ErrorStrategy, @@ -19,6 +23,7 @@ from core.workflow.enums import ( from core.workflow.graph_events import ( GraphNodeEventBase, GraphRunFailedEvent, + GraphRunPartialSucceededEvent, GraphRunSucceededEvent, ) from core.workflow.node_events import ( @@ -34,6 +39,7 @@ from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig from core.workflow.nodes.base.node import Node from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData from libs.datetime_utils import naive_utc_now +from libs.flask_utils import preserve_flask_contexts from .exc import ( InvalidIteratorValueError, @@ -89,7 +95,7 @@ class IterationNode(Node): "config": { "is_parallel": False, "parallel_nums": 10, - "error_handle_mode": ErrorHandleMode.TERMINATED.value, + "error_handle_mode": ErrorHandleMode.TERMINATED, }, } @@ -213,6 +219,13 @@ class IterationNode(Node): graph_engine=graph_engine, ) + # Sync conversation variables after each iteration completes + self._sync_conversation_variables_from_snapshot( + self._extract_conversation_variable_snapshot( + variable_pool=graph_engine.graph_runtime_state.variable_pool + ) + ) + # Update the total tokens from this iteration self.graph_runtime_state.total_tokens += graph_engine.graph_runtime_state.total_tokens iter_run_map[str(index)] = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds() @@ -231,13 +244,18 @@ class IterationNode(Node): with ThreadPoolExecutor(max_workers=max_workers) as executor: # Submit all iteration tasks - future_to_index: dict[Future[tuple[datetime, list[GraphNodeEventBase], object | None, int]], int] = {} + future_to_index: dict[ + Future[tuple[datetime, list[GraphNodeEventBase], object | None, int, dict[str, VariableUnion]]], + int, + ] = {} for index, item in enumerate(iterator_list_value): yield IterationNextEvent(index=index) future = executor.submit( self._execute_single_iteration_parallel, index=index, item=item, + flask_app=current_app._get_current_object(), # type: ignore + context_vars=contextvars.copy_context(), ) future_to_index[future] = index @@ -246,7 +264,7 @@ class IterationNode(Node): index = future_to_index[future] try: result = future.result() - iter_start_at, events, output_value, tokens_used = result + iter_start_at, events, output_value, tokens_used, conversation_snapshot = result # Update outputs at the correct index outputs[index] = output_value @@ -258,6 +276,9 @@ class IterationNode(Node): self.graph_runtime_state.total_tokens += tokens_used iter_run_map[str(index)] = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds() + # Sync conversation variables after iteration completion + self._sync_conversation_variables_from_snapshot(conversation_snapshot) + except Exception as e: # Handle errors based on error_handle_mode match self._node_data.error_handle_mode: @@ -280,26 +301,38 @@ class IterationNode(Node): self, index: int, item: object, - ) -> tuple[datetime, list[GraphNodeEventBase], object | None, int]: + flask_app: Flask, + context_vars: contextvars.Context, + ) -> tuple[datetime, list[GraphNodeEventBase], object | None, int, dict[str, VariableUnion]]: """Execute a single iteration in parallel mode and return results.""" - iter_start_at = datetime.now(UTC).replace(tzinfo=None) - events: list[GraphNodeEventBase] = [] - outputs_temp: list[object] = [] + with preserve_flask_contexts(flask_app=flask_app, context_vars=context_vars): + iter_start_at = datetime.now(UTC).replace(tzinfo=None) + events: list[GraphNodeEventBase] = [] + outputs_temp: list[object] = [] - graph_engine = self._create_graph_engine(index, item) + graph_engine = self._create_graph_engine(index, item) - # Collect events instead of yielding them directly - for event in self._run_single_iter( - variable_pool=graph_engine.graph_runtime_state.variable_pool, - outputs=outputs_temp, - graph_engine=graph_engine, - ): - events.append(event) + # Collect events instead of yielding them directly + for event in self._run_single_iter( + variable_pool=graph_engine.graph_runtime_state.variable_pool, + outputs=outputs_temp, + graph_engine=graph_engine, + ): + events.append(event) - # Get the output value from the temporary outputs list - output_value = outputs_temp[0] if outputs_temp else None + # Get the output value from the temporary outputs list + output_value = outputs_temp[0] if outputs_temp else None + conversation_snapshot = self._extract_conversation_variable_snapshot( + variable_pool=graph_engine.graph_runtime_state.variable_pool + ) - return iter_start_at, events, output_value, graph_engine.graph_runtime_state.total_tokens + return ( + iter_start_at, + events, + output_value, + graph_engine.graph_runtime_state.total_tokens, + conversation_snapshot, + ) def _handle_iteration_success( self, @@ -309,10 +342,13 @@ class IterationNode(Node): iterator_list_value: Sequence[object], iter_run_map: dict[str, float], ) -> Generator[NodeEventBase, None, None]: + # Flatten the list of lists if all outputs are lists + flattened_outputs = self._flatten_outputs_if_needed(outputs) + yield IterationSucceededEvent( start_at=started_at, inputs=inputs, - outputs={"output": outputs}, + outputs={"output": flattened_outputs}, steps=len(iterator_list_value), metadata={ WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens, @@ -324,13 +360,39 @@ class IterationNode(Node): yield StreamCompletedEvent( node_run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, - outputs={"output": outputs}, + outputs={"output": flattened_outputs}, metadata={ WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens, }, ) ) + def _flatten_outputs_if_needed(self, outputs: list[object]) -> list[object]: + """ + Flatten the outputs list if all elements are lists. + This maintains backward compatibility with version 1.8.1 behavior. + """ + if not outputs: + return outputs + + # Check if all non-None outputs are lists + non_none_outputs = [output for output in outputs if output is not None] + if not non_none_outputs: + return outputs + + if all(isinstance(output, list) for output in non_none_outputs): + # Flatten the list of lists + flattened: list[Any] = [] + for output in outputs: + if isinstance(output, list): + flattened.extend(output) + elif output is not None: + # This shouldn't happen based on our check, but handle it gracefully + flattened.append(output) + return flattened + + return outputs + def _handle_iteration_failure( self, started_at: datetime, @@ -340,10 +402,13 @@ class IterationNode(Node): iter_run_map: dict[str, float], error: IterationNodeError, ) -> Generator[NodeEventBase, None, None]: + # Flatten the list of lists if all outputs are lists (even in failure case) + flattened_outputs = self._flatten_outputs_if_needed(outputs) + yield IterationFailedEvent( start_at=started_at, inputs=inputs, - outputs={"output": outputs}, + outputs={"output": flattened_outputs}, steps=len(iterator_list_value), metadata={ WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens, @@ -372,43 +437,16 @@ class IterationNode(Node): variable_mapping: dict[str, Sequence[str]] = { f"{node_id}.input_selector": typed_node_data.iterator_selector, } + iteration_node_ids = set() - # init graph - from core.workflow.entities import GraphInitParams, GraphRuntimeState - from core.workflow.graph import Graph - from core.workflow.nodes.node_factory import DifyNodeFactory - - # Create minimal GraphInitParams for static analysis - graph_init_params = GraphInitParams( - tenant_id="", - app_id="", - workflow_id="", - graph_config=graph_config, - user_id="", - user_from="", - invoke_from="", - call_depth=0, - ) - - # Create minimal GraphRuntimeState for static analysis - from core.workflow.entities import VariablePool - - graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(), - start_at=0, - ) - - # Create node factory for static analysis - node_factory = DifyNodeFactory(graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state) - - iteration_graph = Graph.init( - graph_config=graph_config, - node_factory=node_factory, - root_node_id=typed_node_data.start_node_id, - ) - - if not iteration_graph: - raise IterationGraphNotFoundError("iteration graph not found") + # Find all nodes that belong to this loop + nodes = graph_config.get("nodes", []) + for node in nodes: + node_data = node.get("data", {}) + if node_data.get("iteration_id") == node_id: + in_iteration_node_id = node.get("id") + if in_iteration_node_id: + iteration_node_ids.add(in_iteration_node_id) # Get node configs from graph_config instead of non-existent node_id_config_mapping node_configs = {node["id"]: node for node in graph_config.get("nodes", []) if "id" in node} @@ -444,12 +482,27 @@ class IterationNode(Node): variable_mapping.update(sub_node_variable_mapping) # remove variable out from iteration - variable_mapping = { - key: value for key, value in variable_mapping.items() if value[0] not in iteration_graph.node_ids - } + variable_mapping = {key: value for key, value in variable_mapping.items() if value[0] not in iteration_node_ids} return variable_mapping + def _extract_conversation_variable_snapshot(self, *, variable_pool: VariablePool) -> dict[str, VariableUnion]: + conversation_variables = variable_pool.variable_dictionary.get(CONVERSATION_VARIABLE_NODE_ID, {}) + return {name: variable.model_copy(deep=True) for name, variable in conversation_variables.items()} + + def _sync_conversation_variables_from_snapshot(self, snapshot: dict[str, VariableUnion]) -> None: + parent_pool = self.graph_runtime_state.variable_pool + parent_conversations = parent_pool.variable_dictionary.get(CONVERSATION_VARIABLE_NODE_ID, {}) + + current_keys = set(parent_conversations.keys()) + snapshot_keys = set(snapshot.keys()) + + for removed_key in current_keys - snapshot_keys: + parent_pool.remove((CONVERSATION_VARIABLE_NODE_ID, removed_key)) + + for name, variable in snapshot.items(): + parent_pool.add((CONVERSATION_VARIABLE_NODE_ID, name), variable) + def _append_iteration_info_to_event( self, event: GraphNodeEventBase, @@ -485,7 +538,7 @@ class IterationNode(Node): if isinstance(event, GraphNodeEventBase): self._append_iteration_info_to_event(event=event, iter_run_index=current_index) yield event - elif isinstance(event, GraphRunSucceededEvent): + elif isinstance(event, (GraphRunSucceededEvent, GraphRunPartialSucceededEvent)): result = variable_pool.get(self._node_data.output_selector) if result is None: outputs.append(None) diff --git a/api/core/workflow/nodes/iteration/iteration_start_node.py b/api/core/workflow/nodes/iteration/iteration_start_node.py index 80f39ccebc..90b7f4539b 100644 --- a/api/core/workflow/nodes/iteration/iteration_start_node.py +++ b/api/core/workflow/nodes/iteration/iteration_start_node.py @@ -18,7 +18,7 @@ class IterationStartNode(Node): _node_data: IterationStartNodeData def init_node_data(self, data: Mapping[str, Any]): - self._node_data = IterationStartNodeData(**data) + self._node_data = IterationStartNodeData.model_validate(data) def _get_error_strategy(self) -> ErrorStrategy | None: return self._node_data.error_strategy diff --git a/api/core/workflow/nodes/knowledge_index/entities.py b/api/core/workflow/nodes/knowledge_index/entities.py index 2a2e983a0c..c79373afd5 100644 --- a/api/core/workflow/nodes/knowledge_index/entities.py +++ b/api/core/workflow/nodes/knowledge_index/entities.py @@ -63,7 +63,7 @@ class RetrievalSetting(BaseModel): Retrieval Setting. """ - search_method: Literal["semantic_search", "keyword_search", "fulltext_search", "hybrid_search"] + search_method: Literal["semantic_search", "keyword_search", "full_text_search", "hybrid_search"] top_k: int score_threshold: float | None = 0.5 score_threshold_enabled: bool = False diff --git a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py index 4b6bad1aa3..2751f24048 100644 --- a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py +++ b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py @@ -2,7 +2,7 @@ import datetime import logging import time from collections.abc import Mapping -from typing import Any, cast +from typing import Any from sqlalchemy import func, select @@ -27,7 +27,7 @@ from .exc import ( logger = logging.getLogger(__name__) default_retrieval_model = { - "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, + "search_method": RetrievalMethod.SEMANTIC_SEARCH, "reranking_enable": False, "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, "top_k": 2, @@ -62,7 +62,7 @@ class KnowledgeIndexNode(Node): return self._node_data def _run(self) -> NodeRunResult: # type: ignore - node_data = cast(KnowledgeIndexNodeData, self._node_data) + node_data = self._node_data variable_pool = self.graph_runtime_state.variable_pool dataset_id = variable_pool.get(["sys", SystemVariableKey.DATASET_ID]) if not dataset_id: @@ -77,7 +77,7 @@ class KnowledgeIndexNode(Node): raise KnowledgeIndexNodeError("Index chunk variable is required.") invoke_from = variable_pool.get(["sys", SystemVariableKey.INVOKE_FROM]) if invoke_from: - is_preview = invoke_from.value == InvokeFrom.DEBUGGER.value + is_preview = invoke_from.value == InvokeFrom.DEBUGGER else: is_preview = False chunks = variable.value @@ -136,6 +136,11 @@ class KnowledgeIndexNode(Node): document = db.session.query(Document).filter_by(id=document_id.value).first() if not document: raise KnowledgeIndexNodeError(f"Document {document_id.value} not found.") + doc_id_value = document.id + ds_id_value = dataset.id + dataset_name_value = dataset.name + document_name_value = document.name + created_at_value = document.created_at # chunk nodes by chunk size indexing_start_at = time.perf_counter() index_processor = IndexProcessorFactory(dataset.chunk_structure).init_index_processor() @@ -161,16 +166,16 @@ class KnowledgeIndexNode(Node): document.word_count = ( db.session.query(func.sum(DocumentSegment.word_count)) .where( - DocumentSegment.document_id == document.id, - DocumentSegment.dataset_id == dataset.id, + DocumentSegment.document_id == doc_id_value, + DocumentSegment.dataset_id == ds_id_value, ) .scalar() ) db.session.add(document) # update document segment status db.session.query(DocumentSegment).where( - DocumentSegment.document_id == document.id, - DocumentSegment.dataset_id == dataset.id, + DocumentSegment.document_id == doc_id_value, + DocumentSegment.dataset_id == ds_id_value, ).update( { DocumentSegment.status: "completed", @@ -182,13 +187,13 @@ class KnowledgeIndexNode(Node): db.session.commit() return { - "dataset_id": dataset.id, - "dataset_name": dataset.name, + "dataset_id": ds_id_value, + "dataset_name": dataset_name_value, "batch": batch.value, - "document_id": document.id, - "document_name": document.name, - "created_at": document.created_at.timestamp(), - "display_status": document.indexing_status, + "document_id": doc_id_value, + "document_name": document_name_value, + "created_at": created_at_value.timestamp(), + "display_status": "completed", } def _get_preview_output(self, chunk_structure: str, chunks: Any) -> Mapping[str, Any]: diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index 1afb2e05b9..7091b62463 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -72,7 +72,7 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) default_retrieval_model = { - "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, + "search_method": RetrievalMethod.SEMANTIC_SEARCH, "reranking_enable": False, "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, "top_k": 4, @@ -107,7 +107,7 @@ class KnowledgeRetrievalNode(Node): graph_runtime_state=graph_runtime_state, ) # LLM file outputs, used for MultiModal outputs. - self._file_outputs: list[File] = [] + self._file_outputs = [] if llm_file_saver is None: llm_file_saver = FileSaverImpl( diff --git a/api/core/workflow/nodes/list_operator/node.py b/api/core/workflow/nodes/list_operator/node.py index 7a31d69221..180eb2ad90 100644 --- a/api/core/workflow/nodes/list_operator/node.py +++ b/api/core/workflow/nodes/list_operator/node.py @@ -41,7 +41,7 @@ class ListOperatorNode(Node): _node_data: ListOperatorNodeData def init_node_data(self, data: Mapping[str, Any]): - self._node_data = ListOperatorNodeData(**data) + self._node_data = ListOperatorNodeData.model_validate(data) def _get_error_strategy(self) -> ErrorStrategy | None: return self._node_data.error_strategy @@ -161,6 +161,8 @@ class ListOperatorNode(Node): elif isinstance(variable, ArrayFileSegment): if isinstance(condition.value, str): value = self.graph_runtime_state.variable_pool.convert_template(condition.value).text + elif isinstance(condition.value, bool): + raise ValueError(f"File filter expects a string value, got {type(condition.value)}") else: value = condition.value filter_func = _get_file_filter_func( diff --git a/api/core/workflow/nodes/llm/file_saver.py b/api/core/workflow/nodes/llm/file_saver.py index 81f2df0891..3f32fa894a 100644 --- a/api/core/workflow/nodes/llm/file_saver.py +++ b/api/core/workflow/nodes/llm/file_saver.py @@ -46,7 +46,7 @@ class LLMFileSaver(tp.Protocol): dot (`.`). For example, `.py` and `.tar.gz` are both valid values, while `py` and `tar.gz` are not. """ - pass + raise NotImplementedError() def save_remote_url(self, url: str, file_type: FileType) -> File: """save_remote_url saves the file from a remote url returned by LLM. @@ -56,7 +56,7 @@ class LLMFileSaver(tp.Protocol): :param url: the url of the file. :param file_type: the file type of the file, check `FileType` enum for reference. """ - pass + raise NotImplementedError() EngineFactory: tp.TypeAlias = tp.Callable[[], Engine] diff --git a/api/core/workflow/nodes/llm/llm_utils.py b/api/core/workflow/nodes/llm/llm_utils.py index ad969cdad1..aff84433b2 100644 --- a/api/core/workflow/nodes/llm/llm_utils.py +++ b/api/core/workflow/nodes/llm/llm_utils.py @@ -92,7 +92,7 @@ def fetch_memory( return None # get conversation id - conversation_id_variable = variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID.value]) + conversation_id_variable = variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID]) if not isinstance(conversation_id_variable, StringSegment): return None conversation_id = conversation_id_variable.value @@ -143,7 +143,7 @@ def deduct_llm_quota(tenant_id: str, model_instance: ModelInstance, usage: LLMUs Provider.tenant_id == tenant_id, # TODO: Use provider name with prefix after the data migration. Provider.provider_name == ModelProviderID(model_instance.provider).provider_name, - Provider.provider_type == ProviderType.SYSTEM.value, + Provider.provider_type == ProviderType.SYSTEM, Provider.quota_type == system_configuration.current_quota_type.value, Provider.quota_limit > Provider.quota_used, ) diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index 7767440be6..13f6d904e6 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -23,6 +23,7 @@ from core.model_runtime.entities.llm_entities import ( LLMResult, LLMResultChunk, LLMResultChunkWithStructuredOutput, + LLMResultWithStructuredOutput, LLMStructuredOutput, LLMUsage, ) @@ -127,7 +128,7 @@ class LLMNode(Node): graph_runtime_state=graph_runtime_state, ) # LLM file outputs, used for MultiModal outputs. - self._file_outputs: list[File] = [] + self._file_outputs = [] if llm_file_saver is None: llm_file_saver = FileSaverImpl( @@ -165,6 +166,7 @@ class LLMNode(Node): node_inputs: dict[str, Any] = {} process_data: dict[str, Any] = {} result_text = "" + clean_text = "" usage = LLMUsage.empty_usage() finish_reason = None reasoning_content = None @@ -278,6 +280,13 @@ class LLMNode(Node): # Extract clean text from tags clean_text, _ = LLMNode._split_reasoning(result_text, self._node_data.reasoning_format) + # Process structured output if available from the event. + structured_output = ( + LLMStructuredOutput(structured_output=event.structured_output) + if event.structured_output + else None + ) + # deduct quota llm_utils.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage) break @@ -936,7 +945,7 @@ class LLMNode(Node): variable_mapping["#files#"] = typed_node_data.vision.configs.variable_selector if typed_node_data.memory: - variable_mapping["#sys.query#"] = ["sys", SystemVariableKey.QUERY.value] + variable_mapping["#sys.query#"] = ["sys", SystemVariableKey.QUERY] if typed_node_data.prompt_config: enable_jinja = False @@ -1048,7 +1057,7 @@ class LLMNode(Node): @staticmethod def handle_blocking_result( *, - invoke_result: LLMResult, + invoke_result: LLMResult | LLMResultWithStructuredOutput, saver: LLMFileSaver, file_outputs: list["File"], reasoning_format: Literal["separated", "tagged"] = "tagged", @@ -1079,6 +1088,8 @@ class LLMNode(Node): finish_reason=None, # Reasoning content for workflow variables and downstream nodes reasoning_content=reasoning_content, + # Pass structured output if enabled + structured_output=getattr(invoke_result, "structured_output", None), ) @staticmethod diff --git a/api/core/workflow/nodes/loop/loop_end_node.py b/api/core/workflow/nodes/loop/loop_end_node.py index 38aef06d24..e5bce1230c 100644 --- a/api/core/workflow/nodes/loop/loop_end_node.py +++ b/api/core/workflow/nodes/loop/loop_end_node.py @@ -18,7 +18,7 @@ class LoopEndNode(Node): _node_data: LoopEndNodeData def init_node_data(self, data: Mapping[str, Any]): - self._node_data = LoopEndNodeData(**data) + self._node_data = LoopEndNodeData.model_validate(data) def _get_error_strategy(self) -> ErrorStrategy | None: return self._node_data.error_strategy diff --git a/api/core/workflow/nodes/loop/loop_node.py b/api/core/workflow/nodes/loop/loop_node.py index 2b988ad944..790975d556 100644 --- a/api/core/workflow/nodes/loop/loop_node.py +++ b/api/core/workflow/nodes/loop/loop_node.py @@ -1,3 +1,4 @@ +import contextlib import json import logging from collections.abc import Callable, Generator, Mapping, Sequence @@ -127,11 +128,13 @@ class LoopNode(Node): try: reach_break_condition = False if break_conditions: - _, _, reach_break_condition = condition_processor.process_conditions( - variable_pool=self.graph_runtime_state.variable_pool, - conditions=break_conditions, - operator=logical_operator, - ) + with contextlib.suppress(ValueError): + _, _, reach_break_condition = condition_processor.process_conditions( + variable_pool=self.graph_runtime_state.variable_pool, + conditions=break_conditions, + operator=logical_operator, + ) + if reach_break_condition: loop_count = 0 cost_tokens = 0 @@ -295,42 +298,11 @@ class LoopNode(Node): variable_mapping = {} - # init graph - from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool - from core.workflow.graph import Graph - from core.workflow.nodes.node_factory import DifyNodeFactory + # Extract loop node IDs statically from graph_config - # Create minimal GraphInitParams for static analysis - graph_init_params = GraphInitParams( - tenant_id="", - app_id="", - workflow_id="", - graph_config=graph_config, - user_id="", - user_from="", - invoke_from="", - call_depth=0, - ) + loop_node_ids = cls._extract_loop_node_ids_from_config(graph_config, node_id) - # Create minimal GraphRuntimeState for static analysis - graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(), - start_at=0, - ) - - # Create node factory for static analysis - node_factory = DifyNodeFactory(graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state) - - loop_graph = Graph.init( - graph_config=graph_config, - node_factory=node_factory, - root_node_id=typed_node_data.start_node_id, - ) - - if not loop_graph: - raise ValueError("loop graph not found") - - # Get node configs from graph_config instead of non-existent node_id_config_mapping + # Get node configs from graph_config node_configs = {node["id"]: node for node in graph_config.get("nodes", []) if "id" in node} for sub_node_id, sub_node_config in node_configs.items(): if sub_node_config.get("data", {}).get("loop_id") != node_id: @@ -371,12 +343,35 @@ class LoopNode(Node): variable_mapping[f"{node_id}.{loop_variable.label}"] = selector # remove variable out from loop - variable_mapping = { - key: value for key, value in variable_mapping.items() if value[0] not in loop_graph.node_ids - } + variable_mapping = {key: value for key, value in variable_mapping.items() if value[0] not in loop_node_ids} return variable_mapping + @classmethod + def _extract_loop_node_ids_from_config(cls, graph_config: Mapping[str, Any], loop_node_id: str) -> set[str]: + """ + Extract node IDs that belong to a specific loop from graph configuration. + + This method statically analyzes the graph configuration to find all nodes + that are part of the specified loop, without creating actual node instances. + + :param graph_config: the complete graph configuration + :param loop_node_id: the ID of the loop node + :return: set of node IDs that belong to the loop + """ + loop_node_ids = set() + + # Find all nodes that belong to this loop + nodes = graph_config.get("nodes", []) + for node in nodes: + node_data = node.get("data", {}) + if node_data.get("loop_id") == loop_node_id: + node_id = node.get("id") + if node_id: + loop_node_ids.add(node_id) + + return loop_node_ids + @staticmethod def _get_segment_for_constant(var_type: SegmentType, original_value: Any) -> Segment: """Get the appropriate segment type for a constant value.""" diff --git a/api/core/workflow/nodes/loop/loop_start_node.py b/api/core/workflow/nodes/loop/loop_start_node.py index e777a8cbe9..e065dc90a0 100644 --- a/api/core/workflow/nodes/loop/loop_start_node.py +++ b/api/core/workflow/nodes/loop/loop_start_node.py @@ -18,7 +18,7 @@ class LoopStartNode(Node): _node_data: LoopStartNodeData def init_node_data(self, data: Mapping[str, Any]): - self._node_data = LoopStartNodeData(**data) + self._node_data = LoopStartNodeData.model_validate(data) def _get_error_strategy(self) -> ErrorStrategy | None: return self._node_data.error_strategy diff --git a/api/core/workflow/nodes/parameter_extractor/prompts.py b/api/core/workflow/nodes/parameter_extractor/prompts.py index ab7ddcc32a..b74be8f206 100644 --- a/api/core/workflow/nodes/parameter_extractor/prompts.py +++ b/api/core/workflow/nodes/parameter_extractor/prompts.py @@ -179,6 +179,6 @@ CHAT_EXAMPLE = [ "required": ["food"], }, }, - "assistant": {"text": "I need to output a valid JSON object.", "json": {"result": "apple pie"}}, + "assistant": {"text": "I need to output a valid JSON object.", "json": {"food": "apple pie"}}, }, ] diff --git a/api/core/workflow/nodes/question_classifier/question_classifier_node.py b/api/core/workflow/nodes/question_classifier/question_classifier_node.py index 483cfff574..592a6566fd 100644 --- a/api/core/workflow/nodes/question_classifier/question_classifier_node.py +++ b/api/core/workflow/nodes/question_classifier/question_classifier_node.py @@ -68,7 +68,7 @@ class QuestionClassifierNode(Node): graph_runtime_state=graph_runtime_state, ) # LLM file outputs, used for MultiModal outputs. - self._file_outputs: list[File] = [] + self._file_outputs = [] if llm_file_saver is None: llm_file_saver = FileSaverImpl( @@ -111,9 +111,9 @@ class QuestionClassifierNode(Node): query = variable.value if variable else None variables = {"query": query} # fetch model config - model_instance, model_config = LLMNode._fetch_model_config( - node_data_model=node_data.model, + model_instance, model_config = llm_utils.fetch_model_config( tenant_id=self.tenant_id, + node_data_model=node_data.model, ) # fetch memory memory = llm_utils.fetch_memory( diff --git a/api/core/workflow/nodes/start/start_node.py b/api/core/workflow/nodes/start/start_node.py index 2f33c54128..3b134be1a1 100644 --- a/api/core/workflow/nodes/start/start_node.py +++ b/api/core/workflow/nodes/start/start_node.py @@ -16,7 +16,7 @@ class StartNode(Node): _node_data: StartNodeData def init_node_data(self, data: Mapping[str, Any]): - self._node_data = StartNodeData(**data) + self._node_data = StartNodeData.model_validate(data) def _get_error_strategy(self) -> ErrorStrategy | None: return self._node_data.error_strategy diff --git a/api/core/workflow/nodes/template_transform/template_transform_node.py b/api/core/workflow/nodes/template_transform/template_transform_node.py index cf05ef253a..254a8318b5 100644 --- a/api/core/workflow/nodes/template_transform/template_transform_node.py +++ b/api/core/workflow/nodes/template_transform/template_transform_node.py @@ -1,7 +1,7 @@ -import os from collections.abc import Mapping, Sequence from typing import Any +from configs import dify_config from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus from core.workflow.node_events import NodeRunResult @@ -9,7 +9,7 @@ from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig from core.workflow.nodes.base.node import Node from core.workflow.nodes.template_transform.entities import TemplateTransformNodeData -MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH = int(os.environ.get("TEMPLATE_TRANSFORM_MAX_LENGTH", "80000")) +MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH = dify_config.TEMPLATE_TRANSFORM_MAX_LENGTH class TemplateTransformNode(Node): diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index 5f2abcd378..cd0094f531 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -79,7 +79,7 @@ class ToolNode(Node): # But for backward compatibility with historical data # this version field judgment is still preserved here. variable_pool: VariablePool | None = None - if node_data.version != "1" or node_data.tool_node_version != "1": + if node_data.version != "1" or node_data.tool_node_version is not None: variable_pool = self.graph_runtime_state.variable_pool tool_runtime = ToolManager.get_workflow_tool_runtime( self.tenant_id, self.app_id, self._node_id, self._node_data, self.invoke_from, variable_pool @@ -224,7 +224,7 @@ class ToolNode(Node): return result def _fetch_files(self, variable_pool: "VariablePool") -> list[File]: - variable = variable_pool.get(["sys", SystemVariableKey.FILES.value]) + variable = variable_pool.get(["sys", SystemVariableKey.FILES]) assert isinstance(variable, ArrayAnyVariable | ArrayAnySegment) return list(variable.value) if variable else [] diff --git a/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py b/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py index be00d55937..0ac0d3d858 100644 --- a/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py +++ b/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py @@ -15,7 +15,7 @@ class VariableAggregatorNode(Node): _node_data: VariableAssignerNodeData def init_node_data(self, data: Mapping[str, Any]): - self._node_data = VariableAssignerNodeData(**data) + self._node_data = VariableAssignerNodeData.model_validate(data) def _get_error_strategy(self) -> ErrorStrategy | None: return self._node_data.error_strategy diff --git a/api/core/workflow/variable_loader.py b/api/core/workflow/variable_loader.py index a35215855e..1b31022495 100644 --- a/api/core/workflow/variable_loader.py +++ b/api/core/workflow/variable_loader.py @@ -66,8 +66,8 @@ def load_into_variable_pool( # NOTE(QuantumGhost): this logic needs to be in sync with # `WorkflowEntry.mapping_user_inputs_to_variable_pool`. node_variable_list = key.split(".") - if len(node_variable_list) < 1: - raise ValueError(f"Invalid variable key: {key}. It should have at least one element.") + if len(node_variable_list) < 2: + raise ValueError(f"Invalid variable key: {key}. It should have at least two elements.") if key in user_inputs: continue node_variable_key = ".".join(node_variable_list[1:]) diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py index 49645ff120..4cd885cfa5 100644 --- a/api/core/workflow/workflow_entry.py +++ b/api/core/workflow/workflow_entry.py @@ -227,7 +227,7 @@ class WorkflowEntry: "height": node_height, "type": "custom", "data": { - "type": NodeType.START.value, + "type": NodeType.START, "title": "Start", "desc": "Start", }, @@ -416,4 +416,8 @@ class WorkflowEntry: # append variable and value to variable pool if variable_node_id != ENVIRONMENT_VARIABLE_NODE_ID: + # In single run, the input_value is set as the LLM's structured output value within the variable_pool. + if len(variable_key_list) == 2 and variable_key_list[0] == "structured_output": + input_value = {variable_key_list[1]: input_value} + variable_key_list = variable_key_list[0:1] variable_pool.add([variable_node_id] + variable_key_list, input_value) diff --git a/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py b/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py index 6c9fc0bf1d..1b44d8a1e2 100644 --- a/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py +++ b/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py @@ -12,9 +12,9 @@ def handle(sender, **kwargs): if synced_draft_workflow is None: return for node_data in synced_draft_workflow.graph_dict.get("nodes", []): - if node_data.get("data", {}).get("type") == NodeType.TOOL.value: + if node_data.get("data", {}).get("type") == NodeType.TOOL: try: - tool_entity = ToolEntity(**node_data["data"]) + tool_entity = ToolEntity.model_validate(node_data["data"]) tool_runtime = ToolManager.get_tool_runtime( provider_type=tool_entity.provider_type, provider_id=tool_entity.provider_id, diff --git a/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py b/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py index 898ec1f153..53e0065f6e 100644 --- a/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py +++ b/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py @@ -53,7 +53,7 @@ def get_dataset_ids_from_workflow(published_workflow: Workflow) -> set[str]: # fetch all knowledge retrieval nodes knowledge_retrieval_nodes = [ - node for node in nodes if node.get("data", {}).get("type") == NodeType.KNOWLEDGE_RETRIEVAL.value + node for node in nodes if node.get("data", {}).get("type") == NodeType.KNOWLEDGE_RETRIEVAL ] if not knowledge_retrieval_nodes: @@ -61,7 +61,7 @@ def get_dataset_ids_from_workflow(published_workflow: Workflow) -> set[str]: for node in knowledge_retrieval_nodes: try: - node_data = KnowledgeRetrievalNodeData(**node.get("data", {})) + node_data = KnowledgeRetrievalNodeData.model_validate(node.get("data", {})) dataset_ids.update(dataset_id for dataset_id in node_data.dataset_ids) except Exception: continue diff --git a/api/events/event_handlers/update_provider_when_message_created.py b/api/events/event_handlers/update_provider_when_message_created.py index 27efa539dc..c0694d4efe 100644 --- a/api/events/event_handlers/update_provider_when_message_created.py +++ b/api/events/event_handlers/update_provider_when_message_created.py @@ -139,7 +139,7 @@ def handle(sender: Message, **kwargs): filters=_ProviderUpdateFilters( tenant_id=tenant_id, provider_name=ModelProviderID(model_config.provider).provider_name, - provider_type=ProviderType.SYSTEM.value, + provider_type=ProviderType.SYSTEM, quota_type=provider_configuration.system_configuration.current_quota_type.value, ), values=_ProviderUpdateValues(quota_used=Provider.quota_used + used_quota, last_used=current_time), diff --git a/api/extensions/ext_app_metrics.py b/api/extensions/ext_app_metrics.py index 56a69a1862..4a6490b9f0 100644 --- a/api/extensions/ext_app_metrics.py +++ b/api/extensions/ext_app_metrics.py @@ -10,14 +10,14 @@ from dify_app import DifyApp def init_app(app: DifyApp): @app.after_request - def after_request(response): + def after_request(response): # pyright: ignore[reportUnusedFunction] """Add Version headers to the response.""" response.headers.add("X-Version", dify_config.project.version) response.headers.add("X-Env", dify_config.DEPLOY_ENV) return response @app.route("/health") - def health(): + def health(): # pyright: ignore[reportUnusedFunction] return Response( json.dumps({"pid": os.getpid(), "status": "ok", "version": dify_config.project.version}), status=200, @@ -25,7 +25,7 @@ def init_app(app: DifyApp): ) @app.route("/threads") - def threads(): + def threads(): # pyright: ignore[reportUnusedFunction] num_threads = threading.active_count() threads = threading.enumerate() @@ -50,7 +50,7 @@ def init_app(app: DifyApp): } @app.route("/db-pool-stat") - def pool_stat(): + def pool_stat(): # pyright: ignore[reportUnusedFunction] from extensions.ext_database import db engine = db.engine diff --git a/api/extensions/ext_celery.py b/api/extensions/ext_celery.py index 585539e2ce..6d7d81ed87 100644 --- a/api/extensions/ext_celery.py +++ b/api/extensions/ext_celery.py @@ -145,6 +145,7 @@ def init_app(app: DifyApp) -> Celery: } if dify_config.ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK and dify_config.MARKETPLACE_ENABLED: imports.append("schedule.check_upgradable_plugin_task") + imports.append("tasks.process_tenant_plugin_autoupgrade_check_task") beat_schedule["check_upgradable_plugin_task"] = { "task": "schedule.check_upgradable_plugin_task.check_upgradable_plugin_task", "schedule": crontab(minute="*/15"), diff --git a/api/extensions/ext_database.py b/api/extensions/ext_database.py index 067ce39e4f..c90b1d0a9f 100644 --- a/api/extensions/ext_database.py +++ b/api/extensions/ext_database.py @@ -10,7 +10,7 @@ from models.engine import db logger = logging.getLogger(__name__) # Global flag to avoid duplicate registration of event listener -_GEVENT_COMPATIBILITY_SETUP: bool = False +_gevent_compatibility_setup: bool = False def _safe_rollback(connection): @@ -26,14 +26,14 @@ def _safe_rollback(connection): def _setup_gevent_compatibility(): - global _GEVENT_COMPATIBILITY_SETUP # pylint: disable=global-statement + global _gevent_compatibility_setup # pylint: disable=global-statement # Avoid duplicate registration - if _GEVENT_COMPATIBILITY_SETUP: + if _gevent_compatibility_setup: return @event.listens_for(Pool, "reset") - def _safe_reset(dbapi_connection, connection_record, reset_state): # pylint: disable=unused-argument + def _safe_reset(dbapi_connection, connection_record, reset_state): # pyright: ignore[reportUnusedFunction] if reset_state.terminate_only: return @@ -47,7 +47,7 @@ def _setup_gevent_compatibility(): except (AttributeError, ImportError): _safe_rollback(dbapi_connection) - _GEVENT_COMPATIBILITY_SETUP = True + _gevent_compatibility_setup = True def init_app(app: DifyApp): diff --git a/api/extensions/ext_import_modules.py b/api/extensions/ext_import_modules.py index 9566f430b6..4eb363ff93 100644 --- a/api/extensions/ext_import_modules.py +++ b/api/extensions/ext_import_modules.py @@ -2,4 +2,4 @@ from dify_app import DifyApp def init_app(app: DifyApp): - from events import event_handlers # noqa: F401 + from events import event_handlers # noqa: F401 # pyright: ignore[reportUnusedImport] diff --git a/api/extensions/ext_otel.py b/api/extensions/ext_otel.py index b0059693e2..cb6e4849a9 100644 --- a/api/extensions/ext_otel.py +++ b/api/extensions/ext_otel.py @@ -136,8 +136,8 @@ def init_app(app: DifyApp): from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter as HTTPSpanExporter from opentelemetry.instrumentation.celery import CeleryInstrumentor from opentelemetry.instrumentation.flask import FlaskInstrumentor + from opentelemetry.instrumentation.httpx import HTTPXClientInstrumentor from opentelemetry.instrumentation.redis import RedisInstrumentor - from opentelemetry.instrumentation.requests import RequestsInstrumentor from opentelemetry.instrumentation.sqlalchemy import SQLAlchemyInstrumentor from opentelemetry.metrics import get_meter, get_meter_provider, set_meter_provider from opentelemetry.propagate import set_global_textmap @@ -237,7 +237,7 @@ def init_app(app: DifyApp): instrument_exception_logging() init_sqlalchemy_instrumentor(app) RedisInstrumentor().instrument() - RequestsInstrumentor().instrument() + HTTPXClientInstrumentor().instrument() atexit.register(shutdown_tracer) diff --git a/api/extensions/ext_sentry.py b/api/extensions/ext_sentry.py index 6cfa99a62a..5ed7840211 100644 --- a/api/extensions/ext_sentry.py +++ b/api/extensions/ext_sentry.py @@ -4,7 +4,6 @@ from dify_app import DifyApp def init_app(app: DifyApp): if dify_config.SENTRY_DSN: - import openai import sentry_sdk from langfuse import parse_error # type: ignore from sentry_sdk.integrations.celery import CeleryIntegration @@ -28,7 +27,6 @@ def init_app(app: DifyApp): HTTPException, ValueError, FileNotFoundError, - openai.APIStatusError, InvokeRateLimitError, parse_error.defaultErrorResponse, ], diff --git a/api/extensions/storage/aliyun_oss_storage.py b/api/extensions/storage/aliyun_oss_storage.py index 00bf5d4f93..5da4737138 100644 --- a/api/extensions/storage/aliyun_oss_storage.py +++ b/api/extensions/storage/aliyun_oss_storage.py @@ -33,7 +33,9 @@ class AliyunOssStorage(BaseStorage): def load_once(self, filename: str) -> bytes: obj = self.client.get_object(self.__wrapper_folder_filename(filename)) - data: bytes = obj.read() + data = obj.read() + if not isinstance(data, bytes): + return b"" return data def load_stream(self, filename: str) -> Generator: diff --git a/api/extensions/storage/aws_s3_storage.py b/api/extensions/storage/aws_s3_storage.py index e755ab089a..6ab2a95e3c 100644 --- a/api/extensions/storage/aws_s3_storage.py +++ b/api/extensions/storage/aws_s3_storage.py @@ -39,10 +39,10 @@ class AwsS3Storage(BaseStorage): self.client.head_bucket(Bucket=self.bucket_name) except ClientError as e: # if bucket not exists, create it - if e.response["Error"]["Code"] == "404": + if e.response.get("Error", {}).get("Code") == "404": self.client.create_bucket(Bucket=self.bucket_name) # if bucket is not accessible, pass, maybe the bucket is existing but not accessible - elif e.response["Error"]["Code"] == "403": + elif e.response.get("Error", {}).get("Code") == "403": pass else: # other error, raise exception @@ -55,7 +55,7 @@ class AwsS3Storage(BaseStorage): try: data: bytes = self.client.get_object(Bucket=self.bucket_name, Key=filename)["Body"].read() except ClientError as ex: - if ex.response["Error"]["Code"] == "NoSuchKey": + if ex.response.get("Error", {}).get("Code") == "NoSuchKey": raise FileNotFoundError("File not found") else: raise @@ -66,7 +66,7 @@ class AwsS3Storage(BaseStorage): response = self.client.get_object(Bucket=self.bucket_name, Key=filename) yield from response["Body"].iter_chunks() except ClientError as ex: - if ex.response["Error"]["Code"] == "NoSuchKey": + if ex.response.get("Error", {}).get("Code") == "NoSuchKey": raise FileNotFoundError("file not found") elif "reached max retries" in str(ex): raise ValueError("please do not request the same file too frequently") diff --git a/api/extensions/storage/azure_blob_storage.py b/api/extensions/storage/azure_blob_storage.py index 9053aece89..4bccaf13c8 100644 --- a/api/extensions/storage/azure_blob_storage.py +++ b/api/extensions/storage/azure_blob_storage.py @@ -27,24 +27,38 @@ class AzureBlobStorage(BaseStorage): self.credential = None def save(self, filename, data): + if not self.bucket_name: + return + client = self._sync_client() blob_container = client.get_container_client(container=self.bucket_name) blob_container.upload_blob(filename, data) def load_once(self, filename: str) -> bytes: + if not self.bucket_name: + raise FileNotFoundError("Azure bucket name is not configured.") + client = self._sync_client() blob = client.get_container_client(container=self.bucket_name) blob = blob.get_blob_client(blob=filename) - data: bytes = blob.download_blob().readall() + data = blob.download_blob().readall() + if not isinstance(data, bytes): + raise TypeError(f"Expected bytes from blob.readall(), got {type(data).__name__}") return data def load_stream(self, filename: str) -> Generator: + if not self.bucket_name: + raise FileNotFoundError("Azure bucket name is not configured.") + client = self._sync_client() blob = client.get_blob_client(container=self.bucket_name, blob=filename) blob_data = blob.download_blob() yield from blob_data.chunks() def download(self, filename, target_filepath): + if not self.bucket_name: + return + client = self._sync_client() blob = client.get_blob_client(container=self.bucket_name, blob=filename) @@ -53,12 +67,18 @@ class AzureBlobStorage(BaseStorage): blob_data.readinto(my_blob) def exists(self, filename): + if not self.bucket_name: + return False + client = self._sync_client() blob = client.get_blob_client(container=self.bucket_name, blob=filename) return blob.exists() def delete(self, filename): + if not self.bucket_name: + return + client = self._sync_client() blob_container = client.get_container_client(container=self.bucket_name) diff --git a/api/extensions/storage/clickzetta_volume/clickzetta_volume_storage.py b/api/extensions/storage/clickzetta_volume/clickzetta_volume_storage.py index 2ffac9a92d..06c528ca41 100644 --- a/api/extensions/storage/clickzetta_volume/clickzetta_volume_storage.py +++ b/api/extensions/storage/clickzetta_volume/clickzetta_volume_storage.py @@ -430,7 +430,7 @@ class ClickZettaVolumeStorage(BaseStorage): rows = self._execute_sql(sql, fetch=True) - exists = len(rows) > 0 + exists = len(rows) > 0 if rows else False logger.debug("File %s exists check: %s", filename, exists) return exists except Exception as e: @@ -509,16 +509,17 @@ class ClickZettaVolumeStorage(BaseStorage): rows = self._execute_sql(sql, fetch=True) result = [] - for row in rows: - file_path = row[0] # relative_path column + if rows: + for row in rows: + file_path = row[0] # relative_path column - # For User Volume, remove dify prefix from results - dify_prefix_with_slash = f"{self._config.dify_prefix}/" - if volume_prefix == "USER VOLUME" and file_path.startswith(dify_prefix_with_slash): - file_path = file_path[len(dify_prefix_with_slash) :] # Remove prefix + # For User Volume, remove dify prefix from results + dify_prefix_with_slash = f"{self._config.dify_prefix}/" + if volume_prefix == "USER VOLUME" and file_path.startswith(dify_prefix_with_slash): + file_path = file_path[len(dify_prefix_with_slash) :] # Remove prefix - if files and not file_path.endswith("/") or directories and file_path.endswith("/"): - result.append(file_path) + if files and not file_path.endswith("/") or directories and file_path.endswith("/"): + result.append(file_path) logger.debug("Scanned %d items in path %s", len(result), path) return result diff --git a/api/extensions/storage/clickzetta_volume/file_lifecycle.py b/api/extensions/storage/clickzetta_volume/file_lifecycle.py index 6ab02ad8cc..dc5aa8e39c 100644 --- a/api/extensions/storage/clickzetta_volume/file_lifecycle.py +++ b/api/extensions/storage/clickzetta_volume/file_lifecycle.py @@ -264,7 +264,7 @@ class FileLifecycleManager: logger.warning("File %s not found in metadata", filename) return False - metadata_dict[filename]["status"] = FileStatus.ARCHIVED.value + metadata_dict[filename]["status"] = FileStatus.ARCHIVED metadata_dict[filename]["modified_at"] = datetime.now().isoformat() self._save_metadata(metadata_dict) @@ -309,7 +309,7 @@ class FileLifecycleManager: # Update metadata metadata_dict = self._load_metadata() if filename in metadata_dict: - metadata_dict[filename]["status"] = FileStatus.DELETED.value + metadata_dict[filename]["status"] = FileStatus.DELETED metadata_dict[filename]["modified_at"] = datetime.now().isoformat() self._save_metadata(metadata_dict) diff --git a/api/extensions/storage/clickzetta_volume/volume_permissions.py b/api/extensions/storage/clickzetta_volume/volume_permissions.py index eb1116638f..6dcf800abb 100644 --- a/api/extensions/storage/clickzetta_volume/volume_permissions.py +++ b/api/extensions/storage/clickzetta_volume/volume_permissions.py @@ -439,6 +439,11 @@ class VolumePermissionManager: self._permission_cache.clear() logger.debug("Permission cache cleared") + @property + def volume_type(self) -> str | None: + """Get the volume type.""" + return self._volume_type + def get_permission_summary(self, dataset_id: str | None = None) -> dict[str, bool]: """Get permission summary @@ -632,13 +637,13 @@ def check_volume_permission(permission_manager: VolumePermissionManager, operati VolumePermissionError: If no permission """ if not permission_manager.validate_operation(operation, dataset_id): - error_message = f"Permission denied for operation '{operation}' on {permission_manager._volume_type} volume" + error_message = f"Permission denied for operation '{operation}' on {permission_manager.volume_type} volume" if dataset_id: error_message += f" (dataset: {dataset_id})" raise VolumePermissionError( error_message, operation=operation, - volume_type=permission_manager._volume_type or "unknown", + volume_type=permission_manager.volume_type or "unknown", dataset_id=dataset_id, ) diff --git a/api/extensions/storage/google_cloud_storage.py b/api/extensions/storage/google_cloud_storage.py index 705639f42e..7f59252f2f 100644 --- a/api/extensions/storage/google_cloud_storage.py +++ b/api/extensions/storage/google_cloud_storage.py @@ -35,12 +35,16 @@ class GoogleCloudStorage(BaseStorage): def load_once(self, filename: str) -> bytes: bucket = self.client.get_bucket(self.bucket_name) blob = bucket.get_blob(filename) + if blob is None: + raise FileNotFoundError("File not found") data: bytes = blob.download_as_bytes() return data def load_stream(self, filename: str) -> Generator: bucket = self.client.get_bucket(self.bucket_name) blob = bucket.get_blob(filename) + if blob is None: + raise FileNotFoundError("File not found") with blob.open(mode="rb") as blob_stream: while chunk := blob_stream.read(4096): yield chunk @@ -48,6 +52,8 @@ class GoogleCloudStorage(BaseStorage): def download(self, filename, target_filepath): bucket = self.client.get_bucket(self.bucket_name) blob = bucket.get_blob(filename) + if blob is None: + raise FileNotFoundError("File not found") blob.download_to_filename(target_filepath) def exists(self, filename): diff --git a/api/extensions/storage/huawei_obs_storage.py b/api/extensions/storage/huawei_obs_storage.py index 07f1d19970..3e75ecb7a9 100644 --- a/api/extensions/storage/huawei_obs_storage.py +++ b/api/extensions/storage/huawei_obs_storage.py @@ -45,7 +45,7 @@ class HuaweiObsStorage(BaseStorage): def _get_meta(self, filename): res = self.client.getObjectMetadata(bucketName=self.bucket_name, objectKey=filename) - if res.status < 300: + if res and res.status and res.status < 300: return res else: return None diff --git a/api/extensions/storage/opendal_storage.py b/api/extensions/storage/opendal_storage.py index b10391c7f1..f7146adba6 100644 --- a/api/extensions/storage/opendal_storage.py +++ b/api/extensions/storage/opendal_storage.py @@ -3,9 +3,9 @@ import os from collections.abc import Generator from pathlib import Path +import opendal from dotenv import dotenv_values from opendal import Operator -from opendal.layers import RetryLayer from extensions.storage.base_storage import BaseStorage @@ -35,7 +35,7 @@ class OpenDALStorage(BaseStorage): root = kwargs.get("root", "storage") Path(root).mkdir(parents=True, exist_ok=True) - retry_layer = RetryLayer(max_times=3, factor=2.0, jitter=True) + retry_layer = opendal.layers.RetryLayer(max_times=3, factor=2.0, jitter=True) self.op = Operator(scheme=scheme, **kwargs).layer(retry_layer) logger.debug("opendal operator created with scheme %s", scheme) logger.debug("added retry layer to opendal operator") diff --git a/api/extensions/storage/oracle_oci_storage.py b/api/extensions/storage/oracle_oci_storage.py index 82829f7fd5..acc00cbd6b 100644 --- a/api/extensions/storage/oracle_oci_storage.py +++ b/api/extensions/storage/oracle_oci_storage.py @@ -29,7 +29,7 @@ class OracleOCIStorage(BaseStorage): try: data: bytes = self.client.get_object(Bucket=self.bucket_name, Key=filename)["Body"].read() except ClientError as ex: - if ex.response["Error"]["Code"] == "NoSuchKey": + if ex.response.get("Error", {}).get("Code") == "NoSuchKey": raise FileNotFoundError("File not found") else: raise @@ -40,7 +40,7 @@ class OracleOCIStorage(BaseStorage): response = self.client.get_object(Bucket=self.bucket_name, Key=filename) yield from response["Body"].iter_chunks() except ClientError as ex: - if ex.response["Error"]["Code"] == "NoSuchKey": + if ex.response.get("Error", {}).get("Code") == "NoSuchKey": raise FileNotFoundError("File not found") else: raise diff --git a/api/extensions/storage/supabase_storage.py b/api/extensions/storage/supabase_storage.py index 711c3f7211..2ca84d4c15 100644 --- a/api/extensions/storage/supabase_storage.py +++ b/api/extensions/storage/supabase_storage.py @@ -46,13 +46,13 @@ class SupabaseStorage(BaseStorage): Path(target_filepath).write_bytes(result) def exists(self, filename): - result = self.client.storage.from_(self.bucket_name).list(filename) - if result.count() > 0: + result = self.client.storage.from_(self.bucket_name).list(path=filename) + if len(result) > 0: return True return False def delete(self, filename): - self.client.storage.from_(self.bucket_name).remove(filename) + self.client.storage.from_(self.bucket_name).remove([filename]) def bucket_exists(self): buckets = self.client.storage.list_buckets() diff --git a/api/extensions/storage/volcengine_tos_storage.py b/api/extensions/storage/volcengine_tos_storage.py index 32839d3497..8ed8e4c170 100644 --- a/api/extensions/storage/volcengine_tos_storage.py +++ b/api/extensions/storage/volcengine_tos_storage.py @@ -11,6 +11,14 @@ class VolcengineTosStorage(BaseStorage): def __init__(self): super().__init__() + if not dify_config.VOLCENGINE_TOS_ACCESS_KEY: + raise ValueError("VOLCENGINE_TOS_ACCESS_KEY is not set") + if not dify_config.VOLCENGINE_TOS_SECRET_KEY: + raise ValueError("VOLCENGINE_TOS_SECRET_KEY is not set") + if not dify_config.VOLCENGINE_TOS_ENDPOINT: + raise ValueError("VOLCENGINE_TOS_ENDPOINT is not set") + if not dify_config.VOLCENGINE_TOS_REGION: + raise ValueError("VOLCENGINE_TOS_REGION is not set") self.bucket_name = dify_config.VOLCENGINE_TOS_BUCKET_NAME self.client = tos.TosClientV2( ak=dify_config.VOLCENGINE_TOS_ACCESS_KEY, @@ -20,27 +28,39 @@ class VolcengineTosStorage(BaseStorage): ) def save(self, filename, data): + if not self.bucket_name: + raise ValueError("VOLCENGINE_TOS_BUCKET_NAME is not set") self.client.put_object(bucket=self.bucket_name, key=filename, content=data) def load_once(self, filename: str) -> bytes: + if not self.bucket_name: + raise FileNotFoundError("VOLCENGINE_TOS_BUCKET_NAME is not set") data = self.client.get_object(bucket=self.bucket_name, key=filename).read() if not isinstance(data, bytes): raise TypeError(f"Expected bytes, got {type(data).__name__}") return data def load_stream(self, filename: str) -> Generator: + if not self.bucket_name: + raise FileNotFoundError("VOLCENGINE_TOS_BUCKET_NAME is not set") response = self.client.get_object(bucket=self.bucket_name, key=filename) while chunk := response.read(4096): yield chunk def download(self, filename, target_filepath): + if not self.bucket_name: + raise ValueError("VOLCENGINE_TOS_BUCKET_NAME is not set") self.client.get_object_to_file(bucket=self.bucket_name, key=filename, file_path=target_filepath) def exists(self, filename): + if not self.bucket_name: + return False res = self.client.head_object(bucket=self.bucket_name, key=filename) if res.status_code != 200: return False return True def delete(self, filename): + if not self.bucket_name: + return self.client.delete_object(bucket=self.bucket_name, key=filename) diff --git a/api/factories/file_factory.py b/api/factories/file_factory.py index 588168bd39..69fd1a6da3 100644 --- a/api/factories/file_factory.py +++ b/api/factories/file_factory.py @@ -8,6 +8,7 @@ from typing import Any import httpx from sqlalchemy import select from sqlalchemy.orm import Session +from werkzeug.http import parse_options_header from constants import AUDIO_EXTENSIONS, DOCUMENT_EXTENSIONS, IMAGE_EXTENSIONS, VIDEO_EXTENSIONS from core.file import File, FileBelongsTo, FileTransferMethod, FileType, FileUploadConfig, helpers @@ -44,7 +45,7 @@ def build_from_message_file( } # Set the correct ID field based on transfer method - if message_file.transfer_method == FileTransferMethod.TOOL_FILE.value: + if message_file.transfer_method == FileTransferMethod.TOOL_FILE: mapping["tool_file_id"] = message_file.upload_file_id else: mapping["upload_file_id"] = message_file.upload_file_id @@ -247,6 +248,25 @@ def _build_from_remote_url( ) +def _extract_filename(url_path: str, content_disposition: str | None) -> str | None: + filename = None + # Try to extract from Content-Disposition header first + if content_disposition: + _, params = parse_options_header(content_disposition) + # RFC 5987 https://datatracker.ietf.org/doc/html/rfc5987: filename* takes precedence over filename + filename = params.get("filename*") or params.get("filename") + # Fallback to URL path if no filename from header + if not filename: + filename = os.path.basename(url_path) + return filename or None + + +def _guess_mime_type(filename: str) -> str: + """Guess MIME type from filename, returning empty string if None.""" + guessed_mime, _ = mimetypes.guess_type(filename) + return guessed_mime or "" + + def _get_remote_file_info(url: str): file_size = -1 parsed_url = urllib.parse.urlparse(url) @@ -254,23 +274,26 @@ def _get_remote_file_info(url: str): filename = os.path.basename(url_path) # Initialize mime_type from filename as fallback - mime_type, _ = mimetypes.guess_type(filename) - if mime_type is None: - mime_type = "" + mime_type = _guess_mime_type(filename) resp = ssrf_proxy.head(url, follow_redirects=True) if resp.status_code == httpx.codes.OK: - if content_disposition := resp.headers.get("Content-Disposition"): - filename = str(content_disposition.split("filename=")[-1].strip('"')) - # Re-guess mime_type from updated filename - mime_type, _ = mimetypes.guess_type(filename) - if mime_type is None: - mime_type = "" + content_disposition = resp.headers.get("Content-Disposition") + extracted_filename = _extract_filename(url_path, content_disposition) + if extracted_filename: + filename = extracted_filename + mime_type = _guess_mime_type(filename) file_size = int(resp.headers.get("Content-Length", file_size)) # Fallback to Content-Type header if mime_type is still empty if not mime_type: mime_type = resp.headers.get("Content-Type", "").split(";")[0].strip() + if not filename: + extension = mimetypes.guess_extension(mime_type) or ".bin" + filename = f"{uuid.uuid4().hex}{extension}" + if not mime_type: + mime_type = _guess_mime_type(filename) + return mime_type, filename, file_size @@ -345,9 +368,7 @@ def _build_from_datasource_file( if strict_type_validation and specified_type and detected_file_type.value != specified_type: raise ValueError("Detected file type does not match the specified type. Please verify the file.") - file_type = ( - FileType(specified_type) if specified_type and specified_type != FileType.CUSTOM.value else detected_file_type - ) + file_type = FileType(specified_type) if specified_type and specified_type != FileType.CUSTOM else detected_file_type return File( id=mapping.get("datasource_file_id"), diff --git a/api/factories/variable_factory.py b/api/factories/variable_factory.py index 2104e66254..494194369a 100644 --- a/api/factories/variable_factory.py +++ b/api/factories/variable_factory.py @@ -142,6 +142,8 @@ def build_segment(value: Any, /) -> Segment: # below if value is None: return NoneSegment() + if isinstance(value, Segment): + return value if isinstance(value, str): return StringSegment(value=value) if isinstance(value, bool): diff --git a/api/fields/file_fields.py b/api/fields/file_fields.py index dd359e2f5f..c12ebc09c8 100644 --- a/api/fields/file_fields.py +++ b/api/fields/file_fields.py @@ -33,6 +33,7 @@ file_fields = { "created_by": fields.String, "created_at": TimestampField, "preview_url": fields.String, + "source_url": fields.String, } diff --git a/api/gunicorn.conf.py b/api/gunicorn.conf.py index fc91a43670..943ee100ca 100644 --- a/api/gunicorn.conf.py +++ b/api/gunicorn.conf.py @@ -1,10 +1,32 @@ import psycogreen.gevent as pscycogreen_gevent # type: ignore +from gevent import events as gevent_events from grpc.experimental import gevent as grpc_gevent # type: ignore +# NOTE(QuantumGhost): here we cannot use post_fork to patch gRPC, as +# grpc_gevent.init_gevent must be called after patching stdlib. +# Gunicorn calls `post_init` before applying monkey patch. +# Use `post_init` to setup gRPC gevent support would cause deadlock and +# some other weird issues. +# +# ref: +# - https://github.com/grpc/grpc/blob/62533ea13879d6ee95c6fda11ec0826ca822c9dd/src/python/grpcio/grpc/experimental/gevent.py +# - https://github.com/gevent/gevent/issues/2060#issuecomment-3016768668 +# - https://github.com/benoitc/gunicorn/blob/master/gunicorn/arbiter.py#L607-L613 -def post_fork(server, worker): + +def post_patch(event): + # this function is only called for gevent worker. + # from gevent docs (https://www.gevent.org/api/gevent.monkey.html): + # You can also subscribe to the events to provide additional patching beyond what gevent distributes, either for + # additional standard library modules, or for third-party packages. The suggested time to do this patching is in + # the subscriber for gevent.events.GeventDidPatchBuiltinModulesEvent. + if not isinstance(event, gevent_events.GeventDidPatchBuiltinModulesEvent): + return # grpc gevent grpc_gevent.init_gevent() - server.log.info("gRPC patched with gevent.") + print("gRPC patched with gevent.", flush=True) # noqa: T201 pscycogreen_gevent.patch_psycopg() - server.log.info("psycopg2 patched with gevent.") + print("psycopg2 patched with gevent.", flush=True) # noqa: T201 + + +gevent_events.subscribers.append(post_patch) diff --git a/api/libs/collection_utils.py b/api/libs/collection_utils.py new file mode 100644 index 0000000000..f97308ca44 --- /dev/null +++ b/api/libs/collection_utils.py @@ -0,0 +1,14 @@ +def convert_to_lower_and_upper_set(inputs: list[str] | set[str]) -> set[str]: + """ + Convert a list or set of strings to a set containing both lower and upper case versions of each string. + + Args: + inputs (list[str] | set[str]): A list or set of strings to be converted. + + Returns: + set[str]: A set containing both lower and upper case versions of each string. + """ + if not inputs: + return set() + else: + return {case for s in inputs if s for case in (s.lower(), s.upper())} diff --git a/api/libs/external_api.py b/api/libs/external_api.py index cf91b0117f..25a82f8a96 100644 --- a/api/libs/external_api.py +++ b/api/libs/external_api.py @@ -94,7 +94,7 @@ def register_external_error_handlers(api: Api): got_request_exception.send(current_app, exception=e) status_code = 500 - data = getattr(e, "data", {"message": http_status_message(status_code)}) + data: dict[str, Any] = getattr(e, "data", {"message": http_status_message(status_code)}) # 🔒 Normalize non-mapping data (e.g., if someone set e.data = Response) if not isinstance(data, dict): diff --git a/api/libs/gmpy2_pkcs10aep_cipher.py b/api/libs/gmpy2_pkcs10aep_cipher.py index 9759156c0f..fc38d51005 100644 --- a/api/libs/gmpy2_pkcs10aep_cipher.py +++ b/api/libs/gmpy2_pkcs10aep_cipher.py @@ -27,7 +27,7 @@ import gmpy2 # type: ignore from Crypto import Random from Crypto.Signature.pss import MGF1 from Crypto.Util.number import bytes_to_long, ceil_div, long_to_bytes -from Crypto.Util.py3compat import _copy_bytes, bord +from Crypto.Util.py3compat import bord from Crypto.Util.strxor import strxor @@ -72,7 +72,7 @@ class PKCS1OAepCipher: else: self._mgf = lambda x, y: MGF1(x, y, self._hashObj) - self._label = _copy_bytes(None, None, label) + self._label = bytes(label) self._randfunc = randfunc def can_encrypt(self): @@ -120,7 +120,7 @@ class PKCS1OAepCipher: # Step 2b ps = b"\x00" * ps_len # Step 2c - db = lHash + ps + b"\x01" + _copy_bytes(None, None, message) + db = lHash + ps + b"\x01" + bytes(message) # Step 2d ros = self._randfunc(hLen) # Step 2e diff --git a/api/libs/oauth.py b/api/libs/oauth.py index 35bd6c2c7c..889a5a3248 100644 --- a/api/libs/oauth.py +++ b/api/libs/oauth.py @@ -1,7 +1,7 @@ import urllib.parse from dataclasses import dataclass -import requests +import httpx @dataclass @@ -58,7 +58,7 @@ class GitHubOAuth(OAuth): "redirect_uri": self.redirect_uri, } headers = {"Accept": "application/json"} - response = requests.post(self._TOKEN_URL, data=data, headers=headers) + response = httpx.post(self._TOKEN_URL, data=data, headers=headers) response_json = response.json() access_token = response_json.get("access_token") @@ -70,11 +70,11 @@ class GitHubOAuth(OAuth): def get_raw_user_info(self, token: str): headers = {"Authorization": f"token {token}"} - response = requests.get(self._USER_INFO_URL, headers=headers) + response = httpx.get(self._USER_INFO_URL, headers=headers) response.raise_for_status() user_info = response.json() - email_response = requests.get(self._EMAIL_INFO_URL, headers=headers) + email_response = httpx.get(self._EMAIL_INFO_URL, headers=headers) email_info = email_response.json() primary_email: dict = next((email for email in email_info if email["primary"] == True), {}) @@ -112,7 +112,7 @@ class GoogleOAuth(OAuth): "redirect_uri": self.redirect_uri, } headers = {"Accept": "application/json"} - response = requests.post(self._TOKEN_URL, data=data, headers=headers) + response = httpx.post(self._TOKEN_URL, data=data, headers=headers) response_json = response.json() access_token = response_json.get("access_token") @@ -124,7 +124,7 @@ class GoogleOAuth(OAuth): def get_raw_user_info(self, token: str): headers = {"Authorization": f"Bearer {token}"} - response = requests.get(self._USER_INFO_URL, headers=headers) + response = httpx.get(self._USER_INFO_URL, headers=headers) response.raise_for_status() return response.json() diff --git a/api/libs/oauth_data_source.py b/api/libs/oauth_data_source.py index 987c5d7135..ae0ae3bcb6 100644 --- a/api/libs/oauth_data_source.py +++ b/api/libs/oauth_data_source.py @@ -1,7 +1,7 @@ import urllib.parse from typing import Any -import requests +import httpx from flask_login import current_user from sqlalchemy import select @@ -43,7 +43,7 @@ class NotionOAuth(OAuthDataSource): data = {"code": code, "grant_type": "authorization_code", "redirect_uri": self.redirect_uri} headers = {"Accept": "application/json"} auth = (self.client_id, self.client_secret) - response = requests.post(self._TOKEN_URL, data=data, auth=auth, headers=headers) + response = httpx.post(self._TOKEN_URL, data=data, auth=auth, headers=headers) response_json = response.json() access_token = response_json.get("access_token") @@ -239,7 +239,7 @@ class NotionOAuth(OAuthDataSource): "Notion-Version": "2022-06-28", } - response = requests.post(url=self._NOTION_PAGE_SEARCH, json=data, headers=headers) + response = httpx.post(url=self._NOTION_PAGE_SEARCH, json=data, headers=headers) response_json = response.json() results.extend(response_json.get("results", [])) @@ -254,7 +254,7 @@ class NotionOAuth(OAuthDataSource): "Authorization": f"Bearer {access_token}", "Notion-Version": "2022-06-28", } - response = requests.get(url=f"{self._NOTION_BLOCK_SEARCH}/{block_id}", headers=headers) + response = httpx.get(url=f"{self._NOTION_BLOCK_SEARCH}/{block_id}", headers=headers) response_json = response.json() if response.status_code != 200: message = response_json.get("message", "unknown error") @@ -270,7 +270,7 @@ class NotionOAuth(OAuthDataSource): "Authorization": f"Bearer {access_token}", "Notion-Version": "2022-06-28", } - response = requests.get(url=self._NOTION_BOT_USER, headers=headers) + response = httpx.get(url=self._NOTION_BOT_USER, headers=headers) response_json = response.json() if "object" in response_json and response_json["object"] == "user": user_type = response_json["type"] @@ -294,7 +294,7 @@ class NotionOAuth(OAuthDataSource): "Authorization": f"Bearer {access_token}", "Notion-Version": "2022-06-28", } - response = requests.post(url=self._NOTION_PAGE_SEARCH, json=data, headers=headers) + response = httpx.post(url=self._NOTION_PAGE_SEARCH, json=data, headers=headers) response_json = response.json() results.extend(response_json.get("results", [])) diff --git a/api/libs/sendgrid.py b/api/libs/sendgrid.py index ecc4b3fb98..a270fa70fa 100644 --- a/api/libs/sendgrid.py +++ b/api/libs/sendgrid.py @@ -14,7 +14,7 @@ class SendGridClient: def send(self, mail: dict): logger.debug("Sending email with SendGrid") - + _to = "" try: _to = mail["to"] @@ -28,7 +28,7 @@ class SendGridClient: content = Content("text/html", mail["html"]) sg_mail = Mail(from_email, to_email, subject, content) mail_json = sg_mail.get() - response = sg.client.mail.send.post(request_body=mail_json) # ty: ignore [call-non-callable] + response = sg.client.mail.send.post(request_body=mail_json) # type: ignore logger.debug(response.status_code) logger.debug(response.body) logger.debug(response.headers) diff --git a/api/libs/validators.py b/api/libs/validators.py new file mode 100644 index 0000000000..4d762e8116 --- /dev/null +++ b/api/libs/validators.py @@ -0,0 +1,5 @@ +def validate_description_length(description: str | None) -> str | None: + """Validate description length.""" + if description and len(description) > 400: + raise ValueError("Description cannot exceed 400 characters.") + return description diff --git a/api/migrations/versions/2025_09_17_1515-68519ad5cd18_knowledge_pipeline_migrate.py b/api/migrations/versions/2025_09_17_1515-68519ad5cd18_knowledge_pipeline_migrate.py index 742cfc345a..53a95141ec 100644 --- a/api/migrations/versions/2025_09_17_1515-68519ad5cd18_knowledge_pipeline_migrate.py +++ b/api/migrations/versions/2025_09_17_1515-68519ad5cd18_knowledge_pipeline_migrate.py @@ -47,7 +47,7 @@ def upgrade(): sa.Column('plugin_id', sa.String(length=255), nullable=False), sa.Column('auth_type', sa.String(length=255), nullable=False), sa.Column('encrypted_credentials', postgresql.JSONB(astext_type=sa.Text()), nullable=False), - sa.Column('avatar_url', sa.String(length=255), nullable=True), + sa.Column('avatar_url', sa.Text(), nullable=True), sa.Column('is_default', sa.Boolean(), server_default=sa.text('false'), nullable=False), sa.Column('expires_at', sa.Integer(), server_default='-1', nullable=False), sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), diff --git a/api/models/account.py b/api/models/account.py index 8c1f990aa2..86cd9e41b5 100644 --- a/api/models/account.py +++ b/api/models/account.py @@ -1,15 +1,16 @@ import enum import json +from dataclasses import field from datetime import datetime from typing import Any, Optional import sqlalchemy as sa from flask_login import UserMixin # type: ignore[import-untyped] from sqlalchemy import DateTime, String, func, select -from sqlalchemy.orm import Mapped, Session, mapped_column, reconstructor +from sqlalchemy.orm import Mapped, Session, mapped_column from typing_extensions import deprecated -from models.base import Base +from models.base import TypeBase from .engine import db from .types import StringUUID @@ -83,31 +84,37 @@ class AccountStatus(enum.StrEnum): CLOSED = "closed" -class Account(UserMixin, Base): +class Account(UserMixin, TypeBase): __tablename__ = "accounts" __table_args__ = (sa.PrimaryKeyConstraint("id", name="account_pkey"), sa.Index("account_email_idx", "email")) - id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) + id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"), init=False) name: Mapped[str] = mapped_column(String(255)) email: Mapped[str] = mapped_column(String(255)) - password: Mapped[str | None] = mapped_column(String(255)) - password_salt: Mapped[str | None] = mapped_column(String(255)) - avatar: Mapped[str | None] = mapped_column(String(255), nullable=True) - interface_language: Mapped[str | None] = mapped_column(String(255)) - interface_theme: Mapped[str | None] = mapped_column(String(255), nullable=True) - timezone: Mapped[str | None] = mapped_column(String(255)) - last_login_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) - last_login_ip: Mapped[str | None] = mapped_column(String(255), nullable=True) - last_active_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp(), nullable=False) - status: Mapped[str] = mapped_column(String(16), server_default=sa.text("'active'::character varying")) - initialized_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) - created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp(), nullable=False) - updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp(), nullable=False) + password: Mapped[str | None] = mapped_column(String(255), default=None) + password_salt: Mapped[str | None] = mapped_column(String(255), default=None) + avatar: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None) + interface_language: Mapped[str | None] = mapped_column(String(255), default=None) + interface_theme: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None) + timezone: Mapped[str | None] = mapped_column(String(255), default=None) + last_login_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True, default=None) + last_login_ip: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None) + last_active_at: Mapped[datetime] = mapped_column( + DateTime, server_default=func.current_timestamp(), nullable=False, init=False + ) + status: Mapped[str] = mapped_column( + String(16), server_default=sa.text("'active'::character varying"), default="active" + ) + initialized_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True, default=None) + created_at: Mapped[datetime] = mapped_column( + DateTime, server_default=func.current_timestamp(), nullable=False, init=False + ) + updated_at: Mapped[datetime] = mapped_column( + DateTime, server_default=func.current_timestamp(), nullable=False, init=False + ) - @reconstructor - def init_on_load(self): - self.role: TenantAccountRole | None = None - self._current_tenant: Tenant | None = None + role: TenantAccountRole | None = field(default=None, init=False) + _current_tenant: "Tenant | None" = field(default=None, init=False) @property def is_password_set(self): @@ -226,18 +233,24 @@ class TenantStatus(enum.StrEnum): ARCHIVE = "archive" -class Tenant(Base): +class Tenant(TypeBase): __tablename__ = "tenants" __table_args__ = (sa.PrimaryKeyConstraint("id", name="tenant_pkey"),) - id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) + id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"), init=False) name: Mapped[str] = mapped_column(String(255)) - encrypt_public_key: Mapped[str | None] = mapped_column(sa.Text) - plan: Mapped[str] = mapped_column(String(255), server_default=sa.text("'basic'::character varying")) - status: Mapped[str] = mapped_column(String(255), server_default=sa.text("'normal'::character varying")) - custom_config: Mapped[str | None] = mapped_column(sa.Text) - created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp(), nullable=False) - updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp()) + encrypt_public_key: Mapped[str | None] = mapped_column(sa.Text, default=None) + plan: Mapped[str] = mapped_column( + String(255), server_default=sa.text("'basic'::character varying"), default="basic" + ) + status: Mapped[str] = mapped_column( + String(255), server_default=sa.text("'normal'::character varying"), default="normal" + ) + custom_config: Mapped[str | None] = mapped_column(sa.Text, default=None) + created_at: Mapped[datetime] = mapped_column( + DateTime, server_default=func.current_timestamp(), nullable=False, init=False + ) + updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp(), init=False) def get_accounts(self) -> list[Account]: return list( @@ -257,7 +270,7 @@ class Tenant(Base): self.custom_config = json.dumps(value) -class TenantAccountJoin(Base): +class TenantAccountJoin(TypeBase): __tablename__ = "tenant_account_joins" __table_args__ = ( sa.PrimaryKeyConstraint("id", name="tenant_account_join_pkey"), @@ -266,17 +279,21 @@ class TenantAccountJoin(Base): sa.UniqueConstraint("tenant_id", "account_id", name="unique_tenant_account_join"), ) - id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) + id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"), init=False) tenant_id: Mapped[str] = mapped_column(StringUUID) account_id: Mapped[str] = mapped_column(StringUUID) - current: Mapped[bool] = mapped_column(sa.Boolean, server_default=sa.text("false")) - role: Mapped[str] = mapped_column(String(16), server_default="normal") - invited_by: Mapped[str | None] = mapped_column(StringUUID) - created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp()) - updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp()) + current: Mapped[bool] = mapped_column(sa.Boolean, server_default=sa.text("false"), default=False) + role: Mapped[str] = mapped_column(String(16), server_default="normal", default="normal") + invited_by: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None) + created_at: Mapped[datetime] = mapped_column( + DateTime, server_default=func.current_timestamp(), nullable=False, init=False + ) + updated_at: Mapped[datetime] = mapped_column( + DateTime, server_default=func.current_timestamp(), nullable=False, init=False + ) -class AccountIntegrate(Base): +class AccountIntegrate(TypeBase): __tablename__ = "account_integrates" __table_args__ = ( sa.PrimaryKeyConstraint("id", name="account_integrate_pkey"), @@ -284,16 +301,20 @@ class AccountIntegrate(Base): sa.UniqueConstraint("provider", "open_id", name="unique_provider_open_id"), ) - id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) + id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"), init=False) account_id: Mapped[str] = mapped_column(StringUUID) provider: Mapped[str] = mapped_column(String(16)) open_id: Mapped[str] = mapped_column(String(255)) encrypted_token: Mapped[str] = mapped_column(String(255)) - created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp()) - updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp()) + created_at: Mapped[datetime] = mapped_column( + DateTime, server_default=func.current_timestamp(), nullable=False, init=False + ) + updated_at: Mapped[datetime] = mapped_column( + DateTime, server_default=func.current_timestamp(), nullable=False, init=False + ) -class InvitationCode(Base): +class InvitationCode(TypeBase): __tablename__ = "invitation_codes" __table_args__ = ( sa.PrimaryKeyConstraint("id", name="invitation_code_pkey"), @@ -301,18 +322,22 @@ class InvitationCode(Base): sa.Index("invitation_codes_code_idx", "code", "status"), ) - id: Mapped[int] = mapped_column(sa.Integer) + id: Mapped[int] = mapped_column(sa.Integer, init=False) batch: Mapped[str] = mapped_column(String(255)) code: Mapped[str] = mapped_column(String(32)) - status: Mapped[str] = mapped_column(String(16), server_default=sa.text("'unused'::character varying")) - used_at: Mapped[datetime | None] = mapped_column(DateTime) - used_by_tenant_id: Mapped[str | None] = mapped_column(StringUUID) - used_by_account_id: Mapped[str | None] = mapped_column(StringUUID) - deprecated_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) - created_at: Mapped[datetime] = mapped_column(DateTime, server_default=sa.text("CURRENT_TIMESTAMP(0)")) + status: Mapped[str] = mapped_column( + String(16), server_default=sa.text("'unused'::character varying"), default="unused" + ) + used_at: Mapped[datetime | None] = mapped_column(DateTime, default=None) + used_by_tenant_id: Mapped[str | None] = mapped_column(StringUUID, default=None) + used_by_account_id: Mapped[str | None] = mapped_column(StringUUID, default=None) + deprecated_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True, default=None) + created_at: Mapped[datetime] = mapped_column( + DateTime, server_default=sa.text("CURRENT_TIMESTAMP(0)"), nullable=False, init=False + ) -class TenantPluginPermission(Base): +class TenantPluginPermission(TypeBase): class InstallPermission(enum.StrEnum): EVERYONE = "everyone" ADMINS = "admins" @@ -329,13 +354,17 @@ class TenantPluginPermission(Base): sa.UniqueConstraint("tenant_id", name="unique_tenant_plugin"), ) - id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) + id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"), init=False) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - install_permission: Mapped[InstallPermission] = mapped_column(String(16), nullable=False, server_default="everyone") - debug_permission: Mapped[DebugPermission] = mapped_column(String(16), nullable=False, server_default="noone") + install_permission: Mapped[InstallPermission] = mapped_column( + String(16), nullable=False, server_default="everyone", default=InstallPermission.EVERYONE + ) + debug_permission: Mapped[DebugPermission] = mapped_column( + String(16), nullable=False, server_default="noone", default=DebugPermission.NOBODY + ) -class TenantPluginAutoUpgradeStrategy(Base): +class TenantPluginAutoUpgradeStrategy(TypeBase): class StrategySetting(enum.StrEnum): DISABLED = "disabled" FIX_ONLY = "fix_only" @@ -352,12 +381,20 @@ class TenantPluginAutoUpgradeStrategy(Base): sa.UniqueConstraint("tenant_id", name="unique_tenant_plugin_auto_upgrade_strategy"), ) - id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) + id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"), init=False) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - strategy_setting: Mapped[StrategySetting] = mapped_column(String(16), nullable=False, server_default="fix_only") - upgrade_time_of_day: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0) # seconds of the day - upgrade_mode: Mapped[UpgradeMode] = mapped_column(String(16), nullable=False, server_default="exclude") - exclude_plugins: Mapped[list[str]] = mapped_column(sa.ARRAY(String(255)), nullable=False) # plugin_id (author/name) - include_plugins: Mapped[list[str]] = mapped_column(sa.ARRAY(String(255)), nullable=False) # plugin_id (author/name) - created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) + strategy_setting: Mapped[StrategySetting] = mapped_column( + String(16), nullable=False, server_default="fix_only", default=StrategySetting.FIX_ONLY + ) + upgrade_mode: Mapped[UpgradeMode] = mapped_column( + String(16), nullable=False, server_default="exclude", default=UpgradeMode.EXCLUDE + ) + exclude_plugins: Mapped[list[str]] = mapped_column(sa.ARRAY(String(255)), nullable=False, default_factory=list) + include_plugins: Mapped[list[str]] = mapped_column(sa.ARRAY(String(255)), nullable=False, default_factory=list) + upgrade_time_of_day: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0) + created_at: Mapped[datetime] = mapped_column( + DateTime, nullable=False, server_default=func.current_timestamp(), init=False + ) + updated_at: Mapped[datetime] = mapped_column( + DateTime, nullable=False, server_default=func.current_timestamp(), init=False + ) diff --git a/api/models/api_based_extension.py b/api/models/api_based_extension.py index 60167d9069..e86826fc3d 100644 --- a/api/models/api_based_extension.py +++ b/api/models/api_based_extension.py @@ -9,7 +9,7 @@ from .base import Base from .types import StringUUID -class APIBasedExtensionPoint(enum.Enum): +class APIBasedExtensionPoint(enum.StrEnum): APP_EXTERNAL_DATA_TOOL_QUERY = "app.external_data_tool.query" PING = "ping" APP_MODERATION_INPUT = "app.moderation.input" diff --git a/api/models/dataset.py b/api/models/dataset.py index 2c4059f800..5653445f2b 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -61,18 +61,18 @@ class Dataset(Base): created_by = mapped_column(StringUUID, nullable=False) created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) updated_by = mapped_column(StringUUID, nullable=True) - updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) embedding_model = mapped_column(db.String(255), nullable=True) embedding_model_provider = mapped_column(db.String(255), nullable=True) - keyword_number = db.Column(db.Integer, nullable=True, server_default=db.text("10")) + keyword_number = mapped_column(sa.Integer, nullable=True, server_default=db.text("10")) collection_binding_id = mapped_column(StringUUID, nullable=True) retrieval_model = mapped_column(JSONB, nullable=True) - built_in_field_enabled = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) - icon_info = db.Column(JSONB, nullable=True) - runtime_mode = db.Column(db.String(255), nullable=True, server_default=db.text("'general'::character varying")) - pipeline_id = db.Column(StringUUID, nullable=True) - chunk_structure = db.Column(db.String(255), nullable=True) - enable_api = db.Column(db.Boolean, nullable=False, server_default=db.text("true")) + built_in_field_enabled = mapped_column(sa.Boolean, nullable=False, server_default=db.text("false")) + icon_info = mapped_column(JSONB, nullable=True) + runtime_mode = mapped_column(db.String(255), nullable=True, server_default=db.text("'general'::character varying")) + pipeline_id = mapped_column(StringUUID, nullable=True) + chunk_structure = mapped_column(db.String(255), nullable=True) + enable_api = mapped_column(sa.Boolean, nullable=False, server_default=db.text("true")) @property def total_documents(self): @@ -184,7 +184,7 @@ class Dataset(Base): @property def retrieval_model_dict(self): default_retrieval_model = { - "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, + "search_method": RetrievalMethod.SEMANTIC_SEARCH, "reranking_enable": False, "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, "top_k": 2, @@ -754,7 +754,7 @@ class DocumentSegment(Base): if process_rule and process_rule.mode == "hierarchical": rules_dict = process_rule.rules_dict if rules_dict: - rules = Rule(**rules_dict) + rules = Rule.model_validate(rules_dict) if rules.parent_mode and rules.parent_mode != ParentMode.FULL_DOC: child_chunks = ( db.session.query(ChildChunk) @@ -772,7 +772,7 @@ class DocumentSegment(Base): if process_rule and process_rule.mode == "hierarchical": rules_dict = process_rule.rules_dict if rules_dict: - rules = Rule(**rules_dict) + rules = Rule.model_validate(rules_dict) if rules.parent_mode: child_chunks = ( db.session.query(ChildChunk) @@ -910,7 +910,7 @@ class AppDatasetJoin(Base): id = mapped_column(StringUUID, primary_key=True, nullable=False, server_default=sa.text("uuid_generate_v4()")) app_id = mapped_column(StringUUID, nullable=False) dataset_id = mapped_column(StringUUID, nullable=False) - created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=db.func.current_timestamp()) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=sa.func.current_timestamp()) @property def app(self): @@ -931,7 +931,7 @@ class DatasetQuery(Base): source_app_id = mapped_column(StringUUID, nullable=True) created_by_role = mapped_column(String, nullable=False) created_by = mapped_column(StringUUID, nullable=False) - created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=db.func.current_timestamp()) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=sa.func.current_timestamp()) class DatasetKeywordTable(Base): @@ -1226,21 +1226,21 @@ class PipelineBuiltInTemplate(Base): # type: ignore[name-defined] __tablename__ = "pipeline_built_in_templates" __table_args__ = (db.PrimaryKeyConstraint("id", name="pipeline_built_in_template_pkey"),) - id = db.Column(StringUUID, server_default=db.text("uuidv7()")) - name = db.Column(db.String(255), nullable=False) - description = db.Column(db.Text, nullable=False) - chunk_structure = db.Column(db.String(255), nullable=False) - icon = db.Column(db.JSON, nullable=False) - yaml_content = db.Column(db.Text, nullable=False) - copyright = db.Column(db.String(255), nullable=False) - privacy_policy = db.Column(db.String(255), nullable=False) - position = db.Column(db.Integer, nullable=False) - install_count = db.Column(db.Integer, nullable=False, default=0) - language = db.Column(db.String(255), nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - created_by = db.Column(StringUUID, nullable=False) - updated_by = db.Column(StringUUID, nullable=True) + id = mapped_column(StringUUID, server_default=db.text("uuidv7()")) + name = mapped_column(db.String(255), nullable=False) + description = mapped_column(sa.Text, nullable=False) + chunk_structure = mapped_column(db.String(255), nullable=False) + icon = mapped_column(sa.JSON, nullable=False) + yaml_content = mapped_column(sa.Text, nullable=False) + copyright = mapped_column(db.String(255), nullable=False) + privacy_policy = mapped_column(db.String(255), nullable=False) + position = mapped_column(sa.Integer, nullable=False) + install_count = mapped_column(sa.Integer, nullable=False, default=0) + language = mapped_column(db.String(255), nullable=False) + created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) + created_by = mapped_column(StringUUID, nullable=False) + updated_by = mapped_column(StringUUID, nullable=True) @property def created_user_name(self): @@ -1257,20 +1257,20 @@ class PipelineCustomizedTemplate(Base): # type: ignore[name-defined] db.Index("pipeline_customized_template_tenant_idx", "tenant_id"), ) - id = db.Column(StringUUID, server_default=db.text("uuidv7()")) - tenant_id = db.Column(StringUUID, nullable=False) - name = db.Column(db.String(255), nullable=False) - description = db.Column(db.Text, nullable=False) - chunk_structure = db.Column(db.String(255), nullable=False) - icon = db.Column(db.JSON, nullable=False) - position = db.Column(db.Integer, nullable=False) - yaml_content = db.Column(db.Text, nullable=False) - install_count = db.Column(db.Integer, nullable=False, default=0) - language = db.Column(db.String(255), nullable=False) - created_by = db.Column(StringUUID, nullable=False) - updated_by = db.Column(StringUUID, nullable=True) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + id = mapped_column(StringUUID, server_default=db.text("uuidv7()")) + tenant_id = mapped_column(StringUUID, nullable=False) + name = mapped_column(db.String(255), nullable=False) + description = mapped_column(sa.Text, nullable=False) + chunk_structure = mapped_column(db.String(255), nullable=False) + icon = mapped_column(sa.JSON, nullable=False) + position = mapped_column(sa.Integer, nullable=False) + yaml_content = mapped_column(sa.Text, nullable=False) + install_count = mapped_column(sa.Integer, nullable=False, default=0) + language = mapped_column(db.String(255), nullable=False) + created_by = mapped_column(StringUUID, nullable=False) + updated_by = mapped_column(StringUUID, nullable=True) + created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) @property def created_user_name(self): @@ -1284,17 +1284,17 @@ class Pipeline(Base): # type: ignore[name-defined] __tablename__ = "pipelines" __table_args__ = (db.PrimaryKeyConstraint("id", name="pipeline_pkey"),) - id = db.Column(StringUUID, server_default=db.text("uuidv7()")) - tenant_id: Mapped[str] = db.Column(StringUUID, nullable=False) - name = db.Column(db.String(255), nullable=False) - description = db.Column(db.Text, nullable=False, server_default=db.text("''::character varying")) - workflow_id = db.Column(StringUUID, nullable=True) - is_public = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) - is_published = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) - created_by = db.Column(StringUUID, nullable=True) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_by = db.Column(StringUUID, nullable=True) - updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + id = mapped_column(StringUUID, server_default=db.text("uuidv7()")) + tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + name = mapped_column(db.String(255), nullable=False) + description = mapped_column(sa.Text, nullable=False, server_default=db.text("''::character varying")) + workflow_id = mapped_column(StringUUID, nullable=True) + is_public = mapped_column(sa.Boolean, nullable=False, server_default=db.text("false")) + is_published = mapped_column(sa.Boolean, nullable=False, server_default=db.text("false")) + created_by = mapped_column(StringUUID, nullable=True) + created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_by = mapped_column(StringUUID, nullable=True) + updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) def retrieve_dataset(self, session: Session): return session.query(Dataset).where(Dataset.pipeline_id == self.id).first() @@ -1307,25 +1307,25 @@ class DocumentPipelineExecutionLog(Base): db.Index("document_pipeline_execution_logs_document_id_idx", "document_id"), ) - id = db.Column(StringUUID, server_default=db.text("uuidv7()")) - pipeline_id = db.Column(StringUUID, nullable=False) - document_id = db.Column(StringUUID, nullable=False) - datasource_type = db.Column(db.String(255), nullable=False) - datasource_info = db.Column(db.Text, nullable=False) - datasource_node_id = db.Column(db.String(255), nullable=False) - input_data = db.Column(db.JSON, nullable=False) - created_by = db.Column(StringUUID, nullable=True) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + id = mapped_column(StringUUID, server_default=db.text("uuidv7()")) + pipeline_id = mapped_column(StringUUID, nullable=False) + document_id = mapped_column(StringUUID, nullable=False) + datasource_type = mapped_column(db.String(255), nullable=False) + datasource_info = mapped_column(sa.Text, nullable=False) + datasource_node_id = mapped_column(db.String(255), nullable=False) + input_data = mapped_column(sa.JSON, nullable=False) + created_by = mapped_column(StringUUID, nullable=True) + created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) class PipelineRecommendedPlugin(Base): __tablename__ = "pipeline_recommended_plugins" __table_args__ = (db.PrimaryKeyConstraint("id", name="pipeline_recommended_plugin_pkey"),) - id = db.Column(StringUUID, server_default=db.text("uuidv7()")) - plugin_id = db.Column(db.Text, nullable=False) - provider_name = db.Column(db.Text, nullable=False) - position = db.Column(db.Integer, nullable=False, default=0) - active = db.Column(db.Boolean, nullable=False, default=True) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + id = mapped_column(StringUUID, server_default=db.text("uuidv7()")) + plugin_id = mapped_column(sa.Text, nullable=False) + provider_name = mapped_column(sa.Text, nullable=False) + position = mapped_column(sa.Integer, nullable=False, default=0) + active = mapped_column(sa.Boolean, nullable=False, default=True) + created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) diff --git a/api/models/model.py b/api/models/model.py index 9bcb81b41b..18958c8253 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -186,13 +186,13 @@ class App(Base): if len(keys) >= 4: provider_type = tool.get("provider_type", "") provider_id = tool.get("provider_id", "") - if provider_type == ToolProviderType.API.value: + if provider_type == ToolProviderType.API: try: uuid.UUID(provider_id) except Exception: continue api_provider_ids.append(provider_id) - if provider_type == ToolProviderType.BUILT_IN.value: + if provider_type == ToolProviderType.BUILT_IN: try: # check if it's hardcoded try: @@ -251,23 +251,23 @@ class App(Base): provider_type = tool.get("provider_type", "") provider_id = tool.get("provider_id", "") - if provider_type == ToolProviderType.API.value: + if provider_type == ToolProviderType.API: if uuid.UUID(provider_id) not in existing_api_providers: deleted_tools.append( { - "type": ToolProviderType.API.value, + "type": ToolProviderType.API, "tool_name": tool["tool_name"], "provider_id": provider_id, } ) - if provider_type == ToolProviderType.BUILT_IN.value: + if provider_type == ToolProviderType.BUILT_IN: generic_provider_id = GenericProviderID(provider_id) if not existing_builtin_providers[generic_provider_id.provider_name]: deleted_tools.append( { - "type": ToolProviderType.BUILT_IN.value, + "type": ToolProviderType.BUILT_IN, "tool_name": tool["tool_name"], "provider_id": provider_id, # use the original one } @@ -1154,7 +1154,7 @@ class Message(Base): files: list[File] = [] for message_file in message_files: - if message_file.transfer_method == FileTransferMethod.LOCAL_FILE.value: + if message_file.transfer_method == FileTransferMethod.LOCAL_FILE: if message_file.upload_file_id is None: raise ValueError(f"MessageFile {message_file.id} is a local file but has no upload_file_id") file = file_factory.build_from_mapping( @@ -1166,7 +1166,7 @@ class Message(Base): }, tenant_id=current_app.tenant_id, ) - elif message_file.transfer_method == FileTransferMethod.REMOTE_URL.value: + elif message_file.transfer_method == FileTransferMethod.REMOTE_URL: if message_file.url is None: raise ValueError(f"MessageFile {message_file.id} is a remote url but has no url") file = file_factory.build_from_mapping( @@ -1179,7 +1179,7 @@ class Message(Base): }, tenant_id=current_app.tenant_id, ) - elif message_file.transfer_method == FileTransferMethod.TOOL_FILE.value: + elif message_file.transfer_method == FileTransferMethod.TOOL_FILE: if message_file.upload_file_id is None: assert message_file.url is not None message_file.upload_file_id = message_file.url.split("/")[-1].split(".")[0] @@ -1731,7 +1731,7 @@ class MessageChain(Base): type: Mapped[str] = mapped_column(String(255), nullable=False) input = mapped_column(sa.Text, nullable=True) output = mapped_column(sa.Text, nullable=True) - created_at = mapped_column(sa.DateTime, nullable=False, server_default=db.func.current_timestamp()) + created_at = mapped_column(sa.DateTime, nullable=False, server_default=sa.func.current_timestamp()) class MessageAgentThought(Base): @@ -1769,7 +1769,7 @@ class MessageAgentThought(Base): latency: Mapped[float | None] = mapped_column(sa.Float, nullable=True) created_by_role = mapped_column(String, nullable=False) created_by = mapped_column(StringUUID, nullable=False) - created_at = mapped_column(sa.DateTime, nullable=False, server_default=db.func.current_timestamp()) + created_at = mapped_column(sa.DateTime, nullable=False, server_default=sa.func.current_timestamp()) @property def files(self) -> list[Any]: @@ -1872,7 +1872,7 @@ class DatasetRetrieverResource(Base): index_node_hash = mapped_column(sa.Text, nullable=True) retriever_from = mapped_column(sa.Text, nullable=False) created_by = mapped_column(StringUUID, nullable=False) - created_at = mapped_column(sa.DateTime, nullable=False, server_default=db.func.current_timestamp()) + created_at = mapped_column(sa.DateTime, nullable=False, server_default=sa.func.current_timestamp()) class Tag(Base): diff --git a/api/models/oauth.py b/api/models/oauth.py index b6a76793fc..ef23780dc8 100644 --- a/api/models/oauth.py +++ b/api/models/oauth.py @@ -1,7 +1,8 @@ from datetime import datetime +import sqlalchemy as sa from sqlalchemy.dialects.postgresql import JSONB -from sqlalchemy.orm import Mapped +from sqlalchemy.orm import Mapped, mapped_column from .base import Base from .engine import db @@ -15,10 +16,10 @@ class DatasourceOauthParamConfig(Base): # type: ignore[name-defined] db.UniqueConstraint("plugin_id", "provider", name="datasource_oauth_config_datasource_id_provider_idx"), ) - id = db.Column(StringUUID, server_default=db.text("uuidv7()")) - plugin_id: Mapped[str] = db.Column(db.String(255), nullable=False) - provider: Mapped[str] = db.Column(db.String(255), nullable=False) - system_credentials: Mapped[dict] = db.Column(JSONB, nullable=False) + id = mapped_column(StringUUID, server_default=db.text("uuidv7()")) + plugin_id: Mapped[str] = mapped_column(db.String(255), nullable=False) + provider: Mapped[str] = mapped_column(db.String(255), nullable=False) + system_credentials: Mapped[dict] = mapped_column(JSONB, nullable=False) class DatasourceProvider(Base): @@ -28,19 +29,19 @@ class DatasourceProvider(Base): db.UniqueConstraint("tenant_id", "plugin_id", "provider", "name", name="datasource_provider_unique_name"), db.Index("datasource_provider_auth_type_provider_idx", "tenant_id", "plugin_id", "provider"), ) - id = db.Column(StringUUID, server_default=db.text("uuidv7()")) - tenant_id = db.Column(StringUUID, nullable=False) - name: Mapped[str] = db.Column(db.String(255), nullable=False) - provider: Mapped[str] = db.Column(db.String(255), nullable=False) - plugin_id: Mapped[str] = db.Column(db.String(255), nullable=False) - auth_type: Mapped[str] = db.Column(db.String(255), nullable=False) - encrypted_credentials: Mapped[dict] = db.Column(JSONB, nullable=False) - avatar_url: Mapped[str] = db.Column(db.String(255), nullable=True, default="default") - is_default: Mapped[bool] = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) - expires_at: Mapped[int] = db.Column(db.Integer, nullable=False, server_default="-1") + id = mapped_column(StringUUID, server_default=db.text("uuidv7()")) + tenant_id = mapped_column(StringUUID, nullable=False) + name: Mapped[str] = mapped_column(db.String(255), nullable=False) + provider: Mapped[str] = mapped_column(db.String(255), nullable=False) + plugin_id: Mapped[str] = mapped_column(db.String(255), nullable=False) + auth_type: Mapped[str] = mapped_column(db.String(255), nullable=False) + encrypted_credentials: Mapped[dict] = mapped_column(JSONB, nullable=False) + avatar_url: Mapped[str] = mapped_column(sa.Text, nullable=True, default="default") + is_default: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=db.text("false")) + expires_at: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default="-1") - created_at: Mapped[datetime] = db.Column(db.DateTime, nullable=False, default=datetime.now) - updated_at: Mapped[datetime] = db.Column(db.DateTime, nullable=False, default=datetime.now) + created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, default=datetime.now) + updated_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, default=datetime.now) class DatasourceOauthTenantParamConfig(Base): @@ -50,12 +51,12 @@ class DatasourceOauthTenantParamConfig(Base): db.UniqueConstraint("tenant_id", "plugin_id", "provider", name="datasource_oauth_tenant_config_unique"), ) - id = db.Column(StringUUID, server_default=db.text("uuidv7()")) - tenant_id = db.Column(StringUUID, nullable=False) - provider: Mapped[str] = db.Column(db.String(255), nullable=False) - plugin_id: Mapped[str] = db.Column(db.String(255), nullable=False) - client_params: Mapped[dict] = db.Column(JSONB, nullable=False, default={}) - enabled: Mapped[bool] = db.Column(db.Boolean, nullable=False, default=False) + id = mapped_column(StringUUID, server_default=db.text("uuidv7()")) + tenant_id = mapped_column(StringUUID, nullable=False) + provider: Mapped[str] = mapped_column(db.String(255), nullable=False) + plugin_id: Mapped[str] = mapped_column(db.String(255), nullable=False) + client_params: Mapped[dict] = mapped_column(JSONB, nullable=False, default={}) + enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, default=False) - created_at: Mapped[datetime] = db.Column(db.DateTime, nullable=False, default=datetime.now) - updated_at: Mapped[datetime] = db.Column(db.DateTime, nullable=False, default=datetime.now) + created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, default=datetime.now) + updated_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, default=datetime.now) diff --git a/api/models/provider.py b/api/models/provider.py index aacc6e505a..f6852d49f4 100644 --- a/api/models/provider.py +++ b/api/models/provider.py @@ -107,7 +107,7 @@ class Provider(Base): """ Returns True if the provider is enabled. """ - if self.provider_type == ProviderType.SYSTEM.value: + if self.provider_type == ProviderType.SYSTEM: return self.is_valid else: return self.is_valid and self.token_is_set diff --git a/api/models/task.py b/api/models/task.py index 3da1674536..4e49254dbd 100644 --- a/api/models/task.py +++ b/api/models/task.py @@ -8,8 +8,6 @@ from sqlalchemy.orm import Mapped, mapped_column from libs.datetime_utils import naive_utc_now from models.base import Base -from .engine import db - class CeleryTask(Base): """Task result/status.""" @@ -19,7 +17,7 @@ class CeleryTask(Base): id = mapped_column(sa.Integer, sa.Sequence("task_id_sequence"), primary_key=True, autoincrement=True) task_id = mapped_column(String(155), unique=True) status = mapped_column(String(50), default=states.PENDING) - result = mapped_column(db.PickleType, nullable=True) + result = mapped_column(sa.PickleType, nullable=True) date_done = mapped_column( DateTime, default=lambda: naive_utc_now(), @@ -44,5 +42,5 @@ class CeleryTaskSet(Base): sa.Integer, sa.Sequence("taskset_id_sequence"), autoincrement=True, primary_key=True ) taskset_id = mapped_column(String(155), unique=True) - result = mapped_column(db.PickleType, nullable=True) + result = mapped_column(sa.PickleType, nullable=True) date_done: Mapped[datetime | None] = mapped_column(DateTime, default=lambda: naive_utc_now(), nullable=True) diff --git a/api/models/tools.py b/api/models/tools.py index 7211d7aa3a..d581d588a4 100644 --- a/api/models/tools.py +++ b/api/models/tools.py @@ -152,7 +152,7 @@ class ApiToolProvider(Base): def tools(self) -> list["ApiToolBundle"]: from core.tools.entities.tool_bundle import ApiToolBundle - return [ApiToolBundle(**tool) for tool in json.loads(self.tools_str)] + return [ApiToolBundle.model_validate(tool) for tool in json.loads(self.tools_str)] @property def credentials(self) -> dict[str, Any]: @@ -242,7 +242,10 @@ class WorkflowToolProvider(Base): def parameter_configurations(self) -> list["WorkflowToolParameterConfiguration"]: from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration - return [WorkflowToolParameterConfiguration(**config) for config in json.loads(self.parameter_configuration)] + return [ + WorkflowToolParameterConfiguration.model_validate(config) + for config in json.loads(self.parameter_configuration) + ] @property def app(self) -> App | None: @@ -312,7 +315,7 @@ class MCPToolProvider(Base): def mcp_tools(self) -> list["MCPTool"]: from core.mcp.types import Tool as MCPTool - return [MCPTool(**tool) for tool in json.loads(self.tools)] + return [MCPTool.model_validate(tool) for tool in json.loads(self.tools)] @property def provider_icon(self) -> Mapping[str, str] | str: @@ -552,4 +555,4 @@ class DeprecatedPublishedAppTool(Base): def description_i18n(self) -> "I18nObject": from core.tools.entities.common_entities import I18nObject - return I18nObject(**json.loads(self.description)) + return I18nObject.model_validate(json.loads(self.description)) diff --git a/api/models/workflow.py b/api/models/workflow.py index e61005953e..b898f02612 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -360,7 +360,9 @@ class Workflow(Base): @property def environment_variables(self) -> Sequence[StringVariable | IntegerVariable | FloatVariable | SecretVariable]: - # _environment_variables is guaranteed to be non-None due to server_default="{}" + # TODO: find some way to init `self._environment_variables` when instance created. + if self._environment_variables is None: + self._environment_variables = "{}" # Use workflow.tenant_id to avoid relying on request user in background threads tenant_id = self.tenant_id @@ -444,7 +446,9 @@ class Workflow(Base): @property def conversation_variables(self) -> Sequence[Variable]: - # _conversation_variables is guaranteed to be non-None due to server_default="{}" + # TODO: find some way to init `self._conversation_variables` when instance created. + if self._conversation_variables is None: + self._conversation_variables = "{}" variables_dict: dict[str, Any] = json.loads(self._conversation_variables) results = [variable_factory.build_conversation_variable_from_mapping(v) for v in variables_dict.values()] @@ -825,14 +829,14 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo if self.execution_metadata_dict: from core.workflow.nodes import NodeType - if self.node_type == NodeType.TOOL.value and "tool_info" in self.execution_metadata_dict: + if self.node_type == NodeType.TOOL and "tool_info" in self.execution_metadata_dict: tool_info: dict[str, Any] = self.execution_metadata_dict["tool_info"] extras["icon"] = ToolManager.get_tool_icon( tenant_id=self.tenant_id, provider_type=tool_info["provider_type"], provider_id=tool_info["provider_id"], ) - elif self.node_type == NodeType.DATASOURCE.value and "datasource_info" in self.execution_metadata_dict: + elif self.node_type == NodeType.DATASOURCE and "datasource_info" in self.execution_metadata_dict: datasource_info = self.execution_metadata_dict["datasource_info"] extras["icon"] = datasource_info.get("icon") return extras diff --git a/api/pyproject.toml b/api/pyproject.toml index 0b2b41d6db..897d114dcc 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -1,11 +1,10 @@ [project] name = "dify-api" -version = "2.0.0-beta2" +version = "1.9.1" requires-python = ">=3.11,<3.13" dependencies = [ "arize-phoenix-otel~=0.9.2", - "authlib==1.6.4", "azure-identity==1.16.1", "beautifulsoup4==4.12.2", "boto3==1.35.99", @@ -34,10 +33,8 @@ dependencies = [ "json-repair>=0.41.1", "langfuse~=2.51.3", "langsmith~=0.1.77", - "mailchimp-transactional~=1.0.50", "markdown~=3.5.1", "numpy~=1.26.4", - "openai~=1.61.0", "openpyxl~=3.1.5", "opik~=1.7.25", "opentelemetry-api==1.27.0", @@ -49,8 +46,9 @@ dependencies = [ "opentelemetry-instrumentation==0.48b0", "opentelemetry-instrumentation-celery==0.48b0", "opentelemetry-instrumentation-flask==0.48b0", + "opentelemetry-instrumentation-httpx==0.48b0", "opentelemetry-instrumentation-redis==0.48b0", - "opentelemetry-instrumentation-requests==0.48b0", + "opentelemetry-instrumentation-httpx==0.48b0", "opentelemetry-instrumentation-sqlalchemy==0.48b0", "opentelemetry-propagator-b3==1.27.0", # opentelemetry-proto1.28.0 depends on protobuf (>=5.0,<6.0), @@ -60,7 +58,6 @@ dependencies = [ "opentelemetry-semantic-conventions==0.48b0", "opentelemetry-util-http==0.48b0", "pandas[excel,output-formatting,performance]~=2.2.2", - "pandoc~=2.4", "psycogreen~=1.0.2", "psycopg2-binary~=2.9.6", "pycryptodome==3.19.1", @@ -113,7 +110,7 @@ dev = [ "lxml-stubs~=0.5.1", "ty~=0.0.1a19", "basedpyright~=1.31.0", - "ruff~=0.12.3", + "ruff~=0.14.0", "pytest~=8.3.2", "pytest-benchmark~=4.0.0", "pytest-cov~=4.1.0", @@ -148,8 +145,6 @@ dev = [ "types-pywin32~=310.0.0", "types-pyyaml~=6.0.12", "types-regex~=2024.11.6", - "types-requests~=2.32.0", - "types-requests-oauthlib~=2.0.0", "types-shapely~=2.0.0", "types-simplejson>=3.20.0", "types-six>=1.17.0", @@ -178,10 +173,10 @@ dev = [ # Required for storage clients ############################################################ storage = [ - "azure-storage-blob==12.13.0", + "azure-storage-blob==12.26.0", "bce-python-sdk~=0.9.23", - "cos-python-sdk-v5==1.9.30", - "esdk-obs-python==3.24.6.1", + "cos-python-sdk-v5==1.9.38", + "esdk-obs-python==3.25.8", "google-cloud-storage==2.16.0", "opendal~=0.46.0", "oss2==2.18.5", @@ -207,7 +202,7 @@ vdb = [ "couchbase~=4.3.0", "elasticsearch==8.14.0", "opensearch-py==2.4.0", - "oracledb==3.0.0", + "oracledb==3.3.0", "pgvecto-rs[sqlalchemy]~=0.2.1", "pgvector==0.2.5", "pymilvus~=2.5.0", @@ -222,4 +217,5 @@ vdb = [ "weaviate-client~=3.24.0", "xinference-client~=1.2.2", "mo-vector~=0.1.13", + "mysql-connector-python>=9.3.0", ] diff --git a/api/pyrightconfig.json b/api/pyrightconfig.json index 61ed3ac3b4..bf4ec2314e 100644 --- a/api/pyrightconfig.json +++ b/api/pyrightconfig.json @@ -1,19 +1,10 @@ { "include": ["."], "exclude": [ - ".venv", "tests/", + ".venv", "migrations/", - "core/rag", - "extensions", - "libs", - "controllers/console/datasets", - "controllers/service_api/dataset", - "core/ops", - "core/tools", - "core/model_runtime", - "core/workflow/nodes", - "core/app/app_config/easy_ui_based_app/dataset" + "core/rag" ], "typeCheckingMode": "strict", "allowedUntypedLibraries": [ @@ -21,9 +12,11 @@ "flask_login", "opentelemetry.instrumentation.celery", "opentelemetry.instrumentation.flask", + "opentelemetry.instrumentation.httpx", "opentelemetry.instrumentation.requests", "opentelemetry.instrumentation.sqlalchemy", - "opentelemetry.instrumentation.redis" + "opentelemetry.instrumentation.redis", + "opentelemetry.instrumentation.httpx" ], "reportUnknownMemberType": "hint", "reportUnknownParameterType": "hint", @@ -32,13 +25,11 @@ "reportUnknownLambdaType": "hint", "reportMissingParameterType": "hint", "reportMissingTypeArgument": "hint", - "reportUnnecessaryContains": "hint", "reportUnnecessaryComparison": "hint", - "reportUnnecessaryCast": "hint", "reportUnnecessaryIsInstance": "hint", "reportUntypedFunctionDecorator": "hint", "reportAttributeAccessIssue": "hint", "pythonVersion": "3.11", "pythonPlatform": "All" -} +} \ No newline at end of file diff --git a/api/pytest.ini b/api/pytest.ini index eb49619481..afb53b47cc 100644 --- a/api/pytest.ini +++ b/api/pytest.ini @@ -7,7 +7,7 @@ env = CHATGLM_API_BASE = http://a.abc.com:11451 CODE_EXECUTION_API_KEY = dify-sandbox CODE_EXECUTION_ENDPOINT = http://127.0.0.1:8194 - CODE_MAX_STRING_LENGTH = 80000 + CODE_MAX_STRING_LENGTH = 400000 PLUGIN_DAEMON_KEY=lYkiYYT6owG+71oLerGzA7GXCgOT++6ovaezWAjpCjf+Sjc3ZtU+qUEi PLUGIN_DAEMON_URL=http://127.0.0.1:5002 PLUGIN_MAX_PACKAGE_SIZE=15728640 diff --git a/api/schedule/check_upgradable_plugin_task.py b/api/schedule/check_upgradable_plugin_task.py index 08a5cfce79..e91ce07be3 100644 --- a/api/schedule/check_upgradable_plugin_task.py +++ b/api/schedule/check_upgradable_plugin_task.py @@ -1,3 +1,4 @@ +import math import time import click @@ -5,9 +6,10 @@ import click import app from extensions.ext_database import db from models.account import TenantPluginAutoUpgradeStrategy -from tasks.process_tenant_plugin_autoupgrade_check_task import process_tenant_plugin_autoupgrade_check_task +from tasks import process_tenant_plugin_autoupgrade_check_task as check_task AUTO_UPGRADE_MINIMAL_CHECKING_INTERVAL = 15 * 60 # 15 minutes +MAX_CONCURRENT_CHECK_TASKS = 20 @app.celery.task(queue="plugin") @@ -30,15 +32,29 @@ def check_upgradable_plugin_task(): .all() ) - for strategy in strategies: - process_tenant_plugin_autoupgrade_check_task.delay( - strategy.tenant_id, - strategy.strategy_setting, - strategy.upgrade_time_of_day, - strategy.upgrade_mode, - strategy.exclude_plugins, - strategy.include_plugins, - ) + total_strategies = len(strategies) + click.echo(click.style(f"Total strategies: {total_strategies}", fg="green")) + + batch_chunk_count = math.ceil( + total_strategies / MAX_CONCURRENT_CHECK_TASKS + ) # make sure all strategies are checked in this interval + batch_interval_time = (AUTO_UPGRADE_MINIMAL_CHECKING_INTERVAL / batch_chunk_count) if batch_chunk_count > 0 else 0 + + for i in range(0, total_strategies, MAX_CONCURRENT_CHECK_TASKS): + batch_strategies = strategies[i : i + MAX_CONCURRENT_CHECK_TASKS] + for strategy in batch_strategies: + check_task.process_tenant_plugin_autoupgrade_check_task.delay( + strategy.tenant_id, + strategy.strategy_setting, + strategy.upgrade_time_of_day, + strategy.upgrade_mode, + strategy.exclude_plugins, + strategy.include_plugins, + ) + + # Only sleep if batch_interval_time > 0.0001 AND current batch is not the last one + if batch_interval_time > 0.0001 and i + MAX_CONCURRENT_CHECK_TASKS < total_strategies: + time.sleep(batch_interval_time) end_at = time.perf_counter() click.echo( diff --git a/api/services/account_service.py b/api/services/account_service.py index 0e699d16da..106bc0e77e 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -127,7 +127,7 @@ class AccountService: if not account: return None - if account.status == AccountStatus.BANNED.value: + if account.status == AccountStatus.BANNED: raise Unauthorized("Account is banned.") current_tenant = db.session.query(TenantAccountJoin).filter_by(account_id=account.id, current=True).first() @@ -178,7 +178,7 @@ class AccountService: if not account: raise AccountPasswordError("Invalid email or password.") - if account.status == AccountStatus.BANNED.value: + if account.status == AccountStatus.BANNED: raise AccountLoginError("Account is banned.") if password and invite_token and account.password is None: @@ -193,8 +193,8 @@ class AccountService: if account.password is None or not compare_password(password, account.password, account.password_salt): raise AccountPasswordError("Invalid email or password.") - if account.status == AccountStatus.PENDING.value: - account.status = AccountStatus.ACTIVE.value + if account.status == AccountStatus.PENDING: + account.status = AccountStatus.ACTIVE account.initialized_at = naive_utc_now() db.session.commit() @@ -246,10 +246,8 @@ class AccountService: ) ) - account = Account() - account.email = email - account.name = name - + password_to_set = None + salt_to_set = None if password: valid_password(password) @@ -261,14 +259,18 @@ class AccountService: password_hashed = hash_password(password, salt) base64_password_hashed = base64.b64encode(password_hashed).decode() - account.password = base64_password_hashed - account.password_salt = base64_salt + password_to_set = base64_password_hashed + salt_to_set = base64_salt - account.interface_language = interface_language - account.interface_theme = interface_theme - - # Set timezone based on language - account.timezone = language_timezone_mapping.get(interface_language, "UTC") + account = Account( + name=name, + email=email, + password=password_to_set, + password_salt=salt_to_set, + interface_language=interface_language, + interface_theme=interface_theme, + timezone=language_timezone_mapping.get(interface_language, "UTC"), + ) db.session.add(account) db.session.commit() @@ -355,7 +357,7 @@ class AccountService: @staticmethod def close_account(account: Account): """Close account""" - account.status = AccountStatus.CLOSED.value + account.status = AccountStatus.CLOSED db.session.commit() @staticmethod @@ -395,8 +397,8 @@ class AccountService: if ip_address: AccountService.update_login_info(account=account, ip_address=ip_address) - if account.status == AccountStatus.PENDING.value: - account.status = AccountStatus.ACTIVE.value + if account.status == AccountStatus.PENDING: + account.status = AccountStatus.ACTIVE db.session.commit() access_token = AccountService.get_account_jwt_token(account=account) @@ -764,7 +766,7 @@ class AccountService: if not account: return None - if account.status == AccountStatus.BANNED.value: + if account.status == AccountStatus.BANNED: raise Unauthorized("Account is banned.") return account @@ -1028,7 +1030,7 @@ class TenantService: @staticmethod def create_tenant_member(tenant: Tenant, account: Account, role: str = "normal") -> TenantAccountJoin: """Create tenant member""" - if role == TenantAccountRole.OWNER.value: + if role == TenantAccountRole.OWNER: if TenantService.has_roles(tenant, [TenantAccountRole.OWNER]): logger.error("Tenant %s has already an owner.", tenant.id) raise Exception("Tenant already has an owner.") @@ -1313,7 +1315,7 @@ class RegisterService: password=password, is_setup=is_setup, ) - account.status = AccountStatus.ACTIVE.value if not status else status.value + account.status = status or AccountStatus.ACTIVE account.initialized_at = naive_utc_now() if open_id is not None and provider is not None: @@ -1374,7 +1376,7 @@ class RegisterService: TenantService.create_tenant_member(tenant, account, role) # Support resend invitation email when the account is pending status - if account.status != AccountStatus.PENDING.value: + if account.status != AccountStatus.PENDING: raise AccountAlreadyInTenantError("Account already in tenant.") token = cls.generate_invite_token(tenant, account) diff --git a/api/services/app_dsl_service.py b/api/services/app_dsl_service.py index 8701fe4f4e..e2915ebfbb 100644 --- a/api/services/app_dsl_service.py +++ b/api/services/app_dsl_service.py @@ -29,6 +29,7 @@ from core.workflow.nodes.tool.entities import ToolNodeData from events.app_event import app_model_config_was_updated, app_was_created from extensions.ext_redis import redis_client from factories import variable_factory +from libs.datetime_utils import naive_utc_now from models import Account, App, AppMode from models.model import AppModelConfig from models.workflow import Workflow @@ -439,6 +440,7 @@ class AppDslService: app.icon = icon app.icon_background = icon_background or app_data.get("icon_background", app.icon_background) app.updated_by = account.id + app.updated_at = naive_utc_now() else: if account.current_tenant_id is None: raise ValueError("Current tenant is not set") @@ -494,7 +496,7 @@ class AppDslService: unique_hash = None graph = workflow_data.get("graph", {}) for node in graph.get("nodes", []): - if node.get("data", {}).get("type", "") == NodeType.KNOWLEDGE_RETRIEVAL.value: + if node.get("data", {}).get("type", "") == NodeType.KNOWLEDGE_RETRIEVAL: dataset_ids = node["data"].get("dataset_ids", []) node["data"]["dataset_ids"] = [ decrypted_id @@ -584,17 +586,17 @@ class AppDslService: if not node_data: continue data_type = node_data.get("type", "") - if data_type == NodeType.KNOWLEDGE_RETRIEVAL.value: + if data_type == NodeType.KNOWLEDGE_RETRIEVAL: dataset_ids = node_data.get("dataset_ids", []) node_data["dataset_ids"] = [ cls.encrypt_dataset_id(dataset_id=dataset_id, tenant_id=app_model.tenant_id) for dataset_id in dataset_ids ] # filter credential id from tool node - if not include_secret and data_type == NodeType.TOOL.value: + if not include_secret and data_type == NodeType.TOOL: node_data.pop("credential_id", None) # filter credential id from agent node - if not include_secret and data_type == NodeType.AGENT.value: + if not include_secret and data_type == NodeType.AGENT: for tool in node_data.get("agent_parameters", {}).get("tools", {}).get("value", []): tool.pop("credential_id", None) @@ -658,32 +660,32 @@ class AppDslService: try: typ = node.get("data", {}).get("type") match typ: - case NodeType.TOOL.value: - tool_entity = ToolNodeData(**node["data"]) + case NodeType.TOOL: + tool_entity = ToolNodeData.model_validate(node["data"]) dependencies.append( DependenciesAnalysisService.analyze_tool_dependency(tool_entity.provider_id), ) - case NodeType.LLM.value: - llm_entity = LLMNodeData(**node["data"]) + case NodeType.LLM: + llm_entity = LLMNodeData.model_validate(node["data"]) dependencies.append( DependenciesAnalysisService.analyze_model_provider_dependency(llm_entity.model.provider), ) - case NodeType.QUESTION_CLASSIFIER.value: - question_classifier_entity = QuestionClassifierNodeData(**node["data"]) + case NodeType.QUESTION_CLASSIFIER: + question_classifier_entity = QuestionClassifierNodeData.model_validate(node["data"]) dependencies.append( DependenciesAnalysisService.analyze_model_provider_dependency( question_classifier_entity.model.provider ), ) - case NodeType.PARAMETER_EXTRACTOR.value: - parameter_extractor_entity = ParameterExtractorNodeData(**node["data"]) + case NodeType.PARAMETER_EXTRACTOR: + parameter_extractor_entity = ParameterExtractorNodeData.model_validate(node["data"]) dependencies.append( DependenciesAnalysisService.analyze_model_provider_dependency( parameter_extractor_entity.model.provider ), ) - case NodeType.KNOWLEDGE_RETRIEVAL.value: - knowledge_retrieval_entity = KnowledgeRetrievalNodeData(**node["data"]) + case NodeType.KNOWLEDGE_RETRIEVAL: + knowledge_retrieval_entity = KnowledgeRetrievalNodeData.model_validate(node["data"]) if knowledge_retrieval_entity.retrieval_mode == "multiple": if knowledge_retrieval_entity.multiple_retrieval_config: if ( @@ -773,7 +775,7 @@ class AppDslService: """ Returns the leaked dependencies in current workspace """ - dependencies = [PluginDependency(**dep) for dep in dsl_dependencies] + dependencies = [PluginDependency.model_validate(dep) for dep in dsl_dependencies] if not dependencies: return [] diff --git a/api/services/app_generate_service.py b/api/services/app_generate_service.py index 8911da4728..b462ddf236 100644 --- a/api/services/app_generate_service.py +++ b/api/services/app_generate_service.py @@ -2,8 +2,6 @@ import uuid from collections.abc import Generator, Mapping from typing import Any, Union -from openai._exceptions import RateLimitError - from configs import dify_config from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator from core.app.apps.agent_chat.app_generator import AgentChatAppGenerator @@ -122,8 +120,6 @@ class AppGenerateService: ) else: raise ValueError(f"Invalid app mode {app_model.mode}") - except RateLimitError as e: - raise InvokeRateLimitError(str(e)) except Exception: rate_limit.exit(request_id) raise diff --git a/api/services/app_service.py b/api/services/app_service.py index d524adbf3e..4fc6cf2494 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -2,6 +2,7 @@ import json import logging from typing import TypedDict, cast +import sqlalchemy as sa from flask_sqlalchemy.pagination import Pagination from configs import dify_config @@ -65,7 +66,7 @@ class AppService: return None app_models = db.paginate( - db.select(App).where(*filters).order_by(App.created_at.desc()), + sa.select(App).where(*filters).order_by(App.created_at.desc()), page=args["page"], per_page=args["limit"], error_out=False, diff --git a/api/services/auth/firecrawl/firecrawl.py b/api/services/auth/firecrawl/firecrawl.py index 6ef034f292..d455475bfc 100644 --- a/api/services/auth/firecrawl/firecrawl.py +++ b/api/services/auth/firecrawl/firecrawl.py @@ -1,6 +1,6 @@ import json -import requests +import httpx from services.auth.api_key_auth_base import ApiKeyAuthBase @@ -36,7 +36,7 @@ class FirecrawlAuth(ApiKeyAuthBase): return {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"} def _post_request(self, url, data, headers): - return requests.post(url, headers=headers, json=data) + return httpx.post(url, headers=headers, json=data) def _handle_error(self, response): if response.status_code in {402, 409, 500}: diff --git a/api/services/auth/jina.py b/api/services/auth/jina.py index 6100e9afc8..afaed28ac9 100644 --- a/api/services/auth/jina.py +++ b/api/services/auth/jina.py @@ -1,6 +1,6 @@ import json -import requests +import httpx from services.auth.api_key_auth_base import ApiKeyAuthBase @@ -31,7 +31,7 @@ class JinaAuth(ApiKeyAuthBase): return {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"} def _post_request(self, url, data, headers): - return requests.post(url, headers=headers, json=data) + return httpx.post(url, headers=headers, json=data) def _handle_error(self, response): if response.status_code in {402, 409, 500}: diff --git a/api/services/auth/jina/jina.py b/api/services/auth/jina/jina.py index 6100e9afc8..afaed28ac9 100644 --- a/api/services/auth/jina/jina.py +++ b/api/services/auth/jina/jina.py @@ -1,6 +1,6 @@ import json -import requests +import httpx from services.auth.api_key_auth_base import ApiKeyAuthBase @@ -31,7 +31,7 @@ class JinaAuth(ApiKeyAuthBase): return {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"} def _post_request(self, url, data, headers): - return requests.post(url, headers=headers, json=data) + return httpx.post(url, headers=headers, json=data) def _handle_error(self, response): if response.status_code in {402, 409, 500}: diff --git a/api/services/auth/watercrawl/watercrawl.py b/api/services/auth/watercrawl/watercrawl.py index 153ab5ba75..b2d28a83d1 100644 --- a/api/services/auth/watercrawl/watercrawl.py +++ b/api/services/auth/watercrawl/watercrawl.py @@ -1,7 +1,7 @@ import json from urllib.parse import urljoin -import requests +import httpx from services.auth.api_key_auth_base import ApiKeyAuthBase @@ -31,7 +31,7 @@ class WatercrawlAuth(ApiKeyAuthBase): return {"Content-Type": "application/json", "X-API-KEY": self.api_key} def _get_request(self, url, headers): - return requests.get(url, headers=headers) + return httpx.get(url, headers=headers) def _handle_error(self, response): if response.status_code in {402, 409, 500}: diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 8b3720026d..87861ada87 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -93,7 +93,7 @@ logger = logging.getLogger(__name__) class DatasetService: @staticmethod def get_datasets(page, per_page, tenant_id=None, user=None, search=None, tag_ids=None, include_all=False): - query = select(Dataset).where(Dataset.tenant_id == tenant_id).order_by(Dataset.created_at.desc()) + query = select(Dataset).where(Dataset.tenant_id == tenant_id).order_by(Dataset.created_at.desc(), Dataset.id) if user: # get permitted dataset ids @@ -115,12 +115,12 @@ class DatasetService: # Check if permitted_dataset_ids is not empty to avoid WHERE false condition if permitted_dataset_ids and len(permitted_dataset_ids) > 0: query = query.where( - db.or_( + sa.or_( Dataset.permission == DatasetPermissionEnum.ALL_TEAM, - db.and_( + sa.and_( Dataset.permission == DatasetPermissionEnum.ONLY_ME, Dataset.created_by == user.id ), - db.and_( + sa.and_( Dataset.permission == DatasetPermissionEnum.PARTIAL_TEAM, Dataset.id.in_(permitted_dataset_ids), ), @@ -128,9 +128,9 @@ class DatasetService: ) else: query = query.where( - db.or_( + sa.or_( Dataset.permission == DatasetPermissionEnum.ALL_TEAM, - db.and_( + sa.and_( Dataset.permission == DatasetPermissionEnum.ONLY_ME, Dataset.created_by == user.id ), ) @@ -1879,7 +1879,7 @@ class DocumentService: # for notion_info in notion_info_list: # workspace_id = notion_info.workspace_id # data_source_binding = DataSourceOauthBinding.query.filter( - # db.and_( + # sa.and_( # DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, # DataSourceOauthBinding.provider == "notion", # DataSourceOauthBinding.disabled == False, diff --git a/api/services/datasource_provider_service.py b/api/services/datasource_provider_service.py index 89a5d89f61..36b7084973 100644 --- a/api/services/datasource_provider_service.py +++ b/api/services/datasource_provider_service.py @@ -646,7 +646,7 @@ class DatasourceProviderService: name=db_provider_name, provider=provider_name, plugin_id=plugin_id, - auth_type=CredentialType.API_KEY.value, + auth_type=CredentialType.API_KEY, encrypted_credentials=credentials, ) session.add(datasource_provider) @@ -674,7 +674,7 @@ class DatasourceProviderService: secret_input_form_variables = [] for credential_form_schema in credential_form_schemas: - if credential_form_schema.type.value == FormType.SECRET_INPUT.value: + if credential_form_schema.type.value == FormType.SECRET_INPUT: secret_input_form_variables.append(credential_form_schema.name) return secret_input_form_variables diff --git a/api/services/enterprise/base.py b/api/services/enterprise/base.py index edb76408e8..bdc960aa2d 100644 --- a/api/services/enterprise/base.py +++ b/api/services/enterprise/base.py @@ -1,10 +1,12 @@ import os +from collections.abc import Mapping +from typing import Any -import requests +import httpx class BaseRequest: - proxies = { + proxies: Mapping[str, str] | None = { "http": "", "https": "", } @@ -13,10 +15,31 @@ class BaseRequest: secret_key_header = "" @classmethod - def send_request(cls, method, endpoint, json=None, params=None): + def _build_mounts(cls) -> dict[str, httpx.BaseTransport] | None: + if not cls.proxies: + return None + + mounts: dict[str, httpx.BaseTransport] = {} + for scheme, value in cls.proxies.items(): + if not value: + continue + key = f"{scheme}://" if not scheme.endswith("://") else scheme + mounts[key] = httpx.HTTPTransport(proxy=value) + return mounts or None + + @classmethod + def send_request( + cls, + method: str, + endpoint: str, + json: Any | None = None, + params: Mapping[str, Any] | None = None, + ) -> Any: headers = {"Content-Type": "application/json", cls.secret_key_header: cls.secret_key} url = f"{cls.base_url}{endpoint}" - response = requests.request(method, url, json=json, params=params, headers=headers, proxies=cls.proxies) + mounts = cls._build_mounts() + with httpx.Client(mounts=mounts) as client: + response = client.request(method, url, json=json, params=params, headers=headers) return response.json() diff --git a/api/services/enterprise/enterprise_service.py b/api/services/enterprise/enterprise_service.py index f8612456d6..4fbf33fd6f 100644 --- a/api/services/enterprise/enterprise_service.py +++ b/api/services/enterprise/enterprise_service.py @@ -70,7 +70,7 @@ class EnterpriseService: data = EnterpriseRequest.send_request("GET", "/webapp/access-mode/id", params=params) if not data: raise ValueError("No data found.") - return WebAppSettings(**data) + return WebAppSettings.model_validate(data) @classmethod def batch_get_app_access_mode_by_id(cls, app_ids: list[str]) -> dict[str, WebAppSettings]: @@ -100,7 +100,7 @@ class EnterpriseService: data = EnterpriseRequest.send_request("GET", "/webapp/access-mode/code", params=params) if not data: raise ValueError("No data found.") - return WebAppSettings(**data) + return WebAppSettings.model_validate(data) @classmethod def update_app_access_mode(cls, app_id: str, access_mode: str): diff --git a/api/services/entities/knowledge_entities/rag_pipeline_entities.py b/api/services/entities/knowledge_entities/rag_pipeline_entities.py index ac96b5c8ad..860bfde401 100644 --- a/api/services/entities/knowledge_entities/rag_pipeline_entities.py +++ b/api/services/entities/knowledge_entities/rag_pipeline_entities.py @@ -83,7 +83,7 @@ class RetrievalSetting(BaseModel): Retrieval Setting. """ - search_method: Literal["semantic_search", "fulltext_search", "keyword_search", "hybrid_search"] + search_method: Literal["semantic_search", "full_text_search", "keyword_search", "hybrid_search"] top_k: int score_threshold: float | None = 0.5 score_threshold_enabled: bool = False diff --git a/api/services/entities/model_provider_entities.py b/api/services/entities/model_provider_entities.py index 49d48f044c..0f5151919f 100644 --- a/api/services/entities/model_provider_entities.py +++ b/api/services/entities/model_provider_entities.py @@ -1,6 +1,7 @@ +from collections.abc import Sequence from enum import Enum -from pydantic import BaseModel, ConfigDict +from pydantic import BaseModel, ConfigDict, model_validator from configs import dify_config from core.entities.model_entities import ( @@ -71,7 +72,7 @@ class ProviderResponse(BaseModel): icon_large: I18nObject | None = None background: str | None = None help: ProviderHelpEntity | None = None - supported_model_types: list[ModelType] + supported_model_types: Sequence[ModelType] configurate_methods: list[ConfigurateMethod] provider_credential_schema: ProviderCredentialSchema | None = None model_credential_schema: ModelCredentialSchema | None = None @@ -82,9 +83,8 @@ class ProviderResponse(BaseModel): # pydantic configs model_config = ConfigDict(protected_namespaces=()) - def __init__(self, **data): - super().__init__(**data) - + @model_validator(mode="after") + def _(self): url_prefix = ( dify_config.CONSOLE_API_URL + f"/console/api/workspaces/{self.tenant_id}/model-providers/{self.provider}" ) @@ -97,6 +97,7 @@ class ProviderResponse(BaseModel): self.icon_large = I18nObject( en_US=f"{url_prefix}/icon_large/en_US", zh_Hans=f"{url_prefix}/icon_large/zh_Hans" ) + return self class ProviderWithModelsResponse(BaseModel): @@ -112,9 +113,8 @@ class ProviderWithModelsResponse(BaseModel): status: CustomConfigurationStatus models: list[ProviderModelWithStatusEntity] - def __init__(self, **data): - super().__init__(**data) - + @model_validator(mode="after") + def _(self): url_prefix = ( dify_config.CONSOLE_API_URL + f"/console/api/workspaces/{self.tenant_id}/model-providers/{self.provider}" ) @@ -127,6 +127,7 @@ class ProviderWithModelsResponse(BaseModel): self.icon_large = I18nObject( en_US=f"{url_prefix}/icon_large/en_US", zh_Hans=f"{url_prefix}/icon_large/zh_Hans" ) + return self class SimpleProviderEntityResponse(SimpleProviderEntity): @@ -136,9 +137,8 @@ class SimpleProviderEntityResponse(SimpleProviderEntity): tenant_id: str - def __init__(self, **data): - super().__init__(**data) - + @model_validator(mode="after") + def _(self): url_prefix = ( dify_config.CONSOLE_API_URL + f"/console/api/workspaces/{self.tenant_id}/model-providers/{self.provider}" ) @@ -151,6 +151,7 @@ class SimpleProviderEntityResponse(SimpleProviderEntity): self.icon_large = I18nObject( en_US=f"{url_prefix}/icon_large/en_US", zh_Hans=f"{url_prefix}/icon_large/zh_Hans" ) + return self class DefaultModelResponse(BaseModel): diff --git a/api/services/hit_testing_service.py b/api/services/hit_testing_service.py index 00ec3babf3..aa29354a6e 100644 --- a/api/services/hit_testing_service.py +++ b/api/services/hit_testing_service.py @@ -15,7 +15,7 @@ from models.dataset import Dataset, DatasetQuery logger = logging.getLogger(__name__) default_retrieval_model = { - "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, + "search_method": RetrievalMethod.SEMANTIC_SEARCH, "reranking_enable": False, "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, "top_k": 4, @@ -46,7 +46,7 @@ class HitTestingService: from core.app.app_config.entities import MetadataFilteringCondition - metadata_filtering_conditions = MetadataFilteringCondition(**metadata_filtering_conditions) + metadata_filtering_conditions = MetadataFilteringCondition.model_validate(metadata_filtering_conditions) metadata_filter_document_ids, metadata_condition = dataset_retrieval.get_metadata_filter_condition( dataset_ids=[dataset.id], diff --git a/api/services/operation_service.py b/api/services/operation_service.py index 8c8b64bcd5..c05e9d555c 100644 --- a/api/services/operation_service.py +++ b/api/services/operation_service.py @@ -1,6 +1,6 @@ import os -import requests +import httpx class OperationService: @@ -12,7 +12,7 @@ class OperationService: headers = {"Content-Type": "application/json", "Billing-Api-Secret-Key": cls.secret_key} url = f"{cls.base_url}{endpoint}" - response = requests.request(method, url, json=json, params=params, headers=headers) + response = httpx.request(method, url, json=json, params=params, headers=headers) return response.json() diff --git a/api/services/ops_service.py b/api/services/ops_service.py index c214640653..b4b23b8360 100644 --- a/api/services/ops_service.py +++ b/api/services/ops_service.py @@ -123,7 +123,7 @@ class OpsService: config_class: type[BaseTracingConfig] = provider_config["config_class"] other_keys: list[str] = provider_config["other_keys"] - default_config_instance: BaseTracingConfig = config_class(**tracing_config) + default_config_instance = config_class.model_validate(tracing_config) for key in other_keys: if key in tracing_config and tracing_config[key] == "": tracing_config[key] = getattr(default_config_instance, key, None) diff --git a/api/services/plugin/plugin_migration.py b/api/services/plugin/plugin_migration.py index 5db19711e6..dec92a6faa 100644 --- a/api/services/plugin/plugin_migration.py +++ b/api/services/plugin/plugin_migration.py @@ -242,7 +242,7 @@ class PluginMigration: if data.get("type") == "tool": provider_name = data.get("provider_name") provider_type = data.get("provider_type") - if provider_name not in excluded_providers and provider_type == ToolProviderType.BUILT_IN.value: + if provider_name not in excluded_providers and provider_type == ToolProviderType.BUILT_IN: result.append(ToolProviderID(provider_name).plugin_id) return result @@ -269,9 +269,9 @@ class PluginMigration: for tool in agent_config["tools"]: if isinstance(tool, dict): try: - tool_entity = AgentToolEntity(**tool) + tool_entity = AgentToolEntity.model_validate(tool) if ( - tool_entity.provider_type == ToolProviderType.BUILT_IN.value + tool_entity.provider_type == ToolProviderType.BUILT_IN and tool_entity.provider_id not in excluded_providers ): result.append(ToolProviderID(tool_entity.provider_id).plugin_id) @@ -471,7 +471,7 @@ class PluginMigration: total_failed_tenant = 0 while True: # paginate - tenants = db.paginate(db.select(Tenant).order_by(Tenant.created_at.desc()), page=page, per_page=100) + tenants = db.paginate(sa.select(Tenant).order_by(Tenant.created_at.desc()), page=page, per_page=100) if tenants.items is None or len(tenants.items) == 0: break diff --git a/api/services/rag_pipeline/pipeline_template/remote/remote_retrieval.py b/api/services/rag_pipeline/pipeline_template/remote/remote_retrieval.py index 8f96842337..571ca6c7a6 100644 --- a/api/services/rag_pipeline/pipeline_template/remote/remote_retrieval.py +++ b/api/services/rag_pipeline/pipeline_template/remote/remote_retrieval.py @@ -1,6 +1,6 @@ import logging -import requests +import httpx from configs import dify_config from services.rag_pipeline.pipeline_template.database.database_retrieval import DatabasePipelineTemplateRetrieval @@ -43,7 +43,7 @@ class RemotePipelineTemplateRetrieval(PipelineTemplateRetrievalBase): """ domain = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_REMOTE_DOMAIN url = f"{domain}/pipeline-templates/{template_id}" - response = requests.get(url, timeout=(3, 10)) + response = httpx.get(url, timeout=httpx.Timeout(10.0, connect=3.0)) if response.status_code != 200: return None data: dict = response.json() @@ -58,7 +58,7 @@ class RemotePipelineTemplateRetrieval(PipelineTemplateRetrievalBase): """ domain = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_REMOTE_DOMAIN url = f"{domain}/pipeline-templates?language={language}" - response = requests.get(url, timeout=(3, 10)) + response = httpx.get(url, timeout=httpx.Timeout(10.0, connect=3.0)) if response.status_code != 200: raise ValueError(f"fetch pipeline templates failed, status code: {response.status_code}") diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index fdaaa73bcc..13c0ca7392 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -358,7 +358,7 @@ class RagPipelineService: for node in nodes: if node.get("data", {}).get("type") == "knowledge-index": knowledge_configuration = node.get("data", {}) - knowledge_configuration = KnowledgeConfiguration(**knowledge_configuration) + knowledge_configuration = KnowledgeConfiguration.model_validate(knowledge_configuration) # update dataset dataset = pipeline.retrieve_dataset(session=session) @@ -873,7 +873,7 @@ class RagPipelineService: variable_pool = node_instance.graph_runtime_state.variable_pool invoke_from = variable_pool.get(["sys", SystemVariableKey.INVOKE_FROM]) if invoke_from: - if invoke_from.value == InvokeFrom.PUBLISHED.value: + if invoke_from.value == InvokeFrom.PUBLISHED: document_id = variable_pool.get(["sys", SystemVariableKey.DOCUMENT_ID]) if document_id: document = db.session.query(Document).where(Document.id == document_id.value).first() diff --git a/api/services/rag_pipeline/rag_pipeline_dsl_service.py b/api/services/rag_pipeline/rag_pipeline_dsl_service.py index f74de1bcab..c02fad4dc6 100644 --- a/api/services/rag_pipeline/rag_pipeline_dsl_service.py +++ b/api/services/rag_pipeline/rag_pipeline_dsl_service.py @@ -288,7 +288,7 @@ class RagPipelineDslService: dataset_id = None for node in nodes: if node.get("data", {}).get("type") == "knowledge-index": - knowledge_configuration = KnowledgeConfiguration(**node.get("data", {})) + knowledge_configuration = KnowledgeConfiguration.model_validate(node.get("data", {})) if ( dataset and pipeline.is_published @@ -426,7 +426,7 @@ class RagPipelineDslService: dataset_id = None for node in nodes: if node.get("data", {}).get("type") == "knowledge-index": - knowledge_configuration = KnowledgeConfiguration(**node.get("data", {})) + knowledge_configuration = KnowledgeConfiguration.model_validate(node.get("data", {})) if not dataset: dataset = Dataset( tenant_id=account.current_tenant_id, @@ -556,7 +556,7 @@ class RagPipelineDslService: graph = workflow_data.get("graph", {}) for node in graph.get("nodes", []): - if node.get("data", {}).get("type", "") == NodeType.KNOWLEDGE_RETRIEVAL.value: + if node.get("data", {}).get("type", "") == NodeType.KNOWLEDGE_RETRIEVAL: dataset_ids = node["data"].get("dataset_ids", []) node["data"]["dataset_ids"] = [ decrypted_id @@ -613,7 +613,7 @@ class RagPipelineDslService: tenant_id=pipeline.tenant_id, app_id=pipeline.id, features="{}", - type=WorkflowType.RAG_PIPELINE.value, + type=WorkflowType.RAG_PIPELINE, version="draft", graph=json.dumps(graph), created_by=account.id, @@ -689,17 +689,17 @@ class RagPipelineDslService: if not node_data: continue data_type = node_data.get("type", "") - if data_type == NodeType.KNOWLEDGE_RETRIEVAL.value: + if data_type == NodeType.KNOWLEDGE_RETRIEVAL: dataset_ids = node_data.get("dataset_ids", []) node["data"]["dataset_ids"] = [ self.encrypt_dataset_id(dataset_id=dataset_id, tenant_id=pipeline.tenant_id) for dataset_id in dataset_ids ] # filter credential id from tool node - if not include_secret and data_type == NodeType.TOOL.value: + if not include_secret and data_type == NodeType.TOOL: node_data.pop("credential_id", None) # filter credential id from agent node - if not include_secret and data_type == NodeType.AGENT.value: + if not include_secret and data_type == NodeType.AGENT: for tool in node_data.get("agent_parameters", {}).get("tools", {}).get("value", []): tool.pop("credential_id", None) @@ -733,36 +733,36 @@ class RagPipelineDslService: try: typ = node.get("data", {}).get("type") match typ: - case NodeType.TOOL.value: - tool_entity = ToolNodeData(**node["data"]) + case NodeType.TOOL: + tool_entity = ToolNodeData.model_validate(node["data"]) dependencies.append( DependenciesAnalysisService.analyze_tool_dependency(tool_entity.provider_id), ) - case NodeType.DATASOURCE.value: - datasource_entity = DatasourceNodeData(**node["data"]) + case NodeType.DATASOURCE: + datasource_entity = DatasourceNodeData.model_validate(node["data"]) if datasource_entity.provider_type != "local_file": dependencies.append(datasource_entity.plugin_id) - case NodeType.LLM.value: - llm_entity = LLMNodeData(**node["data"]) + case NodeType.LLM: + llm_entity = LLMNodeData.model_validate(node["data"]) dependencies.append( DependenciesAnalysisService.analyze_model_provider_dependency(llm_entity.model.provider), ) - case NodeType.QUESTION_CLASSIFIER.value: - question_classifier_entity = QuestionClassifierNodeData(**node["data"]) + case NodeType.QUESTION_CLASSIFIER: + question_classifier_entity = QuestionClassifierNodeData.model_validate(node["data"]) dependencies.append( DependenciesAnalysisService.analyze_model_provider_dependency( question_classifier_entity.model.provider ), ) - case NodeType.PARAMETER_EXTRACTOR.value: - parameter_extractor_entity = ParameterExtractorNodeData(**node["data"]) + case NodeType.PARAMETER_EXTRACTOR: + parameter_extractor_entity = ParameterExtractorNodeData.model_validate(node["data"]) dependencies.append( DependenciesAnalysisService.analyze_model_provider_dependency( parameter_extractor_entity.model.provider ), ) - case NodeType.KNOWLEDGE_INDEX.value: - knowledge_index_entity = KnowledgeConfiguration(**node["data"]) + case NodeType.KNOWLEDGE_INDEX: + knowledge_index_entity = KnowledgeConfiguration.model_validate(node["data"]) if knowledge_index_entity.indexing_technique == "high_quality": if knowledge_index_entity.embedding_model_provider: dependencies.append( @@ -782,8 +782,8 @@ class RagPipelineDslService: knowledge_index_entity.retrieval_model.reranking_model.reranking_provider_name ), ) - case NodeType.KNOWLEDGE_RETRIEVAL.value: - knowledge_retrieval_entity = KnowledgeRetrievalNodeData(**node["data"]) + case NodeType.KNOWLEDGE_RETRIEVAL: + knowledge_retrieval_entity = KnowledgeRetrievalNodeData.model_validate(node["data"]) if knowledge_retrieval_entity.retrieval_mode == "multiple": if knowledge_retrieval_entity.multiple_retrieval_config: if ( @@ -873,7 +873,7 @@ class RagPipelineDslService: """ Returns the leaked dependencies in current workspace """ - dependencies = [PluginDependency(**dep) for dep in dsl_dependencies] + dependencies = [PluginDependency.model_validate(dep) for dep in dsl_dependencies] if not dependencies: return [] @@ -927,7 +927,7 @@ class RagPipelineDslService: account = cast(Account, current_user) rag_pipeline_import_info: RagPipelineImportInfo = self.import_rag_pipeline( account=account, - import_mode=ImportMode.YAML_CONTENT.value, + import_mode=ImportMode.YAML_CONTENT, yaml_content=rag_pipeline_dataset_create_entity.yaml_content, dataset=None, dataset_name=rag_pipeline_dataset_create_entity.name, diff --git a/api/services/rag_pipeline/rag_pipeline_transform_service.py b/api/services/rag_pipeline/rag_pipeline_transform_service.py index db9508824b..39f426a2b0 100644 --- a/api/services/rag_pipeline/rag_pipeline_transform_service.py +++ b/api/services/rag_pipeline/rag_pipeline_transform_service.py @@ -149,21 +149,20 @@ class RagPipelineTransformService: file_extensions = node.get("data", {}).get("fileExtensions", []) if not file_extensions: return node - file_extensions = [file_extension.lower() for file_extension in file_extensions] - node["data"]["fileExtensions"] = DOCUMENT_EXTENSIONS + node["data"]["fileExtensions"] = [ext.lower() for ext in file_extensions if ext in DOCUMENT_EXTENSIONS] return node def _deal_knowledge_index( self, dataset: Dataset, doc_form: str, indexing_technique: str | None, retrieval_model: dict, node: dict ): knowledge_configuration_dict = node.get("data", {}) - knowledge_configuration = KnowledgeConfiguration(**knowledge_configuration_dict) + knowledge_configuration = KnowledgeConfiguration.model_validate(knowledge_configuration_dict) if indexing_technique == "high_quality": knowledge_configuration.embedding_model = dataset.embedding_model knowledge_configuration.embedding_model_provider = dataset.embedding_model_provider if retrieval_model: - retrieval_setting = RetrievalSetting(**retrieval_model) + retrieval_setting = RetrievalSetting.model_validate(retrieval_model) if indexing_technique == "economy": retrieval_setting.search_method = "keyword_search" knowledge_configuration.retrieval_model = retrieval_setting @@ -215,7 +214,7 @@ class RagPipelineTransformService: tenant_id=pipeline.tenant_id, app_id=pipeline.id, features="{}", - type=WorkflowType.RAG_PIPELINE.value, + type=WorkflowType.RAG_PIPELINE, version="draft", graph=json.dumps(graph), created_by=current_user.id, @@ -227,7 +226,7 @@ class RagPipelineTransformService: tenant_id=pipeline.tenant_id, app_id=pipeline.id, features="{}", - type=WorkflowType.RAG_PIPELINE.value, + type=WorkflowType.RAG_PIPELINE, version=str(datetime.now(UTC).replace(tzinfo=None)), graph=json.dumps(graph), created_by=current_user.id, diff --git a/api/services/recommend_app/remote/remote_retrieval.py b/api/services/recommend_app/remote/remote_retrieval.py index 2d57769f63..b217c9026a 100644 --- a/api/services/recommend_app/remote/remote_retrieval.py +++ b/api/services/recommend_app/remote/remote_retrieval.py @@ -1,6 +1,6 @@ import logging -import requests +import httpx from configs import dify_config from services.recommend_app.buildin.buildin_retrieval import BuildInRecommendAppRetrieval @@ -43,7 +43,7 @@ class RemoteRecommendAppRetrieval(RecommendAppRetrievalBase): """ domain = dify_config.HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN url = f"{domain}/apps/{app_id}" - response = requests.get(url, timeout=(3, 10)) + response = httpx.get(url, timeout=httpx.Timeout(10.0, connect=3.0)) if response.status_code != 200: return None data: dict = response.json() @@ -58,7 +58,7 @@ class RemoteRecommendAppRetrieval(RecommendAppRetrievalBase): """ domain = dify_config.HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN url = f"{domain}/apps?language={language}" - response = requests.get(url, timeout=(3, 10)) + response = httpx.get(url, timeout=httpx.Timeout(10.0, connect=3.0)) if response.status_code != 200: raise ValueError(f"fetch recommended apps failed, status code: {response.status_code}") diff --git a/api/services/tag_service.py b/api/services/tag_service.py index 4674335ba8..db7ed3d5c3 100644 --- a/api/services/tag_service.py +++ b/api/services/tag_service.py @@ -1,5 +1,6 @@ import uuid +import sqlalchemy as sa from flask_login import current_user from sqlalchemy import func, select from werkzeug.exceptions import NotFound @@ -18,7 +19,7 @@ class TagService: .where(Tag.type == tag_type, Tag.tenant_id == current_tenant_id) ) if keyword: - query = query.where(db.and_(Tag.name.ilike(f"%{keyword}%"))) + query = query.where(sa.and_(Tag.name.ilike(f"%{keyword}%"))) query = query.group_by(Tag.id, Tag.type, Tag.name, Tag.created_at) results: list = query.order_by(Tag.created_at.desc()).all() return results diff --git a/api/services/tools/api_tools_manage_service.py b/api/services/tools/api_tools_manage_service.py index f86d7e51bf..2c0c63f634 100644 --- a/api/services/tools/api_tools_manage_service.py +++ b/api/services/tools/api_tools_manage_service.py @@ -277,7 +277,7 @@ class ApiToolManageService: provider.icon = json.dumps(icon) provider.schema = schema provider.description = extra_info.get("description", "") - provider.schema_type_str = ApiProviderSchemaType.OPENAPI.value + provider.schema_type_str = ApiProviderSchemaType.OPENAPI provider.tools_str = json.dumps(jsonable_encoder(tool_bundles)) provider.privacy_policy = privacy_policy provider.custom_disclaimer = custom_disclaimer @@ -393,7 +393,7 @@ class ApiToolManageService: icon="", schema=schema, description="", - schema_type_str=ApiProviderSchemaType.OPENAPI.value, + schema_type_str=ApiProviderSchemaType.OPENAPI, tools_str=json.dumps(jsonable_encoder(tool_bundles)), credentials_str=json.dumps(credentials), ) diff --git a/api/services/tools/builtin_tools_manage_service.py b/api/services/tools/builtin_tools_manage_service.py index 6b0b6b0f0e..cab4a5c6ab 100644 --- a/api/services/tools/builtin_tools_manage_service.py +++ b/api/services/tools/builtin_tools_manage_service.py @@ -349,14 +349,10 @@ class BuiltinToolManageService: provider_controller = ToolManager.get_builtin_provider(default_provider.provider, tenant_id) credentials: list[ToolProviderCredentialApiEntity] = [] - encrypters = {} for provider in providers: - credential_type = provider.credential_type - if credential_type not in encrypters: - encrypters[credential_type] = BuiltinToolManageService.create_tool_encrypter( - tenant_id, provider, provider.provider, provider_controller - )[0] - encrypter = encrypters[credential_type] + encrypter, _ = BuiltinToolManageService.create_tool_encrypter( + tenant_id, provider, provider.provider, provider_controller + ) decrypt_credential = encrypter.mask_tool_credentials(encrypter.decrypt(provider.credentials)) credential_entity = ToolTransformService.convert_builtin_provider_to_credential_entity( provider=provider, diff --git a/api/services/tools/mcp_tools_manage_service.py b/api/services/tools/mcp_tools_manage_service.py index dd626dd615..605ad8379b 100644 --- a/api/services/tools/mcp_tools_manage_service.py +++ b/api/services/tools/mcp_tools_manage_service.py @@ -1,7 +1,7 @@ import hashlib import json from datetime import datetime -from typing import Any, cast +from typing import Any from sqlalchemy import or_ from sqlalchemy.exc import IntegrityError @@ -55,7 +55,7 @@ class MCPToolManageService: cache=NoOpProviderCredentialCache(), ) - return cast(dict[str, str], encrypter_instance.encrypt(headers)) + return encrypter_instance.encrypt(headers) @staticmethod def get_mcp_provider_by_provider_id(provider_id: str, tenant_id: str) -> MCPToolProvider: diff --git a/api/services/tools/tools_transform_service.py b/api/services/tools/tools_transform_service.py index 6b36ed0eb7..81b4d6993a 100644 --- a/api/services/tools/tools_transform_service.py +++ b/api/services/tools/tools_transform_service.py @@ -50,16 +50,16 @@ class ToolTransformService: URL(dify_config.CONSOLE_API_URL or "/") / "console" / "api" / "workspaces" / "current" / "tool-provider" ) - if provider_type == ToolProviderType.BUILT_IN.value: + if provider_type == ToolProviderType.BUILT_IN: return str(url_prefix / "builtin" / provider_name / "icon") - elif provider_type in {ToolProviderType.API.value, ToolProviderType.WORKFLOW.value}: + elif provider_type in {ToolProviderType.API, ToolProviderType.WORKFLOW}: try: if isinstance(icon, str): return json.loads(icon) return icon except Exception: return {"background": "#252525", "content": "\ud83d\ude01"} - elif provider_type == ToolProviderType.MCP.value: + elif provider_type == ToolProviderType.MCP: return icon return "" @@ -242,7 +242,7 @@ class ToolTransformService: is_team_authorization=db_provider.authed, server_url=db_provider.masked_server_url, tools=ToolTransformService.mcp_tool_to_user_tool( - db_provider, [MCPTool(**tool) for tool in json.loads(db_provider.tools)] + db_provider, [MCPTool.model_validate(tool) for tool in json.loads(db_provider.tools)] ), updated_at=int(db_provider.updated_at.timestamp()), label=I18nObject(en_US=db_provider.name, zh_Hans=db_provider.name), @@ -387,6 +387,7 @@ class ToolTransformService: labels=labels or [], ) else: + assert tool.operation_id return ToolApiEntity( author=tool.author, name=tool.operation_id or "", diff --git a/api/services/variable_truncator.py b/api/services/variable_truncator.py index 4362bb0291..d02508e4f3 100644 --- a/api/services/variable_truncator.py +++ b/api/services/variable_truncator.py @@ -262,6 +262,14 @@ class VariableTruncator: target_length = self._array_element_limit for i, item in enumerate(value): + # Dirty fix: + # The output of `Start` node may contain list of `File` elements, + # causing `AssertionError` while invoking `_truncate_json_primitives`. + # + # This check ensures that `list[File]` are handled separately + if isinstance(item, File): + truncated_value.append(item) + continue if i >= target_length: return _PartResult(truncated_value, used_size, True) if i > 0: diff --git a/api/services/vector_service.py b/api/services/vector_service.py index 1c559f2c2b..abc92a0181 100644 --- a/api/services/vector_service.py +++ b/api/services/vector_service.py @@ -134,7 +134,7 @@ class VectorService: ) # use full doc mode to generate segment's child chunk processing_rule_dict = processing_rule.to_dict() - processing_rule_dict["rules"]["parent_mode"] = ParentMode.FULL_DOC.value + processing_rule_dict["rules"]["parent_mode"] = ParentMode.FULL_DOC documents = index_processor.transform( [document], embedding_model_instance=embedding_model_instance, diff --git a/api/services/webapp_auth_service.py b/api/services/webapp_auth_service.py index 066dc9d741..d30e14f7a1 100644 --- a/api/services/webapp_auth_service.py +++ b/api/services/webapp_auth_service.py @@ -36,7 +36,7 @@ class WebAppAuthService: if not account: raise AccountNotFoundError() - if account.status == AccountStatus.BANNED.value: + if account.status == AccountStatus.BANNED: raise AccountLoginError("Account is banned.") if account.password is None or not compare_password(password, account.password, account.password_salt): @@ -56,7 +56,7 @@ class WebAppAuthService: if not account: return None - if account.status == AccountStatus.BANNED.value: + if account.status == AccountStatus.BANNED: raise Unauthorized("Account is banned.") return account diff --git a/api/services/website_service.py b/api/services/website_service.py index 7634fdd8f3..37588d6ba5 100644 --- a/api/services/website_service.py +++ b/api/services/website_service.py @@ -3,7 +3,7 @@ import json from dataclasses import dataclass from typing import Any -import requests +import httpx from flask_login import current_user from core.helper import encrypter @@ -216,7 +216,7 @@ class WebsiteService: @classmethod def _crawl_with_jinareader(cls, request: CrawlRequest, api_key: str) -> dict[str, Any]: if not request.options.crawl_sub_pages: - response = requests.get( + response = httpx.get( f"https://r.jina.ai/{request.url}", headers={"Accept": "application/json", "Authorization": f"Bearer {api_key}"}, ) @@ -224,7 +224,7 @@ class WebsiteService: raise ValueError("Failed to crawl:") return {"status": "active", "data": response.json().get("data")} else: - response = requests.post( + response = httpx.post( "https://adaptivecrawl-kir3wx7b3a-uc.a.run.app", json={ "url": request.url, @@ -287,7 +287,7 @@ class WebsiteService: @classmethod def _get_jinareader_status(cls, job_id: str, api_key: str) -> dict[str, Any]: - response = requests.post( + response = httpx.post( "https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app", headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}, json={"taskId": job_id}, @@ -303,7 +303,7 @@ class WebsiteService: } if crawl_status_data["status"] == "completed": - response = requests.post( + response = httpx.post( "https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app", headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}, json={"taskId": job_id, "urls": list(data.get("processed", {}).keys())}, @@ -362,7 +362,7 @@ class WebsiteService: @classmethod def _get_jinareader_url_data(cls, job_id: str, url: str, api_key: str) -> dict[str, Any] | None: if not job_id: - response = requests.get( + response = httpx.get( f"https://r.jina.ai/{url}", headers={"Accept": "application/json", "Authorization": f"Bearer {api_key}"}, ) @@ -371,7 +371,7 @@ class WebsiteService: return dict(response.json().get("data", {})) else: # Get crawl status first - status_response = requests.post( + status_response = httpx.post( "https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app", headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}, json={"taskId": job_id}, @@ -381,7 +381,7 @@ class WebsiteService: raise ValueError("Crawl job is not completed") # Get processed data - data_response = requests.post( + data_response = httpx.post( "https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app", headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}, json={"taskId": job_id, "urls": list(status_data.get("processed", {}).keys())}, diff --git a/api/services/workflow/workflow_converter.py b/api/services/workflow/workflow_converter.py index dccd891981..9c09f54bf5 100644 --- a/api/services/workflow/workflow_converter.py +++ b/api/services/workflow/workflow_converter.py @@ -79,7 +79,6 @@ class WorkflowConverter: new_app.updated_by = account.id db.session.add(new_app) db.session.flush() - db.session.commit() workflow.app_id = new_app.id db.session.commit() @@ -229,7 +228,7 @@ class WorkflowConverter: "position": None, "data": { "title": "START", - "type": NodeType.START.value, + "type": NodeType.START, "variables": [jsonable_encoder(v) for v in variables], }, } @@ -274,7 +273,7 @@ class WorkflowConverter: inputs[v.variable] = "{{#start." + v.variable + "#}}" request_body = { - "point": APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY.value, + "point": APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY, "params": { "app_id": app_model.id, "tool_variable": tool_variable, @@ -291,7 +290,7 @@ class WorkflowConverter: "position": None, "data": { "title": f"HTTP REQUEST {api_based_extension.name}", - "type": NodeType.HTTP_REQUEST.value, + "type": NodeType.HTTP_REQUEST, "method": "post", "url": api_based_extension.api_endpoint, "authorization": {"type": "api-key", "config": {"type": "bearer", "api_key": api_key}}, @@ -309,7 +308,7 @@ class WorkflowConverter: "position": None, "data": { "title": f"Parse {api_based_extension.name} Response", - "type": NodeType.CODE.value, + "type": NodeType.CODE, "variables": [{"variable": "response_json", "value_selector": [http_request_node["id"], "body"]}], "code_language": "python3", "code": "import json\n\ndef main(response_json: str) -> str:\n response_body = json.loads(" @@ -349,7 +348,7 @@ class WorkflowConverter: "position": None, "data": { "title": "KNOWLEDGE RETRIEVAL", - "type": NodeType.KNOWLEDGE_RETRIEVAL.value, + "type": NodeType.KNOWLEDGE_RETRIEVAL, "query_variable_selector": query_variable_selector, "dataset_ids": dataset_config.dataset_ids, "retrieval_mode": retrieve_config.retrieve_strategy.value, @@ -397,16 +396,16 @@ class WorkflowConverter: :param external_data_variable_node_mapping: external data variable node mapping """ # fetch start and knowledge retrieval node - start_node = next(filter(lambda n: n["data"]["type"] == NodeType.START.value, graph["nodes"])) + start_node = next(filter(lambda n: n["data"]["type"] == NodeType.START, graph["nodes"])) knowledge_retrieval_node = next( - filter(lambda n: n["data"]["type"] == NodeType.KNOWLEDGE_RETRIEVAL.value, graph["nodes"]), None + filter(lambda n: n["data"]["type"] == NodeType.KNOWLEDGE_RETRIEVAL, graph["nodes"]), None ) role_prefix = None prompts: Any | None = None # Chat Model - if model_config.mode == LLMMode.CHAT.value: + if model_config.mode == LLMMode.CHAT: if prompt_template.prompt_type == PromptTemplateEntity.PromptType.SIMPLE: if not prompt_template.simple_prompt_template: raise ValueError("Simple prompt template is required") @@ -518,7 +517,7 @@ class WorkflowConverter: "position": None, "data": { "title": "LLM", - "type": NodeType.LLM.value, + "type": NodeType.LLM, "model": { "provider": model_config.provider, "name": model_config.model, @@ -573,7 +572,7 @@ class WorkflowConverter: "position": None, "data": { "title": "END", - "type": NodeType.END.value, + "type": NodeType.END, "outputs": [{"variable": "result", "value_selector": ["llm", "text"]}], }, } @@ -587,7 +586,7 @@ class WorkflowConverter: return { "id": "answer", "position": None, - "data": {"title": "ANSWER", "type": NodeType.ANSWER.value, "answer": "{{#llm.text#}}"}, + "data": {"title": "ANSWER", "type": NodeType.ANSWER, "answer": "{{#llm.text#}}"}, } def _create_edge(self, source: str, target: str): diff --git a/api/services/workflow_draft_variable_service.py b/api/services/workflow_draft_variable_service.py index 1378c20128..344b7486ee 100644 --- a/api/services/workflow_draft_variable_service.py +++ b/api/services/workflow_draft_variable_service.py @@ -569,7 +569,7 @@ class WorkflowDraftVariableService: system_instruction="", system_instruction_tokens=0, status="normal", - invoke_from=InvokeFrom.DEBUGGER.value, + invoke_from=InvokeFrom.DEBUGGER, from_source="console", from_end_user_id=None, from_account_id=account_id, diff --git a/api/services/workflow_run_service.py b/api/services/workflow_run_service.py index 79d91cab4c..6a2edd912a 100644 --- a/api/services/workflow_run_service.py +++ b/api/services/workflow_run_service.py @@ -74,7 +74,7 @@ class WorkflowRunService: return self._workflow_run_repo.get_paginated_workflow_runs( tenant_id=app_model.tenant_id, app_id=app_model.id, - triggered_from=WorkflowRunTriggeredFrom.DEBUGGING.value, + triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, limit=limit, last_id=last_id, ) diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 359fdb85fd..dea6a657a4 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -1006,7 +1006,7 @@ def _setup_variable_pool( ) # Only add chatflow-specific variables for non-workflow types - if workflow.type != WorkflowType.WORKFLOW.value: + if workflow.type != WorkflowType.WORKFLOW: system_variable.query = query system_variable.conversation_id = conversation_id system_variable.dialogue_count = 1 diff --git a/api/tasks/document_indexing_sync_task.py b/api/tasks/document_indexing_sync_task.py index 10da9a9af4..4c1f38c3bb 100644 --- a/api/tasks/document_indexing_sync_task.py +++ b/api/tasks/document_indexing_sync_task.py @@ -2,6 +2,7 @@ import logging import time import click +import sqlalchemy as sa from celery import shared_task from sqlalchemy import select @@ -51,7 +52,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str): data_source_binding = ( db.session.query(DataSourceOauthBinding) .where( - db.and_( + sa.and_( DataSourceOauthBinding.tenant_id == document.tenant_id, DataSourceOauthBinding.provider == "notion", DataSourceOauthBinding.disabled == False, diff --git a/api/tasks/ops_trace_task.py b/api/tasks/ops_trace_task.py index 7b254ac3b5..72e3b42ca7 100644 --- a/api/tasks/ops_trace_task.py +++ b/api/tasks/ops_trace_task.py @@ -36,7 +36,7 @@ def process_trace_tasks(file_info): if trace_info.get("workflow_data"): trace_info["workflow_data"] = WorkflowRun.from_dict(data=trace_info["workflow_data"]) if trace_info.get("documents"): - trace_info["documents"] = [Document(**doc) for doc in trace_info["documents"]] + trace_info["documents"] = [Document.model_validate(doc) for doc in trace_info["documents"]] try: if trace_instance: diff --git a/api/tasks/process_tenant_plugin_autoupgrade_check_task.py b/api/tasks/process_tenant_plugin_autoupgrade_check_task.py index bae8f1c4db..124971e8e2 100644 --- a/api/tasks/process_tenant_plugin_autoupgrade_check_task.py +++ b/api/tasks/process_tenant_plugin_autoupgrade_check_task.py @@ -1,5 +1,5 @@ +import json import operator -import traceback import typing import click @@ -9,38 +9,106 @@ from core.helper import marketplace from core.helper.marketplace import MarketplacePluginDeclaration from core.plugin.entities.plugin import PluginInstallationSource from core.plugin.impl.plugin import PluginInstaller +from extensions.ext_redis import redis_client from models.account import TenantPluginAutoUpgradeStrategy RETRY_TIMES_OF_ONE_PLUGIN_IN_ONE_TENANT = 3 +CACHE_REDIS_KEY_PREFIX = "plugin_autoupgrade_check_task:cached_plugin_manifests:" +CACHE_REDIS_TTL = 60 * 15 # 15 minutes -cached_plugin_manifests: dict[str, typing.Union[MarketplacePluginDeclaration, None]] = {} +def _get_redis_cache_key(plugin_id: str) -> str: + """Generate Redis cache key for plugin manifest.""" + return f"{CACHE_REDIS_KEY_PREFIX}{plugin_id}" + + +def _get_cached_manifest(plugin_id: str) -> typing.Union[MarketplacePluginDeclaration, None, bool]: + """ + Get cached plugin manifest from Redis. + Returns: + - MarketplacePluginDeclaration: if found in cache + - None: if cached as not found (marketplace returned no result) + - False: if not in cache at all + """ + try: + key = _get_redis_cache_key(plugin_id) + cached_data = redis_client.get(key) + if cached_data is None: + return False + + cached_json = json.loads(cached_data) + if cached_json is None: + return None + + return MarketplacePluginDeclaration.model_validate(cached_json) + except Exception: + return False + + +def _set_cached_manifest(plugin_id: str, manifest: typing.Union[MarketplacePluginDeclaration, None]) -> None: + """ + Cache plugin manifest in Redis. + Args: + plugin_id: The plugin ID + manifest: The manifest to cache, or None if not found in marketplace + """ + try: + key = _get_redis_cache_key(plugin_id) + if manifest is None: + # Cache the fact that this plugin was not found + redis_client.setex(key, CACHE_REDIS_TTL, json.dumps(None)) + else: + # Cache the manifest data + redis_client.setex(key, CACHE_REDIS_TTL, manifest.model_dump_json()) + except Exception: + # If Redis fails, continue without caching + # traceback.print_exc() + pass def marketplace_batch_fetch_plugin_manifests( plugin_ids_plain_list: list[str], ) -> list[MarketplacePluginDeclaration]: - global cached_plugin_manifests - # return marketplace.batch_fetch_plugin_manifests(plugin_ids_plain_list) - not_included_plugin_ids = [ - plugin_id for plugin_id in plugin_ids_plain_list if plugin_id not in cached_plugin_manifests - ] - if not_included_plugin_ids: - manifests = marketplace.batch_fetch_plugin_manifests_ignore_deserialization_error(not_included_plugin_ids) + """Fetch plugin manifests with Redis caching support.""" + cached_manifests: dict[str, typing.Union[MarketplacePluginDeclaration, None]] = {} + not_cached_plugin_ids: list[str] = [] + + # Check Redis cache for each plugin + for plugin_id in plugin_ids_plain_list: + cached_result = _get_cached_manifest(plugin_id) + if cached_result is False: + # Not in cache, need to fetch + not_cached_plugin_ids.append(plugin_id) + else: + # Either found manifest or cached as None (not found in marketplace) + # At this point, cached_result is either MarketplacePluginDeclaration or None + if isinstance(cached_result, bool): + # This should never happen due to the if condition above, but for type safety + continue + cached_manifests[plugin_id] = cached_result + + # Fetch uncached plugins from marketplace + if not_cached_plugin_ids: + manifests = marketplace.batch_fetch_plugin_manifests_ignore_deserialization_error(not_cached_plugin_ids) + + # Cache the fetched manifests for manifest in manifests: - cached_plugin_manifests[manifest.plugin_id] = manifest + cached_manifests[manifest.plugin_id] = manifest + _set_cached_manifest(manifest.plugin_id, manifest) - if ( - len(manifests) == 0 - ): # this indicates that the plugin not found in marketplace, should set None in cache to prevent future check - for plugin_id in not_included_plugin_ids: - cached_plugin_manifests[plugin_id] = None + # Cache plugins that were not found in marketplace + fetched_plugin_ids = {manifest.plugin_id for manifest in manifests} + for plugin_id in not_cached_plugin_ids: + if plugin_id not in fetched_plugin_ids: + cached_manifests[plugin_id] = None + _set_cached_manifest(plugin_id, None) + # Build result list from cached manifests result: list[MarketplacePluginDeclaration] = [] for plugin_id in plugin_ids_plain_list: - final_manifest = cached_plugin_manifests.get(plugin_id) - if final_manifest is not None: - result.append(final_manifest) + cached_manifest: typing.Union[MarketplacePluginDeclaration, None] = cached_manifests.get(plugin_id) + if cached_manifest is not None: + result.append(cached_manifest) return result @@ -157,10 +225,10 @@ def process_tenant_plugin_autoupgrade_check_task( ) except Exception as e: click.echo(click.style(f"Error when upgrading plugin: {e}", fg="red")) - traceback.print_exc() + # traceback.print_exc() break except Exception as e: click.echo(click.style(f"Error when checking upgradable plugin: {e}", fg="red")) - traceback.print_exc() + # traceback.print_exc() return diff --git a/api/tasks/rag_pipeline/priority_rag_pipeline_run_task.py b/api/tasks/rag_pipeline/priority_rag_pipeline_run_task.py index 028f635188..4171656131 100644 --- a/api/tasks/rag_pipeline/priority_rag_pipeline_run_task.py +++ b/api/tasks/rag_pipeline/priority_rag_pipeline_run_task.py @@ -29,23 +29,10 @@ def priority_rag_pipeline_run_task( tenant_id: str, ): """ - Async Run rag pipeline - :param rag_pipeline_invoke_entities: Rag pipeline invoke entities - rag_pipeline_invoke_entities include: - :param pipeline_id: Pipeline ID - :param user_id: User ID - :param tenant_id: Tenant ID - :param workflow_id: Workflow ID - :param invoke_from: Invoke source (debugger, published, etc.) - :param streaming: Whether to stream results - :param datasource_type: Type of datasource - :param datasource_info: Datasource information dict - :param batch: Batch identifier - :param document_id: Document ID (optional) - :param start_node_id: Starting node ID - :param inputs: Input parameters dict - :param workflow_execution_id: Workflow execution ID - :param workflow_thread_pool_id: Thread pool ID for workflow execution + Async Run rag pipeline task using high priority queue. + + :param rag_pipeline_invoke_entities_file_id: File ID containing serialized RAG pipeline invoke entities + :param tenant_id: Tenant ID for the pipeline execution """ # run with threading, thread pool size is 10 @@ -92,7 +79,7 @@ def run_single_rag_pipeline_task(rag_pipeline_invoke_entity: Mapping[str, Any], # Create Flask application context for this thread with flask_app.app_context(): try: - rag_pipeline_invoke_entity_model = RagPipelineInvokeEntity(**rag_pipeline_invoke_entity) + rag_pipeline_invoke_entity_model = RagPipelineInvokeEntity.model_validate(rag_pipeline_invoke_entity) user_id = rag_pipeline_invoke_entity_model.user_id tenant_id = rag_pipeline_invoke_entity_model.tenant_id pipeline_id = rag_pipeline_invoke_entity_model.pipeline_id @@ -125,7 +112,7 @@ def run_single_rag_pipeline_task(rag_pipeline_invoke_entity: Mapping[str, Any], workflow_execution_id = str(uuid.uuid4()) # Create application generate entity from dict - entity = RagPipelineGenerateEntity(**application_generate_entity) + entity = RagPipelineGenerateEntity.model_validate(application_generate_entity) # Create workflow repositories session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) diff --git a/api/tasks/rag_pipeline/rag_pipeline_run_task.py b/api/tasks/rag_pipeline/rag_pipeline_run_task.py index ee904c4649..90ebe80daf 100644 --- a/api/tasks/rag_pipeline/rag_pipeline_run_task.py +++ b/api/tasks/rag_pipeline/rag_pipeline_run_task.py @@ -30,23 +30,10 @@ def rag_pipeline_run_task( tenant_id: str, ): """ - Async Run rag pipeline - :param rag_pipeline_invoke_entities: Rag pipeline invoke entities - rag_pipeline_invoke_entities include: - :param pipeline_id: Pipeline ID - :param user_id: User ID - :param tenant_id: Tenant ID - :param workflow_id: Workflow ID - :param invoke_from: Invoke source (debugger, published, etc.) - :param streaming: Whether to stream results - :param datasource_type: Type of datasource - :param datasource_info: Datasource information dict - :param batch: Batch identifier - :param document_id: Document ID (optional) - :param start_node_id: Starting node ID - :param inputs: Input parameters dict - :param workflow_execution_id: Workflow execution ID - :param workflow_thread_pool_id: Thread pool ID for workflow execution + Async Run rag pipeline task using regular priority queue. + + :param rag_pipeline_invoke_entities_file_id: File ID containing serialized RAG pipeline invoke entities + :param tenant_id: Tenant ID for the pipeline execution """ # run with threading, thread pool size is 10 @@ -113,7 +100,7 @@ def run_single_rag_pipeline_task(rag_pipeline_invoke_entity: Mapping[str, Any], # Create Flask application context for this thread with flask_app.app_context(): try: - rag_pipeline_invoke_entity_model = RagPipelineInvokeEntity(**rag_pipeline_invoke_entity) + rag_pipeline_invoke_entity_model = RagPipelineInvokeEntity.model_validate(rag_pipeline_invoke_entity) user_id = rag_pipeline_invoke_entity_model.user_id tenant_id = rag_pipeline_invoke_entity_model.tenant_id pipeline_id = rag_pipeline_invoke_entity_model.pipeline_id @@ -146,7 +133,7 @@ def run_single_rag_pipeline_task(rag_pipeline_invoke_entity: Mapping[str, Any], workflow_execution_id = str(uuid.uuid4()) # Create application generate entity from dict - entity = RagPipelineGenerateEntity(**application_generate_entity) + entity = RagPipelineGenerateEntity.model_validate(application_generate_entity) # Create workflow repositories session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) diff --git a/api/tasks/workflow_draft_var_tasks.py b/api/tasks/workflow_draft_var_tasks.py index 457d46a9d8..fcb98ec39e 100644 --- a/api/tasks/workflow_draft_var_tasks.py +++ b/api/tasks/workflow_draft_var_tasks.py @@ -5,15 +5,10 @@ These tasks provide asynchronous storage capabilities for workflow execution dat improving performance by offloading storage operations to background workers. """ -import logging - from celery import shared_task # type: ignore[import-untyped] from sqlalchemy.orm import Session from extensions.ext_database import db - -_logger = logging.getLogger(__name__) - from services.workflow_draft_variable_service import DraftVarFileDeletion, WorkflowDraftVariableService diff --git a/api/tests/fixtures/workflow/test-answer-order.yml b/api/tests/fixtures/workflow/test-answer-order.yml new file mode 100644 index 0000000000..3c6631aebb --- /dev/null +++ b/api/tests/fixtures/workflow/test-answer-order.yml @@ -0,0 +1,222 @@ +app: + description: 'this is a chatflow with 2 answer nodes. + + + it''s outouts should like: + + + ``` + + --- answer 1 --- + + + foo + + --- answer 2 --- + + + + + ```' + icon: 🤖 + icon_background: '#FFEAD5' + mode: advanced-chat + name: test-answer-order + use_icon_as_answer_icon: false +dependencies: +- current_identifier: null + type: marketplace + value: + marketplace_plugin_unique_identifier: langgenius/openai:0.2.6@e2665624a156f52160927bceac9e169bd7e5ae6b936ae82575e14c90af390e6e + version: null +kind: app +version: 0.4.0 +workflow: + conversation_variables: [] + environment_variables: [] + features: + file_upload: + allowed_file_extensions: + - .JPG + - .JPEG + - .PNG + - .GIF + - .WEBP + - .SVG + allowed_file_types: + - image + allowed_file_upload_methods: + - local_file + - remote_url + enabled: false + fileUploadConfig: + audio_file_size_limit: 50 + batch_count_limit: 5 + file_size_limit: 15 + image_file_size_limit: 10 + video_file_size_limit: 100 + workflow_file_upload_limit: 10 + image: + enabled: false + number_limits: 3 + transfer_methods: + - local_file + - remote_url + number_limits: 3 + opening_statement: '' + retriever_resource: + enabled: true + sensitive_word_avoidance: + enabled: false + speech_to_text: + enabled: false + suggested_questions: [] + suggested_questions_after_answer: + enabled: false + text_to_speech: + enabled: false + language: '' + voice: '' + graph: + edges: + - data: + isInIteration: false + isInLoop: false + sourceType: answer + targetType: answer + id: 1759052466526-source-1759052469368-target + source: '1759052466526' + sourceHandle: source + target: '1759052469368' + targetHandle: target + type: custom + zIndex: 0 + - data: + isInIteration: false + isInLoop: false + sourceType: start + targetType: llm + id: 1759052439553-source-1759052580454-target + source: '1759052439553' + sourceHandle: source + target: '1759052580454' + targetHandle: target + type: custom + zIndex: 0 + - data: + isInIteration: false + isInLoop: false + sourceType: llm + targetType: answer + id: 1759052580454-source-1759052466526-target + source: '1759052580454' + sourceHandle: source + target: '1759052466526' + targetHandle: target + type: custom + zIndex: 0 + nodes: + - data: + selected: false + title: Start + type: start + variables: [] + height: 52 + id: '1759052439553' + position: + x: 30 + y: 242 + positionAbsolute: + x: 30 + y: 242 + sourcePosition: right + targetPosition: left + type: custom + width: 242 + - data: + answer: '--- answer 1 --- + + + foo + + ' + selected: false + title: Answer + type: answer + variables: [] + height: 100 + id: '1759052466526' + position: + x: 632 + y: 242 + positionAbsolute: + x: 632 + y: 242 + selected: true + sourcePosition: right + targetPosition: left + type: custom + width: 242 + - data: + answer: '--- answer 2 --- + + + {{#1759052580454.text#}} + + ' + selected: false + title: Answer 2 + type: answer + variables: [] + height: 103 + id: '1759052469368' + position: + x: 934 + y: 242 + positionAbsolute: + x: 934 + y: 242 + selected: false + sourcePosition: right + targetPosition: left + type: custom + width: 242 + - data: + context: + enabled: false + variable_selector: [] + model: + completion_params: + temperature: 0.7 + mode: chat + name: gpt-4o + provider: langgenius/openai/openai + prompt_template: + - id: 5c1d873b-06b2-4dce-939e-672882bbd7c0 + role: system + text: '' + - role: user + text: '{{#sys.query#}}' + selected: false + title: LLM + type: llm + vision: + enabled: false + height: 88 + id: '1759052580454' + position: + x: 332 + y: 242 + positionAbsolute: + x: 332 + y: 242 + selected: false + sourcePosition: right + targetPosition: left + type: custom + width: 242 + viewport: + x: 126.2797574512839 + y: 289.55932160537446 + zoom: 1.0743222672006216 + rag_pipeline_variables: [] diff --git a/api/tests/fixtures/workflow/update-conversation-variable-in-iteration.yml b/api/tests/fixtures/workflow/update-conversation-variable-in-iteration.yml new file mode 100644 index 0000000000..ffc6eb9120 --- /dev/null +++ b/api/tests/fixtures/workflow/update-conversation-variable-in-iteration.yml @@ -0,0 +1,316 @@ +app: + description: 'This chatflow receives a sys.query, writes it into the `answer` variable, + and then outputs the `answer` variable. + + + `answer` is a conversation variable with a blank default value; it will be updated + in an iteration node. + + + if this chatflow works correctly, it will output the `sys.query` as the same.' + icon: 🤖 + icon_background: '#FFEAD5' + mode: advanced-chat + name: update-conversation-variable-in-iteration + use_icon_as_answer_icon: false +dependencies: [] +kind: app +version: 0.4.0 +workflow: + conversation_variables: + - description: '' + id: c30af82d-b2ec-417d-a861-4dd78584faa4 + name: answer + selector: + - conversation + - answer + value: '' + value_type: string + environment_variables: [] + features: + file_upload: + allowed_file_extensions: + - .JPG + - .JPEG + - .PNG + - .GIF + - .WEBP + - .SVG + allowed_file_types: + - image + allowed_file_upload_methods: + - local_file + - remote_url + enabled: false + fileUploadConfig: + audio_file_size_limit: 50 + batch_count_limit: 5 + file_size_limit: 15 + image_file_size_limit: 10 + video_file_size_limit: 100 + workflow_file_upload_limit: 10 + image: + enabled: false + number_limits: 3 + transfer_methods: + - local_file + - remote_url + number_limits: 3 + opening_statement: '' + retriever_resource: + enabled: true + sensitive_word_avoidance: + enabled: false + speech_to_text: + enabled: false + suggested_questions: [] + suggested_questions_after_answer: + enabled: false + text_to_speech: + enabled: false + language: '' + voice: '' + graph: + edges: + - data: + isInIteration: false + isInLoop: false + sourceType: start + targetType: code + id: 1759032354471-source-1759032363865-target + source: '1759032354471' + sourceHandle: source + target: '1759032363865' + targetHandle: target + type: custom + zIndex: 0 + - data: + isInIteration: false + isInLoop: false + sourceType: code + targetType: iteration + id: 1759032363865-source-1759032379989-target + source: '1759032363865' + sourceHandle: source + target: '1759032379989' + targetHandle: target + type: custom + zIndex: 0 + - data: + isInIteration: true + isInLoop: false + iteration_id: '1759032379989' + sourceType: iteration-start + targetType: assigner + id: 1759032379989start-source-1759032394460-target + source: 1759032379989start + sourceHandle: source + target: '1759032394460' + targetHandle: target + type: custom + zIndex: 1002 + - data: + isInIteration: false + isInLoop: false + sourceType: iteration + targetType: answer + id: 1759032379989-source-1759032410331-target + source: '1759032379989' + sourceHandle: source + target: '1759032410331' + targetHandle: target + type: custom + zIndex: 0 + - data: + isInIteration: true + isInLoop: false + iteration_id: '1759032379989' + sourceType: assigner + targetType: code + id: 1759032394460-source-1759032476318-target + source: '1759032394460' + sourceHandle: source + target: '1759032476318' + targetHandle: target + type: custom + zIndex: 1002 + nodes: + - data: + selected: false + title: Start + type: start + variables: [] + height: 52 + id: '1759032354471' + position: + x: 30 + y: 302 + positionAbsolute: + x: 30 + y: 302 + selected: false + sourcePosition: right + targetPosition: left + type: custom + width: 242 + - data: + code: "\ndef main():\n return {\n \"result\": [1],\n }\n" + code_language: python3 + outputs: + result: + children: null + type: array[number] + selected: false + title: Code + type: code + variables: [] + height: 52 + id: '1759032363865' + position: + x: 332 + y: 302 + positionAbsolute: + x: 332 + y: 302 + sourcePosition: right + targetPosition: left + type: custom + width: 242 + - data: + error_handle_mode: terminated + height: 204 + is_parallel: false + iterator_input_type: array[number] + iterator_selector: + - '1759032363865' + - result + output_selector: + - '1759032476318' + - result + output_type: array[string] + parallel_nums: 10 + selected: false + start_node_id: 1759032379989start + title: Iteration + type: iteration + width: 808 + height: 204 + id: '1759032379989' + position: + x: 634 + y: 302 + positionAbsolute: + x: 634 + y: 302 + selected: true + sourcePosition: right + targetPosition: left + type: custom + width: 808 + zIndex: 1 + - data: + desc: '' + isInIteration: true + selected: false + title: '' + type: iteration-start + draggable: false + height: 48 + id: 1759032379989start + parentId: '1759032379989' + position: + x: 60 + y: 78 + positionAbsolute: + x: 694 + y: 380 + selectable: false + sourcePosition: right + targetPosition: left + type: custom-iteration-start + width: 44 + zIndex: 1002 + - data: + isInIteration: true + isInLoop: false + items: + - input_type: variable + operation: over-write + value: + - sys + - query + variable_selector: + - conversation + - answer + write_mode: over-write + iteration_id: '1759032379989' + selected: false + title: Variable Assigner + type: assigner + version: '2' + height: 84 + id: '1759032394460' + parentId: '1759032379989' + position: + x: 204 + y: 60 + positionAbsolute: + x: 838 + y: 362 + sourcePosition: right + targetPosition: left + type: custom + width: 242 + zIndex: 1002 + - data: + answer: '{{#conversation.answer#}}' + selected: false + title: Answer + type: answer + variables: [] + height: 104 + id: '1759032410331' + position: + x: 1502 + y: 302 + positionAbsolute: + x: 1502 + y: 302 + selected: false + sourcePosition: right + targetPosition: left + type: custom + width: 242 + - data: + code: "\ndef main():\n return {\n \"result\": '',\n }\n" + code_language: python3 + isInIteration: true + isInLoop: false + iteration_id: '1759032379989' + outputs: + result: + children: null + type: string + selected: false + title: Code 2 + type: code + variables: [] + height: 52 + id: '1759032476318' + parentId: '1759032379989' + position: + x: 506 + y: 76 + positionAbsolute: + x: 1140 + y: 378 + sourcePosition: right + targetPosition: left + type: custom + width: 242 + zIndex: 1002 + viewport: + x: 120.39999999999998 + y: 85.20000000000005 + zoom: 0.7 + rag_pipeline_variables: [] diff --git a/api/tests/integration_tests/.env.example b/api/tests/integration_tests/.env.example index 92df93fb13..23a0ecf714 100644 --- a/api/tests/integration_tests/.env.example +++ b/api/tests/integration_tests/.env.example @@ -167,7 +167,6 @@ INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH=4000 WORKFLOW_MAX_EXECUTION_STEPS=500 WORKFLOW_MAX_EXECUTION_TIME=1200 WORKFLOW_CALL_MAX_DEPTH=5 -WORKFLOW_PARALLEL_DEPTH_LIMIT=3 MAX_VARIABLE_SIZE=204800 # App configuration diff --git a/api/tests/integration_tests/controllers/console/app/test_chat_message_permissions.py b/api/tests/integration_tests/controllers/console/app/test_chat_message_permissions.py index c8d353ad0a..498ac56d5d 100644 --- a/api/tests/integration_tests/controllers/console/app/test_chat_message_permissions.py +++ b/api/tests/integration_tests/controllers/console/app/test_chat_message_permissions.py @@ -11,8 +11,8 @@ from controllers.console.app import completion as completion_api from controllers.console.app import message as message_api from controllers.console.app import wraps from libs.datetime_utils import naive_utc_now -from models import Account, App, Tenant -from models.account import TenantAccountRole +from models import App, Tenant +from models.account import Account, TenantAccountJoin, TenantAccountRole from models.model import AppMode from services.app_generate_service import AppGenerateService @@ -25,29 +25,42 @@ class TestChatMessageApiPermissions: """Create a mock App model for testing.""" app = App() app.id = str(uuid.uuid4()) - app.mode = AppMode.CHAT.value + app.mode = AppMode.CHAT app.tenant_id = str(uuid.uuid4()) app.status = "normal" return app @pytest.fixture - def mock_account(self): + def mock_account(self, monkeypatch: pytest.MonkeyPatch): """Create a mock Account for testing.""" - account = Account() - account.id = str(uuid.uuid4()) - account.name = "Test User" - account.email = "test@example.com" + account = Account( + name="Test User", + email="test@example.com", + ) account.last_active_at = naive_utc_now() account.created_at = naive_utc_now() account.updated_at = naive_utc_now() + account.id = str(uuid.uuid4()) # Create mock tenant - tenant = Tenant() + tenant = Tenant(name="Test Tenant") tenant.id = str(uuid.uuid4()) - tenant.name = "Test Tenant" - account._current_tenant = tenant + mock_session_instance = mock.Mock() + + mock_tenant_join = TenantAccountJoin(role=TenantAccountRole.OWNER) + monkeypatch.setattr(mock_session_instance, "scalar", mock.Mock(return_value=mock_tenant_join)) + + mock_scalars_result = mock.Mock() + mock_scalars_result.one.return_value = tenant + monkeypatch.setattr(mock_session_instance, "scalars", mock.Mock(return_value=mock_scalars_result)) + + mock_session_context = mock.Mock() + mock_session_context.__enter__.return_value = mock_session_instance + monkeypatch.setattr("models.account.Session", lambda _, expire_on_commit: mock_session_context) + + account.current_tenant = tenant return account @pytest.mark.parametrize( diff --git a/api/tests/integration_tests/controllers/console/app/test_description_validation.py b/api/tests/integration_tests/controllers/console/app/test_description_validation.py index 2d0ceac760..8160807e48 100644 --- a/api/tests/integration_tests/controllers/console/app/test_description_validation.py +++ b/api/tests/integration_tests/controllers/console/app/test_description_validation.py @@ -18,124 +18,87 @@ class TestAppDescriptionValidationUnit: """Unit tests for description validation function""" def test_validate_description_length_function(self): - """Test the _validate_description_length function directly""" - from controllers.console.app.app import _validate_description_length + """Test the validate_description_length function directly""" + from libs.validators import validate_description_length # Test valid descriptions - assert _validate_description_length("") == "" - assert _validate_description_length("x" * 400) == "x" * 400 - assert _validate_description_length(None) is None + assert validate_description_length("") == "" + assert validate_description_length("x" * 400) == "x" * 400 + assert validate_description_length(None) is None # Test invalid descriptions with pytest.raises(ValueError) as exc_info: - _validate_description_length("x" * 401) + validate_description_length("x" * 401) assert "Description cannot exceed 400 characters." in str(exc_info.value) with pytest.raises(ValueError) as exc_info: - _validate_description_length("x" * 500) + validate_description_length("x" * 500) assert "Description cannot exceed 400 characters." in str(exc_info.value) with pytest.raises(ValueError) as exc_info: - _validate_description_length("x" * 1000) + validate_description_length("x" * 1000) assert "Description cannot exceed 400 characters." in str(exc_info.value) - def test_validation_consistency_with_dataset(self): - """Test that App and Dataset validation functions are consistent""" - from controllers.console.app.app import _validate_description_length as app_validate - from controllers.console.datasets.datasets import _validate_description_length as dataset_validate - from controllers.service_api.dataset.dataset import _validate_description_length as service_dataset_validate - - # Test same valid inputs - valid_desc = "x" * 400 - assert app_validate(valid_desc) == dataset_validate(valid_desc) == service_dataset_validate(valid_desc) - assert app_validate("") == dataset_validate("") == service_dataset_validate("") - assert app_validate(None) == dataset_validate(None) == service_dataset_validate(None) - - # Test same invalid inputs produce same error - invalid_desc = "x" * 401 - - app_error = None - dataset_error = None - service_dataset_error = None - - try: - app_validate(invalid_desc) - except ValueError as e: - app_error = str(e) - - try: - dataset_validate(invalid_desc) - except ValueError as e: - dataset_error = str(e) - - try: - service_dataset_validate(invalid_desc) - except ValueError as e: - service_dataset_error = str(e) - - assert app_error == dataset_error == service_dataset_error - assert app_error == "Description cannot exceed 400 characters." - def test_boundary_values(self): """Test boundary values for description validation""" - from controllers.console.app.app import _validate_description_length + from libs.validators import validate_description_length # Test exact boundary exactly_400 = "x" * 400 - assert _validate_description_length(exactly_400) == exactly_400 + assert validate_description_length(exactly_400) == exactly_400 # Test just over boundary just_over_400 = "x" * 401 with pytest.raises(ValueError): - _validate_description_length(just_over_400) + validate_description_length(just_over_400) # Test just under boundary just_under_400 = "x" * 399 - assert _validate_description_length(just_under_400) == just_under_400 + assert validate_description_length(just_under_400) == just_under_400 def test_edge_cases(self): """Test edge cases for description validation""" - from controllers.console.app.app import _validate_description_length + from libs.validators import validate_description_length # Test None input - assert _validate_description_length(None) is None + assert validate_description_length(None) is None # Test empty string - assert _validate_description_length("") == "" + assert validate_description_length("") == "" # Test single character - assert _validate_description_length("a") == "a" + assert validate_description_length("a") == "a" # Test unicode characters unicode_desc = "测试" * 200 # 400 characters in Chinese - assert _validate_description_length(unicode_desc) == unicode_desc + assert validate_description_length(unicode_desc) == unicode_desc # Test unicode over limit unicode_over = "测试" * 201 # 402 characters with pytest.raises(ValueError): - _validate_description_length(unicode_over) + validate_description_length(unicode_over) def test_whitespace_handling(self): """Test how validation handles whitespace""" - from controllers.console.app.app import _validate_description_length + from libs.validators import validate_description_length # Test description with spaces spaces_400 = " " * 400 - assert _validate_description_length(spaces_400) == spaces_400 + assert validate_description_length(spaces_400) == spaces_400 # Test description with spaces over limit spaces_401 = " " * 401 with pytest.raises(ValueError): - _validate_description_length(spaces_401) + validate_description_length(spaces_401) # Test mixed content mixed_400 = "a" * 200 + " " * 200 - assert _validate_description_length(mixed_400) == mixed_400 + assert validate_description_length(mixed_400) == mixed_400 # Test mixed over limit mixed_401 = "a" * 200 + " " * 201 with pytest.raises(ValueError): - _validate_description_length(mixed_401) + validate_description_length(mixed_401) if __name__ == "__main__": diff --git a/api/tests/integration_tests/controllers/console/app/test_model_config_permissions.py b/api/tests/integration_tests/controllers/console/app/test_model_config_permissions.py index ca4d452963..04945e57a0 100644 --- a/api/tests/integration_tests/controllers/console/app/test_model_config_permissions.py +++ b/api/tests/integration_tests/controllers/console/app/test_model_config_permissions.py @@ -9,8 +9,8 @@ from flask.testing import FlaskClient from controllers.console.app import model_config as model_config_api from controllers.console.app import wraps from libs.datetime_utils import naive_utc_now -from models import Account, App, Tenant -from models.account import TenantAccountRole +from models import App, Tenant +from models.account import Account, TenantAccountJoin, TenantAccountRole from models.model import AppMode from services.app_model_config_service import AppModelConfigService @@ -23,30 +23,40 @@ class TestModelConfigResourcePermissions: """Create a mock App model for testing.""" app = App() app.id = str(uuid.uuid4()) - app.mode = AppMode.CHAT.value + app.mode = AppMode.CHAT app.tenant_id = str(uuid.uuid4()) app.status = "normal" app.app_model_config_id = str(uuid.uuid4()) return app @pytest.fixture - def mock_account(self): + def mock_account(self, monkeypatch: pytest.MonkeyPatch): """Create a mock Account for testing.""" - account = Account() + account = Account(name="Test User", email="test@example.com") account.id = str(uuid.uuid4()) - account.name = "Test User" - account.email = "test@example.com" account.last_active_at = naive_utc_now() account.created_at = naive_utc_now() account.updated_at = naive_utc_now() # Create mock tenant - tenant = Tenant() + tenant = Tenant(name="Test Tenant") tenant.id = str(uuid.uuid4()) - tenant.name = "Test Tenant" - account._current_tenant = tenant + mock_session_instance = mock.Mock() + + mock_tenant_join = TenantAccountJoin(role=TenantAccountRole.OWNER) + monkeypatch.setattr(mock_session_instance, "scalar", mock.Mock(return_value=mock_tenant_join)) + + mock_scalars_result = mock.Mock() + mock_scalars_result.one.return_value = tenant + monkeypatch.setattr(mock_session_instance, "scalars", mock.Mock(return_value=mock_scalars_result)) + + mock_session_context = mock.Mock() + mock_session_context.__enter__.return_value = mock_session_instance + monkeypatch.setattr("models.account.Session", lambda _, expire_on_commit: mock_session_context) + + account.current_tenant = tenant return account @pytest.mark.parametrize( diff --git a/api/tests/integration_tests/plugin/__mock/http.py b/api/tests/integration_tests/plugin/__mock/http.py index 8f8988899b..d5cf47e2c2 100644 --- a/api/tests/integration_tests/plugin/__mock/http.py +++ b/api/tests/integration_tests/plugin/__mock/http.py @@ -1,8 +1,8 @@ import os from typing import Literal +import httpx import pytest -import requests from core.plugin.entities.plugin_daemon import PluginDaemonBasicResponse from core.tools.entities.common_entities import I18nObject @@ -27,13 +27,11 @@ class MockedHttp: @classmethod def requests_request( cls, method: Literal["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD"], url: str, **kwargs - ) -> requests.Response: + ) -> httpx.Response: """ - Mocked requests.request + Mocked httpx.request """ - request = requests.PreparedRequest() - request.method = method - request.url = url + request = httpx.Request(method, url) if url.endswith("/tools"): content = PluginDaemonBasicResponse[list[ToolProviderEntity]]( code=0, message="success", data=cls.list_tools() @@ -41,8 +39,7 @@ class MockedHttp: else: raise ValueError("") - response = requests.Response() - response.status_code = 200 + response = httpx.Response(status_code=200) response.request = request response._content = content.encode("utf-8") return response @@ -54,7 +51,7 @@ MOCK_SWITCH = os.getenv("MOCK_SWITCH", "false").lower() == "true" @pytest.fixture def setup_http_mock(request, monkeypatch: pytest.MonkeyPatch): if MOCK_SWITCH: - monkeypatch.setattr(requests, "request", MockedHttp.requests_request) + monkeypatch.setattr(httpx, "request", MockedHttp.requests_request) def unpatch(): monkeypatch.undo() diff --git a/api/tests/integration_tests/services/test_workflow_draft_variable_service.py b/api/tests/integration_tests/services/test_workflow_draft_variable_service.py index aeee882750..f3a5ba0d11 100644 --- a/api/tests/integration_tests/services/test_workflow_draft_variable_service.py +++ b/api/tests/integration_tests/services/test_workflow_draft_variable_service.py @@ -542,7 +542,7 @@ class TestWorkflowDraftVariableServiceResetVariable(unittest.TestCase): index=1, node_execution_id=str(uuid.uuid4()), node_id=self._node_id, - node_type=NodeType.LLM.value, + node_type=NodeType.LLM, title="Test Node", inputs='{"input": "test input"}', process_data='{"test_var": "process_value", "other_var": "other_process"}', diff --git a/api/tests/integration_tests/tools/api_tool/test_api_tool.py b/api/tests/integration_tests/tools/api_tool/test_api_tool.py index 7c1a200c8f..e637530265 100644 --- a/api/tests/integration_tests/tools/api_tool/test_api_tool.py +++ b/api/tests/integration_tests/tools/api_tool/test_api_tool.py @@ -36,7 +36,7 @@ def test_api_tool(setup_http_mock): entity=ToolEntity( identity=ToolIdentity(provider="", author="", name="", label=I18nObject(en_US="test tool")), ), - api_bundle=ApiToolBundle(**tool_bundle), + api_bundle=ApiToolBundle.model_validate(tool_bundle), runtime=ToolRuntime(tenant_id="", credentials={"auth_type": "none"}), provider_id="test_tool", ) diff --git a/api/tests/integration_tests/vdb/__mock/baiduvectordb.py b/api/tests/integration_tests/vdb/__mock/baiduvectordb.py index 6d2aff5197..8a43d03a43 100644 --- a/api/tests/integration_tests/vdb/__mock/baiduvectordb.py +++ b/api/tests/integration_tests/vdb/__mock/baiduvectordb.py @@ -1,5 +1,6 @@ import os from collections import UserDict +from typing import Any from unittest.mock import MagicMock import pytest @@ -9,7 +10,6 @@ from pymochow.model.database import Database # type: ignore from pymochow.model.enum import IndexState, IndexType, MetricType, ReadConsistency, TableState # type: ignore from pymochow.model.schema import HNSWParams, VectorIndex # type: ignore from pymochow.model.table import Table # type: ignore -from requests.adapters import HTTPAdapter class AttrDict(UserDict): @@ -21,7 +21,7 @@ class MockBaiduVectorDBClass: def mock_vector_db_client( self, config=None, - adapter: HTTPAdapter | None = None, + adapter: Any | None = None, ): self.conn = MagicMock() self._config = MagicMock() diff --git a/api/tests/integration_tests/vdb/__mock/huaweicloudvectordb.py b/api/tests/integration_tests/vdb/__mock/huaweicloudvectordb.py index 9706c52455..9e24672317 100644 --- a/api/tests/integration_tests/vdb/__mock/huaweicloudvectordb.py +++ b/api/tests/integration_tests/vdb/__mock/huaweicloudvectordb.py @@ -44,25 +44,25 @@ class MockClient: "hits": [ { "_source": { - Field.CONTENT_KEY.value: "abcdef", - Field.VECTOR.value: [1, 2], - Field.METADATA_KEY.value: {}, + Field.CONTENT_KEY: "abcdef", + Field.VECTOR: [1, 2], + Field.METADATA_KEY: {}, }, "_score": 1.0, }, { "_source": { - Field.CONTENT_KEY.value: "123456", - Field.VECTOR.value: [2, 2], - Field.METADATA_KEY.value: {}, + Field.CONTENT_KEY: "123456", + Field.VECTOR: [2, 2], + Field.METADATA_KEY: {}, }, "_score": 0.9, }, { "_source": { - Field.CONTENT_KEY.value: "a1b2c3", - Field.VECTOR.value: [3, 2], - Field.METADATA_KEY.value: {}, + Field.CONTENT_KEY: "a1b2c3", + Field.VECTOR: [3, 2], + Field.METADATA_KEY: {}, }, "_score": 0.8, }, diff --git a/api/tests/integration_tests/vdb/__mock/tcvectordb.py b/api/tests/integration_tests/vdb/__mock/tcvectordb.py index e0b908cece..5130fcfe17 100644 --- a/api/tests/integration_tests/vdb/__mock/tcvectordb.py +++ b/api/tests/integration_tests/vdb/__mock/tcvectordb.py @@ -1,9 +1,8 @@ import os -from typing import Union +from typing import Any, Union import pytest from _pytest.monkeypatch import MonkeyPatch -from requests.adapters import HTTPAdapter from tcvectordb import RPCVectorDBClient # type: ignore from tcvectordb.model import enum from tcvectordb.model.collection import FilterIndexConfig @@ -23,7 +22,7 @@ class MockTcvectordbClass: key="", read_consistency: ReadConsistency = ReadConsistency.EVENTUAL_CONSISTENCY, timeout=10, - adapter: HTTPAdapter | None = None, + adapter: Any | None = None, pool_size: int = 2, proxies: dict | None = None, password: str | None = None, diff --git a/api/tests/integration_tests/vdb/__mock/vikingdb.py b/api/tests/integration_tests/vdb/__mock/vikingdb.py index 3ad72e5550..f351df8d5b 100644 --- a/api/tests/integration_tests/vdb/__mock/vikingdb.py +++ b/api/tests/integration_tests/vdb/__mock/vikingdb.py @@ -40,13 +40,13 @@ class MockVikingDBClass: collection_name=collection_name, description="Collection For Dify", viking_db_service=self._viking_db_service, - primary_key=vdb_Field.PRIMARY_KEY.value, + primary_key=vdb_Field.PRIMARY_KEY, fields=[ - Field(field_name=vdb_Field.PRIMARY_KEY.value, field_type=FieldType.String, is_primary_key=True), - Field(field_name=vdb_Field.METADATA_KEY.value, field_type=FieldType.String), - Field(field_name=vdb_Field.GROUP_KEY.value, field_type=FieldType.String), - Field(field_name=vdb_Field.CONTENT_KEY.value, field_type=FieldType.Text), - Field(field_name=vdb_Field.VECTOR.value, field_type=FieldType.Vector, dim=768), + Field(field_name=vdb_Field.PRIMARY_KEY, field_type=FieldType.String, is_primary_key=True), + Field(field_name=vdb_Field.METADATA_KEY, field_type=FieldType.String), + Field(field_name=vdb_Field.GROUP_KEY, field_type=FieldType.String), + Field(field_name=vdb_Field.CONTENT_KEY, field_type=FieldType.Text), + Field(field_name=vdb_Field.VECTOR, field_type=FieldType.Vector, dim=768), ], indexes=[ Index( @@ -71,7 +71,7 @@ class MockVikingDBClass: return Collection( collection_name=collection_name, description=description, - primary_key=vdb_Field.PRIMARY_KEY.value, + primary_key=vdb_Field.PRIMARY_KEY, viking_db_service=self._viking_db_service, fields=fields, ) @@ -126,11 +126,11 @@ class MockVikingDBClass: def fetch_data(self, id: Union[str, list[str], int, list[int]]): return Data( fields={ - vdb_Field.GROUP_KEY.value: "test_group", - vdb_Field.METADATA_KEY.value: "{}", - vdb_Field.CONTENT_KEY.value: "content", - vdb_Field.PRIMARY_KEY.value: id, - vdb_Field.VECTOR.value: [-0.00762577411336441, -0.01949881482151406, 0.008832383941428398], + vdb_Field.GROUP_KEY: "test_group", + vdb_Field.METADATA_KEY: "{}", + vdb_Field.CONTENT_KEY: "content", + vdb_Field.PRIMARY_KEY: id, + vdb_Field.VECTOR: [-0.00762577411336441, -0.01949881482151406, 0.008832383941428398], }, id=id, ) @@ -151,16 +151,16 @@ class MockVikingDBClass: return [ Data( fields={ - vdb_Field.GROUP_KEY.value: "test_group", - vdb_Field.METADATA_KEY.value: '\ + vdb_Field.GROUP_KEY: "test_group", + vdb_Field.METADATA_KEY: '\ {"source": "/var/folders/ml/xxx/xxx.txt", \ "document_id": "test_document_id", \ "dataset_id": "test_dataset_id", \ "doc_id": "test_id", \ "doc_hash": "test_hash"}', - vdb_Field.CONTENT_KEY.value: "content", - vdb_Field.PRIMARY_KEY.value: "test_id", - vdb_Field.VECTOR.value: vector, + vdb_Field.CONTENT_KEY: "content", + vdb_Field.PRIMARY_KEY: "test_id", + vdb_Field.VECTOR: vector, }, id="test_id", score=0.10, @@ -173,16 +173,16 @@ class MockVikingDBClass: return [ Data( fields={ - vdb_Field.GROUP_KEY.value: "test_group", - vdb_Field.METADATA_KEY.value: '\ + vdb_Field.GROUP_KEY: "test_group", + vdb_Field.METADATA_KEY: '\ {"source": "/var/folders/ml/xxx/xxx.txt", \ "document_id": "test_document_id", \ "dataset_id": "test_dataset_id", \ "doc_id": "test_id", \ "doc_hash": "test_hash"}', - vdb_Field.CONTENT_KEY.value: "content", - vdb_Field.PRIMARY_KEY.value: "test_id", - vdb_Field.VECTOR.value: [-0.00762577411336441, -0.01949881482151406, 0.008832383941428398], + vdb_Field.CONTENT_KEY: "content", + vdb_Field.PRIMARY_KEY: "test_id", + vdb_Field.VECTOR: [-0.00762577411336441, -0.01949881482151406, 0.008832383941428398], }, id="test_id", score=0.10, diff --git a/api/tests/integration_tests/vdb/clickzetta/test_docker_integration.py b/api/tests/integration_tests/vdb/clickzetta/test_docker_integration.py index ef54eaa174..60e3f30f26 100644 --- a/api/tests/integration_tests/vdb/clickzetta/test_docker_integration.py +++ b/api/tests/integration_tests/vdb/clickzetta/test_docker_integration.py @@ -6,7 +6,7 @@ Test Clickzetta integration in Docker environment import os import time -import requests +import httpx from clickzetta import connect @@ -66,7 +66,7 @@ def test_dify_api(): max_retries = 30 for i in range(max_retries): try: - response = requests.get(f"{base_url}/console/api/health") + response = httpx.get(f"{base_url}/console/api/health") if response.status_code == 200: print("✓ Dify API is ready") break diff --git a/api/tests/integration_tests/vdb/opensearch/test_opensearch.py b/api/tests/integration_tests/vdb/opensearch/test_opensearch.py index 2d44dd2924..192c995ce5 100644 --- a/api/tests/integration_tests/vdb/opensearch/test_opensearch.py +++ b/api/tests/integration_tests/vdb/opensearch/test_opensearch.py @@ -129,8 +129,8 @@ class TestOpenSearchVector: "hits": [ { "_source": { - Field.CONTENT_KEY.value: get_example_text(), - Field.METADATA_KEY.value: {"document_id": self.example_doc_id}, + Field.CONTENT_KEY: get_example_text(), + Field.METADATA_KEY: {"document_id": self.example_doc_id}, }, "_score": 1.0, } diff --git a/api/tests/integration_tests/workflow/nodes/test_code.py b/api/tests/integration_tests/workflow/nodes/test_code.py index e2f3a74bf9..b62d8aa544 100644 --- a/api/tests/integration_tests/workflow/nodes/test_code.py +++ b/api/tests/integration_tests/workflow/nodes/test_code.py @@ -1,9 +1,9 @@ import time import uuid -from os import getenv import pytest +from configs import dify_config from core.app.entities.app_invoke_entities import InvokeFrom from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool from core.workflow.enums import WorkflowNodeExecutionStatus @@ -15,7 +15,7 @@ from core.workflow.system_variable import SystemVariable from models.enums import UserFrom from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock -CODE_MAX_STRING_LENGTH = int(getenv("CODE_MAX_STRING_LENGTH", "10000")) +CODE_MAX_STRING_LENGTH = dify_config.CODE_MAX_STRING_LENGTH def init_code_node(code_config: dict): diff --git a/api/tests/test_containers_integration_tests/conftest.py b/api/tests/test_containers_integration_tests/conftest.py index 145e31bca0..180ee1c963 100644 --- a/api/tests/test_containers_integration_tests/conftest.py +++ b/api/tests/test_containers_integration_tests/conftest.py @@ -18,6 +18,7 @@ from flask.testing import FlaskClient from sqlalchemy import Engine, text from sqlalchemy.orm import Session from testcontainers.core.container import DockerContainer +from testcontainers.core.network import Network from testcontainers.core.waiting_utils import wait_for_logs from testcontainers.postgres import PostgresContainer from testcontainers.redis import RedisContainer @@ -41,6 +42,7 @@ class DifyTestContainers: def __init__(self): """Initialize container management with default configurations.""" + self.network: Network | None = None self.postgres: PostgresContainer | None = None self.redis: RedisContainer | None = None self.dify_sandbox: DockerContainer | None = None @@ -62,12 +64,18 @@ class DifyTestContainers: logger.info("Starting test containers for Dify integration tests...") + # Create Docker network for container communication + logger.info("Creating Docker network for container communication...") + self.network = Network() + self.network.create() + logger.info("Docker network created successfully with name: %s", self.network.name) + # Start PostgreSQL container for main application database # PostgreSQL is used for storing user data, workflows, and application state logger.info("Initializing PostgreSQL container...") self.postgres = PostgresContainer( image="postgres:14-alpine", - ) + ).with_network(self.network) self.postgres.start() db_host = self.postgres.get_container_host_ip() db_port = self.postgres.get_exposed_port(5432) @@ -137,7 +145,7 @@ class DifyTestContainers: # Start Redis container for caching and session management # Redis is used for storing session data, cache entries, and temporary data logger.info("Initializing Redis container...") - self.redis = RedisContainer(image="redis:6-alpine", port=6379) + self.redis = RedisContainer(image="redis:6-alpine", port=6379).with_network(self.network) self.redis.start() redis_host = self.redis.get_container_host_ip() redis_port = self.redis.get_exposed_port(6379) @@ -153,7 +161,7 @@ class DifyTestContainers: # Start Dify Sandbox container for code execution environment # Dify Sandbox provides a secure environment for executing user code logger.info("Initializing Dify Sandbox container...") - self.dify_sandbox = DockerContainer(image="langgenius/dify-sandbox:latest") + self.dify_sandbox = DockerContainer(image="langgenius/dify-sandbox:latest").with_network(self.network) self.dify_sandbox.with_exposed_ports(8194) self.dify_sandbox.env = { "API_KEY": "test_api_key", @@ -173,22 +181,28 @@ class DifyTestContainers: # Start Dify Plugin Daemon container for plugin management # Dify Plugin Daemon provides plugin lifecycle management and execution logger.info("Initializing Dify Plugin Daemon container...") - self.dify_plugin_daemon = DockerContainer(image="langgenius/dify-plugin-daemon:0.2.0-local") + self.dify_plugin_daemon = DockerContainer(image="langgenius/dify-plugin-daemon:0.3.0-local").with_network( + self.network + ) self.dify_plugin_daemon.with_exposed_ports(5002) + # Get container internal network addresses + postgres_container_name = self.postgres.get_wrapped_container().name + redis_container_name = self.redis.get_wrapped_container().name + self.dify_plugin_daemon.env = { - "DB_HOST": db_host, - "DB_PORT": str(db_port), + "DB_HOST": postgres_container_name, # Use container name for internal network communication + "DB_PORT": "5432", # Use internal port "DB_USERNAME": self.postgres.username, "DB_PASSWORD": self.postgres.password, "DB_DATABASE": "dify_plugin", - "REDIS_HOST": redis_host, - "REDIS_PORT": str(redis_port), + "REDIS_HOST": redis_container_name, # Use container name for internal network communication + "REDIS_PORT": "6379", # Use internal port "REDIS_PASSWORD": "", "SERVER_PORT": "5002", "SERVER_KEY": "test_plugin_daemon_key", "MAX_PLUGIN_PACKAGE_SIZE": "52428800", "PPROF_ENABLED": "false", - "DIFY_INNER_API_URL": f"http://{db_host}:5001", + "DIFY_INNER_API_URL": f"http://{postgres_container_name}:5001", "DIFY_INNER_API_KEY": "test_inner_api_key", "PLUGIN_REMOTE_INSTALLING_HOST": "0.0.0.0", "PLUGIN_REMOTE_INSTALLING_PORT": "5003", @@ -253,6 +267,15 @@ class DifyTestContainers: # Log error but don't fail the test cleanup logger.warning("Failed to stop container %s: %s", container, e) + # Stop and remove the network + if self.network: + try: + logger.info("Removing Docker network...") + self.network.remove() + logger.info("Successfully removed Docker network") + except Exception as e: + logger.warning("Failed to remove Docker network: %s", e) + self._containers_started = False logger.info("All test containers stopped and cleaned up successfully") diff --git a/api/tests/test_containers_integration_tests/services/test_account_service.py b/api/tests/test_containers_integration_tests/services/test_account_service.py index c98406d845..6eff73a8f3 100644 --- a/api/tests/test_containers_integration_tests/services/test_account_service.py +++ b/api/tests/test_containers_integration_tests/services/test_account_service.py @@ -16,6 +16,7 @@ from services.errors.account import ( AccountPasswordError, AccountRegisterError, CurrentPasswordIncorrectError, + TenantNotFoundError, ) from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkspacesLimitExceededError @@ -63,7 +64,7 @@ class TestAccountService: password=password, ) assert account.email == email - assert account.status == AccountStatus.ACTIVE.value + assert account.status == AccountStatus.ACTIVE # Login with correct password logged_in = AccountService.authenticate(email, password) @@ -184,7 +185,7 @@ class TestAccountService: ) # Ban the account - account.status = AccountStatus.BANNED.value + account.status = AccountStatus.BANNED from extensions.ext_database import db db.session.commit() @@ -268,14 +269,14 @@ class TestAccountService: interface_language="en-US", password=password, ) - account.status = AccountStatus.PENDING.value + account.status = AccountStatus.PENDING from extensions.ext_database import db db.session.commit() # Authenticate should activate the account authenticated_account = AccountService.authenticate(email, password) - assert authenticated_account.status == AccountStatus.ACTIVE.value + assert authenticated_account.status == AccountStatus.ACTIVE assert authenticated_account.initialized_at is not None def test_update_account_password_success(self, db_session_with_containers, mock_external_service_dependencies): @@ -538,7 +539,7 @@ class TestAccountService: from extensions.ext_database import db db.session.refresh(account) - assert account.status == AccountStatus.CLOSED.value + assert account.status == AccountStatus.CLOSED def test_update_account_fields(self, db_session_with_containers, mock_external_service_dependencies): """ @@ -678,7 +679,7 @@ class TestAccountService: interface_language="en-US", password=password, ) - account.status = AccountStatus.PENDING.value + account.status = AccountStatus.PENDING from extensions.ext_database import db db.session.commit() @@ -687,7 +688,7 @@ class TestAccountService: token_pair = AccountService.login(account) db.session.refresh(account) - assert account.status == AccountStatus.ACTIVE.value + assert account.status == AccountStatus.ACTIVE def test_logout(self, db_session_with_containers, mock_external_service_dependencies): """ @@ -859,7 +860,7 @@ class TestAccountService: ) # Ban the account - account.status = AccountStatus.BANNED.value + account.status = AccountStatus.BANNED from extensions.ext_database import db db.session.commit() @@ -989,7 +990,7 @@ class TestAccountService: ) # Ban the account - account.status = AccountStatus.BANNED.value + account.status = AccountStatus.BANNED from extensions.ext_database import db db.session.commit() @@ -1414,7 +1415,7 @@ class TestTenantService: ) # Try to get current tenant (should fail) - with pytest.raises(AttributeError): + with pytest.raises((AttributeError, TenantNotFoundError)): TenantService.get_current_tenant_by_account(account) def test_switch_tenant_success(self, db_session_with_containers, mock_external_service_dependencies): diff --git a/api/tests/test_containers_integration_tests/services/test_app_generate_service.py b/api/tests/test_containers_integration_tests/services/test_app_generate_service.py index ca0f309fd4..9386687a04 100644 --- a/api/tests/test_containers_integration_tests/services/test_app_generate_service.py +++ b/api/tests/test_containers_integration_tests/services/test_app_generate_service.py @@ -3,7 +3,6 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker -from openai._exceptions import RateLimitError from core.app.entities.app_invoke_entities import InvokeFrom from models.model import EndUser @@ -484,36 +483,6 @@ class TestAppGenerateService: # Verify error message assert "Rate limit exceeded" in str(exc_info.value) - def test_generate_with_rate_limit_error_from_openai( - self, db_session_with_containers, mock_external_service_dependencies - ): - """ - Test generation when OpenAI rate limit error occurs. - """ - fake = Faker() - app, account = self._create_test_app_and_account( - db_session_with_containers, mock_external_service_dependencies, mode="completion" - ) - - # Setup completion generator to raise RateLimitError - mock_response = MagicMock() - mock_response.request = MagicMock() - mock_external_service_dependencies["completion_generator"].return_value.generate.side_effect = RateLimitError( - "Rate limit exceeded", response=mock_response, body=None - ) - - # Setup test arguments - args = {"inputs": {"query": fake.text(max_nb_chars=50)}, "response_mode": "streaming"} - - # Execute the method under test and expect rate limit error - with pytest.raises(InvokeRateLimitError) as exc_info: - AppGenerateService.generate( - app_model=app, user=account, args=args, invoke_from=InvokeFrom.SERVICE_API, streaming=True - ) - - # Verify error message - assert "Rate limit exceeded" in str(exc_info.value) - def test_generate_with_invalid_app_mode(self, db_session_with_containers, mock_external_service_dependencies): """ Test generation with invalid app mode. diff --git a/api/tests/test_containers_integration_tests/services/test_file_service.py b/api/tests/test_containers_integration_tests/services/test_file_service.py index 5598c5bc0c..e6bfc157c7 100644 --- a/api/tests/test_containers_integration_tests/services/test_file_service.py +++ b/api/tests/test_containers_integration_tests/services/test_file_service.py @@ -86,7 +86,7 @@ class TestFileService: join = TenantAccountJoin( tenant_id=tenant.id, account_id=account.id, - role=TenantAccountRole.OWNER.value, + role=TenantAccountRole.OWNER, current=True, ) db.session.add(join) @@ -187,7 +187,7 @@ class TestFileService: assert upload_file.extension == "pdf" assert upload_file.mime_type == mimetype assert upload_file.created_by == account.id - assert upload_file.created_by_role == CreatorUserRole.ACCOUNT.value + assert upload_file.created_by_role == CreatorUserRole.ACCOUNT assert upload_file.used is False assert upload_file.hash == hashlib.sha3_256(content).hexdigest() @@ -216,7 +216,7 @@ class TestFileService: assert upload_file is not None assert upload_file.created_by == end_user.id - assert upload_file.created_by_role == CreatorUserRole.END_USER.value + assert upload_file.created_by_role == CreatorUserRole.END_USER def test_upload_file_with_datasets_source( self, db_session_with_containers, engine, mock_external_service_dependencies diff --git a/api/tests/test_containers_integration_tests/services/test_metadata_service.py b/api/tests/test_containers_integration_tests/services/test_metadata_service.py index d0f7e945f1..253791cc2d 100644 --- a/api/tests/test_containers_integration_tests/services/test_metadata_service.py +++ b/api/tests/test_containers_integration_tests/services/test_metadata_service.py @@ -72,7 +72,7 @@ class TestMetadataService: join = TenantAccountJoin( tenant_id=tenant.id, account_id=account.id, - role=TenantAccountRole.OWNER.value, + role=TenantAccountRole.OWNER, current=True, ) db.session.add(join) diff --git a/api/tests/test_containers_integration_tests/services/test_model_load_balancing_service.py b/api/tests/test_containers_integration_tests/services/test_model_load_balancing_service.py index 66527dd506..8a72331425 100644 --- a/api/tests/test_containers_integration_tests/services/test_model_load_balancing_service.py +++ b/api/tests/test_containers_integration_tests/services/test_model_load_balancing_service.py @@ -103,7 +103,7 @@ class TestModelLoadBalancingService: join = TenantAccountJoin( tenant_id=tenant.id, account_id=account.id, - role=TenantAccountRole.OWNER.value, + role=TenantAccountRole.OWNER, current=True, ) db.session.add(join) diff --git a/api/tests/test_containers_integration_tests/services/test_model_provider_service.py b/api/tests/test_containers_integration_tests/services/test_model_provider_service.py index 2196da8b3e..fb319a4963 100644 --- a/api/tests/test_containers_integration_tests/services/test_model_provider_service.py +++ b/api/tests/test_containers_integration_tests/services/test_model_provider_service.py @@ -67,7 +67,7 @@ class TestModelProviderService: join = TenantAccountJoin( tenant_id=tenant.id, account_id=account.id, - role=TenantAccountRole.OWNER.value, + role=TenantAccountRole.OWNER, current=True, ) db.session.add(join) diff --git a/api/tests/test_containers_integration_tests/services/test_tag_service.py b/api/tests/test_containers_integration_tests/services/test_tag_service.py index 04cff397b2..3d1226019b 100644 --- a/api/tests/test_containers_integration_tests/services/test_tag_service.py +++ b/api/tests/test_containers_integration_tests/services/test_tag_service.py @@ -66,7 +66,7 @@ class TestTagService: join = TenantAccountJoin( tenant_id=tenant.id, account_id=account.id, - role=TenantAccountRole.OWNER.value, + role=TenantAccountRole.OWNER, current=True, ) db.session.add(join) diff --git a/api/tests/test_containers_integration_tests/services/test_web_conversation_service.py b/api/tests/test_containers_integration_tests/services/test_web_conversation_service.py index c9ace46c55..5db7901cbc 100644 --- a/api/tests/test_containers_integration_tests/services/test_web_conversation_service.py +++ b/api/tests/test_containers_integration_tests/services/test_web_conversation_service.py @@ -144,7 +144,7 @@ class TestWebConversationService: system_instruction=fake.text(max_nb_chars=300), system_instruction_tokens=50, status="normal", - invoke_from=InvokeFrom.WEB_APP.value, + invoke_from=InvokeFrom.WEB_APP, from_source="console" if isinstance(user, Account) else "api", from_end_user_id=user.id if isinstance(user, EndUser) else None, from_account_id=user.id if isinstance(user, Account) else None, diff --git a/api/tests/test_containers_integration_tests/services/test_webapp_auth_service.py b/api/tests/test_containers_integration_tests/services/test_webapp_auth_service.py index 316cfe1674..059767458a 100644 --- a/api/tests/test_containers_integration_tests/services/test_webapp_auth_service.py +++ b/api/tests/test_containers_integration_tests/services/test_webapp_auth_service.py @@ -87,7 +87,7 @@ class TestWebAppAuthService: join = TenantAccountJoin( tenant_id=tenant.id, account_id=account.id, - role=TenantAccountRole.OWNER.value, + role=TenantAccountRole.OWNER, current=True, ) db.session.add(join) @@ -150,7 +150,7 @@ class TestWebAppAuthService: join = TenantAccountJoin( tenant_id=tenant.id, account_id=account.id, - role=TenantAccountRole.OWNER.value, + role=TenantAccountRole.OWNER, current=True, ) db.session.add(join) @@ -232,7 +232,7 @@ class TestWebAppAuthService: assert result.id == account.id assert result.email == account.email assert result.name == account.name - assert result.status == AccountStatus.ACTIVE.value + assert result.status == AccountStatus.ACTIVE # Verify database state from extensions.ext_database import db @@ -280,7 +280,7 @@ class TestWebAppAuthService: email=fake.email(), name=fake.name(), interface_language="en-US", - status=AccountStatus.BANNED.value, + status=AccountStatus.BANNED, ) # Hash password @@ -411,7 +411,7 @@ class TestWebAppAuthService: assert result.id == account.id assert result.email == account.email assert result.name == account.name - assert result.status == AccountStatus.ACTIVE.value + assert result.status == AccountStatus.ACTIVE # Verify database state from extensions.ext_database import db @@ -455,7 +455,7 @@ class TestWebAppAuthService: email=unique_email, name=fake.name(), interface_language="en-US", - status=AccountStatus.BANNED.value, + status=AccountStatus.BANNED, ) from extensions.ext_database import db diff --git a/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py b/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py index 2e18184aea..62c9bead86 100644 --- a/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py +++ b/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py @@ -199,7 +199,7 @@ class TestWorkflowAppService: elapsed_time=1.5, total_tokens=100, total_steps=3, - created_by_role=CreatorUserRole.ACCOUNT.value, + created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, created_at=datetime.now(UTC), finished_at=datetime.now(UTC), @@ -215,7 +215,7 @@ class TestWorkflowAppService: workflow_id=workflow.id, workflow_run_id=workflow_run.id, created_from="service-api", - created_by_role=CreatorUserRole.ACCOUNT.value, + created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, created_at=datetime.now(UTC), ) @@ -356,7 +356,7 @@ class TestWorkflowAppService: elapsed_time=1.0 + i, total_tokens=100 + i * 10, total_steps=3, - created_by_role=CreatorUserRole.ACCOUNT.value, + created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, created_at=datetime.now(UTC) + timedelta(minutes=i), finished_at=datetime.now(UTC) + timedelta(minutes=i + 1) if status != "running" else None, @@ -371,7 +371,7 @@ class TestWorkflowAppService: workflow_id=workflow.id, workflow_run_id=workflow_run.id, created_from="service-api", - created_by_role=CreatorUserRole.ACCOUNT.value, + created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, created_at=datetime.now(UTC) + timedelta(minutes=i), ) @@ -464,7 +464,7 @@ class TestWorkflowAppService: elapsed_time=1.0, total_tokens=100, total_steps=3, - created_by_role=CreatorUserRole.ACCOUNT.value, + created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, created_at=timestamp, finished_at=timestamp + timedelta(minutes=1), @@ -479,7 +479,7 @@ class TestWorkflowAppService: workflow_id=workflow.id, workflow_run_id=workflow_run.id, created_from="service-api", - created_by_role=CreatorUserRole.ACCOUNT.value, + created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, created_at=timestamp, ) @@ -571,7 +571,7 @@ class TestWorkflowAppService: elapsed_time=1.0, total_tokens=100, total_steps=3, - created_by_role=CreatorUserRole.ACCOUNT.value, + created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, created_at=datetime.now(UTC) + timedelta(minutes=i), finished_at=datetime.now(UTC) + timedelta(minutes=i + 1), @@ -586,7 +586,7 @@ class TestWorkflowAppService: workflow_id=workflow.id, workflow_run_id=workflow_run.id, created_from="service-api", - created_by_role=CreatorUserRole.ACCOUNT.value, + created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, created_at=datetime.now(UTC) + timedelta(minutes=i), ) @@ -701,7 +701,7 @@ class TestWorkflowAppService: elapsed_time=1.0, total_tokens=100, total_steps=3, - created_by_role=CreatorUserRole.ACCOUNT.value, + created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, created_at=datetime.now(UTC) + timedelta(minutes=i), finished_at=datetime.now(UTC) + timedelta(minutes=i + 1), @@ -716,7 +716,7 @@ class TestWorkflowAppService: workflow_id=workflow.id, workflow_run_id=workflow_run.id, created_from="service-api", - created_by_role=CreatorUserRole.ACCOUNT.value, + created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, created_at=datetime.now(UTC) + timedelta(minutes=i), ) @@ -743,7 +743,7 @@ class TestWorkflowAppService: elapsed_time=1.0, total_tokens=100, total_steps=3, - created_by_role=CreatorUserRole.END_USER.value, + created_by_role=CreatorUserRole.END_USER, created_by=end_user.id, created_at=datetime.now(UTC) + timedelta(minutes=i + 10), finished_at=datetime.now(UTC) + timedelta(minutes=i + 11), @@ -758,7 +758,7 @@ class TestWorkflowAppService: workflow_id=workflow.id, workflow_run_id=workflow_run.id, created_from="web-app", - created_by_role=CreatorUserRole.END_USER.value, + created_by_role=CreatorUserRole.END_USER, created_by=end_user.id, created_at=datetime.now(UTC) + timedelta(minutes=i + 10), ) @@ -780,14 +780,14 @@ class TestWorkflowAppService: limit=20, ) assert result_session_filter["total"] == 2 - assert all(log.created_by_role == CreatorUserRole.END_USER.value for log in result_session_filter["data"]) + assert all(log.created_by_role == CreatorUserRole.END_USER for log in result_session_filter["data"]) # Test filtering by account email result_account_filter = service.get_paginate_workflow_app_logs( session=db_session_with_containers, app_model=app, created_by_account=account.email, page=1, limit=20 ) assert result_account_filter["total"] == 3 - assert all(log.created_by_role == CreatorUserRole.ACCOUNT.value for log in result_account_filter["data"]) + assert all(log.created_by_role == CreatorUserRole.ACCOUNT for log in result_account_filter["data"]) # Test filtering by non-existent session ID result_no_session = service.get_paginate_workflow_app_logs( @@ -853,7 +853,7 @@ class TestWorkflowAppService: elapsed_time=1.0, total_tokens=100, total_steps=3, - created_by_role=CreatorUserRole.ACCOUNT.value, + created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, created_at=datetime.now(UTC), finished_at=datetime.now(UTC) + timedelta(minutes=1), @@ -869,7 +869,7 @@ class TestWorkflowAppService: workflow_id=workflow.id, workflow_run_id=workflow_run.id, created_from="service-api", - created_by_role=CreatorUserRole.ACCOUNT.value, + created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, created_at=datetime.now(UTC), ) @@ -943,7 +943,7 @@ class TestWorkflowAppService: elapsed_time=0.0, # Edge case: 0 elapsed time total_tokens=0, # Edge case: 0 tokens total_steps=0, # Edge case: 0 steps - created_by_role=CreatorUserRole.ACCOUNT.value, + created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, created_at=datetime.now(UTC), finished_at=datetime.now(UTC), @@ -959,7 +959,7 @@ class TestWorkflowAppService: workflow_id=workflow.id, workflow_run_id=workflow_run.id, created_from="service-api", - created_by_role=CreatorUserRole.ACCOUNT.value, + created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, created_at=datetime.now(UTC), ) @@ -1098,7 +1098,7 @@ class TestWorkflowAppService: elapsed_time=1.5, total_tokens=100, total_steps=3, - created_by_role=CreatorUserRole.ACCOUNT.value, + created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, created_at=datetime.now(UTC) + timedelta(minutes=i), finished_at=datetime.now(UTC) + timedelta(minutes=i + 1) if status == "succeeded" else None, @@ -1113,7 +1113,7 @@ class TestWorkflowAppService: workflow_id=workflow.id, workflow_run_id=workflow_run.id, created_from="service-api", - created_by_role=CreatorUserRole.ACCOUNT.value, + created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, created_at=datetime.now(UTC) + timedelta(minutes=i), ) @@ -1198,7 +1198,7 @@ class TestWorkflowAppService: elapsed_time=1.5, total_tokens=100, total_steps=3, - created_by_role=CreatorUserRole.ACCOUNT.value, + created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, created_at=datetime.now(UTC) + timedelta(minutes=i), finished_at=datetime.now(UTC) + timedelta(minutes=i + 1) if status != "running" else None, @@ -1213,7 +1213,7 @@ class TestWorkflowAppService: workflow_id=workflow.id, workflow_run_id=workflow_run.id, created_from="service-api", - created_by_role=CreatorUserRole.ACCOUNT.value, + created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, created_at=datetime.now(UTC) + timedelta(minutes=i), ) @@ -1300,7 +1300,7 @@ class TestWorkflowAppService: elapsed_time=1.5, total_tokens=100, total_steps=3, - created_by_role=CreatorUserRole.ACCOUNT.value, + created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, created_at=datetime.now(UTC) + timedelta(minutes=i * 10 + j), finished_at=datetime.now(UTC) + timedelta(minutes=i * 10 + j + 1), @@ -1315,7 +1315,7 @@ class TestWorkflowAppService: workflow_id=workflow.id, workflow_run_id=workflow_run.id, created_from="service-api", - created_by_role=CreatorUserRole.ACCOUNT.value, + created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, created_at=datetime.now(UTC) + timedelta(minutes=i * 10 + j), ) diff --git a/api/tests/test_containers_integration_tests/services/test_workflow_run_service.py b/api/tests/test_containers_integration_tests/services/test_workflow_run_service.py index 4cb21ef6bd..23c4eeb82f 100644 --- a/api/tests/test_containers_integration_tests/services/test_workflow_run_service.py +++ b/api/tests/test_containers_integration_tests/services/test_workflow_run_service.py @@ -130,7 +130,7 @@ class TestWorkflowRunService: elapsed_time=1.5, total_tokens=100, total_steps=3, - created_by_role=CreatorUserRole.ACCOUNT.value, + created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, created_at=created_time, finished_at=created_time, @@ -167,7 +167,7 @@ class TestWorkflowRunService: inputs={}, status="normal", mode="chat", - from_source=CreatorUserRole.ACCOUNT.value, + from_source=CreatorUserRole.ACCOUNT, from_account_id=account.id, ) db.session.add(conversation) @@ -188,7 +188,7 @@ class TestWorkflowRunService: message.answer_price_unit = 0.001 message.currency = "USD" message.status = "normal" - message.from_source = CreatorUserRole.ACCOUNT.value + message.from_source = CreatorUserRole.ACCOUNT message.from_account_id = account.id message.workflow_run_id = workflow_run.id message.inputs = {"input": "test input"} @@ -458,7 +458,7 @@ class TestWorkflowRunService: status="succeeded", elapsed_time=0.5, execution_metadata=json.dumps({"tokens": 50}), - created_by_role=CreatorUserRole.ACCOUNT.value, + created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, created_at=datetime.now(UTC), ) @@ -689,7 +689,7 @@ class TestWorkflowRunService: status="succeeded", elapsed_time=0.5, execution_metadata=json.dumps({"tokens": 50}), - created_by_role=CreatorUserRole.END_USER.value, + created_by_role=CreatorUserRole.END_USER, created_by=end_user.id, created_at=datetime.now(UTC), ) @@ -710,4 +710,4 @@ class TestWorkflowRunService: assert node_exec.app_id == app.id assert node_exec.workflow_run_id == workflow_run.id assert node_exec.created_by == end_user.id - assert node_exec.created_by_role == CreatorUserRole.END_USER.value + assert node_exec.created_by_role == CreatorUserRole.END_USER diff --git a/api/tests/test_containers_integration_tests/services/test_workflow_service.py b/api/tests/test_containers_integration_tests/services/test_workflow_service.py index 60150667ed..4741eba1f5 100644 --- a/api/tests/test_containers_integration_tests/services/test_workflow_service.py +++ b/api/tests/test_containers_integration_tests/services/test_workflow_service.py @@ -44,27 +44,26 @@ class TestWorkflowService: Account: Created test account instance """ fake = fake or Faker() - account = Account() - account.id = fake.uuid4() - account.email = fake.email() - account.name = fake.name() - account.avatar_url = fake.url() - account.tenant_id = fake.uuid4() - account.status = "active" - account.type = "normal" - account.role = "owner" - account.interface_language = "en-US" # Set interface language for Site creation + account = Account( + email=fake.email(), + name=fake.name(), + avatar=fake.url(), + status="active", + interface_language="en-US", # Set interface language for Site creation + ) account.created_at = fake.date_time_this_year() + account.id = fake.uuid4() account.updated_at = account.created_at # Create a tenant for the account from models.account import Tenant - tenant = Tenant() - tenant.id = account.tenant_id - tenant.name = f"Test Tenant {fake.company()}" - tenant.plan = "basic" - tenant.status = "active" + tenant = Tenant( + name=f"Test Tenant {fake.company()}", + plan="basic", + status="active", + ) + tenant.id = account.current_tenant_id tenant.created_at = fake.date_time_this_year() tenant.updated_at = tenant.created_at @@ -91,20 +90,21 @@ class TestWorkflowService: App: Created test app instance """ fake = fake or Faker() - app = App() - app.id = fake.uuid4() - app.tenant_id = fake.uuid4() - app.name = fake.company() - app.description = fake.text() - app.mode = AppMode.WORKFLOW - app.icon_type = "emoji" - app.icon = "🤖" - app.icon_background = "#FFEAD5" - app.enable_site = True - app.enable_api = True - app.created_by = fake.uuid4() + app = App( + id=fake.uuid4(), + tenant_id=fake.uuid4(), + name=fake.company(), + description=fake.text(), + mode=AppMode.WORKFLOW, + icon_type="emoji", + icon="🤖", + icon_background="#FFEAD5", + enable_site=True, + enable_api=True, + created_by=fake.uuid4(), + workflow_id=None, # Will be set when workflow is created + ) app.updated_by = app.created_by - app.workflow_id = None # Will be set when workflow is created from extensions.ext_database import db @@ -126,19 +126,20 @@ class TestWorkflowService: Workflow: Created test workflow instance """ fake = fake or Faker() - workflow = Workflow() - workflow.id = fake.uuid4() - workflow.tenant_id = app.tenant_id - workflow.app_id = app.id - workflow.type = WorkflowType.WORKFLOW.value - workflow.version = Workflow.VERSION_DRAFT - workflow.graph = json.dumps({"nodes": [], "edges": []}) - workflow.features = json.dumps({"features": []}) - # unique_hash is a computed property based on graph and features - workflow.created_by = account.id - workflow.updated_by = account.id - workflow.environment_variables = [] - workflow.conversation_variables = [] + workflow = Workflow( + id=fake.uuid4(), + tenant_id=app.tenant_id, + app_id=app.id, + type=WorkflowType.WORKFLOW, + version=Workflow.VERSION_DRAFT, + graph=json.dumps({"nodes": [], "edges": []}), + features=json.dumps({"features": []}), + # unique_hash is a computed property based on graph and features + created_by=account.id, + updated_by=account.id, + environment_variables=[], + conversation_variables=[], + ) from extensions.ext_database import db @@ -175,7 +176,7 @@ class TestWorkflowService: node_execution.node_type = "test_node" node_execution.title = "Test Node" # Required field node_execution.status = "succeeded" - node_execution.created_by_role = CreatorUserRole.ACCOUNT.value # Required field + node_execution.created_by_role = CreatorUserRole.ACCOUNT # Required field node_execution.created_by = account.id # Required field node_execution.created_at = fake.date_time_this_year() diff --git a/api/tests/test_containers_integration_tests/services/test_workspace_service.py b/api/tests/test_containers_integration_tests/services/test_workspace_service.py index 3fd439256d..814d1908bd 100644 --- a/api/tests/test_containers_integration_tests/services/test_workspace_service.py +++ b/api/tests/test_containers_integration_tests/services/test_workspace_service.py @@ -69,7 +69,7 @@ class TestWorkspaceService: join = TenantAccountJoin( tenant_id=tenant.id, account_id=account.id, - role=TenantAccountRole.OWNER.value, + role=TenantAccountRole.OWNER, current=True, ) db.session.add(join) @@ -111,7 +111,7 @@ class TestWorkspaceService: assert result["name"] == tenant.name assert result["plan"] == tenant.plan assert result["status"] == tenant.status - assert result["role"] == TenantAccountRole.OWNER.value + assert result["role"] == TenantAccountRole.OWNER assert result["created_at"] == tenant.created_at assert result["trial_end_reason"] is None @@ -159,7 +159,7 @@ class TestWorkspaceService: assert result["name"] == tenant.name assert result["plan"] == tenant.plan assert result["status"] == tenant.status - assert result["role"] == TenantAccountRole.OWNER.value + assert result["role"] == TenantAccountRole.OWNER assert result["created_at"] == tenant.created_at assert result["trial_end_reason"] is None @@ -194,7 +194,7 @@ class TestWorkspaceService: from extensions.ext_database import db join = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first() - join.role = TenantAccountRole.NORMAL.value + join.role = TenantAccountRole.NORMAL db.session.commit() # Setup mocks for feature service @@ -212,7 +212,7 @@ class TestWorkspaceService: assert result["name"] == tenant.name assert result["plan"] == tenant.plan assert result["status"] == tenant.status - assert result["role"] == TenantAccountRole.NORMAL.value + assert result["role"] == TenantAccountRole.NORMAL assert result["created_at"] == tenant.created_at assert result["trial_end_reason"] is None @@ -245,7 +245,7 @@ class TestWorkspaceService: from extensions.ext_database import db join = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first() - join.role = TenantAccountRole.ADMIN.value + join.role = TenantAccountRole.ADMIN db.session.commit() # Setup mocks for feature service and tenant service @@ -260,7 +260,7 @@ class TestWorkspaceService: # Assert: Verify the expected outcomes assert result is not None - assert result["role"] == TenantAccountRole.ADMIN.value + assert result["role"] == TenantAccountRole.ADMIN # Verify custom config is included for admin users assert "custom_config" in result @@ -378,7 +378,7 @@ class TestWorkspaceService: from extensions.ext_database import db join = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first() - join.role = TenantAccountRole.EDITOR.value + join.role = TenantAccountRole.EDITOR db.session.commit() # Setup mocks for feature service and tenant service @@ -394,7 +394,7 @@ class TestWorkspaceService: # Assert: Verify the expected outcomes assert result is not None - assert result["role"] == TenantAccountRole.EDITOR.value + assert result["role"] == TenantAccountRole.EDITOR # Verify custom config is not included for editor users without admin privileges assert "custom_config" not in result @@ -425,7 +425,7 @@ class TestWorkspaceService: from extensions.ext_database import db join = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first() - join.role = TenantAccountRole.DATASET_OPERATOR.value + join.role = TenantAccountRole.DATASET_OPERATOR db.session.commit() # Setup mocks for feature service and tenant service @@ -441,7 +441,7 @@ class TestWorkspaceService: # Assert: Verify the expected outcomes assert result is not None - assert result["role"] == TenantAccountRole.DATASET_OPERATOR.value + assert result["role"] == TenantAccountRole.DATASET_OPERATOR # Verify custom config is not included for dataset operators without admin privileges assert "custom_config" not in result diff --git a/api/tests/test_containers_integration_tests/services/tools/test_api_tools_manage_service.py b/api/tests/test_containers_integration_tests/services/tools/test_api_tools_manage_service.py index a412bdccf8..7366b08439 100644 --- a/api/tests/test_containers_integration_tests/services/tools/test_api_tools_manage_service.py +++ b/api/tests/test_containers_integration_tests/services/tools/test_api_tools_manage_service.py @@ -72,7 +72,7 @@ class TestApiToolManageService: join = TenantAccountJoin( tenant_id=tenant.id, account_id=account.id, - role=TenantAccountRole.OWNER.value, + role=TenantAccountRole.OWNER, current=True, ) db.session.add(join) diff --git a/api/tests/test_containers_integration_tests/services/tools/test_mcp_tools_manage_service.py b/api/tests/test_containers_integration_tests/services/tools/test_mcp_tools_manage_service.py index dd22dcbfd1..f7a4c53318 100644 --- a/api/tests/test_containers_integration_tests/services/tools/test_mcp_tools_manage_service.py +++ b/api/tests/test_containers_integration_tests/services/tools/test_mcp_tools_manage_service.py @@ -72,7 +72,7 @@ class TestMCPToolManageService: join = TenantAccountJoin( tenant_id=tenant.id, account_id=account.id, - role=TenantAccountRole.OWNER.value, + role=TenantAccountRole.OWNER, current=True, ) db.session.add(join) diff --git a/api/tests/test_containers_integration_tests/services/tools/test_tools_transform_service.py b/api/tests/test_containers_integration_tests/services/tools/test_tools_transform_service.py index 827f9c010e..ae0c7b7a6b 100644 --- a/api/tests/test_containers_integration_tests/services/tools/test_tools_transform_service.py +++ b/api/tests/test_containers_integration_tests/services/tools/test_tools_transform_service.py @@ -168,7 +168,7 @@ class TestToolTransformService: """ # Arrange: Setup test data fake = Faker() - provider_type = ToolProviderType.BUILT_IN.value + provider_type = ToolProviderType.BUILT_IN provider_name = fake.company() icon = "🔧" @@ -206,7 +206,7 @@ class TestToolTransformService: """ # Arrange: Setup test data fake = Faker() - provider_type = ToolProviderType.API.value + provider_type = ToolProviderType.API provider_name = fake.company() icon = '{"background": "#FF6B6B", "content": "🔧"}' @@ -231,7 +231,7 @@ class TestToolTransformService: """ # Arrange: Setup test data with invalid JSON fake = Faker() - provider_type = ToolProviderType.API.value + provider_type = ToolProviderType.API provider_name = fake.company() icon = '{"invalid": json}' @@ -257,7 +257,7 @@ class TestToolTransformService: """ # Arrange: Setup test data fake = Faker() - provider_type = ToolProviderType.WORKFLOW.value + provider_type = ToolProviderType.WORKFLOW provider_name = fake.company() icon = {"background": "#FF6B6B", "content": "🔧"} @@ -282,7 +282,7 @@ class TestToolTransformService: """ # Arrange: Setup test data fake = Faker() - provider_type = ToolProviderType.MCP.value + provider_type = ToolProviderType.MCP provider_name = fake.company() icon = {"background": "#FF6B6B", "content": "🔧"} @@ -329,7 +329,7 @@ class TestToolTransformService: # Arrange: Setup test data fake = Faker() tenant_id = fake.uuid4() - provider = {"type": ToolProviderType.BUILT_IN.value, "name": fake.company(), "icon": "🔧"} + provider = {"type": ToolProviderType.BUILT_IN, "name": fake.company(), "icon": "🔧"} # Act: Execute the method under test ToolTransformService.repack_provider(tenant_id, provider) diff --git a/api/tests/test_containers_integration_tests/services/workflow/test_workflow_converter.py b/api/tests/test_containers_integration_tests/services/workflow/test_workflow_converter.py index 18ab4bb73c..88aa0b6e72 100644 --- a/api/tests/test_containers_integration_tests/services/workflow/test_workflow_converter.py +++ b/api/tests/test_containers_integration_tests/services/workflow/test_workflow_converter.py @@ -66,7 +66,7 @@ class TestWorkflowConverter: mock_config.model = ModelConfigEntity( provider="openai", model="gpt-4", - mode=LLMMode.CHAT.value, + mode=LLMMode.CHAT, parameters={}, stop=[], ) @@ -120,7 +120,7 @@ class TestWorkflowConverter: join = TenantAccountJoin( tenant_id=tenant.id, account_id=account.id, - role=TenantAccountRole.OWNER.value, + role=TenantAccountRole.OWNER, current=True, ) db.session.add(join) @@ -150,7 +150,7 @@ class TestWorkflowConverter: app = App( tenant_id=tenant.id, name=fake.company(), - mode=AppMode.CHAT.value, + mode=AppMode.CHAT, icon_type="emoji", icon="🤖", icon_background="#FF6B6B", @@ -218,7 +218,7 @@ class TestWorkflowConverter: # Assert: Verify the expected outcomes assert new_app is not None assert new_app.name == "Test Workflow App" - assert new_app.mode == AppMode.ADVANCED_CHAT.value + assert new_app.mode == AppMode.ADVANCED_CHAT assert new_app.icon_type == "emoji" assert new_app.icon == "🚀" assert new_app.icon_background == "#4CAF50" @@ -257,7 +257,7 @@ class TestWorkflowConverter: app = App( tenant_id=tenant.id, name=fake.company(), - mode=AppMode.CHAT.value, + mode=AppMode.CHAT, icon_type="emoji", icon="🤖", icon_background="#FF6B6B", @@ -522,7 +522,7 @@ class TestWorkflowConverter: model_config = ModelConfigEntity( provider="openai", model="gpt-4", - mode=LLMMode.CHAT.value, + mode=LLMMode.CHAT, parameters={"temperature": 0.7}, stop=[], ) diff --git a/api/tests/test_containers_integration_tests/tasks/test_add_document_to_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_add_document_to_index_task.py index 4600f2addb..96e673d855 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_add_document_to_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_add_document_to_index_task.py @@ -63,7 +63,7 @@ class TestAddDocumentToIndexTask: join = TenantAccountJoin( tenant_id=tenant.id, account_id=account.id, - role=TenantAccountRole.OWNER.value, + role=TenantAccountRole.OWNER, current=True, ) db.session.add(join) diff --git a/api/tests/test_containers_integration_tests/tasks/test_batch_clean_document_task.py b/api/tests/test_containers_integration_tests/tasks/test_batch_clean_document_task.py index 3d17a8ac9d..8628e2af7f 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_batch_clean_document_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_batch_clean_document_task.py @@ -84,7 +84,7 @@ class TestBatchCleanDocumentTask: join = TenantAccountJoin( tenant_id=tenant.id, account_id=account.id, - role=TenantAccountRole.OWNER.value, + role=TenantAccountRole.OWNER, current=True, ) db.session.add(join) diff --git a/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py index fcae93c669..a9cfb6ffd4 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py @@ -112,7 +112,7 @@ class TestBatchCreateSegmentToIndexTask: join = TenantAccountJoin( tenant_id=tenant.id, account_id=account.id, - role=TenantAccountRole.OWNER.value, + role=TenantAccountRole.OWNER, current=True, ) db.session.add(join) diff --git a/api/tests/test_containers_integration_tests/tasks/test_clean_dataset_task.py b/api/tests/test_containers_integration_tests/tasks/test_clean_dataset_task.py index e0c2da63b9..99061d215f 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_clean_dataset_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_clean_dataset_task.py @@ -784,133 +784,6 @@ class TestCleanDatasetTask: print(f"Total cleanup time: {cleanup_duration:.3f} seconds") print(f"Average time per document: {cleanup_duration / len(documents):.3f} seconds") - def test_clean_dataset_task_concurrent_cleanup_scenarios( - self, db_session_with_containers, mock_external_service_dependencies - ): - """ - Test dataset cleanup with concurrent cleanup scenarios and race conditions. - - This test verifies that the task can properly: - 1. Handle multiple cleanup operations on the same dataset - 2. Prevent data corruption during concurrent access - 3. Maintain data consistency across multiple cleanup attempts - 4. Handle race conditions gracefully - 5. Ensure idempotent cleanup operations - """ - # Create test data - account, tenant = self._create_test_account_and_tenant(db_session_with_containers) - dataset = self._create_test_dataset(db_session_with_containers, account, tenant) - document = self._create_test_document(db_session_with_containers, account, tenant, dataset) - segment = self._create_test_segment(db_session_with_containers, account, tenant, dataset, document) - upload_file = self._create_test_upload_file(db_session_with_containers, account, tenant) - - # Update document with file reference - import json - - document.data_source_info = json.dumps({"upload_file_id": upload_file.id}) - from extensions.ext_database import db - - db.session.commit() - - # Save IDs for verification - dataset_id = dataset.id - tenant_id = tenant.id - upload_file_id = upload_file.id - - # Mock storage to simulate slow operations - mock_storage = mock_external_service_dependencies["storage"] - original_delete = mock_storage.delete - - def slow_delete(key): - import time - - time.sleep(0.1) # Simulate slow storage operation - return original_delete(key) - - mock_storage.delete.side_effect = slow_delete - - # Execute multiple cleanup operations concurrently - import threading - - cleanup_results = [] - cleanup_errors = [] - - def run_cleanup(): - try: - clean_dataset_task( - dataset_id=dataset_id, - tenant_id=tenant_id, - indexing_technique="high_quality", - index_struct='{"type": "paragraph"}', - collection_binding_id=str(uuid.uuid4()), - doc_form="paragraph_index", - ) - cleanup_results.append("success") - except Exception as e: - cleanup_errors.append(str(e)) - - # Start multiple cleanup threads - threads = [] - for i in range(3): - thread = threading.Thread(target=run_cleanup) - threads.append(thread) - thread.start() - - # Wait for all threads to complete - for thread in threads: - thread.join() - - # Verify results - # Check that all documents were deleted (only once) - remaining_documents = db.session.query(Document).filter_by(dataset_id=dataset_id).all() - assert len(remaining_documents) == 0 - - # Check that all segments were deleted (only once) - remaining_segments = db.session.query(DocumentSegment).filter_by(dataset_id=dataset_id).all() - assert len(remaining_segments) == 0 - - # Check that upload file was deleted (only once) - # Note: In concurrent scenarios, the first thread deletes documents and segments, - # subsequent threads may not find the related data to clean up upload files - # This demonstrates the idempotent nature of the cleanup process - remaining_files = db.session.query(UploadFile).filter_by(id=upload_file_id).all() - # The upload file should be deleted by the first successful cleanup operation - # However, in concurrent scenarios, this may not always happen due to race conditions - # This test demonstrates the idempotent nature of the cleanup process - if len(remaining_files) > 0: - print(f"Warning: Upload file {upload_file_id} was not deleted in concurrent scenario") - print("This is expected behavior demonstrating the idempotent nature of cleanup") - # We don't assert here as the behavior depends on timing and race conditions - - # Verify that storage.delete was called (may be called multiple times in concurrent scenarios) - # In concurrent scenarios, storage operations may be called multiple times due to race conditions - assert mock_storage.delete.call_count > 0 - - # Verify that index processor was called (may be called multiple times in concurrent scenarios) - mock_index_processor = mock_external_service_dependencies["index_processor"] - assert mock_index_processor.clean.call_count > 0 - - # Check cleanup results - assert len(cleanup_results) == 3, "All cleanup operations should complete" - assert len(cleanup_errors) == 0, "No cleanup errors should occur" - - # Verify idempotency by running cleanup again on the same dataset - # This should not perform any additional operations since data is already cleaned - clean_dataset_task( - dataset_id=dataset_id, - tenant_id=tenant_id, - indexing_technique="high_quality", - index_struct='{"type": "paragraph"}', - collection_binding_id=str(uuid.uuid4()), - doc_form="paragraph_index", - ) - - # Verify that no additional storage operations were performed - # Note: In concurrent scenarios, the exact count may vary due to race conditions - print(f"Final storage delete calls: {mock_storage.delete.call_count}") - print(f"Final index processor calls: {mock_index_processor.clean.call_count}") - print("Note: Multiple calls in concurrent scenarios are expected due to race conditions") - def test_clean_dataset_task_storage_exception_handling( self, db_session_with_containers, mock_external_service_dependencies ): diff --git a/api/tests/test_containers_integration_tests/tasks/test_create_segment_to_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_create_segment_to_index_task.py index de81295100..987ebf8aca 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_create_segment_to_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_create_segment_to_index_task.py @@ -91,7 +91,7 @@ class TestCreateSegmentToIndexTask: join = TenantAccountJoin( tenant_id=tenant.id, account_id=account.id, - role=TenantAccountRole.OWNER.value, + role=TenantAccountRole.OWNER, current=True, ) db.session.add(join) diff --git a/api/tests/test_containers_integration_tests/tasks/test_delete_segment_from_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_delete_segment_from_index_task.py index 7af4f238be..94e9b76965 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_delete_segment_from_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_delete_segment_from_index_task.py @@ -48,11 +48,8 @@ class TestDeleteSegmentFromIndexTask: Tenant: Created test tenant instance """ fake = fake or Faker() - tenant = Tenant() + tenant = Tenant(name=f"Test Tenant {fake.company()}", plan="basic", status="active") tenant.id = fake.uuid4() - tenant.name = f"Test Tenant {fake.company()}" - tenant.plan = "basic" - tenant.status = "active" tenant.created_at = fake.date_time_this_year() tenant.updated_at = tenant.created_at @@ -73,16 +70,14 @@ class TestDeleteSegmentFromIndexTask: Account: Created test account instance """ fake = fake or Faker() - account = Account() + account = Account( + name=fake.name(), + email=fake.email(), + avatar=fake.url(), + status="active", + interface_language="en-US", + ) account.id = fake.uuid4() - account.email = fake.email() - account.name = fake.name() - account.avatar_url = fake.url() - account.tenant_id = tenant.id - account.status = "active" - account.type = "normal" - account.role = "owner" - account.interface_language = "en-US" account.created_at = fake.date_time_this_year() account.updated_at = account.created_at diff --git a/api/tests/test_containers_integration_tests/tasks/test_disable_segment_from_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_disable_segment_from_index_task.py index e1d63e993b..bc3701d098 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_disable_segment_from_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_disable_segment_from_index_task.py @@ -69,7 +69,7 @@ class TestDisableSegmentFromIndexTask: join = TenantAccountJoin( tenant_id=tenant.id, account_id=account.id, - role=TenantAccountRole.OWNER.value, + role=TenantAccountRole.OWNER, current=True, ) db.session.add(join) diff --git a/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py index 5fdb8c617c..0b36e0914a 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py @@ -43,27 +43,30 @@ class TestDisableSegmentsFromIndexTask: Account: Created test account instance """ fake = fake or Faker() - account = Account() + account = Account( + email=fake.email(), + name=fake.name(), + avatar=fake.url(), + status="active", + interface_language="en-US", + ) account.id = fake.uuid4() - account.email = fake.email() - account.name = fake.name() - account.avatar_url = fake.url() + # monkey-patch attributes for test setup account.tenant_id = fake.uuid4() - account.status = "active" account.type = "normal" account.role = "owner" - account.interface_language = "en-US" account.created_at = fake.date_time_this_year() account.updated_at = account.created_at # Create a tenant for the account from models.account import Tenant - tenant = Tenant() + tenant = Tenant( + name=f"Test Tenant {fake.company()}", + plan="basic", + status="active", + ) tenant.id = account.tenant_id - tenant.name = f"Test Tenant {fake.company()}" - tenant.plan = "basic" - tenant.status = "active" tenant.created_at = fake.date_time_this_year() tenant.updated_at = tenant.created_at @@ -91,20 +94,21 @@ class TestDisableSegmentsFromIndexTask: Dataset: Created test dataset instance """ fake = fake or Faker() - dataset = Dataset() - dataset.id = fake.uuid4() - dataset.tenant_id = account.tenant_id - dataset.name = f"Test Dataset {fake.word()}" - dataset.description = fake.text(max_nb_chars=200) - dataset.provider = "vendor" - dataset.permission = "only_me" - dataset.data_source_type = "upload_file" - dataset.indexing_technique = "high_quality" - dataset.created_by = account.id - dataset.updated_by = account.id - dataset.embedding_model = "text-embedding-ada-002" - dataset.embedding_model_provider = "openai" - dataset.built_in_field_enabled = False + dataset = Dataset( + id=fake.uuid4(), + tenant_id=account.tenant_id, + name=f"Test Dataset {fake.word()}", + description=fake.text(max_nb_chars=200), + provider="vendor", + permission="only_me", + data_source_type="upload_file", + indexing_technique="high_quality", + created_by=account.id, + updated_by=account.id, + embedding_model="text-embedding-ada-002", + embedding_model_provider="openai", + built_in_field_enabled=False, + ) from extensions.ext_database import db @@ -128,6 +132,7 @@ class TestDisableSegmentsFromIndexTask: """ fake = fake or Faker() document = DatasetDocument() + document.id = fake.uuid4() document.tenant_id = dataset.tenant_id document.dataset_id = dataset.id @@ -153,7 +158,6 @@ class TestDisableSegmentsFromIndexTask: document.archived = False document.doc_form = "text_model" # Use text_model form for testing document.doc_language = "en" - from extensions.ext_database import db db.session.add(document) diff --git a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_task.py b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_task.py index f75dcf06e1..a315577b78 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_task.py @@ -72,7 +72,7 @@ class TestDocumentIndexingTask: join = TenantAccountJoin( tenant_id=tenant.id, account_id=account.id, - role=TenantAccountRole.OWNER.value, + role=TenantAccountRole.OWNER, current=True, ) db.session.add(join) @@ -154,7 +154,7 @@ class TestDocumentIndexingTask: join = TenantAccountJoin( tenant_id=tenant.id, account_id=account.id, - role=TenantAccountRole.OWNER.value, + role=TenantAccountRole.OWNER, current=True, ) db.session.add(join) diff --git a/api/tests/test_containers_integration_tests/tasks/test_enable_segments_to_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_enable_segments_to_index_task.py new file mode 100644 index 0000000000..798fe091ab --- /dev/null +++ b/api/tests/test_containers_integration_tests/tasks/test_enable_segments_to_index_task.py @@ -0,0 +1,450 @@ +from unittest.mock import MagicMock, patch + +import pytest +from faker import Faker + +from core.rag.index_processor.constant.index_type import IndexType +from extensions.ext_database import db +from extensions.ext_redis import redis_client +from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole +from models.dataset import Dataset, Document, DocumentSegment +from tasks.enable_segments_to_index_task import enable_segments_to_index_task + + +class TestEnableSegmentsToIndexTask: + """Integration tests for enable_segments_to_index_task using testcontainers.""" + + @pytest.fixture + def mock_external_service_dependencies(self): + """Mock setup for external service dependencies.""" + with ( + patch("tasks.enable_segments_to_index_task.IndexProcessorFactory") as mock_index_processor_factory, + ): + # Setup mock index processor + mock_processor = MagicMock() + mock_index_processor_factory.return_value.init_index_processor.return_value = mock_processor + + yield { + "index_processor_factory": mock_index_processor_factory, + "index_processor": mock_processor, + } + + def _create_test_dataset_and_document(self, db_session_with_containers, mock_external_service_dependencies): + """ + Helper method to create a test dataset and document for testing. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + mock_external_service_dependencies: Mock dependencies + + Returns: + tuple: (dataset, document) - Created dataset and document instances + """ + fake = Faker() + + # Create account and tenant + account = Account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + status="active", + ) + db.session.add(account) + db.session.commit() + + tenant = Tenant( + name=fake.company(), + status="normal", + ) + db.session.add(tenant) + db.session.commit() + + # Create tenant-account join + join = TenantAccountJoin( + tenant_id=tenant.id, + account_id=account.id, + role=TenantAccountRole.OWNER, + current=True, + ) + db.session.add(join) + db.session.commit() + + # Create dataset + dataset = Dataset( + id=fake.uuid4(), + tenant_id=tenant.id, + name=fake.company(), + description=fake.text(max_nb_chars=100), + data_source_type="upload_file", + indexing_technique="high_quality", + created_by=account.id, + ) + db.session.add(dataset) + db.session.commit() + + # Create document + document = Document( + id=fake.uuid4(), + tenant_id=tenant.id, + dataset_id=dataset.id, + position=1, + data_source_type="upload_file", + batch="test_batch", + name=fake.file_name(), + created_from="upload_file", + created_by=account.id, + indexing_status="completed", + enabled=True, + doc_form=IndexType.PARAGRAPH_INDEX, + ) + db.session.add(document) + db.session.commit() + + # Refresh dataset to ensure doc_form property works correctly + db.session.refresh(dataset) + + return dataset, document + + def _create_test_segments( + self, db_session_with_containers, document, dataset, count=3, enabled=False, status="completed" + ): + """ + Helper method to create test document segments. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + document: Document instance + dataset: Dataset instance + count: Number of segments to create + enabled: Whether segments should be enabled + status: Status of the segments + + Returns: + list: List of created DocumentSegment instances + """ + fake = Faker() + segments = [] + + for i in range(count): + text = fake.text(max_nb_chars=200) + segment = DocumentSegment( + id=fake.uuid4(), + tenant_id=document.tenant_id, + dataset_id=dataset.id, + document_id=document.id, + position=i, + content=text, + word_count=len(text.split()), + tokens=len(text.split()) * 2, + index_node_id=f"node_{i}", + index_node_hash=f"hash_{i}", + enabled=enabled, + status=status, + created_by=document.created_by, + ) + db.session.add(segment) + segments.append(segment) + + db.session.commit() + return segments + + def test_enable_segments_to_index_with_different_index_type( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test segments indexing with different index types. + + This test verifies: + - Proper handling of different index types + - Index processor factory integration + - Document processing with various configurations + - Redis cache key deletion + """ + # Arrange: Create test data with different index type + dataset, document = self._create_test_dataset_and_document( + db_session_with_containers, mock_external_service_dependencies + ) + + # Update document to use different index type + document.doc_form = IndexType.QA_INDEX + db.session.commit() + + # Refresh dataset to ensure doc_form property reflects the updated document + db.session.refresh(dataset) + + # Create segments + segments = self._create_test_segments(db_session_with_containers, document, dataset) + + # Set up Redis cache keys + segment_ids = [segment.id for segment in segments] + for segment in segments: + indexing_cache_key = f"segment_{segment.id}_indexing" + redis_client.set(indexing_cache_key, "processing", ex=300) + + # Act: Execute the task + enable_segments_to_index_task(segment_ids, dataset.id, document.id) + + # Assert: Verify different index type handling + mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(IndexType.QA_INDEX) + mock_external_service_dependencies["index_processor"].load.assert_called_once() + + # Verify the load method was called with correct parameters + call_args = mock_external_service_dependencies["index_processor"].load.call_args + assert call_args is not None + documents = call_args[0][1] # Second argument should be documents list + assert len(documents) == 3 + + # Verify Redis cache keys were deleted + for segment in segments: + indexing_cache_key = f"segment_{segment.id}_indexing" + assert redis_client.exists(indexing_cache_key) == 0 + + def test_enable_segments_to_index_dataset_not_found( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test handling of non-existent dataset. + + This test verifies: + - Proper error handling for missing datasets + - Early return without processing + - Database session cleanup + - No unnecessary index processor calls + """ + # Arrange: Use non-existent dataset ID + fake = Faker() + non_existent_dataset_id = fake.uuid4() + non_existent_document_id = fake.uuid4() + segment_ids = [fake.uuid4()] + + # Act: Execute the task with non-existent dataset + enable_segments_to_index_task(segment_ids, non_existent_dataset_id, non_existent_document_id) + + # Assert: Verify no processing occurred + mock_external_service_dependencies["index_processor_factory"].assert_not_called() + mock_external_service_dependencies["index_processor"].load.assert_not_called() + + def test_enable_segments_to_index_document_not_found( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test handling of non-existent document. + + This test verifies: + - Proper error handling for missing documents + - Early return without processing + - Database session cleanup + - No unnecessary index processor calls + """ + # Arrange: Create dataset but use non-existent document ID + dataset, _ = self._create_test_dataset_and_document( + db_session_with_containers, mock_external_service_dependencies + ) + fake = Faker() + non_existent_document_id = fake.uuid4() + segment_ids = [fake.uuid4()] + + # Act: Execute the task with non-existent document + enable_segments_to_index_task(segment_ids, dataset.id, non_existent_document_id) + + # Assert: Verify no processing occurred + mock_external_service_dependencies["index_processor_factory"].assert_not_called() + mock_external_service_dependencies["index_processor"].load.assert_not_called() + + def test_enable_segments_to_index_invalid_document_status( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test handling of document with invalid status. + + This test verifies: + - Early return when document is disabled, archived, or not completed + - No index processing for documents not ready for indexing + - Proper database session cleanup + - No unnecessary external service calls + """ + # Arrange: Create test data with invalid document status + dataset, document = self._create_test_dataset_and_document( + db_session_with_containers, mock_external_service_dependencies + ) + + # Test different invalid statuses + invalid_statuses = [ + ("disabled", {"enabled": False}), + ("archived", {"archived": True}), + ("not_completed", {"indexing_status": "processing"}), + ] + + for _, status_attrs in invalid_statuses: + # Reset document status + document.enabled = True + document.archived = False + document.indexing_status = "completed" + db.session.commit() + + # Set invalid status + for attr, value in status_attrs.items(): + setattr(document, attr, value) + db.session.commit() + + # Create segments + segments = self._create_test_segments(db_session_with_containers, document, dataset) + segment_ids = [segment.id for segment in segments] + + # Act: Execute the task + enable_segments_to_index_task(segment_ids, dataset.id, document.id) + + # Assert: Verify no processing occurred + mock_external_service_dependencies["index_processor_factory"].assert_not_called() + mock_external_service_dependencies["index_processor"].load.assert_not_called() + + # Clean up segments for next iteration + for segment in segments: + db.session.delete(segment) + db.session.commit() + + def test_enable_segments_to_index_segments_not_found( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test handling when no segments are found. + + This test verifies: + - Proper handling when segments don't exist + - Early return without processing + - Database session cleanup + - Index processor is created but load is not called + """ + # Arrange: Create test data + dataset, document = self._create_test_dataset_and_document( + db_session_with_containers, mock_external_service_dependencies + ) + + # Use non-existent segment IDs + fake = Faker() + non_existent_segment_ids = [fake.uuid4() for _ in range(3)] + + # Act: Execute the task with non-existent segments + enable_segments_to_index_task(non_existent_segment_ids, dataset.id, document.id) + + # Assert: Verify index processor was created but load was not called + mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(IndexType.PARAGRAPH_INDEX) + mock_external_service_dependencies["index_processor"].load.assert_not_called() + + def test_enable_segments_to_index_with_parent_child_structure( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test segments indexing with parent-child structure. + + This test verifies: + - Proper handling of PARENT_CHILD_INDEX type + - Child document creation from segments + - Correct document structure for parent-child indexing + - Index processor receives properly structured documents + - Redis cache key deletion + """ + # Arrange: Create test data with parent-child index type + dataset, document = self._create_test_dataset_and_document( + db_session_with_containers, mock_external_service_dependencies + ) + + # Update document to use parent-child index type + document.doc_form = IndexType.PARENT_CHILD_INDEX + db.session.commit() + + # Refresh dataset to ensure doc_form property reflects the updated document + db.session.refresh(dataset) + + # Create segments with mock child chunks + segments = self._create_test_segments(db_session_with_containers, document, dataset) + + # Set up Redis cache keys + segment_ids = [segment.id for segment in segments] + for segment in segments: + indexing_cache_key = f"segment_{segment.id}_indexing" + redis_client.set(indexing_cache_key, "processing", ex=300) + + # Mock the get_child_chunks method for each segment + with patch.object(DocumentSegment, "get_child_chunks") as mock_get_child_chunks: + # Setup mock to return child chunks for each segment + mock_child_chunks = [] + for i in range(2): # Each segment has 2 child chunks + mock_child = MagicMock() + mock_child.content = f"child_content_{i}" + mock_child.index_node_id = f"child_node_{i}" + mock_child.index_node_hash = f"child_hash_{i}" + mock_child_chunks.append(mock_child) + + mock_get_child_chunks.return_value = mock_child_chunks + + # Act: Execute the task + enable_segments_to_index_task(segment_ids, dataset.id, document.id) + + # Assert: Verify parent-child index processing + mock_external_service_dependencies["index_processor_factory"].assert_called_once_with( + IndexType.PARENT_CHILD_INDEX + ) + mock_external_service_dependencies["index_processor"].load.assert_called_once() + + # Verify the load method was called with correct parameters + call_args = mock_external_service_dependencies["index_processor"].load.call_args + assert call_args is not None + documents = call_args[0][1] # Second argument should be documents list + assert len(documents) == 3 # 3 segments + + # Verify each document has children + for doc in documents: + assert hasattr(doc, "children") + assert len(doc.children) == 2 # Each document has 2 children + + # Verify Redis cache keys were deleted + for segment in segments: + indexing_cache_key = f"segment_{segment.id}_indexing" + assert redis_client.exists(indexing_cache_key) == 0 + + def test_enable_segments_to_index_general_exception_handling( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test general exception handling during indexing process. + + This test verifies: + - Exceptions are properly caught and handled + - Segment status is set to error + - Segments are disabled + - Error information is recorded + - Redis cache is still cleared + - Database session is properly closed + """ + # Arrange: Create test data + dataset, document = self._create_test_dataset_and_document( + db_session_with_containers, mock_external_service_dependencies + ) + segments = self._create_test_segments(db_session_with_containers, document, dataset) + + # Set up Redis cache keys + segment_ids = [segment.id for segment in segments] + for segment in segments: + indexing_cache_key = f"segment_{segment.id}_indexing" + redis_client.set(indexing_cache_key, "processing", ex=300) + + # Mock the index processor to raise an exception + mock_external_service_dependencies["index_processor"].load.side_effect = Exception("Index processing failed") + + # Act: Execute the task + enable_segments_to_index_task(segment_ids, dataset.id, document.id) + + # Assert: Verify error handling + for segment in segments: + db.session.refresh(segment) + assert segment.enabled is False + assert segment.status == "error" + assert segment.error is not None + assert "Index processing failed" in segment.error + assert segment.disabled_at is not None + + # Verify Redis cache keys were still cleared despite error + for segment in segments: + indexing_cache_key = f"segment_{segment.id}_indexing" + assert redis_client.exists(indexing_cache_key) == 0 diff --git a/api/tests/test_containers_integration_tests/tasks/test_mail_account_deletion_task.py b/api/tests/test_containers_integration_tests/tasks/test_mail_account_deletion_task.py new file mode 100644 index 0000000000..31e9b67421 --- /dev/null +++ b/api/tests/test_containers_integration_tests/tasks/test_mail_account_deletion_task.py @@ -0,0 +1,242 @@ +from unittest.mock import MagicMock, patch + +import pytest +from faker import Faker + +from extensions.ext_database import db +from libs.email_i18n import EmailType +from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole +from tasks.mail_account_deletion_task import send_account_deletion_verification_code, send_deletion_success_task + + +class TestMailAccountDeletionTask: + """Integration tests for mail account deletion tasks using testcontainers.""" + + @pytest.fixture + def mock_external_service_dependencies(self): + """Mock setup for external service dependencies.""" + with ( + patch("tasks.mail_account_deletion_task.mail") as mock_mail, + patch("tasks.mail_account_deletion_task.get_email_i18n_service") as mock_get_email_service, + ): + # Setup mock mail service + mock_mail.is_inited.return_value = True + + # Setup mock email service + mock_email_service = MagicMock() + mock_get_email_service.return_value = mock_email_service + + yield { + "mail": mock_mail, + "get_email_service": mock_get_email_service, + "email_service": mock_email_service, + } + + def _create_test_account(self, db_session_with_containers): + """ + Helper method to create a test account for testing. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + + Returns: + Account: Created account instance + """ + fake = Faker() + + # Create account + account = Account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + status="active", + ) + db.session.add(account) + db.session.commit() + + # Create tenant + tenant = Tenant( + name=fake.company(), + status="normal", + ) + db.session.add(tenant) + db.session.commit() + + # Create tenant-account join + join = TenantAccountJoin( + tenant_id=tenant.id, + account_id=account.id, + role=TenantAccountRole.OWNER, + current=True, + ) + db.session.add(join) + db.session.commit() + + return account + + def test_send_deletion_success_task_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful account deletion success email sending. + + This test verifies: + - Proper email service initialization check + - Correct email service method calls + - Template context is properly formatted + - Email type is correctly specified + """ + # Arrange: Create test data + account = self._create_test_account(db_session_with_containers) + test_email = account.email + test_language = "en-US" + + # Act: Execute the task + send_deletion_success_task(test_email, test_language) + + # Assert: Verify the expected outcomes + # Verify mail service was checked + mock_external_service_dependencies["mail"].is_inited.assert_called_once() + + # Verify email service was retrieved + mock_external_service_dependencies["get_email_service"].assert_called_once() + + # Verify email was sent with correct parameters + mock_external_service_dependencies["email_service"].send_email.assert_called_once_with( + email_type=EmailType.ACCOUNT_DELETION_SUCCESS, + language_code=test_language, + to=test_email, + template_context={ + "to": test_email, + "email": test_email, + }, + ) + + def test_send_deletion_success_task_mail_not_initialized( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test account deletion success email when mail service is not initialized. + + This test verifies: + - Early return when mail service is not initialized + - No email service calls are made + - No exceptions are raised + """ + # Arrange: Setup mail service to return not initialized + mock_external_service_dependencies["mail"].is_inited.return_value = False + account = self._create_test_account(db_session_with_containers) + test_email = account.email + + # Act: Execute the task + send_deletion_success_task(test_email) + + # Assert: Verify no email service calls were made + mock_external_service_dependencies["get_email_service"].assert_not_called() + mock_external_service_dependencies["email_service"].send_email.assert_not_called() + + def test_send_deletion_success_task_email_service_exception( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test account deletion success email when email service raises exception. + + This test verifies: + - Exception is properly caught and logged + - Task completes without raising exception + - Error logging is recorded + """ + # Arrange: Setup email service to raise exception + mock_external_service_dependencies["email_service"].send_email.side_effect = Exception("Email service failed") + account = self._create_test_account(db_session_with_containers) + test_email = account.email + + # Act: Execute the task (should not raise exception) + send_deletion_success_task(test_email) + + # Assert: Verify email service was called but exception was handled + mock_external_service_dependencies["email_service"].send_email.assert_called_once() + + def test_send_account_deletion_verification_code_success( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test successful account deletion verification code email sending. + + This test verifies: + - Proper email service initialization check + - Correct email service method calls + - Template context includes verification code + - Email type is correctly specified + """ + # Arrange: Create test data + account = self._create_test_account(db_session_with_containers) + test_email = account.email + test_code = "123456" + test_language = "en-US" + + # Act: Execute the task + send_account_deletion_verification_code(test_email, test_code, test_language) + + # Assert: Verify the expected outcomes + # Verify mail service was checked + mock_external_service_dependencies["mail"].is_inited.assert_called_once() + + # Verify email service was retrieved + mock_external_service_dependencies["get_email_service"].assert_called_once() + + # Verify email was sent with correct parameters + mock_external_service_dependencies["email_service"].send_email.assert_called_once_with( + email_type=EmailType.ACCOUNT_DELETION_VERIFICATION, + language_code=test_language, + to=test_email, + template_context={ + "to": test_email, + "code": test_code, + }, + ) + + def test_send_account_deletion_verification_code_mail_not_initialized( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test account deletion verification code email when mail service is not initialized. + + This test verifies: + - Early return when mail service is not initialized + - No email service calls are made + - No exceptions are raised + """ + # Arrange: Setup mail service to return not initialized + mock_external_service_dependencies["mail"].is_inited.return_value = False + account = self._create_test_account(db_session_with_containers) + test_email = account.email + test_code = "123456" + + # Act: Execute the task + send_account_deletion_verification_code(test_email, test_code) + + # Assert: Verify no email service calls were made + mock_external_service_dependencies["get_email_service"].assert_not_called() + mock_external_service_dependencies["email_service"].send_email.assert_not_called() + + def test_send_account_deletion_verification_code_email_service_exception( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test account deletion verification code email when email service raises exception. + + This test verifies: + - Exception is properly caught and logged + - Task completes without raising exception + - Error logging is recorded + """ + # Arrange: Setup email service to raise exception + mock_external_service_dependencies["email_service"].send_email.side_effect = Exception("Email service failed") + account = self._create_test_account(db_session_with_containers) + test_email = account.email + test_code = "123456" + + # Act: Execute the task (should not raise exception) + send_account_deletion_verification_code(test_email, test_code) + + # Assert: Verify email service was called but exception was handled + mock_external_service_dependencies["email_service"].send_email.assert_called_once() diff --git a/api/tests/test_containers_integration_tests/tasks/test_mail_change_mail_task.py b/api/tests/test_containers_integration_tests/tasks/test_mail_change_mail_task.py new file mode 100644 index 0000000000..1aed7dc7cc --- /dev/null +++ b/api/tests/test_containers_integration_tests/tasks/test_mail_change_mail_task.py @@ -0,0 +1,282 @@ +from unittest.mock import MagicMock, patch + +import pytest +from faker import Faker + +from libs.email_i18n import EmailType +from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole +from tasks.mail_change_mail_task import send_change_mail_completed_notification_task, send_change_mail_task + + +class TestMailChangeMailTask: + """Integration tests for mail_change_mail_task using testcontainers.""" + + @pytest.fixture + def mock_external_service_dependencies(self): + """Mock setup for external service dependencies.""" + with ( + patch("tasks.mail_change_mail_task.mail") as mock_mail, + patch("tasks.mail_change_mail_task.get_email_i18n_service") as mock_get_email_i18n_service, + ): + # Setup mock mail service + mock_mail.is_inited.return_value = True + + # Setup mock email i18n service + mock_email_service = MagicMock() + mock_get_email_i18n_service.return_value = mock_email_service + + yield { + "mail": mock_mail, + "email_i18n_service": mock_email_service, + "get_email_i18n_service": mock_get_email_i18n_service, + } + + def _create_test_account(self, db_session_with_containers): + """ + Helper method to create a test account for testing. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + + Returns: + Account: Created account instance + """ + fake = Faker() + + # Create account + account = Account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + status="active", + ) + db_session_with_containers.add(account) + db_session_with_containers.commit() + + # Create tenant + tenant = Tenant( + name=fake.company(), + status="normal", + ) + db_session_with_containers.add(tenant) + db_session_with_containers.commit() + + # Create tenant-account join + join = TenantAccountJoin( + tenant_id=tenant.id, + account_id=account.id, + role=TenantAccountRole.OWNER, + current=True, + ) + db_session_with_containers.add(join) + db_session_with_containers.commit() + + return account + + def test_send_change_mail_task_success_old_email_phase( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test successful change email task execution for old_email phase. + + This test verifies: + - Proper mail service initialization check + - Correct email service method call with old_email phase + - Successful task completion + """ + # Arrange: Create test data + account = self._create_test_account(db_session_with_containers) + test_language = "en-US" + test_email = account.email + test_code = "123456" + test_phase = "old_email" + + # Act: Execute the task + send_change_mail_task(test_language, test_email, test_code, test_phase) + + # Assert: Verify the expected outcomes + mock_external_service_dependencies["mail"].is_inited.assert_called_once() + mock_external_service_dependencies["get_email_i18n_service"].assert_called_once() + mock_external_service_dependencies["email_i18n_service"].send_change_email.assert_called_once_with( + language_code=test_language, + to=test_email, + code=test_code, + phase=test_phase, + ) + + def test_send_change_mail_task_success_new_email_phase( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test successful change email task execution for new_email phase. + + This test verifies: + - Proper mail service initialization check + - Correct email service method call with new_email phase + - Successful task completion + """ + # Arrange: Create test data + account = self._create_test_account(db_session_with_containers) + test_language = "zh-Hans" + test_email = "new@example.com" + test_code = "789012" + test_phase = "new_email" + + # Act: Execute the task + send_change_mail_task(test_language, test_email, test_code, test_phase) + + # Assert: Verify the expected outcomes + mock_external_service_dependencies["mail"].is_inited.assert_called_once() + mock_external_service_dependencies["get_email_i18n_service"].assert_called_once() + mock_external_service_dependencies["email_i18n_service"].send_change_email.assert_called_once_with( + language_code=test_language, + to=test_email, + code=test_code, + phase=test_phase, + ) + + def test_send_change_mail_task_mail_not_initialized( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test change email task when mail service is not initialized. + + This test verifies: + - Early return when mail service is not initialized + - No email service calls when mail is not available + """ + # Arrange: Setup mail service as not initialized + mock_external_service_dependencies["mail"].is_inited.return_value = False + test_language = "en-US" + test_email = "test@example.com" + test_code = "123456" + test_phase = "old_email" + + # Act: Execute the task + send_change_mail_task(test_language, test_email, test_code, test_phase) + + # Assert: Verify no email service calls + mock_external_service_dependencies["mail"].is_inited.assert_called_once() + mock_external_service_dependencies["get_email_i18n_service"].assert_not_called() + mock_external_service_dependencies["email_i18n_service"].send_change_email.assert_not_called() + + def test_send_change_mail_task_email_service_exception( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test change email task when email service raises an exception. + + This test verifies: + - Exception is properly caught and logged + - Task completes without raising exception + """ + # Arrange: Setup email service to raise exception + mock_external_service_dependencies["email_i18n_service"].send_change_email.side_effect = Exception( + "Email service failed" + ) + test_language = "en-US" + test_email = "test@example.com" + test_code = "123456" + test_phase = "old_email" + + # Act: Execute the task (should not raise exception) + send_change_mail_task(test_language, test_email, test_code, test_phase) + + # Assert: Verify email service was called despite exception + mock_external_service_dependencies["mail"].is_inited.assert_called_once() + mock_external_service_dependencies["get_email_i18n_service"].assert_called_once() + mock_external_service_dependencies["email_i18n_service"].send_change_email.assert_called_once_with( + language_code=test_language, + to=test_email, + code=test_code, + phase=test_phase, + ) + + def test_send_change_mail_completed_notification_task_success( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test successful change email completed notification task execution. + + This test verifies: + - Proper mail service initialization check + - Correct email service method call with CHANGE_EMAIL_COMPLETED type + - Template context is properly constructed + - Successful task completion + """ + # Arrange: Create test data + account = self._create_test_account(db_session_with_containers) + test_language = "en-US" + test_email = account.email + + # Act: Execute the task + send_change_mail_completed_notification_task(test_language, test_email) + + # Assert: Verify the expected outcomes + mock_external_service_dependencies["mail"].is_inited.assert_called_once() + mock_external_service_dependencies["get_email_i18n_service"].assert_called_once() + mock_external_service_dependencies["email_i18n_service"].send_email.assert_called_once_with( + email_type=EmailType.CHANGE_EMAIL_COMPLETED, + language_code=test_language, + to=test_email, + template_context={ + "to": test_email, + "email": test_email, + }, + ) + + def test_send_change_mail_completed_notification_task_mail_not_initialized( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test change email completed notification task when mail service is not initialized. + + This test verifies: + - Early return when mail service is not initialized + - No email service calls when mail is not available + """ + # Arrange: Setup mail service as not initialized + mock_external_service_dependencies["mail"].is_inited.return_value = False + test_language = "en-US" + test_email = "test@example.com" + + # Act: Execute the task + send_change_mail_completed_notification_task(test_language, test_email) + + # Assert: Verify no email service calls + mock_external_service_dependencies["mail"].is_inited.assert_called_once() + mock_external_service_dependencies["get_email_i18n_service"].assert_not_called() + mock_external_service_dependencies["email_i18n_service"].send_email.assert_not_called() + + def test_send_change_mail_completed_notification_task_email_service_exception( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test change email completed notification task when email service raises an exception. + + This test verifies: + - Exception is properly caught and logged + - Task completes without raising exception + """ + # Arrange: Setup email service to raise exception + mock_external_service_dependencies["email_i18n_service"].send_email.side_effect = Exception( + "Email service failed" + ) + test_language = "en-US" + test_email = "test@example.com" + + # Act: Execute the task (should not raise exception) + send_change_mail_completed_notification_task(test_language, test_email) + + # Assert: Verify email service was called despite exception + mock_external_service_dependencies["mail"].is_inited.assert_called_once() + mock_external_service_dependencies["get_email_i18n_service"].assert_called_once() + mock_external_service_dependencies["email_i18n_service"].send_email.assert_called_once_with( + email_type=EmailType.CHANGE_EMAIL_COMPLETED, + language_code=test_language, + to=test_email, + template_context={ + "to": test_email, + "email": test_email, + }, + ) diff --git a/api/tests/test_containers_integration_tests/tasks/test_mail_email_code_login_task.py b/api/tests/test_containers_integration_tests/tasks/test_mail_email_code_login_task.py new file mode 100644 index 0000000000..e6a804784a --- /dev/null +++ b/api/tests/test_containers_integration_tests/tasks/test_mail_email_code_login_task.py @@ -0,0 +1,598 @@ +""" +TestContainers-based integration tests for send_email_code_login_mail_task. + +This module provides comprehensive integration tests for the email code login mail task +using TestContainers infrastructure. The tests ensure that the task properly sends +email verification codes for login with internationalization support and handles +various error scenarios in a real database environment. + +All tests use the testcontainers infrastructure to ensure proper database isolation +and realistic testing scenarios with actual PostgreSQL and Redis instances. +""" + +from unittest.mock import MagicMock, patch + +import pytest +from faker import Faker + +from libs.email_i18n import EmailType +from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole +from tasks.mail_email_code_login import send_email_code_login_mail_task + + +class TestSendEmailCodeLoginMailTask: + """ + Comprehensive integration tests for send_email_code_login_mail_task using testcontainers. + + This test class covers all major functionality of the email code login mail task: + - Successful email sending with different languages + - Email service integration and template rendering + - Error handling for various failure scenarios + - Performance metrics and logging verification + - Edge cases and boundary conditions + + All tests use the testcontainers infrastructure to ensure proper database isolation + and realistic testing environment with actual database interactions. + """ + + @pytest.fixture(autouse=True) + def cleanup_database(self, db_session_with_containers): + """Clean up database before each test to ensure isolation.""" + from extensions.ext_redis import redis_client + + # Clear all test data + db_session_with_containers.query(TenantAccountJoin).delete() + db_session_with_containers.query(Tenant).delete() + db_session_with_containers.query(Account).delete() + db_session_with_containers.commit() + + # Clear Redis cache + redis_client.flushdb() + + @pytest.fixture + def mock_external_service_dependencies(self): + """Mock setup for external service dependencies.""" + with ( + patch("tasks.mail_email_code_login.mail") as mock_mail, + patch("tasks.mail_email_code_login.get_email_i18n_service") as mock_email_service, + ): + # Setup default mock returns + mock_mail.is_inited.return_value = True + + # Mock email service + mock_email_service_instance = MagicMock() + mock_email_service_instance.send_email.return_value = None + mock_email_service.return_value = mock_email_service_instance + + yield { + "mail": mock_mail, + "email_service": mock_email_service, + "email_service_instance": mock_email_service_instance, + } + + def _create_test_account(self, db_session_with_containers, fake=None): + """ + Helper method to create a test account for testing. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + fake: Faker instance for generating test data + + Returns: + Account: Created account instance + """ + if fake is None: + fake = Faker() + + # Create account + account = Account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + status="active", + ) + + db_session_with_containers.add(account) + db_session_with_containers.commit() + + return account + + def _create_test_tenant_and_account(self, db_session_with_containers, fake=None): + """ + Helper method to create a test tenant and account for testing. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + fake: Faker instance for generating test data + + Returns: + tuple: (Account, Tenant) created instances + """ + if fake is None: + fake = Faker() + + # Create account using the existing helper method + account = self._create_test_account(db_session_with_containers, fake) + + # Create tenant + tenant = Tenant( + name=fake.company(), + plan="basic", + status="active", + ) + + db_session_with_containers.add(tenant) + db_session_with_containers.commit() + + # Create tenant-account relationship + tenant_account_join = TenantAccountJoin( + tenant_id=tenant.id, + account_id=account.id, + role=TenantAccountRole.OWNER, + ) + + db_session_with_containers.add(tenant_account_join) + db_session_with_containers.commit() + + return account, tenant + + def test_send_email_code_login_mail_task_success_english( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test successful email code login mail sending in English. + + This test verifies that the task can successfully: + 1. Send email code login mail with English language + 2. Use proper email service integration + 3. Pass correct template context to email service + 4. Log performance metrics correctly + 5. Complete task execution without errors + """ + # Arrange: Setup test data + fake = Faker() + test_email = fake.email() + test_code = "123456" + test_language = "en-US" + + # Act: Execute the task + send_email_code_login_mail_task( + language=test_language, + to=test_email, + code=test_code, + ) + + # Assert: Verify expected outcomes + mock_mail = mock_external_service_dependencies["mail"] + mock_email_service_instance = mock_external_service_dependencies["email_service_instance"] + + # Verify mail service was checked for initialization + mock_mail.is_inited.assert_called_once() + + # Verify email service was called with correct parameters + mock_email_service_instance.send_email.assert_called_once_with( + email_type=EmailType.EMAIL_CODE_LOGIN, + language_code=test_language, + to=test_email, + template_context={ + "to": test_email, + "code": test_code, + }, + ) + + def test_send_email_code_login_mail_task_success_chinese( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test successful email code login mail sending in Chinese. + + This test verifies that the task can successfully: + 1. Send email code login mail with Chinese language + 2. Handle different language codes properly + 3. Use correct template context for Chinese emails + 4. Complete task execution without errors + """ + # Arrange: Setup test data + fake = Faker() + test_email = fake.email() + test_code = "789012" + test_language = "zh-Hans" + + # Act: Execute the task + send_email_code_login_mail_task( + language=test_language, + to=test_email, + code=test_code, + ) + + # Assert: Verify expected outcomes + mock_email_service_instance = mock_external_service_dependencies["email_service_instance"] + + # Verify email service was called with Chinese language + mock_email_service_instance.send_email.assert_called_once_with( + email_type=EmailType.EMAIL_CODE_LOGIN, + language_code=test_language, + to=test_email, + template_context={ + "to": test_email, + "code": test_code, + }, + ) + + def test_send_email_code_login_mail_task_success_multiple_languages( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test successful email code login mail sending with multiple languages. + + This test verifies that the task can successfully: + 1. Handle various language codes correctly + 2. Send emails with different language configurations + 3. Maintain proper template context for each language + 4. Complete multiple task executions without conflicts + """ + # Arrange: Setup test data + fake = Faker() + test_languages = ["en-US", "zh-Hans", "zh-CN", "ja-JP", "ko-KR"] + test_emails = [fake.email() for _ in test_languages] + test_codes = [fake.numerify("######") for _ in test_languages] + + # Act: Execute the task for each language + for i, language in enumerate(test_languages): + send_email_code_login_mail_task( + language=language, + to=test_emails[i], + code=test_codes[i], + ) + + # Assert: Verify expected outcomes + mock_email_service_instance = mock_external_service_dependencies["email_service_instance"] + + # Verify email service was called for each language + assert mock_email_service_instance.send_email.call_count == len(test_languages) + + # Verify each call had correct parameters + for i, language in enumerate(test_languages): + call_args = mock_email_service_instance.send_email.call_args_list[i] + assert call_args[1]["email_type"] == EmailType.EMAIL_CODE_LOGIN + assert call_args[1]["language_code"] == language + assert call_args[1]["to"] == test_emails[i] + assert call_args[1]["template_context"]["code"] == test_codes[i] + + def test_send_email_code_login_mail_task_mail_not_initialized( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test email code login mail task when mail service is not initialized. + + This test verifies that the task can properly: + 1. Check mail service initialization status + 2. Return early when mail is not initialized + 3. Not attempt to send email when service is unavailable + 4. Handle gracefully without errors + """ + # Arrange: Setup test data + fake = Faker() + test_email = fake.email() + test_code = "123456" + test_language = "en-US" + + # Mock mail service as not initialized + mock_mail = mock_external_service_dependencies["mail"] + mock_mail.is_inited.return_value = False + + # Act: Execute the task + send_email_code_login_mail_task( + language=test_language, + to=test_email, + code=test_code, + ) + + # Assert: Verify expected outcomes + mock_email_service_instance = mock_external_service_dependencies["email_service_instance"] + + # Verify mail service was checked for initialization + mock_mail.is_inited.assert_called_once() + + # Verify email service was not called + mock_email_service_instance.send_email.assert_not_called() + + def test_send_email_code_login_mail_task_email_service_exception( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test email code login mail task when email service raises an exception. + + This test verifies that the task can properly: + 1. Handle email service exceptions gracefully + 2. Log appropriate error messages + 3. Continue execution without crashing + 4. Maintain proper error handling + """ + # Arrange: Setup test data + fake = Faker() + test_email = fake.email() + test_code = "123456" + test_language = "en-US" + + # Mock email service to raise an exception + mock_email_service_instance = mock_external_service_dependencies["email_service_instance"] + mock_email_service_instance.send_email.side_effect = Exception("Email service unavailable") + + # Act: Execute the task - it should handle the exception gracefully + send_email_code_login_mail_task( + language=test_language, + to=test_email, + code=test_code, + ) + + # Assert: Verify expected outcomes + mock_mail = mock_external_service_dependencies["mail"] + mock_email_service_instance = mock_external_service_dependencies["email_service_instance"] + + # Verify mail service was checked for initialization + mock_mail.is_inited.assert_called_once() + + # Verify email service was called (and failed) + mock_email_service_instance.send_email.assert_called_once_with( + email_type=EmailType.EMAIL_CODE_LOGIN, + language_code=test_language, + to=test_email, + template_context={ + "to": test_email, + "code": test_code, + }, + ) + + def test_send_email_code_login_mail_task_invalid_parameters( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test email code login mail task with invalid parameters. + + This test verifies that the task can properly: + 1. Handle empty or None email addresses + 2. Process empty or None verification codes + 3. Handle invalid language codes + 4. Maintain proper error handling for invalid inputs + """ + # Arrange: Setup test data + fake = Faker() + test_language = "en-US" + + # Test cases for invalid parameters + invalid_test_cases = [ + {"email": "", "code": "123456", "description": "empty email"}, + {"email": None, "code": "123456", "description": "None email"}, + {"email": fake.email(), "code": "", "description": "empty code"}, + {"email": fake.email(), "code": None, "description": "None code"}, + {"email": "invalid-email", "code": "123456", "description": "invalid email format"}, + ] + + for test_case in invalid_test_cases: + # Reset mocks for each test case + mock_email_service_instance = mock_external_service_dependencies["email_service_instance"] + mock_email_service_instance.reset_mock() + + # Act: Execute the task with invalid parameters + send_email_code_login_mail_task( + language=test_language, + to=test_case["email"], + code=test_case["code"], + ) + + # Assert: Verify that email service was still called + # The task should pass parameters to email service as-is + # and let the email service handle validation + mock_email_service_instance.send_email.assert_called_once() + + def test_send_email_code_login_mail_task_edge_cases( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test email code login mail task with edge cases and boundary conditions. + + This test verifies that the task can properly: + 1. Handle very long email addresses + 2. Process very long verification codes + 3. Handle special characters in parameters + 4. Process extreme language codes + """ + # Arrange: Setup test data + fake = Faker() + test_language = "en-US" + + # Edge case test data + edge_cases = [ + { + "email": "a" * 100 + "@example.com", # Very long email + "code": "1" * 20, # Very long code + "description": "very long email and code", + }, + { + "email": "test+tag@example.com", # Email with special characters + "code": "123-456", # Code with special characters + "description": "special characters", + }, + { + "email": "test@sub.domain.example.com", # Complex domain + "code": "000000", # All zeros + "description": "complex domain and all zeros code", + }, + { + "email": "test@example.co.uk", # International domain + "code": "999999", # All nines + "description": "international domain and all nines code", + }, + ] + + for test_case in edge_cases: + # Reset mocks for each test case + mock_email_service_instance = mock_external_service_dependencies["email_service_instance"] + mock_email_service_instance.reset_mock() + + # Act: Execute the task with edge case data + send_email_code_login_mail_task( + language=test_language, + to=test_case["email"], + code=test_case["code"], + ) + + # Assert: Verify that email service was called with edge case data + mock_email_service_instance.send_email.assert_called_once_with( + email_type=EmailType.EMAIL_CODE_LOGIN, + language_code=test_language, + to=test_case["email"], + template_context={ + "to": test_case["email"], + "code": test_case["code"], + }, + ) + + def test_send_email_code_login_mail_task_database_integration( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test email code login mail task with database integration. + + This test verifies that the task can properly: + 1. Work with real database connections + 2. Handle database session management + 3. Maintain proper database state + 4. Complete without database-related errors + """ + # Arrange: Setup test data with database + fake = Faker() + account, tenant = self._create_test_tenant_and_account(db_session_with_containers, fake) + + test_email = account.email + test_code = "123456" + test_language = "en-US" + + # Act: Execute the task + send_email_code_login_mail_task( + language=test_language, + to=test_email, + code=test_code, + ) + + # Assert: Verify expected outcomes + mock_email_service_instance = mock_external_service_dependencies["email_service_instance"] + + # Verify email service was called with database account email + mock_email_service_instance.send_email.assert_called_once_with( + email_type=EmailType.EMAIL_CODE_LOGIN, + language_code=test_language, + to=test_email, + template_context={ + "to": test_email, + "code": test_code, + }, + ) + + # Verify database state is maintained + db_session_with_containers.refresh(account) + assert account.email == test_email + assert account.status == "active" + + def test_send_email_code_login_mail_task_redis_integration( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test email code login mail task with Redis integration. + + This test verifies that the task can properly: + 1. Work with Redis cache connections + 2. Handle Redis operations without errors + 3. Maintain proper cache state + 4. Complete without Redis-related errors + """ + # Arrange: Setup test data + fake = Faker() + test_email = fake.email() + test_code = "123456" + test_language = "en-US" + + # Setup Redis cache data + from extensions.ext_redis import redis_client + + cache_key = f"email_code_login_test_{test_email}" + redis_client.set(cache_key, "test_value", ex=300) + + # Act: Execute the task + send_email_code_login_mail_task( + language=test_language, + to=test_email, + code=test_code, + ) + + # Assert: Verify expected outcomes + mock_email_service_instance = mock_external_service_dependencies["email_service_instance"] + + # Verify email service was called + mock_email_service_instance.send_email.assert_called_once() + + # Verify Redis cache is still accessible + assert redis_client.exists(cache_key) == 1 + assert redis_client.get(cache_key) == b"test_value" + + # Clean up Redis cache + redis_client.delete(cache_key) + + def test_send_email_code_login_mail_task_error_handling_comprehensive( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test comprehensive error handling for email code login mail task. + + This test verifies that the task can properly: + 1. Handle various types of exceptions + 2. Log appropriate error messages + 3. Continue execution despite errors + 4. Maintain proper error reporting + """ + # Arrange: Setup test data + fake = Faker() + test_email = fake.email() + test_code = "123456" + test_language = "en-US" + + # Test different exception types + exception_types = [ + ("ValueError", ValueError("Invalid email format")), + ("RuntimeError", RuntimeError("Service unavailable")), + ("ConnectionError", ConnectionError("Network error")), + ("TimeoutError", TimeoutError("Request timeout")), + ("Exception", Exception("Generic error")), + ] + + for error_name, exception in exception_types: + # Reset mocks for each test case + mock_email_service_instance = mock_external_service_dependencies["email_service_instance"] + mock_email_service_instance.reset_mock() + mock_email_service_instance.send_email.side_effect = exception + + # Mock logging to capture error messages + with patch("tasks.mail_email_code_login.logger") as mock_logger: + # Act: Execute the task - it should handle the exception gracefully + send_email_code_login_mail_task( + language=test_language, + to=test_email, + code=test_code, + ) + + # Assert: Verify error handling + # Verify email service was called (and failed) + mock_email_service_instance.send_email.assert_called_once() + + # Verify error was logged + error_calls = [ + call + for call in mock_logger.exception.call_args_list + if f"Send email code login mail to {test_email} failed" in str(call) + ] + # Check if any exception call was made (the exact message format may vary) + assert mock_logger.exception.call_count >= 1, f"Error should be logged for {error_name}" + + # Reset side effect for next iteration + mock_email_service_instance.send_email.side_effect = None diff --git a/api/tests/test_containers_integration_tests/tasks/test_mail_inner_task.py b/api/tests/test_containers_integration_tests/tasks/test_mail_inner_task.py new file mode 100644 index 0000000000..d67794654f --- /dev/null +++ b/api/tests/test_containers_integration_tests/tasks/test_mail_inner_task.py @@ -0,0 +1,261 @@ +from unittest.mock import MagicMock, patch + +import pytest +from faker import Faker + +from tasks.mail_inner_task import send_inner_email_task + + +class TestMailInnerTask: + """Integration tests for send_inner_email_task using testcontainers.""" + + @pytest.fixture + def mock_external_service_dependencies(self): + """Mock setup for external service dependencies.""" + with ( + patch("tasks.mail_inner_task.mail") as mock_mail, + patch("tasks.mail_inner_task.get_email_i18n_service") as mock_get_email_i18n_service, + patch("tasks.mail_inner_task._render_template_with_strategy") as mock_render_template, + ): + # Setup mock mail service + mock_mail.is_inited.return_value = True + + # Setup mock email i18n service + mock_email_service = MagicMock() + mock_get_email_i18n_service.return_value = mock_email_service + + # Setup mock template rendering + mock_render_template.return_value = "Test email content" + + yield { + "mail": mock_mail, + "email_service": mock_email_service, + "render_template": mock_render_template, + } + + def _create_test_email_data(self, fake: Faker) -> dict: + """ + Helper method to create test email data for testing. + + Args: + fake: Faker instance for generating test data + + Returns: + dict: Test email data including recipients, subject, body, and substitutions + """ + return { + "to": [fake.email() for _ in range(3)], + "subject": fake.sentence(nb_words=4), + "body": "Hello {{name}}, this is a test email from {{company}}.", + "substitutions": { + "name": fake.name(), + "company": fake.company(), + "date": fake.date(), + }, + } + + def test_send_inner_email_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful email sending with valid data. + + This test verifies: + - Proper email service initialization check + - Template rendering with substitutions + - Email service integration + - Multiple recipient handling + """ + # Arrange: Create test data + fake = Faker() + email_data = self._create_test_email_data(fake) + + # Act: Execute the task + send_inner_email_task( + to=email_data["to"], + subject=email_data["subject"], + body=email_data["body"], + substitutions=email_data["substitutions"], + ) + + # Assert: Verify the expected outcomes + # Verify mail service was checked for initialization + mock_external_service_dependencies["mail"].is_inited.assert_called_once() + + # Verify template rendering was called with correct parameters + mock_external_service_dependencies["render_template"].assert_called_once_with( + email_data["body"], email_data["substitutions"] + ) + + # Verify email service was called once with the full recipient list + mock_email_service = mock_external_service_dependencies["email_service"] + mock_email_service.send_raw_email.assert_called_once_with( + to=email_data["to"], + subject=email_data["subject"], + html_content="Test email content", + ) + + def test_send_inner_email_single_recipient(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test email sending with single recipient. + + This test verifies: + - Single recipient handling + - Template rendering + - Email service integration + """ + # Arrange: Create test data with single recipient + fake = Faker() + email_data = { + "to": [fake.email()], + "subject": fake.sentence(nb_words=3), + "body": "Welcome {{user_name}}!", + "substitutions": { + "user_name": fake.name(), + }, + } + + # Act: Execute the task + send_inner_email_task( + to=email_data["to"], + subject=email_data["subject"], + body=email_data["body"], + substitutions=email_data["substitutions"], + ) + + # Assert: Verify the expected outcomes + mock_email_service = mock_external_service_dependencies["email_service"] + mock_email_service.send_raw_email.assert_called_once_with( + to=email_data["to"], + subject=email_data["subject"], + html_content="Test email content", + ) + + def test_send_inner_email_empty_substitutions(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test email sending with empty substitutions. + + This test verifies: + - Template rendering with empty substitutions + - Email service integration + - Handling of minimal template context + """ + # Arrange: Create test data with empty substitutions + fake = Faker() + email_data = { + "to": [fake.email()], + "subject": fake.sentence(nb_words=3), + "body": "This is a simple email without variables.", + "substitutions": {}, + } + + # Act: Execute the task + send_inner_email_task( + to=email_data["to"], + subject=email_data["subject"], + body=email_data["body"], + substitutions=email_data["substitutions"], + ) + + # Assert: Verify the expected outcomes + mock_external_service_dependencies["render_template"].assert_called_once_with(email_data["body"], {}) + + mock_email_service = mock_external_service_dependencies["email_service"] + mock_email_service.send_raw_email.assert_called_once_with( + to=email_data["to"], + subject=email_data["subject"], + html_content="Test email content", + ) + + def test_send_inner_email_mail_not_initialized( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test email sending when mail service is not initialized. + + This test verifies: + - Early return when mail service is not initialized + - No template rendering occurs + - No email service calls + - No exceptions raised + """ + # Arrange: Setup mail service as not initialized + mock_external_service_dependencies["mail"].is_inited.return_value = False + + fake = Faker() + email_data = self._create_test_email_data(fake) + + # Act: Execute the task + send_inner_email_task( + to=email_data["to"], + subject=email_data["subject"], + body=email_data["body"], + substitutions=email_data["substitutions"], + ) + + # Assert: Verify no processing occurred + mock_external_service_dependencies["render_template"].assert_not_called() + mock_external_service_dependencies["email_service"].send_raw_email.assert_not_called() + + def test_send_inner_email_template_rendering_error( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test email sending when template rendering fails. + + This test verifies: + - Exception handling during template rendering + - No email service calls when template fails + """ + # Arrange: Setup template rendering to raise an exception + mock_external_service_dependencies["render_template"].side_effect = Exception("Template rendering failed") + + fake = Faker() + email_data = self._create_test_email_data(fake) + + # Act: Execute the task + send_inner_email_task( + to=email_data["to"], + subject=email_data["subject"], + body=email_data["body"], + substitutions=email_data["substitutions"], + ) + + # Assert: Verify template rendering was attempted + mock_external_service_dependencies["render_template"].assert_called_once() + + # Verify no email service calls due to exception + mock_external_service_dependencies["email_service"].send_raw_email.assert_not_called() + + def test_send_inner_email_service_error(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test email sending when email service fails. + + This test verifies: + - Exception handling during email sending + - Graceful error handling + """ + # Arrange: Setup email service to raise an exception + mock_external_service_dependencies["email_service"].send_raw_email.side_effect = Exception( + "Email service failed" + ) + + fake = Faker() + email_data = self._create_test_email_data(fake) + + # Act: Execute the task + send_inner_email_task( + to=email_data["to"], + subject=email_data["subject"], + body=email_data["body"], + substitutions=email_data["substitutions"], + ) + + # Assert: Verify template rendering occurred + mock_external_service_dependencies["render_template"].assert_called_once() + + # Verify email service was called (and failed) + mock_email_service = mock_external_service_dependencies["email_service"] + mock_email_service.send_raw_email.assert_called_once_with( + to=email_data["to"], + subject=email_data["subject"], + html_content="Test email content", + ) diff --git a/api/tests/test_containers_integration_tests/tasks/test_mail_invite_member_task.py b/api/tests/test_containers_integration_tests/tasks/test_mail_invite_member_task.py new file mode 100644 index 0000000000..c083861004 --- /dev/null +++ b/api/tests/test_containers_integration_tests/tasks/test_mail_invite_member_task.py @@ -0,0 +1,544 @@ +""" +Integration tests for mail_invite_member_task using testcontainers. + +This module provides integration tests for the invite member email task +using TestContainers infrastructure. The tests ensure that the task properly sends +invitation emails with internationalization support, handles error scenarios, +and integrates correctly with the database and Redis for token management. + +All tests use the testcontainers infrastructure to ensure proper database isolation +and realistic testing scenarios with actual PostgreSQL and Redis instances. +""" + +import json +import uuid +from datetime import UTC, datetime +from unittest.mock import MagicMock, patch + +import pytest +from faker import Faker + +from extensions.ext_redis import redis_client +from libs.email_i18n import EmailType +from models.account import Account, AccountStatus, Tenant, TenantAccountJoin, TenantAccountRole +from tasks.mail_invite_member_task import send_invite_member_mail_task + + +class TestMailInviteMemberTask: + """ + Integration tests for send_invite_member_mail_task using testcontainers. + + This test class covers the core functionality of the invite member email task: + - Email sending with proper internationalization + - Template context generation and URL construction + - Error handling for failure scenarios + - Integration with Redis for token validation + - Mail service initialization checks + - Real database integration with actual invitation flow + + All tests use the testcontainers infrastructure to ensure proper database isolation + and realistic testing environment with actual database and Redis interactions. + """ + + @pytest.fixture(autouse=True) + def cleanup_database(self, db_session_with_containers): + """Clean up database before each test to ensure isolation.""" + # Clear all test data + db_session_with_containers.query(TenantAccountJoin).delete() + db_session_with_containers.query(Tenant).delete() + db_session_with_containers.query(Account).delete() + db_session_with_containers.commit() + + # Clear Redis cache + redis_client.flushdb() + + @pytest.fixture + def mock_external_service_dependencies(self): + """Mock setup for external service dependencies.""" + with ( + patch("tasks.mail_invite_member_task.mail") as mock_mail, + patch("tasks.mail_invite_member_task.get_email_i18n_service") as mock_email_service, + patch("tasks.mail_invite_member_task.dify_config") as mock_config, + ): + # Setup mail service mock + mock_mail.is_inited.return_value = True + + # Setup email service mock + mock_email_service_instance = MagicMock() + mock_email_service_instance.send_email.return_value = None + mock_email_service.return_value = mock_email_service_instance + + # Setup config mock + mock_config.CONSOLE_WEB_URL = "https://console.dify.ai" + + yield { + "mail": mock_mail, + "email_service": mock_email_service_instance, + "config": mock_config, + } + + def _create_test_account_and_tenant(self, db_session_with_containers): + """ + Helper method to create a test account and tenant for testing. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + + Returns: + tuple: (Account, Tenant) created instances + """ + fake = Faker() + + # Create account + account = Account( + email=fake.email(), + name=fake.name(), + password=fake.password(), + interface_language="en-US", + status=AccountStatus.ACTIVE, + ) + account.created_at = datetime.now(UTC) + account.updated_at = datetime.now(UTC) + db_session_with_containers.add(account) + db_session_with_containers.commit() + db_session_with_containers.refresh(account) + + # Create tenant + tenant = Tenant( + name=fake.company(), + ) + tenant.created_at = datetime.now(UTC) + tenant.updated_at = datetime.now(UTC) + db_session_with_containers.add(tenant) + db_session_with_containers.commit() + db_session_with_containers.refresh(tenant) + + # Create tenant member relationship + tenant_join = TenantAccountJoin( + tenant_id=tenant.id, + account_id=account.id, + role=TenantAccountRole.OWNER, + ) + tenant_join.created_at = datetime.now(UTC) + db_session_with_containers.add(tenant_join) + db_session_with_containers.commit() + + return account, tenant + + def _create_invitation_token(self, tenant, account): + """ + Helper method to create a valid invitation token in Redis. + + Args: + tenant: Tenant instance + account: Account instance + + Returns: + str: Generated invitation token + """ + token = str(uuid.uuid4()) + invitation_data = { + "account_id": account.id, + "email": account.email, + "workspace_id": tenant.id, + } + cache_key = f"member_invite:token:{token}" + redis_client.setex(cache_key, 24 * 60 * 60, json.dumps(invitation_data)) # 24 hours + return token + + def _create_pending_account_for_invitation(self, db_session_with_containers, email, tenant): + """ + Helper method to create a pending account for invitation testing. + + Args: + db_session_with_containers: Database session + email: Email address for the account + tenant: Tenant instance + + Returns: + Account: Created pending account + """ + account = Account( + email=email, + name=email.split("@")[0], + password="", + interface_language="en-US", + status=AccountStatus.PENDING, + ) + + account.created_at = datetime.now(UTC) + account.updated_at = datetime.now(UTC) + db_session_with_containers.add(account) + db_session_with_containers.commit() + db_session_with_containers.refresh(account) + + # Create tenant member relationship + tenant_join = TenantAccountJoin( + tenant_id=tenant.id, + account_id=account.id, + role=TenantAccountRole.NORMAL, + ) + tenant_join.created_at = datetime.now(UTC) + db_session_with_containers.add(tenant_join) + db_session_with_containers.commit() + + return account + + def test_send_invite_member_mail_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful invitation email sending with all parameters. + + This test verifies: + - Email service is called with correct parameters + - Template context includes all required fields + - URL is constructed correctly with token + - Performance logging is recorded + - No exceptions are raised + """ + # Arrange: Create test data + inviter, tenant = self._create_test_account_and_tenant(db_session_with_containers) + invitee_email = "test@example.com" + language = "en-US" + token = self._create_invitation_token(tenant, inviter) + inviter_name = inviter.name + workspace_name = tenant.name + + # Act: Execute the task + send_invite_member_mail_task( + language=language, + to=invitee_email, + token=token, + inviter_name=inviter_name, + workspace_name=workspace_name, + ) + + # Assert: Verify email service was called correctly + mock_email_service = mock_external_service_dependencies["email_service"] + mock_email_service.send_email.assert_called_once() + + # Verify call arguments + call_args = mock_email_service.send_email.call_args + assert call_args[1]["email_type"] == EmailType.INVITE_MEMBER + assert call_args[1]["language_code"] == language + assert call_args[1]["to"] == invitee_email + + # Verify template context + template_context = call_args[1]["template_context"] + assert template_context["to"] == invitee_email + assert template_context["inviter_name"] == inviter_name + assert template_context["workspace_name"] == workspace_name + assert template_context["url"] == f"https://console.dify.ai/activate?token={token}" + + def test_send_invite_member_mail_different_languages( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test invitation email sending with different language codes. + + This test verifies: + - Email service handles different language codes correctly + - Template context is passed correctly for each language + - No language-specific errors occur + """ + # Arrange: Create test data + inviter, tenant = self._create_test_account_and_tenant(db_session_with_containers) + token = self._create_invitation_token(tenant, inviter) + + test_languages = ["en-US", "zh-CN", "ja-JP", "fr-FR", "de-DE", "es-ES"] + + for language in test_languages: + # Act: Execute the task with different language + send_invite_member_mail_task( + language=language, + to="test@example.com", + token=token, + inviter_name=inviter.name, + workspace_name=tenant.name, + ) + + # Assert: Verify language code was passed correctly + mock_email_service = mock_external_service_dependencies["email_service"] + call_args = mock_email_service.send_email.call_args + assert call_args[1]["language_code"] == language + + def test_send_invite_member_mail_mail_not_initialized( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test behavior when mail service is not initialized. + + This test verifies: + - Task returns early when mail is not initialized + - Email service is not called + - No exceptions are raised + """ + # Arrange: Setup mail service as not initialized + mock_mail = mock_external_service_dependencies["mail"] + mock_mail.is_inited.return_value = False + + # Act: Execute the task + result = send_invite_member_mail_task( + language="en-US", + to="test@example.com", + token="test-token", + inviter_name="Test User", + workspace_name="Test Workspace", + ) + + # Assert: Verify early return + assert result is None + mock_email_service = mock_external_service_dependencies["email_service"] + mock_email_service.send_email.assert_not_called() + + def test_send_invite_member_mail_email_service_exception( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test error handling when email service raises an exception. + + This test verifies: + - Exception is caught and logged + - Task completes without raising exception + - Error logging is performed + """ + # Arrange: Setup email service to raise exception + mock_email_service = mock_external_service_dependencies["email_service"] + mock_email_service.send_email.side_effect = Exception("Email service failed") + + # Act & Assert: Execute task and verify exception is handled + with patch("tasks.mail_invite_member_task.logger") as mock_logger: + send_invite_member_mail_task( + language="en-US", + to="test@example.com", + token="test-token", + inviter_name="Test User", + workspace_name="Test Workspace", + ) + + # Verify error was logged + mock_logger.exception.assert_called_once() + error_call = mock_logger.exception.call_args[0][0] + assert "Send invite member mail to %s failed" in error_call + + def test_send_invite_member_mail_template_context_validation( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test template context contains all required fields for email rendering. + + This test verifies: + - All required template context fields are present + - Field values match expected data + - URL construction is correct + - No missing or None values in context + """ + # Arrange: Create test data with specific values + inviter, tenant = self._create_test_account_and_tenant(db_session_with_containers) + token = "test-token-123" + invitee_email = "invitee@example.com" + inviter_name = "John Doe" + workspace_name = "Acme Corp" + + # Act: Execute the task + send_invite_member_mail_task( + language="en-US", + to=invitee_email, + token=token, + inviter_name=inviter_name, + workspace_name=workspace_name, + ) + + # Assert: Verify template context + mock_email_service = mock_external_service_dependencies["email_service"] + call_args = mock_email_service.send_email.call_args + template_context = call_args[1]["template_context"] + + # Verify all required fields are present + required_fields = ["to", "inviter_name", "workspace_name", "url"] + for field in required_fields: + assert field in template_context + assert template_context[field] is not None + assert template_context[field] != "" + + # Verify specific values + assert template_context["to"] == invitee_email + assert template_context["inviter_name"] == inviter_name + assert template_context["workspace_name"] == workspace_name + assert template_context["url"] == f"https://console.dify.ai/activate?token={token}" + + def test_send_invite_member_mail_integration_with_redis_token( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test integration with Redis token validation. + + This test verifies: + - Task works with real Redis token data + - Token validation can be performed after email sending + - Redis data integrity is maintained + """ + # Arrange: Create test data and store token in Redis + inviter, tenant = self._create_test_account_and_tenant(db_session_with_containers) + token = self._create_invitation_token(tenant, inviter) + + # Verify token exists in Redis before sending email + cache_key = f"member_invite:token:{token}" + assert redis_client.exists(cache_key) == 1 + + # Act: Execute the task + send_invite_member_mail_task( + language="en-US", + to=inviter.email, + token=token, + inviter_name=inviter.name, + workspace_name=tenant.name, + ) + + # Assert: Verify token still exists after email sending + assert redis_client.exists(cache_key) == 1 + + # Verify token data integrity + token_data = redis_client.get(cache_key) + assert token_data is not None + invitation_data = json.loads(token_data) + assert invitation_data["account_id"] == inviter.id + assert invitation_data["email"] == inviter.email + assert invitation_data["workspace_id"] == tenant.id + + def test_send_invite_member_mail_with_special_characters( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test email sending with special characters in names and workspace names. + + This test verifies: + - Special characters are handled correctly in template context + - Email service receives properly formatted data + - No encoding issues occur + """ + # Arrange: Create test data with special characters + inviter, tenant = self._create_test_account_and_tenant(db_session_with_containers) + token = self._create_invitation_token(tenant, inviter) + + special_cases = [ + ("John O'Connor", "Acme & Co."), + ("José María", "Café & Restaurant"), + ("李小明", "北京科技有限公司"), + ("François & Marie", "L'École Internationale"), + ("Александр", "ООО Технологии"), + ("محمد أحمد", "شركة التقنية المتقدمة"), + ] + + for inviter_name, workspace_name in special_cases: + # Act: Execute the task + send_invite_member_mail_task( + language="en-US", + to="test@example.com", + token=token, + inviter_name=inviter_name, + workspace_name=workspace_name, + ) + + # Assert: Verify special characters are preserved + mock_email_service = mock_external_service_dependencies["email_service"] + call_args = mock_email_service.send_email.call_args + template_context = call_args[1]["template_context"] + + assert template_context["inviter_name"] == inviter_name + assert template_context["workspace_name"] == workspace_name + + def test_send_invite_member_mail_real_database_integration( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test real database integration with actual invitation flow. + + This test verifies: + - Task works with real database entities + - Account and tenant relationships are properly maintained + - Database state is consistent after email sending + - Real invitation data flow is tested + """ + # Arrange: Create real database entities + inviter, tenant = self._create_test_account_and_tenant(db_session_with_containers) + invitee_email = "newmember@example.com" + + # Create a pending account for invitation (simulating real invitation flow) + pending_account = self._create_pending_account_for_invitation(db_session_with_containers, invitee_email, tenant) + + # Create invitation token with real account data + token = self._create_invitation_token(tenant, pending_account) + + # Act: Execute the task with real data + send_invite_member_mail_task( + language="en-US", + to=invitee_email, + token=token, + inviter_name=inviter.name, + workspace_name=tenant.name, + ) + + # Assert: Verify email service was called with real data + mock_email_service = mock_external_service_dependencies["email_service"] + mock_email_service.send_email.assert_called_once() + + # Verify database state is maintained + db_session_with_containers.refresh(pending_account) + db_session_with_containers.refresh(tenant) + + assert pending_account.status == AccountStatus.PENDING + assert pending_account.email == invitee_email + assert tenant.name is not None + + # Verify tenant relationship exists + tenant_join = ( + db_session_with_containers.query(TenantAccountJoin) + .filter_by(tenant_id=tenant.id, account_id=pending_account.id) + .first() + ) + assert tenant_join is not None + assert tenant_join.role == TenantAccountRole.NORMAL + + def test_send_invite_member_mail_token_lifecycle_management( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test token lifecycle management and validation. + + This test verifies: + - Token is properly stored in Redis with correct TTL + - Token data structure is correct + - Token can be retrieved and validated after email sending + - Token expiration is handled correctly + """ + # Arrange: Create test data + inviter, tenant = self._create_test_account_and_tenant(db_session_with_containers) + token = self._create_invitation_token(tenant, inviter) + + # Act: Execute the task + send_invite_member_mail_task( + language="en-US", + to=inviter.email, + token=token, + inviter_name=inviter.name, + workspace_name=tenant.name, + ) + + # Assert: Verify token lifecycle + cache_key = f"member_invite:token:{token}" + + # Token should still exist + assert redis_client.exists(cache_key) == 1 + + # Token should have correct TTL (approximately 24 hours) + ttl = redis_client.ttl(cache_key) + assert 23 * 60 * 60 <= ttl <= 24 * 60 * 60 # Allow some tolerance + + # Token data should be valid + token_data = redis_client.get(cache_key) + assert token_data is not None + + invitation_data = json.loads(token_data) + assert invitation_data["account_id"] == inviter.id + assert invitation_data["email"] == inviter.email + assert invitation_data["workspace_id"] == tenant.id diff --git a/api/tests/unit_tests/configs/test_dify_config.py b/api/tests/unit_tests/configs/test_dify_config.py index fbe14f1cb5..209b6bf59b 100644 --- a/api/tests/unit_tests/configs/test_dify_config.py +++ b/api/tests/unit_tests/configs/test_dify_config.py @@ -15,13 +15,13 @@ def test_dify_config(monkeypatch: pytest.MonkeyPatch): # Set environment variables using monkeypatch monkeypatch.setenv("CONSOLE_API_URL", "https://example.com") monkeypatch.setenv("CONSOLE_WEB_URL", "https://example.com") - monkeypatch.setenv("HTTP_REQUEST_MAX_WRITE_TIMEOUT", "30") + monkeypatch.setenv("HTTP_REQUEST_MAX_WRITE_TIMEOUT", "30") # Custom value for testing monkeypatch.setenv("DB_USERNAME", "postgres") monkeypatch.setenv("DB_PASSWORD", "postgres") monkeypatch.setenv("DB_HOST", "localhost") monkeypatch.setenv("DB_PORT", "5432") monkeypatch.setenv("DB_DATABASE", "dify") - monkeypatch.setenv("HTTP_REQUEST_MAX_READ_TIMEOUT", "600") + monkeypatch.setenv("HTTP_REQUEST_MAX_READ_TIMEOUT", "300") # Custom value for testing # load dotenv file with pydantic-settings config = DifyConfig() @@ -33,19 +33,38 @@ def test_dify_config(monkeypatch: pytest.MonkeyPatch): assert config.EDITION == "SELF_HOSTED" assert config.API_COMPRESSION_ENABLED is False assert config.SENTRY_TRACES_SAMPLE_RATE == 1.0 + assert config.TEMPLATE_TRANSFORM_MAX_LENGTH == 400_000 - # annotated field with default value - assert config.HTTP_REQUEST_MAX_READ_TIMEOUT == 600 + # annotated field with custom configured value + assert config.HTTP_REQUEST_MAX_READ_TIMEOUT == 300 - # annotated field with configured value + # annotated field with custom configured value assert config.HTTP_REQUEST_MAX_WRITE_TIMEOUT == 30 - assert config.WORKFLOW_PARALLEL_DEPTH_LIMIT == 3 - # values from pyproject.toml assert Version(config.project.version) >= Version("1.0.0") +def test_http_timeout_defaults(monkeypatch: pytest.MonkeyPatch): + """Test that HTTP timeout defaults are correctly set""" + # clear system environment variables + os.environ.clear() + + # Set minimal required env vars + monkeypatch.setenv("DB_USERNAME", "postgres") + monkeypatch.setenv("DB_PASSWORD", "postgres") + monkeypatch.setenv("DB_HOST", "localhost") + monkeypatch.setenv("DB_PORT", "5432") + monkeypatch.setenv("DB_DATABASE", "dify") + + config = DifyConfig() + + # Verify default timeout values + assert config.HTTP_REQUEST_MAX_CONNECT_TIMEOUT == 10 + assert config.HTTP_REQUEST_MAX_READ_TIMEOUT == 600 + assert config.HTTP_REQUEST_MAX_WRITE_TIMEOUT == 600 + + # NOTE: If there is a `.env` file in your Workspace, this test might not succeed as expected. # This is due to `pymilvus` loading all the variables from the `.env` file into `os.environ`. def test_flask_configs(monkeypatch: pytest.MonkeyPatch): @@ -56,7 +75,6 @@ def test_flask_configs(monkeypatch: pytest.MonkeyPatch): # Set environment variables using monkeypatch monkeypatch.setenv("CONSOLE_API_URL", "https://example.com") monkeypatch.setenv("CONSOLE_WEB_URL", "https://example.com") - monkeypatch.setenv("HTTP_REQUEST_MAX_WRITE_TIMEOUT", "30") monkeypatch.setenv("DB_USERNAME", "postgres") monkeypatch.setenv("DB_PASSWORD", "postgres") monkeypatch.setenv("DB_HOST", "localhost") @@ -106,7 +124,6 @@ def test_inner_api_config_exist(monkeypatch: pytest.MonkeyPatch): # Set environment variables using monkeypatch monkeypatch.setenv("CONSOLE_API_URL", "https://example.com") monkeypatch.setenv("CONSOLE_WEB_URL", "https://example.com") - monkeypatch.setenv("HTTP_REQUEST_MAX_WRITE_TIMEOUT", "30") monkeypatch.setenv("DB_USERNAME", "postgres") monkeypatch.setenv("DB_PASSWORD", "postgres") monkeypatch.setenv("DB_HOST", "localhost") diff --git a/api/tests/unit_tests/controllers/console/app/test_description_validation.py b/api/tests/unit_tests/controllers/console/app/test_description_validation.py index 178267e560..dcc408a21c 100644 --- a/api/tests/unit_tests/controllers/console/app/test_description_validation.py +++ b/api/tests/unit_tests/controllers/console/app/test_description_validation.py @@ -1,174 +1,53 @@ import pytest -from controllers.console.app.app import _validate_description_length as app_validate -from controllers.console.datasets.datasets import _validate_description_length as dataset_validate -from controllers.service_api.dataset.dataset import _validate_description_length as service_dataset_validate +from libs.validators import validate_description_length class TestDescriptionValidationUnit: - """Unit tests for description validation functions in App and Dataset APIs""" + """Unit tests for the centralized description validation function.""" - def test_app_validate_description_length_valid(self): - """Test App validation function with valid descriptions""" + def test_validate_description_length_valid(self): + """Test validation function with valid descriptions.""" # Empty string should be valid - assert app_validate("") == "" + assert validate_description_length("") == "" # None should be valid - assert app_validate(None) is None + assert validate_description_length(None) is None # Short description should be valid short_desc = "Short description" - assert app_validate(short_desc) == short_desc + assert validate_description_length(short_desc) == short_desc # Exactly 400 characters should be valid exactly_400 = "x" * 400 - assert app_validate(exactly_400) == exactly_400 + assert validate_description_length(exactly_400) == exactly_400 # Just under limit should be valid just_under = "x" * 399 - assert app_validate(just_under) == just_under + assert validate_description_length(just_under) == just_under - def test_app_validate_description_length_invalid(self): - """Test App validation function with invalid descriptions""" + def test_validate_description_length_invalid(self): + """Test validation function with invalid descriptions.""" # 401 characters should fail just_over = "x" * 401 with pytest.raises(ValueError) as exc_info: - app_validate(just_over) + validate_description_length(just_over) assert "Description cannot exceed 400 characters." in str(exc_info.value) # 500 characters should fail way_over = "x" * 500 with pytest.raises(ValueError) as exc_info: - app_validate(way_over) + validate_description_length(way_over) assert "Description cannot exceed 400 characters." in str(exc_info.value) # 1000 characters should fail very_long = "x" * 1000 with pytest.raises(ValueError) as exc_info: - app_validate(very_long) + validate_description_length(very_long) assert "Description cannot exceed 400 characters." in str(exc_info.value) - def test_dataset_validate_description_length_valid(self): - """Test Dataset validation function with valid descriptions""" - # Empty string should be valid - assert dataset_validate("") == "" - - # Short description should be valid - short_desc = "Short description" - assert dataset_validate(short_desc) == short_desc - - # Exactly 400 characters should be valid - exactly_400 = "x" * 400 - assert dataset_validate(exactly_400) == exactly_400 - - # Just under limit should be valid - just_under = "x" * 399 - assert dataset_validate(just_under) == just_under - - def test_dataset_validate_description_length_invalid(self): - """Test Dataset validation function with invalid descriptions""" - # 401 characters should fail - just_over = "x" * 401 - with pytest.raises(ValueError) as exc_info: - dataset_validate(just_over) - assert "Description cannot exceed 400 characters." in str(exc_info.value) - - # 500 characters should fail - way_over = "x" * 500 - with pytest.raises(ValueError) as exc_info: - dataset_validate(way_over) - assert "Description cannot exceed 400 characters." in str(exc_info.value) - - def test_service_dataset_validate_description_length_valid(self): - """Test Service Dataset validation function with valid descriptions""" - # Empty string should be valid - assert service_dataset_validate("") == "" - - # None should be valid - assert service_dataset_validate(None) is None - - # Short description should be valid - short_desc = "Short description" - assert service_dataset_validate(short_desc) == short_desc - - # Exactly 400 characters should be valid - exactly_400 = "x" * 400 - assert service_dataset_validate(exactly_400) == exactly_400 - - # Just under limit should be valid - just_under = "x" * 399 - assert service_dataset_validate(just_under) == just_under - - def test_service_dataset_validate_description_length_invalid(self): - """Test Service Dataset validation function with invalid descriptions""" - # 401 characters should fail - just_over = "x" * 401 - with pytest.raises(ValueError) as exc_info: - service_dataset_validate(just_over) - assert "Description cannot exceed 400 characters." in str(exc_info.value) - - # 500 characters should fail - way_over = "x" * 500 - with pytest.raises(ValueError) as exc_info: - service_dataset_validate(way_over) - assert "Description cannot exceed 400 characters." in str(exc_info.value) - - def test_app_dataset_validation_consistency(self): - """Test that App and Dataset validation functions behave identically""" - test_cases = [ - "", # Empty string - "Short description", # Normal description - "x" * 100, # Medium description - "x" * 400, # Exactly at limit - ] - - # Test valid cases produce same results - for test_desc in test_cases: - assert app_validate(test_desc) == dataset_validate(test_desc) == service_dataset_validate(test_desc) - - # Test invalid cases produce same errors - invalid_cases = [ - "x" * 401, # Just over limit - "x" * 500, # Way over limit - "x" * 1000, # Very long - ] - - for invalid_desc in invalid_cases: - app_error = None - dataset_error = None - service_dataset_error = None - - # Capture App validation error - try: - app_validate(invalid_desc) - except ValueError as e: - app_error = str(e) - - # Capture Dataset validation error - try: - dataset_validate(invalid_desc) - except ValueError as e: - dataset_error = str(e) - - # Capture Service Dataset validation error - try: - service_dataset_validate(invalid_desc) - except ValueError as e: - service_dataset_error = str(e) - - # All should produce errors - assert app_error is not None, f"App validation should fail for {len(invalid_desc)} characters" - assert dataset_error is not None, f"Dataset validation should fail for {len(invalid_desc)} characters" - error_msg = f"Service Dataset validation should fail for {len(invalid_desc)} characters" - assert service_dataset_error is not None, error_msg - - # Errors should be identical - error_msg = f"Error messages should be identical for {len(invalid_desc)} characters" - assert app_error == dataset_error == service_dataset_error, error_msg - assert app_error == "Description cannot exceed 400 characters." - def test_boundary_values(self): - """Test boundary values around the 400 character limit""" + """Test boundary values around the 400 character limit.""" boundary_tests = [ (0, True), # Empty (1, True), # Minimum @@ -184,69 +63,45 @@ class TestDescriptionValidationUnit: if should_pass: # Should not raise exception - assert app_validate(test_desc) == test_desc - assert dataset_validate(test_desc) == test_desc - assert service_dataset_validate(test_desc) == test_desc + assert validate_description_length(test_desc) == test_desc else: # Should raise ValueError with pytest.raises(ValueError): - app_validate(test_desc) - with pytest.raises(ValueError): - dataset_validate(test_desc) - with pytest.raises(ValueError): - service_dataset_validate(test_desc) + validate_description_length(test_desc) def test_special_characters(self): """Test validation with special characters, Unicode, etc.""" # Unicode characters unicode_desc = "测试描述" * 100 # Chinese characters if len(unicode_desc) <= 400: - assert app_validate(unicode_desc) == unicode_desc - assert dataset_validate(unicode_desc) == unicode_desc - assert service_dataset_validate(unicode_desc) == unicode_desc + assert validate_description_length(unicode_desc) == unicode_desc # Special characters special_desc = "Special chars: !@#$%^&*()_+-=[]{}|;':\",./<>?" * 10 if len(special_desc) <= 400: - assert app_validate(special_desc) == special_desc - assert dataset_validate(special_desc) == special_desc - assert service_dataset_validate(special_desc) == special_desc + assert validate_description_length(special_desc) == special_desc # Mixed content mixed_desc = "Mixed content: 测试 123 !@# " * 15 if len(mixed_desc) <= 400: - assert app_validate(mixed_desc) == mixed_desc - assert dataset_validate(mixed_desc) == mixed_desc - assert service_dataset_validate(mixed_desc) == mixed_desc + assert validate_description_length(mixed_desc) == mixed_desc elif len(mixed_desc) > 400: with pytest.raises(ValueError): - app_validate(mixed_desc) - with pytest.raises(ValueError): - dataset_validate(mixed_desc) - with pytest.raises(ValueError): - service_dataset_validate(mixed_desc) + validate_description_length(mixed_desc) def test_whitespace_handling(self): - """Test validation with various whitespace scenarios""" + """Test validation with various whitespace scenarios.""" # Leading/trailing whitespace whitespace_desc = " Description with whitespace " if len(whitespace_desc) <= 400: - assert app_validate(whitespace_desc) == whitespace_desc - assert dataset_validate(whitespace_desc) == whitespace_desc - assert service_dataset_validate(whitespace_desc) == whitespace_desc + assert validate_description_length(whitespace_desc) == whitespace_desc # Newlines and tabs multiline_desc = "Line 1\nLine 2\tTabbed content" if len(multiline_desc) <= 400: - assert app_validate(multiline_desc) == multiline_desc - assert dataset_validate(multiline_desc) == multiline_desc - assert service_dataset_validate(multiline_desc) == multiline_desc + assert validate_description_length(multiline_desc) == multiline_desc # Only whitespace over limit only_spaces = " " * 401 with pytest.raises(ValueError): - app_validate(only_spaces) - with pytest.raises(ValueError): - dataset_validate(only_spaces) - with pytest.raises(ValueError): - service_dataset_validate(only_spaces) + validate_description_length(only_spaces) diff --git a/api/tests/unit_tests/controllers/console/auth/test_oauth.py b/api/tests/unit_tests/controllers/console/auth/test_oauth.py index a7bdf5de33..67f4b85413 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_oauth.py +++ b/api/tests/unit_tests/controllers/console/auth/test_oauth.py @@ -143,7 +143,7 @@ class TestOAuthCallback: oauth_provider.get_user_info.return_value = OAuthUserInfo(id="123", name="Test User", email="test@example.com") account = MagicMock() - account.status = AccountStatus.ACTIVE.value + account.status = AccountStatus.ACTIVE token_pair = MagicMock() token_pair.access_token = "jwt_access_token" @@ -201,9 +201,9 @@ class TestOAuthCallback: mock_db.session.rollback = MagicMock() # Import the real requests module to create a proper exception - import requests + import httpx - request_exception = requests.exceptions.RequestException("OAuth error") + request_exception = httpx.RequestError("OAuth error") request_exception.response = MagicMock() request_exception.response.text = str(exception) @@ -220,11 +220,11 @@ class TestOAuthCallback: @pytest.mark.parametrize( ("account_status", "expected_redirect"), [ - (AccountStatus.BANNED.value, "http://localhost:3000/signin?message=Account is banned."), + (AccountStatus.BANNED, "http://localhost:3000/signin?message=Account is banned."), # CLOSED status: Currently NOT handled, will proceed to login (security issue) # This documents actual behavior. See test_defensive_check_for_closed_account_status for details ( - AccountStatus.CLOSED.value, + AccountStatus.CLOSED, "http://localhost:3000?access_token=jwt_access_token&refresh_token=jwt_refresh_token", ), ], @@ -296,13 +296,13 @@ class TestOAuthCallback: mock_get_providers.return_value = {"github": oauth_setup["provider"]} mock_account = MagicMock() - mock_account.status = AccountStatus.PENDING.value + mock_account.status = AccountStatus.PENDING mock_generate_account.return_value = mock_account with app.test_request_context("/auth/oauth/github/callback?code=test_code"): resource.get("github") - assert mock_account.status == AccountStatus.ACTIVE.value + assert mock_account.status == AccountStatus.ACTIVE assert mock_account.initialized_at is not None mock_db.session.commit.assert_called_once() @@ -352,7 +352,7 @@ class TestOAuthCallback: # Create account with CLOSED status closed_account = MagicMock() - closed_account.status = AccountStatus.CLOSED.value + closed_account.status = AccountStatus.CLOSED closed_account.id = "123" closed_account.name = "Closed Account" mock_generate_account.return_value = closed_account diff --git a/api/tests/unit_tests/core/ops/test_config_entity.py b/api/tests/unit_tests/core/ops/test_config_entity.py index 1dc380ad0b..2cbff54c42 100644 --- a/api/tests/unit_tests/core/ops/test_config_entity.py +++ b/api/tests/unit_tests/core/ops/test_config_entity.py @@ -329,20 +329,20 @@ class TestAliyunConfig: assert config.endpoint == "https://tracing-analysis-dc-hz.aliyuncs.com" def test_endpoint_validation_with_path(self): - """Test endpoint validation normalizes URL by removing path""" + """Test endpoint validation preserves path for Aliyun endpoints""" config = AliyunConfig( license_key="test_license", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com/api/v1/traces" ) - assert config.endpoint == "https://tracing-analysis-dc-hz.aliyuncs.com" + assert config.endpoint == "https://tracing-analysis-dc-hz.aliyuncs.com/api/v1/traces" def test_endpoint_validation_invalid_scheme(self): """Test endpoint validation rejects invalid schemes""" - with pytest.raises(ValidationError, match="URL scheme must be one of"): + with pytest.raises(ValidationError, match="URL must start with https:// or http://"): AliyunConfig(license_key="test_license", endpoint="ftp://invalid.tracing-analysis-dc-hz.aliyuncs.com") def test_endpoint_validation_no_scheme(self): """Test endpoint validation rejects URLs without scheme""" - with pytest.raises(ValidationError, match="URL scheme must be one of"): + with pytest.raises(ValidationError, match="URL must start with https:// or http://"): AliyunConfig(license_key="test_license", endpoint="invalid.tracing-analysis-dc-hz.aliyuncs.com") def test_license_key_required(self): @@ -350,6 +350,23 @@ class TestAliyunConfig: with pytest.raises(ValidationError): AliyunConfig(license_key="", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com") + def test_valid_endpoint_format_examples(self): + """Test valid endpoint format examples from comments""" + valid_endpoints = [ + # cms2.0 public endpoint + "https://proj-xtrace-123456-cn-heyuan.cn-heyuan.log.aliyuncs.com/apm/trace/opentelemetry", + # cms2.0 intranet endpoint + "https://proj-xtrace-123456-cn-heyuan.cn-heyuan-intranet.log.aliyuncs.com/apm/trace/opentelemetry", + # xtrace public endpoint + "http://tracing-cn-heyuan.arms.aliyuncs.com", + # xtrace intranet endpoint + "http://tracing-cn-heyuan-internal.arms.aliyuncs.com", + ] + + for endpoint in valid_endpoints: + config = AliyunConfig(license_key="test_license", endpoint=endpoint) + assert config.endpoint == endpoint + class TestConfigIntegration: """Integration tests for configuration classes""" @@ -382,7 +399,7 @@ class TestConfigIntegration: assert arize_config.endpoint == "https://arize.com" assert phoenix_with_path_config.endpoint == "https://app.phoenix.arize.com/s/dify-integration" assert phoenix_without_path_config.endpoint == "https://app.phoenix.arize.com" - assert aliyun_config.endpoint == "https://tracing-analysis-dc-hz.aliyuncs.com" + assert aliyun_config.endpoint == "https://tracing-analysis-dc-hz.aliyuncs.com/api/v1/traces" def test_project_default_values(self): """Test that project default values are set correctly""" diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/alibabacloud_mysql/test_alibabacloud_mysql_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/alibabacloud_mysql/test_alibabacloud_mysql_vector.py new file mode 100644 index 0000000000..44fe272c8c --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/alibabacloud_mysql/test_alibabacloud_mysql_vector.py @@ -0,0 +1,722 @@ +import json +import unittest +from unittest.mock import MagicMock, patch + +import pytest + +from core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector import ( + AlibabaCloudMySQLVector, + AlibabaCloudMySQLVectorConfig, +) +from core.rag.models.document import Document + +try: + from mysql.connector import Error as MySQLError +except ImportError: + # Fallback for testing environments where mysql-connector-python might not be installed + class MySQLError(Exception): + def __init__(self, errno, msg): + self.errno = errno + self.msg = msg + super().__init__(msg) + + +class TestAlibabaCloudMySQLVector(unittest.TestCase): + def setUp(self): + self.config = AlibabaCloudMySQLVectorConfig( + host="localhost", + port=3306, + user="test_user", + password="test_password", + database="test_db", + max_connection=5, + charset="utf8mb4", + ) + self.collection_name = "test_collection" + + # Sample documents for testing + self.sample_documents = [ + Document( + page_content="This is a test document about AI.", + metadata={"doc_id": "doc1", "document_id": "dataset1", "source": "test"}, + ), + Document( + page_content="Another document about machine learning.", + metadata={"doc_id": "doc2", "document_id": "dataset1", "source": "test"}, + ), + ] + + # Sample embeddings + self.sample_embeddings = [[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8]] + + @patch( + "core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool" + ) + def test_init(self, mock_pool_class): + """Test AlibabaCloudMySQLVector initialization.""" + # Mock the connection pool + mock_pool = MagicMock() + mock_pool_class.return_value = mock_pool + + # Mock connection and cursor for vector support check + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_pool.get_connection.return_value = mock_conn + mock_conn.cursor.return_value = mock_cursor + mock_cursor.fetchone.side_effect = [ + {"VERSION()": "8.0.36"}, # Version check + {"vector_support": True}, # Vector support check + ] + + alibabacloud_mysql_vector = AlibabaCloudMySQLVector(self.collection_name, self.config) + + assert alibabacloud_mysql_vector.collection_name == self.collection_name + assert alibabacloud_mysql_vector.table_name == self.collection_name.lower() + assert alibabacloud_mysql_vector.get_type() == "alibabacloud_mysql" + assert alibabacloud_mysql_vector.distance_function == "cosine" + assert alibabacloud_mysql_vector.pool is not None + + @patch( + "core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool" + ) + @patch("core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.redis_client") + def test_create_collection(self, mock_redis, mock_pool_class): + """Test collection creation.""" + # Mock Redis operations + mock_redis.lock.return_value.__enter__ = MagicMock() + mock_redis.lock.return_value.__exit__ = MagicMock() + mock_redis.get.return_value = None + mock_redis.set.return_value = None + + # Mock the connection pool + mock_pool = MagicMock() + mock_pool_class.return_value = mock_pool + + # Mock connection and cursor + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_pool.get_connection.return_value = mock_conn + mock_conn.cursor.return_value = mock_cursor + mock_cursor.fetchone.side_effect = [ + {"VERSION()": "8.0.36"}, # Version check + {"vector_support": True}, # Vector support check + ] + + alibabacloud_mysql_vector = AlibabaCloudMySQLVector(self.collection_name, self.config) + alibabacloud_mysql_vector._create_collection(768) + + # Verify SQL execution calls - should include table creation and index creation + assert mock_cursor.execute.called + assert mock_cursor.execute.call_count >= 3 # CREATE TABLE + 2 indexes + mock_redis.set.assert_called_once() + + def test_config_validation(self): + """Test configuration validation.""" + # Test missing required fields + with pytest.raises(ValueError): + AlibabaCloudMySQLVectorConfig( + host="", # Empty host should raise error + port=3306, + user="test", + password="test", + database="test", + max_connection=5, + ) + + @patch( + "core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool" + ) + def test_vector_support_check_success(self, mock_pool_class): + """Test successful vector support check.""" + # Mock the connection pool + mock_pool = MagicMock() + mock_pool_class.return_value = mock_pool + + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_pool.get_connection.return_value = mock_conn + mock_conn.cursor.return_value = mock_cursor + mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}] + + # Should not raise an exception + vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config) + assert vector_store is not None + + @patch( + "core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool" + ) + def test_vector_support_check_failure(self, mock_pool_class): + """Test vector support check failure.""" + # Mock the connection pool + mock_pool = MagicMock() + mock_pool_class.return_value = mock_pool + + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_pool.get_connection.return_value = mock_conn + mock_conn.cursor.return_value = mock_cursor + mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.35"}, {"vector_support": False}] + + with pytest.raises(ValueError) as context: + AlibabaCloudMySQLVector(self.collection_name, self.config) + + assert "RDS MySQL Vector functions are not available" in str(context.value) + + @patch( + "core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool" + ) + def test_vector_support_check_function_error(self, mock_pool_class): + """Test vector support check with function not found error.""" + # Mock the connection pool + mock_pool = MagicMock() + mock_pool_class.return_value = mock_pool + + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_pool.get_connection.return_value = mock_conn + mock_conn.cursor.return_value = mock_cursor + mock_cursor.fetchone.return_value = {"VERSION()": "8.0.36"} + mock_cursor.execute.side_effect = [None, MySQLError(errno=1305, msg="FUNCTION VEC_FromText does not exist")] + + with pytest.raises(ValueError) as context: + AlibabaCloudMySQLVector(self.collection_name, self.config) + + assert "RDS MySQL Vector functions are not available" in str(context.value) + + @patch( + "core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool" + ) + @patch("core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.redis_client") + def test_create_documents(self, mock_redis, mock_pool_class): + """Test creating documents with embeddings.""" + # Setup mocks + self._setup_mocks(mock_redis, mock_pool_class) + + vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config) + result = vector_store.create(self.sample_documents, self.sample_embeddings) + + assert len(result) == 2 + assert "doc1" in result + assert "doc2" in result + + @patch( + "core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool" + ) + def test_add_texts(self, mock_pool_class): + """Test adding texts to the vector store.""" + # Mock the connection pool + mock_pool = MagicMock() + mock_pool_class.return_value = mock_pool + + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_pool.get_connection.return_value = mock_conn + mock_conn.cursor.return_value = mock_cursor + mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}] + + vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config) + result = vector_store.add_texts(self.sample_documents, self.sample_embeddings) + + assert len(result) == 2 + mock_cursor.executemany.assert_called_once() + + @patch( + "core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool" + ) + def test_text_exists(self, mock_pool_class): + """Test checking if text exists.""" + # Mock the connection pool + mock_pool = MagicMock() + mock_pool_class.return_value = mock_pool + + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_pool.get_connection.return_value = mock_conn + mock_conn.cursor.return_value = mock_cursor + mock_cursor.fetchone.side_effect = [ + {"VERSION()": "8.0.36"}, + {"vector_support": True}, + {"id": "doc1"}, # Text exists + ] + + vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config) + exists = vector_store.text_exists("doc1") + + assert exists + # Check that the correct SQL was executed (last call after init) + execute_calls = mock_cursor.execute.call_args_list + last_call = execute_calls[-1] + assert "SELECT id FROM" in last_call[0][0] + assert last_call[0][1] == ("doc1",) + + @patch( + "core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool" + ) + def test_text_not_exists(self, mock_pool_class): + """Test checking if text does not exist.""" + # Mock the connection pool + mock_pool = MagicMock() + mock_pool_class.return_value = mock_pool + + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_pool.get_connection.return_value = mock_conn + mock_conn.cursor.return_value = mock_cursor + mock_cursor.fetchone.side_effect = [ + {"VERSION()": "8.0.36"}, + {"vector_support": True}, + None, # Text does not exist + ] + + vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config) + exists = vector_store.text_exists("nonexistent") + + assert not exists + + @patch( + "core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool" + ) + def test_get_by_ids(self, mock_pool_class): + """Test getting documents by IDs.""" + # Mock the connection pool + mock_pool = MagicMock() + mock_pool_class.return_value = mock_pool + + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_pool.get_connection.return_value = mock_conn + mock_conn.cursor.return_value = mock_cursor + mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}] + mock_cursor.__iter__ = lambda self: iter( + [ + {"meta": json.dumps({"doc_id": "doc1", "source": "test"}), "text": "Test document 1"}, + {"meta": json.dumps({"doc_id": "doc2", "source": "test"}), "text": "Test document 2"}, + ] + ) + + vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config) + docs = vector_store.get_by_ids(["doc1", "doc2"]) + + assert len(docs) == 2 + assert docs[0].page_content == "Test document 1" + assert docs[1].page_content == "Test document 2" + + @patch( + "core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool" + ) + def test_get_by_ids_empty_list(self, mock_pool_class): + """Test getting documents with empty ID list.""" + # Mock the connection pool + mock_pool = MagicMock() + mock_pool_class.return_value = mock_pool + + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_pool.get_connection.return_value = mock_conn + mock_conn.cursor.return_value = mock_cursor + mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}] + + vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config) + docs = vector_store.get_by_ids([]) + + assert len(docs) == 0 + + @patch( + "core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool" + ) + def test_delete_by_ids(self, mock_pool_class): + """Test deleting documents by IDs.""" + # Mock the connection pool + mock_pool = MagicMock() + mock_pool_class.return_value = mock_pool + + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_pool.get_connection.return_value = mock_conn + mock_conn.cursor.return_value = mock_cursor + mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}] + + vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config) + vector_store.delete_by_ids(["doc1", "doc2"]) + + # Check that delete SQL was executed + execute_calls = mock_cursor.execute.call_args_list + delete_calls = [call for call in execute_calls if "DELETE" in str(call)] + assert len(delete_calls) == 1 + delete_call = delete_calls[0] + assert "DELETE FROM" in delete_call[0][0] + assert delete_call[0][1] == ["doc1", "doc2"] + + @patch( + "core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool" + ) + def test_delete_by_ids_empty_list(self, mock_pool_class): + """Test deleting with empty ID list.""" + # Mock the connection pool + mock_pool = MagicMock() + mock_pool_class.return_value = mock_pool + + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_pool.get_connection.return_value = mock_conn + mock_conn.cursor.return_value = mock_cursor + mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}] + + vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config) + vector_store.delete_by_ids([]) # Should not raise an exception + + # Verify no delete SQL was executed + execute_calls = mock_cursor.execute.call_args_list + delete_calls = [call for call in execute_calls if "DELETE" in str(call)] + assert len(delete_calls) == 0 + + @patch( + "core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool" + ) + def test_delete_by_ids_table_not_exists(self, mock_pool_class): + """Test deleting when table doesn't exist.""" + # Mock the connection pool + mock_pool = MagicMock() + mock_pool_class.return_value = mock_pool + + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_pool.get_connection.return_value = mock_conn + mock_conn.cursor.return_value = mock_cursor + mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}] + + # Simulate table doesn't exist error on delete + + def execute_side_effect(*args, **kwargs): + if "DELETE" in args[0]: + raise MySQLError(errno=1146, msg="Table doesn't exist") + + mock_cursor.execute.side_effect = execute_side_effect + + vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config) + # Should not raise an exception + vector_store.delete_by_ids(["doc1"]) + + @patch( + "core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool" + ) + def test_delete_by_metadata_field(self, mock_pool_class): + """Test deleting documents by metadata field.""" + # Mock the connection pool + mock_pool = MagicMock() + mock_pool_class.return_value = mock_pool + + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_pool.get_connection.return_value = mock_conn + mock_conn.cursor.return_value = mock_cursor + mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}] + + vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config) + vector_store.delete_by_metadata_field("document_id", "dataset1") + + # Check that the correct SQL was executed + execute_calls = mock_cursor.execute.call_args_list + delete_calls = [call for call in execute_calls if "DELETE" in str(call)] + assert len(delete_calls) == 1 + delete_call = delete_calls[0] + assert "JSON_UNQUOTE(JSON_EXTRACT(meta" in delete_call[0][0] + assert delete_call[0][1] == ("$.document_id", "dataset1") + + @patch( + "core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool" + ) + def test_search_by_vector_cosine(self, mock_pool_class): + """Test vector search with cosine distance.""" + # Mock the connection pool + mock_pool = MagicMock() + mock_pool_class.return_value = mock_pool + + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_pool.get_connection.return_value = mock_conn + mock_conn.cursor.return_value = mock_cursor + mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}] + mock_cursor.__iter__ = lambda self: iter( + [{"meta": json.dumps({"doc_id": "doc1", "source": "test"}), "text": "Test document 1", "distance": 0.1}] + ) + + vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config) + query_vector = [0.1, 0.2, 0.3, 0.4] + docs = vector_store.search_by_vector(query_vector, top_k=5) + + assert len(docs) == 1 + assert docs[0].page_content == "Test document 1" + assert abs(docs[0].metadata["score"] - 0.9) < 0.1 # 1 - 0.1 = 0.9 + assert docs[0].metadata["distance"] == 0.1 + + @patch( + "core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool" + ) + def test_search_by_vector_euclidean(self, mock_pool_class): + """Test vector search with euclidean distance.""" + config = AlibabaCloudMySQLVectorConfig( + host="localhost", + port=3306, + user="test_user", + password="test_password", + database="test_db", + max_connection=5, + distance_function="euclidean", + ) + + # Mock the connection pool + mock_pool = MagicMock() + mock_pool_class.return_value = mock_pool + + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_pool.get_connection.return_value = mock_conn + mock_conn.cursor.return_value = mock_cursor + mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}] + mock_cursor.__iter__ = lambda self: iter( + [{"meta": json.dumps({"doc_id": "doc1", "source": "test"}), "text": "Test document 1", "distance": 2.0}] + ) + + vector_store = AlibabaCloudMySQLVector(self.collection_name, config) + query_vector = [0.1, 0.2, 0.3, 0.4] + docs = vector_store.search_by_vector(query_vector, top_k=5) + + assert len(docs) == 1 + assert abs(docs[0].metadata["score"] - 1.0 / 3.0) < 0.01 # 1/(1+2) = 1/3 + + @patch( + "core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool" + ) + def test_search_by_vector_with_filter(self, mock_pool_class): + """Test vector search with document ID filter.""" + # Mock the connection pool + mock_pool = MagicMock() + mock_pool_class.return_value = mock_pool + + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_pool.get_connection.return_value = mock_conn + mock_conn.cursor.return_value = mock_cursor + mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}] + mock_cursor.__iter__ = lambda self: iter([]) + + vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config) + query_vector = [0.1, 0.2, 0.3, 0.4] + docs = vector_store.search_by_vector(query_vector, top_k=5, document_ids_filter=["dataset1"]) + + # Verify the SQL contains the WHERE clause for filtering + execute_calls = mock_cursor.execute.call_args_list + search_calls = [call for call in execute_calls if "VEC_DISTANCE" in str(call)] + assert len(search_calls) > 0 + search_call = search_calls[0] + assert "WHERE JSON_UNQUOTE" in search_call[0][0] + + @patch( + "core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool" + ) + def test_search_by_vector_with_score_threshold(self, mock_pool_class): + """Test vector search with score threshold.""" + # Mock the connection pool + mock_pool = MagicMock() + mock_pool_class.return_value = mock_pool + + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_pool.get_connection.return_value = mock_conn + mock_conn.cursor.return_value = mock_cursor + mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}] + mock_cursor.__iter__ = lambda self: iter( + [ + { + "meta": json.dumps({"doc_id": "doc1", "source": "test"}), + "text": "High similarity document", + "distance": 0.1, # High similarity (score = 0.9) + }, + { + "meta": json.dumps({"doc_id": "doc2", "source": "test"}), + "text": "Low similarity document", + "distance": 0.8, # Low similarity (score = 0.2) + }, + ] + ) + + vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config) + query_vector = [0.1, 0.2, 0.3, 0.4] + docs = vector_store.search_by_vector(query_vector, top_k=5, score_threshold=0.5) + + # Only the high similarity document should be returned + assert len(docs) == 1 + assert docs[0].page_content == "High similarity document" + + @patch( + "core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool" + ) + def test_search_by_vector_invalid_top_k(self, mock_pool_class): + """Test vector search with invalid top_k.""" + # Mock the connection pool + mock_pool = MagicMock() + mock_pool_class.return_value = mock_pool + + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_pool.get_connection.return_value = mock_conn + mock_conn.cursor.return_value = mock_cursor + mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}] + + vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config) + query_vector = [0.1, 0.2, 0.3, 0.4] + + with pytest.raises(ValueError): + vector_store.search_by_vector(query_vector, top_k=0) + + with pytest.raises(ValueError): + vector_store.search_by_vector(query_vector, top_k="invalid") + + @patch( + "core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool" + ) + def test_search_by_full_text(self, mock_pool_class): + """Test full-text search.""" + # Mock the connection pool + mock_pool = MagicMock() + mock_pool_class.return_value = mock_pool + + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_pool.get_connection.return_value = mock_conn + mock_conn.cursor.return_value = mock_cursor + mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}] + mock_cursor.__iter__ = lambda self: iter( + [ + { + "meta": {"doc_id": "doc1", "source": "test"}, + "text": "This document contains machine learning content", + "score": 1.5, + } + ] + ) + + vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config) + docs = vector_store.search_by_full_text("machine learning", top_k=5) + + assert len(docs) == 1 + assert docs[0].page_content == "This document contains machine learning content" + assert docs[0].metadata["score"] == 1.5 + + @patch( + "core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool" + ) + def test_search_by_full_text_with_filter(self, mock_pool_class): + """Test full-text search with document ID filter.""" + # Mock the connection pool + mock_pool = MagicMock() + mock_pool_class.return_value = mock_pool + + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_pool.get_connection.return_value = mock_conn + mock_conn.cursor.return_value = mock_cursor + mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}] + mock_cursor.__iter__ = lambda self: iter([]) + + vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config) + docs = vector_store.search_by_full_text("machine learning", top_k=5, document_ids_filter=["dataset1"]) + + # Verify the SQL contains the AND clause for filtering + execute_calls = mock_cursor.execute.call_args_list + search_calls = [call for call in execute_calls if "MATCH" in str(call)] + assert len(search_calls) > 0 + search_call = search_calls[0] + assert "AND JSON_UNQUOTE" in search_call[0][0] + + @patch( + "core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool" + ) + def test_search_by_full_text_invalid_top_k(self, mock_pool_class): + """Test full-text search with invalid top_k.""" + # Mock the connection pool + mock_pool = MagicMock() + mock_pool_class.return_value = mock_pool + + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_pool.get_connection.return_value = mock_conn + mock_conn.cursor.return_value = mock_cursor + mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}] + + vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config) + + with pytest.raises(ValueError): + vector_store.search_by_full_text("test", top_k=0) + + with pytest.raises(ValueError): + vector_store.search_by_full_text("test", top_k="invalid") + + @patch( + "core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool" + ) + def test_delete_collection(self, mock_pool_class): + """Test deleting the entire collection.""" + # Mock the connection pool + mock_pool = MagicMock() + mock_pool_class.return_value = mock_pool + + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_pool.get_connection.return_value = mock_conn + mock_conn.cursor.return_value = mock_cursor + mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}] + + vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config) + vector_store.delete() + + # Check that DROP TABLE SQL was executed + execute_calls = mock_cursor.execute.call_args_list + drop_calls = [call for call in execute_calls if "DROP TABLE" in str(call)] + assert len(drop_calls) == 1 + drop_call = drop_calls[0] + assert f"DROP TABLE IF EXISTS {self.collection_name.lower()}" in drop_call[0][0] + + @patch( + "core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool" + ) + def test_unsupported_distance_function(self, mock_pool_class): + """Test that Pydantic validation rejects unsupported distance functions.""" + # Test that creating config with unsupported distance function raises ValidationError + with pytest.raises(ValueError) as context: + AlibabaCloudMySQLVectorConfig( + host="localhost", + port=3306, + user="test_user", + password="test_password", + database="test_db", + max_connection=5, + distance_function="manhattan", # Unsupported - not in Literal["cosine", "euclidean"] + ) + + # The error should be related to validation + assert "Input should be 'cosine' or 'euclidean'" in str(context.value) or "manhattan" in str(context.value) + + def _setup_mocks(self, mock_redis, mock_pool_class): + """Helper method to setup common mocks.""" + # Mock Redis operations + mock_redis.lock.return_value.__enter__ = MagicMock() + mock_redis.lock.return_value.__exit__ = MagicMock() + mock_redis.get.return_value = None + mock_redis.set.return_value = None + + # Mock the connection pool + mock_pool = MagicMock() + mock_pool_class.return_value = mock_pool + + # Mock connection and cursor + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_pool.get_connection.return_value = mock_conn + mock_conn.cursor.return_value = mock_cursor + mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}] + + +if __name__ == "__main__": + unittest.main() diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/milvus/test_milvus.py b/api/tests/unit_tests/core/rag/datasource/vdb/milvus/test_milvus.py index 48cc8a7e1c..fb2ddfe162 100644 --- a/api/tests/unit_tests/core/rag/datasource/vdb/milvus/test_milvus.py +++ b/api/tests/unit_tests/core/rag/datasource/vdb/milvus/test_milvus.py @@ -11,8 +11,8 @@ def test_default_value(): config = valid_config.copy() del config[key] with pytest.raises(ValidationError) as e: - MilvusConfig(**config) + MilvusConfig.model_validate(config) assert e.value.errors()[0]["msg"] == f"Value error, config MILVUS_{key.upper()} is required" - config = MilvusConfig(**valid_config) + config = MilvusConfig.model_validate(valid_config) assert config.database == "default" diff --git a/api/tests/unit_tests/core/rag/extractor/firecrawl/test_firecrawl.py b/api/tests/unit_tests/core/rag/extractor/firecrawl/test_firecrawl.py index 6689e13b96..e5ead6ff66 100644 --- a/api/tests/unit_tests/core/rag/extractor/firecrawl/test_firecrawl.py +++ b/api/tests/unit_tests/core/rag/extractor/firecrawl/test_firecrawl.py @@ -18,7 +18,7 @@ def test_firecrawl_web_extractor_crawl_mode(mocker): mocked_firecrawl = { "id": "test", } - mocker.patch("requests.post", return_value=_mock_response(mocked_firecrawl)) + mocker.patch("httpx.post", return_value=_mock_response(mocked_firecrawl)) job_id = firecrawl_app.crawl_url(url, params) assert job_id is not None diff --git a/api/tests/unit_tests/core/rag/extractor/test_notion_extractor.py b/api/tests/unit_tests/core/rag/extractor/test_notion_extractor.py index eea584a2f8..f1e1820acc 100644 --- a/api/tests/unit_tests/core/rag/extractor/test_notion_extractor.py +++ b/api/tests/unit_tests/core/rag/extractor/test_notion_extractor.py @@ -69,7 +69,7 @@ def test_notion_page(mocker): ], "next_cursor": None, } - mocker.patch("requests.request", return_value=_mock_response(mocked_notion_page)) + mocker.patch("httpx.request", return_value=_mock_response(mocked_notion_page)) page_docs = extractor._load_data_as_documents(page_id, "page") assert len(page_docs) == 1 @@ -84,7 +84,7 @@ def test_notion_database(mocker): "results": [_generate_page(i) for i in page_title_list], "next_cursor": None, } - mocker.patch("requests.post", return_value=_mock_response(mocked_notion_database)) + mocker.patch("httpx.post", return_value=_mock_response(mocked_notion_database)) database_docs = extractor._load_data_as_documents(database_id, "database") assert len(database_docs) == 1 content = _remove_multiple_new_lines(database_docs[0].page_content) diff --git a/api/tests/unit_tests/core/repositories/test_celery_workflow_execution_repository.py b/api/tests/unit_tests/core/repositories/test_celery_workflow_execution_repository.py index e7733b2317..e6d0371cd5 100644 --- a/api/tests/unit_tests/core/repositories/test_celery_workflow_execution_repository.py +++ b/api/tests/unit_tests/core/repositories/test_celery_workflow_execution_repository.py @@ -140,7 +140,7 @@ class TestCeleryWorkflowExecutionRepository: assert call_args["execution_data"] == sample_workflow_execution.model_dump() assert call_args["tenant_id"] == mock_account.current_tenant_id assert call_args["app_id"] == "test-app" - assert call_args["triggered_from"] == WorkflowRunTriggeredFrom.APP_RUN.value + assert call_args["triggered_from"] == WorkflowRunTriggeredFrom.APP_RUN assert call_args["creator_user_id"] == mock_account.id # Verify no task tracking occurs (no _pending_saves attribute) diff --git a/api/tests/unit_tests/core/repositories/test_celery_workflow_node_execution_repository.py b/api/tests/unit_tests/core/repositories/test_celery_workflow_node_execution_repository.py index 3abe20fca1..f6211f4cca 100644 --- a/api/tests/unit_tests/core/repositories/test_celery_workflow_node_execution_repository.py +++ b/api/tests/unit_tests/core/repositories/test_celery_workflow_node_execution_repository.py @@ -149,7 +149,7 @@ class TestCeleryWorkflowNodeExecutionRepository: assert call_args["execution_data"] == sample_workflow_node_execution.model_dump() assert call_args["tenant_id"] == mock_account.current_tenant_id assert call_args["app_id"] == "test-app" - assert call_args["triggered_from"] == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value + assert call_args["triggered_from"] == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN assert call_args["creator_user_id"] == mock_account.id # Verify execution is cached diff --git a/api/tests/unit_tests/core/repositories/test_workflow_node_execution_truncation.py b/api/tests/unit_tests/core/repositories/test_workflow_node_execution_truncation.py index 36f7d3ef55..485be90eae 100644 --- a/api/tests/unit_tests/core/repositories/test_workflow_node_execution_truncation.py +++ b/api/tests/unit_tests/core/repositories/test_workflow_node_execution_truncation.py @@ -145,12 +145,12 @@ class TestSQLAlchemyWorkflowNodeExecutionRepositoryTruncation: db_model.index = 1 db_model.predecessor_node_id = None db_model.node_id = "node-id" - db_model.node_type = NodeType.LLM.value + db_model.node_type = NodeType.LLM db_model.title = "Test Node" db_model.inputs = json.dumps({"value": "inputs"}) db_model.process_data = json.dumps({"value": "process_data"}) db_model.outputs = json.dumps({"value": "outputs"}) - db_model.status = WorkflowNodeExecutionStatus.SUCCEEDED.value + db_model.status = WorkflowNodeExecutionStatus.SUCCEEDED db_model.error = None db_model.elapsed_time = 1.0 db_model.execution_metadata = "{}" diff --git a/api/tests/unit_tests/core/workflow/entities/test_variable_pool.py b/api/tests/unit_tests/core/workflow/entities/test_variable_pool.py new file mode 100644 index 0000000000..68fe82d05e --- /dev/null +++ b/api/tests/unit_tests/core/workflow/entities/test_variable_pool.py @@ -0,0 +1,113 @@ +from core.variables.segments import ( + BooleanSegment, + IntegerSegment, + NoneSegment, + StringSegment, +) +from core.workflow.entities.variable_pool import VariablePool + + +class TestVariablePoolGetAndNestedAttribute: + # + # _get_nested_attribute tests + # + def test__get_nested_attribute_existing_key(self): + pool = VariablePool.empty() + obj = {"a": 123} + segment = pool._get_nested_attribute(obj, "a") + assert segment is not None + assert segment.value == 123 + + def test__get_nested_attribute_missing_key(self): + pool = VariablePool.empty() + obj = {"a": 123} + segment = pool._get_nested_attribute(obj, "b") + assert segment is None + + def test__get_nested_attribute_non_dict(self): + pool = VariablePool.empty() + obj = ["not", "a", "dict"] + segment = pool._get_nested_attribute(obj, "a") + assert segment is None + + def test__get_nested_attribute_with_none_value(self): + pool = VariablePool.empty() + obj = {"a": None} + segment = pool._get_nested_attribute(obj, "a") + assert segment is not None + assert isinstance(segment, NoneSegment) + + def test__get_nested_attribute_with_empty_string(self): + pool = VariablePool.empty() + obj = {"a": ""} + segment = pool._get_nested_attribute(obj, "a") + assert segment is not None + assert isinstance(segment, StringSegment) + assert segment.value == "" + + # + # get tests + # + def test_get_simple_variable(self): + pool = VariablePool.empty() + pool.add(("node1", "var1"), "value1") + segment = pool.get(("node1", "var1")) + assert segment is not None + assert segment.value == "value1" + + def test_get_missing_variable(self): + pool = VariablePool.empty() + result = pool.get(("node1", "unknown")) + assert result is None + + def test_get_with_too_short_selector(self): + pool = VariablePool.empty() + result = pool.get(("only_node",)) + assert result is None + + def test_get_nested_object_attribute(self): + pool = VariablePool.empty() + obj_value = {"inner": "hello"} + pool.add(("node1", "obj"), obj_value) + + # simulate selector with nested attr + segment = pool.get(("node1", "obj", "inner")) + assert segment is not None + assert segment.value == "hello" + + def test_get_nested_object_missing_attribute(self): + pool = VariablePool.empty() + obj_value = {"inner": "hello"} + pool.add(("node1", "obj"), obj_value) + + result = pool.get(("node1", "obj", "not_exist")) + assert result is None + + def test_get_nested_object_attribute_with_falsy_values(self): + pool = VariablePool.empty() + obj_value = { + "inner_none": None, + "inner_empty": "", + "inner_zero": 0, + "inner_false": False, + } + pool.add(("node1", "obj"), obj_value) + + segment_none = pool.get(("node1", "obj", "inner_none")) + assert segment_none is not None + assert isinstance(segment_none, NoneSegment) + + segment_empty = pool.get(("node1", "obj", "inner_empty")) + assert segment_empty is not None + assert isinstance(segment_empty, StringSegment) + assert segment_empty.value == "" + + segment_zero = pool.get(("node1", "obj", "inner_zero")) + assert segment_zero is not None + assert isinstance(segment_zero, IntegerSegment) + assert segment_zero.value == 0 + + segment_false = pool.get(("node1", "obj", "inner_false")) + assert segment_false is not None + assert isinstance(segment_false, BooleanSegment) + assert segment_false.value is False diff --git a/api/tests/unit_tests/core/workflow/graph_engine/command_channels/test_redis_channel.py b/api/tests/unit_tests/core/workflow/graph_engine/command_channels/test_redis_channel.py index 2c08fff27b..7ebccf83a7 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/command_channels/test_redis_channel.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/command_channels/test_redis_channel.py @@ -147,7 +147,7 @@ class TestRedisChannel: """Test deserializing an abort command.""" channel = RedisChannel(MagicMock(), "test:key") - abort_data = {"command_type": CommandType.ABORT.value} + abort_data = {"command_type": CommandType.ABORT} command = channel._deserialize_command(abort_data) assert isinstance(command, AbortCommand) @@ -158,7 +158,7 @@ class TestRedisChannel: channel = RedisChannel(MagicMock(), "test:key") # For now, only ABORT is supported, but test generic handling - generic_data = {"command_type": CommandType.ABORT.value} + generic_data = {"command_type": CommandType.ABORT} command = channel._deserialize_command(generic_data) assert command is not None diff --git a/api/tests/unit_tests/core/workflow/graph_engine/event_management/test_event_handlers.py b/api/tests/unit_tests/core/workflow/graph_engine/event_management/test_event_handlers.py new file mode 100644 index 0000000000..d556bb138e --- /dev/null +++ b/api/tests/unit_tests/core/workflow/graph_engine/event_management/test_event_handlers.py @@ -0,0 +1,120 @@ +"""Tests for graph engine event handlers.""" + +from __future__ import annotations + +from datetime import datetime + +from core.workflow.entities import GraphRuntimeState, VariablePool +from core.workflow.enums import NodeExecutionType, NodeState, NodeType, WorkflowNodeExecutionStatus +from core.workflow.graph import Graph +from core.workflow.graph_engine.domain.graph_execution import GraphExecution +from core.workflow.graph_engine.event_management.event_handlers import EventHandler +from core.workflow.graph_engine.event_management.event_manager import EventManager +from core.workflow.graph_engine.graph_state_manager import GraphStateManager +from core.workflow.graph_engine.ready_queue.in_memory import InMemoryReadyQueue +from core.workflow.graph_engine.response_coordinator.coordinator import ResponseStreamCoordinator +from core.workflow.graph_events import NodeRunRetryEvent, NodeRunStartedEvent +from core.workflow.node_events import NodeRunResult +from core.workflow.nodes.base.entities import RetryConfig + + +class _StubEdgeProcessor: + """Minimal edge processor stub for tests.""" + + +class _StubErrorHandler: + """Minimal error handler stub for tests.""" + + +class _StubNode: + """Simple node stub exposing the attributes needed by the state manager.""" + + def __init__(self, node_id: str) -> None: + self.id = node_id + self.state = NodeState.UNKNOWN + self.title = "Stub Node" + self.execution_type = NodeExecutionType.EXECUTABLE + self.error_strategy = None + self.retry_config = RetryConfig() + self.retry = False + + +def _build_event_handler(node_id: str) -> tuple[EventHandler, EventManager, GraphExecution]: + """Construct an EventHandler with in-memory dependencies for testing.""" + + node = _StubNode(node_id) + graph = Graph(nodes={node_id: node}, edges={}, in_edges={}, out_edges={}, root_node=node) + + variable_pool = VariablePool() + runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=0.0) + graph_execution = GraphExecution(workflow_id="test-workflow") + + event_manager = EventManager() + state_manager = GraphStateManager(graph=graph, ready_queue=InMemoryReadyQueue()) + response_coordinator = ResponseStreamCoordinator(variable_pool=variable_pool, graph=graph) + + handler = EventHandler( + graph=graph, + graph_runtime_state=runtime_state, + graph_execution=graph_execution, + response_coordinator=response_coordinator, + event_collector=event_manager, + edge_processor=_StubEdgeProcessor(), + state_manager=state_manager, + error_handler=_StubErrorHandler(), + ) + + return handler, event_manager, graph_execution + + +def test_retry_does_not_emit_additional_start_event() -> None: + """Ensure retry attempts do not produce duplicate start events.""" + + node_id = "test-node" + handler, event_manager, graph_execution = _build_event_handler(node_id) + + execution_id = "exec-1" + node_type = NodeType.CODE + start_time = datetime.utcnow() + + start_event = NodeRunStartedEvent( + id=execution_id, + node_id=node_id, + node_type=node_type, + node_title="Stub Node", + start_at=start_time, + ) + handler.dispatch(start_event) + + retry_event = NodeRunRetryEvent( + id=execution_id, + node_id=node_id, + node_type=node_type, + node_title="Stub Node", + start_at=start_time, + error="boom", + retry_index=1, + node_run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + error="boom", + error_type="TestError", + ), + ) + handler.dispatch(retry_event) + + # Simulate the node starting execution again after retry + second_start_event = NodeRunStartedEvent( + id=execution_id, + node_id=node_id, + node_type=node_type, + node_title="Stub Node", + start_at=start_time, + ) + handler.dispatch(second_start_event) + + collected_types = [type(event) for event in event_manager._events] # type: ignore[attr-defined] + + assert collected_types == [NodeRunStartedEvent, NodeRunRetryEvent] + + node_execution = graph_execution.get_or_create_node_execution(node_id) + assert node_execution.retry_count == 1 diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_answer_order_workflow.py b/api/tests/unit_tests/core/workflow/graph_engine/test_answer_order_workflow.py new file mode 100644 index 0000000000..6569439b56 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_answer_order_workflow.py @@ -0,0 +1,28 @@ +from .test_mock_config import MockConfigBuilder +from .test_table_runner import TableTestRunner, WorkflowTestCase + +LLM_NODE_ID = "1759052580454" + + +def test_answer_nodes_emit_in_order() -> None: + mock_config = ( + MockConfigBuilder() + .with_llm_response("unused default") + .with_node_output(LLM_NODE_ID, {"text": "mocked llm text"}) + .build() + ) + + expected_answer = "--- answer 1 ---\n\nfoo\n--- answer 2 ---\n\nmocked llm text\n" + + case = WorkflowTestCase( + fixture_path="test-answer-order", + query="", + expected_outputs={"answer": expected_answer}, + use_auto_mock=True, + mock_config=mock_config, + ) + + runner = TableTestRunner() + result = runner.run_test_case(case) + + assert result.success, result.error diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py index 6a723999de..4a117f8c96 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py @@ -10,11 +10,18 @@ import time from hypothesis import HealthCheck, given, settings from hypothesis import strategies as st +from core.workflow.enums import ErrorStrategy from core.workflow.graph_engine import GraphEngine from core.workflow.graph_engine.command_channels import InMemoryChannel -from core.workflow.graph_events import GraphRunStartedEvent, GraphRunSucceededEvent +from core.workflow.graph_events import ( + GraphRunPartialSucceededEvent, + GraphRunStartedEvent, + GraphRunSucceededEvent, +) +from core.workflow.nodes.base.entities import DefaultValue, DefaultValueType # Import the test framework from the new module +from .test_mock_config import MockConfigBuilder from .test_table_runner import TableTestRunner, WorkflowRunner, WorkflowTestCase @@ -721,3 +728,39 @@ def test_event_sequence_validation_with_table_tests(): else: assert result.event_sequence_match is True assert result.success, f"Test {i + 1} failed: {result.event_mismatch_details or result.error}" + + +def test_graph_run_emits_partial_success_when_node_failure_recovered(): + runner = TableTestRunner() + + fixture_data = runner.workflow_runner.load_fixture("basic_chatflow") + mock_config = MockConfigBuilder().with_node_error("llm", "mock llm failure").build() + + graph, graph_runtime_state = runner.workflow_runner.create_graph_from_fixture( + fixture_data=fixture_data, + query="hello", + use_mock_factory=True, + mock_config=mock_config, + ) + + llm_node = graph.nodes["llm"] + base_node_data = llm_node.get_base_node_data() + base_node_data.error_strategy = ErrorStrategy.DEFAULT_VALUE + base_node_data.default_value = [DefaultValue(key="text", value="fallback response", type=DefaultValueType.STRING)] + + engine = GraphEngine( + workflow_id="test_workflow", + graph=graph, + graph_runtime_state=graph_runtime_state, + command_channel=InMemoryChannel(), + ) + + events = list(engine.run()) + + assert isinstance(events[-1], GraphRunPartialSucceededEvent) + + partial_event = next(event for event in events if isinstance(event, GraphRunPartialSucceededEvent)) + assert partial_event.exceptions_count == 1 + assert partial_event.outputs.get("answer") == "fallback response" + + assert not any(isinstance(event, GraphRunSucceededEvent) for event in events) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_iteration_simple.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_iteration_simple.py index 6a9bfbdcc3..c39c12925f 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_iteration_simple.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_iteration_simple.py @@ -56,8 +56,8 @@ def test_mock_iteration_node_preserves_config(): workflow_id="test", graph_config={"nodes": [], "edges": []}, user_id="test", - user_from=UserFrom.ACCOUNT.value, - invoke_from=InvokeFrom.SERVICE_API.value, + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.SERVICE_API, call_depth=0, ) @@ -117,8 +117,8 @@ def test_mock_loop_node_preserves_config(): workflow_id="test", graph_config={"nodes": [], "edges": []}, user_id="test", - user_from=UserFrom.ACCOUNT.value, - invoke_from=InvokeFrom.SERVICE_API.value, + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.SERVICE_API, call_depth=0, ) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_redis_stop_integration.py b/api/tests/unit_tests/core/workflow/graph_engine/test_redis_stop_integration.py index b286d99f70..bd41fdeee5 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_redis_stop_integration.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_redis_stop_integration.py @@ -49,7 +49,7 @@ class TestRedisStopIntegration: # Verify the command data command_json = calls[0][0][1] command_data = json.loads(command_json) - assert command_data["command_type"] == CommandType.ABORT.value + assert command_data["command_type"] == CommandType.ABORT assert command_data["reason"] == "Test stop" def test_graph_engine_manager_handles_redis_failure_gracefully(self): @@ -122,7 +122,7 @@ class TestRedisStopIntegration: # Verify serialized command command_json = calls[0][0][1] command_data = json.loads(command_json) - assert command_data["command_type"] == CommandType.ABORT.value + assert command_data["command_type"] == CommandType.ABORT assert command_data["reason"] == "User requested stop" # Check expire was set @@ -137,9 +137,7 @@ class TestRedisStopIntegration: mock_redis.pipeline.return_value.__exit__ = Mock(return_value=None) # Mock command data - abort_command_json = json.dumps( - {"command_type": CommandType.ABORT.value, "reason": "Test abort", "payload": None} - ) + abort_command_json = json.dumps({"command_type": CommandType.ABORT, "reason": "Test abort", "payload": None}) # Mock pipeline execute to return commands mock_pipeline.execute.return_value = [ diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_update_conversation_variable_iteration.py b/api/tests/unit_tests/core/workflow/graph_engine/test_update_conversation_variable_iteration.py new file mode 100644 index 0000000000..a7309f64de --- /dev/null +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_update_conversation_variable_iteration.py @@ -0,0 +1,41 @@ +"""Validate conversation variable updates inside an iteration workflow. + +This test uses the ``update-conversation-variable-in-iteration`` fixture, which +routes ``sys.query`` into the conversation variable ``answer`` from within an +iteration container. The workflow should surface that updated conversation +variable in the final answer output. + +Code nodes in the fixture are mocked because their concrete outputs are not +relevant to verifying variable propagation semantics. +""" + +from .test_mock_config import MockConfigBuilder +from .test_table_runner import TableTestRunner, WorkflowTestCase + + +def test_update_conversation_variable_in_iteration(): + fixture_name = "update-conversation-variable-in-iteration" + user_query = "ensure conversation variable syncs" + + mock_config = ( + MockConfigBuilder() + .with_node_output("1759032363865", {"result": [1]}) + .with_node_output("1759032476318", {"result": ""}) + .build() + ) + + case = WorkflowTestCase( + fixture_path=fixture_name, + use_auto_mock=True, + mock_config=mock_config, + query=user_query, + expected_outputs={"answer": user_query}, + description="Conversation variable updated within iteration should flow to answer output.", + ) + + runner = TableTestRunner() + result = runner.run_test_case(case) + + assert result.success, f"Workflow execution failed: {result.error}" + assert result.actual_outputs is not None + assert result.actual_outputs.get("answer") == user_query diff --git a/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py b/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py index b942614232..55fe62ca43 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py @@ -35,7 +35,7 @@ def list_operator_node(): "extract_by": ExtractConfig(enabled=False, serial="1"), "title": "Test Title", } - node_data = ListOperatorNodeData(**config) + node_data = ListOperatorNodeData.model_validate(config) node_config = { "id": "test_node_id", "data": node_data.model_dump(), diff --git a/api/tests/unit_tests/core/workflow/nodes/test_question_classifier_node.py b/api/tests/unit_tests/core/workflow/nodes/test_question_classifier_node.py index f990280c5f..47ef289ef3 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_question_classifier_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_question_classifier_node.py @@ -17,7 +17,7 @@ def test_init_question_classifier_node_data(): "vision": {"enabled": True, "configs": {"variable_selector": ["image"], "detail": "low"}}, } - node_data = QuestionClassifierNodeData(**data) + node_data = QuestionClassifierNodeData.model_validate(data) assert node_data.query_variable_selector == ["id", "name"] assert node_data.model.provider == "openai" @@ -49,7 +49,7 @@ def test_init_question_classifier_node_data_without_vision_config(): }, } - node_data = QuestionClassifierNodeData(**data) + node_data = QuestionClassifierNodeData.model_validate(data) assert node_data.query_variable_selector == ["id", "name"] assert node_data.model.provider == "openai" diff --git a/api/tests/unit_tests/core/workflow/nodes/test_retry.py b/api/tests/unit_tests/core/workflow/nodes/test_retry.py deleted file mode 100644 index 23cef58d2e..0000000000 --- a/api/tests/unit_tests/core/workflow/nodes/test_retry.py +++ /dev/null @@ -1,65 +0,0 @@ -import pytest - -pytest.skip( - "Retry functionality is part of Phase 2 enhanced error handling - not implemented in MVP of queue-based engine", - allow_module_level=True, -) - -DEFAULT_VALUE_EDGE = [ - { - "id": "start-source-node-target", - "source": "start", - "target": "node", - "sourceHandle": "source", - }, - { - "id": "node-source-answer-target", - "source": "node", - "target": "answer", - "sourceHandle": "source", - }, -] - - -def test_retry_default_value_partial_success(): - """retry default value node with partial success status""" - graph_config = { - "edges": DEFAULT_VALUE_EDGE, - "nodes": [ - {"data": {"title": "start", "type": "start", "variables": []}, "id": "start"}, - {"data": {"title": "answer", "type": "answer", "answer": "{{#node.result#}}"}, "id": "answer"}, - ContinueOnErrorTestHelper.get_http_node( - "default-value", - [{"key": "result", "type": "string", "value": "http node got error response"}], - retry_config={"retry_config": {"max_retries": 2, "retry_interval": 1000, "retry_enabled": True}}, - ), - ], - } - - graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config) - events = list(graph_engine.run()) - assert sum(1 for e in events if isinstance(e, NodeRunRetryEvent)) == 2 - assert events[-1].outputs == {"answer": "http node got error response"} - assert any(isinstance(e, GraphRunPartialSucceededEvent) for e in events) - assert len(events) == 11 - - -def test_retry_failed(): - """retry failed with success status""" - graph_config = { - "edges": DEFAULT_VALUE_EDGE, - "nodes": [ - {"data": {"title": "start", "type": "start", "variables": []}, "id": "start"}, - {"data": {"title": "answer", "type": "answer", "answer": "{{#node.result#}}"}, "id": "answer"}, - ContinueOnErrorTestHelper.get_http_node( - None, - None, - retry_config={"retry_config": {"max_retries": 2, "retry_interval": 1000, "retry_enabled": True}}, - ), - ], - } - graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config) - events = list(graph_engine.run()) - assert sum(1 for e in events if isinstance(e, NodeRunRetryEvent)) == 2 - assert any(isinstance(e, GraphRunFailedEvent) for e in events) - assert len(events) == 8 diff --git a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py index 3e50d5522a..6189febdf5 100644 --- a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py +++ b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py @@ -87,7 +87,7 @@ def test_overwrite_string_variable(): "data": { "title": "test", "assigned_variable_selector": ["conversation", conversation_variable.name], - "write_mode": WriteMode.OVER_WRITE.value, + "write_mode": WriteMode.OVER_WRITE, "input_variable_selector": [DEFAULT_NODE_ID, input_variable.name], }, } @@ -189,7 +189,7 @@ def test_append_variable_to_array(): "data": { "title": "test", "assigned_variable_selector": ["conversation", conversation_variable.name], - "write_mode": WriteMode.APPEND.value, + "write_mode": WriteMode.APPEND, "input_variable_selector": [DEFAULT_NODE_ID, input_variable.name], }, } @@ -282,7 +282,7 @@ def test_clear_array(): "data": { "title": "test", "assigned_variable_selector": ["conversation", conversation_variable.name], - "write_mode": WriteMode.CLEAR.value, + "write_mode": WriteMode.CLEAR, "input_variable_selector": [], }, } diff --git a/api/tests/unit_tests/core/workflow/test_system_variable.py b/api/tests/unit_tests/core/workflow/test_system_variable.py index 11d788ed79..3ae5edb383 100644 --- a/api/tests/unit_tests/core/workflow/test_system_variable.py +++ b/api/tests/unit_tests/core/workflow/test_system_variable.py @@ -46,7 +46,7 @@ class TestSystemVariableSerialization: def test_basic_deserialization(self): """Test successful deserialization from JSON structure with all fields correctly mapped.""" # Test with complete data - system_var = SystemVariable(**COMPLETE_VALID_DATA) + system_var = SystemVariable.model_validate(COMPLETE_VALID_DATA) # Verify all fields are correctly mapped assert system_var.user_id == COMPLETE_VALID_DATA["user_id"] @@ -59,7 +59,7 @@ class TestSystemVariableSerialization: assert system_var.files == [] # Test with minimal data (only required fields) - minimal_var = SystemVariable(**VALID_BASE_DATA) + minimal_var = SystemVariable.model_validate(VALID_BASE_DATA) assert minimal_var.user_id == VALID_BASE_DATA["user_id"] assert minimal_var.app_id == VALID_BASE_DATA["app_id"] assert minimal_var.workflow_id == VALID_BASE_DATA["workflow_id"] @@ -75,12 +75,12 @@ class TestSystemVariableSerialization: # Test workflow_run_id only (preferred alias) data_run_id = {**VALID_BASE_DATA, "workflow_run_id": workflow_id} - system_var1 = SystemVariable(**data_run_id) + system_var1 = SystemVariable.model_validate(data_run_id) assert system_var1.workflow_execution_id == workflow_id # Test workflow_execution_id only (direct field name) data_execution_id = {**VALID_BASE_DATA, "workflow_execution_id": workflow_id} - system_var2 = SystemVariable(**data_execution_id) + system_var2 = SystemVariable.model_validate(data_execution_id) assert system_var2.workflow_execution_id == workflow_id # Test both present - workflow_run_id should take precedence @@ -89,17 +89,17 @@ class TestSystemVariableSerialization: "workflow_execution_id": "should-be-ignored", "workflow_run_id": workflow_id, } - system_var3 = SystemVariable(**data_both) + system_var3 = SystemVariable.model_validate(data_both) assert system_var3.workflow_execution_id == workflow_id # Test neither present - should be None - system_var4 = SystemVariable(**VALID_BASE_DATA) + system_var4 = SystemVariable.model_validate(VALID_BASE_DATA) assert system_var4.workflow_execution_id is None def test_serialization_round_trip(self): """Test that serialize → deserialize produces the same result with alias handling.""" # Create original SystemVariable - original = SystemVariable(**COMPLETE_VALID_DATA) + original = SystemVariable.model_validate(COMPLETE_VALID_DATA) # Serialize to dict serialized = original.model_dump(mode="json") @@ -110,7 +110,7 @@ class TestSystemVariableSerialization: assert serialized["workflow_run_id"] == COMPLETE_VALID_DATA["workflow_run_id"] # Deserialize back - deserialized = SystemVariable(**serialized) + deserialized = SystemVariable.model_validate(serialized) # Verify all fields match after round-trip assert deserialized.user_id == original.user_id @@ -125,7 +125,7 @@ class TestSystemVariableSerialization: def test_json_round_trip(self): """Test JSON serialization/deserialization consistency with proper structure.""" # Create original SystemVariable - original = SystemVariable(**COMPLETE_VALID_DATA) + original = SystemVariable.model_validate(COMPLETE_VALID_DATA) # Serialize to JSON string json_str = original.model_dump_json() @@ -137,7 +137,7 @@ class TestSystemVariableSerialization: assert json_data["workflow_run_id"] == COMPLETE_VALID_DATA["workflow_run_id"] # Deserialize from JSON data - deserialized = SystemVariable(**json_data) + deserialized = SystemVariable.model_validate(json_data) # Verify key fields match after JSON round-trip assert deserialized.workflow_execution_id == original.workflow_execution_id @@ -149,13 +149,13 @@ class TestSystemVariableSerialization: """Test deserialization with File objects in the files field - SystemVariable specific logic.""" # Test with empty files list data_empty = {**VALID_BASE_DATA, "files": []} - system_var_empty = SystemVariable(**data_empty) + system_var_empty = SystemVariable.model_validate(data_empty) assert system_var_empty.files == [] # Test with single File object test_file = create_test_file() data_single = {**VALID_BASE_DATA, "files": [test_file]} - system_var_single = SystemVariable(**data_single) + system_var_single = SystemVariable.model_validate(data_single) assert len(system_var_single.files) == 1 assert system_var_single.files[0].filename == "test.txt" assert system_var_single.files[0].tenant_id == "test-tenant-id" @@ -179,14 +179,14 @@ class TestSystemVariableSerialization: ) data_multiple = {**VALID_BASE_DATA, "files": [file1, file2]} - system_var_multiple = SystemVariable(**data_multiple) + system_var_multiple = SystemVariable.model_validate(data_multiple) assert len(system_var_multiple.files) == 2 assert system_var_multiple.files[0].filename == "doc1.txt" assert system_var_multiple.files[1].filename == "image.jpg" # Verify files field serialization/deserialization serialized = system_var_multiple.model_dump(mode="json") - deserialized = SystemVariable(**serialized) + deserialized = SystemVariable.model_validate(serialized) assert len(deserialized.files) == 2 assert deserialized.files[0].filename == "doc1.txt" assert deserialized.files[1].filename == "image.jpg" @@ -197,7 +197,7 @@ class TestSystemVariableSerialization: # Create with workflow_run_id (alias) data_with_alias = {**VALID_BASE_DATA, "workflow_run_id": workflow_id} - system_var = SystemVariable(**data_with_alias) + system_var = SystemVariable.model_validate(data_with_alias) # Serialize and verify alias is used serialized = system_var.model_dump() @@ -205,7 +205,7 @@ class TestSystemVariableSerialization: assert "workflow_execution_id" not in serialized # Deserialize and verify field mapping - deserialized = SystemVariable(**serialized) + deserialized = SystemVariable.model_validate(serialized) assert deserialized.workflow_execution_id == workflow_id # Test JSON serialization path @@ -213,7 +213,7 @@ class TestSystemVariableSerialization: assert json_serialized["workflow_run_id"] == workflow_id assert "workflow_execution_id" not in json_serialized - json_deserialized = SystemVariable(**json_serialized) + json_deserialized = SystemVariable.model_validate(json_serialized) assert json_deserialized.workflow_execution_id == workflow_id def test_model_validator_serialization_logic(self): @@ -222,7 +222,7 @@ class TestSystemVariableSerialization: # Test direct instantiation with workflow_execution_id (should work) data1 = {**VALID_BASE_DATA, "workflow_execution_id": workflow_id} - system_var1 = SystemVariable(**data1) + system_var1 = SystemVariable.model_validate(data1) assert system_var1.workflow_execution_id == workflow_id # Test serialization of the above (should use alias) @@ -236,7 +236,7 @@ class TestSystemVariableSerialization: "workflow_execution_id": "should-be-removed", "workflow_run_id": workflow_id, } - system_var2 = SystemVariable(**data2) + system_var2 = SystemVariable.model_validate(data2) assert system_var2.workflow_execution_id == workflow_id # Verify serialization consistency diff --git a/api/tests/unit_tests/extensions/storage/test_supabase_storage.py b/api/tests/unit_tests/extensions/storage/test_supabase_storage.py index 958072223e..476f87269c 100644 --- a/api/tests/unit_tests/extensions/storage/test_supabase_storage.py +++ b/api/tests/unit_tests/extensions/storage/test_supabase_storage.py @@ -172,73 +172,31 @@ class TestSupabaseStorage: assert "test-bucket" in [call[0][0] for call in mock_client.storage.from_.call_args_list if call[0]] mock_client.storage.from_().download.assert_called_with("test.txt") - def test_exists_with_list_containing_items(self, storage_with_mock_client): - """Test exists returns True when list() returns items (using len() > 0).""" + def test_exists_returns_true_when_file_found(self, storage_with_mock_client): + """Test exists returns True when list() returns items.""" storage, mock_client = storage_with_mock_client - # Mock list return with special object that has count() method - mock_list_result = Mock() - mock_list_result.count.return_value = 1 - mock_client.storage.from_().list.return_value = mock_list_result + mock_client.storage.from_().list.return_value = [{"name": "test.txt"}] result = storage.exists("test.txt") assert result is True - # from_ gets called during init too, so just check it was called with the right bucket assert "test-bucket" in [call[0][0] for call in mock_client.storage.from_.call_args_list if call[0]] - mock_client.storage.from_().list.assert_called_with("test.txt") + mock_client.storage.from_().list.assert_called_with(path="test.txt") - def test_exists_with_count_method_greater_than_zero(self, storage_with_mock_client): - """Test exists returns True when list result has count() > 0.""" + def test_exists_returns_false_when_file_not_found(self, storage_with_mock_client): + """Test exists returns False when list() returns an empty list.""" storage, mock_client = storage_with_mock_client - # Mock list return with count() method - mock_list_result = Mock() - mock_list_result.count.return_value = 1 - mock_client.storage.from_().list.return_value = mock_list_result - - result = storage.exists("test.txt") - - assert result is True - # Verify the correct calls were made - assert "test-bucket" in [call[0][0] for call in mock_client.storage.from_.call_args_list if call[0]] - mock_client.storage.from_().list.assert_called_with("test.txt") - mock_list_result.count.assert_called() - - def test_exists_with_count_method_zero(self, storage_with_mock_client): - """Test exists returns False when list result has count() == 0.""" - storage, mock_client = storage_with_mock_client - - # Mock list return with count() method returning 0 - mock_list_result = Mock() - mock_list_result.count.return_value = 0 - mock_client.storage.from_().list.return_value = mock_list_result + mock_client.storage.from_().list.return_value = [] result = storage.exists("test.txt") assert result is False - # Verify the correct calls were made assert "test-bucket" in [call[0][0] for call in mock_client.storage.from_.call_args_list if call[0]] - mock_client.storage.from_().list.assert_called_with("test.txt") - mock_list_result.count.assert_called() + mock_client.storage.from_().list.assert_called_with(path="test.txt") - def test_exists_with_empty_list(self, storage_with_mock_client): - """Test exists returns False when list() returns empty list.""" - storage, mock_client = storage_with_mock_client - - # Mock list return with special object that has count() method returning 0 - mock_list_result = Mock() - mock_list_result.count.return_value = 0 - mock_client.storage.from_().list.return_value = mock_list_result - - result = storage.exists("test.txt") - - assert result is False - # Verify the correct calls were made - assert "test-bucket" in [call[0][0] for call in mock_client.storage.from_.call_args_list if call[0]] - mock_client.storage.from_().list.assert_called_with("test.txt") - - def test_delete_calls_remove_with_filename(self, storage_with_mock_client): + def test_delete_calls_remove_with_filename_in_list(self, storage_with_mock_client): """Test delete calls remove([...]) (some client versions require a list).""" storage, mock_client = storage_with_mock_client @@ -247,7 +205,7 @@ class TestSupabaseStorage: storage.delete(filename) mock_client.storage.from_.assert_called_once_with("test-bucket") - mock_client.storage.from_().remove.assert_called_once_with(filename) + mock_client.storage.from_().remove.assert_called_once_with([filename]) def test_bucket_exists_returns_true_when_bucket_found(self): """Test bucket_exists returns True when bucket is found in list.""" diff --git a/api/tests/unit_tests/factories/test_file_factory.py b/api/tests/unit_tests/factories/test_file_factory.py new file mode 100644 index 0000000000..777fe5a6e7 --- /dev/null +++ b/api/tests/unit_tests/factories/test_file_factory.py @@ -0,0 +1,115 @@ +import re + +import pytest + +from factories.file_factory import _get_remote_file_info + + +class _FakeResponse: + def __init__(self, status_code: int, headers: dict[str, str]): + self.status_code = status_code + self.headers = headers + + +def _mock_head(monkeypatch: pytest.MonkeyPatch, headers: dict[str, str], status_code: int = 200): + def _fake_head(url: str, follow_redirects: bool = True): + return _FakeResponse(status_code=status_code, headers=headers) + + monkeypatch.setattr("factories.file_factory.ssrf_proxy.head", _fake_head) + + +class TestGetRemoteFileInfo: + """Tests for _get_remote_file_info focusing on filename extraction rules.""" + + def test_inline_no_filename(self, monkeypatch: pytest.MonkeyPatch): + _mock_head( + monkeypatch, + { + "Content-Disposition": "inline", + "Content-Type": "application/pdf", + "Content-Length": "123", + }, + ) + mime_type, filename, size = _get_remote_file_info("http://example.com/some/path/file.pdf") + assert filename == "file.pdf" + assert mime_type == "application/pdf" + assert size == 123 + + def test_attachment_no_filename(self, monkeypatch: pytest.MonkeyPatch): + _mock_head( + monkeypatch, + { + "Content-Disposition": "attachment", + "Content-Type": "application/octet-stream", + "Content-Length": "456", + }, + ) + mime_type, filename, size = _get_remote_file_info("http://example.com/downloads/data.bin") + assert filename == "data.bin" + assert mime_type == "application/octet-stream" + assert size == 456 + + def test_attachment_quoted_space_filename(self, monkeypatch: pytest.MonkeyPatch): + _mock_head( + monkeypatch, + { + "Content-Disposition": 'attachment; filename="file name.jpg"', + "Content-Type": "image/jpeg", + "Content-Length": "789", + }, + ) + mime_type, filename, size = _get_remote_file_info("http://example.com/ignored") + assert filename == "file name.jpg" + assert mime_type == "image/jpeg" + assert size == 789 + + def test_attachment_filename_star_percent20(self, monkeypatch: pytest.MonkeyPatch): + _mock_head( + monkeypatch, + { + "Content-Disposition": "attachment; filename*=UTF-8''file%20name.jpg", + "Content-Type": "image/jpeg", + }, + ) + mime_type, filename, _ = _get_remote_file_info("http://example.com/ignored") + assert filename == "file name.jpg" + assert mime_type == "image/jpeg" + + def test_attachment_filename_star_chinese(self, monkeypatch: pytest.MonkeyPatch): + _mock_head( + monkeypatch, + { + "Content-Disposition": "attachment; filename*=UTF-8''%E6%B5%8B%E8%AF%95%E6%96%87%E4%BB%B6.jpg", + "Content-Type": "image/jpeg", + }, + ) + mime_type, filename, _ = _get_remote_file_info("http://example.com/ignored") + assert filename == "测试文件.jpg" + assert mime_type == "image/jpeg" + + def test_filename_from_url_when_no_header(self, monkeypatch: pytest.MonkeyPatch): + _mock_head( + monkeypatch, + { + # No Content-Disposition + "Content-Type": "text/plain", + "Content-Length": "12", + }, + ) + mime_type, filename, size = _get_remote_file_info("http://example.com/static/file.txt") + assert filename == "file.txt" + assert mime_type == "text/plain" + assert size == 12 + + def test_no_filename_in_url_or_header_generates_uuid_bin(self, monkeypatch: pytest.MonkeyPatch): + _mock_head( + monkeypatch, + { + "Content-Disposition": "inline", + "Content-Type": "application/octet-stream", + }, + ) + mime_type, filename, _ = _get_remote_file_info("http://example.com/test/") + # Should generate a random hex filename with .bin extension + assert re.match(r"^[0-9a-f]{32}\.bin$", filename) is not None + assert mime_type == "application/octet-stream" diff --git a/api/tests/unit_tests/libs/test_helper.py b/api/tests/unit_tests/libs/test_helper.py index b7701055f5..85789bfa7e 100644 --- a/api/tests/unit_tests/libs/test_helper.py +++ b/api/tests/unit_tests/libs/test_helper.py @@ -11,7 +11,7 @@ class TestExtractTenantId: def test_extract_tenant_id_from_account_with_tenant(self): """Test extracting tenant_id from Account with current_tenant_id.""" # Create a mock Account object - account = Account() + account = Account(name="test", email="test@example.com") # Mock the current_tenant_id property account._current_tenant = type("MockTenant", (), {"id": "account-tenant-123"})() @@ -21,7 +21,7 @@ class TestExtractTenantId: def test_extract_tenant_id_from_account_without_tenant(self): """Test extracting tenant_id from Account without current_tenant_id.""" # Create a mock Account object - account = Account() + account = Account(name="test", email="test@example.com") account._current_tenant = None tenant_id = extract_tenant_id(account) diff --git a/api/tests/unit_tests/libs/test_oauth_clients.py b/api/tests/unit_tests/libs/test_oauth_clients.py index 629d15b81a..b6595a8c57 100644 --- a/api/tests/unit_tests/libs/test_oauth_clients.py +++ b/api/tests/unit_tests/libs/test_oauth_clients.py @@ -1,8 +1,8 @@ import urllib.parse from unittest.mock import MagicMock, patch +import httpx import pytest -import requests from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo @@ -68,7 +68,7 @@ class TestGitHubOAuth(BaseOAuthTest): ({}, None, True), ], ) - @patch("requests.post") + @patch("httpx.post") def test_should_retrieve_access_token( self, mock_post, oauth, mock_response, response_data, expected_token, should_raise ): @@ -105,7 +105,7 @@ class TestGitHubOAuth(BaseOAuthTest): ), ], ) - @patch("requests.get") + @patch("httpx.get") def test_should_retrieve_user_info_correctly(self, mock_get, oauth, user_data, email_data, expected_email): user_response = MagicMock() user_response.json.return_value = user_data @@ -121,11 +121,11 @@ class TestGitHubOAuth(BaseOAuthTest): assert user_info.name == user_data["name"] assert user_info.email == expected_email - @patch("requests.get") + @patch("httpx.get") def test_should_handle_network_errors(self, mock_get, oauth): - mock_get.side_effect = requests.exceptions.RequestException("Network error") + mock_get.side_effect = httpx.RequestError("Network error") - with pytest.raises(requests.exceptions.RequestException): + with pytest.raises(httpx.RequestError): oauth.get_raw_user_info("test_token") @@ -167,7 +167,7 @@ class TestGoogleOAuth(BaseOAuthTest): ({}, None, True), ], ) - @patch("requests.post") + @patch("httpx.post") def test_should_retrieve_access_token( self, mock_post, oauth, oauth_config, mock_response, response_data, expected_token, should_raise ): @@ -201,7 +201,7 @@ class TestGoogleOAuth(BaseOAuthTest): ({"sub": "123", "email": "test@example.com", "name": "Test User"}, ""), # Always returns empty string ], ) - @patch("requests.get") + @patch("httpx.get") def test_should_retrieve_user_info_correctly(self, mock_get, oauth, mock_response, user_data, expected_name): mock_response.json.return_value = user_data mock_get.return_value = mock_response @@ -217,12 +217,12 @@ class TestGoogleOAuth(BaseOAuthTest): @pytest.mark.parametrize( "exception_type", [ - requests.exceptions.HTTPError, - requests.exceptions.ConnectionError, - requests.exceptions.Timeout, + httpx.HTTPError, + httpx.ConnectError, + httpx.TimeoutException, ], ) - @patch("requests.get") + @patch("httpx.get") def test_should_handle_http_errors(self, mock_get, oauth, exception_type): mock_response = MagicMock() mock_response.raise_for_status.side_effect = exception_type("Error") diff --git a/api/tests/unit_tests/oss/volcengine_tos/test_volcengine_tos.py b/api/tests/unit_tests/oss/volcengine_tos/test_volcengine_tos.py index 04988e85d8..1659205ec0 100644 --- a/api/tests/unit_tests/oss/volcengine_tos/test_volcengine_tos.py +++ b/api/tests/unit_tests/oss/volcengine_tos/test_volcengine_tos.py @@ -1,3 +1,5 @@ +from unittest.mock import patch + import pytest from tos import TosClientV2 # type: ignore @@ -13,7 +15,13 @@ class TestVolcengineTos(BaseStorageTest): @pytest.fixture(autouse=True) def setup_method(self, setup_volcengine_tos_mock): """Executed before each test method.""" - self.storage = VolcengineTosStorage() + with patch("extensions.storage.volcengine_tos_storage.dify_config") as mock_config: + mock_config.VOLCENGINE_TOS_ACCESS_KEY = "test_access_key" + mock_config.VOLCENGINE_TOS_SECRET_KEY = "test_secret_key" + mock_config.VOLCENGINE_TOS_ENDPOINT = "test_endpoint" + mock_config.VOLCENGINE_TOS_REGION = "test_region" + self.storage = VolcengineTosStorage() + self.storage.bucket_name = get_example_bucket() self.storage.client = TosClientV2( ak="dify", diff --git a/api/tests/unit_tests/repositories/workflow_node_execution/test_sqlalchemy_repository.py b/api/tests/unit_tests/repositories/workflow_node_execution/test_sqlalchemy_repository.py index fadd1ee88f..5cba43714a 100644 --- a/api/tests/unit_tests/repositories/workflow_node_execution/test_sqlalchemy_repository.py +++ b/api/tests/unit_tests/repositories/workflow_node_execution/test_sqlalchemy_repository.py @@ -59,12 +59,11 @@ def session(): @pytest.fixture def mock_user(): """Create a user instance for testing.""" - user = Account() + user = Account(name="test", email="test@example.com") user.id = "test-user-id" - tenant = Tenant() + tenant = Tenant(name="Test Workspace") tenant.id = "test-tenant" - tenant.name = "Test Workspace" user._current_tenant = MagicMock() user._current_tenant.id = "test-tenant" @@ -299,7 +298,7 @@ def test_to_domain_model(repository): db_model.predecessor_node_id = "test-predecessor-id" db_model.node_execution_id = "test-node-execution-id" db_model.node_id = "test-node-id" - db_model.node_type = NodeType.START.value + db_model.node_type = NodeType.START db_model.title = "Test Node" db_model.inputs = json.dumps(inputs_dict) db_model.process_data = json.dumps(process_data_dict) diff --git a/api/tests/unit_tests/services/auth/test_auth_integration.py b/api/tests/unit_tests/services/auth/test_auth_integration.py index bb39b92c09..acfc5cc526 100644 --- a/api/tests/unit_tests/services/auth/test_auth_integration.py +++ b/api/tests/unit_tests/services/auth/test_auth_integration.py @@ -6,8 +6,8 @@ import json from concurrent.futures import ThreadPoolExecutor from unittest.mock import Mock, patch +import httpx import pytest -import requests from services.auth.api_key_auth_factory import ApiKeyAuthFactory from services.auth.api_key_auth_service import ApiKeyAuthService @@ -26,7 +26,7 @@ class TestAuthIntegration: self.watercrawl_credentials = {"auth_type": "x-api-key", "config": {"api_key": "wc_test_key_789"}} @patch("services.auth.api_key_auth_service.db.session") - @patch("services.auth.firecrawl.firecrawl.requests.post") + @patch("services.auth.firecrawl.firecrawl.httpx.post") @patch("services.auth.api_key_auth_service.encrypter.encrypt_token") def test_end_to_end_auth_flow(self, mock_encrypt, mock_http, mock_session): """Test complete authentication flow: request → validation → encryption → storage""" @@ -47,7 +47,7 @@ class TestAuthIntegration: mock_session.add.assert_called_once() mock_session.commit.assert_called_once() - @patch("services.auth.firecrawl.firecrawl.requests.post") + @patch("services.auth.firecrawl.firecrawl.httpx.post") def test_cross_component_integration(self, mock_http): """Test factory → provider → HTTP call integration""" mock_http.return_value = self._create_success_response() @@ -97,7 +97,7 @@ class TestAuthIntegration: assert "another_secret" not in factory_str @patch("services.auth.api_key_auth_service.db.session") - @patch("services.auth.firecrawl.firecrawl.requests.post") + @patch("services.auth.firecrawl.firecrawl.httpx.post") @patch("services.auth.api_key_auth_service.encrypter.encrypt_token") def test_concurrent_creation_safety(self, mock_encrypt, mock_http, mock_session): """Test concurrent authentication creation safety""" @@ -142,31 +142,31 @@ class TestAuthIntegration: with pytest.raises((ValueError, KeyError, TypeError, AttributeError)): ApiKeyAuthFactory(AuthType.FIRECRAWL, invalid_input) - @patch("services.auth.firecrawl.firecrawl.requests.post") + @patch("services.auth.firecrawl.firecrawl.httpx.post") def test_http_error_handling(self, mock_http): """Test proper HTTP error handling""" mock_response = Mock() mock_response.status_code = 401 mock_response.text = '{"error": "Unauthorized"}' - mock_response.raise_for_status.side_effect = requests.exceptions.HTTPError("Unauthorized") + mock_response.raise_for_status.side_effect = httpx.HTTPError("Unauthorized") mock_http.return_value = mock_response # PT012: Split into single statement for pytest.raises factory = ApiKeyAuthFactory(AuthType.FIRECRAWL, self.firecrawl_credentials) - with pytest.raises((requests.exceptions.HTTPError, Exception)): + with pytest.raises((httpx.HTTPError, Exception)): factory.validate_credentials() @patch("services.auth.api_key_auth_service.db.session") - @patch("services.auth.firecrawl.firecrawl.requests.post") + @patch("services.auth.firecrawl.firecrawl.httpx.post") def test_network_failure_recovery(self, mock_http, mock_session): """Test system recovery from network failures""" - mock_http.side_effect = requests.exceptions.RequestException("Network timeout") + mock_http.side_effect = httpx.RequestError("Network timeout") mock_session.add = Mock() mock_session.commit = Mock() args = {"category": self.category, "provider": AuthType.FIRECRAWL, "credentials": self.firecrawl_credentials} - with pytest.raises(requests.exceptions.RequestException): + with pytest.raises(httpx.RequestError): ApiKeyAuthService.create_provider_auth(self.tenant_id_1, args) mock_session.commit.assert_not_called() diff --git a/api/tests/unit_tests/services/auth/test_firecrawl_auth.py b/api/tests/unit_tests/services/auth/test_firecrawl_auth.py index ffdf5897ed..b5ee55706d 100644 --- a/api/tests/unit_tests/services/auth/test_firecrawl_auth.py +++ b/api/tests/unit_tests/services/auth/test_firecrawl_auth.py @@ -1,7 +1,7 @@ from unittest.mock import MagicMock, patch +import httpx import pytest -import requests from services.auth.firecrawl.firecrawl import FirecrawlAuth @@ -64,7 +64,7 @@ class TestFirecrawlAuth: FirecrawlAuth(credentials) assert str(exc_info.value) == expected_error - @patch("services.auth.firecrawl.firecrawl.requests.post") + @patch("services.auth.firecrawl.firecrawl.httpx.post") def test_should_validate_valid_credentials_successfully(self, mock_post, auth_instance): """Test successful credential validation""" mock_response = MagicMock() @@ -95,7 +95,7 @@ class TestFirecrawlAuth: (500, "Internal server error"), ], ) - @patch("services.auth.firecrawl.firecrawl.requests.post") + @patch("services.auth.firecrawl.firecrawl.httpx.post") def test_should_handle_http_errors(self, mock_post, status_code, error_message, auth_instance): """Test handling of various HTTP error codes""" mock_response = MagicMock() @@ -115,7 +115,7 @@ class TestFirecrawlAuth: (401, "Not JSON", True, "Expecting value"), # JSON decode error ], ) - @patch("services.auth.firecrawl.firecrawl.requests.post") + @patch("services.auth.firecrawl.firecrawl.httpx.post") def test_should_handle_unexpected_errors( self, mock_post, status_code, response_text, has_json_error, expected_error_contains, auth_instance ): @@ -134,13 +134,13 @@ class TestFirecrawlAuth: @pytest.mark.parametrize( ("exception_type", "exception_message"), [ - (requests.ConnectionError, "Network error"), - (requests.Timeout, "Request timeout"), - (requests.ReadTimeout, "Read timeout"), - (requests.ConnectTimeout, "Connection timeout"), + (httpx.ConnectError, "Network error"), + (httpx.TimeoutException, "Request timeout"), + (httpx.ReadTimeout, "Read timeout"), + (httpx.ConnectTimeout, "Connection timeout"), ], ) - @patch("services.auth.firecrawl.firecrawl.requests.post") + @patch("services.auth.firecrawl.firecrawl.httpx.post") def test_should_handle_network_errors(self, mock_post, exception_type, exception_message, auth_instance): """Test handling of various network-related errors including timeouts""" mock_post.side_effect = exception_type(exception_message) @@ -162,7 +162,7 @@ class TestFirecrawlAuth: FirecrawlAuth({"auth_type": "basic", "config": {"api_key": "super_secret_key_12345"}}) assert "super_secret_key_12345" not in str(exc_info.value) - @patch("services.auth.firecrawl.firecrawl.requests.post") + @patch("services.auth.firecrawl.firecrawl.httpx.post") def test_should_use_custom_base_url_in_validation(self, mock_post): """Test that custom base URL is used in validation""" mock_response = MagicMock() @@ -179,12 +179,12 @@ class TestFirecrawlAuth: assert result is True assert mock_post.call_args[0][0] == "https://custom.firecrawl.dev/v1/crawl" - @patch("services.auth.firecrawl.firecrawl.requests.post") + @patch("services.auth.firecrawl.firecrawl.httpx.post") def test_should_handle_timeout_with_retry_suggestion(self, mock_post, auth_instance): """Test that timeout errors are handled gracefully with appropriate error message""" - mock_post.side_effect = requests.Timeout("The request timed out after 30 seconds") + mock_post.side_effect = httpx.TimeoutException("The request timed out after 30 seconds") - with pytest.raises(requests.Timeout) as exc_info: + with pytest.raises(httpx.TimeoutException) as exc_info: auth_instance.validate_credentials() # Verify the timeout exception is raised with original message diff --git a/api/tests/unit_tests/services/auth/test_jina_auth.py b/api/tests/unit_tests/services/auth/test_jina_auth.py index ccbca5a36f..4d2f300d25 100644 --- a/api/tests/unit_tests/services/auth/test_jina_auth.py +++ b/api/tests/unit_tests/services/auth/test_jina_auth.py @@ -1,7 +1,7 @@ from unittest.mock import MagicMock, patch +import httpx import pytest -import requests from services.auth.jina.jina import JinaAuth @@ -35,7 +35,7 @@ class TestJinaAuth: JinaAuth(credentials) assert str(exc_info.value) == "No API key provided" - @patch("services.auth.jina.jina.requests.post") + @patch("services.auth.jina.jina.httpx.post") def test_should_validate_valid_credentials_successfully(self, mock_post): """Test successful credential validation""" mock_response = MagicMock() @@ -53,7 +53,7 @@ class TestJinaAuth: json={"url": "https://example.com"}, ) - @patch("services.auth.jina.jina.requests.post") + @patch("services.auth.jina.jina.httpx.post") def test_should_handle_http_402_error(self, mock_post): """Test handling of 402 Payment Required error""" mock_response = MagicMock() @@ -68,7 +68,7 @@ class TestJinaAuth: auth.validate_credentials() assert str(exc_info.value) == "Failed to authorize. Status code: 402. Error: Payment required" - @patch("services.auth.jina.jina.requests.post") + @patch("services.auth.jina.jina.httpx.post") def test_should_handle_http_409_error(self, mock_post): """Test handling of 409 Conflict error""" mock_response = MagicMock() @@ -83,7 +83,7 @@ class TestJinaAuth: auth.validate_credentials() assert str(exc_info.value) == "Failed to authorize. Status code: 409. Error: Conflict error" - @patch("services.auth.jina.jina.requests.post") + @patch("services.auth.jina.jina.httpx.post") def test_should_handle_http_500_error(self, mock_post): """Test handling of 500 Internal Server Error""" mock_response = MagicMock() @@ -98,7 +98,7 @@ class TestJinaAuth: auth.validate_credentials() assert str(exc_info.value) == "Failed to authorize. Status code: 500. Error: Internal server error" - @patch("services.auth.jina.jina.requests.post") + @patch("services.auth.jina.jina.httpx.post") def test_should_handle_unexpected_error_with_text_response(self, mock_post): """Test handling of unexpected errors with text response""" mock_response = MagicMock() @@ -114,7 +114,7 @@ class TestJinaAuth: auth.validate_credentials() assert str(exc_info.value) == "Failed to authorize. Status code: 403. Error: Forbidden" - @patch("services.auth.jina.jina.requests.post") + @patch("services.auth.jina.jina.httpx.post") def test_should_handle_unexpected_error_without_text(self, mock_post): """Test handling of unexpected errors without text response""" mock_response = MagicMock() @@ -130,15 +130,15 @@ class TestJinaAuth: auth.validate_credentials() assert str(exc_info.value) == "Unexpected error occurred while trying to authorize. Status code: 404" - @patch("services.auth.jina.jina.requests.post") + @patch("services.auth.jina.jina.httpx.post") def test_should_handle_network_errors(self, mock_post): """Test handling of network connection errors""" - mock_post.side_effect = requests.ConnectionError("Network error") + mock_post.side_effect = httpx.ConnectError("Network error") credentials = {"auth_type": "bearer", "config": {"api_key": "test_api_key_123"}} auth = JinaAuth(credentials) - with pytest.raises(requests.ConnectionError): + with pytest.raises(httpx.ConnectError): auth.validate_credentials() def test_should_not_expose_api_key_in_error_messages(self): diff --git a/api/tests/unit_tests/services/auth/test_watercrawl_auth.py b/api/tests/unit_tests/services/auth/test_watercrawl_auth.py index bacf0b24ea..ec99cb10b0 100644 --- a/api/tests/unit_tests/services/auth/test_watercrawl_auth.py +++ b/api/tests/unit_tests/services/auth/test_watercrawl_auth.py @@ -1,7 +1,7 @@ from unittest.mock import MagicMock, patch +import httpx import pytest -import requests from services.auth.watercrawl.watercrawl import WatercrawlAuth @@ -64,7 +64,7 @@ class TestWatercrawlAuth: WatercrawlAuth(credentials) assert str(exc_info.value) == expected_error - @patch("services.auth.watercrawl.watercrawl.requests.get") + @patch("services.auth.watercrawl.watercrawl.httpx.get") def test_should_validate_valid_credentials_successfully(self, mock_get, auth_instance): """Test successful credential validation""" mock_response = MagicMock() @@ -87,7 +87,7 @@ class TestWatercrawlAuth: (500, "Internal server error"), ], ) - @patch("services.auth.watercrawl.watercrawl.requests.get") + @patch("services.auth.watercrawl.watercrawl.httpx.get") def test_should_handle_http_errors(self, mock_get, status_code, error_message, auth_instance): """Test handling of various HTTP error codes""" mock_response = MagicMock() @@ -107,7 +107,7 @@ class TestWatercrawlAuth: (401, "Not JSON", True, "Expecting value"), # JSON decode error ], ) - @patch("services.auth.watercrawl.watercrawl.requests.get") + @patch("services.auth.watercrawl.watercrawl.httpx.get") def test_should_handle_unexpected_errors( self, mock_get, status_code, response_text, has_json_error, expected_error_contains, auth_instance ): @@ -126,13 +126,13 @@ class TestWatercrawlAuth: @pytest.mark.parametrize( ("exception_type", "exception_message"), [ - (requests.ConnectionError, "Network error"), - (requests.Timeout, "Request timeout"), - (requests.ReadTimeout, "Read timeout"), - (requests.ConnectTimeout, "Connection timeout"), + (httpx.ConnectError, "Network error"), + (httpx.TimeoutException, "Request timeout"), + (httpx.ReadTimeout, "Read timeout"), + (httpx.ConnectTimeout, "Connection timeout"), ], ) - @patch("services.auth.watercrawl.watercrawl.requests.get") + @patch("services.auth.watercrawl.watercrawl.httpx.get") def test_should_handle_network_errors(self, mock_get, exception_type, exception_message, auth_instance): """Test handling of various network-related errors including timeouts""" mock_get.side_effect = exception_type(exception_message) @@ -154,7 +154,7 @@ class TestWatercrawlAuth: WatercrawlAuth({"auth_type": "bearer", "config": {"api_key": "super_secret_key_12345"}}) assert "super_secret_key_12345" not in str(exc_info.value) - @patch("services.auth.watercrawl.watercrawl.requests.get") + @patch("services.auth.watercrawl.watercrawl.httpx.get") def test_should_use_custom_base_url_in_validation(self, mock_get): """Test that custom base URL is used in validation""" mock_response = MagicMock() @@ -179,7 +179,7 @@ class TestWatercrawlAuth: ("https://app.watercrawl.dev//", "https://app.watercrawl.dev/api/v1/core/crawl-requests/"), ], ) - @patch("services.auth.watercrawl.watercrawl.requests.get") + @patch("services.auth.watercrawl.watercrawl.httpx.get") def test_should_use_urljoin_for_url_construction(self, mock_get, base_url, expected_url): """Test that urljoin is used correctly for URL construction with various base URLs""" mock_response = MagicMock() @@ -193,12 +193,12 @@ class TestWatercrawlAuth: # Verify the correct URL was called assert mock_get.call_args[0][0] == expected_url - @patch("services.auth.watercrawl.watercrawl.requests.get") + @patch("services.auth.watercrawl.watercrawl.httpx.get") def test_should_handle_timeout_with_retry_suggestion(self, mock_get, auth_instance): """Test that timeout errors are handled gracefully with appropriate error message""" - mock_get.side_effect = requests.Timeout("The request timed out after 30 seconds") + mock_get.side_effect = httpx.TimeoutException("The request timed out after 30 seconds") - with pytest.raises(requests.Timeout) as exc_info: + with pytest.raises(httpx.TimeoutException) as exc_info: auth_instance.validate_credentials() # Verify the timeout exception is raised with original message diff --git a/api/tests/unit_tests/services/test_metadata_bug_complete.py b/api/tests/unit_tests/services/test_metadata_bug_complete.py index 0ff1edc950..31fe9b2868 100644 --- a/api/tests/unit_tests/services/test_metadata_bug_complete.py +++ b/api/tests/unit_tests/services/test_metadata_bug_complete.py @@ -118,7 +118,7 @@ class TestMetadataBugCompleteValidation: # But would crash when trying to create MetadataArgs with pytest.raises((ValueError, TypeError)): - MetadataArgs(**args) + MetadataArgs.model_validate(args) def test_7_end_to_end_validation_layers(self): """Test all validation layers work together correctly.""" @@ -131,7 +131,7 @@ class TestMetadataBugCompleteValidation: valid_data = {"type": "string", "name": "test_metadata"} # Should create valid Pydantic object - metadata_args = MetadataArgs(**valid_data) + metadata_args = MetadataArgs.model_validate(valid_data) assert metadata_args.type == "string" assert metadata_args.name == "test_metadata" diff --git a/api/tests/unit_tests/services/test_metadata_nullable_bug.py b/api/tests/unit_tests/services/test_metadata_nullable_bug.py index d151100cf3..c8cd7025c2 100644 --- a/api/tests/unit_tests/services/test_metadata_nullable_bug.py +++ b/api/tests/unit_tests/services/test_metadata_nullable_bug.py @@ -76,7 +76,7 @@ class TestMetadataNullableBug: # Step 2: Try to create MetadataArgs with None values # This should fail at Pydantic validation level with pytest.raises((ValueError, TypeError)): - metadata_args = MetadataArgs(**args) + metadata_args = MetadataArgs.model_validate(args) # Step 3: If we bypass Pydantic (simulating the bug scenario) # Move this outside the request context to avoid Flask-Login issues diff --git a/api/tests/unit_tests/services/test_variable_truncator.py b/api/tests/unit_tests/services/test_variable_truncator.py index 0ad056c985..6761f939e3 100644 --- a/api/tests/unit_tests/services/test_variable_truncator.py +++ b/api/tests/unit_tests/services/test_variable_truncator.py @@ -588,3 +588,11 @@ class TestIntegrationScenarios: if isinstance(result.result, ObjectSegment): result_size = truncator.calculate_json_size(result.result.value) assert result_size <= original_size + + def test_file_and_array_file_variable_mapping(self, file): + truncator = VariableTruncator(string_length_limit=30, array_element_limit=3, max_size_bytes=300) + + mapping = {"array_file": [file]} + truncated_mapping, truncated = truncator.truncate_variable_mapping(mapping) + assert truncated is False + assert truncated_mapping == mapping diff --git a/api/tests/unit_tests/services/workflow/test_workflow_converter.py b/api/tests/unit_tests/services/workflow/test_workflow_converter.py index 2ca781bae5..63ce4c0c3c 100644 --- a/api/tests/unit_tests/services/workflow/test_workflow_converter.py +++ b/api/tests/unit_tests/services/workflow/test_workflow_converter.py @@ -107,7 +107,7 @@ def test__convert_to_http_request_node_for_chatbot(default_variables): assert body_data body_data_json = json.loads(body_data) - assert body_data_json["point"] == APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY.value + assert body_data_json["point"] == APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY body_params = body_data_json["params"] assert body_params["app_id"] == app_model.id @@ -168,7 +168,7 @@ def test__convert_to_http_request_node_for_workflow_app(default_variables): assert body_data body_data_json = json.loads(body_data) - assert body_data_json["point"] == APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY.value + assert body_data_json["point"] == APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY body_params = body_data_json["params"] assert body_params["app_id"] == app_model.id diff --git a/api/tests/unit_tests/services/workflow/test_workflow_draft_variable_service.py b/api/tests/unit_tests/services/workflow/test_workflow_draft_variable_service.py index 7e324ca4db..66361f26e0 100644 --- a/api/tests/unit_tests/services/workflow/test_workflow_draft_variable_service.py +++ b/api/tests/unit_tests/services/workflow/test_workflow_draft_variable_service.py @@ -47,7 +47,8 @@ class TestDraftVariableSaver: def test__should_variable_be_visible(self): mock_session = MagicMock(spec=Session) - mock_user = Account(id=str(uuid.uuid4())) + mock_user = Account(name="test", email="test@example.com") + mock_user.id = str(uuid.uuid4()) test_app_id = self._get_test_app_id() saver = DraftVariableSaver( session=mock_session, diff --git a/api/uv.lock b/api/uv.lock index ee49e79eff..49339129e1 100644 --- a/api/uv.lock +++ b/api/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 2 +revision = 3 requires-python = ">=3.11, <3.13" resolution-markers = [ "python_full_version >= '3.12.4' and sys_platform == 'linux'", @@ -445,16 +445,17 @@ wheels = [ [[package]] name = "azure-storage-blob" -version = "12.13.0" +version = "12.26.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "azure-core" }, { name = "cryptography" }, - { name = "msrest" }, + { name = "isodate" }, + { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/b1/93/b13bf390e940a79a399981f75ac8d2e05a70112a95ebb7b41e9b752d2921/azure-storage-blob-12.13.0.zip", hash = "sha256:53f0d4cd32970ac9ff9b9753f83dd2fb3f9ac30e1d01e71638c436c509bfd884", size = 684838, upload-time = "2022-07-07T22:35:44.543Z" } +sdist = { url = "https://files.pythonhosted.org/packages/96/95/3e3414491ce45025a1cde107b6ae72bf72049e6021597c201cd6a3029b9a/azure_storage_blob-12.26.0.tar.gz", hash = "sha256:5dd7d7824224f7de00bfeb032753601c982655173061e242f13be6e26d78d71f", size = 583332, upload-time = "2025-07-16T21:34:07.644Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/0e/2a/b8246df35af68d64fb7292c93dbbde63cd25036f2f669a9d9ae59e518c76/azure_storage_blob-12.13.0-py3-none-any.whl", hash = "sha256:280a6ab032845bab9627582bee78a50497ca2f14772929b5c5ee8b4605af0cb3", size = 377309, upload-time = "2022-07-07T22:35:41.905Z" }, + { url = "https://files.pythonhosted.org/packages/5b/64/63dbfdd83b31200ac58820a7951ddfdeed1fbee9285b0f3eae12d1357155/azure_storage_blob-12.26.0-py3-none-any.whl", hash = "sha256:8c5631b8b22b4f53ec5fff2f3bededf34cfef111e2af613ad42c9e6de00a77fe", size = 412907, upload-time = "2025-07-16T21:34:09.367Z" }, ] [[package]] @@ -1076,7 +1077,7 @@ wheels = [ [[package]] name = "cos-python-sdk-v5" -version = "1.9.30" +version = "1.9.38" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "crcmod" }, @@ -1085,7 +1086,10 @@ dependencies = [ { name = "six" }, { name = "xmltodict" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/c4/f2/be99b41433b33a76896680920fca621f191875ca410a66778015e47a501b/cos-python-sdk-v5-1.9.30.tar.gz", hash = "sha256:a23fd090211bf90883066d90cd74317860aa67c6d3aa80fe5e44b18c7e9b2a81", size = 108384, upload-time = "2024-06-14T08:02:37.063Z" } +sdist = { url = "https://files.pythonhosted.org/packages/24/3c/d208266fec7cc3221b449e236b87c3fc1999d5ac4379d4578480321cfecc/cos_python_sdk_v5-1.9.38.tar.gz", hash = "sha256:491a8689ae2f1a6f04dacba66a877b2c8d361456f9cfd788ed42170a1cbf7a9f", size = 98092, upload-time = "2025-07-22T07:56:20.34Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ab/c8/c9c156aa3bc7caba9b4f8a2b6abec3da6263215988f3fec0ea843f137a10/cos_python_sdk_v5-1.9.38-py3-none-any.whl", hash = "sha256:1d3dd3be2bd992b2e9c2dcd018e2596aa38eab022dbc86b4a5d14c8fc88370e6", size = 92601, upload-time = "2025-08-17T05:12:30.867Z" }, +] [[package]] name = "couchbase" @@ -1273,11 +1277,10 @@ wheels = [ [[package]] name = "dify-api" -version = "2.0.0b2" +version = "1.9.1" source = { virtual = "." } dependencies = [ { name = "arize-phoenix-otel" }, - { name = "authlib" }, { name = "azure-identity" }, { name = "beautifulsoup4" }, { name = "boto3" }, @@ -1308,10 +1311,8 @@ dependencies = [ { name = "json-repair" }, { name = "langfuse" }, { name = "langsmith" }, - { name = "mailchimp-transactional" }, { name = "markdown" }, { name = "numpy" }, - { name = "openai" }, { name = "openpyxl" }, { name = "opentelemetry-api" }, { name = "opentelemetry-distro" }, @@ -1322,8 +1323,8 @@ dependencies = [ { name = "opentelemetry-instrumentation" }, { name = "opentelemetry-instrumentation-celery" }, { name = "opentelemetry-instrumentation-flask" }, + { name = "opentelemetry-instrumentation-httpx" }, { name = "opentelemetry-instrumentation-redis" }, - { name = "opentelemetry-instrumentation-requests" }, { name = "opentelemetry-instrumentation-sqlalchemy" }, { name = "opentelemetry-propagator-b3" }, { name = "opentelemetry-proto" }, @@ -1333,7 +1334,6 @@ dependencies = [ { name = "opik" }, { name = "packaging" }, { name = "pandas", extra = ["excel", "output-formatting", "performance"] }, - { name = "pandoc" }, { name = "psycogreen" }, { name = "psycopg2-binary" }, { name = "pycryptodome" }, @@ -1417,8 +1417,6 @@ dev = [ { name = "types-pyyaml" }, { name = "types-redis" }, { name = "types-regex" }, - { name = "types-requests" }, - { name = "types-requests-oauthlib" }, { name = "types-setuptools" }, { name = "types-shapely" }, { name = "types-simplejson" }, @@ -1451,6 +1449,7 @@ vdb = [ { name = "couchbase" }, { name = "elasticsearch" }, { name = "mo-vector" }, + { name = "mysql-connector-python" }, { name = "opensearch-py" }, { name = "oracledb" }, { name = "pgvecto-rs", extra = ["sqlalchemy"] }, @@ -1471,7 +1470,6 @@ vdb = [ [package.metadata] requires-dist = [ { name = "arize-phoenix-otel", specifier = "~=0.9.2" }, - { name = "authlib", specifier = "==1.6.4" }, { name = "azure-identity", specifier = "==1.16.1" }, { name = "beautifulsoup4", specifier = "==4.12.2" }, { name = "boto3", specifier = "==1.35.99" }, @@ -1502,10 +1500,8 @@ requires-dist = [ { name = "json-repair", specifier = ">=0.41.1" }, { name = "langfuse", specifier = "~=2.51.3" }, { name = "langsmith", specifier = "~=0.1.77" }, - { name = "mailchimp-transactional", specifier = "~=1.0.50" }, { name = "markdown", specifier = "~=3.5.1" }, { name = "numpy", specifier = "~=1.26.4" }, - { name = "openai", specifier = "~=1.61.0" }, { name = "openpyxl", specifier = "~=3.1.5" }, { name = "opentelemetry-api", specifier = "==1.27.0" }, { name = "opentelemetry-distro", specifier = "==0.48b0" }, @@ -1516,8 +1512,8 @@ requires-dist = [ { name = "opentelemetry-instrumentation", specifier = "==0.48b0" }, { name = "opentelemetry-instrumentation-celery", specifier = "==0.48b0" }, { name = "opentelemetry-instrumentation-flask", specifier = "==0.48b0" }, + { name = "opentelemetry-instrumentation-httpx", specifier = "==0.48b0" }, { name = "opentelemetry-instrumentation-redis", specifier = "==0.48b0" }, - { name = "opentelemetry-instrumentation-requests", specifier = "==0.48b0" }, { name = "opentelemetry-instrumentation-sqlalchemy", specifier = "==0.48b0" }, { name = "opentelemetry-propagator-b3", specifier = "==1.27.0" }, { name = "opentelemetry-proto", specifier = "==1.27.0" }, @@ -1527,7 +1523,6 @@ requires-dist = [ { name = "opik", specifier = "~=1.7.25" }, { name = "packaging", specifier = "~=23.2" }, { name = "pandas", extras = ["excel", "output-formatting", "performance"], specifier = "~=2.2.2" }, - { name = "pandoc", specifier = "~=2.4" }, { name = "psycogreen", specifier = "~=1.0.2" }, { name = "psycopg2-binary", specifier = "~=2.9.6" }, { name = "pycryptodome", specifier = "==3.19.1" }, @@ -1573,7 +1568,7 @@ dev = [ { name = "pytest-cov", specifier = "~=4.1.0" }, { name = "pytest-env", specifier = "~=1.1.3" }, { name = "pytest-mock", specifier = "~=3.14.0" }, - { name = "ruff", specifier = "~=0.12.3" }, + { name = "ruff", specifier = "~=0.14.0" }, { name = "scipy-stubs", specifier = ">=1.15.3.0" }, { name = "sseclient-py", specifier = ">=1.8.0" }, { name = "testcontainers", specifier = "~=4.10.0" }, @@ -1611,8 +1606,6 @@ dev = [ { name = "types-pyyaml", specifier = "~=6.0.12" }, { name = "types-redis", specifier = ">=4.6.0.20241004" }, { name = "types-regex", specifier = "~=2024.11.6" }, - { name = "types-requests", specifier = "~=2.32.0" }, - { name = "types-requests-oauthlib", specifier = "~=2.0.0" }, { name = "types-setuptools", specifier = ">=80.9.0" }, { name = "types-shapely", specifier = "~=2.0.0" }, { name = "types-simplejson", specifier = ">=3.20.0" }, @@ -1622,10 +1615,10 @@ dev = [ { name = "types-ujson", specifier = ">=5.10.0" }, ] storage = [ - { name = "azure-storage-blob", specifier = "==12.13.0" }, + { name = "azure-storage-blob", specifier = "==12.26.0" }, { name = "bce-python-sdk", specifier = "~=0.9.23" }, - { name = "cos-python-sdk-v5", specifier = "==1.9.30" }, - { name = "esdk-obs-python", specifier = "==3.24.6.1" }, + { name = "cos-python-sdk-v5", specifier = "==1.9.38" }, + { name = "esdk-obs-python", specifier = "==3.25.8" }, { name = "google-cloud-storage", specifier = "==2.16.0" }, { name = "opendal", specifier = "~=0.46.0" }, { name = "oss2", specifier = "==2.18.5" }, @@ -1645,8 +1638,9 @@ vdb = [ { name = "couchbase", specifier = "~=4.3.0" }, { name = "elasticsearch", specifier = "==8.14.0" }, { name = "mo-vector", specifier = "~=0.1.13" }, + { name = "mysql-connector-python", specifier = ">=9.3.0" }, { name = "opensearch-py", specifier = "==2.4.0" }, - { name = "oracledb", specifier = "==3.0.0" }, + { name = "oracledb", specifier = "==3.3.0" }, { name = "pgvecto-rs", extras = ["sqlalchemy"], specifier = "~=0.2.1" }, { name = "pgvector", specifier = "==0.2.5" }, { name = "pymilvus", specifier = "~=2.5.0" }, @@ -1776,12 +1770,14 @@ wheels = [ [[package]] name = "esdk-obs-python" -version = "3.24.6.1" +version = "3.25.8" source = { registry = "https://pypi.org/simple" } dependencies = [ + { name = "crcmod" }, { name = "pycryptodome" }, + { name = "requests" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/f7/af/d83276f9e288bd6a62f44d67ae1eafd401028ba1b2b643ae4014b51da5bd/esdk-obs-python-3.24.6.1.tar.gz", hash = "sha256:c45fed143e99d9256c8560c1d78f651eae0d2e809d16e962f8b286b773c33bf0", size = 85798, upload-time = "2024-07-26T13:13:22.467Z" } +sdist = { url = "https://files.pythonhosted.org/packages/40/99/52362d6e081a642d6de78f6ab53baa5e3f82f2386c48954e18ee7b4ab22b/esdk-obs-python-3.25.8.tar.gz", hash = "sha256:aeded00b27ecd5a25ffaec38a2cc9416b51923d48db96c663f1a735f859b5273", size = 96302, upload-time = "2025-09-01T11:35:20.432Z" } [[package]] name = "et-xmlfile" @@ -3166,21 +3162,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/6c/e1/0686c91738f3e6c2e1a243e0fdd4371667c4d2e5009b0a3605806c2aa020/lz4-4.4.4-cp312-cp312-win_arm64.whl", hash = "sha256:2f4f2965c98ab254feddf6b5072854a6935adab7bc81412ec4fe238f07b85f62", size = 89736, upload-time = "2025-04-01T22:55:40.5Z" }, ] -[[package]] -name = "mailchimp-transactional" -version = "1.0.56" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "certifi" }, - { name = "python-dateutil" }, - { name = "requests" }, - { name = "six" }, - { name = "urllib3" }, -] -wheels = [ - { url = "https://files.pythonhosted.org/packages/5f/bc/cb60d02c00996839bbd87444a97d0ba5ac271b1a324001562afb8f685251/mailchimp_transactional-1.0.56-py3-none-any.whl", hash = "sha256:a76ea88b90a2d47d8b5134586aabbd3a96c459f6066d8886748ab59e50de36eb", size = 31660, upload-time = "2024-02-01T18:39:19.717Z" }, -] - [[package]] name = "mako" version = "1.3.10" @@ -3366,22 +3347,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/5e/75/bd9b7bb966668920f06b200e84454c8f3566b102183bc55c5473d96cb2b9/msal_extensions-1.3.1-py3-none-any.whl", hash = "sha256:96d3de4d034504e969ac5e85bae8106c8373b5c6568e4c8fa7af2eca9dbe6bca", size = 20583, upload-time = "2025-03-14T23:51:03.016Z" }, ] -[[package]] -name = "msrest" -version = "0.7.1" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "azure-core" }, - { name = "certifi" }, - { name = "isodate" }, - { name = "requests" }, - { name = "requests-oauthlib" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/68/77/8397c8fb8fc257d8ea0fa66f8068e073278c65f05acb17dcb22a02bfdc42/msrest-0.7.1.zip", hash = "sha256:6e7661f46f3afd88b75667b7187a92829924446c7ea1d169be8c4bb7eeb788b9", size = 175332, upload-time = "2022-06-13T22:41:25.111Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/15/cf/f2966a2638144491f8696c27320d5219f48a072715075d168b31d3237720/msrest-0.7.1-py3-none-any.whl", hash = "sha256:21120a810e1233e5e6cc7fe40b474eeb4ec6f757a15d7cf86702c369f9567c32", size = 85384, upload-time = "2022-06-13T22:41:22.42Z" }, -] - [[package]] name = "multidict" version = "6.6.4" @@ -3474,6 +3439,25 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/79/7b/2c79738432f5c924bef5071f933bcc9efd0473bac3b4aa584a6f7c1c8df8/mypy_extensions-1.1.0-py3-none-any.whl", hash = "sha256:1be4cccdb0f2482337c4743e60421de3a356cd97508abadd57d47403e94f5505", size = 4963, upload-time = "2025-04-22T14:54:22.983Z" }, ] +[[package]] +name = "mysql-connector-python" +version = "9.4.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/02/77/2b45e6460d05b1f1b7a4c8eb79a50440b4417971973bb78c9ef6cad630a6/mysql_connector_python-9.4.0.tar.gz", hash = "sha256:d111360332ae78933daf3d48ff497b70739aa292ab0017791a33e826234e743b", size = 12185532, upload-time = "2025-07-22T08:02:05.788Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/fe/0c/4365a802129be9fa63885533c38be019f1c6b6f5bcf8844ac53902314028/mysql_connector_python-9.4.0-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:7df1a8ddd182dd8adc914f6dc902a986787bf9599705c29aca7b2ce84e79d361", size = 17501627, upload-time = "2025-07-22T07:57:45.416Z" }, + { url = "https://files.pythonhosted.org/packages/c0/bf/ca596c00d7a6eaaf8ef2f66c9b23cd312527f483073c43ffac7843049cb4/mysql_connector_python-9.4.0-cp311-cp311-macosx_14_0_x86_64.whl", hash = "sha256:3892f20472e13e63b1fb4983f454771dd29f211b09724e69a9750e299542f2f8", size = 18369494, upload-time = "2025-07-22T07:57:49.714Z" }, + { url = "https://files.pythonhosted.org/packages/25/14/6510a11ed9f80d77f743dc207773092c4ab78d5efa454b39b48480315d85/mysql_connector_python-9.4.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:d3e87142103d71c4df647ece30f98e85e826652272ed1c74822b56f6acdc38e7", size = 33516187, upload-time = "2025-07-22T07:57:55.294Z" }, + { url = "https://files.pythonhosted.org/packages/16/a8/4f99d80f1cf77733ce9a44b6adb7f0dd7079e7afa51ca4826515ef0c3e16/mysql_connector_python-9.4.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:b27fcd403436fe83bafb2fe7fcb785891e821e639275c4ad3b3bd1e25f533206", size = 33917818, upload-time = "2025-07-22T07:58:00.523Z" }, + { url = "https://files.pythonhosted.org/packages/15/9c/127f974ca9d5ee25373cb5433da06bb1f36e05f2a6b7436da1fe9c6346b0/mysql_connector_python-9.4.0-cp311-cp311-win_amd64.whl", hash = "sha256:fd6ff5afb9c324b0bbeae958c93156cce4168c743bf130faf224d52818d1f0ee", size = 16392378, upload-time = "2025-07-22T07:58:04.669Z" }, + { url = "https://files.pythonhosted.org/packages/03/7c/a543fb17c2dfa6be8548dfdc5879a0c7924cd5d1c79056c48472bb8fe858/mysql_connector_python-9.4.0-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:4efa3898a24aba6a4bfdbf7c1f5023c78acca3150d72cc91199cca2ccd22f76f", size = 17503693, upload-time = "2025-07-22T07:58:08.96Z" }, + { url = "https://files.pythonhosted.org/packages/cb/6e/c22fbee05f5cfd6ba76155b6d45f6261d8d4c1e36e23de04e7f25fbd01a4/mysql_connector_python-9.4.0-cp312-cp312-macosx_14_0_x86_64.whl", hash = "sha256:665c13e7402235162e5b7a2bfdee5895192121b64ea455c90a81edac6a48ede5", size = 18371987, upload-time = "2025-07-22T07:58:13.273Z" }, + { url = "https://files.pythonhosted.org/packages/b4/fd/f426f5f35a3d3180c7f84d1f96b4631be2574df94ca1156adab8618b236c/mysql_connector_python-9.4.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:815aa6cad0f351c1223ef345781a538f2e5e44ef405fdb3851eb322bd9c4ca2b", size = 33516214, upload-time = "2025-07-22T07:58:18.967Z" }, + { url = "https://files.pythonhosted.org/packages/45/5a/1b053ae80b43cd3ccebc4bb99a98826969b3b0f8adebdcc2530750ad76ed/mysql_connector_python-9.4.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:b3436a2c8c0ec7052932213e8d01882e6eb069dbab33402e685409084b133a1c", size = 33918565, upload-time = "2025-07-22T07:58:25.28Z" }, + { url = "https://files.pythonhosted.org/packages/cb/69/36b989de675d98ba8ff7d45c96c30c699865c657046f2e32db14e78f13d9/mysql_connector_python-9.4.0-cp312-cp312-win_amd64.whl", hash = "sha256:57b0c224676946b70548c56798d5023f65afa1ba5b8ac9f04a143d27976c7029", size = 16392563, upload-time = "2025-07-22T07:58:29.623Z" }, + { url = "https://files.pythonhosted.org/packages/36/34/b6165e15fd45a8deb00932d8e7d823de7650270873b4044c4db6688e1d8f/mysql_connector_python-9.4.0-py2.py3-none-any.whl", hash = "sha256:56e679169c704dab279b176fab2a9ee32d2c632a866c0f7cd48a8a1e2cf802c4", size = 406574, upload-time = "2025-07-22T07:59:08.394Z" }, +] + [[package]] name = "nest-asyncio" version = "1.6.0" @@ -3911,6 +3895,21 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/78/3d/fcde4f8f0bf9fa1ee73a12304fa538076fb83fe0a2ae966ab0f0b7da5109/opentelemetry_instrumentation_flask-0.48b0-py3-none-any.whl", hash = "sha256:26b045420b9d76e85493b1c23fcf27517972423480dc6cf78fd6924248ba5808", size = 14588, upload-time = "2024-08-28T21:26:58.504Z" }, ] +[[package]] +name = "opentelemetry-instrumentation-httpx" +version = "0.48b0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "opentelemetry-api" }, + { name = "opentelemetry-instrumentation" }, + { name = "opentelemetry-semantic-conventions" }, + { name = "opentelemetry-util-http" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/d3/d9/c65d818607c16d1b7ea8d2de6111c6cecadf8d2fd38c1885a72733a7c6d3/opentelemetry_instrumentation_httpx-0.48b0.tar.gz", hash = "sha256:ee977479e10398931921fb995ac27ccdeea2e14e392cb27ef012fc549089b60a", size = 16931, upload-time = "2024-08-28T21:28:03.794Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c2/fe/f2daa9d6d988c093b8c7b1d35df675761a8ece0b600b035dc04982746c9d/opentelemetry_instrumentation_httpx-0.48b0-py3-none-any.whl", hash = "sha256:d94f9d612c82d09fe22944d1904a30a464c19bea2ba76be656c99a28ad8be8e5", size = 13900, upload-time = "2024-08-28T21:27:01.566Z" }, +] + [[package]] name = "opentelemetry-instrumentation-redis" version = "0.48b0" @@ -3926,21 +3925,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/94/40/892f30d400091106309cc047fd3f6d76a828fedd984a953fd5386b78a2fb/opentelemetry_instrumentation_redis-0.48b0-py3-none-any.whl", hash = "sha256:48c7f2e25cbb30bde749dc0d8b9c74c404c851f554af832956b9630b27f5bcb7", size = 11610, upload-time = "2024-08-28T21:27:18.759Z" }, ] -[[package]] -name = "opentelemetry-instrumentation-requests" -version = "0.48b0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "opentelemetry-api" }, - { name = "opentelemetry-instrumentation" }, - { name = "opentelemetry-semantic-conventions" }, - { name = "opentelemetry-util-http" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/52/ac/5eb78efde21ff21d0ad5dc8c6cc6a0f8ae482ce8a46293c2f45a628b6166/opentelemetry_instrumentation_requests-0.48b0.tar.gz", hash = "sha256:67ab9bd877a0352ee0db4616c8b4ae59736ddd700c598ed907482d44f4c9a2b3", size = 14120, upload-time = "2024-08-28T21:28:16.933Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/43/df/0df9226d1b14f29d23c07e6194b9fd5ad50e7d987b7fd13df7dcf718aeb1/opentelemetry_instrumentation_requests-0.48b0-py3-none-any.whl", hash = "sha256:d4f01852121d0bd4c22f14f429654a735611d4f7bf3cf93f244bdf1489b2233d", size = 12366, upload-time = "2024-08-28T21:27:20.771Z" }, -] - [[package]] name = "opentelemetry-instrumentation-sqlalchemy" version = "0.48b0" @@ -4079,23 +4063,23 @@ numpy = [ [[package]] name = "oracledb" -version = "3.0.0" +version = "3.3.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "cryptography" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/bf/39/712f797b75705c21148fa1d98651f63c2e5cc6876e509a0a9e2f5b406572/oracledb-3.0.0.tar.gz", hash = "sha256:64dc86ee5c032febc556798b06e7b000ef6828bb0252084f6addacad3363db85", size = 840431, upload-time = "2025-03-03T19:36:12.223Z" } +sdist = { url = "https://files.pythonhosted.org/packages/51/c9/fae18fa5d803712d188486f8e86ad4f4e00316793ca19745d7c11092c360/oracledb-3.3.0.tar.gz", hash = "sha256:e830d3544a1578296bcaa54c6e8c8ae10a58c7db467c528c4b27adbf9c8b4cb0", size = 811776, upload-time = "2025-07-29T22:34:10.489Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/fa/bf/d872c4b3fc15cd3261fe0ea72b21d181700c92dbc050160e161654987062/oracledb-3.0.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:52daa9141c63dfa75c07d445e9bb7f69f43bfb3c5a173ecc48c798fe50288d26", size = 4312963, upload-time = "2025-03-03T19:36:32.576Z" }, - { url = "https://files.pythonhosted.org/packages/b1/ea/01ee29e76a610a53bb34fdc1030f04b7669c3f80b25f661e07850fc6160e/oracledb-3.0.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:af98941789df4c6aaaf4338f5b5f6b7f2c8c3fe6f8d6a9382f177f350868747a", size = 2661536, upload-time = "2025-03-03T19:36:34.904Z" }, - { url = "https://files.pythonhosted.org/packages/3d/8e/ad380e34a46819224423b4773e58c350bc6269643c8969604097ced8c3bc/oracledb-3.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:9812bb48865aaec35d73af54cd1746679f2a8a13cbd1412ab371aba2e39b3943", size = 2867461, upload-time = "2025-03-03T19:36:36.508Z" }, - { url = "https://files.pythonhosted.org/packages/96/09/ecc4384a27fd6e1e4de824ae9c160e4ad3aaebdaade5b4bdcf56a4d1ff63/oracledb-3.0.0-cp311-cp311-win32.whl", hash = "sha256:6c27fe0de64f2652e949eb05b3baa94df9b981a4a45fa7f8a991e1afb450c8e2", size = 1752046, upload-time = "2025-03-03T19:36:38.313Z" }, - { url = "https://files.pythonhosted.org/packages/62/e8/f34bde24050c6e55eeba46b23b2291f2dd7fd272fa8b322dcbe71be55778/oracledb-3.0.0-cp311-cp311-win_amd64.whl", hash = "sha256:f922709672002f0b40997456f03a95f03e5712a86c61159951c5ce09334325e0", size = 2101210, upload-time = "2025-03-03T19:36:40.669Z" }, - { url = "https://files.pythonhosted.org/packages/6f/fc/24590c3a3d41e58494bd3c3b447a62835138e5f9b243d9f8da0cfb5da8dc/oracledb-3.0.0-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:acd0e747227dea01bebe627b07e958bf36588a337539f24db629dc3431d3f7eb", size = 4351993, upload-time = "2025-03-03T19:36:42.577Z" }, - { url = "https://files.pythonhosted.org/packages/b7/b6/1f3b0b7bb94d53e8857d77b2e8dbdf6da091dd7e377523e24b79dac4fd71/oracledb-3.0.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:f8b402f77c22af031cd0051aea2472ecd0635c1b452998f511aa08b7350c90a4", size = 2532640, upload-time = "2025-03-03T19:36:45.066Z" }, - { url = "https://files.pythonhosted.org/packages/72/1a/1815f6c086ab49c00921cf155ff5eede5267fb29fcec37cb246339a5ce4d/oracledb-3.0.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:378a27782e9a37918bd07a5a1427a77cb6f777d0a5a8eac9c070d786f50120ef", size = 2765949, upload-time = "2025-03-03T19:36:47.47Z" }, - { url = "https://files.pythonhosted.org/packages/33/8d/208900f8d372909792ee70b2daad3f7361181e55f2217c45ed9dff658b54/oracledb-3.0.0-cp312-cp312-win32.whl", hash = "sha256:54a28c2cb08316a527cd1467740a63771cc1c1164697c932aa834c0967dc4efc", size = 1709373, upload-time = "2025-03-03T19:36:49.67Z" }, - { url = "https://files.pythonhosted.org/packages/0c/5e/c21754f19c896102793c3afec2277e2180aa7d505e4d7fcca24b52d14e4f/oracledb-3.0.0-cp312-cp312-win_amd64.whl", hash = "sha256:8289bad6d103ce42b140e40576cf0c81633e344d56e2d738b539341eacf65624", size = 2056452, upload-time = "2025-03-03T19:36:51.363Z" }, + { url = "https://files.pythonhosted.org/packages/3f/35/95d9a502fdc48ce1ef3a513ebd027488353441e15aa0448619abb3d09d32/oracledb-3.3.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:d9adb74f837838e21898d938e3a725cf73099c65f98b0b34d77146b453e945e0", size = 3963945, upload-time = "2025-07-29T22:34:28.633Z" }, + { url = "https://files.pythonhosted.org/packages/16/a7/8f1ef447d995bb51d9fdc36356697afeceb603932f16410c12d52b2df1a4/oracledb-3.3.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:4b063d1007882570f170ebde0f364e78d4a70c8f015735cc900663278b9ceef7", size = 2449385, upload-time = "2025-07-29T22:34:30.592Z" }, + { url = "https://files.pythonhosted.org/packages/b3/fa/6a78480450bc7d256808d0f38ade3385735fb5a90dab662167b4257dcf94/oracledb-3.3.0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:187728f0a2d161676b8c581a9d8f15d9631a8fea1e628f6d0e9fa2f01280cd22", size = 2634943, upload-time = "2025-07-29T22:34:33.142Z" }, + { url = "https://files.pythonhosted.org/packages/5b/90/ea32b569a45fb99fac30b96f1ac0fb38b029eeebb78357bc6db4be9dde41/oracledb-3.3.0-cp311-cp311-win32.whl", hash = "sha256:920f14314f3402c5ab98f2efc5932e0547e9c0a4ca9338641357f73844e3e2b1", size = 1483549, upload-time = "2025-07-29T22:34:35.015Z" }, + { url = "https://files.pythonhosted.org/packages/81/55/ae60f72836eb8531b630299f9ed68df3fe7868c6da16f820a108155a21f9/oracledb-3.3.0-cp311-cp311-win_amd64.whl", hash = "sha256:825edb97976468db1c7e52c78ba38d75ce7e2b71a2e88f8629bcf02be8e68a8a", size = 1834737, upload-time = "2025-07-29T22:34:36.824Z" }, + { url = "https://files.pythonhosted.org/packages/08/a8/f6b7809d70e98e113786d5a6f1294da81c046d2fa901ad656669fc5d7fae/oracledb-3.3.0-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:9d25e37d640872731ac9b73f83cbc5fc4743cd744766bdb250488caf0d7696a8", size = 3943512, upload-time = "2025-07-29T22:34:39.237Z" }, + { url = "https://files.pythonhosted.org/packages/df/b9/8145ad8991f4864d3de4a911d439e5bc6cdbf14af448f3ab1e846a54210c/oracledb-3.3.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b0bf7cdc2b668f939aa364f552861bc7a149d7cd3f3794730d43ef07613b2bf9", size = 2276258, upload-time = "2025-07-29T22:34:41.547Z" }, + { url = "https://files.pythonhosted.org/packages/56/bf/f65635ad5df17d6e4a2083182750bb136ac663ff0e9996ce59d77d200f60/oracledb-3.3.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:2fe20540fde64a6987046807ea47af93be918fd70b9766b3eb803c01e6d4202e", size = 2458811, upload-time = "2025-07-29T22:34:44.648Z" }, + { url = "https://files.pythonhosted.org/packages/7d/30/e0c130b6278c10b0e6cd77a3a1a29a785c083c549676cf701c5d180b8e63/oracledb-3.3.0-cp312-cp312-win32.whl", hash = "sha256:db080be9345cbf9506ffdaea3c13d5314605355e76d186ec4edfa49960ffb813", size = 1445525, upload-time = "2025-07-29T22:34:46.603Z" }, + { url = "https://files.pythonhosted.org/packages/1a/5c/7254f5e1a33a5d6b8bf6813d4f4fdcf5c4166ec8a7af932d987879d5595c/oracledb-3.3.0-cp312-cp312-win_amd64.whl", hash = "sha256:be81e3afe79f6c8ece79a86d6067ad1572d2992ce1c590a086f3755a09535eb4", size = 1789976, upload-time = "2025-07-29T22:34:48.5Z" }, ] [[package]] @@ -4228,16 +4212,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ec/f8/46141ba8c9d7064dc5008bfb4a6ae5bd3c30e4c61c28b5c5ed485bf358ba/pandas_stubs-2.2.3.250527-py3-none-any.whl", hash = "sha256:cd0a49a95b8c5f944e605be711042a4dd8550e2c559b43d70ba2c4b524b66163", size = 159683, upload-time = "2025-05-27T15:24:28.4Z" }, ] -[[package]] -name = "pandoc" -version = "2.4" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "plumbum" }, - { name = "ply" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/10/9a/e3186e760c57ee5f1c27ea5cea577a0ff9abfca51eefcb4d9a4cd39aff2e/pandoc-2.4.tar.gz", hash = "sha256:ecd1f8cbb7f4180c6b5db4a17a7c1a74df519995f5f186ef81ce72a9cbd0dd9a", size = 34635, upload-time = "2024-08-07T14:33:58.016Z" } - [[package]] name = "pathspec" version = "0.12.1" @@ -4344,18 +4318,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746", size = 20538, upload-time = "2025-05-15T12:30:06.134Z" }, ] -[[package]] -name = "plumbum" -version = "1.9.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "pywin32", marker = "platform_python_implementation != 'PyPy' and sys_platform == 'win32'" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/f0/5d/49ba324ad4ae5b1a4caefafbce7a1648540129344481f2ed4ef6bb68d451/plumbum-1.9.0.tar.gz", hash = "sha256:e640062b72642c3873bd5bdc3effed75ba4d3c70ef6b6a7b907357a84d909219", size = 319083, upload-time = "2024-10-05T05:59:27.059Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/4f/9d/d03542c93bb3d448406731b80f39c3d5601282f778328c22c77d270f4ed4/plumbum-1.9.0-py3-none-any.whl", hash = "sha256:9fd0d3b0e8d86e4b581af36edf3f3bbe9d1ae15b45b8caab28de1bcb27aaa7f5", size = 127970, upload-time = "2024-10-05T05:59:25.102Z" }, -] - [[package]] name = "ply" version = "3.11" @@ -5499,28 +5461,28 @@ wheels = [ [[package]] name = "ruff" -version = "0.12.12" +version = "0.14.0" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/a8/f0/e0965dd709b8cabe6356811c0ee8c096806bb57d20b5019eb4e48a117410/ruff-0.12.12.tar.gz", hash = "sha256:b86cd3415dbe31b3b46a71c598f4c4b2f550346d1ccf6326b347cc0c8fd063d6", size = 5359915, upload-time = "2025-09-04T16:50:18.273Z" } +sdist = { url = "https://files.pythonhosted.org/packages/41/b9/9bd84453ed6dd04688de9b3f3a4146a1698e8faae2ceeccce4e14c67ae17/ruff-0.14.0.tar.gz", hash = "sha256:62ec8969b7510f77945df916de15da55311fade8d6050995ff7f680afe582c57", size = 5452071, upload-time = "2025-10-07T18:21:55.763Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/09/79/8d3d687224d88367b51c7974cec1040c4b015772bfbeffac95face14c04a/ruff-0.12.12-py3-none-linux_armv6l.whl", hash = "sha256:de1c4b916d98ab289818e55ce481e2cacfaad7710b01d1f990c497edf217dafc", size = 12116602, upload-time = "2025-09-04T16:49:18.892Z" }, - { url = "https://files.pythonhosted.org/packages/c3/c3/6e599657fe192462f94861a09aae935b869aea8a1da07f47d6eae471397c/ruff-0.12.12-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:7acd6045e87fac75a0b0cdedacf9ab3e1ad9d929d149785903cff9bb69ad9727", size = 12868393, upload-time = "2025-09-04T16:49:23.043Z" }, - { url = "https://files.pythonhosted.org/packages/e8/d2/9e3e40d399abc95336b1843f52fc0daaceb672d0e3c9290a28ff1a96f79d/ruff-0.12.12-py3-none-macosx_11_0_arm64.whl", hash = "sha256:abf4073688d7d6da16611f2f126be86523a8ec4343d15d276c614bda8ec44edb", size = 12036967, upload-time = "2025-09-04T16:49:26.04Z" }, - { url = "https://files.pythonhosted.org/packages/e9/03/6816b2ed08836be272e87107d905f0908be5b4a40c14bfc91043e76631b8/ruff-0.12.12-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:968e77094b1d7a576992ac078557d1439df678a34c6fe02fd979f973af167577", size = 12276038, upload-time = "2025-09-04T16:49:29.056Z" }, - { url = "https://files.pythonhosted.org/packages/9f/d5/707b92a61310edf358a389477eabd8af68f375c0ef858194be97ca5b6069/ruff-0.12.12-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:42a67d16e5b1ffc6d21c5f67851e0e769517fb57a8ebad1d0781b30888aa704e", size = 11901110, upload-time = "2025-09-04T16:49:32.07Z" }, - { url = "https://files.pythonhosted.org/packages/9d/3d/f8b1038f4b9822e26ec3d5b49cf2bc313e3c1564cceb4c1a42820bf74853/ruff-0.12.12-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b216ec0a0674e4b1214dcc998a5088e54eaf39417327b19ffefba1c4a1e4971e", size = 13668352, upload-time = "2025-09-04T16:49:35.148Z" }, - { url = "https://files.pythonhosted.org/packages/98/0e/91421368ae6c4f3765dd41a150f760c5f725516028a6be30e58255e3c668/ruff-0.12.12-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:59f909c0fdd8f1dcdbfed0b9569b8bf428cf144bec87d9de298dcd4723f5bee8", size = 14638365, upload-time = "2025-09-04T16:49:38.892Z" }, - { url = "https://files.pythonhosted.org/packages/74/5d/88f3f06a142f58ecc8ecb0c2fe0b82343e2a2b04dcd098809f717cf74b6c/ruff-0.12.12-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9ac93d87047e765336f0c18eacad51dad0c1c33c9df7484c40f98e1d773876f5", size = 14060812, upload-time = "2025-09-04T16:49:42.732Z" }, - { url = "https://files.pythonhosted.org/packages/13/fc/8962e7ddd2e81863d5c92400820f650b86f97ff919c59836fbc4c1a6d84c/ruff-0.12.12-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:01543c137fd3650d322922e8b14cc133b8ea734617c4891c5a9fccf4bfc9aa92", size = 13050208, upload-time = "2025-09-04T16:49:46.434Z" }, - { url = "https://files.pythonhosted.org/packages/53/06/8deb52d48a9a624fd37390555d9589e719eac568c020b27e96eed671f25f/ruff-0.12.12-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2afc2fa864197634e549d87fb1e7b6feb01df0a80fd510d6489e1ce8c0b1cc45", size = 13311444, upload-time = "2025-09-04T16:49:49.931Z" }, - { url = "https://files.pythonhosted.org/packages/2a/81/de5a29af7eb8f341f8140867ffb93f82e4fde7256dadee79016ac87c2716/ruff-0.12.12-py3-none-manylinux_2_31_riscv64.whl", hash = "sha256:0c0945246f5ad776cb8925e36af2438e66188d2b57d9cf2eed2c382c58b371e5", size = 13279474, upload-time = "2025-09-04T16:49:53.465Z" }, - { url = "https://files.pythonhosted.org/packages/7f/14/d9577fdeaf791737ada1b4f5c6b59c21c3326f3f683229096cccd7674e0c/ruff-0.12.12-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:a0fbafe8c58e37aae28b84a80ba1817f2ea552e9450156018a478bf1fa80f4e4", size = 12070204, upload-time = "2025-09-04T16:49:56.882Z" }, - { url = "https://files.pythonhosted.org/packages/77/04/a910078284b47fad54506dc0af13839c418ff704e341c176f64e1127e461/ruff-0.12.12-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:b9c456fb2fc8e1282affa932c9e40f5ec31ec9cbb66751a316bd131273b57c23", size = 11880347, upload-time = "2025-09-04T16:49:59.729Z" }, - { url = "https://files.pythonhosted.org/packages/df/58/30185fcb0e89f05e7ea82e5817b47798f7fa7179863f9d9ba6fd4fe1b098/ruff-0.12.12-py3-none-musllinux_1_2_i686.whl", hash = "sha256:5f12856123b0ad0147d90b3961f5c90e7427f9acd4b40050705499c98983f489", size = 12891844, upload-time = "2025-09-04T16:50:02.591Z" }, - { url = "https://files.pythonhosted.org/packages/21/9c/28a8dacce4855e6703dcb8cdf6c1705d0b23dd01d60150786cd55aa93b16/ruff-0.12.12-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:26a1b5a2bf7dd2c47e3b46d077cd9c0fc3b93e6c6cc9ed750bd312ae9dc302ee", size = 13360687, upload-time = "2025-09-04T16:50:05.8Z" }, - { url = "https://files.pythonhosted.org/packages/c8/fa/05b6428a008e60f79546c943e54068316f32ec8ab5c4f73e4563934fbdc7/ruff-0.12.12-py3-none-win32.whl", hash = "sha256:173be2bfc142af07a01e3a759aba6f7791aa47acf3604f610b1c36db888df7b1", size = 12052870, upload-time = "2025-09-04T16:50:09.121Z" }, - { url = "https://files.pythonhosted.org/packages/85/60/d1e335417804df452589271818749d061b22772b87efda88354cf35cdb7a/ruff-0.12.12-py3-none-win_amd64.whl", hash = "sha256:e99620bf01884e5f38611934c09dd194eb665b0109104acae3ba6102b600fd0d", size = 13178016, upload-time = "2025-09-04T16:50:12.559Z" }, - { url = "https://files.pythonhosted.org/packages/28/7e/61c42657f6e4614a4258f1c3b0c5b93adc4d1f8575f5229d1906b483099b/ruff-0.12.12-py3-none-win_arm64.whl", hash = "sha256:2a8199cab4ce4d72d158319b63370abf60991495fb733db96cd923a34c52d093", size = 12256762, upload-time = "2025-09-04T16:50:15.737Z" }, + { url = "https://files.pythonhosted.org/packages/3a/4e/79d463a5f80654e93fa653ebfb98e0becc3f0e7cf6219c9ddedf1e197072/ruff-0.14.0-py3-none-linux_armv6l.whl", hash = "sha256:58e15bffa7054299becf4bab8a1187062c6f8cafbe9f6e39e0d5aface455d6b3", size = 12494532, upload-time = "2025-10-07T18:21:00.373Z" }, + { url = "https://files.pythonhosted.org/packages/ee/40/e2392f445ed8e02aa6105d49db4bfff01957379064c30f4811c3bf38aece/ruff-0.14.0-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:838d1b065f4df676b7c9957992f2304e41ead7a50a568185efd404297d5701e8", size = 13160768, upload-time = "2025-10-07T18:21:04.73Z" }, + { url = "https://files.pythonhosted.org/packages/75/da/2a656ea7c6b9bd14c7209918268dd40e1e6cea65f4bb9880eaaa43b055cd/ruff-0.14.0-py3-none-macosx_11_0_arm64.whl", hash = "sha256:703799d059ba50f745605b04638fa7e9682cc3da084b2092feee63500ff3d9b8", size = 12363376, upload-time = "2025-10-07T18:21:07.833Z" }, + { url = "https://files.pythonhosted.org/packages/42/e2/1ffef5a1875add82416ff388fcb7ea8b22a53be67a638487937aea81af27/ruff-0.14.0-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3ba9a8925e90f861502f7d974cc60e18ca29c72bb0ee8bfeabb6ade35a3abde7", size = 12608055, upload-time = "2025-10-07T18:21:10.72Z" }, + { url = "https://files.pythonhosted.org/packages/4a/32/986725199d7cee510d9f1dfdf95bf1efc5fa9dd714d0d85c1fb1f6be3bc3/ruff-0.14.0-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:e41f785498bd200ffc276eb9e1570c019c1d907b07cfb081092c8ad51975bbe7", size = 12318544, upload-time = "2025-10-07T18:21:13.741Z" }, + { url = "https://files.pythonhosted.org/packages/9a/ed/4969cefd53315164c94eaf4da7cfba1f267dc275b0abdd593d11c90829a3/ruff-0.14.0-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:30a58c087aef4584c193aebf2700f0fbcfc1e77b89c7385e3139956fa90434e2", size = 14001280, upload-time = "2025-10-07T18:21:16.411Z" }, + { url = "https://files.pythonhosted.org/packages/ab/ad/96c1fc9f8854c37681c9613d825925c7f24ca1acfc62a4eb3896b50bacd2/ruff-0.14.0-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:f8d07350bc7af0a5ce8812b7d5c1a7293cf02476752f23fdfc500d24b79b783c", size = 15027286, upload-time = "2025-10-07T18:21:19.577Z" }, + { url = "https://files.pythonhosted.org/packages/b3/00/1426978f97df4fe331074baf69615f579dc4e7c37bb4c6f57c2aad80c87f/ruff-0.14.0-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:eec3bbbf3a7d5482b5c1f42d5fc972774d71d107d447919fca620b0be3e3b75e", size = 14451506, upload-time = "2025-10-07T18:21:22.779Z" }, + { url = "https://files.pythonhosted.org/packages/58/d5/9c1cea6e493c0cf0647674cca26b579ea9d2a213b74b5c195fbeb9678e15/ruff-0.14.0-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:16b68e183a0e28e5c176d51004aaa40559e8f90065a10a559176713fcf435206", size = 13437384, upload-time = "2025-10-07T18:21:25.758Z" }, + { url = "https://files.pythonhosted.org/packages/29/b4/4cd6a4331e999fc05d9d77729c95503f99eae3ba1160469f2b64866964e3/ruff-0.14.0-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:eb732d17db2e945cfcbbc52af0143eda1da36ca8ae25083dd4f66f1542fdf82e", size = 13447976, upload-time = "2025-10-07T18:21:28.83Z" }, + { url = "https://files.pythonhosted.org/packages/3b/c0/ac42f546d07e4f49f62332576cb845d45c67cf5610d1851254e341d563b6/ruff-0.14.0-py3-none-manylinux_2_31_riscv64.whl", hash = "sha256:c958f66ab884b7873e72df38dcabee03d556a8f2ee1b8538ee1c2bbd619883dd", size = 13682850, upload-time = "2025-10-07T18:21:31.842Z" }, + { url = "https://files.pythonhosted.org/packages/5f/c4/4b0c9bcadd45b4c29fe1af9c5d1dc0ca87b4021665dfbe1c4688d407aa20/ruff-0.14.0-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:7eb0499a2e01f6e0c285afc5bac43ab380cbfc17cd43a2e1dd10ec97d6f2c42d", size = 12449825, upload-time = "2025-10-07T18:21:35.074Z" }, + { url = "https://files.pythonhosted.org/packages/4b/a8/e2e76288e6c16540fa820d148d83e55f15e994d852485f221b9524514730/ruff-0.14.0-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:4c63b2d99fafa05efca0ab198fd48fa6030d57e4423df3f18e03aa62518c565f", size = 12272599, upload-time = "2025-10-07T18:21:38.08Z" }, + { url = "https://files.pythonhosted.org/packages/18/14/e2815d8eff847391af632b22422b8207704222ff575dec8d044f9ab779b2/ruff-0.14.0-py3-none-musllinux_1_2_i686.whl", hash = "sha256:668fce701b7a222f3f5327f86909db2bbe99c30877c8001ff934c5413812ac02", size = 13193828, upload-time = "2025-10-07T18:21:41.216Z" }, + { url = "https://files.pythonhosted.org/packages/44/c6/61ccc2987cf0aecc588ff8f3212dea64840770e60d78f5606cd7dc34de32/ruff-0.14.0-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:a86bf575e05cb68dcb34e4c7dfe1064d44d3f0c04bbc0491949092192b515296", size = 13628617, upload-time = "2025-10-07T18:21:44.04Z" }, + { url = "https://files.pythonhosted.org/packages/73/e6/03b882225a1b0627e75339b420883dc3c90707a8917d2284abef7a58d317/ruff-0.14.0-py3-none-win32.whl", hash = "sha256:7450a243d7125d1c032cb4b93d9625dea46c8c42b4f06c6b709baac168e10543", size = 12367872, upload-time = "2025-10-07T18:21:46.67Z" }, + { url = "https://files.pythonhosted.org/packages/41/77/56cf9cf01ea0bfcc662de72540812e5ba8e9563f33ef3d37ab2174892c47/ruff-0.14.0-py3-none-win_amd64.whl", hash = "sha256:ea95da28cd874c4d9c922b39381cbd69cb7e7b49c21b8152b014bd4f52acddc2", size = 13464628, upload-time = "2025-10-07T18:21:50.318Z" }, + { url = "https://files.pythonhosted.org/packages/c6/2a/65880dfd0e13f7f13a775998f34703674a4554906167dce02daf7865b954/ruff-0.14.0-py3-none-win_arm64.whl", hash = "sha256:f42c9495f5c13ff841b1da4cb3c2a42075409592825dada7c5885c2c844ac730", size = 12565142, upload-time = "2025-10-07T18:21:53.577Z" }, ] [[package]] @@ -6478,19 +6440,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/2b/6f/ec0012be842b1d888d46884ac5558fd62aeae1f0ec4f7a581433d890d4b5/types_requests-2.32.4.20250809-py3-none-any.whl", hash = "sha256:f73d1832fb519ece02c85b1f09d5f0dd3108938e7d47e7f94bbfa18a6782b163", size = 20644, upload-time = "2025-08-09T03:17:09.716Z" }, ] -[[package]] -name = "types-requests-oauthlib" -version = "2.0.0.20250809" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "types-oauthlib" }, - { name = "types-requests" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/ed/40/5eca857a2dbda0fedd69b7fd3f51cb0b6ece8d448327d29f0ae54612ec98/types_requests_oauthlib-2.0.0.20250809.tar.gz", hash = "sha256:f3b9b31e0394fe2c362f0d44bc9ef6d5c150a298d01089513cd54a51daec37a2", size = 11008, upload-time = "2025-08-09T03:17:50.705Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/f3/38/8777f0ab409a7249777f230f6aefe0e9ba98355dc8b05fb31391fa30f312/types_requests_oauthlib-2.0.0.20250809-py3-none-any.whl", hash = "sha256:0d1af4907faf9f4a1b0f0afbc7ec488f1dd5561a2b5b6dad70f78091a1acfb76", size = 14319, upload-time = "2025-08-09T03:17:49.786Z" }, -] - [[package]] name = "types-s3transfer" version = "0.13.1" diff --git a/docker/.env.example b/docker/.env.example index d4e8ab3beb..b0e8d020ba 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -45,7 +45,7 @@ APP_WEB_URL= # Recommendation: use a dedicated domain (e.g., https://upload.example.com). # Alternatively, use http://:5001 or http://api:5001, # ensuring port 5001 is externally accessible (see docker-compose.yaml). -FILES_URL=http://api:5001 +FILES_URL= # INTERNAL_FILES_URL is used for plugin daemon communication within Docker network. # Set this to the internal Docker service URL for proper plugin file access. @@ -449,7 +449,7 @@ SUPABASE_URL=your-server-url # ------------------------------ # The type of vector store to use. -# Supported values are `weaviate`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `oracle`, `tencent`, `elasticsearch`, `elasticsearch-ja`, `analyticdb`, `couchbase`, `vikingdb`, `oceanbase`, `opengauss`, `tablestore`,`vastbase`,`tidb`,`tidb_on_qdrant`,`baidu`,`lindorm`,`huawei_cloud`,`upstash`, `matrixone`, `clickzetta`. +# Supported values are `weaviate`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `oracle`, `tencent`, `elasticsearch`, `elasticsearch-ja`, `analyticdb`, `couchbase`, `vikingdb`, `oceanbase`, `opengauss`, `tablestore`,`vastbase`,`tidb`,`tidb_on_qdrant`,`baidu`,`lindorm`,`huawei_cloud`,`upstash`, `matrixone`, `clickzetta`, `alibabacloud_mysql`. VECTOR_STORE=weaviate # Prefix used to create collection name in vector database VECTOR_INDEX_NAME_PREFIX=Vector_index @@ -580,6 +580,15 @@ ORACLE_WALLET_LOCATION=/app/api/storage/wallet ORACLE_WALLET_PASSWORD=dify ORACLE_IS_AUTONOMOUS=false +# AlibabaCloud MySQL configuration, only available when VECTOR_STORE is `alibabcloud_mysql` +ALIBABACLOUD_MYSQL_HOST=127.0.0.1 +ALIBABACLOUD_MYSQL_PORT=3306 +ALIBABACLOUD_MYSQL_USER=root +ALIBABACLOUD_MYSQL_PASSWORD=difyai123456 +ALIBABACLOUD_MYSQL_DATABASE=dify +ALIBABACLOUD_MYSQL_MAX_CONNECTION=5 +ALIBABACLOUD_MYSQL_HNSW_M=6 + # relyt configurations, only available when VECTOR_STORE is `relyt` RELYT_HOST=db RELYT_PORT=5432 @@ -655,6 +664,8 @@ LINDORM_USING_UGC=True LINDORM_QUERY_TIMEOUT=1 # OceanBase Vector configuration, only available when VECTOR_STORE is `oceanbase` +# Built-in fulltext parsers are `ngram`, `beng`, `space`, `ngram2`, `ik` +# External fulltext parsers (require plugin installation) are `japanese_ftparser`, `thai_ftparser` OCEANBASE_VECTOR_HOST=oceanbase OCEANBASE_VECTOR_PORT=2881 OCEANBASE_VECTOR_USER=root@test @@ -857,25 +868,28 @@ OWNER_TRANSFER_TOKEN_EXPIRY_MINUTES=5 # The sandbox service endpoint. CODE_EXECUTION_ENDPOINT=http://sandbox:8194 CODE_EXECUTION_API_KEY=dify-sandbox +CODE_EXECUTION_SSL_VERIFY=True +CODE_EXECUTION_POOL_MAX_CONNECTIONS=100 +CODE_EXECUTION_POOL_MAX_KEEPALIVE_CONNECTIONS=20 +CODE_EXECUTION_POOL_KEEPALIVE_EXPIRY=5.0 CODE_MAX_NUMBER=9223372036854775807 CODE_MIN_NUMBER=-9223372036854775808 CODE_MAX_DEPTH=5 CODE_MAX_PRECISION=20 -CODE_MAX_STRING_LENGTH=80000 +CODE_MAX_STRING_LENGTH=400000 CODE_MAX_STRING_ARRAY_LENGTH=30 CODE_MAX_OBJECT_ARRAY_LENGTH=30 CODE_MAX_NUMBER_ARRAY_LENGTH=1000 CODE_EXECUTION_CONNECT_TIMEOUT=10 CODE_EXECUTION_READ_TIMEOUT=60 CODE_EXECUTION_WRITE_TIMEOUT=10 -TEMPLATE_TRANSFORM_MAX_LENGTH=80000 +TEMPLATE_TRANSFORM_MAX_LENGTH=400000 # Workflow runtime configuration WORKFLOW_MAX_EXECUTION_STEPS=500 WORKFLOW_MAX_EXECUTION_TIME=1200 WORKFLOW_CALL_MAX_DEPTH=5 MAX_VARIABLE_SIZE=204800 -WORKFLOW_PARALLEL_DEPTH_LIMIT=3 WORKFLOW_FILE_UPLOAD_LIMIT=10 # GraphEngine Worker Pool Configuration @@ -925,6 +939,16 @@ WORKFLOW_LOG_CLEANUP_BATCH_SIZE=100 HTTP_REQUEST_NODE_MAX_BINARY_SIZE=10485760 HTTP_REQUEST_NODE_MAX_TEXT_SIZE=1048576 HTTP_REQUEST_NODE_SSL_VERIFY=True + +# HTTP request node timeout configuration +# Maximum timeout values (in seconds) that users can set in HTTP request nodes +# - Connect timeout: Time to wait for establishing connection (default: 10s) +# - Read timeout: Time to wait for receiving response data (default: 600s, 10 minutes) +# - Write timeout: Time to wait for sending request data (default: 600s, 10 minutes) +HTTP_REQUEST_MAX_CONNECT_TIMEOUT=10 +HTTP_REQUEST_MAX_READ_TIMEOUT=600 +HTTP_REQUEST_MAX_WRITE_TIMEOUT=600 + # Base64 encoded CA certificate data for custom certificate verification (PEM format, optional) # HTTP_REQUEST_NODE_SSL_CERT_DATA=LS0tLS1CRUdJTi... # Base64 encoded client certificate data for mutual TLS authentication (PEM format, optional) @@ -1132,6 +1156,9 @@ SSRF_DEFAULT_TIME_OUT=5 SSRF_DEFAULT_CONNECT_TIME_OUT=5 SSRF_DEFAULT_READ_TIME_OUT=5 SSRF_DEFAULT_WRITE_TIME_OUT=5 +SSRF_POOL_MAX_CONNECTIONS=100 +SSRF_POOL_MAX_KEEPALIVE_CONNECTIONS=20 +SSRF_POOL_KEEPALIVE_EXPIRY=5.0 # ------------------------------ # docker env var for specifying vector db type at startup diff --git a/docker/docker-compose-template.yaml b/docker/docker-compose-template.yaml index 93159b056f..5253f750b9 100644 --- a/docker/docker-compose-template.yaml +++ b/docker/docker-compose-template.yaml @@ -2,7 +2,7 @@ x-shared-env: &shared-api-worker-env services: # API service api: - image: langgenius/dify-api:2.0.0-beta.2 + image: langgenius/dify-api:1.9.1 restart: always environment: # Use the shared environment variables. @@ -31,7 +31,7 @@ services: # worker service # The Celery worker for processing the queue. worker: - image: langgenius/dify-api:2.0.0-beta.2 + image: langgenius/dify-api:1.9.1 restart: always environment: # Use the shared environment variables. @@ -58,7 +58,7 @@ services: # worker_beat service # Celery beat for scheduling periodic tasks. worker_beat: - image: langgenius/dify-api:2.0.0-beta.2 + image: langgenius/dify-api:1.9.1 restart: always environment: # Use the shared environment variables. @@ -76,7 +76,7 @@ services: # Frontend web application. web: - image: langgenius/dify-web:2.0.0-beta.2 + image: langgenius/dify-web:1.9.1 restart: always environment: CONSOLE_API_URL: ${CONSOLE_API_URL:-} @@ -177,7 +177,7 @@ services: # plugin daemon plugin_daemon: - image: langgenius/dify-plugin-daemon:0.3.0b1-local + image: langgenius/dify-plugin-daemon:0.3.0-local restart: always environment: # Use the shared environment variables. diff --git a/docker/docker-compose.middleware.yaml b/docker/docker-compose.middleware.yaml index 9e7060aad2..d350503f27 100644 --- a/docker/docker-compose.middleware.yaml +++ b/docker/docker-compose.middleware.yaml @@ -20,7 +20,17 @@ services: ports: - "${EXPOSE_POSTGRES_PORT:-5432}:5432" healthcheck: - test: [ 'CMD', 'pg_isready', '-h', 'db', '-U', '${PGUSER:-postgres}', '-d', '${POSTGRES_DB:-dify}' ] + test: + [ + "CMD", + "pg_isready", + "-h", + "db", + "-U", + "${PGUSER:-postgres}", + "-d", + "${POSTGRES_DB:-dify}", + ] interval: 1s timeout: 3s retries: 30 @@ -41,7 +51,11 @@ services: ports: - "${EXPOSE_REDIS_PORT:-6379}:6379" healthcheck: - test: [ 'CMD-SHELL', 'redis-cli -a ${REDIS_PASSWORD:-difyai123456} ping | grep -q PONG' ] + test: + [ + "CMD-SHELL", + "redis-cli -a ${REDIS_PASSWORD:-difyai123456} ping | grep -q PONG", + ] # The DifySandbox sandbox: @@ -65,13 +79,13 @@ services: - ./volumes/sandbox/dependencies:/dependencies - ./volumes/sandbox/conf:/conf healthcheck: - test: [ "CMD", "curl", "-f", "http://localhost:8194/health" ] + test: ["CMD", "curl", "-f", "http://localhost:8194/health"] networks: - ssrf_proxy_network # plugin daemon plugin_daemon: - image: langgenius/dify-plugin-daemon:0.3.0b1-local + image: langgenius/dify-plugin-daemon:0.3.0-local restart: always env_file: - ./middleware.env @@ -143,7 +157,12 @@ services: volumes: - ./ssrf_proxy/squid.conf.template:/etc/squid/squid.conf.template - ./ssrf_proxy/docker-entrypoint.sh:/docker-entrypoint-mount.sh - entrypoint: [ "sh", "-c", "cp /docker-entrypoint-mount.sh /docker-entrypoint.sh && sed -i 's/\r$$//' /docker-entrypoint.sh && chmod +x /docker-entrypoint.sh && /docker-entrypoint.sh" ] + entrypoint: + [ + "sh", + "-c", + "cp /docker-entrypoint-mount.sh /docker-entrypoint.sh && sed -i 's/\r$$//' /docker-entrypoint.sh && chmod +x /docker-entrypoint.sh && /docker-entrypoint.sh", + ] env_file: - ./middleware.env environment: diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index 2d6ba572e6..0df648f38f 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -10,7 +10,7 @@ x-shared-env: &shared-api-worker-env SERVICE_API_URL: ${SERVICE_API_URL:-} APP_API_URL: ${APP_API_URL:-} APP_WEB_URL: ${APP_WEB_URL:-} - FILES_URL: ${FILES_URL:-http://api:5001} + FILES_URL: ${FILES_URL:-} INTERNAL_FILES_URL: ${INTERNAL_FILES_URL:-} LANG: ${LANG:-en_US.UTF-8} LC_ALL: ${LC_ALL:-en_US.UTF-8} @@ -244,6 +244,13 @@ x-shared-env: &shared-api-worker-env ORACLE_WALLET_LOCATION: ${ORACLE_WALLET_LOCATION:-/app/api/storage/wallet} ORACLE_WALLET_PASSWORD: ${ORACLE_WALLET_PASSWORD:-dify} ORACLE_IS_AUTONOMOUS: ${ORACLE_IS_AUTONOMOUS:-false} + ALIBABACLOUD_MYSQL_HOST: ${ALIBABACLOUD_MYSQL_HOST:-127.0.0.1} + ALIBABACLOUD_MYSQL_PORT: ${ALIBABACLOUD_MYSQL_PORT:-3306} + ALIBABACLOUD_MYSQL_USER: ${ALIBABACLOUD_MYSQL_USER:-root} + ALIBABACLOUD_MYSQL_PASSWORD: ${ALIBABACLOUD_MYSQL_PASSWORD:-difyai123456} + ALIBABACLOUD_MYSQL_DATABASE: ${ALIBABACLOUD_MYSQL_DATABASE:-dify} + ALIBABACLOUD_MYSQL_MAX_CONNECTION: ${ALIBABACLOUD_MYSQL_MAX_CONNECTION:-5} + ALIBABACLOUD_MYSQL_HNSW_M: ${ALIBABACLOUD_MYSQL_HNSW_M:-6} RELYT_HOST: ${RELYT_HOST:-db} RELYT_PORT: ${RELYT_PORT:-5432} RELYT_USER: ${RELYT_USER:-postgres} @@ -382,23 +389,26 @@ x-shared-env: &shared-api-worker-env OWNER_TRANSFER_TOKEN_EXPIRY_MINUTES: ${OWNER_TRANSFER_TOKEN_EXPIRY_MINUTES:-5} CODE_EXECUTION_ENDPOINT: ${CODE_EXECUTION_ENDPOINT:-http://sandbox:8194} CODE_EXECUTION_API_KEY: ${CODE_EXECUTION_API_KEY:-dify-sandbox} + CODE_EXECUTION_SSL_VERIFY: ${CODE_EXECUTION_SSL_VERIFY:-True} + CODE_EXECUTION_POOL_MAX_CONNECTIONS: ${CODE_EXECUTION_POOL_MAX_CONNECTIONS:-100} + CODE_EXECUTION_POOL_MAX_KEEPALIVE_CONNECTIONS: ${CODE_EXECUTION_POOL_MAX_KEEPALIVE_CONNECTIONS:-20} + CODE_EXECUTION_POOL_KEEPALIVE_EXPIRY: ${CODE_EXECUTION_POOL_KEEPALIVE_EXPIRY:-5.0} CODE_MAX_NUMBER: ${CODE_MAX_NUMBER:-9223372036854775807} CODE_MIN_NUMBER: ${CODE_MIN_NUMBER:--9223372036854775808} CODE_MAX_DEPTH: ${CODE_MAX_DEPTH:-5} CODE_MAX_PRECISION: ${CODE_MAX_PRECISION:-20} - CODE_MAX_STRING_LENGTH: ${CODE_MAX_STRING_LENGTH:-80000} + CODE_MAX_STRING_LENGTH: ${CODE_MAX_STRING_LENGTH:-400000} CODE_MAX_STRING_ARRAY_LENGTH: ${CODE_MAX_STRING_ARRAY_LENGTH:-30} CODE_MAX_OBJECT_ARRAY_LENGTH: ${CODE_MAX_OBJECT_ARRAY_LENGTH:-30} CODE_MAX_NUMBER_ARRAY_LENGTH: ${CODE_MAX_NUMBER_ARRAY_LENGTH:-1000} CODE_EXECUTION_CONNECT_TIMEOUT: ${CODE_EXECUTION_CONNECT_TIMEOUT:-10} CODE_EXECUTION_READ_TIMEOUT: ${CODE_EXECUTION_READ_TIMEOUT:-60} CODE_EXECUTION_WRITE_TIMEOUT: ${CODE_EXECUTION_WRITE_TIMEOUT:-10} - TEMPLATE_TRANSFORM_MAX_LENGTH: ${TEMPLATE_TRANSFORM_MAX_LENGTH:-80000} + TEMPLATE_TRANSFORM_MAX_LENGTH: ${TEMPLATE_TRANSFORM_MAX_LENGTH:-400000} WORKFLOW_MAX_EXECUTION_STEPS: ${WORKFLOW_MAX_EXECUTION_STEPS:-500} WORKFLOW_MAX_EXECUTION_TIME: ${WORKFLOW_MAX_EXECUTION_TIME:-1200} WORKFLOW_CALL_MAX_DEPTH: ${WORKFLOW_CALL_MAX_DEPTH:-5} MAX_VARIABLE_SIZE: ${MAX_VARIABLE_SIZE:-204800} - WORKFLOW_PARALLEL_DEPTH_LIMIT: ${WORKFLOW_PARALLEL_DEPTH_LIMIT:-3} WORKFLOW_FILE_UPLOAD_LIMIT: ${WORKFLOW_FILE_UPLOAD_LIMIT:-10} GRAPH_ENGINE_MIN_WORKERS: ${GRAPH_ENGINE_MIN_WORKERS:-1} GRAPH_ENGINE_MAX_WORKERS: ${GRAPH_ENGINE_MAX_WORKERS:-10} @@ -415,6 +425,9 @@ x-shared-env: &shared-api-worker-env HTTP_REQUEST_NODE_MAX_BINARY_SIZE: ${HTTP_REQUEST_NODE_MAX_BINARY_SIZE:-10485760} HTTP_REQUEST_NODE_MAX_TEXT_SIZE: ${HTTP_REQUEST_NODE_MAX_TEXT_SIZE:-1048576} HTTP_REQUEST_NODE_SSL_VERIFY: ${HTTP_REQUEST_NODE_SSL_VERIFY:-True} + HTTP_REQUEST_MAX_CONNECT_TIMEOUT: ${HTTP_REQUEST_MAX_CONNECT_TIMEOUT:-10} + HTTP_REQUEST_MAX_READ_TIMEOUT: ${HTTP_REQUEST_MAX_READ_TIMEOUT:-600} + HTTP_REQUEST_MAX_WRITE_TIMEOUT: ${HTTP_REQUEST_MAX_WRITE_TIMEOUT:-600} RESPECT_XFORWARD_HEADERS_ENABLED: ${RESPECT_XFORWARD_HEADERS_ENABLED:-false} SSRF_PROXY_HTTP_URL: ${SSRF_PROXY_HTTP_URL:-http://ssrf_proxy:3128} SSRF_PROXY_HTTPS_URL: ${SSRF_PROXY_HTTPS_URL:-http://ssrf_proxy:3128} @@ -497,6 +510,9 @@ x-shared-env: &shared-api-worker-env SSRF_DEFAULT_CONNECT_TIME_OUT: ${SSRF_DEFAULT_CONNECT_TIME_OUT:-5} SSRF_DEFAULT_READ_TIME_OUT: ${SSRF_DEFAULT_READ_TIME_OUT:-5} SSRF_DEFAULT_WRITE_TIME_OUT: ${SSRF_DEFAULT_WRITE_TIME_OUT:-5} + SSRF_POOL_MAX_CONNECTIONS: ${SSRF_POOL_MAX_CONNECTIONS:-100} + SSRF_POOL_MAX_KEEPALIVE_CONNECTIONS: ${SSRF_POOL_MAX_KEEPALIVE_CONNECTIONS:-20} + SSRF_POOL_KEEPALIVE_EXPIRY: ${SSRF_POOL_KEEPALIVE_EXPIRY:-5.0} EXPOSE_NGINX_PORT: ${EXPOSE_NGINX_PORT:-80} EXPOSE_NGINX_SSL_PORT: ${EXPOSE_NGINX_SSL_PORT:-443} POSITION_TOOL_PINS: ${POSITION_TOOL_PINS:-} @@ -593,7 +609,7 @@ x-shared-env: &shared-api-worker-env services: # API service api: - image: langgenius/dify-api:2.0.0-beta.2 + image: langgenius/dify-api:1.9.1 restart: always environment: # Use the shared environment variables. @@ -622,7 +638,7 @@ services: # worker service # The Celery worker for processing the queue. worker: - image: langgenius/dify-api:2.0.0-beta.2 + image: langgenius/dify-api:1.9.1 restart: always environment: # Use the shared environment variables. @@ -649,7 +665,7 @@ services: # worker_beat service # Celery beat for scheduling periodic tasks. worker_beat: - image: langgenius/dify-api:2.0.0-beta.2 + image: langgenius/dify-api:1.9.1 restart: always environment: # Use the shared environment variables. @@ -667,7 +683,7 @@ services: # Frontend web application. web: - image: langgenius/dify-web:2.0.0-beta.2 + image: langgenius/dify-web:1.9.1 restart: always environment: CONSOLE_API_URL: ${CONSOLE_API_URL:-} @@ -768,7 +784,7 @@ services: # plugin daemon plugin_daemon: - image: langgenius/dify-plugin-daemon:0.3.0b1-local + image: langgenius/dify-plugin-daemon:0.3.0-local restart: always environment: # Use the shared environment variables. diff --git a/README_AR.md b/docs/ar-SA/README.md similarity index 81% rename from README_AR.md rename to docs/ar-SA/README.md index 2451757ab5..afa494c5d3 100644 --- a/README_AR.md +++ b/docs/ar-SA/README.md @@ -1,4 +1,4 @@ -![cover-v5-optimized](./images/GitHub_README_if.png) +![cover-v5-optimized](../../images/GitHub_README_if.png)

Dify Cloud · @@ -35,17 +35,19 @@

- README in English - 简体中文版自述文件 - 日本語のREADME - README en Español - README en Français - README tlhIngan Hol - README in Korean - README بالعربية - Türkçe README - README Tiếng Việt - README in বাংলা + README in English + 繁體中文文件 + 简体中文文件 + 日本語のREADME + README en Español + README en Français + README tlhIngan Hol + README in Korean + README بالعربية + Türkçe README + README Tiếng Việt + README in Deutsch + README in বাংলা

@@ -97,7 +99,7 @@
-أسهل طريقة لبدء تشغيل خادم Dify هي تشغيل ملف [docker-compose.yml](docker/docker-compose.yaml) الخاص بنا. قبل تشغيل أمر التثبيت، تأكد من تثبيت [Docker](https://docs.docker.com/get-docker/) و [Docker Compose](https://docs.docker.com/compose/install/) على جهازك: +أسهل طريقة لبدء تشغيل خادم Dify هي تشغيل ملف [docker-compose.yml](../../docker/docker-compose.yaml) الخاص بنا. قبل تشغيل أمر التثبيت، تأكد من تثبيت [Docker](https://docs.docker.com/get-docker/) و [Docker Compose](https://docs.docker.com/compose/install/) على جهازك: ```bash cd docker @@ -111,7 +113,7 @@ docker compose up -d ## الخطوات التالية -إذا كنت بحاجة إلى تخصيص الإعدادات، فيرجى الرجوع إلى التعليقات في ملف [.env.example](docker/.env.example) وتحديث القيم المقابلة في ملف `.env`. بالإضافة إلى ذلك، قد تحتاج إلى إجراء تعديلات على ملف `docker-compose.yaml` نفسه، مثل تغيير إصدارات الصور أو تعيينات المنافذ أو نقاط تحميل وحدات التخزين، بناءً على بيئة النشر ومتطلباتك الخاصة. بعد إجراء أي تغييرات، يرجى إعادة تشغيل `docker-compose up -d`. يمكنك العثور على قائمة كاملة بمتغيرات البيئة المتاحة [هنا](https://docs.dify.ai/getting-started/install-self-hosted/environments). +إذا كنت بحاجة إلى تخصيص الإعدادات، فيرجى الرجوع إلى التعليقات في ملف [.env.example](../../docker/.env.example) وتحديث القيم المقابلة في ملف `.env`. بالإضافة إلى ذلك، قد تحتاج إلى إجراء تعديلات على ملف `docker-compose.yaml` نفسه، مثل تغيير إصدارات الصور أو تعيينات المنافذ أو نقاط تحميل وحدات التخزين، بناءً على بيئة النشر ومتطلباتك الخاصة. بعد إجراء أي تغييرات، يرجى إعادة تشغيل `docker-compose up -d`. يمكنك العثور على قائمة كاملة بمتغيرات البيئة المتاحة [هنا](https://docs.dify.ai/getting-started/install-self-hosted/environments). يوجد مجتمع خاص بـ [Helm Charts](https://helm.sh/) وملفات YAML التي تسمح بتنفيذ Dify على Kubernetes للنظام من الإيجابيات العلوية. @@ -185,12 +187,4 @@ docker compose up -d ## الرخصة -هذا المستودع متاح تحت [رخصة البرنامج الحر Dify](LICENSE)، والتي تعتبر بشكل أساسي Apache 2.0 مع بعض القيود الإضافية. - -## الكشف عن الأمان - -لحماية خصوصيتك، يرجى تجنب نشر مشكلات الأمان على GitHub. بدلاً من ذلك، أرسل أسئلتك إلى وسنقدم لك إجابة أكثر تفصيلاً. - -## الرخصة - -هذا المستودع متاح تحت [رخصة البرنامج الحر Dify](LICENSE)، والتي تعتبر بشكل أساسي Apache 2.0 مع بعض القيود الإضافية. +هذا المستودع متاح تحت [رخصة البرنامج الحر Dify](../../LICENSE)، والتي تعتبر بشكل أساسي Apache 2.0 مع بعض القيود الإضافية. diff --git a/README_BN.md b/docs/bn-BD/README.md similarity index 85% rename from README_BN.md rename to docs/bn-BD/README.md index ef24dea171..318853a8de 100644 --- a/README_BN.md +++ b/docs/bn-BD/README.md @@ -1,4 +1,4 @@ -![cover-v5-optimized](./images/GitHub_README_if.png) +![cover-v5-optimized](../../images/GitHub_README_if.png)

📌 ডিফাই ওয়ার্কফ্লো ফাইল আপলোড পরিচিতি: গুগল নোটবুক-এলএম পডকাস্ট পুনর্নির্মাণ @@ -39,18 +39,19 @@

- README in English - 简体中文版自述文件 - 日本語のREADME - README en Español - README en Français - README tlhIngan Hol - README in Korean - README بالعربية - Türkçe README - README Tiếng Việt - README in Deutsch - README in বাংলা + README in English + 繁體中文文件 + 简体中文文件 + 日本語のREADME + README en Español + README en Français + README tlhIngan Hol + README in Korean + README بالعربية + Türkçe README + README Tiếng Việt + README in Deutsch + README in বাংলা

ডিফাই একটি ওপেন-সোর্স LLM অ্যাপ ডেভেলপমেন্ট প্ল্যাটফর্ম। এটি ইন্টুইটিভ ইন্টারফেস, এজেন্টিক AI ওয়ার্কফ্লো, RAG পাইপলাইন, এজেন্ট ক্যাপাবিলিটি, মডেল ম্যানেজমেন্ট, মনিটরিং সুবিধা এবং আরও অনেক কিছু একত্রিত করে, যা দ্রুত প্রোটোটাইপ থেকে প্রোডাকশন পর্যন্ত নিয়ে যেতে সহায়তা করে। @@ -64,7 +65,7 @@
-ডিফাই সার্ভার চালু করার সবচেয়ে সহজ উপায় [docker compose](docker/docker-compose.yaml) মাধ্যমে। নিম্নলিখিত কমান্ডগুলো ব্যবহার করে ডিফাই চালানোর আগে, নিশ্চিত করুন যে আপনার মেশিনে [Docker](https://docs.docker.com/get-docker/) এবং [Docker Compose](https://docs.docker.com/compose/install/) ইনস্টল করা আছে : +ডিফাই সার্ভার চালু করার সবচেয়ে সহজ উপায় [docker compose](../../docker/docker-compose.yaml) মাধ্যমে। নিম্নলিখিত কমান্ডগুলো ব্যবহার করে ডিফাই চালানোর আগে, নিশ্চিত করুন যে আপনার মেশিনে [Docker](https://docs.docker.com/get-docker/) এবং [Docker Compose](https://docs.docker.com/compose/install/) ইনস্টল করা আছে : ```bash cd dify @@ -128,7 +129,7 @@ GitHub-এ ডিফাইকে স্টার দিয়ে রাখুন ## Advanced Setup -যদি আপনার কনফিগারেশনটি কাস্টমাইজ করার প্রয়োজন হয়, তাহলে অনুগ্রহ করে আমাদের [.env.example](docker/.env.example) ফাইল দেখুন এবং আপনার `.env` ফাইলে সংশ্লিষ্ট মানগুলি আপডেট করুন। এছাড়াও, আপনার নির্দিষ্ট এনভায়রনমেন্ট এবং প্রয়োজনীয়তার উপর ভিত্তি করে আপনাকে `docker-compose.yaml` ফাইলে সমন্বয় করতে হতে পারে, যেমন ইমেজ ভার্সন পরিবর্তন করা, পোর্ট ম্যাপিং করা, অথবা ভলিউম মাউন্ট করা। +যদি আপনার কনফিগারেশনটি কাস্টমাইজ করার প্রয়োজন হয়, তাহলে অনুগ্রহ করে আমাদের [.env.example](../../docker/.env.example) ফাইল দেখুন এবং আপনার `.env` ফাইলে সংশ্লিষ্ট মানগুলি আপডেট করুন। এছাড়াও, আপনার নির্দিষ্ট এনভায়রনমেন্ট এবং প্রয়োজনীয়তার উপর ভিত্তি করে আপনাকে `docker-compose.yaml` ফাইলে সমন্বয় করতে হতে পারে, যেমন ইমেজ ভার্সন পরিবর্তন করা, পোর্ট ম্যাপিং করা, অথবা ভলিউম মাউন্ট করা। যেকোনো পরিবর্তন করার পর, অনুগ্রহ করে `docker-compose up -d` পুনরায় চালান। ভেরিয়েবলের সম্পূর্ণ তালিকা [এখানে] (https://docs.dify.ai/getting-started/install-self-hosted/environments) খুঁজে পেতে পারেন। যদি আপনি একটি হাইলি এভেইলেবল সেটআপ কনফিগার করতে চান, তাহলে কমিউনিটি [Helm Charts](https://helm.sh/) এবং YAML ফাইল রয়েছে যা Dify কে Kubernetes-এ ডিপ্লয় করার প্রক্রিয়া বর্ণনা করে। @@ -175,7 +176,7 @@ GitHub-এ ডিফাইকে স্টার দিয়ে রাখুন ## Contributing -যারা কোড অবদান রাখতে চান, তাদের জন্য আমাদের [অবদান নির্দেশিকা] দেখুন (https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md)। +যারা কোড অবদান রাখতে চান, তাদের জন্য আমাদের [অবদান নির্দেশিকা](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md) দেখুন। একই সাথে, সোশ্যাল মিডিয়া এবং ইভেন্ট এবং কনফারেন্সে এটি শেয়ার করে Dify কে সমর্থন করুন। > আমরা ম্যান্ডারিন বা ইংরেজি ছাড়া অন্য ভাষায় Dify অনুবাদ করতে সাহায্য করার জন্য অবদানকারীদের খুঁজছি। আপনি যদি সাহায্য করতে আগ্রহী হন, তাহলে আরও তথ্যের জন্য [i18n README](https://github.com/langgenius/dify/blob/main/web/i18n-config/README.md) দেখুন এবং আমাদের [ডিসকর্ড কমিউনিটি সার্ভার](https://discord.gg/8Tpq4AcN9c) এর `গ্লোবাল-ইউজারস` চ্যানেলে আমাদের একটি মন্তব্য করুন। @@ -203,4 +204,4 @@ GitHub-এ ডিফাইকে স্টার দিয়ে রাখুন ## লাইসেন্স -এই রিপোজিটরিটি [ডিফাই ওপেন সোর্স লাইসেন্স](LICENSE) এর অধিনে , যা মূলত অ্যাপাচি ২.০, তবে কিছু অতিরিক্ত বিধিনিষেধ রয়েছে। +এই রিপোজিটরিটি [ডিফাই ওপেন সোর্স লাইসেন্স](../../LICENSE) এর অধিনে , যা মূলত অ্যাপাচি ২.০, তবে কিছু অতিরিক্ত বিধিনিষেধ রয়েছে। diff --git a/CONTRIBUTING_DE.md b/docs/de-DE/CONTRIBUTING.md similarity index 96% rename from CONTRIBUTING_DE.md rename to docs/de-DE/CONTRIBUTING.md index f819e80bbb..db12006b30 100644 --- a/CONTRIBUTING_DE.md +++ b/docs/de-DE/CONTRIBUTING.md @@ -6,7 +6,7 @@ Wir müssen wendig sein und schnell liefern, aber wir möchten auch sicherstelle Dieser Leitfaden ist, wie Dify selbst, in ständiger Entwicklung. Wir sind dankbar für Ihr Verständnis, falls er manchmal hinter dem eigentlichen Projekt zurückbleibt, und begrüßen jedes Feedback zur Verbesserung. -Bitte nehmen Sie sich einen Moment Zeit, um unsere [Lizenz- und Mitwirkungsvereinbarung](./LICENSE) zu lesen. Die Community hält sich außerdem an den [Verhaltenskodex](https://github.com/langgenius/.github/blob/main/CODE_OF_CONDUCT.md). +Bitte nehmen Sie sich einen Moment Zeit, um unsere [Lizenz- und Mitwirkungsvereinbarung](../../LICENSE) zu lesen. Die Community hält sich außerdem an den [Verhaltenskodex](https://github.com/langgenius/.github/blob/main/CODE_OF_CONDUCT.md). ## Bevor Sie loslegen diff --git a/README_DE.md b/docs/de-DE/README.md similarity index 79% rename from README_DE.md rename to docs/de-DE/README.md index a08fe63d4f..8907d914d3 100644 --- a/README_DE.md +++ b/docs/de-DE/README.md @@ -1,4 +1,4 @@ -![cover-v5-optimized](./images/GitHub_README_if.png) +![cover-v5-optimized](../../images/GitHub_README_if.png)

📌 Einführung in Dify Workflow File Upload: Google NotebookLM Podcast nachbilden @@ -39,18 +39,19 @@

- README in English - 简体中文版自述文件 - 日本語のREADME - README en Español - README en Français - README tlhIngan Hol - README in Korean - README بالعربية - Türkçe README - README Tiếng Việt - README in Deutsch - README in বাংলা + README in English + 繁體中文文件 + 简体中文文件 + 日本語のREADME + README en Español + README en Français + README tlhIngan Hol + README in Korean + README بالعربية + Türkçe README + README Tiếng Việt + README in Deutsch + README in বাংলা

Dify ist eine Open-Source-Plattform zur Entwicklung von LLM-Anwendungen. Ihre intuitive Benutzeroberfläche vereint agentenbasierte KI-Workflows, RAG-Pipelines, Agentenfunktionen, Modellverwaltung, Überwachungsfunktionen und mehr, sodass Sie schnell von einem Prototyp in die Produktion übergehen können. @@ -64,7 +65,7 @@ Dify ist eine Open-Source-Plattform zur Entwicklung von LLM-Anwendungen. Ihre in
-Der einfachste Weg, den Dify-Server zu starten, ist über [docker compose](docker/docker-compose.yaml). Stellen Sie vor dem Ausführen von Dify mit den folgenden Befehlen sicher, dass [Docker](https://docs.docker.com/get-docker/) und [Docker Compose](https://docs.docker.com/compose/install/) auf Ihrem System installiert sind: +Der einfachste Weg, den Dify-Server zu starten, ist über [docker compose](../../docker/docker-compose.yaml). Stellen Sie vor dem Ausführen von Dify mit den folgenden Befehlen sicher, dass [Docker](https://docs.docker.com/get-docker/) und [Docker Compose](https://docs.docker.com/compose/install/) auf Ihrem System installiert sind: ```bash cd dify @@ -127,7 +128,7 @@ Star Dify auf GitHub und lassen Sie sich sofort über neue Releases benachrichti ## Erweiterte Einstellungen -Falls Sie die Konfiguration anpassen müssen, lesen Sie bitte die Kommentare in unserer [.env.example](docker/.env.example)-Datei und aktualisieren Sie die entsprechenden Werte in Ihrer `.env`-Datei. Zusätzlich müssen Sie eventuell Anpassungen an der `docker-compose.yaml`-Datei vornehmen, wie zum Beispiel das Ändern von Image-Versionen, Portzuordnungen oder Volumen-Mounts, je nach Ihrer spezifischen Einsatzumgebung und Ihren Anforderungen. Nachdem Sie Änderungen vorgenommen haben, starten Sie `docker-compose up -d` erneut. Eine vollständige Liste der verfügbaren Umgebungsvariablen finden Sie [hier](https://docs.dify.ai/getting-started/install-self-hosted/environments). +Falls Sie die Konfiguration anpassen müssen, lesen Sie bitte die Kommentare in unserer [.env.example](../../docker/.env.example)-Datei und aktualisieren Sie die entsprechenden Werte in Ihrer `.env`-Datei. Zusätzlich müssen Sie eventuell Anpassungen an der `docker-compose.yaml`-Datei vornehmen, wie zum Beispiel das Ändern von Image-Versionen, Portzuordnungen oder Volumen-Mounts, je nach Ihrer spezifischen Einsatzumgebung und Ihren Anforderungen. Nachdem Sie Änderungen vorgenommen haben, starten Sie `docker-compose up -d` erneut. Eine vollständige Liste der verfügbaren Umgebungsvariablen finden Sie [hier](https://docs.dify.ai/getting-started/install-self-hosted/environments). Falls Sie eine hochverfügbare Konfiguration einrichten möchten, gibt es von der Community bereitgestellte [Helm Charts](https://helm.sh/) und YAML-Dateien, die es ermöglichen, Dify auf Kubernetes bereitzustellen. @@ -173,14 +174,14 @@ Stellen Sie Dify mit einem Klick in AKS bereit, indem Sie [Azure Devops Pipeline ## Contributing -Falls Sie Code beitragen möchten, lesen Sie bitte unseren [Contribution Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING_DE.md). Gleichzeitig bitten wir Sie, Dify zu unterstützen, indem Sie es in den sozialen Medien teilen und auf Veranstaltungen und Konferenzen präsentieren. +Falls Sie Code beitragen möchten, lesen Sie bitte unseren [Contribution Guide](./CONTRIBUTING.md). Gleichzeitig bitten wir Sie, Dify zu unterstützen, indem Sie es in den sozialen Medien teilen und auf Veranstaltungen und Konferenzen präsentieren. > Wir suchen Mitwirkende, die dabei helfen, Dify in weitere Sprachen zu übersetzen – außer Mandarin oder Englisch. Wenn Sie Interesse an einer Mitarbeit haben, lesen Sie bitte die [i18n README](https://github.com/langgenius/dify/blob/main/web/i18n-config/README.md) für weitere Informationen und hinterlassen Sie einen Kommentar im `global-users`-Kanal unseres [Discord Community Servers](https://discord.gg/8Tpq4AcN9c). ## Gemeinschaft & Kontakt - [GitHub Discussion](https://github.com/langgenius/dify/discussions). Am besten geeignet für: den Austausch von Feedback und das Stellen von Fragen. -- [GitHub Issues](https://github.com/langgenius/dify/issues). Am besten für: Fehler, auf die Sie bei der Verwendung von Dify.AI stoßen, und Funktionsvorschläge. Siehe unseren [Contribution Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md). +- [GitHub Issues](https://github.com/langgenius/dify/issues). Am besten für: Fehler, auf die Sie bei der Verwendung von Dify.AI stoßen, und Funktionsvorschläge. Siehe unseren [Contribution Guide](./CONTRIBUTING.md). - [Discord](https://discord.gg/FngNHpbcY7). Am besten geeignet für: den Austausch von Bewerbungen und den Austausch mit der Community. - [X(Twitter)](https://twitter.com/dify_ai). Am besten geeignet für: den Austausch von Bewerbungen und den Austausch mit der Community. @@ -200,4 +201,4 @@ Um Ihre Privatsphäre zu schützen, vermeiden Sie es bitte, Sicherheitsprobleme ## Lizenz -Dieses Repository steht unter der [Dify Open Source License](LICENSE), die im Wesentlichen Apache 2.0 mit einigen zusätzlichen Einschränkungen ist. +Dieses Repository steht unter der [Dify Open Source License](../../LICENSE), die im Wesentlichen Apache 2.0 mit einigen zusätzlichen Einschränkungen ist. diff --git a/CONTRIBUTING_ES.md b/docs/es-ES/CONTRIBUTING.md similarity index 96% rename from CONTRIBUTING_ES.md rename to docs/es-ES/CONTRIBUTING.md index e19d958c65..6cd80651c4 100644 --- a/CONTRIBUTING_ES.md +++ b/docs/es-ES/CONTRIBUTING.md @@ -6,7 +6,7 @@ Necesitamos ser ágiles y enviar rápidamente dado donde estamos, pero también Esta guía, como Dify mismo, es un trabajo en constante progreso. Agradecemos mucho tu comprensión si a veces se queda atrás del proyecto real, y damos la bienvenida a cualquier comentario para que podamos mejorar. -En términos de licencia, por favor tómate un minuto para leer nuestro breve [Acuerdo de Licencia y Colaborador](./LICENSE). La comunidad también se adhiere al [código de conducta](https://github.com/langgenius/.github/blob/main/CODE_OF_CONDUCT.md). +En términos de licencia, por favor tómate un minuto para leer nuestro breve [Acuerdo de Licencia y Colaborador](../../LICENSE). La comunidad también se adhiere al [código de conducta](https://github.com/langgenius/.github/blob/main/CODE_OF_CONDUCT.md). ## Antes de empezar diff --git a/README_ES.md b/docs/es-ES/README.md similarity index 79% rename from README_ES.md rename to docs/es-ES/README.md index d8fdbf54e6..b005691fea 100644 --- a/README_ES.md +++ b/docs/es-ES/README.md @@ -1,4 +1,4 @@ -![cover-v5-optimized](./images/GitHub_README_if.png) +![cover-v5-optimized](../../images/GitHub_README_if.png)

Dify Cloud · @@ -35,17 +35,19 @@

- README in English - 简体中文版自述文件 - 日本語のREADME - README en Español - README en Français - README tlhIngan Hol - README in Korean - README بالعربية - Türkçe README - README Tiếng Việt - README in বাংলা + README in English + 繁體中文文件 + 简体中文文件 + 日本語のREADME + README en Español + README en Français + README tlhIngan Hol + README in Korean + README بالعربية + Türkçe README + README Tiếng Việt + README in Deutsch + README in বাংলা

# @@ -108,7 +110,7 @@ Dale estrella a Dify en GitHub y serás notificado instantáneamente de las nuev
-La forma más fácil de iniciar el servidor de Dify es ejecutar nuestro archivo [docker-compose.yml](docker/docker-compose.yaml). Antes de ejecutar el comando de instalación, asegúrate de que [Docker](https://docs.docker.com/get-docker/) y [Docker Compose](https://docs.docker.com/compose/install/) estén instalados en tu máquina: +La forma más fácil de iniciar el servidor de Dify es ejecutar nuestro archivo [docker-compose.yml](../../docker/docker-compose.yaml). Antes de ejecutar el comando de instalación, asegúrate de que [Docker](https://docs.docker.com/get-docker/) y [Docker Compose](https://docs.docker.com/compose/install/) estén instalados en tu máquina: ```bash cd docker @@ -122,7 +124,7 @@ Después de ejecutarlo, puedes acceder al panel de control de Dify en tu navegad ## Próximos pasos -Si necesita personalizar la configuración, consulte los comentarios en nuestro archivo [.env.example](docker/.env.example) y actualice los valores correspondientes en su archivo `.env`. Además, es posible que deba realizar ajustes en el propio archivo `docker-compose.yaml`, como cambiar las versiones de las imágenes, las asignaciones de puertos o los montajes de volúmenes, según su entorno de implementación y requisitos específicos. Después de realizar cualquier cambio, vuelva a ejecutar `docker-compose up -d`. Puede encontrar la lista completa de variables de entorno disponibles [aquí](https://docs.dify.ai/getting-started/install-self-hosted/environments). +Si necesita personalizar la configuración, consulte los comentarios en nuestro archivo [.env.example](../../docker/.env.example) y actualice los valores correspondientes en su archivo `.env`. Además, es posible que deba realizar ajustes en el propio archivo `docker-compose.yaml`, como cambiar las versiones de las imágenes, las asignaciones de puertos o los montajes de volúmenes, según su entorno de implementación y requisitos específicos. Después de realizar cualquier cambio, vuelva a ejecutar `docker-compose up -d`. Puede encontrar la lista completa de variables de entorno disponibles [aquí](https://docs.dify.ai/getting-started/install-self-hosted/environments). . Después de realizar los cambios, ejecuta `docker-compose up -d` nuevamente. Puedes ver la lista completa de variables de entorno [aquí](https://docs.dify.ai/getting-started/install-self-hosted/environments). @@ -170,7 +172,7 @@ Implementa Dify en AKS con un clic usando [Azure Devops Pipeline Helm Chart by @ ## Contribuir -Para aquellos que deseen contribuir con código, consulten nuestra [Guía de contribución](https://github.com/langgenius/dify/blob/main/CONTRIBUTING_ES.md). +Para aquellos que deseen contribuir con código, consulten nuestra [Guía de contribución](./CONTRIBUTING.md). Al mismo tiempo, considera apoyar a Dify compartiéndolo en redes sociales y en eventos y conferencias. > Estamos buscando colaboradores para ayudar con la traducción de Dify a idiomas que no sean el mandarín o el inglés. Si estás interesado en ayudar, consulta el [README de i18n](https://github.com/langgenius/dify/blob/main/web/i18n-config/README.md) para obtener más información y déjanos un comentario en el canal `global-users` de nuestro [Servidor de Comunidad en Discord](https://discord.gg/8Tpq4AcN9c). @@ -184,7 +186,7 @@ Al mismo tiempo, considera apoyar a Dify compartiéndolo en redes sociales y en ## Comunidad y Contacto - [Discusión en GitHub](https://github.com/langgenius/dify/discussions). Lo mejor para: compartir comentarios y hacer preguntas. -- [Reporte de problemas en GitHub](https://github.com/langgenius/dify/issues). Lo mejor para: errores que encuentres usando Dify.AI y propuestas de características. Consulta nuestra [Guía de contribución](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md). +- [Reporte de problemas en GitHub](https://github.com/langgenius/dify/issues). Lo mejor para: errores que encuentres usando Dify.AI y propuestas de características. Consulta nuestra [Guía de contribución](./CONTRIBUTING.md). - [Discord](https://discord.gg/FngNHpbcY7). Lo mejor para: compartir tus aplicaciones y pasar el rato con la comunidad. - [X(Twitter)](https://twitter.com/dify_ai). Lo mejor para: compartir tus aplicaciones y pasar el rato con la comunidad. @@ -198,12 +200,4 @@ Para proteger tu privacidad, evita publicar problemas de seguridad en GitHub. En ## Licencia -Este repositorio está disponible bajo la [Licencia de Código Abierto de Dify](LICENSE), que es esencialmente Apache 2.0 con algunas restricciones adicionales. - -## Divulgación de Seguridad - -Para proteger tu privacidad, evita publicar problemas de seguridad en GitHub. En su lugar, envía tus preguntas a security@dify.ai y te proporcionaremos una respuesta más detallada. - -## Licencia - -Este repositorio está disponible bajo la [Licencia de Código Abierto de Dify](LICENSE), que es esencialmente Apache 2.0 con algunas restricciones adicionales. +Este repositorio está disponible bajo la [Licencia de Código Abierto de Dify](../../LICENSE), que es esencialmente Apache 2.0 con algunas restricciones adicionales. diff --git a/CONTRIBUTING_FR.md b/docs/fr-FR/CONTRIBUTING.md similarity index 96% rename from CONTRIBUTING_FR.md rename to docs/fr-FR/CONTRIBUTING.md index 335e943fcd..74e44ca734 100644 --- a/CONTRIBUTING_FR.md +++ b/docs/fr-FR/CONTRIBUTING.md @@ -6,7 +6,7 @@ Nous devons être agiles et livrer rapidement compte tenu de notre position, mai Ce guide, comme Dify lui-même, est un travail en constante évolution. Nous apprécions grandement votre compréhension si parfois il est en retard par rapport au projet réel, et nous accueillons tout commentaire pour nous aider à nous améliorer. -En termes de licence, veuillez prendre une minute pour lire notre bref [Accord de Licence et de Contributeur](./LICENSE). La communauté adhère également au [code de conduite](https://github.com/langgenius/.github/blob/main/CODE_OF_CONDUCT.md). +En termes de licence, veuillez prendre une minute pour lire notre bref [Accord de Licence et de Contributeur](../../LICENSE). La communauté adhère également au [code de conduite](https://github.com/langgenius/.github/blob/main/CODE_OF_CONDUCT.md). ## Avant de vous lancer diff --git a/README_FR.md b/docs/fr-FR/README.md similarity index 79% rename from README_FR.md rename to docs/fr-FR/README.md index 7474ea50c2..3aca9a9672 100644 --- a/README_FR.md +++ b/docs/fr-FR/README.md @@ -1,4 +1,4 @@ -![cover-v5-optimized](./images/GitHub_README_if.png) +![cover-v5-optimized](../../images/GitHub_README_if.png)

Dify Cloud · @@ -35,17 +35,19 @@

- README in English - 简体中文版自述文件 - 日本語のREADME - README en Español - README en Français - README tlhIngan Hol - README in Korean - README بالعربية - Türkçe README - README Tiếng Việt - README in বাংলা + README in English + 繁體中文文件 + 简体中文文件 + 日本語のREADME + README en Español + README en Français + README tlhIngan Hol + README in Korean + README بالعربية + Türkçe README + README Tiếng Việt + README in Deutsch + README in বাংলা

# @@ -108,7 +110,7 @@ Mettez une étoile à Dify sur GitHub et soyez instantanément informé des nouv
-La manière la plus simple de démarrer le serveur Dify est d'exécuter notre fichier [docker-compose.yml](docker/docker-compose.yaml). Avant d'exécuter la commande d'installation, assurez-vous que [Docker](https://docs.docker.com/get-docker/) et [Docker Compose](https://docs.docker.com/compose/install/) sont installés sur votre machine: +La manière la plus simple de démarrer le serveur Dify est d'exécuter notre fichier [docker-compose.yml](../../docker/docker-compose.yaml). Avant d'exécuter la commande d'installation, assurez-vous que [Docker](https://docs.docker.com/get-docker/) et [Docker Compose](https://docs.docker.com/compose/install/) sont installés sur votre machine: ```bash cd docker @@ -122,7 +124,7 @@ Après l'exécution, vous pouvez accéder au tableau de bord Dify dans votre nav ## Prochaines étapes -Si vous devez personnaliser la configuration, veuillez vous référer aux commentaires dans notre fichier [.env.example](docker/.env.example) et mettre à jour les valeurs correspondantes dans votre fichier `.env`. De plus, vous devrez peut-être apporter des modifications au fichier `docker-compose.yaml` lui-même, comme changer les versions d'image, les mappages de ports ou les montages de volumes, en fonction de votre environnement de déploiement et de vos exigences spécifiques. Après avoir effectué des modifications, veuillez réexécuter `docker-compose up -d`. Vous pouvez trouver la liste complète des variables d'environnement disponibles [ici](https://docs.dify.ai/getting-started/install-self-hosted/environments). +Si vous devez personnaliser la configuration, veuillez vous référer aux commentaires dans notre fichier [.env.example](../../docker/.env.example) et mettre à jour les valeurs correspondantes dans votre fichier `.env`. De plus, vous devrez peut-être apporter des modifications au fichier `docker-compose.yaml` lui-même, comme changer les versions d'image, les mappages de ports ou les montages de volumes, en fonction de votre environnement de déploiement et de vos exigences spécifiques. Après avoir effectué des modifications, veuillez réexécuter `docker-compose up -d`. Vous pouvez trouver la liste complète des variables d'environnement disponibles [ici](https://docs.dify.ai/getting-started/install-self-hosted/environments). Si vous souhaitez configurer une configuration haute disponibilité, la communauté fournit des [Helm Charts](https://helm.sh/) et des fichiers YAML, à travers lesquels vous pouvez déployer Dify sur Kubernetes. @@ -168,7 +170,7 @@ Déployez Dify sur AKS en un clic en utilisant [Azure Devops Pipeline Helm Chart ## Contribuer -Pour ceux qui souhaitent contribuer du code, consultez notre [Guide de contribution](https://github.com/langgenius/dify/blob/main/CONTRIBUTING_FR.md). +Pour ceux qui souhaitent contribuer du code, consultez notre [Guide de contribution](./CONTRIBUTING.md). Dans le même temps, veuillez envisager de soutenir Dify en le partageant sur les réseaux sociaux et lors d'événements et de conférences. > Nous recherchons des contributeurs pour aider à traduire Dify dans des langues autres que le mandarin ou l'anglais. Si vous êtes intéressé à aider, veuillez consulter le [README i18n](https://github.com/langgenius/dify/blob/main/web/i18n-config/README.md) pour plus d'informations, et laissez-nous un commentaire dans le canal `global-users` de notre [Serveur communautaire Discord](https://discord.gg/8Tpq4AcN9c). @@ -182,7 +184,7 @@ Dans le même temps, veuillez envisager de soutenir Dify en le partageant sur le ## Communauté & Contact - [Discussion GitHub](https://github.com/langgenius/dify/discussions). Meilleur pour: partager des commentaires et poser des questions. -- [Problèmes GitHub](https://github.com/langgenius/dify/issues). Meilleur pour: les bogues que vous rencontrez en utilisant Dify.AI et les propositions de fonctionnalités. Consultez notre [Guide de contribution](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md). +- [Problèmes GitHub](https://github.com/langgenius/dify/issues). Meilleur pour: les bogues que vous rencontrez en utilisant Dify.AI et les propositions de fonctionnalités. Consultez notre [Guide de contribution](./CONTRIBUTING.md). - [Discord](https://discord.gg/FngNHpbcY7). Meilleur pour: partager vos applications et passer du temps avec la communauté. - [X(Twitter)](https://twitter.com/dify_ai). Meilleur pour: partager vos applications et passer du temps avec la communauté. @@ -196,12 +198,4 @@ Pour protéger votre vie privée, veuillez éviter de publier des problèmes de ## Licence -Ce référentiel est disponible sous la [Licence open source Dify](LICENSE), qui est essentiellement l'Apache 2.0 avec quelques restrictions supplémentaires. - -## Divulgation de sécurité - -Pour protéger votre vie privée, veuillez éviter de publier des problèmes de sécurité sur GitHub. Au lieu de cela, envoyez vos questions à security@dify.ai et nous vous fournirons une réponse plus détaillée. - -## Licence - -Ce référentiel est disponible sous la [Licence open source Dify](LICENSE), qui est essentiellement l'Apache 2.0 avec quelques restrictions supplémentaires. +Ce référentiel est disponible sous la [Licence open source Dify](../../LICENSE), qui est essentiellement l'Apache 2.0 avec quelques restrictions supplémentaires. diff --git a/CONTRIBUTING_JA.md b/docs/ja-JP/CONTRIBUTING.md similarity index 96% rename from CONTRIBUTING_JA.md rename to docs/ja-JP/CONTRIBUTING.md index 2d0d79fc16..4ee7d8c963 100644 --- a/CONTRIBUTING_JA.md +++ b/docs/ja-JP/CONTRIBUTING.md @@ -6,7 +6,7 @@ Difyに貢献しようとお考えですか?素晴らしいですね。私た このガイドは、Dify自体と同様に、常に進化し続けています。実際のプロジェクトの進行状況と多少のずれが生じる場合もございますが、ご理解いただけますと幸いです。改善のためのフィードバックも歓迎いたします。 -ライセンスについては、[ライセンスと貢献者同意書](./LICENSE)をご一読ください。また、コミュニティは[行動規範](https://github.com/langgenius/.github/blob/main/CODE_OF_CONDUCT.md)に従っています。 +ライセンスについては、[ライセンスと貢献者同意書](../../LICENSE)をご一読ください。また、コミュニティは[行動規範](https://github.com/langgenius/.github/blob/main/CODE_OF_CONDUCT.md)に従っています。 ## 始める前に diff --git a/README_JA.md b/docs/ja-JP/README.md similarity index 79% rename from README_JA.md rename to docs/ja-JP/README.md index a782849f6e..66831285d6 100644 --- a/README_JA.md +++ b/docs/ja-JP/README.md @@ -1,4 +1,4 @@ -![cover-v5-optimized](./images/GitHub_README_if.png) +![cover-v5-optimized](../../images/GitHub_README_if.png)

Dify Cloud · @@ -35,17 +35,19 @@

- README in English - 简体中文版自述文件 - 日本語のREADME - README en Español - README en Français - README tlhIngan Hol - README in Korean - README بالعربية - Türkçe README - README Tiếng Việt - README in বাংলা + README in English + 繁體中文文件 + 简体中文文件 + 日本語のREADME + README en Español + README en Français + README tlhIngan Hol + README in Korean + README بالعربية + Türkçe README + README Tiếng Việt + README in Deutsch + README in বাংলা

# @@ -109,7 +111,7 @@ GitHub上でDifyにスターを付けることで、Difyに関する新しいニ
-Difyサーバーを起動する最も簡単な方法は、[docker-compose.yml](docker/docker-compose.yaml)ファイルを実行することです。インストールコマンドを実行する前に、マシンに[Docker](https://docs.docker.com/get-docker/)と[Docker Compose](https://docs.docker.com/compose/install/)がインストールされていることを確認してください。 +Difyサーバーを起動する最も簡単な方法は、[docker-compose.yml](../../docker/docker-compose.yaml)ファイルを実行することです。インストールコマンドを実行する前に、マシンに[Docker](https://docs.docker.com/get-docker/)と[Docker Compose](https://docs.docker.com/compose/install/)がインストールされていることを確認してください。 ```bash cd docker @@ -123,7 +125,7 @@ docker compose up -d ## 次のステップ -設定をカスタマイズする必要がある場合は、[.env.example](docker/.env.example) ファイルのコメントを参照し、`.env` ファイルの対応する値を更新してください。さらに、デプロイ環境や要件に応じて、`docker-compose.yaml` ファイル自体を調整する必要がある場合があります。たとえば、イメージのバージョン、ポートのマッピング、ボリュームのマウントなどを変更します。変更を加えた後は、`docker-compose up -d` を再実行してください。利用可能な環境変数の全一覧は、[こちら](https://docs.dify.ai/getting-started/install-self-hosted/environments)で確認できます。 +設定をカスタマイズする必要がある場合は、[.env.example](../../docker/.env.example) ファイルのコメントを参照し、`.env` ファイルの対応する値を更新してください。さらに、デプロイ環境や要件に応じて、`docker-compose.yaml` ファイル自体を調整する必要がある場合があります。たとえば、イメージのバージョン、ポートのマッピング、ボリュームのマウントなどを変更します。変更を加えた後は、`docker-compose up -d` を再実行してください。利用可能な環境変数の全一覧は、[こちら](https://docs.dify.ai/getting-started/install-self-hosted/environments)で確認できます。 高可用性設定を設定する必要がある場合、コミュニティは[Helm Charts](https://helm.sh/)とYAMLファイルにより、DifyをKubernetesにデプロイすることができます。 @@ -169,7 +171,7 @@ docker compose up -d ## 貢献 -コードに貢献したい方は、[Contribution Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING_JA.md)を参照してください。 +コードに貢献したい方は、[Contribution Guide](./CONTRIBUTING.md)を参照してください。 同時に、DifyをSNSやイベント、カンファレンスで共有してサポートしていただけると幸いです。 > Difyを英語または中国語以外の言語に翻訳してくれる貢献者を募集しています。興味がある場合は、詳細については[i18n README](https://github.com/langgenius/dify/blob/main/web/i18n-config/README.md)を参照してください。また、[Discordコミュニティサーバー](https://discord.gg/8Tpq4AcN9c)の`global-users`チャンネルにコメントを残してください。 @@ -183,10 +185,10 @@ docker compose up -d ## コミュニティ & お問い合わせ - [GitHub Discussion](https://github.com/langgenius/dify/discussions). 主に: フィードバックの共有や質問。 -- [GitHub Issues](https://github.com/langgenius/dify/issues). 主に: Dify.AIを使用する際に発生するエラーや問題については、[貢献ガイド](CONTRIBUTING_JA.md)を参照してください +- [GitHub Issues](https://github.com/langgenius/dify/issues). 主に: Dify.AIを使用する際に発生するエラーや問題については、[貢献ガイド](./CONTRIBUTING.md)を参照してください - [Discord](https://discord.gg/FngNHpbcY7). 主に: アプリケーションの共有やコミュニティとの交流。 - [X(Twitter)](https://twitter.com/dify_ai). 主に: アプリケーションの共有やコミュニティとの交流。 ## ライセンス -このリポジトリは、Dify Open Source License にいくつかの追加制限を加えた[Difyオープンソースライセンス](LICENSE)の下で利用可能です。 +このリポジトリは、Dify Open Source License にいくつかの追加制限を加えた[Difyオープンソースライセンス](../../LICENSE)の下で利用可能です。 diff --git a/CONTRIBUTING_KR.md b/docs/ko-KR/CONTRIBUTING.md similarity index 96% rename from CONTRIBUTING_KR.md rename to docs/ko-KR/CONTRIBUTING.md index 14b1c9a9ca..9c171c3561 100644 --- a/CONTRIBUTING_KR.md +++ b/docs/ko-KR/CONTRIBUTING.md @@ -6,7 +6,7 @@ Dify에 기여하려고 하시는군요 - 정말 멋집니다, 당신이 무엇 이 가이드는 Dify 자체와 마찬가지로 끊임없이 진행 중인 작업입니다. 때로는 실제 프로젝트보다 뒤처질 수 있다는 점을 이해해 주시면 감사하겠으며, 개선을 위한 피드백은 언제든지 환영합니다. -라이센스 측면에서, 간략한 [라이센스 및 기여자 동의서](./LICENSE)를 읽어보는 시간을 가져주세요. 커뮤니티는 또한 [행동 강령](https://github.com/langgenius/.github/blob/main/CODE_OF_CONDUCT.md)을 준수합니다. +라이센스 측면에서, 간략한 [라이센스 및 기여자 동의서](../../LICENSE)를 읽어보는 시간을 가져주세요. 커뮤니티는 또한 [행동 강령](https://github.com/langgenius/.github/blob/main/CODE_OF_CONDUCT.md)을 준수합니다. ## 시작하기 전에 diff --git a/README_KR.md b/docs/ko-KR/README.md similarity index 79% rename from README_KR.md rename to docs/ko-KR/README.md index ec28cc0f61..ec67bc90ed 100644 --- a/README_KR.md +++ b/docs/ko-KR/README.md @@ -1,4 +1,4 @@ -![cover-v5-optimized](./images/GitHub_README_if.png) +![cover-v5-optimized](../../images/GitHub_README_if.png)

Dify 클라우드 · @@ -35,17 +35,19 @@

- README in English - 简体中文版自述文件 - 日本語のREADME - README en Español - README en Français - README tlhIngan Hol - README in Korean - README بالعربية - Türkçe README - README Tiếng Việt - README in বাংলা + README in English + 繁體中文文件 + 简体中文文件 + 日本語のREADME + README en Español + README en Français + README tlhIngan Hol + README in Korean + README بالعربية + Türkçe README + README Tiếng Việt + README in Deutsch + README in বাংলা

Dify는 오픈 소스 LLM 앱 개발 플랫폼입니다. 직관적인 인터페이스를 통해 AI 워크플로우, RAG 파이프라인, 에이전트 기능, 모델 관리, 관찰 기능 등을 결합하여 프로토타입에서 프로덕션까지 빠르게 전환할 수 있습니다. 주요 기능 목록은 다음과 같습니다:

@@ -102,7 +104,7 @@ GitHub에서 Dify에 별표를 찍어 새로운 릴리스를 즉시 알림 받
-Dify 서버를 시작하는 가장 쉬운 방법은 [docker-compose.yml](docker/docker-compose.yaml) 파일을 실행하는 것입니다. 설치 명령을 실행하기 전에 [Docker](https://docs.docker.com/get-docker/) 및 [Docker Compose](https://docs.docker.com/compose/install/)가 머신에 설치되어 있는지 확인하세요. +Dify 서버를 시작하는 가장 쉬운 방법은 [docker-compose.yml](../../docker/docker-compose.yaml) 파일을 실행하는 것입니다. 설치 명령을 실행하기 전에 [Docker](https://docs.docker.com/get-docker/) 및 [Docker Compose](https://docs.docker.com/compose/install/)가 머신에 설치되어 있는지 확인하세요. ```bash cd docker @@ -116,7 +118,7 @@ docker compose up -d ## 다음 단계 -구성을 사용자 정의해야 하는 경우 [.env.example](docker/.env.example) 파일의 주석을 참조하고 `.env` 파일에서 해당 값을 업데이트하십시오. 또한 특정 배포 환경 및 요구 사항에 따라 `docker-compose.yaml` 파일 자체를 조정해야 할 수도 있습니다. 예를 들어 이미지 버전, 포트 매핑 또는 볼륨 마운트를 변경합니다. 변경 한 후 `docker-compose up -d`를 다시 실행하십시오. 사용 가능한 환경 변수의 전체 목록은 [여기](https://docs.dify.ai/getting-started/install-self-hosted/environments)에서 찾을 수 있습니다. +구성을 사용자 정의해야 하는 경우 [.env.example](../../docker/.env.example) 파일의 주석을 참조하고 `.env` 파일에서 해당 값을 업데이트하십시오. 또한 특정 배포 환경 및 요구 사항에 따라 `docker-compose.yaml` 파일 자체를 조정해야 할 수도 있습니다. 예를 들어 이미지 버전, 포트 매핑 또는 볼륨 마운트를 변경합니다. 변경 한 후 `docker-compose up -d`를 다시 실행하십시오. 사용 가능한 환경 변수의 전체 목록은 [여기](https://docs.dify.ai/getting-started/install-self-hosted/environments)에서 찾을 수 있습니다. Dify를 Kubernetes에 배포하고 프리미엄 스케일링 설정을 구성했다는 커뮤니티가 제공하는 [Helm Charts](https://helm.sh/)와 YAML 파일이 존재합니다. @@ -162,7 +164,7 @@ Dify를 Kubernetes에 배포하고 프리미엄 스케일링 설정을 구성했 ## 기여 -코드에 기여하고 싶은 분들은 [기여 가이드](https://github.com/langgenius/dify/blob/main/CONTRIBUTING_KR.md)를 참조하세요. +코드에 기여하고 싶은 분들은 [기여 가이드](./CONTRIBUTING.md)를 참조하세요. 동시에 Dify를 소셜 미디어와 행사 및 컨퍼런스에 공유하여 지원하는 것을 고려해 주시기 바랍니다. > 우리는 Dify를 중국어나 영어 이외의 언어로 번역하는 데 도움을 줄 수 있는 기여자를 찾고 있습니다. 도움을 주고 싶으시다면 [i18n README](https://github.com/langgenius/dify/blob/main/web/i18n-config/README.md)에서 더 많은 정보를 확인하시고 [Discord 커뮤니티 서버](https://discord.gg/8Tpq4AcN9c)의 `global-users` 채널에 댓글을 남겨주세요. @@ -176,7 +178,7 @@ Dify를 Kubernetes에 배포하고 프리미엄 스케일링 설정을 구성했 ## 커뮤니티 & 연락처 - [GitHub 토론](https://github.com/langgenius/dify/discussions). 피드백 공유 및 질문하기에 적합합니다. -- [GitHub 이슈](https://github.com/langgenius/dify/issues). Dify.AI 사용 중 발견한 버그와 기능 제안에 적합합니다. [기여 가이드](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md)를 참조하세요. +- [GitHub 이슈](https://github.com/langgenius/dify/issues). Dify.AI 사용 중 발견한 버그와 기능 제안에 적합합니다. [기여 가이드](./CONTRIBUTING.md)를 참조하세요. - [디스코드](https://discord.gg/FngNHpbcY7). 애플리케이션 공유 및 커뮤니티와 소통하기에 적합합니다. - [트위터](https://twitter.com/dify_ai). 애플리케이션 공유 및 커뮤니티와 소통하기에 적합합니다. @@ -190,4 +192,4 @@ Dify를 Kubernetes에 배포하고 프리미엄 스케일링 설정을 구성했 ## 라이선스 -이 저장소는 기본적으로 몇 가지 추가 제한 사항이 있는 Apache 2.0인 [Dify 오픈 소스 라이선스](LICENSE)에 따라 사용할 수 있습니다. +이 저장소는 기본적으로 몇 가지 추가 제한 사항이 있는 Apache 2.0인 [Dify 오픈 소스 라이선스](../../LICENSE)에 따라 사용할 수 있습니다. diff --git a/CONTRIBUTING_PT.md b/docs/pt-BR/CONTRIBUTING.md similarity index 96% rename from CONTRIBUTING_PT.md rename to docs/pt-BR/CONTRIBUTING.md index aeabcad51f..737b2ddce2 100644 --- a/CONTRIBUTING_PT.md +++ b/docs/pt-BR/CONTRIBUTING.md @@ -6,7 +6,7 @@ Precisamos ser ágeis e entregar rapidamente considerando onde estamos, mas tamb Este guia, como o próprio Dify, é um trabalho em constante evolução. Agradecemos muito a sua compreensão se às vezes ele ficar atrasado em relação ao projeto real, e damos as boas-vindas a qualquer feedback para que possamos melhorar. -Em termos de licenciamento, por favor, dedique um minuto para ler nosso breve [Acordo de Licença e Contribuidor](./LICENSE). A comunidade também adere ao [código de conduta](https://github.com/langgenius/.github/blob/main/CODE_OF_CONDUCT.md). +Em termos de licenciamento, por favor, dedique um minuto para ler nosso breve [Acordo de Licença e Contribuidor](../../LICENSE). A comunidade também adere ao [código de conduta](https://github.com/langgenius/.github/blob/main/CODE_OF_CONDUCT.md). ## Antes de começar diff --git a/README_PT.md b/docs/pt-BR/README.md similarity index 78% rename from README_PT.md rename to docs/pt-BR/README.md index da8f354a49..78383a3c76 100644 --- a/README_PT.md +++ b/docs/pt-BR/README.md @@ -1,4 +1,4 @@ -![cover-v5-optimized](./images/GitHub_README_if.png) +![cover-v5-optimized](../../images/GitHub_README_if.png)

📌 Introduzindo o Dify Workflow com Upload de Arquivo: Recrie o Podcast Google NotebookLM @@ -39,18 +39,20 @@

- README em Inglês - 简体中文版自述文件 - 日本語のREADME - README em Espanhol - README em Francês - README tlhIngan Hol - README em Coreano - README em Árabe - README em Turco - README em Vietnamita - README em Português - BR - README in বাংলা + README em Inglês + 繁體中文文件 + 简体中文文件 + 日本語のREADME + README em Espanhol + README em Francês + README tlhIngan Hol + README em Coreano + README em Árabe + README em Turco + README em Vietnamita + README em Português - BR + README in Deutsch + README in বাংলা

Dify é uma plataforma de desenvolvimento de aplicativos LLM de código aberto. Sua interface intuitiva combina workflow de IA, pipeline RAG, capacidades de agente, gerenciamento de modelos, recursos de observabilidade e muito mais, permitindo que você vá rapidamente do protótipo à produção. Aqui está uma lista das principais funcionalidades: @@ -108,7 +110,7 @@ Dê uma estrela no Dify no GitHub e seja notificado imediatamente sobre novos la
-A maneira mais fácil de iniciar o servidor Dify é executar nosso arquivo [docker-compose.yml](docker/docker-compose.yaml). Antes de rodar o comando de instalação, certifique-se de que o [Docker](https://docs.docker.com/get-docker/) e o [Docker Compose](https://docs.docker.com/compose/install/) estão instalados na sua máquina: +A maneira mais fácil de iniciar o servidor Dify é executar nosso arquivo [docker-compose.yml](../../docker/docker-compose.yaml). Antes de rodar o comando de instalação, certifique-se de que o [Docker](https://docs.docker.com/get-docker/) e o [Docker Compose](https://docs.docker.com/compose/install/) estão instalados na sua máquina: ```bash cd docker @@ -122,7 +124,7 @@ Após a execução, você pode acessar o painel do Dify no navegador em [http:// ## Próximos passos -Se precisar personalizar a configuração, consulte os comentários no nosso arquivo [.env.example](docker/.env.example) e atualize os valores correspondentes no seu arquivo `.env`. Além disso, talvez seja necessário fazer ajustes no próprio arquivo `docker-compose.yaml`, como alterar versões de imagem, mapeamentos de portas ou montagens de volumes, com base no seu ambiente de implantação específico e nas suas necessidades. Após fazer quaisquer alterações, execute novamente `docker-compose up -d`. Você pode encontrar a lista completa de variáveis de ambiente disponíveis [aqui](https://docs.dify.ai/getting-started/install-self-hosted/environments). +Se precisar personalizar a configuração, consulte os comentários no nosso arquivo [.env.example](../../docker/.env.example) e atualize os valores correspondentes no seu arquivo `.env`. Além disso, talvez seja necessário fazer ajustes no próprio arquivo `docker-compose.yaml`, como alterar versões de imagem, mapeamentos de portas ou montagens de volumes, com base no seu ambiente de implantação específico e nas suas necessidades. Após fazer quaisquer alterações, execute novamente `docker-compose up -d`. Você pode encontrar a lista completa de variáveis de ambiente disponíveis [aqui](https://docs.dify.ai/getting-started/install-self-hosted/environments). Se deseja configurar uma instalação de alta disponibilidade, há [Helm Charts](https://helm.sh/) e arquivos YAML contribuídos pela comunidade que permitem a implantação do Dify no Kubernetes. @@ -168,7 +170,7 @@ Implante o Dify no AKS com um clique usando [Azure Devops Pipeline Helm Chart by ## Contribuindo -Para aqueles que desejam contribuir com código, veja nosso [Guia de Contribuição](https://github.com/langgenius/dify/blob/main/CONTRIBUTING_PT.md). +Para aqueles que desejam contribuir com código, veja nosso [Guia de Contribuição](./CONTRIBUTING.md). Ao mesmo tempo, considere apoiar o Dify compartilhando-o nas redes sociais e em eventos e conferências. > Estamos buscando contribuidores para ajudar na tradução do Dify para idiomas além de Mandarim e Inglês. Se você tiver interesse em ajudar, consulte o [README i18n](https://github.com/langgenius/dify/blob/main/web/i18n-config/README.md) para mais informações e deixe-nos um comentário no canal `global-users` em nosso [Servidor da Comunidade no Discord](https://discord.gg/8Tpq4AcN9c). @@ -182,7 +184,7 @@ Ao mesmo tempo, considere apoiar o Dify compartilhando-o nas redes sociais e em ## Comunidade e contato - [Discussões no GitHub](https://github.com/langgenius/dify/discussions). Melhor para: compartilhar feedback e fazer perguntas. -- [Problemas no GitHub](https://github.com/langgenius/dify/issues). Melhor para: relatar bugs encontrados no Dify.AI e propor novos recursos. Veja nosso [Guia de Contribuição](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md). +- [Problemas no GitHub](https://github.com/langgenius/dify/issues). Melhor para: relatar bugs encontrados no Dify.AI e propor novos recursos. Veja nosso [Guia de Contribuição](./CONTRIBUTING.md). - [Discord](https://discord.gg/FngNHpbcY7). Melhor para: compartilhar suas aplicações e interagir com a comunidade. - [X(Twitter)](https://twitter.com/dify_ai). Melhor para: compartilhar suas aplicações e interagir com a comunidade. @@ -196,4 +198,4 @@ Para proteger sua privacidade, evite postar problemas de segurança no GitHub. E ## Licença -Este repositório está disponível sob a [Licença de Código Aberto Dify](LICENSE), que é essencialmente Apache 2.0 com algumas restrições adicionais. +Este repositório está disponível sob a [Licença de Código Aberto Dify](../../LICENSE), que é essencialmente Apache 2.0 com algumas restrições adicionais. diff --git a/README_SI.md b/docs/sl-SI/README.md similarity index 83% rename from README_SI.md rename to docs/sl-SI/README.md index c20dc3484f..65aedb7703 100644 --- a/README_SI.md +++ b/docs/sl-SI/README.md @@ -1,4 +1,4 @@ -![cover-v5-optimized](./images/GitHub_README_if.png) +![cover-v5-optimized](../../images/GitHub_README_if.png)

📌 Predstavljamo nalaganje datotek Dify Workflow: znova ustvarite Google NotebookLM Podcast @@ -36,18 +36,20 @@

- README in English - 简体中文版自述文件 - 日本語のREADME - README en Español - README en Français - README tlhIngan Hol - README in Korean - README بالعربية - Türkçe README - README Tiếng Việt - README Slovenščina - README in বাংলা + README in English + 繁體中文文件 + 简体中文文件 + 日本語のREADME + README en Español + README en Français + README tlhIngan Hol + README in Korean + README بالعربية + Türkçe README + README Tiếng Việt + README Slovenščina + README in Deutsch + README in বাংলা

Dify je odprtokodna platforma za razvoj aplikacij LLM. Njegov intuitivni vmesnik združuje agentski potek dela z umetno inteligenco, cevovod RAG, zmogljivosti agentov, upravljanje modelov, funkcije opazovanja in več, kar vam omogoča hiter prehod od prototipa do proizvodnje. @@ -169,7 +171,7 @@ Z enim klikom namestite Dify v AKS z uporabo [Azure Devops Pipeline Helm Chart b ## Prispevam -Za tiste, ki bi radi prispevali kodo, si oglejte naš vodnik za prispevke . Hkrati vas prosimo, da podprete Dify tako, da ga delite na družbenih medijih ter na dogodkih in konferencah. +Za tiste, ki bi radi prispevali kodo, si oglejte naš [vodnik za prispevke](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md). Hkrati vas prosimo, da podprete Dify tako, da ga delite na družbenih medijih ter na dogodkih in konferencah. > Iščemo sodelavce za pomoč pri prevajanju Difyja v jezike, ki niso mandarinščina ali angleščina. Če želite pomagati, si oglejte i18n README za več informacij in nam pustite komentar v global-userskanalu našega strežnika skupnosti Discord . @@ -196,4 +198,4 @@ Zaradi zaščite vaše zasebnosti se izogibajte objavljanju varnostnih vprašanj ## Licenca -To skladišče je na voljo pod [odprtokodno licenco Dify](LICENSE) , ki je v bistvu Apache 2.0 z nekaj dodatnimi omejitvami. +To skladišče je na voljo pod [odprtokodno licenco Dify](../../LICENSE) , ki je v bistvu Apache 2.0 z nekaj dodatnimi omejitvami. diff --git a/README_KL.md b/docs/tlh/README.md similarity index 79% rename from README_KL.md rename to docs/tlh/README.md index 93da9a6140..b1e3016efd 100644 --- a/README_KL.md +++ b/docs/tlh/README.md @@ -1,4 +1,4 @@ -![cover-v5-optimized](./images/GitHub_README_if.png) +![cover-v5-optimized](../../images/GitHub_README_if.png)

Dify Cloud · @@ -35,17 +35,19 @@

- README in English - 简体中文版自述文件 - 日本語のREADME - README en Español - README en Français - README tlhIngan Hol - README in Korean - README بالعربية - Türkçe README - README Tiếng Việt - README in বাংলা + README in English + 繁體中文文件 + 简体中文文件 + 日本語のREADME + README en Español + README en Français + README tlhIngan Hol + README in Korean + README بالعربية + Türkçe README + README Tiếng Việt + README in Deutsch + README in বাংলা

# @@ -108,7 +110,7 @@ Star Dify on GitHub and be instantly notified of new releases.
-The easiest way to start the Dify server is to run our [docker-compose.yml](docker/docker-compose.yaml) file. Before running the installation command, make sure that [Docker](https://docs.docker.com/get-docker/) and [Docker Compose](https://docs.docker.com/compose/install/) are installed on your machine: +The easiest way to start the Dify server is to run our [docker-compose.yml](../../docker/docker-compose.yaml) file. Before running the installation command, make sure that [Docker](https://docs.docker.com/get-docker/) and [Docker Compose](https://docs.docker.com/compose/install/) are installed on your machine: ```bash cd docker @@ -122,7 +124,7 @@ After running, you can access the Dify dashboard in your browser at [http://loca ## Next steps -If you need to customize the configuration, please refer to the comments in our [.env.example](docker/.env.example) file and update the corresponding values in your `.env` file. Additionally, you might need to make adjustments to the `docker-compose.yaml` file itself, such as changing image versions, port mappings, or volume mounts, based on your specific deployment environment and requirements. After making any changes, please re-run `docker-compose up -d`. You can find the full list of available environment variables [here](https://docs.dify.ai/getting-started/install-self-hosted/environments). +If you need to customize the configuration, please refer to the comments in our [.env.example](../../docker/.env.example) file and update the corresponding values in your `.env` file. Additionally, you might need to make adjustments to the `docker-compose.yaml` file itself, such as changing image versions, port mappings, or volume mounts, based on your specific deployment environment and requirements. After making any changes, please re-run `docker-compose up -d`. You can find the full list of available environment variables [here](https://docs.dify.ai/getting-started/install-self-hosted/environments). If you'd like to configure a highly-available setup, there are community-contributed [Helm Charts](https://helm.sh/) and YAML files which allow Dify to be deployed on Kubernetes. @@ -181,10 +183,7 @@ At the same time, please consider supporting Dify by sharing it on social media ## Community & Contact -- \[GitHub Discussion\](https://github.com/langgenius/dify/discussions - -). Best for: sharing feedback and asking questions. - +- [GitHub Discussion](https://github.com/langgenius/dify/discussions). Best for: sharing feedback and asking questions. - [GitHub Issues](https://github.com/langgenius/dify/issues). Best for: bugs you encounter using Dify.AI, and feature proposals. See our [Contribution Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md). - [Discord](https://discord.gg/FngNHpbcY7). Best for: sharing your applications and hanging out with the community. - [X(Twitter)](https://twitter.com/dify_ai). Best for: sharing your applications and hanging out with the community. @@ -199,4 +198,4 @@ To protect your privacy, please avoid posting security issues on GitHub. Instead ## License -This repository is available under the [Dify Open Source License](LICENSE), which is essentially Apache 2.0 with a few additional restrictions. +This repository is available under the [Dify Open Source License](../../LICENSE), which is essentially Apache 2.0 with a few additional restrictions. diff --git a/CONTRIBUTING_TR.md b/docs/tr-TR/CONTRIBUTING.md similarity index 96% rename from CONTRIBUTING_TR.md rename to docs/tr-TR/CONTRIBUTING.md index d016802a53..59227d31a9 100644 --- a/CONTRIBUTING_TR.md +++ b/docs/tr-TR/CONTRIBUTING.md @@ -6,7 +6,7 @@ Bulunduğumuz noktada çevik olmamız ve hızlı hareket etmemiz gerekiyor, anca Bu rehber, Dify'ın kendisi gibi, sürekli gelişen bir çalışmadır. Bazen gerçek projenin gerisinde kalırsa anlayışınız için çok minnettarız ve gelişmemize yardımcı olacak her türlü geri bildirimi memnuniyetle karşılıyoruz. -Lisanslama konusunda, lütfen kısa [Lisans ve Katkıda Bulunan Anlaşmamızı](./LICENSE) okumak için bir dakikanızı ayırın. Topluluk ayrıca [davranış kurallarına](https://github.com/langgenius/.github/blob/main/CODE_OF_CONDUCT.md) da uyar. +Lisanslama konusunda, lütfen kısa [Lisans ve Katkıda Bulunan Anlaşmamızı](../../LICENSE) okumak için bir dakikanızı ayırın. Topluluk ayrıca [davranış kurallarına](https://github.com/langgenius/.github/blob/main/CODE_OF_CONDUCT.md) da uyar. ## Başlamadan Önce diff --git a/README_TR.md b/docs/tr-TR/README.md similarity index 79% rename from README_TR.md rename to docs/tr-TR/README.md index 21df0d1605..a044da1f4e 100644 --- a/README_TR.md +++ b/docs/tr-TR/README.md @@ -1,4 +1,4 @@ -![cover-v5-optimized](./images/GitHub_README_if.png) +![cover-v5-optimized](../../images/GitHub_README_if.png)

Dify Bulut · @@ -35,17 +35,19 @@

- README in English - 简体中文版自述文件 - 日本語のREADME - README en Español - README en Français - README tlhIngan Hol - README in Korean - README بالعربية - Türkçe README - README Tiếng Việt - README in বাংলা + README in English + 繁體中文文件 + 简体中文文件 + 日本語のREADME + README en Español + README en Français + README tlhIngan Hol + README in Korean + README بالعربية + Türkçe README + README Tiếng Việt + README in Deutsch + README in বাংলা

Dify, açık kaynaklı bir LLM uygulama geliştirme platformudur. Sezgisel arayüzü, AI iş akışı, RAG pipeline'ı, ajan yetenekleri, model yönetimi, gözlemlenebilirlik özellikleri ve daha fazlasını birleştirerek, prototipten üretime hızlıca geçmenizi sağlar. İşte temel özelliklerin bir listesi: @@ -102,7 +104,7 @@ GitHub'da Dify'a yıldız verin ve yeni sürümlerden anında haberdar olun. > - RAM >= 4GB
-Dify sunucusunu başlatmanın en kolay yolu, [docker-compose.yml](docker/docker-compose.yaml) dosyamızı çalıştırmaktır. Kurulum komutunu çalıştırmadan önce, makinenizde [Docker](https://docs.docker.com/get-docker/) ve [Docker Compose](https://docs.docker.com/compose/install/)'un kurulu olduğundan emin olun: +Dify sunucusunu başlatmanın en kolay yolu, [docker-compose.yml](../../docker/docker-compose.yaml) dosyamızı çalıştırmaktır. Kurulum komutunu çalıştırmadan önce, makinenizde [Docker](https://docs.docker.com/get-docker/) ve [Docker Compose](https://docs.docker.com/compose/install/)'un kurulu olduğundan emin olun: ```bash cd docker @@ -116,7 +118,7 @@ docker compose up -d ## Sonraki adımlar -Yapılandırmayı özelleştirmeniz gerekiyorsa, lütfen [.env.example](docker/.env.example) dosyamızdaki yorumlara bakın ve `.env` dosyanızdaki ilgili değerleri güncelleyin. Ayrıca, spesifik dağıtım ortamınıza ve gereksinimlerinize bağlı olarak `docker-compose.yaml` dosyasının kendisinde de, imaj sürümlerini, port eşlemelerini veya hacim bağlantılarını değiştirmek gibi ayarlamalar yapmanız gerekebilir. Herhangi bir değişiklik yaptıktan sonra, lütfen `docker-compose up -d` komutunu tekrar çalıştırın. Kullanılabilir tüm ortam değişkenlerinin tam listesini [burada](https://docs.dify.ai/getting-started/install-self-hosted/environments) bulabilirsiniz. +Yapılandırmayı özelleştirmeniz gerekiyorsa, lütfen [.env.example](../../docker/.env.example) dosyamızdaki yorumlara bakın ve `.env` dosyanızdaki ilgili değerleri güncelleyin. Ayrıca, spesifik dağıtım ortamınıza ve gereksinimlerinize bağlı olarak `docker-compose.yaml` dosyasının kendisinde de, imaj sürümlerini, port eşlemelerini veya hacim bağlantılarını değiştirmek gibi ayarlamalar yapmanız gerekebilir. Herhangi bir değişiklik yaptıktan sonra, lütfen `docker-compose up -d` komutunu tekrar çalıştırın. Kullanılabilir tüm ortam değişkenlerinin tam listesini [burada](https://docs.dify.ai/getting-started/install-self-hosted/environments) bulabilirsiniz. Yüksek kullanılabilirliğe sahip bir kurulum yapılandırmak isterseniz, Dify'ın Kubernetes üzerine dağıtılmasına olanak tanıyan topluluk katkılı [Helm Charts](https://helm.sh/) ve YAML dosyaları mevcuttur. @@ -161,7 +163,7 @@ Dify'ı bulut platformuna tek tıklamayla dağıtın [terraform](https://www.ter ## Katkıda Bulunma -Kod katkısında bulunmak isteyenler için [Katkı Kılavuzumuza](https://github.com/langgenius/dify/blob/main/CONTRIBUTING_TR.md) bakabilirsiniz. +Kod katkısında bulunmak isteyenler için [Katkı Kılavuzumuza](./CONTRIBUTING.md) bakabilirsiniz. Aynı zamanda, lütfen Dify'ı sosyal medyada, etkinliklerde ve konferanslarda paylaşarak desteklemeyi düşünün. > Dify'ı Mandarin veya İngilizce dışındaki dillere çevirmemize yardımcı olacak katkıda bulunanlara ihtiyacımız var. Yardımcı olmakla ilgileniyorsanız, lütfen daha fazla bilgi için [i18n README](https://github.com/langgenius/dify/blob/main/web/i18n-config/README.md) dosyasına bakın ve [Discord Topluluk Sunucumuzdaki](https://discord.gg/8Tpq4AcN9c) `global-users` kanalında bize bir yorum bırakın. @@ -175,7 +177,7 @@ Aynı zamanda, lütfen Dify'ı sosyal medyada, etkinliklerde ve konferanslarda p ## Topluluk & iletişim - [GitHub Tartışmaları](https://github.com/langgenius/dify/discussions). En uygun: geri bildirim paylaşmak ve soru sormak için. -- [GitHub Sorunları](https://github.com/langgenius/dify/issues). En uygun: Dify.AI kullanırken karşılaştığınız hatalar ve özellik önerileri için. [Katkı Kılavuzumuza](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md) bakın. +- [GitHub Sorunları](https://github.com/langgenius/dify/issues). En uygun: Dify.AI kullanırken karşılaştığınız hatalar ve özellik önerileri için. [Katkı Kılavuzumuza](./CONTRIBUTING.md) bakın. - [Discord](https://discord.gg/FngNHpbcY7). En uygun: uygulamalarınızı paylaşmak ve toplulukla vakit geçirmek için. - [X(Twitter)](https://twitter.com/dify_ai). En uygun: uygulamalarınızı paylaşmak ve toplulukla vakit geçirmek için. @@ -189,4 +191,4 @@ Gizliliğinizi korumak için, lütfen güvenlik sorunlarını GitHub'da paylaşm ## Lisans -Bu depo, temel olarak Apache 2.0 lisansı ve birkaç ek kısıtlama içeren [Dify Açık Kaynak Lisansı](LICENSE) altında kullanıma sunulmuştur. +Bu depo, temel olarak Apache 2.0 lisansı ve birkaç ek kısıtlama içeren [Dify Açık Kaynak Lisansı](../../LICENSE) altında kullanıma sunulmuştur. diff --git a/CONTRIBUTING_VI.md b/docs/vi-VN/CONTRIBUTING.md similarity index 96% rename from CONTRIBUTING_VI.md rename to docs/vi-VN/CONTRIBUTING.md index 2ad431296a..fa1d875f83 100644 --- a/CONTRIBUTING_VI.md +++ b/docs/vi-VN/CONTRIBUTING.md @@ -6,7 +6,7 @@ Chúng tôi cần phải nhanh nhẹn và triển khai nhanh chóng, nhưng cũn Hướng dẫn này, giống như Dify, đang được phát triển liên tục. Chúng tôi rất cảm kích sự thông cảm của bạn nếu đôi khi nó chưa theo kịp dự án thực tế, và hoan nghênh mọi phản hồi để cải thiện. -Về giấy phép, vui lòng dành chút thời gian đọc [Thỏa thuận Cấp phép và Người đóng góp](./LICENSE) ngắn gọn của chúng tôi. Cộng đồng cũng tuân theo [quy tắc ứng xử](https://github.com/langgenius/.github/blob/main/CODE_OF_CONDUCT.md). +Về giấy phép, vui lòng dành chút thời gian đọc [Thỏa thuận Cấp phép và Người đóng góp](../../LICENSE) ngắn gọn của chúng tôi. Cộng đồng cũng tuân theo [quy tắc ứng xử](https://github.com/langgenius/.github/blob/main/CODE_OF_CONDUCT.md). ## Trước khi bắt đầu diff --git a/README_VI.md b/docs/vi-VN/README.md similarity index 80% rename from README_VI.md rename to docs/vi-VN/README.md index 6d5305fb75..847641da12 100644 --- a/README_VI.md +++ b/docs/vi-VN/README.md @@ -1,4 +1,4 @@ -![cover-v5-optimized](./images/GitHub_README_if.png) +![cover-v5-optimized](../../images/GitHub_README_if.png)

Dify Cloud · @@ -35,17 +35,19 @@

- README in English - 简体中文版自述文件 - 日本語のREADME - README en Español - README en Français - README tlhIngan Hol - README in Korean - README بالعربية - Türkçe README - README Tiếng Việt - README in বাংলা + README in English + 繁體中文文件 + 简体中文文件 + 日本語のREADME + README en Español + README en Français + README tlhIngan Hol + README in Korean + README بالعربية + Türkçe README + README Tiếng Việt + README in Deutsch + README in বাংলা

Dify là một nền tảng phát triển ứng dụng LLM mã nguồn mở. Giao diện trực quan kết hợp quy trình làm việc AI, mô hình RAG, khả năng tác nhân, quản lý mô hình, tính năng quan sát và hơn thế nữa, cho phép bạn nhanh chóng chuyển từ nguyên mẫu sang sản phẩm. Đây là danh sách các tính năng cốt lõi: @@ -103,7 +105,7 @@ Yêu thích Dify trên GitHub và được thông báo ngay lập tức về cá
-Cách dễ nhất để khởi động máy chủ Dify là chạy tệp [docker-compose.yml](docker/docker-compose.yaml) của chúng tôi. Trước khi chạy lệnh cài đặt, hãy đảm bảo rằng [Docker](https://docs.docker.com/get-docker/) và [Docker Compose](https://docs.docker.com/compose/install/) đã được cài đặt trên máy của bạn: +Cách dễ nhất để khởi động máy chủ Dify là chạy tệp [docker-compose.yml](../../docker/docker-compose.yaml) của chúng tôi. Trước khi chạy lệnh cài đặt, hãy đảm bảo rằng [Docker](https://docs.docker.com/get-docker/) và [Docker Compose](https://docs.docker.com/compose/install/) đã được cài đặt trên máy của bạn: ```bash cd docker @@ -117,7 +119,7 @@ Sau khi chạy, bạn có thể truy cập bảng điều khiển Dify trong tr ## Các bước tiếp theo -Nếu bạn cần tùy chỉnh cấu hình, vui lòng tham khảo các nhận xét trong tệp [.env.example](docker/.env.example) của chúng tôi và cập nhật các giá trị tương ứng trong tệp `.env` của bạn. Ngoài ra, bạn có thể cần điều chỉnh tệp `docker-compose.yaml`, chẳng hạn như thay đổi phiên bản hình ảnh, ánh xạ cổng hoặc gắn kết khối lượng, dựa trên môi trường triển khai cụ thể và yêu cầu của bạn. Sau khi thực hiện bất kỳ thay đổi nào, vui lòng chạy lại `docker-compose up -d`. Bạn có thể tìm thấy danh sách đầy đủ các biến môi trường có sẵn [tại đây](https://docs.dify.ai/getting-started/install-self-hosted/environments). +Nếu bạn cần tùy chỉnh cấu hình, vui lòng tham khảo các nhận xét trong tệp [.env.example](../../docker/.env.example) của chúng tôi và cập nhật các giá trị tương ứng trong tệp `.env` của bạn. Ngoài ra, bạn có thể cần điều chỉnh tệp `docker-compose.yaml`, chẳng hạn như thay đổi phiên bản hình ảnh, ánh xạ cổng hoặc gắn kết khối lượng, dựa trên môi trường triển khai cụ thể và yêu cầu của bạn. Sau khi thực hiện bất kỳ thay đổi nào, vui lòng chạy lại `docker-compose up -d`. Bạn có thể tìm thấy danh sách đầy đủ các biến môi trường có sẵn [tại đây](https://docs.dify.ai/getting-started/install-self-hosted/environments). Nếu bạn muốn cấu hình một cài đặt có độ sẵn sàng cao, có các [Helm Charts](https://helm.sh/) và tệp YAML do cộng đồng đóng góp cho phép Dify được triển khai trên Kubernetes. @@ -162,7 +164,7 @@ Triển khai Dify lên AKS chỉ với một cú nhấp chuột bằng [Azure De ## Đóng góp -Đối với những người muốn đóng góp mã, xem [Hướng dẫn Đóng góp](https://github.com/langgenius/dify/blob/main/CONTRIBUTING_VI.md) của chúng tôi. +Đối với những người muốn đóng góp mã, xem [Hướng dẫn Đóng góp](./CONTRIBUTING.md) của chúng tôi. Đồng thời, vui lòng xem xét hỗ trợ Dify bằng cách chia sẻ nó trên mạng xã hội và tại các sự kiện và hội nghị. > Chúng tôi đang tìm kiếm người đóng góp để giúp dịch Dify sang các ngôn ngữ khác ngoài tiếng Trung hoặc tiếng Anh. Nếu bạn quan tâm đến việc giúp đỡ, vui lòng xem [README i18n](https://github.com/langgenius/dify/blob/main/web/i18n-config/README.md) để biết thêm thông tin và để lại bình luận cho chúng tôi trong kênh `global-users` của [Máy chủ Cộng đồng Discord](https://discord.gg/8Tpq4AcN9c) của chúng tôi. @@ -176,7 +178,7 @@ Triển khai Dify lên AKS chỉ với một cú nhấp chuột bằng [Azure De ## Cộng đồng & liên hệ - [Thảo luận GitHub](https://github.com/langgenius/dify/discussions). Tốt nhất cho: chia sẻ phản hồi và đặt câu hỏi. -- [Vấn đề GitHub](https://github.com/langgenius/dify/issues). Tốt nhất cho: lỗi bạn gặp phải khi sử dụng Dify.AI và đề xuất tính năng. Xem [Hướng dẫn Đóng góp](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md) của chúng tôi. +- [Vấn đề GitHub](https://github.com/langgenius/dify/issues). Tốt nhất cho: lỗi bạn gặp phải khi sử dụng Dify.AI và đề xuất tính năng. Xem [Hướng dẫn Đóng góp](./CONTRIBUTING.md) của chúng tôi. - [Discord](https://discord.gg/FngNHpbcY7). Tốt nhất cho: chia sẻ ứng dụng của bạn và giao lưu với cộng đồng. - [X(Twitter)](https://twitter.com/dify_ai). Tốt nhất cho: chia sẻ ứng dụng của bạn và giao lưu với cộng đồng. @@ -190,4 +192,4 @@ Triển khai Dify lên AKS chỉ với một cú nhấp chuột bằng [Azure De ## Giấy phép -Kho lưu trữ này có sẵn theo [Giấy phép Mã nguồn Mở Dify](LICENSE), về cơ bản là Apache 2.0 với một vài hạn chế bổ sung. +Kho lưu trữ này có sẵn theo [Giấy phép Mã nguồn Mở Dify](../../LICENSE), về cơ bản là Apache 2.0 với một vài hạn chế bổ sung. diff --git a/CONTRIBUTING_CN.md b/docs/zh-CN/CONTRIBUTING.md similarity index 96% rename from CONTRIBUTING_CN.md rename to docs/zh-CN/CONTRIBUTING.md index c278c8fd7a..5b71467804 100644 --- a/CONTRIBUTING_CN.md +++ b/docs/zh-CN/CONTRIBUTING.md @@ -6,7 +6,7 @@ 本指南和 Dify 一样在不断完善中。如果有任何滞后于项目实际情况的地方,恳请谅解,我们也欢迎任何改进建议。 -关于许可证,请花一分钟阅读我们简短的[许可和贡献者协议](./LICENSE)。同时也请遵循社区[行为准则](https://github.com/langgenius/.github/blob/main/CODE_OF_CONDUCT.md)。 +关于许可证,请花一分钟阅读我们简短的[许可和贡献者协议](../../LICENSE)。同时也请遵循社区[行为准则](https://github.com/langgenius/.github/blob/main/CODE_OF_CONDUCT.md)。 ## 开始之前 diff --git a/README_CN.md b/docs/zh-CN/README.md similarity index 79% rename from README_CN.md rename to docs/zh-CN/README.md index 9aaebf4037..202b99a6b1 100644 --- a/README_CN.md +++ b/docs/zh-CN/README.md @@ -1,4 +1,4 @@ -![cover-v5-optimized](./images/GitHub_README_if.png) +![cover-v5-optimized](../../images/GitHub_README_if.png)
Dify 云服务 · @@ -35,17 +35,19 @@

- README in English - 简体中文版自述文件 - 日本語のREADME - README en Español - README en Français - README tlhIngan Hol - README in Korean - README بالعربية - Türkçe README - README Tiếng Việt - README in বাংলা + README in English + 繁體中文文件 + 简体中文文件 + 日本語のREADME + README en Español + README en Français + README tlhIngan Hol + README in Korean + README بالعربية + Türkçe README + README Tiếng Việt + README in Deutsch + README in বাংলা
# @@ -111,7 +113,7 @@ Dify 是一个开源的 LLM 应用开发平台。其直观的界面结合了 AI ### 快速启动 -启动 Dify 服务器的最简单方法是运行我们的 [docker-compose.yml](docker/docker-compose.yaml) 文件。在运行安装命令之前,请确保您的机器上安装了 [Docker](https://docs.docker.com/get-docker/) 和 [Docker Compose](https://docs.docker.com/compose/install/): +启动 Dify 服务器的最简单方法是运行我们的 [docker-compose.yml](../../docker/docker-compose.yaml) 文件。在运行安装命令之前,请确保您的机器上安装了 [Docker](https://docs.docker.com/get-docker/) 和 [Docker Compose](https://docs.docker.com/compose/install/): ```bash cd docker @@ -123,7 +125,7 @@ docker compose up -d ### 自定义配置 -如果您需要自定义配置,请参考 [.env.example](docker/.env.example) 文件中的注释,并更新 `.env` 文件中对应的值。此外,您可能需要根据您的具体部署环境和需求对 `docker-compose.yaml` 文件本身进行调整,例如更改镜像版本、端口映射或卷挂载。完成任何更改后,请重新运行 `docker-compose up -d`。您可以在[此处](https://docs.dify.ai/getting-started/install-self-hosted/environments)找到可用环境变量的完整列表。 +如果您需要自定义配置,请参考 [.env.example](../../docker/.env.example) 文件中的注释,并更新 `.env` 文件中对应的值。此外,您可能需要根据您的具体部署环境和需求对 `docker-compose.yaml` 文件本身进行调整,例如更改镜像版本、端口映射或卷挂载。完成任何更改后,请重新运行 `docker-compose up -d`。您可以在[此处](https://docs.dify.ai/getting-started/install-self-hosted/environments)找到可用环境变量的完整列表。 #### 使用 Helm Chart 或 Kubernetes 资源清单(YAML)部署 @@ -180,7 +182,7 @@ docker compose up -d ## Contributing -对于那些想要贡献代码的人,请参阅我们的[贡献指南](https://github.com/langgenius/dify/blob/main/CONTRIBUTING_CN.md)。 +对于那些想要贡献代码的人,请参阅我们的[贡献指南](./CONTRIBUTING.md)。 同时,请考虑通过社交媒体、活动和会议来支持 Dify 的分享。 > 我们正在寻找贡献者来帮助将 Dify 翻译成除了中文和英文之外的其他语言。如果您有兴趣帮助,请参阅我们的[i18n README](https://github.com/langgenius/dify/blob/main/web/i18n-config/README.md)获取更多信息,并在我们的[Discord 社区服务器](https://discord.gg/8Tpq4AcN9c)的`global-users`频道中留言。 @@ -196,7 +198,7 @@ docker compose up -d 我们欢迎您为 Dify 做出贡献,以帮助改善 Dify。包括:提交代码、问题、新想法,或分享您基于 Dify 创建的有趣且有用的 AI 应用程序。同时,我们也欢迎您在不同的活动、会议和社交媒体上分享 Dify。 - [GitHub Discussion](https://github.com/langgenius/dify/discussions). 👉:分享您的应用程序并与社区交流。 -- [GitHub Issues](https://github.com/langgenius/dify/issues)。👉:使用 Dify.AI 时遇到的错误和问题,请参阅[贡献指南](CONTRIBUTING.md)。 +- [GitHub Issues](https://github.com/langgenius/dify/issues)。👉:使用 Dify.AI 时遇到的错误和问题,请参阅[贡献指南](./CONTRIBUTING.md)。 - [电子邮件支持](mailto:hello@dify.ai?subject=%5BGitHub%5DQuestions%20About%20Dify)。👉:关于使用 Dify.AI 的问题。 - [Discord](https://discord.gg/FngNHpbcY7)。👉:分享您的应用程序并与社区交流。 - [X(Twitter)](https://twitter.com/dify_ai)。👉:分享您的应用程序并与社区交流。 @@ -208,4 +210,4 @@ docker compose up -d ## License -本仓库遵循 [Dify Open Source License](LICENSE) 开源协议,该许可证本质上是 Apache 2.0,但有一些额外的限制。 +本仓库遵循 [Dify Open Source License](../../LICENSE) 开源协议,该许可证本质上是 Apache 2.0,但有一些额外的限制。 diff --git a/CONTRIBUTING_TW.md b/docs/zh-TW/CONTRIBUTING.md similarity index 96% rename from CONTRIBUTING_TW.md rename to docs/zh-TW/CONTRIBUTING.md index 5c4d7022fe..1d5f02efa1 100644 --- a/CONTRIBUTING_TW.md +++ b/docs/zh-TW/CONTRIBUTING.md @@ -6,7 +6,7 @@ 這份指南與 Dify 一樣,都在持續完善中。如果指南內容有落後於實際專案的情況,還請見諒,也歡迎提供改進建議。 -關於授權部分,請花點時間閱讀我們簡短的[授權和貢獻者協議](./LICENSE)。社群也需遵守[行為準則](https://github.com/langgenius/.github/blob/main/CODE_OF_CONDUCT.md)。 +關於授權部分,請花點時間閱讀我們簡短的[授權和貢獻者協議](../../LICENSE)。社群也需遵守[行為準則](https://github.com/langgenius/.github/blob/main/CODE_OF_CONDUCT.md)。 ## 開始之前 diff --git a/README_TW.md b/docs/zh-TW/README.md similarity index 80% rename from README_TW.md rename to docs/zh-TW/README.md index 18d0724784..526e8d9c8c 100644 --- a/README_TW.md +++ b/docs/zh-TW/README.md @@ -1,4 +1,4 @@ -![cover-v5-optimized](./images/GitHub_README_if.png) +![cover-v5-optimized](../../images/GitHub_README_if.png)

📌 介紹 Dify 工作流程檔案上傳功能:重現 Google NotebookLM Podcast @@ -39,18 +39,18 @@

- README in English - 繁體中文文件 - 简体中文版自述文件 - 日本語のREADME - README en Español - README en Français - README tlhIngan Hol - README in Korean - README بالعربية - Türkçe README - README Tiếng Việt - README in Deutsch + README in English + 繁體中文文件 + 简体中文文件 + 日本語のREADME + README en Español + README en Français + README tlhIngan Hol + README in Korean + README بالعربية + Türkçe README + README Tiếng Việt + README in Deutsch

Dify 是一個開源的 LLM 應用程式開發平台。其直觀的界面結合了智能代理工作流程、RAG 管道、代理功能、模型管理、可觀察性功能等,讓您能夠快速從原型進展到生產環境。 @@ -64,7 +64,7 @@ Dify 是一個開源的 LLM 應用程式開發平台。其直觀的界面結合
-啟動 Dify 伺服器最簡單的方式是透過 [docker compose](docker/docker-compose.yaml)。在使用以下命令運行 Dify 之前,請確保您的機器已安裝 [Docker](https://docs.docker.com/get-docker/) 和 [Docker Compose](https://docs.docker.com/compose/install/): +啟動 Dify 伺服器最簡單的方式是透過 [docker compose](../../docker/docker-compose.yaml)。在使用以下命令運行 Dify 之前,請確保您的機器已安裝 [Docker](https://docs.docker.com/get-docker/) 和 [Docker Compose](https://docs.docker.com/compose/install/): ```bash cd dify @@ -128,7 +128,7 @@ Dify 的所有功能都提供相應的 API,因此您可以輕鬆地將 Dify ## 進階設定 -如果您需要自定義配置,請參考我們的 [.env.example](docker/.env.example) 文件中的註釋,並在您的 `.env` 文件中更新相應的值。此外,根據您特定的部署環境和需求,您可能需要調整 `docker-compose.yaml` 文件本身,例如更改映像版本、端口映射或卷掛載。進行任何更改後,請重新運行 `docker-compose up -d`。您可以在[這裡](https://docs.dify.ai/getting-started/install-self-hosted/environments)找到可用環境變數的完整列表。 +如果您需要自定義配置,請參考我們的 [.env.example](../../docker/.env.example) 文件中的註釋,並在您的 `.env` 文件中更新相應的值。此外,根據您特定的部署環境和需求,您可能需要調整 `docker-compose.yaml` 文件本身,例如更改映像版本、端口映射或卷掛載。進行任何更改後,請重新運行 `docker-compose up -d`。您可以在[這裡](https://docs.dify.ai/getting-started/install-self-hosted/environments)找到可用環境變數的完整列表。 如果您想配置高可用性設置,社區貢獻的 [Helm Charts](https://helm.sh/) 和 Kubernetes 資源清單(YAML)允許在 Kubernetes 上部署 Dify。 @@ -173,7 +173,7 @@ Dify 的所有功能都提供相應的 API,因此您可以輕鬆地將 Dify ## 貢獻 -對於想要貢獻程式碼的開發者,請參閱我們的[貢獻指南](https://github.com/langgenius/dify/blob/main/CONTRIBUTING_TW.md)。 +對於想要貢獻程式碼的開發者,請參閱我們的[貢獻指南](./CONTRIBUTING.md)。 同時,也請考慮透過在社群媒體和各種活動與會議上分享 Dify 來支持我們。 > 我們正在尋找貢獻者協助將 Dify 翻譯成中文和英文以外的語言。如果您有興趣幫忙,請查看 [i18n README](https://github.com/langgenius/dify/blob/main/web/i18n-config/README.md) 獲取更多資訊,並在我們的 [Discord 社群伺服器](https://discord.gg/8Tpq4AcN9c) 的 `global-users` 頻道留言給我們。 @@ -181,7 +181,7 @@ Dify 的所有功能都提供相應的 API,因此您可以輕鬆地將 Dify ## 社群與聯絡方式 - [GitHub Discussion](https://github.com/langgenius/dify/discussions):最適合分享反饋和提問。 -- [GitHub Issues](https://github.com/langgenius/dify/issues):最適合報告使用 Dify.AI 時遇到的問題和提出功能建議。請參閱我們的[貢獻指南](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md)。 +- [GitHub Issues](https://github.com/langgenius/dify/issues):最適合報告使用 Dify.AI 時遇到的問題和提出功能建議。請參閱我們的[貢獻指南](./CONTRIBUTING.md)。 - [Discord](https://discord.gg/FngNHpbcY7):最適合分享您的應用程式並與社群互動。 - [X(Twitter)](https://twitter.com/dify_ai):最適合分享您的應用程式並與社群互動。 @@ -201,4 +201,4 @@ Dify 的所有功能都提供相應的 API,因此您可以輕鬆地將 Dify ## 授權條款 -本代碼庫採用 [Dify 開源授權](LICENSE),這基本上是 Apache 2.0 授權加上一些額外限制條款。 +本代碼庫採用 [Dify 開源授權](../../LICENSE),這基本上是 Apache 2.0 授權加上一些額外限制條款。 diff --git a/scripts/stress-test/setup/import_workflow_app.py b/scripts/stress-test/setup/import_workflow_app.py index 86d0239e35..41a76bd29b 100755 --- a/scripts/stress-test/setup/import_workflow_app.py +++ b/scripts/stress-test/setup/import_workflow_app.py @@ -8,7 +8,7 @@ sys.path.append(str(Path(__file__).parent.parent)) import json import httpx -from common import Logger, config_helper +from common import Logger, config_helper # type: ignore[import] def import_workflow_app() -> None: diff --git a/sdks/python-client/dify_client/__init__.py b/sdks/python-client/dify_client/__init__.py index e866472f45..e252bc0472 100644 --- a/sdks/python-client/dify_client/__init__.py +++ b/sdks/python-client/dify_client/__init__.py @@ -4,6 +4,7 @@ from dify_client.client import ( DifyClient, KnowledgeBaseClient, WorkflowClient, + WorkspaceClient, ) __all__ = [ @@ -12,4 +13,5 @@ __all__ = [ "DifyClient", "KnowledgeBaseClient", "WorkflowClient", + "WorkspaceClient", ] diff --git a/sdks/python-client/dify_client/client.py b/sdks/python-client/dify_client/client.py index 791cb98a1b..fb42e3773d 100644 --- a/sdks/python-client/dify_client/client.py +++ b/sdks/python-client/dify_client/client.py @@ -1,5 +1,6 @@ import json -from typing import Literal +from typing import Literal, Union, Dict, List, Any, Optional, IO + import requests @@ -49,6 +50,18 @@ class DifyClient: params = {"user": user} return self._send_request("GET", "/meta", params=params) + def get_app_info(self): + """Get basic application information including name, description, tags, and mode.""" + return self._send_request("GET", "/info") + + def get_app_site_info(self): + """Get application site information.""" + return self._send_request("GET", "/site") + + def get_file_preview(self, file_id: str): + """Get file preview by file ID.""" + return self._send_request("GET", f"/files/{file_id}/preview") + class CompletionClient(DifyClient): def create_completion_message( @@ -139,11 +152,56 @@ class ChatClient(DifyClient): data = {"user": user} return self._send_request("DELETE", f"/conversations/{conversation_id}", data) - def audio_to_text(self, audio_file: dict, user: str): + def audio_to_text(self, audio_file: IO[bytes] | tuple, user: str): data = {"user": user} - files = {"audio_file": audio_file} + files = {"file": audio_file} return self._send_request_with_files("POST", "/audio-to-text", data, files) + # Annotation APIs + def annotation_reply_action( + self, + action: Literal["enable", "disable"], + score_threshold: float, + embedding_provider_name: str, + embedding_model_name: str, + ): + """Enable or disable annotation reply feature.""" + # Backend API requires these fields to be non-None values + if score_threshold is None or embedding_provider_name is None or embedding_model_name is None: + raise ValueError("score_threshold, embedding_provider_name, and embedding_model_name cannot be None") + + data = { + "score_threshold": score_threshold, + "embedding_provider_name": embedding_provider_name, + "embedding_model_name": embedding_model_name, + } + return self._send_request("POST", f"/apps/annotation-reply/{action}", json=data) + + def get_annotation_reply_status(self, action: Literal["enable", "disable"], job_id: str): + """Get the status of an annotation reply action job.""" + return self._send_request("GET", f"/apps/annotation-reply/{action}/status/{job_id}") + + def list_annotations(self, page: int = 1, limit: int = 20, keyword: str = ""): + """List annotations for the application.""" + params = {"page": page, "limit": limit} + if keyword: + params["keyword"] = keyword + return self._send_request("GET", "/apps/annotations", params=params) + + def create_annotation(self, question: str, answer: str): + """Create a new annotation.""" + data = {"question": question, "answer": answer} + return self._send_request("POST", "/apps/annotations", json=data) + + def update_annotation(self, annotation_id: str, question: str, answer: str): + """Update an existing annotation.""" + data = {"question": question, "answer": answer} + return self._send_request("PUT", f"/apps/annotations/{annotation_id}", json=data) + + def delete_annotation(self, annotation_id: str): + """Delete an annotation.""" + return self._send_request("DELETE", f"/apps/annotations/{annotation_id}") + class WorkflowClient(DifyClient): def run(self, inputs: dict, response_mode: Literal["blocking", "streaming"] = "streaming", user: str = "abc-123"): @@ -157,6 +215,55 @@ class WorkflowClient(DifyClient): def get_result(self, workflow_run_id): return self._send_request("GET", f"/workflows/run/{workflow_run_id}") + def get_workflow_logs( + self, + keyword: str = None, + status: Literal["succeeded", "failed", "stopped"] | None = None, + page: int = 1, + limit: int = 20, + created_at__before: str = None, + created_at__after: str = None, + created_by_end_user_session_id: str = None, + created_by_account: str = None, + ): + """Get workflow execution logs with optional filtering.""" + params = {"page": page, "limit": limit} + if keyword: + params["keyword"] = keyword + if status: + params["status"] = status + if created_at__before: + params["created_at__before"] = created_at__before + if created_at__after: + params["created_at__after"] = created_at__after + if created_by_end_user_session_id: + params["created_by_end_user_session_id"] = created_by_end_user_session_id + if created_by_account: + params["created_by_account"] = created_by_account + return self._send_request("GET", "/workflows/logs", params=params) + + def run_specific_workflow( + self, + workflow_id: str, + inputs: dict, + response_mode: Literal["blocking", "streaming"] = "streaming", + user: str = "abc-123", + ): + """Run a specific workflow by workflow ID.""" + data = {"inputs": inputs, "response_mode": response_mode, "user": user} + return self._send_request( + "POST", f"/workflows/{workflow_id}/run", data, stream=True if response_mode == "streaming" else False + ) + + +class WorkspaceClient(DifyClient): + """Client for workspace-related operations.""" + + def get_available_models(self, model_type: str): + """Get available models by model type.""" + url = f"/workspaces/current/models/model-types/{model_type}" + return self._send_request("GET", url) + class KnowledgeBaseClient(DifyClient): def __init__( @@ -443,3 +550,117 @@ class KnowledgeBaseClient(DifyClient): data = {"segment": segment_data} url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/segments/{segment_id}" return self._send_request("POST", url, json=data, **kwargs) + + # Advanced Knowledge Base APIs + def hit_testing( + self, query: str, retrieval_model: Dict[str, Any] = None, external_retrieval_model: Dict[str, Any] = None + ): + """Perform hit testing on the dataset.""" + data = {"query": query} + if retrieval_model: + data["retrieval_model"] = retrieval_model + if external_retrieval_model: + data["external_retrieval_model"] = external_retrieval_model + url = f"/datasets/{self._get_dataset_id()}/hit-testing" + return self._send_request("POST", url, json=data) + + def get_dataset_metadata(self): + """Get dataset metadata.""" + url = f"/datasets/{self._get_dataset_id()}/metadata" + return self._send_request("GET", url) + + def create_dataset_metadata(self, metadata_data: Dict[str, Any]): + """Create dataset metadata.""" + url = f"/datasets/{self._get_dataset_id()}/metadata" + return self._send_request("POST", url, json=metadata_data) + + def update_dataset_metadata(self, metadata_id: str, metadata_data: Dict[str, Any]): + """Update dataset metadata.""" + url = f"/datasets/{self._get_dataset_id()}/metadata/{metadata_id}" + return self._send_request("PATCH", url, json=metadata_data) + + def get_built_in_metadata(self): + """Get built-in metadata.""" + url = f"/datasets/{self._get_dataset_id()}/metadata/built-in" + return self._send_request("GET", url) + + def manage_built_in_metadata(self, action: str, metadata_data: Dict[str, Any] = None): + """Manage built-in metadata with specified action.""" + data = metadata_data or {} + url = f"/datasets/{self._get_dataset_id()}/metadata/built-in/{action}" + return self._send_request("POST", url, json=data) + + def update_documents_metadata(self, operation_data: List[Dict[str, Any]]): + """Update metadata for multiple documents.""" + url = f"/datasets/{self._get_dataset_id()}/documents/metadata" + data = {"operation_data": operation_data} + return self._send_request("POST", url, json=data) + + # Dataset Tags APIs + def list_dataset_tags(self): + """List all dataset tags.""" + return self._send_request("GET", "/datasets/tags") + + def bind_dataset_tags(self, tag_ids: List[str]): + """Bind tags to dataset.""" + data = {"tag_ids": tag_ids, "target_id": self._get_dataset_id()} + return self._send_request("POST", "/datasets/tags/binding", json=data) + + def unbind_dataset_tag(self, tag_id: str): + """Unbind a single tag from dataset.""" + data = {"tag_id": tag_id, "target_id": self._get_dataset_id()} + return self._send_request("POST", "/datasets/tags/unbinding", json=data) + + def get_dataset_tags(self): + """Get tags for current dataset.""" + url = f"/datasets/{self._get_dataset_id()}/tags" + return self._send_request("GET", url) + + # RAG Pipeline APIs + def get_datasource_plugins(self, is_published: bool = True): + """Get datasource plugins for RAG pipeline.""" + params = {"is_published": is_published} + url = f"/datasets/{self._get_dataset_id()}/pipeline/datasource-plugins" + return self._send_request("GET", url, params=params) + + def run_datasource_node( + self, + node_id: str, + inputs: Dict[str, Any], + datasource_type: str, + is_published: bool = True, + credential_id: str = None, + ): + """Run a datasource node in RAG pipeline.""" + data = {"inputs": inputs, "datasource_type": datasource_type, "is_published": is_published} + if credential_id: + data["credential_id"] = credential_id + url = f"/datasets/{self._get_dataset_id()}/pipeline/datasource/nodes/{node_id}/run" + return self._send_request("POST", url, json=data, stream=True) + + def run_rag_pipeline( + self, + inputs: Dict[str, Any], + datasource_type: str, + datasource_info_list: List[Dict[str, Any]], + start_node_id: str, + is_published: bool = True, + response_mode: Literal["streaming", "blocking"] = "blocking", + ): + """Run RAG pipeline.""" + data = { + "inputs": inputs, + "datasource_type": datasource_type, + "datasource_info_list": datasource_info_list, + "start_node_id": start_node_id, + "is_published": is_published, + "response_mode": response_mode, + } + url = f"/datasets/{self._get_dataset_id()}/pipeline/run" + return self._send_request("POST", url, json=data, stream=response_mode == "streaming") + + def upload_pipeline_file(self, file_path: str): + """Upload file for RAG pipeline.""" + with open(file_path, "rb") as f: + files = {"file": f} + return self._send_request_with_files("POST", "/datasets/pipeline/file-upload", {}, files) diff --git a/sdks/python-client/tests/test_new_apis.py b/sdks/python-client/tests/test_new_apis.py new file mode 100644 index 0000000000..09c62dfda7 --- /dev/null +++ b/sdks/python-client/tests/test_new_apis.py @@ -0,0 +1,416 @@ +#!/usr/bin/env python3 +""" +Test suite for the new Service API functionality in the Python SDK. + +This test validates the implementation of the missing Service API endpoints +that were added to the Python SDK to achieve complete coverage. +""" + +import unittest +from unittest.mock import Mock, patch, MagicMock +import json + +from dify_client import ( + DifyClient, + ChatClient, + WorkflowClient, + KnowledgeBaseClient, + WorkspaceClient, +) + + +class TestNewServiceAPIs(unittest.TestCase): + """Test cases for new Service API implementations.""" + + def setUp(self): + """Set up test fixtures.""" + self.api_key = "test-api-key" + self.base_url = "https://api.dify.ai/v1" + + @patch("dify_client.client.requests.request") + def test_app_info_apis(self, mock_request): + """Test application info APIs.""" + mock_response = Mock() + mock_response.json.return_value = { + "name": "Test App", + "description": "Test Description", + "tags": ["test", "api"], + "mode": "chat", + "author_name": "Test Author", + } + mock_request.return_value = mock_response + + client = DifyClient(self.api_key, self.base_url) + + # Test get_app_info + result = client.get_app_info() + mock_request.assert_called_with( + "GET", + f"{self.base_url}/info", + json=None, + params=None, + headers={ + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + }, + stream=False, + ) + + # Test get_app_site_info + client.get_app_site_info() + mock_request.assert_called_with( + "GET", + f"{self.base_url}/site", + json=None, + params=None, + headers={ + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + }, + stream=False, + ) + + # Test get_file_preview + file_id = "test-file-id" + client.get_file_preview(file_id) + mock_request.assert_called_with( + "GET", + f"{self.base_url}/files/{file_id}/preview", + json=None, + params=None, + headers={ + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + }, + stream=False, + ) + + @patch("dify_client.client.requests.request") + def test_annotation_apis(self, mock_request): + """Test annotation APIs.""" + mock_response = Mock() + mock_response.json.return_value = {"result": "success"} + mock_request.return_value = mock_response + + client = ChatClient(self.api_key, self.base_url) + + # Test annotation_reply_action - enable + client.annotation_reply_action( + action="enable", + score_threshold=0.8, + embedding_provider_name="openai", + embedding_model_name="text-embedding-ada-002", + ) + mock_request.assert_called_with( + "POST", + f"{self.base_url}/apps/annotation-reply/enable", + json={ + "score_threshold": 0.8, + "embedding_provider_name": "openai", + "embedding_model_name": "text-embedding-ada-002", + }, + params=None, + headers={ + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + }, + stream=False, + ) + + # Test annotation_reply_action - disable (now requires same fields as enable) + client.annotation_reply_action( + action="disable", + score_threshold=0.5, + embedding_provider_name="openai", + embedding_model_name="text-embedding-ada-002", + ) + + # Test annotation_reply_action with score_threshold=0 (edge case) + client.annotation_reply_action( + action="enable", + score_threshold=0.0, # This should work and not raise ValueError + embedding_provider_name="openai", + embedding_model_name="text-embedding-ada-002", + ) + + # Test get_annotation_reply_status + client.get_annotation_reply_status("enable", "job-123") + + # Test list_annotations + client.list_annotations(page=1, limit=20, keyword="test") + + # Test create_annotation + client.create_annotation("Test question?", "Test answer.") + + # Test update_annotation + client.update_annotation("annotation-123", "Updated question?", "Updated answer.") + + # Test delete_annotation + client.delete_annotation("annotation-123") + + # Verify all calls were made (8 calls: enable + disable + enable with 0.0 + 5 other operations) + self.assertEqual(mock_request.call_count, 8) + + @patch("dify_client.client.requests.request") + def test_knowledge_base_advanced_apis(self, mock_request): + """Test advanced knowledge base APIs.""" + mock_response = Mock() + mock_response.json.return_value = {"result": "success"} + mock_request.return_value = mock_response + + dataset_id = "test-dataset-id" + client = KnowledgeBaseClient(self.api_key, self.base_url, dataset_id) + + # Test hit_testing + client.hit_testing("test query", {"type": "vector"}) + mock_request.assert_called_with( + "POST", + f"{self.base_url}/datasets/{dataset_id}/hit-testing", + json={"query": "test query", "retrieval_model": {"type": "vector"}}, + params=None, + headers={ + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + }, + stream=False, + ) + + # Test metadata operations + client.get_dataset_metadata() + client.create_dataset_metadata({"key": "value"}) + client.update_dataset_metadata("meta-123", {"key": "new_value"}) + client.get_built_in_metadata() + client.manage_built_in_metadata("enable", {"type": "built_in"}) + client.update_documents_metadata([{"document_id": "doc1", "metadata": {"key": "value"}}]) + + # Test tag operations + client.list_dataset_tags() + client.bind_dataset_tags(["tag1", "tag2"]) + client.unbind_dataset_tag("tag1") + client.get_dataset_tags() + + # Verify multiple calls were made + self.assertGreater(mock_request.call_count, 5) + + @patch("dify_client.client.requests.request") + def test_rag_pipeline_apis(self, mock_request): + """Test RAG pipeline APIs.""" + mock_response = Mock() + mock_response.json.return_value = {"result": "success"} + mock_request.return_value = mock_response + + dataset_id = "test-dataset-id" + client = KnowledgeBaseClient(self.api_key, self.base_url, dataset_id) + + # Test get_datasource_plugins + client.get_datasource_plugins(is_published=True) + mock_request.assert_called_with( + "GET", + f"{self.base_url}/datasets/{dataset_id}/pipeline/datasource-plugins", + json=None, + params={"is_published": True}, + headers={ + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + }, + stream=False, + ) + + # Test run_datasource_node + client.run_datasource_node( + node_id="node-123", + inputs={"param": "value"}, + datasource_type="online_document", + is_published=True, + credential_id="cred-123", + ) + + # Test run_rag_pipeline with blocking mode + client.run_rag_pipeline( + inputs={"query": "test"}, + datasource_type="online_document", + datasource_info_list=[{"id": "ds1"}], + start_node_id="start-node", + is_published=True, + response_mode="blocking", + ) + + # Test run_rag_pipeline with streaming mode + client.run_rag_pipeline( + inputs={"query": "test"}, + datasource_type="online_document", + datasource_info_list=[{"id": "ds1"}], + start_node_id="start-node", + is_published=True, + response_mode="streaming", + ) + + self.assertEqual(mock_request.call_count, 4) + + @patch("dify_client.client.requests.request") + def test_workspace_apis(self, mock_request): + """Test workspace APIs.""" + mock_response = Mock() + mock_response.json.return_value = { + "data": [{"name": "gpt-3.5-turbo", "type": "llm"}, {"name": "gpt-4", "type": "llm"}] + } + mock_request.return_value = mock_response + + client = WorkspaceClient(self.api_key, self.base_url) + + # Test get_available_models + result = client.get_available_models("llm") + mock_request.assert_called_with( + "GET", + f"{self.base_url}/workspaces/current/models/model-types/llm", + json=None, + params=None, + headers={ + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + }, + stream=False, + ) + + @patch("dify_client.client.requests.request") + def test_workflow_advanced_apis(self, mock_request): + """Test advanced workflow APIs.""" + mock_response = Mock() + mock_response.json.return_value = {"result": "success"} + mock_request.return_value = mock_response + + client = WorkflowClient(self.api_key, self.base_url) + + # Test get_workflow_logs + client.get_workflow_logs(keyword="test", status="succeeded", page=1, limit=20) + mock_request.assert_called_with( + "GET", + f"{self.base_url}/workflows/logs", + json=None, + params={"page": 1, "limit": 20, "keyword": "test", "status": "succeeded"}, + headers={ + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + }, + stream=False, + ) + + # Test get_workflow_logs with additional filters + client.get_workflow_logs( + keyword="test", + status="succeeded", + page=1, + limit=20, + created_at__before="2024-01-01", + created_at__after="2023-01-01", + created_by_account="user123", + ) + + # Test run_specific_workflow + client.run_specific_workflow( + workflow_id="workflow-123", inputs={"param": "value"}, response_mode="streaming", user="user-123" + ) + + self.assertEqual(mock_request.call_count, 3) + + def test_error_handling(self): + """Test error handling for required parameters.""" + client = ChatClient(self.api_key, self.base_url) + + # Test annotation_reply_action with missing required parameters would be a TypeError now + # since parameters are required in method signature + with self.assertRaises(TypeError): + client.annotation_reply_action("enable") + + # Test annotation_reply_action with explicit None values should raise ValueError + with self.assertRaises(ValueError) as context: + client.annotation_reply_action("enable", None, "provider", "model") + + self.assertIn("cannot be None", str(context.exception)) + + # Test KnowledgeBaseClient without dataset_id + kb_client = KnowledgeBaseClient(self.api_key, self.base_url) + with self.assertRaises(ValueError) as context: + kb_client.hit_testing("test query") + + self.assertIn("dataset_id is not set", str(context.exception)) + + @patch("dify_client.client.open") + @patch("dify_client.client.requests.request") + def test_file_upload_apis(self, mock_request, mock_open): + """Test file upload APIs.""" + mock_response = Mock() + mock_response.json.return_value = {"result": "success"} + mock_request.return_value = mock_response + + mock_file = MagicMock() + mock_open.return_value.__enter__.return_value = mock_file + + dataset_id = "test-dataset-id" + client = KnowledgeBaseClient(self.api_key, self.base_url, dataset_id) + + # Test upload_pipeline_file + client.upload_pipeline_file("/path/to/test.pdf") + + mock_open.assert_called_with("/path/to/test.pdf", "rb") + mock_request.assert_called_once() + + def test_comprehensive_coverage(self): + """Test that all previously missing APIs are now implemented.""" + + # Test DifyClient methods + dify_methods = ["get_app_info", "get_app_site_info", "get_file_preview"] + client = DifyClient(self.api_key) + for method in dify_methods: + self.assertTrue(hasattr(client, method), f"DifyClient missing method: {method}") + + # Test ChatClient annotation methods + chat_methods = [ + "annotation_reply_action", + "get_annotation_reply_status", + "list_annotations", + "create_annotation", + "update_annotation", + "delete_annotation", + ] + chat_client = ChatClient(self.api_key) + for method in chat_methods: + self.assertTrue(hasattr(chat_client, method), f"ChatClient missing method: {method}") + + # Test WorkflowClient advanced methods + workflow_methods = ["get_workflow_logs", "run_specific_workflow"] + workflow_client = WorkflowClient(self.api_key) + for method in workflow_methods: + self.assertTrue(hasattr(workflow_client, method), f"WorkflowClient missing method: {method}") + + # Test KnowledgeBaseClient advanced methods + kb_methods = [ + "hit_testing", + "get_dataset_metadata", + "create_dataset_metadata", + "update_dataset_metadata", + "get_built_in_metadata", + "manage_built_in_metadata", + "update_documents_metadata", + "list_dataset_tags", + "bind_dataset_tags", + "unbind_dataset_tag", + "get_dataset_tags", + "get_datasource_plugins", + "run_datasource_node", + "run_rag_pipeline", + "upload_pipeline_file", + ] + kb_client = KnowledgeBaseClient(self.api_key) + for method in kb_methods: + self.assertTrue(hasattr(kb_client, method), f"KnowledgeBaseClient missing method: {method}") + + # Test WorkspaceClient methods + workspace_methods = ["get_available_models"] + workspace_client = WorkspaceClient(self.api_key) + for method in workspace_methods: + self.assertTrue(hasattr(workspace_client, method), f"WorkspaceClient missing method: {method}") + + +if __name__ == "__main__": + unittest.main() diff --git a/spec.http b/spec.http deleted file mode 100644 index dc3a37d08a..0000000000 --- a/spec.http +++ /dev/null @@ -1,4 +0,0 @@ -GET /console/api/spec/schema-definitions -Host: cloud-rag.dify.dev -authorization: Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJ1c2VyX2lkIjoiNzExMDZhYTQtZWJlMC00NGMzLWI4NWYtMWQ4Mjc5ZTExOGZmIiwiZXhwIjoxNzU2MTkyNDE4LCJpc3MiOiJDTE9VRCIsInN1YiI6IkNvbnNvbGUgQVBJIFBhc3Nwb3J0In0.Yx_TMdWVXCp5YEoQ8WR90lRhHHKggxAQvEl5RUnkZuc -### \ No newline at end of file diff --git a/web/__tests__/document-detail-navigation-fix.test.tsx b/web/__tests__/document-detail-navigation-fix.test.tsx index 200ed09ea9..a358744998 100644 --- a/web/__tests__/document-detail-navigation-fix.test.tsx +++ b/web/__tests__/document-detail-navigation-fix.test.tsx @@ -54,7 +54,7 @@ const DocumentDetailWithFix = ({ datasetId, documentId }: { datasetId: string; d return (
-
diff --git a/web/__tests__/goto-anything/command-selector.test.tsx b/web/__tests__/goto-anything/command-selector.test.tsx index 1db4be31fb..6d4e045d49 100644 --- a/web/__tests__/goto-anything/command-selector.test.tsx +++ b/web/__tests__/goto-anything/command-selector.test.tsx @@ -16,7 +16,7 @@ jest.mock('cmdk', () => ({ Item: ({ children, onSelect, value, className }: any) => (
onSelect && onSelect()} + onClick={() => onSelect?.()} data-value={value} data-testid={`command-item-${value}`} > diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/config-button.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/config-button.tsx index 1ab40e31bf..246a1eb6a3 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/config-button.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/config-button.tsx @@ -4,6 +4,7 @@ import React, { useCallback, useRef, useState } from 'react' import type { PopupProps } from './config-popup' import ConfigPopup from './config-popup' +import cn from '@/utils/classnames' import { PortalToFollowElem, PortalToFollowElemContent, @@ -45,7 +46,7 @@ const ConfigBtn: FC = ({ offset={12} > -
+
{children}
diff --git a/web/app/components/app-sidebar/app-info.tsx b/web/app/components/app-sidebar/app-info.tsx index d22577c9ad..baf52946df 100644 --- a/web/app/components/app-sidebar/app-info.tsx +++ b/web/app/components/app-sidebar/app-info.tsx @@ -260,7 +260,7 @@ const AppInfo = ({ expand, onlyShowDetail = false, openState = false, onDetailEx return (
{!onlyShowDetail && ( - - -
- - - {(app.mode === 'completion' || app.mode === 'chat') && ( <> : !(isGettingUserCanAccessApp || !userCanAccessApp?.result) && ( <> - @@ -300,13 +301,14 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => { { systemFeatures.webapp_auth.enabled && isCurrentWorkspaceEditor && <> - } - + ((e.target as HTMLInputElement).value = '')} diff --git a/web/app/components/base/app-icon-picker/index.tsx b/web/app/components/base/app-icon-picker/index.tsx index a8de07bf6b..3deb6a6c8f 100644 --- a/web/app/components/base/app-icon-picker/index.tsx +++ b/web/app/components/base/app-icon-picker/index.tsx @@ -117,7 +117,7 @@ const AppIconPicker: FC = ({ {!DISABLE_UPLOAD_IMAGE_AS_ICON &&
{tabs.map(tab => ( -
- -
)}
diff --git a/web/app/components/base/emoji-picker/index.tsx b/web/app/components/base/emoji-picker/index.tsx index d3b20bb507..7b91c62797 100644 --- a/web/app/components/base/emoji-picker/index.tsx +++ b/web/app/components/base/emoji-picker/index.tsx @@ -45,7 +45,7 @@ const EmojiPicker: FC = ({
@@ -54,7 +54,7 @@ const EmojiPicker: FC = ({ variant="primary" className='w-full' onClick={() => { - onSelect && onSelect(selectedEmoji, selectedBackground!) + onSelect?.(selectedEmoji, selectedBackground!) }}> {t('app.iconPicker.ok')} diff --git a/web/app/components/base/form/components/field/select.tsx b/web/app/components/base/form/components/field/select.tsx index f12b90335b..dee047e2eb 100644 --- a/web/app/components/base/form/components/field/select.tsx +++ b/web/app/components/base/form/components/field/select.tsx @@ -33,7 +33,10 @@ const SelectField = ({ field.handleChange(value)} + onChange={(value) => { + field.handleChange(value) + onChange?.(value) + }} {...selectProps} />
diff --git a/web/app/components/base/icons/assets/public/tracing/aliyun-icon-big.svg b/web/app/components/base/icons/assets/public/tracing/aliyun-icon-big.svg index 210a1cd00b..d82b9bc1e4 100644 --- a/web/app/components/base/icons/assets/public/tracing/aliyun-icon-big.svg +++ b/web/app/components/base/icons/assets/public/tracing/aliyun-icon-big.svg @@ -1 +1 @@ - + \ No newline at end of file diff --git a/web/app/components/base/icons/assets/public/tracing/aliyun-icon.svg b/web/app/components/base/icons/assets/public/tracing/aliyun-icon.svg index 6f7645301c..cee8858471 100644 --- a/web/app/components/base/icons/assets/public/tracing/aliyun-icon.svg +++ b/web/app/components/base/icons/assets/public/tracing/aliyun-icon.svg @@ -1 +1 @@ - + \ No newline at end of file diff --git a/web/app/components/base/icons/src/public/tracing/AliyunIcon.json b/web/app/components/base/icons/src/public/tracing/AliyunIcon.json index 5cbb52c237..154aeff8c6 100644 --- a/web/app/components/base/icons/src/public/tracing/AliyunIcon.json +++ b/web/app/components/base/icons/src/public/tracing/AliyunIcon.json @@ -1,118 +1,129 @@ { - "icon": { - "type": "element", - "isRootNode": true, - "name": "svg", - "attributes": { - "xmlns": "http://www.w3.org/2000/svg", - "xmlns:xlink": "http://www.w3.org/1999/xlink", - "fill": "none", - "version": "1.1", - "width": "65", - "height": "16", - "viewBox": "0 0 65 16" - }, - "children": [ - { - "type": "element", - "name": "defs", - "children": [ - { - "type": "element", - "name": "clipPath", - "attributes": { - "id": "master_svg0_42_34281" - }, - "children": [ - { - "type": "element", - "name": "rect", - "attributes": { - "x": "0", - "y": "0", - "width": "19", - "height": "16", - "rx": "0" - } - } - ] - } - ] - }, - { - "type": "element", - "name": "g", - "children": [ - { - "type": "element", - "name": "g", - "attributes": { - "clip-path": "url(#master_svg0_42_34281)" - }, - "children": [ - { - "type": "element", - "name": "g", - "children": [ - { - "type": "element", - "name": "g", - "children": [ - { - "type": "element", - "name": "path", - "attributes": { - "d": "M4.06862,14.6667C3.79213,14.6667,3.45463,14.5688,3.05614,14.373C2.97908,14.3351,2.92692,14.3105,2.89968,14.2992C2.33193,14.0628,1.82911,13.7294,1.39123,13.2989C0.463742,12.3871,0,11.2874,0,10C0,8.71258,0.463742,7.61293,1.39123,6.70107C2.16172,5.94358,3.06404,5.50073,4.09819,5.37252C4.23172,3.98276,4.81755,2.77756,5.85569,1.75693C7.04708,0.585642,8.4857,0,10.1716,0C11.5256,0,12.743,0.396982,13.8239,1.19095C14.8847,1.97019,15.61,2.97855,16,4.21604L14.7045,4.61063C14.4016,3.64918,13.8374,2.86532,13.0121,2.25905C12.1719,1.64191,11.2251,1.33333,10.1716,1.33333C8.8602,1.33333,7.74124,1.7888,6.81467,2.69974C5.88811,3.61067,5.42483,4.71076,5.42483,6L5.42483,6.66667L4.74673,6.66667C3.81172,6.66667,3.01288,6.99242,2.35021,7.64393C1.68754,8.2954,1.35621,9.08076,1.35621,10C1.35621,10.9192,1.68754,11.7046,2.35021,12.3561C2.66354,12.6641,3.02298,12.9026,3.42852,13.0714C3.48193,13.0937,3.55988,13.13,3.66237,13.1803C3.87004,13.2823,4.00545,13.3333,4.06862,13.3333L4.06862,14.6667Z", - "fill-rule": "evenodd", - "fill": "#FF6A00", - "fill-opacity": "1" - } - } - ] - }, - { - "type": "element", - "name": "g", - "children": [ - { - "type": "element", - "name": "path", - "attributes": { - "d": "M13.458613505859375,7.779393492279053C12.975613505859375,7.717463492279053,12.484813505859375,7.686503492279053,11.993983505859376,7.686503492279053C11.152583505859376,7.686503492279053,10.303403505859375,7.779393492279053,9.493183505859374,7.941943492279053C8.682953505859375,8.104503492279052,7.903893505859375,8.359943492279053,7.155983505859375,8.654083492279053C6.657383505859375,8.870823492279053,6.158783505859375,9.128843492279053,5.660181505859375,9.428153492279053C5.332974751859375,9.621673492279053,5.239486705859375,10.070633492279054,5.434253505859375,10.395743492279053L7.413073505859375,13.298533492279052C7.639003505859375,13.623603492279052,8.090863505859375,13.716463492279052,8.418073505859375,13.523003492279052C8.547913505859375,13.435263492279052,8.763453505859374,13.326893492279053,9.064693505859374,13.197863492279053C9.516553505859374,13.004333492279052,9.976203505859374,12.872733492279053,10.459223505859375,12.779863492279052C10.942243505859375,12.679263492279052,11.433053505859375,12.617333492279052,11.955023505859375,12.617333492279052L13.380683505859375,7.810353492279052L13.458613505859375,7.779393492279053ZM15.273813505859374,8.135463492279053L15.016753505859375,5.333333492279053L13.458613505859375,7.787133492279053C13.817013505859375,7.818093492279052,14.144213505859375,7.880023492279053,14.494743505859375,7.949683492279053C14.494743505859375,7.944523492279053,14.754433505859375,8.006453492279054,15.273813505859374,8.135463492279053ZM12.064083505859376,12.648273492279053L11.378523505859375,14.970463492279054L12.515943505859376,16.00003349227905L14.074083505859376,15.643933492279054L14.525943505859376,13.027603492279052C14.198743505859374,12.934663492279054,13.879283505859375,12.834063492279054,13.552083505859375,12.772133492279053C13.069083505859375,12.717933492279052,12.578283505859375,12.648273492279053,12.064083505859376,12.648273492279053ZM18.327743505859374,9.428153492279053C17.829143505859374,9.128843492279053,17.330543505859374,8.870823492279053,16.831943505859375,8.654083492279053C16.348943505859374,8.460573492279053,15.826943505859376,8.267053492279054,15.305013505859375,8.135463492279053L15.305013505859375,8.267053492279054L14.463613505859374,13.043063492279053C14.596083505859376,13.105003492279053,14.759683505859375,13.135933492279053,14.884283505859376,13.205603492279053C15.185523505859376,13.334623492279052,15.401043505859375,13.443003492279052,15.530943505859375,13.530733492279053C15.858143505859376,13.724263492279054,16.341143505859375,13.623603492279052,16.535943505859375,13.306263492279053L18.514743505859375,10.403483492279053C18.779643505859376,10.039673492279054,18.686143505859377,9.621673492279053,18.327743505859374,9.428153492279053Z", - "fill": "#FF6A00", - "fill-opacity": "1" - } - } - ] - } - ] - } - ] - }, - { - "type": "element", - "name": "g", - "children": [ - { - "type": "element", - "name": "g", - "children": [ - { - "type": "element", - "name": "path", - "attributes": { - "d": "M25.044,2.668L34.676,2.668L34.676,4.04L25.044,4.04L25.044,2.668ZM29.958,7.82Q29.258,9.066,28.355,10.41Q27.451999999999998,11.754,26.92,12.3L32.506,11.782Q31.442,10.158,30.84,9.346L32.058,8.562000000000001Q32.786,9.5,33.843,11.012Q34.9,12.524,35.516,13.546L34.214,14.526Q33.891999999999996,13.966,33.346000000000004,13.098Q32.016,13.182,29.734,13.378Q27.451999999999998,13.574,25.87,13.742L25.31,13.812L24.834,13.882L24.414,12.468Q24.708,12.37,24.862000000000002,12.265Q25.016,12.16,25.121,12.069Q25.226,11.978,25.268,11.936Q25.912,11.32,26.724,10.165Q27.536,9.01,28.208,7.82L23.854,7.82L23.854,6.434L35.866,6.434L35.866,7.82L29.958,7.82ZM42.656,7.414L42.656,8.576L41.354,8.576L41.354,1.814L42.656,1.87L42.656,7.036Q43.314,5.846,43.888000000000005,4.369Q44.462,2.892,44.714,1.6600000000000001L46.086,1.981999Q45.96,2.612,45.722,3.41L49.6,3.41L49.6,4.74L45.274,4.74Q44.616,6.56,43.706,8.128L42.656,7.414ZM38.596000000000004,2.346L39.884,2.402L39.884,8.212L38.596000000000004,8.212L38.596000000000004,2.346ZM46.184,4.964Q46.688,5.356,47.5,6.175Q48.312,6.994,48.788,7.582L47.751999999999995,8.59Q47.346000000000004,8.072,46.576,7.274Q45.806,6.476,45.204,5.902L46.184,4.964ZM48.41,9.01L48.41,12.706L49.894,12.706L49.894,13.966L37.391999999999996,13.966L37.391999999999996,12.706L38.848,12.706L38.848,9.01L48.41,9.01ZM41.676,10.256L40.164,10.256L40.164,12.706L41.676,12.706L41.676,10.256ZM42.908,12.706L44.364000000000004,12.706L44.364000000000004,10.256L42.908,10.256L42.908,12.706ZM45.582,12.706L47.108000000000004,12.706L47.108000000000004,10.256L45.582,10.256L45.582,12.706ZM54.906,7.456L55.116,8.394L54.178,8.814L54.178,12.818Q54.178,13.434,54.031,13.735Q53.884,14.036,53.534,14.162Q53.184,14.288,52.456,14.358L51.867999999999995,14.414L51.476,13.084L52.162,13.028Q52.512,13,52.652,12.958Q52.792,12.916,52.841,12.797Q52.89,12.678,52.89,12.384L52.89,9.36Q51.980000000000004,9.724,51.322,9.948L51.013999999999996,8.576Q51.798,8.324,52.89,7.876L52.89,5.524L51.42,5.524L51.42,4.166L52.89,4.166L52.89,1.7579989999999999L54.178,1.814L54.178,4.166L55.214,4.166L55.214,5.524L54.178,5.524L54.178,7.316L54.808,7.022L54.906,7.456ZM56.894,4.5440000000000005L56.894,6.098L55.564,6.098L55.564,3.256L58.686,3.256Q58.42,2.346,58.266,1.9260000000000002L59.624,1.7579989999999999Q59.848,2.276,60.142,3.256L63.25,3.256L63.25,6.098L61.962,6.098L61.962,4.5440000000000005L56.894,4.5440000000000005ZM59.008,6.322Q58.392,6.938,57.685,7.512Q56.978,8.086,55.956,8.841999999999999L55.242,7.764Q56.824,6.728,58.126,5.37L59.008,6.322ZM60.422,5.37Q61.024,5.776,62.095,6.581Q63.166,7.386,63.656,7.806L62.942,8.982Q62.368,8.45,61.332,7.652Q60.296,6.854,59.666,6.434L60.422,5.37ZM62.592,10.256L60.044,10.256L60.044,12.566L63.572,12.566L63.572,13.826L55.144,13.826L55.144,12.566L58.63,12.566L58.63,10.256L56.054,10.256L56.054,8.982L62.592,8.982L62.592,10.256Z", - "fill": "#FF6A00", - "fill-opacity": "1" - } - } - ] - } - ] - } - ] - } - ] - }, - "name": "AliyunIcon" + "icon": { + "type": "element", + "isRootNode": true, + "name": "svg", + "attributes": { + "xmlns": "http://www.w3.org/2000/svg", + "xmlns:xlink": "http://www.w3.org/1999/xlink", + "fill": "none", + "version": "1.1", + "width": "65", + "height": "16", + "viewBox": "0 0 65 16" + }, + "children": [ + { + "type": "element", + "name": "defs", + "attributes": {}, + "children": [ + { + "type": "element", + "name": "clipPath", + "attributes": { + "id": "master_svg0_42_34281" + }, + "children": [ + { + "type": "element", + "name": "rect", + "attributes": { + "x": "0", + "y": "0", + "width": "19", + "height": "16", + "rx": "0" + }, + "children": [] + } + ] + } + ] + }, + { + "type": "element", + "name": "g", + "attributes": {}, + "children": [ + { + "type": "element", + "name": "g", + "attributes": { + "clip-path": "url(#master_svg0_42_34281)" + }, + "children": [ + { + "type": "element", + "name": "g", + "attributes": {}, + "children": [ + { + "type": "element", + "name": "g", + "attributes": {}, + "children": [ + { + "type": "element", + "name": "path", + "attributes": { + "d": "M4.06862,14.6667C3.79213,14.6667,3.45463,14.5688,3.05614,14.373C2.97908,14.3351,2.92692,14.3105,2.89968,14.2992C2.33193,14.0628,1.82911,13.7294,1.39123,13.2989C0.463742,12.3871,0,11.2874,0,10C0,8.71258,0.463742,7.61293,1.39123,6.70107C2.16172,5.94358,3.06404,5.50073,4.09819,5.37252C4.23172,3.98276,4.81755,2.77756,5.85569,1.75693C7.04708,0.585642,8.4857,0,10.1716,0C11.5256,0,12.743,0.396982,13.8239,1.19095C14.8847,1.97019,15.61,2.97855,16,4.21604L14.7045,4.61063C14.4016,3.64918,13.8374,2.86532,13.0121,2.25905C12.1719,1.64191,11.2251,1.33333,10.1716,1.33333C8.8602,1.33333,7.74124,1.7888,6.81467,2.69974C5.88811,3.61067,5.42483,4.71076,5.42483,6L5.42483,6.66667L4.74673,6.66667C3.81172,6.66667,3.01288,6.99242,2.35021,7.64393C1.68754,8.2954,1.35621,9.08076,1.35621,10C1.35621,10.9192,1.68754,11.7046,2.35021,12.3561C2.66354,12.6641,3.02298,12.9026,3.42852,13.0714C3.48193,13.0937,3.55988,13.13,3.66237,13.1803C3.87004,13.2823,4.00545,13.3333,4.06862,13.3333L4.06862,14.6667Z", + "fill-rule": "evenodd", + "fill": "#FF6A00", + "fill-opacity": "1" + }, + "children": [] + } + ] + }, + { + "type": "element", + "name": "g", + "attributes": {}, + "children": [ + { + "type": "element", + "name": "path", + "attributes": { + "d": "M13.458613505859375,7.779393492279053C12.975613505859375,7.717463492279053,12.484813505859375,7.686503492279053,11.993983505859376,7.686503492279053C11.152583505859376,7.686503492279053,10.303403505859375,7.779393492279053,9.493183505859374,7.941943492279053C8.682953505859375,8.104503492279052,7.903893505859375,8.359943492279053,7.155983505859375,8.654083492279053C6.657383505859375,8.870823492279053,6.158783505859375,9.128843492279053,5.660181505859375,9.428153492279053C5.332974751859375,9.621673492279053,5.239486705859375,10.070633492279054,5.434253505859375,10.395743492279053L7.413073505859375,13.298533492279052C7.639003505859375,13.623603492279052,8.090863505859375,13.716463492279052,8.418073505859375,13.523003492279052C8.547913505859375,13.435263492279052,8.763453505859374,13.326893492279053,9.064693505859374,13.197863492279053C9.516553505859374,13.004333492279052,9.976203505859374,12.872733492279053,10.459223505859375,12.779863492279052C10.942243505859375,12.679263492279052,11.433053505859375,12.617333492279052,11.955023505859375,12.617333492279052L13.380683505859375,7.810353492279052L13.458613505859375,7.779393492279053ZM15.273813505859374,8.135463492279053L15.016753505859375,5.333333492279053L13.458613505859375,7.787133492279053C13.817013505859375,7.818093492279052,14.144213505859375,7.880023492279053,14.494743505859375,7.949683492279053C14.494743505859375,7.944523492279053,14.754433505859375,8.006453492279054,15.273813505859374,8.135463492279053ZM12.064083505859376,12.648273492279053L11.378523505859375,14.970463492279054L12.515943505859376,16.00003349227905L14.074083505859376,15.643933492279054L14.525943505859376,13.027603492279052C14.198743505859374,12.934663492279054,13.879283505859375,12.834063492279054,13.552083505859375,12.772133492279053C13.069083505859375,12.717933492279052,12.578283505859375,12.648273492279053,12.064083505859376,12.648273492279053ZM18.327743505859374,9.428153492279053C17.829143505859374,9.128843492279053,17.330543505859374,8.870823492279053,16.831943505859375,8.654083492279053C16.348943505859374,8.460573492279053,15.826943505859376,8.267053492279054,15.305013505859375,8.135463492279053L15.305013505859375,8.267053492279054L14.463613505859374,13.043063492279053C14.596083505859376,13.105003492279053,14.759683505859375,13.135933492279053,14.884283505859376,13.205603492279053C15.185523505859376,13.334623492279052,15.401043505859375,13.443003492279052,15.530943505859375,13.530733492279053C15.858143505859376,13.724263492279054,16.341143505859375,13.623603492279052,16.535943505859375,13.306263492279053L18.514743505859375,10.403483492279053C18.779643505859376,10.039673492279054,18.686143505859377,9.621673492279053,18.327743505859374,9.428153492279053Z", + "fill": "#FF6A00", + "fill-opacity": "1" + }, + "children": [] + } + ] + } + ] + } + ] + }, + { + "type": "element", + "name": "g", + "attributes": {}, + "children": [ + { + "type": "element", + "name": "g", + "attributes": {}, + "children": [ + { + "type": "element", + "name": "path", + "attributes": { + "d": "M25.044,2.668L34.676,2.668L34.676,4.04L25.044,4.04L25.044,2.668ZM29.958,7.82Q29.258,9.066,28.355,10.41Q27.451999999999998,11.754,26.92,12.3L32.506,11.782Q31.442,10.158,30.84,9.346L32.058,8.562000000000001Q32.786,9.5,33.843,11.012Q34.9,12.524,35.516,13.546L34.214,14.526Q33.891999999999996,13.966,33.346000000000004,13.098Q32.016,13.182,29.734,13.378Q27.451999999999998,13.574,25.87,13.742L25.31,13.812L24.834,13.882L24.414,12.468Q24.708,12.37,24.862000000000002,12.265Q25.016,12.16,25.121,12.069Q25.226,11.978,25.268,11.936Q25.912,11.32,26.724,10.165Q27.536,9.01,28.208,7.82L23.854,7.82L23.854,6.434L35.866,6.434L35.866,7.82L29.958,7.82ZM42.656,7.414L42.656,8.576L41.354,8.576L41.354,1.814L42.656,1.87L42.656,7.036Q43.314,5.846,43.888000000000005,4.369Q44.462,2.892,44.714,1.6600000000000001L46.086,1.981999Q45.96,2.612,45.722,3.41L49.6,3.41L49.6,4.74L45.274,4.74Q44.616,6.56,43.706,8.128L42.656,7.414ZM38.596000000000004,2.346L39.884,2.402L39.884,8.212L38.596000000000004,8.212L38.596000000000004,2.346ZM46.184,4.964Q46.688,5.356,47.5,6.175Q48.312,6.994,48.788,7.582L47.751999999999995,8.59Q47.346000000000004,8.072,46.576,7.274Q45.806,6.476,45.204,5.902L46.184,4.964ZM48.41,9.01L48.41,12.706L49.894,12.706L49.894,13.966L37.391999999999996,13.966L37.391999999999996,12.706L38.848,12.706L38.848,9.01L48.41,9.01ZM41.676,10.256L40.164,10.256L40.164,12.706L41.676,12.706L41.676,10.256ZM42.908,12.706L44.364000000000004,12.706L44.364000000000004,10.256L42.908,10.256L42.908,12.706ZM45.582,12.706L47.108000000000004,12.706L47.108000000000004,10.256L45.582,10.256L45.582,12.706ZM54.906,7.456L55.116,8.394L54.178,8.814L54.178,12.818Q54.178,13.434,54.031,13.735Q53.884,14.036,53.534,14.162Q53.184,14.288,52.456,14.358L51.867999999999995,14.414L51.476,13.084L52.162,13.028Q52.512,13,52.652,12.958Q52.792,12.916,52.841,12.797Q52.89,12.678,52.89,12.384L52.89,9.36Q51.980000000000004,9.724,51.322,9.948L51.013999999999996,8.576Q51.798,8.324,52.89,7.876L52.89,5.524L51.42,5.524L51.42,4.166L52.89,4.166L52.89,1.7579989999999999L54.178,1.814L54.178,4.166L55.214,4.166L55.214,5.524L54.178,5.524L54.178,7.316L54.808,7.022L54.906,7.456ZM56.894,4.5440000000000005L56.894,6.098L55.564,6.098L55.564,3.256L58.686,3.256Q58.42,2.346,58.266,1.9260000000000002L59.624,1.7579989999999999Q59.848,2.276,60.142,3.256L63.25,3.256L63.25,6.098L61.962,6.098L61.962,4.5440000000000005L56.894,4.5440000000000005ZM59.008,6.322Q58.392,6.938,57.685,7.512Q56.978,8.086,55.956,8.841999999999999L55.242,7.764Q56.824,6.728,58.126,5.37L59.008,6.322ZM60.422,5.37Q61.024,5.776,62.095,6.581Q63.166,7.386,63.656,7.806L62.942,8.982Q62.368,8.45,61.332,7.652Q60.296,6.854,59.666,6.434L60.422,5.37ZM62.592,10.256L60.044,10.256L60.044,12.566L63.572,12.566L63.572,13.826L55.144,13.826L55.144,12.566L58.63,12.566L58.63,10.256L56.054,10.256L56.054,8.982L62.592,8.982L62.592,10.256Z", + "fill": "#FF6A00", + "fill-opacity": "1" + }, + "children": [] + } + ] + } + ] + } + ] + } + ] + }, + "name": "AliyunIcon" } diff --git a/web/app/components/base/icons/src/public/tracing/AliyunIconBig.json b/web/app/components/base/icons/src/public/tracing/AliyunIconBig.json index ea60744daf..7ed5166461 100644 --- a/web/app/components/base/icons/src/public/tracing/AliyunIconBig.json +++ b/web/app/components/base/icons/src/public/tracing/AliyunIconBig.json @@ -1,71 +1,78 @@ { - "icon": { - "type": "element", - "isRootNode": true, - "name": "svg", - "attributes": { - "xmlns": "http://www.w3.org/2000/svg", - "xmlns:xlink": "http://www.w3.org/1999/xlink", - "fill": "none", - "version": "1.1", - "width": "96", - "height": "24", - "viewBox": "0 0 96 24" - }, - "children": [ - { - "type": "element", - "name": "g", - "children": [ - { - "type": "element", - "name": "g", - "children": [ - { - "type": "element", - "name": "path", - "attributes": { - "d": "M6.10294,22C5.68819,22,5.18195,21.8532,4.58421,21.5595C4.46861,21.5027,4.39038,21.4658,4.34951,21.4488C3.49789,21.0943,2.74367,20.5941,2.08684,19.9484C0.695613,18.5806,0,16.9311,0,15C0,13.0689,0.695612,11.4194,2.08684,10.0516C3.24259,8.91537,4.59607,8.2511,6.14728,8.05878C6.34758,5.97414,7.22633,4.16634,8.78354,2.63539C10.5706,0.878463,12.7286,0,15.2573,0C17.2884,0,19.1146,0.595472,20.7358,1.78642C22.327,2.95528,23.4151,4.46783,24,6.32406L22.0568,6.91594C21.6024,5.47377,20.7561,4.29798,19.5181,3.38858C18.2579,2.46286,16.8377,2,15.2573,2C13.2903,2,11.6119,2.6832,10.222,4.04961C8.83217,5.41601,8.13725,7.06614,8.13725,9L8.13725,10L7.12009,10C5.71758,10,4.51932,10.4886,3.52532,11.4659C2.53132,12.4431,2.03431,13.6211,2.03431,15C2.03431,16.3789,2.53132,17.5569,3.52532,18.5341C3.99531,18.9962,4.53447,19.3538,5.14278,19.6071C5.2229,19.6405,5.33983,19.695,5.49356,19.7705C5.80505,19.9235,6.00818,20,6.10294,20L6.10294,22Z", - "fill-rule": "evenodd", - "fill": "#FF6A00", - "fill-opacity": "1" - } - } - ] - }, - { - "type": "element", - "name": "g", - "children": [ - { - "type": "element", - "name": "path", - "attributes": { - "d": "M20.18796103515625,11.66909C19.46346103515625,11.5762,18.72726103515625,11.52975,17.991011035156248,11.52975C16.728921035156247,11.52975,15.45515103515625,11.66909,14.23981103515625,11.91292C13.02447103515625,12.156749999999999,11.85588103515625,12.539909999999999,10.73402103515625,12.98113C9.98612103515625,13.306239999999999,9.23822103515625,13.69327,8.49031803515625,14.14223C7.99950790415625,14.43251,7.85927603515625,15.10595,8.15142503515625,15.59361L11.11966103515625,19.9478C11.45855103515625,20.4354,12.13634103515625,20.5747,12.627151035156249,20.2845C12.821921035156251,20.152900000000002,13.14523103515625,19.990299999999998,13.59708103515625,19.796799999999998C14.27487103515625,19.506500000000003,14.964341035156249,19.3091,15.68887103515625,19.169800000000002C16.413401035156248,19.018900000000002,17.14962103515625,18.926000000000002,17.93258103515625,18.926000000000002L20.071061035156248,11.715530000000001L20.18796103515625,11.66909ZM22.91076103515625,12.20319L22.525161035156252,8L20.18796103515625,11.6807C20.72556103515625,11.72714,21.21636103515625,11.82003,21.74216103515625,11.92453C21.74216103515625,11.91679,22.13166103515625,12.00968,22.91076103515625,12.20319ZM18.09616103515625,18.9724L17.06782103515625,22.4557L18.773961035156248,24L21.11116103515625,23.465899999999998L21.788961035156248,19.5414C21.298161035156248,19.402,20.81896103515625,19.2511,20.32816103515625,19.1582C19.60366103515625,19.076900000000002,18.86746103515625,18.9724,18.09616103515625,18.9724ZM27.49166103515625,14.14223C26.74376103515625,13.69327,25.99586103515625,13.306239999999999,25.24796103515625,12.98113C24.52346103515625,12.69086,23.74046103515625,12.40058,22.95756103515625,12.20319L22.95756103515625,12.40058L21.69546103515625,19.5646C21.89416103515625,19.6575,22.139561035156248,19.7039,22.32646103515625,19.8084C22.77836103515625,20.0019,23.101661035156248,20.1645,23.29646103515625,20.2961C23.78726103515625,20.586399999999998,24.51176103515625,20.4354,24.80396103515625,19.959400000000002L27.77216103515625,15.605229999999999C28.16946103515625,15.05951,28.02926103515625,14.43251,27.49166103515625,14.14223Z", - "fill": "#FF6A00", - "fill-opacity": "1" - } - } - ] - }, - { - "type": "element", - "name": "g", - "children": [ - { - "type": "element", - "name": "path", - "attributes": { - "d": "M35.785,3.8624638671875L50.233000000000004,3.8624638671875L50.233000000000004,5.9204638671875L35.785,5.9204638671875L35.785,3.8624638671875ZM43.156,11.5904638671875Q42.106,13.4594638671875,40.7515,15.4754638671875Q39.397,17.4914638671875,38.599000000000004,18.3104638671875L46.978,17.5334638671875Q45.382,15.0974638671875,44.479,13.8794638671875L46.306,12.7034638671875Q47.397999999999996,14.1104638671875,48.9835,16.3784638671875Q50.569,18.6464638671875,51.492999999999995,20.1794638671875L49.54,21.6494638671875Q49.057,20.8094638671875,48.238,19.5074638671875Q46.243,19.6334638671875,42.82,19.9274638671875Q39.397,20.2214638671875,37.024,20.4734638671875L36.184,20.5784638671875L35.47,20.6834638671875L34.84,18.5624638671875Q35.281,18.4154638671875,35.512,18.2579638671875Q35.743,18.1004638671875,35.9005,17.963963867187502Q36.058,17.8274638671875,36.121,17.7644638671875Q37.087,16.840463867187502,38.305,15.1079638671875Q39.522999999999996,13.3754638671875,40.531,11.5904638671875L34,11.5904638671875L34,9.5114638671875L52.018,9.5114638671875L52.018,11.5904638671875L43.156,11.5904638671875ZM62.203,10.9814638671875L62.203,12.7244638671875L60.25,12.7244638671875L60.25,2.5814638671875L62.203,2.6654638671875L62.203,10.4144638671875Q63.19,8.6294638671875,64.051,6.4139638671875Q64.912,4.1984638671875,65.28999999999999,2.3504638671875L67.348,2.8334628671875Q67.15899999999999,3.7784638671875,66.80199999999999,4.9754638671875L72.619,4.9754638671875L72.619,6.9704638671875L66.13,6.9704638671875Q65.143,9.7004638671875,63.778,12.0524638671875L62.203,10.9814638671875ZM56.113,3.3794638671875L58.045,3.4634638671875L58.045,12.1784638671875L56.113,12.1784638671875L56.113,3.3794638671875ZM67.495,7.3064638671875Q68.251,7.8944638671875,69.469,9.1229638671875Q70.687,10.3514638671875,71.40100000000001,11.2334638671875L69.84700000000001,12.7454638671875Q69.238,11.9684638671875,68.083,10.7714638671875Q66.928,9.5744638671875,66.025,8.7134638671875L67.495,7.3064638671875ZM70.834,13.3754638671875L70.834,18.9194638671875L73.06,18.9194638671875L73.06,20.8094638671875L54.307,20.8094638671875L54.307,18.9194638671875L56.491,18.9194638671875L56.491,13.3754638671875L70.834,13.3754638671875ZM60.733000000000004,15.2444638671875L58.465,15.2444638671875L58.465,18.9194638671875L60.733000000000004,18.9194638671875L60.733000000000004,15.2444638671875ZM62.581,18.9194638671875L64.765,18.9194638671875L64.765,15.2444638671875L62.581,15.2444638671875L62.581,18.9194638671875ZM66.592,18.9194638671875L68.881,18.9194638671875L68.881,15.2444638671875L66.592,15.2444638671875L66.592,18.9194638671875ZM80.578,11.0444638671875L80.893,12.4514638671875L79.48599999999999,13.0814638671875L79.48599999999999,19.0874638671875Q79.48599999999999,20.0114638671875,79.2655,20.4629638671875Q79.045,20.9144638671875,78.52000000000001,21.1034638671875Q77.995,21.2924638671875,76.90299999999999,21.3974638671875L76.021,21.4814638671875L75.43299999999999,19.4864638671875L76.462,19.4024638671875Q76.987,19.3604638671875,77.197,19.2974638671875Q77.407,19.2344638671875,77.4805,19.0559638671875Q77.554,18.8774638671875,77.554,18.4364638671875L77.554,13.9004638671875Q76.189,14.4464638671875,75.202,14.7824638671875L74.74000000000001,12.7244638671875Q75.916,12.3464638671875,77.554,11.6744638671875L77.554,8.1464638671875L75.34899999999999,8.1464638671875L75.34899999999999,6.1094638671875L77.554,6.1094638671875L77.554,2.4974628671875L79.48599999999999,2.5814638671875L79.48599999999999,6.1094638671875L81.03999999999999,6.1094638671875L81.03999999999999,8.1464638671875L79.48599999999999,8.1464638671875L79.48599999999999,10.8344638671875L80.431,10.3934638671875L80.578,11.0444638671875ZM83.56,6.6764638671875L83.56,9.0074638671875L81.565,9.0074638671875L81.565,4.7444638671875L86.24799999999999,4.7444638671875Q85.84899999999999,3.3794638671875,85.618,2.7494638671875L87.655,2.4974628671875Q87.991,3.2744638671875,88.432,4.7444638671875L93.094,4.7444638671875L93.094,9.0074638671875L91.162,9.0074638671875L91.162,6.6764638671875L83.56,6.6764638671875ZM86.731,9.3434638671875Q85.807,10.2674638671875,84.7465,11.1284638671875Q83.686,11.9894638671875,82.15299999999999,13.1234638671875L81.082,11.5064638671875Q83.455,9.9524638671875,85.408,7.9154638671875L86.731,9.3434638671875ZM88.852,7.9154638671875Q89.755,8.5244638671875,91.3615,9.731963867187499Q92.968,10.9394638671875,93.703,11.5694638671875L92.632,13.3334638671875Q91.771,12.5354638671875,90.217,11.3384638671875Q88.663,10.1414638671875,87.718,9.5114638671875L88.852,7.9154638671875ZM92.107,15.2444638671875L88.285,15.2444638671875L88.285,18.7094638671875L93.577,18.7094638671875L93.577,20.5994638671875L80.935,20.5994638671875L80.935,18.7094638671875L86.164,18.7094638671875L86.164,15.2444638671875L82.3,15.2444638671875L82.3,13.3334638671875L92.107,13.3334638671875L92.107,15.2444638671875Z", - "fill": "#FF6A00", - "fill-opacity": "1" - } - } - ] - } - ] - } - ] - }, - "name": "AliyunBigIcon" + "icon": { + "type": "element", + "isRootNode": true, + "name": "svg", + "attributes": { + "xmlns": "http://www.w3.org/2000/svg", + "xmlns:xlink": "http://www.w3.org/1999/xlink", + "fill": "none", + "version": "1.1", + "width": "96", + "height": "24", + "viewBox": "0 0 96 24" + }, + "children": [ + { + "type": "element", + "name": "g", + "attributes": {}, + "children": [ + { + "type": "element", + "name": "g", + "attributes": {}, + "children": [ + { + "type": "element", + "name": "path", + "attributes": { + "d": "M6.10294,22C5.68819,22,5.18195,21.8532,4.58421,21.5595C4.46861,21.5027,4.39038,21.4658,4.34951,21.4488C3.49789,21.0943,2.74367,20.5941,2.08684,19.9484C0.695613,18.5806,0,16.9311,0,15C0,13.0689,0.695612,11.4194,2.08684,10.0516C3.24259,8.91537,4.59607,8.2511,6.14728,8.05878C6.34758,5.97414,7.22633,4.16634,8.78354,2.63539C10.5706,0.878463,12.7286,0,15.2573,0C17.2884,0,19.1146,0.595472,20.7358,1.78642C22.327,2.95528,23.4151,4.46783,24,6.32406L22.0568,6.91594C21.6024,5.47377,20.7561,4.29798,19.5181,3.38858C18.2579,2.46286,16.8377,2,15.2573,2C13.2903,2,11.6119,2.6832,10.222,4.04961C8.83217,5.41601,8.13725,7.06614,8.13725,9L8.13725,10L7.12009,10C5.71758,10,4.51932,10.4886,3.52532,11.4659C2.53132,12.4431,2.03431,13.6211,2.03431,15C2.03431,16.3789,2.53132,17.5569,3.52532,18.5341C3.99531,18.9962,4.53447,19.3538,5.14278,19.6071C5.2229,19.6405,5.33983,19.695,5.49356,19.7705C5.80505,19.9235,6.00818,20,6.10294,20L6.10294,22Z", + "fill-rule": "evenodd", + "fill": "#FF6A00", + "fill-opacity": "1" + }, + "children": [] + } + ] + }, + { + "type": "element", + "name": "g", + "attributes": {}, + "children": [ + { + "type": "element", + "name": "path", + "attributes": { + "d": "M20.18796103515625,11.66909C19.46346103515625,11.5762,18.72726103515625,11.52975,17.991011035156248,11.52975C16.728921035156247,11.52975,15.45515103515625,11.66909,14.23981103515625,11.91292C13.02447103515625,12.156749999999999,11.85588103515625,12.539909999999999,10.73402103515625,12.98113C9.98612103515625,13.306239999999999,9.23822103515625,13.69327,8.49031803515625,14.14223C7.99950790415625,14.43251,7.85927603515625,15.10595,8.15142503515625,15.59361L11.11966103515625,19.9478C11.45855103515625,20.4354,12.13634103515625,20.5747,12.627151035156249,20.2845C12.821921035156251,20.152900000000002,13.14523103515625,19.990299999999998,13.59708103515625,19.796799999999998C14.27487103515625,19.506500000000003,14.964341035156249,19.3091,15.68887103515625,19.169800000000002C16.413401035156248,19.018900000000002,17.14962103515625,18.926000000000002,17.93258103515625,18.926000000000002L20.071061035156248,11.715530000000001L20.18796103515625,11.66909ZM22.91076103515625,12.20319L22.525161035156252,8L20.18796103515625,11.6807C20.72556103515625,11.72714,21.21636103515625,11.82003,21.74216103515625,11.92453C21.74216103515625,11.91679,22.13166103515625,12.00968,22.91076103515625,12.20319ZM18.09616103515625,18.9724L17.06782103515625,22.4557L18.773961035156248,24L21.11116103515625,23.465899999999998L21.788961035156248,19.5414C21.298161035156248,19.402,20.81896103515625,19.2511,20.32816103515625,19.1582C19.60366103515625,19.076900000000002,18.86746103515625,18.9724,18.09616103515625,18.9724ZM27.49166103515625,14.14223C26.74376103515625,13.69327,25.99586103515625,13.306239999999999,25.24796103515625,12.98113C24.52346103515625,12.69086,23.74046103515625,12.40058,22.95756103515625,12.20319L22.95756103515625,12.40058L21.69546103515625,19.5646C21.89416103515625,19.6575,22.139561035156248,19.7039,22.32646103515625,19.8084C22.77836103515625,20.0019,23.101661035156248,20.1645,23.29646103515625,20.2961C23.78726103515625,20.586399999999998,24.51176103515625,20.4354,24.80396103515625,19.959400000000002L27.77216103515625,15.605229999999999C28.16946103515625,15.05951,28.02926103515625,14.43251,27.49166103515625,14.14223Z", + "fill": "#FF6A00", + "fill-opacity": "1" + }, + "children": [] + } + ] + }, + { + "type": "element", + "name": "g", + "attributes": {}, + "children": [ + { + "type": "element", + "name": "path", + "attributes": { + "d": "M35.785,3.8624638671875L50.233000000000004,3.8624638671875L50.233000000000004,5.9204638671875L35.785,5.9204638671875L35.785,3.8624638671875ZM43.156,11.5904638671875Q42.106,13.4594638671875,40.7515,15.4754638671875Q39.397,17.4914638671875,38.599000000000004,18.3104638671875L46.978,17.5334638671875Q45.382,15.0974638671875,44.479,13.8794638671875L46.306,12.7034638671875Q47.397999999999996,14.1104638671875,48.9835,16.3784638671875Q50.569,18.6464638671875,51.492999999999995,20.1794638671875L49.54,21.6494638671875Q49.057,20.8094638671875,48.238,19.5074638671875Q46.243,19.6334638671875,42.82,19.9274638671875Q39.397,20.2214638671875,37.024,20.4734638671875L36.184,20.5784638671875L35.47,20.6834638671875L34.84,18.5624638671875Q35.281,18.4154638671875,35.512,18.2579638671875Q35.743,18.1004638671875,35.9005,17.963963867187502Q36.058,17.8274638671875,36.121,17.7644638671875Q37.087,16.840463867187502,38.305,15.1079638671875Q39.522999999999996,13.3754638671875,40.531,11.5904638671875L34,11.5904638671875L34,9.5114638671875L52.018,9.5114638671875L52.018,11.5904638671875L43.156,11.5904638671875ZM62.203,10.9814638671875L62.203,12.7244638671875L60.25,12.7244638671875L60.25,2.5814638671875L62.203,2.6654638671875L62.203,10.4144638671875Q63.19,8.6294638671875,64.051,6.4139638671875Q64.912,4.1984638671875,65.28999999999999,2.3504638671875L67.348,2.8334628671875Q67.15899999999999,3.7784638671875,66.80199999999999,4.9754638671875L72.619,4.9754638671875L72.619,6.9704638671875L66.13,6.9704638671875Q65.143,9.7004638671875,63.778,12.0524638671875L62.203,10.9814638671875ZM56.113,3.3794638671875L58.045,3.4634638671875L58.045,12.1784638671875L56.113,12.1784638671875L56.113,3.3794638671875ZM67.495,7.3064638671875Q68.251,7.8944638671875,69.469,9.1229638671875Q70.687,10.3514638671875,71.40100000000001,11.2334638671875L69.84700000000001,12.7454638671875Q69.238,11.9684638671875,68.083,10.7714638671875Q66.928,9.5744638671875,66.025,8.7134638671875L67.495,7.3064638671875ZM70.834,13.3754638671875L70.834,18.9194638671875L73.06,18.9194638671875L73.06,20.8094638671875L54.307,20.8094638671875L54.307,18.9194638671875L56.491,18.9194638671875L56.491,13.3754638671875L70.834,13.3754638671875ZM60.733000000000004,15.2444638671875L58.465,15.2444638671875L58.465,18.9194638671875L60.733000000000004,18.9194638671875L60.733000000000004,15.2444638671875ZM62.581,18.9194638671875L64.765,18.9194638671875L64.765,15.2444638671875L62.581,15.2444638671875L62.581,18.9194638671875ZM66.592,18.9194638671875L68.881,18.9194638671875L68.881,15.2444638671875L66.592,15.2444638671875L66.592,18.9194638671875ZM80.578,11.0444638671875L80.893,12.4514638671875L79.48599999999999,13.0814638671875L79.48599999999999,19.0874638671875Q79.48599999999999,20.0114638671875,79.2655,20.4629638671875Q79.045,20.9144638671875,78.52000000000001,21.1034638671875Q77.995,21.2924638671875,76.90299999999999,21.3974638671875L76.021,21.4814638671875L75.43299999999999,19.4864638671875L76.462,19.4024638671875Q76.987,19.3604638671875,77.197,19.2974638671875Q77.407,19.2344638671875,77.4805,19.0559638671875Q77.554,18.8774638671875,77.554,18.4364638671875L77.554,13.9004638671875Q76.189,14.4464638671875,75.202,14.7824638671875L74.74000000000001,12.7244638671875Q75.916,12.3464638671875,77.554,11.6744638671875L77.554,8.1464638671875L75.34899999999999,8.1464638671875L75.34899999999999,6.1094638671875L77.554,6.1094638671875L77.554,2.4974628671875L79.48599999999999,2.5814638671875L79.48599999999999,6.1094638671875L81.03999999999999,6.1094638671875L81.03999999999999,8.1464638671875L79.48599999999999,8.1464638671875L79.48599999999999,10.8344638671875L80.431,10.3934638671875L80.578,11.0444638671875ZM83.56,6.6764638671875L83.56,9.0074638671875L81.565,9.0074638671875L81.565,4.7444638671875L86.24799999999999,4.7444638671875Q85.84899999999999,3.3794638671875,85.618,2.7494638671875L87.655,2.4974628671875Q87.991,3.2744638671875,88.432,4.7444638671875L93.094,4.7444638671875L93.094,9.0074638671875L91.162,9.0074638671875L91.162,6.6764638671875L83.56,6.6764638671875ZM86.731,9.3434638671875Q85.807,10.2674638671875,84.7465,11.1284638671875Q83.686,11.9894638671875,82.15299999999999,13.1234638671875L81.082,11.5064638671875Q83.455,9.9524638671875,85.408,7.9154638671875L86.731,9.3434638671875ZM88.852,7.9154638671875Q89.755,8.5244638671875,91.3615,9.731963867187499Q92.968,10.9394638671875,93.703,11.5694638671875L92.632,13.3334638671875Q91.771,12.5354638671875,90.217,11.3384638671875Q88.663,10.1414638671875,87.718,9.5114638671875L88.852,7.9154638671875ZM92.107,15.2444638671875L88.285,15.2444638671875L88.285,18.7094638671875L93.577,18.7094638671875L93.577,20.5994638671875L80.935,20.5994638671875L80.935,18.7094638671875L86.164,18.7094638671875L86.164,15.2444638671875L82.3,15.2444638671875L82.3,13.3334638671875L92.107,13.3334638671875L92.107,15.2444638671875Z", + "fill": "#FF6A00", + "fill-opacity": "1" + }, + "children": [] + } + ] + } + ] + } + ] + }, + "name": "AliyunIconBig" } diff --git a/web/app/components/base/image-uploader/image-list.tsx b/web/app/components/base/image-uploader/image-list.tsx index 758ffe99d5..3b5f6dee9c 100644 --- a/web/app/components/base/image-uploader/image-list.tsx +++ b/web/app/components/base/image-uploader/image-list.tsx @@ -62,7 +62,7 @@ const ImageList: FC = ({ {item.progress === -1 && ( onReUpload && onReUpload(item._id)} + onClick={() => onReUpload?.(item._id)} /> )}
@@ -122,7 +122,7 @@ const ImageList: FC = ({ 'rounded-2xl shadow-lg hover:bg-state-base-hover', item.progress === -1 ? 'flex' : 'hidden group-hover:flex', )} - onClick={() => onRemove && onRemove(item._id)} + onClick={() => onRemove?.(item._id)} > diff --git a/web/app/components/base/image-uploader/image-preview.tsx b/web/app/components/base/image-uploader/image-preview.tsx index e67edaa3ca..53c22e344f 100644 --- a/web/app/components/base/image-uploader/image-preview.tsx +++ b/web/app/components/base/image-uploader/image-preview.tsx @@ -20,7 +20,7 @@ const isBase64 = (str: string): boolean => { try { return btoa(atob(str)) === str } - catch (err) { + catch { return false } } diff --git a/web/app/components/base/markdown-blocks/code-block.tsx b/web/app/components/base/markdown-blocks/code-block.tsx index 48de8bf4ab..bc41c65fd5 100644 --- a/web/app/components/base/markdown-blocks/code-block.tsx +++ b/web/app/components/base/markdown-blocks/code-block.tsx @@ -8,12 +8,14 @@ import { import ActionButton from '@/app/components/base/action-button' import CopyIcon from '@/app/components/base/copy-icon' import SVGBtn from '@/app/components/base/svg' -import Flowchart from '@/app/components/base/mermaid' import { Theme } from '@/types/app' import useTheme from '@/hooks/use-theme' import SVGRenderer from '../svg-gallery' // Assumes svg-gallery.tsx is in /base directory import MarkdownMusic from '@/app/components/base/markdown-blocks/music' import ErrorBoundary from '@/app/components/base/markdown/error-boundary' +import dynamic from 'next/dynamic' + +const Flowchart = dynamic(() => import('@/app/components/base/mermaid'), { ssr: false }) // Available language https://github.com/react-syntax-highlighter/react-syntax-highlighter/blob/master/AVAILABLE_LANGUAGES_HLJS.MD const capitalizationLanguageNameMap: Record = { @@ -125,7 +127,7 @@ const CodeBlock: any = memo(({ inline, className, children = '', ...props }: any // Store event handlers in useMemo to avoid recreating them const echartsEvents = useMemo(() => ({ - finished: (params: EChartsEventParams) => { + finished: (_params: EChartsEventParams) => { // Limit finished event frequency to avoid infinite loops finishedEventCountRef.current++ if (finishedEventCountRef.current > 3) { diff --git a/web/app/components/base/markdown-blocks/think-block.tsx b/web/app/components/base/markdown-blocks/think-block.tsx index a5813266f1..a3b0561677 100644 --- a/web/app/components/base/markdown-blocks/think-block.tsx +++ b/web/app/components/base/markdown-blocks/think-block.tsx @@ -37,7 +37,7 @@ const removeEndThink = (children: any): any => { const useThinkTimer = (children: any) => { const { isResponding } = useChatContext() - const [startTime] = useState(Date.now()) + const [startTime] = useState(() => Date.now()) const [elapsedTime, setElapsedTime] = useState(0) const [isComplete, setIsComplete] = useState(false) const timerRef = useRef() @@ -63,7 +63,7 @@ const useThinkTimer = (children: any) => { return { elapsedTime, isComplete } } -export const ThinkBlock = ({ children, ...props }: any) => { +const ThinkBlock = ({ children, ...props }: React.ComponentProps<'details'>) => { const { elapsedTime, isComplete } = useThinkTimer(children) const displayContent = removeEndThink(children) const { t } = useTranslation() diff --git a/web/app/components/base/markdown/index.tsx b/web/app/components/base/markdown/index.tsx index bab5ac8eba..19f39d8aaa 100644 --- a/web/app/components/base/markdown/index.tsx +++ b/web/app/components/base/markdown/index.tsx @@ -1,25 +1,11 @@ -import ReactMarkdown from 'react-markdown' +import dynamic from 'next/dynamic' import 'katex/dist/katex.min.css' -import RemarkMath from 'remark-math' -import RemarkBreaks from 'remark-breaks' -import RehypeKatex from 'rehype-katex' -import RemarkGfm from 'remark-gfm' -import RehypeRaw from 'rehype-raw' import { flow } from 'lodash-es' import cn from '@/utils/classnames' -import { customUrlTransform, preprocessLaTeX, preprocessThinkTag } from './markdown-utils' -import { - AudioBlock, - CodeBlock, - Img, - Link, - MarkdownButton, - MarkdownForm, - Paragraph, - ScriptBlock, - ThinkBlock, - VideoBlock, -} from '@/app/components/base/markdown-blocks' +import { preprocessLaTeX, preprocessThinkTag } from './markdown-utils' +import type { ReactMarkdownWrapperProps } from './react-markdown-wrapper' + +const ReactMarkdown = dynamic(() => import('./react-markdown-wrapper').then(mod => mod.ReactMarkdownWrapper), { ssr: false }) /** * @fileoverview Main Markdown rendering component. @@ -31,9 +17,7 @@ import { export type MarkdownProps = { content: string className?: string - customDisallowedElements?: string[] - customComponents?: Record> -} +} & Pick export const Markdown = (props: MarkdownProps) => { const { customComponents = {} } = props @@ -44,53 +28,7 @@ export const Markdown = (props: MarkdownProps) => { return (
- { - return (tree: any) => { - const iterate = (node: any) => { - if (node.type === 'element' && node.properties?.ref) - delete node.properties.ref - - if (node.type === 'element' && !/^[a-z][a-z0-9]*$/i.test(node.tagName)) { - node.type = 'text' - node.value = `<${node.tagName}` - } - - if (node.children) - node.children.forEach(iterate) - } - tree.children.forEach(iterate) - } - }, - ]} - urlTransform={customUrlTransform} - disallowedElements={['iframe', 'head', 'html', 'meta', 'link', 'style', 'body', ...(props.customDisallowedElements || [])]} - components={{ - code: CodeBlock, - img: Img, - video: VideoBlock, - audio: AudioBlock, - a: Link, - p: Paragraph, - button: MarkdownButton, - form: MarkdownForm, - script: ScriptBlock as any, - details: ThinkBlock, - ...customComponents, - }} - > - {/* Markdown detect has problem. */} - {latexContent} - +
) } diff --git a/web/app/components/base/markdown/react-markdown-wrapper.tsx b/web/app/components/base/markdown/react-markdown-wrapper.tsx new file mode 100644 index 0000000000..054b5f66cb --- /dev/null +++ b/web/app/components/base/markdown/react-markdown-wrapper.tsx @@ -0,0 +1,82 @@ +import ReactMarkdown from 'react-markdown' +import RemarkMath from 'remark-math' +import RemarkBreaks from 'remark-breaks' +import RehypeKatex from 'rehype-katex' +import RemarkGfm from 'remark-gfm' +import RehypeRaw from 'rehype-raw' +import AudioBlock from '@/app/components/base/markdown-blocks/audio-block' +import Img from '@/app/components/base/markdown-blocks/img' +import Link from '@/app/components/base/markdown-blocks/link' +import MarkdownButton from '@/app/components/base/markdown-blocks/button' +import MarkdownForm from '@/app/components/base/markdown-blocks/form' +import Paragraph from '@/app/components/base/markdown-blocks/paragraph' +import ScriptBlock from '@/app/components/base/markdown-blocks/script-block' +import ThinkBlock from '@/app/components/base/markdown-blocks/think-block' +import VideoBlock from '@/app/components/base/markdown-blocks/video-block' +import { customUrlTransform } from './markdown-utils' + +import type { FC } from 'react' + +import dynamic from 'next/dynamic' + +const CodeBlock = dynamic(() => import('@/app/components/base/markdown-blocks/code-block'), { ssr: false }) + +export type ReactMarkdownWrapperProps = { + latexContent: any + customDisallowedElements?: string[] + customComponents?: Record> +} + +export const ReactMarkdownWrapper: FC = (props) => { + const { customComponents, latexContent } = props + + return ( + { + return (tree: any) => { + const iterate = (node: any) => { + if (node.type === 'element' && node.properties?.ref) + delete node.properties.ref + + if (node.type === 'element' && !/^[a-z][a-z0-9]*$/i.test(node.tagName)) { + node.type = 'text' + node.value = `<${node.tagName}` + } + + if (node.children) + node.children.forEach(iterate) + } + tree.children.forEach(iterate) + } + }, + ]} + urlTransform={customUrlTransform} + disallowedElements={['iframe', 'head', 'html', 'meta', 'link', 'style', 'body', ...(props.customDisallowedElements || [])]} + components={{ + code: CodeBlock, + img: Img, + video: VideoBlock, + audio: AudioBlock, + a: Link, + p: Paragraph, + button: MarkdownButton, + form: MarkdownForm, + script: ScriptBlock as any, + details: ThinkBlock, + ...customComponents, + }} + > + {/* Markdown detect has problem. */} + {latexContent} + + ) +} diff --git a/web/app/components/base/mermaid/index.tsx b/web/app/components/base/mermaid/index.tsx index c1deab6e09..9b324349f8 100644 --- a/web/app/components/base/mermaid/index.tsx +++ b/web/app/components/base/mermaid/index.tsx @@ -541,7 +541,7 @@ const Flowchart = (props: FlowchartProps) => { {svgString && (
-
} -
} - - +
) } diff --git a/web/app/components/base/video-gallery/VideoPlayer.tsx b/web/app/components/base/video-gallery/VideoPlayer.tsx index d7c86a1af9..c2fcd6ee8d 100644 --- a/web/app/components/base/video-gallery/VideoPlayer.tsx +++ b/web/app/components/base/video-gallery/VideoPlayer.tsx @@ -234,13 +234,13 @@ const VideoPlayer: React.FC = ({ src }) => {
- {!isSmallSize && ({formatTime(currentTime)} / {formatTime(duration)})}
- {!isSmallSize && ( @@ -264,7 +264,7 @@ const VideoPlayer: React.FC = ({ src }) => {
)} -
diff --git a/web/app/components/base/voice-input/index.tsx b/web/app/components/base/voice-input/index.tsx index 5a5400ad30..6587a61217 100644 --- a/web/app/components/base/voice-input/index.tsx +++ b/web/app/components/base/voice-input/index.tsx @@ -81,7 +81,8 @@ const VoiceInput = ({ setStartRecord(false) setStartConvert(true) recorder.current.stop() - drawRecordId.current && cancelAnimationFrame(drawRecordId.current) + if (drawRecordId.current) + cancelAnimationFrame(drawRecordId.current) drawRecordId.current = null const canvas = canvasRef.current! const ctx = ctxRef.current! diff --git a/web/app/components/billing/pricing/footer.tsx b/web/app/components/billing/pricing/footer.tsx index 4e3cdfee3d..fd713eb3da 100644 --- a/web/app/components/billing/pricing/footer.tsx +++ b/web/app/components/billing/pricing/footer.tsx @@ -2,19 +2,29 @@ import React from 'react' import Link from 'next/link' import { useTranslation } from 'react-i18next' import { RiArrowRightUpLine } from '@remixicon/react' +import { type Category, CategoryEnum } from '.' +import cn from '@/utils/classnames' type FooterProps = { pricingPageURL: string + currentCategory: Category } const Footer = ({ pricingPageURL, + currentCategory, }: FooterProps) => { const { t } = useTranslation() return (
-
+
+ {currentCategory === CategoryEnum.CLOUD && ( +
+ {t('billing.plansCommon.taxTip')} + {t('billing.plansCommon.taxTipSecond')} +
+ )} void @@ -25,7 +30,7 @@ const Pricing: FC = ({ const { plan } = useProviderContext() const { isCurrentWorkspaceManager } = useAppContext() const [planRange, setPlanRange] = React.useState(PlanRange.monthly) - const [currentCategory, setCurrentCategory] = useState('cloud') + const [currentCategory, setCurrentCategory] = useState(CategoryEnum.CLOUD) const canPay = isCurrentWorkspaceManager useKeyPress(['esc'], onCancel) @@ -57,7 +62,7 @@ const Pricing: FC = ({ planRange={planRange} canPay={canPay} /> -
+
diff --git a/web/app/components/billing/pricing/plans/self-hosted-plan-item/button.tsx b/web/app/components/billing/pricing/plans/self-hosted-plan-item/button.tsx index 5308490d79..ffa4dbcb65 100644 --- a/web/app/components/billing/pricing/plans/self-hosted-plan-item/button.tsx +++ b/web/app/components/billing/pricing/plans/self-hosted-plan-item/button.tsx @@ -31,7 +31,7 @@ const Button = ({ }, [theme]) return ( - diff --git a/web/app/components/datasets/create/website/index.tsx b/web/app/components/datasets/create/website/index.tsx index 80d6b52315..7190ca3228 100644 --- a/web/app/components/datasets/create/website/index.tsx +++ b/web/app/components/datasets/create/website/index.tsx @@ -61,7 +61,7 @@ const Website: FC = ({ {t('datasetCreation.stepOne.website.chooseProvider')}
- {ENABLE_WEBSITE_JINAREADER && } - {ENABLE_WEBSITE_FIRECRAWL && } - {ENABLE_WEBSITE_WATERCRAWL &&
- + { diff --git a/web/app/components/datasets/list/dataset-card/index.tsx b/web/app/components/datasets/list/dataset-card/index.tsx index f95e7b2199..b1304e578e 100644 --- a/web/app/components/datasets/list/dataset-card/index.tsx +++ b/web/app/components/datasets/list/dataset-card/index.tsx @@ -11,9 +11,6 @@ import cn from '@/utils/classnames' import { useHover } from 'ahooks' import { RiFileTextFill, RiMoreFill, RiRobot2Fill } from '@remixicon/react' import Tooltip from '@/app/components/base/tooltip' -import { useGetLanguage } from '@/context/i18n' -import dayjs from 'dayjs' -import relativeTime from 'dayjs/plugin/relativeTime' import { checkIsUsedInApp, deleteDataset } from '@/service/datasets' import RenameDatasetModal from '../../rename-modal' import Confirm from '@/app/components/base/confirm' @@ -24,7 +21,7 @@ import AppIcon from '@/app/components/base/app-icon' import CornerLabel from '@/app/components/base/corner-label' import { DOC_FORM_ICON_WITH_BG, DOC_FORM_TEXT } from '@/models/datasets' import { useExportPipelineDSL } from '@/service/use-pipeline' -dayjs.extend(relativeTime) +import { useFormatTimeFromNow } from '@/hooks/use-format-time-from-now' const EXTERNAL_PROVIDER = 'external' @@ -87,10 +84,7 @@ const DatasetCard = ({ return t('dataset.partialEnabled', { count: dataset.document_count, num: availableDocCount }) }, [t, dataset.document_count, dataset.total_available_documents]) - const language = useGetLanguage() - const formatTimeFromNow = useCallback((time: number) => { - return dayjs(time * 1_000).locale(language === 'zh_Hans' ? 'zh-cn' : language.replace('_', '-')).fromNow() - }, [language]) + const { formatTimeFromNow } = useFormatTimeFromNow() const openRenameModal = useCallback(() => { setShowRenameModal(true) @@ -163,12 +157,12 @@ const DatasetCard = ({ data-disable-nprogress={true} onClick={(e) => { e.preventDefault() - isExternalProvider - ? push(`/datasets/${dataset.id}/hitTesting`) - // eslint-disable-next-line sonarjs/no-nested-conditional - : isPipelineUnpublished - ? push(`/datasets/${dataset.id}/pipeline`) - : push(`/datasets/${dataset.id}/documents`) + if (isExternalProvider) + push(`/datasets/${dataset.id}/hitTesting`) + else if (isPipelineUnpublished) + push(`/datasets/${dataset.id}/pipeline`) + else + push(`/datasets/${dataset.id}/documents`) }} > {!dataset.embedding_available && ( @@ -269,7 +263,7 @@ const DatasetCard = ({ )} / - {`${t('dataset.updated')} ${formatTimeFromNow(dataset.updated_at)}`} + {`${t('dataset.updated')} ${formatTimeFromNow(dataset.updated_at * 1000)}`}
null + +export default DatasetsLoading diff --git a/web/app/components/datasets/preview/index.tsx b/web/app/components/datasets/preview/index.tsx index e69de29bb2..e71c440c20 100644 --- a/web/app/components/datasets/preview/index.tsx +++ b/web/app/components/datasets/preview/index.tsx @@ -0,0 +1,3 @@ +const DatasetPreview = () => null + +export default DatasetPreview diff --git a/web/app/components/develop/code.tsx b/web/app/components/develop/code.tsx index ee67921031..69d5624966 100644 --- a/web/app/components/develop/code.tsx +++ b/web/app/components/develop/code.tsx @@ -13,17 +13,6 @@ import classNames from '@/utils/classnames' import { writeTextToClipboard } from '@/utils/clipboard' import type { PropsWithChildren, ReactElement, ReactNode } from 'react' -const languageNames = { - js: 'JavaScript', - ts: 'TypeScript', - javascript: 'JavaScript', - typescript: 'TypeScript', - php: 'PHP', - python: 'Python', - ruby: 'Ruby', - go: 'Go', -} as { [key: string]: string } - type IChildrenProps = { children: React.ReactNode [key: string]: any @@ -204,8 +193,8 @@ function CodeGroupPanels({ children, targetCode, ...props }: ICodeGroupPanelsPro if ((targetCode?.length ?? 0) > 1) { return ( - {targetCode!.map(code => ( - + {targetCode!.map((code, index) => ( + ))} @@ -217,8 +206,8 @@ function CodeGroupPanels({ children, targetCode, ...props }: ICodeGroupPanelsPro } function usePreventLayoutShift() { - const positionRef = useRef() - const rafRef = useRef() + const positionRef = useRef(null) + const rafRef = useRef(null) useEffect(() => { return () => { diff --git a/web/app/components/develop/doc.tsx b/web/app/components/develop/doc.tsx index ef5e7022c1..82b6b00e44 100644 --- a/web/app/components/develop/doc.tsx +++ b/web/app/components/develop/doc.tsx @@ -168,7 +168,7 @@ const Doc = ({ appDetail }: IDocProps) => { {t('appApi.develop.toc')} -