merge main

This commit is contained in:
JzoNg 2024-12-16 11:09:13 +08:00
commit 171cc88a0d
411 changed files with 11923 additions and 6592 deletions

View File

@ -8,5 +8,6 @@ echo 'alias start-api="cd /workspaces/dify/api && poetry run python -m flask run
echo 'alias start-worker="cd /workspaces/dify/api && poetry run python -m celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion"' >> ~/.bashrc
echo 'alias start-web="cd /workspaces/dify/web && pnpm dev"' >> ~/.bashrc
echo 'alias start-containers="cd /workspaces/dify/docker && docker-compose -f docker-compose.middleware.yaml -p dify up -d"' >> ~/.bashrc
echo 'alias stop-containers="cd /workspaces/dify/docker && docker-compose -f docker-compose.middleware.yaml -p dify down"' >> ~/.bashrc
source /home/vscode/.bashrc

View File

@ -37,6 +37,7 @@ jobs:
- name: Ruff check
if: steps.changed-files.outputs.any_changed == 'true'
run: |
poetry run -C api ruff --version
poetry run -C api ruff check ./api
poetry run -C api ruff format --check ./api

View File

@ -51,7 +51,7 @@ jobs:
- name: Expose Service Ports
run: sh .github/workflows/expose_service_ports.sh
- name: Set up Vector Stores (Weaviate, Qdrant, PGVector, Milvus, PgVecto-RS, Chroma, MyScale, ElasticSearch, Couchbase)
- name: Set up Vector Stores (TiDB, Weaviate, Qdrant, PGVector, Milvus, PgVecto-RS, Chroma, MyScale, ElasticSearch, Couchbase)
uses: hoverkraft-tech/compose-action@v2.0.2
with:
compose-file: |
@ -67,6 +67,7 @@ jobs:
pgvector
chroma
elasticsearch
tidb
- name: Test Vector Stores
run: poetry run -C api bash dev/pytest/pytest_vdb.sh

View File

@ -56,20 +56,36 @@ DB_DATABASE=dify
# Storage configuration
# use for store upload files, private keys...
# storage type: local, s3, aliyun-oss, azure-blob, baidu-obs, google-storage, huawei-obs, oci-storage, tencent-cos, volcengine-tos, supabase
STORAGE_TYPE=local
STORAGE_LOCAL_PATH=storage
# storage type: opendal, s3, aliyun-oss, azure-blob, baidu-obs, google-storage, huawei-obs, oci-storage, tencent-cos, volcengine-tos, supabase
STORAGE_TYPE=opendal
# Apache OpenDAL storage configuration, refer to https://github.com/apache/opendal
STORAGE_OPENDAL_SCHEME=fs
# OpenDAL FS
OPENDAL_FS_ROOT=storage
# OpenDAL S3
OPENDAL_S3_ROOT=/
OPENDAL_S3_BUCKET=your-bucket-name
OPENDAL_S3_ENDPOINT=https://s3.amazonaws.com
OPENDAL_S3_ACCESS_KEY_ID=your-access-key
OPENDAL_S3_SECRET_ACCESS_KEY=your-secret-key
OPENDAL_S3_REGION=your-region
OPENDAL_S3_SERVER_SIDE_ENCRYPTION=
# S3 Storage configuration
S3_USE_AWS_MANAGED_IAM=false
S3_ENDPOINT=https://your-bucket-name.storage.s3.clooudflare.com
S3_BUCKET_NAME=your-bucket-name
S3_ACCESS_KEY=your-access-key
S3_SECRET_KEY=your-secret-key
S3_REGION=your-region
# Azure Blob Storage configuration
AZURE_BLOB_ACCOUNT_NAME=your-account-name
AZURE_BLOB_ACCOUNT_KEY=your-account-key
AZURE_BLOB_CONTAINER_NAME=yout-container-name
AZURE_BLOB_ACCOUNT_URL=https://<your_account_name>.blob.core.windows.net
# Aliyun oss Storage configuration
ALIYUN_OSS_BUCKET_NAME=your-bucket-name
ALIYUN_OSS_ACCESS_KEY=your-access-key
@ -79,6 +95,7 @@ ALIYUN_OSS_AUTH_VERSION=v1
ALIYUN_OSS_REGION=your-region
# Don't start with '/'. OSS doesn't support leading slash in object names.
ALIYUN_OSS_PATH=your-path
# Google Storage configuration
GOOGLE_STORAGE_BUCKET_NAME=yout-bucket-name
GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64=your-google-service-account-json-base64-string
@ -125,8 +142,8 @@ SUPABASE_URL=your-server-url
WEB_API_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
CONSOLE_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
# Vector database configuration, support: weaviate, qdrant, milvus, myscale, relyt, pgvecto_rs, pgvector, pgvector, chroma, opensearch, tidb_vector, couchbase, vikingdb, upstash, lindorm, oceanbase
# Vector database configuration
# support: weaviate, qdrant, milvus, myscale, relyt, pgvecto_rs, pgvector, pgvector, chroma, opensearch, tidb_vector, couchbase, vikingdb, upstash, lindorm, oceanbase
VECTOR_STORE=weaviate
# Weaviate configuration
@ -277,6 +294,7 @@ VIKINGDB_SOCKET_TIMEOUT=30
LINDORM_URL=http://ld-*******************-proxy-search-pub.lindorm.aliyuncs.com:30070
LINDORM_USERNAME=admin
LINDORM_PASSWORD=admin
USING_UGC_INDEX=False
# OceanBase Vector configuration
OCEANBASE_VECTOR_HOST=127.0.0.1
@ -381,6 +399,8 @@ LOG_FILE_BACKUP_COUNT=5
LOG_DATEFORMAT=%Y-%m-%d %H:%M:%S
# Log Timezone
LOG_TZ=UTC
# Log format
LOG_FORMAT=%(asctime)s,%(msecs)d %(levelname)-2s [%(filename)s:%(lineno)d] %(req_id)s %(message)s
# Indexing configuration
INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH=4000

View File

@ -1,11 +1,51 @@
from pydantic_settings import SettingsConfigDict
import logging
from typing import Any
from configs.deploy import DeploymentConfig
from configs.enterprise import EnterpriseFeatureConfig
from configs.extra import ExtraServiceConfig
from configs.feature import FeatureConfig
from configs.middleware import MiddlewareConfig
from configs.packaging import PackagingInfo
from pydantic.fields import FieldInfo
from pydantic_settings import BaseSettings, PydanticBaseSettingsSource, SettingsConfigDict
from .deploy import DeploymentConfig
from .enterprise import EnterpriseFeatureConfig
from .extra import ExtraServiceConfig
from .feature import FeatureConfig
from .middleware import MiddlewareConfig
from .packaging import PackagingInfo
from .remote_settings_sources import RemoteSettingsSource, RemoteSettingsSourceConfig, RemoteSettingsSourceName
from .remote_settings_sources.apollo import ApolloSettingsSource
logger = logging.getLogger(__name__)
class RemoteSettingsSourceFactory(PydanticBaseSettingsSource):
def __init__(self, settings_cls: type[BaseSettings]):
super().__init__(settings_cls)
def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
raise NotImplementedError
def __call__(self) -> dict[str, Any]:
current_state = self.current_state
remote_source_name = current_state.get("REMOTE_SETTINGS_SOURCE_NAME")
if not remote_source_name:
return {}
remote_source: RemoteSettingsSource | None = None
match remote_source_name:
case RemoteSettingsSourceName.APOLLO:
remote_source = ApolloSettingsSource(current_state)
case _:
logger.warning(f"Unsupported remote source: {remote_source_name}")
return {}
d: dict[str, Any] = {}
for field_name, field in self.settings_cls.model_fields.items():
field_value, field_key, value_is_complex = remote_source.get_field_value(field, field_name)
field_value = remote_source.prepare_field_value(field_name, field, field_value, value_is_complex)
if field_value is not None:
d[field_key] = field_value
return d
class DifyConfig(
@ -19,6 +59,8 @@ class DifyConfig(
MiddlewareConfig,
# Extra service configs
ExtraServiceConfig,
# Remote source configs
RemoteSettingsSourceConfig,
# Enterprise feature configs
# **Before using, please contact business@dify.ai by email to inquire about licensing matters.**
EnterpriseFeatureConfig,
@ -35,3 +77,20 @@ class DifyConfig(
# please consider to arrange it in the proper config group of existed or added
# for better readability and maintainability.
# Thanks for your concentration and consideration.
@classmethod
def settings_customise_sources(
cls,
settings_cls: type[BaseSettings],
init_settings: PydanticBaseSettingsSource,
env_settings: PydanticBaseSettingsSource,
dotenv_settings: PydanticBaseSettingsSource,
file_secret_settings: PydanticBaseSettingsSource,
) -> tuple[PydanticBaseSettingsSource, ...]:
return (
init_settings,
env_settings,
RemoteSettingsSourceFactory(settings_cls),
dotenv_settings,
file_secret_settings,
)

View File

@ -1,54 +1,69 @@
from typing import Any, Optional
from typing import Any, Literal, Optional
from urllib.parse import quote_plus
from pydantic import Field, NonNegativeInt, PositiveFloat, PositiveInt, computed_field
from pydantic_settings import BaseSettings
from configs.middleware.cache.redis_config import RedisConfig
from configs.middleware.storage.aliyun_oss_storage_config import AliyunOSSStorageConfig
from configs.middleware.storage.amazon_s3_storage_config import S3StorageConfig
from configs.middleware.storage.azure_blob_storage_config import AzureBlobStorageConfig
from configs.middleware.storage.baidu_obs_storage_config import BaiduOBSStorageConfig
from configs.middleware.storage.google_cloud_storage_config import GoogleCloudStorageConfig
from configs.middleware.storage.huawei_obs_storage_config import HuaweiCloudOBSStorageConfig
from configs.middleware.storage.oci_storage_config import OCIStorageConfig
from configs.middleware.storage.supabase_storage_config import SupabaseStorageConfig
from configs.middleware.storage.tencent_cos_storage_config import TencentCloudCOSStorageConfig
from configs.middleware.storage.volcengine_tos_storage_config import VolcengineTOSStorageConfig
from configs.middleware.vdb.analyticdb_config import AnalyticdbConfig
from configs.middleware.vdb.baidu_vector_config import BaiduVectorDBConfig
from configs.middleware.vdb.chroma_config import ChromaConfig
from configs.middleware.vdb.couchbase_config import CouchbaseConfig
from configs.middleware.vdb.elasticsearch_config import ElasticsearchConfig
from configs.middleware.vdb.lindorm_config import LindormConfig
from configs.middleware.vdb.milvus_config import MilvusConfig
from configs.middleware.vdb.myscale_config import MyScaleConfig
from configs.middleware.vdb.oceanbase_config import OceanBaseVectorConfig
from configs.middleware.vdb.opensearch_config import OpenSearchConfig
from configs.middleware.vdb.oracle_config import OracleConfig
from configs.middleware.vdb.pgvector_config import PGVectorConfig
from configs.middleware.vdb.pgvectors_config import PGVectoRSConfig
from configs.middleware.vdb.qdrant_config import QdrantConfig
from configs.middleware.vdb.relyt_config import RelytConfig
from configs.middleware.vdb.tencent_vector_config import TencentVectorDBConfig
from configs.middleware.vdb.tidb_on_qdrant_config import TidbOnQdrantConfig
from configs.middleware.vdb.tidb_vector_config import TiDBVectorConfig
from configs.middleware.vdb.upstash_config import UpstashConfig
from configs.middleware.vdb.vikingdb_config import VikingDBConfig
from configs.middleware.vdb.weaviate_config import WeaviateConfig
from .cache.redis_config import RedisConfig
from .storage.aliyun_oss_storage_config import AliyunOSSStorageConfig
from .storage.amazon_s3_storage_config import S3StorageConfig
from .storage.azure_blob_storage_config import AzureBlobStorageConfig
from .storage.baidu_obs_storage_config import BaiduOBSStorageConfig
from .storage.google_cloud_storage_config import GoogleCloudStorageConfig
from .storage.huawei_obs_storage_config import HuaweiCloudOBSStorageConfig
from .storage.oci_storage_config import OCIStorageConfig
from .storage.opendal_storage_config import OpenDALStorageConfig
from .storage.supabase_storage_config import SupabaseStorageConfig
from .storage.tencent_cos_storage_config import TencentCloudCOSStorageConfig
from .storage.volcengine_tos_storage_config import VolcengineTOSStorageConfig
from .vdb.analyticdb_config import AnalyticdbConfig
from .vdb.baidu_vector_config import BaiduVectorDBConfig
from .vdb.chroma_config import ChromaConfig
from .vdb.couchbase_config import CouchbaseConfig
from .vdb.elasticsearch_config import ElasticsearchConfig
from .vdb.lindorm_config import LindormConfig
from .vdb.milvus_config import MilvusConfig
from .vdb.myscale_config import MyScaleConfig
from .vdb.oceanbase_config import OceanBaseVectorConfig
from .vdb.opensearch_config import OpenSearchConfig
from .vdb.oracle_config import OracleConfig
from .vdb.pgvector_config import PGVectorConfig
from .vdb.pgvectors_config import PGVectoRSConfig
from .vdb.qdrant_config import QdrantConfig
from .vdb.relyt_config import RelytConfig
from .vdb.tencent_vector_config import TencentVectorDBConfig
from .vdb.tidb_on_qdrant_config import TidbOnQdrantConfig
from .vdb.tidb_vector_config import TiDBVectorConfig
from .vdb.upstash_config import UpstashConfig
from .vdb.vikingdb_config import VikingDBConfig
from .vdb.weaviate_config import WeaviateConfig
class StorageConfig(BaseSettings):
STORAGE_TYPE: str = Field(
STORAGE_TYPE: Literal[
"opendal",
"s3",
"aliyun-oss",
"azure-blob",
"baidu-obs",
"google-storage",
"huawei-obs",
"oci-storage",
"tencent-cos",
"volcengine-tos",
"supabase",
"local",
] = Field(
description="Type of storage to use."
" Options: 'local', 's3', 'aliyun-oss', 'azure-blob', 'baidu-obs', 'google-storage', 'huawei-obs', "
"'oci-storage', 'tencent-cos', 'volcengine-tos', 'supabase'. Default is 'local'.",
default="local",
" Options: 'opendal', '(deprecated) local', 's3', 'aliyun-oss', 'azure-blob', 'baidu-obs', 'google-storage', "
"'huawei-obs', 'oci-storage', 'tencent-cos', 'volcengine-tos', 'supabase'. Default is 'opendal'.",
default="opendal",
)
STORAGE_LOCAL_PATH: str = Field(
description="Path for local storage when STORAGE_TYPE is set to 'local'.",
default="storage",
deprecated=True,
)
@ -73,7 +88,7 @@ class KeywordStoreConfig(BaseSettings):
)
class DatabaseConfig:
class DatabaseConfig(BaseSettings):
DB_HOST: str = Field(
description="Hostname or IP address of the database server.",
default="localhost",
@ -235,6 +250,7 @@ class MiddlewareConfig(
GoogleCloudStorageConfig,
HuaweiCloudOBSStorageConfig,
OCIStorageConfig,
OpenDALStorageConfig,
S3StorageConfig,
SupabaseStorageConfig,
TencentCloudCOSStorageConfig,

View File

@ -1,9 +1,10 @@
from typing import Optional
from pydantic import BaseModel, Field
from pydantic import Field
from pydantic_settings import BaseSettings
class BaiduOBSStorageConfig(BaseModel):
class BaiduOBSStorageConfig(BaseSettings):
"""
Configuration settings for Baidu Object Storage Service (OBS)
"""

View File

@ -1,9 +1,10 @@
from typing import Optional
from pydantic import BaseModel, Field
from pydantic import Field
from pydantic_settings import BaseSettings
class HuaweiCloudOBSStorageConfig(BaseModel):
class HuaweiCloudOBSStorageConfig(BaseSettings):
"""
Configuration settings for Huawei Cloud Object Storage Service (OBS)
"""

View File

@ -0,0 +1,51 @@
from enum import StrEnum
from typing import Literal
from pydantic import Field
from pydantic_settings import BaseSettings
class OpenDALScheme(StrEnum):
FS = "fs"
S3 = "s3"
class OpenDALStorageConfig(BaseSettings):
STORAGE_OPENDAL_SCHEME: str = Field(
default=OpenDALScheme.FS.value,
description="OpenDAL scheme.",
)
# FS
OPENDAL_FS_ROOT: str = Field(
default="storage",
description="Root path for local storage.",
)
# S3
OPENDAL_S3_ROOT: str = Field(
default="/",
description="Root path for S3 storage.",
)
OPENDAL_S3_BUCKET: str = Field(
default="",
description="S3 bucket name.",
)
OPENDAL_S3_ENDPOINT: str = Field(
default="https://s3.amazonaws.com",
description="S3 endpoint URL.",
)
OPENDAL_S3_ACCESS_KEY_ID: str = Field(
default="",
description="S3 access key ID.",
)
OPENDAL_S3_SECRET_ACCESS_KEY: str = Field(
default="",
description="S3 secret access key.",
)
OPENDAL_S3_REGION: str = Field(
default="",
description="S3 region.",
)
OPENDAL_S3_SERVER_SIDE_ENCRYPTION: Literal["aws:kms", ""] = Field(
default="",
description="S3 server-side encryption.",
)

View File

@ -1,9 +1,10 @@
from typing import Optional
from pydantic import BaseModel, Field
from pydantic import Field
from pydantic_settings import BaseSettings
class SupabaseStorageConfig(BaseModel):
class SupabaseStorageConfig(BaseSettings):
"""
Configuration settings for Supabase Object Storage Service
"""

View File

@ -1,9 +1,10 @@
from typing import Optional
from pydantic import BaseModel, Field
from pydantic import Field
from pydantic_settings import BaseSettings
class VolcengineTOSStorageConfig(BaseModel):
class VolcengineTOSStorageConfig(BaseSettings):
"""
Configuration settings for Volcengine Tinder Object Storage (TOS)
"""

View File

@ -1,9 +1,10 @@
from typing import Optional
from pydantic import BaseModel, Field, PositiveInt
from pydantic import Field, PositiveInt
from pydantic_settings import BaseSettings
class AnalyticdbConfig(BaseModel):
class AnalyticdbConfig(BaseSettings):
"""
Configuration for connecting to Alibaba Cloud AnalyticDB for PostgreSQL.
Refer to the following documentation for details on obtaining credentials:

View File

@ -1,9 +1,10 @@
from typing import Optional
from pydantic import BaseModel, Field
from pydantic import Field
from pydantic_settings import BaseSettings
class CouchbaseConfig(BaseModel):
class CouchbaseConfig(BaseSettings):
"""
Couchbase configs
"""

View File

@ -21,3 +21,14 @@ class LindormConfig(BaseSettings):
description="Lindorm password",
default=None,
)
DEFAULT_INDEX_TYPE: Optional[str] = Field(
description="Lindorm Vector Index Type, hnsw or flat is available in dify",
default="hnsw",
)
DEFAULT_DISTANCE_TYPE: Optional[str] = Field(
description="Vector Distance Type, support l2, cosinesimil, innerproduct", default="l2"
)
USING_UGC_INDEX: Optional[bool] = Field(
description="Using UGC index will store the same type of Index in a single index but can retrieve separately.",
default=False,
)

View File

@ -1,7 +1,8 @@
from pydantic import BaseModel, Field, PositiveInt
from pydantic import Field, PositiveInt
from pydantic_settings import BaseSettings
class MyScaleConfig(BaseModel):
class MyScaleConfig(BaseSettings):
"""
Configuration settings for MyScale vector database
"""

View File

@ -1,9 +1,10 @@
from typing import Optional
from pydantic import BaseModel, Field
from pydantic import Field
from pydantic_settings import BaseSettings
class VikingDBConfig(BaseModel):
class VikingDBConfig(BaseSettings):
"""
Configuration for connecting to Volcengine VikingDB.
Refer to the following documentation for details on obtaining credentials:

View File

@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings):
CURRENT_VERSION: str = Field(
description="Dify version",
default="0.13.1",
default="0.13.2",
)
COMMIT_SHA: str = Field(

View File

@ -0,0 +1,17 @@
from typing import Optional
from pydantic import Field
from .apollo import ApolloSettingsSourceInfo
from .base import RemoteSettingsSource
from .enums import RemoteSettingsSourceName
class RemoteSettingsSourceConfig(ApolloSettingsSourceInfo):
REMOTE_SETTINGS_SOURCE_NAME: RemoteSettingsSourceName | str = Field(
description="name of remote config source",
default="",
)
__all__ = ["RemoteSettingsSource", "RemoteSettingsSourceConfig", "RemoteSettingsSourceName"]

View File

@ -0,0 +1,55 @@
from collections.abc import Mapping
from typing import Any, Optional
from pydantic import Field
from pydantic.fields import FieldInfo
from pydantic_settings import BaseSettings
from configs.remote_settings_sources.base import RemoteSettingsSource
from .client import ApolloClient
class ApolloSettingsSourceInfo(BaseSettings):
"""
Packaging build information
"""
APOLLO_APP_ID: Optional[str] = Field(
description="apollo app_id",
default=None,
)
APOLLO_CLUSTER: Optional[str] = Field(
description="apollo cluster",
default=None,
)
APOLLO_CONFIG_URL: Optional[str] = Field(
description="apollo config url",
default=None,
)
APOLLO_NAMESPACE: Optional[str] = Field(
description="apollo namespace",
default=None,
)
class ApolloSettingsSource(RemoteSettingsSource):
def __init__(self, configs: Mapping[str, Any]):
self.client = ApolloClient(
app_id=configs["APOLLO_APP_ID"],
cluster=configs["APOLLO_CLUSTER"],
config_url=configs["APOLLO_CONFIG_URL"],
start_hot_update=False,
_notification_map={configs["APOLLO_NAMESPACE"]: -1},
)
self.namespace = configs["APOLLO_NAMESPACE"]
self.remote_configs = self.client.get_all_dicts(self.namespace)
def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
if not isinstance(self.remote_configs, dict):
raise ValueError(f"remote configs is not dict, but {type(self.remote_configs)}")
field_value = self.remote_configs.get(field_name)
return field_value, field_name, False

View File

@ -0,0 +1,303 @@
import hashlib
import json
import logging
import os
import threading
import time
from pathlib import Path
from .python_3x import http_request, makedirs_wrapper
from .utils import (
CONFIGURATIONS,
NAMESPACE_NAME,
NOTIFICATION_ID,
get_value_from_dict,
init_ip,
no_key_cache_key,
signature,
url_encode_wrapper,
)
logger = logging.getLogger(__name__)
class ApolloClient:
def __init__(
self,
config_url,
app_id,
cluster="default",
secret="",
start_hot_update=True,
change_listener=None,
_notification_map=None,
):
# Core routing parameters
self.config_url = config_url
self.cluster = cluster
self.app_id = app_id
# Non-core parameters
self.ip = init_ip()
self.secret = secret
# Check the parameter variables
# Private control variables
self._cycle_time = 5
self._stopping = False
self._cache = {}
self._no_key = {}
self._hash = {}
self._pull_timeout = 75
self._cache_file_path = os.path.expanduser("~") + "/.dify/config/remote-settings/apollo/cache/"
self._long_poll_thread = None
self._change_listener = change_listener # "add" "delete" "update"
if _notification_map is None:
_notification_map = {"application": -1}
self._notification_map = _notification_map
self.last_release_key = None
# Private startup method
self._path_checker()
if start_hot_update:
self._start_hot_update()
# start the heartbeat thread
heartbeat = threading.Thread(target=self._heart_beat)
heartbeat.daemon = True
heartbeat.start()
def get_json_from_net(self, namespace="application"):
url = "{}/configs/{}/{}/{}?releaseKey={}&ip={}".format(
self.config_url, self.app_id, self.cluster, namespace, "", self.ip
)
try:
code, body = http_request(url, timeout=3, headers=self._sign_headers(url))
if code == 200:
if not body:
logger.error(f"get_json_from_net load configs failed, body is {body}")
return None
data = json.loads(body)
data = data["configurations"]
return_data = {CONFIGURATIONS: data}
return return_data
else:
return None
except Exception:
logger.exception("an error occurred in get_json_from_net")
return None
def get_value(self, key, default_val=None, namespace="application"):
try:
# read memory configuration
namespace_cache = self._cache.get(namespace)
val = get_value_from_dict(namespace_cache, key)
if val is not None:
return val
no_key = no_key_cache_key(namespace, key)
if no_key in self._no_key:
return default_val
# read the network configuration
namespace_data = self.get_json_from_net(namespace)
val = get_value_from_dict(namespace_data, key)
if val is not None:
self._update_cache_and_file(namespace_data, namespace)
return val
# read the file configuration
namespace_cache = self._get_local_cache(namespace)
val = get_value_from_dict(namespace_cache, key)
if val is not None:
self._update_cache_and_file(namespace_cache, namespace)
return val
# If all of them are not obtained, the default value is returned
# and the local cache is set to None
self._set_local_cache_none(namespace, key)
return default_val
except Exception:
logger.exception("get_value has error, [key is %s], [namespace is %s]", key, namespace)
return default_val
# Set the key of a namespace to none, and do not set default val
# to ensure the real-time correctness of the function call.
# If the user does not have the same default val twice
# and the default val is used here, there may be a problem.
def _set_local_cache_none(self, namespace, key):
no_key = no_key_cache_key(namespace, key)
self._no_key[no_key] = key
def _start_hot_update(self):
self._long_poll_thread = threading.Thread(target=self._listener)
# When the asynchronous thread is started, the daemon thread will automatically exit
# when the main thread is launched.
self._long_poll_thread.daemon = True
self._long_poll_thread.start()
def stop(self):
self._stopping = True
logger.info("Stopping listener...")
# Call the set callback function, and if it is abnormal, try it out
def _call_listener(self, namespace, old_kv, new_kv):
if self._change_listener is None:
return
if old_kv is None:
old_kv = {}
if new_kv is None:
new_kv = {}
try:
for key in old_kv:
new_value = new_kv.get(key)
old_value = old_kv.get(key)
if new_value is None:
# If newValue is empty, it means key, and the value is deleted.
self._change_listener("delete", namespace, key, old_value)
continue
if new_value != old_value:
self._change_listener("update", namespace, key, new_value)
continue
for key in new_kv:
new_value = new_kv.get(key)
old_value = old_kv.get(key)
if old_value is None:
self._change_listener("add", namespace, key, new_value)
except BaseException as e:
logger.warning(str(e))
def _path_checker(self):
if not os.path.isdir(self._cache_file_path):
makedirs_wrapper(self._cache_file_path)
# update the local cache and file cache
def _update_cache_and_file(self, namespace_data, namespace="application"):
# update the local cache
self._cache[namespace] = namespace_data
# update the file cache
new_string = json.dumps(namespace_data)
new_hash = hashlib.md5(new_string.encode("utf-8")).hexdigest()
if self._hash.get(namespace) == new_hash:
pass
else:
file_path = Path(self._cache_file_path) / f"{self.app_id}_configuration_{namespace}.txt"
file_path.write_text(new_string)
self._hash[namespace] = new_hash
# get the configuration from the local file
def _get_local_cache(self, namespace="application"):
cache_file_path = os.path.join(self._cache_file_path, f"{self.app_id}_configuration_{namespace}.txt")
if os.path.isfile(cache_file_path):
with open(cache_file_path) as f:
result = json.loads(f.readline())
return result
return {}
def _long_poll(self):
notifications = []
for key in self._cache:
namespace_data = self._cache[key]
notification_id = -1
if NOTIFICATION_ID in namespace_data:
notification_id = self._cache[key][NOTIFICATION_ID]
notifications.append({NAMESPACE_NAME: key, NOTIFICATION_ID: notification_id})
try:
# if the length is 0 it is returned directly
if len(notifications) == 0:
return
url = "{}/notifications/v2".format(self.config_url)
params = {
"appId": self.app_id,
"cluster": self.cluster,
"notifications": json.dumps(notifications, ensure_ascii=False),
}
param_str = url_encode_wrapper(params)
url = url + "?" + param_str
code, body = http_request(url, self._pull_timeout, headers=self._sign_headers(url))
http_code = code
if http_code == 304:
logger.debug("No change, loop...")
return
if http_code == 200:
if not body:
logger.error(f"_long_poll load configs failed,body is {body}")
return
data = json.loads(body)
for entry in data:
namespace = entry[NAMESPACE_NAME]
n_id = entry[NOTIFICATION_ID]
logger.info("%s has changes: notificationId=%d", namespace, n_id)
self._get_net_and_set_local(namespace, n_id, call_change=True)
return
else:
logger.warning("Sleep...")
except Exception as e:
logger.warning(str(e))
def _get_net_and_set_local(self, namespace, n_id, call_change=False):
namespace_data = self.get_json_from_net(namespace)
if not namespace_data:
return
namespace_data[NOTIFICATION_ID] = n_id
old_namespace = self._cache.get(namespace)
self._update_cache_and_file(namespace_data, namespace)
if self._change_listener is not None and call_change and old_namespace:
old_kv = old_namespace.get(CONFIGURATIONS)
new_kv = namespace_data.get(CONFIGURATIONS)
self._call_listener(namespace, old_kv, new_kv)
def _listener(self):
logger.info("start long_poll")
while not self._stopping:
self._long_poll()
time.sleep(self._cycle_time)
logger.info("stopped, long_poll")
# add the need for endorsement to the header
def _sign_headers(self, url):
headers = {}
if self.secret == "":
return headers
uri = url[len(self.config_url) : len(url)]
time_unix_now = str(int(round(time.time() * 1000)))
headers["Authorization"] = "Apollo " + self.app_id + ":" + signature(time_unix_now, uri, self.secret)
headers["Timestamp"] = time_unix_now
return headers
def _heart_beat(self):
while not self._stopping:
for namespace in self._notification_map:
self._do_heart_beat(namespace)
time.sleep(60 * 10) # 10分钟
def _do_heart_beat(self, namespace):
url = "{}/configs/{}/{}/{}?ip={}".format(self.config_url, self.app_id, self.cluster, namespace, self.ip)
try:
code, body = http_request(url, timeout=3, headers=self._sign_headers(url))
if code == 200:
if not body:
logger.error(f"_do_heart_beat load configs failed,body is {body}")
return None
data = json.loads(body)
if self.last_release_key == data["releaseKey"]:
return None
self.last_release_key = data["releaseKey"]
data = data["configurations"]
self._update_cache_and_file(data, namespace)
else:
return None
except Exception:
logger.exception("an error occurred in _do_heart_beat")
return None
def get_all_dicts(self, namespace):
namespace_data = self._cache.get(namespace)
if namespace_data is None:
net_namespace_data = self.get_json_from_net(namespace)
if not net_namespace_data:
return namespace_data
namespace_data = net_namespace_data.get(CONFIGURATIONS)
if namespace_data:
self._update_cache_and_file(namespace_data, namespace)
return namespace_data

View File

@ -0,0 +1,41 @@
import logging
import os
import ssl
import urllib.request
from urllib import parse
from urllib.error import HTTPError
# Create an SSL context that allows for a lower level of security
ssl_context = ssl.create_default_context()
ssl_context.set_ciphers("HIGH:!DH:!aNULL")
ssl_context.check_hostname = False
ssl_context.verify_mode = ssl.CERT_NONE
# Create an opener object and pass in a custom SSL context
opener = urllib.request.build_opener(urllib.request.HTTPSHandler(context=ssl_context))
urllib.request.install_opener(opener)
logger = logging.getLogger(__name__)
def http_request(url, timeout, headers={}):
try:
request = urllib.request.Request(url, headers=headers)
res = urllib.request.urlopen(request, timeout=timeout)
body = res.read().decode("utf-8")
return res.code, body
except HTTPError as e:
if e.code == 304:
logger.warning("http_request error,code is 304, maybe you should check secret")
return 304, None
logger.warning("http_request error,code is %d, msg is %s", e.code, e.msg)
raise e
def url_encode(params):
return parse.urlencode(params)
def makedirs_wrapper(path):
os.makedirs(path, exist_ok=True)

View File

@ -0,0 +1,51 @@
import hashlib
import socket
from .python_3x import url_encode
# define constants
CONFIGURATIONS = "configurations"
NOTIFICATION_ID = "notificationId"
NAMESPACE_NAME = "namespaceName"
# add timestamps uris and keys
def signature(timestamp, uri, secret):
import base64
import hmac
string_to_sign = "" + timestamp + "\n" + uri
hmac_code = hmac.new(secret.encode(), string_to_sign.encode(), hashlib.sha1).digest()
return base64.b64encode(hmac_code).decode()
def url_encode_wrapper(params):
return url_encode(params)
def no_key_cache_key(namespace, key):
return "{}{}{}".format(namespace, len(namespace), key)
# Returns whether the obtained value is obtained, and None if it does not
def get_value_from_dict(namespace_cache, key):
if namespace_cache:
kv_data = namespace_cache.get(CONFIGURATIONS)
if kv_data is None:
return None
if key in kv_data:
return kv_data[key]
return None
def init_ip():
ip = ""
s = None
try:
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
s.connect(("8.8.8.8", 53))
ip = s.getsockname()[0]
finally:
if s:
s.close()
return ip

View File

@ -0,0 +1,15 @@
from collections.abc import Mapping
from typing import Any
from pydantic.fields import FieldInfo
class RemoteSettingsSource:
def __init__(self, configs: Mapping[str, Any]):
pass
def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
raise NotImplementedError
def prepare_field_value(self, field_name: str, field: FieldInfo, value: Any, value_is_complex: bool) -> Any:
return value

View File

@ -0,0 +1,5 @@
from enum import StrEnum
class RemoteSettingsSourceName(StrEnum):
APOLLO = "apollo"

View File

@ -14,11 +14,11 @@ AUDIO_EXTENSIONS.extend([ext.upper() for ext in AUDIO_EXTENSIONS])
if dify_config.ETL_TYPE == "Unstructured":
DOCUMENT_EXTENSIONS = ["txt", "markdown", "md", "pdf", "html", "htm", "xlsx", "xls"]
DOCUMENT_EXTENSIONS = ["txt", "markdown", "md", "mdx", "pdf", "html", "htm", "xlsx", "xls"]
DOCUMENT_EXTENSIONS.extend(("docx", "csv", "eml", "msg", "pptx", "xml", "epub"))
if dify_config.UNSTRUCTURED_API_URL:
DOCUMENT_EXTENSIONS.append("ppt")
DOCUMENT_EXTENSIONS.extend([ext.upper() for ext in DOCUMENT_EXTENSIONS])
else:
DOCUMENT_EXTENSIONS = ["txt", "markdown", "md", "pdf", "html", "htm", "xlsx", "xls", "docx", "csv"]
DOCUMENT_EXTENSIONS = ["txt", "markdown", "md", "mdx", "pdf", "html", "htm", "xlsx", "xls", "docx", "csv"]
DOCUMENT_EXTENSIONS.extend([ext.upper() for ext in DOCUMENT_EXTENSIONS])

View File

@ -1,5 +1,6 @@
from datetime import UTC, datetime
from flask import request
from flask_login import current_user
from flask_restful import Resource, inputs, marshal_with, reqparse
from sqlalchemy import and_
@ -20,8 +21,17 @@ class InstalledAppsListApi(Resource):
@account_initialization_required
@marshal_with(installed_app_list_fields)
def get(self):
app_id = request.args.get("app_id", default=None, type=str)
current_tenant_id = current_user.current_tenant_id
installed_apps = db.session.query(InstalledApp).filter(InstalledApp.tenant_id == current_tenant_id).all()
if app_id:
installed_apps = (
db.session.query(InstalledApp)
.filter(and_(InstalledApp.tenant_id == current_tenant_id, InstalledApp.app_id == app_id))
.all()
)
else:
installed_apps = db.session.query(InstalledApp).filter(InstalledApp.tenant_id == current_tenant_id).all()
current_user.role = TenantService.get_user_role(current_user, current_user.current_tenant)
installed_apps = [

View File

@ -368,6 +368,7 @@ class ToolWorkflowProviderCreateApi(Resource):
description=args["description"],
parameters=args["parameters"],
privacy_policy=args["privacy_policy"],
labels=args["labels"],
)

View File

@ -3,7 +3,7 @@ import logging
import threading
import uuid
from collections.abc import Generator, Mapping
from typing import Any, Optional, Union
from typing import Any, Literal, Optional, Union, overload
from flask import Flask, current_app
from pydantic import ValidationError
@ -36,6 +36,29 @@ logger = logging.getLogger(__name__)
class AdvancedChatAppGenerator(MessageBasedAppGenerator):
_dialogue_count: int
@overload
def generate(
self,
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: Literal[True],
) -> Generator[str, None, None]: ...
@overload
def generate(
self,
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: Literal[False],
) -> Mapping[str, Any]: ...
@overload
def generate(
self,
app_model: App,
@ -44,7 +67,17 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool = True,
) -> Mapping[str, Any] | Generator[str, None, None]:
) -> Union[Mapping[str, Any], Generator[str, None, None]]: ...
def generate(
self,
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool = True,
):
"""
Generate App response.

View File

@ -19,6 +19,7 @@ from core.app.entities.queue_entities import (
QueueIterationNextEvent,
QueueIterationStartEvent,
QueueMessageReplaceEvent,
QueueNodeExceptionEvent,
QueueNodeFailedEvent,
QueueNodeInIterationFailedEvent,
QueueNodeStartedEvent,
@ -31,6 +32,7 @@ from core.app.entities.queue_entities import (
QueueStopEvent,
QueueTextChunkEvent,
QueueWorkflowFailedEvent,
QueueWorkflowPartialSuccessEvent,
QueueWorkflowStartedEvent,
QueueWorkflowSucceededEvent,
)
@ -127,7 +129,6 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
self._conversation_name_generate_thread = None
self._recorded_files: list[Mapping[str, Any]] = []
self.total_tokens: int = 0
def process(self):
"""
@ -318,7 +319,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
if response:
yield response
elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent):
elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent):
workflow_node_execution = self._handle_workflow_node_execution_failed(event)
response = self._workflow_node_finish_to_stream_response(
@ -361,8 +362,6 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
if not workflow_run:
raise Exception("Workflow run not initialized.")
# FIXME for issue #11221 quick fix maybe have a better solution
self.total_tokens += event.metadata.get("total_tokens", 0) if event.metadata else 0
yield self._workflow_iteration_completed_to_stream_response(
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
)
@ -376,7 +375,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
workflow_run = self._handle_workflow_run_success(
workflow_run=workflow_run,
start_at=graph_runtime_state.start_at,
total_tokens=graph_runtime_state.total_tokens or self.total_tokens,
total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
outputs=event.outputs,
conversation_id=self._conversation.id,
@ -387,6 +386,29 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
)
self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE)
elif isinstance(event, QueueWorkflowPartialSuccessEvent):
if not workflow_run:
raise Exception("Workflow run not initialized.")
if not graph_runtime_state:
raise Exception("Graph runtime state not initialized.")
workflow_run = self._handle_workflow_run_partial_success(
workflow_run=workflow_run,
start_at=graph_runtime_state.start_at,
total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
outputs=event.outputs,
exceptions_count=event.exceptions_count,
conversation_id=None,
trace_manager=trace_manager,
)
yield self._workflow_finish_to_stream_response(
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
)
self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE)
elif isinstance(event, QueueWorkflowFailedEvent):
if not workflow_run:
@ -404,6 +426,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
error=event.error,
conversation_id=self._conversation.id,
trace_manager=trace_manager,
exceptions_count=event.exceptions_count,
)
yield self._workflow_finish_to_stream_response(

View File

@ -2,7 +2,7 @@ import logging
import threading
import uuid
from collections.abc import Generator, Mapping
from typing import Any, Union
from typing import Any, Literal, Union, overload
from flask import Flask, current_app
from pydantic import ValidationError
@ -28,6 +28,39 @@ logger = logging.getLogger(__name__)
class AgentChatAppGenerator(MessageBasedAppGenerator):
@overload
def generate(
self,
*,
app_model: App,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: Literal[True],
) -> Generator[str, None, None]: ...
@overload
def generate(
self,
*,
app_model: App,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: Literal[False],
) -> Mapping[str, Any]: ...
@overload
def generate(
self,
*,
app_model: App,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool,
) -> Mapping[str, Any] | Generator[str, None, None]: ...
def generate(
self,
*,
@ -36,7 +69,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool = True,
) -> Mapping[str, Any] | Generator[str, None, None]:
):
"""
Generate App response.

View File

@ -1,7 +1,7 @@
import logging
import threading
import uuid
from collections.abc import Generator
from collections.abc import Generator, Mapping
from typing import Any, Literal, Union, overload
from flask import Flask, current_app
@ -34,9 +34,9 @@ class ChatAppGenerator(MessageBasedAppGenerator):
self,
app_model: App,
user: Union[Account, EndUser],
args: Any,
args: Mapping[str, Any],
invoke_from: InvokeFrom,
stream: Literal[True] = True,
streaming: Literal[True],
) -> Generator[str, None, None]: ...
@overload
@ -44,19 +44,29 @@ class ChatAppGenerator(MessageBasedAppGenerator):
self,
app_model: App,
user: Union[Account, EndUser],
args: Any,
args: Mapping[str, Any],
invoke_from: InvokeFrom,
stream: Literal[False] = False,
) -> dict: ...
streaming: Literal[False],
) -> Mapping[str, Any]: ...
@overload
def generate(
self,
app_model: App,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool,
) -> Union[Mapping[str, Any], Generator[str, None, None]]: ...
def generate(
self,
app_model: App,
user: Union[Account, EndUser],
args: Any,
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool = True,
) -> Union[dict, Generator[str, None, None]]:
):
"""
Generate App response.

View File

@ -1,7 +1,7 @@
import logging
import threading
import uuid
from collections.abc import Generator
from collections.abc import Generator, Mapping
from typing import Any, Literal, Union, overload
from flask import Flask, current_app
@ -34,9 +34,9 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
self,
app_model: App,
user: Union[Account, EndUser],
args: dict,
args: Mapping[str, Any],
invoke_from: InvokeFrom,
stream: Literal[True] = True,
streaming: Literal[True],
) -> Generator[str, None, None]: ...
@overload
@ -44,14 +44,29 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
self,
app_model: App,
user: Union[Account, EndUser],
args: dict,
args: Mapping[str, Any],
invoke_from: InvokeFrom,
stream: Literal[False] = False,
) -> dict: ...
streaming: Literal[False],
) -> Mapping[str, Any]: ...
@overload
def generate(
self,
app_model: App,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool,
) -> Mapping[str, Any] | Generator[str, None, None]: ...
def generate(
self, app_model: App, user: Union[Account, EndUser], args: Any, invoke_from: InvokeFrom, streaming: bool = True
) -> Union[dict, Generator[str, None, None]]:
self,
app_model: App,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool = True,
):
"""
Generate App response.

View File

@ -3,7 +3,7 @@ import logging
import threading
import uuid
from collections.abc import Generator, Mapping, Sequence
from typing import Any, Optional, Union
from typing import Any, Literal, Optional, Union, overload
from flask import Flask, current_app
from pydantic import ValidationError
@ -30,6 +30,35 @@ logger = logging.getLogger(__name__)
class WorkflowAppGenerator(BaseAppGenerator):
@overload
def generate(
self,
*,
app_model: App,
workflow: Workflow,
user: Account | EndUser,
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: Literal[True],
call_depth: int = 0,
workflow_thread_pool_id: Optional[str] = None,
) -> Generator[str, None, None]: ...
@overload
def generate(
self,
*,
app_model: App,
workflow: Workflow,
user: Account | EndUser,
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: Literal[False],
call_depth: int = 0,
workflow_thread_pool_id: Optional[str] = None,
) -> Mapping[str, Any]: ...
@overload
def generate(
self,
*,
@ -41,7 +70,20 @@ class WorkflowAppGenerator(BaseAppGenerator):
streaming: bool = True,
call_depth: int = 0,
workflow_thread_pool_id: Optional[str] = None,
) -> Mapping[str, Any] | Generator[str, None, None]:
) -> Mapping[str, Any] | Generator[str, None, None]: ...
def generate(
self,
*,
app_model: App,
workflow: Workflow,
user: Account | EndUser,
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool = True,
call_depth: int = 0,
workflow_thread_pool_id: Optional[str] = None,
):
files: Sequence[Mapping[str, Any]] = args.get("files") or []
# parse files

View File

@ -6,6 +6,7 @@ from core.app.entities.queue_entities import (
QueueMessageEndEvent,
QueueStopEvent,
QueueWorkflowFailedEvent,
QueueWorkflowPartialSuccessEvent,
QueueWorkflowSucceededEvent,
WorkflowQueueMessage,
)
@ -34,7 +35,8 @@ class WorkflowAppQueueManager(AppQueueManager):
| QueueErrorEvent
| QueueMessageEndEvent
| QueueWorkflowSucceededEvent
| QueueWorkflowFailedEvent,
| QueueWorkflowFailedEvent
| QueueWorkflowPartialSuccessEvent,
):
self.stop_listen()

View File

@ -15,6 +15,7 @@ from core.app.entities.queue_entities import (
QueueIterationCompletedEvent,
QueueIterationNextEvent,
QueueIterationStartEvent,
QueueNodeExceptionEvent,
QueueNodeFailedEvent,
QueueNodeInIterationFailedEvent,
QueueNodeStartedEvent,
@ -26,6 +27,7 @@ from core.app.entities.queue_entities import (
QueueStopEvent,
QueueTextChunkEvent,
QueueWorkflowFailedEvent,
QueueWorkflowPartialSuccessEvent,
QueueWorkflowStartedEvent,
QueueWorkflowSucceededEvent,
)
@ -106,7 +108,6 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
self._task_state = WorkflowTaskState()
self._wip_workflow_node_executions = {}
self.total_tokens: int = 0
def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
"""
@ -258,36 +259,36 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
workflow_node_execution = self._handle_node_execution_start(workflow_run=workflow_run, event=event)
response = self._workflow_node_start_to_stream_response(
node_start_response = self._workflow_node_start_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
if response:
yield response
if node_start_response:
yield node_start_response
elif isinstance(event, QueueNodeSucceededEvent):
workflow_node_execution = self._handle_workflow_node_execution_success(event)
response = self._workflow_node_finish_to_stream_response(
node_success_response = self._workflow_node_finish_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
if response:
yield response
elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent):
if node_success_response:
yield node_success_response
elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent):
workflow_node_execution = self._handle_workflow_node_execution_failed(event)
response = self._workflow_node_finish_to_stream_response(
node_failed_response = self._workflow_node_finish_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
if response:
yield response
if node_failed_response:
yield node_failed_response
elif isinstance(event, QueueParallelBranchRunStartedEvent):
if not workflow_run:
raise Exception("Workflow run not initialized.")
@ -320,8 +321,6 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
if not workflow_run:
raise Exception("Workflow run not initialized.")
# FIXME for issue #11221 quick fix maybe have a better solution
self.total_tokens += event.metadata.get("total_tokens", 0) if event.metadata else 0
yield self._workflow_iteration_completed_to_stream_response(
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
)
@ -335,7 +334,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
workflow_run = self._handle_workflow_run_success(
workflow_run=workflow_run,
start_at=graph_runtime_state.start_at,
total_tokens=graph_runtime_state.total_tokens or self.total_tokens,
total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
outputs=event.outputs,
conversation_id=None,
@ -345,6 +344,30 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
# save workflow app log
self._save_workflow_app_log(workflow_run)
yield self._workflow_finish_to_stream_response(
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
)
elif isinstance(event, QueueWorkflowPartialSuccessEvent):
if not workflow_run:
raise Exception("Workflow run not initialized.")
if not graph_runtime_state:
raise Exception("Graph runtime state not initialized.")
workflow_run = self._handle_workflow_run_partial_success(
workflow_run=workflow_run,
start_at=graph_runtime_state.start_at,
total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
outputs=event.outputs,
exceptions_count=event.exceptions_count,
conversation_id=None,
trace_manager=trace_manager,
)
# save workflow app log
self._save_workflow_app_log(workflow_run)
yield self._workflow_finish_to_stream_response(
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
)
@ -354,7 +377,6 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
if not graph_runtime_state:
raise Exception("Graph runtime state not initialized.")
workflow_run = self._handle_workflow_run_failed(
workflow_run=workflow_run,
start_at=graph_runtime_state.start_at,
@ -366,6 +388,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
error=event.error if isinstance(event, QueueWorkflowFailedEvent) else event.get_stop_reason(),
conversation_id=None,
trace_manager=trace_manager,
exceptions_count=event.exceptions_count if isinstance(event, QueueWorkflowFailedEvent) else 0,
)
# save workflow app log

View File

@ -8,6 +8,7 @@ from core.app.entities.queue_entities import (
QueueIterationCompletedEvent,
QueueIterationNextEvent,
QueueIterationStartEvent,
QueueNodeExceptionEvent,
QueueNodeFailedEvent,
QueueNodeInIterationFailedEvent,
QueueNodeStartedEvent,
@ -18,6 +19,7 @@ from core.app.entities.queue_entities import (
QueueRetrieverResourcesEvent,
QueueTextChunkEvent,
QueueWorkflowFailedEvent,
QueueWorkflowPartialSuccessEvent,
QueueWorkflowStartedEvent,
QueueWorkflowSucceededEvent,
)
@ -25,6 +27,7 @@ from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.entities.event import (
GraphEngineEvent,
GraphRunFailedEvent,
GraphRunPartialSucceededEvent,
GraphRunStartedEvent,
GraphRunSucceededEvent,
IterationRunFailedEvent,
@ -32,6 +35,7 @@ from core.workflow.graph_engine.entities.event import (
IterationRunStartedEvent,
IterationRunSucceededEvent,
NodeInIterationFailedEvent,
NodeRunExceptionEvent,
NodeRunFailedEvent,
NodeRunRetrieverResourceEvent,
NodeRunStartedEvent,
@ -176,8 +180,12 @@ class WorkflowBasedAppRunner(AppRunner):
)
elif isinstance(event, GraphRunSucceededEvent):
self._publish_event(QueueWorkflowSucceededEvent(outputs=event.outputs))
elif isinstance(event, GraphRunPartialSucceededEvent):
self._publish_event(
QueueWorkflowPartialSuccessEvent(outputs=event.outputs, exceptions_count=event.exceptions_count)
)
elif isinstance(event, GraphRunFailedEvent):
self._publish_event(QueueWorkflowFailedEvent(error=event.error))
self._publish_event(QueueWorkflowFailedEvent(error=event.error, exceptions_count=event.exceptions_count))
elif isinstance(event, NodeRunStartedEvent):
self._publish_event(
QueueNodeStartedEvent(
@ -253,6 +261,36 @@ class WorkflowBasedAppRunner(AppRunner):
in_iteration_id=event.in_iteration_id,
)
)
elif isinstance(event, NodeRunExceptionEvent):
self._publish_event(
QueueNodeExceptionEvent(
node_execution_id=event.id,
node_id=event.node_id,
node_type=event.node_type,
node_data=event.node_data,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
start_at=event.route_node_state.start_at,
inputs=event.route_node_state.node_run_result.inputs
if event.route_node_state.node_run_result
else {},
process_data=event.route_node_state.node_run_result.process_data
if event.route_node_state.node_run_result
else {},
outputs=event.route_node_state.node_run_result.outputs
if event.route_node_state.node_run_result
else {},
error=event.route_node_state.node_run_result.error
if event.route_node_state.node_run_result and event.route_node_state.node_run_result.error
else "Unknown error",
execution_metadata=event.route_node_state.node_run_result.metadata
if event.route_node_state.node_run_result
else {},
in_iteration_id=event.in_iteration_id,
)
)
elif isinstance(event, NodeInIterationFailedEvent):
self._publish_event(
QueueNodeInIterationFailedEvent(

View File

@ -2,7 +2,7 @@ from datetime import datetime
from enum import Enum, StrEnum
from typing import Any, Optional
from pydantic import BaseModel, field_validator
from pydantic import BaseModel
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
from core.workflow.entities.node_entities import NodeRunMetadataKey
@ -25,12 +25,14 @@ class QueueEvent(StrEnum):
WORKFLOW_STARTED = "workflow_started"
WORKFLOW_SUCCEEDED = "workflow_succeeded"
WORKFLOW_FAILED = "workflow_failed"
WORKFLOW_PARTIAL_SUCCEEDED = "workflow_partial_succeeded"
ITERATION_START = "iteration_start"
ITERATION_NEXT = "iteration_next"
ITERATION_COMPLETED = "iteration_completed"
NODE_STARTED = "node_started"
NODE_SUCCEEDED = "node_succeeded"
NODE_FAILED = "node_failed"
NODE_EXCEPTION = "node_exception"
RETRIEVER_RESOURCES = "retriever_resources"
ANNOTATION_REPLY = "annotation_reply"
AGENT_THOUGHT = "agent_thought"
@ -113,18 +115,6 @@ class QueueIterationNextEvent(AppQueueEvent):
output: Optional[Any] = None # output for the current iteration
duration: Optional[float] = None
@field_validator("output", mode="before")
@classmethod
def set_output(cls, v):
"""
Set output
"""
if v is None:
return None
if isinstance(v, int | float | str | bool | dict | list):
return v
raise ValueError("output must be a valid type")
class QueueIterationCompletedEvent(AppQueueEvent):
"""
@ -249,6 +239,17 @@ class QueueWorkflowFailedEvent(AppQueueEvent):
event: QueueEvent = QueueEvent.WORKFLOW_FAILED
error: str
exceptions_count: int
class QueueWorkflowPartialSuccessEvent(AppQueueEvent):
"""
QueueWorkflowFailedEvent entity
"""
event: QueueEvent = QueueEvent.WORKFLOW_PARTIAL_SUCCEEDED
exceptions_count: int
outputs: Optional[dict[str, Any]] = None
class QueueNodeStartedEvent(AppQueueEvent):
@ -343,6 +344,37 @@ class QueueNodeInIterationFailedEvent(AppQueueEvent):
error: str
class QueueNodeExceptionEvent(AppQueueEvent):
"""
QueueNodeExceptionEvent entity
"""
event: QueueEvent = QueueEvent.NODE_EXCEPTION
node_execution_id: str
node_id: str
node_type: NodeType
node_data: BaseNodeData
parallel_id: Optional[str] = None
"""parallel id if node is in parallel"""
parallel_start_node_id: Optional[str] = None
"""parallel start node id if node is in parallel"""
parent_parallel_id: Optional[str] = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: Optional[str] = None
"""parent parallel start node id if node is in parallel"""
in_iteration_id: Optional[str] = None
"""iteration id if node is in iteration"""
start_at: datetime
inputs: Optional[dict[str, Any]] = None
process_data: Optional[dict[str, Any]] = None
outputs: Optional[dict[str, Any]] = None
execution_metadata: Optional[dict[NodeRunMetadataKey, Any]] = None
error: str
class QueueNodeFailedEvent(AppQueueEvent):
"""
QueueNodeFailedEvent entity

View File

@ -213,6 +213,7 @@ class WorkflowFinishStreamResponse(StreamResponse):
created_by: Optional[dict] = None
created_at: int
finished_at: int
exceptions_count: Optional[int] = 0
files: Optional[Sequence[Mapping[str, Any]]] = []
event: StreamEvent = StreamEvent.WORKFLOW_FINISHED

View File

@ -110,7 +110,7 @@ class RateLimitGenerator:
raise StopIteration
try:
return next(self.generator)
except StopIteration:
except Exception:
self.close()
raise

View File

@ -12,6 +12,7 @@ from core.app.entities.queue_entities import (
QueueIterationCompletedEvent,
QueueIterationNextEvent,
QueueIterationStartEvent,
QueueNodeExceptionEvent,
QueueNodeFailedEvent,
QueueNodeInIterationFailedEvent,
QueueNodeStartedEvent,
@ -164,6 +165,55 @@ class WorkflowCycleManage:
return workflow_run
def _handle_workflow_run_partial_success(
self,
workflow_run: WorkflowRun,
start_at: float,
total_tokens: int,
total_steps: int,
outputs: Mapping[str, Any] | None = None,
exceptions_count: int = 0,
conversation_id: Optional[str] = None,
trace_manager: Optional[TraceQueueManager] = None,
) -> WorkflowRun:
"""
Workflow run success
:param workflow_run: workflow run
:param start_at: start time
:param total_tokens: total tokens
:param total_steps: total steps
:param outputs: outputs
:param conversation_id: conversation id
:return:
"""
workflow_run = self._refetch_workflow_run(workflow_run.id)
outputs = WorkflowEntry.handle_special_values(outputs)
workflow_run.status = WorkflowRunStatus.PARTIAL_SUCCESSED.value
workflow_run.outputs = json.dumps(outputs or {})
workflow_run.elapsed_time = time.perf_counter() - start_at
workflow_run.total_tokens = total_tokens
workflow_run.total_steps = total_steps
workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None)
workflow_run.exceptions_count = exceptions_count
db.session.commit()
db.session.refresh(workflow_run)
if trace_manager:
trace_manager.add_trace_task(
TraceTask(
TraceTaskName.WORKFLOW_TRACE,
workflow_run=workflow_run,
conversation_id=conversation_id,
user_id=trace_manager.user_id,
)
)
db.session.close()
return workflow_run
def _handle_workflow_run_failed(
self,
workflow_run: WorkflowRun,
@ -174,6 +224,7 @@ class WorkflowCycleManage:
error: str,
conversation_id: Optional[str] = None,
trace_manager: Optional[TraceQueueManager] = None,
exceptions_count: int = 0,
) -> WorkflowRun:
"""
Workflow run failed
@ -193,7 +244,7 @@ class WorkflowCycleManage:
workflow_run.total_tokens = total_tokens
workflow_run.total_steps = total_steps
workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None)
workflow_run.exceptions_count = exceptions_count
db.session.commit()
running_workflow_node_executions = (
@ -318,7 +369,7 @@ class WorkflowCycleManage:
return workflow_node_execution
def _handle_workflow_node_execution_failed(
self, event: QueueNodeFailedEvent | QueueNodeInIterationFailedEvent
self, event: QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent
) -> WorkflowNodeExecution:
"""
Workflow node execution failed
@ -337,7 +388,11 @@ class WorkflowCycleManage:
)
db.session.query(WorkflowNodeExecution).filter(WorkflowNodeExecution.id == workflow_node_execution.id).update(
{
WorkflowNodeExecution.status: WorkflowNodeExecutionStatus.FAILED.value,
WorkflowNodeExecution.status: (
WorkflowNodeExecutionStatus.FAILED.value
if not isinstance(event, QueueNodeExceptionEvent)
else WorkflowNodeExecutionStatus.EXCEPTION.value
),
WorkflowNodeExecution.error: event.error,
WorkflowNodeExecution.inputs: json.dumps(inputs) if inputs else None,
WorkflowNodeExecution.process_data: json.dumps(process_data) if process_data else None,
@ -351,8 +406,11 @@ class WorkflowCycleManage:
db.session.commit()
db.session.close()
process_data = WorkflowEntry.handle_special_values(event.process_data)
workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
workflow_node_execution.status = (
WorkflowNodeExecutionStatus.FAILED.value
if not isinstance(event, QueueNodeExceptionEvent)
else WorkflowNodeExecutionStatus.EXCEPTION.value
)
workflow_node_execution.error = event.error
workflow_node_execution.inputs = json.dumps(inputs) if inputs else None
workflow_node_execution.process_data = json.dumps(process_data) if process_data else None
@ -433,6 +491,7 @@ class WorkflowCycleManage:
created_at=int(workflow_run.created_at.timestamp()),
finished_at=int(workflow_run.finished_at.timestamp()),
files=self._fetch_files_from_node_outputs(workflow_run.outputs_dict),
exceptions_count=workflow_run.exceptions_count,
),
)
@ -483,7 +542,10 @@ class WorkflowCycleManage:
def _workflow_node_finish_to_stream_response(
self,
event: QueueNodeSucceededEvent | QueueNodeFailedEvent | QueueNodeInIterationFailedEvent,
event: QueueNodeSucceededEvent
| QueueNodeFailedEvent
| QueueNodeInIterationFailedEvent
| QueueNodeExceptionEvent,
task_id: str,
workflow_node_execution: WorkflowNodeExecution,
) -> Optional[NodeFinishStreamResponse]:

View File

@ -141,7 +141,7 @@ def _to_url(f: File, /):
elif f.transfer_method == FileTransferMethod.LOCAL_FILE:
if f.related_id is None:
raise ValueError("Missing file related_id")
return helpers.get_signed_file_url(upload_file_id=f.related_id)
return f.remote_url or helpers.get_signed_file_url(upload_file_id=f.related_id)
elif f.transfer_method == FileTransferMethod.TOOL_FILE:
# add sign url
if f.related_id is None or f.extension is None:

View File

@ -24,6 +24,12 @@ BACKOFF_FACTOR = 0.5
STATUS_FORCELIST = [429, 500, 502, 503, 504]
class MaxRetriesExceededError(Exception):
"""Raised when the maximum number of retries is exceeded."""
pass
def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
if "allow_redirects" in kwargs:
allow_redirects = kwargs.pop("allow_redirects")
@ -66,7 +72,7 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
if retries <= max_retries:
time.sleep(BACKOFF_FACTOR * (2 ** (retries - 1)))
raise Exception(f"Reached maximum retries ({max_retries}) for URL {url}")
raise MaxRetriesExceededError(f"Reached maximum retries ({max_retries}) for URL {url}")
def get(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):

View File

@ -0,0 +1,21 @@
import boto3
from botocore.config import Config
def get_bedrock_client(service_name, credentials=None):
client_config = Config(region_name=credentials["aws_region"])
aws_access_key_id = credentials["aws_access_key_id"]
aws_secret_access_key = credentials["aws_secret_access_key"]
if aws_access_key_id and aws_secret_access_key:
# use aksk to call bedrock
client = boto3.client(
service_name=service_name,
config=client_config,
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
)
else:
# use iam without aksk to call
client = boto3.client(service_name=service_name, config=client_config)
return client

View File

@ -40,6 +40,7 @@ from core.model_runtime.errors.invoke import (
)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.model_providers.bedrock.get_bedrock_client import get_bedrock_client
logger = logging.getLogger(__name__)
ANTHROPIC_BLOCK_MODE_PROMPT = """You should always follow the instructions and output a valid {{block}} object.
@ -173,13 +174,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
:param stream: is stream response
:return: full response or stream response chunk generator result
"""
bedrock_client = boto3.client(
service_name="bedrock-runtime",
aws_access_key_id=credentials.get("aws_access_key_id"),
aws_secret_access_key=credentials.get("aws_secret_access_key"),
region_name=credentials["aws_region"],
)
bedrock_client = get_bedrock_client("bedrock-runtime", credentials)
system, prompt_message_dicts = self._convert_converse_prompt_messages(prompt_messages)
inference_config, additional_model_fields = self._convert_converse_api_model_parameters(model_parameters, stop)

View File

@ -1,8 +1,5 @@
from typing import Optional
import boto3
from botocore.config import Config
from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult
from core.model_runtime.errors.invoke import (
InvokeAuthorizationError,
@ -14,6 +11,7 @@ from core.model_runtime.errors.invoke import (
)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.rerank_model import RerankModel
from core.model_runtime.model_providers.bedrock.get_bedrock_client import get_bedrock_client
class BedrockRerankModel(RerankModel):
@ -48,13 +46,7 @@ class BedrockRerankModel(RerankModel):
return RerankResult(model=model, docs=docs)
# initialize client
client_config = Config(region_name=credentials["aws_region"])
bedrock_runtime = boto3.client(
service_name="bedrock-agent-runtime",
config=client_config,
aws_access_key_id=credentials.get("aws_access_key_id", ""),
aws_secret_access_key=credentials.get("aws_secret_access_key"),
)
bedrock_runtime = get_bedrock_client("bedrock-agent-runtime", credentials)
queries = [{"type": "TEXT", "textQuery": {"text": query}}]
text_sources = []
for text in docs:

View File

@ -3,8 +3,6 @@ import logging
import time
from typing import Optional
import boto3
from botocore.config import Config
from botocore.exceptions import (
ClientError,
EndpointConnectionError,
@ -25,6 +23,7 @@ from core.model_runtime.errors.invoke import (
InvokeServerUnavailableError,
)
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from core.model_runtime.model_providers.bedrock.get_bedrock_client import get_bedrock_client
logger = logging.getLogger(__name__)
@ -48,14 +47,7 @@ class BedrockTextEmbeddingModel(TextEmbeddingModel):
:param input_type: input type
:return: embeddings result
"""
client_config = Config(region_name=credentials["aws_region"])
bedrock_runtime = boto3.client(
service_name="bedrock-runtime",
config=client_config,
aws_access_key_id=credentials.get("aws_access_key_id"),
aws_secret_access_key=credentials.get("aws_secret_access_key"),
)
bedrock_runtime = get_bedrock_client("bedrock-runtime", credentials)
embeddings = []
token_usage = 0

View File

@ -0,0 +1,39 @@
model: gemini-2.0-flash-exp
label:
en_US: Gemini 2.0 Flash Exp
model_type: llm
features:
- agent-thought
- vision
- tool-call
- stream-tool-call
- document
model_properties:
mode: chat
context_size: 1048576
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: top_k
label:
zh_Hans: 取样数量
en_US: Top k
type: int
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_output_tokens
use_template: max_tokens
default: 8192
min: 1
max: 8192
- name: json_schema
use_template: json_schema
pricing:
input: '0.00'
output: '0.00'
unit: '0.000001'
currency: USD

View File

@ -0,0 +1,38 @@
model: gemini-exp-1206
label:
en_US: Gemini exp 1206
model_type: llm
features:
- agent-thought
- vision
- tool-call
- stream-tool-call
model_properties:
mode: chat
context_size: 2097152
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: top_k
label:
zh_Hans: 取样数量
en_US: Top k
type: int
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_output_tokens
use_template: max_tokens
default: 8192
min: 1
max: 8192
- name: json_schema
use_template: json_schema
pricing:
input: '0.00'
output: '0.00'
unit: '0.000001'
currency: USD

View File

@ -1,4 +1,5 @@
- llama-3.1-405b-reasoning
- llama-3.3-70b-versatile
- llama-3.1-70b-versatile
- llama-3.1-8b-instant
- llama3-70b-8192

View File

@ -0,0 +1,25 @@
model: gemma-7b-it
label:
zh_Hans: Gemma 7B Instruction Tuned
en_US: Gemma 7B Instruction Tuned
model_type: llm
features:
- agent-thought
model_properties:
mode: chat
context_size: 8192
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: max_tokens
use_template: max_tokens
default: 512
min: 1
max: 8192
pricing:
input: '0.05'
output: '0.1'
unit: '0.000001'
currency: USD

View File

@ -0,0 +1,25 @@
model: gemma2-9b-it
label:
zh_Hans: Gemma 2 9B Instruction Tuned
en_US: Gemma 2 9B Instruction Tuned
model_type: llm
features:
- agent-thought
model_properties:
mode: chat
context_size: 8192
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: max_tokens
use_template: max_tokens
default: 512
min: 1
max: 8192
pricing:
input: '0.05'
output: '0.1'
unit: '0.000001'
currency: USD

View File

@ -1,7 +1,8 @@
model: llama-3.1-70b-versatile
deprecated: true
label:
zh_Hans: Llama-3.1-70b-versatile
en_US: Llama-3.1-70b-versatile
zh_Hans: Llama-3.1-70b-versatile (DEPRECATED)
en_US: Llama-3.1-70b-versatile (DEPRECATED)
model_type: llm
features:
- agent-thought

View File

@ -1,4 +1,5 @@
model: llama-3.2-11b-text-preview
deprecated: true
label:
zh_Hans: Llama 3.2 11B Text (Preview)
en_US: Llama 3.2 11B Text (Preview)

View File

@ -1,4 +1,5 @@
model: llama-3.2-90b-text-preview
depraceted: true
label:
zh_Hans: Llama 3.2 90B Text (Preview)
en_US: Llama 3.2 90B Text (Preview)

View File

@ -0,0 +1,25 @@
model: llama-3.3-70b-specdec
label:
zh_Hans: Llama 3.3 70B Specdec
en_US: Llama 3.3 70B Specdec
model_type: llm
features:
- agent-thought
model_properties:
mode: chat
context_size: 131072
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: max_tokens
use_template: max_tokens
default: 1024
min: 1
max: 32768
pricing:
input: "0.05"
output: "0.1"
unit: "0.000001"
currency: USD

View File

@ -0,0 +1,25 @@
model: llama-3.3-70b-versatile
label:
zh_Hans: Llama 3.3 70B Versatile
en_US: Llama 3.3 70B Versatile
model_type: llm
features:
- agent-thought
model_properties:
mode: chat
context_size: 131072
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: max_tokens
use_template: max_tokens
default: 1024
min: 1
max: 32768
pricing:
input: "0.05"
output: "0.1"
unit: "0.000001"
currency: USD

View File

@ -0,0 +1,25 @@
model: llama3-groq-70b-8192-tool-use-preview
label:
zh_Hans: Llama3-groq-70b-8192-tool-use (PREVIEW)
en_US: Llama3-groq-70b-8192-tool-use (PREVIEW)
model_type: llm
features:
- agent-thought
model_properties:
mode: chat
context_size: 8192
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: max_tokens
use_template: max_tokens
default: 512
min: 1
max: 8192
pricing:
input: '0.05'
output: '0.08'
unit: '0.000001'
currency: USD

View File

@ -35,6 +35,7 @@ from core.model_runtime.model_providers.minimax.llm.types import MinimaxMessage
class MinimaxLargeLanguageModel(LargeLanguageModel):
model_apis = {
"abab7-chat-preview": MinimaxChatCompletionPro,
"abab6.5t-chat": MinimaxChatCompletionPro,
"abab6.5s-chat": MinimaxChatCompletionPro,
"abab6.5-chat": MinimaxChatCompletionPro,
"abab6-chat": MinimaxChatCompletionPro,

View File

@ -1,3 +1,5 @@
- pixtral-large-latest
- pixtral-large-2411
- pixtral-12b-2409
- codestral-latest
- mistral-embed

View File

@ -5,6 +5,7 @@ label:
model_type: llm
features:
- agent-thought
- vision
model_properties:
mode: chat
context_size: 128000
@ -21,7 +22,7 @@ parameter_rules:
max: 1
- name: max_tokens
use_template: max_tokens
default: 1024
default: 8192
min: 1
max: 8192
- name: safe_prompt

View File

@ -0,0 +1,52 @@
model: pixtral-large-2411
label:
zh_Hans: pixtral-large-2411
en_US: pixtral-large-2411
model_type: llm
features:
- agent-thought
- vision
model_properties:
mode: chat
context_size: 128000
parameter_rules:
- name: temperature
use_template: temperature
default: 0.7
min: 0
max: 1
- name: top_p
use_template: top_p
default: 1
min: 0
max: 1
- name: max_tokens
use_template: max_tokens
default: 8192
min: 1
max: 8192
- name: safe_prompt
default: false
type: boolean
help:
en_US: Whether to inject a safety prompt before all conversations.
zh_Hans: 是否开启提示词审查
label:
en_US: SafePrompt
zh_Hans: 提示词审查
- name: random_seed
type: int
help:
en_US: The seed to use for random sampling. If set, different calls will generate deterministic results.
zh_Hans: 当开启随机数种子以后,你可以通过指定一个固定的种子来使得回答结果更加稳定
label:
en_US: RandomSeed
zh_Hans: 随机数种子
default: 0
min: 0
max: 2147483647
pricing:
input: '0.008'
output: '0.024'
unit: '0.001'
currency: USD

View File

@ -0,0 +1,52 @@
model: pixtral-large-latest
label:
zh_Hans: pixtral-large-latest
en_US: pixtral-large-latest
model_type: llm
features:
- agent-thought
- vision
model_properties:
mode: chat
context_size: 128000
parameter_rules:
- name: temperature
use_template: temperature
default: 0.7
min: 0
max: 1
- name: top_p
use_template: top_p
default: 1
min: 0
max: 1
- name: max_tokens
use_template: max_tokens
default: 8192
min: 1
max: 8192
- name: safe_prompt
default: false
type: boolean
help:
en_US: Whether to inject a safety prompt before all conversations.
zh_Hans: 是否开启提示词审查
label:
en_US: SafePrompt
zh_Hans: 提示词审查
- name: random_seed
type: int
help:
en_US: The seed to use for random sampling. If set, different calls will generate deterministic results.
zh_Hans: 当开启随机数种子以后,你可以通过指定一个固定的种子来使得回答结果更加稳定
label:
en_US: RandomSeed
zh_Hans: 随机数种子
default: 0
min: 0
max: 2147483647
pricing:
input: '0.008'
output: '0.024'
unit: '0.001'
currency: USD

View File

@ -181,9 +181,11 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
# prepare the payload for a simple ping to the model
data = {"model": model, "stream": stream}
if "format" in model_parameters:
data["format"] = model_parameters["format"]
del model_parameters["format"]
if format_schema := model_parameters.pop("format", None):
try:
data["format"] = format_schema if format_schema == "json" else json.loads(format_schema)
except json.JSONDecodeError as e:
raise InvokeBadRequestError(f"Invalid format schema: {str(e)}")
if "keep_alive" in model_parameters:
data["keep_alive"] = model_parameters["keep_alive"]
@ -733,12 +735,12 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
ParameterRule(
name="format",
label=I18nObject(en_US="Format", zh_Hans="返回格式"),
type=ParameterType.STRING,
type=ParameterType.TEXT,
default="json",
help=I18nObject(
en_US="the format to return a response in. Currently the only accepted value is json.",
zh_Hans="返回响应的格式。目前唯一接受的值是json。",
en_US="the format to return a response in. Format can be `json` or a JSON schema.",
zh_Hans="返回响应的格式。目前接受的值是字符串`json`或JSON schema.",
),
options=["json"],
),
],
pricing=PriceConfig(

View File

@ -478,6 +478,10 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel):
usage=usage,
)
break
# handle the error here. for issue #11629
if chunk_json.get("error") and chunk_json.get("choices") is None:
raise ValueError(chunk_json.get("error"))
if chunk_json:
if u := chunk_json.get("usage"):
usage = u

View File

@ -1,4 +1,5 @@
- Tencent/Hunyuan-A52B-Instruct
- Qwen/QwQ-32B-Preview
- Qwen/Qwen2.5-72B-Instruct
- Qwen/Qwen2.5-32B-Instruct
- Qwen/Qwen2.5-14B-Instruct
@ -19,6 +20,7 @@
- 01-ai/Yi-1.5-6B-Chat
- internlm/internlm2_5-20b-chat
- internlm/internlm2_5-7b-chat
- meta-llama/Llama-3.3-70B-Instruct
- meta-llama/Meta-Llama-3.1-405B-Instruct
- meta-llama/Meta-Llama-3.1-70B-Instruct
- meta-llama/Meta-Llama-3.1-8B-Instruct

View File

@ -0,0 +1,53 @@
model: meta-llama/Llama-3.3-70B-Instruct
label:
en_US: meta-llama/Llama-3.3-70B-Instruct
model_type: llm
features:
- agent-thought
- tool-call
- stream-tool-call
model_properties:
mode: chat
context_size: 32768
parameter_rules:
- name: temperature
use_template: temperature
- name: max_tokens
use_template: max_tokens
type: int
default: 512
min: 1
max: 4096
help:
zh_Hans: 指定生成结果长度的上限。如果生成结果截断,可以调大该参数。
en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter.
- name: top_p
use_template: top_p
- name: top_k
label:
zh_Hans: 取样数量
en_US: Top k
type: int
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: frequency_penalty
use_template: frequency_penalty
- name: response_format
label:
zh_Hans: 回复格式
en_US: Response Format
type: string
help:
zh_Hans: 指定模型必须输出的格式
en_US: specifying the format that the model must output
required: false
options:
- text
- json_object
pricing:
input: '4.13'
output: '4.13'
unit: '0.000001'
currency: RMB

View File

@ -0,0 +1,53 @@
model: Qwen/QwQ-32B-Preview
label:
en_US: Qwen/QwQ-32B-Preview
model_type: llm
features:
- agent-thought
- tool-call
- stream-tool-call
model_properties:
mode: chat
context_size: 32768
parameter_rules:
- name: temperature
use_template: temperature
- name: max_tokens
use_template: max_tokens
type: int
default: 512
min: 1
max: 4096
help:
zh_Hans: 指定生成结果长度的上限。如果生成结果截断,可以调大该参数。
en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter.
- name: top_p
use_template: top_p
- name: top_k
label:
zh_Hans: 取样数量
en_US: Top k
type: int
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: frequency_penalty
use_template: frequency_penalty
- name: response_format
label:
zh_Hans: 回复格式
en_US: Response Format
type: string
help:
zh_Hans: 指定模型必须输出的格式
en_US: specifying the format that the model must output
required: false
options:
- text
- json_object
pricing:
input: '1.26'
output: '1.26'
unit: '0.000001'
currency: RMB

View File

@ -59,8 +59,6 @@ parameter_rules:
help:
zh_Hans: 生成时使用的随机数种子用户控制模型生成内容的随机性。支持无符号64位整数默认值为 1234。在使用seed时模型将尽可能生成相同或相似的结果但目前不保证每次生成的结果完全相同。
en_US: The random number seed used when generating, the user controls the randomness of the content generated by the model. Supports unsigned 64-bit integers, default value is 1234. When using seed, the model will try its best to generate the same or similar results, but there is currently no guarantee that the results will be exactly the same every time.
- name: response_format
use_template: response_format
- name: repetition_penalty
required: false
type: float

View File

@ -59,8 +59,6 @@ parameter_rules:
help:
zh_Hans: 生成时使用的随机数种子用户控制模型生成内容的随机性。支持无符号64位整数默认值为 1234。在使用seed时模型将尽可能生成相同或相似的结果但目前不保证每次生成的结果完全相同。
en_US: The random number seed used when generating, the user controls the randomness of the content generated by the model. Supports unsigned 64-bit integers, default value is 1234. When using seed, the model will try its best to generate the same or similar results, but there is currently no guarantee that the results will be exactly the same every time.
- name: response_format
use_template: response_format
- name: repetition_penalty
required: false
type: float

View File

@ -58,8 +58,6 @@ parameter_rules:
help:
zh_Hans: 生成时使用的随机数种子用户控制模型生成内容的随机性。支持无符号64位整数默认值为 1234。在使用seed时模型将尽可能生成相同或相似的结果但目前不保证每次生成的结果完全相同。
en_US: The random number seed used when generating, the user controls the randomness of the content generated by the model. Supports unsigned 64-bit integers, default value is 1234. When using seed, the model will try its best to generate the same or similar results, but there is currently no guarantee that the results will be exactly the same every time.
- name: response_format
use_template: response_format
- name: repetition_penalty
required: false
type: float

View File

@ -59,8 +59,6 @@ parameter_rules:
help:
zh_Hans: 生成时使用的随机数种子用户控制模型生成内容的随机性。支持无符号64位整数默认值为 1234。在使用seed时模型将尽可能生成相同或相似的结果但目前不保证每次生成的结果完全相同。
en_US: The random number seed used when generating, the user controls the randomness of the content generated by the model. Supports unsigned 64-bit integers, default value is 1234. When using seed, the model will try its best to generate the same or similar results, but there is currently no guarantee that the results will be exactly the same every time.
- name: response_format
use_template: response_format
- name: repetition_penalty
required: false
type: float

View File

@ -59,8 +59,6 @@ parameter_rules:
help:
zh_Hans: 生成时使用的随机数种子用户控制模型生成内容的随机性。支持无符号64位整数默认值为 1234。在使用seed时模型将尽可能生成相同或相似的结果但目前不保证每次生成的结果完全相同。
en_US: The random number seed used when generating, the user controls the randomness of the content generated by the model. Supports unsigned 64-bit integers, default value is 1234. When using seed, the model will try its best to generate the same or similar results, but there is currently no guarantee that the results will be exactly the same every time.
- name: response_format
use_template: response_format
- name: repetition_penalty
required: false
type: float

View File

@ -0,0 +1,39 @@
model: gemini-2.0-flash-exp
label:
en_US: Gemini 2.0 Flash Exp
model_type: llm
features:
- agent-thought
- vision
- tool-call
- stream-tool-call
- document
model_properties:
mode: chat
context_size: 1048576
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: top_k
label:
zh_Hans: 取样数量
en_US: Top k
type: int
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_output_tokens
use_template: max_tokens
default: 8192
min: 1
max: 8192
- name: json_schema
use_template: json_schema
pricing:
input: '0.00'
output: '0.00'
unit: '0.000001'
currency: USD

View File

@ -104,13 +104,14 @@ class VertexAiLargeLanguageModel(LargeLanguageModel):
"""
# use Anthropic official SDK references
# - https://github.com/anthropics/anthropic-sdk-python
service_account_info = json.loads(base64.b64decode(credentials["vertex_service_account_key"]))
service_account_key = credentials.get("vertex_service_account_key", "")
project_id = credentials["vertex_project_id"]
SCOPES = ["https://www.googleapis.com/auth/cloud-platform"]
token = ""
# get access token from service account credential
if service_account_info:
if service_account_key:
service_account_info = json.loads(base64.b64decode(service_account_key))
credentials = service_account.Credentials.from_service_account_info(service_account_info, scopes=SCOPES)
request = google.auth.transport.requests.Request()
credentials.refresh(request)
@ -478,10 +479,11 @@ class VertexAiLargeLanguageModel(LargeLanguageModel):
if stop:
config_kwargs["stop_sequences"] = stop
service_account_info = json.loads(base64.b64decode(credentials["vertex_service_account_key"]))
service_account_key = credentials.get("vertex_service_account_key", "")
project_id = credentials["vertex_project_id"]
location = credentials["vertex_location"]
if service_account_info:
if service_account_key:
service_account_info = json.loads(base64.b64decode(service_account_key))
service_accountSA = service_account.Credentials.from_service_account_info(service_account_info)
aiplatform.init(credentials=service_accountSA, project=project_id, location=location)
else:

View File

@ -48,10 +48,11 @@ class VertexAiTextEmbeddingModel(_CommonVertexAi, TextEmbeddingModel):
:param input_type: input type
:return: embeddings result
"""
service_account_info = json.loads(base64.b64decode(credentials["vertex_service_account_key"]))
service_account_key = credentials.get("vertex_service_account_key", "")
project_id = credentials["vertex_project_id"]
location = credentials["vertex_location"]
if service_account_info:
if service_account_key:
service_account_info = json.loads(base64.b64decode(service_account_key))
service_accountSA = service_account.Credentials.from_service_account_info(service_account_info)
aiplatform.init(credentials=service_accountSA, project=project_id, location=location)
else:
@ -100,10 +101,11 @@ class VertexAiTextEmbeddingModel(_CommonVertexAi, TextEmbeddingModel):
:return:
"""
try:
service_account_info = json.loads(base64.b64decode(credentials["vertex_service_account_key"]))
service_account_key = credentials.get("vertex_service_account_key", "")
project_id = credentials["vertex_project_id"]
location = credentials["vertex_location"]
if service_account_info:
if service_account_key:
service_account_info = json.loads(base64.b64decode(service_account_key))
service_accountSA = service_account.Credentials.from_service_account_info(service_account_info)
aiplatform.init(credentials=service_accountSA, project=project_id, location=location)
else:

View File

@ -0,0 +1,66 @@
model: grok-2-1212
label:
en_US: grok-2-1212
model_type: llm
features:
- agent-thought
- tool-call
- multi-tool-call
- stream-tool-call
model_properties:
mode: chat
context_size: 131072
parameter_rules:
- name: temperature
label:
en_US: "Temperature"
zh_Hans: "采样温度"
type: float
default: 0.7
min: 0.0
max: 2.0
precision: 1
required: true
help:
en_US: "The randomness of the sampling temperature control output. The temperature value is within the range of [0.0, 1.0]. The higher the value, the more random and creative the output; the lower the value, the more stable it is. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time."
zh_Hans: "采样温度控制输出的随机性。温度值在 [0.0, 1.0] 范围内,值越高,输出越随机和创造性;值越低,输出越稳定。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。"
- name: top_p
label:
en_US: "Top P"
zh_Hans: "Top P"
type: float
default: 0.7
min: 0.0
max: 1.0
precision: 1
required: true
help:
en_US: "The value range of the sampling method is [0.0, 1.0]. The top_p value determines that the model selects tokens from the top p% of candidate words with the highest probability; when top_p is 0, this parameter is invalid. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time."
zh_Hans: "采样方法的取值范围为 [0.0,1.0]。top_p 值确定模型从概率最高的前p%的候选词中选取 tokens当 top_p 为 0 时,此参数无效。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。"
- name: frequency_penalty
use_template: frequency_penalty
label:
en_US: "Frequency Penalty"
zh_Hans: "频率惩罚"
type: float
default: 0
min: 0
max: 2.0
precision: 1
required: false
help:
en_US: "Number between 0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim."
zh_Hans: "介于0和2.0之间的数字。正值会根据新标记在文本中迄今为止的现有频率来惩罚它们,从而降低模型一字不差地重复同一句话的可能性。"
- name: user
use_template: text
label:
en_US: "User"
zh_Hans: "用户"
type: string
required: false
help:
en_US: "Used to track and differentiate conversation requests from different users."
zh_Hans: "用于追踪和区分不同用户的对话请求。"

View File

@ -0,0 +1,64 @@
model: grok-2-vision-1212
label:
en_US: grok-2-vision-1212
model_type: llm
features:
- agent-thought
- vision
model_properties:
mode: chat
context_size: 8192
parameter_rules:
- name: temperature
label:
en_US: "Temperature"
zh_Hans: "采样温度"
type: float
default: 0.7
min: 0.0
max: 2.0
precision: 1
required: true
help:
en_US: "The randomness of the sampling temperature control output. The temperature value is within the range of [0.0, 1.0]. The higher the value, the more random and creative the output; the lower the value, the more stable it is. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time."
zh_Hans: "采样温度控制输出的随机性。温度值在 [0.0, 1.0] 范围内,值越高,输出越随机和创造性;值越低,输出越稳定。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。"
- name: top_p
label:
en_US: "Top P"
zh_Hans: "Top P"
type: float
default: 0.7
min: 0.0
max: 1.0
precision: 1
required: true
help:
en_US: "The value range of the sampling method is [0.0, 1.0]. The top_p value determines that the model selects tokens from the top p% of candidate words with the highest probability; when top_p is 0, this parameter is invalid. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time."
zh_Hans: "采样方法的取值范围为 [0.0,1.0]。top_p 值确定模型从概率最高的前p%的候选词中选取 tokens当 top_p 为 0 时,此参数无效。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。"
- name: frequency_penalty
use_template: frequency_penalty
label:
en_US: "Frequency Penalty"
zh_Hans: "频率惩罚"
type: float
default: 0
min: 0
max: 2.0
precision: 1
required: false
help:
en_US: "Number between 0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim."
zh_Hans: "介于0和2.0之间的数字。正值会根据新标记在文本中迄今为止的现有频率来惩罚它们,从而降低模型一字不差地重复同一句话的可能性。"
- name: user
use_template: text
label:
en_US: "User"
zh_Hans: "用户"
type: string
required: false
help:
en_US: "Used to track and differentiate conversation requests from different users."
zh_Hans: "用于追踪和区分不同用户的对话请求。"

View File

@ -1,6 +1,6 @@
model: grok-beta
label:
en_US: Grok Beta
en_US: grok-beta
model_type: llm
features:
- agent-thought

View File

@ -1,6 +1,6 @@
model: grok-vision-beta
label:
en_US: Grok Vision Beta
en_US: grok-vision-beta
model_type: llm
features:
- agent-thought

View File

@ -0,0 +1,52 @@
model: glm-4v-flash
label:
en_US: glm-4v-flash
model_type: llm
model_properties:
mode: chat
context_size: 2048
features:
- vision
parameter_rules:
- name: temperature
use_template: temperature
default: 0.95
min: 0.0
max: 1.0
help:
zh_Hans: 采样温度,控制输出的随机性,必须为正数取值范围是:(0.0,1.0],不能等于 0,默认值为 0.95 值越大,会使输出更随机,更具创造性;值越小,输出会更加稳定或确定建议您根据应用场景调整 top_p 或 temperature 参数,但不要同时调整两个参数。
en_US: Sampling temperature, controls the randomness of the output, must be a positive number. The value range is (0.0,1.0], which cannot be equal to 0. The default value is 0.95. The larger the value, the more random and creative the output will be; the smaller the value, The output will be more stable or certain. It is recommended that you adjust the top_p or temperature parameters according to the application scenario, but do not adjust both parameters at the same time.
- name: top_p
use_template: top_p
default: 0.6
help:
zh_Hans: 用温度取样的另一种方法,称为核取样取值范围是:(0.0, 1.0) 开区间,不能等于 0 或 1默认值为 0.7 模型考虑具有 top_p 概率质量tokens的结果例如0.1 意味着模型解码器只考虑从前 10% 的概率的候选集中取 tokens 建议您根据应用场景调整 top_p 或 temperature 参数,但不要同时调整两个参数。
en_US: Another method of temperature sampling is called kernel sampling. The value range is (0.0, 1.0) open interval, which cannot be equal to 0 or 1. The default value is 0.7. The model considers the results with top_p probability mass tokens. For example 0.1 means The model decoder only considers tokens from the candidate set with the top 10% probability. It is recommended that you adjust the top_p or temperature parameters according to the application scenario, but do not adjust both parameters at the same time.
- name: do_sample
label:
zh_Hans: 采样策略
en_US: Sampling strategy
type: boolean
help:
zh_Hans: do_sample 为 true 时启用采样策略do_sample 为 false 时采样策略 temperature、top_p 将不生效。默认值为 true。
en_US: When `do_sample` is set to true, the sampling strategy is enabled. When `do_sample` is set to false, the sampling strategies such as `temperature` and `top_p` will not take effect. The default value is true.
default: true
- name: max_tokens
use_template: max_tokens
default: 1024
min: 1
max: 1024
- name: web_search
type: boolean
label:
zh_Hans: 联网搜索
en_US: Web Search
default: false
help:
zh_Hans: 模型内置了互联网搜索服务,该参数控制模型在生成文本时是否参考使用互联网搜索结果。启用互联网搜索,模型会将搜索结果作为文本生成过程中的参考信息,但模型会基于其内部逻辑“自行判断”是否使用互联网搜索结果。
en_US: The model has a built-in Internet search service. This parameter controls whether the model refers to Internet search results when generating text. When Internet search is enabled, the model will use the search results as reference information in the text generation process, but the model will "judge" whether to use Internet search results based on its internal logic.
pricing:
input: '0.00'
output: '0.00'
unit: '0.000001'
currency: RMB

View File

@ -144,7 +144,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
if copy_prompt_message.role in {PromptMessageRole.USER, PromptMessageRole.SYSTEM, PromptMessageRole.TOOL}:
if isinstance(copy_prompt_message.content, list):
# check if model is 'glm-4v'
if model not in {"glm-4v", "glm-4v-plus"}:
if not model.startswith("glm-4v"):
# not support list message
continue
# get image and
@ -188,7 +188,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
else:
model_parameters["tools"] = [web_search_params]
if model in {"glm-4v", "glm-4v-plus"}:
if model.startswith("glm-4v"):
params = self._construct_glm_4v_parameter(model, new_prompt_messages, model_parameters)
else:
params = {"model": model, "messages": [], **model_parameters}
@ -412,6 +412,8 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
human_prompt = "\n\nHuman:"
ai_prompt = "\n\nAssistant:"
content = message.content
if isinstance(content, list):
content = "".join(c.data for c in content if c.type == PromptMessageContentType.TEXT)
if isinstance(message, UserPromptMessage):
message_text = f"{human_prompt} {content}"

View File

@ -4,7 +4,7 @@ import os
from datetime import datetime, timedelta
from typing import Optional
from langfuse import Langfuse
from langfuse import Langfuse # type: ignore
from core.ops.base_trace_instance import BaseTraceInstance
from core.ops.entities.config_entity import LangfuseConfig
@ -65,8 +65,11 @@ class LangFuseDataTrace(BaseTraceInstance):
self.generate_name_trace(trace_info)
def workflow_trace(self, trace_info: WorkflowTraceInfo):
trace_id = trace_info.workflow_app_log_id or trace_info.workflow_run_id
trace_id = trace_info.workflow_run_id
user_id = trace_info.metadata.get("user_id")
metadata = trace_info.metadata
metadata["workflow_app_log_id"] = trace_info.workflow_app_log_id
if trace_info.message_id:
trace_id = trace_info.message_id
name = TraceTaskName.MESSAGE_TRACE.value
@ -76,22 +79,20 @@ class LangFuseDataTrace(BaseTraceInstance):
name=name,
input=trace_info.workflow_run_inputs,
output=trace_info.workflow_run_outputs,
metadata=trace_info.metadata,
metadata=metadata,
session_id=trace_info.conversation_id,
tags=["message", "workflow"],
created_at=trace_info.start_time,
updated_at=trace_info.end_time,
)
self.add_trace(langfuse_trace_data=trace_data)
workflow_span_data = LangfuseSpan(
id=(trace_info.workflow_app_log_id or trace_info.workflow_run_id),
id=trace_info.workflow_run_id,
name=TraceTaskName.WORKFLOW_TRACE.value,
input=trace_info.workflow_run_inputs,
output=trace_info.workflow_run_outputs,
trace_id=trace_id,
start_time=trace_info.start_time,
end_time=trace_info.end_time,
metadata=trace_info.metadata,
metadata=metadata,
level=LevelEnum.DEFAULT if trace_info.error == "" else LevelEnum.ERROR,
status_message=trace_info.error or "",
)
@ -103,7 +104,7 @@ class LangFuseDataTrace(BaseTraceInstance):
name=TraceTaskName.WORKFLOW_TRACE.value,
input=trace_info.workflow_run_inputs,
output=trace_info.workflow_run_outputs,
metadata=trace_info.metadata,
metadata=metadata,
session_id=trace_info.conversation_id,
tags=["workflow"],
)
@ -192,7 +193,7 @@ class LangFuseDataTrace(BaseTraceInstance):
metadata=metadata,
level=(LevelEnum.DEFAULT if status == "succeeded" else LevelEnum.ERROR),
status_message=trace_info.error or "",
parent_observation_id=(trace_info.workflow_app_log_id or trace_info.workflow_run_id),
parent_observation_id=trace_info.workflow_run_id,
)
else:
span_data = LangfuseSpan(
@ -239,11 +240,13 @@ class LangFuseDataTrace(BaseTraceInstance):
file_list = trace_info.file_list
metadata = trace_info.metadata
message_data = trace_info.message_data
if message_data is None:
return
message_id = message_data.id
user_id = message_data.from_account_id
if message_data.from_end_user_id:
end_user_data: EndUser = (
end_user_data: Optional[EndUser] = (
db.session.query(EndUser).filter(EndUser.id == message_data.from_end_user_id).first()
)
if end_user_data is not None:
@ -300,6 +303,8 @@ class LangFuseDataTrace(BaseTraceInstance):
self.add_generation(langfuse_generation_data)
def moderation_trace(self, trace_info: ModerationTraceInfo):
if trace_info.message_data is None:
return
span_data = LangfuseSpan(
name=TraceTaskName.MODERATION_TRACE.value,
input=trace_info.inputs,
@ -319,9 +324,11 @@ class LangFuseDataTrace(BaseTraceInstance):
def suggested_question_trace(self, trace_info: SuggestedQuestionTraceInfo):
message_data = trace_info.message_data
if message_data is None:
return
generation_usage = GenerationUsage(
total=len(str(trace_info.suggested_question)),
input=len(trace_info.inputs),
input=len(trace_info.inputs) if trace_info.inputs else 0,
output=len(trace_info.suggested_question),
unit=UnitEnum.CHARACTERS,
)
@ -342,6 +349,8 @@ class LangFuseDataTrace(BaseTraceInstance):
self.add_generation(langfuse_generation_data=generation_data)
def dataset_retrieval_trace(self, trace_info: DatasetRetrievalTraceInfo):
if trace_info.message_data is None:
return
dataset_retrieval_span_data = LangfuseSpan(
name=TraceTaskName.DATASET_RETRIEVAL_TRACE.value,
input=trace_info.inputs,

View File

@ -62,15 +62,17 @@ class LangSmithDataTrace(BaseTraceInstance):
self.generate_name_trace(trace_info)
def workflow_trace(self, trace_info: WorkflowTraceInfo):
trace_id = trace_info.message_id or trace_info.workflow_app_log_id or trace_info.workflow_run_id
trace_id = trace_info.message_id or trace_info.workflow_run_id
message_dotted_order = (
generate_dotted_order(trace_info.message_id, trace_info.start_time) if trace_info.message_id else None
)
workflow_dotted_order = generate_dotted_order(
trace_info.workflow_app_log_id or trace_info.workflow_run_id,
trace_info.workflow_run_id,
trace_info.workflow_data.created_at,
message_dotted_order,
)
metadata = trace_info.metadata
metadata["workflow_app_log_id"] = trace_info.workflow_app_log_id
if trace_info.message_id:
message_run = LangSmithRunModel(
@ -82,7 +84,7 @@ class LangSmithDataTrace(BaseTraceInstance):
start_time=trace_info.start_time,
end_time=trace_info.end_time,
extra={
"metadata": trace_info.metadata,
"metadata": metadata,
},
tags=["message", "workflow"],
error=trace_info.error,
@ -94,7 +96,7 @@ class LangSmithDataTrace(BaseTraceInstance):
langsmith_run = LangSmithRunModel(
file_list=trace_info.file_list,
total_tokens=trace_info.total_tokens,
id=trace_info.workflow_app_log_id or trace_info.workflow_run_id,
id=trace_info.workflow_run_id,
name=TraceTaskName.WORKFLOW_TRACE.value,
inputs=trace_info.workflow_run_inputs,
run_type=LangSmithRunType.tool,
@ -102,7 +104,7 @@ class LangSmithDataTrace(BaseTraceInstance):
end_time=trace_info.workflow_data.finished_at,
outputs=trace_info.workflow_run_outputs,
extra={
"metadata": trace_info.metadata,
"metadata": metadata,
},
error=trace_info.error,
tags=["workflow"],
@ -204,7 +206,7 @@ class LangSmithDataTrace(BaseTraceInstance):
extra={
"metadata": metadata,
},
parent_run_id=trace_info.workflow_app_log_id or trace_info.workflow_run_id,
parent_run_id=trace_info.workflow_run_id,
tags=["node_execution"],
id=node_execution_id,
trace_id=trace_id,

View File

@ -1,13 +1,10 @@
import copy
import json
import logging
from collections.abc import Iterable
from typing import Any, Optional
from opensearchpy import OpenSearch
from opensearchpy.helpers import bulk
from pydantic import BaseModel, model_validator
from tenacity import retry, stop_after_attempt, wait_fixed
from configs import dify_config
from core.rag.datasource.vdb.field import Field
@ -24,11 +21,15 @@ logging.basicConfig(level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s")
logging.getLogger("lindorm").setLevel(logging.WARN)
ROUTING_FIELD = "routing_field"
UGC_INDEX_PREFIX = "ugc_index"
class LindormVectorStoreConfig(BaseModel):
hosts: str
username: Optional[str] = None
password: Optional[str] = None
using_ugc: Optional[bool] = False
@model_validator(mode="before")
@classmethod
@ -42,9 +43,7 @@ class LindormVectorStoreConfig(BaseModel):
return values
def to_opensearch_params(self) -> dict[str, Any]:
params = {
"hosts": self.hosts,
}
params = {"hosts": self.hosts}
if self.username and self.password:
params["http_auth"] = (self.username, self.password)
return params
@ -52,9 +51,21 @@ class LindormVectorStoreConfig(BaseModel):
class LindormVectorStore(BaseVector):
def __init__(self, collection_name: str, config: LindormVectorStoreConfig, **kwargs):
super().__init__(collection_name.lower())
self._routing = None
self._routing_field = None
if config.using_ugc:
routing_value: str = kwargs.get("routing_value")
if routing_value is None:
raise ValueError("UGC index should init vector with valid 'routing_value' parameter value")
self._routing = routing_value.lower()
self._routing_field = ROUTING_FIELD
ugc_index_name = collection_name
super().__init__(ugc_index_name.lower())
else:
super().__init__(collection_name.lower())
self._client_config = config
self._client = OpenSearch(**config.to_opensearch_params())
self._using_ugc = config.using_ugc
self.kwargs = kwargs
def get_type(self) -> str:
@ -67,94 +78,37 @@ class LindormVectorStore(BaseVector):
def refresh(self):
self._client.indices.refresh(index=self._collection_name)
def __filter_existed_ids(
self,
texts: list[str],
metadatas: list[dict],
ids: list[str],
bulk_size: int = 1024,
) -> tuple[Iterable[str], Optional[list[dict]], Optional[list[str]]]:
@retry(stop=stop_after_attempt(3), wait=wait_fixed(60))
def __fetch_existing_ids(batch_ids: list[str]) -> set[str]:
try:
existing_docs = self._client.mget(index=self._collection_name, body={
"ids": batch_ids}, _source=False)
return {doc["_id"] for doc in existing_docs["docs"] if doc["found"]}
except Exception as e:
logger.exception(f"Error fetching batch {batch_ids}")
return set()
@retry(stop=stop_after_attempt(3), wait=wait_fixed(60))
def __fetch_existing_routing_ids(batch_ids: list[str], route_ids: list[str]) -> set[str]:
try:
existing_docs = self._client.mget(
body={
"docs": [
{"_index": self._collection_name,
"_id": id, "routing": routing}
for id, routing in zip(batch_ids, route_ids)
]
},
_source=False,
)
return {doc["_id"] for doc in existing_docs["docs"] if doc["found"]}
except Exception as e:
logger.exception(f"Error fetching batch ids: {batch_ids}")
return set()
if ids is None:
return texts, metadatas, ids
if len(texts) != len(ids):
raise RuntimeError(f"texts {len(texts)} != {ids}")
filtered_texts = []
filtered_metadatas = []
filtered_ids = []
def batch(iterable, n):
length = len(iterable)
for idx in range(0, length, n):
yield iterable[idx: min(idx + n, length)]
for ids_batch, texts_batch, metadatas_batch in zip(
batch(ids, bulk_size),
batch(texts, bulk_size),
batch(metadatas, bulk_size) if metadatas is not None else batch(
[None] * len(ids), bulk_size),
):
existing_ids_set = __fetch_existing_ids(ids_batch)
for text, metadata, doc_id in zip(texts_batch, metadatas_batch, ids_batch):
if doc_id not in existing_ids_set:
filtered_texts.append(text)
filtered_ids.append(doc_id)
if metadatas is not None:
filtered_metadatas.append(metadata)
return filtered_texts, metadatas if metadatas is None else filtered_metadatas, filtered_ids
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
actions = []
uuids = self._get_uuids(documents)
for i in range(len(documents)):
action = {
"_op_type": "index",
"_index": self._collection_name.lower(),
"_id": uuids[i],
"_source": {
Field.CONTENT_KEY.value: documents[i].page_content,
# Make sure you pass an array here
Field.VECTOR.value: embeddings[i],
Field.METADATA_KEY.value: documents[i].metadata,
},
action_header = {
"index": {
"_index": self.collection_name.lower(),
"_id": uuids[i],
}
}
actions.append(action)
bulk(self._client, actions)
self.refresh()
action_values = {
Field.CONTENT_KEY.value: documents[i].page_content,
Field.VECTOR.value: embeddings[i], # Make sure you pass an array here
Field.METADATA_KEY.value: documents[i].metadata,
}
if self._using_ugc:
action_header["index"]["routing"] = self._routing
action_values[self._routing_field] = self._routing
actions.append(action_header)
actions.append(action_values)
response = self._client.bulk(actions)
if response["errors"]:
for item in response["items"]:
print(f"{item['index']['status']}: {item['index']['error']['type']}")
else:
self.refresh()
def get_ids_by_metadata_field(self, key: str, value: str):
query = {
"query": {"term": {f"{Field.METADATA_KEY.value}.{key}.keyword": value}}}
query = {"query": {"bool": {"must": [{"term": {f"{Field.METADATA_KEY.value}.{key}.keyword": value}}]}}}
if self._using_ugc:
query["query"]["bool"]["must"].append({"term": {f"{self._routing_field}.keyword": self._routing}})
response = self._client.search(index=self._collection_name, body=query)
if response["hits"]["hits"]:
return [hit["_id"] for hit in response["hits"]["hits"]]
@ -162,47 +116,54 @@ class LindormVectorStore(BaseVector):
return None
def delete_by_metadata_field(self, key: str, value: str):
query_str = {"query": {"match": {f"metadata.{key}": f"{value}"}}}
results = self._client.search(
index=self._collection_name, body=query_str)
ids = [hit["_id"] for hit in results["hits"]["hits"]]
ids = self.get_ids_by_metadata_field(key, value)
if ids:
self.delete_by_ids(ids)
def delete_by_ids(self, ids: list[str]) -> None:
params = {}
if self._using_ugc:
params["routing"] = self._routing
for id in ids:
if self._client.exists(index=self._collection_name, id=id):
self._client.delete(index=self._collection_name, id=id)
if self._client.exists(index=self._collection_name, id=id, params=params):
params = {}
if self._using_ugc:
params["routing"] = self._routing
self._client.delete(index=self._collection_name, id=id, params=params)
self.refresh()
else:
logger.warning(
f"DELETE BY ID: ID {id} does not exist in the index.")
def delete(self) -> None:
try:
if self._using_ugc:
routing_filter_query = {
"query": {"bool": {"must": [{"term": {f"{self._routing_field}.keyword": self._routing}}]}}
}
self._client.delete_by_query(self._collection_name, body=routing_filter_query)
self.refresh()
else:
if self._client.indices.exists(index=self._collection_name):
self._client.indices.delete(
index=self._collection_name, params={"timeout": 60})
logger.info("Delete index success")
else:
logger.warning(
f"Index '{self._collection_name}' does not exist. No deletion performed.")
except Exception as e:
logger.exception(f"Error occurred while deleting the index: {self._collection_name}")
raise e
logger.warning(f"Index '{self._collection_name}' does not exist. No deletion performed.")
def text_exists(self, id: str) -> bool:
try:
self._client.get(index=self._collection_name, id=id)
params = {}
if self._using_ugc:
params["routing"] = self._routing
self._client.get(index=self._collection_name, id=id, params=params)
return True
except:
return False
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
# Make sure query_vector is a list
if not isinstance(query_vector, list):
raise ValueError("query_vector should be a list of floats")
# Check whether query_vector is a floating-point number list
if not all(isinstance(x, float) for x in query_vector):
raise ValueError("All elements in query_vector should be floats")
@ -210,8 +171,10 @@ class LindormVectorStore(BaseVector):
query = default_vector_search_query(
query_vector=query_vector, k=top_k, **kwargs)
try:
response = self._client.search(
index=self._collection_name, body=query)
params = {}
if self._using_ugc:
params["routing"] = self._routing
response = self._client.search(index=self._collection_name, body=query, params=params)
except Exception as e:
logger.exception(f"Error executing vector search, query: {query}")
raise
@ -244,7 +207,7 @@ class LindormVectorStore(BaseVector):
minimum_should_match = kwargs.get("minimum_should_match", 0)
top_k = kwargs.get("top_k", 10)
filters = kwargs.get("filter")
routing = kwargs.get("routing")
routing = self._routing
full_text_query = default_text_search_query(
query_text=query,
k=top_k,
@ -255,6 +218,7 @@ class LindormVectorStore(BaseVector):
minimum_should_match=minimum_should_match,
filters=filters,
routing=routing,
routing_field=self._routing_field,
)
response = self._client.search(
index=self._collection_name, body=full_text_query)
@ -279,17 +243,18 @@ class LindormVectorStore(BaseVector):
f"Collection {self._collection_name} already exists.")
return
if self._client.indices.exists(index=self._collection_name):
logger.info("{self._collection_name.lower()} already exists.")
logger.info(f"{self._collection_name.lower()} already exists.")
redis_client.set(collection_exist_cache_key, 1, ex=3600)
return
if len(self.kwargs) == 0 and len(kwargs) != 0:
self.kwargs = copy.deepcopy(kwargs)
vector_field = kwargs.pop("vector_field", Field.VECTOR.value)
shards = kwargs.pop("shards", 2)
shards = kwargs.pop("shards", 4)
engine = kwargs.pop("engine", "lvector")
method_name = kwargs.pop("method_name", "hnsw")
method_name = kwargs.pop("method_name", dify_config.DEFAULT_INDEX_TYPE)
space_type = kwargs.pop("space_type", dify_config.DEFAULT_DISTANCE_TYPE)
data_type = kwargs.pop("data_type", "float")
space_type = kwargs.pop("space_type", "cosinesimil")
hnsw_m = kwargs.pop("hnsw_m", 24)
hnsw_ef_construction = kwargs.pop("hnsw_ef_construction", 500)
@ -305,10 +270,10 @@ class LindormVectorStore(BaseVector):
mapping = default_text_mapping(
dimension,
method_name,
space_type=space_type,
shards=shards,
engine=engine,
data_type=data_type,
space_type=space_type,
vector_field=vector_field,
hnsw_m=hnsw_m,
hnsw_ef_construction=hnsw_ef_construction,
@ -318,6 +283,7 @@ class LindormVectorStore(BaseVector):
centroids_hnsw_m=centroids_hnsw_m,
centroids_hnsw_ef_construct=centroids_hnsw_ef_construct,
centroids_hnsw_ef_search=centroids_hnsw_ef_search,
using_ugc=self._using_ugc,
**kwargs,
)
self._client.indices.create(
@ -327,15 +293,20 @@ class LindormVectorStore(BaseVector):
def default_text_mapping(dimension: int, method_name: str, **kwargs: Any) -> dict:
routing_field = kwargs.get("routing_field")
excludes_from_source = kwargs.get("excludes_from_source")
analyzer = kwargs.get("analyzer", "ik_max_word")
text_field = kwargs.get("text_field", Field.CONTENT_KEY.value)
engine = kwargs["engine"]
shard = kwargs["shards"]
space_type = kwargs["space_type"]
space_type = kwargs.get("space_type")
if space_type is None:
if method_name == "hnsw":
space_type = "l2"
else:
space_type = "cosine"
data_type = kwargs["data_type"]
vector_field = kwargs.get("vector_field", Field.VECTOR.value)
using_ugc = kwargs.get("using_ugc", False)
if method_name == "ivfpq":
ivfpq_m = kwargs["ivfpq_m"]
@ -385,13 +356,11 @@ def default_text_mapping(dimension: int, method_name: str, **kwargs: Any) -> dic
# e.g. {"excludes": ["vector_field"]}
mapping["mappings"]["_source"] = {"excludes": excludes_from_source}
if method_name == "ivfpq" and routing_field is not None:
if using_ugc and method_name == "ivfpq":
mapping["settings"]["index"]["knn_routing"] = True
mapping["settings"]["index"]["knn.offline.construction"] = True
if method_name == "flat" and routing_field is not None:
elif using_ugc and method_name == "hnsw" or using_ugc and method_name == "flat":
mapping["settings"]["index"]["knn_routing"] = True
return mapping
@ -405,14 +374,12 @@ def default_text_search_query(
minimum_should_match: int = 0,
filters: Optional[list[dict]] = None,
routing: Optional[str] = None,
routing_field: Optional[str] = None,
**kwargs,
) -> dict:
if routing is not None:
routing_field = kwargs.get("routing_field", "routing_field")
query_clause = {
"bool": {
"must": [{"match": {text_field: query_text}}, {"term": {f"metadata.{routing_field}.keyword": routing}}]
}
"bool": {"must": [{"match": {text_field: query_text}}, {"term": {f"{routing_field}.keyword": routing}}]}
}
else:
query_clause = {"match": {text_field: query_text}}
@ -472,7 +439,7 @@ def default_vector_search_query(
) -> dict:
if filters is not None:
filter_type = "post_filter" if filter_type is None else filter_type
if not isinstance(filter, list):
if not isinstance(filters, list):
raise RuntimeError(f"unexpected filter with {type(filters)}")
final_ext = {"lvector": {}}
if min_score != "0.0":
@ -508,17 +475,40 @@ def default_vector_search_query(
class LindormVectorStoreFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> LindormVectorStore:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
dataset.index_struct = json.dumps(
self.gen_index_struct_dict(VectorType.LINDORM, collection_name))
lindorm_config = LindormVectorStoreConfig(
hosts=dify_config.LINDORM_URL,
username=dify_config.LINDORM_USERNAME,
password=dify_config.LINDORM_PASSWORD,
using_ugc=dify_config.USING_UGC_INDEX,
)
return LindormVectorStore(collection_name, lindorm_config)
using_ugc = dify_config.USING_UGC_INDEX
routing_value = None
if dataset.index_struct:
if using_ugc:
dimension = dataset.index_struct_dict["dimension"]
index_type = dataset.index_struct_dict["index_type"]
distance_type = dataset.index_struct_dict["distance_type"]
index_name = f"{UGC_INDEX_PREFIX}_{dimension}_{index_type}_{distance_type}"
routing_value = dataset.index_struct_dict["vector_store"]["class_prefix"]
else:
index_name = dataset.index_struct_dict["vector_store"]["class_prefix"]
else:
embedding_vector = embeddings.embed_query("hello word")
dimension = len(embedding_vector)
index_type = dify_config.DEFAULT_INDEX_TYPE
distance_type = dify_config.DEFAULT_DISTANCE_TYPE
class_prefix = Dataset.gen_collection_name_by_id(dataset.id)
index_struct_dict = {
"type": VectorType.LINDORM,
"vector_store": {"class_prefix": class_prefix},
"index_type": index_type,
"dimension": dimension,
"distance_type": distance_type,
}
dataset.index_struct = json.dumps(index_struct_dict)
if using_ugc:
index_name = f"{UGC_INDEX_PREFIX}_{dimension}_{index_type}_{distance_type}"
routing_value = class_prefix
else:
index_name = class_prefix
return LindormVectorStore(index_name, lindorm_config, routing_value=routing_value)

View File

@ -162,7 +162,7 @@ class TidbService:
clusters = []
tidb_serverless_list_map = {item.cluster_id: item for item in tidb_serverless_list}
cluster_ids = [item.cluster_id for item in tidb_serverless_list]
params = {"clusterIds": cluster_ids, "view": "FULL"}
params = {"clusterIds": cluster_ids, "view": "BASIC"}
response = requests.get(
f"{api_url}/clusters:batchGet", params=params, auth=HTTPDigestAuth(public_key, private_key)
)

View File

@ -37,8 +37,6 @@ class TiDBVectorConfig(BaseModel):
raise ValueError("config TIDB_VECTOR_PORT is required")
if not values["user"]:
raise ValueError("config TIDB_VECTOR_USER is required")
if not values["password"]:
raise ValueError("config TIDB_VECTOR_PASSWORD is required")
if not values["database"]:
raise ValueError("config TIDB_VECTOR_DATABASE is required")
if not values["program_name"]:

View File

@ -103,7 +103,7 @@ class ExtractProcessor:
extractor = ExcelExtractor(file_path)
elif file_extension == ".pdf":
extractor = PdfExtractor(file_path)
elif file_extension in {".md", ".markdown"}:
elif file_extension in {".md", ".markdown", ".mdx"}:
extractor = (
UnstructuredMarkdownExtractor(file_path, unstructured_api_url, unstructured_api_key)
if is_automatic
@ -141,7 +141,7 @@ class ExtractProcessor:
extractor = ExcelExtractor(file_path)
elif file_extension == ".pdf":
extractor = PdfExtractor(file_path)
elif file_extension in {".md", ".markdown"}:
elif file_extension in {".md", ".markdown", ".mdx"}:
extractor = MarkdownExtractor(file_path, autodetect_encoding=True)
elif file_extension in {".htm", ".html"}:
extractor = HtmlExtractor(file_path)

View File

@ -270,9 +270,6 @@ class ApiTool(Tool):
elif property["type"] == "object" or property["type"] == "array":
if isinstance(value, str):
try:
# an array str like '[1,2]' also can convert to list [1,2] through json.loads
# json not support single quote, but we can support it
value = value.replace("'", '"')
return json.loads(value)
except ValueError:
return value

View File

@ -4,6 +4,7 @@ from core.model_runtime.utils.encoders import jsonable_encoder
from core.workflow.graph_engine.entities.event import (
GraphEngineEvent,
GraphRunFailedEvent,
GraphRunPartialSucceededEvent,
GraphRunStartedEvent,
GraphRunSucceededEvent,
IterationRunFailedEvent,
@ -39,6 +40,8 @@ class WorkflowLoggingCallback(WorkflowCallback):
self.print_text("\n[GraphRunStartedEvent]", color="pink")
elif isinstance(event, GraphRunSucceededEvent):
self.print_text("\n[GraphRunSucceededEvent]", color="green")
elif isinstance(event, GraphRunPartialSucceededEvent):
self.print_text("\n[GraphRunPartialSucceededEvent]", color="pink")
elif isinstance(event, GraphRunFailedEvent):
self.print_text(f"\n[GraphRunFailedEvent] reason: {event.error}", color="red")
elif isinstance(event, NodeRunStartedEvent):

View File

@ -25,6 +25,7 @@ class NodeRunMetadataKey(StrEnum):
PARENT_PARALLEL_START_NODE_ID = "parent_parallel_start_node_id"
PARALLEL_MODE_RUN_ID = "parallel_mode_run_id"
ITERATION_DURATION_MAP = "iteration_duration_map" # single iteration duration if iteration node runs
ERROR_STRATEGY = "error_strategy" # node in continue on error mode return the field
class NodeRunResult(BaseModel):
@ -43,3 +44,4 @@ class NodeRunResult(BaseModel):
edge_source_handle: Optional[str] = None # source handle id of node with multiple branches
error: Optional[str] = None # error message if status is failed
error_type: Optional[str] = None # error type if status is failed

View File

@ -1,3 +1,4 @@
from collections.abc import Mapping
from datetime import datetime
from typing import Any, Optional
@ -32,6 +33,12 @@ class GraphRunSucceededEvent(BaseGraphEvent):
class GraphRunFailedEvent(BaseGraphEvent):
error: str = Field(..., description="failed reason")
exceptions_count: Optional[int] = Field(description="exception count", default=0)
class GraphRunPartialSucceededEvent(BaseGraphEvent):
exceptions_count: int = Field(..., description="exception count")
outputs: Optional[dict[str, Any]] = None
###########################################
@ -82,6 +89,10 @@ class NodeRunFailedEvent(BaseNodeEvent):
error: str = Field(..., description="error")
class NodeRunExceptionEvent(BaseNodeEvent):
error: str = Field(..., description="error")
class NodeInIterationFailedEvent(BaseNodeEvent):
error: str = Field(..., description="error")
@ -140,8 +151,8 @@ class BaseIterationEvent(GraphEngineEvent):
class IterationRunStartedEvent(BaseIterationEvent):
start_at: datetime = Field(..., description="start at")
inputs: Optional[dict[str, Any]] = None
metadata: Optional[dict[str, Any]] = None
inputs: Optional[Mapping[str, Any]] = None
metadata: Optional[Mapping[str, Any]] = None
predecessor_node_id: Optional[str] = None
@ -153,18 +164,18 @@ class IterationRunNextEvent(BaseIterationEvent):
class IterationRunSucceededEvent(BaseIterationEvent):
start_at: datetime = Field(..., description="start at")
inputs: Optional[dict[str, Any]] = None
outputs: Optional[dict[str, Any]] = None
metadata: Optional[dict[str, Any]] = None
inputs: Optional[Mapping[str, Any]] = None
outputs: Optional[Mapping[str, Any]] = None
metadata: Optional[Mapping[str, Any]] = None
steps: int = 0
iteration_duration_map: Optional[dict[str, float]] = None
class IterationRunFailedEvent(BaseIterationEvent):
start_at: datetime = Field(..., description="start at")
inputs: Optional[dict[str, Any]] = None
outputs: Optional[dict[str, Any]] = None
metadata: Optional[dict[str, Any]] = None
inputs: Optional[Mapping[str, Any]] = None
outputs: Optional[Mapping[str, Any]] = None
metadata: Optional[Mapping[str, Any]] = None
steps: int = 0
error: str = Field(..., description="failed reason")

View File

@ -64,13 +64,21 @@ class Graph(BaseModel):
edge_configs = graph_config.get("edges")
if edge_configs is None:
edge_configs = []
# node configs
node_configs = graph_config.get("nodes")
if not node_configs:
raise ValueError("Graph must have at least one node")
edge_configs = cast(list, edge_configs)
node_configs = cast(list, node_configs)
# reorganize edges mapping
edge_mapping: dict[str, list[GraphEdge]] = {}
reverse_edge_mapping: dict[str, list[GraphEdge]] = {}
target_edge_ids = set()
fail_branch_source_node_id = [
node["id"] for node in node_configs if node["data"].get("error_strategy") == "fail-branch"
]
for edge_config in edge_configs:
source_node_id = edge_config.get("source")
if not source_node_id:
@ -90,8 +98,16 @@ class Graph(BaseModel):
# parse run condition
run_condition = None
if edge_config.get("sourceHandle") and edge_config.get("sourceHandle") != "source":
run_condition = RunCondition(type="branch_identify", branch_identify=edge_config.get("sourceHandle"))
if edge_config.get("sourceHandle"):
if (
edge_config.get("source") in fail_branch_source_node_id
and edge_config.get("sourceHandle") != "fail-branch"
):
run_condition = RunCondition(type="branch_identify", branch_identify="success-branch")
elif edge_config.get("sourceHandle") != "source":
run_condition = RunCondition(
type="branch_identify", branch_identify=edge_config.get("sourceHandle")
)
graph_edge = GraphEdge(
source_node_id=source_node_id, target_node_id=target_node_id, run_condition=run_condition
@ -100,13 +116,6 @@ class Graph(BaseModel):
edge_mapping[source_node_id].append(graph_edge)
reverse_edge_mapping[target_node_id].append(graph_edge)
# node configs
node_configs = graph_config.get("nodes")
if not node_configs:
raise ValueError("Graph must have at least one node")
node_configs = cast(list, node_configs)
# fetch nodes that have no predecessor node
root_node_configs = []
all_node_id_config_mapping: dict[str, dict] = {}

View File

@ -15,6 +15,7 @@ class RouteNodeState(BaseModel):
SUCCESS = "success"
FAILED = "failed"
PAUSED = "paused"
EXCEPTION = "exception"
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
"""node state id"""
@ -51,7 +52,11 @@ class RouteNodeState(BaseModel):
:param run_result: run result
"""
if self.status in {RouteNodeState.Status.SUCCESS, RouteNodeState.Status.FAILED}:
if self.status in {
RouteNodeState.Status.SUCCESS,
RouteNodeState.Status.FAILED,
RouteNodeState.Status.EXCEPTION,
}:
raise Exception(f"Route state {self.id} already finished")
if run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
@ -59,6 +64,9 @@ class RouteNodeState(BaseModel):
elif run_result.status == WorkflowNodeExecutionStatus.FAILED:
self.status = RouteNodeState.Status.FAILED
self.failed_reason = run_result.error
elif run_result.status == WorkflowNodeExecutionStatus.EXCEPTION:
self.status = RouteNodeState.Status.EXCEPTION
self.failed_reason = run_result.error
else:
raise Exception(f"Invalid route status {run_result.status}")

View File

@ -5,21 +5,23 @@ import uuid
from collections.abc import Generator, Mapping
from concurrent.futures import ThreadPoolExecutor, wait
from copy import copy, deepcopy
from typing import Any, Optional
from typing import Any, Optional, cast
from flask import Flask, current_app
from core.app.apps.base_app_queue_manager import GenerateTaskStoppedError
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.node_entities import NodeRunMetadataKey
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
from core.workflow.entities.variable_pool import VariablePool, VariableValue
from core.workflow.graph_engine.condition_handlers.condition_manager import ConditionManager
from core.workflow.graph_engine.entities.event import (
BaseIterationEvent,
GraphEngineEvent,
GraphRunFailedEvent,
GraphRunPartialSucceededEvent,
GraphRunStartedEvent,
GraphRunSucceededEvent,
NodeRunExceptionEvent,
NodeRunFailedEvent,
NodeRunRetrieverResourceEvent,
NodeRunStartedEvent,
@ -36,7 +38,9 @@ from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeSta
from core.workflow.nodes import NodeType
from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProcessor
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.base.entities import BaseNodeData
from core.workflow.nodes.end.end_stream_processor import EndStreamProcessor
from core.workflow.nodes.enums import ErrorStrategy, FailBranchSourceHandle
from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
from extensions.ext_database import db
@ -128,6 +132,7 @@ class GraphEngine:
def run(self) -> Generator[GraphEngineEvent, None, None]:
# trigger graph run start event
yield GraphRunStartedEvent()
handle_exceptions = []
try:
if self.init_params.workflow_type == WorkflowType.CHAT:
@ -140,13 +145,17 @@ class GraphEngine:
)
# run graph
generator = stream_processor.process(self._run(start_node_id=self.graph.root_node_id))
generator = stream_processor.process(
self._run(start_node_id=self.graph.root_node_id, handle_exceptions=handle_exceptions)
)
for item in generator:
try:
yield item
if isinstance(item, NodeRunFailedEvent):
yield GraphRunFailedEvent(error=item.route_node_state.failed_reason or "Unknown error.")
yield GraphRunFailedEvent(
error=item.route_node_state.failed_reason or "Unknown error.",
exceptions_count=len(handle_exceptions),
)
return
elif isinstance(item, NodeRunSucceededEvent):
if item.node_type == NodeType.END:
@ -172,19 +181,24 @@ class GraphEngine:
].strip()
except Exception as e:
logger.exception("Graph run failed")
yield GraphRunFailedEvent(error=str(e))
yield GraphRunFailedEvent(error=str(e), exceptions_count=len(handle_exceptions))
return
# trigger graph run success event
yield GraphRunSucceededEvent(outputs=self.graph_runtime_state.outputs)
# count exceptions to determine partial success
if len(handle_exceptions) > 0:
yield GraphRunPartialSucceededEvent(
exceptions_count=len(handle_exceptions), outputs=self.graph_runtime_state.outputs
)
else:
# trigger graph run success event
yield GraphRunSucceededEvent(outputs=self.graph_runtime_state.outputs)
self._release_thread()
except GraphRunFailedError as e:
yield GraphRunFailedEvent(error=e.error)
yield GraphRunFailedEvent(error=e.error, exceptions_count=len(handle_exceptions))
self._release_thread()
return
except Exception as e:
logger.exception("Unknown Error when graph running")
yield GraphRunFailedEvent(error=str(e))
yield GraphRunFailedEvent(error=str(e), exceptions_count=len(handle_exceptions))
self._release_thread()
raise e
@ -198,6 +212,7 @@ class GraphEngine:
in_parallel_id: Optional[str] = None,
parent_parallel_id: Optional[str] = None,
parent_parallel_start_node_id: Optional[str] = None,
handle_exceptions: list[str] = [],
) -> Generator[GraphEngineEvent, None, None]:
parallel_start_node_id = None
if in_parallel_id:
@ -242,7 +257,7 @@ class GraphEngine:
previous_node_id=previous_node_id,
thread_pool_id=self.thread_pool_id,
)
node_instance = cast(BaseNode[BaseNodeData], node_instance)
try:
# run node
generator = self._run_node(
@ -252,6 +267,7 @@ class GraphEngine:
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
handle_exceptions=handle_exceptions,
)
for item in generator:
@ -301,7 +317,12 @@ class GraphEngine:
if len(edge_mappings) == 1:
edge = edge_mappings[0]
if (
previous_route_node_state.status == RouteNodeState.Status.EXCEPTION
and node_instance.node_data.error_strategy == ErrorStrategy.FAIL_BRANCH
and edge.run_condition is None
):
break
if edge.run_condition:
result = ConditionManager.get_condition_handler(
init_params=self.init_params,
@ -334,7 +355,7 @@ class GraphEngine:
if len(sub_edge_mappings) == 0:
continue
edge = sub_edge_mappings[0]
edge = cast(GraphEdge, sub_edge_mappings[0])
result = ConditionManager.get_condition_handler(
init_params=self.init_params,
@ -355,6 +376,7 @@ class GraphEngine:
edge_mappings=sub_edge_mappings,
in_parallel_id=in_parallel_id,
parallel_start_node_id=parallel_start_node_id,
handle_exceptions=handle_exceptions,
)
for item in parallel_generator:
@ -369,11 +391,18 @@ class GraphEngine:
break
next_node_id = final_node_id
elif (
node_instance.node_data.error_strategy == ErrorStrategy.FAIL_BRANCH
and node_instance.should_continue_on_error
and previous_route_node_state.status == RouteNodeState.Status.EXCEPTION
):
break
else:
parallel_generator = self._run_parallel_branches(
edge_mappings=edge_mappings,
in_parallel_id=in_parallel_id,
parallel_start_node_id=parallel_start_node_id,
handle_exceptions=handle_exceptions,
)
for item in parallel_generator:
@ -395,6 +424,7 @@ class GraphEngine:
edge_mappings: list[GraphEdge],
in_parallel_id: Optional[str] = None,
parallel_start_node_id: Optional[str] = None,
handle_exceptions: list[str] = [],
) -> Generator[GraphEngineEvent | str, None, None]:
# if nodes has no run conditions, parallel run all nodes
parallel_id = self.graph.node_parallel_mapping.get(edge_mappings[0].target_node_id)
@ -438,6 +468,7 @@ class GraphEngine:
"parallel_start_node_id": edge.target_node_id,
"parent_parallel_id": in_parallel_id,
"parent_parallel_start_node_id": parallel_start_node_id,
"handle_exceptions": handle_exceptions,
},
)
@ -481,6 +512,7 @@ class GraphEngine:
parallel_start_node_id: str,
parent_parallel_id: Optional[str] = None,
parent_parallel_start_node_id: Optional[str] = None,
handle_exceptions: list[str] = [],
) -> None:
"""
Run parallel nodes
@ -502,6 +534,7 @@ class GraphEngine:
in_parallel_id=parallel_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
handle_exceptions=handle_exceptions,
)
for item in generator:
@ -548,6 +581,7 @@ class GraphEngine:
parallel_start_node_id: Optional[str] = None,
parent_parallel_id: Optional[str] = None,
parent_parallel_start_node_id: Optional[str] = None,
handle_exceptions: list[str] = [],
) -> Generator[GraphEngineEvent, None, None]:
"""
Run node
@ -587,19 +621,55 @@ class GraphEngine:
route_node_state.set_finished(run_result=run_result)
if run_result.status == WorkflowNodeExecutionStatus.FAILED:
yield NodeRunFailedEvent(
error=route_node_state.failed_reason or "Unknown error.",
id=node_instance.id,
node_id=node_instance.node_id,
node_type=node_instance.node_type,
node_data=node_instance.node_data,
route_node_state=route_node_state,
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
)
if node_instance.should_continue_on_error:
# if run failed, handle error
run_result = self._handle_continue_on_error(
node_instance,
item.run_result,
self.graph_runtime_state.variable_pool,
handle_exceptions=handle_exceptions,
)
route_node_state.node_run_result = run_result
route_node_state.status = RouteNodeState.Status.EXCEPTION
if run_result.outputs:
for variable_key, variable_value in run_result.outputs.items():
# append variables to variable pool recursively
self._append_variables_recursively(
node_id=node_instance.node_id,
variable_key_list=[variable_key],
variable_value=variable_value,
)
yield NodeRunExceptionEvent(
error=run_result.error or "System Error",
id=node_instance.id,
node_id=node_instance.node_id,
node_type=node_instance.node_type,
node_data=node_instance.node_data,
route_node_state=route_node_state,
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
)
else:
yield NodeRunFailedEvent(
error=route_node_state.failed_reason or "Unknown error.",
id=node_instance.id,
node_id=node_instance.node_id,
node_type=node_instance.node_type,
node_data=node_instance.node_data,
route_node_state=route_node_state,
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
)
elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
if node_instance.should_continue_on_error and self.graph.edge_mapping.get(
node_instance.node_id
):
run_result.edge_source_handle = FailBranchSourceHandle.SUCCESS
if run_result.metadata and run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS):
# plus state total_tokens
self.graph_runtime_state.total_tokens += int(
@ -735,6 +805,56 @@ class GraphEngine:
new_instance.graph_runtime_state.variable_pool = deepcopy(self.graph_runtime_state.variable_pool)
return new_instance
def _handle_continue_on_error(
self,
node_instance: BaseNode[BaseNodeData],
error_result: NodeRunResult,
variable_pool: VariablePool,
handle_exceptions: list[str] = [],
) -> NodeRunResult:
"""
handle continue on error when self._should_continue_on_error is True
:param error_result (NodeRunResult): error run result
:param variable_pool (VariablePool): variable pool
:return: excption run result
"""
# add error message and error type to variable pool
variable_pool.add([node_instance.node_id, "error_message"], error_result.error)
variable_pool.add([node_instance.node_id, "error_type"], error_result.error_type)
# add error message to handle_exceptions
handle_exceptions.append(error_result.error)
node_error_args = {
"status": WorkflowNodeExecutionStatus.EXCEPTION,
"error": error_result.error,
"inputs": error_result.inputs,
"metadata": {
NodeRunMetadataKey.ERROR_STRATEGY: node_instance.node_data.error_strategy,
},
}
if node_instance.node_data.error_strategy is ErrorStrategy.DEFAULT_VALUE:
return NodeRunResult(
**node_error_args,
outputs={
**node_instance.node_data.default_value_dict,
"error_message": error_result.error,
"error_type": error_result.error_type,
},
)
elif node_instance.node_data.error_strategy is ErrorStrategy.FAIL_BRANCH:
if self.graph.edge_mapping.get(node_instance.node_id):
node_error_args["edge_source_handle"] = FailBranchSourceHandle.FAILED
return NodeRunResult(
**node_error_args,
outputs={
"error_message": error_result.error,
"error_type": error_result.error_type,
},
)
return error_result
class GraphRunFailedError(Exception):
def __init__(self, error: str):

View File

@ -6,7 +6,7 @@ from core.workflow.nodes.answer.entities import (
TextGenerateRouteChunk,
VarGenerateRouteChunk,
)
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.utils.variable_template_parser import VariableTemplateParser
@ -148,13 +148,18 @@ class AnswerStreamGeneratorRouter:
for edge in reverse_edges:
source_node_id = edge.source_node_id
source_node_type = node_id_config_mapping[source_node_id].get("data", {}).get("type")
if source_node_type in {
NodeType.ANSWER,
NodeType.IF_ELSE,
NodeType.QUESTION_CLASSIFIER,
NodeType.ITERATION,
NodeType.VARIABLE_ASSIGNER,
}:
source_node_data = node_id_config_mapping[source_node_id].get("data", {})
if (
source_node_type
in {
NodeType.ANSWER,
NodeType.IF_ELSE,
NodeType.QUESTION_CLASSIFIER,
NodeType.ITERATION,
NodeType.VARIABLE_ASSIGNER,
}
or source_node_data.get("error_strategy") == ErrorStrategy.FAIL_BRANCH
):
answer_dependencies[answer_node_id].append(source_node_id)
else:
cls._recursive_fetch_answer_dependencies(

View File

@ -6,6 +6,7 @@ from core.file import FILE_MODEL_IDENTITY, File
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.entities.event import (
GraphEngineEvent,
NodeRunExceptionEvent,
NodeRunStartedEvent,
NodeRunStreamChunkEvent,
NodeRunSucceededEvent,
@ -50,7 +51,7 @@ class AnswerStreamProcessor(StreamProcessor):
for _ in stream_out_answer_node_ids:
yield event
elif isinstance(event, NodeRunSucceededEvent):
elif isinstance(event, NodeRunSucceededEvent | NodeRunExceptionEvent):
yield event
if event.route_node_state.node_id in self.current_stream_chunk_generating_node_ids:
# update self.route_position after all stream event finished

View File

@ -1,14 +1,124 @@
import json
from abc import ABC
from typing import Optional
from enum import StrEnum
from typing import Any, Optional, Union
from pydantic import BaseModel
from pydantic import BaseModel, model_validator
from core.workflow.nodes.base.exc import DefaultValueTypeError
from core.workflow.nodes.enums import ErrorStrategy
class DefaultValueType(StrEnum):
STRING = "string"
NUMBER = "number"
OBJECT = "object"
ARRAY_NUMBER = "array[number]"
ARRAY_STRING = "array[string]"
ARRAY_OBJECT = "array[object]"
ARRAY_FILES = "array[file]"
NumberType = Union[int, float]
class DefaultValue(BaseModel):
value: Any
type: DefaultValueType
key: str
@staticmethod
def _parse_json(value: str) -> Any:
"""Unified JSON parsing handler"""
try:
return json.loads(value)
except json.JSONDecodeError:
raise DefaultValueTypeError(f"Invalid JSON format for value: {value}")
@staticmethod
def _validate_array(value: Any, element_type: DefaultValueType) -> bool:
"""Unified array type validation"""
return isinstance(value, list) and all(isinstance(x, element_type) for x in value)
@staticmethod
def _convert_number(value: str) -> float:
"""Unified number conversion handler"""
try:
return float(value)
except ValueError:
raise DefaultValueTypeError(f"Cannot convert to number: {value}")
@model_validator(mode="after")
def validate_value_type(self) -> "DefaultValue":
if self.type is None:
raise DefaultValueTypeError("type field is required")
# Type validation configuration
type_validators = {
DefaultValueType.STRING: {
"type": str,
"converter": lambda x: x,
},
DefaultValueType.NUMBER: {
"type": NumberType,
"converter": self._convert_number,
},
DefaultValueType.OBJECT: {
"type": dict,
"converter": self._parse_json,
},
DefaultValueType.ARRAY_NUMBER: {
"type": list,
"element_type": NumberType,
"converter": self._parse_json,
},
DefaultValueType.ARRAY_STRING: {
"type": list,
"element_type": str,
"converter": self._parse_json,
},
DefaultValueType.ARRAY_OBJECT: {
"type": list,
"element_type": dict,
"converter": self._parse_json,
},
}
validator = type_validators.get(self.type)
if not validator:
if self.type == DefaultValueType.ARRAY_FILES:
# Handle files type
return self
raise DefaultValueTypeError(f"Unsupported type: {self.type}")
# Handle string input cases
if isinstance(self.value, str) and self.type != DefaultValueType.STRING:
self.value = validator["converter"](self.value)
# Validate base type
if not isinstance(self.value, validator["type"]):
raise DefaultValueTypeError(f"Value must be {validator['type'].__name__} type for {self.value}")
# Validate array element types
if validator["type"] == list and not self._validate_array(self.value, validator["element_type"]):
raise DefaultValueTypeError(f"All elements must be {validator['element_type'].__name__} for {self.value}")
return self
class BaseNodeData(ABC, BaseModel):
title: str
desc: Optional[str] = None
error_strategy: Optional[ErrorStrategy] = None
default_value: Optional[list[DefaultValue]] = None
version: str = "1"
@property
def default_value_dict(self):
if self.default_value:
return {item.key: item.value for item in self.default_value}
return {}
class BaseIterationNodeData(BaseNodeData):
start_node_id: Optional[str] = None

View File

@ -0,0 +1,10 @@
class BaseNodeError(Exception):
"""Base class for node errors."""
pass
class DefaultValueTypeError(BaseNodeError):
"""Raised when the default value type is invalid."""
pass

View File

@ -4,7 +4,7 @@ from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, Union, cast
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.enums import CONTINUE_ON_ERROR_NODE_TYPE, NodeType
from core.workflow.nodes.event import NodeEvent, RunCompletedEvent
from models.workflow import WorkflowNodeExecutionStatus
@ -72,10 +72,7 @@ class BaseNode(Generic[GenericNodeData]):
result = self._run()
except Exception as e:
logger.exception(f"Node {self.node_id} failed to run")
result = NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=str(e),
)
result = NodeRunResult(status=WorkflowNodeExecutionStatus.FAILED, error=str(e), error_type="SystemError")
if isinstance(result, NodeRunResult):
yield RunCompletedEvent(run_result=result)
@ -137,3 +134,12 @@ class BaseNode(Generic[GenericNodeData]):
:return:
"""
return self._node_type
@property
def should_continue_on_error(self) -> bool:
"""judge if should continue on error
Returns:
bool: if should continue on error
"""
return self.node_data.error_strategy is not None and self.node_type in CONTINUE_ON_ERROR_NODE_TYPE

Some files were not shown because too many files have changed in this diff Show More