Merge branch 'main' into feat/trigger

This commit is contained in:
Yeuoly 2025-09-01 18:05:31 +08:00
commit 676648e0b3
480 changed files with 10562 additions and 4040 deletions

View File

@ -0,0 +1,19 @@
{
"permissions": {
"allow": [],
"deny": []
},
"env": {
"__comment": "Environment variables for MCP servers. Override in .claude/settings.local.json with actual values.",
"GITHUB_PERSONAL_ACCESS_TOKEN": "ghp_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"
},
"enabledMcpjsonServers": [
"context7",
"sequential-thinking",
"github",
"fetch",
"playwright",
"ide"
],
"enableAllProjectMcpServers": true
}

View File

@ -1,6 +1,7 @@
#!/bin/bash
npm add -g pnpm@10.15.0
corepack enable
cd web && pnpm install
pipx install uv

View File

@ -1,13 +1,7 @@
name: Run Pytest
on:
pull_request:
branches:
- main
paths:
- api/**
- docker/**
- .github/workflows/api-tests.yml
workflow_call:
concurrency:
group: api-tests-${{ github.head_ref || github.run_id }}

View File

@ -1,10 +1,9 @@
name: autofix.ci
on:
workflow_call:
pull_request:
branches: [ "main" ]
branches: ["main"]
push:
branches: [ "main" ]
branches: ["main"]
permissions:
contents: read
@ -18,7 +17,7 @@ jobs:
# Use uv to ensure we have the same ruff version in CI and locally.
- uses: astral-sh/setup-uv@v6
with:
python-version: "3.12"
python-version: "3.12"
- run: |
cd api
uv sync --dev
@ -29,6 +28,7 @@ jobs:
- name: ast-grep
run: |
uvx --from ast-grep-cli sg --pattern 'db.session.query($WHATEVER).filter($HERE)' --rewrite 'db.session.query($WHATEVER).where($HERE)' -l py --update-all
uvx --from ast-grep-cli sg --pattern 'session.query($WHATEVER).filter($HERE)' --rewrite 'session.query($WHATEVER).where($HERE)' -l py --update-all
- name: mdformat
run: |
uvx mdformat .

View File

@ -1,13 +1,7 @@
name: DB Migration Test
on:
pull_request:
branches:
- main
- plugins/beta
paths:
- api/migrations/**
- .github/workflows/db-migration-test.yml
workflow_call:
concurrency:
group: db-migration-test-${{ github.ref }}
@ -33,6 +27,12 @@ jobs:
- name: Install dependencies
run: uv sync --project api
- name: Ensure Offline migration are supported
run: |
# upgrade
uv run --directory api flask db upgrade 'base:head' --sql
# downgrade
uv run --directory api flask db downgrade 'head:base' --sql
- name: Prepare middleware env
run: |

78
.github/workflows/main-ci.yml vendored Normal file
View File

@ -0,0 +1,78 @@
name: Main CI Pipeline
on:
pull_request:
branches: ["main"]
push:
branches: ["main"]
permissions:
contents: write
pull-requests: write
checks: write
statuses: write
concurrency:
group: main-ci-${{ github.head_ref || github.run_id }}
cancel-in-progress: true
jobs:
# Check which paths were changed to determine which tests to run
check-changes:
name: Check Changed Files
runs-on: ubuntu-latest
outputs:
api-changed: ${{ steps.changes.outputs.api }}
web-changed: ${{ steps.changes.outputs.web }}
vdb-changed: ${{ steps.changes.outputs.vdb }}
migration-changed: ${{ steps.changes.outputs.migration }}
steps:
- uses: actions/checkout@v4
- uses: dorny/paths-filter@v3
id: changes
with:
filters: |
api:
- 'api/**'
- 'docker/**'
- '.github/workflows/api-tests.yml'
web:
- 'web/**'
vdb:
- 'api/core/rag/datasource/**'
- 'docker/**'
- '.github/workflows/vdb-tests.yml'
- 'api/uv.lock'
- 'api/pyproject.toml'
migration:
- 'api/migrations/**'
- '.github/workflows/db-migration-test.yml'
# Run tests in parallel
api-tests:
name: API Tests
needs: check-changes
if: needs.check-changes.outputs.api-changed == 'true'
uses: ./.github/workflows/api-tests.yml
web-tests:
name: Web Tests
needs: check-changes
if: needs.check-changes.outputs.web-changed == 'true'
uses: ./.github/workflows/web-tests.yml
style-check:
name: Style Check
uses: ./.github/workflows/style.yml
vdb-tests:
name: VDB Tests
needs: check-changes
if: needs.check-changes.outputs.vdb-changed == 'true'
uses: ./.github/workflows/vdb-tests.yml
db-migration-test:
name: DB Migration Test
needs: check-changes
if: needs.check-changes.outputs.migration-changed == 'true'
uses: ./.github/workflows/db-migration-test.yml

View File

@ -1,9 +1,7 @@
name: Style check
on:
pull_request:
branches:
- main
workflow_call:
concurrency:
group: style-${{ github.head_ref || github.run_id }}
@ -46,21 +44,10 @@ jobs:
if: steps.changed-files.outputs.any_changed == 'true'
run: uv sync --project api --dev
- name: Ruff check
if: steps.changed-files.outputs.any_changed == 'true'
run: |
uv run --directory api ruff --version
uv run --directory api ruff check ./
uv run --directory api ruff format --check ./
- name: Dotenv check
if: steps.changed-files.outputs.any_changed == 'true'
run: uv run --project api dotenv-linter ./api/.env.example ./web/.env.example
- name: Lint hints
if: failure()
run: echo "Please run 'dev/reformat' to fix the fixable linting errors."
web-style:
name: Web Style
runs-on: ubuntu-latest
@ -102,7 +89,9 @@ jobs:
- name: Web style check
if: steps.changed-files.outputs.any_changed == 'true'
working-directory: ./web
run: pnpm run lint
run: |
pnpm run lint
pnpm run eslint
docker-compose-template:
name: Docker Compose Template

View File

@ -1,15 +1,7 @@
name: Run VDB Tests
on:
pull_request:
branches:
- main
paths:
- api/core/rag/datasource/**
- docker/**
- .github/workflows/vdb-tests.yml
- api/uv.lock
- api/pyproject.toml
workflow_call:
concurrency:
group: vdb-tests-${{ github.head_ref || github.run_id }}

View File

@ -1,11 +1,7 @@
name: Web Tests
on:
pull_request:
branches:
- main
paths:
- web/**
workflow_call:
concurrency:
group: web-tests-${{ github.head_ref || github.run_id }}

34
.mcp.json Normal file
View File

@ -0,0 +1,34 @@
{
"mcpServers": {
"context7": {
"type": "http",
"url": "https://mcp.context7.com/mcp"
},
"sequential-thinking": {
"type": "stdio",
"command": "npx",
"args": ["-y", "@modelcontextprotocol/server-sequential-thinking"],
"env": {}
},
"github": {
"type": "stdio",
"command": "npx",
"args": ["-y", "@modelcontextprotocol/server-github"],
"env": {
"GITHUB_PERSONAL_ACCESS_TOKEN": "${GITHUB_PERSONAL_ACCESS_TOKEN}"
}
},
"fetch": {
"type": "stdio",
"command": "uvx",
"args": ["mcp-server-fetch"],
"env": {}
},
"playwright": {
"type": "stdio",
"command": "npx",
"args": ["-y", "@playwright/mcp@latest"],
"env": {}
}
}
}

View File

@ -86,3 +86,4 @@ pnpm test # Run Jest tests
## Project-Specific Conventions
- All async tasks use Celery with Redis as broker
- **Internationalization**: Frontend supports multiple languages with English (`web/i18n/en-US/`) as the source. All user-facing text must use i18n keys, no hardcoded strings. Edit corresponding module files in `en-US/` directory for translations.

View File

@ -71,7 +71,7 @@ from .app import (
)
# Import auth controllers
from .auth import activate, data_source_bearer_auth, data_source_oauth, forgot_password, login, oauth
from .auth import activate, data_source_bearer_auth, data_source_oauth, forgot_password, login, oauth, oauth_server
# Import billing controllers
from .billing import billing, compliance

View File

@ -3,6 +3,7 @@ import logging
from flask_login import current_user
from flask_restx import Resource, fields, marshal_with, reqparse
from flask_restx.inputs import int_range
from sqlalchemy import exists, select
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
from controllers.console import api
@ -94,21 +95,22 @@ class ChatMessageListApi(Resource):
.all()
)
has_more = False
# Initialize has_more based on whether we have a full page
if len(history_messages) == args["limit"]:
current_page_first_message = history_messages[-1]
rest_count = (
db.session.query(Message)
.where(
Message.conversation_id == conversation.id,
Message.created_at < current_page_first_message.created_at,
Message.id != current_page_first_message.id,
# Check if there are more messages before the current page
has_more = db.session.scalar(
select(
exists().where(
Message.conversation_id == conversation.id,
Message.created_at < current_page_first_message.created_at,
Message.id != current_page_first_message.id,
)
)
.count()
)
if rest_count > 0:
has_more = True
else:
# If we don't have a full page, there are no more messages
has_more = False
history_messages = list(reversed(history_messages))
@ -128,7 +130,7 @@ class MessageFeedbackApi(Resource):
message_id = str(args["message_id"])
message = db.session.query(Message).filter(Message.id == message_id, Message.app_id == app_model.id).first()
message = db.session.query(Message).where(Message.id == message_id, Message.app_id == app_model.id).first()
if not message:
raise NotFound("Message Not Exists.")

View File

@ -0,0 +1,187 @@
from functools import wraps
from typing import cast
import flask_login
from flask import request
from flask_restx import Resource, reqparse
from werkzeug.exceptions import BadRequest, NotFound
from controllers.console.wraps import account_initialization_required, setup_required
from core.model_runtime.utils.encoders import jsonable_encoder
from libs.login import login_required
from models.account import Account
from models.model import OAuthProviderApp
from services.oauth_server import OAUTH_ACCESS_TOKEN_EXPIRES_IN, OAuthGrantType, OAuthServerService
from .. import api
def oauth_server_client_id_required(view):
@wraps(view)
def decorated(*args, **kwargs):
parser = reqparse.RequestParser()
parser.add_argument("client_id", type=str, required=True, location="json")
parsed_args = parser.parse_args()
client_id = parsed_args.get("client_id")
if not client_id:
raise BadRequest("client_id is required")
oauth_provider_app = OAuthServerService.get_oauth_provider_app(client_id)
if not oauth_provider_app:
raise NotFound("client_id is invalid")
kwargs["oauth_provider_app"] = oauth_provider_app
return view(*args, **kwargs)
return decorated
def oauth_server_access_token_required(view):
@wraps(view)
def decorated(*args, **kwargs):
oauth_provider_app = kwargs.get("oauth_provider_app")
if not oauth_provider_app or not isinstance(oauth_provider_app, OAuthProviderApp):
raise BadRequest("Invalid oauth_provider_app")
authorization_header = request.headers.get("Authorization")
if not authorization_header:
raise BadRequest("Authorization header is required")
parts = authorization_header.strip().split(" ")
if len(parts) != 2:
raise BadRequest("Invalid Authorization header format")
token_type = parts[0].strip()
if token_type.lower() != "bearer":
raise BadRequest("token_type is invalid")
access_token = parts[1].strip()
if not access_token:
raise BadRequest("access_token is required")
account = OAuthServerService.validate_oauth_access_token(oauth_provider_app.client_id, access_token)
if not account:
raise BadRequest("access_token or client_id is invalid")
kwargs["account"] = account
return view(*args, **kwargs)
return decorated
class OAuthServerAppApi(Resource):
@setup_required
@oauth_server_client_id_required
def post(self, oauth_provider_app: OAuthProviderApp):
parser = reqparse.RequestParser()
parser.add_argument("redirect_uri", type=str, required=True, location="json")
parsed_args = parser.parse_args()
redirect_uri = parsed_args.get("redirect_uri")
# check if redirect_uri is valid
if redirect_uri not in oauth_provider_app.redirect_uris:
raise BadRequest("redirect_uri is invalid")
return jsonable_encoder(
{
"app_icon": oauth_provider_app.app_icon,
"app_label": oauth_provider_app.app_label,
"scope": oauth_provider_app.scope,
}
)
class OAuthServerUserAuthorizeApi(Resource):
@setup_required
@login_required
@account_initialization_required
@oauth_server_client_id_required
def post(self, oauth_provider_app: OAuthProviderApp):
account = cast(Account, flask_login.current_user)
user_account_id = account.id
code = OAuthServerService.sign_oauth_authorization_code(oauth_provider_app.client_id, user_account_id)
return jsonable_encoder(
{
"code": code,
}
)
class OAuthServerUserTokenApi(Resource):
@setup_required
@oauth_server_client_id_required
def post(self, oauth_provider_app: OAuthProviderApp):
parser = reqparse.RequestParser()
parser.add_argument("grant_type", type=str, required=True, location="json")
parser.add_argument("code", type=str, required=False, location="json")
parser.add_argument("client_secret", type=str, required=False, location="json")
parser.add_argument("redirect_uri", type=str, required=False, location="json")
parser.add_argument("refresh_token", type=str, required=False, location="json")
parsed_args = parser.parse_args()
try:
grant_type = OAuthGrantType(parsed_args["grant_type"])
except ValueError:
raise BadRequest("invalid grant_type")
if grant_type == OAuthGrantType.AUTHORIZATION_CODE:
if not parsed_args["code"]:
raise BadRequest("code is required")
if parsed_args["client_secret"] != oauth_provider_app.client_secret:
raise BadRequest("client_secret is invalid")
if parsed_args["redirect_uri"] not in oauth_provider_app.redirect_uris:
raise BadRequest("redirect_uri is invalid")
access_token, refresh_token = OAuthServerService.sign_oauth_access_token(
grant_type, code=parsed_args["code"], client_id=oauth_provider_app.client_id
)
return jsonable_encoder(
{
"access_token": access_token,
"token_type": "Bearer",
"expires_in": OAUTH_ACCESS_TOKEN_EXPIRES_IN,
"refresh_token": refresh_token,
}
)
elif grant_type == OAuthGrantType.REFRESH_TOKEN:
if not parsed_args["refresh_token"]:
raise BadRequest("refresh_token is required")
access_token, refresh_token = OAuthServerService.sign_oauth_access_token(
grant_type, refresh_token=parsed_args["refresh_token"], client_id=oauth_provider_app.client_id
)
return jsonable_encoder(
{
"access_token": access_token,
"token_type": "Bearer",
"expires_in": OAUTH_ACCESS_TOKEN_EXPIRES_IN,
"refresh_token": refresh_token,
}
)
class OAuthServerUserAccountApi(Resource):
@setup_required
@oauth_server_client_id_required
@oauth_server_access_token_required
def post(self, oauth_provider_app: OAuthProviderApp, account: Account):
return jsonable_encoder(
{
"name": account.name,
"email": account.email,
"avatar": account.avatar,
"interface_language": account.interface_language,
"timezone": account.timezone,
}
)
api.add_resource(OAuthServerAppApi, "/oauth/provider")
api.add_resource(OAuthServerUserAuthorizeApi, "/oauth/provider/authorize")
api.add_resource(OAuthServerUserTokenApi, "/oauth/provider/token")
api.add_resource(OAuthServerUserAccountApi, "/oauth/provider/account")

View File

@ -1,8 +1,12 @@
from base64 import b64encode
from collections.abc import Callable
from functools import wraps
from hashlib import sha1
from hmac import new as hmac_new
from typing import ParamSpec, TypeVar
P = ParamSpec("P")
R = TypeVar("R")
from flask import abort, request
from configs import dify_config
@ -10,9 +14,9 @@ from extensions.ext_database import db
from models.model import EndUser
def billing_inner_api_only(view):
def billing_inner_api_only(view: Callable[P, R]):
@wraps(view)
def decorated(*args, **kwargs):
def decorated(*args: P.args, **kwargs: P.kwargs):
if not dify_config.INNER_API:
abort(404)
@ -26,9 +30,9 @@ def billing_inner_api_only(view):
return decorated
def enterprise_inner_api_only(view):
def enterprise_inner_api_only(view: Callable[P, R]):
@wraps(view)
def decorated(*args, **kwargs):
def decorated(*args: P.args, **kwargs: P.kwargs):
if not dify_config.INNER_API:
abort(404)
@ -78,9 +82,9 @@ def enterprise_inner_api_user_auth(view):
return decorated
def plugin_inner_api_only(view):
def plugin_inner_api_only(view: Callable[P, R]):
@wraps(view)
def decorated(*args, **kwargs):
def decorated(*args: P.args, **kwargs: P.kwargs):
if not dify_config.PLUGIN_DAEMON_KEY:
abort(404)

View File

@ -1,20 +1,28 @@
from typing import Optional, Union
from flask import Response
from flask_restx import Resource, reqparse
from pydantic import ValidationError
from sqlalchemy.orm import Session
from controllers.console.app.mcp_server import AppMCPServerStatus
from controllers.mcp import mcp_ns
from core.app.app_config.entities import VariableEntity
from core.mcp import types
from core.mcp.server.streamable_http import MCPServerStreamableHTTPRequestHandler
from core.mcp.types import ClientNotification, ClientRequest
from core.mcp.utils import create_mcp_error_response
from core.mcp import types as mcp_types
from extensions.ext_database import db
from libs import helper
from models.model import App, AppMCPServer, AppMode
class MCPRequestError(Exception):
"""Custom exception for MCP request processing errors"""
def __init__(self, error_code: int, message: str):
self.error_code = error_code
self.message = message
super().__init__(message)
def int_or_str(value):
"""Validate that a value is either an integer or string."""
if isinstance(value, (int, str)):
@ -63,76 +71,128 @@ class MCPAppApi(Resource):
Raises:
ValidationError: Invalid request format or parameters
"""
# Parse and validate all arguments
args = mcp_request_parser.parse_args()
request_id: Optional[Union[int, str]] = args.get("id")
mcp_request = self._parse_mcp_request(args)
server = db.session.query(AppMCPServer).where(AppMCPServer.server_code == server_code).first()
if not server:
return helper.compact_generate_response(
create_mcp_error_response(request_id, types.INVALID_REQUEST, "Server Not Found")
)
with Session(db.engine, expire_on_commit=False) as session:
# Get MCP server and app
mcp_server, app = self._get_mcp_server_and_app(server_code, session)
self._validate_server_status(mcp_server)
if server.status != AppMCPServerStatus.ACTIVE:
return helper.compact_generate_response(
create_mcp_error_response(request_id, types.INVALID_REQUEST, "Server is not active")
)
# Get user input form
user_input_form = self._get_user_input_form(app)
app = db.session.query(App).where(App.id == server.app_id).first()
# Handle notification vs request differently
return self._process_mcp_message(mcp_request, request_id, app, mcp_server, user_input_form, session)
def _get_mcp_server_and_app(self, server_code: str, session: Session) -> tuple[AppMCPServer, App]:
"""Get and validate MCP server and app in one query session"""
mcp_server = session.query(AppMCPServer).where(AppMCPServer.server_code == server_code).first()
if not mcp_server:
raise MCPRequestError(mcp_types.INVALID_REQUEST, "Server Not Found")
app = session.query(App).where(App.id == mcp_server.app_id).first()
if not app:
return helper.compact_generate_response(
create_mcp_error_response(request_id, types.INVALID_REQUEST, "App Not Found")
)
raise MCPRequestError(mcp_types.INVALID_REQUEST, "App Not Found")
if app.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
workflow = app.workflow
if workflow is None:
return helper.compact_generate_response(
create_mcp_error_response(request_id, types.INVALID_REQUEST, "App is unavailable")
)
return mcp_server, app
user_input_form = workflow.user_input_form(to_old_structure=True)
def _validate_server_status(self, mcp_server: AppMCPServer) -> None:
"""Validate MCP server status"""
if mcp_server.status != AppMCPServerStatus.ACTIVE:
raise MCPRequestError(mcp_types.INVALID_REQUEST, "Server is not active")
def _process_mcp_message(
self,
mcp_request: mcp_types.ClientRequest | mcp_types.ClientNotification,
request_id: Optional[Union[int, str]],
app: App,
mcp_server: AppMCPServer,
user_input_form: list[VariableEntity],
session: Session,
) -> Response:
"""Process MCP message (notification or request)"""
if isinstance(mcp_request, mcp_types.ClientNotification):
return self._handle_notification(mcp_request)
else:
app_model_config = app.app_model_config
if app_model_config is None:
return helper.compact_generate_response(
create_mcp_error_response(request_id, types.INVALID_REQUEST, "App is unavailable")
)
return self._handle_request(mcp_request, request_id, app, mcp_server, user_input_form, session)
features_dict = app_model_config.to_dict()
user_input_form = features_dict.get("user_input_form", [])
converted_user_input_form: list[VariableEntity] = []
try:
for item in user_input_form:
variable_type = item.get("type", "") or list(item.keys())[0]
variable = item[variable_type]
converted_user_input_form.append(
VariableEntity(
type=variable_type,
variable=variable.get("variable"),
description=variable.get("description") or "",
label=variable.get("label"),
required=variable.get("required", False),
max_length=variable.get("max_length"),
options=variable.get("options") or [],
)
)
except ValidationError as e:
return helper.compact_generate_response(
create_mcp_error_response(request_id, types.INVALID_PARAMS, f"Invalid user_input_form: {str(e)}")
)
def _handle_notification(self, mcp_request: mcp_types.ClientNotification) -> Response:
"""Handle MCP notification"""
# For notifications, only support init notification
if mcp_request.root.method != "notifications/initialized":
raise MCPRequestError(mcp_types.INVALID_REQUEST, "Invalid notification method")
# Return HTTP 202 Accepted for notifications (no response body)
return Response("", status=202, content_type="application/json")
def _handle_request(
self,
mcp_request: mcp_types.ClientRequest,
request_id: Optional[Union[int, str]],
app: App,
mcp_server: AppMCPServer,
user_input_form: list[VariableEntity],
session: Session,
) -> Response:
"""Handle MCP request"""
if request_id is None:
raise MCPRequestError(mcp_types.INVALID_REQUEST, "Request ID is required")
result = self._handle_mcp_request(app, mcp_server, mcp_request, user_input_form, session, request_id)
if result is None:
# This shouldn't happen for requests, but handle gracefully
raise MCPRequestError(mcp_types.INTERNAL_ERROR, "No response generated for request")
return helper.compact_generate_response(result.model_dump(by_alias=True, mode="json", exclude_none=True))
def _get_user_input_form(self, app: App) -> list[VariableEntity]:
"""Get and convert user input form"""
# Get raw user input form based on app mode
if app.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
if not app.workflow:
raise MCPRequestError(mcp_types.INVALID_REQUEST, "App is unavailable")
raw_user_input_form = app.workflow.user_input_form(to_old_structure=True)
else:
if not app.app_model_config:
raise MCPRequestError(mcp_types.INVALID_REQUEST, "App is unavailable")
features_dict = app.app_model_config.to_dict()
raw_user_input_form = features_dict.get("user_input_form", [])
# Convert to VariableEntity objects
try:
request: ClientRequest | ClientNotification = ClientRequest.model_validate(args)
return self._convert_user_input_form(raw_user_input_form)
except ValidationError as e:
raise MCPRequestError(mcp_types.INVALID_PARAMS, f"Invalid user_input_form: {str(e)}")
def _convert_user_input_form(self, raw_form: list[dict]) -> list[VariableEntity]:
"""Convert raw user input form to VariableEntity objects"""
return [self._create_variable_entity(item) for item in raw_form]
def _create_variable_entity(self, item: dict) -> VariableEntity:
"""Create a single VariableEntity from raw form item"""
variable_type = item.get("type", "") or list(item.keys())[0]
variable = item[variable_type]
return VariableEntity(
type=variable_type,
variable=variable.get("variable"),
description=variable.get("description") or "",
label=variable.get("label"),
required=variable.get("required", False),
max_length=variable.get("max_length"),
options=variable.get("options") or [],
)
def _parse_mcp_request(self, args: dict) -> mcp_types.ClientRequest | mcp_types.ClientNotification:
"""Parse and validate MCP request"""
try:
return mcp_types.ClientRequest.model_validate(args)
except ValidationError:
try:
notification = ClientNotification.model_validate(args)
request = notification
return mcp_types.ClientNotification.model_validate(args)
except ValidationError as e:
return helper.compact_generate_response(
create_mcp_error_response(request_id, types.INVALID_PARAMS, f"Invalid MCP request: {str(e)}")
)
raise MCPRequestError(mcp_types.INVALID_PARAMS, f"Invalid MCP request: {str(e)}")
mcp_server_handler = MCPServerStreamableHTTPRequestHandler(app, request, converted_user_input_form)
response = mcp_server_handler.handle()

View File

@ -318,10 +318,6 @@ class DatasetApi(DatasetApiResource):
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
data = marshal(dataset, dataset_detail_fields)
if data.get("permission") == "partial_members":
part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
data.update({"partial_member_list": part_users_list})
# check embedding setting
provider_manager = ProviderManager()
assert isinstance(current_user, Account)

View File

@ -1,6 +1,6 @@
from typing import Literal
from flask_login import current_user # type: ignore
from flask_login import current_user
from flask_restx import marshal, reqparse
from werkzeug.exceptions import NotFound

View File

@ -1,12 +1,12 @@
import time
from collections.abc import Callable
from datetime import timedelta
from enum import Enum
from enum import StrEnum, auto
from functools import wraps
from typing import Optional
from flask import current_app, request
from flask_login import user_logged_in # type: ignore
from flask_login import user_logged_in
from flask_restx import Resource
from pydantic import BaseModel
from sqlalchemy import select, update
@ -23,14 +23,14 @@ from models.model import ApiToken, App, EndUser
from services.feature_service import FeatureService
class WhereisUserArg(Enum):
class WhereisUserArg(StrEnum):
"""
Enum for whereis_user_arg.
"""
QUERY = "query"
JSON = "json"
FORM = "form"
QUERY = auto()
JSON = auto()
FORM = auto()
class FetchUserArg(BaseModel):

View File

@ -1,19 +1,20 @@
from flask import Blueprint
from flask_restx import Namespace
from libs.external_api import ExternalApi
from .files import FileApi
from .remote_files import RemoteFileInfoApi, RemoteFileUploadApi
bp = Blueprint("web", __name__, url_prefix="/api")
api = ExternalApi(bp)
# Files
api.add_resource(FileApi, "/files/upload")
api = ExternalApi(
bp,
version="1.0",
title="Web API",
description="Public APIs for web applications including file uploads, chat interactions, and app management",
doc="/docs", # Enable Swagger UI at /api/docs
)
# Remote files
api.add_resource(RemoteFileInfoApi, "/remote-files/<path:url>")
api.add_resource(RemoteFileUploadApi, "/remote-files/upload")
# Create namespace
web_ns = Namespace("web", description="Web application API operations", path="/")
from . import (
app,
@ -21,11 +22,15 @@ from . import (
completion,
conversation,
feature,
files,
forgot_password,
login,
message,
passport,
remote_files,
saved_message,
site,
workflow,
)
api.add_namespace(web_ns)

View File

@ -5,7 +5,7 @@ from flask_restx import Resource, marshal_with, reqparse
from werkzeug.exceptions import Unauthorized
from controllers.common import fields
from controllers.web import api
from controllers.web import web_ns
from controllers.web.error import AppUnavailableError
from controllers.web.wraps import WebApiResource
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
@ -19,9 +19,22 @@ from services.webapp_auth_service import WebAppAuthService
logger = logging.getLogger(__name__)
@web_ns.route("/parameters")
class AppParameterApi(WebApiResource):
"""Resource for app variables."""
@web_ns.doc("Get App Parameters")
@web_ns.doc(description="Retrieve the parameters for a specific app.")
@web_ns.doc(
responses={
200: "Success",
400: "Bad Request",
401: "Unauthorized",
403: "Forbidden",
404: "App Not Found",
500: "Internal Server Error",
}
)
@marshal_with(fields.parameters_fields)
def get(self, app_model: App, end_user):
"""Retrieve app parameters."""
@ -44,13 +57,42 @@ class AppParameterApi(WebApiResource):
return get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
@web_ns.route("/meta")
class AppMeta(WebApiResource):
@web_ns.doc("Get App Meta")
@web_ns.doc(description="Retrieve the metadata for a specific app.")
@web_ns.doc(
responses={
200: "Success",
400: "Bad Request",
401: "Unauthorized",
403: "Forbidden",
404: "App Not Found",
500: "Internal Server Error",
}
)
def get(self, app_model: App, end_user):
"""Get app meta"""
return AppService().get_app_meta(app_model)
@web_ns.route("/webapp/access-mode")
class AppAccessMode(Resource):
@web_ns.doc("Get App Access Mode")
@web_ns.doc(description="Retrieve the access mode for a web application (public or restricted).")
@web_ns.doc(
params={
"appId": {"description": "Application ID", "type": "string", "required": False},
"appCode": {"description": "Application code", "type": "string", "required": False},
}
)
@web_ns.doc(
responses={
200: "Success",
400: "Bad Request",
500: "Internal Server Error",
}
)
def get(self):
parser = reqparse.RequestParser()
parser.add_argument("appId", type=str, required=False, location="args")
@ -74,7 +116,19 @@ class AppAccessMode(Resource):
return {"accessMode": res.access_mode}
@web_ns.route("/webapp/permission")
class AppWebAuthPermission(Resource):
@web_ns.doc("Check App Permission")
@web_ns.doc(description="Check if user has permission to access a web application.")
@web_ns.doc(params={"appId": {"description": "Application ID", "type": "string", "required": True}})
@web_ns.doc(
responses={
200: "Success",
400: "Bad Request",
401: "Unauthorized",
500: "Internal Server Error",
}
)
def get(self):
user_id = "visitor"
try:
@ -112,10 +166,3 @@ class AppWebAuthPermission(Resource):
if WebAppAuthService.is_app_require_permission_check(app_id=app_id):
res = EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(str(user_id), app_code)
return {"result": res}
api.add_resource(AppParameterApi, "/parameters")
api.add_resource(AppMeta, "/meta")
# webapp auth apis
api.add_resource(AppAccessMode, "/webapp/access-mode")
api.add_resource(AppWebAuthPermission, "/webapp/permission")

View File

@ -1,6 +1,7 @@
import logging
from flask import request
from flask_restx import fields, marshal_with, reqparse
from werkzeug.exceptions import InternalServerError
import services
@ -32,7 +33,26 @@ logger = logging.getLogger(__name__)
class AudioApi(WebApiResource):
audio_to_text_response_fields = {
"text": fields.String,
}
@marshal_with(audio_to_text_response_fields)
@api.doc("Audio to Text")
@api.doc(description="Convert audio file to text using speech-to-text service.")
@api.doc(
responses={
200: "Success",
400: "Bad Request",
401: "Unauthorized",
403: "Forbidden",
413: "Audio file too large",
415: "Unsupported audio type",
500: "Internal Server Error",
}
)
def post(self, app_model: App, end_user):
"""Convert audio to text"""
file = request.files["file"]
try:
@ -66,9 +86,25 @@ class AudioApi(WebApiResource):
class TextApi(WebApiResource):
def post(self, app_model: App, end_user):
from flask_restx import reqparse
text_to_audio_response_fields = {
"audio_url": fields.String,
"duration": fields.Float,
}
@marshal_with(text_to_audio_response_fields)
@api.doc("Text to Audio")
@api.doc(description="Convert text to audio using text-to-speech service.")
@api.doc(
responses={
200: "Success",
400: "Bad Request",
401: "Unauthorized",
403: "Forbidden",
500: "Internal Server Error",
}
)
def post(self, app_model: App, end_user):
"""Convert text to audio"""
try:
parser = reqparse.RequestParser()
parser.add_argument("message_id", type=str, required=False, location="json")

View File

@ -36,6 +36,32 @@ logger = logging.getLogger(__name__)
# define completion api for user
class CompletionApi(WebApiResource):
@api.doc("Create Completion Message")
@api.doc(description="Create a completion message for text generation applications.")
@api.doc(
params={
"inputs": {"description": "Input variables for the completion", "type": "object", "required": True},
"query": {"description": "Query text for completion", "type": "string", "required": False},
"files": {"description": "Files to be processed", "type": "array", "required": False},
"response_mode": {
"description": "Response mode: blocking or streaming",
"type": "string",
"enum": ["blocking", "streaming"],
"required": False,
},
"retriever_from": {"description": "Source of retriever", "type": "string", "required": False},
}
)
@api.doc(
responses={
200: "Success",
400: "Bad Request",
401: "Unauthorized",
403: "Forbidden",
404: "App Not Found",
500: "Internal Server Error",
}
)
def post(self, app_model, end_user):
if app_model.mode != "completion":
raise NotCompletionAppError()
@ -81,6 +107,19 @@ class CompletionApi(WebApiResource):
class CompletionStopApi(WebApiResource):
@api.doc("Stop Completion Message")
@api.doc(description="Stop a running completion message task.")
@api.doc(params={"task_id": {"description": "Task ID to stop", "type": "string", "required": True}})
@api.doc(
responses={
200: "Success",
400: "Bad Request",
401: "Unauthorized",
403: "Forbidden",
404: "Task Not Found",
500: "Internal Server Error",
}
)
def post(self, app_model, end_user, task_id):
if app_model.mode != "completion":
raise NotCompletionAppError()
@ -91,6 +130,34 @@ class CompletionStopApi(WebApiResource):
class ChatApi(WebApiResource):
@api.doc("Create Chat Message")
@api.doc(description="Create a chat message for conversational applications.")
@api.doc(
params={
"inputs": {"description": "Input variables for the chat", "type": "object", "required": True},
"query": {"description": "User query/message", "type": "string", "required": True},
"files": {"description": "Files to be processed", "type": "array", "required": False},
"response_mode": {
"description": "Response mode: blocking or streaming",
"type": "string",
"enum": ["blocking", "streaming"],
"required": False,
},
"conversation_id": {"description": "Conversation UUID", "type": "string", "required": False},
"parent_message_id": {"description": "Parent message UUID", "type": "string", "required": False},
"retriever_from": {"description": "Source of retriever", "type": "string", "required": False},
}
)
@api.doc(
responses={
200: "Success",
400: "Bad Request",
401: "Unauthorized",
403: "Forbidden",
404: "App Not Found",
500: "Internal Server Error",
}
)
def post(self, app_model, end_user):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
@ -141,6 +208,19 @@ class ChatApi(WebApiResource):
class ChatStopApi(WebApiResource):
@api.doc("Stop Chat Message")
@api.doc(description="Stop a running chat message task.")
@api.doc(params={"task_id": {"description": "Task ID to stop", "type": "string", "required": True}})
@api.doc(
responses={
200: "Success",
400: "Bad Request",
401: "Unauthorized",
403: "Forbidden",
404: "Task Not Found",
500: "Internal Server Error",
}
)
def post(self, app_model, end_user, task_id):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:

View File

@ -1,4 +1,4 @@
from flask_restx import marshal_with, reqparse
from flask_restx import fields, marshal_with, reqparse
from flask_restx.inputs import int_range
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound
@ -58,6 +58,11 @@ class ConversationListApi(WebApiResource):
class ConversationApi(WebApiResource):
delete_response_fields = {
"result": fields.String,
}
@marshal_with(delete_response_fields)
def delete(self, app_model, end_user, c_id):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
@ -94,6 +99,11 @@ class ConversationRenameApi(WebApiResource):
class ConversationPinApi(WebApiResource):
pin_response_fields = {
"result": fields.String,
}
@marshal_with(pin_response_fields)
def patch(self, app_model, end_user, c_id):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
@ -110,6 +120,11 @@ class ConversationPinApi(WebApiResource):
class ConversationUnPinApi(WebApiResource):
unpin_response_fields = {
"result": fields.String,
}
@marshal_with(unpin_response_fields)
def patch(self, app_model, end_user, c_id):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:

View File

@ -1,12 +1,21 @@
from flask_restx import Resource
from controllers.web import api
from controllers.web import web_ns
from services.feature_service import FeatureService
@web_ns.route("/system-features")
class SystemFeatureApi(Resource):
@web_ns.doc("get_system_features")
@web_ns.doc(description="Get system feature flags and configuration")
@web_ns.doc(responses={200: "System features retrieved successfully", 500: "Internal server error"})
def get(self):
"""Get system feature flags and configuration.
Returns the current system feature flags and configuration
that control various functionalities across the platform.
Returns:
dict: System feature configuration object
"""
return FeatureService.get_system_features().model_dump()
api.add_resource(SystemFeatureApi, "/system-features")

View File

@ -9,14 +9,50 @@ from controllers.common.errors import (
TooManyFilesError,
UnsupportedFileTypeError,
)
from controllers.web import web_ns
from controllers.web.wraps import WebApiResource
from fields.file_fields import file_fields
from fields.file_fields import build_file_model
from services.file_service import FileService
@web_ns.route("/files/upload")
class FileApi(WebApiResource):
@marshal_with(file_fields)
@web_ns.doc("upload_file")
@web_ns.doc(description="Upload a file for use in web applications")
@web_ns.doc(
responses={
201: "File uploaded successfully",
400: "Bad request - invalid file or parameters",
413: "File too large",
415: "Unsupported file type",
}
)
@marshal_with(build_file_model(web_ns))
def post(self, app_model, end_user):
"""Upload a file for use in web applications.
Accepts file uploads for use within web applications, supporting
multiple file types with automatic validation and storage.
Args:
app_model: The associated application model
end_user: The end user uploading the file
Form Parameters:
file: The file to upload (required)
source: Optional source type (datasets or None)
Returns:
dict: File information including ID, URL, and metadata
int: HTTP status code 201 for success
Raises:
NoFileUploadedError: No file provided in request
TooManyFilesError: Multiple files provided (only one allowed)
FilenameNotExistsError: File has no filename
FileTooLargeError: File exceeds size limit
UnsupportedFileTypeError: File type not supported
"""
if "file" not in request.files:
raise NoFileUploadedError()

View File

@ -16,7 +16,7 @@ from controllers.console.auth.error import (
)
from controllers.console.error import EmailSendIpLimitError
from controllers.console.wraps import email_password_login_enabled, only_edition_enterprise, setup_required
from controllers.web import api
from controllers.web import web_ns
from extensions.ext_database import db
from libs.helper import email, extract_remote_ip
from libs.password import hash_password, valid_password
@ -24,10 +24,21 @@ from models.account import Account
from services.account_service import AccountService
@web_ns.route("/forgot-password")
class ForgotPasswordSendEmailApi(Resource):
@only_edition_enterprise
@setup_required
@email_password_login_enabled
@web_ns.doc("send_forgot_password_email")
@web_ns.doc(description="Send password reset email")
@web_ns.doc(
responses={
200: "Password reset email sent successfully",
400: "Bad request - invalid email format",
404: "Account not found",
429: "Too many requests - rate limit exceeded",
}
)
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("email", type=email, required=True, location="json")
@ -54,10 +65,16 @@ class ForgotPasswordSendEmailApi(Resource):
return {"result": "success", "data": token}
@web_ns.route("/forgot-password/validity")
class ForgotPasswordCheckApi(Resource):
@only_edition_enterprise
@setup_required
@email_password_login_enabled
@web_ns.doc("check_forgot_password_token")
@web_ns.doc(description="Verify password reset token validity")
@web_ns.doc(
responses={200: "Token is valid", 400: "Bad request - invalid token format", 401: "Invalid or expired token"}
)
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("email", type=str, required=True, location="json")
@ -94,10 +111,21 @@ class ForgotPasswordCheckApi(Resource):
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
@web_ns.route("/forgot-password/resets")
class ForgotPasswordResetApi(Resource):
@only_edition_enterprise
@setup_required
@email_password_login_enabled
@web_ns.doc("reset_password")
@web_ns.doc(description="Reset user password with verification token")
@web_ns.doc(
responses={
200: "Password reset successfully",
400: "Bad request - invalid parameters or password mismatch",
401: "Invalid or expired token",
404: "Account not found",
}
)
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("token", type=str, required=True, nullable=False, location="json")
@ -141,8 +169,3 @@ class ForgotPasswordResetApi(Resource):
account.password = base64.b64encode(password_hashed).decode()
account.password_salt = base64.b64encode(salt).decode()
session.commit()
api.add_resource(ForgotPasswordSendEmailApi, "/forgot-password")
api.add_resource(ForgotPasswordCheckApi, "/forgot-password/validity")
api.add_resource(ForgotPasswordResetApi, "/forgot-password/resets")

View File

@ -1,5 +1,5 @@
from flask_restx import Resource, reqparse
from jwt import InvalidTokenError # type: ignore
from jwt import InvalidTokenError
import services
from controllers.console.auth.error import (
@ -9,18 +9,30 @@ from controllers.console.auth.error import (
)
from controllers.console.error import AccountBannedError
from controllers.console.wraps import only_edition_enterprise, setup_required
from controllers.web import api
from controllers.web import web_ns
from libs.helper import email
from libs.password import valid_password
from services.account_service import AccountService
from services.webapp_auth_service import WebAppAuthService
@web_ns.route("/login")
class LoginApi(Resource):
"""Resource for web app email/password login."""
@setup_required
@only_edition_enterprise
@web_ns.doc("web_app_login")
@web_ns.doc(description="Authenticate user for web application access")
@web_ns.doc(
responses={
200: "Authentication successful",
400: "Bad request - invalid email or password format",
401: "Authentication failed - email or password mismatch",
403: "Account banned or login disabled",
404: "Account not found",
}
)
def post(self):
"""Authenticate user and login."""
parser = reqparse.RequestParser()
@ -51,9 +63,19 @@ class LoginApi(Resource):
# return {"result": "success"}
@web_ns.route("/email-code-login")
class EmailCodeLoginSendEmailApi(Resource):
@setup_required
@only_edition_enterprise
@web_ns.doc("send_email_code_login")
@web_ns.doc(description="Send email verification code for login")
@web_ns.doc(
responses={
200: "Email code sent successfully",
400: "Bad request - invalid email format",
404: "Account not found",
}
)
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("email", type=email, required=True, location="json")
@ -74,9 +96,20 @@ class EmailCodeLoginSendEmailApi(Resource):
return {"result": "success", "data": token}
@web_ns.route("/email-code-login/validity")
class EmailCodeLoginApi(Resource):
@setup_required
@only_edition_enterprise
@web_ns.doc("verify_email_code_login")
@web_ns.doc(description="Verify email code and complete login")
@web_ns.doc(
responses={
200: "Email code verified and login successful",
400: "Bad request - invalid code or token",
401: "Invalid token or expired code",
404: "Account not found",
}
)
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("email", type=str, required=True, location="json")
@ -104,9 +137,3 @@ class EmailCodeLoginApi(Resource):
token = WebAppAuthService.login(account=account)
AccountService.reset_login_error_rate_limit(args["email"])
return {"result": "success", "data": {"access_token": token}}
api.add_resource(LoginApi, "/login")
# api.add_resource(LogoutApi, "/logout")
api.add_resource(EmailCodeLoginSendEmailApi, "/email-code-login")
api.add_resource(EmailCodeLoginApi, "/email-code-login/validity")

View File

@ -85,6 +85,11 @@ class MessageListApi(WebApiResource):
class MessageFeedbackApi(WebApiResource):
feedback_response_fields = {
"result": fields.String,
}
@marshal_with(feedback_response_fields)
def post(self, app_model, end_user, message_id):
message_id = str(message_id)
@ -152,6 +157,11 @@ class MessageMoreLikeThisApi(WebApiResource):
class MessageSuggestedQuestionApi(WebApiResource):
suggested_questions_response_fields = {
"data": fields.List(fields.String),
}
@marshal_with(suggested_questions_response_fields)
def get(self, app_model, end_user, message_id):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:

View File

@ -7,7 +7,7 @@ from sqlalchemy import func, select
from werkzeug.exceptions import NotFound, Unauthorized
from configs import dify_config
from controllers.web import api
from controllers.web import web_ns
from controllers.web.error import WebAppAuthRequiredError
from extensions.ext_database import db
from libs.passport import PassportService
@ -17,9 +17,19 @@ from services.feature_service import FeatureService
from services.webapp_auth_service import WebAppAuthService, WebAppAuthType
@web_ns.route("/passport")
class PassportResource(Resource):
"""Base resource for passport."""
@web_ns.doc("get_passport")
@web_ns.doc(description="Get authentication passport for web application access")
@web_ns.doc(
responses={
200: "Passport retrieved successfully",
401: "Unauthorized - missing app code or invalid authentication",
404: "Application or user not found",
}
)
def get(self):
system_features = FeatureService.get_system_features()
app_code = request.headers.get("X-App-Code")
@ -94,9 +104,6 @@ class PassportResource(Resource):
}
api.add_resource(PassportResource, "/passport")
def decode_enterprise_webapp_user_id(jwt_token: str | None):
"""
Decode the enterprise user session from the Authorization header.

View File

@ -10,16 +10,44 @@ from controllers.common.errors import (
RemoteFileUploadError,
UnsupportedFileTypeError,
)
from controllers.web import web_ns
from controllers.web.wraps import WebApiResource
from core.file import helpers as file_helpers
from core.helper import ssrf_proxy
from fields.file_fields import file_fields_with_signed_url, remote_file_info_fields
from fields.file_fields import build_file_with_signed_url_model, build_remote_file_info_model
from services.file_service import FileService
@web_ns.route("/remote-files/<path:url>")
class RemoteFileInfoApi(WebApiResource):
@marshal_with(remote_file_info_fields)
@web_ns.doc("get_remote_file_info")
@web_ns.doc(description="Get information about a remote file")
@web_ns.doc(
responses={
200: "Remote file information retrieved successfully",
400: "Bad request - invalid URL",
404: "Remote file not found",
500: "Failed to fetch remote file",
}
)
@marshal_with(build_remote_file_info_model(web_ns))
def get(self, app_model, end_user, url):
"""Get information about a remote file.
Retrieves basic information about a file located at a remote URL,
including content type and content length.
Args:
app_model: The associated application model
end_user: The end user making the request
url: URL-encoded path to the remote file
Returns:
dict: Remote file information including type and length
Raises:
HTTPException: If the remote file cannot be accessed
"""
decoded_url = urllib.parse.unquote(url)
resp = ssrf_proxy.head(decoded_url)
if resp.status_code != httpx.codes.OK:
@ -32,9 +60,42 @@ class RemoteFileInfoApi(WebApiResource):
}
@web_ns.route("/remote-files/upload")
class RemoteFileUploadApi(WebApiResource):
@marshal_with(file_fields_with_signed_url)
def post(self, app_model, end_user): # Add app_model and end_user parameters
@web_ns.doc("upload_remote_file")
@web_ns.doc(description="Upload a file from a remote URL")
@web_ns.doc(
responses={
201: "Remote file uploaded successfully",
400: "Bad request - invalid URL or parameters",
413: "File too large",
415: "Unsupported file type",
500: "Failed to fetch remote file",
}
)
@marshal_with(build_file_with_signed_url_model(web_ns))
def post(self, app_model, end_user):
"""Upload a file from a remote URL.
Downloads a file from the provided remote URL and uploads it
to the platform storage for use in web applications.
Args:
app_model: The associated application model
end_user: The end user making the request
JSON Parameters:
url: The remote URL to download the file from (required)
Returns:
dict: File information including ID, signed URL, and metadata
int: HTTP status code 201 for success
Raises:
RemoteFileUploadError: Failed to fetch file from remote URL
FileTooLargeError: File exceeds size limit
UnsupportedFileTypeError: File type not supported
"""
parser = reqparse.RequestParser()
parser.add_argument("url", type=str, required=True, help="URL is required")
args = parser.parse_args()

View File

@ -30,6 +30,10 @@ class SavedMessageListApi(WebApiResource):
"data": fields.List(fields.Nested(message_fields)),
}
post_response_fields = {
"result": fields.String,
}
@marshal_with(saved_message_infinite_scroll_pagination_fields)
def get(self, app_model, end_user):
if app_model.mode != "completion":
@ -42,6 +46,7 @@ class SavedMessageListApi(WebApiResource):
return SavedMessageService.pagination_by_last_id(app_model, end_user, args["last_id"], args["limit"])
@marshal_with(post_response_fields)
def post(self, app_model, end_user):
if app_model.mode != "completion":
raise NotCompletionAppError()
@ -59,6 +64,11 @@ class SavedMessageListApi(WebApiResource):
class SavedMessageApi(WebApiResource):
delete_response_fields = {
"result": fields.String,
}
@marshal_with(delete_response_fields)
def delete(self, app_model, end_user, message_id):
message_id = str(message_id)

View File

@ -53,6 +53,18 @@ class AppSiteApi(WebApiResource):
"custom_config": fields.Raw(attribute="custom_config"),
}
@api.doc("Get App Site Info")
@api.doc(description="Retrieve app site information and configuration.")
@api.doc(
responses={
200: "Success",
400: "Bad Request",
401: "Unauthorized",
403: "Forbidden",
404: "App Not Found",
500: "Internal Server Error",
}
)
@marshal_with(app_fields)
def get(self, app_model, end_user):
"""Retrieve app site info."""

View File

@ -30,6 +30,24 @@ logger = logging.getLogger(__name__)
class WorkflowRunApi(WebApiResource):
@api.doc("Run Workflow")
@api.doc(description="Execute a workflow with provided inputs and files.")
@api.doc(
params={
"inputs": {"description": "Input variables for the workflow", "type": "object", "required": True},
"files": {"description": "Files to be processed by the workflow", "type": "array", "required": False},
}
)
@api.doc(
responses={
200: "Success",
400: "Bad Request",
401: "Unauthorized",
403: "Forbidden",
404: "App Not Found",
500: "Internal Server Error",
}
)
def post(self, app_model: App, end_user: EndUser):
"""
Run workflow
@ -67,6 +85,23 @@ class WorkflowRunApi(WebApiResource):
class WorkflowTaskStopApi(WebApiResource):
@api.doc("Stop Workflow Task")
@api.doc(description="Stop a running workflow task.")
@api.doc(
params={
"task_id": {"description": "Task ID to stop", "type": "string", "required": True},
}
)
@api.doc(
responses={
200: "Success",
400: "Bad Request",
401: "Unauthorized",
403: "Forbidden",
404: "Task Not Found",
500: "Internal Server Error",
}
)
def post(self, app_model: App, end_user: EndUser, task_id: str):
"""
Stop workflow task

View File

@ -1,3 +1,16 @@
from pydantic import BaseModel, ConfigDict, Field, ValidationError
class MoreLikeThisConfig(BaseModel):
enabled: bool = False
model_config = ConfigDict(extra="allow")
class AppConfigModel(BaseModel):
more_like_this: MoreLikeThisConfig = Field(default_factory=MoreLikeThisConfig)
model_config = ConfigDict(extra="allow")
class MoreLikeThisConfigManager:
@classmethod
def convert(cls, config: dict) -> bool:
@ -6,31 +19,14 @@ class MoreLikeThisConfigManager:
:param config: model config args
"""
more_like_this = False
more_like_this_dict = config.get("more_like_this")
if more_like_this_dict:
if more_like_this_dict.get("enabled"):
more_like_this = True
return more_like_this
validated_config, _ = cls.validate_and_set_defaults(config)
return AppConfigModel.model_validate(validated_config).more_like_this.enabled
@classmethod
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
"""
Validate and set defaults for more like this feature
:param config: app model config args
"""
if not config.get("more_like_this"):
config["more_like_this"] = {"enabled": False}
if not isinstance(config["more_like_this"], dict):
raise ValueError("more_like_this must be of dict type")
if "enabled" not in config["more_like_this"] or not config["more_like_this"]["enabled"]:
config["more_like_this"]["enabled"] = False
if not isinstance(config["more_like_this"]["enabled"], bool):
raise ValueError("enabled in more_like_this must be of boolean type")
return config, ["more_like_this"]
try:
return AppConfigModel.model_validate(config).model_dump(), ["more_like_this"]
except ValidationError as e:
raise ValueError(
"more_like_this must be of dict type and enabled in more_like_this must be of boolean type"
)

View File

@ -140,7 +140,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
environment_variables=self._workflow.environment_variables,
# Based on the definition of `VariableUnion`,
# `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible.
conversation_variables=cast(list[VariableUnion], conversation_variables),
conversation_variables=conversation_variables,
)
# init graph

View File

@ -1,4 +1,5 @@
import logging
import re
import time
from collections.abc import Callable, Generator, Mapping
from contextlib import contextmanager
@ -143,6 +144,7 @@ class AdvancedChatAppGenerateTaskPipeline:
self._workflow_response_converter = WorkflowResponseConverter(
application_generate_entity=application_generate_entity,
user=user,
)
self._task_state = WorkflowTaskState()
@ -373,7 +375,7 @@ class AdvancedChatAppGenerateTaskPipeline:
) -> Generator[StreamResponse, None, None]:
"""Handle node succeeded events."""
# Record files if it's an answer node or end node
if event.node_type in [NodeType.ANSWER, NodeType.END]:
if event.node_type in [NodeType.ANSWER, NodeType.END, NodeType.LLM]:
self._recorded_files.extend(
self._workflow_response_converter.fetch_files_from_node_outputs(event.outputs or {})
)
@ -896,7 +898,14 @@ class AdvancedChatAppGenerateTaskPipeline:
def _save_message(self, *, session: Session, graph_runtime_state: Optional[GraphRuntimeState] = None) -> None:
message = self._get_message(session=session)
message.answer = self._task_state.answer
# If there are assistant files, remove markdown image links from answer
answer_text = self._task_state.answer
if self._recorded_files:
# Remove markdown image links since we're storing files separately
answer_text = re.sub(r"!\[.*?\]\(.*?\)", "", answer_text).strip()
message.answer = answer_text
message.updated_at = naive_utc_now()
message.provider_response_latency = time.perf_counter() - self._base_task_pipeline._start_at
message.message_metadata = self._task_state.metadata.model_dump_json()

View File

@ -1,4 +1,3 @@
import json
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Optional, Union, final
@ -14,6 +13,7 @@ from core.workflow.repositories.draft_variable_repository import (
NoopDraftVariableSaver,
)
from factories import file_factory
from libs.orjson import orjson_dumps
from services.workflow_draft_variable_service import DraftVariableSaver as DraftVariableSaverImpl
if TYPE_CHECKING:
@ -174,7 +174,7 @@ class BaseAppGenerator:
def gen():
for message in generator:
if isinstance(message, Mapping | dict):
yield f"data: {json.dumps(message)}\n\n"
yield f"data: {orjson_dumps(message)}\n\n"
else:
yield f"event: {message}\n\n"

View File

@ -1,7 +1,7 @@
import queue
import time
from abc import abstractmethod
from enum import Enum
from enum import IntEnum, auto
from typing import Any, Optional
from sqlalchemy.orm import DeclarativeMeta
@ -19,9 +19,9 @@ from core.app.entities.queue_entities import (
from extensions.ext_redis import redis_client
class PublishFrom(Enum):
APPLICATION_MANAGER = 1
TASK_PIPELINE = 2
class PublishFrom(IntEnum):
APPLICATION_MANAGER = auto()
TASK_PIPELINE = auto()
class AppQueueManager:

View File

@ -3,7 +3,6 @@ from collections.abc import Mapping, Sequence
from datetime import UTC, datetime
from typing import Any, Optional, Union, cast
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity
@ -53,9 +52,7 @@ from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
from libs.datetime_utils import naive_utc_now
from models import (
Account,
CreatorUserRole,
EndUser,
WorkflowRun,
)
@ -64,8 +61,10 @@ class WorkflowResponseConverter:
self,
*,
application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity],
user: Union[Account, EndUser],
) -> None:
self._application_generate_entity = application_generate_entity
self._user = user
def workflow_start_to_stream_response(
self,
@ -92,27 +91,21 @@ class WorkflowResponseConverter:
workflow_execution: WorkflowExecution,
) -> WorkflowFinishStreamResponse:
created_by = None
workflow_run = session.scalar(select(WorkflowRun).where(WorkflowRun.id == workflow_execution.id_))
assert workflow_run is not None
if workflow_run.created_by_role == CreatorUserRole.ACCOUNT:
stmt = select(Account).where(Account.id == workflow_run.created_by)
account = session.scalar(stmt)
if account:
created_by = {
"id": account.id,
"name": account.name,
"email": account.email,
}
elif workflow_run.created_by_role == CreatorUserRole.END_USER:
stmt = select(EndUser).where(EndUser.id == workflow_run.created_by)
end_user = session.scalar(stmt)
if end_user:
created_by = {
"id": end_user.id,
"user": end_user.session_id,
}
user = self._user
if isinstance(user, Account):
created_by = {
"id": user.id,
"name": user.name,
"email": user.email,
}
elif isinstance(user, EndUser):
created_by = {
"id": user.id,
"user": user.session_id,
}
else:
raise NotImplementedError(f"unknown created_by_role: {workflow_run.created_by_role}")
raise NotImplementedError(f"User type not supported: {type(user)}")
# Handle the case where finished_at is None by using current time as default
finished_at_timestamp = (

View File

@ -131,6 +131,7 @@ class WorkflowAppGenerateTaskPipeline:
self._workflow_response_converter = WorkflowResponseConverter(
application_generate_entity=application_generate_entity,
user=user,
)
self._application_generate_entity = application_generate_entity

View File

@ -118,7 +118,7 @@ class QueueIterationNextEvent(AppQueueEvent):
parent_parallel_start_node_id: Optional[str] = None
"""parent parallel start node id if node is in parallel"""
parallel_mode_run_id: Optional[str] = None
"""iteratoin run in parallel mode run id"""
"""iteration run in parallel mode run id"""
node_run_index: int
output: Optional[Any] = None # output for the current iteration
duration: Optional[float] = None
@ -201,7 +201,7 @@ class QueueLoopNextEvent(AppQueueEvent):
parent_parallel_start_node_id: Optional[str] = None
"""parent parallel start node id if node is in parallel"""
parallel_mode_run_id: Optional[str] = None
"""iteratoin run in parallel mode run id"""
"""iteration run in parallel mode run id"""
node_run_index: int
output: Optional[Any] = None # output for the current loop
duration: Optional[float] = None
@ -382,7 +382,7 @@ class QueueNodeStartedEvent(AppQueueEvent):
"""loop id if node is in loop"""
start_at: datetime
parallel_mode_run_id: Optional[str] = None
"""iteratoin run in parallel mode run id"""
"""iteration run in parallel mode run id"""
agent_strategy: Optional[AgentNodeStrategyInit] = None

View File

@ -472,9 +472,10 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
:param event: agent thought event
:return:
"""
agent_thought: Optional[MessageAgentThought] = (
db.session.query(MessageAgentThought).where(MessageAgentThought.id == event.agent_thought_id).first()
)
with Session(db.engine, expire_on_commit=False) as session:
agent_thought: Optional[MessageAgentThought] = (
session.query(MessageAgentThought).where(MessageAgentThought.id == event.agent_thought_id).first()
)
if agent_thought:
return AgentThoughtStreamResponse(

View File

@ -1,8 +1,8 @@
class TaskPipilineError(ValueError):
class TaskPipelineError(ValueError):
pass
class RecordNotFoundError(TaskPipilineError):
class RecordNotFoundError(TaskPipelineError):
def __init__(self, record_name: str, record_id: str):
super().__init__(f"{record_name} with id {record_id} not found")

View File

@ -88,6 +88,7 @@ def to_prompt_message_content(
"url": _to_url(f) if dify_config.MULTIMODAL_SEND_FORMAT == "url" else "",
"format": f.extension.removeprefix("."),
"mime_type": f.mime_type,
"filename": f.filename or "",
}
if f.type == FileType.IMAGE:
params["detail"] = image_detail_config or ImagePromptMessageContent.DETAIL.LOW

View File

@ -3,7 +3,7 @@ import base64
from libs import rsa
def obfuscated_token(token: str):
def obfuscated_token(token: str) -> str:
if not token:
return token
if len(token) <= 8:

View File

@ -1,6 +1,6 @@
from collections.abc import Sequence
import requests
import httpx
from yarl import URL
from configs import dify_config
@ -23,7 +23,7 @@ def batch_fetch_plugin_manifests(plugin_ids: list[str]) -> Sequence[MarketplaceP
return []
url = str(marketplace_api_url / "api/v1/plugins/batch")
response = requests.post(url, json={"plugin_ids": plugin_ids})
response = httpx.post(url, json={"plugin_ids": plugin_ids})
response.raise_for_status()
return [MarketplacePluginDeclaration(**plugin) for plugin in response.json()["data"]["plugins"]]
@ -36,7 +36,7 @@ def batch_fetch_plugin_manifests_ignore_deserialization_error(
return []
url = str(marketplace_api_url / "api/v1/plugins/batch")
response = requests.post(url, json={"plugin_ids": plugin_ids})
response = httpx.post(url, json={"plugin_ids": plugin_ids})
response.raise_for_status()
result: list[MarketplacePluginDeclaration] = []
for plugin in response.json()["data"]["plugins"]:
@ -50,5 +50,5 @@ def batch_fetch_plugin_manifests_ignore_deserialization_error(
def record_install_plugin_event(plugin_unique_identifier: str):
url = str(marketplace_api_url / "api/v1/stats/plugins/install_count")
response = requests.post(url, json={"unique_identifier": plugin_unique_identifier})
response = httpx.post(url, json={"unique_identifier": plugin_unique_identifier})
response.raise_for_status()

View File

@ -4,224 +4,259 @@ from collections.abc import Mapping
from typing import Any, cast
from configs import dify_config
from controllers.web.passport import generate_session_id
from core.app.app_config.entities import VariableEntity, VariableEntityType
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.features.rate_limiting.rate_limit import RateLimitGenerator
from core.mcp import types
from core.mcp.types import INTERNAL_ERROR, INVALID_PARAMS, METHOD_NOT_FOUND
from core.mcp.utils import create_mcp_error_response
from core.model_runtime.utils.encoders import jsonable_encoder
from extensions.ext_database import db
from core.mcp import types as mcp_types
from models.model import App, AppMCPServer, AppMode, EndUser
from services.app_generate_service import AppGenerateService
logger = logging.getLogger(__name__)
class MCPServerStreamableHTTPRequestHandler:
def handle_mcp_request(
app: App,
request: mcp_types.ClientRequest,
user_input_form: list[VariableEntity],
mcp_server: AppMCPServer,
end_user: EndUser | None = None,
request_id: int | str = 1,
) -> mcp_types.JSONRPCResponse | mcp_types.JSONRPCError:
"""
Apply to MCP HTTP streamable server with stateless http
Handle MCP request and return JSON-RPC response
Args:
app: The Dify app instance
request: The JSON-RPC request message
user_input_form: List of variable entities for the app
mcp_server: The MCP server configuration
end_user: Optional end user
request_id: The request ID
Returns:
JSON-RPC response or error
"""
def __init__(
self, app: App, request: types.ClientRequest | types.ClientNotification, user_input_form: list[VariableEntity]
):
self.app = app
self.request = request
mcp_server = db.session.query(AppMCPServer).where(AppMCPServer.app_id == self.app.id).first()
if not mcp_server:
raise ValueError("MCP server not found")
self.mcp_server: AppMCPServer = mcp_server
self.end_user = self.retrieve_end_user()
self.user_input_form = user_input_form
request_type = type(request.root)
@property
def request_type(self):
return type(self.request.root)
def create_success_response(result_data: mcp_types.Result) -> mcp_types.JSONRPCResponse:
"""Create success response with business result data"""
return mcp_types.JSONRPCResponse(
jsonrpc="2.0",
id=request_id,
result=result_data.model_dump(by_alias=True, mode="json", exclude_none=True),
)
@property
def parameter_schema(self):
parameters, required = self._convert_input_form_to_parameters(self.user_input_form)
if self.app.mode in {AppMode.COMPLETION.value, AppMode.WORKFLOW.value}:
return {
"type": "object",
"properties": parameters,
"required": required,
}
def create_error_response(code: int, message: str) -> mcp_types.JSONRPCError:
"""Create error response with error code and message"""
from core.mcp.types import ErrorData
error_data = ErrorData(code=code, message=message)
return mcp_types.JSONRPCError(
jsonrpc="2.0",
id=request_id,
error=error_data,
)
# Request handler mapping using functional approach
request_handlers = {
mcp_types.InitializeRequest: lambda: handle_initialize(mcp_server.description),
mcp_types.ListToolsRequest: lambda: handle_list_tools(
app.name, app.mode, user_input_form, mcp_server.description, mcp_server.parameters_dict
),
mcp_types.CallToolRequest: lambda: handle_call_tool(app, request, user_input_form, end_user),
mcp_types.PingRequest: lambda: handle_ping(),
}
try:
# Dispatch request to appropriate handler
handler = request_handlers.get(request_type)
if handler:
return create_success_response(handler())
else:
return create_error_response(mcp_types.METHOD_NOT_FOUND, f"Method not found: {request_type.__name__}")
except ValueError as e:
logger.exception("Invalid params")
return create_error_response(mcp_types.INVALID_PARAMS, str(e))
except Exception as e:
logger.exception("Internal server error")
return create_error_response(mcp_types.INTERNAL_ERROR, "Internal server error: " + str(e))
def handle_ping() -> mcp_types.EmptyResult:
"""Handle ping request"""
return mcp_types.EmptyResult()
def handle_initialize(description: str) -> mcp_types.InitializeResult:
"""Handle initialize request"""
capabilities = mcp_types.ServerCapabilities(
tools=mcp_types.ToolsCapability(listChanged=False),
)
return mcp_types.InitializeResult(
protocolVersion=mcp_types.SERVER_LATEST_PROTOCOL_VERSION,
capabilities=capabilities,
serverInfo=mcp_types.Implementation(name="Dify", version=dify_config.project.version),
instructions=description,
)
def handle_list_tools(
app_name: str,
app_mode: str,
user_input_form: list[VariableEntity],
description: str,
parameters_dict: dict[str, str],
) -> mcp_types.ListToolsResult:
"""Handle list tools request"""
parameter_schema = build_parameter_schema(app_mode, user_input_form, parameters_dict)
return mcp_types.ListToolsResult(
tools=[
mcp_types.Tool(
name=app_name,
description=description,
inputSchema=parameter_schema,
)
],
)
def handle_call_tool(
app: App,
request: mcp_types.ClientRequest,
user_input_form: list[VariableEntity],
end_user: EndUser | None,
) -> mcp_types.CallToolResult:
"""Handle call tool request"""
request_obj = cast(mcp_types.CallToolRequest, request.root)
args = prepare_tool_arguments(app, request_obj.params.arguments or {})
if not end_user:
raise ValueError("End user not found")
response = AppGenerateService.generate(
app,
end_user,
args,
InvokeFrom.SERVICE_API,
streaming=app.mode == AppMode.AGENT_CHAT.value,
)
answer = extract_answer_from_response(app, response)
return mcp_types.CallToolResult(content=[mcp_types.TextContent(text=answer, type="text")])
def build_parameter_schema(
app_mode: str,
user_input_form: list[VariableEntity],
parameters_dict: dict[str, str],
) -> dict[str, Any]:
"""Build parameter schema for the tool"""
parameters, required = convert_input_form_to_parameters(user_input_form, parameters_dict)
if app_mode in {AppMode.COMPLETION.value, AppMode.WORKFLOW.value}:
return {
"type": "object",
"properties": {
"query": {"type": "string", "description": "User Input/Question content"},
**parameters,
},
"required": ["query", *required],
"properties": parameters,
"required": required,
}
return {
"type": "object",
"properties": {
"query": {"type": "string", "description": "User Input/Question content"},
**parameters,
},
"required": ["query", *required],
}
@property
def capabilities(self):
return types.ServerCapabilities(
tools=types.ToolsCapability(listChanged=False),
)
def response(self, response: types.Result | str):
if isinstance(response, str):
sse_content = f"event: ping\ndata: {response}\n\n".encode()
yield sse_content
return
json_response = types.JSONRPCResponse(
jsonrpc="2.0",
id=(self.request.root.model_extra or {}).get("id", 1),
result=response.model_dump(by_alias=True, mode="json", exclude_none=True),
)
json_data = json.dumps(jsonable_encoder(json_response))
def prepare_tool_arguments(app: App, arguments: dict[str, Any]) -> dict[str, Any]:
"""Prepare arguments based on app mode"""
if app.mode == AppMode.WORKFLOW.value:
return {"inputs": arguments}
elif app.mode == AppMode.COMPLETION.value:
return {"query": "", "inputs": arguments}
else:
# Chat modes - create a copy to avoid modifying original dict
args_copy = arguments.copy()
query = args_copy.pop("query", "")
return {"query": query, "inputs": args_copy}
sse_content = f"event: message\ndata: {json_data}\n\n".encode()
yield sse_content
def extract_answer_from_response(app: App, response: Any) -> str:
"""Extract answer from app generate response"""
answer = ""
def error_response(self, code: int, message: str, data=None):
request_id = (self.request.root.model_extra or {}).get("id", 1) or 1
return create_mcp_error_response(request_id, code, message, data)
if isinstance(response, RateLimitGenerator):
answer = process_streaming_response(response)
elif isinstance(response, Mapping):
answer = process_mapping_response(app, response)
else:
logger.warning("Unexpected response type: %s", type(response))
def handle(self):
handle_map = {
types.InitializeRequest: self.initialize,
types.ListToolsRequest: self.list_tools,
types.CallToolRequest: self.invoke_tool,
types.InitializedNotification: self.handle_notification,
types.PingRequest: self.handle_ping,
}
try:
if self.request_type in handle_map:
return self.response(handle_map[self.request_type]())
else:
return self.error_response(METHOD_NOT_FOUND, f"Method not found: {self.request_type}")
except ValueError as e:
logger.exception("Invalid params")
return self.error_response(INVALID_PARAMS, str(e))
except Exception as e:
logger.exception("Internal server error")
return self.error_response(INTERNAL_ERROR, f"Internal server error: {str(e)}")
return answer
def handle_notification(self):
return "ping"
def handle_ping(self):
return types.EmptyResult()
def initialize(self):
request = cast(types.InitializeRequest, self.request.root)
client_info = request.params.clientInfo
client_name = f"{client_info.name}@{client_info.version}"
if not self.end_user:
end_user = EndUser(
tenant_id=self.app.tenant_id,
app_id=self.app.id,
type="mcp",
name=client_name,
session_id=generate_session_id(),
external_user_id=self.mcp_server.id,
)
db.session.add(end_user)
db.session.commit()
return types.InitializeResult(
protocolVersion=types.SERVER_LATEST_PROTOCOL_VERSION,
capabilities=self.capabilities,
serverInfo=types.Implementation(name="Dify", version=dify_config.project.version),
instructions=self.mcp_server.description,
)
def list_tools(self):
if not self.end_user:
raise ValueError("User not found")
return types.ListToolsResult(
tools=[
types.Tool(
name=self.app.name,
description=self.mcp_server.description,
inputSchema=self.parameter_schema,
)
],
)
def invoke_tool(self):
if not self.end_user:
raise ValueError("User not found")
request = cast(types.CallToolRequest, self.request.root)
args = request.params.arguments or {}
if self.app.mode in {AppMode.WORKFLOW.value}:
args = {"inputs": args}
elif self.app.mode in {AppMode.COMPLETION.value}:
args = {"query": "", "inputs": args}
else:
args = {"query": args["query"], "inputs": {k: v for k, v in args.items() if k != "query"}}
response = AppGenerateService.generate(
self.app,
self.end_user,
args,
InvokeFrom.SERVICE_API,
streaming=self.app.mode == AppMode.AGENT_CHAT.value,
)
answer = ""
if isinstance(response, RateLimitGenerator):
for item in response.generator:
data = item
if isinstance(data, str) and data.startswith("data: "):
try:
json_str = data[6:].strip()
parsed_data = json.loads(json_str)
if parsed_data.get("event") == "agent_thought":
answer += parsed_data.get("thought", "")
except json.JSONDecodeError:
continue
if isinstance(response, Mapping):
if self.app.mode in {
AppMode.ADVANCED_CHAT.value,
AppMode.COMPLETION.value,
AppMode.CHAT.value,
AppMode.AGENT_CHAT.value,
}:
answer = response["answer"]
elif self.app.mode in {AppMode.WORKFLOW.value}:
answer = json.dumps(response["data"]["outputs"], ensure_ascii=False)
else:
raise ValueError("Invalid app mode")
# Not support image yet
return types.CallToolResult(content=[types.TextContent(text=answer, type="text")])
def retrieve_end_user(self):
return (
db.session.query(EndUser)
.where(EndUser.external_user_id == self.mcp_server.id, EndUser.type == "mcp")
.first()
)
def _convert_input_form_to_parameters(self, user_input_form: list[VariableEntity]):
parameters: dict[str, dict[str, Any]] = {}
required = []
for item in user_input_form:
parameters[item.variable] = {}
if item.type in (
VariableEntityType.FILE,
VariableEntityType.FILE_LIST,
VariableEntityType.EXTERNAL_DATA_TOOL,
):
continue
if item.required:
required.append(item.variable)
# if the workflow republished, the parameters not changed
# we should not raise error here
def process_streaming_response(response: RateLimitGenerator) -> str:
"""Process streaming response for agent chat mode"""
answer = ""
for item in response.generator:
if isinstance(item, str) and item.startswith("data: "):
try:
description = self.mcp_server.parameters_dict[item.variable]
except KeyError:
description = ""
parameters[item.variable]["description"] = description
if item.type in (VariableEntityType.TEXT_INPUT, VariableEntityType.PARAGRAPH):
parameters[item.variable]["type"] = "string"
elif item.type == VariableEntityType.SELECT:
parameters[item.variable]["type"] = "string"
parameters[item.variable]["enum"] = item.options
elif item.type == VariableEntityType.NUMBER:
parameters[item.variable]["type"] = "float"
return parameters, required
json_str = item[6:].strip()
parsed_data = json.loads(json_str)
if parsed_data.get("event") == "agent_thought":
answer += parsed_data.get("thought", "")
except json.JSONDecodeError:
continue
return answer
def process_mapping_response(app: App, response: Mapping) -> str:
"""Process mapping response based on app mode"""
if app.mode in {
AppMode.ADVANCED_CHAT.value,
AppMode.COMPLETION.value,
AppMode.CHAT.value,
AppMode.AGENT_CHAT.value,
}:
return response.get("answer", "")
elif app.mode == AppMode.WORKFLOW.value:
return json.dumps(response["data"]["outputs"], ensure_ascii=False)
else:
raise ValueError("Invalid app mode: " + str(app.mode))
def convert_input_form_to_parameters(
user_input_form: list[VariableEntity],
parameters_dict: dict[str, str],
) -> tuple[dict[str, dict[str, Any]], list[str]]:
"""Convert user input form to parameter schema"""
parameters: dict[str, dict[str, Any]] = {}
required = []
for item in user_input_form:
if item.type in (
VariableEntityType.FILE,
VariableEntityType.FILE_LIST,
VariableEntityType.EXTERNAL_DATA_TOOL,
):
continue
parameters[item.variable] = {}
if item.required:
required.append(item.variable)
# if the workflow republished, the parameters not changed
# we should not raise error here
description = parameters_dict.get(item.variable, "")
parameters[item.variable]["description"] = description
if item.type in (VariableEntityType.TEXT_INPUT, VariableEntityType.PARAGRAPH):
parameters[item.variable]["type"] = "string"
elif item.type == VariableEntityType.SELECT:
parameters[item.variable]["type"] = "string"
parameters[item.variable]["enum"] = item.options
elif item.type == VariableEntityType.NUMBER:
parameters[item.variable]["type"] = "float"
return parameters, required

View File

@ -138,5 +138,5 @@ def create_mcp_error_response(
error=error_data,
)
json_data = json.dumps(jsonable_encoder(json_response))
sse_content = f"event: message\ndata: {json_data}\n\n".encode()
sse_content = json_data.encode()
yield sse_content

View File

@ -31,6 +31,65 @@ class TokenBufferMemory:
self.conversation = conversation
self.model_instance = model_instance
def _build_prompt_message_with_files(
self, message_files: list[MessageFile], text_content: str, message: Message, app_record, is_user_message: bool
) -> PromptMessage:
"""
Build prompt message with files.
:param message_files: list of MessageFile objects
:param text_content: text content of the message
:param message: Message object
:param app_record: app record
:param is_user_message: whether this is a user message
:return: PromptMessage
"""
if self.conversation.mode in {AppMode.AGENT_CHAT, AppMode.COMPLETION, AppMode.CHAT}:
file_extra_config = FileUploadConfigManager.convert(self.conversation.model_config)
elif self.conversation.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
workflow_run = db.session.scalar(select(WorkflowRun).where(WorkflowRun.id == message.workflow_run_id))
if not workflow_run:
raise ValueError(f"Workflow run not found: {message.workflow_run_id}")
workflow = db.session.scalar(select(Workflow).where(Workflow.id == workflow_run.workflow_id))
if not workflow:
raise ValueError(f"Workflow not found: {workflow_run.workflow_id}")
file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
else:
raise AssertionError(f"Invalid app mode: {self.conversation.mode}")
detail = ImagePromptMessageContent.DETAIL.HIGH
if file_extra_config and app_record:
# Build files directly without filtering by belongs_to
file_objs = [
file_factory.build_from_message_file(
message_file=message_file, tenant_id=app_record.tenant_id, config=file_extra_config
)
for message_file in message_files
]
if file_extra_config.image_config and file_extra_config.image_config.detail:
detail = file_extra_config.image_config.detail
else:
file_objs = []
if not file_objs:
if is_user_message:
return UserPromptMessage(content=text_content)
else:
return AssistantPromptMessage(content=text_content)
else:
prompt_message_contents: list[PromptMessageContentUnionTypes] = []
for file in file_objs:
prompt_message = file_manager.to_prompt_message_content(
file,
image_detail_config=detail,
)
prompt_message_contents.append(prompt_message)
prompt_message_contents.append(TextPromptMessageContent(data=text_content))
if is_user_message:
return UserPromptMessage(content=prompt_message_contents)
else:
return AssistantPromptMessage(content=prompt_message_contents)
def get_history_prompt_messages(
self, max_token_limit: int = 2000, message_limit: Optional[int] = None
) -> Sequence[PromptMessage]:
@ -67,52 +126,46 @@ class TokenBufferMemory:
prompt_messages: list[PromptMessage] = []
for message in messages:
files = db.session.query(MessageFile).where(MessageFile.message_id == message.id).all()
if files:
file_extra_config = None
if self.conversation.mode in {AppMode.AGENT_CHAT, AppMode.COMPLETION, AppMode.CHAT}:
file_extra_config = FileUploadConfigManager.convert(self.conversation.model_config)
elif self.conversation.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
workflow_run = db.session.scalar(
select(WorkflowRun).where(WorkflowRun.id == message.workflow_run_id)
)
if not workflow_run:
raise ValueError(f"Workflow run not found: {message.workflow_run_id}")
workflow = db.session.scalar(select(Workflow).where(Workflow.id == workflow_run.workflow_id))
if not workflow:
raise ValueError(f"Workflow not found: {workflow_run.workflow_id}")
file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
else:
raise AssertionError(f"Invalid app mode: {self.conversation.mode}")
detail = ImagePromptMessageContent.DETAIL.LOW
if file_extra_config and app_record:
file_objs = file_factory.build_from_message_files(
message_files=files, tenant_id=app_record.tenant_id, config=file_extra_config
)
if file_extra_config.image_config and file_extra_config.image_config.detail:
detail = file_extra_config.image_config.detail
else:
file_objs = []
if not file_objs:
prompt_messages.append(UserPromptMessage(content=message.query))
else:
prompt_message_contents: list[PromptMessageContentUnionTypes] = []
for file in file_objs:
prompt_message = file_manager.to_prompt_message_content(
file,
image_detail_config=detail,
)
prompt_message_contents.append(prompt_message)
prompt_message_contents.append(TextPromptMessageContent(data=message.query))
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
# Process user message with files
user_files = (
db.session.query(MessageFile)
.where(
MessageFile.message_id == message.id,
(MessageFile.belongs_to == "user") | (MessageFile.belongs_to.is_(None)),
)
.all()
)
if user_files:
user_prompt_message = self._build_prompt_message_with_files(
message_files=user_files,
text_content=message.query,
message=message,
app_record=app_record,
is_user_message=True,
)
prompt_messages.append(user_prompt_message)
else:
prompt_messages.append(UserPromptMessage(content=message.query))
prompt_messages.append(AssistantPromptMessage(content=message.answer))
# Process assistant message with files
assistant_files = (
db.session.query(MessageFile)
.where(MessageFile.message_id == message.id, MessageFile.belongs_to == "assistant")
.all()
)
if assistant_files:
assistant_prompt_message = self._build_prompt_message_with_files(
message_files=assistant_files,
text_content=message.answer,
message=message,
app_record=app_record,
is_user_message=False,
)
prompt_messages.append(assistant_prompt_message)
else:
prompt_messages.append(AssistantPromptMessage(content=message.answer))
if not prompt_messages:
return []

View File

@ -158,8 +158,6 @@ class ModelInstance:
"""
if not isinstance(self.model_type_instance, LargeLanguageModel):
raise Exception("Model type instance is not LargeLanguageModel")
self.model_type_instance = cast(LargeLanguageModel, self.model_type_instance)
return cast(
Union[LLMResult, Generator],
self._round_robin_invoke(
@ -188,8 +186,6 @@ class ModelInstance:
"""
if not isinstance(self.model_type_instance, LargeLanguageModel):
raise Exception("Model type instance is not LargeLanguageModel")
self.model_type_instance = cast(LargeLanguageModel, self.model_type_instance)
return cast(
int,
self._round_robin_invoke(
@ -214,8 +210,6 @@ class ModelInstance:
"""
if not isinstance(self.model_type_instance, TextEmbeddingModel):
raise Exception("Model type instance is not TextEmbeddingModel")
self.model_type_instance = cast(TextEmbeddingModel, self.model_type_instance)
return cast(
TextEmbeddingResult,
self._round_robin_invoke(
@ -237,8 +231,6 @@ class ModelInstance:
"""
if not isinstance(self.model_type_instance, TextEmbeddingModel):
raise Exception("Model type instance is not TextEmbeddingModel")
self.model_type_instance = cast(TextEmbeddingModel, self.model_type_instance)
return cast(
list[int],
self._round_robin_invoke(
@ -269,8 +261,6 @@ class ModelInstance:
"""
if not isinstance(self.model_type_instance, RerankModel):
raise Exception("Model type instance is not RerankModel")
self.model_type_instance = cast(RerankModel, self.model_type_instance)
return cast(
RerankResult,
self._round_robin_invoke(
@ -295,8 +285,6 @@ class ModelInstance:
"""
if not isinstance(self.model_type_instance, ModerationModel):
raise Exception("Model type instance is not ModerationModel")
self.model_type_instance = cast(ModerationModel, self.model_type_instance)
return cast(
bool,
self._round_robin_invoke(
@ -318,8 +306,6 @@ class ModelInstance:
"""
if not isinstance(self.model_type_instance, Speech2TextModel):
raise Exception("Model type instance is not Speech2TextModel")
self.model_type_instance = cast(Speech2TextModel, self.model_type_instance)
return cast(
str,
self._round_robin_invoke(
@ -343,8 +329,6 @@ class ModelInstance:
"""
if not isinstance(self.model_type_instance, TTSModel):
raise Exception("Model type instance is not TTSModel")
self.model_type_instance = cast(TTSModel, self.model_type_instance)
return cast(
Iterable[bytes],
self._round_robin_invoke(
@ -404,8 +388,6 @@ class ModelInstance:
"""
if not isinstance(self.model_type_instance, TTSModel):
raise Exception("Model type instance is not TTSModel")
self.model_type_instance = cast(TTSModel, self.model_type_instance)
return self.model_type_instance.get_tts_model_voices(
model=self.model, credentials=self.credentials, language=language
)

View File

@ -87,6 +87,7 @@ class MultiModalPromptMessageContent(PromptMessageContent):
base64_data: str = Field(default="", description="the base64 data of multi-modal file")
url: str = Field(default="", description="the url of multi-modal file")
mime_type: str = Field(default=..., description="the mime type of multi-modal file")
filename: str = Field(default="", description="the filename of multi-modal file")
@property
def data(self):

View File

@ -43,7 +43,7 @@ class GPT2Tokenizer:
except Exception:
from os.path import abspath, dirname, join
from transformers import GPT2Tokenizer as TransformerGPT2Tokenizer # type: ignore
from transformers import GPT2Tokenizer as TransformerGPT2Tokenizer
base_path = abspath(__file__)
gpt2_tokenizer_path = join(dirname(base_path), "gpt2")

View File

@ -330,7 +330,7 @@ class OpsTraceManager:
except KeyError:
raise ValueError(f"Invalid tracing provider: {tracing_provider}")
else:
if tracing_provider is not None:
if tracing_provider is None:
raise ValueError(f"Invalid tracing provider: {tracing_provider}")
app_config: Optional[App] = db.session.query(App).where(App.id == app_id).first()

View File

@ -375,16 +375,16 @@ Here is the extra instruction you need to follow:
# merge lines into messages with max tokens
messages: list[str] = []
for i in new_lines: # type: ignore
for line in new_lines:
if len(messages) == 0:
messages.append(i) # type: ignore
messages.append(line)
else:
if len(messages[-1]) + len(i) < max_tokens * 0.5: # type: ignore
messages[-1] += i # type: ignore
if get_prompt_tokens(messages[-1] + i) > max_tokens * 0.7: # type: ignore
messages.append(i) # type: ignore
if len(messages[-1]) + len(line) < max_tokens * 0.5:
messages[-1] += line
if get_prompt_tokens(messages[-1] + line) > max_tokens * 0.7:
messages.append(line)
else:
messages[-1] += i # type: ignore
messages[-1] += line
summaries = []
for i in range(len(messages)):

View File

@ -87,7 +87,6 @@ class PromptMessageUtil:
if isinstance(prompt_message.content, list):
for content in prompt_message.content:
if content.type == PromptMessageContentType.TEXT:
content = cast(TextPromptMessageContent, content)
text += content.data
else:
content = cast(ImagePromptMessageContent, content)

View File

@ -2,7 +2,7 @@ import contextlib
import json
from collections import defaultdict
from json import JSONDecodeError
from typing import Any, Optional, cast
from typing import Any, Optional
from sqlalchemy import select
from sqlalchemy.exc import IntegrityError
@ -154,8 +154,8 @@ class ProviderManager:
for provider_entity in provider_entities:
# handle include, exclude
if is_filtered(
include_set=cast(set[str], dify_config.POSITION_PROVIDER_INCLUDES_SET),
exclude_set=cast(set[str], dify_config.POSITION_PROVIDER_EXCLUDES_SET),
include_set=dify_config.POSITION_PROVIDER_INCLUDES_SET,
exclude_set=dify_config.POSITION_PROVIDER_EXCLUDES_SET,
data=provider_entity,
name_func=lambda x: x.provider,
):

View File

@ -24,7 +24,7 @@ default_retrieval_model = {
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
"reranking_enable": False,
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
"top_k": 2,
"top_k": 4,
"score_threshold_enabled": False,
}

View File

@ -3,8 +3,8 @@ import uuid
from contextlib import contextmanager
from typing import Any
import psycopg2.extras # type: ignore
import psycopg2.pool # type: ignore
import psycopg2.extras
import psycopg2.pool
from pydantic import BaseModel, model_validator
from core.rag.models.document import Document

View File

@ -304,7 +304,7 @@ class CouchbaseVector(BaseVector):
return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
top_k = kwargs.get("top_k", 2)
top_k = kwargs.get("top_k", 4)
try:
CBrequest = search.SearchRequest.create(search.QueryStringQuery("text:" + query))
search_iter = self._scope.search(

View File

@ -3,8 +3,8 @@ import uuid
from contextlib import contextmanager
from typing import Any
import psycopg2.extras # type: ignore
import psycopg2.pool # type: ignore
import psycopg2.extras
import psycopg2.pool
from pydantic import BaseModel, model_validator
from configs import dify_config

View File

@ -48,7 +48,7 @@ class OpenSearchConfig(BaseModel):
return values
def create_aws_managed_iam_auth(self) -> Urllib3AWSV4SignerAuth:
import boto3 # type: ignore
import boto3
return Urllib3AWSV4SignerAuth(
credentials=boto3.Session().get_credentials(),

View File

@ -6,8 +6,8 @@ from contextlib import contextmanager
from typing import Any
import psycopg2.errors
import psycopg2.extras # type: ignore
import psycopg2.pool # type: ignore
import psycopg2.extras
import psycopg2.pool
from pydantic import BaseModel, model_validator
from configs import dify_config

View File

@ -3,8 +3,8 @@ import uuid
from contextlib import contextmanager
from typing import Any
import psycopg2.extras # type: ignore
import psycopg2.pool # type: ignore
import psycopg2.extras
import psycopg2.pool
from pydantic import BaseModel, model_validator
from configs import dify_config

View File

@ -3,7 +3,7 @@ import os
import uuid
from collections.abc import Generator, Iterable, Sequence
from itertools import islice
from typing import TYPE_CHECKING, Any, Optional, Union, cast
from typing import TYPE_CHECKING, Any, Optional, Union
import qdrant_client
from flask import current_app
@ -426,7 +426,6 @@ class QdrantVector(BaseVector):
def _reload_if_needed(self):
if isinstance(self._client, QdrantLocal):
self._client = cast(QdrantLocal, self._client)
self._client._load()
@classmethod

View File

@ -3,7 +3,7 @@ import os
import uuid
from collections.abc import Generator, Iterable, Sequence
from itertools import islice
from typing import TYPE_CHECKING, Any, Optional, Union, cast
from typing import TYPE_CHECKING, Any, Optional, Union
import qdrant_client
import requests
@ -398,7 +398,6 @@ class TidbOnQdrantVector(BaseVector):
def _reload_if_needed(self):
if isinstance(self._client, QdrantLocal):
self._client = cast(QdrantLocal, self._client)
self._client._load()
@classmethod

View File

@ -4,7 +4,7 @@ import os
from typing import Optional, cast
import pandas as pd
from openpyxl import load_workbook # type: ignore
from openpyxl import load_workbook
from core.rag.extractor.extractor_base import BaseExtractor
from core.rag.models.document import Document

View File

@ -73,8 +73,8 @@ class ExtractProcessor:
suffix = "." + match.group(1)
else:
suffix = ""
# FIXME mypy: Cannot determine type of 'tempfile._get_candidate_names' better not use it here
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" # type: ignore
# https://stackoverflow.com/questions/26541416/generate-temporary-file-names-without-creating-actual-file-in-python#comment90414256_26541521
file_path = f"{temp_dir}/{tempfile.gettempdir()}{suffix}"
Path(file_path).write_bytes(response.content)
extract_setting = ExtractSetting(datasource_type="upload_file", document_model="text_model")
if return_text:

View File

@ -1,6 +1,6 @@
"""Abstract interface for document loader implementations."""
from bs4 import BeautifulSoup # type: ignore
from bs4 import BeautifulSoup
from core.rag.extractor.extractor_base import BaseExtractor
from core.rag.models.document import Document

View File

@ -2,7 +2,7 @@
import re
from pathlib import Path
from typing import Optional, cast
from typing import Optional
from core.rag.extractor.extractor_base import BaseExtractor
from core.rag.extractor.helpers import detect_file_encodings
@ -76,7 +76,7 @@ class MarkdownExtractor(BaseExtractor):
markdown_tups.append((current_header, current_text))
markdown_tups = [
(re.sub(r"#", "", cast(str, key)).strip() if key else None, re.sub(r"<.*?>", "", value))
(re.sub(r"#", "", key).strip() if key else None, re.sub(r"<.*?>", "", value))
for key, value in markdown_tups
]

View File

@ -385,4 +385,4 @@ class NotionExtractor(BaseExtractor):
f"No notion data source binding found for tenant {tenant_id} and notion workspace {notion_workspace_id}"
)
return cast(str, data_source_binding.access_token)
return data_source_binding.access_token

View File

@ -2,7 +2,7 @@
import contextlib
from collections.abc import Iterator
from typing import Optional, cast
from typing import Optional
from core.rag.extractor.blob.blob import Blob
from core.rag.extractor.extractor_base import BaseExtractor
@ -27,7 +27,7 @@ class PdfExtractor(BaseExtractor):
plaintext_file_exists = False
if self._file_cache_key:
with contextlib.suppress(FileNotFoundError):
text = cast(bytes, storage.load(self._file_cache_key)).decode("utf-8")
text = storage.load(self._file_cache_key).decode("utf-8")
plaintext_file_exists = True
return [Document(page_content=text)]
documents = list(self.load())

View File

@ -3,7 +3,7 @@ import contextlib
import logging
from typing import Optional
from bs4 import BeautifulSoup # type: ignore
from bs4 import BeautifulSoup
from core.rag.extractor.extractor_base import BaseExtractor
from core.rag.models.document import Document

View File

@ -65,7 +65,7 @@ default_retrieval_model: dict[str, Any] = {
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
"reranking_enable": False,
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
"top_k": 2,
"top_k": 4,
"score_threshold_enabled": False,
}
@ -647,7 +647,7 @@ class DatasetRetrieval:
retrieval_method=retrieval_model["search_method"],
dataset_id=dataset.id,
query=query,
top_k=retrieval_model.get("top_k") or 2,
top_k=retrieval_model.get("top_k") or 4,
score_threshold=retrieval_model.get("score_threshold", 0.0)
if retrieval_model["score_threshold_enabled"]
else 0.0,
@ -743,7 +743,7 @@ class DatasetRetrieval:
tool = DatasetMultiRetrieverTool.from_dataset(
dataset_ids=[dataset.id for dataset in available_datasets],
tenant_id=tenant_id,
top_k=retrieve_config.top_k or 2,
top_k=retrieve_config.top_k or 4,
score_threshold=retrieve_config.score_threshold,
hit_callbacks=[hit_callback],
return_resource=return_resource,

View File

@ -144,7 +144,7 @@ class TextSplitter(BaseDocumentTransformer, ABC):
def from_huggingface_tokenizer(cls, tokenizer: Any, **kwargs: Any) -> TextSplitter:
"""Text splitter that uses HuggingFace tokenizer to count length."""
try:
from transformers import PreTrainedTokenizerBase # type: ignore
from transformers import PreTrainedTokenizerBase
if not isinstance(tokenizer, PreTrainedTokenizerBase):
raise ValueError("Tokenizer received was not an instance of PreTrainedTokenizerBase")

View File

@ -331,16 +331,13 @@ class ToolManager:
if controller_tools is None or len(controller_tools) == 0:
raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")
return cast(
WorkflowTool,
controller.get_tools(tenant_id=workflow_provider.tenant_id)[0].fork_tool_runtime(
runtime=ToolRuntime(
tenant_id=tenant_id,
credentials={},
invoke_from=invoke_from,
tool_invoke_from=tool_invoke_from,
)
),
return controller.get_tools(tenant_id=workflow_provider.tenant_id)[0].fork_tool_runtime(
runtime=ToolRuntime(
tenant_id=tenant_id,
credentials={},
invoke_from=invoke_from,
tool_invoke_from=tool_invoke_from,
)
)
elif provider_type == ToolProviderType.APP:
raise NotImplementedError("app provider not implemented")
@ -648,8 +645,8 @@ class ToolManager:
for provider in builtin_providers:
# handle include, exclude
if is_filtered(
include_set=cast(set[str], dify_config.POSITION_TOOL_INCLUDES_SET),
exclude_set=cast(set[str], dify_config.POSITION_TOOL_EXCLUDES_SET),
include_set=dify_config.POSITION_TOOL_INCLUDES_SET,
exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET,
data=provider,
name_func=lambda x: x.identity.name,
):

View File

@ -181,7 +181,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
retrieval_method="keyword_search",
dataset_id=dataset.id,
query=query,
top_k=retrieval_model.get("top_k") or 2,
top_k=retrieval_model.get("top_k") or 4,
)
if documents:
all_documents.extend(documents)
@ -192,7 +192,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
retrieval_method=retrieval_model["search_method"],
dataset_id=dataset.id,
query=query,
top_k=retrieval_model.get("top_k") or 2,
top_k=retrieval_model.get("top_k") or 4,
score_threshold=retrieval_model.get("score_threshold", 0.0)
if retrieval_model["score_threshold_enabled"]
else 0.0,

View File

@ -13,7 +13,7 @@ class DatasetRetrieverBaseTool(BaseModel, ABC):
name: str = "dataset"
description: str = "use this to retrieve a dataset. "
tenant_id: str
top_k: int = 2
top_k: int = 4
score_threshold: Optional[float] = None
hit_callbacks: list[DatasetIndexToolCallbackHandler] = []
return_resource: bool

View File

@ -3,25 +3,26 @@ from collections.abc import Generator
from datetime import date, datetime
from decimal import Decimal
from mimetypes import guess_extension
from typing import Optional, cast
from typing import Optional
from uuid import UUID
import numpy as np
import pytz
from flask_login import current_user
from core.file import File, FileTransferMethod, FileType
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool_file_manager import ToolFileManager
from libs.login import current_user
from models.account import Account
logger = logging.getLogger(__name__)
def safe_json_value(v):
if isinstance(v, datetime):
tz_name = getattr(current_user, "timezone", None) if current_user is not None else None
if not tz_name:
tz_name = "UTC"
tz_name = "UTC"
if isinstance(current_user, Account) and current_user.timezone is not None:
tz_name = current_user.timezone
return v.astimezone(pytz.timezone(tz_name)).isoformat()
elif isinstance(v, date):
return v.isoformat()
@ -46,7 +47,7 @@ def safe_json_value(v):
return v
def safe_json_dict(d):
def safe_json_dict(d: dict):
if not isinstance(d, dict):
raise TypeError("safe_json_dict() expects a dictionary (dict) as input")
return {k: safe_json_value(v) for k, v in d.items()}
@ -158,8 +159,7 @@ class ToolFileMessageTransformer:
elif message.type == ToolInvokeMessage.MessageType.JSON:
if isinstance(message.message, ToolInvokeMessage.JsonMessage):
json_msg = cast(ToolInvokeMessage.JsonMessage, message.message)
json_msg.json_object = safe_json_value(json_msg.json_object)
message.message.json_object = safe_json_value(message.message.json_object)
yield message
else:
yield message

View File

@ -129,17 +129,14 @@ class ModelInvocationUtils:
db.session.commit()
try:
response: LLMResult = cast(
LLMResult,
model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters=model_parameters,
tools=[],
stop=[],
stream=False,
user=user_id,
callbacks=[],
),
response: LLMResult = model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters=model_parameters,
tools=[],
stop=[],
stream=False,
user=user_id,
callbacks=[],
)
except InvokeRateLimitError as e:
raise InvokeModelError(f"Invoke rate limit error: {e}")

View File

@ -6,7 +6,7 @@ from typing import Optional
from flask import request
from requests import get
from yaml import YAMLError, safe_load # type: ignore
from yaml import YAMLError, safe_load
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_bundle import ApiToolBundle

View File

@ -1,9 +1,7 @@
import json
import logging
from collections.abc import Generator
from typing import Any, Optional, cast
from flask_login import current_user
from typing import Any, Optional
from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod
from core.tools.__base.tool import Tool
@ -17,8 +15,8 @@ from core.tools.entities.tool_entities import (
from core.tools.errors import ToolInvokeError
from extensions.ext_database import db
from factories.file_factory import build_from_mapping
from models.account import Account
from models.model import App, EndUser
from libs.login import current_user
from models.model import App
from models.workflow import Workflow
logger = logging.getLogger(__name__)
@ -81,11 +79,11 @@ class WorkflowTool(Tool):
generator = WorkflowAppGenerator()
assert self.runtime is not None
assert self.runtime.invoke_from is not None
assert current_user is not None
result = generator.generate(
app_model=app,
workflow=workflow,
user=cast("Account | EndUser", current_user),
user=current_user,
args={"inputs": tool_parameters, "files": files},
invoke_from=self.runtime.invoke_from,
streaming=False,
@ -206,14 +204,14 @@ class WorkflowTool(Tool):
item = self._update_file_mapping(item)
file = build_from_mapping(
mapping=item,
tenant_id=str(cast(ToolRuntime, self.runtime).tenant_id),
tenant_id=str(self.runtime.tenant_id),
)
files.append(file)
elif isinstance(value, dict) and value.get("dify_model_identity") == FILE_MODEL_IDENTITY:
value = self._update_file_mapping(value)
file = build_from_mapping(
mapping=value,
tenant_id=str(cast(ToolRuntime, self.runtime).tenant_id),
tenant_id=str(self.runtime.tenant_id),
)
files.append(file)

View File

@ -1,5 +1,5 @@
from collections.abc import Sequence
from typing import Annotated, TypeAlias, cast
from typing import Annotated, TypeAlias
from uuid import uuid4
from pydantic import Discriminator, Field, Tag
@ -86,7 +86,7 @@ class SecretVariable(StringVariable):
@property
def log(self) -> str:
return cast(str, encrypter.obfuscated_token(self.value))
return encrypter.obfuscated_token(self.value)
class NoneVariable(NoneSegment, Variable):

View File

@ -166,7 +166,7 @@ class BaseIterationEvent(GraphEngineEvent):
parent_parallel_start_node_id: Optional[str] = None
"""parent parallel start node id if node is in parallel"""
parallel_mode_run_id: Optional[str] = None
"""iteratoin run in parallel mode run id"""
"""iteration run in parallel mode run id"""
class IterationRunStartedEvent(BaseIterationEvent):

View File

@ -374,7 +374,7 @@ class GraphEngine:
if len(sub_edge_mappings) == 0:
continue
edge = cast(GraphEdge, sub_edge_mappings[0])
edge = sub_edge_mappings[0]
if edge.run_condition is None:
logger.warning("Edge %s run condition is None", edge.target_node_id)
continue

View File

@ -153,7 +153,7 @@ class AgentNode(BaseNode):
messages=message_stream,
tool_info={
"icon": self.agent_strategy_icon,
"agent_strategy": cast(AgentNodeData, self._node_data).agent_strategy_name,
"agent_strategy": self._node_data.agent_strategy_name,
},
parameters_for_log=parameters_for_log,
user_id=self.user_id,
@ -394,8 +394,7 @@ class AgentNode(BaseNode):
current_plugin = next(
plugin
for plugin in plugins
if f"{plugin.plugin_id}/{plugin.name}"
== cast(AgentNodeData, self._node_data).agent_strategy_provider_name
if f"{plugin.plugin_id}/{plugin.name}" == self._node_data.agent_strategy_provider_name
)
icon = current_plugin.declaration.icon
except StopIteration:

View File

@ -149,9 +149,6 @@ class AnswerStreamProcessor(StreamProcessor):
return []
stream_output_value_selector = event.from_variable_selector
if not stream_output_value_selector:
return []
stream_out_answer_node_ids = []
for answer_node_id, route_position in self.route_position.items():
if answer_node_id not in self.rest_node_ids:

View File

@ -302,12 +302,12 @@ def _extract_text_from_yaml(file_content: bytes) -> str:
encoding = "utf-8"
yaml_data = yaml.safe_load_all(file_content.decode(encoding, errors="ignore"))
return cast(str, yaml.dump_all(yaml_data, allow_unicode=True, sort_keys=False))
return yaml.dump_all(yaml_data, allow_unicode=True, sort_keys=False)
except (UnicodeDecodeError, LookupError, yaml.YAMLError) as e:
# If decoding fails, try with utf-8 as last resort
try:
yaml_data = yaml.safe_load_all(file_content.decode("utf-8", errors="ignore"))
return cast(str, yaml.dump_all(yaml_data, allow_unicode=True, sort_keys=False))
return yaml.dump_all(yaml_data, allow_unicode=True, sort_keys=False)
except (UnicodeDecodeError, yaml.YAMLError):
raise TextExtractionError(f"Failed to decode or parse YAML file: {e}") from e
@ -515,14 +515,14 @@ def _extract_text_from_excel(file_content: bytes) -> str:
df.dropna(how="all", inplace=True)
# Combine multi-line text in each cell into a single line
df = df.applymap(lambda x: " ".join(str(x).splitlines()) if isinstance(x, str) else x) # type: ignore
df = df.map(lambda x: " ".join(str(x).splitlines()) if isinstance(x, str) else x)
# Combine multi-line text in column names into a single line
df.columns = pd.Index([" ".join(str(col).splitlines()) for col in df.columns])
# Manually construct the Markdown table
markdown_table += _construct_markdown_table(df) + "\n\n"
except Exception as e:
except Exception:
continue
return markdown_table
except Exception as e:

View File

@ -78,7 +78,7 @@ default_retrieval_model = {
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
"reranking_enable": False,
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
"top_k": 2,
"top_k": 4,
"score_threshold_enabled": False,
}

View File

@ -524,7 +524,10 @@ class LoopNode(BaseNode):
@staticmethod
def _get_segment_for_constant(var_type: SegmentType, original_value: Any) -> Segment:
"""Get the appropriate segment type for a constant value."""
if not var_type.is_array_type() or var_type == SegmentType.BOOLEAN:
# TODO: Refactor for maintainability:
# 1. Ensure type handling logic stays synchronized with _VALID_VAR_TYPE (entities.py)
# 2. Consider moving this method to LoopVariableData class for better encapsulation
if not var_type.is_array_type() or var_type == SegmentType.ARRAY_BOOLEAN:
value = original_value
elif var_type in [
SegmentType.ARRAY_NUMBER,

View File

@ -139,7 +139,7 @@ class ParameterExtractorNode(BaseNode):
"""
Run the node.
"""
node_data = cast(ParameterExtractorNodeData, self._node_data)
node_data = self._node_data
variable = self.graph_runtime_state.variable_pool.get(node_data.query)
query = variable.text if variable else ""

View File

@ -1,6 +1,6 @@
import json
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any, Optional, cast
from typing import TYPE_CHECKING, Any, Optional
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.memory.token_buffer_memory import TokenBufferMemory
@ -109,7 +109,7 @@ class QuestionClassifierNode(BaseNode):
return "1"
def _run(self):
node_data = cast(QuestionClassifierNodeData, self._node_data)
node_data = self._node_data
variable_pool = self.graph_runtime_state.variable_pool
# extract variables

View File

@ -1,5 +1,5 @@
from collections.abc import Generator, Mapping, Sequence
from typing import Any, Optional, cast
from typing import Any, Optional
from sqlalchemy import select
from sqlalchemy.orm import Session
@ -57,7 +57,7 @@ class ToolNode(BaseNode):
Run the tool node
"""
node_data = cast(ToolNodeData, self._node_data)
node_data = self._node_data
# fetch tool icon
tool_info = {

View File

@ -2,7 +2,7 @@ import logging
import time
import uuid
from collections.abc import Generator, Mapping, Sequence
from typing import Any, Optional, cast
from typing import Any, Optional
from configs import dify_config
from core.app.apps.exc import GenerateTaskStoppedError
@ -261,7 +261,6 @@ class WorkflowEntry:
environment_variables=[],
)
node_cls = cast(type[BaseNode], node_cls)
# init workflow run state
node: BaseNode = node_cls(
id=str(uuid.uuid4()),

View File

@ -34,10 +34,10 @@ if [[ "${MODE}" == "worker" ]]; then
if [[ -z "${CELERY_QUEUES}" ]]; then
if [[ "${EDITION}" == "CLOUD" ]]; then
# Cloud edition: separate queues for dataset and trigger tasks
DEFAULT_QUEUES="dataset,mail,ops_trace,app_deletion,plugin,workflow_storage,workflow_professional,workflow_team,workflow_sandbox"
DEFAULT_QUEUES="dataset,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow_professional,workflow_team,workflow_sandbox"
else
# Community edition (SELF_HOSTED): dataset and workflow have separate queues
DEFAULT_QUEUES="dataset,mail,ops_trace,app_deletion,plugin,workflow_storage,workflow"
DEFAULT_QUEUES="dataset,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow"
fi
else
DEFAULT_QUEUES="${CELERY_QUEUES}"

View File

@ -5,7 +5,7 @@ from dify_app import DifyApp
def init_app(app: DifyApp):
# register blueprint routers
from flask_cors import CORS # type: ignore
from flask_cors import CORS
from controllers.console import bp as console_app_bp
from controllers.files import bp as files_bp

View File

@ -9,7 +9,7 @@ from typing import Union
import flask
from celery.signals import worker_init
from flask_login import user_loaded_from_request, user_logged_in # type: ignore
from flask_login import user_loaded_from_request, user_logged_in
from configs import dify_config
from dify_app import DifyApp

View File

@ -65,7 +65,7 @@ class Storage:
from extensions.storage.volcengine_tos_storage import VolcengineTosStorage
return VolcengineTosStorage
case StorageType.SUPBASE:
case StorageType.SUPABASE:
from extensions.storage.supabase_storage import SupabaseStorage
return SupabaseStorage

View File

@ -1,9 +1,9 @@
import logging
from collections.abc import Generator
import boto3 # type: ignore
from botocore.client import Config # type: ignore
from botocore.exceptions import ClientError # type: ignore
import boto3
from botocore.client import Config
from botocore.exceptions import ClientError
from configs import dify_config
from extensions.storage.base_storage import BaseStorage

View File

@ -14,4 +14,4 @@ class StorageType(StrEnum):
S3 = "s3"
TENCENT_COS = "tencent-cos"
VOLCENGINE_TOS = "volcengine-tos"
SUPBASE = "supabase"
SUPABASE = "supabase"

Some files were not shown because too many files have changed in this diff Show More