mirror of https://github.com/langgenius/dify.git
merge main
This commit is contained in:
commit
c3f3b79b79
|
|
@ -50,6 +50,9 @@ jobs:
|
|||
- name: Run ModelRuntime
|
||||
run: poetry run -C api bash dev/pytest/pytest_model_runtime.sh
|
||||
|
||||
- name: Run dify config tests
|
||||
run: poetry run -C api python dev/pytest/pytest_config_tests.py
|
||||
|
||||
- name: Run Tool
|
||||
run: poetry run -C api bash dev/pytest/pytest_tools.sh
|
||||
|
||||
|
|
|
|||
|
|
@ -9,5 +9,6 @@ yq eval '.services["pgvecto-rs"].ports += ["5431:5432"]' -i docker/docker-compos
|
|||
yq eval '.services["elasticsearch"].ports += ["9200:9200"]' -i docker/docker-compose.yaml
|
||||
yq eval '.services.couchbase-server.ports += ["8091-8096:8091-8096"]' -i docker/docker-compose.yaml
|
||||
yq eval '.services.couchbase-server.ports += ["11210:11210"]' -i docker/docker-compose.yaml
|
||||
yq eval '.services.tidb.ports += ["4000:4000"]' -i docker/docker-compose.yaml
|
||||
|
||||
echo "Ports exposed for sandbox, weaviate, qdrant, chroma, milvus, pgvector, pgvecto-rs, elasticsearch, couchbase"
|
||||
echo "Ports exposed for sandbox, weaviate, tidb, qdrant, chroma, milvus, pgvector, pgvecto-rs, elasticsearch, couchbase"
|
||||
|
|
|
|||
|
|
@ -60,17 +60,8 @@ DB_DATABASE=dify
|
|||
STORAGE_TYPE=opendal
|
||||
|
||||
# Apache OpenDAL storage configuration, refer to https://github.com/apache/opendal
|
||||
STORAGE_OPENDAL_SCHEME=fs
|
||||
# OpenDAL FS
|
||||
OPENDAL_SCHEME=fs
|
||||
OPENDAL_FS_ROOT=storage
|
||||
# OpenDAL S3
|
||||
OPENDAL_S3_ROOT=/
|
||||
OPENDAL_S3_BUCKET=your-bucket-name
|
||||
OPENDAL_S3_ENDPOINT=https://s3.amazonaws.com
|
||||
OPENDAL_S3_ACCESS_KEY_ID=your-access-key
|
||||
OPENDAL_S3_SECRET_ACCESS_KEY=your-secret-key
|
||||
OPENDAL_S3_REGION=your-region
|
||||
OPENDAL_S3_SERVER_SIDE_ENCRYPTION=
|
||||
|
||||
# S3 Storage configuration
|
||||
S3_USE_AWS_MANAGED_IAM=false
|
||||
|
|
@ -313,8 +304,7 @@ UPLOAD_VIDEO_FILE_SIZE_LIMIT=100
|
|||
UPLOAD_AUDIO_FILE_SIZE_LIMIT=50
|
||||
|
||||
# Model configuration
|
||||
MULTIMODAL_SEND_IMAGE_FORMAT=base64
|
||||
MULTIMODAL_SEND_VIDEO_FORMAT=base64
|
||||
MULTIMODAL_SEND_FORMAT=base64
|
||||
PROMPT_GENERATION_MAX_TOKENS=512
|
||||
CODE_GENERATION_MAX_TOKENS=1024
|
||||
|
||||
|
|
@ -409,6 +399,7 @@ INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH=4000
|
|||
WORKFLOW_MAX_EXECUTION_STEPS=500
|
||||
WORKFLOW_MAX_EXECUTION_TIME=1200
|
||||
WORKFLOW_CALL_MAX_DEPTH=5
|
||||
WORKFLOW_PARALLEL_DEPTH_LIMIT=3
|
||||
MAX_VARIABLE_SIZE=204800
|
||||
|
||||
# App configuration
|
||||
|
|
@ -435,3 +426,5 @@ CREATE_TIDB_SERVICE_JOB_ENABLED=false
|
|||
|
||||
# Maximum number of submitted thread count in a ThreadPool for parallel node execution
|
||||
MAX_SUBMIT_COUNT=100
|
||||
# Lockout duration in seconds
|
||||
LOGIN_LOCKOUT_DURATION=86400
|
||||
|
|
@ -70,7 +70,6 @@ ignore = [
|
|||
"SIM113", # eumerate-for-loop
|
||||
"SIM117", # multiple-with-statements
|
||||
"SIM210", # if-expr-with-true-false
|
||||
"SIM300", # yoda-conditions,
|
||||
]
|
||||
|
||||
[lint.per-file-ignores]
|
||||
|
|
|
|||
27
api/app.py
27
api/app.py
|
|
@ -1,13 +1,30 @@
|
|||
from app_factory import create_app
|
||||
from libs import threadings_utils, version_utils
|
||||
from libs import version_utils
|
||||
|
||||
# preparation before creating app
|
||||
version_utils.check_supported_python_version()
|
||||
threadings_utils.apply_gevent_threading_patch()
|
||||
|
||||
|
||||
def is_db_command():
|
||||
import sys
|
||||
|
||||
if len(sys.argv) > 1 and sys.argv[0].endswith("flask") and sys.argv[1] == "db":
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
# create app
|
||||
app = create_app()
|
||||
celery = app.extensions["celery"]
|
||||
if is_db_command():
|
||||
from app_factory import create_migrations_app
|
||||
|
||||
app = create_migrations_app()
|
||||
else:
|
||||
from app_factory import create_app
|
||||
from libs import threadings_utils
|
||||
|
||||
threadings_utils.apply_gevent_threading_patch()
|
||||
|
||||
app = create_app()
|
||||
celery = app.extensions["celery"]
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run(host="0.0.0.0", port=5001)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
import logging
|
||||
import os
|
||||
import time
|
||||
|
||||
from configs import dify_config
|
||||
|
|
@ -17,15 +16,6 @@ def create_flask_app_with_configs() -> DifyApp:
|
|||
dify_app = DifyApp(__name__)
|
||||
dify_app.config.from_mapping(dify_config.model_dump())
|
||||
|
||||
# populate configs into system environment variables
|
||||
for key, value in dify_app.config.items():
|
||||
if isinstance(value, str):
|
||||
os.environ[key] = value
|
||||
elif isinstance(value, int | float | bool):
|
||||
os.environ[key] = str(value)
|
||||
elif value is None:
|
||||
os.environ[key] = ""
|
||||
|
||||
return dify_app
|
||||
|
||||
|
||||
|
|
@ -98,3 +88,14 @@ def initialize_extensions(app: DifyApp):
|
|||
end_time = time.perf_counter()
|
||||
if dify_config.DEBUG:
|
||||
logging.info(f"Loaded {short_name} ({round((end_time - start_time) * 1000, 2)} ms)")
|
||||
|
||||
|
||||
def create_migrations_app():
|
||||
app = create_flask_app_with_configs()
|
||||
from extensions import ext_database, ext_migrate
|
||||
|
||||
# Initialize only required extensions
|
||||
ext_database.init_app(app)
|
||||
ext_migrate.init_app(app)
|
||||
|
||||
return app
|
||||
|
|
|
|||
|
|
@ -555,7 +555,8 @@ def create_tenant(email: str, language: Optional[str] = None, name: Optional[str
|
|||
if language not in languages:
|
||||
language = "en-US"
|
||||
|
||||
name = name.strip()
|
||||
# Validates name encoding for non-Latin characters.
|
||||
name = name.strip().encode("utf-8").decode("utf-8") if name else None
|
||||
|
||||
# generate random password
|
||||
new_password = secrets.token_urlsafe(16)
|
||||
|
|
|
|||
|
|
@ -433,6 +433,11 @@ class WorkflowConfig(BaseSettings):
|
|||
default=5,
|
||||
)
|
||||
|
||||
WORKFLOW_PARALLEL_DEPTH_LIMIT: PositiveInt = Field(
|
||||
description="Maximum allowed depth for nested parallel executions",
|
||||
default=3,
|
||||
)
|
||||
|
||||
MAX_VARIABLE_SIZE: PositiveInt = Field(
|
||||
description="Maximum size in bytes for a single variable in workflows. Default to 200 KB.",
|
||||
default=200 * 1024,
|
||||
|
|
@ -485,6 +490,11 @@ class AuthConfig(BaseSettings):
|
|||
default=60,
|
||||
)
|
||||
|
||||
LOGIN_LOCKOUT_DURATION: PositiveInt = Field(
|
||||
description="Time (in seconds) a user must wait before retrying login after exceeding the rate limit.",
|
||||
default=86400,
|
||||
)
|
||||
|
||||
|
||||
class ModerationConfig(BaseSettings):
|
||||
"""
|
||||
|
|
@ -660,14 +670,9 @@ class IndexingConfig(BaseSettings):
|
|||
)
|
||||
|
||||
|
||||
class VisionFormatConfig(BaseSettings):
|
||||
MULTIMODAL_SEND_IMAGE_FORMAT: Literal["base64", "url"] = Field(
|
||||
description="Format for sending images in multimodal contexts ('base64' or 'url'), default is base64",
|
||||
default="base64",
|
||||
)
|
||||
|
||||
MULTIMODAL_SEND_VIDEO_FORMAT: Literal["base64", "url"] = Field(
|
||||
description="Format for sending videos in multimodal contexts ('base64' or 'url'), default is base64",
|
||||
class MultiModalTransferConfig(BaseSettings):
|
||||
MULTIMODAL_SEND_FORMAT: Literal["base64", "url"] = Field(
|
||||
description="Format for sending files in multimodal contexts ('base64' or 'url'), default is base64",
|
||||
default="base64",
|
||||
)
|
||||
|
||||
|
|
@ -773,13 +778,13 @@ class FeatureConfig(
|
|||
FileAccessConfig,
|
||||
FileUploadConfig,
|
||||
HttpConfig,
|
||||
VisionFormatConfig,
|
||||
InnerAPIConfig,
|
||||
IndexingConfig,
|
||||
LoggingConfig,
|
||||
MailConfig,
|
||||
ModelLoadBalanceConfig,
|
||||
ModerationConfig,
|
||||
MultiModalTransferConfig,
|
||||
PositionConfig,
|
||||
RagEtlConfig,
|
||||
SecurityConfig,
|
||||
|
|
|
|||
|
|
@ -1,51 +1,9 @@
|
|||
from enum import StrEnum
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class OpenDALScheme(StrEnum):
|
||||
FS = "fs"
|
||||
S3 = "s3"
|
||||
|
||||
|
||||
class OpenDALStorageConfig(BaseSettings):
|
||||
STORAGE_OPENDAL_SCHEME: str = Field(
|
||||
default=OpenDALScheme.FS.value,
|
||||
OPENDAL_SCHEME: str = Field(
|
||||
default="fs",
|
||||
description="OpenDAL scheme.",
|
||||
)
|
||||
# FS
|
||||
OPENDAL_FS_ROOT: str = Field(
|
||||
default="storage",
|
||||
description="Root path for local storage.",
|
||||
)
|
||||
# S3
|
||||
OPENDAL_S3_ROOT: str = Field(
|
||||
default="/",
|
||||
description="Root path for S3 storage.",
|
||||
)
|
||||
OPENDAL_S3_BUCKET: str = Field(
|
||||
default="",
|
||||
description="S3 bucket name.",
|
||||
)
|
||||
OPENDAL_S3_ENDPOINT: str = Field(
|
||||
default="https://s3.amazonaws.com",
|
||||
description="S3 endpoint URL.",
|
||||
)
|
||||
OPENDAL_S3_ACCESS_KEY_ID: str = Field(
|
||||
default="",
|
||||
description="S3 access key ID.",
|
||||
)
|
||||
OPENDAL_S3_SECRET_ACCESS_KEY: str = Field(
|
||||
default="",
|
||||
description="S3 secret access key.",
|
||||
)
|
||||
OPENDAL_S3_REGION: str = Field(
|
||||
default="",
|
||||
description="S3 region.",
|
||||
)
|
||||
OPENDAL_S3_SERVER_SIDE_ENCRYPTION: Literal["aws:kms", ""] = Field(
|
||||
default="",
|
||||
description="S3 server-side encryption.",
|
||||
)
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings):
|
|||
|
||||
CURRENT_VERSION: str = Field(
|
||||
description="Dify version",
|
||||
default="0.14.0",
|
||||
default="0.14.1",
|
||||
)
|
||||
|
||||
COMMIT_SHA: str = Field(
|
||||
|
|
|
|||
|
|
@ -4,3 +4,8 @@ from werkzeug.exceptions import HTTPException
|
|||
class FilenameNotExistsError(HTTPException):
|
||||
code = 400
|
||||
description = "The specified filename does not exist."
|
||||
|
||||
|
||||
class RemoteFileUploadError(HTTPException):
|
||||
code = 400
|
||||
description = "Error uploading remote file."
|
||||
|
|
|
|||
|
|
@ -31,7 +31,7 @@ def admin_required(view):
|
|||
if auth_scheme != "bearer":
|
||||
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
|
||||
|
||||
if dify_config.ADMIN_API_KEY != auth_token:
|
||||
if auth_token != dify_config.ADMIN_API_KEY:
|
||||
raise Unauthorized("API key is invalid.")
|
||||
|
||||
return view(*args, **kwargs)
|
||||
|
|
|
|||
|
|
@ -65,7 +65,7 @@ class ModelConfigResource(Resource):
|
|||
provider_type=agent_tool_entity.provider_type,
|
||||
identity_id=f"AGENT.{app_model.id}",
|
||||
)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
# get decrypted parameters
|
||||
|
|
@ -97,7 +97,7 @@ class ModelConfigResource(Resource):
|
|||
app_id=app_model.id,
|
||||
agent_tool=agent_tool_entity,
|
||||
)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
manager = ToolParameterConfigurationManager(
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
from flask_restful import Resource, reqparse
|
||||
from werkzeug.exceptions import BadRequest
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.app.error import TracingConfigCheckError, TracingConfigIsExist, TracingConfigNotExist
|
||||
|
|
@ -26,7 +27,7 @@ class TraceAppConfigApi(Resource):
|
|||
return {"has_not_configured": True}
|
||||
return trace_config
|
||||
except Exception as e:
|
||||
raise e
|
||||
raise BadRequest(str(e))
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -48,7 +49,7 @@ class TraceAppConfigApi(Resource):
|
|||
raise TracingConfigCheckError()
|
||||
return result
|
||||
except Exception as e:
|
||||
raise e
|
||||
raise BadRequest(str(e))
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -68,7 +69,7 @@ class TraceAppConfigApi(Resource):
|
|||
raise TracingConfigNotExist()
|
||||
return {"result": "success"}
|
||||
except Exception as e:
|
||||
raise e
|
||||
raise BadRequest(str(e))
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -85,7 +86,7 @@ class TraceAppConfigApi(Resource):
|
|||
raise TracingConfigNotExist()
|
||||
return {"result": "success"}
|
||||
except Exception as e:
|
||||
raise e
|
||||
raise BadRequest(str(e))
|
||||
|
||||
|
||||
api.add_resource(TraceAppConfigApi, "/apps/<uuid:app_id>/trace-config")
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ from flask_restful import Resource, marshal_with, reqparse
|
|||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from configs import dify_config
|
||||
from controllers.console import api
|
||||
from controllers.console.app.error import ConversationCompletedError, DraftWorkflowNotExist, DraftWorkflowNotSync
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
|
|
@ -426,7 +427,21 @@ class ConvertToWorkflowApi(Resource):
|
|||
}
|
||||
|
||||
|
||||
class WorkflowConfigApi(Resource):
|
||||
"""Resource for workflow configuration."""
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def get(self, app_model: App):
|
||||
return {
|
||||
"parallel_depth_limit": dify_config.WORKFLOW_PARALLEL_DEPTH_LIMIT,
|
||||
}
|
||||
|
||||
|
||||
api.add_resource(DraftWorkflowApi, "/apps/<uuid:app_id>/workflows/draft")
|
||||
api.add_resource(WorkflowConfigApi, "/apps/<uuid:app_id>/workflows/draft/config")
|
||||
api.add_resource(AdvancedChatDraftWorkflowRunApi, "/apps/<uuid:app_id>/advanced-chat/workflows/draft/run")
|
||||
api.add_resource(DraftWorkflowRunApi, "/apps/<uuid:app_id>/workflows/draft/run")
|
||||
api.add_resource(WorkflowTaskStopApi, "/apps/<uuid:app_id>/workflow-runs/tasks/<string:task_id>/stop")
|
||||
|
|
|
|||
|
|
@ -5,8 +5,7 @@ from typing import Optional, Union
|
|||
from controllers.console.app.error import AppNotFoundError
|
||||
from extensions.ext_database import db
|
||||
from libs.login import current_user
|
||||
from models import App
|
||||
from models.model import AppMode
|
||||
from models import App, AppMode
|
||||
|
||||
|
||||
def get_app_model(view: Optional[Callable] = None, *, mode: Union[AppMode, list[AppMode]] = None):
|
||||
|
|
|
|||
|
|
@ -948,7 +948,7 @@ class DocumentRetryApi(DocumentResource):
|
|||
if document.indexing_status == "completed":
|
||||
raise DocumentAlreadyFinishedError()
|
||||
retry_documents.append(document)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
logging.exception(f"Failed to retry document, document id: {document_id}")
|
||||
continue
|
||||
# retry document
|
||||
|
|
|
|||
|
|
@ -1,12 +1,14 @@
|
|||
from flask_login import current_user
|
||||
from flask_restful import marshal_with, reqparse
|
||||
from flask_restful.inputs import int_range
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.explore.error import NotChatAppError
|
||||
from controllers.console.explore.wraps import InstalledAppResource
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from extensions.ext_database import db
|
||||
from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields
|
||||
from libs.helper import uuid_value
|
||||
from models.model import AppMode
|
||||
|
|
@ -34,14 +36,16 @@ class ConversationListApi(InstalledAppResource):
|
|||
pinned = True if args["pinned"] == "true" else False
|
||||
|
||||
try:
|
||||
return WebConversationService.pagination_by_last_id(
|
||||
app_model=app_model,
|
||||
user=current_user,
|
||||
last_id=args["last_id"],
|
||||
limit=args["limit"],
|
||||
invoke_from=InvokeFrom.EXPLORE,
|
||||
pinned=pinned,
|
||||
)
|
||||
with Session(db.engine) as session:
|
||||
return WebConversationService.pagination_by_last_id(
|
||||
session=session,
|
||||
app_model=app_model,
|
||||
user=current_user,
|
||||
last_id=args["last_id"],
|
||||
limit=args["limit"],
|
||||
invoke_from=InvokeFrom.EXPLORE,
|
||||
pinned=pinned,
|
||||
)
|
||||
except LastConversationNotExistsError:
|
||||
raise NotFound("Last Conversation Not Exists.")
|
||||
|
||||
|
|
|
|||
|
|
@ -70,7 +70,7 @@ class MessageFeedbackApi(InstalledAppResource):
|
|||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
MessageService.create_feedback(app_model, message_id, current_user, args["rating"])
|
||||
MessageService.create_feedback(app_model, message_id, current_user, args["rating"], args["content"])
|
||||
except services.errors.message.MessageNotExistsError:
|
||||
raise NotFound("Message Not Exists.")
|
||||
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ from flask_restful import Resource, fields, marshal_with, reqparse
|
|||
from constants.languages import languages
|
||||
from controllers.console import api
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from libs.helper import AppIconUrlField
|
||||
from libs.login import login_required
|
||||
from services.recommended_app_service import RecommendedAppService
|
||||
|
||||
|
|
@ -12,6 +13,8 @@ app_fields = {
|
|||
"name": fields.String,
|
||||
"mode": fields.String,
|
||||
"icon": fields.String,
|
||||
"icon_type": fields.String,
|
||||
"icon_url": AppIconUrlField,
|
||||
"icon_background": fields.String,
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
from flask import request
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, marshal_with
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
import services
|
||||
from configs import dify_config
|
||||
|
|
@ -58,6 +59,9 @@ class FileApi(Resource):
|
|||
if not file.filename:
|
||||
raise FilenameNotExistsError
|
||||
|
||||
if source == "datasets" and not current_user.is_dataset_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if source not in ("datasets", None):
|
||||
source = None
|
||||
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ from flask_restful import Resource, marshal_with, reqparse
|
|||
|
||||
import services
|
||||
from controllers.common import helpers
|
||||
from controllers.common.errors import RemoteFileUploadError
|
||||
from core.file import helpers as file_helpers
|
||||
from core.helper import ssrf_proxy
|
||||
from fields.file_fields import file_fields_with_signed_url, remote_file_info_fields
|
||||
|
|
@ -43,10 +44,14 @@ class RemoteFileUploadApi(Resource):
|
|||
|
||||
url = args["url"]
|
||||
|
||||
resp = ssrf_proxy.head(url=url)
|
||||
if resp.status_code != httpx.codes.OK:
|
||||
resp = ssrf_proxy.get(url=url, timeout=3, follow_redirects=True)
|
||||
resp.raise_for_status()
|
||||
try:
|
||||
resp = ssrf_proxy.head(url=url)
|
||||
if resp.status_code != httpx.codes.OK:
|
||||
resp = ssrf_proxy.get(url=url, timeout=3, follow_redirects=True)
|
||||
if resp.status_code != httpx.codes.OK:
|
||||
raise RemoteFileUploadError(f"Failed to fetch file from {url}: {resp.text}")
|
||||
except httpx.RequestError as e:
|
||||
raise RemoteFileUploadError(f"Failed to fetch file from {url}: {str(e)}")
|
||||
|
||||
file_info = helpers.guess_file_info_from_response(resp)
|
||||
|
||||
|
|
|
|||
|
|
@ -3,12 +3,14 @@ import io
|
|||
from flask import send_file
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, reqparse
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.console import api
|
||||
from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import alphanumeric, uuid_value
|
||||
from libs.login import login_required
|
||||
from services.tools.api_tools_manage_service import ApiToolManageService
|
||||
|
|
@ -91,12 +93,16 @@ class ToolBuiltinProviderUpdateApi(Resource):
|
|||
|
||||
args = parser.parse_args()
|
||||
|
||||
return BuiltinToolManageService.update_builtin_tool_provider(
|
||||
user_id,
|
||||
tenant_id,
|
||||
provider,
|
||||
args["credentials"],
|
||||
)
|
||||
with Session(db.engine) as session:
|
||||
result = BuiltinToolManageService.update_builtin_tool_provider(
|
||||
session=session,
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
provider_name=provider,
|
||||
credentials=args["credentials"],
|
||||
)
|
||||
session.commit()
|
||||
return result
|
||||
|
||||
|
||||
class ToolBuiltinProviderGetCredentialsApi(Resource):
|
||||
|
|
@ -104,13 +110,11 @@ class ToolBuiltinProviderGetCredentialsApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider):
|
||||
user_id = current_user.id
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
return BuiltinToolManageService.get_builtin_tool_provider_credentials(
|
||||
user_id,
|
||||
tenant_id,
|
||||
provider,
|
||||
tenant_id=tenant_id,
|
||||
provider_name=provider,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
from flask_restful import Resource, marshal_with, reqparse
|
||||
from flask_restful.inputs import int_range
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
import services
|
||||
|
|
@ -7,6 +8,7 @@ from controllers.service_api import api
|
|||
from controllers.service_api.app.error import NotChatAppError
|
||||
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from extensions.ext_database import db
|
||||
from fields.conversation_fields import (
|
||||
conversation_delete_fields,
|
||||
conversation_infinite_scroll_pagination_fields,
|
||||
|
|
@ -39,14 +41,16 @@ class ConversationApi(Resource):
|
|||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
return ConversationService.pagination_by_last_id(
|
||||
app_model=app_model,
|
||||
user=end_user,
|
||||
last_id=args["last_id"],
|
||||
limit=args["limit"],
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
sort_by=args["sort_by"],
|
||||
)
|
||||
with Session(db.engine) as session:
|
||||
return ConversationService.pagination_by_last_id(
|
||||
session=session,
|
||||
app_model=app_model,
|
||||
user=end_user,
|
||||
last_id=args["last_id"],
|
||||
limit=args["limit"],
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
sort_by=args["sort_by"],
|
||||
)
|
||||
except services.errors.conversation.LastConversationNotExistsError:
|
||||
raise NotFound("Last Conversation Not Exists.")
|
||||
|
||||
|
|
|
|||
|
|
@ -104,10 +104,11 @@ class MessageFeedbackApi(Resource):
|
|||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("rating", type=str, choices=["like", "dislike", None], location="json")
|
||||
parser.add_argument("content", type=str, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
MessageService.create_feedback(app_model, message_id, end_user, args["rating"])
|
||||
MessageService.create_feedback(app_model, message_id, end_user, args["rating"], args["content"])
|
||||
except services.errors.message.MessageNotExistsError:
|
||||
raise NotFound("Message Not Exists.")
|
||||
|
||||
|
|
|
|||
|
|
@ -1,11 +1,13 @@
|
|||
from flask_restful import marshal_with, reqparse
|
||||
from flask_restful.inputs import int_range
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.web import api
|
||||
from controllers.web.error import NotChatAppError
|
||||
from controllers.web.wraps import WebApiResource
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from extensions.ext_database import db
|
||||
from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields
|
||||
from libs.helper import uuid_value
|
||||
from models.model import AppMode
|
||||
|
|
@ -40,15 +42,17 @@ class ConversationListApi(WebApiResource):
|
|||
pinned = True if args["pinned"] == "true" else False
|
||||
|
||||
try:
|
||||
return WebConversationService.pagination_by_last_id(
|
||||
app_model=app_model,
|
||||
user=end_user,
|
||||
last_id=args["last_id"],
|
||||
limit=args["limit"],
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
pinned=pinned,
|
||||
sort_by=args["sort_by"],
|
||||
)
|
||||
with Session(db.engine) as session:
|
||||
return WebConversationService.pagination_by_last_id(
|
||||
session=session,
|
||||
app_model=app_model,
|
||||
user=end_user,
|
||||
last_id=args["last_id"],
|
||||
limit=args["limit"],
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
pinned=pinned,
|
||||
sort_by=args["sort_by"],
|
||||
)
|
||||
except LastConversationNotExistsError:
|
||||
raise NotFound("Last Conversation Not Exists.")
|
||||
|
||||
|
|
|
|||
|
|
@ -108,7 +108,7 @@ class MessageFeedbackApi(WebApiResource):
|
|||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
MessageService.create_feedback(app_model, message_id, end_user, args["rating"])
|
||||
MessageService.create_feedback(app_model, message_id, end_user, args["rating"], args["content"])
|
||||
except services.errors.message.MessageNotExistsError:
|
||||
raise NotFound("Message Not Exists.")
|
||||
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ from flask_restful import marshal_with, reqparse
|
|||
|
||||
import services
|
||||
from controllers.common import helpers
|
||||
from controllers.common.errors import RemoteFileUploadError
|
||||
from controllers.web.wraps import WebApiResource
|
||||
from core.file import helpers as file_helpers
|
||||
from core.helper import ssrf_proxy
|
||||
|
|
@ -38,10 +39,14 @@ class RemoteFileUploadApi(WebApiResource):
|
|||
|
||||
url = args["url"]
|
||||
|
||||
resp = ssrf_proxy.head(url=url)
|
||||
if resp.status_code != httpx.codes.OK:
|
||||
resp = ssrf_proxy.get(url=url, timeout=3)
|
||||
resp.raise_for_status()
|
||||
try:
|
||||
resp = ssrf_proxy.head(url=url)
|
||||
if resp.status_code != httpx.codes.OK:
|
||||
resp = ssrf_proxy.get(url=url, timeout=3, follow_redirects=True)
|
||||
if resp.status_code != httpx.codes.OK:
|
||||
raise RemoteFileUploadError(f"Failed to fetch file from {url}: {resp.text}")
|
||||
except httpx.RequestError as e:
|
||||
raise RemoteFileUploadError(f"Failed to fetch file from {url}: {str(e)}")
|
||||
|
||||
file_info = helpers.guess_file_info_from_response(resp)
|
||||
|
||||
|
|
|
|||
|
|
@ -4,14 +4,17 @@ import logging
|
|||
import queue
|
||||
import re
|
||||
import threading
|
||||
from collections.abc import Iterable
|
||||
|
||||
from core.app.entities.queue_entities import (
|
||||
MessageQueueMessage,
|
||||
QueueAgentMessageEvent,
|
||||
QueueLLMChunkEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
QueueTextChunkEvent,
|
||||
WorkflowQueueMessage,
|
||||
)
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
|
||||
|
||||
|
|
@ -21,7 +24,7 @@ class AudioTrunk:
|
|||
self.status = status
|
||||
|
||||
|
||||
def _invoice_tts(text_content: str, model_instance, tenant_id: str, voice: str):
|
||||
def _invoice_tts(text_content: str, model_instance: ModelInstance, tenant_id: str, voice: str):
|
||||
if not text_content or text_content.isspace():
|
||||
return
|
||||
return model_instance.invoke_tts(
|
||||
|
|
@ -29,13 +32,19 @@ def _invoice_tts(text_content: str, model_instance, tenant_id: str, voice: str):
|
|||
)
|
||||
|
||||
|
||||
def _process_future(future_queue, audio_queue):
|
||||
def _process_future(
|
||||
future_queue: queue.Queue[concurrent.futures.Future[Iterable[bytes] | None] | None],
|
||||
audio_queue: queue.Queue[AudioTrunk],
|
||||
):
|
||||
while True:
|
||||
try:
|
||||
future = future_queue.get()
|
||||
if future is None:
|
||||
break
|
||||
for audio in future.result():
|
||||
invoke_result = future.result()
|
||||
if not invoke_result:
|
||||
continue
|
||||
for audio in invoke_result:
|
||||
audio_base64 = base64.b64encode(bytes(audio))
|
||||
audio_queue.put(AudioTrunk("responding", audio=audio_base64))
|
||||
except Exception as e:
|
||||
|
|
@ -49,8 +58,8 @@ class AppGeneratorTTSPublisher:
|
|||
self.logger = logging.getLogger(__name__)
|
||||
self.tenant_id = tenant_id
|
||||
self.msg_text = ""
|
||||
self._audio_queue = queue.Queue()
|
||||
self._msg_queue = queue.Queue()
|
||||
self._audio_queue: queue.Queue[AudioTrunk] = queue.Queue()
|
||||
self._msg_queue: queue.Queue[WorkflowQueueMessage | MessageQueueMessage | None] = queue.Queue()
|
||||
self.match = re.compile(r"[。.!?]")
|
||||
self.model_manager = ModelManager()
|
||||
self.model_instance = self.model_manager.get_default_model_instance(
|
||||
|
|
@ -66,14 +75,11 @@ class AppGeneratorTTSPublisher:
|
|||
self._runtime_thread = threading.Thread(target=self._runtime).start()
|
||||
self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=3)
|
||||
|
||||
def publish(self, message):
|
||||
try:
|
||||
self._msg_queue.put(message)
|
||||
except Exception as e:
|
||||
self.logger.warning(e)
|
||||
def publish(self, message: WorkflowQueueMessage | MessageQueueMessage | None, /):
|
||||
self._msg_queue.put(message)
|
||||
|
||||
def _runtime(self):
|
||||
future_queue = queue.Queue()
|
||||
future_queue: queue.Queue[concurrent.futures.Future[Iterable[bytes] | None] | None] = queue.Queue()
|
||||
threading.Thread(target=_process_future, args=(future_queue, self._audio_queue)).start()
|
||||
while True:
|
||||
try:
|
||||
|
|
@ -110,7 +116,7 @@ class AppGeneratorTTSPublisher:
|
|||
break
|
||||
future_queue.put(None)
|
||||
|
||||
def check_and_get_audio(self) -> AudioTrunk | None:
|
||||
def check_and_get_audio(self):
|
||||
try:
|
||||
if self._last_audio_event and self._last_audio_event.status == "finish":
|
||||
if self.executor:
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ from core.app.entities.queue_entities import (
|
|||
QueueNodeExceptionEvent,
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeInIterationFailedEvent,
|
||||
QueueNodeRetryEvent,
|
||||
QueueNodeStartedEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
QueueParallelBranchRunFailedEvent,
|
||||
|
|
@ -179,7 +180,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||
else:
|
||||
continue
|
||||
|
||||
raise Exception("Queue listening stopped unexpectedly.")
|
||||
raise ValueError("queue listening stopped unexpectedly.")
|
||||
|
||||
def _to_stream_response(
|
||||
self, generator: Generator[StreamResponse, None, None]
|
||||
|
|
@ -196,11 +197,11 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||
stream_response=stream_response,
|
||||
)
|
||||
|
||||
def _listen_audio_msg(self, publisher, task_id: str):
|
||||
def _listen_audio_msg(self, publisher: AppGeneratorTTSPublisher | None, task_id: str):
|
||||
if not publisher:
|
||||
return None
|
||||
audio_msg: AudioTrunk = publisher.check_and_get_audio()
|
||||
if audio_msg and audio_msg.status != "finish":
|
||||
audio_msg = publisher.check_and_get_audio()
|
||||
if audio_msg and isinstance(audio_msg, AudioTrunk) and audio_msg.status != "finish":
|
||||
return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id)
|
||||
return None
|
||||
|
||||
|
|
@ -221,7 +222,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||
|
||||
for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager):
|
||||
while True:
|
||||
audio_response = self._listen_audio_msg(tts_publisher, task_id=task_id)
|
||||
audio_response = self._listen_audio_msg(publisher=tts_publisher, task_id=task_id)
|
||||
if audio_response:
|
||||
yield audio_response
|
||||
else:
|
||||
|
|
@ -290,9 +291,27 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||
yield self._workflow_start_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
||||
)
|
||||
elif isinstance(
|
||||
event,
|
||||
QueueNodeRetryEvent,
|
||||
):
|
||||
if not workflow_run:
|
||||
raise ValueError("workflow run not initialized.")
|
||||
workflow_node_execution = self._handle_workflow_node_execution_retried(
|
||||
workflow_run=workflow_run, event=event
|
||||
)
|
||||
|
||||
response = self._workflow_node_retry_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
)
|
||||
|
||||
if response:
|
||||
yield response
|
||||
elif isinstance(event, QueueNodeStartedEvent):
|
||||
if not workflow_run:
|
||||
raise Exception("Workflow run not initialized.")
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
workflow_node_execution = self._handle_node_execution_start(workflow_run=workflow_run, event=event)
|
||||
|
||||
|
|
@ -330,47 +349,48 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||
|
||||
if response:
|
||||
yield response
|
||||
|
||||
elif isinstance(event, QueueParallelBranchRunStartedEvent):
|
||||
if not workflow_run:
|
||||
raise Exception("Workflow run not initialized.")
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
yield self._workflow_parallel_branch_start_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
|
||||
)
|
||||
elif isinstance(event, QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent):
|
||||
if not workflow_run:
|
||||
raise Exception("Workflow run not initialized.")
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
yield self._workflow_parallel_branch_finished_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
|
||||
)
|
||||
elif isinstance(event, QueueIterationStartEvent):
|
||||
if not workflow_run:
|
||||
raise Exception("Workflow run not initialized.")
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
yield self._workflow_iteration_start_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
|
||||
)
|
||||
elif isinstance(event, QueueIterationNextEvent):
|
||||
if not workflow_run:
|
||||
raise Exception("Workflow run not initialized.")
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
yield self._workflow_iteration_next_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
|
||||
)
|
||||
elif isinstance(event, QueueIterationCompletedEvent):
|
||||
if not workflow_run:
|
||||
raise Exception("Workflow run not initialized.")
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
yield self._workflow_iteration_completed_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
|
||||
)
|
||||
elif isinstance(event, QueueWorkflowSucceededEvent):
|
||||
if not workflow_run:
|
||||
raise Exception("Workflow run not initialized.")
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
if not graph_runtime_state:
|
||||
raise Exception("Graph runtime state not initialized.")
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
workflow_run = self._handle_workflow_run_success(
|
||||
workflow_run=workflow_run,
|
||||
|
|
@ -389,10 +409,10 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||
self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE)
|
||||
elif isinstance(event, QueueWorkflowPartialSuccessEvent):
|
||||
if not workflow_run:
|
||||
raise Exception("Workflow run not initialized.")
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
if not graph_runtime_state:
|
||||
raise Exception("Graph runtime state not initialized.")
|
||||
raise ValueError("graph runtime state not initialized.")
|
||||
|
||||
workflow_run = self._handle_workflow_run_partial_success(
|
||||
workflow_run=workflow_run,
|
||||
|
|
@ -412,10 +432,10 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||
self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE)
|
||||
elif isinstance(event, QueueWorkflowFailedEvent):
|
||||
if not workflow_run:
|
||||
raise Exception("Workflow run not initialized.")
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
if not graph_runtime_state:
|
||||
raise Exception("Graph runtime state not initialized.")
|
||||
raise ValueError("graph runtime state not initialized.")
|
||||
|
||||
workflow_run = self._handle_workflow_run_failed(
|
||||
workflow_run=workflow_run,
|
||||
|
|
@ -494,7 +514,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||
|
||||
# only publish tts message at text chunk streaming
|
||||
if tts_publisher:
|
||||
tts_publisher.publish(message=queue_message)
|
||||
tts_publisher.publish(queue_message)
|
||||
|
||||
self._task_state.answer += delta_text
|
||||
yield self._message_to_stream_response(
|
||||
|
|
@ -505,7 +525,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||
yield self._message_replace_to_stream_response(answer=event.text)
|
||||
elif isinstance(event, QueueAdvancedChatMessageEndEvent):
|
||||
if not graph_runtime_state:
|
||||
raise Exception("Graph runtime state not initialized.")
|
||||
raise ValueError("graph runtime state not initialized.")
|
||||
|
||||
output_moderation_answer = self._handle_output_moderation_when_task_finished(self._task_state.answer)
|
||||
if output_moderation_answer:
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
import queue
|
||||
import time
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Generator
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
|
|
@ -11,9 +10,11 @@ from configs import dify_config
|
|||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.queue_entities import (
|
||||
AppQueueEvent,
|
||||
MessageQueueMessage,
|
||||
QueueErrorEvent,
|
||||
QueuePingEvent,
|
||||
QueueStopEvent,
|
||||
WorkflowQueueMessage,
|
||||
)
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
|
|
@ -37,11 +38,11 @@ class AppQueueManager:
|
|||
AppQueueManager._generate_task_belong_cache_key(self._task_id), 1800, f"{user_prefix}-{self._user_id}"
|
||||
)
|
||||
|
||||
q = queue.Queue()
|
||||
q: queue.Queue[WorkflowQueueMessage | MessageQueueMessage | None] = queue.Queue()
|
||||
|
||||
self._q = q
|
||||
|
||||
def listen(self) -> Generator:
|
||||
def listen(self):
|
||||
"""
|
||||
Listen to queue
|
||||
:return:
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ from core.app.entities.queue_entities import (
|
|||
QueueNodeExceptionEvent,
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeInIterationFailedEvent,
|
||||
QueueNodeRetryEvent,
|
||||
QueueNodeStartedEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
QueueParallelBranchRunFailedEvent,
|
||||
|
|
@ -154,7 +155,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||
else:
|
||||
continue
|
||||
|
||||
raise Exception("Queue listening stopped unexpectedly.")
|
||||
raise ValueError("queue listening stopped unexpectedly.")
|
||||
|
||||
def _to_stream_response(
|
||||
self, generator: Generator[StreamResponse, None, None]
|
||||
|
|
@ -170,11 +171,11 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||
|
||||
yield WorkflowAppStreamResponse(workflow_run_id=workflow_run_id, stream_response=stream_response)
|
||||
|
||||
def _listen_audio_msg(self, publisher, task_id: str):
|
||||
def _listen_audio_msg(self, publisher: AppGeneratorTTSPublisher | None, task_id: str):
|
||||
if not publisher:
|
||||
return None
|
||||
audio_msg: AudioTrunk = publisher.check_and_get_audio()
|
||||
if audio_msg and audio_msg.status != "finish":
|
||||
audio_msg = publisher.check_and_get_audio()
|
||||
if audio_msg and isinstance(audio_msg, AudioTrunk) and audio_msg.status != "finish":
|
||||
return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id)
|
||||
return None
|
||||
|
||||
|
|
@ -195,7 +196,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||
|
||||
for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager):
|
||||
while True:
|
||||
audio_response = self._listen_audio_msg(tts_publisher, task_id=task_id)
|
||||
audio_response = self._listen_audio_msg(publisher=tts_publisher, task_id=task_id)
|
||||
if audio_response:
|
||||
yield audio_response
|
||||
else:
|
||||
|
|
@ -217,7 +218,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||
break
|
||||
else:
|
||||
yield MessageAudioStreamResponse(audio=audio_trunk.audio, task_id=task_id)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
logger.exception(f"Fails to get audio trunk, task_id: {task_id}")
|
||||
break
|
||||
if tts_publisher:
|
||||
|
|
@ -253,9 +254,27 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||
yield self._workflow_start_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
||||
)
|
||||
elif isinstance(
|
||||
event,
|
||||
QueueNodeRetryEvent,
|
||||
):
|
||||
if not workflow_run:
|
||||
raise ValueError("workflow run not initialized.")
|
||||
workflow_node_execution = self._handle_workflow_node_execution_retried(
|
||||
workflow_run=workflow_run, event=event
|
||||
)
|
||||
|
||||
response = self._workflow_node_retry_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
)
|
||||
|
||||
if response:
|
||||
yield response
|
||||
elif isinstance(event, QueueNodeStartedEvent):
|
||||
if not workflow_run:
|
||||
raise Exception("Workflow run not initialized.")
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
workflow_node_execution = self._handle_node_execution_start(workflow_run=workflow_run, event=event)
|
||||
|
||||
|
|
@ -286,50 +305,50 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
)
|
||||
|
||||
if node_failed_response:
|
||||
yield node_failed_response
|
||||
|
||||
elif isinstance(event, QueueParallelBranchRunStartedEvent):
|
||||
if not workflow_run:
|
||||
raise Exception("Workflow run not initialized.")
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
yield self._workflow_parallel_branch_start_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
|
||||
)
|
||||
elif isinstance(event, QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent):
|
||||
if not workflow_run:
|
||||
raise Exception("Workflow run not initialized.")
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
yield self._workflow_parallel_branch_finished_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
|
||||
)
|
||||
elif isinstance(event, QueueIterationStartEvent):
|
||||
if not workflow_run:
|
||||
raise Exception("Workflow run not initialized.")
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
yield self._workflow_iteration_start_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
|
||||
)
|
||||
elif isinstance(event, QueueIterationNextEvent):
|
||||
if not workflow_run:
|
||||
raise Exception("Workflow run not initialized.")
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
yield self._workflow_iteration_next_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
|
||||
)
|
||||
elif isinstance(event, QueueIterationCompletedEvent):
|
||||
if not workflow_run:
|
||||
raise Exception("Workflow run not initialized.")
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
yield self._workflow_iteration_completed_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
|
||||
)
|
||||
elif isinstance(event, QueueWorkflowSucceededEvent):
|
||||
if not workflow_run:
|
||||
raise Exception("Workflow run not initialized.")
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
if not graph_runtime_state:
|
||||
raise Exception("Graph runtime state not initialized.")
|
||||
raise ValueError("graph runtime state not initialized.")
|
||||
|
||||
workflow_run = self._handle_workflow_run_success(
|
||||
workflow_run=workflow_run,
|
||||
|
|
@ -349,10 +368,10 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||
)
|
||||
elif isinstance(event, QueueWorkflowPartialSuccessEvent):
|
||||
if not workflow_run:
|
||||
raise Exception("Workflow run not initialized.")
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
if not graph_runtime_state:
|
||||
raise Exception("Graph runtime state not initialized.")
|
||||
raise ValueError("graph runtime state not initialized.")
|
||||
|
||||
workflow_run = self._handle_workflow_run_partial_success(
|
||||
workflow_run=workflow_run,
|
||||
|
|
@ -373,10 +392,10 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||
)
|
||||
elif isinstance(event, QueueWorkflowFailedEvent | QueueStopEvent):
|
||||
if not workflow_run:
|
||||
raise Exception("Workflow run not initialized.")
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
if not graph_runtime_state:
|
||||
raise Exception("Graph runtime state not initialized.")
|
||||
raise ValueError("graph runtime state not initialized.")
|
||||
workflow_run = self._handle_workflow_run_failed(
|
||||
workflow_run=workflow_run,
|
||||
start_at=graph_runtime_state.start_at,
|
||||
|
|
@ -404,7 +423,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||
|
||||
# only publish tts message at text chunk streaming
|
||||
if tts_publisher:
|
||||
tts_publisher.publish(message=queue_message)
|
||||
tts_publisher.publish(queue_message)
|
||||
|
||||
self._task_state.answer += delta_text
|
||||
yield self._text_chunk_to_stream_response(
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ from core.app.entities.queue_entities import (
|
|||
QueueNodeExceptionEvent,
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeInIterationFailedEvent,
|
||||
QueueNodeRetryEvent,
|
||||
QueueNodeStartedEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
QueueParallelBranchRunFailedEvent,
|
||||
|
|
@ -38,6 +39,7 @@ from core.workflow.graph_engine.entities.event import (
|
|||
NodeRunExceptionEvent,
|
||||
NodeRunFailedEvent,
|
||||
NodeRunRetrieverResourceEvent,
|
||||
NodeRunRetryEvent,
|
||||
NodeRunStartedEvent,
|
||||
NodeRunStreamChunkEvent,
|
||||
NodeRunSucceededEvent,
|
||||
|
|
@ -186,6 +188,41 @@ class WorkflowBasedAppRunner(AppRunner):
|
|||
)
|
||||
elif isinstance(event, GraphRunFailedEvent):
|
||||
self._publish_event(QueueWorkflowFailedEvent(error=event.error, exceptions_count=event.exceptions_count))
|
||||
elif isinstance(event, NodeRunRetryEvent):
|
||||
node_run_result = event.route_node_state.node_run_result
|
||||
if node_run_result:
|
||||
inputs = node_run_result.inputs
|
||||
process_data = node_run_result.process_data
|
||||
outputs = node_run_result.outputs
|
||||
execution_metadata = node_run_result.metadata
|
||||
else:
|
||||
inputs = {}
|
||||
process_data = {}
|
||||
outputs = {}
|
||||
execution_metadata = {}
|
||||
self._publish_event(
|
||||
QueueNodeRetryEvent(
|
||||
node_execution_id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_data=event.node_data,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
start_at=event.start_at,
|
||||
node_run_index=event.route_node_state.index,
|
||||
predecessor_node_id=event.predecessor_node_id,
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
parallel_mode_run_id=event.parallel_mode_run_id,
|
||||
inputs=inputs,
|
||||
process_data=process_data,
|
||||
outputs=outputs,
|
||||
error=event.error,
|
||||
execution_metadata=execution_metadata,
|
||||
retry_index=event.retry_index,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunStartedEvent):
|
||||
self._publish_event(
|
||||
QueueNodeStartedEvent(
|
||||
|
|
@ -205,6 +242,17 @@ class WorkflowBasedAppRunner(AppRunner):
|
|||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunSucceededEvent):
|
||||
node_run_result = event.route_node_state.node_run_result
|
||||
if node_run_result:
|
||||
inputs = node_run_result.inputs
|
||||
process_data = node_run_result.process_data
|
||||
outputs = node_run_result.outputs
|
||||
execution_metadata = node_run_result.metadata
|
||||
else:
|
||||
inputs = {}
|
||||
process_data = {}
|
||||
outputs = {}
|
||||
execution_metadata = {}
|
||||
self._publish_event(
|
||||
QueueNodeSucceededEvent(
|
||||
node_execution_id=event.id,
|
||||
|
|
@ -216,18 +264,10 @@ class WorkflowBasedAppRunner(AppRunner):
|
|||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
start_at=event.route_node_state.start_at,
|
||||
inputs=event.route_node_state.node_run_result.inputs
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
process_data=event.route_node_state.node_run_result.process_data
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
outputs=event.route_node_state.node_run_result.outputs
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
execution_metadata=event.route_node_state.node_run_result.metadata
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
inputs=inputs,
|
||||
process_data=process_data,
|
||||
outputs=outputs,
|
||||
execution_metadata=execution_metadata,
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
)
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
from collections.abc import Mapping
|
||||
from datetime import datetime
|
||||
from enum import Enum, StrEnum
|
||||
from typing import Any, Optional
|
||||
|
|
@ -43,6 +44,7 @@ class QueueEvent(StrEnum):
|
|||
ERROR = "error"
|
||||
PING = "ping"
|
||||
STOP = "stop"
|
||||
RETRY = "retry"
|
||||
|
||||
|
||||
class AppQueueEvent(BaseModel):
|
||||
|
|
@ -84,9 +86,9 @@ class QueueIterationStartEvent(AppQueueEvent):
|
|||
start_at: datetime
|
||||
|
||||
node_run_index: int
|
||||
inputs: Optional[dict[str, Any]] = None
|
||||
inputs: Optional[Mapping[str, Any]] = None
|
||||
predecessor_node_id: Optional[str] = None
|
||||
metadata: Optional[dict[str, Any]] = None
|
||||
metadata: Optional[Mapping[str, Any]] = None
|
||||
|
||||
|
||||
class QueueIterationNextEvent(AppQueueEvent):
|
||||
|
|
@ -138,9 +140,9 @@ class QueueIterationCompletedEvent(AppQueueEvent):
|
|||
start_at: datetime
|
||||
|
||||
node_run_index: int
|
||||
inputs: Optional[dict[str, Any]] = None
|
||||
outputs: Optional[dict[str, Any]] = None
|
||||
metadata: Optional[dict[str, Any]] = None
|
||||
inputs: Optional[Mapping[str, Any]] = None
|
||||
outputs: Optional[Mapping[str, Any]] = None
|
||||
metadata: Optional[Mapping[str, Any]] = None
|
||||
steps: int = 0
|
||||
|
||||
error: Optional[str] = None
|
||||
|
|
@ -303,9 +305,9 @@ class QueueNodeSucceededEvent(AppQueueEvent):
|
|||
"""iteration id if node is in iteration"""
|
||||
start_at: datetime
|
||||
|
||||
inputs: Optional[dict[str, Any]] = None
|
||||
process_data: Optional[dict[str, Any]] = None
|
||||
outputs: Optional[dict[str, Any]] = None
|
||||
inputs: Optional[Mapping[str, Any]] = None
|
||||
process_data: Optional[Mapping[str, Any]] = None
|
||||
outputs: Optional[Mapping[str, Any]] = None
|
||||
execution_metadata: Optional[dict[NodeRunMetadataKey, Any]] = None
|
||||
|
||||
error: Optional[str] = None
|
||||
|
|
@ -313,6 +315,20 @@ class QueueNodeSucceededEvent(AppQueueEvent):
|
|||
iteration_duration_map: Optional[dict[str, float]] = None
|
||||
|
||||
|
||||
class QueueNodeRetryEvent(QueueNodeStartedEvent):
|
||||
"""QueueNodeRetryEvent entity"""
|
||||
|
||||
event: QueueEvent = QueueEvent.RETRY
|
||||
|
||||
inputs: Optional[Mapping[str, Any]] = None
|
||||
process_data: Optional[Mapping[str, Any]] = None
|
||||
outputs: Optional[Mapping[str, Any]] = None
|
||||
execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None
|
||||
|
||||
error: str
|
||||
retry_index: int # retry index
|
||||
|
||||
|
||||
class QueueNodeInIterationFailedEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueNodeInIterationFailedEvent entity
|
||||
|
|
@ -336,10 +352,10 @@ class QueueNodeInIterationFailedEvent(AppQueueEvent):
|
|||
"""iteration id if node is in iteration"""
|
||||
start_at: datetime
|
||||
|
||||
inputs: Optional[dict[str, Any]] = None
|
||||
process_data: Optional[dict[str, Any]] = None
|
||||
outputs: Optional[dict[str, Any]] = None
|
||||
execution_metadata: Optional[dict[NodeRunMetadataKey, Any]] = None
|
||||
inputs: Optional[Mapping[str, Any]] = None
|
||||
process_data: Optional[Mapping[str, Any]] = None
|
||||
outputs: Optional[Mapping[str, Any]] = None
|
||||
execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None
|
||||
|
||||
error: str
|
||||
|
||||
|
|
@ -367,10 +383,10 @@ class QueueNodeExceptionEvent(AppQueueEvent):
|
|||
"""iteration id if node is in iteration"""
|
||||
start_at: datetime
|
||||
|
||||
inputs: Optional[dict[str, Any]] = None
|
||||
process_data: Optional[dict[str, Any]] = None
|
||||
outputs: Optional[dict[str, Any]] = None
|
||||
execution_metadata: Optional[dict[NodeRunMetadataKey, Any]] = None
|
||||
inputs: Optional[Mapping[str, Any]] = None
|
||||
process_data: Optional[Mapping[str, Any]] = None
|
||||
outputs: Optional[Mapping[str, Any]] = None
|
||||
execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None
|
||||
|
||||
error: str
|
||||
|
||||
|
|
@ -398,10 +414,10 @@ class QueueNodeFailedEvent(AppQueueEvent):
|
|||
"""iteration id if node is in iteration"""
|
||||
start_at: datetime
|
||||
|
||||
inputs: Optional[dict[str, Any]] = None
|
||||
process_data: Optional[dict[str, Any]] = None
|
||||
outputs: Optional[dict[str, Any]] = None
|
||||
execution_metadata: Optional[dict[NodeRunMetadataKey, Any]] = None
|
||||
inputs: Optional[Mapping[str, Any]] = None
|
||||
process_data: Optional[Mapping[str, Any]] = None
|
||||
outputs: Optional[Mapping[str, Any]] = None
|
||||
execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None
|
||||
|
||||
error: str
|
||||
|
||||
|
|
|
|||
|
|
@ -52,6 +52,7 @@ class StreamEvent(Enum):
|
|||
WORKFLOW_FINISHED = "workflow_finished"
|
||||
NODE_STARTED = "node_started"
|
||||
NODE_FINISHED = "node_finished"
|
||||
NODE_RETRY = "node_retry"
|
||||
PARALLEL_BRANCH_STARTED = "parallel_branch_started"
|
||||
PARALLEL_BRANCH_FINISHED = "parallel_branch_finished"
|
||||
ITERATION_STARTED = "iteration_started"
|
||||
|
|
@ -342,6 +343,75 @@ class NodeFinishStreamResponse(StreamResponse):
|
|||
}
|
||||
|
||||
|
||||
class NodeRetryStreamResponse(StreamResponse):
|
||||
"""
|
||||
NodeFinishStreamResponse entity
|
||||
"""
|
||||
|
||||
class Data(BaseModel):
|
||||
"""
|
||||
Data entity
|
||||
"""
|
||||
|
||||
id: str
|
||||
node_id: str
|
||||
node_type: str
|
||||
title: str
|
||||
index: int
|
||||
predecessor_node_id: Optional[str] = None
|
||||
inputs: Optional[dict] = None
|
||||
process_data: Optional[dict] = None
|
||||
outputs: Optional[dict] = None
|
||||
status: str
|
||||
error: Optional[str] = None
|
||||
elapsed_time: float
|
||||
execution_metadata: Optional[dict] = None
|
||||
created_at: int
|
||||
finished_at: int
|
||||
files: Optional[Sequence[Mapping[str, Any]]] = []
|
||||
parallel_id: Optional[str] = None
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
parent_parallel_id: Optional[str] = None
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
iteration_id: Optional[str] = None
|
||||
retry_index: int = 0
|
||||
|
||||
event: StreamEvent = StreamEvent.NODE_RETRY
|
||||
workflow_run_id: str
|
||||
data: Data
|
||||
|
||||
def to_ignore_detail_dict(self):
|
||||
return {
|
||||
"event": self.event.value,
|
||||
"task_id": self.task_id,
|
||||
"workflow_run_id": self.workflow_run_id,
|
||||
"data": {
|
||||
"id": self.data.id,
|
||||
"node_id": self.data.node_id,
|
||||
"node_type": self.data.node_type,
|
||||
"title": self.data.title,
|
||||
"index": self.data.index,
|
||||
"predecessor_node_id": self.data.predecessor_node_id,
|
||||
"inputs": None,
|
||||
"process_data": None,
|
||||
"outputs": None,
|
||||
"status": self.data.status,
|
||||
"error": None,
|
||||
"elapsed_time": self.data.elapsed_time,
|
||||
"execution_metadata": None,
|
||||
"created_at": self.data.created_at,
|
||||
"finished_at": self.data.finished_at,
|
||||
"files": [],
|
||||
"parallel_id": self.data.parallel_id,
|
||||
"parallel_start_node_id": self.data.parallel_start_node_id,
|
||||
"parent_parallel_id": self.data.parent_parallel_id,
|
||||
"parent_parallel_start_node_id": self.data.parent_parallel_start_node_id,
|
||||
"iteration_id": self.data.iteration_id,
|
||||
"retry_index": self.data.retry_index,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class ParallelBranchStartStreamResponse(StreamResponse):
|
||||
"""
|
||||
ParallelBranchStartStreamResponse entity
|
||||
|
|
|
|||
|
|
@ -201,11 +201,11 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
|||
stream_response=stream_response,
|
||||
)
|
||||
|
||||
def _listen_audio_msg(self, publisher, task_id: str):
|
||||
def _listen_audio_msg(self, publisher: AppGeneratorTTSPublisher | None, task_id: str):
|
||||
if publisher is None:
|
||||
return None
|
||||
audio_msg: AudioTrunk = publisher.check_and_get_audio()
|
||||
if audio_msg and audio_msg.status != "finish":
|
||||
audio_msg = publisher.check_and_get_audio()
|
||||
if audio_msg and isinstance(audio_msg, AudioTrunk) and audio_msg.status != "finish":
|
||||
# audio_str = audio_msg.audio.decode('utf-8', errors='ignore')
|
||||
return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id)
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ from core.app.entities.queue_entities import (
|
|||
QueueNodeExceptionEvent,
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeInIterationFailedEvent,
|
||||
QueueNodeRetryEvent,
|
||||
QueueNodeStartedEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
QueueParallelBranchRunFailedEvent,
|
||||
|
|
@ -26,6 +27,7 @@ from core.app.entities.task_entities import (
|
|||
IterationNodeNextStreamResponse,
|
||||
IterationNodeStartStreamResponse,
|
||||
NodeFinishStreamResponse,
|
||||
NodeRetryStreamResponse,
|
||||
NodeStartStreamResponse,
|
||||
ParallelBranchFinishedStreamResponse,
|
||||
ParallelBranchStartStreamResponse,
|
||||
|
|
@ -271,9 +273,9 @@ class WorkflowCycleManage:
|
|||
|
||||
db.session.close()
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
session.add(workflow_run)
|
||||
session.refresh(workflow_run)
|
||||
# with Session(db.engine, expire_on_commit=False) as session:
|
||||
# session.add(workflow_run)
|
||||
# session.refresh(workflow_run)
|
||||
|
||||
if trace_manager:
|
||||
trace_manager.add_trace_task(
|
||||
|
|
@ -423,6 +425,59 @@ class WorkflowCycleManage:
|
|||
|
||||
return workflow_node_execution
|
||||
|
||||
def _handle_workflow_node_execution_retried(
|
||||
self, workflow_run: WorkflowRun, event: QueueNodeRetryEvent
|
||||
) -> WorkflowNodeExecution:
|
||||
"""
|
||||
Workflow node execution failed
|
||||
:param event: queue node failed event
|
||||
:return:
|
||||
"""
|
||||
created_at = event.start_at
|
||||
finished_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
elapsed_time = (finished_at - created_at).total_seconds()
|
||||
inputs = WorkflowEntry.handle_special_values(event.inputs)
|
||||
outputs = WorkflowEntry.handle_special_values(event.outputs)
|
||||
origin_metadata = {
|
||||
NodeRunMetadataKey.ITERATION_ID: event.in_iteration_id,
|
||||
NodeRunMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id,
|
||||
}
|
||||
merged_metadata = (
|
||||
{**jsonable_encoder(event.execution_metadata), **origin_metadata}
|
||||
if event.execution_metadata is not None
|
||||
else origin_metadata
|
||||
)
|
||||
execution_metadata = json.dumps(merged_metadata)
|
||||
|
||||
workflow_node_execution = WorkflowNodeExecution()
|
||||
workflow_node_execution.tenant_id = workflow_run.tenant_id
|
||||
workflow_node_execution.app_id = workflow_run.app_id
|
||||
workflow_node_execution.workflow_id = workflow_run.workflow_id
|
||||
workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value
|
||||
workflow_node_execution.workflow_run_id = workflow_run.id
|
||||
workflow_node_execution.predecessor_node_id = event.predecessor_node_id
|
||||
workflow_node_execution.node_execution_id = event.node_execution_id
|
||||
workflow_node_execution.node_id = event.node_id
|
||||
workflow_node_execution.node_type = event.node_type.value
|
||||
workflow_node_execution.title = event.node_data.title
|
||||
workflow_node_execution.status = WorkflowNodeExecutionStatus.RETRY.value
|
||||
workflow_node_execution.created_by_role = workflow_run.created_by_role
|
||||
workflow_node_execution.created_by = workflow_run.created_by
|
||||
workflow_node_execution.created_at = created_at
|
||||
workflow_node_execution.finished_at = finished_at
|
||||
workflow_node_execution.elapsed_time = elapsed_time
|
||||
workflow_node_execution.error = event.error
|
||||
workflow_node_execution.inputs = json.dumps(inputs) if inputs else None
|
||||
workflow_node_execution.outputs = json.dumps(outputs) if outputs else None
|
||||
workflow_node_execution.execution_metadata = execution_metadata
|
||||
workflow_node_execution.index = event.node_run_index
|
||||
|
||||
db.session.add(workflow_node_execution)
|
||||
db.session.commit()
|
||||
db.session.refresh(workflow_node_execution)
|
||||
|
||||
return workflow_node_execution
|
||||
|
||||
#################################################
|
||||
# to stream responses #
|
||||
#################################################
|
||||
|
|
@ -457,6 +512,12 @@ class WorkflowCycleManage:
|
|||
:param workflow_run: workflow run
|
||||
:return:
|
||||
"""
|
||||
# Attach WorkflowRun to an active session so "created_by_role" can be accessed.
|
||||
workflow_run = db.session.merge(workflow_run)
|
||||
|
||||
# Refresh to ensure any expired attributes are fully loaded
|
||||
db.session.refresh(workflow_run)
|
||||
|
||||
created_by = None
|
||||
if workflow_run.created_by_role == CreatedByRole.ACCOUNT.value:
|
||||
created_by_account = workflow_run.created_by_account
|
||||
|
|
@ -587,6 +648,51 @@ class WorkflowCycleManage:
|
|||
),
|
||||
)
|
||||
|
||||
def _workflow_node_retry_to_stream_response(
|
||||
self,
|
||||
event: QueueNodeRetryEvent,
|
||||
task_id: str,
|
||||
workflow_node_execution: WorkflowNodeExecution,
|
||||
) -> Optional[NodeFinishStreamResponse]:
|
||||
"""
|
||||
Workflow node finish to stream response.
|
||||
:param event: queue node succeeded or failed event
|
||||
:param task_id: task id
|
||||
:param workflow_node_execution: workflow node execution
|
||||
:return:
|
||||
"""
|
||||
if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}:
|
||||
return None
|
||||
|
||||
return NodeRetryStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_node_execution.workflow_run_id,
|
||||
data=NodeRetryStreamResponse.Data(
|
||||
id=workflow_node_execution.id,
|
||||
node_id=workflow_node_execution.node_id,
|
||||
node_type=workflow_node_execution.node_type,
|
||||
index=workflow_node_execution.index,
|
||||
title=workflow_node_execution.title,
|
||||
predecessor_node_id=workflow_node_execution.predecessor_node_id,
|
||||
inputs=workflow_node_execution.inputs_dict,
|
||||
process_data=workflow_node_execution.process_data_dict,
|
||||
outputs=workflow_node_execution.outputs_dict,
|
||||
status=workflow_node_execution.status,
|
||||
error=workflow_node_execution.error,
|
||||
elapsed_time=workflow_node_execution.elapsed_time,
|
||||
execution_metadata=workflow_node_execution.execution_metadata_dict,
|
||||
created_at=int(workflow_node_execution.created_at.timestamp()),
|
||||
finished_at=int(workflow_node_execution.finished_at.timestamp()),
|
||||
files=self._fetch_files_from_node_outputs(workflow_node_execution.outputs_dict or {}),
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
iteration_id=event.in_iteration_id,
|
||||
retry_index=event.retry_index,
|
||||
),
|
||||
)
|
||||
|
||||
def _workflow_parallel_branch_start_to_stream_response(
|
||||
self, task_id: str, workflow_run: WorkflowRun, event: QueueParallelBranchRunStartedEvent
|
||||
) -> ParallelBranchStartStreamResponse:
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from typing import Optional
|
||||
|
||||
|
||||
class LLMError(Exception):
|
||||
class LLMError(ValueError):
|
||||
"""Base class for all LLM exceptions."""
|
||||
|
||||
description: Optional[str] = None
|
||||
|
|
@ -16,7 +16,7 @@ class LLMBadRequestError(LLMError):
|
|||
description = "Bad Request"
|
||||
|
||||
|
||||
class ProviderTokenNotInitError(Exception):
|
||||
class ProviderTokenNotInitError(ValueError):
|
||||
"""
|
||||
Custom exception raised when the provider token is not initialized.
|
||||
"""
|
||||
|
|
@ -27,7 +27,7 @@ class ProviderTokenNotInitError(Exception):
|
|||
self.description = args[0] if args else self.description
|
||||
|
||||
|
||||
class QuotaExceededError(Exception):
|
||||
class QuotaExceededError(ValueError):
|
||||
"""
|
||||
Custom exception raised when the quota for a provider has been exceeded.
|
||||
"""
|
||||
|
|
@ -35,7 +35,7 @@ class QuotaExceededError(Exception):
|
|||
description = "Quota Exceeded"
|
||||
|
||||
|
||||
class AppInvokeQuotaExceededError(Exception):
|
||||
class AppInvokeQuotaExceededError(ValueError):
|
||||
"""
|
||||
Custom exception raised when the quota for an app has been exceeded.
|
||||
"""
|
||||
|
|
@ -43,7 +43,7 @@ class AppInvokeQuotaExceededError(Exception):
|
|||
description = "App Invoke Quota Exceeded"
|
||||
|
||||
|
||||
class ModelCurrentlyNotSupportError(Exception):
|
||||
class ModelCurrentlyNotSupportError(ValueError):
|
||||
"""
|
||||
Custom exception raised when the model not support
|
||||
"""
|
||||
|
|
@ -51,7 +51,7 @@ class ModelCurrentlyNotSupportError(Exception):
|
|||
description = "Model Currently Not Support"
|
||||
|
||||
|
||||
class InvokeRateLimitError(Exception):
|
||||
class InvokeRateLimitError(ValueError):
|
||||
"""Raised when the Invoke returns rate limit error."""
|
||||
|
||||
description = "Rate Limit Error"
|
||||
|
|
|
|||
|
|
@ -1,15 +1,14 @@
|
|||
import base64
|
||||
|
||||
from configs import dify_config
|
||||
from core.file import file_repository
|
||||
from core.helper import ssrf_proxy
|
||||
from core.model_runtime.entities import (
|
||||
AudioPromptMessageContent,
|
||||
DocumentPromptMessageContent,
|
||||
ImagePromptMessageContent,
|
||||
MultiModalPromptMessageContent,
|
||||
VideoPromptMessageContent,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
|
||||
from . import helpers
|
||||
|
|
@ -41,53 +40,42 @@ def to_prompt_message_content(
|
|||
/,
|
||||
*,
|
||||
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
|
||||
):
|
||||
match f.type:
|
||||
case FileType.IMAGE:
|
||||
image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
|
||||
if dify_config.MULTIMODAL_SEND_IMAGE_FORMAT == "url":
|
||||
data = _to_url(f)
|
||||
else:
|
||||
data = _to_base64_data_string(f)
|
||||
) -> MultiModalPromptMessageContent:
|
||||
if f.extension is None:
|
||||
raise ValueError("Missing file extension")
|
||||
if f.mime_type is None:
|
||||
raise ValueError("Missing file mime_type")
|
||||
|
||||
return ImagePromptMessageContent(data=data, detail=image_detail_config)
|
||||
case FileType.AUDIO:
|
||||
encoded_string = _get_encoded_string(f)
|
||||
if f.extension is None:
|
||||
raise ValueError("Missing file extension")
|
||||
return AudioPromptMessageContent(data=encoded_string, format=f.extension.lstrip("."))
|
||||
case FileType.VIDEO:
|
||||
if dify_config.MULTIMODAL_SEND_VIDEO_FORMAT == "url":
|
||||
data = _to_url(f)
|
||||
else:
|
||||
data = _to_base64_data_string(f)
|
||||
if f.extension is None:
|
||||
raise ValueError("Missing file extension")
|
||||
return VideoPromptMessageContent(data=data, format=f.extension.lstrip("."))
|
||||
case FileType.DOCUMENT:
|
||||
data = _get_encoded_string(f)
|
||||
if f.mime_type is None:
|
||||
raise ValueError("Missing file mime_type")
|
||||
return DocumentPromptMessageContent(
|
||||
encode_format="base64",
|
||||
mime_type=f.mime_type,
|
||||
data=data,
|
||||
)
|
||||
case _:
|
||||
raise ValueError(f"file type {f.type} is not supported")
|
||||
params = {
|
||||
"base64_data": _get_encoded_string(f) if dify_config.MULTIMODAL_SEND_FORMAT == "base64" else "",
|
||||
"url": _to_url(f) if dify_config.MULTIMODAL_SEND_FORMAT == "url" else "",
|
||||
"format": f.extension.removeprefix("."),
|
||||
"mime_type": f.mime_type,
|
||||
}
|
||||
if f.type == FileType.IMAGE:
|
||||
params["detail"] = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
|
||||
|
||||
prompt_class_map = {
|
||||
FileType.IMAGE: ImagePromptMessageContent,
|
||||
FileType.AUDIO: AudioPromptMessageContent,
|
||||
FileType.VIDEO: VideoPromptMessageContent,
|
||||
FileType.DOCUMENT: DocumentPromptMessageContent,
|
||||
}
|
||||
|
||||
try:
|
||||
return prompt_class_map[f.type](**params)
|
||||
except KeyError:
|
||||
raise ValueError(f"file type {f.type} is not supported")
|
||||
|
||||
|
||||
def download(f: File, /):
|
||||
if f.transfer_method == FileTransferMethod.TOOL_FILE:
|
||||
tool_file = file_repository.get_tool_file(session=db.session(), file=f)
|
||||
return _download_file_content(tool_file.file_key)
|
||||
elif f.transfer_method == FileTransferMethod.LOCAL_FILE:
|
||||
upload_file = file_repository.get_upload_file(session=db.session(), file=f)
|
||||
return _download_file_content(upload_file.key)
|
||||
# remote file
|
||||
response = ssrf_proxy.get(f.remote_url, follow_redirects=True)
|
||||
response.raise_for_status()
|
||||
return response.content
|
||||
if f.transfer_method in (FileTransferMethod.TOOL_FILE, FileTransferMethod.LOCAL_FILE):
|
||||
return _download_file_content(f._storage_key)
|
||||
elif f.transfer_method == FileTransferMethod.REMOTE_URL:
|
||||
response = ssrf_proxy.get(f.remote_url, follow_redirects=True)
|
||||
response.raise_for_status()
|
||||
return response.content
|
||||
raise ValueError(f"unsupported transfer method: {f.transfer_method}")
|
||||
|
||||
|
||||
def _download_file_content(path: str, /):
|
||||
|
|
@ -118,21 +106,14 @@ def _get_encoded_string(f: File, /):
|
|||
response.raise_for_status()
|
||||
data = response.content
|
||||
case FileTransferMethod.LOCAL_FILE:
|
||||
upload_file = file_repository.get_upload_file(session=db.session(), file=f)
|
||||
data = _download_file_content(upload_file.key)
|
||||
data = _download_file_content(f._storage_key)
|
||||
case FileTransferMethod.TOOL_FILE:
|
||||
tool_file = file_repository.get_tool_file(session=db.session(), file=f)
|
||||
data = _download_file_content(tool_file.file_key)
|
||||
data = _download_file_content(f._storage_key)
|
||||
|
||||
encoded_string = base64.b64encode(data).decode("utf-8")
|
||||
return encoded_string
|
||||
|
||||
|
||||
def _to_base64_data_string(f: File, /):
|
||||
encoded_string = _get_encoded_string(f)
|
||||
return f"data:{f.mime_type};base64,{encoded_string}"
|
||||
|
||||
|
||||
def _to_url(f: File, /):
|
||||
if f.transfer_method == FileTransferMethod.REMOTE_URL:
|
||||
if f.remote_url is None:
|
||||
|
|
|
|||
|
|
@ -1,32 +0,0 @@
|
|||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from models import ToolFile, UploadFile
|
||||
|
||||
from .models import File
|
||||
|
||||
|
||||
def get_upload_file(*, session: Session, file: File):
|
||||
if file.related_id is None:
|
||||
raise ValueError("Missing file related_id")
|
||||
stmt = select(UploadFile).filter(
|
||||
UploadFile.id == file.related_id,
|
||||
UploadFile.tenant_id == file.tenant_id,
|
||||
)
|
||||
record = session.scalar(stmt)
|
||||
if not record:
|
||||
raise ValueError(f"upload file {file.related_id} not found")
|
||||
return record
|
||||
|
||||
|
||||
def get_tool_file(*, session: Session, file: File):
|
||||
if file.related_id is None:
|
||||
raise ValueError("Missing file related_id")
|
||||
stmt = select(ToolFile).filter(
|
||||
ToolFile.id == file.related_id,
|
||||
ToolFile.tenant_id == file.tenant_id,
|
||||
)
|
||||
record = session.scalar(stmt)
|
||||
if not record:
|
||||
raise ValueError(f"tool file {file.related_id} not found")
|
||||
return record
|
||||
|
|
@ -47,6 +47,38 @@ class File(BaseModel):
|
|||
mime_type: Optional[str] = None
|
||||
size: int = -1
|
||||
|
||||
# Those properties are private, should not be exposed to the outside.
|
||||
_storage_key: str
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
id: Optional[str] = None,
|
||||
tenant_id: str,
|
||||
type: FileType,
|
||||
transfer_method: FileTransferMethod,
|
||||
remote_url: Optional[str] = None,
|
||||
related_id: Optional[str] = None,
|
||||
filename: Optional[str] = None,
|
||||
extension: Optional[str] = None,
|
||||
mime_type: Optional[str] = None,
|
||||
size: int = -1,
|
||||
storage_key: str,
|
||||
):
|
||||
super().__init__(
|
||||
id=id,
|
||||
tenant_id=tenant_id,
|
||||
type=type,
|
||||
transfer_method=transfer_method,
|
||||
remote_url=remote_url,
|
||||
related_id=related_id,
|
||||
filename=filename,
|
||||
extension=extension,
|
||||
mime_type=mime_type,
|
||||
size=size,
|
||||
)
|
||||
self._storage_key = storage_key
|
||||
|
||||
def to_dict(self) -> Mapping[str, str | int | None]:
|
||||
data = self.model_dump(mode="json")
|
||||
return {
|
||||
|
|
|
|||
|
|
@ -118,7 +118,7 @@ class CodeExecutor:
|
|||
return response.data.stdout or ""
|
||||
|
||||
@classmethod
|
||||
def execute_workflow_code_template(cls, language: CodeLanguage, code: str, inputs: Mapping[str, Any]) -> dict:
|
||||
def execute_workflow_code_template(cls, language: CodeLanguage, code: str, inputs: Mapping[str, Any]):
|
||||
"""
|
||||
Execute code
|
||||
:param language: code language
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ class TemplateTransformer(ABC):
|
|||
return runner_script, preload_script
|
||||
|
||||
@classmethod
|
||||
def extract_result_str_from_response(cls, response: str) -> str:
|
||||
def extract_result_str_from_response(cls, response: str):
|
||||
result = re.search(rf"{cls._result_tag}(.*){cls._result_tag}", response, re.DOTALL)
|
||||
if not result:
|
||||
raise ValueError("Failed to parse result")
|
||||
|
|
@ -33,13 +33,21 @@ class TemplateTransformer(ABC):
|
|||
return result
|
||||
|
||||
@classmethod
|
||||
def transform_response(cls, response: str) -> dict:
|
||||
def transform_response(cls, response: str) -> Mapping[str, Any]:
|
||||
"""
|
||||
Transform response to dict
|
||||
:param response: response
|
||||
:return:
|
||||
"""
|
||||
return json.loads(cls.extract_result_str_from_response(response))
|
||||
try:
|
||||
result = json.loads(cls.extract_result_str_from_response(response))
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError("failed to parse response")
|
||||
if not isinstance(result, dict):
|
||||
raise ValueError("result must be a dict")
|
||||
if not all(isinstance(k, str) for k in result):
|
||||
raise ValueError("result keys must be strings")
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
import base64
|
||||
|
||||
from extensions.ext_database import db
|
||||
from libs import rsa
|
||||
|
||||
|
||||
|
|
@ -14,6 +13,7 @@ def obfuscated_token(token: str):
|
|||
|
||||
def encrypt_token(tenant_id: str, token: str):
|
||||
from models.account import Tenant
|
||||
from models.engine import db
|
||||
|
||||
if not (tenant := db.session.query(Tenant).filter(Tenant.id == tenant_id).first()):
|
||||
raise ValueError(f"Tenant with id {tenant_id} not found")
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ BACKOFF_FACTOR = 0.5
|
|||
STATUS_FORCELIST = [429, 500, 502, 503, 504]
|
||||
|
||||
|
||||
class MaxRetriesExceededError(Exception):
|
||||
class MaxRetriesExceededError(ValueError):
|
||||
"""Raised when the maximum number of retries is exceeded."""
|
||||
|
||||
pass
|
||||
|
|
@ -65,14 +65,16 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
|||
f"Received status code {response.status_code} for URL {url} which is in the force list")
|
||||
|
||||
except httpx.RequestError as e:
|
||||
logging.warning(
|
||||
f"Request to URL {url} failed on attempt {retries + 1}: {e}")
|
||||
logging.warning(f"Request to URL {url} failed on attempt {
|
||||
retries + 1}: {e}")
|
||||
if max_retries == 0:
|
||||
raise
|
||||
|
||||
retries += 1
|
||||
if retries <= max_retries:
|
||||
time.sleep(BACKOFF_FACTOR * (2 ** (retries - 1)))
|
||||
|
||||
raise MaxRetriesExceededError(f"Reached maximum retries ({max_retries}) for URL {url}")
|
||||
raise MaxRetriesExceededError(
|
||||
f"Reached maximum retries ({max_retries}) for URL {url}")
|
||||
|
||||
|
||||
def get(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
||||
|
|
|
|||
|
|
@ -1,2 +1,2 @@
|
|||
class OutputParserError(Exception):
|
||||
class OutputParserError(ValueError):
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ from .message_entities import (
|
|||
AudioPromptMessageContent,
|
||||
DocumentPromptMessageContent,
|
||||
ImagePromptMessageContent,
|
||||
MultiModalPromptMessageContent,
|
||||
PromptMessage,
|
||||
PromptMessageContent,
|
||||
PromptMessageContentType,
|
||||
|
|
@ -27,6 +28,7 @@ __all__ = [
|
|||
"LLMResultChunkDelta",
|
||||
"LLMUsage",
|
||||
"ModelPropertyKey",
|
||||
"MultiModalPromptMessageContent",
|
||||
"PromptMessage",
|
||||
"PromptMessage",
|
||||
"PromptMessageContent",
|
||||
|
|
|
|||
|
|
@ -1,9 +1,9 @@
|
|||
from abc import ABC
|
||||
from collections.abc import Sequence
|
||||
from enum import Enum, StrEnum
|
||||
from typing import Literal, Optional
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic import BaseModel, Field, computed_field, field_validator
|
||||
|
||||
|
||||
class PromptMessageRole(Enum):
|
||||
|
|
@ -67,7 +67,6 @@ class PromptMessageContent(BaseModel):
|
|||
"""
|
||||
|
||||
type: PromptMessageContentType
|
||||
data: str
|
||||
|
||||
|
||||
class TextPromptMessageContent(PromptMessageContent):
|
||||
|
|
@ -76,21 +75,35 @@ class TextPromptMessageContent(PromptMessageContent):
|
|||
"""
|
||||
|
||||
type: PromptMessageContentType = PromptMessageContentType.TEXT
|
||||
data: str
|
||||
|
||||
|
||||
class VideoPromptMessageContent(PromptMessageContent):
|
||||
class MultiModalPromptMessageContent(PromptMessageContent):
|
||||
"""
|
||||
Model class for multi-modal prompt message content.
|
||||
"""
|
||||
|
||||
type: PromptMessageContentType
|
||||
format: str = Field(default=..., description="the format of multi-modal file")
|
||||
base64_data: str = Field(default="", description="the base64 data of multi-modal file")
|
||||
url: str = Field(default="", description="the url of multi-modal file")
|
||||
mime_type: str = Field(default=..., description="the mime type of multi-modal file")
|
||||
|
||||
@computed_field(return_type=str)
|
||||
@property
|
||||
def data(self):
|
||||
return self.url or f"data:{self.mime_type};base64,{self.base64_data}"
|
||||
|
||||
|
||||
class VideoPromptMessageContent(MultiModalPromptMessageContent):
|
||||
type: PromptMessageContentType = PromptMessageContentType.VIDEO
|
||||
data: str = Field(..., description="Base64 encoded video data")
|
||||
format: str = Field(..., description="Video format")
|
||||
|
||||
|
||||
class AudioPromptMessageContent(PromptMessageContent):
|
||||
class AudioPromptMessageContent(MultiModalPromptMessageContent):
|
||||
type: PromptMessageContentType = PromptMessageContentType.AUDIO
|
||||
data: str = Field(..., description="Base64 encoded audio data")
|
||||
format: str = Field(..., description="Audio format")
|
||||
|
||||
|
||||
class ImagePromptMessageContent(PromptMessageContent):
|
||||
class ImagePromptMessageContent(MultiModalPromptMessageContent):
|
||||
"""
|
||||
Model class for image prompt message content.
|
||||
"""
|
||||
|
|
@ -103,11 +116,8 @@ class ImagePromptMessageContent(PromptMessageContent):
|
|||
detail: DETAIL = DETAIL.LOW
|
||||
|
||||
|
||||
class DocumentPromptMessageContent(PromptMessageContent):
|
||||
class DocumentPromptMessageContent(MultiModalPromptMessageContent):
|
||||
type: PromptMessageContentType = PromptMessageContentType.DOCUMENT
|
||||
encode_format: Literal["base64"]
|
||||
mime_type: str
|
||||
data: str
|
||||
|
||||
|
||||
class PromptMessage(ABC, BaseModel):
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from typing import Optional
|
||||
|
||||
|
||||
class InvokeError(Exception):
|
||||
class InvokeError(ValueError):
|
||||
"""Base class for all LLM exceptions."""
|
||||
|
||||
description: Optional[str] = None
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
class CredentialsValidateFailedError(Exception):
|
||||
class CredentialsValidateFailedError(ValueError):
|
||||
"""
|
||||
Credentials validate failed error
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
import base64
|
||||
import io
|
||||
import json
|
||||
from collections.abc import Generator, Sequence
|
||||
from typing import Optional, Union, cast
|
||||
|
|
@ -18,7 +17,6 @@ from anthropic.types import (
|
|||
)
|
||||
from anthropic.types.beta.tools import ToolsBetaMessage
|
||||
from httpx import Timeout
|
||||
from PIL import Image
|
||||
|
||||
from core.model_runtime.callbacks.base_callback import Callback
|
||||
from core.model_runtime.entities import (
|
||||
|
|
@ -498,22 +496,19 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
|||
sub_messages.append(sub_message_dict)
|
||||
elif message_content.type == PromptMessageContentType.IMAGE:
|
||||
message_content = cast(ImagePromptMessageContent, message_content)
|
||||
if not message_content.data.startswith("data:"):
|
||||
if not message_content.base64_data:
|
||||
# fetch image data from url
|
||||
try:
|
||||
image_content = requests.get(message_content.data).content
|
||||
with Image.open(io.BytesIO(image_content)) as img:
|
||||
mime_type = f"image/{img.format.lower()}"
|
||||
image_content = requests.get(message_content.url).content
|
||||
base64_data = base64.b64encode(image_content).decode("utf-8")
|
||||
except Exception as ex:
|
||||
raise ValueError(
|
||||
f"Failed to fetch image data from url {message_content.data}, {ex}"
|
||||
)
|
||||
else:
|
||||
data_split = message_content.data.split(";base64,")
|
||||
mime_type = data_split[0].replace("data:", "")
|
||||
base64_data = data_split[1]
|
||||
base64_data = message_content.base64_data
|
||||
|
||||
mime_type = message_content.mime_type
|
||||
if mime_type not in {"image/jpeg", "image/png", "image/gif", "image/webp"}:
|
||||
raise ValueError(
|
||||
f"Unsupported image type {mime_type}, "
|
||||
|
|
@ -534,9 +529,9 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
|||
sub_message_dict = {
|
||||
"type": "document",
|
||||
"source": {
|
||||
"type": message_content.encode_format,
|
||||
"type": "base64",
|
||||
"media_type": message_content.mime_type,
|
||||
"data": message_content.data,
|
||||
"data": message_content.base64_data,
|
||||
},
|
||||
}
|
||||
sub_messages.append(sub_message_dict)
|
||||
|
|
|
|||
|
|
@ -819,6 +819,82 @@ LLM_BASE_MODELS = [
|
|||
),
|
||||
),
|
||||
),
|
||||
AzureBaseModel(
|
||||
base_model_name="gpt-4o-2024-11-20",
|
||||
entity=AIModelEntity(
|
||||
model="fake-deployment-name",
|
||||
label=I18nObject(
|
||||
en_US="fake-deployment-name-label",
|
||||
),
|
||||
model_type=ModelType.LLM,
|
||||
features=[
|
||||
ModelFeature.AGENT_THOUGHT,
|
||||
ModelFeature.VISION,
|
||||
ModelFeature.MULTI_TOOL_CALL,
|
||||
ModelFeature.STREAM_TOOL_CALL,
|
||||
],
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_properties={
|
||||
ModelPropertyKey.MODE: LLMMode.CHAT.value,
|
||||
ModelPropertyKey.CONTEXT_SIZE: 128000,
|
||||
},
|
||||
parameter_rules=[
|
||||
ParameterRule(
|
||||
name="temperature",
|
||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE],
|
||||
),
|
||||
ParameterRule(
|
||||
name="top_p",
|
||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P],
|
||||
),
|
||||
ParameterRule(
|
||||
name="presence_penalty",
|
||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY],
|
||||
),
|
||||
ParameterRule(
|
||||
name="frequency_penalty",
|
||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY],
|
||||
),
|
||||
_get_max_tokens(default=512, min_val=1, max_val=16384),
|
||||
ParameterRule(
|
||||
name="seed",
|
||||
label=I18nObject(zh_Hans="种子", en_US="Seed"),
|
||||
type="int",
|
||||
help=AZURE_DEFAULT_PARAM_SEED_HELP,
|
||||
required=False,
|
||||
precision=2,
|
||||
min=0,
|
||||
max=1,
|
||||
),
|
||||
ParameterRule(
|
||||
name="response_format",
|
||||
label=I18nObject(zh_Hans="回复格式", en_US="response_format"),
|
||||
type="string",
|
||||
help=I18nObject(
|
||||
zh_Hans="指定模型必须输出的格式", en_US="specifying the format that the model must output"
|
||||
),
|
||||
required=False,
|
||||
options=["text", "json_object", "json_schema"],
|
||||
),
|
||||
ParameterRule(
|
||||
name="json_schema",
|
||||
label=I18nObject(en_US="JSON Schema"),
|
||||
type="text",
|
||||
help=I18nObject(
|
||||
zh_Hans="设置返回的json schema,llm将按照它返回",
|
||||
en_US="Set a response json schema will ensure LLM to adhere it.",
|
||||
),
|
||||
required=False,
|
||||
),
|
||||
],
|
||||
pricing=PriceConfig(
|
||||
input=5.00,
|
||||
output=15.00,
|
||||
unit=0.000001,
|
||||
currency="USD",
|
||||
),
|
||||
),
|
||||
),
|
||||
AzureBaseModel(
|
||||
base_model_name="gpt-4-turbo",
|
||||
entity=AIModelEntity(
|
||||
|
|
|
|||
|
|
@ -86,6 +86,9 @@ model_credential_schema:
|
|||
- label:
|
||||
en_US: '2024-06-01'
|
||||
value: '2024-06-01'
|
||||
- label:
|
||||
en_US: '2024-10-21'
|
||||
value: '2024-10-21'
|
||||
placeholder:
|
||||
zh_Hans: 在此选择您的 API 版本
|
||||
en_US: Select your API Version here
|
||||
|
|
@ -168,6 +171,12 @@ model_credential_schema:
|
|||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
- label:
|
||||
en_US: gpt-4o-2024-11-20
|
||||
value: gpt-4o-2024-11-20
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
- label:
|
||||
en_US: gpt-4-turbo
|
||||
value: gpt-4-turbo
|
||||
|
|
|
|||
|
|
@ -92,7 +92,10 @@ class AzureOpenAITextEmbeddingModel(_CommonAzureOpenAI, TextEmbeddingModel):
|
|||
average = embeddings_batch[0]
|
||||
else:
|
||||
average = np.average(_result, axis=0, weights=num_tokens_in_batch[i])
|
||||
embeddings[i] = (average / np.linalg.norm(average)).tolist()
|
||||
embedding = (average / np.linalg.norm(average)).tolist()
|
||||
if np.isnan(embedding).any():
|
||||
raise ValueError("Normalized embedding is nan please try again")
|
||||
embeddings[i] = embedding
|
||||
|
||||
# calc usage
|
||||
usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens)
|
||||
|
|
|
|||
|
|
@ -1,11 +1,19 @@
|
|||
from collections.abc import Mapping
|
||||
|
||||
import boto3
|
||||
from botocore.config import Config
|
||||
|
||||
from core.model_runtime.errors.invoke import InvokeBadRequestError
|
||||
|
||||
|
||||
def get_bedrock_client(service_name: str, credentials: Mapping[str, str]):
|
||||
region_name = credentials.get("aws_region")
|
||||
if not region_name:
|
||||
raise InvokeBadRequestError("aws_region is required")
|
||||
client_config = Config(region_name=region_name)
|
||||
aws_access_key_id = credentials.get("aws_access_key_id")
|
||||
aws_secret_access_key = credentials.get("aws_secret_access_key")
|
||||
|
||||
def get_bedrock_client(service_name, credentials=None):
|
||||
client_config = Config(region_name=credentials["aws_region"])
|
||||
aws_access_key_id = credentials["aws_access_key_id"]
|
||||
aws_secret_access_key = credentials["aws_secret_access_key"]
|
||||
if aws_access_key_id and aws_secret_access_key:
|
||||
# use aksk to call bedrock
|
||||
client = boto3.client(
|
||||
|
|
|
|||
|
|
@ -62,7 +62,10 @@ class BedrockRerankModel(RerankModel):
|
|||
}
|
||||
)
|
||||
modelId = model
|
||||
region = credentials["aws_region"]
|
||||
region = credentials.get("aws_region")
|
||||
# region is a required field
|
||||
if not region:
|
||||
raise InvokeBadRequestError("aws_region is required in credentials")
|
||||
model_package_arn = f"arn:aws:bedrock:{region}::foundation-model/{modelId}"
|
||||
rerankingConfiguration = {
|
||||
"type": "BEDROCK_RERANKING_MODEL",
|
||||
|
|
|
|||
|
|
@ -88,7 +88,10 @@ class CohereTextEmbeddingModel(TextEmbeddingModel):
|
|||
average = embeddings_batch[0]
|
||||
else:
|
||||
average = np.average(_result, axis=0, weights=num_tokens_in_batch[i])
|
||||
embeddings[i] = (average / np.linalg.norm(average)).tolist()
|
||||
embedding = (average / np.linalg.norm(average)).tolist()
|
||||
if np.isnan(embedding).any():
|
||||
raise ValueError("Normalized embedding is nan please try again")
|
||||
embeddings[i] = embedding
|
||||
|
||||
# calc usage
|
||||
usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,93 @@
|
|||
model: InternVL2-8B
|
||||
label:
|
||||
en_US: InternVL2-8B
|
||||
model_type: llm
|
||||
features:
|
||||
- vision
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32000
|
||||
parameter_rules:
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
label:
|
||||
en_US: "Max Tokens"
|
||||
zh_Hans: "最大Token数"
|
||||
type: int
|
||||
default: 512
|
||||
min: 1
|
||||
required: true
|
||||
help:
|
||||
en_US: "The maximum number of tokens that can be generated by the model varies depending on the model."
|
||||
zh_Hans: "模型可生成的最大 token 个数,不同模型上限不同。"
|
||||
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
label:
|
||||
en_US: "Temperature"
|
||||
zh_Hans: "采样温度"
|
||||
type: float
|
||||
default: 0.7
|
||||
min: 0.0
|
||||
max: 1.0
|
||||
precision: 1
|
||||
required: true
|
||||
help:
|
||||
en_US: "The randomness of the sampling temperature control output. The temperature value is within the range of [0.0, 1.0]. The higher the value, the more random and creative the output; the lower the value, the more stable it is. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time."
|
||||
zh_Hans: "采样温度控制输出的随机性。温度值在 [0.0, 1.0] 范围内,值越高,输出越随机和创造性;值越低,输出越稳定。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。"
|
||||
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
label:
|
||||
en_US: "Top P"
|
||||
zh_Hans: "Top P"
|
||||
type: float
|
||||
default: 0.7
|
||||
min: 0.0
|
||||
max: 1.0
|
||||
precision: 1
|
||||
required: true
|
||||
help:
|
||||
en_US: "The value range of the sampling method is [0.0, 1.0]. The top_p value determines that the model selects tokens from the top p% of candidate words with the highest probability; when top_p is 0, this parameter is invalid. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time."
|
||||
zh_Hans: "采样方法的取值范围为 [0.0,1.0]。top_p 值确定模型从概率最高的前p%的候选词中选取 tokens;当 top_p 为 0 时,此参数无效。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。"
|
||||
|
||||
- name: top_k
|
||||
use_template: top_k
|
||||
label:
|
||||
en_US: "Top K"
|
||||
zh_Hans: "Top K"
|
||||
type: int
|
||||
default: 50
|
||||
min: 0
|
||||
max: 100
|
||||
required: true
|
||||
help:
|
||||
en_US: "The value range is [0,100], which limits the model to only select from the top k words with the highest probability when choosing the next word at each step. The larger the value, the more diverse text generation will be."
|
||||
zh_Hans: "取值范围为 [0,100],限制模型在每一步选择下一个词时,只从概率最高的前 k 个词中选取。数值越大,文本生成越多样。"
|
||||
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
label:
|
||||
en_US: "Frequency Penalty"
|
||||
zh_Hans: "频率惩罚"
|
||||
type: float
|
||||
default: 0
|
||||
min: -1.0
|
||||
max: 1.0
|
||||
precision: 1
|
||||
required: false
|
||||
help:
|
||||
en_US: "Used to adjust the frequency of repeated content in automatically generated text. Positive numbers reduce repetition, while negative numbers increase repetition. After setting this parameter, if a word has already appeared in the text, the model will decrease the probability of choosing that word for subsequent generation."
|
||||
zh_Hans: "用于调整自动生成文本中重复内容的频率。正数减少重复,负数增加重复。设置此参数后,如果一个词在文本中已经出现过,模型在后续生成中选择该词的概率会降低。"
|
||||
|
||||
- name: user
|
||||
use_template: text
|
||||
label:
|
||||
en_US: "User"
|
||||
zh_Hans: "用户"
|
||||
type: string
|
||||
required: false
|
||||
help:
|
||||
en_US: "Used to track and differentiate conversation requests from different users."
|
||||
zh_Hans: "用于追踪和区分不同用户的对话请求。"
|
||||
|
|
@ -0,0 +1,93 @@
|
|||
model: InternVL2.5-26B
|
||||
label:
|
||||
en_US: InternVL2.5-26B
|
||||
model_type: llm
|
||||
features:
|
||||
- vision
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32000
|
||||
parameter_rules:
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
label:
|
||||
en_US: "Max Tokens"
|
||||
zh_Hans: "最大Token数"
|
||||
type: int
|
||||
default: 512
|
||||
min: 1
|
||||
required: true
|
||||
help:
|
||||
en_US: "The maximum number of tokens that can be generated by the model varies depending on the model."
|
||||
zh_Hans: "模型可生成的最大 token 个数,不同模型上限不同。"
|
||||
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
label:
|
||||
en_US: "Temperature"
|
||||
zh_Hans: "采样温度"
|
||||
type: float
|
||||
default: 0.7
|
||||
min: 0.0
|
||||
max: 1.0
|
||||
precision: 1
|
||||
required: true
|
||||
help:
|
||||
en_US: "The randomness of the sampling temperature control output. The temperature value is within the range of [0.0, 1.0]. The higher the value, the more random and creative the output; the lower the value, the more stable it is. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time."
|
||||
zh_Hans: "采样温度控制输出的随机性。温度值在 [0.0, 1.0] 范围内,值越高,输出越随机和创造性;值越低,输出越稳定。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。"
|
||||
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
label:
|
||||
en_US: "Top P"
|
||||
zh_Hans: "Top P"
|
||||
type: float
|
||||
default: 0.7
|
||||
min: 0.0
|
||||
max: 1.0
|
||||
precision: 1
|
||||
required: true
|
||||
help:
|
||||
en_US: "The value range of the sampling method is [0.0, 1.0]. The top_p value determines that the model selects tokens from the top p% of candidate words with the highest probability; when top_p is 0, this parameter is invalid. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time."
|
||||
zh_Hans: "采样方法的取值范围为 [0.0,1.0]。top_p 值确定模型从概率最高的前p%的候选词中选取 tokens;当 top_p 为 0 时,此参数无效。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。"
|
||||
|
||||
- name: top_k
|
||||
use_template: top_k
|
||||
label:
|
||||
en_US: "Top K"
|
||||
zh_Hans: "Top K"
|
||||
type: int
|
||||
default: 50
|
||||
min: 0
|
||||
max: 100
|
||||
required: true
|
||||
help:
|
||||
en_US: "The value range is [0,100], which limits the model to only select from the top k words with the highest probability when choosing the next word at each step. The larger the value, the more diverse text generation will be."
|
||||
zh_Hans: "取值范围为 [0,100],限制模型在每一步选择下一个词时,只从概率最高的前 k 个词中选取。数值越大,文本生成越多样。"
|
||||
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
label:
|
||||
en_US: "Frequency Penalty"
|
||||
zh_Hans: "频率惩罚"
|
||||
type: float
|
||||
default: 0
|
||||
min: -1.0
|
||||
max: 1.0
|
||||
precision: 1
|
||||
required: false
|
||||
help:
|
||||
en_US: "Used to adjust the frequency of repeated content in automatically generated text. Positive numbers reduce repetition, while negative numbers increase repetition. After setting this parameter, if a word has already appeared in the text, the model will decrease the probability of choosing that word for subsequent generation."
|
||||
zh_Hans: "用于调整自动生成文本中重复内容的频率。正数减少重复,负数增加重复。设置此参数后,如果一个词在文本中已经出现过,模型在后续生成中选择该词的概率会降低。"
|
||||
|
||||
- name: user
|
||||
use_template: text
|
||||
label:
|
||||
en_US: "User"
|
||||
zh_Hans: "用户"
|
||||
type: string
|
||||
required: false
|
||||
help:
|
||||
en_US: "Used to track and differentiate conversation requests from different users."
|
||||
zh_Hans: "用于追踪和区分不同用户的对话请求。"
|
||||
|
|
@ -6,3 +6,5 @@
|
|||
- deepseek-coder-33B-instruct-chat
|
||||
- deepseek-coder-33B-instruct-completions
|
||||
- codegeex4-all-9b
|
||||
- InternVL2.5-26B
|
||||
- InternVL2-8B
|
||||
|
|
|
|||
|
|
@ -29,18 +29,26 @@ class GiteeAILargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
|||
user: Optional[str] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
self._add_custom_parameters(credentials, model, model_parameters)
|
||||
return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
|
||||
return super()._invoke(
|
||||
GiteeAILargeLanguageModel.MODEL_TO_IDENTITY.get(model, model),
|
||||
credentials,
|
||||
prompt_messages,
|
||||
model_parameters,
|
||||
tools,
|
||||
stop,
|
||||
stream,
|
||||
user,
|
||||
)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
self._add_custom_parameters(credentials, None)
|
||||
super().validate_credentials(model, credentials)
|
||||
self._add_custom_parameters(credentials, model, None)
|
||||
super().validate_credentials(GiteeAILargeLanguageModel.MODEL_TO_IDENTITY.get(model, model), credentials)
|
||||
|
||||
def _add_custom_parameters(self, credentials: dict, model: Optional[str]) -> None:
|
||||
def _add_custom_parameters(self, credentials: dict, model: Optional[str], model_parameters: dict) -> None:
|
||||
if model is None:
|
||||
model = "Qwen2-72B-Instruct"
|
||||
|
||||
model_identity = GiteeAILargeLanguageModel.MODEL_TO_IDENTITY.get(model, model)
|
||||
credentials["endpoint_url"] = f"https://ai.gitee.com/api/serverless/{model_identity}/"
|
||||
credentials["endpoint_url"] = "https://ai.gitee.com/v1"
|
||||
if model.endswith("completions"):
|
||||
credentials["mode"] = LLMMode.COMPLETION.value
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
- gemini-2.0-flash-exp
|
||||
- gemini-2.0-flash-thinking-exp-1219
|
||||
- gemini-1.5-pro
|
||||
- gemini-1.5-pro-latest
|
||||
- gemini-1.5-pro-001
|
||||
|
|
@ -11,6 +13,8 @@
|
|||
- gemini-1.5-flash-exp-0827
|
||||
- gemini-1.5-flash-8b-exp-0827
|
||||
- gemini-1.5-flash-8b-exp-0924
|
||||
- gemini-exp-1206
|
||||
- gemini-exp-1121
|
||||
- gemini-exp-1114
|
||||
- gemini-pro
|
||||
- gemini-pro-vision
|
||||
|
|
|
|||
|
|
@ -8,6 +8,8 @@ features:
|
|||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
- video
|
||||
- audio
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 1048576
|
||||
|
|
|
|||
|
|
@ -8,6 +8,8 @@ features:
|
|||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
- video
|
||||
- audio
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 1048576
|
||||
|
|
|
|||
|
|
@ -8,6 +8,8 @@ features:
|
|||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
- video
|
||||
- audio
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 1048576
|
||||
|
|
|
|||
|
|
@ -8,6 +8,8 @@ features:
|
|||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
- video
|
||||
- audio
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 1048576
|
||||
|
|
|
|||
|
|
@ -8,6 +8,8 @@ features:
|
|||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
- video
|
||||
- audio
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 1048576
|
||||
|
|
|
|||
|
|
@ -8,6 +8,8 @@ features:
|
|||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
- video
|
||||
- audio
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 1048576
|
||||
|
|
|
|||
|
|
@ -8,6 +8,8 @@ features:
|
|||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
- video
|
||||
- audio
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 1048576
|
||||
|
|
|
|||
|
|
@ -8,6 +8,8 @@ features:
|
|||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
- video
|
||||
- audio
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 2097152
|
||||
|
|
|
|||
|
|
@ -8,6 +8,8 @@ features:
|
|||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
- video
|
||||
- audio
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 2097152
|
||||
|
|
|
|||
|
|
@ -8,6 +8,8 @@ features:
|
|||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
- video
|
||||
- audio
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 2097152
|
||||
|
|
|
|||
|
|
@ -8,6 +8,8 @@ features:
|
|||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
- video
|
||||
- audio
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 2097152
|
||||
|
|
|
|||
|
|
@ -8,6 +8,8 @@ features:
|
|||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
- video
|
||||
- audio
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 2097152
|
||||
|
|
|
|||
|
|
@ -8,6 +8,8 @@ features:
|
|||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
- video
|
||||
- audio
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 2097152
|
||||
|
|
|
|||
|
|
@ -8,6 +8,8 @@ features:
|
|||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
- video
|
||||
- audio
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 1048576
|
||||
|
|
|
|||
|
|
@ -0,0 +1,39 @@
|
|||
model: gemini-2.0-flash-thinking-exp-1219
|
||||
label:
|
||||
en_US: Gemini 2.0 Flash Thinking Exp 1219
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- vision
|
||||
- document
|
||||
- video
|
||||
- audio
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32767
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_output_tokens
|
||||
use_template: max_tokens
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: json_schema
|
||||
use_template: json_schema
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
|
|
@ -8,6 +8,8 @@ features:
|
|||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
- video
|
||||
- audio
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32767
|
||||
|
|
|
|||
|
|
@ -7,6 +7,9 @@ features:
|
|||
- vision
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
- video
|
||||
- audio
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32767
|
||||
|
|
|
|||
|
|
@ -7,6 +7,9 @@ features:
|
|||
- vision
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
- video
|
||||
- audio
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 2097152
|
||||
|
|
|
|||
|
|
@ -7,6 +7,9 @@ features:
|
|||
- vision
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
- video
|
||||
- audio
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32767
|
||||
|
|
|
|||
|
|
@ -1,27 +1,27 @@
|
|||
import base64
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union, cast
|
||||
from typing import Optional, Union
|
||||
|
||||
import google.ai.generativelanguage as glm
|
||||
import google.generativeai as genai
|
||||
import requests
|
||||
from google.api_core import exceptions
|
||||
from google.generativeai.client import _ClientManager
|
||||
from google.generativeai.types import ContentType, GenerateContentResponse
|
||||
from google.generativeai.types import ContentType, File, GenerateContentResponse
|
||||
from google.generativeai.types.content_types import to_part
|
||||
from PIL import Image
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
DocumentPromptMessageContent,
|
||||
ImagePromptMessageContent,
|
||||
PromptMessage,
|
||||
PromptMessageContent,
|
||||
PromptMessageContentType,
|
||||
PromptMessageTool,
|
||||
SystemPromptMessage,
|
||||
TextPromptMessageContent,
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
|
|
@ -35,21 +35,7 @@ from core.model_runtime.errors.invoke import (
|
|||
)
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
|
||||
GOOGLE_AVAILABLE_MIMETYPE = [
|
||||
"application/pdf",
|
||||
"application/x-javascript",
|
||||
"text/javascript",
|
||||
"application/x-python",
|
||||
"text/x-python",
|
||||
"text/plain",
|
||||
"text/html",
|
||||
"text/css",
|
||||
"text/md",
|
||||
"text/csv",
|
||||
"text/xml",
|
||||
"text/rtf",
|
||||
]
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
|
||||
class GoogleLargeLanguageModel(LargeLanguageModel):
|
||||
|
|
@ -158,7 +144,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
|||
"""
|
||||
|
||||
try:
|
||||
ping_message = SystemPromptMessage(content="ping")
|
||||
ping_message = UserPromptMessage(content="ping")
|
||||
self._generate(model, credentials, [ping_message], {"max_output_tokens": 5})
|
||||
|
||||
except Exception as ex:
|
||||
|
|
@ -201,30 +187,24 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
|||
if stop:
|
||||
config_kwargs["stop_sequences"] = stop
|
||||
|
||||
google_model = genai.GenerativeModel(model_name=model)
|
||||
genai.configure(api_key=credentials["google_api_key"])
|
||||
|
||||
history = []
|
||||
system_instruction = None
|
||||
|
||||
# hack for gemini-pro-vision, which currently does not support multi-turn chat
|
||||
if model == "gemini-pro-vision":
|
||||
last_msg = prompt_messages[-1]
|
||||
content = self._format_message_to_glm_content(last_msg)
|
||||
history.append(content)
|
||||
else:
|
||||
for msg in prompt_messages: # makes message roles strictly alternating
|
||||
content = self._format_message_to_glm_content(msg)
|
||||
if history and history[-1]["role"] == content["role"]:
|
||||
history[-1]["parts"].extend(content["parts"])
|
||||
else:
|
||||
history.append(content)
|
||||
for msg in prompt_messages: # makes message roles strictly alternating
|
||||
content = self._format_message_to_glm_content(msg)
|
||||
if history and history[-1]["role"] == content["role"]:
|
||||
history[-1]["parts"].extend(content["parts"])
|
||||
elif content["role"] == "system":
|
||||
system_instruction = content["parts"][0]
|
||||
else:
|
||||
history.append(content)
|
||||
|
||||
# Create a new ClientManager with tenant's API key
|
||||
new_client_manager = _ClientManager()
|
||||
new_client_manager.configure(api_key=credentials["google_api_key"])
|
||||
new_custom_client = new_client_manager.make_client("generative")
|
||||
|
||||
google_model._client = new_custom_client
|
||||
if not history:
|
||||
raise InvokeError("The user prompt message is required. You only add a system prompt message.")
|
||||
|
||||
google_model = genai.GenerativeModel(model_name=model, system_instruction=system_instruction)
|
||||
response = google_model.generate_content(
|
||||
contents=history,
|
||||
generation_config=genai.types.GenerationConfig(**config_kwargs),
|
||||
|
|
@ -317,8 +297,12 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
|||
)
|
||||
else:
|
||||
# calculate num tokens
|
||||
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
|
||||
completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])
|
||||
if hasattr(response, "usage_metadata") and response.usage_metadata:
|
||||
prompt_tokens = response.usage_metadata.prompt_token_count
|
||||
completion_tokens = response.usage_metadata.candidates_token_count
|
||||
else:
|
||||
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
|
||||
completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])
|
||||
|
||||
# transform usage
|
||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
||||
|
|
@ -346,7 +330,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
|||
|
||||
content = message.content
|
||||
if isinstance(content, list):
|
||||
content = "".join(c.data for c in content if c.type != PromptMessageContentType.IMAGE)
|
||||
content = "".join(c.data for c in content if c.type == PromptMessageContentType.TEXT)
|
||||
|
||||
if isinstance(message, UserPromptMessage):
|
||||
message_text = f"{human_prompt} {content}"
|
||||
|
|
@ -359,6 +343,40 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
|||
|
||||
return message_text
|
||||
|
||||
def _upload_file_content_to_google(self, message_content: PromptMessageContent) -> File:
|
||||
key = f"{message_content.type.value}:{hash(message_content.data)}"
|
||||
if redis_client.exists(key):
|
||||
try:
|
||||
return genai.get_file(redis_client.get(key).decode())
|
||||
except:
|
||||
pass
|
||||
with tempfile.NamedTemporaryFile(delete=False) as temp_file:
|
||||
if message_content.base64_data:
|
||||
file_content = base64.b64decode(message_content.base64_data)
|
||||
temp_file.write(file_content)
|
||||
else:
|
||||
try:
|
||||
response = requests.get(message_content.url)
|
||||
response.raise_for_status()
|
||||
temp_file.write(response.content)
|
||||
except Exception as ex:
|
||||
raise ValueError(f"Failed to fetch data from url {message_content.url}, {ex}")
|
||||
temp_file.flush()
|
||||
|
||||
file = genai.upload_file(path=temp_file.name, mime_type=message_content.mime_type)
|
||||
while file.state.name == "PROCESSING":
|
||||
time.sleep(5)
|
||||
file = genai.get_file(file.name)
|
||||
# google will delete your upload files in 2 days.
|
||||
redis_client.setex(key, 47 * 60 * 60, file.name)
|
||||
|
||||
try:
|
||||
os.unlink(temp_file.name)
|
||||
except PermissionError:
|
||||
# windows may raise permission error
|
||||
pass
|
||||
return file
|
||||
|
||||
def _format_message_to_glm_content(self, message: PromptMessage) -> ContentType:
|
||||
"""
|
||||
Format a single message into glm.Content for Google API
|
||||
|
|
@ -374,28 +392,8 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
|||
for c in message.content:
|
||||
if c.type == PromptMessageContentType.TEXT:
|
||||
glm_content["parts"].append(to_part(c.data))
|
||||
elif c.type == PromptMessageContentType.IMAGE:
|
||||
message_content = cast(ImagePromptMessageContent, c)
|
||||
if message_content.data.startswith("data:"):
|
||||
metadata, base64_data = c.data.split(",", 1)
|
||||
mime_type = metadata.split(";", 1)[0].split(":")[1]
|
||||
else:
|
||||
# fetch image data from url
|
||||
try:
|
||||
image_content = requests.get(message_content.data).content
|
||||
with Image.open(io.BytesIO(image_content)) as img:
|
||||
mime_type = f"image/{img.format.lower()}"
|
||||
base64_data = base64.b64encode(image_content).decode("utf-8")
|
||||
except Exception as ex:
|
||||
raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}")
|
||||
blob = {"inline_data": {"mime_type": mime_type, "data": base64_data}}
|
||||
glm_content["parts"].append(blob)
|
||||
elif c.type == PromptMessageContentType.DOCUMENT:
|
||||
message_content = cast(DocumentPromptMessageContent, c)
|
||||
if message_content.mime_type not in GOOGLE_AVAILABLE_MIMETYPE:
|
||||
raise ValueError(f"Unsupported mime type {message_content.mime_type}")
|
||||
blob = {"inline_data": {"mime_type": message_content.mime_type, "data": message_content.data}}
|
||||
glm_content["parts"].append(blob)
|
||||
else:
|
||||
glm_content["parts"].append(self._upload_file_content_to_google(c))
|
||||
|
||||
return glm_content
|
||||
elif isinstance(message, AssistantPromptMessage):
|
||||
|
|
@ -413,7 +411,10 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
|||
)
|
||||
return glm_content
|
||||
elif isinstance(message, SystemPromptMessage):
|
||||
return {"role": "user", "parts": [to_part(message.content)]}
|
||||
if isinstance(message.content, list):
|
||||
text_contents = filter(lambda c: isinstance(c, TextPromptMessageContent), message.content)
|
||||
message.content = "".join(c.data for c in text_contents)
|
||||
return {"role": "system", "parts": [to_part(message.content)]}
|
||||
elif isinstance(message, ToolPromptMessage):
|
||||
return {
|
||||
"role": "function",
|
||||
|
|
|
|||
|
|
@ -3,8 +3,8 @@ label:
|
|||
zh_Hans: 腾讯混元
|
||||
en_US: Hunyuan
|
||||
description:
|
||||
en_US: Models provided by Tencent Hunyuan, such as hunyuan-standard, hunyuan-standard-256k, hunyuan-pro and hunyuan-lite.
|
||||
zh_Hans: 腾讯混元提供的模型,例如 hunyuan-standard、 hunyuan-standard-256k, hunyuan-pro 和 hunyuan-lite。
|
||||
en_US: Models provided by Tencent Hunyuan, such as hunyuan-standard, hunyuan-standard-256k, hunyuan-pro, hunyuan-role, hunyuan-large, hunyuan-large-role, hunyuan-turbo-latest, hunyuan-large-longcontext, hunyuan-turbo, hunyuan-vision, hunyuan-turbo-vision, hunyuan-functioncall and hunyuan-lite.
|
||||
zh_Hans: 腾讯混元提供的模型,例如 hunyuan-standard、 hunyuan-standard-256k, hunyuan-pro, hunyuan-role, hunyuan-large, hunyuan-large-role, hunyuan-turbo-latest, hunyuan-large-longcontext, hunyuan-turbo, hunyuan-vision, hunyuan-turbo-vision, hunyuan-functioncall 和 hunyuan-lite。
|
||||
icon_small:
|
||||
en_US: icon_s_en.png
|
||||
icon_large:
|
||||
|
|
|
|||
|
|
@ -4,3 +4,10 @@
|
|||
- hunyuan-pro
|
||||
- hunyuan-turbo
|
||||
- hunyuan-vision
|
||||
- hunyuan-role
|
||||
- hunyuan-large
|
||||
- hunyuan-large-role
|
||||
- hunyuan-large-longcontext
|
||||
- hunyuan-turbo-latest
|
||||
- hunyuan-turbo-vision
|
||||
- hunyuan-functioncall
|
||||
|
|
|
|||
|
|
@ -0,0 +1,38 @@
|
|||
model: hunyuan-functioncall
|
||||
label:
|
||||
zh_Hans: hunyuan-functioncall
|
||||
en_US: hunyuan-functioncall
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- tool-call
|
||||
- multi-tool-call
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 1024
|
||||
min: 1
|
||||
max: 32000
|
||||
- name: enable_enhance
|
||||
label:
|
||||
zh_Hans: 功能增强
|
||||
en_US: Enable Enhancement
|
||||
type: boolean
|
||||
help:
|
||||
zh_Hans: 功能增强(如搜索)开关,关闭时将直接由主模型生成回复内容,可以降低响应时延(对于流式输出时的首字时延尤为明显)。但在少数场景里,回复效果可能会下降。
|
||||
en_US: Allow the model to perform external search to enhance the generation results.
|
||||
required: false
|
||||
default: true
|
||||
pricing:
|
||||
input: '0.004'
|
||||
output: '0.008'
|
||||
unit: '0.001'
|
||||
currency: RMB
|
||||
|
|
@ -0,0 +1,38 @@
|
|||
model: hunyuan-large-longcontext
|
||||
label:
|
||||
zh_Hans: hunyuan-large-longcontext
|
||||
en_US: hunyuan-large-longcontext
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- tool-call
|
||||
- multi-tool-call
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 134000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 1024
|
||||
min: 1
|
||||
max: 134000
|
||||
- name: enable_enhance
|
||||
label:
|
||||
zh_Hans: 功能增强
|
||||
en_US: Enable Enhancement
|
||||
type: boolean
|
||||
help:
|
||||
zh_Hans: 功能增强(如搜索)开关,关闭时将直接由主模型生成回复内容,可以降低响应时延(对于流式输出时的首字时延尤为明显)。但在少数场景里,回复效果可能会下降。
|
||||
en_US: Allow the model to perform external search to enhance the generation results.
|
||||
required: false
|
||||
default: true
|
||||
pricing:
|
||||
input: '0.006'
|
||||
output: '0.018'
|
||||
unit: '0.001'
|
||||
currency: RMB
|
||||
|
|
@ -0,0 +1,38 @@
|
|||
model: hunyuan-large-role
|
||||
label:
|
||||
zh_Hans: hunyuan-large-role
|
||||
en_US: hunyuan-large-role
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- tool-call
|
||||
- multi-tool-call
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 1024
|
||||
min: 1
|
||||
max: 32000
|
||||
- name: enable_enhance
|
||||
label:
|
||||
zh_Hans: 功能增强
|
||||
en_US: Enable Enhancement
|
||||
type: boolean
|
||||
help:
|
||||
zh_Hans: 功能增强(如搜索)开关,关闭时将直接由主模型生成回复内容,可以降低响应时延(对于流式输出时的首字时延尤为明显)。但在少数场景里,回复效果可能会下降。
|
||||
en_US: Allow the model to perform external search to enhance the generation results.
|
||||
required: false
|
||||
default: true
|
||||
pricing:
|
||||
input: '0.004'
|
||||
output: '0.008'
|
||||
unit: '0.001'
|
||||
currency: RMB
|
||||
|
|
@ -0,0 +1,38 @@
|
|||
model: hunyuan-large
|
||||
label:
|
||||
zh_Hans: hunyuan-large
|
||||
en_US: hunyuan-large
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- tool-call
|
||||
- multi-tool-call
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 1024
|
||||
min: 1
|
||||
max: 32000
|
||||
- name: enable_enhance
|
||||
label:
|
||||
zh_Hans: 功能增强
|
||||
en_US: Enable Enhancement
|
||||
type: boolean
|
||||
help:
|
||||
zh_Hans: 功能增强(如搜索)开关,关闭时将直接由主模型生成回复内容,可以降低响应时延(对于流式输出时的首字时延尤为明显)。但在少数场景里,回复效果可能会下降。
|
||||
en_US: Allow the model to perform external search to enhance the generation results.
|
||||
required: false
|
||||
default: true
|
||||
pricing:
|
||||
input: '0.004'
|
||||
output: '0.012'
|
||||
unit: '0.001'
|
||||
currency: RMB
|
||||
|
|
@ -0,0 +1,38 @@
|
|||
model: hunyuan-role
|
||||
label:
|
||||
zh_Hans: hunyuan-role
|
||||
en_US: hunyuan-role
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- tool-call
|
||||
- multi-tool-call
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 1024
|
||||
min: 1
|
||||
max: 32000
|
||||
- name: enable_enhance
|
||||
label:
|
||||
zh_Hans: 功能增强
|
||||
en_US: Enable Enhancement
|
||||
type: boolean
|
||||
help:
|
||||
zh_Hans: 功能增强(如搜索)开关,关闭时将直接由主模型生成回复内容,可以降低响应时延(对于流式输出时的首字时延尤为明显)。但在少数场景里,回复效果可能会下降。
|
||||
en_US: Allow the model to perform external search to enhance the generation results.
|
||||
required: false
|
||||
default: true
|
||||
pricing:
|
||||
input: '0.004'
|
||||
output: '0.008'
|
||||
unit: '0.001'
|
||||
currency: RMB
|
||||
|
|
@ -0,0 +1,38 @@
|
|||
model: hunyuan-turbo-latest
|
||||
label:
|
||||
zh_Hans: hunyuan-turbo-latest
|
||||
en_US: hunyuan-turbo-latest
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- tool-call
|
||||
- multi-tool-call
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 1024
|
||||
min: 1
|
||||
max: 32000
|
||||
- name: enable_enhance
|
||||
label:
|
||||
zh_Hans: 功能增强
|
||||
en_US: Enable Enhancement
|
||||
type: boolean
|
||||
help:
|
||||
zh_Hans: 功能增强(如搜索)开关,关闭时将直接由主模型生成回复内容,可以降低响应时延(对于流式输出时的首字时延尤为明显)。但在少数场景里,回复效果可能会下降。
|
||||
en_US: Allow the model to perform external search to enhance the generation results.
|
||||
required: false
|
||||
default: true
|
||||
pricing:
|
||||
input: '0.015'
|
||||
output: '0.05'
|
||||
unit: '0.001'
|
||||
currency: RMB
|
||||
|
|
@ -0,0 +1,39 @@
|
|||
model: hunyuan-turbo-vision
|
||||
label:
|
||||
zh_Hans: hunyuan-turbo-vision
|
||||
en_US: hunyuan-turbo-vision
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- tool-call
|
||||
- multi-tool-call
|
||||
- stream-tool-call
|
||||
- vision
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 8000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 1024
|
||||
min: 1
|
||||
max: 8000
|
||||
- name: enable_enhance
|
||||
label:
|
||||
zh_Hans: 功能增强
|
||||
en_US: Enable Enhancement
|
||||
type: boolean
|
||||
help:
|
||||
zh_Hans: 功能增强(如搜索)开关,关闭时将直接由主模型生成回复内容,可以降低响应时延(对于流式输出时的首字时延尤为明显)。但在少数场景里,回复效果可能会下降。
|
||||
en_US: Allow the model to perform external search to enhance the generation results.
|
||||
required: false
|
||||
default: true
|
||||
pricing:
|
||||
input: '0.08'
|
||||
output: '0.08'
|
||||
unit: '0.001'
|
||||
currency: RMB
|
||||
|
|
@ -1,4 +1,7 @@
|
|||
- gpt-4o-audio-preview
|
||||
- o1
|
||||
- o1-2024-12-17
|
||||
- o1-mini
|
||||
- o1-mini-2024-09-12
|
||||
- gpt-4
|
||||
- gpt-4o
|
||||
- gpt-4o-2024-05-13
|
||||
|
|
@ -7,10 +10,6 @@
|
|||
- chatgpt-4o-latest
|
||||
- gpt-4o-mini
|
||||
- gpt-4o-mini-2024-07-18
|
||||
- o1-preview
|
||||
- o1-preview-2024-09-12
|
||||
- o1-mini
|
||||
- o1-mini-2024-09-12
|
||||
- gpt-4-turbo
|
||||
- gpt-4-turbo-2024-04-09
|
||||
- gpt-4-turbo-preview
|
||||
|
|
@ -25,4 +24,7 @@
|
|||
- gpt-3.5-turbo-1106
|
||||
- gpt-3.5-turbo-0613
|
||||
- gpt-3.5-turbo-instruct
|
||||
- gpt-4o-audio-preview
|
||||
- o1-preview
|
||||
- o1-preview-2024-09-12
|
||||
- text-davinci-003
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ parameter_rules:
|
|||
use_template: frequency_penalty
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 512
|
||||
default: 16384
|
||||
min: 1
|
||||
max: 16384
|
||||
- name: response_format
|
||||
|
|
|
|||
|
|
@ -22,9 +22,9 @@ parameter_rules:
|
|||
use_template: frequency_penalty
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 512
|
||||
default: 16384
|
||||
min: 1
|
||||
max: 4096
|
||||
max: 16384
|
||||
- name: response_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ parameter_rules:
|
|||
use_template: frequency_penalty
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 512
|
||||
default: 16384
|
||||
min: 1
|
||||
max: 16384
|
||||
- name: response_format
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ parameter_rules:
|
|||
use_template: frequency_penalty
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 512
|
||||
default: 16384
|
||||
min: 1
|
||||
max: 16384
|
||||
- name: response_format
|
||||
|
|
|
|||
|
|
@ -22,9 +22,9 @@ parameter_rules:
|
|||
use_template: frequency_penalty
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 512
|
||||
default: 16384
|
||||
min: 1
|
||||
max: 4096
|
||||
max: 16384
|
||||
- name: response_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ parameter_rules:
|
|||
use_template: frequency_penalty
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 512
|
||||
default: 16384
|
||||
min: 1
|
||||
max: 16384
|
||||
- name: response_format
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ parameter_rules:
|
|||
use_template: frequency_penalty
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 512
|
||||
default: 16384
|
||||
min: 1
|
||||
max: 16384
|
||||
- name: response_format
|
||||
|
|
|
|||
|
|
@ -22,9 +22,9 @@ parameter_rules:
|
|||
use_template: frequency_penalty
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 512
|
||||
default: 16384
|
||||
min: 1
|
||||
max: 4096
|
||||
max: 16384
|
||||
- name: response_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
|
|
@ -38,7 +38,7 @@ parameter_rules:
|
|||
- text
|
||||
- json_object
|
||||
pricing:
|
||||
input: '5.00'
|
||||
output: '15.00'
|
||||
input: '2.50'
|
||||
output: '10.00'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue