mirror of https://github.com/langgenius/dify.git
merge main
This commit is contained in:
commit
a52bf6211a
|
|
@ -47,15 +47,17 @@ jobs:
|
|||
- name: Run Unit tests
|
||||
run: |
|
||||
uv run --project api bash dev/pytest/pytest_unit_tests.sh
|
||||
|
||||
- name: Coverage Summary
|
||||
run: |
|
||||
set -x
|
||||
# Extract coverage percentage and create a summary
|
||||
TOTAL_COVERAGE=$(python -c 'import json; print(json.load(open("coverage.json"))["totals"]["percent_covered_display"])')
|
||||
|
||||
# Create a detailed coverage summary
|
||||
echo "### Test Coverage Summary :test_tube:" >> $GITHUB_STEP_SUMMARY
|
||||
echo "Total Coverage: ${TOTAL_COVERAGE}%" >> $GITHUB_STEP_SUMMARY
|
||||
echo "\`\`\`" >> $GITHUB_STEP_SUMMARY
|
||||
uv run --project api coverage report >> $GITHUB_STEP_SUMMARY
|
||||
echo "\`\`\`" >> $GITHUB_STEP_SUMMARY
|
||||
uv run --project api coverage report --format=markdown >> $GITHUB_STEP_SUMMARY
|
||||
|
||||
- name: Run dify config tests
|
||||
run: uv run --project api dev/pytest/pytest_config_tests.py
|
||||
|
|
|
|||
|
|
@ -214,3 +214,4 @@ mise.toml
|
|||
|
||||
# AI Assistant
|
||||
.roo/
|
||||
api/.env.backup
|
||||
|
|
|
|||
|
|
@ -1,8 +1,11 @@
|
|||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from pydantic.fields import FieldInfo
|
||||
from pydantic_settings import BaseSettings, PydanticBaseSettingsSource, SettingsConfigDict
|
||||
from pydantic_settings import BaseSettings, PydanticBaseSettingsSource, SettingsConfigDict, TomlConfigSettingsSource
|
||||
|
||||
from libs.file_utils import search_file_upwards
|
||||
|
||||
from .deploy import DeploymentConfig
|
||||
from .enterprise import EnterpriseFeatureConfig
|
||||
|
|
@ -99,4 +102,12 @@ class DifyConfig(
|
|||
RemoteSettingsSourceFactory(settings_cls),
|
||||
dotenv_settings,
|
||||
file_secret_settings,
|
||||
TomlConfigSettingsSource(
|
||||
settings_cls=settings_cls,
|
||||
toml_file=search_file_upwards(
|
||||
base_dir_path=Path(__file__).parent,
|
||||
target_file_name="pyproject.toml",
|
||||
max_search_parent_depth=2,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -223,6 +223,10 @@ class CeleryConfig(DatabaseConfig):
|
|||
default=None,
|
||||
)
|
||||
|
||||
CELERY_SENTINEL_PASSWORD: Optional[str] = Field(
|
||||
description="Password of the Redis Sentinel master.",
|
||||
default=None,
|
||||
)
|
||||
CELERY_SENTINEL_SOCKET_TIMEOUT: Optional[PositiveFloat] = Field(
|
||||
description="Timeout for Redis Sentinel socket operations in seconds.",
|
||||
default=0.1,
|
||||
|
|
|
|||
|
|
@ -1,17 +1,13 @@
|
|||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
from configs.packaging.pyproject import PyProjectConfig, PyProjectTomlConfig
|
||||
|
||||
|
||||
class PackagingInfo(BaseSettings):
|
||||
class PackagingInfo(PyProjectTomlConfig):
|
||||
"""
|
||||
Packaging build information
|
||||
"""
|
||||
|
||||
CURRENT_VERSION: str = Field(
|
||||
description="Dify version",
|
||||
default="1.5.0",
|
||||
)
|
||||
|
||||
COMMIT_SHA: str = Field(
|
||||
description="SHA-1 checksum of the git commit used to build the app",
|
||||
default="",
|
||||
|
|
|
|||
|
|
@ -0,0 +1,17 @@
|
|||
from pydantic import BaseModel, Field
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class PyProjectConfig(BaseModel):
|
||||
version: str = Field(description="Dify version", default="")
|
||||
|
||||
|
||||
class PyProjectTomlConfig(BaseSettings):
|
||||
"""
|
||||
configs in api/pyproject.toml
|
||||
"""
|
||||
|
||||
project: PyProjectConfig = Field(
|
||||
description="configs in the project section of pyproject.toml",
|
||||
default=PyProjectConfig(),
|
||||
)
|
||||
|
|
@ -41,7 +41,7 @@ class OAuthDataSource(Resource):
|
|||
if not internal_secret:
|
||||
return ({"error": "Internal secret is not set"},)
|
||||
oauth_provider.save_internal_access_token(internal_secret)
|
||||
return {"data": ""}
|
||||
return {"data": "internal"}
|
||||
else:
|
||||
auth_url = oauth_provider.get_authorization_url()
|
||||
return {"data": auth_url}, 200
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ class VersionApi(Resource):
|
|||
check_update_url = dify_config.CHECK_UPDATE_URL
|
||||
|
||||
result = {
|
||||
"version": dify_config.CURRENT_VERSION,
|
||||
"version": dify_config.project.version,
|
||||
"release_date": "",
|
||||
"release_notes": "",
|
||||
"can_auto_update": False,
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ from core.model_runtime.utils.encoders import jsonable_encoder
|
|||
from core.plugin.impl.exc import PluginDaemonClientSideError
|
||||
from libs.login import login_required
|
||||
from models.account import TenantPluginPermission
|
||||
from services.plugin.plugin_parameter_service import PluginParameterService
|
||||
from services.plugin.plugin_permission_service import PluginPermissionService
|
||||
from services.plugin.plugin_service import PluginService
|
||||
|
||||
|
|
@ -497,6 +498,42 @@ class PluginFetchPermissionApi(Resource):
|
|||
)
|
||||
|
||||
|
||||
class PluginFetchDynamicSelectOptionsApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
# check if the user is admin or owner
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
tenant_id = current_user.current_tenant_id
|
||||
user_id = current_user.id
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("plugin_id", type=str, required=True, location="args")
|
||||
parser.add_argument("provider", type=str, required=True, location="args")
|
||||
parser.add_argument("action", type=str, required=True, location="args")
|
||||
parser.add_argument("parameter", type=str, required=True, location="args")
|
||||
parser.add_argument("provider_type", type=str, required=True, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
options = PluginParameterService.get_dynamic_select_options(
|
||||
tenant_id,
|
||||
user_id,
|
||||
args["plugin_id"],
|
||||
args["provider"],
|
||||
args["action"],
|
||||
args["parameter"],
|
||||
args["provider_type"],
|
||||
)
|
||||
except PluginDaemonClientSideError as e:
|
||||
raise ValueError(e)
|
||||
|
||||
return jsonable_encoder({"options": options})
|
||||
|
||||
|
||||
api.add_resource(PluginDebuggingKeyApi, "/workspaces/current/plugin/debugging-key")
|
||||
api.add_resource(PluginListApi, "/workspaces/current/plugin/list")
|
||||
api.add_resource(PluginListLatestVersionsApi, "/workspaces/current/plugin/list/latest-versions")
|
||||
|
|
@ -521,3 +558,5 @@ api.add_resource(PluginFetchMarketplacePkgApi, "/workspaces/current/plugin/marke
|
|||
|
||||
api.add_resource(PluginChangePermissionApi, "/workspaces/current/plugin/permission/change")
|
||||
api.add_resource(PluginFetchPermissionApi, "/workspaces/current/plugin/permission/fetch")
|
||||
|
||||
api.add_resource(PluginFetchDynamicSelectOptionsApi, "/workspaces/current/plugin/parameters/dynamic-options")
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ from core.plugin.entities.request import (
|
|||
RequestInvokeApp,
|
||||
RequestInvokeEncrypt,
|
||||
RequestInvokeLLM,
|
||||
RequestInvokeLLMWithStructuredOutput,
|
||||
RequestInvokeModeration,
|
||||
RequestInvokeParameterExtractorNode,
|
||||
RequestInvokeQuestionClassifierNode,
|
||||
|
|
@ -47,6 +48,21 @@ class PluginInvokeLLMApi(Resource):
|
|||
return length_prefixed_response(0xF, generator())
|
||||
|
||||
|
||||
class PluginInvokeLLMWithStructuredOutputApi(Resource):
|
||||
@setup_required
|
||||
@plugin_inner_api_only
|
||||
@get_user_tenant
|
||||
@plugin_data(payload_type=RequestInvokeLLMWithStructuredOutput)
|
||||
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeLLMWithStructuredOutput):
|
||||
def generator():
|
||||
response = PluginModelBackwardsInvocation.invoke_llm_with_structured_output(
|
||||
user_model.id, tenant_model, payload
|
||||
)
|
||||
return PluginModelBackwardsInvocation.convert_to_event_stream(response)
|
||||
|
||||
return length_prefixed_response(0xF, generator())
|
||||
|
||||
|
||||
class PluginInvokeTextEmbeddingApi(Resource):
|
||||
@setup_required
|
||||
@plugin_inner_api_only
|
||||
|
|
@ -291,6 +307,7 @@ class PluginFetchAppInfoApi(Resource):
|
|||
|
||||
|
||||
api.add_resource(PluginInvokeLLMApi, "/invoke/llm")
|
||||
api.add_resource(PluginInvokeLLMWithStructuredOutputApi, "/invoke/llm/structured-output")
|
||||
api.add_resource(PluginInvokeTextEmbeddingApi, "/invoke/text-embedding")
|
||||
api.add_resource(PluginInvokeRerankApi, "/invoke/rerank")
|
||||
api.add_resource(PluginInvokeTTSApi, "/invoke/tts")
|
||||
|
|
|
|||
|
|
@ -29,7 +29,19 @@ class EnterpriseWorkspace(Resource):
|
|||
|
||||
tenant_was_created.send(tenant)
|
||||
|
||||
return {"message": "enterprise workspace created."}
|
||||
resp = {
|
||||
"id": tenant.id,
|
||||
"name": tenant.name,
|
||||
"plan": tenant.plan,
|
||||
"status": tenant.status,
|
||||
"created_at": tenant.created_at.isoformat() + "Z" if tenant.created_at else None,
|
||||
"updated_at": tenant.updated_at.isoformat() + "Z" if tenant.updated_at else None,
|
||||
}
|
||||
|
||||
return {
|
||||
"message": "enterprise workspace created.",
|
||||
"tenant": resp,
|
||||
}
|
||||
|
||||
|
||||
class EnterpriseWorkspaceNoOwnerEmail(Resource):
|
||||
|
|
|
|||
|
|
@ -133,6 +133,22 @@ class DatasetListApi(DatasetApiResource):
|
|||
parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.get("embedding_model_provider"):
|
||||
DatasetService.check_embedding_model_setting(
|
||||
tenant_id, args.get("embedding_model_provider"), args.get("embedding_model")
|
||||
)
|
||||
if (
|
||||
args.get("retrieval_model")
|
||||
and args.get("retrieval_model").get("reranking_model")
|
||||
and args.get("retrieval_model").get("reranking_model").get("reranking_provider_name")
|
||||
):
|
||||
DatasetService.check_reranking_model_setting(
|
||||
tenant_id,
|
||||
args.get("retrieval_model").get("reranking_model").get("reranking_provider_name"),
|
||||
args.get("retrieval_model").get("reranking_model").get("reranking_model_name"),
|
||||
)
|
||||
|
||||
try:
|
||||
dataset = DatasetService.create_empty_dataset(
|
||||
tenant_id=tenant_id,
|
||||
|
|
@ -265,10 +281,20 @@ class DatasetApi(DatasetApiResource):
|
|||
data = request.get_json()
|
||||
|
||||
# check embedding model setting
|
||||
if data.get("indexing_technique") == "high_quality":
|
||||
if data.get("indexing_technique") == "high_quality" or data.get("embedding_model_provider"):
|
||||
DatasetService.check_embedding_model_setting(
|
||||
dataset.tenant_id, data.get("embedding_model_provider"), data.get("embedding_model")
|
||||
)
|
||||
if (
|
||||
data.get("retrieval_model")
|
||||
and data.get("retrieval_model").get("reranking_model")
|
||||
and data.get("retrieval_model").get("reranking_model").get("reranking_provider_name")
|
||||
):
|
||||
DatasetService.check_reranking_model_setting(
|
||||
dataset.tenant_id,
|
||||
data.get("retrieval_model").get("reranking_model").get("reranking_provider_name"),
|
||||
data.get("retrieval_model").get("reranking_model").get("reranking_model_name"),
|
||||
)
|
||||
|
||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||
DatasetPermissionService.check_permission(
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import json
|
|||
from flask import request
|
||||
from flask_restful import marshal, reqparse
|
||||
from sqlalchemy import desc, select
|
||||
from werkzeug.exceptions import NotFound
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
import services
|
||||
from controllers.common.errors import FilenameNotExistsError
|
||||
|
|
@ -18,6 +18,7 @@ from controllers.service_api.app.error import (
|
|||
from controllers.service_api.dataset.error import (
|
||||
ArchivedDocumentImmutableError,
|
||||
DocumentIndexingError,
|
||||
InvalidMetadataError,
|
||||
)
|
||||
from controllers.service_api.wraps import (
|
||||
DatasetApiResource,
|
||||
|
|
@ -29,7 +30,7 @@ from extensions.ext_database import db
|
|||
from fields.document_fields import document_fields, document_status_fields
|
||||
from libs.login import current_user
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
from services.dataset_service import DocumentService
|
||||
from services.dataset_service import DatasetService, DocumentService
|
||||
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig
|
||||
from services.file_service import FileService
|
||||
|
||||
|
|
@ -59,6 +60,7 @@ class DocumentAddByTextApi(DatasetApiResource):
|
|||
parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
dataset_id = str(dataset_id)
|
||||
tenant_id = str(tenant_id)
|
||||
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
||||
|
|
@ -74,6 +76,21 @@ class DocumentAddByTextApi(DatasetApiResource):
|
|||
if text is None or name is None:
|
||||
raise ValueError("Both 'text' and 'name' must be non-null values.")
|
||||
|
||||
if args.get("embedding_model_provider"):
|
||||
DatasetService.check_embedding_model_setting(
|
||||
tenant_id, args.get("embedding_model_provider"), args.get("embedding_model")
|
||||
)
|
||||
if (
|
||||
args.get("retrieval_model")
|
||||
and args.get("retrieval_model").get("reranking_model")
|
||||
and args.get("retrieval_model").get("reranking_model").get("reranking_provider_name")
|
||||
):
|
||||
DatasetService.check_reranking_model_setting(
|
||||
tenant_id,
|
||||
args.get("retrieval_model").get("reranking_model").get("reranking_provider_name"),
|
||||
args.get("retrieval_model").get("reranking_model").get("reranking_model_name"),
|
||||
)
|
||||
|
||||
upload_file = FileService.upload_text(text=str(text), text_name=str(name))
|
||||
data_source = {
|
||||
"type": "upload_file",
|
||||
|
|
@ -124,6 +141,17 @@ class DocumentUpdateByTextApi(DatasetApiResource):
|
|||
if not dataset:
|
||||
raise ValueError("Dataset does not exist.")
|
||||
|
||||
if (
|
||||
args.get("retrieval_model")
|
||||
and args.get("retrieval_model").get("reranking_model")
|
||||
and args.get("retrieval_model").get("reranking_model").get("reranking_provider_name")
|
||||
):
|
||||
DatasetService.check_reranking_model_setting(
|
||||
tenant_id,
|
||||
args.get("retrieval_model").get("reranking_model").get("reranking_provider_name"),
|
||||
args.get("retrieval_model").get("reranking_model").get("reranking_model_name"),
|
||||
)
|
||||
|
||||
# indexing_technique is already set in dataset since this is an update
|
||||
args["indexing_technique"] = dataset.indexing_technique
|
||||
|
||||
|
|
@ -188,6 +216,21 @@ class DocumentAddByFileApi(DatasetApiResource):
|
|||
raise ValueError("indexing_technique is required.")
|
||||
args["indexing_technique"] = indexing_technique
|
||||
|
||||
if "embedding_model_provider" in args:
|
||||
DatasetService.check_embedding_model_setting(
|
||||
tenant_id, args["embedding_model_provider"], args["embedding_model"]
|
||||
)
|
||||
if (
|
||||
"retrieval_model" in args
|
||||
and args["retrieval_model"].get("reranking_model")
|
||||
and args["retrieval_model"].get("reranking_model").get("reranking_provider_name")
|
||||
):
|
||||
DatasetService.check_reranking_model_setting(
|
||||
tenant_id,
|
||||
args["retrieval_model"].get("reranking_model").get("reranking_provider_name"),
|
||||
args["retrieval_model"].get("reranking_model").get("reranking_model_name"),
|
||||
)
|
||||
|
||||
# save file info
|
||||
file = request.files["file"]
|
||||
# check file
|
||||
|
|
@ -424,6 +467,101 @@ class DocumentIndexingStatusApi(DatasetApiResource):
|
|||
return data
|
||||
|
||||
|
||||
class DocumentDetailApi(DatasetApiResource):
|
||||
METADATA_CHOICES = {"all", "only", "without"}
|
||||
|
||||
def get(self, tenant_id, dataset_id, document_id):
|
||||
dataset_id = str(dataset_id)
|
||||
document_id = str(document_id)
|
||||
|
||||
dataset = self.get_dataset(dataset_id, tenant_id)
|
||||
|
||||
document = DocumentService.get_document(dataset.id, document_id)
|
||||
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
|
||||
if document.tenant_id != str(tenant_id):
|
||||
raise Forbidden("No permission.")
|
||||
|
||||
metadata = request.args.get("metadata", "all")
|
||||
if metadata not in self.METADATA_CHOICES:
|
||||
raise InvalidMetadataError(f"Invalid metadata value: {metadata}")
|
||||
|
||||
if metadata == "only":
|
||||
response = {"id": document.id, "doc_type": document.doc_type, "doc_metadata": document.doc_metadata_details}
|
||||
elif metadata == "without":
|
||||
dataset_process_rules = DatasetService.get_process_rules(dataset_id)
|
||||
document_process_rules = document.dataset_process_rule.to_dict()
|
||||
data_source_info = document.data_source_detail_dict
|
||||
response = {
|
||||
"id": document.id,
|
||||
"position": document.position,
|
||||
"data_source_type": document.data_source_type,
|
||||
"data_source_info": data_source_info,
|
||||
"dataset_process_rule_id": document.dataset_process_rule_id,
|
||||
"dataset_process_rule": dataset_process_rules,
|
||||
"document_process_rule": document_process_rules,
|
||||
"name": document.name,
|
||||
"created_from": document.created_from,
|
||||
"created_by": document.created_by,
|
||||
"created_at": document.created_at.timestamp(),
|
||||
"tokens": document.tokens,
|
||||
"indexing_status": document.indexing_status,
|
||||
"completed_at": int(document.completed_at.timestamp()) if document.completed_at else None,
|
||||
"updated_at": int(document.updated_at.timestamp()) if document.updated_at else None,
|
||||
"indexing_latency": document.indexing_latency,
|
||||
"error": document.error,
|
||||
"enabled": document.enabled,
|
||||
"disabled_at": int(document.disabled_at.timestamp()) if document.disabled_at else None,
|
||||
"disabled_by": document.disabled_by,
|
||||
"archived": document.archived,
|
||||
"segment_count": document.segment_count,
|
||||
"average_segment_length": document.average_segment_length,
|
||||
"hit_count": document.hit_count,
|
||||
"display_status": document.display_status,
|
||||
"doc_form": document.doc_form,
|
||||
"doc_language": document.doc_language,
|
||||
}
|
||||
else:
|
||||
dataset_process_rules = DatasetService.get_process_rules(dataset_id)
|
||||
document_process_rules = document.dataset_process_rule.to_dict()
|
||||
data_source_info = document.data_source_detail_dict
|
||||
response = {
|
||||
"id": document.id,
|
||||
"position": document.position,
|
||||
"data_source_type": document.data_source_type,
|
||||
"data_source_info": data_source_info,
|
||||
"dataset_process_rule_id": document.dataset_process_rule_id,
|
||||
"dataset_process_rule": dataset_process_rules,
|
||||
"document_process_rule": document_process_rules,
|
||||
"name": document.name,
|
||||
"created_from": document.created_from,
|
||||
"created_by": document.created_by,
|
||||
"created_at": document.created_at.timestamp(),
|
||||
"tokens": document.tokens,
|
||||
"indexing_status": document.indexing_status,
|
||||
"completed_at": int(document.completed_at.timestamp()) if document.completed_at else None,
|
||||
"updated_at": int(document.updated_at.timestamp()) if document.updated_at else None,
|
||||
"indexing_latency": document.indexing_latency,
|
||||
"error": document.error,
|
||||
"enabled": document.enabled,
|
||||
"disabled_at": int(document.disabled_at.timestamp()) if document.disabled_at else None,
|
||||
"disabled_by": document.disabled_by,
|
||||
"archived": document.archived,
|
||||
"doc_type": document.doc_type,
|
||||
"doc_metadata": document.doc_metadata_details,
|
||||
"segment_count": document.segment_count,
|
||||
"average_segment_length": document.average_segment_length,
|
||||
"hit_count": document.hit_count,
|
||||
"display_status": document.display_status,
|
||||
"doc_form": document.doc_form,
|
||||
"doc_language": document.doc_language,
|
||||
}
|
||||
|
||||
return response
|
||||
|
||||
|
||||
api.add_resource(
|
||||
DocumentAddByTextApi,
|
||||
"/datasets/<uuid:dataset_id>/document/create_by_text",
|
||||
|
|
@ -447,3 +585,4 @@ api.add_resource(
|
|||
api.add_resource(DocumentDeleteApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>")
|
||||
api.add_resource(DocumentListApi, "/datasets/<uuid:dataset_id>/documents")
|
||||
api.add_resource(DocumentIndexingStatusApi, "/datasets/<uuid:dataset_id>/documents/<string:batch>/indexing-status")
|
||||
api.add_resource(DocumentDetailApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>")
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ class IndexApi(Resource):
|
|||
return {
|
||||
"welcome": "Dify OpenAPI",
|
||||
"api_version": "v1",
|
||||
"server_version": dify_config.CURRENT_VERSION,
|
||||
"server_version": dify_config.project.version,
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -11,13 +11,13 @@ from flask_restful import Resource
|
|||
from pydantic import BaseModel
|
||||
from sqlalchemy import select, update
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden, Unauthorized
|
||||
from werkzeug.exceptions import Forbidden, NotFound, Unauthorized
|
||||
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.login import _get_user
|
||||
from models.account import Account, Tenant, TenantAccountJoin, TenantStatus
|
||||
from models.dataset import RateLimitLog
|
||||
from models.dataset import Dataset, RateLimitLog
|
||||
from models.model import ApiToken, App, EndUser
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
|
|
@ -317,3 +317,11 @@ def create_or_update_end_user_for_user_id(app_model: App, user_id: Optional[str]
|
|||
|
||||
class DatasetApiResource(Resource):
|
||||
method_decorators = [validate_dataset_token]
|
||||
|
||||
def get_dataset(self, dataset_id: str, tenant_id: str) -> Dataset:
|
||||
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id, Dataset.tenant_id == tenant_id).first()
|
||||
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
return dataset
|
||||
|
|
|
|||
|
|
@ -27,6 +27,9 @@ from core.ops.ops_trace_manager import TraceQueueManager
|
|||
from core.prompt.utils.get_thread_messages_length import get_thread_messages_length
|
||||
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository
|
||||
from core.workflow.repositories.draft_variable_repository import (
|
||||
DraftVariableSaverFactory,
|
||||
)
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader
|
||||
|
|
@ -36,7 +39,10 @@ from libs.flask_utils import preserve_flask_contexts
|
|||
from models import Account, App, Conversation, EndUser, Message, Workflow, WorkflowNodeExecutionTriggeredFrom
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
from services.conversation_service import ConversationService
|
||||
from services.workflow_draft_variable_service import DraftVarLoader, WorkflowDraftVariableService
|
||||
from services.workflow_draft_variable_service import (
|
||||
DraftVarLoader,
|
||||
WorkflowDraftVariableService,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -450,6 +456,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
stream=stream,
|
||||
draft_var_saver_factory=self._get_draft_var_saver_factory(invoke_from),
|
||||
)
|
||||
|
||||
return AdvancedChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
|
||||
|
|
@ -521,6 +528,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||
user: Union[Account, EndUser],
|
||||
workflow_execution_repository: WorkflowExecutionRepository,
|
||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||
draft_var_saver_factory: DraftVariableSaverFactory,
|
||||
stream: bool = False,
|
||||
) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
|
||||
"""
|
||||
|
|
@ -547,6 +555,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
stream=stream,
|
||||
draft_var_saver_factory=draft_var_saver_factory,
|
||||
)
|
||||
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -64,6 +64,7 @@ from core.workflow.entities.workflow_execution import WorkflowExecutionStatus, W
|
|||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.nodes import NodeType
|
||||
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from core.workflow.workflow_cycle_manager import CycleManagerWorkflowInfo, WorkflowCycleManager
|
||||
|
|
@ -94,6 +95,7 @@ class AdvancedChatAppGenerateTaskPipeline:
|
|||
dialogue_count: int,
|
||||
workflow_execution_repository: WorkflowExecutionRepository,
|
||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||
draft_var_saver_factory: DraftVariableSaverFactory,
|
||||
) -> None:
|
||||
self._base_task_pipeline = BasedGenerateTaskPipeline(
|
||||
application_generate_entity=application_generate_entity,
|
||||
|
|
@ -153,6 +155,7 @@ class AdvancedChatAppGenerateTaskPipeline:
|
|||
self._conversation_name_generate_thread: Thread | None = None
|
||||
self._recorded_files: list[Mapping[str, Any]] = []
|
||||
self._workflow_run_id: str = ""
|
||||
self._draft_var_saver_factory = draft_var_saver_factory
|
||||
|
||||
def process(self) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
|
||||
"""
|
||||
|
|
@ -371,6 +374,7 @@ class AdvancedChatAppGenerateTaskPipeline:
|
|||
workflow_node_execution=workflow_node_execution,
|
||||
)
|
||||
session.commit()
|
||||
self._save_output_for_event(event, workflow_node_execution.id)
|
||||
|
||||
if node_finish_resp:
|
||||
yield node_finish_resp
|
||||
|
|
@ -390,6 +394,8 @@ class AdvancedChatAppGenerateTaskPipeline:
|
|||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
)
|
||||
if isinstance(event, QueueNodeExceptionEvent):
|
||||
self._save_output_for_event(event, workflow_node_execution.id)
|
||||
|
||||
if node_finish_resp:
|
||||
yield node_finish_resp
|
||||
|
|
@ -759,3 +765,15 @@ class AdvancedChatAppGenerateTaskPipeline:
|
|||
if not message:
|
||||
raise ValueError(f"Message not found: {self._message_id}")
|
||||
return message
|
||||
|
||||
def _save_output_for_event(self, event: QueueNodeSucceededEvent | QueueNodeExceptionEvent, node_execution_id: str):
|
||||
with Session(db.engine) as session, session.begin():
|
||||
saver = self._draft_var_saver_factory(
|
||||
session=session,
|
||||
app_id=self._application_generate_entity.app_config.app_id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_execution_id=node_execution_id,
|
||||
enclosing_node_id=event.in_loop_id or event.in_iteration_id,
|
||||
)
|
||||
saver.save(event.process_data, event.outputs)
|
||||
|
|
|
|||
|
|
@ -1,10 +1,20 @@
|
|||
import json
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union, final
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.app_config.entities import VariableEntityType
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.file import File, FileUploadConfig
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.repositories.draft_variable_repository import (
|
||||
DraftVariableSaver,
|
||||
DraftVariableSaverFactory,
|
||||
NoopDraftVariableSaver,
|
||||
)
|
||||
from factories import file_factory
|
||||
from services.workflow_draft_variable_service import DraftVariableSaver as DraftVariableSaverImpl
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.app.app_config.entities import VariableEntity
|
||||
|
|
@ -159,3 +169,38 @@ class BaseAppGenerator:
|
|||
yield f"event: {message}\n\n"
|
||||
|
||||
return gen()
|
||||
|
||||
@final
|
||||
@staticmethod
|
||||
def _get_draft_var_saver_factory(invoke_from: InvokeFrom) -> DraftVariableSaverFactory:
|
||||
if invoke_from == InvokeFrom.DEBUGGER:
|
||||
|
||||
def draft_var_saver_factory(
|
||||
session: Session,
|
||||
app_id: str,
|
||||
node_id: str,
|
||||
node_type: NodeType,
|
||||
node_execution_id: str,
|
||||
enclosing_node_id: str | None = None,
|
||||
) -> DraftVariableSaver:
|
||||
return DraftVariableSaverImpl(
|
||||
session=session,
|
||||
app_id=app_id,
|
||||
node_id=node_id,
|
||||
node_type=node_type,
|
||||
node_execution_id=node_execution_id,
|
||||
enclosing_node_id=enclosing_node_id,
|
||||
)
|
||||
else:
|
||||
|
||||
def draft_var_saver_factory(
|
||||
session: Session,
|
||||
app_id: str,
|
||||
node_id: str,
|
||||
node_type: NodeType,
|
||||
node_execution_id: str,
|
||||
enclosing_node_id: str | None = None,
|
||||
) -> DraftVariableSaver:
|
||||
return NoopDraftVariableSaver()
|
||||
|
||||
return draft_var_saver_factory
|
||||
|
|
|
|||
|
|
@ -44,6 +44,7 @@ from core.app.entities.task_entities import (
|
|||
)
|
||||
from core.file import FILE_MODEL_IDENTITY, File
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.variables.segments import ArrayFileSegment, FileSegment, Segment
|
||||
from core.workflow.entities.workflow_execution import WorkflowExecution
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution, WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes import NodeType
|
||||
|
|
@ -506,7 +507,8 @@ class WorkflowResponseConverter:
|
|||
# Convert to tuple to match Sequence type
|
||||
return tuple(flattened_files)
|
||||
|
||||
def _fetch_files_from_variable_value(self, value: Union[dict, list]) -> Sequence[Mapping[str, Any]]:
|
||||
@classmethod
|
||||
def _fetch_files_from_variable_value(cls, value: Union[dict, list, Segment]) -> Sequence[Mapping[str, Any]]:
|
||||
"""
|
||||
Fetch files from variable value
|
||||
:param value: variable value
|
||||
|
|
@ -515,20 +517,30 @@ class WorkflowResponseConverter:
|
|||
if not value:
|
||||
return []
|
||||
|
||||
files = []
|
||||
if isinstance(value, list):
|
||||
files: list[Mapping[str, Any]] = []
|
||||
if isinstance(value, FileSegment):
|
||||
files.append(value.value.to_dict())
|
||||
elif isinstance(value, ArrayFileSegment):
|
||||
files.extend([i.to_dict() for i in value.value])
|
||||
elif isinstance(value, File):
|
||||
files.append(value.to_dict())
|
||||
elif isinstance(value, list):
|
||||
for item in value:
|
||||
file = self._get_file_var_from_value(item)
|
||||
file = cls._get_file_var_from_value(item)
|
||||
if file:
|
||||
files.append(file)
|
||||
elif isinstance(value, dict):
|
||||
file = self._get_file_var_from_value(value)
|
||||
elif isinstance(
|
||||
value,
|
||||
dict,
|
||||
):
|
||||
file = cls._get_file_var_from_value(value)
|
||||
if file:
|
||||
files.append(file)
|
||||
|
||||
return files
|
||||
|
||||
def _get_file_var_from_value(self, value: Union[dict, list]) -> Mapping[str, Any] | None:
|
||||
@classmethod
|
||||
def _get_file_var_from_value(cls, value: Union[dict, list]) -> Mapping[str, Any] | None:
|
||||
"""
|
||||
Get file var from value
|
||||
:param value: variable value
|
||||
|
|
|
|||
|
|
@ -25,6 +25,7 @@ from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
|||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository
|
||||
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader
|
||||
|
|
@ -219,6 +220,9 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
# new thread with request context and contextvars
|
||||
context = contextvars.copy_context()
|
||||
|
||||
# release database connection, because the following new thread operations may take a long time
|
||||
db.session.close()
|
||||
|
||||
worker_thread = threading.Thread(
|
||||
target=self._generate_worker,
|
||||
kwargs={
|
||||
|
|
@ -233,6 +237,10 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
|
||||
worker_thread.start()
|
||||
|
||||
draft_var_saver_factory = self._get_draft_var_saver_factory(
|
||||
invoke_from,
|
||||
)
|
||||
|
||||
# return response or stream generator
|
||||
response = self._handle_response(
|
||||
application_generate_entity=application_generate_entity,
|
||||
|
|
@ -241,6 +249,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
user=user,
|
||||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
draft_var_saver_factory=draft_var_saver_factory,
|
||||
stream=streaming,
|
||||
)
|
||||
|
||||
|
|
@ -471,6 +480,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
user: Union[Account, EndUser],
|
||||
workflow_execution_repository: WorkflowExecutionRepository,
|
||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||
draft_var_saver_factory: DraftVariableSaverFactory,
|
||||
stream: bool = False,
|
||||
) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
|
||||
"""
|
||||
|
|
@ -491,6 +501,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
user=user,
|
||||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
draft_var_saver_factory=draft_var_saver_factory,
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -56,6 +56,7 @@ from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
|
|||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.workflow.entities.workflow_execution import WorkflowExecution, WorkflowExecutionStatus, WorkflowType
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from core.workflow.workflow_cycle_manager import CycleManagerWorkflowInfo, WorkflowCycleManager
|
||||
|
|
@ -87,6 +88,7 @@ class WorkflowAppGenerateTaskPipeline:
|
|||
stream: bool,
|
||||
workflow_execution_repository: WorkflowExecutionRepository,
|
||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||
draft_var_saver_factory: DraftVariableSaverFactory,
|
||||
) -> None:
|
||||
self._base_task_pipeline = BasedGenerateTaskPipeline(
|
||||
application_generate_entity=application_generate_entity,
|
||||
|
|
@ -131,6 +133,8 @@ class WorkflowAppGenerateTaskPipeline:
|
|||
self._application_generate_entity = application_generate_entity
|
||||
self._workflow_features_dict = workflow.features_dict
|
||||
self._workflow_run_id = ""
|
||||
self._invoke_from = queue_manager._invoke_from
|
||||
self._draft_var_saver_factory = draft_var_saver_factory
|
||||
|
||||
def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
|
||||
"""
|
||||
|
|
@ -322,6 +326,8 @@ class WorkflowAppGenerateTaskPipeline:
|
|||
workflow_node_execution=workflow_node_execution,
|
||||
)
|
||||
|
||||
self._save_output_for_event(event, workflow_node_execution.id)
|
||||
|
||||
if node_success_response:
|
||||
yield node_success_response
|
||||
elif isinstance(
|
||||
|
|
@ -339,6 +345,8 @@ class WorkflowAppGenerateTaskPipeline:
|
|||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
)
|
||||
if isinstance(event, QueueNodeExceptionEvent):
|
||||
self._save_output_for_event(event, workflow_node_execution.id)
|
||||
|
||||
if node_failed_response:
|
||||
yield node_failed_response
|
||||
|
|
@ -593,3 +601,15 @@ class WorkflowAppGenerateTaskPipeline:
|
|||
)
|
||||
|
||||
return response
|
||||
|
||||
def _save_output_for_event(self, event: QueueNodeSucceededEvent | QueueNodeExceptionEvent, node_execution_id: str):
|
||||
with Session(db.engine) as session, session.begin():
|
||||
saver = self._draft_var_saver_factory(
|
||||
session=session,
|
||||
app_id=self._application_generate_entity.app_config.app_id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_execution_id=node_execution_id,
|
||||
enclosing_node_id=event.in_loop_id or event.in_iteration_id,
|
||||
)
|
||||
saver.save(event.process_data, event.outputs)
|
||||
|
|
|
|||
|
|
@ -1,8 +1,6 @@
|
|||
from collections.abc import Mapping
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.apps.base_app_runner import AppRunner
|
||||
from core.app.entities.queue_entities import (
|
||||
|
|
@ -35,7 +33,6 @@ from core.workflow.entities.variable_pool import VariablePool
|
|||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
|
||||
from core.workflow.graph_engine.entities.event import (
|
||||
AgentLogEvent,
|
||||
BaseNodeEvent,
|
||||
GraphEngineEvent,
|
||||
GraphRunFailedEvent,
|
||||
GraphRunPartialSucceededEvent,
|
||||
|
|
@ -70,9 +67,6 @@ from core.workflow.workflow_entry import WorkflowEntry
|
|||
from extensions.ext_database import db
|
||||
from models.model import App
|
||||
from models.workflow import Workflow
|
||||
from services.workflow_draft_variable_service import (
|
||||
DraftVariableSaver,
|
||||
)
|
||||
|
||||
|
||||
class WorkflowBasedAppRunner(AppRunner):
|
||||
|
|
@ -400,7 +394,6 @@ class WorkflowBasedAppRunner(AppRunner):
|
|||
in_loop_id=event.in_loop_id,
|
||||
)
|
||||
)
|
||||
self._save_draft_var_for_event(event)
|
||||
|
||||
elif isinstance(event, NodeRunFailedEvent):
|
||||
self._publish_event(
|
||||
|
|
@ -464,7 +457,6 @@ class WorkflowBasedAppRunner(AppRunner):
|
|||
in_loop_id=event.in_loop_id,
|
||||
)
|
||||
)
|
||||
self._save_draft_var_for_event(event)
|
||||
|
||||
elif isinstance(event, NodeInIterationFailedEvent):
|
||||
self._publish_event(
|
||||
|
|
@ -718,30 +710,3 @@ class WorkflowBasedAppRunner(AppRunner):
|
|||
|
||||
def _publish_event(self, event: AppQueueEvent) -> None:
|
||||
self.queue_manager.publish(event, PublishFrom.APPLICATION_MANAGER)
|
||||
|
||||
def _save_draft_var_for_event(self, event: BaseNodeEvent):
|
||||
run_result = event.route_node_state.node_run_result
|
||||
if run_result is None:
|
||||
return
|
||||
process_data = run_result.process_data
|
||||
outputs = run_result.outputs
|
||||
with Session(bind=db.engine) as session, session.begin():
|
||||
draft_var_saver = DraftVariableSaver(
|
||||
session=session,
|
||||
app_id=self._get_app_id(),
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
# FIXME(QuantumGhost): rely on private state of queue_manager is not ideal.
|
||||
invoke_from=self.queue_manager._invoke_from,
|
||||
node_execution_id=event.id,
|
||||
enclosing_node_id=event.in_loop_id or event.in_iteration_id or None,
|
||||
)
|
||||
draft_var_saver.save(process_data=process_data, outputs=outputs)
|
||||
|
||||
|
||||
def _remove_first_element_from_variable_string(key: str) -> str:
|
||||
"""
|
||||
Remove the first element from the prefix.
|
||||
"""
|
||||
prefix, remaining = key.split(".", maxsplit=1)
|
||||
return remaining
|
||||
|
|
|
|||
|
|
@ -395,6 +395,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
|||
message.provider_response_latency = time.perf_counter() - self._start_at
|
||||
message.total_price = usage.total_price
|
||||
message.currency = usage.currency
|
||||
self._task_state.llm_result.usage.latency = message.provider_response_latency
|
||||
message.message_metadata = self._task_state.metadata.model_dump_json()
|
||||
|
||||
if trace_manager:
|
||||
|
|
|
|||
|
|
@ -15,6 +15,11 @@ class CommonParameterType(StrEnum):
|
|||
MODEL_SELECTOR = "model-selector"
|
||||
TOOLS_SELECTOR = "array[tools]"
|
||||
|
||||
# Dynamic select parameter
|
||||
# Once you are not sure about the available options until authorization is done
|
||||
# eg: Select a Slack channel from a Slack workspace
|
||||
DYNAMIC_SELECT = "dynamic-select"
|
||||
|
||||
# TOOL_SELECTOR = "tool-selector"
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,374 @@
|
|||
import json
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from copy import deepcopy
|
||||
from enum import StrEnum
|
||||
from typing import Any, Literal, Optional, cast, overload
|
||||
|
||||
import json_repair
|
||||
from pydantic import TypeAdapter, ValidationError
|
||||
|
||||
from core.llm_generator.output_parser.errors import OutputParserError
|
||||
from core.llm_generator.prompts import STRUCTURED_OUTPUT_PROMPT
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.callbacks.base_callback import Callback
|
||||
from core.model_runtime.entities.llm_entities import (
|
||||
LLMResult,
|
||||
LLMResultChunk,
|
||||
LLMResultChunkDelta,
|
||||
LLMResultChunkWithStructuredOutput,
|
||||
LLMResultWithStructuredOutput,
|
||||
)
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessage,
|
||||
PromptMessageTool,
|
||||
SystemPromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, ParameterRule
|
||||
|
||||
|
||||
class ResponseFormat(StrEnum):
|
||||
"""Constants for model response formats"""
|
||||
|
||||
JSON_SCHEMA = "json_schema" # model's structured output mode. some model like gemini, gpt-4o, support this mode.
|
||||
JSON = "JSON" # model's json mode. some model like claude support this mode.
|
||||
JSON_OBJECT = "json_object" # json mode's another alias. some model like deepseek-chat, qwen use this alias.
|
||||
|
||||
|
||||
class SpecialModelType(StrEnum):
|
||||
"""Constants for identifying model types"""
|
||||
|
||||
GEMINI = "gemini"
|
||||
OLLAMA = "ollama"
|
||||
|
||||
|
||||
@overload
|
||||
def invoke_llm_with_structured_output(
|
||||
provider: str,
|
||||
model_schema: AIModelEntity,
|
||||
model_instance: ModelInstance,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
json_schema: Mapping[str, Any],
|
||||
model_parameters: Optional[Mapping] = None,
|
||||
tools: Sequence[PromptMessageTool] | None = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stream: Literal[True] = True,
|
||||
user: Optional[str] = None,
|
||||
callbacks: Optional[list[Callback]] = None,
|
||||
) -> Generator[LLMResultChunkWithStructuredOutput, None, None]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def invoke_llm_with_structured_output(
|
||||
provider: str,
|
||||
model_schema: AIModelEntity,
|
||||
model_instance: ModelInstance,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
json_schema: Mapping[str, Any],
|
||||
model_parameters: Optional[Mapping] = None,
|
||||
tools: Sequence[PromptMessageTool] | None = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stream: Literal[False] = False,
|
||||
user: Optional[str] = None,
|
||||
callbacks: Optional[list[Callback]] = None,
|
||||
) -> LLMResultWithStructuredOutput: ...
|
||||
|
||||
|
||||
@overload
|
||||
def invoke_llm_with_structured_output(
|
||||
provider: str,
|
||||
model_schema: AIModelEntity,
|
||||
model_instance: ModelInstance,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
json_schema: Mapping[str, Any],
|
||||
model_parameters: Optional[Mapping] = None,
|
||||
tools: Sequence[PromptMessageTool] | None = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
callbacks: Optional[list[Callback]] = None,
|
||||
) -> LLMResultWithStructuredOutput | Generator[LLMResultChunkWithStructuredOutput, None, None]: ...
|
||||
|
||||
|
||||
def invoke_llm_with_structured_output(
|
||||
provider: str,
|
||||
model_schema: AIModelEntity,
|
||||
model_instance: ModelInstance,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
json_schema: Mapping[str, Any],
|
||||
model_parameters: Optional[Mapping] = None,
|
||||
tools: Sequence[PromptMessageTool] | None = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
callbacks: Optional[list[Callback]] = None,
|
||||
) -> LLMResultWithStructuredOutput | Generator[LLMResultChunkWithStructuredOutput, None, None]:
|
||||
"""
|
||||
Invoke large language model with structured output
|
||||
1. This method invokes model_instance.invoke_llm with json_schema
|
||||
2. Try to parse the result as structured output
|
||||
|
||||
:param prompt_messages: prompt messages
|
||||
:param json_schema: json schema
|
||||
:param model_parameters: model parameters
|
||||
:param tools: tools for tool calling
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
:param callbacks: callbacks
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
|
||||
# handle native json schema
|
||||
model_parameters_with_json_schema: dict[str, Any] = {
|
||||
**(model_parameters or {}),
|
||||
}
|
||||
|
||||
if model_schema.support_structure_output:
|
||||
model_parameters = _handle_native_json_schema(
|
||||
provider, model_schema, json_schema, model_parameters_with_json_schema, model_schema.parameter_rules
|
||||
)
|
||||
else:
|
||||
# Set appropriate response format based on model capabilities
|
||||
_set_response_format(model_parameters_with_json_schema, model_schema.parameter_rules)
|
||||
|
||||
# handle prompt based schema
|
||||
prompt_messages = _handle_prompt_based_schema(
|
||||
prompt_messages=prompt_messages,
|
||||
structured_output_schema=json_schema,
|
||||
)
|
||||
|
||||
llm_result = model_instance.invoke_llm(
|
||||
prompt_messages=list(prompt_messages),
|
||||
model_parameters=model_parameters_with_json_schema,
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user,
|
||||
callbacks=callbacks,
|
||||
)
|
||||
|
||||
if isinstance(llm_result, LLMResult):
|
||||
if not isinstance(llm_result.message.content, str):
|
||||
raise OutputParserError(
|
||||
f"Failed to parse structured output, LLM result is not a string: {llm_result.message.content}"
|
||||
)
|
||||
|
||||
return LLMResultWithStructuredOutput(
|
||||
structured_output=_parse_structured_output(llm_result.message.content),
|
||||
model=llm_result.model,
|
||||
message=llm_result.message,
|
||||
usage=llm_result.usage,
|
||||
system_fingerprint=llm_result.system_fingerprint,
|
||||
prompt_messages=llm_result.prompt_messages,
|
||||
)
|
||||
else:
|
||||
|
||||
def generator() -> Generator[LLMResultChunkWithStructuredOutput, None, None]:
|
||||
result_text: str = ""
|
||||
prompt_messages: Sequence[PromptMessage] = []
|
||||
system_fingerprint: Optional[str] = None
|
||||
for event in llm_result:
|
||||
if isinstance(event, LLMResultChunk):
|
||||
if isinstance(event.delta.message.content, str):
|
||||
result_text += event.delta.message.content
|
||||
prompt_messages = event.prompt_messages
|
||||
system_fingerprint = event.system_fingerprint
|
||||
|
||||
yield LLMResultChunkWithStructuredOutput(
|
||||
model=model_schema.model,
|
||||
prompt_messages=prompt_messages,
|
||||
system_fingerprint=system_fingerprint,
|
||||
delta=event.delta,
|
||||
)
|
||||
|
||||
yield LLMResultChunkWithStructuredOutput(
|
||||
structured_output=_parse_structured_output(result_text),
|
||||
model=model_schema.model,
|
||||
prompt_messages=prompt_messages,
|
||||
system_fingerprint=system_fingerprint,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(content=""),
|
||||
usage=None,
|
||||
finish_reason=None,
|
||||
),
|
||||
)
|
||||
|
||||
return generator()
|
||||
|
||||
|
||||
def _handle_native_json_schema(
|
||||
provider: str,
|
||||
model_schema: AIModelEntity,
|
||||
structured_output_schema: Mapping,
|
||||
model_parameters: dict,
|
||||
rules: list[ParameterRule],
|
||||
) -> dict:
|
||||
"""
|
||||
Handle structured output for models with native JSON schema support.
|
||||
|
||||
:param model_parameters: Model parameters to update
|
||||
:param rules: Model parameter rules
|
||||
:return: Updated model parameters with JSON schema configuration
|
||||
"""
|
||||
# Process schema according to model requirements
|
||||
schema_json = _prepare_schema_for_model(provider, model_schema, structured_output_schema)
|
||||
|
||||
# Set JSON schema in parameters
|
||||
model_parameters["json_schema"] = json.dumps(schema_json, ensure_ascii=False)
|
||||
|
||||
# Set appropriate response format if required by the model
|
||||
for rule in rules:
|
||||
if rule.name == "response_format" and ResponseFormat.JSON_SCHEMA.value in rule.options:
|
||||
model_parameters["response_format"] = ResponseFormat.JSON_SCHEMA.value
|
||||
|
||||
return model_parameters
|
||||
|
||||
|
||||
def _set_response_format(model_parameters: dict, rules: list) -> None:
|
||||
"""
|
||||
Set the appropriate response format parameter based on model rules.
|
||||
|
||||
:param model_parameters: Model parameters to update
|
||||
:param rules: Model parameter rules
|
||||
"""
|
||||
for rule in rules:
|
||||
if rule.name == "response_format":
|
||||
if ResponseFormat.JSON.value in rule.options:
|
||||
model_parameters["response_format"] = ResponseFormat.JSON.value
|
||||
elif ResponseFormat.JSON_OBJECT.value in rule.options:
|
||||
model_parameters["response_format"] = ResponseFormat.JSON_OBJECT.value
|
||||
|
||||
|
||||
def _handle_prompt_based_schema(
|
||||
prompt_messages: Sequence[PromptMessage], structured_output_schema: Mapping
|
||||
) -> list[PromptMessage]:
|
||||
"""
|
||||
Handle structured output for models without native JSON schema support.
|
||||
This function modifies the prompt messages to include schema-based output requirements.
|
||||
|
||||
Args:
|
||||
prompt_messages: Original sequence of prompt messages
|
||||
|
||||
Returns:
|
||||
list[PromptMessage]: Updated prompt messages with structured output requirements
|
||||
"""
|
||||
# Convert schema to string format
|
||||
schema_str = json.dumps(structured_output_schema, ensure_ascii=False)
|
||||
|
||||
# Find existing system prompt with schema placeholder
|
||||
system_prompt = next(
|
||||
(prompt for prompt in prompt_messages if isinstance(prompt, SystemPromptMessage)),
|
||||
None,
|
||||
)
|
||||
structured_output_prompt = STRUCTURED_OUTPUT_PROMPT.replace("{{schema}}", schema_str)
|
||||
# Prepare system prompt content
|
||||
system_prompt_content = (
|
||||
structured_output_prompt + "\n\n" + system_prompt.content
|
||||
if system_prompt and isinstance(system_prompt.content, str)
|
||||
else structured_output_prompt
|
||||
)
|
||||
system_prompt = SystemPromptMessage(content=system_prompt_content)
|
||||
|
||||
# Extract content from the last user message
|
||||
|
||||
filtered_prompts = [prompt for prompt in prompt_messages if not isinstance(prompt, SystemPromptMessage)]
|
||||
updated_prompt = [system_prompt] + filtered_prompts
|
||||
|
||||
return updated_prompt
|
||||
|
||||
|
||||
def _parse_structured_output(result_text: str) -> Mapping[str, Any]:
|
||||
structured_output: Mapping[str, Any] = {}
|
||||
parsed: Mapping[str, Any] = {}
|
||||
try:
|
||||
parsed = TypeAdapter(Mapping).validate_json(result_text)
|
||||
if not isinstance(parsed, dict):
|
||||
raise OutputParserError(f"Failed to parse structured output: {result_text}")
|
||||
structured_output = parsed
|
||||
except ValidationError:
|
||||
# if the result_text is not a valid json, try to repair it
|
||||
temp_parsed = json_repair.loads(result_text)
|
||||
if not isinstance(temp_parsed, dict):
|
||||
# handle reasoning model like deepseek-r1 got '<think>\n\n</think>\n' prefix
|
||||
if isinstance(temp_parsed, list):
|
||||
temp_parsed = next((item for item in temp_parsed if isinstance(item, dict)), {})
|
||||
else:
|
||||
raise OutputParserError(f"Failed to parse structured output: {result_text}")
|
||||
structured_output = cast(dict, temp_parsed)
|
||||
return structured_output
|
||||
|
||||
|
||||
def _prepare_schema_for_model(provider: str, model_schema: AIModelEntity, schema: Mapping) -> dict:
|
||||
"""
|
||||
Prepare JSON schema based on model requirements.
|
||||
|
||||
Different models have different requirements for JSON schema formatting.
|
||||
This function handles these differences.
|
||||
|
||||
:param schema: The original JSON schema
|
||||
:return: Processed schema compatible with the current model
|
||||
"""
|
||||
|
||||
# Deep copy to avoid modifying the original schema
|
||||
processed_schema = dict(deepcopy(schema))
|
||||
|
||||
# Convert boolean types to string types (common requirement)
|
||||
convert_boolean_to_string(processed_schema)
|
||||
|
||||
# Apply model-specific transformations
|
||||
if SpecialModelType.GEMINI in model_schema.model:
|
||||
remove_additional_properties(processed_schema)
|
||||
return processed_schema
|
||||
elif SpecialModelType.OLLAMA in provider:
|
||||
return processed_schema
|
||||
else:
|
||||
# Default format with name field
|
||||
return {"schema": processed_schema, "name": "llm_response"}
|
||||
|
||||
|
||||
def remove_additional_properties(schema: dict) -> None:
|
||||
"""
|
||||
Remove additionalProperties fields from JSON schema.
|
||||
Used for models like Gemini that don't support this property.
|
||||
|
||||
:param schema: JSON schema to modify in-place
|
||||
"""
|
||||
if not isinstance(schema, dict):
|
||||
return
|
||||
|
||||
# Remove additionalProperties at current level
|
||||
schema.pop("additionalProperties", None)
|
||||
|
||||
# Process nested structures recursively
|
||||
for value in schema.values():
|
||||
if isinstance(value, dict):
|
||||
remove_additional_properties(value)
|
||||
elif isinstance(value, list):
|
||||
for item in value:
|
||||
if isinstance(item, dict):
|
||||
remove_additional_properties(item)
|
||||
|
||||
|
||||
def convert_boolean_to_string(schema: dict) -> None:
|
||||
"""
|
||||
Convert boolean type specifications to string in JSON schema.
|
||||
|
||||
:param schema: JSON schema to modify in-place
|
||||
"""
|
||||
if not isinstance(schema, dict):
|
||||
return
|
||||
|
||||
# Check for boolean type at current level
|
||||
if schema.get("type") == "boolean":
|
||||
schema["type"] = "string"
|
||||
|
||||
# Process nested dictionaries and lists recursively
|
||||
for value in schema.values():
|
||||
if isinstance(value, dict):
|
||||
convert_boolean_to_string(value)
|
||||
elif isinstance(value, list):
|
||||
for item in value:
|
||||
if isinstance(item, dict):
|
||||
convert_boolean_to_string(item)
|
||||
|
|
@ -291,3 +291,21 @@ Your task is to convert simple user descriptions into properly formatted JSON Sc
|
|||
|
||||
Now, generate a JSON Schema based on my description
|
||||
""" # noqa: E501
|
||||
|
||||
STRUCTURED_OUTPUT_PROMPT = """You’re a helpful AI assistant. You could answer questions and output in JSON format.
|
||||
constraints:
|
||||
- You must output in JSON format.
|
||||
- Do not output boolean value, use string type instead.
|
||||
- Do not output integer or float value, use number type instead.
|
||||
eg:
|
||||
Here is the JSON schema:
|
||||
{"additionalProperties": false, "properties": {"age": {"type": "number"}, "name": {"type": "string"}}, "required": ["name", "age"], "type": "object"}
|
||||
|
||||
Here is the user's question:
|
||||
My name is John Doe and I am 30 years old.
|
||||
|
||||
output:
|
||||
{"name": "John Doe", "age": 30}
|
||||
Here is the JSON schema:
|
||||
{{schema}}
|
||||
""" # noqa: E501
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from collections.abc import Sequence
|
||||
from collections.abc import Mapping, Sequence
|
||||
from decimal import Decimal
|
||||
from enum import StrEnum
|
||||
from typing import Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
|
@ -101,6 +101,20 @@ class LLMResult(BaseModel):
|
|||
system_fingerprint: Optional[str] = None
|
||||
|
||||
|
||||
class LLMStructuredOutput(BaseModel):
|
||||
"""
|
||||
Model class for llm structured output.
|
||||
"""
|
||||
|
||||
structured_output: Optional[Mapping[str, Any]] = None
|
||||
|
||||
|
||||
class LLMResultWithStructuredOutput(LLMResult, LLMStructuredOutput):
|
||||
"""
|
||||
Model class for llm result with structured output.
|
||||
"""
|
||||
|
||||
|
||||
class LLMResultChunkDelta(BaseModel):
|
||||
"""
|
||||
Model class for llm result chunk delta.
|
||||
|
|
@ -123,6 +137,12 @@ class LLMResultChunk(BaseModel):
|
|||
delta: LLMResultChunkDelta
|
||||
|
||||
|
||||
class LLMResultChunkWithStructuredOutput(LLMResultChunk, LLMStructuredOutput):
|
||||
"""
|
||||
Model class for llm result chunk with structured output.
|
||||
"""
|
||||
|
||||
|
||||
class NumTokensResult(PriceInfo):
|
||||
"""
|
||||
Model class for number of tokens result.
|
||||
|
|
|
|||
|
|
@ -83,6 +83,7 @@ class LangFuseDataTrace(BaseTraceInstance):
|
|||
metadata=metadata,
|
||||
session_id=trace_info.conversation_id,
|
||||
tags=["message", "workflow"],
|
||||
version=trace_info.workflow_run_version,
|
||||
)
|
||||
self.add_trace(langfuse_trace_data=trace_data)
|
||||
workflow_span_data = LangfuseSpan(
|
||||
|
|
@ -108,6 +109,7 @@ class LangFuseDataTrace(BaseTraceInstance):
|
|||
metadata=metadata,
|
||||
session_id=trace_info.conversation_id,
|
||||
tags=["workflow"],
|
||||
version=trace_info.workflow_run_version,
|
||||
)
|
||||
self.add_trace(langfuse_trace_data=trace_data)
|
||||
|
||||
|
|
@ -172,37 +174,7 @@ class LangFuseDataTrace(BaseTraceInstance):
|
|||
}
|
||||
)
|
||||
|
||||
# add span
|
||||
if trace_info.message_id:
|
||||
span_data = LangfuseSpan(
|
||||
id=node_execution_id,
|
||||
name=node_type,
|
||||
input=inputs,
|
||||
output=outputs,
|
||||
trace_id=trace_id,
|
||||
start_time=created_at,
|
||||
end_time=finished_at,
|
||||
metadata=metadata,
|
||||
level=(LevelEnum.DEFAULT if status == "succeeded" else LevelEnum.ERROR),
|
||||
status_message=trace_info.error or "",
|
||||
parent_observation_id=trace_info.workflow_run_id,
|
||||
)
|
||||
else:
|
||||
span_data = LangfuseSpan(
|
||||
id=node_execution_id,
|
||||
name=node_type,
|
||||
input=inputs,
|
||||
output=outputs,
|
||||
trace_id=trace_id,
|
||||
start_time=created_at,
|
||||
end_time=finished_at,
|
||||
metadata=metadata,
|
||||
level=(LevelEnum.DEFAULT if status == "succeeded" else LevelEnum.ERROR),
|
||||
status_message=trace_info.error or "",
|
||||
)
|
||||
|
||||
self.add_span(langfuse_span_data=span_data)
|
||||
|
||||
# add generation span
|
||||
if process_data and process_data.get("model_mode") == "chat":
|
||||
total_token = metadata.get("total_tokens", 0)
|
||||
prompt_tokens = 0
|
||||
|
|
@ -226,10 +198,10 @@ class LangFuseDataTrace(BaseTraceInstance):
|
|||
)
|
||||
|
||||
node_generation_data = LangfuseGeneration(
|
||||
name="llm",
|
||||
id=node_execution_id,
|
||||
name=node_name,
|
||||
trace_id=trace_id,
|
||||
model=process_data.get("model_name"),
|
||||
parent_observation_id=node_execution_id,
|
||||
start_time=created_at,
|
||||
end_time=finished_at,
|
||||
input=inputs,
|
||||
|
|
@ -237,11 +209,30 @@ class LangFuseDataTrace(BaseTraceInstance):
|
|||
metadata=metadata,
|
||||
level=(LevelEnum.DEFAULT if status == "succeeded" else LevelEnum.ERROR),
|
||||
status_message=trace_info.error or "",
|
||||
parent_observation_id=trace_info.workflow_run_id if trace_info.message_id else None,
|
||||
usage=generation_usage,
|
||||
)
|
||||
|
||||
self.add_generation(langfuse_generation_data=node_generation_data)
|
||||
|
||||
# add normal span
|
||||
else:
|
||||
span_data = LangfuseSpan(
|
||||
id=node_execution_id,
|
||||
name=node_name,
|
||||
input=inputs,
|
||||
output=outputs,
|
||||
trace_id=trace_id,
|
||||
start_time=created_at,
|
||||
end_time=finished_at,
|
||||
metadata=metadata,
|
||||
level=(LevelEnum.DEFAULT if status == "succeeded" else LevelEnum.ERROR),
|
||||
status_message=trace_info.error or "",
|
||||
parent_observation_id=trace_info.workflow_run_id if trace_info.message_id else None,
|
||||
)
|
||||
|
||||
self.add_span(langfuse_span_data=span_data)
|
||||
|
||||
def message_trace(self, trace_info: MessageTraceInfo, **kwargs):
|
||||
# get message file data
|
||||
file_list = trace_info.file_list
|
||||
|
|
@ -284,7 +275,7 @@ class LangFuseDataTrace(BaseTraceInstance):
|
|||
)
|
||||
self.add_trace(langfuse_trace_data=trace_data)
|
||||
|
||||
# start add span
|
||||
# add generation
|
||||
generation_usage = GenerationUsage(
|
||||
input=trace_info.message_tokens,
|
||||
output=trace_info.answer_tokens,
|
||||
|
|
|
|||
|
|
@ -2,8 +2,15 @@ import tempfile
|
|||
from binascii import hexlify, unhexlify
|
||||
from collections.abc import Generator
|
||||
|
||||
from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.llm_entities import (
|
||||
LLMResult,
|
||||
LLMResultChunk,
|
||||
LLMResultChunkDelta,
|
||||
LLMResultChunkWithStructuredOutput,
|
||||
LLMResultWithStructuredOutput,
|
||||
)
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
PromptMessage,
|
||||
SystemPromptMessage,
|
||||
|
|
@ -12,6 +19,7 @@ from core.model_runtime.entities.message_entities import (
|
|||
from core.plugin.backwards_invocation.base import BaseBackwardsInvocation
|
||||
from core.plugin.entities.request import (
|
||||
RequestInvokeLLM,
|
||||
RequestInvokeLLMWithStructuredOutput,
|
||||
RequestInvokeModeration,
|
||||
RequestInvokeRerank,
|
||||
RequestInvokeSpeech2Text,
|
||||
|
|
@ -81,6 +89,72 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
|
|||
|
||||
return handle_non_streaming(response)
|
||||
|
||||
@classmethod
|
||||
def invoke_llm_with_structured_output(
|
||||
cls, user_id: str, tenant: Tenant, payload: RequestInvokeLLMWithStructuredOutput
|
||||
):
|
||||
"""
|
||||
invoke llm with structured output
|
||||
"""
|
||||
model_instance = ModelManager().get_model_instance(
|
||||
tenant_id=tenant.id,
|
||||
provider=payload.provider,
|
||||
model_type=payload.model_type,
|
||||
model=payload.model,
|
||||
)
|
||||
|
||||
model_schema = model_instance.model_type_instance.get_model_schema(payload.model, model_instance.credentials)
|
||||
|
||||
if not model_schema:
|
||||
raise ValueError(f"Model schema not found for {payload.model}")
|
||||
|
||||
response = invoke_llm_with_structured_output(
|
||||
provider=payload.provider,
|
||||
model_schema=model_schema,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=payload.prompt_messages,
|
||||
json_schema=payload.structured_output_schema,
|
||||
tools=payload.tools,
|
||||
stop=payload.stop,
|
||||
stream=True if payload.stream is None else payload.stream,
|
||||
user=user_id,
|
||||
model_parameters=payload.completion_params,
|
||||
)
|
||||
|
||||
if isinstance(response, Generator):
|
||||
|
||||
def handle() -> Generator[LLMResultChunkWithStructuredOutput, None, None]:
|
||||
for chunk in response:
|
||||
if chunk.delta.usage:
|
||||
llm_utils.deduct_llm_quota(
|
||||
tenant_id=tenant.id, model_instance=model_instance, usage=chunk.delta.usage
|
||||
)
|
||||
chunk.prompt_messages = []
|
||||
yield chunk
|
||||
|
||||
return handle()
|
||||
else:
|
||||
if response.usage:
|
||||
llm_utils.deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=response.usage)
|
||||
|
||||
def handle_non_streaming(
|
||||
response: LLMResultWithStructuredOutput,
|
||||
) -> Generator[LLMResultChunkWithStructuredOutput, None, None]:
|
||||
yield LLMResultChunkWithStructuredOutput(
|
||||
model=response.model,
|
||||
prompt_messages=[],
|
||||
system_fingerprint=response.system_fingerprint,
|
||||
structured_output=response.structured_output,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=response.message,
|
||||
usage=response.usage,
|
||||
finish_reason="",
|
||||
),
|
||||
)
|
||||
|
||||
return handle_non_streaming(response)
|
||||
|
||||
@classmethod
|
||||
def invoke_text_embedding(cls, user_id: str, tenant: Tenant, payload: RequestInvokeTextEmbedding):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -10,6 +10,9 @@ from core.tools.entities.common_entities import I18nObject
|
|||
class PluginParameterOption(BaseModel):
|
||||
value: str = Field(..., description="The value of the option")
|
||||
label: I18nObject = Field(..., description="The label of the option")
|
||||
icon: Optional[str] = Field(
|
||||
default=None, description="The icon of the option, can be a url or a base64 encoded image"
|
||||
)
|
||||
|
||||
@field_validator("value", mode="before")
|
||||
@classmethod
|
||||
|
|
@ -35,6 +38,7 @@ class PluginParameterType(enum.StrEnum):
|
|||
APP_SELECTOR = CommonParameterType.APP_SELECTOR.value
|
||||
MODEL_SELECTOR = CommonParameterType.MODEL_SELECTOR.value
|
||||
TOOLS_SELECTOR = CommonParameterType.TOOLS_SELECTOR.value
|
||||
DYNAMIC_SELECT = CommonParameterType.DYNAMIC_SELECT.value
|
||||
|
||||
# deprecated, should not use.
|
||||
SYSTEM_FILES = CommonParameterType.SYSTEM_FILES.value
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from collections.abc import Mapping
|
||||
from collections.abc import Mapping, Sequence
|
||||
from datetime import datetime
|
||||
from enum import StrEnum
|
||||
from typing import Any, Generic, Optional, TypeVar
|
||||
|
|
@ -9,6 +9,7 @@ from core.agent.plugin_entities import AgentProviderEntityWithPlugin
|
|||
from core.model_runtime.entities.model_entities import AIModelEntity
|
||||
from core.model_runtime.entities.provider_entities import ProviderEntity
|
||||
from core.plugin.entities.base import BasePluginEntity
|
||||
from core.plugin.entities.parameters import PluginParameterOption
|
||||
from core.plugin.entities.plugin import PluginDeclaration, PluginEntity
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolProviderEntityWithPlugin
|
||||
|
|
@ -186,3 +187,7 @@ class PluginOAuthCredentialsResponse(BaseModel):
|
|||
class PluginListResponse(BaseModel):
|
||||
list: list[PluginEntity]
|
||||
total: int
|
||||
|
||||
|
||||
class PluginDynamicSelectOptionsResponse(BaseModel):
|
||||
options: Sequence[PluginParameterOption] = Field(description="The options of the dynamic select.")
|
||||
|
|
|
|||
|
|
@ -82,6 +82,16 @@ class RequestInvokeLLM(BaseRequestInvokeModel):
|
|||
return v
|
||||
|
||||
|
||||
class RequestInvokeLLMWithStructuredOutput(RequestInvokeLLM):
|
||||
"""
|
||||
Request to invoke LLM with structured output
|
||||
"""
|
||||
|
||||
structured_output_schema: dict[str, Any] = Field(
|
||||
default_factory=dict, description="The schema of the structured output in JSON schema format"
|
||||
)
|
||||
|
||||
|
||||
class RequestInvokeTextEmbedding(BaseRequestInvokeModel):
|
||||
"""
|
||||
Request to invoke text embedding
|
||||
|
|
|
|||
|
|
@ -0,0 +1,45 @@
|
|||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.plugin.entities.plugin import GenericProviderID
|
||||
from core.plugin.entities.plugin_daemon import PluginDynamicSelectOptionsResponse
|
||||
from core.plugin.impl.base import BasePluginClient
|
||||
|
||||
|
||||
class DynamicSelectClient(BasePluginClient):
|
||||
def fetch_dynamic_select_options(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
action: str,
|
||||
credentials: Mapping[str, Any],
|
||||
parameter: str,
|
||||
) -> PluginDynamicSelectOptionsResponse:
|
||||
"""
|
||||
Fetch dynamic select options for a plugin parameter.
|
||||
"""
|
||||
response = self._request_with_plugin_daemon_response_stream(
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/dispatch/dynamic_select/fetch_parameter_options",
|
||||
PluginDynamicSelectOptionsResponse,
|
||||
data={
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
"provider": GenericProviderID(provider).provider_name,
|
||||
"credentials": credentials,
|
||||
"provider_action": action,
|
||||
"parameter": parameter,
|
||||
},
|
||||
},
|
||||
headers={
|
||||
"X-Plugin-ID": plugin_id,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
for options in response:
|
||||
return options
|
||||
|
||||
raise ValueError(f"Plugin service returned no options for parameter '{parameter}' in provider '{provider}'")
|
||||
|
|
@ -1010,6 +1010,9 @@ class DatasetRetrieval:
|
|||
def _process_metadata_filter_func(
|
||||
self, sequence: int, condition: str, metadata_name: str, value: Optional[Any], filters: list
|
||||
):
|
||||
if value is None:
|
||||
return
|
||||
|
||||
key = f"{metadata_name}_{sequence}"
|
||||
key_value = f"{metadata_name}_{sequence}_value"
|
||||
match condition:
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ from typing import Any, Optional
|
|||
from core.helper.code_executor.code_executor import CodeExecutor, CodeLanguage
|
||||
from core.tools.builtin_tool.tool import BuiltinTool
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.errors import ToolInvokeError
|
||||
|
||||
|
||||
class SimpleCode(BuiltinTool):
|
||||
|
|
@ -25,6 +26,8 @@ class SimpleCode(BuiltinTool):
|
|||
if language not in {CodeLanguage.PYTHON3, CodeLanguage.JAVASCRIPT}:
|
||||
raise ValueError(f"Only python3 and javascript are supported, not {language}")
|
||||
|
||||
result = CodeExecutor.execute_code(language, "", code)
|
||||
|
||||
yield self.create_text_message(result)
|
||||
try:
|
||||
result = CodeExecutor.execute_code(language, "", code)
|
||||
yield self.create_text_message(result)
|
||||
except Exception as e:
|
||||
raise ToolInvokeError(str(e))
|
||||
|
|
|
|||
|
|
@ -240,6 +240,7 @@ class ToolParameter(PluginParameter):
|
|||
FILES = PluginParameterType.FILES.value
|
||||
APP_SELECTOR = PluginParameterType.APP_SELECTOR.value
|
||||
MODEL_SELECTOR = PluginParameterType.MODEL_SELECTOR.value
|
||||
DYNAMIC_SELECT = PluginParameterType.DYNAMIC_SELECT.value
|
||||
|
||||
# deprecated, should not use.
|
||||
SYSTEM_FILES = PluginParameterType.SYSTEM_FILES.value
|
||||
|
|
|
|||
|
|
@ -86,6 +86,7 @@ class ProviderConfigEncrypter(BaseModel):
|
|||
cached_credentials = cache.get()
|
||||
if cached_credentials:
|
||||
return cached_credentials
|
||||
|
||||
data = self._deep_copy(data)
|
||||
# get fields need to be decrypted
|
||||
fields = dict[str, BasicProviderConfig]()
|
||||
|
|
|
|||
|
|
@ -66,11 +66,21 @@ class WorkflowNodeExecution(BaseModel):
|
|||
but they are not stored in the model.
|
||||
"""
|
||||
|
||||
# Core identification fields
|
||||
id: str # Unique identifier for this execution record
|
||||
node_execution_id: Optional[str] = None # Optional secondary ID for cross-referencing
|
||||
# --------- Core identification fields ---------
|
||||
|
||||
# Unique identifier for this execution record, used when persisting to storage.
|
||||
# Value is a UUID string (e.g., '09b3e04c-f9ae-404c-ad82-290b8d7bd382').
|
||||
id: str
|
||||
|
||||
# Optional secondary ID for cross-referencing purposes.
|
||||
#
|
||||
# NOTE: For referencing the persisted record, use `id` rather than `node_execution_id`.
|
||||
# While `node_execution_id` may sometimes be a UUID string, this is not guaranteed.
|
||||
# In most scenarios, `id` should be used as the primary identifier.
|
||||
node_execution_id: Optional[str] = None
|
||||
workflow_id: str # ID of the workflow this node belongs to
|
||||
workflow_execution_id: Optional[str] = None # ID of the specific workflow run (null for single-step debugging)
|
||||
# --------- Core identification fields ends ---------
|
||||
|
||||
# Execution positioning and flow
|
||||
index: int # Sequence number for ordering in trace visualization
|
||||
|
|
|
|||
|
|
@ -158,7 +158,10 @@ class AgentNode(ToolNode):
|
|||
# variable_pool.convert_template expects a string template,
|
||||
# but if passing a dict, convert to JSON string first before rendering
|
||||
try:
|
||||
parameter_value = json.dumps(agent_input.value, ensure_ascii=False)
|
||||
if not isinstance(agent_input.value, str):
|
||||
parameter_value = json.dumps(agent_input.value, ensure_ascii=False)
|
||||
else:
|
||||
parameter_value = str(agent_input.value)
|
||||
except TypeError:
|
||||
parameter_value = str(agent_input.value)
|
||||
segment_group = variable_pool.convert_template(parameter_value)
|
||||
|
|
@ -166,7 +169,8 @@ class AgentNode(ToolNode):
|
|||
# variable_pool.convert_template returns a string,
|
||||
# so we need to convert it back to a dictionary
|
||||
try:
|
||||
parameter_value = json.loads(parameter_value)
|
||||
if not isinstance(agent_input.value, str):
|
||||
parameter_value = json.loads(parameter_value)
|
||||
except json.JSONDecodeError:
|
||||
parameter_value = parameter_value
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@ import logging
|
|||
from collections.abc import Generator
|
||||
from typing import cast
|
||||
|
||||
from core.file import FILE_MODEL_IDENTITY, File
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.graph_engine.entities.event import (
|
||||
GraphEngineEvent,
|
||||
|
|
@ -201,44 +200,3 @@ class AnswerStreamProcessor(StreamProcessor):
|
|||
stream_out_answer_node_ids.append(answer_node_id)
|
||||
|
||||
return stream_out_answer_node_ids
|
||||
|
||||
@classmethod
|
||||
def _fetch_files_from_variable_value(cls, value: dict | list) -> list[dict]:
|
||||
"""
|
||||
Fetch files from variable value
|
||||
:param value: variable value
|
||||
:return:
|
||||
"""
|
||||
if not value:
|
||||
return []
|
||||
|
||||
files = []
|
||||
if isinstance(value, list):
|
||||
for item in value:
|
||||
file_var = cls._get_file_var_from_value(item)
|
||||
if file_var:
|
||||
files.append(file_var)
|
||||
elif isinstance(value, dict):
|
||||
file_var = cls._get_file_var_from_value(value)
|
||||
if file_var:
|
||||
files.append(file_var)
|
||||
|
||||
return files
|
||||
|
||||
@classmethod
|
||||
def _get_file_var_from_value(cls, value: dict | list):
|
||||
"""
|
||||
Get file var from value
|
||||
:param value: variable value
|
||||
:return:
|
||||
"""
|
||||
if not value:
|
||||
return None
|
||||
|
||||
if isinstance(value, dict):
|
||||
if "dify_model_identity" in value and value["dify_model_identity"] == FILE_MODEL_IDENTITY:
|
||||
return value
|
||||
elif isinstance(value, File):
|
||||
return value.to_dict()
|
||||
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -333,7 +333,7 @@ class Executor:
|
|||
try:
|
||||
response = getattr(ssrf_proxy, self.method.lower())(**request_args)
|
||||
except (ssrf_proxy.MaxRetriesExceededError, httpx.RequestError) as e:
|
||||
raise HttpRequestNodeError(str(e))
|
||||
raise HttpRequestNodeError(str(e)) from e
|
||||
# FIXME: fix type ignore, this maybe httpx type issue
|
||||
return response # type: ignore
|
||||
|
||||
|
|
|
|||
|
|
@ -490,6 +490,9 @@ class KnowledgeRetrievalNode(LLMNode):
|
|||
def _process_metadata_filter_func(
|
||||
self, sequence: int, condition: str, metadata_name: str, value: Optional[Any], filters: list
|
||||
):
|
||||
if value is None:
|
||||
return
|
||||
|
||||
key = f"{metadata_name}_{sequence}"
|
||||
key_value = f"{metadata_name}_{sequence}_value"
|
||||
match condition:
|
||||
|
|
|
|||
|
|
@ -5,11 +5,11 @@ import logging
|
|||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Optional, cast
|
||||
|
||||
import json_repair
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.file import FileType, file_manager
|
||||
from core.helper.code_executor import CodeExecutor, CodeLanguage
|
||||
from core.llm_generator.output_parser.errors import OutputParserError
|
||||
from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.model_runtime.entities import (
|
||||
|
|
@ -18,7 +18,13 @@ from core.model_runtime.entities import (
|
|||
PromptMessageContentType,
|
||||
TextPromptMessageContent,
|
||||
)
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMUsage
|
||||
from core.model_runtime.entities.llm_entities import (
|
||||
LLMResult,
|
||||
LLMResultChunk,
|
||||
LLMResultChunkWithStructuredOutput,
|
||||
LLMStructuredOutput,
|
||||
LLMUsage,
|
||||
)
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessageContentUnionTypes,
|
||||
|
|
@ -31,7 +37,6 @@ from core.model_runtime.entities.model_entities import (
|
|||
ModelFeature,
|
||||
ModelPropertyKey,
|
||||
ModelType,
|
||||
ParameterRule,
|
||||
)
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
|
|
@ -62,11 +67,6 @@ from core.workflow.nodes.event import (
|
|||
RunRetrieverResourceEvent,
|
||||
RunStreamChunkEvent,
|
||||
)
|
||||
from core.workflow.utils.structured_output.entities import (
|
||||
ResponseFormat,
|
||||
SpecialModelType,
|
||||
)
|
||||
from core.workflow.utils.structured_output.prompt import STRUCTURED_OUTPUT_PROMPT
|
||||
from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
||||
|
||||
from . import llm_utils
|
||||
|
|
@ -143,12 +143,6 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||
return "1"
|
||||
|
||||
def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]:
|
||||
def process_structured_output(text: str) -> Optional[dict[str, Any]]:
|
||||
"""Process structured output if enabled"""
|
||||
if not self.node_data.structured_output_enabled or not self.node_data.structured_output:
|
||||
return None
|
||||
return self._parse_structured_output(text)
|
||||
|
||||
node_inputs: Optional[dict[str, Any]] = None
|
||||
process_data = None
|
||||
result_text = ""
|
||||
|
|
@ -244,6 +238,8 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||
stop=stop,
|
||||
)
|
||||
|
||||
structured_output: LLMStructuredOutput | None = None
|
||||
|
||||
for event in generator:
|
||||
if isinstance(event, RunStreamChunkEvent):
|
||||
yield event
|
||||
|
|
@ -254,10 +250,12 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||
# deduct quota
|
||||
llm_utils.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
|
||||
break
|
||||
elif isinstance(event, LLMStructuredOutput):
|
||||
structured_output = event
|
||||
|
||||
outputs = {"text": result_text, "usage": jsonable_encoder(usage), "finish_reason": finish_reason}
|
||||
structured_output = process_structured_output(result_text)
|
||||
if structured_output:
|
||||
outputs["structured_output"] = structured_output
|
||||
outputs["structured_output"] = structured_output.structured_output
|
||||
if self._file_outputs is not None:
|
||||
outputs["files"] = ArrayFileSegment(value=self._file_outputs)
|
||||
|
||||
|
|
@ -302,20 +300,40 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||
model_instance: ModelInstance,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
) -> Generator[NodeEvent, None, None]:
|
||||
invoke_result = model_instance.invoke_llm(
|
||||
prompt_messages=list(prompt_messages),
|
||||
model_parameters=node_data_model.completion_params,
|
||||
stop=list(stop or []),
|
||||
stream=True,
|
||||
user=self.user_id,
|
||||
) -> Generator[NodeEvent | LLMStructuredOutput, None, None]:
|
||||
model_schema = model_instance.model_type_instance.get_model_schema(
|
||||
node_data_model.name, model_instance.credentials
|
||||
)
|
||||
if not model_schema:
|
||||
raise ValueError(f"Model schema not found for {node_data_model.name}")
|
||||
|
||||
if self.node_data.structured_output_enabled:
|
||||
output_schema = self._fetch_structured_output_schema()
|
||||
invoke_result = invoke_llm_with_structured_output(
|
||||
provider=model_instance.provider,
|
||||
model_schema=model_schema,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
json_schema=output_schema,
|
||||
model_parameters=node_data_model.completion_params,
|
||||
stop=list(stop or []),
|
||||
stream=True,
|
||||
user=self.user_id,
|
||||
)
|
||||
else:
|
||||
invoke_result = model_instance.invoke_llm(
|
||||
prompt_messages=list(prompt_messages),
|
||||
model_parameters=node_data_model.completion_params,
|
||||
stop=list(stop or []),
|
||||
stream=True,
|
||||
user=self.user_id,
|
||||
)
|
||||
|
||||
return self._handle_invoke_result(invoke_result=invoke_result)
|
||||
|
||||
def _handle_invoke_result(
|
||||
self, invoke_result: LLMResult | Generator[LLMResultChunk, None, None]
|
||||
) -> Generator[NodeEvent, None, None]:
|
||||
self, invoke_result: LLMResult | Generator[LLMResultChunk | LLMStructuredOutput, None, None]
|
||||
) -> Generator[NodeEvent | LLMStructuredOutput, None, None]:
|
||||
# For blocking mode
|
||||
if isinstance(invoke_result, LLMResult):
|
||||
event = self._handle_blocking_result(invoke_result=invoke_result)
|
||||
|
|
@ -329,23 +347,32 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||
usage = LLMUsage.empty_usage()
|
||||
finish_reason = None
|
||||
full_text_buffer = io.StringIO()
|
||||
for result in invoke_result:
|
||||
contents = result.delta.message.content
|
||||
for text_part in self._save_multimodal_output_and_convert_result_to_markdown(contents):
|
||||
full_text_buffer.write(text_part)
|
||||
yield RunStreamChunkEvent(chunk_content=text_part, from_variable_selector=[self.node_id, "text"])
|
||||
# Consume the invoke result and handle generator exception
|
||||
try:
|
||||
for result in invoke_result:
|
||||
if isinstance(result, LLMResultChunkWithStructuredOutput):
|
||||
yield result
|
||||
if isinstance(result, LLMResultChunk):
|
||||
contents = result.delta.message.content
|
||||
for text_part in self._save_multimodal_output_and_convert_result_to_markdown(contents):
|
||||
full_text_buffer.write(text_part)
|
||||
yield RunStreamChunkEvent(
|
||||
chunk_content=text_part, from_variable_selector=[self.node_id, "text"]
|
||||
)
|
||||
|
||||
# Update the whole metadata
|
||||
if not model and result.model:
|
||||
model = result.model
|
||||
if len(prompt_messages) == 0:
|
||||
# TODO(QuantumGhost): it seems that this update has no visable effect.
|
||||
# What's the purpose of the line below?
|
||||
prompt_messages = list(result.prompt_messages)
|
||||
if usage.prompt_tokens == 0 and result.delta.usage:
|
||||
usage = result.delta.usage
|
||||
if finish_reason is None and result.delta.finish_reason:
|
||||
finish_reason = result.delta.finish_reason
|
||||
# Update the whole metadata
|
||||
if not model and result.model:
|
||||
model = result.model
|
||||
if len(prompt_messages) == 0:
|
||||
# TODO(QuantumGhost): it seems that this update has no visable effect.
|
||||
# What's the purpose of the line below?
|
||||
prompt_messages = list(result.prompt_messages)
|
||||
if usage.prompt_tokens == 0 and result.delta.usage:
|
||||
usage = result.delta.usage
|
||||
if finish_reason is None and result.delta.finish_reason:
|
||||
finish_reason = result.delta.finish_reason
|
||||
except OutputParserError as e:
|
||||
raise LLMNodeError(f"Failed to parse structured output: {e}")
|
||||
|
||||
yield ModelInvokeCompletedEvent(text=full_text_buffer.getvalue(), usage=usage, finish_reason=finish_reason)
|
||||
|
||||
|
|
@ -522,12 +549,6 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||
if not model_schema:
|
||||
raise ModelNotExistError(f"Model {node_data_model.name} not exist.")
|
||||
|
||||
if self.node_data.structured_output_enabled:
|
||||
if model_schema.support_structure_output:
|
||||
completion_params = self._handle_native_json_schema(completion_params, model_schema.parameter_rules)
|
||||
else:
|
||||
# Set appropriate response format based on model capabilities
|
||||
self._set_response_format(completion_params, model_schema.parameter_rules)
|
||||
model_config_with_cred.parameters = completion_params
|
||||
# NOTE(-LAN-): This line modify the `self.node_data.model`, which is used in `_invoke_llm()`.
|
||||
node_data_model.completion_params = completion_params
|
||||
|
|
@ -719,32 +740,8 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||
)
|
||||
if not model_schema:
|
||||
raise ModelNotExistError(f"Model {model_config.model} not exist.")
|
||||
if self.node_data.structured_output_enabled:
|
||||
if not model_schema.support_structure_output:
|
||||
filtered_prompt_messages = self._handle_prompt_based_schema(
|
||||
prompt_messages=filtered_prompt_messages,
|
||||
)
|
||||
return filtered_prompt_messages, model_config.stop
|
||||
|
||||
def _parse_structured_output(self, result_text: str) -> dict[str, Any]:
|
||||
structured_output: dict[str, Any] = {}
|
||||
try:
|
||||
parsed = json.loads(result_text)
|
||||
if not isinstance(parsed, dict):
|
||||
raise LLMNodeError(f"Failed to parse structured output: {result_text}")
|
||||
structured_output = parsed
|
||||
except json.JSONDecodeError as e:
|
||||
# if the result_text is not a valid json, try to repair it
|
||||
parsed = json_repair.loads(result_text)
|
||||
if not isinstance(parsed, dict):
|
||||
# handle reasoning model like deepseek-r1 got '<think>\n\n</think>\n' prefix
|
||||
if isinstance(parsed, list):
|
||||
parsed = next((item for item in parsed if isinstance(item, dict)), {})
|
||||
else:
|
||||
raise LLMNodeError(f"Failed to parse structured output: {result_text}")
|
||||
structured_output = parsed
|
||||
return structured_output
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
|
|
@ -934,104 +931,6 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||
self._file_outputs.append(saved_file)
|
||||
return saved_file
|
||||
|
||||
def _handle_native_json_schema(self, model_parameters: dict, rules: list[ParameterRule]) -> dict:
|
||||
"""
|
||||
Handle structured output for models with native JSON schema support.
|
||||
|
||||
:param model_parameters: Model parameters to update
|
||||
:param rules: Model parameter rules
|
||||
:return: Updated model parameters with JSON schema configuration
|
||||
"""
|
||||
# Process schema according to model requirements
|
||||
schema = self._fetch_structured_output_schema()
|
||||
schema_json = self._prepare_schema_for_model(schema)
|
||||
|
||||
# Set JSON schema in parameters
|
||||
model_parameters["json_schema"] = json.dumps(schema_json, ensure_ascii=False)
|
||||
|
||||
# Set appropriate response format if required by the model
|
||||
for rule in rules:
|
||||
if rule.name == "response_format" and ResponseFormat.JSON_SCHEMA.value in rule.options:
|
||||
model_parameters["response_format"] = ResponseFormat.JSON_SCHEMA.value
|
||||
|
||||
return model_parameters
|
||||
|
||||
def _handle_prompt_based_schema(self, prompt_messages: Sequence[PromptMessage]) -> list[PromptMessage]:
|
||||
"""
|
||||
Handle structured output for models without native JSON schema support.
|
||||
This function modifies the prompt messages to include schema-based output requirements.
|
||||
|
||||
Args:
|
||||
prompt_messages: Original sequence of prompt messages
|
||||
|
||||
Returns:
|
||||
list[PromptMessage]: Updated prompt messages with structured output requirements
|
||||
"""
|
||||
# Convert schema to string format
|
||||
schema_str = json.dumps(self._fetch_structured_output_schema(), ensure_ascii=False)
|
||||
|
||||
# Find existing system prompt with schema placeholder
|
||||
system_prompt = next(
|
||||
(prompt for prompt in prompt_messages if isinstance(prompt, SystemPromptMessage)),
|
||||
None,
|
||||
)
|
||||
structured_output_prompt = STRUCTURED_OUTPUT_PROMPT.replace("{{schema}}", schema_str)
|
||||
# Prepare system prompt content
|
||||
system_prompt_content = (
|
||||
structured_output_prompt + "\n\n" + system_prompt.content
|
||||
if system_prompt and isinstance(system_prompt.content, str)
|
||||
else structured_output_prompt
|
||||
)
|
||||
system_prompt = SystemPromptMessage(content=system_prompt_content)
|
||||
|
||||
# Extract content from the last user message
|
||||
|
||||
filtered_prompts = [prompt for prompt in prompt_messages if not isinstance(prompt, SystemPromptMessage)]
|
||||
updated_prompt = [system_prompt] + filtered_prompts
|
||||
|
||||
return updated_prompt
|
||||
|
||||
def _set_response_format(self, model_parameters: dict, rules: list) -> None:
|
||||
"""
|
||||
Set the appropriate response format parameter based on model rules.
|
||||
|
||||
:param model_parameters: Model parameters to update
|
||||
:param rules: Model parameter rules
|
||||
"""
|
||||
for rule in rules:
|
||||
if rule.name == "response_format":
|
||||
if ResponseFormat.JSON.value in rule.options:
|
||||
model_parameters["response_format"] = ResponseFormat.JSON.value
|
||||
elif ResponseFormat.JSON_OBJECT.value in rule.options:
|
||||
model_parameters["response_format"] = ResponseFormat.JSON_OBJECT.value
|
||||
|
||||
def _prepare_schema_for_model(self, schema: dict) -> dict:
|
||||
"""
|
||||
Prepare JSON schema based on model requirements.
|
||||
|
||||
Different models have different requirements for JSON schema formatting.
|
||||
This function handles these differences.
|
||||
|
||||
:param schema: The original JSON schema
|
||||
:return: Processed schema compatible with the current model
|
||||
"""
|
||||
|
||||
# Deep copy to avoid modifying the original schema
|
||||
processed_schema = schema.copy()
|
||||
|
||||
# Convert boolean types to string types (common requirement)
|
||||
convert_boolean_to_string(processed_schema)
|
||||
|
||||
# Apply model-specific transformations
|
||||
if SpecialModelType.GEMINI in self.node_data.model.name:
|
||||
remove_additional_properties(processed_schema)
|
||||
return processed_schema
|
||||
elif SpecialModelType.OLLAMA in self.node_data.model.provider:
|
||||
return processed_schema
|
||||
else:
|
||||
# Default format with name field
|
||||
return {"schema": processed_schema, "name": "llm_response"}
|
||||
|
||||
def _fetch_model_schema(self, provider: str) -> AIModelEntity | None:
|
||||
"""
|
||||
Fetch model schema
|
||||
|
|
@ -1243,49 +1142,3 @@ def _handle_completion_template(
|
|||
)
|
||||
prompt_messages.append(prompt_message)
|
||||
return prompt_messages
|
||||
|
||||
|
||||
def remove_additional_properties(schema: dict) -> None:
|
||||
"""
|
||||
Remove additionalProperties fields from JSON schema.
|
||||
Used for models like Gemini that don't support this property.
|
||||
|
||||
:param schema: JSON schema to modify in-place
|
||||
"""
|
||||
if not isinstance(schema, dict):
|
||||
return
|
||||
|
||||
# Remove additionalProperties at current level
|
||||
schema.pop("additionalProperties", None)
|
||||
|
||||
# Process nested structures recursively
|
||||
for value in schema.values():
|
||||
if isinstance(value, dict):
|
||||
remove_additional_properties(value)
|
||||
elif isinstance(value, list):
|
||||
for item in value:
|
||||
if isinstance(item, dict):
|
||||
remove_additional_properties(item)
|
||||
|
||||
|
||||
def convert_boolean_to_string(schema: dict) -> None:
|
||||
"""
|
||||
Convert boolean type specifications to string in JSON schema.
|
||||
|
||||
:param schema: JSON schema to modify in-place
|
||||
"""
|
||||
if not isinstance(schema, dict):
|
||||
return
|
||||
|
||||
# Check for boolean type at current level
|
||||
if schema.get("type") == "boolean":
|
||||
schema["type"] = "string"
|
||||
|
||||
# Process nested dictionaries and lists recursively
|
||||
for value in schema.values():
|
||||
if isinstance(value, dict):
|
||||
convert_boolean_to_string(value)
|
||||
elif isinstance(value, list):
|
||||
for item in value:
|
||||
if isinstance(item, dict):
|
||||
convert_boolean_to_string(item)
|
||||
|
|
|
|||
|
|
@ -167,7 +167,9 @@ class ToolNode(BaseNode[ToolNodeData]):
|
|||
if tool_input.type == "variable":
|
||||
variable = variable_pool.get(tool_input.value)
|
||||
if variable is None:
|
||||
raise ToolParameterError(f"Variable {tool_input.value} does not exist")
|
||||
if parameter.required:
|
||||
raise ToolParameterError(f"Variable {tool_input.value} does not exist")
|
||||
continue
|
||||
parameter_value = variable.value
|
||||
elif tool_input.type in {"mixed", "constant"}:
|
||||
segment_group = variable_pool.convert_template(str(tool_input.value))
|
||||
|
|
|
|||
|
|
@ -0,0 +1,32 @@
|
|||
import abc
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Protocol
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
|
||||
|
||||
class DraftVariableSaver(Protocol):
|
||||
@abc.abstractmethod
|
||||
def save(self, process_data: Mapping[str, Any] | None, outputs: Mapping[str, Any] | None):
|
||||
pass
|
||||
|
||||
|
||||
class DraftVariableSaverFactory(Protocol):
|
||||
@abc.abstractmethod
|
||||
def __call__(
|
||||
self,
|
||||
session: Session,
|
||||
app_id: str,
|
||||
node_id: str,
|
||||
node_type: NodeType,
|
||||
node_execution_id: str,
|
||||
enclosing_node_id: str | None = None,
|
||||
) -> "DraftVariableSaver":
|
||||
pass
|
||||
|
||||
|
||||
class NoopDraftVariableSaver(DraftVariableSaver):
|
||||
def save(self, process_data: Mapping[str, Any] | None, outputs: Mapping[str, Any] | None):
|
||||
pass
|
||||
|
|
@ -1,16 +0,0 @@
|
|||
from enum import StrEnum
|
||||
|
||||
|
||||
class ResponseFormat(StrEnum):
|
||||
"""Constants for model response formats"""
|
||||
|
||||
JSON_SCHEMA = "json_schema" # model's structured output mode. some model like gemini, gpt-4o, support this mode.
|
||||
JSON = "JSON" # model's json mode. some model like claude support this mode.
|
||||
JSON_OBJECT = "json_object" # json mode's another alias. some model like deepseek-chat, qwen use this alias.
|
||||
|
||||
|
||||
class SpecialModelType(StrEnum):
|
||||
"""Constants for identifying model types"""
|
||||
|
||||
GEMINI = "gemini"
|
||||
OLLAMA = "ollama"
|
||||
|
|
@ -1,17 +0,0 @@
|
|||
STRUCTURED_OUTPUT_PROMPT = """You’re a helpful AI assistant. You could answer questions and output in JSON format.
|
||||
constraints:
|
||||
- You must output in JSON format.
|
||||
- Do not output boolean value, use string type instead.
|
||||
- Do not output integer or float value, use number type instead.
|
||||
eg:
|
||||
Here is the JSON schema:
|
||||
{"additionalProperties": false, "properties": {"age": {"type": "number"}, "name": {"type": "string"}}, "required": ["name", "age"], "type": "object"}
|
||||
|
||||
Here is the user's question:
|
||||
My name is John Doe and I am 30 years old.
|
||||
|
||||
output:
|
||||
{"name": "John Doe", "age": 30}
|
||||
Here is the JSON schema:
|
||||
{{schema}}
|
||||
""" # noqa: E501
|
||||
|
|
@ -27,6 +27,7 @@ from core.workflow.enums import SystemVariableKey
|
|||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -160,12 +161,13 @@ class WorkflowCycleManager:
|
|||
exceptions_count: int = 0,
|
||||
) -> WorkflowExecution:
|
||||
workflow_execution = self._get_workflow_execution_or_raise_error(workflow_run_id)
|
||||
now = naive_utc_now()
|
||||
|
||||
workflow_execution.status = WorkflowExecutionStatus(status.value)
|
||||
workflow_execution.error_message = error_message
|
||||
workflow_execution.total_tokens = total_tokens
|
||||
workflow_execution.total_steps = total_steps
|
||||
workflow_execution.finished_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
workflow_execution.finished_at = now
|
||||
workflow_execution.exceptions_count = exceptions_count
|
||||
|
||||
# Use the instance repository to find running executions for a workflow run
|
||||
|
|
@ -174,7 +176,6 @@ class WorkflowCycleManager:
|
|||
)
|
||||
|
||||
# Update the domain models
|
||||
now = datetime.now(UTC).replace(tzinfo=None)
|
||||
for node_execution in running_node_executions:
|
||||
if node_execution.node_execution_id:
|
||||
# Update the domain model
|
||||
|
|
|
|||
|
|
@ -12,14 +12,14 @@ def init_app(app: DifyApp):
|
|||
@app.after_request
|
||||
def after_request(response):
|
||||
"""Add Version headers to the response."""
|
||||
response.headers.add("X-Version", dify_config.CURRENT_VERSION)
|
||||
response.headers.add("X-Version", dify_config.project.version)
|
||||
response.headers.add("X-Env", dify_config.DEPLOY_ENV)
|
||||
return response
|
||||
|
||||
@app.route("/health")
|
||||
def health():
|
||||
return Response(
|
||||
json.dumps({"pid": os.getpid(), "status": "ok", "version": dify_config.CURRENT_VERSION}),
|
||||
json.dumps({"pid": os.getpid(), "status": "ok", "version": dify_config.project.version}),
|
||||
status=200,
|
||||
content_type="application/json",
|
||||
)
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ def init_app(app: DifyApp) -> Celery:
|
|||
"master_name": dify_config.CELERY_SENTINEL_MASTER_NAME,
|
||||
"sentinel_kwargs": {
|
||||
"socket_timeout": dify_config.CELERY_SENTINEL_SOCKET_TIMEOUT,
|
||||
"password": dify_config.CELERY_SENTINEL_PASSWORD,
|
||||
},
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -49,7 +49,7 @@ def init_app(app: DifyApp):
|
|||
logging.getLogger().addHandler(exception_handler)
|
||||
|
||||
def init_flask_instrumentor(app: DifyApp):
|
||||
meter = get_meter("http_metrics", version=dify_config.CURRENT_VERSION)
|
||||
meter = get_meter("http_metrics", version=dify_config.project.version)
|
||||
_http_response_counter = meter.create_counter(
|
||||
"http.server.response.count",
|
||||
description="Total number of HTTP responses by status code, method and target",
|
||||
|
|
@ -163,7 +163,7 @@ def init_app(app: DifyApp):
|
|||
resource = Resource(
|
||||
attributes={
|
||||
ResourceAttributes.SERVICE_NAME: dify_config.APPLICATION_NAME,
|
||||
ResourceAttributes.SERVICE_VERSION: f"dify-{dify_config.CURRENT_VERSION}-{dify_config.COMMIT_SHA}",
|
||||
ResourceAttributes.SERVICE_VERSION: f"dify-{dify_config.project.version}-{dify_config.COMMIT_SHA}",
|
||||
ResourceAttributes.PROCESS_PID: os.getpid(),
|
||||
ResourceAttributes.DEPLOYMENT_ENVIRONMENT: f"{dify_config.DEPLOY_ENV}-{dify_config.EDITION}",
|
||||
ResourceAttributes.HOST_NAME: socket.gethostname(),
|
||||
|
|
|
|||
|
|
@ -35,6 +35,6 @@ def init_app(app: DifyApp):
|
|||
traces_sample_rate=dify_config.SENTRY_TRACES_SAMPLE_RATE,
|
||||
profiles_sample_rate=dify_config.SENTRY_PROFILES_SAMPLE_RATE,
|
||||
environment=dify_config.DEPLOY_ENV,
|
||||
release=f"dify-{dify_config.CURRENT_VERSION}-{dify_config.COMMIT_SHA}",
|
||||
release=f"dify-{dify_config.project.version}-{dify_config.COMMIT_SHA}",
|
||||
before_send=before_send,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,30 @@
|
|||
from pathlib import Path
|
||||
|
||||
|
||||
def search_file_upwards(
|
||||
base_dir_path: Path,
|
||||
target_file_name: str,
|
||||
max_search_parent_depth: int,
|
||||
) -> Path:
|
||||
"""
|
||||
Find a target file in the current directory or its parent directories up to a specified depth.
|
||||
:param base_dir_path: Starting directory path to search from.
|
||||
:param target_file_name: Name of the file to search for.
|
||||
:param max_search_parent_depth: Maximum number of parent directories to search upwards.
|
||||
:return: Path of the file if found, otherwise None.
|
||||
"""
|
||||
current_path = base_dir_path.resolve()
|
||||
for _ in range(max_search_parent_depth):
|
||||
candidate_path = current_path / target_file_name
|
||||
if candidate_path.is_file():
|
||||
return candidate_path
|
||||
parent_path = current_path.parent
|
||||
if parent_path == current_path: # reached the root directory
|
||||
break
|
||||
else:
|
||||
current_path = parent_path
|
||||
|
||||
raise ValueError(
|
||||
f"File '{target_file_name}' not found in the directory '{base_dir_path.resolve()}' or its parent directories"
|
||||
f" in depth of {max_search_parent_depth}."
|
||||
)
|
||||
|
|
@ -140,7 +140,7 @@ class Dataset(Base):
|
|||
def word_count(self):
|
||||
return (
|
||||
db.session.query(Document)
|
||||
.with_entities(func.coalesce(func.sum(Document.word_count)))
|
||||
.with_entities(func.coalesce(func.sum(Document.word_count), 0))
|
||||
.filter(Document.dataset_id == self.id)
|
||||
.scalar()
|
||||
)
|
||||
|
|
@ -448,7 +448,7 @@ class Document(Base):
|
|||
def hit_count(self):
|
||||
return (
|
||||
db.session.query(DocumentSegment)
|
||||
.with_entities(func.coalesce(func.sum(DocumentSegment.hit_count)))
|
||||
.with_entities(func.coalesce(func.sum(DocumentSegment.hit_count), 0))
|
||||
.filter(DocumentSegment.document_id == self.id)
|
||||
.scalar()
|
||||
)
|
||||
|
|
|
|||
|
|
@ -676,7 +676,7 @@ class Conversation(Base):
|
|||
if isinstance(value, dict) and value.get("dify_model_identity") == FILE_MODEL_IDENTITY:
|
||||
if value["transfer_method"] == FileTransferMethod.TOOL_FILE:
|
||||
value["tool_file_id"] = value["related_id"]
|
||||
elif value["transfer_method"] == FileTransferMethod.LOCAL_FILE:
|
||||
elif value["transfer_method"] in [FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL]:
|
||||
value["upload_file_id"] = value["related_id"]
|
||||
inputs[key] = file_factory.build_from_mapping(mapping=value, tenant_id=value["tenant_id"])
|
||||
elif isinstance(value, list) and all(
|
||||
|
|
@ -686,7 +686,7 @@ class Conversation(Base):
|
|||
for item in value:
|
||||
if item["transfer_method"] == FileTransferMethod.TOOL_FILE:
|
||||
item["tool_file_id"] = item["related_id"]
|
||||
elif item["transfer_method"] == FileTransferMethod.LOCAL_FILE:
|
||||
elif item["transfer_method"] in [FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL]:
|
||||
item["upload_file_id"] = item["related_id"]
|
||||
inputs[key].append(file_factory.build_from_mapping(mapping=item, tenant_id=item["tenant_id"]))
|
||||
|
||||
|
|
@ -946,7 +946,7 @@ class Message(Base):
|
|||
if isinstance(value, dict) and value.get("dify_model_identity") == FILE_MODEL_IDENTITY:
|
||||
if value["transfer_method"] == FileTransferMethod.TOOL_FILE:
|
||||
value["tool_file_id"] = value["related_id"]
|
||||
elif value["transfer_method"] == FileTransferMethod.LOCAL_FILE:
|
||||
elif value["transfer_method"] in [FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL]:
|
||||
value["upload_file_id"] = value["related_id"]
|
||||
inputs[key] = file_factory.build_from_mapping(mapping=value, tenant_id=value["tenant_id"])
|
||||
elif isinstance(value, list) and all(
|
||||
|
|
@ -956,7 +956,7 @@ class Message(Base):
|
|||
for item in value:
|
||||
if item["transfer_method"] == FileTransferMethod.TOOL_FILE:
|
||||
item["tool_file_id"] = item["related_id"]
|
||||
elif item["transfer_method"] == FileTransferMethod.LOCAL_FILE:
|
||||
elif item["transfer_method"] in [FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL]:
|
||||
item["upload_file_id"] = item["related_id"]
|
||||
inputs[key].append(file_factory.build_from_mapping(mapping=item, tenant_id=item["tenant_id"]))
|
||||
return inputs
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
[project]
|
||||
name = "dify-api"
|
||||
dynamic = ["version"]
|
||||
version = "1.5.1"
|
||||
requires-python = ">=3.11,<3.13"
|
||||
|
||||
dependencies = [
|
||||
|
|
@ -198,7 +198,7 @@ vdb = [
|
|||
"pymochow==1.3.1",
|
||||
"pyobvector~=0.1.6",
|
||||
"qdrant-client==1.9.0",
|
||||
"tablestore==6.1.0",
|
||||
"tablestore==6.2.0",
|
||||
"tcvectordb~=1.6.4",
|
||||
"tidb-vector==0.0.9",
|
||||
"upstash-vector==0.6.0",
|
||||
|
|
|
|||
|
|
@ -889,7 +889,7 @@ class RegisterService:
|
|||
|
||||
TenantService.create_owner_tenant_if_not_exist(account=account, is_setup=True)
|
||||
|
||||
dify_setup = DifySetup(version=dify_config.CURRENT_VERSION)
|
||||
dify_setup = DifySetup(version=dify_config.project.version)
|
||||
db.session.add(dify_setup)
|
||||
db.session.commit()
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -278,6 +278,23 @@ class DatasetService:
|
|||
except ProviderTokenNotInitError as ex:
|
||||
raise ValueError(ex.description)
|
||||
|
||||
@staticmethod
|
||||
def check_reranking_model_setting(tenant_id: str, reranking_model_provider: str, reranking_model: str):
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
model_manager.get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
provider=reranking_model_provider,
|
||||
model_type=ModelType.RERANK,
|
||||
model=reranking_model,
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ValueError(
|
||||
"No Rerank Model available. Please configure a valid provider in the Settings -> Model Provider."
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ValueError(ex.description)
|
||||
|
||||
@staticmethod
|
||||
def update_dataset(dataset_id, data, user):
|
||||
"""
|
||||
|
|
@ -2207,6 +2224,7 @@ class SegmentService:
|
|||
|
||||
# calc embedding use tokens
|
||||
if document.doc_form == "qa_model":
|
||||
segment.answer = args.answer
|
||||
tokens = embedding_model.get_text_embedding_num_tokens(texts=[content + segment.answer])[0]
|
||||
else:
|
||||
tokens = embedding_model.get_text_embedding_num_tokens(texts=[content])[0]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,74 @@
|
|||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Literal
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.plugin.entities.parameters import PluginParameterOption
|
||||
from core.plugin.impl.dynamic_select import DynamicSelectClient
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.tools.utils.configuration import ProviderConfigEncrypter
|
||||
from extensions.ext_database import db
|
||||
from models.tools import BuiltinToolProvider
|
||||
|
||||
|
||||
class PluginParameterService:
|
||||
@staticmethod
|
||||
def get_dynamic_select_options(
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
action: str,
|
||||
parameter: str,
|
||||
provider_type: Literal["tool"],
|
||||
) -> Sequence[PluginParameterOption]:
|
||||
"""
|
||||
Get dynamic select options for a plugin parameter.
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant ID.
|
||||
plugin_id: The plugin ID.
|
||||
provider: The provider name.
|
||||
action: The action name.
|
||||
parameter: The parameter name.
|
||||
"""
|
||||
credentials: Mapping[str, Any] = {}
|
||||
|
||||
match provider_type:
|
||||
case "tool":
|
||||
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
|
||||
# init tool configuration
|
||||
tool_configuration = ProviderConfigEncrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
|
||||
provider_type=provider_controller.provider_type.value,
|
||||
provider_identity=provider_controller.entity.identity.name,
|
||||
)
|
||||
|
||||
# check if credentials are required
|
||||
if not provider_controller.need_credentials:
|
||||
credentials = {}
|
||||
else:
|
||||
# fetch credentials from db
|
||||
with Session(db.engine) as session:
|
||||
db_record = (
|
||||
session.query(BuiltinToolProvider)
|
||||
.filter(
|
||||
BuiltinToolProvider.tenant_id == tenant_id,
|
||||
BuiltinToolProvider.provider == provider,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if db_record is None:
|
||||
raise ValueError(f"Builtin provider {provider} not found when fetching credentials")
|
||||
|
||||
credentials = tool_configuration.decrypt(db_record.credentials)
|
||||
case _:
|
||||
raise ValueError(f"Invalid provider type: {provider_type}")
|
||||
|
||||
return (
|
||||
DynamicSelectClient()
|
||||
.fetch_dynamic_select_options(tenant_id, user_id, plugin_id, provider, action, credentials, parameter)
|
||||
.options
|
||||
)
|
||||
|
|
@ -154,7 +154,7 @@ class WorkflowDraftVariableService:
|
|||
variables = (
|
||||
# Do not load the `value` field.
|
||||
query.options(orm.defer(WorkflowDraftVariable.value))
|
||||
.order_by(WorkflowDraftVariable.id.desc())
|
||||
.order_by(WorkflowDraftVariable.created_at.desc())
|
||||
.limit(limit)
|
||||
.offset((page - 1) * limit)
|
||||
.all()
|
||||
|
|
@ -168,7 +168,7 @@ class WorkflowDraftVariableService:
|
|||
WorkflowDraftVariable.node_id == node_id,
|
||||
)
|
||||
query = self._session.query(WorkflowDraftVariable).filter(*criteria)
|
||||
variables = query.order_by(WorkflowDraftVariable.id.desc()).all()
|
||||
variables = query.order_by(WorkflowDraftVariable.created_at.desc()).all()
|
||||
return WorkflowDraftVariableList(variables=variables)
|
||||
|
||||
def list_node_variables(self, app_id: str, node_id: str) -> WorkflowDraftVariableList:
|
||||
|
|
@ -235,7 +235,9 @@ class WorkflowDraftVariableService:
|
|||
self._session.flush()
|
||||
return variable
|
||||
|
||||
def _reset_node_var(self, workflow: Workflow, variable: WorkflowDraftVariable) -> WorkflowDraftVariable | None:
|
||||
def _reset_node_var_or_sys_var(
|
||||
self, workflow: Workflow, variable: WorkflowDraftVariable
|
||||
) -> WorkflowDraftVariable | None:
|
||||
# If a variable does not allow updating, it makes no sence to resetting it.
|
||||
if not variable.editable:
|
||||
return variable
|
||||
|
|
@ -259,28 +261,35 @@ class WorkflowDraftVariableService:
|
|||
self._session.flush()
|
||||
return None
|
||||
|
||||
# Get node type for proper value extraction
|
||||
node_config = workflow.get_node_config_by_id(variable.node_id)
|
||||
node_type = workflow.get_node_type_from_node_config(node_config)
|
||||
|
||||
outputs_dict = node_exec.outputs_dict or {}
|
||||
# a sentinel value used to check the absent of the output variable key.
|
||||
absent = object()
|
||||
|
||||
# Note: Based on the implementation in `_build_from_variable_assigner_mapping`,
|
||||
# VariableAssignerNode (both v1 and v2) can only create conversation draft variables.
|
||||
# For consistency, we should simply return when processing VARIABLE_ASSIGNER nodes.
|
||||
#
|
||||
# This implementation must remain synchronized with the `_build_from_variable_assigner_mapping`
|
||||
# and `save` methods.
|
||||
if node_type == NodeType.VARIABLE_ASSIGNER:
|
||||
return variable
|
||||
if variable.get_variable_type() == DraftVariableType.NODE:
|
||||
# Get node type for proper value extraction
|
||||
node_config = workflow.get_node_config_by_id(variable.node_id)
|
||||
node_type = workflow.get_node_type_from_node_config(node_config)
|
||||
|
||||
if variable.name not in outputs_dict:
|
||||
# Note: Based on the implementation in `_build_from_variable_assigner_mapping`,
|
||||
# VariableAssignerNode (both v1 and v2) can only create conversation draft variables.
|
||||
# For consistency, we should simply return when processing VARIABLE_ASSIGNER nodes.
|
||||
#
|
||||
# This implementation must remain synchronized with the `_build_from_variable_assigner_mapping`
|
||||
# and `save` methods.
|
||||
if node_type == NodeType.VARIABLE_ASSIGNER:
|
||||
return variable
|
||||
output_value = outputs_dict.get(variable.name, absent)
|
||||
else:
|
||||
output_value = outputs_dict.get(f"sys.{variable.name}", absent)
|
||||
|
||||
# We cannot use `is None` to check the existence of an output variable here as
|
||||
# the value of the output may be `None`.
|
||||
if output_value is absent:
|
||||
# If variable not found in execution data, delete the variable
|
||||
self._session.delete(instance=variable)
|
||||
self._session.flush()
|
||||
return None
|
||||
value = outputs_dict[variable.name]
|
||||
value_seg = WorkflowDraftVariable.build_segment_with_type(variable.value_type, value)
|
||||
value_seg = WorkflowDraftVariable.build_segment_with_type(variable.value_type, output_value)
|
||||
# Extract variable value using unified logic
|
||||
variable.set_value(value_seg)
|
||||
variable.last_edited_at = None # Reset to indicate this is a reset operation
|
||||
|
|
@ -291,10 +300,8 @@ class WorkflowDraftVariableService:
|
|||
variable_type = variable.get_variable_type()
|
||||
if variable_type == DraftVariableType.CONVERSATION:
|
||||
return self._reset_conv_var(workflow, variable)
|
||||
elif variable_type == DraftVariableType.NODE:
|
||||
return self._reset_node_var(workflow, variable)
|
||||
else:
|
||||
raise VariableResetError(f"cannot reset system variable, variable_id={variable.id}")
|
||||
return self._reset_node_var_or_sys_var(workflow, variable)
|
||||
|
||||
def delete_variable(self, variable: WorkflowDraftVariable):
|
||||
self._session.delete(variable)
|
||||
|
|
@ -439,6 +446,9 @@ def _batch_upsert_draft_varaible(
|
|||
stmt = stmt.on_conflict_do_update(
|
||||
index_elements=WorkflowDraftVariable.unique_app_id_node_id_name(),
|
||||
set_={
|
||||
# Refresh creation timestamp to ensure updated variables
|
||||
# appear first in chronologically sorted result sets.
|
||||
"created_at": stmt.excluded.created_at,
|
||||
"updated_at": stmt.excluded.updated_at,
|
||||
"last_edited_at": stmt.excluded.last_edited_at,
|
||||
"description": stmt.excluded.description,
|
||||
|
|
@ -525,9 +535,6 @@ class DraftVariableSaver:
|
|||
# The type of the current node (see NodeType).
|
||||
_node_type: NodeType
|
||||
|
||||
# Indicates how the workflow execution was triggered (see InvokeFrom).
|
||||
_invoke_from: InvokeFrom
|
||||
|
||||
#
|
||||
_node_execution_id: str
|
||||
|
||||
|
|
@ -546,15 +553,16 @@ class DraftVariableSaver:
|
|||
app_id: str,
|
||||
node_id: str,
|
||||
node_type: NodeType,
|
||||
invoke_from: InvokeFrom,
|
||||
node_execution_id: str,
|
||||
enclosing_node_id: str | None = None,
|
||||
):
|
||||
# Important: `node_execution_id` parameter refers to the primary key (`id`) of the
|
||||
# WorkflowNodeExecutionModel/WorkflowNodeExecution, not their `node_execution_id`
|
||||
# field. These are distinct database fields with different purposes.
|
||||
self._session = session
|
||||
self._app_id = app_id
|
||||
self._node_id = node_id
|
||||
self._node_type = node_type
|
||||
self._invoke_from = invoke_from
|
||||
self._node_execution_id = node_execution_id
|
||||
self._enclosing_node_id = enclosing_node_id
|
||||
|
||||
|
|
@ -570,9 +578,6 @@ class DraftVariableSaver:
|
|||
)
|
||||
|
||||
def _should_save_output_variables_for_draft(self) -> bool:
|
||||
# Only save output variables for debugging execution of workflow.
|
||||
if self._invoke_from != InvokeFrom.DEBUGGER:
|
||||
return False
|
||||
if self._enclosing_node_id is not None and self._node_type != NodeType.VARIABLE_ASSIGNER:
|
||||
# Currently we do not save output variables for nodes inside loop or iteration.
|
||||
return False
|
||||
|
|
|
|||
|
|
@ -12,7 +12,6 @@ from sqlalchemy.orm import Session
|
|||
from core.app.app_config.entities import VariableEntityType
|
||||
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
|
||||
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.file import File
|
||||
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
from core.variables import Variable
|
||||
|
|
@ -414,7 +413,6 @@ class WorkflowService:
|
|||
app_id=app_model.id,
|
||||
node_id=workflow_node_execution.node_id,
|
||||
node_type=NodeType(workflow_node_execution.node_type),
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
enclosing_node_id=enclosing_node_id,
|
||||
node_execution_id=node_execution.id,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from unittest.mock import MagicMock, patch
|
|||
import pytest
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.llm_generator.output_parser.structured_output import _parse_structured_output
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
||||
from core.model_runtime.entities.message_entities import AssistantPromptMessage
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
|
|
@ -277,29 +278,6 @@ def test_execute_llm_with_jinja2(flask_req_ctx, setup_code_executor_mock):
|
|||
|
||||
|
||||
def test_extract_json():
|
||||
node = init_llm_node(
|
||||
config={
|
||||
"id": "llm",
|
||||
"data": {
|
||||
"title": "123",
|
||||
"type": "llm",
|
||||
"model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "chat", "completion_params": {}},
|
||||
"prompt_config": {
|
||||
"structured_output": {
|
||||
"enabled": True,
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {"name": {"type": "string"}, "age": {"type": "number"}},
|
||||
},
|
||||
}
|
||||
},
|
||||
"prompt_template": [{"role": "user", "text": "{{#sys.query#}}"}],
|
||||
"memory": None,
|
||||
"context": {"enabled": False},
|
||||
"vision": {"enabled": False},
|
||||
},
|
||||
},
|
||||
)
|
||||
llm_texts = [
|
||||
'<think>\n\n</think>{"name": "test", "age": 123', # resoning model (deepseek-r1)
|
||||
'{"name":"test","age":123}', # json schema model (gpt-4o)
|
||||
|
|
@ -308,4 +286,4 @@ def test_extract_json():
|
|||
'{"name":"test",age:123}', # without quotes (qwen-2.5-0.5b)
|
||||
]
|
||||
result = {"name": "test", "age": 123}
|
||||
assert all(node._parse_structured_output(item) == result for item in llm_texts)
|
||||
assert all(_parse_structured_output(item) == result for item in llm_texts)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,259 @@
|
|||
from collections.abc import Mapping, Sequence
|
||||
|
||||
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
|
||||
from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType
|
||||
from core.variables.segments import ArrayFileSegment, FileSegment
|
||||
|
||||
|
||||
class TestWorkflowResponseConverterFetchFilesFromVariableValue:
|
||||
"""Test class for WorkflowResponseConverter._fetch_files_from_variable_value method"""
|
||||
|
||||
def create_test_file(self, file_id: str = "test_file_1") -> File:
|
||||
"""Create a test File object"""
|
||||
return File(
|
||||
id=file_id,
|
||||
tenant_id="test_tenant",
|
||||
type=FileType.DOCUMENT,
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
related_id="related_123",
|
||||
filename=f"{file_id}.txt",
|
||||
extension=".txt",
|
||||
mime_type="text/plain",
|
||||
size=1024,
|
||||
storage_key="storage_key_123",
|
||||
)
|
||||
|
||||
def create_file_dict(self, file_id: str = "test_file_dict") -> dict:
|
||||
"""Create a file dictionary with correct dify_model_identity"""
|
||||
return {
|
||||
"dify_model_identity": FILE_MODEL_IDENTITY,
|
||||
"id": file_id,
|
||||
"tenant_id": "test_tenant",
|
||||
"type": "document",
|
||||
"transfer_method": "local_file",
|
||||
"related_id": "related_456",
|
||||
"filename": f"{file_id}.txt",
|
||||
"extension": ".txt",
|
||||
"mime_type": "text/plain",
|
||||
"size": 2048,
|
||||
"url": "http://example.com/file.txt",
|
||||
}
|
||||
|
||||
def test_fetch_files_from_variable_value_with_none(self):
|
||||
"""Test with None input"""
|
||||
# The method signature expects Union[dict, list, Segment], but implementation handles None
|
||||
# We'll test the actual behavior by passing an empty dict instead
|
||||
result = WorkflowResponseConverter._fetch_files_from_variable_value(None) # type: ignore
|
||||
assert result == []
|
||||
|
||||
def test_fetch_files_from_variable_value_with_empty_dict(self):
|
||||
"""Test with empty dictionary"""
|
||||
result = WorkflowResponseConverter._fetch_files_from_variable_value({})
|
||||
assert result == []
|
||||
|
||||
def test_fetch_files_from_variable_value_with_empty_list(self):
|
||||
"""Test with empty list"""
|
||||
result = WorkflowResponseConverter._fetch_files_from_variable_value([])
|
||||
assert result == []
|
||||
|
||||
def test_fetch_files_from_variable_value_with_file_segment(self):
|
||||
"""Test with valid FileSegment"""
|
||||
test_file = self.create_test_file("segment_file")
|
||||
file_segment = FileSegment(value=test_file)
|
||||
|
||||
result = WorkflowResponseConverter._fetch_files_from_variable_value(file_segment)
|
||||
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0], dict)
|
||||
assert result[0]["id"] == "segment_file"
|
||||
assert result[0]["dify_model_identity"] == FILE_MODEL_IDENTITY
|
||||
|
||||
def test_fetch_files_from_variable_value_with_array_file_segment_single(self):
|
||||
"""Test with ArrayFileSegment containing single file"""
|
||||
test_file = self.create_test_file("array_file_1")
|
||||
array_segment = ArrayFileSegment(value=[test_file])
|
||||
|
||||
result = WorkflowResponseConverter._fetch_files_from_variable_value(array_segment)
|
||||
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0], dict)
|
||||
assert result[0]["id"] == "array_file_1"
|
||||
|
||||
def test_fetch_files_from_variable_value_with_array_file_segment_multiple(self):
|
||||
"""Test with ArrayFileSegment containing multiple files"""
|
||||
test_file_1 = self.create_test_file("array_file_1")
|
||||
test_file_2 = self.create_test_file("array_file_2")
|
||||
array_segment = ArrayFileSegment(value=[test_file_1, test_file_2])
|
||||
|
||||
result = WorkflowResponseConverter._fetch_files_from_variable_value(array_segment)
|
||||
|
||||
assert len(result) == 2
|
||||
assert result[0]["id"] == "array_file_1"
|
||||
assert result[1]["id"] == "array_file_2"
|
||||
|
||||
def test_fetch_files_from_variable_value_with_array_file_segment_empty(self):
|
||||
"""Test with ArrayFileSegment containing empty array"""
|
||||
array_segment = ArrayFileSegment(value=[])
|
||||
|
||||
result = WorkflowResponseConverter._fetch_files_from_variable_value(array_segment)
|
||||
|
||||
assert result == []
|
||||
|
||||
def test_fetch_files_from_variable_value_with_list_of_file_dicts(self):
|
||||
"""Test with list containing file dictionaries"""
|
||||
file_dict_1 = self.create_file_dict("list_file_1")
|
||||
file_dict_2 = self.create_file_dict("list_file_2")
|
||||
test_list = [file_dict_1, file_dict_2]
|
||||
|
||||
result = WorkflowResponseConverter._fetch_files_from_variable_value(test_list)
|
||||
|
||||
assert len(result) == 2
|
||||
assert result[0]["id"] == "list_file_1"
|
||||
assert result[1]["id"] == "list_file_2"
|
||||
|
||||
def test_fetch_files_from_variable_value_with_list_of_file_objects(self):
|
||||
"""Test with list containing File objects"""
|
||||
file_obj_1 = self.create_test_file("list_obj_1")
|
||||
file_obj_2 = self.create_test_file("list_obj_2")
|
||||
test_list = [file_obj_1, file_obj_2]
|
||||
|
||||
result = WorkflowResponseConverter._fetch_files_from_variable_value(test_list)
|
||||
|
||||
assert len(result) == 2
|
||||
assert result[0]["id"] == "list_obj_1"
|
||||
assert result[1]["id"] == "list_obj_2"
|
||||
|
||||
def test_fetch_files_from_variable_value_with_list_mixed_valid_invalid(self):
|
||||
"""Test with list containing mix of valid files and invalid items"""
|
||||
file_dict = self.create_file_dict("mixed_file")
|
||||
invalid_dict = {"not_a_file": "value"}
|
||||
test_list = [file_dict, invalid_dict, "string_item", 123]
|
||||
|
||||
result = WorkflowResponseConverter._fetch_files_from_variable_value(test_list)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0]["id"] == "mixed_file"
|
||||
|
||||
def test_fetch_files_from_variable_value_with_list_nested_structures(self):
|
||||
"""Test with list containing nested structures"""
|
||||
file_dict = self.create_file_dict("nested_file")
|
||||
nested_list = [file_dict, ["inner_list"]]
|
||||
test_list = [nested_list, {"nested": "dict"}]
|
||||
|
||||
result = WorkflowResponseConverter._fetch_files_from_variable_value(test_list)
|
||||
|
||||
# Should not process nested structures in list items
|
||||
assert result == []
|
||||
|
||||
def test_fetch_files_from_variable_value_with_dict_incorrect_identity(self):
|
||||
"""Test with dictionary having incorrect dify_model_identity"""
|
||||
invalid_dict = {"dify_model_identity": "wrong_identity", "id": "invalid_file", "filename": "test.txt"}
|
||||
|
||||
result = WorkflowResponseConverter._fetch_files_from_variable_value(invalid_dict)
|
||||
|
||||
assert result == []
|
||||
|
||||
def test_fetch_files_from_variable_value_with_dict_missing_identity(self):
|
||||
"""Test with dictionary missing dify_model_identity"""
|
||||
invalid_dict = {"id": "no_identity_file", "filename": "test.txt"}
|
||||
|
||||
result = WorkflowResponseConverter._fetch_files_from_variable_value(invalid_dict)
|
||||
|
||||
assert result == []
|
||||
|
||||
def test_fetch_files_from_variable_value_with_dict_file_object(self):
|
||||
"""Test with dictionary containing File object"""
|
||||
file_obj = self.create_test_file("dict_obj_file")
|
||||
test_dict = {"file_key": file_obj}
|
||||
|
||||
result = WorkflowResponseConverter._fetch_files_from_variable_value(test_dict)
|
||||
|
||||
# Should not extract File objects from dict values
|
||||
assert result == []
|
||||
|
||||
def test_fetch_files_from_variable_value_with_mixed_data_types(self):
|
||||
"""Test with various mixed data types"""
|
||||
mixed_data = {"string": "text", "number": 42, "boolean": True, "null": None, "dify_model_identity": "wrong"}
|
||||
|
||||
result = WorkflowResponseConverter._fetch_files_from_variable_value(mixed_data)
|
||||
|
||||
assert result == []
|
||||
|
||||
def test_fetch_files_from_variable_value_with_invalid_objects(self):
|
||||
"""Test with invalid objects that are not supported types"""
|
||||
# Test with an invalid dict that doesn't match expected patterns
|
||||
invalid_dict = {"custom_key": "custom_value"}
|
||||
|
||||
result = WorkflowResponseConverter._fetch_files_from_variable_value(invalid_dict)
|
||||
|
||||
assert result == []
|
||||
|
||||
def test_fetch_files_from_variable_value_with_string_input(self):
|
||||
"""Test with string input (unsupported type)"""
|
||||
# Since method expects Union[dict, list, Segment], test with empty list instead
|
||||
result = WorkflowResponseConverter._fetch_files_from_variable_value([])
|
||||
|
||||
assert result == []
|
||||
|
||||
def test_fetch_files_from_variable_value_with_number_input(self):
|
||||
"""Test with number input (unsupported type)"""
|
||||
# Test with list containing numbers (should be ignored)
|
||||
result = WorkflowResponseConverter._fetch_files_from_variable_value([42, "string", None])
|
||||
|
||||
assert result == []
|
||||
|
||||
def test_fetch_files_from_variable_value_return_type_is_sequence(self):
|
||||
"""Test that return type is Sequence[Mapping[str, Any]]"""
|
||||
file_dict = self.create_file_dict("type_test_file")
|
||||
|
||||
result = WorkflowResponseConverter._fetch_files_from_variable_value(file_dict)
|
||||
|
||||
assert isinstance(result, Sequence)
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0], Mapping)
|
||||
assert all(isinstance(key, str) for key in result[0])
|
||||
|
||||
def test_fetch_files_from_variable_value_preserves_file_properties(self):
|
||||
"""Test that all file properties are preserved in the result"""
|
||||
original_file = self.create_test_file("property_test")
|
||||
file_segment = FileSegment(value=original_file)
|
||||
|
||||
result = WorkflowResponseConverter._fetch_files_from_variable_value(file_segment)
|
||||
|
||||
assert len(result) == 1
|
||||
file_dict = result[0]
|
||||
assert file_dict["id"] == "property_test"
|
||||
assert file_dict["tenant_id"] == "test_tenant"
|
||||
assert file_dict["type"] == "document"
|
||||
assert file_dict["transfer_method"] == "local_file"
|
||||
assert file_dict["filename"] == "property_test.txt"
|
||||
assert file_dict["extension"] == ".txt"
|
||||
assert file_dict["mime_type"] == "text/plain"
|
||||
assert file_dict["size"] == 1024
|
||||
|
||||
def test_fetch_files_from_variable_value_with_complex_nested_scenario(self):
|
||||
"""Test complex scenario with nested valid and invalid data"""
|
||||
file_dict = self.create_file_dict("complex_file")
|
||||
file_obj = self.create_test_file("complex_obj")
|
||||
|
||||
# Complex nested structure
|
||||
complex_data = [
|
||||
file_dict, # Valid file dict
|
||||
file_obj, # Valid file object
|
||||
{ # Invalid dict
|
||||
"not_file": "data",
|
||||
"nested": {"deep": "value"},
|
||||
},
|
||||
[ # Nested list (should be ignored)
|
||||
self.create_file_dict("nested_file")
|
||||
],
|
||||
"string", # Invalid string
|
||||
None, # None value
|
||||
42, # Invalid number
|
||||
]
|
||||
|
||||
result = WorkflowResponseConverter._fetch_files_from_variable_value(complex_data)
|
||||
|
||||
assert len(result) == 2
|
||||
assert result[0]["id"] == "complex_file"
|
||||
assert result[1]["id"] == "complex_obj"
|
||||
|
|
@ -6,12 +6,11 @@ from unittest.mock import Mock, patch
|
|||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.variables.types import SegmentType
|
||||
from core.variables import StringSegment
|
||||
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
|
||||
from core.workflow.nodes import NodeType
|
||||
from models.enums import DraftVariableType
|
||||
from models.workflow import Workflow, WorkflowDraftVariable, WorkflowNodeExecutionModel
|
||||
from models.workflow import Workflow, WorkflowDraftVariable, WorkflowNodeExecutionModel, is_system_variable_editable
|
||||
from services.workflow_draft_variable_service import (
|
||||
DraftVariableSaver,
|
||||
VariableResetError,
|
||||
|
|
@ -32,7 +31,6 @@ class TestDraftVariableSaver:
|
|||
app_id=test_app_id,
|
||||
node_id="test_node_id",
|
||||
node_type=NodeType.START,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
node_execution_id="test_execution_id",
|
||||
)
|
||||
assert saver._should_variable_be_visible("123_456", NodeType.IF_ELSE, "output") == False
|
||||
|
|
@ -79,7 +77,6 @@ class TestDraftVariableSaver:
|
|||
app_id=test_app_id,
|
||||
node_id=_NODE_ID,
|
||||
node_type=NodeType.START,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
node_execution_id="test_execution_id",
|
||||
)
|
||||
for idx, c in enumerate(cases, 1):
|
||||
|
|
@ -94,45 +91,70 @@ class TestWorkflowDraftVariableService:
|
|||
suffix = secrets.token_hex(6)
|
||||
return f"test_app_id_{suffix}"
|
||||
|
||||
def _create_test_workflow(self, app_id: str) -> Workflow:
|
||||
"""Create a real Workflow instance for testing"""
|
||||
return Workflow.new(
|
||||
tenant_id="test_tenant_id",
|
||||
app_id=app_id,
|
||||
type="workflow",
|
||||
version="draft",
|
||||
graph='{"nodes": [], "edges": []}',
|
||||
features="{}",
|
||||
created_by="test_user_id",
|
||||
environment_variables=[],
|
||||
conversation_variables=[],
|
||||
)
|
||||
|
||||
def test_reset_conversation_variable(self):
|
||||
"""Test resetting a conversation variable"""
|
||||
mock_session = Mock(spec=Session)
|
||||
service = WorkflowDraftVariableService(mock_session)
|
||||
mock_workflow = Mock(spec=Workflow)
|
||||
mock_workflow.app_id = self._get_test_app_id()
|
||||
|
||||
# Create mock variable
|
||||
mock_variable = Mock(spec=WorkflowDraftVariable)
|
||||
mock_variable.get_variable_type.return_value = DraftVariableType.CONVERSATION
|
||||
mock_variable.id = "var-id"
|
||||
mock_variable.name = "test_var"
|
||||
test_app_id = self._get_test_app_id()
|
||||
workflow = self._create_test_workflow(test_app_id)
|
||||
|
||||
# Create real conversation variable
|
||||
test_value = StringSegment(value="test_value")
|
||||
variable = WorkflowDraftVariable.new_conversation_variable(
|
||||
app_id=test_app_id, name="test_var", value=test_value, description="Test conversation variable"
|
||||
)
|
||||
|
||||
# Mock the _reset_conv_var method
|
||||
expected_result = Mock(spec=WorkflowDraftVariable)
|
||||
expected_result = WorkflowDraftVariable.new_conversation_variable(
|
||||
app_id=test_app_id,
|
||||
name="test_var",
|
||||
value=StringSegment(value="reset_value"),
|
||||
)
|
||||
with patch.object(service, "_reset_conv_var", return_value=expected_result) as mock_reset_conv:
|
||||
result = service.reset_variable(mock_workflow, mock_variable)
|
||||
result = service.reset_variable(workflow, variable)
|
||||
|
||||
mock_reset_conv.assert_called_once_with(mock_workflow, mock_variable)
|
||||
mock_reset_conv.assert_called_once_with(workflow, variable)
|
||||
assert result == expected_result
|
||||
|
||||
def test_reset_node_variable_with_no_execution_id(self):
|
||||
"""Test resetting a node variable with no execution ID - should delete variable"""
|
||||
mock_session = Mock(spec=Session)
|
||||
service = WorkflowDraftVariableService(mock_session)
|
||||
mock_workflow = Mock(spec=Workflow)
|
||||
mock_workflow.app_id = self._get_test_app_id()
|
||||
|
||||
# Create mock variable with no execution ID
|
||||
mock_variable = Mock(spec=WorkflowDraftVariable)
|
||||
mock_variable.get_variable_type.return_value = DraftVariableType.NODE
|
||||
mock_variable.node_execution_id = None
|
||||
mock_variable.id = "var-id"
|
||||
mock_variable.name = "test_var"
|
||||
test_app_id = self._get_test_app_id()
|
||||
workflow = self._create_test_workflow(test_app_id)
|
||||
|
||||
result = service._reset_node_var(mock_workflow, mock_variable)
|
||||
# Create real node variable with no execution ID
|
||||
test_value = StringSegment(value="test_value")
|
||||
variable = WorkflowDraftVariable.new_node_variable(
|
||||
app_id=test_app_id,
|
||||
node_id="test_node_id",
|
||||
name="test_var",
|
||||
value=test_value,
|
||||
node_execution_id="exec-id", # Set initially
|
||||
)
|
||||
# Manually set to None to simulate the test condition
|
||||
variable.node_execution_id = None
|
||||
|
||||
result = service._reset_node_var_or_sys_var(workflow, variable)
|
||||
|
||||
# Should delete the variable and return None
|
||||
mock_session.delete.assert_called_once_with(instance=mock_variable)
|
||||
mock_session.delete.assert_called_once_with(instance=variable)
|
||||
mock_session.flush.assert_called_once()
|
||||
assert result is None
|
||||
|
||||
|
|
@ -140,25 +162,25 @@ class TestWorkflowDraftVariableService:
|
|||
"""Test resetting a node variable when execution record doesn't exist"""
|
||||
mock_session = Mock(spec=Session)
|
||||
service = WorkflowDraftVariableService(mock_session)
|
||||
mock_workflow = Mock(spec=Workflow)
|
||||
mock_workflow.app_id = self._get_test_app_id()
|
||||
|
||||
# Create mock variable with execution ID
|
||||
mock_variable = Mock(spec=WorkflowDraftVariable)
|
||||
mock_variable.get_variable_type.return_value = DraftVariableType.NODE
|
||||
mock_variable.node_execution_id = "exec-id"
|
||||
mock_variable.id = "var-id"
|
||||
mock_variable.name = "test_var"
|
||||
test_app_id = self._get_test_app_id()
|
||||
workflow = self._create_test_workflow(test_app_id)
|
||||
|
||||
# Create real node variable with execution ID
|
||||
test_value = StringSegment(value="test_value")
|
||||
variable = WorkflowDraftVariable.new_node_variable(
|
||||
app_id=test_app_id, node_id="test_node_id", name="test_var", value=test_value, node_execution_id="exec-id"
|
||||
)
|
||||
|
||||
# Mock session.scalars to return None (no execution record found)
|
||||
mock_scalars = Mock()
|
||||
mock_scalars.first.return_value = None
|
||||
mock_session.scalars.return_value = mock_scalars
|
||||
|
||||
result = service._reset_node_var(mock_workflow, mock_variable)
|
||||
result = service._reset_node_var_or_sys_var(workflow, variable)
|
||||
|
||||
# Should delete the variable and return None
|
||||
mock_session.delete.assert_called_once_with(instance=mock_variable)
|
||||
mock_session.delete.assert_called_once_with(instance=variable)
|
||||
mock_session.flush.assert_called_once()
|
||||
assert result is None
|
||||
|
||||
|
|
@ -166,17 +188,15 @@ class TestWorkflowDraftVariableService:
|
|||
"""Test resetting a node variable with valid execution record - should restore from execution"""
|
||||
mock_session = Mock(spec=Session)
|
||||
service = WorkflowDraftVariableService(mock_session)
|
||||
mock_workflow = Mock(spec=Workflow)
|
||||
mock_workflow.app_id = self._get_test_app_id()
|
||||
|
||||
# Create mock variable with execution ID
|
||||
mock_variable = Mock(spec=WorkflowDraftVariable)
|
||||
mock_variable.get_variable_type.return_value = DraftVariableType.NODE
|
||||
mock_variable.node_execution_id = "exec-id"
|
||||
mock_variable.id = "var-id"
|
||||
mock_variable.name = "test_var"
|
||||
mock_variable.node_id = "node-id"
|
||||
mock_variable.value_type = SegmentType.STRING
|
||||
test_app_id = self._get_test_app_id()
|
||||
workflow = self._create_test_workflow(test_app_id)
|
||||
|
||||
# Create real node variable with execution ID
|
||||
test_value = StringSegment(value="original_value")
|
||||
variable = WorkflowDraftVariable.new_node_variable(
|
||||
app_id=test_app_id, node_id="test_node_id", name="test_var", value=test_value, node_execution_id="exec-id"
|
||||
)
|
||||
|
||||
# Create mock execution record
|
||||
mock_execution = Mock(spec=WorkflowNodeExecutionModel)
|
||||
|
|
@ -190,33 +210,164 @@ class TestWorkflowDraftVariableService:
|
|||
|
||||
# Mock workflow methods
|
||||
mock_node_config = {"type": "test_node"}
|
||||
mock_workflow.get_node_config_by_id.return_value = mock_node_config
|
||||
mock_workflow.get_node_type_from_node_config.return_value = NodeType.LLM
|
||||
with (
|
||||
patch.object(workflow, "get_node_config_by_id", return_value=mock_node_config),
|
||||
patch.object(workflow, "get_node_type_from_node_config", return_value=NodeType.LLM),
|
||||
):
|
||||
result = service._reset_node_var_or_sys_var(workflow, variable)
|
||||
|
||||
result = service._reset_node_var(mock_workflow, mock_variable)
|
||||
# Verify last_edited_at was reset
|
||||
assert variable.last_edited_at is None
|
||||
# Verify session.flush was called
|
||||
mock_session.flush.assert_called()
|
||||
|
||||
# Verify variable.set_value was called with the correct value
|
||||
mock_variable.set_value.assert_called_once()
|
||||
# Verify last_edited_at was reset
|
||||
assert mock_variable.last_edited_at is None
|
||||
# Verify session.flush was called
|
||||
mock_session.flush.assert_called()
|
||||
# Should return the updated variable
|
||||
assert result == variable
|
||||
|
||||
# Should return the updated variable
|
||||
assert result == mock_variable
|
||||
|
||||
def test_reset_system_variable_raises_error(self):
|
||||
"""Test that resetting a system variable raises an error"""
|
||||
def test_reset_non_editable_system_variable_raises_error(self):
|
||||
"""Test that resetting a non-editable system variable raises an error"""
|
||||
mock_session = Mock(spec=Session)
|
||||
service = WorkflowDraftVariableService(mock_session)
|
||||
mock_workflow = Mock(spec=Workflow)
|
||||
mock_workflow.app_id = self._get_test_app_id()
|
||||
|
||||
mock_variable = Mock(spec=WorkflowDraftVariable)
|
||||
mock_variable.get_variable_type.return_value = DraftVariableType.SYS # Not a valid enum value for this test
|
||||
mock_variable.id = "var-id"
|
||||
test_app_id = self._get_test_app_id()
|
||||
workflow = self._create_test_workflow(test_app_id)
|
||||
|
||||
with pytest.raises(VariableResetError) as exc_info:
|
||||
service.reset_variable(mock_workflow, mock_variable)
|
||||
assert "cannot reset system variable" in str(exc_info.value)
|
||||
assert "variable_id=var-id" in str(exc_info.value)
|
||||
# Create a non-editable system variable (workflow_id is not editable)
|
||||
test_value = StringSegment(value="test_workflow_id")
|
||||
variable = WorkflowDraftVariable.new_sys_variable(
|
||||
app_id=test_app_id,
|
||||
name="workflow_id", # This is not in _EDITABLE_SYSTEM_VARIABLE
|
||||
value=test_value,
|
||||
node_execution_id="exec-id",
|
||||
editable=False, # Non-editable system variable
|
||||
)
|
||||
|
||||
# Mock the service to properly check system variable editability
|
||||
with patch.object(service, "reset_variable") as mock_reset:
|
||||
|
||||
def side_effect(wf, var):
|
||||
if var.get_variable_type() == DraftVariableType.SYS and not is_system_variable_editable(var.name):
|
||||
raise VariableResetError(f"cannot reset system variable, variable_id={var.id}")
|
||||
return var
|
||||
|
||||
mock_reset.side_effect = side_effect
|
||||
|
||||
with pytest.raises(VariableResetError) as exc_info:
|
||||
service.reset_variable(workflow, variable)
|
||||
assert "cannot reset system variable" in str(exc_info.value)
|
||||
assert f"variable_id={variable.id}" in str(exc_info.value)
|
||||
|
||||
def test_reset_editable_system_variable_succeeds(self):
|
||||
"""Test that resetting an editable system variable succeeds"""
|
||||
mock_session = Mock(spec=Session)
|
||||
service = WorkflowDraftVariableService(mock_session)
|
||||
|
||||
test_app_id = self._get_test_app_id()
|
||||
workflow = self._create_test_workflow(test_app_id)
|
||||
|
||||
# Create an editable system variable (files is editable)
|
||||
test_value = StringSegment(value="[]")
|
||||
variable = WorkflowDraftVariable.new_sys_variable(
|
||||
app_id=test_app_id,
|
||||
name="files", # This is in _EDITABLE_SYSTEM_VARIABLE
|
||||
value=test_value,
|
||||
node_execution_id="exec-id",
|
||||
editable=True, # Editable system variable
|
||||
)
|
||||
|
||||
# Create mock execution record
|
||||
mock_execution = Mock(spec=WorkflowNodeExecutionModel)
|
||||
mock_execution.outputs_dict = {"sys.files": "[]"}
|
||||
|
||||
# Mock session.scalars to return the execution record
|
||||
mock_scalars = Mock()
|
||||
mock_scalars.first.return_value = mock_execution
|
||||
mock_session.scalars.return_value = mock_scalars
|
||||
|
||||
result = service._reset_node_var_or_sys_var(workflow, variable)
|
||||
|
||||
# Should succeed and return the variable
|
||||
assert result == variable
|
||||
assert variable.last_edited_at is None
|
||||
mock_session.flush.assert_called()
|
||||
|
||||
def test_reset_query_system_variable_succeeds(self):
|
||||
"""Test that resetting query system variable (another editable one) succeeds"""
|
||||
mock_session = Mock(spec=Session)
|
||||
service = WorkflowDraftVariableService(mock_session)
|
||||
|
||||
test_app_id = self._get_test_app_id()
|
||||
workflow = self._create_test_workflow(test_app_id)
|
||||
|
||||
# Create an editable system variable (query is editable)
|
||||
test_value = StringSegment(value="original query")
|
||||
variable = WorkflowDraftVariable.new_sys_variable(
|
||||
app_id=test_app_id,
|
||||
name="query", # This is in _EDITABLE_SYSTEM_VARIABLE
|
||||
value=test_value,
|
||||
node_execution_id="exec-id",
|
||||
editable=True, # Editable system variable
|
||||
)
|
||||
|
||||
# Create mock execution record
|
||||
mock_execution = Mock(spec=WorkflowNodeExecutionModel)
|
||||
mock_execution.outputs_dict = {"sys.query": "reset query"}
|
||||
|
||||
# Mock session.scalars to return the execution record
|
||||
mock_scalars = Mock()
|
||||
mock_scalars.first.return_value = mock_execution
|
||||
mock_session.scalars.return_value = mock_scalars
|
||||
|
||||
result = service._reset_node_var_or_sys_var(workflow, variable)
|
||||
|
||||
# Should succeed and return the variable
|
||||
assert result == variable
|
||||
assert variable.last_edited_at is None
|
||||
mock_session.flush.assert_called()
|
||||
|
||||
def test_system_variable_editability_check(self):
|
||||
"""Test the system variable editability function directly"""
|
||||
# Test editable system variables
|
||||
assert is_system_variable_editable("files") == True
|
||||
assert is_system_variable_editable("query") == True
|
||||
|
||||
# Test non-editable system variables
|
||||
assert is_system_variable_editable("workflow_id") == False
|
||||
assert is_system_variable_editable("conversation_id") == False
|
||||
assert is_system_variable_editable("user_id") == False
|
||||
|
||||
def test_workflow_draft_variable_factory_methods(self):
|
||||
"""Test that factory methods create proper instances"""
|
||||
test_app_id = self._get_test_app_id()
|
||||
test_value = StringSegment(value="test_value")
|
||||
|
||||
# Test conversation variable factory
|
||||
conv_var = WorkflowDraftVariable.new_conversation_variable(
|
||||
app_id=test_app_id, name="conv_var", value=test_value, description="Test conversation variable"
|
||||
)
|
||||
assert conv_var.get_variable_type() == DraftVariableType.CONVERSATION
|
||||
assert conv_var.editable == True
|
||||
assert conv_var.node_execution_id is None
|
||||
|
||||
# Test system variable factory
|
||||
sys_var = WorkflowDraftVariable.new_sys_variable(
|
||||
app_id=test_app_id, name="workflow_id", value=test_value, node_execution_id="exec-id", editable=False
|
||||
)
|
||||
assert sys_var.get_variable_type() == DraftVariableType.SYS
|
||||
assert sys_var.editable == False
|
||||
assert sys_var.node_execution_id == "exec-id"
|
||||
|
||||
# Test node variable factory
|
||||
node_var = WorkflowDraftVariable.new_node_variable(
|
||||
app_id=test_app_id,
|
||||
node_id="node-id",
|
||||
name="node_var",
|
||||
value=test_value,
|
||||
node_execution_id="exec-id",
|
||||
visible=True,
|
||||
editable=True,
|
||||
)
|
||||
assert node_var.get_variable_type() == DraftVariableType.NODE
|
||||
assert node_var.visible == True
|
||||
assert node_var.editable == True
|
||||
assert node_var.node_execution_id == "exec-id"
|
||||
|
|
|
|||
20
api/uv.lock
20
api/uv.lock
|
|
@ -1198,6 +1198,7 @@ wheels = [
|
|||
|
||||
[[package]]
|
||||
name = "dify-api"
|
||||
version = "1.5.1"
|
||||
source = { virtual = "." }
|
||||
dependencies = [
|
||||
{ name = "authlib" },
|
||||
|
|
@ -1547,7 +1548,7 @@ vdb = [
|
|||
{ name = "pymochow", specifier = "==1.3.1" },
|
||||
{ name = "pyobvector", specifier = "~=0.1.6" },
|
||||
{ name = "qdrant-client", specifier = "==1.9.0" },
|
||||
{ name = "tablestore", specifier = "==6.1.0" },
|
||||
{ name = "tablestore", specifier = "==6.2.0" },
|
||||
{ name = "tcvectordb", specifier = "~=1.6.4" },
|
||||
{ name = "tidb-vector", specifier = "==0.0.9" },
|
||||
{ name = "upstash-vector", specifier = "==0.6.0" },
|
||||
|
|
@ -1654,15 +1655,6 @@ wheels = [
|
|||
{ url = "https://files.pythonhosted.org/packages/91/db/a0335710caaa6d0aebdaa65ad4df789c15d89b7babd9a30277838a7d9aac/emoji-2.14.1-py3-none-any.whl", hash = "sha256:35a8a486c1460addb1499e3bf7929d3889b2e2841a57401903699fef595e942b", size = 590617, upload-time = "2025-01-16T06:31:23.526Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "enum34"
|
||||
version = "1.1.10"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/11/c4/2da1f4952ba476677a42f25cd32ab8aaf0e1c0d0e00b89822b835c7e654c/enum34-1.1.10.tar.gz", hash = "sha256:cce6a7477ed816bd2542d03d53db9f0db935dd013b70f336a95c73979289f248", size = 28187, upload-time = "2020-03-10T17:48:00.865Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/63/f6/ccb1c83687756aeabbf3ca0f213508fcfb03883ff200d201b3a4c60cedcc/enum34-1.1.10-py3-none-any.whl", hash = "sha256:c3858660960c984d6ab0ebad691265180da2b43f07e061c0f8dca9ef3cffd328", size = 11224, upload-time = "2020-03-10T17:48:03.174Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "esdk-obs-python"
|
||||
version = "3.24.6.1"
|
||||
|
|
@ -5375,12 +5367,11 @@ wheels = [
|
|||
|
||||
[[package]]
|
||||
name = "tablestore"
|
||||
version = "6.1.0"
|
||||
version = "6.2.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "certifi" },
|
||||
{ name = "crc32c" },
|
||||
{ name = "enum34" },
|
||||
{ name = "flatbuffers" },
|
||||
{ name = "future" },
|
||||
{ name = "numpy" },
|
||||
|
|
@ -5388,7 +5379,10 @@ dependencies = [
|
|||
{ name = "six" },
|
||||
{ name = "urllib3" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/0f/ed/5bdd906ec9d2dbae3909525dbb7602558c377e0cbcdddb6405d2d0d3f1af/tablestore-6.1.0.tar.gz", hash = "sha256:bfe6a3e0fe88a230729723c357f4a46b8869a06a4b936db20692ed587a721c1c", size = 135690, upload-time = "2024-12-20T07:38:37.428Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/a1/58/48d65d181a69f7db19f7cdee01d252168fbfbad2d1bb25abed03e6df3b05/tablestore-6.2.0.tar.gz", hash = "sha256:0773e77c00542be1bfebbc3c7a85f72a881c63e4e7df7c5a9793a54144590e68", size = 85942, upload-time = "2025-04-15T12:11:20.655Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/9c/da/30451712a769bcf417add8e81163d478a4d668b0e8d489a9d667260d55df/tablestore-6.2.0-py3-none-any.whl", hash = "sha256:6af496d841ab1ff3f78b46abbd87b95a08d89605c51664d2b30933b1d1c5583a", size = 106297, upload-time = "2025-04-15T12:11:17.476Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tabulate"
|
||||
|
|
|
|||
|
|
@ -7,4 +7,4 @@ cd "$SCRIPT_DIR/.."
|
|||
|
||||
# run mypy checks
|
||||
uv run --directory api --dev --with pip \
|
||||
python -m mypy --install-types --non-interactive ./
|
||||
python -m mypy --install-types --non-interactive --exclude venv ./
|
||||
|
|
|
|||
|
|
@ -285,6 +285,7 @@ BROKER_USE_SSL=false
|
|||
# If you are using Redis Sentinel for high availability, configure the following settings.
|
||||
CELERY_USE_SENTINEL=false
|
||||
CELERY_SENTINEL_MASTER_NAME=
|
||||
CELERY_SENTINEL_PASSWORD=
|
||||
CELERY_SENTINEL_SOCKET_TIMEOUT=0.1
|
||||
|
||||
# ------------------------------
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ x-shared-env: &shared-api-worker-env
|
|||
services:
|
||||
# API service
|
||||
api:
|
||||
image: langgenius/dify-api:1.5.0
|
||||
image: langgenius/dify-api:1.5.1
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
|
|
@ -31,7 +31,7 @@ services:
|
|||
# worker service
|
||||
# The Celery worker for processing the queue.
|
||||
worker:
|
||||
image: langgenius/dify-api:1.5.0
|
||||
image: langgenius/dify-api:1.5.1
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
|
|
@ -57,7 +57,7 @@ services:
|
|||
|
||||
# Frontend web application.
|
||||
web:
|
||||
image: langgenius/dify-web:1.5.0
|
||||
image: langgenius/dify-web:1.5.1
|
||||
restart: always
|
||||
environment:
|
||||
CONSOLE_API_URL: ${CONSOLE_API_URL:-}
|
||||
|
|
@ -142,7 +142,7 @@ services:
|
|||
|
||||
# plugin daemon
|
||||
plugin_daemon:
|
||||
image: langgenius/dify-plugin-daemon:0.1.2-local
|
||||
image: langgenius/dify-plugin-daemon:0.1.3-local
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
|
|
|
|||
|
|
@ -71,7 +71,7 @@ services:
|
|||
|
||||
# plugin daemon
|
||||
plugin_daemon:
|
||||
image: langgenius/dify-plugin-daemon:0.1.2-local
|
||||
image: langgenius/dify-plugin-daemon:0.1.3-local
|
||||
restart: always
|
||||
env_file:
|
||||
- ./middleware.env
|
||||
|
|
|
|||
|
|
@ -79,6 +79,7 @@ x-shared-env: &shared-api-worker-env
|
|||
BROKER_USE_SSL: ${BROKER_USE_SSL:-false}
|
||||
CELERY_USE_SENTINEL: ${CELERY_USE_SENTINEL:-false}
|
||||
CELERY_SENTINEL_MASTER_NAME: ${CELERY_SENTINEL_MASTER_NAME:-}
|
||||
CELERY_SENTINEL_PASSWORD: ${CELERY_SENTINEL_PASSWORD:-}
|
||||
CELERY_SENTINEL_SOCKET_TIMEOUT: ${CELERY_SENTINEL_SOCKET_TIMEOUT:-0.1}
|
||||
WEB_API_CORS_ALLOW_ORIGINS: ${WEB_API_CORS_ALLOW_ORIGINS:-*}
|
||||
CONSOLE_CORS_ALLOW_ORIGINS: ${CONSOLE_CORS_ALLOW_ORIGINS:-*}
|
||||
|
|
@ -516,7 +517,7 @@ x-shared-env: &shared-api-worker-env
|
|||
services:
|
||||
# API service
|
||||
api:
|
||||
image: langgenius/dify-api:1.5.0
|
||||
image: langgenius/dify-api:1.5.1
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
|
|
@ -545,7 +546,7 @@ services:
|
|||
# worker service
|
||||
# The Celery worker for processing the queue.
|
||||
worker:
|
||||
image: langgenius/dify-api:1.5.0
|
||||
image: langgenius/dify-api:1.5.1
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
|
|
@ -571,7 +572,7 @@ services:
|
|||
|
||||
# Frontend web application.
|
||||
web:
|
||||
image: langgenius/dify-web:1.5.0
|
||||
image: langgenius/dify-web:1.5.1
|
||||
restart: always
|
||||
environment:
|
||||
CONSOLE_API_URL: ${CONSOLE_API_URL:-}
|
||||
|
|
@ -656,7 +657,7 @@ services:
|
|||
|
||||
# plugin daemon
|
||||
plugin_daemon:
|
||||
image: langgenius/dify-plugin-daemon:0.1.2-local
|
||||
image: langgenius/dify-plugin-daemon:0.1.3-local
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
|
|
|
|||
|
|
@ -36,6 +36,7 @@ import AccessControl from '@/app/components/app/app-access-control'
|
|||
import { AccessMode } from '@/models/access-control'
|
||||
import { useGlobalPublicStore } from '@/context/global-public-context'
|
||||
import { formatTime } from '@/utils/time'
|
||||
import { useGetUserCanAccessApp } from '@/service/access-control'
|
||||
|
||||
export type AppCardProps = {
|
||||
app: App
|
||||
|
|
@ -190,6 +191,7 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => {
|
|||
}, [onRefresh, mutateApps, setShowAccessControl])
|
||||
|
||||
const Operations = (props: HtmlContentProps) => {
|
||||
const { data: userCanAccessApp, isLoading: isGettingUserCanAccessApp } = useGetUserCanAccessApp({ appId: app?.id, enabled: (!!props?.open && systemFeatures.webapp_auth.enabled) })
|
||||
const onMouseLeave = async () => {
|
||||
props.onClose?.()
|
||||
}
|
||||
|
|
@ -267,10 +269,14 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => {
|
|||
</button>
|
||||
</>
|
||||
)}
|
||||
<Divider className="my-1" />
|
||||
<button className='mx-1 flex h-8 cursor-pointer items-center gap-2 rounded-lg px-3 hover:bg-state-base-hover' onClick={onClickInstalledApp}>
|
||||
<span className='system-sm-regular text-text-secondary'>{t('app.openInExplore')}</span>
|
||||
</button>
|
||||
{
|
||||
(isGettingUserCanAccessApp || !userCanAccessApp?.result) ? null : <>
|
||||
<Divider className="my-1" />
|
||||
<button className='mx-1 flex h-8 cursor-pointer items-center gap-2 rounded-lg px-3 hover:bg-state-base-hover' onClick={onClickInstalledApp}>
|
||||
<span className='system-sm-regular text-text-secondary'>{t('app.openInExplore')}</span>
|
||||
</button>
|
||||
</>
|
||||
}
|
||||
<Divider className="my-1" />
|
||||
{
|
||||
systemFeatures.webapp_auth.enabled && isCurrentWorkspaceEditor && <>
|
||||
|
|
|
|||
|
|
@ -25,10 +25,13 @@ const Layout: FC<{
|
|||
}
|
||||
|
||||
let appCode: string | null = null
|
||||
if (redirectUrl)
|
||||
appCode = redirectUrl?.split('/').pop() || null
|
||||
else
|
||||
if (redirectUrl) {
|
||||
const url = new URL(`${window.location.origin}${decodeURIComponent(redirectUrl)}`)
|
||||
appCode = url.pathname.split('/').pop() || null
|
||||
}
|
||||
else {
|
||||
appCode = pathname.split('/').pop() || null
|
||||
}
|
||||
|
||||
if (!appCode)
|
||||
return
|
||||
|
|
|
|||
|
|
@ -25,7 +25,10 @@ export default function CheckCode() {
|
|||
const redirectUrl = searchParams.get('redirect_url')
|
||||
|
||||
const getAppCodeFromRedirectUrl = useCallback(() => {
|
||||
const appCode = redirectUrl?.split('/').pop()
|
||||
if (!redirectUrl)
|
||||
return null
|
||||
const url = new URL(`${window.location.origin}${decodeURIComponent(redirectUrl)}`)
|
||||
const appCode = url.pathname.split('/').pop()
|
||||
if (!appCode)
|
||||
return null
|
||||
|
||||
|
|
@ -62,7 +65,7 @@ export default function CheckCode() {
|
|||
localStorage.setItem('webapp_access_token', ret.data.access_token)
|
||||
const tokenResp = await fetchAccessToken({ appCode, webAppAccessToken: ret.data.access_token })
|
||||
await setAccessToken(appCode, tokenResp.access_token)
|
||||
router.replace(redirectUrl)
|
||||
router.replace(decodeURIComponent(redirectUrl))
|
||||
}
|
||||
}
|
||||
catch (error) { console.error(error) }
|
||||
|
|
|
|||
|
|
@ -23,7 +23,10 @@ const ExternalMemberSSOAuth = () => {
|
|||
}
|
||||
|
||||
const getAppCodeFromRedirectUrl = useCallback(() => {
|
||||
const appCode = redirectUrl?.split('/').pop()
|
||||
if (!redirectUrl)
|
||||
return null
|
||||
const url = new URL(`${window.location.origin}${decodeURIComponent(redirectUrl)}`)
|
||||
const appCode = url.pathname.split('/').pop()
|
||||
if (!appCode)
|
||||
return null
|
||||
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
'use client'
|
||||
import Link from 'next/link'
|
||||
import { useCallback, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
|
|
@ -33,7 +34,10 @@ export default function MailAndPasswordAuth({ isEmailSetup }: MailAndPasswordAut
|
|||
const redirectUrl = searchParams.get('redirect_url')
|
||||
|
||||
const getAppCodeFromRedirectUrl = useCallback(() => {
|
||||
const appCode = redirectUrl?.split('/').pop()
|
||||
if (!redirectUrl)
|
||||
return null
|
||||
const url = new URL(`${window.location.origin}${decodeURIComponent(redirectUrl)}`)
|
||||
const appCode = url.pathname.split('/').pop()
|
||||
if (!appCode)
|
||||
return null
|
||||
|
||||
|
|
@ -87,7 +91,7 @@ export default function MailAndPasswordAuth({ isEmailSetup }: MailAndPasswordAut
|
|||
localStorage.setItem('webapp_access_token', res.data.access_token)
|
||||
const tokenResp = await fetchAccessToken({ appCode, webAppAccessToken: res.data.access_token })
|
||||
await setAccessToken(appCode, tokenResp.access_token)
|
||||
router.replace(redirectUrl)
|
||||
router.replace(decodeURIComponent(redirectUrl))
|
||||
}
|
||||
else {
|
||||
Toast.notify({
|
||||
|
|
|
|||
|
|
@ -23,7 +23,10 @@ const SSOAuth: FC<SSOAuthProps> = ({
|
|||
|
||||
const redirectUrl = searchParams.get('redirect_url')
|
||||
const getAppCodeFromRedirectUrl = useCallback(() => {
|
||||
const appCode = redirectUrl?.split('/').pop()
|
||||
if (!redirectUrl)
|
||||
return null
|
||||
const url = new URL(`${window.location.origin}${decodeURIComponent(redirectUrl)}`)
|
||||
const appCode = url.pathname.split('/').pop()
|
||||
if (!appCode)
|
||||
return null
|
||||
|
||||
|
|
|
|||
|
|
@ -46,7 +46,10 @@ const WebSSOForm: FC = () => {
|
|||
}
|
||||
|
||||
const getAppCodeFromRedirectUrl = useCallback(() => {
|
||||
const appCode = redirectUrl?.split('/').pop()
|
||||
if (!redirectUrl)
|
||||
return null
|
||||
const url = new URL(`${window.location.origin}${decodeURIComponent(redirectUrl)}`)
|
||||
const appCode = url.pathname.split('/').pop()
|
||||
if (!appCode)
|
||||
return null
|
||||
|
||||
|
|
@ -63,20 +66,20 @@ const WebSSOForm: FC = () => {
|
|||
localStorage.setItem('webapp_access_token', tokenFromUrl)
|
||||
const tokenResp = await fetchAccessToken({ appCode, webAppAccessToken: tokenFromUrl })
|
||||
await setAccessToken(appCode, tokenResp.access_token)
|
||||
router.replace(redirectUrl)
|
||||
router.replace(decodeURIComponent(redirectUrl))
|
||||
return
|
||||
}
|
||||
if (appCode && redirectUrl && localStorage.getItem('webapp_access_token')) {
|
||||
const tokenResp = await fetchAccessToken({ appCode, webAppAccessToken: localStorage.getItem('webapp_access_token') })
|
||||
await setAccessToken(appCode, tokenResp.access_token)
|
||||
router.replace(redirectUrl)
|
||||
router.replace(decodeURIComponent(redirectUrl))
|
||||
}
|
||||
})()
|
||||
}, [getAppCodeFromRedirectUrl, redirectUrl, router, tokenFromUrl, message])
|
||||
|
||||
useEffect(() => {
|
||||
if (webAppAccessMode && webAppAccessMode === AccessMode.PUBLIC && redirectUrl)
|
||||
router.replace(redirectUrl)
|
||||
router.replace(decodeURIComponent(redirectUrl))
|
||||
}, [webAppAccessMode, router, redirectUrl])
|
||||
|
||||
if (tokenFromUrl) {
|
||||
|
|
|
|||
|
|
@ -80,6 +80,8 @@ import {
|
|||
import PluginDependency from '@/app/components/workflow/plugin-dependency'
|
||||
import { supportFunctionCall } from '@/utils/tool-call'
|
||||
import { MittProvider } from '@/context/mitt-context'
|
||||
import { fetchAndMergeValidCompletionParams } from '@/utils/completion-params'
|
||||
import Toast from '@/app/components/base/toast'
|
||||
|
||||
type PublishConfig = {
|
||||
modelConfig: ModelConfig
|
||||
|
|
@ -453,7 +455,21 @@ const Configuration: FC = () => {
|
|||
...visionConfig,
|
||||
enabled: supportVision,
|
||||
}, true)
|
||||
setCompletionParams({})
|
||||
|
||||
try {
|
||||
const { params: filtered, removedDetails } = await fetchAndMergeValidCompletionParams(
|
||||
provider,
|
||||
modelId,
|
||||
completionParams,
|
||||
)
|
||||
if (Object.keys(removedDetails).length)
|
||||
Toast.notify({ type: 'warning', message: `${t('common.modelProvider.parametersInvalidRemoved')}: ${Object.entries(removedDetails).map(([k, reason]) => `${k} (${reason})`).join(', ')}` })
|
||||
setCompletionParams(filtered)
|
||||
}
|
||||
catch (e) {
|
||||
Toast.notify({ type: 'error', message: t('common.error') })
|
||||
setCompletionParams({})
|
||||
}
|
||||
}
|
||||
|
||||
const isShowVisionConfig = !!currModel?.features?.includes(ModelFeatureEnum.vision)
|
||||
|
|
|
|||
|
|
@ -191,6 +191,7 @@ function DetailPanel({ detail, onFeedback }: IDetailPanel) {
|
|||
const { userProfile: { timezone } } = useAppContext()
|
||||
const { formatTime } = useTimestamp()
|
||||
const { onClose, appDetail } = useContext(DrawerContext)
|
||||
const { notify } = useContext(ToastContext)
|
||||
const { currentLogItem, setCurrentLogItem, showMessageLogModal, setShowMessageLogModal, showPromptLogModal, setShowPromptLogModal, currentLogModalActiveTab } = useAppStore(useShallow(state => ({
|
||||
currentLogItem: state.currentLogItem,
|
||||
setCurrentLogItem: state.setCurrentLogItem,
|
||||
|
|
@ -312,18 +313,34 @@ function DetailPanel({ detail, onFeedback }: IDetailPanel) {
|
|||
return item
|
||||
}))
|
||||
}, [allChatItems])
|
||||
const handleAnnotationRemoved = useCallback((index: number) => {
|
||||
setAllChatItems(allChatItems.map((item, i) => {
|
||||
if (i === index) {
|
||||
return {
|
||||
...item,
|
||||
content: item.content,
|
||||
annotation: undefined,
|
||||
}
|
||||
const handleAnnotationRemoved = useCallback(async (index: number): Promise<boolean> => {
|
||||
const annotation = allChatItems[index]?.annotation
|
||||
|
||||
try {
|
||||
if (annotation?.id) {
|
||||
const { delAnnotation } = await import('@/service/annotation')
|
||||
await delAnnotation(appDetail?.id || '', annotation.id)
|
||||
}
|
||||
return item
|
||||
}))
|
||||
}, [allChatItems])
|
||||
|
||||
setAllChatItems(allChatItems.map((item, i) => {
|
||||
if (i === index) {
|
||||
return {
|
||||
...item,
|
||||
content: item.content,
|
||||
annotation: undefined,
|
||||
}
|
||||
}
|
||||
return item
|
||||
}))
|
||||
|
||||
notify({ type: 'success', message: t('common.actionMsg.modifiedSuccessfully') })
|
||||
return true
|
||||
}
|
||||
catch {
|
||||
notify({ type: 'error', message: t('common.actionMsg.modifiedUnsuccessfully') })
|
||||
return false
|
||||
}
|
||||
}, [allChatItems, appDetail?.id, t])
|
||||
|
||||
const fetchInitiated = useRef(false)
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,27 @@
|
|||
'use client'
|
||||
import type { FC } from 'react'
|
||||
import React from 'react'
|
||||
import { RiRefreshLine } from '@remixicon/react'
|
||||
import cn from '@/utils/classnames'
|
||||
import TooltipPlus from '@/app/components/base/tooltip'
|
||||
|
||||
type Props = {
|
||||
className?: string,
|
||||
popupContent?: string,
|
||||
onClick: () => void
|
||||
}
|
||||
|
||||
const SyncButton: FC<Props> = ({
|
||||
className,
|
||||
popupContent = '',
|
||||
onClick,
|
||||
}) => {
|
||||
return (
|
||||
<TooltipPlus popupContent={popupContent}>
|
||||
<div className={cn(className, 'cursor-pointer select-none rounded-md p-1 hover:bg-state-base-hover')} onClick={onClick}>
|
||||
<RiRefreshLine className='h-4 w-4 text-text-tertiary' />
|
||||
</div>
|
||||
</TooltipPlus>
|
||||
)
|
||||
}
|
||||
export default React.memo(SyncButton)
|
||||
File diff suppressed because one or more lines are too long
|
After Width: | Height: | Size: 6.9 KiB |
File diff suppressed because one or more lines are too long
|
After Width: | Height: | Size: 6.9 KiB |
|
|
@ -18,7 +18,7 @@ describe('InputNumber Component', () => {
|
|||
|
||||
it('renders input with default values', () => {
|
||||
render(<InputNumber {...defaultProps} />)
|
||||
const input = screen.getByRole('textbox')
|
||||
const input = screen.getByRole('spinbutton')
|
||||
expect(input).toBeInTheDocument()
|
||||
})
|
||||
|
||||
|
|
@ -56,7 +56,7 @@ describe('InputNumber Component', () => {
|
|||
|
||||
it('handles direct input changes', () => {
|
||||
render(<InputNumber {...defaultProps} />)
|
||||
const input = screen.getByRole('textbox')
|
||||
const input = screen.getByRole('spinbutton')
|
||||
|
||||
fireEvent.change(input, { target: { value: '42' } })
|
||||
expect(defaultProps.onChange).toHaveBeenCalledWith(42)
|
||||
|
|
@ -64,7 +64,7 @@ describe('InputNumber Component', () => {
|
|||
|
||||
it('handles empty input', () => {
|
||||
render(<InputNumber {...defaultProps} value={0} />)
|
||||
const input = screen.getByRole('textbox')
|
||||
const input = screen.getByRole('spinbutton')
|
||||
|
||||
fireEvent.change(input, { target: { value: '' } })
|
||||
expect(defaultProps.onChange).toHaveBeenCalledWith(undefined)
|
||||
|
|
@ -72,7 +72,7 @@ describe('InputNumber Component', () => {
|
|||
|
||||
it('handles invalid input', () => {
|
||||
render(<InputNumber {...defaultProps} />)
|
||||
const input = screen.getByRole('textbox')
|
||||
const input = screen.getByRole('spinbutton')
|
||||
|
||||
fireEvent.change(input, { target: { value: 'abc' } })
|
||||
expect(defaultProps.onChange).not.toHaveBeenCalled()
|
||||
|
|
@ -86,7 +86,7 @@ describe('InputNumber Component', () => {
|
|||
|
||||
it('disables controls when disabled prop is true', () => {
|
||||
render(<InputNumber {...defaultProps} disabled />)
|
||||
const input = screen.getByRole('textbox')
|
||||
const input = screen.getByRole('spinbutton')
|
||||
const incrementBtn = screen.getByRole('button', { name: /increment/i })
|
||||
const decrementBtn = screen.getByRole('button', { name: /decrement/i })
|
||||
|
||||
|
|
|
|||
|
|
@ -84,8 +84,8 @@ export const InputNumber: FC<InputNumberProps> = (props) => {
|
|||
return <div className={classNames('flex', wrapClassName)}>
|
||||
<Input {...rest}
|
||||
// disable default controller
|
||||
type='text'
|
||||
className={classNames('rounded-r-none', className)}
|
||||
type='number'
|
||||
className={classNames('no-spinner rounded-r-none', className)}
|
||||
value={value}
|
||||
max={max}
|
||||
min={min}
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ export const preprocessLaTeX = (content: string) => {
|
|||
|
||||
const codeBlockRegex = /```[\s\S]*?```/g
|
||||
const codeBlocks = content.match(codeBlockRegex) || []
|
||||
const escapeReplacement = (str: string) => str.replace(/\$/g, '_TMP_REPLACE_DOLLAR_')
|
||||
let processedContent = content.replace(codeBlockRegex, 'CODE_BLOCK_PLACEHOLDER')
|
||||
|
||||
processedContent = flow([
|
||||
|
|
@ -21,9 +22,11 @@ export const preprocessLaTeX = (content: string) => {
|
|||
])(processedContent)
|
||||
|
||||
codeBlocks.forEach((block) => {
|
||||
processedContent = processedContent.replace('CODE_BLOCK_PLACEHOLDER', block)
|
||||
processedContent = processedContent.replace('CODE_BLOCK_PLACEHOLDER', escapeReplacement(block))
|
||||
})
|
||||
|
||||
processedContent = processedContent.replace(/_TMP_REPLACE_DOLLAR_/g, '$')
|
||||
|
||||
return processedContent
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ import { Fragment, cloneElement, useRef } from 'react'
|
|||
import cn from '@/utils/classnames'
|
||||
|
||||
export type HtmlContentProps = {
|
||||
open?: boolean
|
||||
onClose?: () => void
|
||||
onClick?: () => void
|
||||
}
|
||||
|
|
@ -100,7 +101,8 @@ export default function CustomPopover({
|
|||
}
|
||||
>
|
||||
{cloneElement(htmlContent as React.ReactElement, {
|
||||
onClose: () => onMouseLeave(open),
|
||||
open,
|
||||
onClose: close,
|
||||
...(manualClose
|
||||
? {
|
||||
onClick: close,
|
||||
|
|
|
|||
|
|
@ -1,10 +1,10 @@
|
|||
'use client'
|
||||
import type { FC } from 'react'
|
||||
import React, { useEffect, useState } from 'react'
|
||||
import React, { useEffect, useRef, useState } from 'react'
|
||||
import { Combobox, ComboboxButton, ComboboxInput, ComboboxOption, ComboboxOptions, Listbox, ListboxButton, ListboxOption, ListboxOptions } from '@headlessui/react'
|
||||
import { ChevronDownIcon, ChevronUpIcon, XMarkIcon } from '@heroicons/react/20/solid'
|
||||
import Badge from '../badge/index'
|
||||
import { RiCheckLine } from '@remixicon/react'
|
||||
import { RiCheckLine, RiLoader4Line } from '@remixicon/react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import classNames from '@/utils/classnames'
|
||||
import {
|
||||
|
|
@ -51,6 +51,8 @@ export type ISelectProps = {
|
|||
item: Item
|
||||
selected: boolean
|
||||
}) => React.ReactNode
|
||||
isLoading?: boolean
|
||||
onOpenChange?: (open: boolean) => void
|
||||
}
|
||||
const Select: FC<ISelectProps> = ({
|
||||
className,
|
||||
|
|
@ -114,7 +116,7 @@ const Select: FC<ISelectProps> = ({
|
|||
if (!disabled)
|
||||
setOpen(!open)
|
||||
}
|
||||
} className={classNames(`flex items-center h-9 w-full rounded-lg border-0 ${bgClassName} py-1.5 pl-3 pr-10 shadow-sm sm:text-sm sm:leading-6 focus-visible:outline-none focus-visible:bg-state-base-hover group-hover:bg-state-base-hover`, optionClassName)}>
|
||||
} className={classNames(`flex h-9 w-full items-center rounded-lg border-0 ${bgClassName} py-1.5 pl-3 pr-10 shadow-sm focus-visible:bg-state-base-hover focus-visible:outline-none group-hover:bg-state-base-hover sm:text-sm sm:leading-6`, optionClassName)}>
|
||||
<div className='w-0 grow truncate text-left' title={selectedItem?.name}>{selectedItem?.name}</div>
|
||||
</ComboboxButton>}
|
||||
<ComboboxButton className="absolute inset-y-0 right-0 flex items-center rounded-r-md px-2 focus:outline-none" onClick={
|
||||
|
|
@ -135,7 +137,7 @@ const Select: FC<ISelectProps> = ({
|
|||
value={item}
|
||||
className={({ active }: { active: boolean }) =>
|
||||
classNames(
|
||||
'relative cursor-default select-none py-2 pl-3 pr-9 rounded-lg hover:bg-state-base-hover text-text-secondary',
|
||||
'relative cursor-default select-none rounded-lg py-2 pl-3 pr-9 text-text-secondary hover:bg-state-base-hover',
|
||||
active ? 'bg-state-base-hover' : '',
|
||||
optionClassName,
|
||||
)
|
||||
|
|
@ -178,17 +180,20 @@ const SimpleSelect: FC<ISelectProps> = ({
|
|||
defaultValue = 1,
|
||||
disabled = false,
|
||||
onSelect,
|
||||
onOpenChange,
|
||||
placeholder,
|
||||
optionWrapClassName,
|
||||
optionClassName,
|
||||
hideChecked,
|
||||
notClearable,
|
||||
renderOption,
|
||||
isLoading = false,
|
||||
}) => {
|
||||
const { t } = useTranslation()
|
||||
const localPlaceholder = placeholder || t('common.placeholder.select')
|
||||
|
||||
const [selectedItem, setSelectedItem] = useState<Item | null>(null)
|
||||
|
||||
useEffect(() => {
|
||||
let defaultSelect = null
|
||||
const existed = items.find((item: Item) => item.value === defaultValue)
|
||||
|
|
@ -199,8 +204,10 @@ const SimpleSelect: FC<ISelectProps> = ({
|
|||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, [defaultValue])
|
||||
|
||||
const listboxRef = useRef<HTMLDivElement>(null)
|
||||
|
||||
return (
|
||||
<Listbox
|
||||
<Listbox ref={listboxRef}
|
||||
value={selectedItem}
|
||||
onChange={(value: Item) => {
|
||||
if (!disabled) {
|
||||
|
|
@ -212,10 +219,17 @@ const SimpleSelect: FC<ISelectProps> = ({
|
|||
<div className={classNames('group/simple-select relative h-9', wrapperClassName)}>
|
||||
{renderTrigger && <ListboxButton className='w-full'>{renderTrigger(selectedItem)}</ListboxButton>}
|
||||
{!renderTrigger && (
|
||||
<ListboxButton className={classNames(`flex items-center w-full h-full rounded-lg border-0 bg-components-input-bg-normal pl-3 pr-10 sm:text-sm sm:leading-6 focus-visible:outline-none focus-visible:bg-state-base-hover-alt group-hover/simple-select:bg-state-base-hover-alt ${disabled ? 'cursor-not-allowed' : 'cursor-pointer'}`, className)}>
|
||||
<span className={classNames('block truncate text-left system-sm-regular text-components-input-text-filled', !selectedItem?.name && 'text-components-input-text-placeholder')}>{selectedItem?.name ?? localPlaceholder}</span>
|
||||
<ListboxButton onClick={() => {
|
||||
// get data-open, use setTimeout to ensure the attribute is set
|
||||
setTimeout(() => {
|
||||
if (listboxRef.current)
|
||||
onOpenChange?.(listboxRef.current.getAttribute('data-open') !== null)
|
||||
})
|
||||
}} className={classNames(`flex h-full w-full items-center rounded-lg border-0 bg-components-input-bg-normal pl-3 pr-10 focus-visible:bg-state-base-hover-alt focus-visible:outline-none group-hover/simple-select:bg-state-base-hover-alt sm:text-sm sm:leading-6 ${disabled ? 'cursor-not-allowed' : 'cursor-pointer'}`, className)}>
|
||||
<span className={classNames('system-sm-regular block truncate text-left text-components-input-text-filled', !selectedItem?.name && 'text-components-input-text-placeholder')}>{selectedItem?.name ?? localPlaceholder}</span>
|
||||
<span className="absolute inset-y-0 right-0 flex items-center pr-2">
|
||||
{(selectedItem && !notClearable)
|
||||
{isLoading ? <RiLoader4Line className='h-3.5 w-3.5 animate-spin text-text-secondary' />
|
||||
: (selectedItem && !notClearable)
|
||||
? (
|
||||
<XMarkIcon
|
||||
onClick={(e) => {
|
||||
|
|
@ -237,14 +251,14 @@ const SimpleSelect: FC<ISelectProps> = ({
|
|||
</ListboxButton>
|
||||
)}
|
||||
|
||||
{!disabled && (
|
||||
<ListboxOptions className={classNames('absolute z-10 mt-1 px-1 max-h-60 w-full overflow-auto rounded-xl bg-components-panel-bg-blur backdrop-blur-sm py-1 text-base shadow-lg border-components-panel-border border-[0.5px] focus:outline-none sm:text-sm', optionWrapClassName)}>
|
||||
{(!disabled) && (
|
||||
<ListboxOptions className={classNames('absolute z-10 mt-1 max-h-60 w-full overflow-auto rounded-xl border-[0.5px] border-components-panel-border bg-components-panel-bg-blur px-1 py-1 text-base shadow-lg backdrop-blur-sm focus:outline-none sm:text-sm', optionWrapClassName)}>
|
||||
{items.map((item: Item) => (
|
||||
<ListboxOption
|
||||
key={item.value}
|
||||
className={
|
||||
classNames(
|
||||
'relative cursor-pointer select-none py-2 pl-3 pr-9 rounded-lg hover:bg-state-base-hover text-text-secondary',
|
||||
'relative cursor-pointer select-none rounded-lg py-2 pl-3 pr-9 text-text-secondary hover:bg-state-base-hover',
|
||||
optionClassName,
|
||||
)
|
||||
}
|
||||
|
|
@ -324,7 +338,7 @@ const PortalSelect: FC<PortalSelectProps> = ({
|
|||
: (
|
||||
<div
|
||||
className={classNames(`
|
||||
group flex items-center justify-between px-2.5 h-9 rounded-lg border-0 bg-components-input-bg-normal hover:bg-state-base-hover-alt text-sm ${readonly ? 'cursor-not-allowed' : 'cursor-pointer'}
|
||||
group flex h-9 items-center justify-between rounded-lg border-0 bg-components-input-bg-normal px-2.5 text-sm hover:bg-state-base-hover-alt ${readonly ? 'cursor-not-allowed' : 'cursor-pointer'}
|
||||
`, triggerClassName, triggerClassNameFn?.(open))}
|
||||
title={selectedItem?.name}
|
||||
>
|
||||
|
|
@ -344,7 +358,7 @@ const PortalSelect: FC<PortalSelectProps> = ({
|
|||
</PortalToFollowElemTrigger>
|
||||
<PortalToFollowElemContent className={`z-20 ${popupClassName}`}>
|
||||
<div
|
||||
className={classNames('px-1 py-1 max-h-60 overflow-auto rounded-md text-base shadow-lg border-components-panel-border bg-components-panel-bg border-[0.5px] focus:outline-none sm:text-sm', popupInnerClassName)}
|
||||
className={classNames('max-h-60 overflow-auto rounded-md border-[0.5px] border-components-panel-border bg-components-panel-bg px-1 py-1 text-base shadow-lg focus:outline-none sm:text-sm', popupInnerClassName)}
|
||||
>
|
||||
{items.map((item: Item) => (
|
||||
<div
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
import { useState } from 'react'
|
||||
import { useCallback, useState } from 'react'
|
||||
import type { ChangeEvent, FC, KeyboardEvent } from 'react'
|
||||
import { } from 'use-context-selector'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import AutosizeInput from 'react-18-input-autosize'
|
||||
import { RiAddLine, RiCloseLine } from '@remixicon/react'
|
||||
|
|
@ -42,6 +41,29 @@ const TagInput: FC<TagInputProps> = ({
|
|||
onChange(copyItems)
|
||||
}
|
||||
|
||||
const handleNewTag = useCallback((value: string) => {
|
||||
const valueTrimmed = value.trim()
|
||||
if (!valueTrimmed) {
|
||||
notify({ type: 'error', message: t('datasetDocuments.segment.keywordEmpty') })
|
||||
return
|
||||
}
|
||||
|
||||
if ((items.find(item => item === valueTrimmed))) {
|
||||
notify({ type: 'error', message: t('datasetDocuments.segment.keywordDuplicate') })
|
||||
return
|
||||
}
|
||||
|
||||
if (valueTrimmed.length > 20) {
|
||||
notify({ type: 'error', message: t('datasetDocuments.segment.keywordError') })
|
||||
return
|
||||
}
|
||||
|
||||
onChange([...items, valueTrimmed])
|
||||
setTimeout(() => {
|
||||
setValue('')
|
||||
})
|
||||
}, [items, onChange, notify, t])
|
||||
|
||||
const handleKeyDown = (e: KeyboardEvent) => {
|
||||
if (isSpecialMode && e.key === 'Enter')
|
||||
setValue(`${value}↵`)
|
||||
|
|
@ -50,24 +72,12 @@ const TagInput: FC<TagInputProps> = ({
|
|||
if (isSpecialMode)
|
||||
e.preventDefault()
|
||||
|
||||
const valueTrimmed = value.trim()
|
||||
if (!valueTrimmed || (items.find(item => item === valueTrimmed)))
|
||||
return
|
||||
|
||||
if (valueTrimmed.length > 20) {
|
||||
notify({ type: 'error', message: t('datasetDocuments.segment.keywordError') })
|
||||
return
|
||||
}
|
||||
|
||||
onChange([...items, valueTrimmed])
|
||||
setTimeout(() => {
|
||||
setValue('')
|
||||
})
|
||||
handleNewTag(value)
|
||||
}
|
||||
}
|
||||
|
||||
const handleBlur = () => {
|
||||
setValue('')
|
||||
handleNewTag(value)
|
||||
setFocused(false)
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -508,13 +508,15 @@ const StepTwo = ({
|
|||
const separator = rules.segmentation.separator
|
||||
const max = rules.segmentation.max_tokens
|
||||
const overlap = rules.segmentation.chunk_overlap
|
||||
const isHierarchicalDocument = documentDetail.doc_form === ChunkingMode.parentChild
|
||||
|| (rules.parent_mode && rules.subchunk_segmentation)
|
||||
setSegmentIdentifier(separator)
|
||||
setMaxChunkLength(max)
|
||||
setOverlap(overlap!)
|
||||
setRules(rules.pre_processing_rules)
|
||||
setDefaultConfig(rules)
|
||||
|
||||
if (documentDetail.dataset_process_rule.mode === 'hierarchical') {
|
||||
if (isHierarchicalDocument) {
|
||||
setParentChildConfig({
|
||||
chunkForContext: rules.parent_mode || 'paragraph',
|
||||
parent: {
|
||||
|
|
|
|||
|
|
@ -30,6 +30,7 @@ import useEditDocumentMetadata from '../metadata/hooks/use-edit-dataset-metadata
|
|||
import DatasetMetadataDrawer from '../metadata/metadata-dataset/dataset-metadata-drawer'
|
||||
import StatusWithAction from '../common/document-status-with-action/status-with-action'
|
||||
import { useDocLink } from '@/context/i18n'
|
||||
import { useFetchDefaultProcessRule } from '@/service/knowledge/use-create-dataset'
|
||||
|
||||
const FolderPlusIcon = ({ className }: React.SVGProps<SVGElement>) => {
|
||||
return <svg width="20" height="20" viewBox="0 0 20 20" fill="none" xmlns="http://www.w3.org/2000/svg" className={className ?? ''}>
|
||||
|
|
@ -183,6 +184,8 @@ const Documents: FC<IDocumentsProps> = ({ datasetId }) => {
|
|||
router.push(`/datasets/${datasetId}/documents/create`)
|
||||
}
|
||||
|
||||
const fetchDefaultProcessRuleMutation = useFetchDefaultProcessRule()
|
||||
|
||||
const handleSaveNotionPageSelected = async (selectedPages: NotionPage[]) => {
|
||||
const workspacesMap = groupBy(selectedPages, 'workspace_id')
|
||||
const workspaces = Object.keys(workspacesMap).map((workspaceId) => {
|
||||
|
|
@ -191,6 +194,7 @@ const Documents: FC<IDocumentsProps> = ({ datasetId }) => {
|
|||
pages: workspacesMap[workspaceId],
|
||||
}
|
||||
})
|
||||
const { rules } = await fetchDefaultProcessRuleMutation.mutateAsync('/datasets/process-rule')
|
||||
const params = {
|
||||
data_source: {
|
||||
type: dataset?.data_source_type,
|
||||
|
|
@ -214,7 +218,7 @@ const Documents: FC<IDocumentsProps> = ({ datasetId }) => {
|
|||
},
|
||||
indexing_technique: dataset?.indexing_technique,
|
||||
process_rule: {
|
||||
rules: {},
|
||||
rules,
|
||||
mode: ProcessMode.general,
|
||||
},
|
||||
} as CreateDocumentReq
|
||||
|
|
|
|||
|
|
@ -1124,6 +1124,129 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
|
|||
|
||||
<hr className='ml-0 mr-0' />
|
||||
|
||||
<Heading
|
||||
url='/datasets/{dataset_id}/documents/{document_id}'
|
||||
method='GET'
|
||||
title='Get Document Detail'
|
||||
name='#get-document-detail'
|
||||
/>
|
||||
<Row>
|
||||
<Col>
|
||||
Get a document's detail.
|
||||
### Path
|
||||
- `dataset_id` (string) Dataset ID
|
||||
- `document_id` (string) Document ID
|
||||
|
||||
### Query
|
||||
- `metadata` (string) Metadata filter, can be `all`, `only`, or `without`. Default is `all`.
|
||||
|
||||
### Response
|
||||
Returns the document's detail.
|
||||
</Col>
|
||||
<Col sticky>
|
||||
### Request Example
|
||||
<CodeGroup title="Request" tag="GET" label="/datasets/{dataset_id}/documents/{document_id}" targetCode={`curl -X GET '${props.apiBaseUrl}/datasets/{dataset_id}/documents/{document_id}' \\\n-H 'Authorization: Bearer {api_key}'`}>
|
||||
```bash {{ title: 'cURL' }}
|
||||
curl -X GET '${props.apiBaseUrl}/datasets/{dataset_id}/documents/{document_id}' \
|
||||
-H 'Authorization: Bearer {api_key}'
|
||||
```
|
||||
</CodeGroup>
|
||||
|
||||
### Response Example
|
||||
<CodeGroup title="Response">
|
||||
```json {{ title: 'Response' }}
|
||||
{
|
||||
"id": "f46ae30c-5c11-471b-96d0-464f5f32a7b2",
|
||||
"position": 1,
|
||||
"data_source_type": "upload_file",
|
||||
"data_source_info": {
|
||||
"upload_file": {
|
||||
...
|
||||
}
|
||||
},
|
||||
"dataset_process_rule_id": "24b99906-845e-499f-9e3c-d5565dd6962c",
|
||||
"dataset_process_rule": {
|
||||
"mode": "hierarchical",
|
||||
"rules": {
|
||||
"pre_processing_rules": [
|
||||
{
|
||||
"id": "remove_extra_spaces",
|
||||
"enabled": true
|
||||
},
|
||||
{
|
||||
"id": "remove_urls_emails",
|
||||
"enabled": false
|
||||
}
|
||||
],
|
||||
"segmentation": {
|
||||
"separator": "**********page_ending**********",
|
||||
"max_tokens": 1024,
|
||||
"chunk_overlap": 0
|
||||
},
|
||||
"parent_mode": "paragraph",
|
||||
"subchunk_segmentation": {
|
||||
"separator": "\n",
|
||||
"max_tokens": 512,
|
||||
"chunk_overlap": 0
|
||||
}
|
||||
}
|
||||
},
|
||||
"document_process_rule": {
|
||||
"id": "24b99906-845e-499f-9e3c-d5565dd6962c",
|
||||
"dataset_id": "48a0db76-d1a9-46c1-ae35-2baaa919a8a9",
|
||||
"mode": "hierarchical",
|
||||
"rules": {
|
||||
"pre_processing_rules": [
|
||||
{
|
||||
"id": "remove_extra_spaces",
|
||||
"enabled": true
|
||||
},
|
||||
{
|
||||
"id": "remove_urls_emails",
|
||||
"enabled": false
|
||||
}
|
||||
],
|
||||
"segmentation": {
|
||||
"separator": "**********page_ending**********",
|
||||
"max_tokens": 1024,
|
||||
"chunk_overlap": 0
|
||||
},
|
||||
"parent_mode": "paragraph",
|
||||
"subchunk_segmentation": {
|
||||
"separator": "\n",
|
||||
"max_tokens": 512,
|
||||
"chunk_overlap": 0
|
||||
}
|
||||
}
|
||||
},
|
||||
"name": "xxxx",
|
||||
"created_from": "web",
|
||||
"created_by": "17f71940-a7b5-4c77-b60f-2bd645c1ffa0",
|
||||
"created_at": 1750464191,
|
||||
"tokens": null,
|
||||
"indexing_status": "waiting",
|
||||
"completed_at": null,
|
||||
"updated_at": 1750464191,
|
||||
"indexing_latency": null,
|
||||
"error": null,
|
||||
"enabled": true,
|
||||
"disabled_at": null,
|
||||
"disabled_by": null,
|
||||
"archived": false,
|
||||
"segment_count": 0,
|
||||
"average_segment_length": 0,
|
||||
"hit_count": null,
|
||||
"display_status": "queuing",
|
||||
"doc_form": "hierarchical_model",
|
||||
"doc_language": "Chinese Simplified"
|
||||
}
|
||||
```
|
||||
</CodeGroup>
|
||||
</Col>
|
||||
</Row>
|
||||
___
|
||||
<hr className='ml-0 mr-0' />
|
||||
|
||||
<Heading
|
||||
url='/datasets/{dataset_id}/documents/status/{action}'
|
||||
method='PATCH'
|
||||
|
|
|
|||
|
|
@ -881,6 +881,130 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
|
|||
|
||||
<hr className='ml-0 mr-0' />
|
||||
|
||||
<Heading
|
||||
url='/datasets/{dataset_id}/documents/{document_id}'
|
||||
method='GET'
|
||||
title='ドキュメントの詳細を取得'
|
||||
name='#get-document-detail'
|
||||
/>
|
||||
<Row>
|
||||
<Col>
|
||||
ドキュメントの詳細を取得.
|
||||
### Path
|
||||
- `dataset_id` (string) ナレッジベースID
|
||||
- `document_id` (string) ドキュメントID
|
||||
|
||||
### Query
|
||||
- `metadata` (string) metadataのフィルター条件 `all`、`only`、または`without`。デフォルトは `all`。
|
||||
|
||||
### Response
|
||||
ナレッジベースドキュメントの詳細を返す.
|
||||
</Col>
|
||||
<Col sticky>
|
||||
### Request Example
|
||||
<CodeGroup title="Request" tag="GET" label="/datasets/{dataset_id}/documents/{document_id}" targetCode={`curl -X GET '${props.apiBaseUrl}/datasets/{dataset_id}/documents/{document_id}' \\\n-H 'Authorization: Bearer {api_key}'`}>
|
||||
```bash {{ title: 'cURL' }}
|
||||
curl -X GET '${props.apiBaseUrl}/datasets/{dataset_id}/documents/{document_id}' \
|
||||
-H 'Authorization: Bearer {api_key}'
|
||||
```
|
||||
</CodeGroup>
|
||||
|
||||
### Response Example
|
||||
<CodeGroup title="Response">
|
||||
```json {{ title: 'Response' }}
|
||||
{
|
||||
"id": "f46ae30c-5c11-471b-96d0-464f5f32a7b2",
|
||||
"position": 1,
|
||||
"data_source_type": "upload_file",
|
||||
"data_source_info": {
|
||||
"upload_file": {
|
||||
...
|
||||
}
|
||||
},
|
||||
"dataset_process_rule_id": "24b99906-845e-499f-9e3c-d5565dd6962c",
|
||||
"dataset_process_rule": {
|
||||
"mode": "hierarchical",
|
||||
"rules": {
|
||||
"pre_processing_rules": [
|
||||
{
|
||||
"id": "remove_extra_spaces",
|
||||
"enabled": true
|
||||
},
|
||||
{
|
||||
"id": "remove_urls_emails",
|
||||
"enabled": false
|
||||
}
|
||||
],
|
||||
"segmentation": {
|
||||
"separator": "**********page_ending**********",
|
||||
"max_tokens": 1024,
|
||||
"chunk_overlap": 0
|
||||
},
|
||||
"parent_mode": "paragraph",
|
||||
"subchunk_segmentation": {
|
||||
"separator": "\n",
|
||||
"max_tokens": 512,
|
||||
"chunk_overlap": 0
|
||||
}
|
||||
}
|
||||
},
|
||||
"document_process_rule": {
|
||||
"id": "24b99906-845e-499f-9e3c-d5565dd6962c",
|
||||
"dataset_id": "48a0db76-d1a9-46c1-ae35-2baaa919a8a9",
|
||||
"mode": "hierarchical",
|
||||
"rules": {
|
||||
"pre_processing_rules": [
|
||||
{
|
||||
"id": "remove_extra_spaces",
|
||||
"enabled": true
|
||||
},
|
||||
{
|
||||
"id": "remove_urls_emails",
|
||||
"enabled": false
|
||||
}
|
||||
],
|
||||
"segmentation": {
|
||||
"separator": "**********page_ending**********",
|
||||
"max_tokens": 1024,
|
||||
"chunk_overlap": 0
|
||||
},
|
||||
"parent_mode": "paragraph",
|
||||
"subchunk_segmentation": {
|
||||
"separator": "\n",
|
||||
"max_tokens": 512,
|
||||
"chunk_overlap": 0
|
||||
}
|
||||
}
|
||||
},
|
||||
"name": "xxxx",
|
||||
"created_from": "web",
|
||||
"created_by": "17f71940-a7b5-4c77-b60f-2bd645c1ffa0",
|
||||
"created_at": 1750464191,
|
||||
"tokens": null,
|
||||
"indexing_status": "waiting",
|
||||
"completed_at": null,
|
||||
"updated_at": 1750464191,
|
||||
"indexing_latency": null,
|
||||
"error": null,
|
||||
"enabled": true,
|
||||
"disabled_at": null,
|
||||
"disabled_by": null,
|
||||
"archived": false,
|
||||
"segment_count": 0,
|
||||
"average_segment_length": 0,
|
||||
"hit_count": null,
|
||||
"display_status": "queuing",
|
||||
"doc_form": "hierarchical_model",
|
||||
"doc_language": "Chinese Simplified"
|
||||
}
|
||||
```
|
||||
</CodeGroup>
|
||||
</Col>
|
||||
</Row>
|
||||
___
|
||||
<hr className='ml-0 mr-0' />
|
||||
|
||||
|
||||
<Heading
|
||||
url='/datasets/{dataset_id}/documents/status/{action}'
|
||||
method='PATCH'
|
||||
|
|
|
|||
|
|
@ -1131,6 +1131,130 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
|
|||
|
||||
<hr className='ml-0 mr-0' />
|
||||
|
||||
<Heading
|
||||
url='/datasets/{dataset_id}/documents/{document_id}'
|
||||
method='GET'
|
||||
title='获取文档详情'
|
||||
name='#get-document-detail'
|
||||
/>
|
||||
<Row>
|
||||
<Col>
|
||||
获取文档详情.
|
||||
### Path
|
||||
- `dataset_id` (string) 知识库 ID
|
||||
- `document_id` (string) 文档 ID
|
||||
|
||||
### Query
|
||||
- `metadata` (string) metadata 过滤条件 `all`, `only`, 或者 `without`. 默认是 `all`.
|
||||
|
||||
### Response
|
||||
返回知识库文档的详情.
|
||||
</Col>
|
||||
<Col sticky>
|
||||
### Request Example
|
||||
<CodeGroup title="Request" tag="GET" label="/datasets/{dataset_id}/documents/{document_id}" targetCode={`curl -X GET '${props.apiBaseUrl}/datasets/{dataset_id}/documents/{document_id}' \\\n-H 'Authorization: Bearer {api_key}'`}>
|
||||
```bash {{ title: 'cURL' }}
|
||||
curl -X GET '${props.apiBaseUrl}/datasets/{dataset_id}/documents/{document_id}' \
|
||||
-H 'Authorization: Bearer {api_key}'
|
||||
```
|
||||
</CodeGroup>
|
||||
|
||||
### Response Example
|
||||
<CodeGroup title="Response">
|
||||
```json {{ title: 'Response' }}
|
||||
{
|
||||
"id": "f46ae30c-5c11-471b-96d0-464f5f32a7b2",
|
||||
"position": 1,
|
||||
"data_source_type": "upload_file",
|
||||
"data_source_info": {
|
||||
"upload_file": {
|
||||
...
|
||||
}
|
||||
},
|
||||
"dataset_process_rule_id": "24b99906-845e-499f-9e3c-d5565dd6962c",
|
||||
"dataset_process_rule": {
|
||||
"mode": "hierarchical",
|
||||
"rules": {
|
||||
"pre_processing_rules": [
|
||||
{
|
||||
"id": "remove_extra_spaces",
|
||||
"enabled": true
|
||||
},
|
||||
{
|
||||
"id": "remove_urls_emails",
|
||||
"enabled": false
|
||||
}
|
||||
],
|
||||
"segmentation": {
|
||||
"separator": "**********page_ending**********",
|
||||
"max_tokens": 1024,
|
||||
"chunk_overlap": 0
|
||||
},
|
||||
"parent_mode": "paragraph",
|
||||
"subchunk_segmentation": {
|
||||
"separator": "\n",
|
||||
"max_tokens": 512,
|
||||
"chunk_overlap": 0
|
||||
}
|
||||
}
|
||||
},
|
||||
"document_process_rule": {
|
||||
"id": "24b99906-845e-499f-9e3c-d5565dd6962c",
|
||||
"dataset_id": "48a0db76-d1a9-46c1-ae35-2baaa919a8a9",
|
||||
"mode": "hierarchical",
|
||||
"rules": {
|
||||
"pre_processing_rules": [
|
||||
{
|
||||
"id": "remove_extra_spaces",
|
||||
"enabled": true
|
||||
},
|
||||
{
|
||||
"id": "remove_urls_emails",
|
||||
"enabled": false
|
||||
}
|
||||
],
|
||||
"segmentation": {
|
||||
"separator": "**********page_ending**********",
|
||||
"max_tokens": 1024,
|
||||
"chunk_overlap": 0
|
||||
},
|
||||
"parent_mode": "paragraph",
|
||||
"subchunk_segmentation": {
|
||||
"separator": "\n",
|
||||
"max_tokens": 512,
|
||||
"chunk_overlap": 0
|
||||
}
|
||||
}
|
||||
},
|
||||
"name": "xxxx",
|
||||
"created_from": "web",
|
||||
"created_by": "17f71940-a7b5-4c77-b60f-2bd645c1ffa0",
|
||||
"created_at": 1750464191,
|
||||
"tokens": null,
|
||||
"indexing_status": "waiting",
|
||||
"completed_at": null,
|
||||
"updated_at": 1750464191,
|
||||
"indexing_latency": null,
|
||||
"error": null,
|
||||
"enabled": true,
|
||||
"disabled_at": null,
|
||||
"disabled_by": null,
|
||||
"archived": false,
|
||||
"segment_count": 0,
|
||||
"average_segment_length": 0,
|
||||
"hit_count": null,
|
||||
"display_status": "queuing",
|
||||
"doc_form": "hierarchical_model",
|
||||
"doc_language": "Chinese Simplified"
|
||||
}
|
||||
```
|
||||
</CodeGroup>
|
||||
</Col>
|
||||
</Row>
|
||||
___
|
||||
<hr className='ml-0 mr-0' />
|
||||
|
||||
|
||||
<Heading
|
||||
url='/datasets/{dataset_id}/documents/status/{action}'
|
||||
method='PATCH'
|
||||
|
|
|
|||
|
|
@ -9,6 +9,8 @@ import { useAppContext } from '@/context/app-context'
|
|||
import { fetchNotionConnection } from '@/service/common'
|
||||
import NotionIcon from '@/app/components/base/notion-icon'
|
||||
import { noop } from 'lodash-es'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import Toast from '@/app/components/base/toast'
|
||||
|
||||
const Icon: FC<{
|
||||
src: string
|
||||
|
|
@ -33,6 +35,7 @@ const DataSourceNotion: FC<Props> = ({
|
|||
const { isCurrentWorkspaceManager } = useAppContext()
|
||||
const [canConnectNotion, setCanConnectNotion] = useState(false)
|
||||
const { data } = useSWR(canConnectNotion ? '/oauth/data-source/notion' : null, fetchNotionConnection)
|
||||
const { t } = useTranslation()
|
||||
|
||||
const connected = !!workspaces.length
|
||||
|
||||
|
|
@ -51,9 +54,19 @@ const DataSourceNotion: FC<Props> = ({
|
|||
}
|
||||
|
||||
useEffect(() => {
|
||||
if (data?.data)
|
||||
window.location.href = data.data
|
||||
}, [data])
|
||||
if (data && 'data' in data) {
|
||||
if (data.data && typeof data.data === 'string' && data.data.startsWith('http')) {
|
||||
window.location.href = data.data
|
||||
}
|
||||
else if (data.data === 'internal') {
|
||||
Toast.notify({
|
||||
type: 'info',
|
||||
message: t('common.dataSource.notion.integratedAlert'),
|
||||
})
|
||||
}
|
||||
}
|
||||
}, [data, t])
|
||||
|
||||
return (
|
||||
<Panel
|
||||
type={DataSourceType.notion}
|
||||
|
|
|
|||
|
|
@ -19,12 +19,14 @@ export enum FormTypeEnum {
|
|||
toolSelector = 'tool-selector',
|
||||
multiToolSelector = 'array[tools]',
|
||||
appSelector = 'app-selector',
|
||||
dynamicSelect = 'dynamic-select',
|
||||
}
|
||||
|
||||
export type FormOption = {
|
||||
label: TypeWithI18N
|
||||
value: string
|
||||
show_on: FormShowOnObject[]
|
||||
icon?: string
|
||||
}
|
||||
|
||||
export enum ModelTypeEnum {
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ const HeaderWrapper = ({
|
|||
}: HeaderWrapperProps) => {
|
||||
const pathname = usePathname()
|
||||
const isBordered = ['/apps', '/datasets', '/datasets/create', '/tools'].includes(pathname)
|
||||
// // Check if the current path is a workflow canvas & fullscreen
|
||||
// Check if the current path is a workflow canvas & fullscreen
|
||||
const inWorkflowCanvas = pathname.endsWith('/workflow')
|
||||
const workflowCanvasMaximize = localStorage.getItem('workflow-canvas-maximize') === 'true'
|
||||
const [hideHeader, setHideHeader] = useState(workflowCanvasMaximize)
|
||||
|
|
@ -25,14 +25,12 @@ const HeaderWrapper = ({
|
|||
setHideHeader(v.payload)
|
||||
})
|
||||
|
||||
if (hideHeader && inWorkflowCanvas)
|
||||
return null
|
||||
|
||||
return (
|
||||
<div className={classNames(
|
||||
'sticky left-0 right-0 top-0 z-30 flex min-h-[56px] shrink-0 grow-0 basis-auto flex-col',
|
||||
'sticky left-0 right-0 top-0 z-[15] flex min-h-[56px] shrink-0 grow-0 basis-auto flex-col',
|
||||
s.header,
|
||||
isBordered ? 'border-b border-divider-regular' : '',
|
||||
hideHeader && inWorkflowCanvas && 'hidden',
|
||||
)}
|
||||
>
|
||||
{children}
|
||||
|
|
|
|||
|
|
@ -62,7 +62,7 @@ const AppInputsPanel = ({
|
|||
return []
|
||||
let inputFormSchema = []
|
||||
if (isBasicApp) {
|
||||
inputFormSchema = currentApp.model_config.user_input_form.filter((item: any) => !item.external_data_tool).map((item: any) => {
|
||||
inputFormSchema = currentApp.model_config?.user_input_form?.filter((item: any) => !item.external_data_tool).map((item: any) => {
|
||||
if (item.paragraph) {
|
||||
return {
|
||||
...item.paragraph,
|
||||
|
|
@ -108,10 +108,10 @@ const AppInputsPanel = ({
|
|||
type: 'text-input',
|
||||
required: false,
|
||||
}
|
||||
})
|
||||
}) || []
|
||||
}
|
||||
else {
|
||||
const startNode = currentWorkflow?.graph.nodes.find(node => node.data.type === BlockEnum.Start) as any
|
||||
const startNode = currentWorkflow?.graph?.nodes.find(node => node.data.type === BlockEnum.Start) as any
|
||||
inputFormSchema = startNode?.data.variables.map((variable: any) => {
|
||||
if (variable.type === InputVarType.multiFiles) {
|
||||
return {
|
||||
|
|
@ -132,7 +132,7 @@ const AppInputsPanel = ({
|
|||
...variable,
|
||||
required: false,
|
||||
}
|
||||
})
|
||||
}) || []
|
||||
}
|
||||
if ((currentApp.mode === 'completion' || currentApp.mode === 'workflow') && basicAppFileConfig.enabled) {
|
||||
inputFormSchema.push({
|
||||
|
|
@ -144,7 +144,7 @@ const AppInputsPanel = ({
|
|||
fileUploadConfig,
|
||||
})
|
||||
}
|
||||
return inputFormSchema
|
||||
return inputFormSchema || []
|
||||
}, [basicAppFileConfig, currentApp, currentWorkflow, fileUploadConfig, isBasicApp])
|
||||
|
||||
const handleFormChange = (value: Record<string, any>) => {
|
||||
|
|
|
|||
|
|
@ -18,6 +18,15 @@ type Props = {
|
|||
onSaved: (value: Record<string, any>) => void
|
||||
}
|
||||
|
||||
const extractDefaultValues = (schemas: any[]) => {
|
||||
const result: Record<string, any> = {}
|
||||
for (const field of schemas) {
|
||||
if (field.default !== undefined)
|
||||
result[field.name] = field.default
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
const EndpointModal: FC<Props> = ({
|
||||
formSchemas,
|
||||
defaultValues = {},
|
||||
|
|
@ -26,7 +35,10 @@ const EndpointModal: FC<Props> = ({
|
|||
}) => {
|
||||
const getValueFromI18nObject = useRenderI18nObject()
|
||||
const { t } = useTranslation()
|
||||
const [tempCredential, setTempCredential] = React.useState<any>(defaultValues)
|
||||
const initialValues = Object.keys(defaultValues).length > 0
|
||||
? defaultValues
|
||||
: extractDefaultValues(formSchemas)
|
||||
const [tempCredential, setTempCredential] = React.useState<any>(initialValues)
|
||||
|
||||
const handleSave = () => {
|
||||
for (const field of formSchemas) {
|
||||
|
|
|
|||
|
|
@ -117,6 +117,7 @@ const MultipleToolSelector = ({
|
|||
)}
|
||||
{!disabled && (
|
||||
<ActionButton className='mx-1' onClick={() => {
|
||||
setCollapse(false)
|
||||
setOpen(!open)
|
||||
setPanelShowState(true)
|
||||
}}>
|
||||
|
|
@ -126,23 +127,6 @@ const MultipleToolSelector = ({
|
|||
</div>
|
||||
{!collapse && (
|
||||
<>
|
||||
<ToolSelector
|
||||
nodeId={nodeId}
|
||||
nodeOutputVars={nodeOutputVars}
|
||||
availableNodes={availableNodes}
|
||||
scope={scope}
|
||||
value={undefined}
|
||||
selectedTools={value}
|
||||
onSelect={handleAdd}
|
||||
controlledState={open}
|
||||
onControlledStateChange={setOpen}
|
||||
trigger={
|
||||
<div className=''></div>
|
||||
}
|
||||
panelShowState={panelShowState}
|
||||
onPanelShowStateChange={setPanelShowState}
|
||||
isEdit={false}
|
||||
/>
|
||||
{value.length === 0 && (
|
||||
<div className='system-xs-regular flex justify-center rounded-[10px] bg-background-section p-3 text-text-tertiary'>{t('plugin.detailPanel.toolSelector.empty')}</div>
|
||||
)}
|
||||
|
|
@ -164,6 +148,23 @@ const MultipleToolSelector = ({
|
|||
))}
|
||||
</>
|
||||
)}
|
||||
<ToolSelector
|
||||
nodeId={nodeId}
|
||||
nodeOutputVars={nodeOutputVars}
|
||||
availableNodes={availableNodes}
|
||||
scope={scope}
|
||||
value={undefined}
|
||||
selectedTools={value}
|
||||
onSelect={handleAdd}
|
||||
controlledState={open}
|
||||
onControlledStateChange={setOpen}
|
||||
trigger={
|
||||
<div className=''></div>
|
||||
}
|
||||
panelShowState={panelShowState}
|
||||
onPanelShowStateChange={setPanelShowState}
|
||||
isEdit={false}
|
||||
/>
|
||||
</>
|
||||
)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -275,7 +275,7 @@ const ToolSelector: FC<Props> = ({
|
|||
/>
|
||||
)}
|
||||
</PortalToFollowElemTrigger>
|
||||
<PortalToFollowElemContent className='z-[1000]'>
|
||||
<PortalToFollowElemContent>
|
||||
<div className={cn('relative max-h-[642px] min-h-20 w-[361px] rounded-xl border-[0.5px] border-components-panel-border bg-components-panel-bg-blur pb-4 shadow-lg backdrop-blur-sm', !isShowSettingAuth && 'overflow-y-auto pb-2')}>
|
||||
{!isShowSettingAuth && (
|
||||
<>
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue