mirror of
https://github.com/langgenius/dify.git
synced 2026-04-29 04:26:30 +08:00
Merge branch 'feat/queue-based-graph-engine' into feat/rag-2
This commit is contained in:
commit
23cd615489
4
.github/workflows/style.yml
vendored
4
.github/workflows/style.yml
vendored
@ -47,6 +47,10 @@ jobs:
|
|||||||
if: steps.changed-files.outputs.any_changed == 'true'
|
if: steps.changed-files.outputs.any_changed == 'true'
|
||||||
run: dev/basedpyright-check
|
run: dev/basedpyright-check
|
||||||
|
|
||||||
|
- name: Run Mypy Type Checks
|
||||||
|
if: steps.changed-files.outputs.any_changed == 'true'
|
||||||
|
run: uv --directory api run mypy --exclude-gitignore --exclude 'tests/' --exclude 'migrations/' --check-untyped-defs --disable-error-code=import-untyped .
|
||||||
|
|
||||||
- name: Dotenv check
|
- name: Dotenv check
|
||||||
if: steps.changed-files.outputs.any_changed == 'true'
|
if: steps.changed-files.outputs.any_changed == 'true'
|
||||||
run: uv run --project api dotenv-linter ./api/.env.example ./web/.env.example
|
run: uv run --project api dotenv-linter ./api/.env.example ./web/.env.example
|
||||||
|
|||||||
8
.gitignore
vendored
8
.gitignore
vendored
@ -198,6 +198,7 @@ sdks/python-client/dify_client.egg-info
|
|||||||
!.vscode/launch.json.template
|
!.vscode/launch.json.template
|
||||||
!.vscode/README.md
|
!.vscode/README.md
|
||||||
api/.vscode
|
api/.vscode
|
||||||
|
web/.vscode
|
||||||
# vscode Code History Extension
|
# vscode Code History Extension
|
||||||
.history
|
.history
|
||||||
|
|
||||||
@ -215,6 +216,13 @@ mise.toml
|
|||||||
# Next.js build output
|
# Next.js build output
|
||||||
.next/
|
.next/
|
||||||
|
|
||||||
|
# PWA generated files
|
||||||
|
web/public/sw.js
|
||||||
|
web/public/sw.js.map
|
||||||
|
web/public/workbox-*.js
|
||||||
|
web/public/workbox-*.js.map
|
||||||
|
web/public/fallback-*.js
|
||||||
|
|
||||||
# AI Assistant
|
# AI Assistant
|
||||||
.roo/
|
.roo/
|
||||||
api/.env.backup
|
api/.env.backup
|
||||||
|
|||||||
@ -25,6 +25,9 @@ def create_flask_app_with_configs() -> DifyApp:
|
|||||||
# add an unique identifier to each request
|
# add an unique identifier to each request
|
||||||
RecyclableContextVar.increment_thread_recycles()
|
RecyclableContextVar.increment_thread_recycles()
|
||||||
|
|
||||||
|
# Capture the decorator's return value to avoid pyright reportUnusedFunction
|
||||||
|
_ = before_request
|
||||||
|
|
||||||
return dify_app
|
return dify_app
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -14,7 +14,6 @@ from sqlalchemy.exc import SQLAlchemyError
|
|||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from constants.languages import languages
|
from constants.languages import languages
|
||||||
from core.helper import encrypter
|
from core.helper import encrypter
|
||||||
from core.plugin.entities.plugin import PluginInstallationSource
|
|
||||||
from core.plugin.impl.plugin import PluginInstaller
|
from core.plugin.impl.plugin import PluginInstaller
|
||||||
from core.rag.datasource.vdb.vector_factory import Vector
|
from core.rag.datasource.vdb.vector_factory import Vector
|
||||||
from core.rag.datasource.vdb.vector_type import VectorType
|
from core.rag.datasource.vdb.vector_type import VectorType
|
||||||
@ -1494,7 +1493,7 @@ def transform_datasource_credentials():
|
|||||||
for credential in credentials:
|
for credential in credentials:
|
||||||
auth_count += 1
|
auth_count += 1
|
||||||
# get credential api key
|
# get credential api key
|
||||||
credentials_json =json.loads(credential.credentials)
|
credentials_json = json.loads(credential.credentials)
|
||||||
api_key = credentials_json.get("config", {}).get("api_key")
|
api_key = credentials_json.get("config", {}).get("api_key")
|
||||||
base_url = credentials_json.get("config", {}).get("base_url")
|
base_url = credentials_json.get("config", {}).get("base_url")
|
||||||
new_credentials = {
|
new_credentials = {
|
||||||
|
|||||||
@ -300,8 +300,7 @@ class DatasetQueueMonitorConfig(BaseSettings):
|
|||||||
|
|
||||||
class MiddlewareConfig(
|
class MiddlewareConfig(
|
||||||
# place the configs in alphabet order
|
# place the configs in alphabet order
|
||||||
CeleryConfig,
|
CeleryConfig, # Note: CeleryConfig already inherits from DatabaseConfig
|
||||||
DatabaseConfig,
|
|
||||||
KeywordStoreConfig,
|
KeywordStoreConfig,
|
||||||
RedisConfig,
|
RedisConfig,
|
||||||
# configs of storage and storage providers
|
# configs of storage and storage providers
|
||||||
|
|||||||
@ -1,9 +1,10 @@
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import Field
|
||||||
|
from pydantic_settings import BaseSettings
|
||||||
|
|
||||||
|
|
||||||
class ClickzettaConfig(BaseModel):
|
class ClickzettaConfig(BaseSettings):
|
||||||
"""
|
"""
|
||||||
Clickzetta Lakehouse vector database configuration
|
Clickzetta Lakehouse vector database configuration
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -1,7 +1,8 @@
|
|||||||
from pydantic import BaseModel, Field
|
from pydantic import Field
|
||||||
|
from pydantic_settings import BaseSettings
|
||||||
|
|
||||||
|
|
||||||
class MatrixoneConfig(BaseModel):
|
class MatrixoneConfig(BaseSettings):
|
||||||
"""Matrixone vector database configuration."""
|
"""Matrixone vector database configuration."""
|
||||||
|
|
||||||
MATRIXONE_HOST: str = Field(default="localhost", description="Host address of the Matrixone server")
|
MATRIXONE_HOST: str = Field(default="localhost", description="Host address of the Matrixone server")
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|
||||||
from configs.packaging.pyproject import PyProjectConfig, PyProjectTomlConfig
|
from configs.packaging.pyproject import PyProjectTomlConfig
|
||||||
|
|
||||||
|
|
||||||
class PackagingInfo(PyProjectTomlConfig):
|
class PackagingInfo(PyProjectTomlConfig):
|
||||||
|
|||||||
@ -4,8 +4,9 @@ import logging
|
|||||||
import os
|
import os
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
from collections.abc import Mapping
|
from collections.abc import Callable, Mapping
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from .python_3x import http_request, makedirs_wrapper
|
from .python_3x import http_request, makedirs_wrapper
|
||||||
from .utils import (
|
from .utils import (
|
||||||
@ -25,13 +26,13 @@ logger = logging.getLogger(__name__)
|
|||||||
class ApolloClient:
|
class ApolloClient:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config_url,
|
config_url: str,
|
||||||
app_id,
|
app_id: str,
|
||||||
cluster="default",
|
cluster: str = "default",
|
||||||
secret="",
|
secret: str = "",
|
||||||
start_hot_update=True,
|
start_hot_update: bool = True,
|
||||||
change_listener=None,
|
change_listener: Callable[[str, str, str, Any], None] | None = None,
|
||||||
_notification_map=None,
|
_notification_map: dict[str, int] | None = None,
|
||||||
):
|
):
|
||||||
# Core routing parameters
|
# Core routing parameters
|
||||||
self.config_url = config_url
|
self.config_url = config_url
|
||||||
@ -47,17 +48,17 @@ class ApolloClient:
|
|||||||
# Private control variables
|
# Private control variables
|
||||||
self._cycle_time = 5
|
self._cycle_time = 5
|
||||||
self._stopping = False
|
self._stopping = False
|
||||||
self._cache = {}
|
self._cache: dict[str, dict[str, Any]] = {}
|
||||||
self._no_key = {}
|
self._no_key: dict[str, str] = {}
|
||||||
self._hash = {}
|
self._hash: dict[str, str] = {}
|
||||||
self._pull_timeout = 75
|
self._pull_timeout = 75
|
||||||
self._cache_file_path = os.path.expanduser("~") + "/.dify/config/remote-settings/apollo/cache/"
|
self._cache_file_path = os.path.expanduser("~") + "/.dify/config/remote-settings/apollo/cache/"
|
||||||
self._long_poll_thread = None
|
self._long_poll_thread: threading.Thread | None = None
|
||||||
self._change_listener = change_listener # "add" "delete" "update"
|
self._change_listener = change_listener # "add" "delete" "update"
|
||||||
if _notification_map is None:
|
if _notification_map is None:
|
||||||
_notification_map = {"application": -1}
|
_notification_map = {"application": -1}
|
||||||
self._notification_map = _notification_map
|
self._notification_map = _notification_map
|
||||||
self.last_release_key = None
|
self.last_release_key: str | None = None
|
||||||
# Private startup method
|
# Private startup method
|
||||||
self._path_checker()
|
self._path_checker()
|
||||||
if start_hot_update:
|
if start_hot_update:
|
||||||
@ -68,7 +69,7 @@ class ApolloClient:
|
|||||||
heartbeat.daemon = True
|
heartbeat.daemon = True
|
||||||
heartbeat.start()
|
heartbeat.start()
|
||||||
|
|
||||||
def get_json_from_net(self, namespace="application"):
|
def get_json_from_net(self, namespace: str = "application") -> dict[str, Any] | None:
|
||||||
url = "{}/configs/{}/{}/{}?releaseKey={}&ip={}".format(
|
url = "{}/configs/{}/{}/{}?releaseKey={}&ip={}".format(
|
||||||
self.config_url, self.app_id, self.cluster, namespace, "", self.ip
|
self.config_url, self.app_id, self.cluster, namespace, "", self.ip
|
||||||
)
|
)
|
||||||
@ -88,7 +89,7 @@ class ApolloClient:
|
|||||||
logger.exception("an error occurred in get_json_from_net")
|
logger.exception("an error occurred in get_json_from_net")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_value(self, key, default_val=None, namespace="application"):
|
def get_value(self, key: str, default_val: Any = None, namespace: str = "application") -> Any:
|
||||||
try:
|
try:
|
||||||
# read memory configuration
|
# read memory configuration
|
||||||
namespace_cache = self._cache.get(namespace)
|
namespace_cache = self._cache.get(namespace)
|
||||||
@ -104,7 +105,8 @@ class ApolloClient:
|
|||||||
namespace_data = self.get_json_from_net(namespace)
|
namespace_data = self.get_json_from_net(namespace)
|
||||||
val = get_value_from_dict(namespace_data, key)
|
val = get_value_from_dict(namespace_data, key)
|
||||||
if val is not None:
|
if val is not None:
|
||||||
self._update_cache_and_file(namespace_data, namespace)
|
if namespace_data is not None:
|
||||||
|
self._update_cache_and_file(namespace_data, namespace)
|
||||||
return val
|
return val
|
||||||
|
|
||||||
# read the file configuration
|
# read the file configuration
|
||||||
@ -126,23 +128,23 @@ class ApolloClient:
|
|||||||
# to ensure the real-time correctness of the function call.
|
# to ensure the real-time correctness of the function call.
|
||||||
# If the user does not have the same default val twice
|
# If the user does not have the same default val twice
|
||||||
# and the default val is used here, there may be a problem.
|
# and the default val is used here, there may be a problem.
|
||||||
def _set_local_cache_none(self, namespace, key):
|
def _set_local_cache_none(self, namespace: str, key: str) -> None:
|
||||||
no_key = no_key_cache_key(namespace, key)
|
no_key = no_key_cache_key(namespace, key)
|
||||||
self._no_key[no_key] = key
|
self._no_key[no_key] = key
|
||||||
|
|
||||||
def _start_hot_update(self):
|
def _start_hot_update(self) -> None:
|
||||||
self._long_poll_thread = threading.Thread(target=self._listener)
|
self._long_poll_thread = threading.Thread(target=self._listener)
|
||||||
# When the asynchronous thread is started, the daemon thread will automatically exit
|
# When the asynchronous thread is started, the daemon thread will automatically exit
|
||||||
# when the main thread is launched.
|
# when the main thread is launched.
|
||||||
self._long_poll_thread.daemon = True
|
self._long_poll_thread.daemon = True
|
||||||
self._long_poll_thread.start()
|
self._long_poll_thread.start()
|
||||||
|
|
||||||
def stop(self):
|
def stop(self) -> None:
|
||||||
self._stopping = True
|
self._stopping = True
|
||||||
logger.info("Stopping listener...")
|
logger.info("Stopping listener...")
|
||||||
|
|
||||||
# Call the set callback function, and if it is abnormal, try it out
|
# Call the set callback function, and if it is abnormal, try it out
|
||||||
def _call_listener(self, namespace, old_kv, new_kv):
|
def _call_listener(self, namespace: str, old_kv: dict[str, Any] | None, new_kv: dict[str, Any] | None) -> None:
|
||||||
if self._change_listener is None:
|
if self._change_listener is None:
|
||||||
return
|
return
|
||||||
if old_kv is None:
|
if old_kv is None:
|
||||||
@ -168,12 +170,12 @@ class ApolloClient:
|
|||||||
except BaseException as e:
|
except BaseException as e:
|
||||||
logger.warning(str(e))
|
logger.warning(str(e))
|
||||||
|
|
||||||
def _path_checker(self):
|
def _path_checker(self) -> None:
|
||||||
if not os.path.isdir(self._cache_file_path):
|
if not os.path.isdir(self._cache_file_path):
|
||||||
makedirs_wrapper(self._cache_file_path)
|
makedirs_wrapper(self._cache_file_path)
|
||||||
|
|
||||||
# update the local cache and file cache
|
# update the local cache and file cache
|
||||||
def _update_cache_and_file(self, namespace_data, namespace="application"):
|
def _update_cache_and_file(self, namespace_data: dict[str, Any], namespace: str = "application") -> None:
|
||||||
# update the local cache
|
# update the local cache
|
||||||
self._cache[namespace] = namespace_data
|
self._cache[namespace] = namespace_data
|
||||||
# update the file cache
|
# update the file cache
|
||||||
@ -187,7 +189,7 @@ class ApolloClient:
|
|||||||
self._hash[namespace] = new_hash
|
self._hash[namespace] = new_hash
|
||||||
|
|
||||||
# get the configuration from the local file
|
# get the configuration from the local file
|
||||||
def _get_local_cache(self, namespace="application"):
|
def _get_local_cache(self, namespace: str = "application") -> dict[str, Any]:
|
||||||
cache_file_path = os.path.join(self._cache_file_path, f"{self.app_id}_configuration_{namespace}.txt")
|
cache_file_path = os.path.join(self._cache_file_path, f"{self.app_id}_configuration_{namespace}.txt")
|
||||||
if os.path.isfile(cache_file_path):
|
if os.path.isfile(cache_file_path):
|
||||||
with open(cache_file_path) as f:
|
with open(cache_file_path) as f:
|
||||||
@ -195,8 +197,8 @@ class ApolloClient:
|
|||||||
return result
|
return result
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
def _long_poll(self):
|
def _long_poll(self) -> None:
|
||||||
notifications = []
|
notifications: list[dict[str, Any]] = []
|
||||||
for key in self._cache:
|
for key in self._cache:
|
||||||
namespace_data = self._cache[key]
|
namespace_data = self._cache[key]
|
||||||
notification_id = -1
|
notification_id = -1
|
||||||
@ -236,7 +238,7 @@ class ApolloClient:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(str(e))
|
logger.warning(str(e))
|
||||||
|
|
||||||
def _get_net_and_set_local(self, namespace, n_id, call_change=False):
|
def _get_net_and_set_local(self, namespace: str, n_id: int, call_change: bool = False) -> None:
|
||||||
namespace_data = self.get_json_from_net(namespace)
|
namespace_data = self.get_json_from_net(namespace)
|
||||||
if not namespace_data:
|
if not namespace_data:
|
||||||
return
|
return
|
||||||
@ -248,7 +250,7 @@ class ApolloClient:
|
|||||||
new_kv = namespace_data.get(CONFIGURATIONS)
|
new_kv = namespace_data.get(CONFIGURATIONS)
|
||||||
self._call_listener(namespace, old_kv, new_kv)
|
self._call_listener(namespace, old_kv, new_kv)
|
||||||
|
|
||||||
def _listener(self):
|
def _listener(self) -> None:
|
||||||
logger.info("start long_poll")
|
logger.info("start long_poll")
|
||||||
while not self._stopping:
|
while not self._stopping:
|
||||||
self._long_poll()
|
self._long_poll()
|
||||||
@ -266,13 +268,13 @@ class ApolloClient:
|
|||||||
headers["Timestamp"] = time_unix_now
|
headers["Timestamp"] = time_unix_now
|
||||||
return headers
|
return headers
|
||||||
|
|
||||||
def _heart_beat(self):
|
def _heart_beat(self) -> None:
|
||||||
while not self._stopping:
|
while not self._stopping:
|
||||||
for namespace in self._notification_map:
|
for namespace in self._notification_map:
|
||||||
self._do_heart_beat(namespace)
|
self._do_heart_beat(namespace)
|
||||||
time.sleep(60 * 10) # 10 minutes
|
time.sleep(60 * 10) # 10 minutes
|
||||||
|
|
||||||
def _do_heart_beat(self, namespace):
|
def _do_heart_beat(self, namespace: str) -> None:
|
||||||
url = f"{self.config_url}/configs/{self.app_id}/{self.cluster}/{namespace}?ip={self.ip}"
|
url = f"{self.config_url}/configs/{self.app_id}/{self.cluster}/{namespace}?ip={self.ip}"
|
||||||
try:
|
try:
|
||||||
code, body = http_request(url, timeout=3, headers=self._sign_headers(url))
|
code, body = http_request(url, timeout=3, headers=self._sign_headers(url))
|
||||||
@ -292,7 +294,7 @@ class ApolloClient:
|
|||||||
logger.exception("an error occurred in _do_heart_beat")
|
logger.exception("an error occurred in _do_heart_beat")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_all_dicts(self, namespace):
|
def get_all_dicts(self, namespace: str) -> dict[str, Any] | None:
|
||||||
namespace_data = self._cache.get(namespace)
|
namespace_data = self._cache.get(namespace)
|
||||||
if namespace_data is None:
|
if namespace_data is None:
|
||||||
net_namespace_data = self.get_json_from_net(namespace)
|
net_namespace_data = self.get_json_from_net(namespace)
|
||||||
|
|||||||
@ -2,6 +2,8 @@ import logging
|
|||||||
import os
|
import os
|
||||||
import ssl
|
import ssl
|
||||||
import urllib.request
|
import urllib.request
|
||||||
|
from collections.abc import Mapping
|
||||||
|
from typing import Any
|
||||||
from urllib import parse
|
from urllib import parse
|
||||||
from urllib.error import HTTPError
|
from urllib.error import HTTPError
|
||||||
|
|
||||||
@ -19,9 +21,9 @@ urllib.request.install_opener(opener)
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def http_request(url, timeout, headers={}):
|
def http_request(url: str, timeout: int | float, headers: Mapping[str, str] = {}) -> tuple[int, str | None]:
|
||||||
try:
|
try:
|
||||||
request = urllib.request.Request(url, headers=headers)
|
request = urllib.request.Request(url, headers=dict(headers))
|
||||||
res = urllib.request.urlopen(request, timeout=timeout)
|
res = urllib.request.urlopen(request, timeout=timeout)
|
||||||
body = res.read().decode("utf-8")
|
body = res.read().decode("utf-8")
|
||||||
return res.code, body
|
return res.code, body
|
||||||
@ -33,9 +35,9 @@ def http_request(url, timeout, headers={}):
|
|||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
|
||||||
def url_encode(params):
|
def url_encode(params: dict[str, Any]) -> str:
|
||||||
return parse.urlencode(params)
|
return parse.urlencode(params)
|
||||||
|
|
||||||
|
|
||||||
def makedirs_wrapper(path):
|
def makedirs_wrapper(path: str) -> None:
|
||||||
os.makedirs(path, exist_ok=True)
|
os.makedirs(path, exist_ok=True)
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
import hashlib
|
import hashlib
|
||||||
import socket
|
import socket
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from .python_3x import url_encode
|
from .python_3x import url_encode
|
||||||
|
|
||||||
@ -10,7 +11,7 @@ NAMESPACE_NAME = "namespaceName"
|
|||||||
|
|
||||||
|
|
||||||
# add timestamps uris and keys
|
# add timestamps uris and keys
|
||||||
def signature(timestamp, uri, secret):
|
def signature(timestamp: str, uri: str, secret: str) -> str:
|
||||||
import base64
|
import base64
|
||||||
import hmac
|
import hmac
|
||||||
|
|
||||||
@ -19,16 +20,16 @@ def signature(timestamp, uri, secret):
|
|||||||
return base64.b64encode(hmac_code).decode()
|
return base64.b64encode(hmac_code).decode()
|
||||||
|
|
||||||
|
|
||||||
def url_encode_wrapper(params):
|
def url_encode_wrapper(params: dict[str, Any]) -> str:
|
||||||
return url_encode(params)
|
return url_encode(params)
|
||||||
|
|
||||||
|
|
||||||
def no_key_cache_key(namespace, key):
|
def no_key_cache_key(namespace: str, key: str) -> str:
|
||||||
return f"{namespace}{len(namespace)}{key}"
|
return f"{namespace}{len(namespace)}{key}"
|
||||||
|
|
||||||
|
|
||||||
# Returns whether the obtained value is obtained, and None if it does not
|
# Returns whether the obtained value is obtained, and None if it does not
|
||||||
def get_value_from_dict(namespace_cache, key):
|
def get_value_from_dict(namespace_cache: dict[str, Any] | None, key: str) -> Any | None:
|
||||||
if namespace_cache:
|
if namespace_cache:
|
||||||
kv_data = namespace_cache.get(CONFIGURATIONS)
|
kv_data = namespace_cache.get(CONFIGURATIONS)
|
||||||
if kv_data is None:
|
if kv_data is None:
|
||||||
@ -38,7 +39,7 @@ def get_value_from_dict(namespace_cache, key):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def init_ip():
|
def init_ip() -> str:
|
||||||
ip = ""
|
ip = ""
|
||||||
s = None
|
s = None
|
||||||
try:
|
try:
|
||||||
|
|||||||
@ -11,5 +11,5 @@ class RemoteSettingsSource:
|
|||||||
def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
|
def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def prepare_field_value(self, field_name: str, field: FieldInfo, value: Any, value_is_complex: bool) -> Any:
|
def prepare_field_value(self, field_name: str, field: FieldInfo, value: Any, value_is_complex: bool):
|
||||||
return value
|
return value
|
||||||
|
|||||||
@ -11,16 +11,16 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
from configs.remote_settings_sources.base import RemoteSettingsSource
|
from configs.remote_settings_sources.base import RemoteSettingsSource
|
||||||
|
|
||||||
from .utils import _parse_config
|
from .utils import parse_config
|
||||||
|
|
||||||
|
|
||||||
class NacosSettingsSource(RemoteSettingsSource):
|
class NacosSettingsSource(RemoteSettingsSource):
|
||||||
def __init__(self, configs: Mapping[str, Any]):
|
def __init__(self, configs: Mapping[str, Any]):
|
||||||
self.configs = configs
|
self.configs = configs
|
||||||
self.remote_configs: dict[str, Any] = {}
|
self.remote_configs: dict[str, str] = {}
|
||||||
self.async_init()
|
self.async_init()
|
||||||
|
|
||||||
def async_init(self):
|
def async_init(self) -> None:
|
||||||
data_id = os.getenv("DIFY_ENV_NACOS_DATA_ID", "dify-api-env.properties")
|
data_id = os.getenv("DIFY_ENV_NACOS_DATA_ID", "dify-api-env.properties")
|
||||||
group = os.getenv("DIFY_ENV_NACOS_GROUP", "nacos-dify")
|
group = os.getenv("DIFY_ENV_NACOS_GROUP", "nacos-dify")
|
||||||
tenant = os.getenv("DIFY_ENV_NACOS_NAMESPACE", "")
|
tenant = os.getenv("DIFY_ENV_NACOS_NAMESPACE", "")
|
||||||
@ -33,18 +33,15 @@ class NacosSettingsSource(RemoteSettingsSource):
|
|||||||
logger.exception("[get-access-token] exception occurred")
|
logger.exception("[get-access-token] exception occurred")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def _parse_config(self, content: str) -> dict:
|
def _parse_config(self, content: str) -> dict[str, str]:
|
||||||
if not content:
|
if not content:
|
||||||
return {}
|
return {}
|
||||||
try:
|
try:
|
||||||
return _parse_config(self, content)
|
return parse_config(content)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise RuntimeError(f"Failed to parse config: {e}")
|
raise RuntimeError(f"Failed to parse config: {e}")
|
||||||
|
|
||||||
def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
|
def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
|
||||||
if not isinstance(self.remote_configs, dict):
|
|
||||||
raise ValueError(f"remote configs is not dict, but {type(self.remote_configs)}")
|
|
||||||
|
|
||||||
field_value = self.remote_configs.get(field_name)
|
field_value = self.remote_configs.get(field_name)
|
||||||
if field_value is None:
|
if field_value is None:
|
||||||
return None, field_name, False
|
return None, field_name, False
|
||||||
|
|||||||
@ -17,11 +17,17 @@ class NacosHttpClient:
|
|||||||
self.ak = os.getenv("DIFY_ENV_NACOS_ACCESS_KEY")
|
self.ak = os.getenv("DIFY_ENV_NACOS_ACCESS_KEY")
|
||||||
self.sk = os.getenv("DIFY_ENV_NACOS_SECRET_KEY")
|
self.sk = os.getenv("DIFY_ENV_NACOS_SECRET_KEY")
|
||||||
self.server = os.getenv("DIFY_ENV_NACOS_SERVER_ADDR", "localhost:8848")
|
self.server = os.getenv("DIFY_ENV_NACOS_SERVER_ADDR", "localhost:8848")
|
||||||
self.token = None
|
self.token: str | None = None
|
||||||
self.token_ttl = 18000
|
self.token_ttl = 18000
|
||||||
self.token_expire_time: float = 0
|
self.token_expire_time: float = 0
|
||||||
|
|
||||||
def http_request(self, url, method="GET", headers=None, params=None):
|
def http_request(
|
||||||
|
self, url: str, method: str = "GET", headers: dict[str, str] | None = None, params: dict[str, str] | None = None
|
||||||
|
) -> str:
|
||||||
|
if headers is None:
|
||||||
|
headers = {}
|
||||||
|
if params is None:
|
||||||
|
params = {}
|
||||||
try:
|
try:
|
||||||
self._inject_auth_info(headers, params)
|
self._inject_auth_info(headers, params)
|
||||||
response = requests.request(method, url="http://" + self.server + url, headers=headers, params=params)
|
response = requests.request(method, url="http://" + self.server + url, headers=headers, params=params)
|
||||||
@ -30,7 +36,7 @@ class NacosHttpClient:
|
|||||||
except requests.RequestException as e:
|
except requests.RequestException as e:
|
||||||
return f"Request to Nacos failed: {e}"
|
return f"Request to Nacos failed: {e}"
|
||||||
|
|
||||||
def _inject_auth_info(self, headers, params, module="config"):
|
def _inject_auth_info(self, headers: dict[str, str], params: dict[str, str], module: str = "config") -> None:
|
||||||
headers.update({"User-Agent": "Nacos-Http-Client-In-Dify:v0.0.1"})
|
headers.update({"User-Agent": "Nacos-Http-Client-In-Dify:v0.0.1"})
|
||||||
|
|
||||||
if module == "login":
|
if module == "login":
|
||||||
@ -45,16 +51,17 @@ class NacosHttpClient:
|
|||||||
headers["timeStamp"] = ts
|
headers["timeStamp"] = ts
|
||||||
if self.username and self.password:
|
if self.username and self.password:
|
||||||
self.get_access_token(force_refresh=False)
|
self.get_access_token(force_refresh=False)
|
||||||
params["accessToken"] = self.token
|
if self.token is not None:
|
||||||
|
params["accessToken"] = self.token
|
||||||
|
|
||||||
def __do_sign(self, sign_str, sk):
|
def __do_sign(self, sign_str: str, sk: str) -> str:
|
||||||
return (
|
return (
|
||||||
base64.encodebytes(hmac.new(sk.encode(), sign_str.encode(), digestmod=hashlib.sha1).digest())
|
base64.encodebytes(hmac.new(sk.encode(), sign_str.encode(), digestmod=hashlib.sha1).digest())
|
||||||
.decode()
|
.decode()
|
||||||
.strip()
|
.strip()
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_sign_str(self, group, tenant, ts):
|
def get_sign_str(self, group: str, tenant: str, ts: str) -> str:
|
||||||
sign_str = ""
|
sign_str = ""
|
||||||
if tenant:
|
if tenant:
|
||||||
sign_str = tenant + "+"
|
sign_str = tenant + "+"
|
||||||
@ -63,7 +70,7 @@ class NacosHttpClient:
|
|||||||
sign_str += ts # Directly concatenate ts without conditional checks, because the nacos auth header forced it.
|
sign_str += ts # Directly concatenate ts without conditional checks, because the nacos auth header forced it.
|
||||||
return sign_str
|
return sign_str
|
||||||
|
|
||||||
def get_access_token(self, force_refresh=False):
|
def get_access_token(self, force_refresh: bool = False) -> str | None:
|
||||||
current_time = time.time()
|
current_time = time.time()
|
||||||
if self.token and not force_refresh and self.token_expire_time > current_time:
|
if self.token and not force_refresh and self.token_expire_time > current_time:
|
||||||
return self.token
|
return self.token
|
||||||
@ -77,6 +84,7 @@ class NacosHttpClient:
|
|||||||
self.token = response_data.get("accessToken")
|
self.token = response_data.get("accessToken")
|
||||||
self.token_ttl = response_data.get("tokenTtl", 18000)
|
self.token_ttl = response_data.get("tokenTtl", 18000)
|
||||||
self.token_expire_time = current_time + self.token_ttl - 10
|
self.token_expire_time = current_time + self.token_ttl - 10
|
||||||
|
return self.token
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("[get-access-token] exception occur")
|
logger.exception("[get-access-token] exception occur")
|
||||||
raise
|
raise
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
def _parse_config(self, content: str) -> dict[str, str]:
|
def parse_config(content: str) -> dict[str, str]:
|
||||||
config: dict[str, str] = {}
|
config: dict[str, str] = {}
|
||||||
if not content:
|
if not content:
|
||||||
return config
|
return config
|
||||||
|
|||||||
@ -1,4 +1,6 @@
|
|||||||
|
from collections.abc import Callable
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
|
from typing import ParamSpec, TypeVar
|
||||||
|
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask_restx import Resource, reqparse
|
from flask_restx import Resource, reqparse
|
||||||
@ -6,6 +8,8 @@ from sqlalchemy import select
|
|||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from werkzeug.exceptions import NotFound, Unauthorized
|
from werkzeug.exceptions import NotFound, Unauthorized
|
||||||
|
|
||||||
|
P = ParamSpec("P")
|
||||||
|
R = TypeVar("R")
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from constants.languages import supported_language
|
from constants.languages import supported_language
|
||||||
from controllers.console import api
|
from controllers.console import api
|
||||||
@ -14,9 +18,9 @@ from extensions.ext_database import db
|
|||||||
from models.model import App, InstalledApp, RecommendedApp
|
from models.model import App, InstalledApp, RecommendedApp
|
||||||
|
|
||||||
|
|
||||||
def admin_required(view):
|
def admin_required(view: Callable[P, R]):
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(*args, **kwargs):
|
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||||
if not dify_config.ADMIN_API_KEY:
|
if not dify_config.ADMIN_API_KEY:
|
||||||
raise Unauthorized("API key is invalid.")
|
raise Unauthorized("API key is invalid.")
|
||||||
|
|
||||||
|
|||||||
@ -87,7 +87,7 @@ class BaseApiKeyListResource(Resource):
|
|||||||
custom="max_keys_exceeded",
|
custom="max_keys_exceeded",
|
||||||
)
|
)
|
||||||
|
|
||||||
key = ApiToken.generate_api_key(self.token_prefix, 24)
|
key = ApiToken.generate_api_key(self.token_prefix or "", 24)
|
||||||
api_token = ApiToken()
|
api_token = ApiToken()
|
||||||
setattr(api_token, self.resource_id_field, resource_id)
|
setattr(api_token, self.resource_id_field, resource_id)
|
||||||
api_token.tenant_id = current_user.current_tenant_id
|
api_token.tenant_id = current_user.current_tenant_id
|
||||||
|
|||||||
@ -207,7 +207,7 @@ class InstructionGenerationTemplateApi(Resource):
|
|||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self) -> dict:
|
def post(self):
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument("type", type=str, required=True, default=False, location="json")
|
parser.add_argument("type", type=str, required=True, default=False, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Any, NoReturn
|
from typing import NoReturn
|
||||||
|
|
||||||
from flask import Response
|
from flask import Response
|
||||||
from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse
|
from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse
|
||||||
@ -31,7 +31,7 @@ from services.workflow_service import WorkflowService
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _convert_values_to_json_serializable_object(value: Segment) -> Any:
|
def _convert_values_to_json_serializable_object(value: Segment):
|
||||||
if isinstance(value, FileSegment):
|
if isinstance(value, FileSegment):
|
||||||
return value.value.model_dump()
|
return value.value.model_dump()
|
||||||
elif isinstance(value, ArrayFileSegment):
|
elif isinstance(value, ArrayFileSegment):
|
||||||
@ -42,8 +42,7 @@ def _convert_values_to_json_serializable_object(value: Segment) -> Any:
|
|||||||
return value.value
|
return value.value
|
||||||
|
|
||||||
|
|
||||||
def _serialize_var_value(variable: WorkflowDraftVariable) -> Any:
|
def _serialize_var_value(variable: WorkflowDraftVariable):
|
||||||
"""Serialize variable value. If variable is truncated, return the truncated value."""
|
|
||||||
value = variable.get_value()
|
value = variable.get_value()
|
||||||
# create a copy of the value to avoid affecting the model cache.
|
# create a copy of the value to avoid affecting the model cache.
|
||||||
value = value.model_copy(deep=True)
|
value = value.model_copy(deep=True)
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
|
from collections.abc import Callable
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import cast
|
from typing import Concatenate, ParamSpec, TypeVar, cast
|
||||||
|
|
||||||
import flask_login
|
import flask_login
|
||||||
from flask import jsonify, request
|
from flask import jsonify, request
|
||||||
@ -15,10 +16,14 @@ from services.oauth_server import OAUTH_ACCESS_TOKEN_EXPIRES_IN, OAuthGrantType,
|
|||||||
|
|
||||||
from .. import api
|
from .. import api
|
||||||
|
|
||||||
|
P = ParamSpec("P")
|
||||||
|
R = TypeVar("R")
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
def oauth_server_client_id_required(view):
|
|
||||||
|
def oauth_server_client_id_required(view: Callable[Concatenate[T, OAuthProviderApp, P], R]):
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(*args, **kwargs):
|
def decorated(self: T, *args: P.args, **kwargs: P.kwargs):
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument("client_id", type=str, required=True, location="json")
|
parser.add_argument("client_id", type=str, required=True, location="json")
|
||||||
parsed_args = parser.parse_args()
|
parsed_args = parser.parse_args()
|
||||||
@ -30,18 +35,15 @@ def oauth_server_client_id_required(view):
|
|||||||
if not oauth_provider_app:
|
if not oauth_provider_app:
|
||||||
raise NotFound("client_id is invalid")
|
raise NotFound("client_id is invalid")
|
||||||
|
|
||||||
kwargs["oauth_provider_app"] = oauth_provider_app
|
return view(self, oauth_provider_app, *args, **kwargs)
|
||||||
|
|
||||||
return view(*args, **kwargs)
|
|
||||||
|
|
||||||
return decorated
|
return decorated
|
||||||
|
|
||||||
|
|
||||||
def oauth_server_access_token_required(view):
|
def oauth_server_access_token_required(view: Callable[Concatenate[T, OAuthProviderApp, Account, P], R]):
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(*args, **kwargs):
|
def decorated(self: T, oauth_provider_app: OAuthProviderApp, *args: P.args, **kwargs: P.kwargs):
|
||||||
oauth_provider_app = kwargs.get("oauth_provider_app")
|
if not isinstance(oauth_provider_app, OAuthProviderApp):
|
||||||
if not oauth_provider_app or not isinstance(oauth_provider_app, OAuthProviderApp):
|
|
||||||
raise BadRequest("Invalid oauth_provider_app")
|
raise BadRequest("Invalid oauth_provider_app")
|
||||||
|
|
||||||
authorization_header = request.headers.get("Authorization")
|
authorization_header = request.headers.get("Authorization")
|
||||||
@ -79,9 +81,7 @@ def oauth_server_access_token_required(view):
|
|||||||
response.headers["WWW-Authenticate"] = "Bearer"
|
response.headers["WWW-Authenticate"] = "Bearer"
|
||||||
return response
|
return response
|
||||||
|
|
||||||
kwargs["account"] = account
|
return view(self, oauth_provider_app, account, *args, **kwargs)
|
||||||
|
|
||||||
return view(*args, **kwargs)
|
|
||||||
|
|
||||||
return decorated
|
return decorated
|
||||||
|
|
||||||
|
|||||||
@ -1,9 +1,9 @@
|
|||||||
from flask_login import current_user
|
|
||||||
from flask_restx import Resource, reqparse
|
from flask_restx import Resource, reqparse
|
||||||
|
|
||||||
from controllers.console import api
|
from controllers.console import api
|
||||||
from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required
|
from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required
|
||||||
from libs.login import login_required
|
from libs.login import current_user, login_required
|
||||||
|
from models.model import Account
|
||||||
from services.billing_service import BillingService
|
from services.billing_service import BillingService
|
||||||
|
|
||||||
|
|
||||||
@ -17,9 +17,10 @@ class Subscription(Resource):
|
|||||||
parser.add_argument("plan", type=str, required=True, location="args", choices=["professional", "team"])
|
parser.add_argument("plan", type=str, required=True, location="args", choices=["professional", "team"])
|
||||||
parser.add_argument("interval", type=str, required=True, location="args", choices=["month", "year"])
|
parser.add_argument("interval", type=str, required=True, location="args", choices=["month", "year"])
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
assert isinstance(current_user, Account)
|
||||||
|
|
||||||
BillingService.is_tenant_owner_or_admin(current_user)
|
BillingService.is_tenant_owner_or_admin(current_user)
|
||||||
|
assert current_user.current_tenant_id is not None
|
||||||
return BillingService.get_subscription(
|
return BillingService.get_subscription(
|
||||||
args["plan"], args["interval"], current_user.email, current_user.current_tenant_id
|
args["plan"], args["interval"], current_user.email, current_user.current_tenant_id
|
||||||
)
|
)
|
||||||
@ -31,7 +32,9 @@ class Invoices(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@only_edition_cloud
|
@only_edition_cloud
|
||||||
def get(self):
|
def get(self):
|
||||||
|
assert isinstance(current_user, Account)
|
||||||
BillingService.is_tenant_owner_or_admin(current_user)
|
BillingService.is_tenant_owner_or_admin(current_user)
|
||||||
|
assert current_user.current_tenant_id is not None
|
||||||
return BillingService.get_invoices(current_user.email, current_user.current_tenant_id)
|
return BillingService.get_invoices(current_user.email, current_user.current_tenant_id)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -477,6 +477,8 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
|||||||
data_source_info = document.data_source_info_dict
|
data_source_info = document.data_source_info_dict
|
||||||
|
|
||||||
if document.data_source_type == "upload_file":
|
if document.data_source_type == "upload_file":
|
||||||
|
if not data_source_info:
|
||||||
|
continue
|
||||||
file_id = data_source_info["upload_file_id"]
|
file_id = data_source_info["upload_file_id"]
|
||||||
file_detail = (
|
file_detail = (
|
||||||
db.session.query(UploadFile)
|
db.session.query(UploadFile)
|
||||||
@ -493,6 +495,8 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
|||||||
extract_settings.append(extract_setting)
|
extract_settings.append(extract_setting)
|
||||||
|
|
||||||
elif document.data_source_type == "notion_import":
|
elif document.data_source_type == "notion_import":
|
||||||
|
if not data_source_info:
|
||||||
|
continue
|
||||||
extract_setting = ExtractSetting(
|
extract_setting = ExtractSetting(
|
||||||
datasource_type=DatasourceType.NOTION.value,
|
datasource_type=DatasourceType.NOTION.value,
|
||||||
notion_info={
|
notion_info={
|
||||||
@ -506,6 +510,8 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
|||||||
)
|
)
|
||||||
extract_settings.append(extract_setting)
|
extract_settings.append(extract_setting)
|
||||||
elif document.data_source_type == "website_crawl":
|
elif document.data_source_type == "website_crawl":
|
||||||
|
if not data_source_info:
|
||||||
|
continue
|
||||||
extract_setting = ExtractSetting(
|
extract_setting = ExtractSetting(
|
||||||
datasource_type=DatasourceType.WEBSITE.value,
|
datasource_type=DatasourceType.WEBSITE.value,
|
||||||
website_info={
|
website_info={
|
||||||
|
|||||||
@ -43,6 +43,8 @@ class ExploreAppMetaApi(InstalledAppResource):
|
|||||||
def get(self, installed_app: InstalledApp):
|
def get(self, installed_app: InstalledApp):
|
||||||
"""Get app meta"""
|
"""Get app meta"""
|
||||||
app_model = installed_app.app
|
app_model = installed_app.app
|
||||||
|
if not app_model:
|
||||||
|
raise ValueError("App not found")
|
||||||
return AppService().get_app_meta(app_model)
|
return AppService().get_app_meta(app_model)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -36,6 +36,8 @@ class InstalledAppWorkflowRunApi(InstalledAppResource):
|
|||||||
Run workflow
|
Run workflow
|
||||||
"""
|
"""
|
||||||
app_model = installed_app.app
|
app_model = installed_app.app
|
||||||
|
if not app_model:
|
||||||
|
raise NotWorkflowAppError()
|
||||||
app_mode = AppMode.value_of(app_model.mode)
|
app_mode = AppMode.value_of(app_model.mode)
|
||||||
if app_mode != AppMode.WORKFLOW:
|
if app_mode != AppMode.WORKFLOW:
|
||||||
raise NotWorkflowAppError()
|
raise NotWorkflowAppError()
|
||||||
@ -74,6 +76,8 @@ class InstalledAppWorkflowTaskStopApi(InstalledAppResource):
|
|||||||
Stop workflow task
|
Stop workflow task
|
||||||
"""
|
"""
|
||||||
app_model = installed_app.app
|
app_model = installed_app.app
|
||||||
|
if not app_model:
|
||||||
|
raise NotWorkflowAppError()
|
||||||
app_mode = AppMode.value_of(app_model.mode)
|
app_mode = AppMode.value_of(app_model.mode)
|
||||||
if app_mode != AppMode.WORKFLOW:
|
if app_mode != AppMode.WORKFLOW:
|
||||||
raise NotWorkflowAppError()
|
raise NotWorkflowAppError()
|
||||||
|
|||||||
@ -1,4 +1,6 @@
|
|||||||
|
from collections.abc import Callable
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
|
from typing import Concatenate, Optional, ParamSpec, TypeVar
|
||||||
|
|
||||||
from flask_login import current_user
|
from flask_login import current_user
|
||||||
from flask_restx import Resource
|
from flask_restx import Resource
|
||||||
@ -13,19 +15,15 @@ from services.app_service import AppService
|
|||||||
from services.enterprise.enterprise_service import EnterpriseService
|
from services.enterprise.enterprise_service import EnterpriseService
|
||||||
from services.feature_service import FeatureService
|
from services.feature_service import FeatureService
|
||||||
|
|
||||||
|
P = ParamSpec("P")
|
||||||
|
R = TypeVar("R")
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
def installed_app_required(view=None):
|
|
||||||
def decorator(view):
|
def installed_app_required(view: Optional[Callable[Concatenate[InstalledApp, P], R]] = None):
|
||||||
|
def decorator(view: Callable[Concatenate[InstalledApp, P], R]):
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(*args, **kwargs):
|
def decorated(installed_app_id: str, *args: P.args, **kwargs: P.kwargs):
|
||||||
if not kwargs.get("installed_app_id"):
|
|
||||||
raise ValueError("missing installed_app_id in path parameters")
|
|
||||||
|
|
||||||
installed_app_id = kwargs.get("installed_app_id")
|
|
||||||
installed_app_id = str(installed_app_id)
|
|
||||||
|
|
||||||
del kwargs["installed_app_id"]
|
|
||||||
|
|
||||||
installed_app = (
|
installed_app = (
|
||||||
db.session.query(InstalledApp)
|
db.session.query(InstalledApp)
|
||||||
.where(
|
.where(
|
||||||
@ -52,10 +50,10 @@ def installed_app_required(view=None):
|
|||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
def user_allowed_to_access_app(view=None):
|
def user_allowed_to_access_app(view: Optional[Callable[Concatenate[InstalledApp, P], R]] = None):
|
||||||
def decorator(view):
|
def decorator(view: Callable[Concatenate[InstalledApp, P], R]):
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(installed_app: InstalledApp, *args, **kwargs):
|
def decorated(installed_app: InstalledApp, *args: P.args, **kwargs: P.kwargs):
|
||||||
feature = FeatureService.get_system_features()
|
feature = FeatureService.get_system_features()
|
||||||
if feature.webapp_auth.enabled:
|
if feature.webapp_auth.enabled:
|
||||||
app_id = installed_app.app_id
|
app_id = installed_app.app_id
|
||||||
|
|||||||
@ -1,4 +1,6 @@
|
|||||||
|
from collections.abc import Callable
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
|
from typing import ParamSpec, TypeVar
|
||||||
|
|
||||||
from flask_login import current_user
|
from flask_login import current_user
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
@ -7,14 +9,17 @@ from werkzeug.exceptions import Forbidden
|
|||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.account import TenantPluginPermission
|
from models.account import TenantPluginPermission
|
||||||
|
|
||||||
|
P = ParamSpec("P")
|
||||||
|
R = TypeVar("R")
|
||||||
|
|
||||||
|
|
||||||
def plugin_permission_required(
|
def plugin_permission_required(
|
||||||
install_required: bool = False,
|
install_required: bool = False,
|
||||||
debug_required: bool = False,
|
debug_required: bool = False,
|
||||||
):
|
):
|
||||||
def interceptor(view):
|
def interceptor(view: Callable[P, R]):
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(*args, **kwargs):
|
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||||
user = current_user
|
user = current_user
|
||||||
tenant_id = user.current_tenant_id
|
tenant_id = user.current_tenant_id
|
||||||
|
|
||||||
|
|||||||
@ -2,7 +2,9 @@ import contextlib
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
|
from collections.abc import Callable
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
|
from typing import ParamSpec, TypeVar
|
||||||
|
|
||||||
from flask import abort, request
|
from flask import abort, request
|
||||||
from flask_login import current_user
|
from flask_login import current_user
|
||||||
@ -19,10 +21,13 @@ from services.operation_service import OperationService
|
|||||||
|
|
||||||
from .error import NotInitValidateError, NotSetupError, UnauthorizedAndForceLogout
|
from .error import NotInitValidateError, NotSetupError, UnauthorizedAndForceLogout
|
||||||
|
|
||||||
|
P = ParamSpec("P")
|
||||||
|
R = TypeVar("R")
|
||||||
|
|
||||||
def account_initialization_required(view):
|
|
||||||
|
def account_initialization_required(view: Callable[P, R]):
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(*args, **kwargs):
|
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||||
# check account initialization
|
# check account initialization
|
||||||
account = current_user
|
account = current_user
|
||||||
|
|
||||||
@ -34,9 +39,9 @@ def account_initialization_required(view):
|
|||||||
return decorated
|
return decorated
|
||||||
|
|
||||||
|
|
||||||
def only_edition_cloud(view):
|
def only_edition_cloud(view: Callable[P, R]):
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(*args, **kwargs):
|
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||||
if dify_config.EDITION != "CLOUD":
|
if dify_config.EDITION != "CLOUD":
|
||||||
abort(404)
|
abort(404)
|
||||||
|
|
||||||
@ -45,9 +50,9 @@ def only_edition_cloud(view):
|
|||||||
return decorated
|
return decorated
|
||||||
|
|
||||||
|
|
||||||
def only_edition_enterprise(view):
|
def only_edition_enterprise(view: Callable[P, R]):
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(*args, **kwargs):
|
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||||
if not dify_config.ENTERPRISE_ENABLED:
|
if not dify_config.ENTERPRISE_ENABLED:
|
||||||
abort(404)
|
abort(404)
|
||||||
|
|
||||||
@ -56,9 +61,9 @@ def only_edition_enterprise(view):
|
|||||||
return decorated
|
return decorated
|
||||||
|
|
||||||
|
|
||||||
def only_edition_self_hosted(view):
|
def only_edition_self_hosted(view: Callable[P, R]):
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(*args, **kwargs):
|
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||||
if dify_config.EDITION != "SELF_HOSTED":
|
if dify_config.EDITION != "SELF_HOSTED":
|
||||||
abort(404)
|
abort(404)
|
||||||
|
|
||||||
@ -67,9 +72,9 @@ def only_edition_self_hosted(view):
|
|||||||
return decorated
|
return decorated
|
||||||
|
|
||||||
|
|
||||||
def cloud_edition_billing_enabled(view):
|
def cloud_edition_billing_enabled(view: Callable[P, R]):
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(*args, **kwargs):
|
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||||
features = FeatureService.get_features(current_user.current_tenant_id)
|
features = FeatureService.get_features(current_user.current_tenant_id)
|
||||||
if not features.billing.enabled:
|
if not features.billing.enabled:
|
||||||
abort(403, "Billing feature is not enabled.")
|
abort(403, "Billing feature is not enabled.")
|
||||||
@ -79,9 +84,9 @@ def cloud_edition_billing_enabled(view):
|
|||||||
|
|
||||||
|
|
||||||
def cloud_edition_billing_resource_check(resource: str):
|
def cloud_edition_billing_resource_check(resource: str):
|
||||||
def interceptor(view):
|
def interceptor(view: Callable[P, R]):
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(*args, **kwargs):
|
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||||
features = FeatureService.get_features(current_user.current_tenant_id)
|
features = FeatureService.get_features(current_user.current_tenant_id)
|
||||||
if features.billing.enabled:
|
if features.billing.enabled:
|
||||||
members = features.members
|
members = features.members
|
||||||
@ -120,9 +125,9 @@ def cloud_edition_billing_resource_check(resource: str):
|
|||||||
|
|
||||||
|
|
||||||
def cloud_edition_billing_knowledge_limit_check(resource: str):
|
def cloud_edition_billing_knowledge_limit_check(resource: str):
|
||||||
def interceptor(view):
|
def interceptor(view: Callable[P, R]):
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(*args, **kwargs):
|
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||||
features = FeatureService.get_features(current_user.current_tenant_id)
|
features = FeatureService.get_features(current_user.current_tenant_id)
|
||||||
if features.billing.enabled:
|
if features.billing.enabled:
|
||||||
if resource == "add_segment":
|
if resource == "add_segment":
|
||||||
@ -142,9 +147,9 @@ def cloud_edition_billing_knowledge_limit_check(resource: str):
|
|||||||
|
|
||||||
|
|
||||||
def cloud_edition_billing_rate_limit_check(resource: str):
|
def cloud_edition_billing_rate_limit_check(resource: str):
|
||||||
def interceptor(view):
|
def interceptor(view: Callable[P, R]):
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(*args, **kwargs):
|
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||||
if resource == "knowledge":
|
if resource == "knowledge":
|
||||||
knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(current_user.current_tenant_id)
|
knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(current_user.current_tenant_id)
|
||||||
if knowledge_rate_limit.enabled:
|
if knowledge_rate_limit.enabled:
|
||||||
@ -176,9 +181,9 @@ def cloud_edition_billing_rate_limit_check(resource: str):
|
|||||||
return interceptor
|
return interceptor
|
||||||
|
|
||||||
|
|
||||||
def cloud_utm_record(view):
|
def cloud_utm_record(view: Callable[P, R]):
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(*args, **kwargs):
|
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||||
with contextlib.suppress(Exception):
|
with contextlib.suppress(Exception):
|
||||||
features = FeatureService.get_features(current_user.current_tenant_id)
|
features = FeatureService.get_features(current_user.current_tenant_id)
|
||||||
|
|
||||||
@ -194,9 +199,9 @@ def cloud_utm_record(view):
|
|||||||
return decorated
|
return decorated
|
||||||
|
|
||||||
|
|
||||||
def setup_required(view):
|
def setup_required(view: Callable[P, R]):
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(*args, **kwargs):
|
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||||
# check setup
|
# check setup
|
||||||
if (
|
if (
|
||||||
dify_config.EDITION == "SELF_HOSTED"
|
dify_config.EDITION == "SELF_HOSTED"
|
||||||
@ -212,9 +217,9 @@ def setup_required(view):
|
|||||||
return decorated
|
return decorated
|
||||||
|
|
||||||
|
|
||||||
def enterprise_license_required(view):
|
def enterprise_license_required(view: Callable[P, R]):
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(*args, **kwargs):
|
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||||
settings = FeatureService.get_system_features()
|
settings = FeatureService.get_system_features()
|
||||||
if settings.license.status in [LicenseStatus.INACTIVE, LicenseStatus.EXPIRED, LicenseStatus.LOST]:
|
if settings.license.status in [LicenseStatus.INACTIVE, LicenseStatus.EXPIRED, LicenseStatus.LOST]:
|
||||||
raise UnauthorizedAndForceLogout("Your license is invalid. Please contact your administrator.")
|
raise UnauthorizedAndForceLogout("Your license is invalid. Please contact your administrator.")
|
||||||
@ -224,9 +229,9 @@ def enterprise_license_required(view):
|
|||||||
return decorated
|
return decorated
|
||||||
|
|
||||||
|
|
||||||
def email_password_login_enabled(view):
|
def email_password_login_enabled(view: Callable[P, R]):
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(*args, **kwargs):
|
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||||
features = FeatureService.get_system_features()
|
features = FeatureService.get_system_features()
|
||||||
if features.enable_email_password_login:
|
if features.enable_email_password_login:
|
||||||
return view(*args, **kwargs)
|
return view(*args, **kwargs)
|
||||||
@ -237,9 +242,9 @@ def email_password_login_enabled(view):
|
|||||||
return decorated
|
return decorated
|
||||||
|
|
||||||
|
|
||||||
def enable_change_email(view):
|
def enable_change_email(view: Callable[P, R]):
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(*args, **kwargs):
|
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||||
features = FeatureService.get_system_features()
|
features = FeatureService.get_system_features()
|
||||||
if features.enable_change_email:
|
if features.enable_change_email:
|
||||||
return view(*args, **kwargs)
|
return view(*args, **kwargs)
|
||||||
@ -250,9 +255,9 @@ def enable_change_email(view):
|
|||||||
return decorated
|
return decorated
|
||||||
|
|
||||||
|
|
||||||
def is_allow_transfer_owner(view):
|
def is_allow_transfer_owner(view: Callable[P, R]):
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(*args, **kwargs):
|
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||||
features = FeatureService.get_features(current_user.current_tenant_id)
|
features = FeatureService.get_features(current_user.current_tenant_id)
|
||||||
if features.is_allow_transfer_workspace:
|
if features.is_allow_transfer_workspace:
|
||||||
return view(*args, **kwargs)
|
return view(*args, **kwargs)
|
||||||
|
|||||||
@ -99,7 +99,7 @@ class MCPAppApi(Resource):
|
|||||||
|
|
||||||
return mcp_server, app
|
return mcp_server, app
|
||||||
|
|
||||||
def _validate_server_status(self, mcp_server: AppMCPServer) -> None:
|
def _validate_server_status(self, mcp_server: AppMCPServer):
|
||||||
"""Validate MCP server status"""
|
"""Validate MCP server status"""
|
||||||
if mcp_server.status != AppMCPServerStatus.ACTIVE:
|
if mcp_server.status != AppMCPServerStatus.ACTIVE:
|
||||||
raise MCPRequestError(mcp_types.INVALID_REQUEST, "Server is not active")
|
raise MCPRequestError(mcp_types.INVALID_REQUEST, "Server is not active")
|
||||||
|
|||||||
@ -440,7 +440,7 @@ class DatasetChildChunkApi(DatasetApiResource):
|
|||||||
raise NotFound("Segment not found.")
|
raise NotFound("Segment not found.")
|
||||||
|
|
||||||
# validate segment belongs to the specified document
|
# validate segment belongs to the specified document
|
||||||
if segment.document_id != document_id:
|
if str(segment.document_id) != str(document_id):
|
||||||
raise NotFound("Document not found.")
|
raise NotFound("Document not found.")
|
||||||
|
|
||||||
# check child chunk
|
# check child chunk
|
||||||
@ -451,7 +451,7 @@ class DatasetChildChunkApi(DatasetApiResource):
|
|||||||
raise NotFound("Child chunk not found.")
|
raise NotFound("Child chunk not found.")
|
||||||
|
|
||||||
# validate child chunk belongs to the specified segment
|
# validate child chunk belongs to the specified segment
|
||||||
if child_chunk.segment_id != segment.id:
|
if str(child_chunk.segment_id) != str(segment.id):
|
||||||
raise NotFound("Child chunk not found.")
|
raise NotFound("Child chunk not found.")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -500,7 +500,7 @@ class DatasetChildChunkApi(DatasetApiResource):
|
|||||||
raise NotFound("Segment not found.")
|
raise NotFound("Segment not found.")
|
||||||
|
|
||||||
# validate segment belongs to the specified document
|
# validate segment belongs to the specified document
|
||||||
if segment.document_id != document_id:
|
if str(segment.document_id) != str(document_id):
|
||||||
raise NotFound("Segment not found.")
|
raise NotFound("Segment not found.")
|
||||||
|
|
||||||
# get child chunk
|
# get child chunk
|
||||||
@ -511,7 +511,7 @@ class DatasetChildChunkApi(DatasetApiResource):
|
|||||||
raise NotFound("Child chunk not found.")
|
raise NotFound("Child chunk not found.")
|
||||||
|
|
||||||
# validate child chunk belongs to the specified segment
|
# validate child chunk belongs to the specified segment
|
||||||
if child_chunk.segment_id != segment.id:
|
if str(child_chunk.segment_id) != str(segment.id):
|
||||||
raise NotFound("Child chunk not found.")
|
raise NotFound("Child chunk not found.")
|
||||||
|
|
||||||
# validate args
|
# validate args
|
||||||
|
|||||||
@ -3,7 +3,7 @@ from collections.abc import Callable
|
|||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from enum import StrEnum, auto
|
from enum import StrEnum, auto
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import Optional
|
from typing import Optional, ParamSpec, TypeVar
|
||||||
|
|
||||||
from flask import current_app, request
|
from flask import current_app, request
|
||||||
from flask_login import user_logged_in
|
from flask_login import user_logged_in
|
||||||
@ -22,6 +22,9 @@ from models.dataset import Dataset, RateLimitLog
|
|||||||
from models.model import ApiToken, App, EndUser
|
from models.model import ApiToken, App, EndUser
|
||||||
from services.feature_service import FeatureService
|
from services.feature_service import FeatureService
|
||||||
|
|
||||||
|
P = ParamSpec("P")
|
||||||
|
R = TypeVar("R")
|
||||||
|
|
||||||
|
|
||||||
class WhereisUserArg(StrEnum):
|
class WhereisUserArg(StrEnum):
|
||||||
"""
|
"""
|
||||||
@ -60,27 +63,6 @@ def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optio
|
|||||||
if tenant.status == TenantStatus.ARCHIVE:
|
if tenant.status == TenantStatus.ARCHIVE:
|
||||||
raise Forbidden("The workspace's status is archived.")
|
raise Forbidden("The workspace's status is archived.")
|
||||||
|
|
||||||
tenant_account_join = (
|
|
||||||
db.session.query(Tenant, TenantAccountJoin)
|
|
||||||
.where(Tenant.id == api_token.tenant_id)
|
|
||||||
.where(TenantAccountJoin.tenant_id == Tenant.id)
|
|
||||||
.where(TenantAccountJoin.role.in_(["owner"]))
|
|
||||||
.where(Tenant.status == TenantStatus.NORMAL)
|
|
||||||
.one_or_none()
|
|
||||||
) # TODO: only owner information is required, so only one is returned.
|
|
||||||
if tenant_account_join:
|
|
||||||
tenant, ta = tenant_account_join
|
|
||||||
account = db.session.query(Account).where(Account.id == ta.account_id).first()
|
|
||||||
# Login admin
|
|
||||||
if account:
|
|
||||||
account.current_tenant = tenant
|
|
||||||
current_app.login_manager._update_request_context_with_user(account) # type: ignore
|
|
||||||
user_logged_in.send(current_app._get_current_object(), user=_get_user()) # type: ignore
|
|
||||||
else:
|
|
||||||
raise Unauthorized("Tenant owner account does not exist.")
|
|
||||||
else:
|
|
||||||
raise Unauthorized("Tenant does not exist.")
|
|
||||||
|
|
||||||
kwargs["app_model"] = app_model
|
kwargs["app_model"] = app_model
|
||||||
|
|
||||||
if fetch_user_arg:
|
if fetch_user_arg:
|
||||||
@ -118,8 +100,8 @@ def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optio
|
|||||||
|
|
||||||
|
|
||||||
def cloud_edition_billing_resource_check(resource: str, api_token_type: str):
|
def cloud_edition_billing_resource_check(resource: str, api_token_type: str):
|
||||||
def interceptor(view):
|
def interceptor(view: Callable[P, R]):
|
||||||
def decorated(*args, **kwargs):
|
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||||
api_token = validate_and_get_api_token(api_token_type)
|
api_token = validate_and_get_api_token(api_token_type)
|
||||||
features = FeatureService.get_features(api_token.tenant_id)
|
features = FeatureService.get_features(api_token.tenant_id)
|
||||||
|
|
||||||
@ -148,9 +130,9 @@ def cloud_edition_billing_resource_check(resource: str, api_token_type: str):
|
|||||||
|
|
||||||
|
|
||||||
def cloud_edition_billing_knowledge_limit_check(resource: str, api_token_type: str):
|
def cloud_edition_billing_knowledge_limit_check(resource: str, api_token_type: str):
|
||||||
def interceptor(view):
|
def interceptor(view: Callable[P, R]):
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(*args, **kwargs):
|
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||||
api_token = validate_and_get_api_token(api_token_type)
|
api_token = validate_and_get_api_token(api_token_type)
|
||||||
features = FeatureService.get_features(api_token.tenant_id)
|
features = FeatureService.get_features(api_token.tenant_id)
|
||||||
if features.billing.enabled:
|
if features.billing.enabled:
|
||||||
@ -170,9 +152,9 @@ def cloud_edition_billing_knowledge_limit_check(resource: str, api_token_type: s
|
|||||||
|
|
||||||
|
|
||||||
def cloud_edition_billing_rate_limit_check(resource: str, api_token_type: str):
|
def cloud_edition_billing_rate_limit_check(resource: str, api_token_type: str):
|
||||||
def interceptor(view):
|
def interceptor(view: Callable[P, R]):
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(*args, **kwargs):
|
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||||
api_token = validate_and_get_api_token(api_token_type)
|
api_token = validate_and_get_api_token(api_token_type)
|
||||||
|
|
||||||
if resource == "knowledge":
|
if resource == "knowledge":
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
from datetime import UTC, datetime
|
from datetime import UTC, datetime
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
|
from typing import ParamSpec, TypeVar
|
||||||
|
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask_restx import Resource
|
from flask_restx import Resource
|
||||||
@ -15,6 +16,9 @@ from services.enterprise.enterprise_service import EnterpriseService, WebAppSett
|
|||||||
from services.feature_service import FeatureService
|
from services.feature_service import FeatureService
|
||||||
from services.webapp_auth_service import WebAppAuthService
|
from services.webapp_auth_service import WebAppAuthService
|
||||||
|
|
||||||
|
P = ParamSpec("P")
|
||||||
|
R = TypeVar("R")
|
||||||
|
|
||||||
|
|
||||||
def validate_jwt_token(view=None):
|
def validate_jwt_token(view=None):
|
||||||
def decorator(view):
|
def decorator(view):
|
||||||
|
|||||||
@ -62,7 +62,7 @@ class BaseAgentRunner(AppRunner):
|
|||||||
model_instance: ModelInstance,
|
model_instance: ModelInstance,
|
||||||
memory: Optional[TokenBufferMemory] = None,
|
memory: Optional[TokenBufferMemory] = None,
|
||||||
prompt_messages: Optional[list[PromptMessage]] = None,
|
prompt_messages: Optional[list[PromptMessage]] = None,
|
||||||
) -> None:
|
):
|
||||||
self.tenant_id = tenant_id
|
self.tenant_id = tenant_id
|
||||||
self.application_generate_entity = application_generate_entity
|
self.application_generate_entity = application_generate_entity
|
||||||
self.conversation = conversation
|
self.conversation = conversation
|
||||||
|
|||||||
@ -338,7 +338,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
|||||||
|
|
||||||
return instruction
|
return instruction
|
||||||
|
|
||||||
def _init_react_state(self, query) -> None:
|
def _init_react_state(self, query):
|
||||||
"""
|
"""
|
||||||
init agent scratchpad
|
init agent scratchpad
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -41,7 +41,7 @@ class AgentScratchpadUnit(BaseModel):
|
|||||||
action_name: str
|
action_name: str
|
||||||
action_input: Union[dict, str]
|
action_input: Union[dict, str]
|
||||||
|
|
||||||
def to_dict(self) -> dict:
|
def to_dict(self):
|
||||||
"""
|
"""
|
||||||
Convert to dictionary.
|
Convert to dictionary.
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -158,7 +158,7 @@ class DatasetConfigManager:
|
|||||||
return config, ["agent_mode", "dataset_configs", "dataset_query_variable"]
|
return config, ["agent_mode", "dataset_configs", "dataset_query_variable"]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def extract_dataset_config_for_legacy_compatibility(cls, tenant_id: str, app_mode: AppMode, config: dict) -> dict:
|
def extract_dataset_config_for_legacy_compatibility(cls, tenant_id: str, app_mode: AppMode, config: dict):
|
||||||
"""
|
"""
|
||||||
Extract dataset config for legacy compatibility
|
Extract dataset config for legacy compatibility
|
||||||
|
|
||||||
|
|||||||
@ -105,7 +105,7 @@ class ModelConfigManager:
|
|||||||
return dict(config), ["model"]
|
return dict(config), ["model"]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_model_completion_params(cls, cp: dict) -> dict:
|
def validate_model_completion_params(cls, cp: dict):
|
||||||
# model.completion_params
|
# model.completion_params
|
||||||
if not isinstance(cp, dict):
|
if not isinstance(cp, dict):
|
||||||
raise ValueError("model.completion_params must be of object type")
|
raise ValueError("model.completion_params must be of object type")
|
||||||
|
|||||||
@ -122,7 +122,7 @@ class PromptTemplateConfigManager:
|
|||||||
return config, ["prompt_type", "pre_prompt", "chat_prompt_config", "completion_prompt_config"]
|
return config, ["prompt_type", "pre_prompt", "chat_prompt_config", "completion_prompt_config"]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_post_prompt_and_set_defaults(cls, config: dict) -> dict:
|
def validate_post_prompt_and_set_defaults(cls, config: dict):
|
||||||
"""
|
"""
|
||||||
Validate post_prompt and set defaults for prompt feature
|
Validate post_prompt and set defaults for prompt feature
|
||||||
|
|
||||||
|
|||||||
@ -41,7 +41,7 @@ class AdvancedChatAppConfigManager(BaseAppConfigManager):
|
|||||||
return app_config
|
return app_config
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def config_validate(cls, tenant_id: str, config: dict, only_structure_validate: bool = False) -> dict:
|
def config_validate(cls, tenant_id: str, config: dict, only_structure_validate: bool = False):
|
||||||
"""
|
"""
|
||||||
Validate for advanced chat app model config
|
Validate for advanced chat app model config
|
||||||
|
|
||||||
|
|||||||
@ -481,7 +481,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||||||
message_id: str,
|
message_id: str,
|
||||||
context: contextvars.Context,
|
context: contextvars.Context,
|
||||||
variable_loader: VariableLoader,
|
variable_loader: VariableLoader,
|
||||||
) -> None:
|
):
|
||||||
"""
|
"""
|
||||||
Generate worker in a new thread.
|
Generate worker in a new thread.
|
||||||
:param flask_app: Flask app
|
:param flask_app: Flask app
|
||||||
|
|||||||
@ -55,7 +55,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
|||||||
workflow: Workflow,
|
workflow: Workflow,
|
||||||
system_user_id: str,
|
system_user_id: str,
|
||||||
app: App,
|
app: App,
|
||||||
) -> None:
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
queue_manager=queue_manager,
|
queue_manager=queue_manager,
|
||||||
variable_loader=variable_loader,
|
variable_loader=variable_loader,
|
||||||
@ -69,7 +69,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
|||||||
self.system_user_id = system_user_id
|
self.system_user_id = system_user_id
|
||||||
self._app = app
|
self._app = app
|
||||||
|
|
||||||
def run(self) -> None:
|
def run(self):
|
||||||
app_config = self.application_generate_entity.app_config
|
app_config = self.application_generate_entity.app_config
|
||||||
app_config = cast(AdvancedChatAppConfig, app_config)
|
app_config = cast(AdvancedChatAppConfig, app_config)
|
||||||
|
|
||||||
@ -184,6 +184,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
|||||||
),
|
),
|
||||||
invoke_from=self.application_generate_entity.invoke_from,
|
invoke_from=self.application_generate_entity.invoke_from,
|
||||||
call_depth=self.application_generate_entity.call_depth,
|
call_depth=self.application_generate_entity.call_depth,
|
||||||
|
variable_pool=variable_pool,
|
||||||
graph_runtime_state=graph_runtime_state,
|
graph_runtime_state=graph_runtime_state,
|
||||||
command_channel=command_channel,
|
command_channel=command_channel,
|
||||||
)
|
)
|
||||||
@ -238,7 +239,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
|||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _complete_with_stream_output(self, text: str, stopped_by: QueueStopEvent.StopBy) -> None:
|
def _complete_with_stream_output(self, text: str, stopped_by: QueueStopEvent.StopBy):
|
||||||
"""
|
"""
|
||||||
Direct output
|
Direct output
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -96,7 +96,7 @@ class AdvancedChatAppGenerateTaskPipeline:
|
|||||||
workflow_execution_repository: WorkflowExecutionRepository,
|
workflow_execution_repository: WorkflowExecutionRepository,
|
||||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||||
draft_var_saver_factory: DraftVariableSaverFactory,
|
draft_var_saver_factory: DraftVariableSaverFactory,
|
||||||
) -> None:
|
):
|
||||||
self._base_task_pipeline = BasedGenerateTaskPipeline(
|
self._base_task_pipeline = BasedGenerateTaskPipeline(
|
||||||
application_generate_entity=application_generate_entity,
|
application_generate_entity=application_generate_entity,
|
||||||
queue_manager=queue_manager,
|
queue_manager=queue_manager,
|
||||||
@ -284,7 +284,7 @@ class AdvancedChatAppGenerateTaskPipeline:
|
|||||||
session.rollback()
|
session.rollback()
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def _ensure_workflow_initialized(self) -> None:
|
def _ensure_workflow_initialized(self):
|
||||||
"""Fluent validation for workflow state."""
|
"""Fluent validation for workflow state."""
|
||||||
if not self._workflow_run_id:
|
if not self._workflow_run_id:
|
||||||
raise ValueError("workflow run not initialized.")
|
raise ValueError("workflow run not initialized.")
|
||||||
@ -835,7 +835,7 @@ class AdvancedChatAppGenerateTaskPipeline:
|
|||||||
if self._conversation_name_generate_thread:
|
if self._conversation_name_generate_thread:
|
||||||
self._conversation_name_generate_thread.join()
|
self._conversation_name_generate_thread.join()
|
||||||
|
|
||||||
def _save_message(self, *, session: Session, graph_runtime_state: Optional[GraphRuntimeState] = None) -> None:
|
def _save_message(self, *, session: Session, graph_runtime_state: Optional[GraphRuntimeState] = None):
|
||||||
message = self._get_message(session=session)
|
message = self._get_message(session=session)
|
||||||
|
|
||||||
# If there are assistant files, remove markdown image links from answer
|
# If there are assistant files, remove markdown image links from answer
|
||||||
|
|||||||
@ -86,7 +86,7 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
|
|||||||
return app_config
|
return app_config
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def config_validate(cls, tenant_id: str, config: Mapping[str, Any]) -> dict:
|
def config_validate(cls, tenant_id: str, config: Mapping[str, Any]):
|
||||||
"""
|
"""
|
||||||
Validate for agent chat app model config
|
Validate for agent chat app model config
|
||||||
|
|
||||||
|
|||||||
@ -222,7 +222,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
|||||||
queue_manager: AppQueueManager,
|
queue_manager: AppQueueManager,
|
||||||
conversation_id: str,
|
conversation_id: str,
|
||||||
message_id: str,
|
message_id: str,
|
||||||
) -> None:
|
):
|
||||||
"""
|
"""
|
||||||
Generate worker in a new thread.
|
Generate worker in a new thread.
|
||||||
:param flask_app: Flask app
|
:param flask_app: Flask app
|
||||||
|
|||||||
@ -35,7 +35,7 @@ class AgentChatAppRunner(AppRunner):
|
|||||||
queue_manager: AppQueueManager,
|
queue_manager: AppQueueManager,
|
||||||
conversation: Conversation,
|
conversation: Conversation,
|
||||||
message: Message,
|
message: Message,
|
||||||
) -> None:
|
):
|
||||||
"""
|
"""
|
||||||
Run assistant application
|
Run assistant application
|
||||||
:param application_generate_entity: application generate entity
|
:param application_generate_entity: application generate entity
|
||||||
|
|||||||
@ -16,7 +16,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
|||||||
_blocking_response_type = ChatbotAppBlockingResponse
|
_blocking_response_type = ChatbotAppBlockingResponse
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict: # type: ignore[override]
|
def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse): # type: ignore[override]
|
||||||
"""
|
"""
|
||||||
Convert blocking full response.
|
Convert blocking full response.
|
||||||
:param blocking_response: blocking response
|
:param blocking_response: blocking response
|
||||||
@ -37,7 +37,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
|||||||
return response
|
return response
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict: # type: ignore[override]
|
def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse): # type: ignore[override]
|
||||||
"""
|
"""
|
||||||
Convert blocking simple response.
|
Convert blocking simple response.
|
||||||
:param blocking_response: blocking response
|
:param blocking_response: blocking response
|
||||||
|
|||||||
@ -94,7 +94,7 @@ class AppGenerateResponseConverter(ABC):
|
|||||||
return metadata
|
return metadata
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _error_to_stream_response(cls, e: Exception) -> dict:
|
def _error_to_stream_response(cls, e: Exception):
|
||||||
"""
|
"""
|
||||||
Error to stream response.
|
Error to stream response.
|
||||||
:param e: exception
|
:param e: exception
|
||||||
|
|||||||
@ -158,7 +158,7 @@ class BaseAppGenerator:
|
|||||||
|
|
||||||
return value
|
return value
|
||||||
|
|
||||||
def _sanitize_value(self, value: Any) -> Any:
|
def _sanitize_value(self, value: Any):
|
||||||
if isinstance(value, str):
|
if isinstance(value, str):
|
||||||
return value.replace("\x00", "")
|
return value.replace("\x00", "")
|
||||||
return value
|
return value
|
||||||
|
|||||||
@ -25,7 +25,7 @@ class PublishFrom(IntEnum):
|
|||||||
|
|
||||||
|
|
||||||
class AppQueueManager:
|
class AppQueueManager:
|
||||||
def __init__(self, task_id: str, user_id: str, invoke_from: InvokeFrom) -> None:
|
def __init__(self, task_id: str, user_id: str, invoke_from: InvokeFrom):
|
||||||
if not user_id:
|
if not user_id:
|
||||||
raise ValueError("user is required")
|
raise ValueError("user is required")
|
||||||
|
|
||||||
@ -73,14 +73,14 @@ class AppQueueManager:
|
|||||||
self.publish(QueuePingEvent(), PublishFrom.TASK_PIPELINE)
|
self.publish(QueuePingEvent(), PublishFrom.TASK_PIPELINE)
|
||||||
last_ping_time = elapsed_time // 10
|
last_ping_time = elapsed_time // 10
|
||||||
|
|
||||||
def stop_listen(self) -> None:
|
def stop_listen(self):
|
||||||
"""
|
"""
|
||||||
Stop listen to queue
|
Stop listen to queue
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
self._q.put(None)
|
self._q.put(None)
|
||||||
|
|
||||||
def publish_error(self, e, pub_from: PublishFrom) -> None:
|
def publish_error(self, e, pub_from: PublishFrom):
|
||||||
"""
|
"""
|
||||||
Publish error
|
Publish error
|
||||||
:param e: error
|
:param e: error
|
||||||
@ -89,7 +89,7 @@ class AppQueueManager:
|
|||||||
"""
|
"""
|
||||||
self.publish(QueueErrorEvent(error=e), pub_from)
|
self.publish(QueueErrorEvent(error=e), pub_from)
|
||||||
|
|
||||||
def publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None:
|
def publish(self, event: AppQueueEvent, pub_from: PublishFrom):
|
||||||
"""
|
"""
|
||||||
Publish event to queue
|
Publish event to queue
|
||||||
:param event:
|
:param event:
|
||||||
@ -100,7 +100,7 @@ class AppQueueManager:
|
|||||||
self._publish(event, pub_from)
|
self._publish(event, pub_from)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None:
|
def _publish(self, event: AppQueueEvent, pub_from: PublishFrom):
|
||||||
"""
|
"""
|
||||||
Publish event to queue
|
Publish event to queue
|
||||||
:param event:
|
:param event:
|
||||||
@ -110,7 +110,7 @@ class AppQueueManager:
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def set_stop_flag(cls, task_id: str, invoke_from: InvokeFrom, user_id: str) -> None:
|
def set_stop_flag(cls, task_id: str, invoke_from: InvokeFrom, user_id: str):
|
||||||
"""
|
"""
|
||||||
Set task stop flag
|
Set task stop flag
|
||||||
:return:
|
:return:
|
||||||
|
|||||||
@ -162,7 +162,7 @@ class AppRunner:
|
|||||||
text: str,
|
text: str,
|
||||||
stream: bool,
|
stream: bool,
|
||||||
usage: Optional[LLMUsage] = None,
|
usage: Optional[LLMUsage] = None,
|
||||||
) -> None:
|
):
|
||||||
"""
|
"""
|
||||||
Direct output
|
Direct output
|
||||||
:param queue_manager: application queue manager
|
:param queue_manager: application queue manager
|
||||||
@ -204,7 +204,7 @@ class AppRunner:
|
|||||||
queue_manager: AppQueueManager,
|
queue_manager: AppQueueManager,
|
||||||
stream: bool,
|
stream: bool,
|
||||||
agent: bool = False,
|
agent: bool = False,
|
||||||
) -> None:
|
):
|
||||||
"""
|
"""
|
||||||
Handle invoke result
|
Handle invoke result
|
||||||
:param invoke_result: invoke result
|
:param invoke_result: invoke result
|
||||||
@ -220,9 +220,7 @@ class AppRunner:
|
|||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"unsupported invoke result type: {type(invoke_result)}")
|
raise NotImplementedError(f"unsupported invoke result type: {type(invoke_result)}")
|
||||||
|
|
||||||
def _handle_invoke_result_direct(
|
def _handle_invoke_result_direct(self, invoke_result: LLMResult, queue_manager: AppQueueManager, agent: bool):
|
||||||
self, invoke_result: LLMResult, queue_manager: AppQueueManager, agent: bool
|
|
||||||
) -> None:
|
|
||||||
"""
|
"""
|
||||||
Handle invoke result direct
|
Handle invoke result direct
|
||||||
:param invoke_result: invoke result
|
:param invoke_result: invoke result
|
||||||
@ -239,7 +237,7 @@ class AppRunner:
|
|||||||
|
|
||||||
def _handle_invoke_result_stream(
|
def _handle_invoke_result_stream(
|
||||||
self, invoke_result: Generator[LLMResultChunk, None, None], queue_manager: AppQueueManager, agent: bool
|
self, invoke_result: Generator[LLMResultChunk, None, None], queue_manager: AppQueueManager, agent: bool
|
||||||
) -> None:
|
):
|
||||||
"""
|
"""
|
||||||
Handle invoke result
|
Handle invoke result
|
||||||
:param invoke_result: invoke result
|
:param invoke_result: invoke result
|
||||||
|
|||||||
@ -81,7 +81,7 @@ class ChatAppConfigManager(BaseAppConfigManager):
|
|||||||
return app_config
|
return app_config
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def config_validate(cls, tenant_id: str, config: dict) -> dict:
|
def config_validate(cls, tenant_id: str, config: dict):
|
||||||
"""
|
"""
|
||||||
Validate for chat app model config
|
Validate for chat app model config
|
||||||
|
|
||||||
|
|||||||
@ -211,7 +211,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
|||||||
queue_manager: AppQueueManager,
|
queue_manager: AppQueueManager,
|
||||||
conversation_id: str,
|
conversation_id: str,
|
||||||
message_id: str,
|
message_id: str,
|
||||||
) -> None:
|
):
|
||||||
"""
|
"""
|
||||||
Generate worker in a new thread.
|
Generate worker in a new thread.
|
||||||
:param flask_app: Flask app
|
:param flask_app: Flask app
|
||||||
|
|||||||
@ -33,7 +33,7 @@ class ChatAppRunner(AppRunner):
|
|||||||
queue_manager: AppQueueManager,
|
queue_manager: AppQueueManager,
|
||||||
conversation: Conversation,
|
conversation: Conversation,
|
||||||
message: Message,
|
message: Message,
|
||||||
) -> None:
|
):
|
||||||
"""
|
"""
|
||||||
Run application
|
Run application
|
||||||
:param application_generate_entity: application generate entity
|
:param application_generate_entity: application generate entity
|
||||||
|
|||||||
@ -16,7 +16,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
|||||||
_blocking_response_type = ChatbotAppBlockingResponse
|
_blocking_response_type = ChatbotAppBlockingResponse
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict: # type: ignore[override]
|
def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse): # type: ignore[override]
|
||||||
"""
|
"""
|
||||||
Convert blocking full response.
|
Convert blocking full response.
|
||||||
:param blocking_response: blocking response
|
:param blocking_response: blocking response
|
||||||
@ -37,7 +37,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
|||||||
return response
|
return response
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict: # type: ignore[override]
|
def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse): # type: ignore[override]
|
||||||
"""
|
"""
|
||||||
Convert blocking simple response.
|
Convert blocking simple response.
|
||||||
:param blocking_response: blocking response
|
:param blocking_response: blocking response
|
||||||
|
|||||||
@ -56,7 +56,7 @@ class WorkflowResponseConverter:
|
|||||||
*,
|
*,
|
||||||
application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity],
|
application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity],
|
||||||
user: Union[Account, EndUser],
|
user: Union[Account, EndUser],
|
||||||
) -> None:
|
):
|
||||||
self._application_generate_entity = application_generate_entity
|
self._application_generate_entity = application_generate_entity
|
||||||
self._user = user
|
self._user = user
|
||||||
self._truncator = VariableTruncator.default()
|
self._truncator = VariableTruncator.default()
|
||||||
|
|||||||
@ -66,7 +66,7 @@ class CompletionAppConfigManager(BaseAppConfigManager):
|
|||||||
return app_config
|
return app_config
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def config_validate(cls, tenant_id: str, config: dict) -> dict:
|
def config_validate(cls, tenant_id: str, config: dict):
|
||||||
"""
|
"""
|
||||||
Validate for completion app model config
|
Validate for completion app model config
|
||||||
|
|
||||||
|
|||||||
@ -192,7 +192,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
|||||||
application_generate_entity: CompletionAppGenerateEntity,
|
application_generate_entity: CompletionAppGenerateEntity,
|
||||||
queue_manager: AppQueueManager,
|
queue_manager: AppQueueManager,
|
||||||
message_id: str,
|
message_id: str,
|
||||||
) -> None:
|
):
|
||||||
"""
|
"""
|
||||||
Generate worker in a new thread.
|
Generate worker in a new thread.
|
||||||
:param flask_app: Flask app
|
:param flask_app: Flask app
|
||||||
@ -262,6 +262,9 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
|||||||
raise MessageNotExistsError()
|
raise MessageNotExistsError()
|
||||||
|
|
||||||
current_app_model_config = app_model.app_model_config
|
current_app_model_config = app_model.app_model_config
|
||||||
|
if not current_app_model_config:
|
||||||
|
raise MoreLikeThisDisabledError()
|
||||||
|
|
||||||
more_like_this = current_app_model_config.more_like_this_dict
|
more_like_this = current_app_model_config.more_like_this_dict
|
||||||
|
|
||||||
if not current_app_model_config.more_like_this or more_like_this.get("enabled", False) is False:
|
if not current_app_model_config.more_like_this or more_like_this.get("enabled", False) is False:
|
||||||
|
|||||||
@ -27,7 +27,7 @@ class CompletionAppRunner(AppRunner):
|
|||||||
|
|
||||||
def run(
|
def run(
|
||||||
self, application_generate_entity: CompletionAppGenerateEntity, queue_manager: AppQueueManager, message: Message
|
self, application_generate_entity: CompletionAppGenerateEntity, queue_manager: AppQueueManager, message: Message
|
||||||
) -> None:
|
):
|
||||||
"""
|
"""
|
||||||
Run application
|
Run application
|
||||||
:param application_generate_entity: application generate entity
|
:param application_generate_entity: application generate entity
|
||||||
|
|||||||
@ -16,7 +16,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
|
|||||||
_blocking_response_type = CompletionAppBlockingResponse
|
_blocking_response_type = CompletionAppBlockingResponse
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def convert_blocking_full_response(cls, blocking_response: CompletionAppBlockingResponse) -> dict: # type: ignore[override]
|
def convert_blocking_full_response(cls, blocking_response: CompletionAppBlockingResponse): # type: ignore[override]
|
||||||
"""
|
"""
|
||||||
Convert blocking full response.
|
Convert blocking full response.
|
||||||
:param blocking_response: blocking response
|
:param blocking_response: blocking response
|
||||||
@ -36,7 +36,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
|
|||||||
return response
|
return response
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def convert_blocking_simple_response(cls, blocking_response: CompletionAppBlockingResponse) -> dict: # type: ignore[override]
|
def convert_blocking_simple_response(cls, blocking_response: CompletionAppBlockingResponse): # type: ignore[override]
|
||||||
"""
|
"""
|
||||||
Convert blocking simple response.
|
Convert blocking simple response.
|
||||||
:param blocking_response: blocking response
|
:param blocking_response: blocking response
|
||||||
|
|||||||
@ -14,14 +14,14 @@ from core.app.entities.queue_entities import (
|
|||||||
class MessageBasedAppQueueManager(AppQueueManager):
|
class MessageBasedAppQueueManager(AppQueueManager):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, task_id: str, user_id: str, invoke_from: InvokeFrom, conversation_id: str, app_mode: str, message_id: str
|
self, task_id: str, user_id: str, invoke_from: InvokeFrom, conversation_id: str, app_mode: str, message_id: str
|
||||||
) -> None:
|
):
|
||||||
super().__init__(task_id, user_id, invoke_from)
|
super().__init__(task_id, user_id, invoke_from)
|
||||||
|
|
||||||
self._conversation_id = str(conversation_id)
|
self._conversation_id = str(conversation_id)
|
||||||
self._app_mode = app_mode
|
self._app_mode = app_mode
|
||||||
self._message_id = str(message_id)
|
self._message_id = str(message_id)
|
||||||
|
|
||||||
def _publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None:
|
def _publish(self, event: AppQueueEvent, pub_from: PublishFrom):
|
||||||
"""
|
"""
|
||||||
Publish event to queue
|
Publish event to queue
|
||||||
:param event:
|
:param event:
|
||||||
|
|||||||
@ -35,7 +35,7 @@ class WorkflowAppConfigManager(BaseAppConfigManager):
|
|||||||
return app_config
|
return app_config
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def config_validate(cls, tenant_id: str, config: dict, only_structure_validate: bool = False) -> dict:
|
def config_validate(cls, tenant_id: str, config: dict, only_structure_validate: bool = False):
|
||||||
"""
|
"""
|
||||||
Validate for workflow app model config
|
Validate for workflow app model config
|
||||||
|
|
||||||
|
|||||||
@ -14,12 +14,12 @@ from core.app.entities.queue_entities import (
|
|||||||
|
|
||||||
|
|
||||||
class WorkflowAppQueueManager(AppQueueManager):
|
class WorkflowAppQueueManager(AppQueueManager):
|
||||||
def __init__(self, task_id: str, user_id: str, invoke_from: InvokeFrom, app_mode: str) -> None:
|
def __init__(self, task_id: str, user_id: str, invoke_from: InvokeFrom, app_mode: str):
|
||||||
super().__init__(task_id, user_id, invoke_from)
|
super().__init__(task_id, user_id, invoke_from)
|
||||||
|
|
||||||
self._app_mode = app_mode
|
self._app_mode = app_mode
|
||||||
|
|
||||||
def _publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None:
|
def _publish(self, event: AppQueueEvent, pub_from: PublishFrom):
|
||||||
"""
|
"""
|
||||||
Publish event to queue
|
Publish event to queue
|
||||||
:param event:
|
:param event:
|
||||||
|
|||||||
@ -34,7 +34,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
|||||||
variable_loader: VariableLoader,
|
variable_loader: VariableLoader,
|
||||||
workflow: Workflow,
|
workflow: Workflow,
|
||||||
system_user_id: str,
|
system_user_id: str,
|
||||||
) -> None:
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
queue_manager=queue_manager,
|
queue_manager=queue_manager,
|
||||||
variable_loader=variable_loader,
|
variable_loader=variable_loader,
|
||||||
@ -44,7 +44,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
|||||||
self._workflow = workflow
|
self._workflow = workflow
|
||||||
self._sys_user_id = system_user_id
|
self._sys_user_id = system_user_id
|
||||||
|
|
||||||
def run(self) -> None:
|
def run(self):
|
||||||
"""
|
"""
|
||||||
Run application
|
Run application
|
||||||
"""
|
"""
|
||||||
@ -127,6 +127,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
|||||||
),
|
),
|
||||||
invoke_from=self.application_generate_entity.invoke_from,
|
invoke_from=self.application_generate_entity.invoke_from,
|
||||||
call_depth=self.application_generate_entity.call_depth,
|
call_depth=self.application_generate_entity.call_depth,
|
||||||
|
variable_pool=variable_pool,
|
||||||
graph_runtime_state=graph_runtime_state,
|
graph_runtime_state=graph_runtime_state,
|
||||||
command_channel=command_channel,
|
command_channel=command_channel,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -17,7 +17,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
|||||||
_blocking_response_type = WorkflowAppBlockingResponse
|
_blocking_response_type = WorkflowAppBlockingResponse
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: # type: ignore[override]
|
def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse): # type: ignore[override]
|
||||||
"""
|
"""
|
||||||
Convert blocking full response.
|
Convert blocking full response.
|
||||||
:param blocking_response: blocking response
|
:param blocking_response: blocking response
|
||||||
@ -26,7 +26,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
|||||||
return dict(blocking_response.to_dict())
|
return dict(blocking_response.to_dict())
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: # type: ignore[override]
|
def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse): # type: ignore[override]
|
||||||
"""
|
"""
|
||||||
Convert blocking simple response.
|
Convert blocking simple response.
|
||||||
:param blocking_response: blocking response
|
:param blocking_response: blocking response
|
||||||
|
|||||||
@ -88,7 +88,7 @@ class WorkflowAppGenerateTaskPipeline:
|
|||||||
workflow_execution_repository: WorkflowExecutionRepository,
|
workflow_execution_repository: WorkflowExecutionRepository,
|
||||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||||
draft_var_saver_factory: DraftVariableSaverFactory,
|
draft_var_saver_factory: DraftVariableSaverFactory,
|
||||||
) -> None:
|
):
|
||||||
self._base_task_pipeline = BasedGenerateTaskPipeline(
|
self._base_task_pipeline = BasedGenerateTaskPipeline(
|
||||||
application_generate_entity=application_generate_entity,
|
application_generate_entity=application_generate_entity,
|
||||||
queue_manager=queue_manager,
|
queue_manager=queue_manager,
|
||||||
@ -259,7 +259,7 @@ class WorkflowAppGenerateTaskPipeline:
|
|||||||
session.rollback()
|
session.rollback()
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def _ensure_workflow_initialized(self) -> None:
|
def _ensure_workflow_initialized(self):
|
||||||
"""Fluent validation for workflow state."""
|
"""Fluent validation for workflow state."""
|
||||||
if not self._workflow_run_id:
|
if not self._workflow_run_id:
|
||||||
raise ValueError("workflow run not initialized.")
|
raise ValueError("workflow run not initialized.")
|
||||||
@ -697,7 +697,7 @@ class WorkflowAppGenerateTaskPipeline:
|
|||||||
if tts_publisher:
|
if tts_publisher:
|
||||||
tts_publisher.publish(None)
|
tts_publisher.publish(None)
|
||||||
|
|
||||||
def _save_workflow_app_log(self, *, session: Session, workflow_execution: WorkflowExecution) -> None:
|
def _save_workflow_app_log(self, *, session: Session, workflow_execution: WorkflowExecution):
|
||||||
invoke_from = self._application_generate_entity.invoke_from
|
invoke_from = self._application_generate_entity.invoke_from
|
||||||
if invoke_from == InvokeFrom.SERVICE_API:
|
if invoke_from == InvokeFrom.SERVICE_API:
|
||||||
created_from = WorkflowAppLogCreatedFrom.SERVICE_API
|
created_from = WorkflowAppLogCreatedFrom.SERVICE_API
|
||||||
|
|||||||
@ -67,7 +67,7 @@ class WorkflowBasedAppRunner:
|
|||||||
queue_manager: AppQueueManager,
|
queue_manager: AppQueueManager,
|
||||||
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
|
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
|
||||||
app_id: str,
|
app_id: str,
|
||||||
) -> None:
|
):
|
||||||
self._queue_manager = queue_manager
|
self._queue_manager = queue_manager
|
||||||
self._variable_loader = variable_loader
|
self._variable_loader = variable_loader
|
||||||
self._app_id = app_id
|
self._app_id = app_id
|
||||||
@ -348,7 +348,7 @@ class WorkflowBasedAppRunner:
|
|||||||
|
|
||||||
return graph, variable_pool
|
return graph, variable_pool
|
||||||
|
|
||||||
def _handle_event(self, workflow_entry: WorkflowEntry, event: GraphEngineEvent) -> None:
|
def _handle_event(self, workflow_entry: WorkflowEntry, event: GraphEngineEvent):
|
||||||
"""
|
"""
|
||||||
Handle event
|
Handle event
|
||||||
:param workflow_entry: workflow entry
|
:param workflow_entry: workflow entry
|
||||||
@ -580,5 +580,5 @@ class WorkflowBasedAppRunner:
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
def _publish_event(self, event: AppQueueEvent) -> None:
|
def _publish_event(self, event: AppQueueEvent):
|
||||||
self._queue_manager.publish(event, PublishFrom.APPLICATION_MANAGER)
|
self._queue_manager.publish(event, PublishFrom.APPLICATION_MANAGER)
|
||||||
|
|||||||
@ -35,7 +35,7 @@ class BasedGenerateTaskPipeline:
|
|||||||
application_generate_entity: AppGenerateEntity,
|
application_generate_entity: AppGenerateEntity,
|
||||||
queue_manager: AppQueueManager,
|
queue_manager: AppQueueManager,
|
||||||
stream: bool,
|
stream: bool,
|
||||||
) -> None:
|
):
|
||||||
self._application_generate_entity = application_generate_entity
|
self._application_generate_entity = application_generate_entity
|
||||||
self.queue_manager = queue_manager
|
self.queue_manager = queue_manager
|
||||||
self._start_at = time.perf_counter()
|
self._start_at = time.perf_counter()
|
||||||
|
|||||||
@ -80,7 +80,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
|||||||
conversation: Conversation,
|
conversation: Conversation,
|
||||||
message: Message,
|
message: Message,
|
||||||
stream: bool,
|
stream: bool,
|
||||||
) -> None:
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
application_generate_entity=application_generate_entity,
|
application_generate_entity=application_generate_entity,
|
||||||
queue_manager=queue_manager,
|
queue_manager=queue_manager,
|
||||||
@ -362,7 +362,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
|||||||
if self._conversation_name_generate_thread:
|
if self._conversation_name_generate_thread:
|
||||||
self._conversation_name_generate_thread.join()
|
self._conversation_name_generate_thread.join()
|
||||||
|
|
||||||
def _save_message(self, *, session: Session, trace_manager: Optional[TraceQueueManager] = None) -> None:
|
def _save_message(self, *, session: Session, trace_manager: Optional[TraceQueueManager] = None):
|
||||||
"""
|
"""
|
||||||
Save message.
|
Save message.
|
||||||
:return:
|
:return:
|
||||||
@ -412,7 +412,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
|||||||
application_generate_entity=self._application_generate_entity,
|
application_generate_entity=self._application_generate_entity,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _handle_stop(self, event: QueueStopEvent) -> None:
|
def _handle_stop(self, event: QueueStopEvent):
|
||||||
"""
|
"""
|
||||||
Handle stop.
|
Handle stop.
|
||||||
:return:
|
:return:
|
||||||
|
|||||||
@ -48,7 +48,7 @@ class MessageCycleManager:
|
|||||||
AdvancedChatAppGenerateEntity,
|
AdvancedChatAppGenerateEntity,
|
||||||
],
|
],
|
||||||
task_state: Union[EasyUITaskState, WorkflowTaskState],
|
task_state: Union[EasyUITaskState, WorkflowTaskState],
|
||||||
) -> None:
|
):
|
||||||
self._application_generate_entity = application_generate_entity
|
self._application_generate_entity = application_generate_entity
|
||||||
self._task_state = task_state
|
self._task_state = task_state
|
||||||
|
|
||||||
@ -132,7 +132,7 @@ class MessageCycleManager:
|
|||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def handle_retriever_resources(self, event: QueueRetrieverResourcesEvent) -> None:
|
def handle_retriever_resources(self, event: QueueRetrieverResourcesEvent):
|
||||||
"""
|
"""
|
||||||
Handle retriever resources.
|
Handle retriever resources.
|
||||||
:param event: event
|
:param event: event
|
||||||
|
|||||||
@ -23,7 +23,7 @@ def get_colored_text(text: str, color: str) -> str:
|
|||||||
return f"\u001b[{color_str}m\033[1;3m{text}\u001b[0m"
|
return f"\u001b[{color_str}m\033[1;3m{text}\u001b[0m"
|
||||||
|
|
||||||
|
|
||||||
def print_text(text: str, color: Optional[str] = None, end: str = "", file: Optional[TextIO] = None) -> None:
|
def print_text(text: str, color: Optional[str] = None, end: str = "", file: Optional[TextIO] = None):
|
||||||
"""Print text with highlighting and no end characters."""
|
"""Print text with highlighting and no end characters."""
|
||||||
text_to_print = get_colored_text(text, color) if color else text
|
text_to_print = get_colored_text(text, color) if color else text
|
||||||
print(text_to_print, end=end, file=file)
|
print(text_to_print, end=end, file=file)
|
||||||
@ -37,7 +37,7 @@ class DifyAgentCallbackHandler(BaseModel):
|
|||||||
color: Optional[str] = ""
|
color: Optional[str] = ""
|
||||||
current_loop: int = 1
|
current_loop: int = 1
|
||||||
|
|
||||||
def __init__(self, color: Optional[str] = None) -> None:
|
def __init__(self, color: Optional[str] = None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
"""Initialize callback handler."""
|
"""Initialize callback handler."""
|
||||||
# use a specific color is not specified
|
# use a specific color is not specified
|
||||||
@ -48,7 +48,7 @@ class DifyAgentCallbackHandler(BaseModel):
|
|||||||
self,
|
self,
|
||||||
tool_name: str,
|
tool_name: str,
|
||||||
tool_inputs: Mapping[str, Any],
|
tool_inputs: Mapping[str, Any],
|
||||||
) -> None:
|
):
|
||||||
"""Do nothing."""
|
"""Do nothing."""
|
||||||
if dify_config.DEBUG:
|
if dify_config.DEBUG:
|
||||||
print_text("\n[on_tool_start] ToolCall:" + tool_name + "\n" + str(tool_inputs) + "\n", color=self.color)
|
print_text("\n[on_tool_start] ToolCall:" + tool_name + "\n" + str(tool_inputs) + "\n", color=self.color)
|
||||||
@ -61,7 +61,7 @@ class DifyAgentCallbackHandler(BaseModel):
|
|||||||
message_id: Optional[str] = None,
|
message_id: Optional[str] = None,
|
||||||
timer: Optional[Any] = None,
|
timer: Optional[Any] = None,
|
||||||
trace_manager: Optional[TraceQueueManager] = None,
|
trace_manager: Optional[TraceQueueManager] = None,
|
||||||
) -> None:
|
):
|
||||||
"""If not the final action, print out observation."""
|
"""If not the final action, print out observation."""
|
||||||
if dify_config.DEBUG:
|
if dify_config.DEBUG:
|
||||||
print_text("\n[on_tool_end]\n", color=self.color)
|
print_text("\n[on_tool_end]\n", color=self.color)
|
||||||
@ -82,12 +82,12 @@ class DifyAgentCallbackHandler(BaseModel):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
def on_tool_error(self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any) -> None:
|
def on_tool_error(self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any):
|
||||||
"""Do nothing."""
|
"""Do nothing."""
|
||||||
if dify_config.DEBUG:
|
if dify_config.DEBUG:
|
||||||
print_text("\n[on_tool_error] Error: " + str(error) + "\n", color="red")
|
print_text("\n[on_tool_error] Error: " + str(error) + "\n", color="red")
|
||||||
|
|
||||||
def on_agent_start(self, thought: str) -> None:
|
def on_agent_start(self, thought: str):
|
||||||
"""Run on agent start."""
|
"""Run on agent start."""
|
||||||
if dify_config.DEBUG:
|
if dify_config.DEBUG:
|
||||||
if thought:
|
if thought:
|
||||||
@ -98,7 +98,7 @@ class DifyAgentCallbackHandler(BaseModel):
|
|||||||
else:
|
else:
|
||||||
print_text("\n[on_agent_start] \nCurrent Loop: " + str(self.current_loop) + "\n", color=self.color)
|
print_text("\n[on_agent_start] \nCurrent Loop: " + str(self.current_loop) + "\n", color=self.color)
|
||||||
|
|
||||||
def on_agent_finish(self, color: Optional[str] = None, **kwargs: Any) -> None:
|
def on_agent_finish(self, color: Optional[str] = None, **kwargs: Any):
|
||||||
"""Run on agent end."""
|
"""Run on agent end."""
|
||||||
if dify_config.DEBUG:
|
if dify_config.DEBUG:
|
||||||
print_text("\n[on_agent_finish]\n Loop: " + str(self.current_loop) + "\n", color=self.color)
|
print_text("\n[on_agent_finish]\n Loop: " + str(self.current_loop) + "\n", color=self.color)
|
||||||
|
|||||||
@ -21,14 +21,14 @@ class DatasetIndexToolCallbackHandler:
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, queue_manager: AppQueueManager, app_id: str, message_id: str, user_id: str, invoke_from: InvokeFrom
|
self, queue_manager: AppQueueManager, app_id: str, message_id: str, user_id: str, invoke_from: InvokeFrom
|
||||||
) -> None:
|
):
|
||||||
self._queue_manager = queue_manager
|
self._queue_manager = queue_manager
|
||||||
self._app_id = app_id
|
self._app_id = app_id
|
||||||
self._message_id = message_id
|
self._message_id = message_id
|
||||||
self._user_id = user_id
|
self._user_id = user_id
|
||||||
self._invoke_from = invoke_from
|
self._invoke_from = invoke_from
|
||||||
|
|
||||||
def on_query(self, query: str, dataset_id: str) -> None:
|
def on_query(self, query: str, dataset_id: str):
|
||||||
"""
|
"""
|
||||||
Handle query.
|
Handle query.
|
||||||
"""
|
"""
|
||||||
@ -46,7 +46,7 @@ class DatasetIndexToolCallbackHandler:
|
|||||||
db.session.add(dataset_query)
|
db.session.add(dataset_query)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
def on_tool_end(self, documents: list[Document]) -> None:
|
def on_tool_end(self, documents: list[Document]):
|
||||||
"""Handle tool end."""
|
"""Handle tool end."""
|
||||||
for document in documents:
|
for document in documents:
|
||||||
if document.metadata is not None:
|
if document.metadata is not None:
|
||||||
|
|||||||
@ -30,8 +30,6 @@ class FakeDatasourceRuntime(DatasourceRuntime):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
tenant_id="fake_tenant_id",
|
tenant_id="fake_tenant_id",
|
||||||
datasource_id="fake_datasource_id",
|
datasource_id="fake_datasource_id",
|
||||||
|
|||||||
@ -33,7 +33,7 @@ class SimpleModelProviderEntity(BaseModel):
|
|||||||
icon_large: Optional[I18nObject] = None
|
icon_large: Optional[I18nObject] = None
|
||||||
supported_model_types: list[ModelType]
|
supported_model_types: list[ModelType]
|
||||||
|
|
||||||
def __init__(self, provider_entity: ProviderEntity) -> None:
|
def __init__(self, provider_entity: ProviderEntity):
|
||||||
"""
|
"""
|
||||||
Init simple provider.
|
Init simple provider.
|
||||||
|
|
||||||
@ -57,7 +57,7 @@ class ProviderModelWithStatusEntity(ProviderModel):
|
|||||||
load_balancing_enabled: bool = False
|
load_balancing_enabled: bool = False
|
||||||
has_invalid_load_balancing_configs: bool = False
|
has_invalid_load_balancing_configs: bool = False
|
||||||
|
|
||||||
def raise_for_status(self) -> None:
|
def raise_for_status(self):
|
||||||
"""
|
"""
|
||||||
Check model status and raise ValueError if not active.
|
Check model status and raise ValueError if not active.
|
||||||
|
|
||||||
|
|||||||
@ -280,9 +280,7 @@ class ProviderConfiguration(BaseModel):
|
|||||||
else [],
|
else [],
|
||||||
)
|
)
|
||||||
|
|
||||||
def validate_provider_credentials(
|
def validate_provider_credentials(self, credentials: dict, credential_id: str = "", session: Session | None = None):
|
||||||
self, credentials: dict, credential_id: str = "", session: Session | None = None
|
|
||||||
) -> dict:
|
|
||||||
"""
|
"""
|
||||||
Validate custom credentials.
|
Validate custom credentials.
|
||||||
:param credentials: provider credentials
|
:param credentials: provider credentials
|
||||||
@ -291,7 +289,7 @@ class ProviderConfiguration(BaseModel):
|
|||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _validate(s: Session) -> dict:
|
def _validate(s: Session):
|
||||||
# Get provider credential secret variables
|
# Get provider credential secret variables
|
||||||
provider_credential_secret_variables = self.extract_secret_variables(
|
provider_credential_secret_variables = self.extract_secret_variables(
|
||||||
self.provider.provider_credential_schema.credential_form_schemas
|
self.provider.provider_credential_schema.credential_form_schemas
|
||||||
@ -402,7 +400,7 @@ class ProviderConfiguration(BaseModel):
|
|||||||
logger.warning("Error generating next credential name: %s", str(e))
|
logger.warning("Error generating next credential name: %s", str(e))
|
||||||
return "API KEY 1"
|
return "API KEY 1"
|
||||||
|
|
||||||
def create_provider_credential(self, credentials: dict, credential_name: str | None) -> None:
|
def create_provider_credential(self, credentials: dict, credential_name: str | None):
|
||||||
"""
|
"""
|
||||||
Add custom provider credentials.
|
Add custom provider credentials.
|
||||||
:param credentials: provider credentials
|
:param credentials: provider credentials
|
||||||
@ -458,7 +456,7 @@ class ProviderConfiguration(BaseModel):
|
|||||||
credentials: dict,
|
credentials: dict,
|
||||||
credential_id: str,
|
credential_id: str,
|
||||||
credential_name: str | None,
|
credential_name: str | None,
|
||||||
) -> None:
|
):
|
||||||
"""
|
"""
|
||||||
update a saved provider credential (by credential_id).
|
update a saved provider credential (by credential_id).
|
||||||
|
|
||||||
@ -519,7 +517,7 @@ class ProviderConfiguration(BaseModel):
|
|||||||
credential_record: ProviderCredential | ProviderModelCredential,
|
credential_record: ProviderCredential | ProviderModelCredential,
|
||||||
credential_source: str,
|
credential_source: str,
|
||||||
session: Session,
|
session: Session,
|
||||||
) -> None:
|
):
|
||||||
"""
|
"""
|
||||||
Update load balancing configurations that reference the given credential_id.
|
Update load balancing configurations that reference the given credential_id.
|
||||||
|
|
||||||
@ -559,7 +557,7 @@ class ProviderConfiguration(BaseModel):
|
|||||||
|
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
def delete_provider_credential(self, credential_id: str) -> None:
|
def delete_provider_credential(self, credential_id: str):
|
||||||
"""
|
"""
|
||||||
Delete a saved provider credential (by credential_id).
|
Delete a saved provider credential (by credential_id).
|
||||||
|
|
||||||
@ -636,7 +634,7 @@ class ProviderConfiguration(BaseModel):
|
|||||||
session.rollback()
|
session.rollback()
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def switch_active_provider_credential(self, credential_id: str) -> None:
|
def switch_active_provider_credential(self, credential_id: str):
|
||||||
"""
|
"""
|
||||||
Switch active provider credential (copy the selected one into current active snapshot).
|
Switch active provider credential (copy the selected one into current active snapshot).
|
||||||
|
|
||||||
@ -815,7 +813,7 @@ class ProviderConfiguration(BaseModel):
|
|||||||
credentials: dict,
|
credentials: dict,
|
||||||
credential_id: str = "",
|
credential_id: str = "",
|
||||||
session: Session | None = None,
|
session: Session | None = None,
|
||||||
) -> dict:
|
):
|
||||||
"""
|
"""
|
||||||
Validate custom model credentials.
|
Validate custom model credentials.
|
||||||
|
|
||||||
@ -826,7 +824,7 @@ class ProviderConfiguration(BaseModel):
|
|||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _validate(s: Session) -> dict:
|
def _validate(s: Session):
|
||||||
# Get provider credential secret variables
|
# Get provider credential secret variables
|
||||||
provider_credential_secret_variables = self.extract_secret_variables(
|
provider_credential_secret_variables = self.extract_secret_variables(
|
||||||
self.provider.model_credential_schema.credential_form_schemas
|
self.provider.model_credential_schema.credential_form_schemas
|
||||||
@ -1010,7 +1008,7 @@ class ProviderConfiguration(BaseModel):
|
|||||||
session.rollback()
|
session.rollback()
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def delete_custom_model_credential(self, model_type: ModelType, model: str, credential_id: str) -> None:
|
def delete_custom_model_credential(self, model_type: ModelType, model: str, credential_id: str):
|
||||||
"""
|
"""
|
||||||
Delete a saved provider credential (by credential_id).
|
Delete a saved provider credential (by credential_id).
|
||||||
|
|
||||||
@ -1080,7 +1078,7 @@ class ProviderConfiguration(BaseModel):
|
|||||||
session.rollback()
|
session.rollback()
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def add_model_credential_to_model(self, model_type: ModelType, model: str, credential_id: str) -> None:
|
def add_model_credential_to_model(self, model_type: ModelType, model: str, credential_id: str):
|
||||||
"""
|
"""
|
||||||
if model list exist this custom model, switch the custom model credential.
|
if model list exist this custom model, switch the custom model credential.
|
||||||
if model list not exist this custom model, use the credential to add a new custom model record.
|
if model list not exist this custom model, use the credential to add a new custom model record.
|
||||||
@ -1123,7 +1121,7 @@ class ProviderConfiguration(BaseModel):
|
|||||||
session.add(provider_model_record)
|
session.add(provider_model_record)
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
def switch_custom_model_credential(self, model_type: ModelType, model: str, credential_id: str) -> None:
|
def switch_custom_model_credential(self, model_type: ModelType, model: str, credential_id: str):
|
||||||
"""
|
"""
|
||||||
switch the custom model credential.
|
switch the custom model credential.
|
||||||
|
|
||||||
@ -1153,7 +1151,7 @@ class ProviderConfiguration(BaseModel):
|
|||||||
session.add(provider_model_record)
|
session.add(provider_model_record)
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
def delete_custom_model(self, model_type: ModelType, model: str) -> None:
|
def delete_custom_model(self, model_type: ModelType, model: str):
|
||||||
"""
|
"""
|
||||||
Delete custom model.
|
Delete custom model.
|
||||||
:param model_type: model type
|
:param model_type: model type
|
||||||
@ -1350,7 +1348,7 @@ class ProviderConfiguration(BaseModel):
|
|||||||
provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials
|
provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials
|
||||||
)
|
)
|
||||||
|
|
||||||
def switch_preferred_provider_type(self, provider_type: ProviderType, session: Session | None = None) -> None:
|
def switch_preferred_provider_type(self, provider_type: ProviderType, session: Session | None = None):
|
||||||
"""
|
"""
|
||||||
Switch preferred provider type.
|
Switch preferred provider type.
|
||||||
:param provider_type:
|
:param provider_type:
|
||||||
@ -1362,7 +1360,7 @@ class ProviderConfiguration(BaseModel):
|
|||||||
if provider_type == ProviderType.SYSTEM and not self.system_configuration.enabled:
|
if provider_type == ProviderType.SYSTEM and not self.system_configuration.enabled:
|
||||||
return
|
return
|
||||||
|
|
||||||
def _switch(s: Session) -> None:
|
def _switch(s: Session):
|
||||||
# get preferred provider
|
# get preferred provider
|
||||||
model_provider_id = ModelProviderID(self.provider.provider)
|
model_provider_id = ModelProviderID(self.provider.provider)
|
||||||
provider_names = [self.provider.provider]
|
provider_names = [self.provider.provider]
|
||||||
@ -1406,7 +1404,7 @@ class ProviderConfiguration(BaseModel):
|
|||||||
|
|
||||||
return secret_input_form_variables
|
return secret_input_form_variables
|
||||||
|
|
||||||
def obfuscated_credentials(self, credentials: dict, credential_form_schemas: list[CredentialFormSchema]) -> dict:
|
def obfuscated_credentials(self, credentials: dict, credential_form_schemas: list[CredentialFormSchema]):
|
||||||
"""
|
"""
|
||||||
Obfuscated credentials.
|
Obfuscated credentials.
|
||||||
|
|
||||||
|
|||||||
@ -6,7 +6,7 @@ class LLMError(ValueError):
|
|||||||
|
|
||||||
description: Optional[str] = None
|
description: Optional[str] = None
|
||||||
|
|
||||||
def __init__(self, description: Optional[str] = None) -> None:
|
def __init__(self, description: Optional[str] = None):
|
||||||
self.description = description
|
self.description = description
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -10,11 +10,11 @@ class APIBasedExtensionRequestor:
|
|||||||
timeout: tuple[int, int] = (5, 60)
|
timeout: tuple[int, int] = (5, 60)
|
||||||
"""timeout for request connect and read"""
|
"""timeout for request connect and read"""
|
||||||
|
|
||||||
def __init__(self, api_endpoint: str, api_key: str) -> None:
|
def __init__(self, api_endpoint: str, api_key: str):
|
||||||
self.api_endpoint = api_endpoint
|
self.api_endpoint = api_endpoint
|
||||||
self.api_key = api_key
|
self.api_key = api_key
|
||||||
|
|
||||||
def request(self, point: APIBasedExtensionPoint, params: dict) -> dict:
|
def request(self, point: APIBasedExtensionPoint, params: dict):
|
||||||
"""
|
"""
|
||||||
Request the api.
|
Request the api.
|
||||||
|
|
||||||
|
|||||||
@ -34,7 +34,7 @@ class Extensible:
|
|||||||
tenant_id: str
|
tenant_id: str
|
||||||
config: Optional[dict] = None
|
config: Optional[dict] = None
|
||||||
|
|
||||||
def __init__(self, tenant_id: str, config: Optional[dict] = None) -> None:
|
def __init__(self, tenant_id: str, config: Optional[dict] = None):
|
||||||
self.tenant_id = tenant_id
|
self.tenant_id = tenant_id
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
|
|||||||
@ -18,7 +18,7 @@ class ApiExternalDataTool(ExternalDataTool):
|
|||||||
"""the unique name of external data tool"""
|
"""the unique name of external data tool"""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_config(cls, tenant_id: str, config: dict) -> None:
|
def validate_config(cls, tenant_id: str, config: dict):
|
||||||
"""
|
"""
|
||||||
Validate the incoming form config data.
|
Validate the incoming form config data.
|
||||||
|
|
||||||
|
|||||||
@ -16,14 +16,14 @@ class ExternalDataTool(Extensible, ABC):
|
|||||||
variable: str
|
variable: str
|
||||||
"""the tool variable name of app tool"""
|
"""the tool variable name of app tool"""
|
||||||
|
|
||||||
def __init__(self, tenant_id: str, app_id: str, variable: str, config: Optional[dict] = None) -> None:
|
def __init__(self, tenant_id: str, app_id: str, variable: str, config: Optional[dict] = None):
|
||||||
super().__init__(tenant_id, config)
|
super().__init__(tenant_id, config)
|
||||||
self.app_id = app_id
|
self.app_id = app_id
|
||||||
self.variable = variable
|
self.variable = variable
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def validate_config(cls, tenant_id: str, config: dict) -> None:
|
def validate_config(cls, tenant_id: str, config: dict):
|
||||||
"""
|
"""
|
||||||
Validate the incoming form config data.
|
Validate the incoming form config data.
|
||||||
|
|
||||||
|
|||||||
@ -6,14 +6,14 @@ from extensions.ext_code_based_extension import code_based_extension
|
|||||||
|
|
||||||
|
|
||||||
class ExternalDataToolFactory:
|
class ExternalDataToolFactory:
|
||||||
def __init__(self, name: str, tenant_id: str, app_id: str, variable: str, config: dict) -> None:
|
def __init__(self, name: str, tenant_id: str, app_id: str, variable: str, config: dict):
|
||||||
extension_class = code_based_extension.extension_class(ExtensionModule.EXTERNAL_DATA_TOOL, name)
|
extension_class = code_based_extension.extension_class(ExtensionModule.EXTERNAL_DATA_TOOL, name)
|
||||||
self.__extension_instance = extension_class(
|
self.__extension_instance = extension_class(
|
||||||
tenant_id=tenant_id, app_id=app_id, variable=variable, config=config
|
tenant_id=tenant_id, app_id=app_id, variable=variable, config=config
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_config(cls, name: str, tenant_id: str, config: dict) -> None:
|
def validate_config(cls, name: str, tenant_id: str, config: dict):
|
||||||
"""
|
"""
|
||||||
Validate the incoming form config data.
|
Validate the incoming form config data.
|
||||||
|
|
||||||
|
|||||||
@ -7,6 +7,6 @@ if TYPE_CHECKING:
|
|||||||
_tool_file_manager_factory: Callable[[], "ToolFileManager"] | None = None
|
_tool_file_manager_factory: Callable[[], "ToolFileManager"] | None = None
|
||||||
|
|
||||||
|
|
||||||
def set_tool_file_manager_factory(factory: Callable[[], "ToolFileManager"]) -> None:
|
def set_tool_file_manager_factory(factory: Callable[[], "ToolFileManager"]):
|
||||||
global _tool_file_manager_factory
|
global _tool_file_manager_factory
|
||||||
_tool_file_manager_factory = factory
|
_tool_file_manager_factory = factory
|
||||||
|
|||||||
@ -22,7 +22,7 @@ class CodeNodeProvider(BaseModel):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_default_config(cls) -> dict:
|
def get_default_config(cls):
|
||||||
return {
|
return {
|
||||||
"type": "code",
|
"type": "code",
|
||||||
"config": {
|
"config": {
|
||||||
|
|||||||
@ -5,7 +5,7 @@ from core.helper.code_executor.template_transformer import TemplateTransformer
|
|||||||
|
|
||||||
class Jinja2TemplateTransformer(TemplateTransformer):
|
class Jinja2TemplateTransformer(TemplateTransformer):
|
||||||
@classmethod
|
@classmethod
|
||||||
def transform_response(cls, response: str) -> dict:
|
def transform_response(cls, response: str):
|
||||||
"""
|
"""
|
||||||
Transform response to dict
|
Transform response to dict
|
||||||
:param response: response
|
:param response: response
|
||||||
|
|||||||
@ -13,7 +13,7 @@ class Python3CodeProvider(CodeNodeProvider):
|
|||||||
def get_default_code(cls) -> str:
|
def get_default_code(cls) -> str:
|
||||||
return dedent(
|
return dedent(
|
||||||
"""
|
"""
|
||||||
def main(arg1: str, arg2: str) -> dict:
|
def main(arg1: str, arg2: str):
|
||||||
return {
|
return {
|
||||||
"result": arg1 + arg2,
|
"result": arg1 + arg2,
|
||||||
}
|
}
|
||||||
|
|||||||
@ -34,7 +34,7 @@ class ProviderCredentialsCache:
|
|||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def set(self, credentials: dict) -> None:
|
def set(self, credentials: dict):
|
||||||
"""
|
"""
|
||||||
Cache model provider credentials.
|
Cache model provider credentials.
|
||||||
|
|
||||||
@ -43,7 +43,7 @@ class ProviderCredentialsCache:
|
|||||||
"""
|
"""
|
||||||
redis_client.setex(self.cache_key, 86400, json.dumps(credentials))
|
redis_client.setex(self.cache_key, 86400, json.dumps(credentials))
|
||||||
|
|
||||||
def delete(self) -> None:
|
def delete(self):
|
||||||
"""
|
"""
|
||||||
Delete cached model provider credentials.
|
Delete cached model provider credentials.
|
||||||
|
|
||||||
|
|||||||
@ -28,11 +28,11 @@ class ProviderCredentialsCache(ABC):
|
|||||||
return None
|
return None
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def set(self, config: dict[str, Any]) -> None:
|
def set(self, config: dict[str, Any]):
|
||||||
"""Cache provider credentials"""
|
"""Cache provider credentials"""
|
||||||
redis_client.setex(self.cache_key, 86400, json.dumps(config))
|
redis_client.setex(self.cache_key, 86400, json.dumps(config))
|
||||||
|
|
||||||
def delete(self) -> None:
|
def delete(self):
|
||||||
"""Delete cached provider credentials"""
|
"""Delete cached provider credentials"""
|
||||||
redis_client.delete(self.cache_key)
|
redis_client.delete(self.cache_key)
|
||||||
|
|
||||||
@ -75,10 +75,10 @@ class NoOpProviderCredentialCache:
|
|||||||
"""Get cached provider credentials"""
|
"""Get cached provider credentials"""
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def set(self, config: dict[str, Any]) -> None:
|
def set(self, config: dict[str, Any]):
|
||||||
"""Cache provider credentials"""
|
"""Cache provider credentials"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def delete(self) -> None:
|
def delete(self):
|
||||||
"""Delete cached provider credentials"""
|
"""Delete cached provider credentials"""
|
||||||
pass
|
pass
|
||||||
|
|||||||
@ -37,11 +37,11 @@ class ToolParameterCache:
|
|||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def set(self, parameters: dict) -> None:
|
def set(self, parameters: dict):
|
||||||
"""Cache model provider credentials."""
|
"""Cache model provider credentials."""
|
||||||
redis_client.setex(self.cache_key, 86400, json.dumps(parameters))
|
redis_client.setex(self.cache_key, 86400, json.dumps(parameters))
|
||||||
|
|
||||||
def delete(self) -> None:
|
def delete(self):
|
||||||
"""
|
"""
|
||||||
Delete cached model provider credentials.
|
Delete cached model provider credentials.
|
||||||
|
|
||||||
|
|||||||
@ -49,7 +49,7 @@ def get_external_trace_id(request: Any) -> Optional[str]:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def extract_external_trace_id_from_args(args: Mapping[str, Any]) -> dict:
|
def extract_external_trace_id_from_args(args: Mapping[str, Any]):
|
||||||
"""
|
"""
|
||||||
Extract 'external_trace_id' from args.
|
Extract 'external_trace_id' from args.
|
||||||
|
|
||||||
|
|||||||
@ -44,11 +44,11 @@ class HostingConfiguration:
|
|||||||
provider_map: dict[str, HostingProvider]
|
provider_map: dict[str, HostingProvider]
|
||||||
moderation_config: Optional[HostedModerationConfig] = None
|
moderation_config: Optional[HostedModerationConfig] = None
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self):
|
||||||
self.provider_map = {}
|
self.provider_map = {}
|
||||||
self.moderation_config = None
|
self.moderation_config = None
|
||||||
|
|
||||||
def init_app(self, app: Flask) -> None:
|
def init_app(self, app: Flask):
|
||||||
if dify_config.EDITION != "CLOUD":
|
if dify_config.EDITION != "CLOUD":
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|||||||
@ -270,7 +270,9 @@ class IndexingRunner:
|
|||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
model_type=ModelType.TEXT_EMBEDDING,
|
model_type=ModelType.TEXT_EMBEDDING,
|
||||||
)
|
)
|
||||||
preview_texts = [] # type: ignore
|
# keep separate, avoid union-list ambiguity
|
||||||
|
preview_texts: list[PreviewDetail] = []
|
||||||
|
qa_preview_texts: list[QAPreviewDetail] = []
|
||||||
|
|
||||||
total_segments = 0
|
total_segments = 0
|
||||||
index_type = doc_form
|
index_type = doc_form
|
||||||
@ -293,14 +295,14 @@ class IndexingRunner:
|
|||||||
for document in documents:
|
for document in documents:
|
||||||
if len(preview_texts) < 10:
|
if len(preview_texts) < 10:
|
||||||
if doc_form and doc_form == "qa_model":
|
if doc_form and doc_form == "qa_model":
|
||||||
preview_detail = QAPreviewDetail(
|
qa_detail = QAPreviewDetail(
|
||||||
question=document.page_content, answer=document.metadata.get("answer") or ""
|
question=document.page_content, answer=document.metadata.get("answer") or ""
|
||||||
)
|
)
|
||||||
preview_texts.append(preview_detail)
|
qa_preview_texts.append(qa_detail)
|
||||||
else:
|
else:
|
||||||
preview_detail = PreviewDetail(content=document.page_content) # type: ignore
|
preview_detail = PreviewDetail(content=document.page_content)
|
||||||
if document.children:
|
if document.children:
|
||||||
preview_detail.child_chunks = [child.page_content for child in document.children] # type: ignore
|
preview_detail.child_chunks = [child.page_content for child in document.children]
|
||||||
preview_texts.append(preview_detail)
|
preview_texts.append(preview_detail)
|
||||||
|
|
||||||
# delete image files and related db records
|
# delete image files and related db records
|
||||||
@ -321,8 +323,8 @@ class IndexingRunner:
|
|||||||
db.session.delete(image_file)
|
db.session.delete(image_file)
|
||||||
|
|
||||||
if doc_form and doc_form == "qa_model":
|
if doc_form and doc_form == "qa_model":
|
||||||
return IndexingEstimate(total_segments=total_segments * 20, qa_preview=preview_texts, preview=[])
|
return IndexingEstimate(total_segments=total_segments * 20, qa_preview=qa_preview_texts, preview=[])
|
||||||
return IndexingEstimate(total_segments=total_segments, preview=preview_texts) # type: ignore
|
return IndexingEstimate(total_segments=total_segments, preview=preview_texts)
|
||||||
|
|
||||||
def _extract(
|
def _extract(
|
||||||
self, index_processor: BaseIndexProcessor, dataset_document: DatasetDocument, process_rule: dict
|
self, index_processor: BaseIndexProcessor, dataset_document: DatasetDocument, process_rule: dict
|
||||||
@ -425,6 +427,7 @@ class IndexingRunner:
|
|||||||
"""
|
"""
|
||||||
Get the NodeParser object according to the processing rule.
|
Get the NodeParser object according to the processing rule.
|
||||||
"""
|
"""
|
||||||
|
character_splitter: TextSplitter
|
||||||
if processing_rule_mode in ["custom", "hierarchical"]:
|
if processing_rule_mode in ["custom", "hierarchical"]:
|
||||||
# The user-defined segmentation rule
|
# The user-defined segmentation rule
|
||||||
max_segmentation_tokens_length = dify_config.INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH
|
max_segmentation_tokens_length = dify_config.INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH
|
||||||
@ -451,7 +454,7 @@ class IndexingRunner:
|
|||||||
embedding_model_instance=embedding_model_instance,
|
embedding_model_instance=embedding_model_instance,
|
||||||
)
|
)
|
||||||
|
|
||||||
return character_splitter # type: ignore
|
return character_splitter
|
||||||
|
|
||||||
def _split_to_documents_for_estimate(
|
def _split_to_documents_for_estimate(
|
||||||
self, text_docs: list[Document], splitter: TextSplitter, processing_rule: DatasetProcessRule
|
self, text_docs: list[Document], splitter: TextSplitter, processing_rule: DatasetProcessRule
|
||||||
@ -510,7 +513,7 @@ class IndexingRunner:
|
|||||||
dataset: Dataset,
|
dataset: Dataset,
|
||||||
dataset_document: DatasetDocument,
|
dataset_document: DatasetDocument,
|
||||||
documents: list[Document],
|
documents: list[Document],
|
||||||
) -> None:
|
):
|
||||||
"""
|
"""
|
||||||
insert index and update document/segment status to completed
|
insert index and update document/segment status to completed
|
||||||
"""
|
"""
|
||||||
@ -649,7 +652,7 @@ class IndexingRunner:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def _update_document_index_status(
|
def _update_document_index_status(
|
||||||
document_id: str, after_indexing_status: str, extra_update_params: Optional[dict] = None
|
document_id: str, after_indexing_status: str, extra_update_params: Optional[dict] = None
|
||||||
) -> None:
|
):
|
||||||
"""
|
"""
|
||||||
Update the document indexing status.
|
Update the document indexing status.
|
||||||
"""
|
"""
|
||||||
@ -668,7 +671,7 @@ class IndexingRunner:
|
|||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _update_segments_by_document(dataset_document_id: str, update_params: dict) -> None:
|
def _update_segments_by_document(dataset_document_id: str, update_params: dict):
|
||||||
"""
|
"""
|
||||||
Update the document segment by document id.
|
Update the document segment by document id.
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -129,7 +129,7 @@ class LLMGenerator:
|
|||||||
return questions
|
return questions
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def generate_rule_config(cls, tenant_id: str, instruction: str, model_config: dict, no_variable: bool) -> dict:
|
def generate_rule_config(cls, tenant_id: str, instruction: str, model_config: dict, no_variable: bool):
|
||||||
output_parser = RuleConfigGeneratorOutputParser()
|
output_parser = RuleConfigGeneratorOutputParser()
|
||||||
|
|
||||||
error = ""
|
error = ""
|
||||||
@ -264,9 +264,7 @@ class LLMGenerator:
|
|||||||
return rule_config
|
return rule_config
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def generate_code(
|
def generate_code(cls, tenant_id: str, instruction: str, model_config: dict, code_language: str = "javascript"):
|
||||||
cls, tenant_id: str, instruction: str, model_config: dict, code_language: str = "javascript"
|
|
||||||
) -> dict:
|
|
||||||
if code_language == "python":
|
if code_language == "python":
|
||||||
prompt_template = PromptTemplateParser(PYTHON_CODE_GENERATOR_PROMPT_TEMPLATE)
|
prompt_template = PromptTemplateParser(PYTHON_CODE_GENERATOR_PROMPT_TEMPLATE)
|
||||||
else:
|
else:
|
||||||
@ -375,7 +373,7 @@ class LLMGenerator:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def instruction_modify_legacy(
|
def instruction_modify_legacy(
|
||||||
tenant_id: str, flow_id: str, current: str, instruction: str, model_config: dict, ideal_output: str | None
|
tenant_id: str, flow_id: str, current: str, instruction: str, model_config: dict, ideal_output: str | None
|
||||||
) -> dict:
|
):
|
||||||
last_run: Message | None = (
|
last_run: Message | None = (
|
||||||
db.session.query(Message).where(Message.app_id == flow_id).order_by(Message.created_at.desc()).first()
|
db.session.query(Message).where(Message.app_id == flow_id).order_by(Message.created_at.desc()).first()
|
||||||
)
|
)
|
||||||
@ -415,7 +413,7 @@ class LLMGenerator:
|
|||||||
instruction: str,
|
instruction: str,
|
||||||
model_config: dict,
|
model_config: dict,
|
||||||
ideal_output: str | None,
|
ideal_output: str | None,
|
||||||
) -> dict:
|
):
|
||||||
from services.workflow_service import WorkflowService
|
from services.workflow_service import WorkflowService
|
||||||
|
|
||||||
session = db.session()
|
session = db.session()
|
||||||
@ -455,7 +453,7 @@ class LLMGenerator:
|
|||||||
return []
|
return []
|
||||||
parsed: Sequence[AgentLogEvent] = json.loads(raw_agent_log)
|
parsed: Sequence[AgentLogEvent] = json.loads(raw_agent_log)
|
||||||
|
|
||||||
def dict_of_event(event: AgentLogEvent) -> dict:
|
def dict_of_event(event: AgentLogEvent):
|
||||||
return {
|
return {
|
||||||
"status": event.status,
|
"status": event.status,
|
||||||
"error": event.error,
|
"error": event.error,
|
||||||
@ -493,7 +491,7 @@ class LLMGenerator:
|
|||||||
instruction: str,
|
instruction: str,
|
||||||
node_type: str,
|
node_type: str,
|
||||||
ideal_output: str | None,
|
ideal_output: str | None,
|
||||||
) -> dict:
|
):
|
||||||
LAST_RUN = "{{#last_run#}}"
|
LAST_RUN = "{{#last_run#}}"
|
||||||
CURRENT = "{{#current#}}"
|
CURRENT = "{{#current#}}"
|
||||||
ERROR_MESSAGE = "{{#error_message#}}"
|
ERROR_MESSAGE = "{{#error_message#}}"
|
||||||
|
|||||||
@ -1,5 +1,3 @@
|
|||||||
from typing import Any
|
|
||||||
|
|
||||||
from core.llm_generator.output_parser.errors import OutputParserError
|
from core.llm_generator.output_parser.errors import OutputParserError
|
||||||
from core.llm_generator.prompts import (
|
from core.llm_generator.prompts import (
|
||||||
RULE_CONFIG_PARAMETER_GENERATE_TEMPLATE,
|
RULE_CONFIG_PARAMETER_GENERATE_TEMPLATE,
|
||||||
@ -17,7 +15,7 @@ class RuleConfigGeneratorOutputParser:
|
|||||||
RULE_CONFIG_STATEMENT_GENERATE_TEMPLATE,
|
RULE_CONFIG_STATEMENT_GENERATE_TEMPLATE,
|
||||||
)
|
)
|
||||||
|
|
||||||
def parse(self, text: str) -> Any:
|
def parse(self, text: str):
|
||||||
try:
|
try:
|
||||||
expected_keys = ["prompt", "variables", "opening_statement"]
|
expected_keys = ["prompt", "variables", "opening_statement"]
|
||||||
parsed = parse_and_check_json_markdown(text, expected_keys)
|
parsed = parse_and_check_json_markdown(text, expected_keys)
|
||||||
|
|||||||
@ -210,7 +210,7 @@ def _handle_native_json_schema(
|
|||||||
structured_output_schema: Mapping,
|
structured_output_schema: Mapping,
|
||||||
model_parameters: dict,
|
model_parameters: dict,
|
||||||
rules: list[ParameterRule],
|
rules: list[ParameterRule],
|
||||||
) -> dict:
|
):
|
||||||
"""
|
"""
|
||||||
Handle structured output for models with native JSON schema support.
|
Handle structured output for models with native JSON schema support.
|
||||||
|
|
||||||
@ -232,7 +232,7 @@ def _handle_native_json_schema(
|
|||||||
return model_parameters
|
return model_parameters
|
||||||
|
|
||||||
|
|
||||||
def _set_response_format(model_parameters: dict, rules: list) -> None:
|
def _set_response_format(model_parameters: dict, rules: list):
|
||||||
"""
|
"""
|
||||||
Set the appropriate response format parameter based on model rules.
|
Set the appropriate response format parameter based on model rules.
|
||||||
|
|
||||||
@ -306,7 +306,7 @@ def _parse_structured_output(result_text: str) -> Mapping[str, Any]:
|
|||||||
return structured_output
|
return structured_output
|
||||||
|
|
||||||
|
|
||||||
def _prepare_schema_for_model(provider: str, model_schema: AIModelEntity, schema: Mapping) -> dict:
|
def _prepare_schema_for_model(provider: str, model_schema: AIModelEntity, schema: Mapping):
|
||||||
"""
|
"""
|
||||||
Prepare JSON schema based on model requirements.
|
Prepare JSON schema based on model requirements.
|
||||||
|
|
||||||
@ -334,7 +334,7 @@ def _prepare_schema_for_model(provider: str, model_schema: AIModelEntity, schema
|
|||||||
return {"schema": processed_schema, "name": "llm_response"}
|
return {"schema": processed_schema, "name": "llm_response"}
|
||||||
|
|
||||||
|
|
||||||
def remove_additional_properties(schema: dict) -> None:
|
def remove_additional_properties(schema: dict):
|
||||||
"""
|
"""
|
||||||
Remove additionalProperties fields from JSON schema.
|
Remove additionalProperties fields from JSON schema.
|
||||||
Used for models like Gemini that don't support this property.
|
Used for models like Gemini that don't support this property.
|
||||||
@ -357,7 +357,7 @@ def remove_additional_properties(schema: dict) -> None:
|
|||||||
remove_additional_properties(item)
|
remove_additional_properties(item)
|
||||||
|
|
||||||
|
|
||||||
def convert_boolean_to_string(schema: dict) -> None:
|
def convert_boolean_to_string(schema: dict):
|
||||||
"""
|
"""
|
||||||
Convert boolean type specifications to string in JSON schema.
|
Convert boolean type specifications to string in JSON schema.
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,5 @@
|
|||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from core.llm_generator.prompts import SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT
|
from core.llm_generator.prompts import SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT
|
||||||
|
|
||||||
@ -9,7 +8,7 @@ class SuggestedQuestionsAfterAnswerOutputParser:
|
|||||||
def get_format_instructions(self) -> str:
|
def get_format_instructions(self) -> str:
|
||||||
return SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT
|
return SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT
|
||||||
|
|
||||||
def parse(self, text: str) -> Any:
|
def parse(self, text: str):
|
||||||
action_match = re.search(r"\[.*?\]", text.strip(), re.DOTALL)
|
action_match = re.search(r"\[.*?\]", text.strip(), re.DOTALL)
|
||||||
if action_match is not None:
|
if action_match is not None:
|
||||||
json_obj = json.loads(action_match.group(0).strip())
|
json_obj = json.loads(action_match.group(0).strip())
|
||||||
|
|||||||
@ -44,7 +44,7 @@ class OAuthClientProvider:
|
|||||||
return None
|
return None
|
||||||
return OAuthClientInformation.model_validate(client_information)
|
return OAuthClientInformation.model_validate(client_information)
|
||||||
|
|
||||||
def save_client_information(self, client_information: OAuthClientInformationFull) -> None:
|
def save_client_information(self, client_information: OAuthClientInformationFull):
|
||||||
"""Saves client information after dynamic registration."""
|
"""Saves client information after dynamic registration."""
|
||||||
MCPToolManageService.update_mcp_provider_credentials(
|
MCPToolManageService.update_mcp_provider_credentials(
|
||||||
self.mcp_provider,
|
self.mcp_provider,
|
||||||
@ -63,13 +63,13 @@ class OAuthClientProvider:
|
|||||||
refresh_token=credentials.get("refresh_token", ""),
|
refresh_token=credentials.get("refresh_token", ""),
|
||||||
)
|
)
|
||||||
|
|
||||||
def save_tokens(self, tokens: OAuthTokens) -> None:
|
def save_tokens(self, tokens: OAuthTokens):
|
||||||
"""Stores new OAuth tokens for the current session."""
|
"""Stores new OAuth tokens for the current session."""
|
||||||
# update mcp provider credentials
|
# update mcp provider credentials
|
||||||
token_dict = tokens.model_dump()
|
token_dict = tokens.model_dump()
|
||||||
MCPToolManageService.update_mcp_provider_credentials(self.mcp_provider, token_dict, authed=True)
|
MCPToolManageService.update_mcp_provider_credentials(self.mcp_provider, token_dict, authed=True)
|
||||||
|
|
||||||
def save_code_verifier(self, code_verifier: str) -> None:
|
def save_code_verifier(self, code_verifier: str):
|
||||||
"""Saves a PKCE code verifier for the current session."""
|
"""Saves a PKCE code verifier for the current session."""
|
||||||
MCPToolManageService.update_mcp_provider_credentials(self.mcp_provider, {"code_verifier": code_verifier})
|
MCPToolManageService.update_mcp_provider_credentials(self.mcp_provider, {"code_verifier": code_verifier})
|
||||||
|
|
||||||
|
|||||||
@ -47,7 +47,7 @@ class SSETransport:
|
|||||||
headers: dict[str, Any] | None = None,
|
headers: dict[str, Any] | None = None,
|
||||||
timeout: float = 5.0,
|
timeout: float = 5.0,
|
||||||
sse_read_timeout: float = 5 * 60,
|
sse_read_timeout: float = 5 * 60,
|
||||||
) -> None:
|
):
|
||||||
"""Initialize the SSE transport.
|
"""Initialize the SSE transport.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -76,7 +76,7 @@ class SSETransport:
|
|||||||
|
|
||||||
return url_parsed.netloc == endpoint_parsed.netloc and url_parsed.scheme == endpoint_parsed.scheme
|
return url_parsed.netloc == endpoint_parsed.netloc and url_parsed.scheme == endpoint_parsed.scheme
|
||||||
|
|
||||||
def _handle_endpoint_event(self, sse_data: str, status_queue: StatusQueue) -> None:
|
def _handle_endpoint_event(self, sse_data: str, status_queue: StatusQueue):
|
||||||
"""Handle an 'endpoint' SSE event.
|
"""Handle an 'endpoint' SSE event.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -94,7 +94,7 @@ class SSETransport:
|
|||||||
|
|
||||||
status_queue.put(_StatusReady(endpoint_url))
|
status_queue.put(_StatusReady(endpoint_url))
|
||||||
|
|
||||||
def _handle_message_event(self, sse_data: str, read_queue: ReadQueue) -> None:
|
def _handle_message_event(self, sse_data: str, read_queue: ReadQueue):
|
||||||
"""Handle a 'message' SSE event.
|
"""Handle a 'message' SSE event.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -110,7 +110,7 @@ class SSETransport:
|
|||||||
logger.exception("Error parsing server message")
|
logger.exception("Error parsing server message")
|
||||||
read_queue.put(exc)
|
read_queue.put(exc)
|
||||||
|
|
||||||
def _handle_sse_event(self, sse: ServerSentEvent, read_queue: ReadQueue, status_queue: StatusQueue) -> None:
|
def _handle_sse_event(self, sse: ServerSentEvent, read_queue: ReadQueue, status_queue: StatusQueue):
|
||||||
"""Handle a single SSE event.
|
"""Handle a single SSE event.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -126,7 +126,7 @@ class SSETransport:
|
|||||||
case _:
|
case _:
|
||||||
logger.warning("Unknown SSE event: %s", sse.event)
|
logger.warning("Unknown SSE event: %s", sse.event)
|
||||||
|
|
||||||
def sse_reader(self, event_source: EventSource, read_queue: ReadQueue, status_queue: StatusQueue) -> None:
|
def sse_reader(self, event_source: EventSource, read_queue: ReadQueue, status_queue: StatusQueue):
|
||||||
"""Read and process SSE events.
|
"""Read and process SSE events.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -144,7 +144,7 @@ class SSETransport:
|
|||||||
finally:
|
finally:
|
||||||
read_queue.put(None)
|
read_queue.put(None)
|
||||||
|
|
||||||
def _send_message(self, client: httpx.Client, endpoint_url: str, message: SessionMessage) -> None:
|
def _send_message(self, client: httpx.Client, endpoint_url: str, message: SessionMessage):
|
||||||
"""Send a single message to the server.
|
"""Send a single message to the server.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -163,7 +163,7 @@ class SSETransport:
|
|||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
logger.debug("Client message sent successfully: %s", response.status_code)
|
logger.debug("Client message sent successfully: %s", response.status_code)
|
||||||
|
|
||||||
def post_writer(self, client: httpx.Client, endpoint_url: str, write_queue: WriteQueue) -> None:
|
def post_writer(self, client: httpx.Client, endpoint_url: str, write_queue: WriteQueue):
|
||||||
"""Handle writing messages to the server.
|
"""Handle writing messages to the server.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -303,7 +303,7 @@ def sse_client(
|
|||||||
write_queue.put(None)
|
write_queue.put(None)
|
||||||
|
|
||||||
|
|
||||||
def send_message(http_client: httpx.Client, endpoint_url: str, session_message: SessionMessage) -> None:
|
def send_message(http_client: httpx.Client, endpoint_url: str, session_message: SessionMessage):
|
||||||
"""
|
"""
|
||||||
Send a message to the server using the provided HTTP client.
|
Send a message to the server using the provided HTTP client.
|
||||||
|
|
||||||
|
|||||||
@ -82,7 +82,7 @@ class StreamableHTTPTransport:
|
|||||||
headers: dict[str, Any] | None = None,
|
headers: dict[str, Any] | None = None,
|
||||||
timeout: float | timedelta = 30,
|
timeout: float | timedelta = 30,
|
||||||
sse_read_timeout: float | timedelta = 60 * 5,
|
sse_read_timeout: float | timedelta = 60 * 5,
|
||||||
) -> None:
|
):
|
||||||
"""Initialize the StreamableHTTP transport.
|
"""Initialize the StreamableHTTP transport.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -122,7 +122,7 @@ class StreamableHTTPTransport:
|
|||||||
def _maybe_extract_session_id_from_response(
|
def _maybe_extract_session_id_from_response(
|
||||||
self,
|
self,
|
||||||
response: httpx.Response,
|
response: httpx.Response,
|
||||||
) -> None:
|
):
|
||||||
"""Extract and store session ID from response headers."""
|
"""Extract and store session ID from response headers."""
|
||||||
new_session_id = response.headers.get(MCP_SESSION_ID)
|
new_session_id = response.headers.get(MCP_SESSION_ID)
|
||||||
if new_session_id:
|
if new_session_id:
|
||||||
@ -173,7 +173,7 @@ class StreamableHTTPTransport:
|
|||||||
self,
|
self,
|
||||||
client: httpx.Client,
|
client: httpx.Client,
|
||||||
server_to_client_queue: ServerToClientQueue,
|
server_to_client_queue: ServerToClientQueue,
|
||||||
) -> None:
|
):
|
||||||
"""Handle GET stream for server-initiated messages."""
|
"""Handle GET stream for server-initiated messages."""
|
||||||
try:
|
try:
|
||||||
if not self.session_id:
|
if not self.session_id:
|
||||||
@ -197,7 +197,7 @@ class StreamableHTTPTransport:
|
|||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.debug("GET stream error (non-fatal): %s", exc)
|
logger.debug("GET stream error (non-fatal): %s", exc)
|
||||||
|
|
||||||
def _handle_resumption_request(self, ctx: RequestContext) -> None:
|
def _handle_resumption_request(self, ctx: RequestContext):
|
||||||
"""Handle a resumption request using GET with SSE."""
|
"""Handle a resumption request using GET with SSE."""
|
||||||
headers = self._update_headers_with_session(ctx.headers)
|
headers = self._update_headers_with_session(ctx.headers)
|
||||||
if ctx.metadata and ctx.metadata.resumption_token:
|
if ctx.metadata and ctx.metadata.resumption_token:
|
||||||
@ -230,7 +230,7 @@ class StreamableHTTPTransport:
|
|||||||
if is_complete:
|
if is_complete:
|
||||||
break
|
break
|
||||||
|
|
||||||
def _handle_post_request(self, ctx: RequestContext) -> None:
|
def _handle_post_request(self, ctx: RequestContext):
|
||||||
"""Handle a POST request with response processing."""
|
"""Handle a POST request with response processing."""
|
||||||
headers = self._update_headers_with_session(ctx.headers)
|
headers = self._update_headers_with_session(ctx.headers)
|
||||||
message = ctx.session_message.message
|
message = ctx.session_message.message
|
||||||
@ -278,7 +278,7 @@ class StreamableHTTPTransport:
|
|||||||
self,
|
self,
|
||||||
response: httpx.Response,
|
response: httpx.Response,
|
||||||
server_to_client_queue: ServerToClientQueue,
|
server_to_client_queue: ServerToClientQueue,
|
||||||
) -> None:
|
):
|
||||||
"""Handle JSON response from the server."""
|
"""Handle JSON response from the server."""
|
||||||
try:
|
try:
|
||||||
content = response.read()
|
content = response.read()
|
||||||
@ -288,7 +288,7 @@ class StreamableHTTPTransport:
|
|||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
server_to_client_queue.put(exc)
|
server_to_client_queue.put(exc)
|
||||||
|
|
||||||
def _handle_sse_response(self, response: httpx.Response, ctx: RequestContext) -> None:
|
def _handle_sse_response(self, response: httpx.Response, ctx: RequestContext):
|
||||||
"""Handle SSE response from the server."""
|
"""Handle SSE response from the server."""
|
||||||
try:
|
try:
|
||||||
event_source = EventSource(response)
|
event_source = EventSource(response)
|
||||||
@ -307,7 +307,7 @@ class StreamableHTTPTransport:
|
|||||||
self,
|
self,
|
||||||
content_type: str,
|
content_type: str,
|
||||||
server_to_client_queue: ServerToClientQueue,
|
server_to_client_queue: ServerToClientQueue,
|
||||||
) -> None:
|
):
|
||||||
"""Handle unexpected content type in response."""
|
"""Handle unexpected content type in response."""
|
||||||
error_msg = f"Unexpected content type: {content_type}"
|
error_msg = f"Unexpected content type: {content_type}"
|
||||||
logger.error(error_msg)
|
logger.error(error_msg)
|
||||||
@ -317,7 +317,7 @@ class StreamableHTTPTransport:
|
|||||||
self,
|
self,
|
||||||
server_to_client_queue: ServerToClientQueue,
|
server_to_client_queue: ServerToClientQueue,
|
||||||
request_id: RequestId,
|
request_id: RequestId,
|
||||||
) -> None:
|
):
|
||||||
"""Send a session terminated error response."""
|
"""Send a session terminated error response."""
|
||||||
jsonrpc_error = JSONRPCError(
|
jsonrpc_error = JSONRPCError(
|
||||||
jsonrpc="2.0",
|
jsonrpc="2.0",
|
||||||
@ -333,7 +333,7 @@ class StreamableHTTPTransport:
|
|||||||
client_to_server_queue: ClientToServerQueue,
|
client_to_server_queue: ClientToServerQueue,
|
||||||
server_to_client_queue: ServerToClientQueue,
|
server_to_client_queue: ServerToClientQueue,
|
||||||
start_get_stream: Callable[[], None],
|
start_get_stream: Callable[[], None],
|
||||||
) -> None:
|
):
|
||||||
"""Handle writing requests to the server.
|
"""Handle writing requests to the server.
|
||||||
|
|
||||||
This method processes messages from the client_to_server_queue and sends them to the server.
|
This method processes messages from the client_to_server_queue and sends them to the server.
|
||||||
@ -379,7 +379,7 @@ class StreamableHTTPTransport:
|
|||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
server_to_client_queue.put(exc)
|
server_to_client_queue.put(exc)
|
||||||
|
|
||||||
def terminate_session(self, client: httpx.Client) -> None:
|
def terminate_session(self, client: httpx.Client):
|
||||||
"""Terminate the session by sending a DELETE request."""
|
"""Terminate the session by sending a DELETE request."""
|
||||||
if not self.session_id:
|
if not self.session_id:
|
||||||
return
|
return
|
||||||
@ -441,7 +441,7 @@ def streamablehttp_client(
|
|||||||
timeout=httpx.Timeout(transport.timeout, read=transport.sse_read_timeout),
|
timeout=httpx.Timeout(transport.timeout, read=transport.sse_read_timeout),
|
||||||
) as client:
|
) as client:
|
||||||
# Define callbacks that need access to thread pool
|
# Define callbacks that need access to thread pool
|
||||||
def start_get_stream() -> None:
|
def start_get_stream():
|
||||||
"""Start a worker thread to handle server-initiated messages."""
|
"""Start a worker thread to handle server-initiated messages."""
|
||||||
executor.submit(transport.handle_get_stream, client, server_to_client_queue)
|
executor.submit(transport.handle_get_stream, client, server_to_client_queue)
|
||||||
|
|
||||||
|
|||||||
@ -76,7 +76,7 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
|
|||||||
ReceiveNotificationT
|
ReceiveNotificationT
|
||||||
]""",
|
]""",
|
||||||
on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], Any],
|
on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], Any],
|
||||||
) -> None:
|
):
|
||||||
self.request_id = request_id
|
self.request_id = request_id
|
||||||
self.request_meta = request_meta
|
self.request_meta = request_meta
|
||||||
self.request = request
|
self.request = request
|
||||||
@ -95,7 +95,7 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
|
|||||||
exc_type: type[BaseException] | None,
|
exc_type: type[BaseException] | None,
|
||||||
exc_val: BaseException | None,
|
exc_val: BaseException | None,
|
||||||
exc_tb: TracebackType | None,
|
exc_tb: TracebackType | None,
|
||||||
) -> None:
|
):
|
||||||
"""Exit the context manager, performing cleanup and notifying completion."""
|
"""Exit the context manager, performing cleanup and notifying completion."""
|
||||||
try:
|
try:
|
||||||
if self._completed:
|
if self._completed:
|
||||||
@ -103,7 +103,7 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
|
|||||||
finally:
|
finally:
|
||||||
self._entered = False
|
self._entered = False
|
||||||
|
|
||||||
def respond(self, response: SendResultT | ErrorData) -> None:
|
def respond(self, response: SendResultT | ErrorData):
|
||||||
"""Send a response for this request.
|
"""Send a response for this request.
|
||||||
|
|
||||||
Must be called within a context manager block.
|
Must be called within a context manager block.
|
||||||
@ -119,7 +119,7 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
|
|||||||
|
|
||||||
self._session._send_response(request_id=self.request_id, response=response)
|
self._session._send_response(request_id=self.request_id, response=response)
|
||||||
|
|
||||||
def cancel(self) -> None:
|
def cancel(self):
|
||||||
"""Cancel this request and mark it as completed."""
|
"""Cancel this request and mark it as completed."""
|
||||||
if not self._entered:
|
if not self._entered:
|
||||||
raise RuntimeError("RequestResponder must be used as a context manager")
|
raise RuntimeError("RequestResponder must be used as a context manager")
|
||||||
@ -163,7 +163,7 @@ class BaseSession(
|
|||||||
receive_notification_type: type[ReceiveNotificationT],
|
receive_notification_type: type[ReceiveNotificationT],
|
||||||
# If none, reading will never time out
|
# If none, reading will never time out
|
||||||
read_timeout_seconds: timedelta | None = None,
|
read_timeout_seconds: timedelta | None = None,
|
||||||
) -> None:
|
):
|
||||||
self._read_stream = read_stream
|
self._read_stream = read_stream
|
||||||
self._write_stream = write_stream
|
self._write_stream = write_stream
|
||||||
self._response_streams = {}
|
self._response_streams = {}
|
||||||
@ -183,7 +183,7 @@ class BaseSession(
|
|||||||
self._receiver_future = self._executor.submit(self._receive_loop)
|
self._receiver_future = self._executor.submit(self._receive_loop)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def check_receiver_status(self) -> None:
|
def check_receiver_status(self):
|
||||||
"""`check_receiver_status` ensures that any exceptions raised during the
|
"""`check_receiver_status` ensures that any exceptions raised during the
|
||||||
execution of `_receive_loop` are retrieved and propagated."""
|
execution of `_receive_loop` are retrieved and propagated."""
|
||||||
if self._receiver_future and self._receiver_future.done():
|
if self._receiver_future and self._receiver_future.done():
|
||||||
@ -191,7 +191,7 @@ class BaseSession(
|
|||||||
|
|
||||||
def __exit__(
|
def __exit__(
|
||||||
self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None
|
self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None
|
||||||
) -> None:
|
):
|
||||||
self._read_stream.put(None)
|
self._read_stream.put(None)
|
||||||
self._write_stream.put(None)
|
self._write_stream.put(None)
|
||||||
|
|
||||||
@ -277,7 +277,7 @@ class BaseSession(
|
|||||||
self,
|
self,
|
||||||
notification: SendNotificationT,
|
notification: SendNotificationT,
|
||||||
related_request_id: RequestId | None = None,
|
related_request_id: RequestId | None = None,
|
||||||
) -> None:
|
):
|
||||||
"""
|
"""
|
||||||
Emits a notification, which is a one-way message that does not expect
|
Emits a notification, which is a one-way message that does not expect
|
||||||
a response.
|
a response.
|
||||||
@ -296,7 +296,7 @@ class BaseSession(
|
|||||||
)
|
)
|
||||||
self._write_stream.put(session_message)
|
self._write_stream.put(session_message)
|
||||||
|
|
||||||
def _send_response(self, request_id: RequestId, response: SendResultT | ErrorData) -> None:
|
def _send_response(self, request_id: RequestId, response: SendResultT | ErrorData):
|
||||||
if isinstance(response, ErrorData):
|
if isinstance(response, ErrorData):
|
||||||
jsonrpc_error = JSONRPCError(jsonrpc="2.0", id=request_id, error=response)
|
jsonrpc_error = JSONRPCError(jsonrpc="2.0", id=request_id, error=response)
|
||||||
session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_error))
|
session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_error))
|
||||||
@ -310,7 +310,7 @@ class BaseSession(
|
|||||||
session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_response))
|
session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_response))
|
||||||
self._write_stream.put(session_message)
|
self._write_stream.put(session_message)
|
||||||
|
|
||||||
def _receive_loop(self) -> None:
|
def _receive_loop(self):
|
||||||
"""
|
"""
|
||||||
Main message processing loop.
|
Main message processing loop.
|
||||||
In a real synchronous implementation, this would likely run in a separate thread.
|
In a real synchronous implementation, this would likely run in a separate thread.
|
||||||
@ -382,7 +382,7 @@ class BaseSession(
|
|||||||
logger.exception("Error in message processing loop")
|
logger.exception("Error in message processing loop")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def _received_request(self, responder: RequestResponder[ReceiveRequestT, SendResultT]) -> None:
|
def _received_request(self, responder: RequestResponder[ReceiveRequestT, SendResultT]):
|
||||||
"""
|
"""
|
||||||
Can be overridden by subclasses to handle a request without needing to
|
Can be overridden by subclasses to handle a request without needing to
|
||||||
listen on the message stream.
|
listen on the message stream.
|
||||||
@ -391,15 +391,13 @@ class BaseSession(
|
|||||||
forwarded on to the message stream.
|
forwarded on to the message stream.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _received_notification(self, notification: ReceiveNotificationT) -> None:
|
def _received_notification(self, notification: ReceiveNotificationT):
|
||||||
"""
|
"""
|
||||||
Can be overridden by subclasses to handle a notification without needing
|
Can be overridden by subclasses to handle a notification without needing
|
||||||
to listen on the message stream.
|
to listen on the message stream.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def send_progress_notification(
|
def send_progress_notification(self, progress_token: str | int, progress: float, total: float | None = None):
|
||||||
self, progress_token: str | int, progress: float, total: float | None = None
|
|
||||||
) -> None:
|
|
||||||
"""
|
"""
|
||||||
Sends a progress notification for a request that is currently being
|
Sends a progress notification for a request that is currently being
|
||||||
processed.
|
processed.
|
||||||
@ -408,5 +406,5 @@ class BaseSession(
|
|||||||
def _handle_incoming(
|
def _handle_incoming(
|
||||||
self,
|
self,
|
||||||
req: RequestResponder[ReceiveRequestT, SendResultT] | ReceiveNotificationT | Exception,
|
req: RequestResponder[ReceiveRequestT, SendResultT] | ReceiveNotificationT | Exception,
|
||||||
) -> None:
|
):
|
||||||
"""A generic handler for incoming messages. Overwritten by subclasses."""
|
"""A generic handler for incoming messages. Overwritten by subclasses."""
|
||||||
|
|||||||
@ -28,19 +28,19 @@ class LoggingFnT(Protocol):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
params: types.LoggingMessageNotificationParams,
|
params: types.LoggingMessageNotificationParams,
|
||||||
) -> None: ...
|
): ...
|
||||||
|
|
||||||
|
|
||||||
class MessageHandlerFnT(Protocol):
|
class MessageHandlerFnT(Protocol):
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
|
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
|
||||||
) -> None: ...
|
): ...
|
||||||
|
|
||||||
|
|
||||||
def _default_message_handler(
|
def _default_message_handler(
|
||||||
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
|
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
|
||||||
) -> None:
|
):
|
||||||
if isinstance(message, Exception):
|
if isinstance(message, Exception):
|
||||||
raise ValueError(str(message))
|
raise ValueError(str(message))
|
||||||
elif isinstance(message, (types.ServerNotification | RequestResponder)):
|
elif isinstance(message, (types.ServerNotification | RequestResponder)):
|
||||||
@ -68,7 +68,7 @@ def _default_list_roots_callback(
|
|||||||
|
|
||||||
def _default_logging_callback(
|
def _default_logging_callback(
|
||||||
params: types.LoggingMessageNotificationParams,
|
params: types.LoggingMessageNotificationParams,
|
||||||
) -> None:
|
):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@ -94,7 +94,7 @@ class ClientSession(
|
|||||||
logging_callback: LoggingFnT | None = None,
|
logging_callback: LoggingFnT | None = None,
|
||||||
message_handler: MessageHandlerFnT | None = None,
|
message_handler: MessageHandlerFnT | None = None,
|
||||||
client_info: types.Implementation | None = None,
|
client_info: types.Implementation | None = None,
|
||||||
) -> None:
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
read_stream,
|
read_stream,
|
||||||
write_stream,
|
write_stream,
|
||||||
@ -155,9 +155,7 @@ class ClientSession(
|
|||||||
types.EmptyResult,
|
types.EmptyResult,
|
||||||
)
|
)
|
||||||
|
|
||||||
def send_progress_notification(
|
def send_progress_notification(self, progress_token: str | int, progress: float, total: float | None = None):
|
||||||
self, progress_token: str | int, progress: float, total: float | None = None
|
|
||||||
) -> None:
|
|
||||||
"""Send a progress notification."""
|
"""Send a progress notification."""
|
||||||
self.send_notification(
|
self.send_notification(
|
||||||
types.ClientNotification(
|
types.ClientNotification(
|
||||||
@ -314,7 +312,7 @@ class ClientSession(
|
|||||||
types.ListToolsResult,
|
types.ListToolsResult,
|
||||||
)
|
)
|
||||||
|
|
||||||
def send_roots_list_changed(self) -> None:
|
def send_roots_list_changed(self):
|
||||||
"""Send a roots/list_changed notification."""
|
"""Send a roots/list_changed notification."""
|
||||||
self.send_notification(
|
self.send_notification(
|
||||||
types.ClientNotification(
|
types.ClientNotification(
|
||||||
@ -324,7 +322,7 @@ class ClientSession(
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
def _received_request(self, responder: RequestResponder[types.ServerRequest, types.ClientResult]) -> None:
|
def _received_request(self, responder: RequestResponder[types.ServerRequest, types.ClientResult]):
|
||||||
ctx = RequestContext[ClientSession, Any](
|
ctx = RequestContext[ClientSession, Any](
|
||||||
request_id=responder.request_id,
|
request_id=responder.request_id,
|
||||||
meta=responder.request_meta,
|
meta=responder.request_meta,
|
||||||
@ -352,11 +350,11 @@ class ClientSession(
|
|||||||
def _handle_incoming(
|
def _handle_incoming(
|
||||||
self,
|
self,
|
||||||
req: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
|
req: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
|
||||||
) -> None:
|
):
|
||||||
"""Handle incoming messages by forwarding to the message handler."""
|
"""Handle incoming messages by forwarding to the message handler."""
|
||||||
self._message_handler(req)
|
self._message_handler(req)
|
||||||
|
|
||||||
def _received_notification(self, notification: types.ServerNotification) -> None:
|
def _received_notification(self, notification: types.ServerNotification):
|
||||||
"""Handle notifications from the server."""
|
"""Handle notifications from the server."""
|
||||||
# Process specific notification types
|
# Process specific notification types
|
||||||
match notification.root:
|
match notification.root:
|
||||||
|
|||||||
@ -27,7 +27,7 @@ class TokenBufferMemory:
|
|||||||
self,
|
self,
|
||||||
conversation: Conversation,
|
conversation: Conversation,
|
||||||
model_instance: ModelInstance,
|
model_instance: ModelInstance,
|
||||||
) -> None:
|
):
|
||||||
self.conversation = conversation
|
self.conversation = conversation
|
||||||
self.model_instance = model_instance
|
self.model_instance = model_instance
|
||||||
|
|
||||||
@ -124,6 +124,7 @@ class TokenBufferMemory:
|
|||||||
|
|
||||||
messages = list(reversed(thread_messages))
|
messages = list(reversed(thread_messages))
|
||||||
|
|
||||||
|
curr_message_tokens = 0
|
||||||
prompt_messages: list[PromptMessage] = []
|
prompt_messages: list[PromptMessage] = []
|
||||||
for message in messages:
|
for message in messages:
|
||||||
# Process user message with files
|
# Process user message with files
|
||||||
|
|||||||
@ -32,7 +32,7 @@ class ModelInstance:
|
|||||||
Model instance class
|
Model instance class
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, provider_model_bundle: ProviderModelBundle, model: str) -> None:
|
def __init__(self, provider_model_bundle: ProviderModelBundle, model: str):
|
||||||
self.provider_model_bundle = provider_model_bundle
|
self.provider_model_bundle = provider_model_bundle
|
||||||
self.model = model
|
self.model = model
|
||||||
self.provider = provider_model_bundle.configuration.provider.provider
|
self.provider = provider_model_bundle.configuration.provider.provider
|
||||||
@ -46,7 +46,7 @@ class ModelInstance:
|
|||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _fetch_credentials_from_bundle(provider_model_bundle: ProviderModelBundle, model: str) -> dict:
|
def _fetch_credentials_from_bundle(provider_model_bundle: ProviderModelBundle, model: str):
|
||||||
"""
|
"""
|
||||||
Fetch credentials from provider model bundle
|
Fetch credentials from provider model bundle
|
||||||
:param provider_model_bundle: provider model bundle
|
:param provider_model_bundle: provider model bundle
|
||||||
@ -342,7 +342,7 @@ class ModelInstance:
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def _round_robin_invoke(self, function: Callable[..., Any], *args, **kwargs) -> Any:
|
def _round_robin_invoke(self, function: Callable[..., Any], *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
Round-robin invoke
|
Round-robin invoke
|
||||||
:param function: function to invoke
|
:param function: function to invoke
|
||||||
@ -379,7 +379,7 @@ class ModelInstance:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
def get_tts_voices(self, language: Optional[str] = None) -> list:
|
def get_tts_voices(self, language: Optional[str] = None):
|
||||||
"""
|
"""
|
||||||
Invoke large language tts model voices
|
Invoke large language tts model voices
|
||||||
|
|
||||||
@ -394,7 +394,7 @@ class ModelInstance:
|
|||||||
|
|
||||||
|
|
||||||
class ModelManager:
|
class ModelManager:
|
||||||
def __init__(self) -> None:
|
def __init__(self):
|
||||||
self._provider_manager = ProviderManager()
|
self._provider_manager = ProviderManager()
|
||||||
|
|
||||||
def get_model_instance(self, tenant_id: str, provider: str, model_type: ModelType, model: str) -> ModelInstance:
|
def get_model_instance(self, tenant_id: str, provider: str, model_type: ModelType, model: str) -> ModelInstance:
|
||||||
@ -453,7 +453,7 @@ class LBModelManager:
|
|||||||
model: str,
|
model: str,
|
||||||
load_balancing_configs: list[ModelLoadBalancingConfiguration],
|
load_balancing_configs: list[ModelLoadBalancingConfiguration],
|
||||||
managed_credentials: Optional[dict] = None,
|
managed_credentials: Optional[dict] = None,
|
||||||
) -> None:
|
):
|
||||||
"""
|
"""
|
||||||
Load balancing model manager
|
Load balancing model manager
|
||||||
:param tenant_id: tenant_id
|
:param tenant_id: tenant_id
|
||||||
@ -534,7 +534,7 @@ model: %s""",
|
|||||||
|
|
||||||
return config
|
return config
|
||||||
|
|
||||||
def cooldown(self, config: ModelLoadBalancingConfiguration, expire: int = 60) -> None:
|
def cooldown(self, config: ModelLoadBalancingConfiguration, expire: int = 60):
|
||||||
"""
|
"""
|
||||||
Cooldown model load balancing config
|
Cooldown model load balancing config
|
||||||
:param config: model load balancing config
|
:param config: model load balancing config
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user