Merge branch 'feat/queue-based-graph-engine' into feat/rag-2

This commit is contained in:
-LAN- 2025-09-08 14:30:43 +08:00
commit 23cd615489
No known key found for this signature in database
GPG Key ID: 6BA0D108DED011FF
828 changed files with 7240 additions and 2951 deletions

View File

@ -47,6 +47,10 @@ jobs:
if: steps.changed-files.outputs.any_changed == 'true' if: steps.changed-files.outputs.any_changed == 'true'
run: dev/basedpyright-check run: dev/basedpyright-check
- name: Run Mypy Type Checks
if: steps.changed-files.outputs.any_changed == 'true'
run: uv --directory api run mypy --exclude-gitignore --exclude 'tests/' --exclude 'migrations/' --check-untyped-defs --disable-error-code=import-untyped .
- name: Dotenv check - name: Dotenv check
if: steps.changed-files.outputs.any_changed == 'true' if: steps.changed-files.outputs.any_changed == 'true'
run: uv run --project api dotenv-linter ./api/.env.example ./web/.env.example run: uv run --project api dotenv-linter ./api/.env.example ./web/.env.example

8
.gitignore vendored
View File

@ -198,6 +198,7 @@ sdks/python-client/dify_client.egg-info
!.vscode/launch.json.template !.vscode/launch.json.template
!.vscode/README.md !.vscode/README.md
api/.vscode api/.vscode
web/.vscode
# vscode Code History Extension # vscode Code History Extension
.history .history
@ -215,6 +216,13 @@ mise.toml
# Next.js build output # Next.js build output
.next/ .next/
# PWA generated files
web/public/sw.js
web/public/sw.js.map
web/public/workbox-*.js
web/public/workbox-*.js.map
web/public/fallback-*.js
# AI Assistant # AI Assistant
.roo/ .roo/
api/.env.backup api/.env.backup

View File

@ -25,6 +25,9 @@ def create_flask_app_with_configs() -> DifyApp:
# add an unique identifier to each request # add an unique identifier to each request
RecyclableContextVar.increment_thread_recycles() RecyclableContextVar.increment_thread_recycles()
# Capture the decorator's return value to avoid pyright reportUnusedFunction
_ = before_request
return dify_app return dify_app

View File

@ -14,7 +14,6 @@ from sqlalchemy.exc import SQLAlchemyError
from configs import dify_config from configs import dify_config
from constants.languages import languages from constants.languages import languages
from core.helper import encrypter from core.helper import encrypter
from core.plugin.entities.plugin import PluginInstallationSource
from core.plugin.impl.plugin import PluginInstaller from core.plugin.impl.plugin import PluginInstaller
from core.rag.datasource.vdb.vector_factory import Vector from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.datasource.vdb.vector_type import VectorType from core.rag.datasource.vdb.vector_type import VectorType
@ -1494,7 +1493,7 @@ def transform_datasource_credentials():
for credential in credentials: for credential in credentials:
auth_count += 1 auth_count += 1
# get credential api key # get credential api key
credentials_json =json.loads(credential.credentials) credentials_json = json.loads(credential.credentials)
api_key = credentials_json.get("config", {}).get("api_key") api_key = credentials_json.get("config", {}).get("api_key")
base_url = credentials_json.get("config", {}).get("base_url") base_url = credentials_json.get("config", {}).get("base_url")
new_credentials = { new_credentials = {

View File

@ -300,8 +300,7 @@ class DatasetQueueMonitorConfig(BaseSettings):
class MiddlewareConfig( class MiddlewareConfig(
# place the configs in alphabet order # place the configs in alphabet order
CeleryConfig, CeleryConfig, # Note: CeleryConfig already inherits from DatabaseConfig
DatabaseConfig,
KeywordStoreConfig, KeywordStoreConfig,
RedisConfig, RedisConfig,
# configs of storage and storage providers # configs of storage and storage providers

View File

@ -1,9 +1,10 @@
from typing import Optional from typing import Optional
from pydantic import BaseModel, Field from pydantic import Field
from pydantic_settings import BaseSettings
class ClickzettaConfig(BaseModel): class ClickzettaConfig(BaseSettings):
""" """
Clickzetta Lakehouse vector database configuration Clickzetta Lakehouse vector database configuration
""" """

View File

@ -1,7 +1,8 @@
from pydantic import BaseModel, Field from pydantic import Field
from pydantic_settings import BaseSettings
class MatrixoneConfig(BaseModel): class MatrixoneConfig(BaseSettings):
"""Matrixone vector database configuration.""" """Matrixone vector database configuration."""
MATRIXONE_HOST: str = Field(default="localhost", description="Host address of the Matrixone server") MATRIXONE_HOST: str = Field(default="localhost", description="Host address of the Matrixone server")

View File

@ -1,6 +1,6 @@
from pydantic import Field from pydantic import Field
from configs.packaging.pyproject import PyProjectConfig, PyProjectTomlConfig from configs.packaging.pyproject import PyProjectTomlConfig
class PackagingInfo(PyProjectTomlConfig): class PackagingInfo(PyProjectTomlConfig):

View File

@ -4,8 +4,9 @@ import logging
import os import os
import threading import threading
import time import time
from collections.abc import Mapping from collections.abc import Callable, Mapping
from pathlib import Path from pathlib import Path
from typing import Any
from .python_3x import http_request, makedirs_wrapper from .python_3x import http_request, makedirs_wrapper
from .utils import ( from .utils import (
@ -25,13 +26,13 @@ logger = logging.getLogger(__name__)
class ApolloClient: class ApolloClient:
def __init__( def __init__(
self, self,
config_url, config_url: str,
app_id, app_id: str,
cluster="default", cluster: str = "default",
secret="", secret: str = "",
start_hot_update=True, start_hot_update: bool = True,
change_listener=None, change_listener: Callable[[str, str, str, Any], None] | None = None,
_notification_map=None, _notification_map: dict[str, int] | None = None,
): ):
# Core routing parameters # Core routing parameters
self.config_url = config_url self.config_url = config_url
@ -47,17 +48,17 @@ class ApolloClient:
# Private control variables # Private control variables
self._cycle_time = 5 self._cycle_time = 5
self._stopping = False self._stopping = False
self._cache = {} self._cache: dict[str, dict[str, Any]] = {}
self._no_key = {} self._no_key: dict[str, str] = {}
self._hash = {} self._hash: dict[str, str] = {}
self._pull_timeout = 75 self._pull_timeout = 75
self._cache_file_path = os.path.expanduser("~") + "/.dify/config/remote-settings/apollo/cache/" self._cache_file_path = os.path.expanduser("~") + "/.dify/config/remote-settings/apollo/cache/"
self._long_poll_thread = None self._long_poll_thread: threading.Thread | None = None
self._change_listener = change_listener # "add" "delete" "update" self._change_listener = change_listener # "add" "delete" "update"
if _notification_map is None: if _notification_map is None:
_notification_map = {"application": -1} _notification_map = {"application": -1}
self._notification_map = _notification_map self._notification_map = _notification_map
self.last_release_key = None self.last_release_key: str | None = None
# Private startup method # Private startup method
self._path_checker() self._path_checker()
if start_hot_update: if start_hot_update:
@ -68,7 +69,7 @@ class ApolloClient:
heartbeat.daemon = True heartbeat.daemon = True
heartbeat.start() heartbeat.start()
def get_json_from_net(self, namespace="application"): def get_json_from_net(self, namespace: str = "application") -> dict[str, Any] | None:
url = "{}/configs/{}/{}/{}?releaseKey={}&ip={}".format( url = "{}/configs/{}/{}/{}?releaseKey={}&ip={}".format(
self.config_url, self.app_id, self.cluster, namespace, "", self.ip self.config_url, self.app_id, self.cluster, namespace, "", self.ip
) )
@ -88,7 +89,7 @@ class ApolloClient:
logger.exception("an error occurred in get_json_from_net") logger.exception("an error occurred in get_json_from_net")
return None return None
def get_value(self, key, default_val=None, namespace="application"): def get_value(self, key: str, default_val: Any = None, namespace: str = "application") -> Any:
try: try:
# read memory configuration # read memory configuration
namespace_cache = self._cache.get(namespace) namespace_cache = self._cache.get(namespace)
@ -104,7 +105,8 @@ class ApolloClient:
namespace_data = self.get_json_from_net(namespace) namespace_data = self.get_json_from_net(namespace)
val = get_value_from_dict(namespace_data, key) val = get_value_from_dict(namespace_data, key)
if val is not None: if val is not None:
self._update_cache_and_file(namespace_data, namespace) if namespace_data is not None:
self._update_cache_and_file(namespace_data, namespace)
return val return val
# read the file configuration # read the file configuration
@ -126,23 +128,23 @@ class ApolloClient:
# to ensure the real-time correctness of the function call. # to ensure the real-time correctness of the function call.
# If the user does not have the same default val twice # If the user does not have the same default val twice
# and the default val is used here, there may be a problem. # and the default val is used here, there may be a problem.
def _set_local_cache_none(self, namespace, key): def _set_local_cache_none(self, namespace: str, key: str) -> None:
no_key = no_key_cache_key(namespace, key) no_key = no_key_cache_key(namespace, key)
self._no_key[no_key] = key self._no_key[no_key] = key
def _start_hot_update(self): def _start_hot_update(self) -> None:
self._long_poll_thread = threading.Thread(target=self._listener) self._long_poll_thread = threading.Thread(target=self._listener)
# When the asynchronous thread is started, the daemon thread will automatically exit # When the asynchronous thread is started, the daemon thread will automatically exit
# when the main thread is launched. # when the main thread is launched.
self._long_poll_thread.daemon = True self._long_poll_thread.daemon = True
self._long_poll_thread.start() self._long_poll_thread.start()
def stop(self): def stop(self) -> None:
self._stopping = True self._stopping = True
logger.info("Stopping listener...") logger.info("Stopping listener...")
# Call the set callback function, and if it is abnormal, try it out # Call the set callback function, and if it is abnormal, try it out
def _call_listener(self, namespace, old_kv, new_kv): def _call_listener(self, namespace: str, old_kv: dict[str, Any] | None, new_kv: dict[str, Any] | None) -> None:
if self._change_listener is None: if self._change_listener is None:
return return
if old_kv is None: if old_kv is None:
@ -168,12 +170,12 @@ class ApolloClient:
except BaseException as e: except BaseException as e:
logger.warning(str(e)) logger.warning(str(e))
def _path_checker(self): def _path_checker(self) -> None:
if not os.path.isdir(self._cache_file_path): if not os.path.isdir(self._cache_file_path):
makedirs_wrapper(self._cache_file_path) makedirs_wrapper(self._cache_file_path)
# update the local cache and file cache # update the local cache and file cache
def _update_cache_and_file(self, namespace_data, namespace="application"): def _update_cache_and_file(self, namespace_data: dict[str, Any], namespace: str = "application") -> None:
# update the local cache # update the local cache
self._cache[namespace] = namespace_data self._cache[namespace] = namespace_data
# update the file cache # update the file cache
@ -187,7 +189,7 @@ class ApolloClient:
self._hash[namespace] = new_hash self._hash[namespace] = new_hash
# get the configuration from the local file # get the configuration from the local file
def _get_local_cache(self, namespace="application"): def _get_local_cache(self, namespace: str = "application") -> dict[str, Any]:
cache_file_path = os.path.join(self._cache_file_path, f"{self.app_id}_configuration_{namespace}.txt") cache_file_path = os.path.join(self._cache_file_path, f"{self.app_id}_configuration_{namespace}.txt")
if os.path.isfile(cache_file_path): if os.path.isfile(cache_file_path):
with open(cache_file_path) as f: with open(cache_file_path) as f:
@ -195,8 +197,8 @@ class ApolloClient:
return result return result
return {} return {}
def _long_poll(self): def _long_poll(self) -> None:
notifications = [] notifications: list[dict[str, Any]] = []
for key in self._cache: for key in self._cache:
namespace_data = self._cache[key] namespace_data = self._cache[key]
notification_id = -1 notification_id = -1
@ -236,7 +238,7 @@ class ApolloClient:
except Exception as e: except Exception as e:
logger.warning(str(e)) logger.warning(str(e))
def _get_net_and_set_local(self, namespace, n_id, call_change=False): def _get_net_and_set_local(self, namespace: str, n_id: int, call_change: bool = False) -> None:
namespace_data = self.get_json_from_net(namespace) namespace_data = self.get_json_from_net(namespace)
if not namespace_data: if not namespace_data:
return return
@ -248,7 +250,7 @@ class ApolloClient:
new_kv = namespace_data.get(CONFIGURATIONS) new_kv = namespace_data.get(CONFIGURATIONS)
self._call_listener(namespace, old_kv, new_kv) self._call_listener(namespace, old_kv, new_kv)
def _listener(self): def _listener(self) -> None:
logger.info("start long_poll") logger.info("start long_poll")
while not self._stopping: while not self._stopping:
self._long_poll() self._long_poll()
@ -266,13 +268,13 @@ class ApolloClient:
headers["Timestamp"] = time_unix_now headers["Timestamp"] = time_unix_now
return headers return headers
def _heart_beat(self): def _heart_beat(self) -> None:
while not self._stopping: while not self._stopping:
for namespace in self._notification_map: for namespace in self._notification_map:
self._do_heart_beat(namespace) self._do_heart_beat(namespace)
time.sleep(60 * 10) # 10 minutes time.sleep(60 * 10) # 10 minutes
def _do_heart_beat(self, namespace): def _do_heart_beat(self, namespace: str) -> None:
url = f"{self.config_url}/configs/{self.app_id}/{self.cluster}/{namespace}?ip={self.ip}" url = f"{self.config_url}/configs/{self.app_id}/{self.cluster}/{namespace}?ip={self.ip}"
try: try:
code, body = http_request(url, timeout=3, headers=self._sign_headers(url)) code, body = http_request(url, timeout=3, headers=self._sign_headers(url))
@ -292,7 +294,7 @@ class ApolloClient:
logger.exception("an error occurred in _do_heart_beat") logger.exception("an error occurred in _do_heart_beat")
return None return None
def get_all_dicts(self, namespace): def get_all_dicts(self, namespace: str) -> dict[str, Any] | None:
namespace_data = self._cache.get(namespace) namespace_data = self._cache.get(namespace)
if namespace_data is None: if namespace_data is None:
net_namespace_data = self.get_json_from_net(namespace) net_namespace_data = self.get_json_from_net(namespace)

View File

@ -2,6 +2,8 @@ import logging
import os import os
import ssl import ssl
import urllib.request import urllib.request
from collections.abc import Mapping
from typing import Any
from urllib import parse from urllib import parse
from urllib.error import HTTPError from urllib.error import HTTPError
@ -19,9 +21,9 @@ urllib.request.install_opener(opener)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def http_request(url, timeout, headers={}): def http_request(url: str, timeout: int | float, headers: Mapping[str, str] = {}) -> tuple[int, str | None]:
try: try:
request = urllib.request.Request(url, headers=headers) request = urllib.request.Request(url, headers=dict(headers))
res = urllib.request.urlopen(request, timeout=timeout) res = urllib.request.urlopen(request, timeout=timeout)
body = res.read().decode("utf-8") body = res.read().decode("utf-8")
return res.code, body return res.code, body
@ -33,9 +35,9 @@ def http_request(url, timeout, headers={}):
raise e raise e
def url_encode(params): def url_encode(params: dict[str, Any]) -> str:
return parse.urlencode(params) return parse.urlencode(params)
def makedirs_wrapper(path): def makedirs_wrapper(path: str) -> None:
os.makedirs(path, exist_ok=True) os.makedirs(path, exist_ok=True)

View File

@ -1,5 +1,6 @@
import hashlib import hashlib
import socket import socket
from typing import Any
from .python_3x import url_encode from .python_3x import url_encode
@ -10,7 +11,7 @@ NAMESPACE_NAME = "namespaceName"
# add timestamps uris and keys # add timestamps uris and keys
def signature(timestamp, uri, secret): def signature(timestamp: str, uri: str, secret: str) -> str:
import base64 import base64
import hmac import hmac
@ -19,16 +20,16 @@ def signature(timestamp, uri, secret):
return base64.b64encode(hmac_code).decode() return base64.b64encode(hmac_code).decode()
def url_encode_wrapper(params): def url_encode_wrapper(params: dict[str, Any]) -> str:
return url_encode(params) return url_encode(params)
def no_key_cache_key(namespace, key): def no_key_cache_key(namespace: str, key: str) -> str:
return f"{namespace}{len(namespace)}{key}" return f"{namespace}{len(namespace)}{key}"
# Returns whether the obtained value is obtained, and None if it does not # Returns whether the obtained value is obtained, and None if it does not
def get_value_from_dict(namespace_cache, key): def get_value_from_dict(namespace_cache: dict[str, Any] | None, key: str) -> Any | None:
if namespace_cache: if namespace_cache:
kv_data = namespace_cache.get(CONFIGURATIONS) kv_data = namespace_cache.get(CONFIGURATIONS)
if kv_data is None: if kv_data is None:
@ -38,7 +39,7 @@ def get_value_from_dict(namespace_cache, key):
return None return None
def init_ip(): def init_ip() -> str:
ip = "" ip = ""
s = None s = None
try: try:

View File

@ -11,5 +11,5 @@ class RemoteSettingsSource:
def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]: def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
raise NotImplementedError raise NotImplementedError
def prepare_field_value(self, field_name: str, field: FieldInfo, value: Any, value_is_complex: bool) -> Any: def prepare_field_value(self, field_name: str, field: FieldInfo, value: Any, value_is_complex: bool):
return value return value

View File

@ -11,16 +11,16 @@ logger = logging.getLogger(__name__)
from configs.remote_settings_sources.base import RemoteSettingsSource from configs.remote_settings_sources.base import RemoteSettingsSource
from .utils import _parse_config from .utils import parse_config
class NacosSettingsSource(RemoteSettingsSource): class NacosSettingsSource(RemoteSettingsSource):
def __init__(self, configs: Mapping[str, Any]): def __init__(self, configs: Mapping[str, Any]):
self.configs = configs self.configs = configs
self.remote_configs: dict[str, Any] = {} self.remote_configs: dict[str, str] = {}
self.async_init() self.async_init()
def async_init(self): def async_init(self) -> None:
data_id = os.getenv("DIFY_ENV_NACOS_DATA_ID", "dify-api-env.properties") data_id = os.getenv("DIFY_ENV_NACOS_DATA_ID", "dify-api-env.properties")
group = os.getenv("DIFY_ENV_NACOS_GROUP", "nacos-dify") group = os.getenv("DIFY_ENV_NACOS_GROUP", "nacos-dify")
tenant = os.getenv("DIFY_ENV_NACOS_NAMESPACE", "") tenant = os.getenv("DIFY_ENV_NACOS_NAMESPACE", "")
@ -33,18 +33,15 @@ class NacosSettingsSource(RemoteSettingsSource):
logger.exception("[get-access-token] exception occurred") logger.exception("[get-access-token] exception occurred")
raise raise
def _parse_config(self, content: str) -> dict: def _parse_config(self, content: str) -> dict[str, str]:
if not content: if not content:
return {} return {}
try: try:
return _parse_config(self, content) return parse_config(content)
except Exception as e: except Exception as e:
raise RuntimeError(f"Failed to parse config: {e}") raise RuntimeError(f"Failed to parse config: {e}")
def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]: def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
if not isinstance(self.remote_configs, dict):
raise ValueError(f"remote configs is not dict, but {type(self.remote_configs)}")
field_value = self.remote_configs.get(field_name) field_value = self.remote_configs.get(field_name)
if field_value is None: if field_value is None:
return None, field_name, False return None, field_name, False

View File

@ -17,11 +17,17 @@ class NacosHttpClient:
self.ak = os.getenv("DIFY_ENV_NACOS_ACCESS_KEY") self.ak = os.getenv("DIFY_ENV_NACOS_ACCESS_KEY")
self.sk = os.getenv("DIFY_ENV_NACOS_SECRET_KEY") self.sk = os.getenv("DIFY_ENV_NACOS_SECRET_KEY")
self.server = os.getenv("DIFY_ENV_NACOS_SERVER_ADDR", "localhost:8848") self.server = os.getenv("DIFY_ENV_NACOS_SERVER_ADDR", "localhost:8848")
self.token = None self.token: str | None = None
self.token_ttl = 18000 self.token_ttl = 18000
self.token_expire_time: float = 0 self.token_expire_time: float = 0
def http_request(self, url, method="GET", headers=None, params=None): def http_request(
self, url: str, method: str = "GET", headers: dict[str, str] | None = None, params: dict[str, str] | None = None
) -> str:
if headers is None:
headers = {}
if params is None:
params = {}
try: try:
self._inject_auth_info(headers, params) self._inject_auth_info(headers, params)
response = requests.request(method, url="http://" + self.server + url, headers=headers, params=params) response = requests.request(method, url="http://" + self.server + url, headers=headers, params=params)
@ -30,7 +36,7 @@ class NacosHttpClient:
except requests.RequestException as e: except requests.RequestException as e:
return f"Request to Nacos failed: {e}" return f"Request to Nacos failed: {e}"
def _inject_auth_info(self, headers, params, module="config"): def _inject_auth_info(self, headers: dict[str, str], params: dict[str, str], module: str = "config") -> None:
headers.update({"User-Agent": "Nacos-Http-Client-In-Dify:v0.0.1"}) headers.update({"User-Agent": "Nacos-Http-Client-In-Dify:v0.0.1"})
if module == "login": if module == "login":
@ -45,16 +51,17 @@ class NacosHttpClient:
headers["timeStamp"] = ts headers["timeStamp"] = ts
if self.username and self.password: if self.username and self.password:
self.get_access_token(force_refresh=False) self.get_access_token(force_refresh=False)
params["accessToken"] = self.token if self.token is not None:
params["accessToken"] = self.token
def __do_sign(self, sign_str, sk): def __do_sign(self, sign_str: str, sk: str) -> str:
return ( return (
base64.encodebytes(hmac.new(sk.encode(), sign_str.encode(), digestmod=hashlib.sha1).digest()) base64.encodebytes(hmac.new(sk.encode(), sign_str.encode(), digestmod=hashlib.sha1).digest())
.decode() .decode()
.strip() .strip()
) )
def get_sign_str(self, group, tenant, ts): def get_sign_str(self, group: str, tenant: str, ts: str) -> str:
sign_str = "" sign_str = ""
if tenant: if tenant:
sign_str = tenant + "+" sign_str = tenant + "+"
@ -63,7 +70,7 @@ class NacosHttpClient:
sign_str += ts # Directly concatenate ts without conditional checks, because the nacos auth header forced it. sign_str += ts # Directly concatenate ts without conditional checks, because the nacos auth header forced it.
return sign_str return sign_str
def get_access_token(self, force_refresh=False): def get_access_token(self, force_refresh: bool = False) -> str | None:
current_time = time.time() current_time = time.time()
if self.token and not force_refresh and self.token_expire_time > current_time: if self.token and not force_refresh and self.token_expire_time > current_time:
return self.token return self.token
@ -77,6 +84,7 @@ class NacosHttpClient:
self.token = response_data.get("accessToken") self.token = response_data.get("accessToken")
self.token_ttl = response_data.get("tokenTtl", 18000) self.token_ttl = response_data.get("tokenTtl", 18000)
self.token_expire_time = current_time + self.token_ttl - 10 self.token_expire_time = current_time + self.token_ttl - 10
return self.token
except Exception: except Exception:
logger.exception("[get-access-token] exception occur") logger.exception("[get-access-token] exception occur")
raise raise

View File

@ -1,4 +1,4 @@
def _parse_config(self, content: str) -> dict[str, str]: def parse_config(content: str) -> dict[str, str]:
config: dict[str, str] = {} config: dict[str, str] = {}
if not content: if not content:
return config return config

View File

@ -1,4 +1,6 @@
from collections.abc import Callable
from functools import wraps from functools import wraps
from typing import ParamSpec, TypeVar
from flask import request from flask import request
from flask_restx import Resource, reqparse from flask_restx import Resource, reqparse
@ -6,6 +8,8 @@ from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound, Unauthorized from werkzeug.exceptions import NotFound, Unauthorized
P = ParamSpec("P")
R = TypeVar("R")
from configs import dify_config from configs import dify_config
from constants.languages import supported_language from constants.languages import supported_language
from controllers.console import api from controllers.console import api
@ -14,9 +18,9 @@ from extensions.ext_database import db
from models.model import App, InstalledApp, RecommendedApp from models.model import App, InstalledApp, RecommendedApp
def admin_required(view): def admin_required(view: Callable[P, R]):
@wraps(view) @wraps(view)
def decorated(*args, **kwargs): def decorated(*args: P.args, **kwargs: P.kwargs):
if not dify_config.ADMIN_API_KEY: if not dify_config.ADMIN_API_KEY:
raise Unauthorized("API key is invalid.") raise Unauthorized("API key is invalid.")

View File

@ -87,7 +87,7 @@ class BaseApiKeyListResource(Resource):
custom="max_keys_exceeded", custom="max_keys_exceeded",
) )
key = ApiToken.generate_api_key(self.token_prefix, 24) key = ApiToken.generate_api_key(self.token_prefix or "", 24)
api_token = ApiToken() api_token = ApiToken()
setattr(api_token, self.resource_id_field, resource_id) setattr(api_token, self.resource_id_field, resource_id)
api_token.tenant_id = current_user.current_tenant_id api_token.tenant_id = current_user.current_tenant_id

View File

@ -207,7 +207,7 @@ class InstructionGenerationTemplateApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self) -> dict: def post(self):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("type", type=str, required=True, default=False, location="json") parser.add_argument("type", type=str, required=True, default=False, location="json")
args = parser.parse_args() args = parser.parse_args()

View File

@ -1,5 +1,5 @@
import logging import logging
from typing import Any, NoReturn from typing import NoReturn
from flask import Response from flask import Response
from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse
@ -31,7 +31,7 @@ from services.workflow_service import WorkflowService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def _convert_values_to_json_serializable_object(value: Segment) -> Any: def _convert_values_to_json_serializable_object(value: Segment):
if isinstance(value, FileSegment): if isinstance(value, FileSegment):
return value.value.model_dump() return value.value.model_dump()
elif isinstance(value, ArrayFileSegment): elif isinstance(value, ArrayFileSegment):
@ -42,8 +42,7 @@ def _convert_values_to_json_serializable_object(value: Segment) -> Any:
return value.value return value.value
def _serialize_var_value(variable: WorkflowDraftVariable) -> Any: def _serialize_var_value(variable: WorkflowDraftVariable):
"""Serialize variable value. If variable is truncated, return the truncated value."""
value = variable.get_value() value = variable.get_value()
# create a copy of the value to avoid affecting the model cache. # create a copy of the value to avoid affecting the model cache.
value = value.model_copy(deep=True) value = value.model_copy(deep=True)

View File

@ -1,5 +1,6 @@
from collections.abc import Callable
from functools import wraps from functools import wraps
from typing import cast from typing import Concatenate, ParamSpec, TypeVar, cast
import flask_login import flask_login
from flask import jsonify, request from flask import jsonify, request
@ -15,10 +16,14 @@ from services.oauth_server import OAUTH_ACCESS_TOKEN_EXPIRES_IN, OAuthGrantType,
from .. import api from .. import api
P = ParamSpec("P")
R = TypeVar("R")
T = TypeVar("T")
def oauth_server_client_id_required(view):
def oauth_server_client_id_required(view: Callable[Concatenate[T, OAuthProviderApp, P], R]):
@wraps(view) @wraps(view)
def decorated(*args, **kwargs): def decorated(self: T, *args: P.args, **kwargs: P.kwargs):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("client_id", type=str, required=True, location="json") parser.add_argument("client_id", type=str, required=True, location="json")
parsed_args = parser.parse_args() parsed_args = parser.parse_args()
@ -30,18 +35,15 @@ def oauth_server_client_id_required(view):
if not oauth_provider_app: if not oauth_provider_app:
raise NotFound("client_id is invalid") raise NotFound("client_id is invalid")
kwargs["oauth_provider_app"] = oauth_provider_app return view(self, oauth_provider_app, *args, **kwargs)
return view(*args, **kwargs)
return decorated return decorated
def oauth_server_access_token_required(view): def oauth_server_access_token_required(view: Callable[Concatenate[T, OAuthProviderApp, Account, P], R]):
@wraps(view) @wraps(view)
def decorated(*args, **kwargs): def decorated(self: T, oauth_provider_app: OAuthProviderApp, *args: P.args, **kwargs: P.kwargs):
oauth_provider_app = kwargs.get("oauth_provider_app") if not isinstance(oauth_provider_app, OAuthProviderApp):
if not oauth_provider_app or not isinstance(oauth_provider_app, OAuthProviderApp):
raise BadRequest("Invalid oauth_provider_app") raise BadRequest("Invalid oauth_provider_app")
authorization_header = request.headers.get("Authorization") authorization_header = request.headers.get("Authorization")
@ -79,9 +81,7 @@ def oauth_server_access_token_required(view):
response.headers["WWW-Authenticate"] = "Bearer" response.headers["WWW-Authenticate"] = "Bearer"
return response return response
kwargs["account"] = account return view(self, oauth_provider_app, account, *args, **kwargs)
return view(*args, **kwargs)
return decorated return decorated

View File

@ -1,9 +1,9 @@
from flask_login import current_user
from flask_restx import Resource, reqparse from flask_restx import Resource, reqparse
from controllers.console import api from controllers.console import api
from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required
from libs.login import login_required from libs.login import current_user, login_required
from models.model import Account
from services.billing_service import BillingService from services.billing_service import BillingService
@ -17,9 +17,10 @@ class Subscription(Resource):
parser.add_argument("plan", type=str, required=True, location="args", choices=["professional", "team"]) parser.add_argument("plan", type=str, required=True, location="args", choices=["professional", "team"])
parser.add_argument("interval", type=str, required=True, location="args", choices=["month", "year"]) parser.add_argument("interval", type=str, required=True, location="args", choices=["month", "year"])
args = parser.parse_args() args = parser.parse_args()
assert isinstance(current_user, Account)
BillingService.is_tenant_owner_or_admin(current_user) BillingService.is_tenant_owner_or_admin(current_user)
assert current_user.current_tenant_id is not None
return BillingService.get_subscription( return BillingService.get_subscription(
args["plan"], args["interval"], current_user.email, current_user.current_tenant_id args["plan"], args["interval"], current_user.email, current_user.current_tenant_id
) )
@ -31,7 +32,9 @@ class Invoices(Resource):
@account_initialization_required @account_initialization_required
@only_edition_cloud @only_edition_cloud
def get(self): def get(self):
assert isinstance(current_user, Account)
BillingService.is_tenant_owner_or_admin(current_user) BillingService.is_tenant_owner_or_admin(current_user)
assert current_user.current_tenant_id is not None
return BillingService.get_invoices(current_user.email, current_user.current_tenant_id) return BillingService.get_invoices(current_user.email, current_user.current_tenant_id)

View File

@ -477,6 +477,8 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
data_source_info = document.data_source_info_dict data_source_info = document.data_source_info_dict
if document.data_source_type == "upload_file": if document.data_source_type == "upload_file":
if not data_source_info:
continue
file_id = data_source_info["upload_file_id"] file_id = data_source_info["upload_file_id"]
file_detail = ( file_detail = (
db.session.query(UploadFile) db.session.query(UploadFile)
@ -493,6 +495,8 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
extract_settings.append(extract_setting) extract_settings.append(extract_setting)
elif document.data_source_type == "notion_import": elif document.data_source_type == "notion_import":
if not data_source_info:
continue
extract_setting = ExtractSetting( extract_setting = ExtractSetting(
datasource_type=DatasourceType.NOTION.value, datasource_type=DatasourceType.NOTION.value,
notion_info={ notion_info={
@ -506,6 +510,8 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
) )
extract_settings.append(extract_setting) extract_settings.append(extract_setting)
elif document.data_source_type == "website_crawl": elif document.data_source_type == "website_crawl":
if not data_source_info:
continue
extract_setting = ExtractSetting( extract_setting = ExtractSetting(
datasource_type=DatasourceType.WEBSITE.value, datasource_type=DatasourceType.WEBSITE.value,
website_info={ website_info={

View File

@ -43,6 +43,8 @@ class ExploreAppMetaApi(InstalledAppResource):
def get(self, installed_app: InstalledApp): def get(self, installed_app: InstalledApp):
"""Get app meta""" """Get app meta"""
app_model = installed_app.app app_model = installed_app.app
if not app_model:
raise ValueError("App not found")
return AppService().get_app_meta(app_model) return AppService().get_app_meta(app_model)

View File

@ -36,6 +36,8 @@ class InstalledAppWorkflowRunApi(InstalledAppResource):
Run workflow Run workflow
""" """
app_model = installed_app.app app_model = installed_app.app
if not app_model:
raise NotWorkflowAppError()
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
if app_mode != AppMode.WORKFLOW: if app_mode != AppMode.WORKFLOW:
raise NotWorkflowAppError() raise NotWorkflowAppError()
@ -74,6 +76,8 @@ class InstalledAppWorkflowTaskStopApi(InstalledAppResource):
Stop workflow task Stop workflow task
""" """
app_model = installed_app.app app_model = installed_app.app
if not app_model:
raise NotWorkflowAppError()
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
if app_mode != AppMode.WORKFLOW: if app_mode != AppMode.WORKFLOW:
raise NotWorkflowAppError() raise NotWorkflowAppError()

View File

@ -1,4 +1,6 @@
from collections.abc import Callable
from functools import wraps from functools import wraps
from typing import Concatenate, Optional, ParamSpec, TypeVar
from flask_login import current_user from flask_login import current_user
from flask_restx import Resource from flask_restx import Resource
@ -13,19 +15,15 @@ from services.app_service import AppService
from services.enterprise.enterprise_service import EnterpriseService from services.enterprise.enterprise_service import EnterpriseService
from services.feature_service import FeatureService from services.feature_service import FeatureService
P = ParamSpec("P")
R = TypeVar("R")
T = TypeVar("T")
def installed_app_required(view=None):
def decorator(view): def installed_app_required(view: Optional[Callable[Concatenate[InstalledApp, P], R]] = None):
def decorator(view: Callable[Concatenate[InstalledApp, P], R]):
@wraps(view) @wraps(view)
def decorated(*args, **kwargs): def decorated(installed_app_id: str, *args: P.args, **kwargs: P.kwargs):
if not kwargs.get("installed_app_id"):
raise ValueError("missing installed_app_id in path parameters")
installed_app_id = kwargs.get("installed_app_id")
installed_app_id = str(installed_app_id)
del kwargs["installed_app_id"]
installed_app = ( installed_app = (
db.session.query(InstalledApp) db.session.query(InstalledApp)
.where( .where(
@ -52,10 +50,10 @@ def installed_app_required(view=None):
return decorator return decorator
def user_allowed_to_access_app(view=None): def user_allowed_to_access_app(view: Optional[Callable[Concatenate[InstalledApp, P], R]] = None):
def decorator(view): def decorator(view: Callable[Concatenate[InstalledApp, P], R]):
@wraps(view) @wraps(view)
def decorated(installed_app: InstalledApp, *args, **kwargs): def decorated(installed_app: InstalledApp, *args: P.args, **kwargs: P.kwargs):
feature = FeatureService.get_system_features() feature = FeatureService.get_system_features()
if feature.webapp_auth.enabled: if feature.webapp_auth.enabled:
app_id = installed_app.app_id app_id = installed_app.app_id

View File

@ -1,4 +1,6 @@
from collections.abc import Callable
from functools import wraps from functools import wraps
from typing import ParamSpec, TypeVar
from flask_login import current_user from flask_login import current_user
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@ -7,14 +9,17 @@ from werkzeug.exceptions import Forbidden
from extensions.ext_database import db from extensions.ext_database import db
from models.account import TenantPluginPermission from models.account import TenantPluginPermission
P = ParamSpec("P")
R = TypeVar("R")
def plugin_permission_required( def plugin_permission_required(
install_required: bool = False, install_required: bool = False,
debug_required: bool = False, debug_required: bool = False,
): ):
def interceptor(view): def interceptor(view: Callable[P, R]):
@wraps(view) @wraps(view)
def decorated(*args, **kwargs): def decorated(*args: P.args, **kwargs: P.kwargs):
user = current_user user = current_user
tenant_id = user.current_tenant_id tenant_id = user.current_tenant_id

View File

@ -2,7 +2,9 @@ import contextlib
import json import json
import os import os
import time import time
from collections.abc import Callable
from functools import wraps from functools import wraps
from typing import ParamSpec, TypeVar
from flask import abort, request from flask import abort, request
from flask_login import current_user from flask_login import current_user
@ -19,10 +21,13 @@ from services.operation_service import OperationService
from .error import NotInitValidateError, NotSetupError, UnauthorizedAndForceLogout from .error import NotInitValidateError, NotSetupError, UnauthorizedAndForceLogout
P = ParamSpec("P")
R = TypeVar("R")
def account_initialization_required(view):
def account_initialization_required(view: Callable[P, R]):
@wraps(view) @wraps(view)
def decorated(*args, **kwargs): def decorated(*args: P.args, **kwargs: P.kwargs):
# check account initialization # check account initialization
account = current_user account = current_user
@ -34,9 +39,9 @@ def account_initialization_required(view):
return decorated return decorated
def only_edition_cloud(view): def only_edition_cloud(view: Callable[P, R]):
@wraps(view) @wraps(view)
def decorated(*args, **kwargs): def decorated(*args: P.args, **kwargs: P.kwargs):
if dify_config.EDITION != "CLOUD": if dify_config.EDITION != "CLOUD":
abort(404) abort(404)
@ -45,9 +50,9 @@ def only_edition_cloud(view):
return decorated return decorated
def only_edition_enterprise(view): def only_edition_enterprise(view: Callable[P, R]):
@wraps(view) @wraps(view)
def decorated(*args, **kwargs): def decorated(*args: P.args, **kwargs: P.kwargs):
if not dify_config.ENTERPRISE_ENABLED: if not dify_config.ENTERPRISE_ENABLED:
abort(404) abort(404)
@ -56,9 +61,9 @@ def only_edition_enterprise(view):
return decorated return decorated
def only_edition_self_hosted(view): def only_edition_self_hosted(view: Callable[P, R]):
@wraps(view) @wraps(view)
def decorated(*args, **kwargs): def decorated(*args: P.args, **kwargs: P.kwargs):
if dify_config.EDITION != "SELF_HOSTED": if dify_config.EDITION != "SELF_HOSTED":
abort(404) abort(404)
@ -67,9 +72,9 @@ def only_edition_self_hosted(view):
return decorated return decorated
def cloud_edition_billing_enabled(view): def cloud_edition_billing_enabled(view: Callable[P, R]):
@wraps(view) @wraps(view)
def decorated(*args, **kwargs): def decorated(*args: P.args, **kwargs: P.kwargs):
features = FeatureService.get_features(current_user.current_tenant_id) features = FeatureService.get_features(current_user.current_tenant_id)
if not features.billing.enabled: if not features.billing.enabled:
abort(403, "Billing feature is not enabled.") abort(403, "Billing feature is not enabled.")
@ -79,9 +84,9 @@ def cloud_edition_billing_enabled(view):
def cloud_edition_billing_resource_check(resource: str): def cloud_edition_billing_resource_check(resource: str):
def interceptor(view): def interceptor(view: Callable[P, R]):
@wraps(view) @wraps(view)
def decorated(*args, **kwargs): def decorated(*args: P.args, **kwargs: P.kwargs):
features = FeatureService.get_features(current_user.current_tenant_id) features = FeatureService.get_features(current_user.current_tenant_id)
if features.billing.enabled: if features.billing.enabled:
members = features.members members = features.members
@ -120,9 +125,9 @@ def cloud_edition_billing_resource_check(resource: str):
def cloud_edition_billing_knowledge_limit_check(resource: str): def cloud_edition_billing_knowledge_limit_check(resource: str):
def interceptor(view): def interceptor(view: Callable[P, R]):
@wraps(view) @wraps(view)
def decorated(*args, **kwargs): def decorated(*args: P.args, **kwargs: P.kwargs):
features = FeatureService.get_features(current_user.current_tenant_id) features = FeatureService.get_features(current_user.current_tenant_id)
if features.billing.enabled: if features.billing.enabled:
if resource == "add_segment": if resource == "add_segment":
@ -142,9 +147,9 @@ def cloud_edition_billing_knowledge_limit_check(resource: str):
def cloud_edition_billing_rate_limit_check(resource: str): def cloud_edition_billing_rate_limit_check(resource: str):
def interceptor(view): def interceptor(view: Callable[P, R]):
@wraps(view) @wraps(view)
def decorated(*args, **kwargs): def decorated(*args: P.args, **kwargs: P.kwargs):
if resource == "knowledge": if resource == "knowledge":
knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(current_user.current_tenant_id) knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(current_user.current_tenant_id)
if knowledge_rate_limit.enabled: if knowledge_rate_limit.enabled:
@ -176,9 +181,9 @@ def cloud_edition_billing_rate_limit_check(resource: str):
return interceptor return interceptor
def cloud_utm_record(view): def cloud_utm_record(view: Callable[P, R]):
@wraps(view) @wraps(view)
def decorated(*args, **kwargs): def decorated(*args: P.args, **kwargs: P.kwargs):
with contextlib.suppress(Exception): with contextlib.suppress(Exception):
features = FeatureService.get_features(current_user.current_tenant_id) features = FeatureService.get_features(current_user.current_tenant_id)
@ -194,9 +199,9 @@ def cloud_utm_record(view):
return decorated return decorated
def setup_required(view): def setup_required(view: Callable[P, R]):
@wraps(view) @wraps(view)
def decorated(*args, **kwargs): def decorated(*args: P.args, **kwargs: P.kwargs):
# check setup # check setup
if ( if (
dify_config.EDITION == "SELF_HOSTED" dify_config.EDITION == "SELF_HOSTED"
@ -212,9 +217,9 @@ def setup_required(view):
return decorated return decorated
def enterprise_license_required(view): def enterprise_license_required(view: Callable[P, R]):
@wraps(view) @wraps(view)
def decorated(*args, **kwargs): def decorated(*args: P.args, **kwargs: P.kwargs):
settings = FeatureService.get_system_features() settings = FeatureService.get_system_features()
if settings.license.status in [LicenseStatus.INACTIVE, LicenseStatus.EXPIRED, LicenseStatus.LOST]: if settings.license.status in [LicenseStatus.INACTIVE, LicenseStatus.EXPIRED, LicenseStatus.LOST]:
raise UnauthorizedAndForceLogout("Your license is invalid. Please contact your administrator.") raise UnauthorizedAndForceLogout("Your license is invalid. Please contact your administrator.")
@ -224,9 +229,9 @@ def enterprise_license_required(view):
return decorated return decorated
def email_password_login_enabled(view): def email_password_login_enabled(view: Callable[P, R]):
@wraps(view) @wraps(view)
def decorated(*args, **kwargs): def decorated(*args: P.args, **kwargs: P.kwargs):
features = FeatureService.get_system_features() features = FeatureService.get_system_features()
if features.enable_email_password_login: if features.enable_email_password_login:
return view(*args, **kwargs) return view(*args, **kwargs)
@ -237,9 +242,9 @@ def email_password_login_enabled(view):
return decorated return decorated
def enable_change_email(view): def enable_change_email(view: Callable[P, R]):
@wraps(view) @wraps(view)
def decorated(*args, **kwargs): def decorated(*args: P.args, **kwargs: P.kwargs):
features = FeatureService.get_system_features() features = FeatureService.get_system_features()
if features.enable_change_email: if features.enable_change_email:
return view(*args, **kwargs) return view(*args, **kwargs)
@ -250,9 +255,9 @@ def enable_change_email(view):
return decorated return decorated
def is_allow_transfer_owner(view): def is_allow_transfer_owner(view: Callable[P, R]):
@wraps(view) @wraps(view)
def decorated(*args, **kwargs): def decorated(*args: P.args, **kwargs: P.kwargs):
features = FeatureService.get_features(current_user.current_tenant_id) features = FeatureService.get_features(current_user.current_tenant_id)
if features.is_allow_transfer_workspace: if features.is_allow_transfer_workspace:
return view(*args, **kwargs) return view(*args, **kwargs)

View File

@ -99,7 +99,7 @@ class MCPAppApi(Resource):
return mcp_server, app return mcp_server, app
def _validate_server_status(self, mcp_server: AppMCPServer) -> None: def _validate_server_status(self, mcp_server: AppMCPServer):
"""Validate MCP server status""" """Validate MCP server status"""
if mcp_server.status != AppMCPServerStatus.ACTIVE: if mcp_server.status != AppMCPServerStatus.ACTIVE:
raise MCPRequestError(mcp_types.INVALID_REQUEST, "Server is not active") raise MCPRequestError(mcp_types.INVALID_REQUEST, "Server is not active")

View File

@ -440,7 +440,7 @@ class DatasetChildChunkApi(DatasetApiResource):
raise NotFound("Segment not found.") raise NotFound("Segment not found.")
# validate segment belongs to the specified document # validate segment belongs to the specified document
if segment.document_id != document_id: if str(segment.document_id) != str(document_id):
raise NotFound("Document not found.") raise NotFound("Document not found.")
# check child chunk # check child chunk
@ -451,7 +451,7 @@ class DatasetChildChunkApi(DatasetApiResource):
raise NotFound("Child chunk not found.") raise NotFound("Child chunk not found.")
# validate child chunk belongs to the specified segment # validate child chunk belongs to the specified segment
if child_chunk.segment_id != segment.id: if str(child_chunk.segment_id) != str(segment.id):
raise NotFound("Child chunk not found.") raise NotFound("Child chunk not found.")
try: try:
@ -500,7 +500,7 @@ class DatasetChildChunkApi(DatasetApiResource):
raise NotFound("Segment not found.") raise NotFound("Segment not found.")
# validate segment belongs to the specified document # validate segment belongs to the specified document
if segment.document_id != document_id: if str(segment.document_id) != str(document_id):
raise NotFound("Segment not found.") raise NotFound("Segment not found.")
# get child chunk # get child chunk
@ -511,7 +511,7 @@ class DatasetChildChunkApi(DatasetApiResource):
raise NotFound("Child chunk not found.") raise NotFound("Child chunk not found.")
# validate child chunk belongs to the specified segment # validate child chunk belongs to the specified segment
if child_chunk.segment_id != segment.id: if str(child_chunk.segment_id) != str(segment.id):
raise NotFound("Child chunk not found.") raise NotFound("Child chunk not found.")
# validate args # validate args

View File

@ -3,7 +3,7 @@ from collections.abc import Callable
from datetime import timedelta from datetime import timedelta
from enum import StrEnum, auto from enum import StrEnum, auto
from functools import wraps from functools import wraps
from typing import Optional from typing import Optional, ParamSpec, TypeVar
from flask import current_app, request from flask import current_app, request
from flask_login import user_logged_in from flask_login import user_logged_in
@ -22,6 +22,9 @@ from models.dataset import Dataset, RateLimitLog
from models.model import ApiToken, App, EndUser from models.model import ApiToken, App, EndUser
from services.feature_service import FeatureService from services.feature_service import FeatureService
P = ParamSpec("P")
R = TypeVar("R")
class WhereisUserArg(StrEnum): class WhereisUserArg(StrEnum):
""" """
@ -60,27 +63,6 @@ def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optio
if tenant.status == TenantStatus.ARCHIVE: if tenant.status == TenantStatus.ARCHIVE:
raise Forbidden("The workspace's status is archived.") raise Forbidden("The workspace's status is archived.")
tenant_account_join = (
db.session.query(Tenant, TenantAccountJoin)
.where(Tenant.id == api_token.tenant_id)
.where(TenantAccountJoin.tenant_id == Tenant.id)
.where(TenantAccountJoin.role.in_(["owner"]))
.where(Tenant.status == TenantStatus.NORMAL)
.one_or_none()
) # TODO: only owner information is required, so only one is returned.
if tenant_account_join:
tenant, ta = tenant_account_join
account = db.session.query(Account).where(Account.id == ta.account_id).first()
# Login admin
if account:
account.current_tenant = tenant
current_app.login_manager._update_request_context_with_user(account) # type: ignore
user_logged_in.send(current_app._get_current_object(), user=_get_user()) # type: ignore
else:
raise Unauthorized("Tenant owner account does not exist.")
else:
raise Unauthorized("Tenant does not exist.")
kwargs["app_model"] = app_model kwargs["app_model"] = app_model
if fetch_user_arg: if fetch_user_arg:
@ -118,8 +100,8 @@ def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optio
def cloud_edition_billing_resource_check(resource: str, api_token_type: str): def cloud_edition_billing_resource_check(resource: str, api_token_type: str):
def interceptor(view): def interceptor(view: Callable[P, R]):
def decorated(*args, **kwargs): def decorated(*args: P.args, **kwargs: P.kwargs):
api_token = validate_and_get_api_token(api_token_type) api_token = validate_and_get_api_token(api_token_type)
features = FeatureService.get_features(api_token.tenant_id) features = FeatureService.get_features(api_token.tenant_id)
@ -148,9 +130,9 @@ def cloud_edition_billing_resource_check(resource: str, api_token_type: str):
def cloud_edition_billing_knowledge_limit_check(resource: str, api_token_type: str): def cloud_edition_billing_knowledge_limit_check(resource: str, api_token_type: str):
def interceptor(view): def interceptor(view: Callable[P, R]):
@wraps(view) @wraps(view)
def decorated(*args, **kwargs): def decorated(*args: P.args, **kwargs: P.kwargs):
api_token = validate_and_get_api_token(api_token_type) api_token = validate_and_get_api_token(api_token_type)
features = FeatureService.get_features(api_token.tenant_id) features = FeatureService.get_features(api_token.tenant_id)
if features.billing.enabled: if features.billing.enabled:
@ -170,9 +152,9 @@ def cloud_edition_billing_knowledge_limit_check(resource: str, api_token_type: s
def cloud_edition_billing_rate_limit_check(resource: str, api_token_type: str): def cloud_edition_billing_rate_limit_check(resource: str, api_token_type: str):
def interceptor(view): def interceptor(view: Callable[P, R]):
@wraps(view) @wraps(view)
def decorated(*args, **kwargs): def decorated(*args: P.args, **kwargs: P.kwargs):
api_token = validate_and_get_api_token(api_token_type) api_token = validate_and_get_api_token(api_token_type)
if resource == "knowledge": if resource == "knowledge":

View File

@ -1,5 +1,6 @@
from datetime import UTC, datetime from datetime import UTC, datetime
from functools import wraps from functools import wraps
from typing import ParamSpec, TypeVar
from flask import request from flask import request
from flask_restx import Resource from flask_restx import Resource
@ -15,6 +16,9 @@ from services.enterprise.enterprise_service import EnterpriseService, WebAppSett
from services.feature_service import FeatureService from services.feature_service import FeatureService
from services.webapp_auth_service import WebAppAuthService from services.webapp_auth_service import WebAppAuthService
P = ParamSpec("P")
R = TypeVar("R")
def validate_jwt_token(view=None): def validate_jwt_token(view=None):
def decorator(view): def decorator(view):

View File

@ -62,7 +62,7 @@ class BaseAgentRunner(AppRunner):
model_instance: ModelInstance, model_instance: ModelInstance,
memory: Optional[TokenBufferMemory] = None, memory: Optional[TokenBufferMemory] = None,
prompt_messages: Optional[list[PromptMessage]] = None, prompt_messages: Optional[list[PromptMessage]] = None,
) -> None: ):
self.tenant_id = tenant_id self.tenant_id = tenant_id
self.application_generate_entity = application_generate_entity self.application_generate_entity = application_generate_entity
self.conversation = conversation self.conversation = conversation

View File

@ -338,7 +338,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
return instruction return instruction
def _init_react_state(self, query) -> None: def _init_react_state(self, query):
""" """
init agent scratchpad init agent scratchpad
""" """

View File

@ -41,7 +41,7 @@ class AgentScratchpadUnit(BaseModel):
action_name: str action_name: str
action_input: Union[dict, str] action_input: Union[dict, str]
def to_dict(self) -> dict: def to_dict(self):
""" """
Convert to dictionary. Convert to dictionary.
""" """

View File

@ -158,7 +158,7 @@ class DatasetConfigManager:
return config, ["agent_mode", "dataset_configs", "dataset_query_variable"] return config, ["agent_mode", "dataset_configs", "dataset_query_variable"]
@classmethod @classmethod
def extract_dataset_config_for_legacy_compatibility(cls, tenant_id: str, app_mode: AppMode, config: dict) -> dict: def extract_dataset_config_for_legacy_compatibility(cls, tenant_id: str, app_mode: AppMode, config: dict):
""" """
Extract dataset config for legacy compatibility Extract dataset config for legacy compatibility

View File

@ -105,7 +105,7 @@ class ModelConfigManager:
return dict(config), ["model"] return dict(config), ["model"]
@classmethod @classmethod
def validate_model_completion_params(cls, cp: dict) -> dict: def validate_model_completion_params(cls, cp: dict):
# model.completion_params # model.completion_params
if not isinstance(cp, dict): if not isinstance(cp, dict):
raise ValueError("model.completion_params must be of object type") raise ValueError("model.completion_params must be of object type")

View File

@ -122,7 +122,7 @@ class PromptTemplateConfigManager:
return config, ["prompt_type", "pre_prompt", "chat_prompt_config", "completion_prompt_config"] return config, ["prompt_type", "pre_prompt", "chat_prompt_config", "completion_prompt_config"]
@classmethod @classmethod
def validate_post_prompt_and_set_defaults(cls, config: dict) -> dict: def validate_post_prompt_and_set_defaults(cls, config: dict):
""" """
Validate post_prompt and set defaults for prompt feature Validate post_prompt and set defaults for prompt feature

View File

@ -41,7 +41,7 @@ class AdvancedChatAppConfigManager(BaseAppConfigManager):
return app_config return app_config
@classmethod @classmethod
def config_validate(cls, tenant_id: str, config: dict, only_structure_validate: bool = False) -> dict: def config_validate(cls, tenant_id: str, config: dict, only_structure_validate: bool = False):
""" """
Validate for advanced chat app model config Validate for advanced chat app model config

View File

@ -481,7 +481,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
message_id: str, message_id: str,
context: contextvars.Context, context: contextvars.Context,
variable_loader: VariableLoader, variable_loader: VariableLoader,
) -> None: ):
""" """
Generate worker in a new thread. Generate worker in a new thread.
:param flask_app: Flask app :param flask_app: Flask app

View File

@ -55,7 +55,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
workflow: Workflow, workflow: Workflow,
system_user_id: str, system_user_id: str,
app: App, app: App,
) -> None: ):
super().__init__( super().__init__(
queue_manager=queue_manager, queue_manager=queue_manager,
variable_loader=variable_loader, variable_loader=variable_loader,
@ -69,7 +69,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
self.system_user_id = system_user_id self.system_user_id = system_user_id
self._app = app self._app = app
def run(self) -> None: def run(self):
app_config = self.application_generate_entity.app_config app_config = self.application_generate_entity.app_config
app_config = cast(AdvancedChatAppConfig, app_config) app_config = cast(AdvancedChatAppConfig, app_config)
@ -184,6 +184,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
), ),
invoke_from=self.application_generate_entity.invoke_from, invoke_from=self.application_generate_entity.invoke_from,
call_depth=self.application_generate_entity.call_depth, call_depth=self.application_generate_entity.call_depth,
variable_pool=variable_pool,
graph_runtime_state=graph_runtime_state, graph_runtime_state=graph_runtime_state,
command_channel=command_channel, command_channel=command_channel,
) )
@ -238,7 +239,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
return False return False
def _complete_with_stream_output(self, text: str, stopped_by: QueueStopEvent.StopBy) -> None: def _complete_with_stream_output(self, text: str, stopped_by: QueueStopEvent.StopBy):
""" """
Direct output Direct output
""" """

View File

@ -96,7 +96,7 @@ class AdvancedChatAppGenerateTaskPipeline:
workflow_execution_repository: WorkflowExecutionRepository, workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository, workflow_node_execution_repository: WorkflowNodeExecutionRepository,
draft_var_saver_factory: DraftVariableSaverFactory, draft_var_saver_factory: DraftVariableSaverFactory,
) -> None: ):
self._base_task_pipeline = BasedGenerateTaskPipeline( self._base_task_pipeline = BasedGenerateTaskPipeline(
application_generate_entity=application_generate_entity, application_generate_entity=application_generate_entity,
queue_manager=queue_manager, queue_manager=queue_manager,
@ -284,7 +284,7 @@ class AdvancedChatAppGenerateTaskPipeline:
session.rollback() session.rollback()
raise raise
def _ensure_workflow_initialized(self) -> None: def _ensure_workflow_initialized(self):
"""Fluent validation for workflow state.""" """Fluent validation for workflow state."""
if not self._workflow_run_id: if not self._workflow_run_id:
raise ValueError("workflow run not initialized.") raise ValueError("workflow run not initialized.")
@ -835,7 +835,7 @@ class AdvancedChatAppGenerateTaskPipeline:
if self._conversation_name_generate_thread: if self._conversation_name_generate_thread:
self._conversation_name_generate_thread.join() self._conversation_name_generate_thread.join()
def _save_message(self, *, session: Session, graph_runtime_state: Optional[GraphRuntimeState] = None) -> None: def _save_message(self, *, session: Session, graph_runtime_state: Optional[GraphRuntimeState] = None):
message = self._get_message(session=session) message = self._get_message(session=session)
# If there are assistant files, remove markdown image links from answer # If there are assistant files, remove markdown image links from answer

View File

@ -86,7 +86,7 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
return app_config return app_config
@classmethod @classmethod
def config_validate(cls, tenant_id: str, config: Mapping[str, Any]) -> dict: def config_validate(cls, tenant_id: str, config: Mapping[str, Any]):
""" """
Validate for agent chat app model config Validate for agent chat app model config

View File

@ -222,7 +222,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
queue_manager: AppQueueManager, queue_manager: AppQueueManager,
conversation_id: str, conversation_id: str,
message_id: str, message_id: str,
) -> None: ):
""" """
Generate worker in a new thread. Generate worker in a new thread.
:param flask_app: Flask app :param flask_app: Flask app

View File

@ -35,7 +35,7 @@ class AgentChatAppRunner(AppRunner):
queue_manager: AppQueueManager, queue_manager: AppQueueManager,
conversation: Conversation, conversation: Conversation,
message: Message, message: Message,
) -> None: ):
""" """
Run assistant application Run assistant application
:param application_generate_entity: application generate entity :param application_generate_entity: application generate entity

View File

@ -16,7 +16,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
_blocking_response_type = ChatbotAppBlockingResponse _blocking_response_type = ChatbotAppBlockingResponse
@classmethod @classmethod
def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict: # type: ignore[override] def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse): # type: ignore[override]
""" """
Convert blocking full response. Convert blocking full response.
:param blocking_response: blocking response :param blocking_response: blocking response
@ -37,7 +37,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
return response return response
@classmethod @classmethod
def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict: # type: ignore[override] def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse): # type: ignore[override]
""" """
Convert blocking simple response. Convert blocking simple response.
:param blocking_response: blocking response :param blocking_response: blocking response

View File

@ -94,7 +94,7 @@ class AppGenerateResponseConverter(ABC):
return metadata return metadata
@classmethod @classmethod
def _error_to_stream_response(cls, e: Exception) -> dict: def _error_to_stream_response(cls, e: Exception):
""" """
Error to stream response. Error to stream response.
:param e: exception :param e: exception

View File

@ -158,7 +158,7 @@ class BaseAppGenerator:
return value return value
def _sanitize_value(self, value: Any) -> Any: def _sanitize_value(self, value: Any):
if isinstance(value, str): if isinstance(value, str):
return value.replace("\x00", "") return value.replace("\x00", "")
return value return value

View File

@ -25,7 +25,7 @@ class PublishFrom(IntEnum):
class AppQueueManager: class AppQueueManager:
def __init__(self, task_id: str, user_id: str, invoke_from: InvokeFrom) -> None: def __init__(self, task_id: str, user_id: str, invoke_from: InvokeFrom):
if not user_id: if not user_id:
raise ValueError("user is required") raise ValueError("user is required")
@ -73,14 +73,14 @@ class AppQueueManager:
self.publish(QueuePingEvent(), PublishFrom.TASK_PIPELINE) self.publish(QueuePingEvent(), PublishFrom.TASK_PIPELINE)
last_ping_time = elapsed_time // 10 last_ping_time = elapsed_time // 10
def stop_listen(self) -> None: def stop_listen(self):
""" """
Stop listen to queue Stop listen to queue
:return: :return:
""" """
self._q.put(None) self._q.put(None)
def publish_error(self, e, pub_from: PublishFrom) -> None: def publish_error(self, e, pub_from: PublishFrom):
""" """
Publish error Publish error
:param e: error :param e: error
@ -89,7 +89,7 @@ class AppQueueManager:
""" """
self.publish(QueueErrorEvent(error=e), pub_from) self.publish(QueueErrorEvent(error=e), pub_from)
def publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None: def publish(self, event: AppQueueEvent, pub_from: PublishFrom):
""" """
Publish event to queue Publish event to queue
:param event: :param event:
@ -100,7 +100,7 @@ class AppQueueManager:
self._publish(event, pub_from) self._publish(event, pub_from)
@abstractmethod @abstractmethod
def _publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None: def _publish(self, event: AppQueueEvent, pub_from: PublishFrom):
""" """
Publish event to queue Publish event to queue
:param event: :param event:
@ -110,7 +110,7 @@ class AppQueueManager:
raise NotImplementedError raise NotImplementedError
@classmethod @classmethod
def set_stop_flag(cls, task_id: str, invoke_from: InvokeFrom, user_id: str) -> None: def set_stop_flag(cls, task_id: str, invoke_from: InvokeFrom, user_id: str):
""" """
Set task stop flag Set task stop flag
:return: :return:

View File

@ -162,7 +162,7 @@ class AppRunner:
text: str, text: str,
stream: bool, stream: bool,
usage: Optional[LLMUsage] = None, usage: Optional[LLMUsage] = None,
) -> None: ):
""" """
Direct output Direct output
:param queue_manager: application queue manager :param queue_manager: application queue manager
@ -204,7 +204,7 @@ class AppRunner:
queue_manager: AppQueueManager, queue_manager: AppQueueManager,
stream: bool, stream: bool,
agent: bool = False, agent: bool = False,
) -> None: ):
""" """
Handle invoke result Handle invoke result
:param invoke_result: invoke result :param invoke_result: invoke result
@ -220,9 +220,7 @@ class AppRunner:
else: else:
raise NotImplementedError(f"unsupported invoke result type: {type(invoke_result)}") raise NotImplementedError(f"unsupported invoke result type: {type(invoke_result)}")
def _handle_invoke_result_direct( def _handle_invoke_result_direct(self, invoke_result: LLMResult, queue_manager: AppQueueManager, agent: bool):
self, invoke_result: LLMResult, queue_manager: AppQueueManager, agent: bool
) -> None:
""" """
Handle invoke result direct Handle invoke result direct
:param invoke_result: invoke result :param invoke_result: invoke result
@ -239,7 +237,7 @@ class AppRunner:
def _handle_invoke_result_stream( def _handle_invoke_result_stream(
self, invoke_result: Generator[LLMResultChunk, None, None], queue_manager: AppQueueManager, agent: bool self, invoke_result: Generator[LLMResultChunk, None, None], queue_manager: AppQueueManager, agent: bool
) -> None: ):
""" """
Handle invoke result Handle invoke result
:param invoke_result: invoke result :param invoke_result: invoke result

View File

@ -81,7 +81,7 @@ class ChatAppConfigManager(BaseAppConfigManager):
return app_config return app_config
@classmethod @classmethod
def config_validate(cls, tenant_id: str, config: dict) -> dict: def config_validate(cls, tenant_id: str, config: dict):
""" """
Validate for chat app model config Validate for chat app model config

View File

@ -211,7 +211,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
queue_manager: AppQueueManager, queue_manager: AppQueueManager,
conversation_id: str, conversation_id: str,
message_id: str, message_id: str,
) -> None: ):
""" """
Generate worker in a new thread. Generate worker in a new thread.
:param flask_app: Flask app :param flask_app: Flask app

View File

@ -33,7 +33,7 @@ class ChatAppRunner(AppRunner):
queue_manager: AppQueueManager, queue_manager: AppQueueManager,
conversation: Conversation, conversation: Conversation,
message: Message, message: Message,
) -> None: ):
""" """
Run application Run application
:param application_generate_entity: application generate entity :param application_generate_entity: application generate entity

View File

@ -16,7 +16,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
_blocking_response_type = ChatbotAppBlockingResponse _blocking_response_type = ChatbotAppBlockingResponse
@classmethod @classmethod
def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict: # type: ignore[override] def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse): # type: ignore[override]
""" """
Convert blocking full response. Convert blocking full response.
:param blocking_response: blocking response :param blocking_response: blocking response
@ -37,7 +37,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
return response return response
@classmethod @classmethod
def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict: # type: ignore[override] def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse): # type: ignore[override]
""" """
Convert blocking simple response. Convert blocking simple response.
:param blocking_response: blocking response :param blocking_response: blocking response

View File

@ -56,7 +56,7 @@ class WorkflowResponseConverter:
*, *,
application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity], application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity],
user: Union[Account, EndUser], user: Union[Account, EndUser],
) -> None: ):
self._application_generate_entity = application_generate_entity self._application_generate_entity = application_generate_entity
self._user = user self._user = user
self._truncator = VariableTruncator.default() self._truncator = VariableTruncator.default()

View File

@ -66,7 +66,7 @@ class CompletionAppConfigManager(BaseAppConfigManager):
return app_config return app_config
@classmethod @classmethod
def config_validate(cls, tenant_id: str, config: dict) -> dict: def config_validate(cls, tenant_id: str, config: dict):
""" """
Validate for completion app model config Validate for completion app model config

View File

@ -192,7 +192,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
application_generate_entity: CompletionAppGenerateEntity, application_generate_entity: CompletionAppGenerateEntity,
queue_manager: AppQueueManager, queue_manager: AppQueueManager,
message_id: str, message_id: str,
) -> None: ):
""" """
Generate worker in a new thread. Generate worker in a new thread.
:param flask_app: Flask app :param flask_app: Flask app
@ -262,6 +262,9 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
raise MessageNotExistsError() raise MessageNotExistsError()
current_app_model_config = app_model.app_model_config current_app_model_config = app_model.app_model_config
if not current_app_model_config:
raise MoreLikeThisDisabledError()
more_like_this = current_app_model_config.more_like_this_dict more_like_this = current_app_model_config.more_like_this_dict
if not current_app_model_config.more_like_this or more_like_this.get("enabled", False) is False: if not current_app_model_config.more_like_this or more_like_this.get("enabled", False) is False:

View File

@ -27,7 +27,7 @@ class CompletionAppRunner(AppRunner):
def run( def run(
self, application_generate_entity: CompletionAppGenerateEntity, queue_manager: AppQueueManager, message: Message self, application_generate_entity: CompletionAppGenerateEntity, queue_manager: AppQueueManager, message: Message
) -> None: ):
""" """
Run application Run application
:param application_generate_entity: application generate entity :param application_generate_entity: application generate entity

View File

@ -16,7 +16,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
_blocking_response_type = CompletionAppBlockingResponse _blocking_response_type = CompletionAppBlockingResponse
@classmethod @classmethod
def convert_blocking_full_response(cls, blocking_response: CompletionAppBlockingResponse) -> dict: # type: ignore[override] def convert_blocking_full_response(cls, blocking_response: CompletionAppBlockingResponse): # type: ignore[override]
""" """
Convert blocking full response. Convert blocking full response.
:param blocking_response: blocking response :param blocking_response: blocking response
@ -36,7 +36,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
return response return response
@classmethod @classmethod
def convert_blocking_simple_response(cls, blocking_response: CompletionAppBlockingResponse) -> dict: # type: ignore[override] def convert_blocking_simple_response(cls, blocking_response: CompletionAppBlockingResponse): # type: ignore[override]
""" """
Convert blocking simple response. Convert blocking simple response.
:param blocking_response: blocking response :param blocking_response: blocking response

View File

@ -14,14 +14,14 @@ from core.app.entities.queue_entities import (
class MessageBasedAppQueueManager(AppQueueManager): class MessageBasedAppQueueManager(AppQueueManager):
def __init__( def __init__(
self, task_id: str, user_id: str, invoke_from: InvokeFrom, conversation_id: str, app_mode: str, message_id: str self, task_id: str, user_id: str, invoke_from: InvokeFrom, conversation_id: str, app_mode: str, message_id: str
) -> None: ):
super().__init__(task_id, user_id, invoke_from) super().__init__(task_id, user_id, invoke_from)
self._conversation_id = str(conversation_id) self._conversation_id = str(conversation_id)
self._app_mode = app_mode self._app_mode = app_mode
self._message_id = str(message_id) self._message_id = str(message_id)
def _publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None: def _publish(self, event: AppQueueEvent, pub_from: PublishFrom):
""" """
Publish event to queue Publish event to queue
:param event: :param event:

View File

@ -35,7 +35,7 @@ class WorkflowAppConfigManager(BaseAppConfigManager):
return app_config return app_config
@classmethod @classmethod
def config_validate(cls, tenant_id: str, config: dict, only_structure_validate: bool = False) -> dict: def config_validate(cls, tenant_id: str, config: dict, only_structure_validate: bool = False):
""" """
Validate for workflow app model config Validate for workflow app model config

View File

@ -14,12 +14,12 @@ from core.app.entities.queue_entities import (
class WorkflowAppQueueManager(AppQueueManager): class WorkflowAppQueueManager(AppQueueManager):
def __init__(self, task_id: str, user_id: str, invoke_from: InvokeFrom, app_mode: str) -> None: def __init__(self, task_id: str, user_id: str, invoke_from: InvokeFrom, app_mode: str):
super().__init__(task_id, user_id, invoke_from) super().__init__(task_id, user_id, invoke_from)
self._app_mode = app_mode self._app_mode = app_mode
def _publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None: def _publish(self, event: AppQueueEvent, pub_from: PublishFrom):
""" """
Publish event to queue Publish event to queue
:param event: :param event:

View File

@ -34,7 +34,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
variable_loader: VariableLoader, variable_loader: VariableLoader,
workflow: Workflow, workflow: Workflow,
system_user_id: str, system_user_id: str,
) -> None: ):
super().__init__( super().__init__(
queue_manager=queue_manager, queue_manager=queue_manager,
variable_loader=variable_loader, variable_loader=variable_loader,
@ -44,7 +44,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
self._workflow = workflow self._workflow = workflow
self._sys_user_id = system_user_id self._sys_user_id = system_user_id
def run(self) -> None: def run(self):
""" """
Run application Run application
""" """
@ -127,6 +127,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
), ),
invoke_from=self.application_generate_entity.invoke_from, invoke_from=self.application_generate_entity.invoke_from,
call_depth=self.application_generate_entity.call_depth, call_depth=self.application_generate_entity.call_depth,
variable_pool=variable_pool,
graph_runtime_state=graph_runtime_state, graph_runtime_state=graph_runtime_state,
command_channel=command_channel, command_channel=command_channel,
) )

View File

@ -17,7 +17,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
_blocking_response_type = WorkflowAppBlockingResponse _blocking_response_type = WorkflowAppBlockingResponse
@classmethod @classmethod
def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: # type: ignore[override] def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse): # type: ignore[override]
""" """
Convert blocking full response. Convert blocking full response.
:param blocking_response: blocking response :param blocking_response: blocking response
@ -26,7 +26,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
return dict(blocking_response.to_dict()) return dict(blocking_response.to_dict())
@classmethod @classmethod
def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: # type: ignore[override] def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse): # type: ignore[override]
""" """
Convert blocking simple response. Convert blocking simple response.
:param blocking_response: blocking response :param blocking_response: blocking response

View File

@ -88,7 +88,7 @@ class WorkflowAppGenerateTaskPipeline:
workflow_execution_repository: WorkflowExecutionRepository, workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository, workflow_node_execution_repository: WorkflowNodeExecutionRepository,
draft_var_saver_factory: DraftVariableSaverFactory, draft_var_saver_factory: DraftVariableSaverFactory,
) -> None: ):
self._base_task_pipeline = BasedGenerateTaskPipeline( self._base_task_pipeline = BasedGenerateTaskPipeline(
application_generate_entity=application_generate_entity, application_generate_entity=application_generate_entity,
queue_manager=queue_manager, queue_manager=queue_manager,
@ -259,7 +259,7 @@ class WorkflowAppGenerateTaskPipeline:
session.rollback() session.rollback()
raise raise
def _ensure_workflow_initialized(self) -> None: def _ensure_workflow_initialized(self):
"""Fluent validation for workflow state.""" """Fluent validation for workflow state."""
if not self._workflow_run_id: if not self._workflow_run_id:
raise ValueError("workflow run not initialized.") raise ValueError("workflow run not initialized.")
@ -697,7 +697,7 @@ class WorkflowAppGenerateTaskPipeline:
if tts_publisher: if tts_publisher:
tts_publisher.publish(None) tts_publisher.publish(None)
def _save_workflow_app_log(self, *, session: Session, workflow_execution: WorkflowExecution) -> None: def _save_workflow_app_log(self, *, session: Session, workflow_execution: WorkflowExecution):
invoke_from = self._application_generate_entity.invoke_from invoke_from = self._application_generate_entity.invoke_from
if invoke_from == InvokeFrom.SERVICE_API: if invoke_from == InvokeFrom.SERVICE_API:
created_from = WorkflowAppLogCreatedFrom.SERVICE_API created_from = WorkflowAppLogCreatedFrom.SERVICE_API

View File

@ -67,7 +67,7 @@ class WorkflowBasedAppRunner:
queue_manager: AppQueueManager, queue_manager: AppQueueManager,
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER, variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
app_id: str, app_id: str,
) -> None: ):
self._queue_manager = queue_manager self._queue_manager = queue_manager
self._variable_loader = variable_loader self._variable_loader = variable_loader
self._app_id = app_id self._app_id = app_id
@ -348,7 +348,7 @@ class WorkflowBasedAppRunner:
return graph, variable_pool return graph, variable_pool
def _handle_event(self, workflow_entry: WorkflowEntry, event: GraphEngineEvent) -> None: def _handle_event(self, workflow_entry: WorkflowEntry, event: GraphEngineEvent):
""" """
Handle event Handle event
:param workflow_entry: workflow entry :param workflow_entry: workflow entry
@ -580,5 +580,5 @@ class WorkflowBasedAppRunner:
) )
) )
def _publish_event(self, event: AppQueueEvent) -> None: def _publish_event(self, event: AppQueueEvent):
self._queue_manager.publish(event, PublishFrom.APPLICATION_MANAGER) self._queue_manager.publish(event, PublishFrom.APPLICATION_MANAGER)

View File

@ -35,7 +35,7 @@ class BasedGenerateTaskPipeline:
application_generate_entity: AppGenerateEntity, application_generate_entity: AppGenerateEntity,
queue_manager: AppQueueManager, queue_manager: AppQueueManager,
stream: bool, stream: bool,
) -> None: ):
self._application_generate_entity = application_generate_entity self._application_generate_entity = application_generate_entity
self.queue_manager = queue_manager self.queue_manager = queue_manager
self._start_at = time.perf_counter() self._start_at = time.perf_counter()

View File

@ -80,7 +80,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
conversation: Conversation, conversation: Conversation,
message: Message, message: Message,
stream: bool, stream: bool,
) -> None: ):
super().__init__( super().__init__(
application_generate_entity=application_generate_entity, application_generate_entity=application_generate_entity,
queue_manager=queue_manager, queue_manager=queue_manager,
@ -362,7 +362,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
if self._conversation_name_generate_thread: if self._conversation_name_generate_thread:
self._conversation_name_generate_thread.join() self._conversation_name_generate_thread.join()
def _save_message(self, *, session: Session, trace_manager: Optional[TraceQueueManager] = None) -> None: def _save_message(self, *, session: Session, trace_manager: Optional[TraceQueueManager] = None):
""" """
Save message. Save message.
:return: :return:
@ -412,7 +412,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
application_generate_entity=self._application_generate_entity, application_generate_entity=self._application_generate_entity,
) )
def _handle_stop(self, event: QueueStopEvent) -> None: def _handle_stop(self, event: QueueStopEvent):
""" """
Handle stop. Handle stop.
:return: :return:

View File

@ -48,7 +48,7 @@ class MessageCycleManager:
AdvancedChatAppGenerateEntity, AdvancedChatAppGenerateEntity,
], ],
task_state: Union[EasyUITaskState, WorkflowTaskState], task_state: Union[EasyUITaskState, WorkflowTaskState],
) -> None: ):
self._application_generate_entity = application_generate_entity self._application_generate_entity = application_generate_entity
self._task_state = task_state self._task_state = task_state
@ -132,7 +132,7 @@ class MessageCycleManager:
return None return None
def handle_retriever_resources(self, event: QueueRetrieverResourcesEvent) -> None: def handle_retriever_resources(self, event: QueueRetrieverResourcesEvent):
""" """
Handle retriever resources. Handle retriever resources.
:param event: event :param event: event

View File

@ -23,7 +23,7 @@ def get_colored_text(text: str, color: str) -> str:
return f"\u001b[{color_str}m\033[1;3m{text}\u001b[0m" return f"\u001b[{color_str}m\033[1;3m{text}\u001b[0m"
def print_text(text: str, color: Optional[str] = None, end: str = "", file: Optional[TextIO] = None) -> None: def print_text(text: str, color: Optional[str] = None, end: str = "", file: Optional[TextIO] = None):
"""Print text with highlighting and no end characters.""" """Print text with highlighting and no end characters."""
text_to_print = get_colored_text(text, color) if color else text text_to_print = get_colored_text(text, color) if color else text
print(text_to_print, end=end, file=file) print(text_to_print, end=end, file=file)
@ -37,7 +37,7 @@ class DifyAgentCallbackHandler(BaseModel):
color: Optional[str] = "" color: Optional[str] = ""
current_loop: int = 1 current_loop: int = 1
def __init__(self, color: Optional[str] = None) -> None: def __init__(self, color: Optional[str] = None):
super().__init__() super().__init__()
"""Initialize callback handler.""" """Initialize callback handler."""
# use a specific color is not specified # use a specific color is not specified
@ -48,7 +48,7 @@ class DifyAgentCallbackHandler(BaseModel):
self, self,
tool_name: str, tool_name: str,
tool_inputs: Mapping[str, Any], tool_inputs: Mapping[str, Any],
) -> None: ):
"""Do nothing.""" """Do nothing."""
if dify_config.DEBUG: if dify_config.DEBUG:
print_text("\n[on_tool_start] ToolCall:" + tool_name + "\n" + str(tool_inputs) + "\n", color=self.color) print_text("\n[on_tool_start] ToolCall:" + tool_name + "\n" + str(tool_inputs) + "\n", color=self.color)
@ -61,7 +61,7 @@ class DifyAgentCallbackHandler(BaseModel):
message_id: Optional[str] = None, message_id: Optional[str] = None,
timer: Optional[Any] = None, timer: Optional[Any] = None,
trace_manager: Optional[TraceQueueManager] = None, trace_manager: Optional[TraceQueueManager] = None,
) -> None: ):
"""If not the final action, print out observation.""" """If not the final action, print out observation."""
if dify_config.DEBUG: if dify_config.DEBUG:
print_text("\n[on_tool_end]\n", color=self.color) print_text("\n[on_tool_end]\n", color=self.color)
@ -82,12 +82,12 @@ class DifyAgentCallbackHandler(BaseModel):
) )
) )
def on_tool_error(self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any) -> None: def on_tool_error(self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any):
"""Do nothing.""" """Do nothing."""
if dify_config.DEBUG: if dify_config.DEBUG:
print_text("\n[on_tool_error] Error: " + str(error) + "\n", color="red") print_text("\n[on_tool_error] Error: " + str(error) + "\n", color="red")
def on_agent_start(self, thought: str) -> None: def on_agent_start(self, thought: str):
"""Run on agent start.""" """Run on agent start."""
if dify_config.DEBUG: if dify_config.DEBUG:
if thought: if thought:
@ -98,7 +98,7 @@ class DifyAgentCallbackHandler(BaseModel):
else: else:
print_text("\n[on_agent_start] \nCurrent Loop: " + str(self.current_loop) + "\n", color=self.color) print_text("\n[on_agent_start] \nCurrent Loop: " + str(self.current_loop) + "\n", color=self.color)
def on_agent_finish(self, color: Optional[str] = None, **kwargs: Any) -> None: def on_agent_finish(self, color: Optional[str] = None, **kwargs: Any):
"""Run on agent end.""" """Run on agent end."""
if dify_config.DEBUG: if dify_config.DEBUG:
print_text("\n[on_agent_finish]\n Loop: " + str(self.current_loop) + "\n", color=self.color) print_text("\n[on_agent_finish]\n Loop: " + str(self.current_loop) + "\n", color=self.color)

View File

@ -21,14 +21,14 @@ class DatasetIndexToolCallbackHandler:
def __init__( def __init__(
self, queue_manager: AppQueueManager, app_id: str, message_id: str, user_id: str, invoke_from: InvokeFrom self, queue_manager: AppQueueManager, app_id: str, message_id: str, user_id: str, invoke_from: InvokeFrom
) -> None: ):
self._queue_manager = queue_manager self._queue_manager = queue_manager
self._app_id = app_id self._app_id = app_id
self._message_id = message_id self._message_id = message_id
self._user_id = user_id self._user_id = user_id
self._invoke_from = invoke_from self._invoke_from = invoke_from
def on_query(self, query: str, dataset_id: str) -> None: def on_query(self, query: str, dataset_id: str):
""" """
Handle query. Handle query.
""" """
@ -46,7 +46,7 @@ class DatasetIndexToolCallbackHandler:
db.session.add(dataset_query) db.session.add(dataset_query)
db.session.commit() db.session.commit()
def on_tool_end(self, documents: list[Document]) -> None: def on_tool_end(self, documents: list[Document]):
"""Handle tool end.""" """Handle tool end."""
for document in documents: for document in documents:
if document.metadata is not None: if document.metadata is not None:

View File

@ -30,8 +30,6 @@ class FakeDatasourceRuntime(DatasourceRuntime):
""" """
def __init__(self): def __init__(self):
super().__init__( super().__init__(
tenant_id="fake_tenant_id", tenant_id="fake_tenant_id",
datasource_id="fake_datasource_id", datasource_id="fake_datasource_id",

View File

@ -33,7 +33,7 @@ class SimpleModelProviderEntity(BaseModel):
icon_large: Optional[I18nObject] = None icon_large: Optional[I18nObject] = None
supported_model_types: list[ModelType] supported_model_types: list[ModelType]
def __init__(self, provider_entity: ProviderEntity) -> None: def __init__(self, provider_entity: ProviderEntity):
""" """
Init simple provider. Init simple provider.
@ -57,7 +57,7 @@ class ProviderModelWithStatusEntity(ProviderModel):
load_balancing_enabled: bool = False load_balancing_enabled: bool = False
has_invalid_load_balancing_configs: bool = False has_invalid_load_balancing_configs: bool = False
def raise_for_status(self) -> None: def raise_for_status(self):
""" """
Check model status and raise ValueError if not active. Check model status and raise ValueError if not active.

View File

@ -280,9 +280,7 @@ class ProviderConfiguration(BaseModel):
else [], else [],
) )
def validate_provider_credentials( def validate_provider_credentials(self, credentials: dict, credential_id: str = "", session: Session | None = None):
self, credentials: dict, credential_id: str = "", session: Session | None = None
) -> dict:
""" """
Validate custom credentials. Validate custom credentials.
:param credentials: provider credentials :param credentials: provider credentials
@ -291,7 +289,7 @@ class ProviderConfiguration(BaseModel):
:return: :return:
""" """
def _validate(s: Session) -> dict: def _validate(s: Session):
# Get provider credential secret variables # Get provider credential secret variables
provider_credential_secret_variables = self.extract_secret_variables( provider_credential_secret_variables = self.extract_secret_variables(
self.provider.provider_credential_schema.credential_form_schemas self.provider.provider_credential_schema.credential_form_schemas
@ -402,7 +400,7 @@ class ProviderConfiguration(BaseModel):
logger.warning("Error generating next credential name: %s", str(e)) logger.warning("Error generating next credential name: %s", str(e))
return "API KEY 1" return "API KEY 1"
def create_provider_credential(self, credentials: dict, credential_name: str | None) -> None: def create_provider_credential(self, credentials: dict, credential_name: str | None):
""" """
Add custom provider credentials. Add custom provider credentials.
:param credentials: provider credentials :param credentials: provider credentials
@ -458,7 +456,7 @@ class ProviderConfiguration(BaseModel):
credentials: dict, credentials: dict,
credential_id: str, credential_id: str,
credential_name: str | None, credential_name: str | None,
) -> None: ):
""" """
update a saved provider credential (by credential_id). update a saved provider credential (by credential_id).
@ -519,7 +517,7 @@ class ProviderConfiguration(BaseModel):
credential_record: ProviderCredential | ProviderModelCredential, credential_record: ProviderCredential | ProviderModelCredential,
credential_source: str, credential_source: str,
session: Session, session: Session,
) -> None: ):
""" """
Update load balancing configurations that reference the given credential_id. Update load balancing configurations that reference the given credential_id.
@ -559,7 +557,7 @@ class ProviderConfiguration(BaseModel):
session.commit() session.commit()
def delete_provider_credential(self, credential_id: str) -> None: def delete_provider_credential(self, credential_id: str):
""" """
Delete a saved provider credential (by credential_id). Delete a saved provider credential (by credential_id).
@ -636,7 +634,7 @@ class ProviderConfiguration(BaseModel):
session.rollback() session.rollback()
raise raise
def switch_active_provider_credential(self, credential_id: str) -> None: def switch_active_provider_credential(self, credential_id: str):
""" """
Switch active provider credential (copy the selected one into current active snapshot). Switch active provider credential (copy the selected one into current active snapshot).
@ -815,7 +813,7 @@ class ProviderConfiguration(BaseModel):
credentials: dict, credentials: dict,
credential_id: str = "", credential_id: str = "",
session: Session | None = None, session: Session | None = None,
) -> dict: ):
""" """
Validate custom model credentials. Validate custom model credentials.
@ -826,7 +824,7 @@ class ProviderConfiguration(BaseModel):
:return: :return:
""" """
def _validate(s: Session) -> dict: def _validate(s: Session):
# Get provider credential secret variables # Get provider credential secret variables
provider_credential_secret_variables = self.extract_secret_variables( provider_credential_secret_variables = self.extract_secret_variables(
self.provider.model_credential_schema.credential_form_schemas self.provider.model_credential_schema.credential_form_schemas
@ -1010,7 +1008,7 @@ class ProviderConfiguration(BaseModel):
session.rollback() session.rollback()
raise raise
def delete_custom_model_credential(self, model_type: ModelType, model: str, credential_id: str) -> None: def delete_custom_model_credential(self, model_type: ModelType, model: str, credential_id: str):
""" """
Delete a saved provider credential (by credential_id). Delete a saved provider credential (by credential_id).
@ -1080,7 +1078,7 @@ class ProviderConfiguration(BaseModel):
session.rollback() session.rollback()
raise raise
def add_model_credential_to_model(self, model_type: ModelType, model: str, credential_id: str) -> None: def add_model_credential_to_model(self, model_type: ModelType, model: str, credential_id: str):
""" """
if model list exist this custom model, switch the custom model credential. if model list exist this custom model, switch the custom model credential.
if model list not exist this custom model, use the credential to add a new custom model record. if model list not exist this custom model, use the credential to add a new custom model record.
@ -1123,7 +1121,7 @@ class ProviderConfiguration(BaseModel):
session.add(provider_model_record) session.add(provider_model_record)
session.commit() session.commit()
def switch_custom_model_credential(self, model_type: ModelType, model: str, credential_id: str) -> None: def switch_custom_model_credential(self, model_type: ModelType, model: str, credential_id: str):
""" """
switch the custom model credential. switch the custom model credential.
@ -1153,7 +1151,7 @@ class ProviderConfiguration(BaseModel):
session.add(provider_model_record) session.add(provider_model_record)
session.commit() session.commit()
def delete_custom_model(self, model_type: ModelType, model: str) -> None: def delete_custom_model(self, model_type: ModelType, model: str):
""" """
Delete custom model. Delete custom model.
:param model_type: model type :param model_type: model type
@ -1350,7 +1348,7 @@ class ProviderConfiguration(BaseModel):
provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials
) )
def switch_preferred_provider_type(self, provider_type: ProviderType, session: Session | None = None) -> None: def switch_preferred_provider_type(self, provider_type: ProviderType, session: Session | None = None):
""" """
Switch preferred provider type. Switch preferred provider type.
:param provider_type: :param provider_type:
@ -1362,7 +1360,7 @@ class ProviderConfiguration(BaseModel):
if provider_type == ProviderType.SYSTEM and not self.system_configuration.enabled: if provider_type == ProviderType.SYSTEM and not self.system_configuration.enabled:
return return
def _switch(s: Session) -> None: def _switch(s: Session):
# get preferred provider # get preferred provider
model_provider_id = ModelProviderID(self.provider.provider) model_provider_id = ModelProviderID(self.provider.provider)
provider_names = [self.provider.provider] provider_names = [self.provider.provider]
@ -1406,7 +1404,7 @@ class ProviderConfiguration(BaseModel):
return secret_input_form_variables return secret_input_form_variables
def obfuscated_credentials(self, credentials: dict, credential_form_schemas: list[CredentialFormSchema]) -> dict: def obfuscated_credentials(self, credentials: dict, credential_form_schemas: list[CredentialFormSchema]):
""" """
Obfuscated credentials. Obfuscated credentials.

View File

@ -6,7 +6,7 @@ class LLMError(ValueError):
description: Optional[str] = None description: Optional[str] = None
def __init__(self, description: Optional[str] = None) -> None: def __init__(self, description: Optional[str] = None):
self.description = description self.description = description

View File

@ -10,11 +10,11 @@ class APIBasedExtensionRequestor:
timeout: tuple[int, int] = (5, 60) timeout: tuple[int, int] = (5, 60)
"""timeout for request connect and read""" """timeout for request connect and read"""
def __init__(self, api_endpoint: str, api_key: str) -> None: def __init__(self, api_endpoint: str, api_key: str):
self.api_endpoint = api_endpoint self.api_endpoint = api_endpoint
self.api_key = api_key self.api_key = api_key
def request(self, point: APIBasedExtensionPoint, params: dict) -> dict: def request(self, point: APIBasedExtensionPoint, params: dict):
""" """
Request the api. Request the api.

View File

@ -34,7 +34,7 @@ class Extensible:
tenant_id: str tenant_id: str
config: Optional[dict] = None config: Optional[dict] = None
def __init__(self, tenant_id: str, config: Optional[dict] = None) -> None: def __init__(self, tenant_id: str, config: Optional[dict] = None):
self.tenant_id = tenant_id self.tenant_id = tenant_id
self.config = config self.config = config

View File

@ -18,7 +18,7 @@ class ApiExternalDataTool(ExternalDataTool):
"""the unique name of external data tool""" """the unique name of external data tool"""
@classmethod @classmethod
def validate_config(cls, tenant_id: str, config: dict) -> None: def validate_config(cls, tenant_id: str, config: dict):
""" """
Validate the incoming form config data. Validate the incoming form config data.

View File

@ -16,14 +16,14 @@ class ExternalDataTool(Extensible, ABC):
variable: str variable: str
"""the tool variable name of app tool""" """the tool variable name of app tool"""
def __init__(self, tenant_id: str, app_id: str, variable: str, config: Optional[dict] = None) -> None: def __init__(self, tenant_id: str, app_id: str, variable: str, config: Optional[dict] = None):
super().__init__(tenant_id, config) super().__init__(tenant_id, config)
self.app_id = app_id self.app_id = app_id
self.variable = variable self.variable = variable
@classmethod @classmethod
@abstractmethod @abstractmethod
def validate_config(cls, tenant_id: str, config: dict) -> None: def validate_config(cls, tenant_id: str, config: dict):
""" """
Validate the incoming form config data. Validate the incoming form config data.

View File

@ -6,14 +6,14 @@ from extensions.ext_code_based_extension import code_based_extension
class ExternalDataToolFactory: class ExternalDataToolFactory:
def __init__(self, name: str, tenant_id: str, app_id: str, variable: str, config: dict) -> None: def __init__(self, name: str, tenant_id: str, app_id: str, variable: str, config: dict):
extension_class = code_based_extension.extension_class(ExtensionModule.EXTERNAL_DATA_TOOL, name) extension_class = code_based_extension.extension_class(ExtensionModule.EXTERNAL_DATA_TOOL, name)
self.__extension_instance = extension_class( self.__extension_instance = extension_class(
tenant_id=tenant_id, app_id=app_id, variable=variable, config=config tenant_id=tenant_id, app_id=app_id, variable=variable, config=config
) )
@classmethod @classmethod
def validate_config(cls, name: str, tenant_id: str, config: dict) -> None: def validate_config(cls, name: str, tenant_id: str, config: dict):
""" """
Validate the incoming form config data. Validate the incoming form config data.

View File

@ -7,6 +7,6 @@ if TYPE_CHECKING:
_tool_file_manager_factory: Callable[[], "ToolFileManager"] | None = None _tool_file_manager_factory: Callable[[], "ToolFileManager"] | None = None
def set_tool_file_manager_factory(factory: Callable[[], "ToolFileManager"]) -> None: def set_tool_file_manager_factory(factory: Callable[[], "ToolFileManager"]):
global _tool_file_manager_factory global _tool_file_manager_factory
_tool_file_manager_factory = factory _tool_file_manager_factory = factory

View File

@ -22,7 +22,7 @@ class CodeNodeProvider(BaseModel):
pass pass
@classmethod @classmethod
def get_default_config(cls) -> dict: def get_default_config(cls):
return { return {
"type": "code", "type": "code",
"config": { "config": {

View File

@ -5,7 +5,7 @@ from core.helper.code_executor.template_transformer import TemplateTransformer
class Jinja2TemplateTransformer(TemplateTransformer): class Jinja2TemplateTransformer(TemplateTransformer):
@classmethod @classmethod
def transform_response(cls, response: str) -> dict: def transform_response(cls, response: str):
""" """
Transform response to dict Transform response to dict
:param response: response :param response: response

View File

@ -13,7 +13,7 @@ class Python3CodeProvider(CodeNodeProvider):
def get_default_code(cls) -> str: def get_default_code(cls) -> str:
return dedent( return dedent(
""" """
def main(arg1: str, arg2: str) -> dict: def main(arg1: str, arg2: str):
return { return {
"result": arg1 + arg2, "result": arg1 + arg2,
} }

View File

@ -34,7 +34,7 @@ class ProviderCredentialsCache:
else: else:
return None return None
def set(self, credentials: dict) -> None: def set(self, credentials: dict):
""" """
Cache model provider credentials. Cache model provider credentials.
@ -43,7 +43,7 @@ class ProviderCredentialsCache:
""" """
redis_client.setex(self.cache_key, 86400, json.dumps(credentials)) redis_client.setex(self.cache_key, 86400, json.dumps(credentials))
def delete(self) -> None: def delete(self):
""" """
Delete cached model provider credentials. Delete cached model provider credentials.

View File

@ -28,11 +28,11 @@ class ProviderCredentialsCache(ABC):
return None return None
return None return None
def set(self, config: dict[str, Any]) -> None: def set(self, config: dict[str, Any]):
"""Cache provider credentials""" """Cache provider credentials"""
redis_client.setex(self.cache_key, 86400, json.dumps(config)) redis_client.setex(self.cache_key, 86400, json.dumps(config))
def delete(self) -> None: def delete(self):
"""Delete cached provider credentials""" """Delete cached provider credentials"""
redis_client.delete(self.cache_key) redis_client.delete(self.cache_key)
@ -75,10 +75,10 @@ class NoOpProviderCredentialCache:
"""Get cached provider credentials""" """Get cached provider credentials"""
return None return None
def set(self, config: dict[str, Any]) -> None: def set(self, config: dict[str, Any]):
"""Cache provider credentials""" """Cache provider credentials"""
pass pass
def delete(self) -> None: def delete(self):
"""Delete cached provider credentials""" """Delete cached provider credentials"""
pass pass

View File

@ -37,11 +37,11 @@ class ToolParameterCache:
else: else:
return None return None
def set(self, parameters: dict) -> None: def set(self, parameters: dict):
"""Cache model provider credentials.""" """Cache model provider credentials."""
redis_client.setex(self.cache_key, 86400, json.dumps(parameters)) redis_client.setex(self.cache_key, 86400, json.dumps(parameters))
def delete(self) -> None: def delete(self):
""" """
Delete cached model provider credentials. Delete cached model provider credentials.

View File

@ -49,7 +49,7 @@ def get_external_trace_id(request: Any) -> Optional[str]:
return None return None
def extract_external_trace_id_from_args(args: Mapping[str, Any]) -> dict: def extract_external_trace_id_from_args(args: Mapping[str, Any]):
""" """
Extract 'external_trace_id' from args. Extract 'external_trace_id' from args.

View File

@ -44,11 +44,11 @@ class HostingConfiguration:
provider_map: dict[str, HostingProvider] provider_map: dict[str, HostingProvider]
moderation_config: Optional[HostedModerationConfig] = None moderation_config: Optional[HostedModerationConfig] = None
def __init__(self) -> None: def __init__(self):
self.provider_map = {} self.provider_map = {}
self.moderation_config = None self.moderation_config = None
def init_app(self, app: Flask) -> None: def init_app(self, app: Flask):
if dify_config.EDITION != "CLOUD": if dify_config.EDITION != "CLOUD":
return return

View File

@ -270,7 +270,9 @@ class IndexingRunner:
tenant_id=tenant_id, tenant_id=tenant_id,
model_type=ModelType.TEXT_EMBEDDING, model_type=ModelType.TEXT_EMBEDDING,
) )
preview_texts = [] # type: ignore # keep separate, avoid union-list ambiguity
preview_texts: list[PreviewDetail] = []
qa_preview_texts: list[QAPreviewDetail] = []
total_segments = 0 total_segments = 0
index_type = doc_form index_type = doc_form
@ -293,14 +295,14 @@ class IndexingRunner:
for document in documents: for document in documents:
if len(preview_texts) < 10: if len(preview_texts) < 10:
if doc_form and doc_form == "qa_model": if doc_form and doc_form == "qa_model":
preview_detail = QAPreviewDetail( qa_detail = QAPreviewDetail(
question=document.page_content, answer=document.metadata.get("answer") or "" question=document.page_content, answer=document.metadata.get("answer") or ""
) )
preview_texts.append(preview_detail) qa_preview_texts.append(qa_detail)
else: else:
preview_detail = PreviewDetail(content=document.page_content) # type: ignore preview_detail = PreviewDetail(content=document.page_content)
if document.children: if document.children:
preview_detail.child_chunks = [child.page_content for child in document.children] # type: ignore preview_detail.child_chunks = [child.page_content for child in document.children]
preview_texts.append(preview_detail) preview_texts.append(preview_detail)
# delete image files and related db records # delete image files and related db records
@ -321,8 +323,8 @@ class IndexingRunner:
db.session.delete(image_file) db.session.delete(image_file)
if doc_form and doc_form == "qa_model": if doc_form and doc_form == "qa_model":
return IndexingEstimate(total_segments=total_segments * 20, qa_preview=preview_texts, preview=[]) return IndexingEstimate(total_segments=total_segments * 20, qa_preview=qa_preview_texts, preview=[])
return IndexingEstimate(total_segments=total_segments, preview=preview_texts) # type: ignore return IndexingEstimate(total_segments=total_segments, preview=preview_texts)
def _extract( def _extract(
self, index_processor: BaseIndexProcessor, dataset_document: DatasetDocument, process_rule: dict self, index_processor: BaseIndexProcessor, dataset_document: DatasetDocument, process_rule: dict
@ -425,6 +427,7 @@ class IndexingRunner:
""" """
Get the NodeParser object according to the processing rule. Get the NodeParser object according to the processing rule.
""" """
character_splitter: TextSplitter
if processing_rule_mode in ["custom", "hierarchical"]: if processing_rule_mode in ["custom", "hierarchical"]:
# The user-defined segmentation rule # The user-defined segmentation rule
max_segmentation_tokens_length = dify_config.INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH max_segmentation_tokens_length = dify_config.INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH
@ -451,7 +454,7 @@ class IndexingRunner:
embedding_model_instance=embedding_model_instance, embedding_model_instance=embedding_model_instance,
) )
return character_splitter # type: ignore return character_splitter
def _split_to_documents_for_estimate( def _split_to_documents_for_estimate(
self, text_docs: list[Document], splitter: TextSplitter, processing_rule: DatasetProcessRule self, text_docs: list[Document], splitter: TextSplitter, processing_rule: DatasetProcessRule
@ -510,7 +513,7 @@ class IndexingRunner:
dataset: Dataset, dataset: Dataset,
dataset_document: DatasetDocument, dataset_document: DatasetDocument,
documents: list[Document], documents: list[Document],
) -> None: ):
""" """
insert index and update document/segment status to completed insert index and update document/segment status to completed
""" """
@ -649,7 +652,7 @@ class IndexingRunner:
@staticmethod @staticmethod
def _update_document_index_status( def _update_document_index_status(
document_id: str, after_indexing_status: str, extra_update_params: Optional[dict] = None document_id: str, after_indexing_status: str, extra_update_params: Optional[dict] = None
) -> None: ):
""" """
Update the document indexing status. Update the document indexing status.
""" """
@ -668,7 +671,7 @@ class IndexingRunner:
db.session.commit() db.session.commit()
@staticmethod @staticmethod
def _update_segments_by_document(dataset_document_id: str, update_params: dict) -> None: def _update_segments_by_document(dataset_document_id: str, update_params: dict):
""" """
Update the document segment by document id. Update the document segment by document id.
""" """

View File

@ -129,7 +129,7 @@ class LLMGenerator:
return questions return questions
@classmethod @classmethod
def generate_rule_config(cls, tenant_id: str, instruction: str, model_config: dict, no_variable: bool) -> dict: def generate_rule_config(cls, tenant_id: str, instruction: str, model_config: dict, no_variable: bool):
output_parser = RuleConfigGeneratorOutputParser() output_parser = RuleConfigGeneratorOutputParser()
error = "" error = ""
@ -264,9 +264,7 @@ class LLMGenerator:
return rule_config return rule_config
@classmethod @classmethod
def generate_code( def generate_code(cls, tenant_id: str, instruction: str, model_config: dict, code_language: str = "javascript"):
cls, tenant_id: str, instruction: str, model_config: dict, code_language: str = "javascript"
) -> dict:
if code_language == "python": if code_language == "python":
prompt_template = PromptTemplateParser(PYTHON_CODE_GENERATOR_PROMPT_TEMPLATE) prompt_template = PromptTemplateParser(PYTHON_CODE_GENERATOR_PROMPT_TEMPLATE)
else: else:
@ -375,7 +373,7 @@ class LLMGenerator:
@staticmethod @staticmethod
def instruction_modify_legacy( def instruction_modify_legacy(
tenant_id: str, flow_id: str, current: str, instruction: str, model_config: dict, ideal_output: str | None tenant_id: str, flow_id: str, current: str, instruction: str, model_config: dict, ideal_output: str | None
) -> dict: ):
last_run: Message | None = ( last_run: Message | None = (
db.session.query(Message).where(Message.app_id == flow_id).order_by(Message.created_at.desc()).first() db.session.query(Message).where(Message.app_id == flow_id).order_by(Message.created_at.desc()).first()
) )
@ -415,7 +413,7 @@ class LLMGenerator:
instruction: str, instruction: str,
model_config: dict, model_config: dict,
ideal_output: str | None, ideal_output: str | None,
) -> dict: ):
from services.workflow_service import WorkflowService from services.workflow_service import WorkflowService
session = db.session() session = db.session()
@ -455,7 +453,7 @@ class LLMGenerator:
return [] return []
parsed: Sequence[AgentLogEvent] = json.loads(raw_agent_log) parsed: Sequence[AgentLogEvent] = json.loads(raw_agent_log)
def dict_of_event(event: AgentLogEvent) -> dict: def dict_of_event(event: AgentLogEvent):
return { return {
"status": event.status, "status": event.status,
"error": event.error, "error": event.error,
@ -493,7 +491,7 @@ class LLMGenerator:
instruction: str, instruction: str,
node_type: str, node_type: str,
ideal_output: str | None, ideal_output: str | None,
) -> dict: ):
LAST_RUN = "{{#last_run#}}" LAST_RUN = "{{#last_run#}}"
CURRENT = "{{#current#}}" CURRENT = "{{#current#}}"
ERROR_MESSAGE = "{{#error_message#}}" ERROR_MESSAGE = "{{#error_message#}}"

View File

@ -1,5 +1,3 @@
from typing import Any
from core.llm_generator.output_parser.errors import OutputParserError from core.llm_generator.output_parser.errors import OutputParserError
from core.llm_generator.prompts import ( from core.llm_generator.prompts import (
RULE_CONFIG_PARAMETER_GENERATE_TEMPLATE, RULE_CONFIG_PARAMETER_GENERATE_TEMPLATE,
@ -17,7 +15,7 @@ class RuleConfigGeneratorOutputParser:
RULE_CONFIG_STATEMENT_GENERATE_TEMPLATE, RULE_CONFIG_STATEMENT_GENERATE_TEMPLATE,
) )
def parse(self, text: str) -> Any: def parse(self, text: str):
try: try:
expected_keys = ["prompt", "variables", "opening_statement"] expected_keys = ["prompt", "variables", "opening_statement"]
parsed = parse_and_check_json_markdown(text, expected_keys) parsed = parse_and_check_json_markdown(text, expected_keys)

View File

@ -210,7 +210,7 @@ def _handle_native_json_schema(
structured_output_schema: Mapping, structured_output_schema: Mapping,
model_parameters: dict, model_parameters: dict,
rules: list[ParameterRule], rules: list[ParameterRule],
) -> dict: ):
""" """
Handle structured output for models with native JSON schema support. Handle structured output for models with native JSON schema support.
@ -232,7 +232,7 @@ def _handle_native_json_schema(
return model_parameters return model_parameters
def _set_response_format(model_parameters: dict, rules: list) -> None: def _set_response_format(model_parameters: dict, rules: list):
""" """
Set the appropriate response format parameter based on model rules. Set the appropriate response format parameter based on model rules.
@ -306,7 +306,7 @@ def _parse_structured_output(result_text: str) -> Mapping[str, Any]:
return structured_output return structured_output
def _prepare_schema_for_model(provider: str, model_schema: AIModelEntity, schema: Mapping) -> dict: def _prepare_schema_for_model(provider: str, model_schema: AIModelEntity, schema: Mapping):
""" """
Prepare JSON schema based on model requirements. Prepare JSON schema based on model requirements.
@ -334,7 +334,7 @@ def _prepare_schema_for_model(provider: str, model_schema: AIModelEntity, schema
return {"schema": processed_schema, "name": "llm_response"} return {"schema": processed_schema, "name": "llm_response"}
def remove_additional_properties(schema: dict) -> None: def remove_additional_properties(schema: dict):
""" """
Remove additionalProperties fields from JSON schema. Remove additionalProperties fields from JSON schema.
Used for models like Gemini that don't support this property. Used for models like Gemini that don't support this property.
@ -357,7 +357,7 @@ def remove_additional_properties(schema: dict) -> None:
remove_additional_properties(item) remove_additional_properties(item)
def convert_boolean_to_string(schema: dict) -> None: def convert_boolean_to_string(schema: dict):
""" """
Convert boolean type specifications to string in JSON schema. Convert boolean type specifications to string in JSON schema.

View File

@ -1,6 +1,5 @@
import json import json
import re import re
from typing import Any
from core.llm_generator.prompts import SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT from core.llm_generator.prompts import SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT
@ -9,7 +8,7 @@ class SuggestedQuestionsAfterAnswerOutputParser:
def get_format_instructions(self) -> str: def get_format_instructions(self) -> str:
return SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT return SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT
def parse(self, text: str) -> Any: def parse(self, text: str):
action_match = re.search(r"\[.*?\]", text.strip(), re.DOTALL) action_match = re.search(r"\[.*?\]", text.strip(), re.DOTALL)
if action_match is not None: if action_match is not None:
json_obj = json.loads(action_match.group(0).strip()) json_obj = json.loads(action_match.group(0).strip())

View File

@ -44,7 +44,7 @@ class OAuthClientProvider:
return None return None
return OAuthClientInformation.model_validate(client_information) return OAuthClientInformation.model_validate(client_information)
def save_client_information(self, client_information: OAuthClientInformationFull) -> None: def save_client_information(self, client_information: OAuthClientInformationFull):
"""Saves client information after dynamic registration.""" """Saves client information after dynamic registration."""
MCPToolManageService.update_mcp_provider_credentials( MCPToolManageService.update_mcp_provider_credentials(
self.mcp_provider, self.mcp_provider,
@ -63,13 +63,13 @@ class OAuthClientProvider:
refresh_token=credentials.get("refresh_token", ""), refresh_token=credentials.get("refresh_token", ""),
) )
def save_tokens(self, tokens: OAuthTokens) -> None: def save_tokens(self, tokens: OAuthTokens):
"""Stores new OAuth tokens for the current session.""" """Stores new OAuth tokens for the current session."""
# update mcp provider credentials # update mcp provider credentials
token_dict = tokens.model_dump() token_dict = tokens.model_dump()
MCPToolManageService.update_mcp_provider_credentials(self.mcp_provider, token_dict, authed=True) MCPToolManageService.update_mcp_provider_credentials(self.mcp_provider, token_dict, authed=True)
def save_code_verifier(self, code_verifier: str) -> None: def save_code_verifier(self, code_verifier: str):
"""Saves a PKCE code verifier for the current session.""" """Saves a PKCE code verifier for the current session."""
MCPToolManageService.update_mcp_provider_credentials(self.mcp_provider, {"code_verifier": code_verifier}) MCPToolManageService.update_mcp_provider_credentials(self.mcp_provider, {"code_verifier": code_verifier})

View File

@ -47,7 +47,7 @@ class SSETransport:
headers: dict[str, Any] | None = None, headers: dict[str, Any] | None = None,
timeout: float = 5.0, timeout: float = 5.0,
sse_read_timeout: float = 5 * 60, sse_read_timeout: float = 5 * 60,
) -> None: ):
"""Initialize the SSE transport. """Initialize the SSE transport.
Args: Args:
@ -76,7 +76,7 @@ class SSETransport:
return url_parsed.netloc == endpoint_parsed.netloc and url_parsed.scheme == endpoint_parsed.scheme return url_parsed.netloc == endpoint_parsed.netloc and url_parsed.scheme == endpoint_parsed.scheme
def _handle_endpoint_event(self, sse_data: str, status_queue: StatusQueue) -> None: def _handle_endpoint_event(self, sse_data: str, status_queue: StatusQueue):
"""Handle an 'endpoint' SSE event. """Handle an 'endpoint' SSE event.
Args: Args:
@ -94,7 +94,7 @@ class SSETransport:
status_queue.put(_StatusReady(endpoint_url)) status_queue.put(_StatusReady(endpoint_url))
def _handle_message_event(self, sse_data: str, read_queue: ReadQueue) -> None: def _handle_message_event(self, sse_data: str, read_queue: ReadQueue):
"""Handle a 'message' SSE event. """Handle a 'message' SSE event.
Args: Args:
@ -110,7 +110,7 @@ class SSETransport:
logger.exception("Error parsing server message") logger.exception("Error parsing server message")
read_queue.put(exc) read_queue.put(exc)
def _handle_sse_event(self, sse: ServerSentEvent, read_queue: ReadQueue, status_queue: StatusQueue) -> None: def _handle_sse_event(self, sse: ServerSentEvent, read_queue: ReadQueue, status_queue: StatusQueue):
"""Handle a single SSE event. """Handle a single SSE event.
Args: Args:
@ -126,7 +126,7 @@ class SSETransport:
case _: case _:
logger.warning("Unknown SSE event: %s", sse.event) logger.warning("Unknown SSE event: %s", sse.event)
def sse_reader(self, event_source: EventSource, read_queue: ReadQueue, status_queue: StatusQueue) -> None: def sse_reader(self, event_source: EventSource, read_queue: ReadQueue, status_queue: StatusQueue):
"""Read and process SSE events. """Read and process SSE events.
Args: Args:
@ -144,7 +144,7 @@ class SSETransport:
finally: finally:
read_queue.put(None) read_queue.put(None)
def _send_message(self, client: httpx.Client, endpoint_url: str, message: SessionMessage) -> None: def _send_message(self, client: httpx.Client, endpoint_url: str, message: SessionMessage):
"""Send a single message to the server. """Send a single message to the server.
Args: Args:
@ -163,7 +163,7 @@ class SSETransport:
response.raise_for_status() response.raise_for_status()
logger.debug("Client message sent successfully: %s", response.status_code) logger.debug("Client message sent successfully: %s", response.status_code)
def post_writer(self, client: httpx.Client, endpoint_url: str, write_queue: WriteQueue) -> None: def post_writer(self, client: httpx.Client, endpoint_url: str, write_queue: WriteQueue):
"""Handle writing messages to the server. """Handle writing messages to the server.
Args: Args:
@ -303,7 +303,7 @@ def sse_client(
write_queue.put(None) write_queue.put(None)
def send_message(http_client: httpx.Client, endpoint_url: str, session_message: SessionMessage) -> None: def send_message(http_client: httpx.Client, endpoint_url: str, session_message: SessionMessage):
""" """
Send a message to the server using the provided HTTP client. Send a message to the server using the provided HTTP client.

View File

@ -82,7 +82,7 @@ class StreamableHTTPTransport:
headers: dict[str, Any] | None = None, headers: dict[str, Any] | None = None,
timeout: float | timedelta = 30, timeout: float | timedelta = 30,
sse_read_timeout: float | timedelta = 60 * 5, sse_read_timeout: float | timedelta = 60 * 5,
) -> None: ):
"""Initialize the StreamableHTTP transport. """Initialize the StreamableHTTP transport.
Args: Args:
@ -122,7 +122,7 @@ class StreamableHTTPTransport:
def _maybe_extract_session_id_from_response( def _maybe_extract_session_id_from_response(
self, self,
response: httpx.Response, response: httpx.Response,
) -> None: ):
"""Extract and store session ID from response headers.""" """Extract and store session ID from response headers."""
new_session_id = response.headers.get(MCP_SESSION_ID) new_session_id = response.headers.get(MCP_SESSION_ID)
if new_session_id: if new_session_id:
@ -173,7 +173,7 @@ class StreamableHTTPTransport:
self, self,
client: httpx.Client, client: httpx.Client,
server_to_client_queue: ServerToClientQueue, server_to_client_queue: ServerToClientQueue,
) -> None: ):
"""Handle GET stream for server-initiated messages.""" """Handle GET stream for server-initiated messages."""
try: try:
if not self.session_id: if not self.session_id:
@ -197,7 +197,7 @@ class StreamableHTTPTransport:
except Exception as exc: except Exception as exc:
logger.debug("GET stream error (non-fatal): %s", exc) logger.debug("GET stream error (non-fatal): %s", exc)
def _handle_resumption_request(self, ctx: RequestContext) -> None: def _handle_resumption_request(self, ctx: RequestContext):
"""Handle a resumption request using GET with SSE.""" """Handle a resumption request using GET with SSE."""
headers = self._update_headers_with_session(ctx.headers) headers = self._update_headers_with_session(ctx.headers)
if ctx.metadata and ctx.metadata.resumption_token: if ctx.metadata and ctx.metadata.resumption_token:
@ -230,7 +230,7 @@ class StreamableHTTPTransport:
if is_complete: if is_complete:
break break
def _handle_post_request(self, ctx: RequestContext) -> None: def _handle_post_request(self, ctx: RequestContext):
"""Handle a POST request with response processing.""" """Handle a POST request with response processing."""
headers = self._update_headers_with_session(ctx.headers) headers = self._update_headers_with_session(ctx.headers)
message = ctx.session_message.message message = ctx.session_message.message
@ -278,7 +278,7 @@ class StreamableHTTPTransport:
self, self,
response: httpx.Response, response: httpx.Response,
server_to_client_queue: ServerToClientQueue, server_to_client_queue: ServerToClientQueue,
) -> None: ):
"""Handle JSON response from the server.""" """Handle JSON response from the server."""
try: try:
content = response.read() content = response.read()
@ -288,7 +288,7 @@ class StreamableHTTPTransport:
except Exception as exc: except Exception as exc:
server_to_client_queue.put(exc) server_to_client_queue.put(exc)
def _handle_sse_response(self, response: httpx.Response, ctx: RequestContext) -> None: def _handle_sse_response(self, response: httpx.Response, ctx: RequestContext):
"""Handle SSE response from the server.""" """Handle SSE response from the server."""
try: try:
event_source = EventSource(response) event_source = EventSource(response)
@ -307,7 +307,7 @@ class StreamableHTTPTransport:
self, self,
content_type: str, content_type: str,
server_to_client_queue: ServerToClientQueue, server_to_client_queue: ServerToClientQueue,
) -> None: ):
"""Handle unexpected content type in response.""" """Handle unexpected content type in response."""
error_msg = f"Unexpected content type: {content_type}" error_msg = f"Unexpected content type: {content_type}"
logger.error(error_msg) logger.error(error_msg)
@ -317,7 +317,7 @@ class StreamableHTTPTransport:
self, self,
server_to_client_queue: ServerToClientQueue, server_to_client_queue: ServerToClientQueue,
request_id: RequestId, request_id: RequestId,
) -> None: ):
"""Send a session terminated error response.""" """Send a session terminated error response."""
jsonrpc_error = JSONRPCError( jsonrpc_error = JSONRPCError(
jsonrpc="2.0", jsonrpc="2.0",
@ -333,7 +333,7 @@ class StreamableHTTPTransport:
client_to_server_queue: ClientToServerQueue, client_to_server_queue: ClientToServerQueue,
server_to_client_queue: ServerToClientQueue, server_to_client_queue: ServerToClientQueue,
start_get_stream: Callable[[], None], start_get_stream: Callable[[], None],
) -> None: ):
"""Handle writing requests to the server. """Handle writing requests to the server.
This method processes messages from the client_to_server_queue and sends them to the server. This method processes messages from the client_to_server_queue and sends them to the server.
@ -379,7 +379,7 @@ class StreamableHTTPTransport:
except Exception as exc: except Exception as exc:
server_to_client_queue.put(exc) server_to_client_queue.put(exc)
def terminate_session(self, client: httpx.Client) -> None: def terminate_session(self, client: httpx.Client):
"""Terminate the session by sending a DELETE request.""" """Terminate the session by sending a DELETE request."""
if not self.session_id: if not self.session_id:
return return
@ -441,7 +441,7 @@ def streamablehttp_client(
timeout=httpx.Timeout(transport.timeout, read=transport.sse_read_timeout), timeout=httpx.Timeout(transport.timeout, read=transport.sse_read_timeout),
) as client: ) as client:
# Define callbacks that need access to thread pool # Define callbacks that need access to thread pool
def start_get_stream() -> None: def start_get_stream():
"""Start a worker thread to handle server-initiated messages.""" """Start a worker thread to handle server-initiated messages."""
executor.submit(transport.handle_get_stream, client, server_to_client_queue) executor.submit(transport.handle_get_stream, client, server_to_client_queue)

View File

@ -76,7 +76,7 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
ReceiveNotificationT ReceiveNotificationT
]""", ]""",
on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], Any], on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], Any],
) -> None: ):
self.request_id = request_id self.request_id = request_id
self.request_meta = request_meta self.request_meta = request_meta
self.request = request self.request = request
@ -95,7 +95,7 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
exc_type: type[BaseException] | None, exc_type: type[BaseException] | None,
exc_val: BaseException | None, exc_val: BaseException | None,
exc_tb: TracebackType | None, exc_tb: TracebackType | None,
) -> None: ):
"""Exit the context manager, performing cleanup and notifying completion.""" """Exit the context manager, performing cleanup and notifying completion."""
try: try:
if self._completed: if self._completed:
@ -103,7 +103,7 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
finally: finally:
self._entered = False self._entered = False
def respond(self, response: SendResultT | ErrorData) -> None: def respond(self, response: SendResultT | ErrorData):
"""Send a response for this request. """Send a response for this request.
Must be called within a context manager block. Must be called within a context manager block.
@ -119,7 +119,7 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
self._session._send_response(request_id=self.request_id, response=response) self._session._send_response(request_id=self.request_id, response=response)
def cancel(self) -> None: def cancel(self):
"""Cancel this request and mark it as completed.""" """Cancel this request and mark it as completed."""
if not self._entered: if not self._entered:
raise RuntimeError("RequestResponder must be used as a context manager") raise RuntimeError("RequestResponder must be used as a context manager")
@ -163,7 +163,7 @@ class BaseSession(
receive_notification_type: type[ReceiveNotificationT], receive_notification_type: type[ReceiveNotificationT],
# If none, reading will never time out # If none, reading will never time out
read_timeout_seconds: timedelta | None = None, read_timeout_seconds: timedelta | None = None,
) -> None: ):
self._read_stream = read_stream self._read_stream = read_stream
self._write_stream = write_stream self._write_stream = write_stream
self._response_streams = {} self._response_streams = {}
@ -183,7 +183,7 @@ class BaseSession(
self._receiver_future = self._executor.submit(self._receive_loop) self._receiver_future = self._executor.submit(self._receive_loop)
return self return self
def check_receiver_status(self) -> None: def check_receiver_status(self):
"""`check_receiver_status` ensures that any exceptions raised during the """`check_receiver_status` ensures that any exceptions raised during the
execution of `_receive_loop` are retrieved and propagated.""" execution of `_receive_loop` are retrieved and propagated."""
if self._receiver_future and self._receiver_future.done(): if self._receiver_future and self._receiver_future.done():
@ -191,7 +191,7 @@ class BaseSession(
def __exit__( def __exit__(
self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None
) -> None: ):
self._read_stream.put(None) self._read_stream.put(None)
self._write_stream.put(None) self._write_stream.put(None)
@ -277,7 +277,7 @@ class BaseSession(
self, self,
notification: SendNotificationT, notification: SendNotificationT,
related_request_id: RequestId | None = None, related_request_id: RequestId | None = None,
) -> None: ):
""" """
Emits a notification, which is a one-way message that does not expect Emits a notification, which is a one-way message that does not expect
a response. a response.
@ -296,7 +296,7 @@ class BaseSession(
) )
self._write_stream.put(session_message) self._write_stream.put(session_message)
def _send_response(self, request_id: RequestId, response: SendResultT | ErrorData) -> None: def _send_response(self, request_id: RequestId, response: SendResultT | ErrorData):
if isinstance(response, ErrorData): if isinstance(response, ErrorData):
jsonrpc_error = JSONRPCError(jsonrpc="2.0", id=request_id, error=response) jsonrpc_error = JSONRPCError(jsonrpc="2.0", id=request_id, error=response)
session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_error)) session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_error))
@ -310,7 +310,7 @@ class BaseSession(
session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_response)) session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_response))
self._write_stream.put(session_message) self._write_stream.put(session_message)
def _receive_loop(self) -> None: def _receive_loop(self):
""" """
Main message processing loop. Main message processing loop.
In a real synchronous implementation, this would likely run in a separate thread. In a real synchronous implementation, this would likely run in a separate thread.
@ -382,7 +382,7 @@ class BaseSession(
logger.exception("Error in message processing loop") logger.exception("Error in message processing loop")
raise raise
def _received_request(self, responder: RequestResponder[ReceiveRequestT, SendResultT]) -> None: def _received_request(self, responder: RequestResponder[ReceiveRequestT, SendResultT]):
""" """
Can be overridden by subclasses to handle a request without needing to Can be overridden by subclasses to handle a request without needing to
listen on the message stream. listen on the message stream.
@ -391,15 +391,13 @@ class BaseSession(
forwarded on to the message stream. forwarded on to the message stream.
""" """
def _received_notification(self, notification: ReceiveNotificationT) -> None: def _received_notification(self, notification: ReceiveNotificationT):
""" """
Can be overridden by subclasses to handle a notification without needing Can be overridden by subclasses to handle a notification without needing
to listen on the message stream. to listen on the message stream.
""" """
def send_progress_notification( def send_progress_notification(self, progress_token: str | int, progress: float, total: float | None = None):
self, progress_token: str | int, progress: float, total: float | None = None
) -> None:
""" """
Sends a progress notification for a request that is currently being Sends a progress notification for a request that is currently being
processed. processed.
@ -408,5 +406,5 @@ class BaseSession(
def _handle_incoming( def _handle_incoming(
self, self,
req: RequestResponder[ReceiveRequestT, SendResultT] | ReceiveNotificationT | Exception, req: RequestResponder[ReceiveRequestT, SendResultT] | ReceiveNotificationT | Exception,
) -> None: ):
"""A generic handler for incoming messages. Overwritten by subclasses.""" """A generic handler for incoming messages. Overwritten by subclasses."""

View File

@ -28,19 +28,19 @@ class LoggingFnT(Protocol):
def __call__( def __call__(
self, self,
params: types.LoggingMessageNotificationParams, params: types.LoggingMessageNotificationParams,
) -> None: ... ): ...
class MessageHandlerFnT(Protocol): class MessageHandlerFnT(Protocol):
def __call__( def __call__(
self, self,
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
) -> None: ... ): ...
def _default_message_handler( def _default_message_handler(
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
) -> None: ):
if isinstance(message, Exception): if isinstance(message, Exception):
raise ValueError(str(message)) raise ValueError(str(message))
elif isinstance(message, (types.ServerNotification | RequestResponder)): elif isinstance(message, (types.ServerNotification | RequestResponder)):
@ -68,7 +68,7 @@ def _default_list_roots_callback(
def _default_logging_callback( def _default_logging_callback(
params: types.LoggingMessageNotificationParams, params: types.LoggingMessageNotificationParams,
) -> None: ):
pass pass
@ -94,7 +94,7 @@ class ClientSession(
logging_callback: LoggingFnT | None = None, logging_callback: LoggingFnT | None = None,
message_handler: MessageHandlerFnT | None = None, message_handler: MessageHandlerFnT | None = None,
client_info: types.Implementation | None = None, client_info: types.Implementation | None = None,
) -> None: ):
super().__init__( super().__init__(
read_stream, read_stream,
write_stream, write_stream,
@ -155,9 +155,7 @@ class ClientSession(
types.EmptyResult, types.EmptyResult,
) )
def send_progress_notification( def send_progress_notification(self, progress_token: str | int, progress: float, total: float | None = None):
self, progress_token: str | int, progress: float, total: float | None = None
) -> None:
"""Send a progress notification.""" """Send a progress notification."""
self.send_notification( self.send_notification(
types.ClientNotification( types.ClientNotification(
@ -314,7 +312,7 @@ class ClientSession(
types.ListToolsResult, types.ListToolsResult,
) )
def send_roots_list_changed(self) -> None: def send_roots_list_changed(self):
"""Send a roots/list_changed notification.""" """Send a roots/list_changed notification."""
self.send_notification( self.send_notification(
types.ClientNotification( types.ClientNotification(
@ -324,7 +322,7 @@ class ClientSession(
) )
) )
def _received_request(self, responder: RequestResponder[types.ServerRequest, types.ClientResult]) -> None: def _received_request(self, responder: RequestResponder[types.ServerRequest, types.ClientResult]):
ctx = RequestContext[ClientSession, Any]( ctx = RequestContext[ClientSession, Any](
request_id=responder.request_id, request_id=responder.request_id,
meta=responder.request_meta, meta=responder.request_meta,
@ -352,11 +350,11 @@ class ClientSession(
def _handle_incoming( def _handle_incoming(
self, self,
req: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, req: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
) -> None: ):
"""Handle incoming messages by forwarding to the message handler.""" """Handle incoming messages by forwarding to the message handler."""
self._message_handler(req) self._message_handler(req)
def _received_notification(self, notification: types.ServerNotification) -> None: def _received_notification(self, notification: types.ServerNotification):
"""Handle notifications from the server.""" """Handle notifications from the server."""
# Process specific notification types # Process specific notification types
match notification.root: match notification.root:

View File

@ -27,7 +27,7 @@ class TokenBufferMemory:
self, self,
conversation: Conversation, conversation: Conversation,
model_instance: ModelInstance, model_instance: ModelInstance,
) -> None: ):
self.conversation = conversation self.conversation = conversation
self.model_instance = model_instance self.model_instance = model_instance
@ -124,6 +124,7 @@ class TokenBufferMemory:
messages = list(reversed(thread_messages)) messages = list(reversed(thread_messages))
curr_message_tokens = 0
prompt_messages: list[PromptMessage] = [] prompt_messages: list[PromptMessage] = []
for message in messages: for message in messages:
# Process user message with files # Process user message with files

View File

@ -32,7 +32,7 @@ class ModelInstance:
Model instance class Model instance class
""" """
def __init__(self, provider_model_bundle: ProviderModelBundle, model: str) -> None: def __init__(self, provider_model_bundle: ProviderModelBundle, model: str):
self.provider_model_bundle = provider_model_bundle self.provider_model_bundle = provider_model_bundle
self.model = model self.model = model
self.provider = provider_model_bundle.configuration.provider.provider self.provider = provider_model_bundle.configuration.provider.provider
@ -46,7 +46,7 @@ class ModelInstance:
) )
@staticmethod @staticmethod
def _fetch_credentials_from_bundle(provider_model_bundle: ProviderModelBundle, model: str) -> dict: def _fetch_credentials_from_bundle(provider_model_bundle: ProviderModelBundle, model: str):
""" """
Fetch credentials from provider model bundle Fetch credentials from provider model bundle
:param provider_model_bundle: provider model bundle :param provider_model_bundle: provider model bundle
@ -342,7 +342,7 @@ class ModelInstance:
), ),
) )
def _round_robin_invoke(self, function: Callable[..., Any], *args, **kwargs) -> Any: def _round_robin_invoke(self, function: Callable[..., Any], *args, **kwargs):
""" """
Round-robin invoke Round-robin invoke
:param function: function to invoke :param function: function to invoke
@ -379,7 +379,7 @@ class ModelInstance:
except Exception as e: except Exception as e:
raise e raise e
def get_tts_voices(self, language: Optional[str] = None) -> list: def get_tts_voices(self, language: Optional[str] = None):
""" """
Invoke large language tts model voices Invoke large language tts model voices
@ -394,7 +394,7 @@ class ModelInstance:
class ModelManager: class ModelManager:
def __init__(self) -> None: def __init__(self):
self._provider_manager = ProviderManager() self._provider_manager = ProviderManager()
def get_model_instance(self, tenant_id: str, provider: str, model_type: ModelType, model: str) -> ModelInstance: def get_model_instance(self, tenant_id: str, provider: str, model_type: ModelType, model: str) -> ModelInstance:
@ -453,7 +453,7 @@ class LBModelManager:
model: str, model: str,
load_balancing_configs: list[ModelLoadBalancingConfiguration], load_balancing_configs: list[ModelLoadBalancingConfiguration],
managed_credentials: Optional[dict] = None, managed_credentials: Optional[dict] = None,
) -> None: ):
""" """
Load balancing model manager Load balancing model manager
:param tenant_id: tenant_id :param tenant_id: tenant_id
@ -534,7 +534,7 @@ model: %s""",
return config return config
def cooldown(self, config: ModelLoadBalancingConfiguration, expire: int = 60) -> None: def cooldown(self, config: ModelLoadBalancingConfiguration, expire: int = 60):
""" """
Cooldown model load balancing config Cooldown model load balancing config
:param config: model load balancing config :param config: model load balancing config

Some files were not shown because too many files have changed in this diff Show More