mirror of https://github.com/langgenius/dify.git
Merge branch 'main' into feat/r2
# Conflicts: # docker/docker-compose.middleware.yaml # web/app/components/workflow-app/components/workflow-main.tsx # web/app/components/workflow-app/hooks/index.ts # web/app/components/workflow/hooks-store/store.ts # web/app/components/workflow/hooks/index.ts # web/app/components/workflow/nodes/_base/components/variable/var-reference-picker.tsx
This commit is contained in:
commit
832bef053f
|
|
@ -47,15 +47,17 @@ jobs:
|
|||
- name: Run Unit tests
|
||||
run: |
|
||||
uv run --project api bash dev/pytest/pytest_unit_tests.sh
|
||||
|
||||
- name: Coverage Summary
|
||||
run: |
|
||||
set -x
|
||||
# Extract coverage percentage and create a summary
|
||||
TOTAL_COVERAGE=$(python -c 'import json; print(json.load(open("coverage.json"))["totals"]["percent_covered_display"])')
|
||||
|
||||
# Create a detailed coverage summary
|
||||
echo "### Test Coverage Summary :test_tube:" >> $GITHUB_STEP_SUMMARY
|
||||
echo "Total Coverage: ${TOTAL_COVERAGE}%" >> $GITHUB_STEP_SUMMARY
|
||||
echo "\`\`\`" >> $GITHUB_STEP_SUMMARY
|
||||
uv run --project api coverage report >> $GITHUB_STEP_SUMMARY
|
||||
echo "\`\`\`" >> $GITHUB_STEP_SUMMARY
|
||||
uv run --project api coverage report --format=markdown >> $GITHUB_STEP_SUMMARY
|
||||
|
||||
- name: Run dify config tests
|
||||
run: uv run --project api dev/pytest/pytest_config_tests.py
|
||||
|
|
|
|||
|
|
@ -214,3 +214,4 @@ mise.toml
|
|||
|
||||
# AI Assistant
|
||||
.roo/
|
||||
api/.env.backup
|
||||
|
|
|
|||
|
|
@ -1,8 +1,11 @@
|
|||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from pydantic.fields import FieldInfo
|
||||
from pydantic_settings import BaseSettings, PydanticBaseSettingsSource, SettingsConfigDict
|
||||
from pydantic_settings import BaseSettings, PydanticBaseSettingsSource, SettingsConfigDict, TomlConfigSettingsSource
|
||||
|
||||
from libs.file_utils import search_file_upwards
|
||||
|
||||
from .deploy import DeploymentConfig
|
||||
from .enterprise import EnterpriseFeatureConfig
|
||||
|
|
@ -99,4 +102,12 @@ class DifyConfig(
|
|||
RemoteSettingsSourceFactory(settings_cls),
|
||||
dotenv_settings,
|
||||
file_secret_settings,
|
||||
TomlConfigSettingsSource(
|
||||
settings_cls=settings_cls,
|
||||
toml_file=search_file_upwards(
|
||||
base_dir_path=Path(__file__).parent,
|
||||
target_file_name="pyproject.toml",
|
||||
max_search_parent_depth=2,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -223,6 +223,10 @@ class CeleryConfig(DatabaseConfig):
|
|||
default=None,
|
||||
)
|
||||
|
||||
CELERY_SENTINEL_PASSWORD: Optional[str] = Field(
|
||||
description="Password of the Redis Sentinel master.",
|
||||
default=None,
|
||||
)
|
||||
CELERY_SENTINEL_SOCKET_TIMEOUT: Optional[PositiveFloat] = Field(
|
||||
description="Timeout for Redis Sentinel socket operations in seconds.",
|
||||
default=0.1,
|
||||
|
|
|
|||
|
|
@ -1,17 +1,13 @@
|
|||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
from configs.packaging.pyproject import PyProjectConfig, PyProjectTomlConfig
|
||||
|
||||
|
||||
class PackagingInfo(BaseSettings):
|
||||
class PackagingInfo(PyProjectTomlConfig):
|
||||
"""
|
||||
Packaging build information
|
||||
"""
|
||||
|
||||
CURRENT_VERSION: str = Field(
|
||||
description="Dify version",
|
||||
default="1.4.3",
|
||||
)
|
||||
|
||||
COMMIT_SHA: str = Field(
|
||||
description="SHA-1 checksum of the git commit used to build the app",
|
||||
default="",
|
||||
|
|
|
|||
|
|
@ -0,0 +1,17 @@
|
|||
from pydantic import BaseModel, Field
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class PyProjectConfig(BaseModel):
|
||||
version: str = Field(description="Dify version", default="")
|
||||
|
||||
|
||||
class PyProjectTomlConfig(BaseSettings):
|
||||
"""
|
||||
configs in api/pyproject.toml
|
||||
"""
|
||||
|
||||
project: PyProjectConfig = Field(
|
||||
description="configs in the project section of pyproject.toml",
|
||||
default=PyProjectConfig(),
|
||||
)
|
||||
|
|
@ -41,7 +41,7 @@ class OAuthDataSource(Resource):
|
|||
if not internal_secret:
|
||||
return ({"error": "Internal secret is not set"},)
|
||||
oauth_provider.save_internal_access_token(internal_secret)
|
||||
return {"data": ""}
|
||||
return {"data": "internal"}
|
||||
else:
|
||||
auth_url = oauth_provider.get_authorization_url()
|
||||
return {"data": auth_url}, 200
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ class VersionApi(Resource):
|
|||
check_update_url = dify_config.CHECK_UPDATE_URL
|
||||
|
||||
result = {
|
||||
"version": dify_config.CURRENT_VERSION,
|
||||
"version": dify_config.project.version,
|
||||
"release_date": "",
|
||||
"release_notes": "",
|
||||
"can_auto_update": False,
|
||||
|
|
|
|||
|
|
@ -85,6 +85,7 @@ class MemberInviteEmailApi(Resource):
|
|||
return {
|
||||
"result": "success",
|
||||
"invitation_results": invitation_results,
|
||||
"tenant_id": str(current_user.current_tenant.id),
|
||||
}, 201
|
||||
|
||||
|
||||
|
|
@ -110,7 +111,7 @@ class MemberCancelInviteApi(Resource):
|
|||
except Exception as e:
|
||||
raise ValueError(str(e))
|
||||
|
||||
return {"result": "success"}, 204
|
||||
return {"result": "success", "tenant_id": str(current_user.current_tenant.id)}, 200
|
||||
|
||||
|
||||
class MemberUpdateRoleApi(Resource):
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ from core.model_runtime.utils.encoders import jsonable_encoder
|
|||
from core.plugin.impl.exc import PluginDaemonClientSideError
|
||||
from libs.login import login_required
|
||||
from models.account import TenantPluginPermission
|
||||
from services.plugin.plugin_parameter_service import PluginParameterService
|
||||
from services.plugin.plugin_permission_service import PluginPermissionService
|
||||
from services.plugin.plugin_service import PluginService
|
||||
|
||||
|
|
@ -497,6 +498,42 @@ class PluginFetchPermissionApi(Resource):
|
|||
)
|
||||
|
||||
|
||||
class PluginFetchDynamicSelectOptionsApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
# check if the user is admin or owner
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
tenant_id = current_user.current_tenant_id
|
||||
user_id = current_user.id
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("plugin_id", type=str, required=True, location="args")
|
||||
parser.add_argument("provider", type=str, required=True, location="args")
|
||||
parser.add_argument("action", type=str, required=True, location="args")
|
||||
parser.add_argument("parameter", type=str, required=True, location="args")
|
||||
parser.add_argument("provider_type", type=str, required=True, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
options = PluginParameterService.get_dynamic_select_options(
|
||||
tenant_id,
|
||||
user_id,
|
||||
args["plugin_id"],
|
||||
args["provider"],
|
||||
args["action"],
|
||||
args["parameter"],
|
||||
args["provider_type"],
|
||||
)
|
||||
except PluginDaemonClientSideError as e:
|
||||
raise ValueError(e)
|
||||
|
||||
return jsonable_encoder({"options": options})
|
||||
|
||||
|
||||
api.add_resource(PluginDebuggingKeyApi, "/workspaces/current/plugin/debugging-key")
|
||||
api.add_resource(PluginListApi, "/workspaces/current/plugin/list")
|
||||
api.add_resource(PluginListLatestVersionsApi, "/workspaces/current/plugin/list/latest-versions")
|
||||
|
|
@ -521,3 +558,5 @@ api.add_resource(PluginFetchMarketplacePkgApi, "/workspaces/current/plugin/marke
|
|||
|
||||
api.add_resource(PluginChangePermissionApi, "/workspaces/current/plugin/permission/change")
|
||||
api.add_resource(PluginFetchPermissionApi, "/workspaces/current/plugin/permission/fetch")
|
||||
|
||||
api.add_resource(PluginFetchDynamicSelectOptionsApi, "/workspaces/current/plugin/parameters/dynamic-options")
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ from core.plugin.entities.request import (
|
|||
RequestInvokeApp,
|
||||
RequestInvokeEncrypt,
|
||||
RequestInvokeLLM,
|
||||
RequestInvokeLLMWithStructuredOutput,
|
||||
RequestInvokeModeration,
|
||||
RequestInvokeParameterExtractorNode,
|
||||
RequestInvokeQuestionClassifierNode,
|
||||
|
|
@ -47,6 +48,21 @@ class PluginInvokeLLMApi(Resource):
|
|||
return length_prefixed_response(0xF, generator())
|
||||
|
||||
|
||||
class PluginInvokeLLMWithStructuredOutputApi(Resource):
|
||||
@setup_required
|
||||
@plugin_inner_api_only
|
||||
@get_user_tenant
|
||||
@plugin_data(payload_type=RequestInvokeLLMWithStructuredOutput)
|
||||
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeLLMWithStructuredOutput):
|
||||
def generator():
|
||||
response = PluginModelBackwardsInvocation.invoke_llm_with_structured_output(
|
||||
user_model.id, tenant_model, payload
|
||||
)
|
||||
return PluginModelBackwardsInvocation.convert_to_event_stream(response)
|
||||
|
||||
return length_prefixed_response(0xF, generator())
|
||||
|
||||
|
||||
class PluginInvokeTextEmbeddingApi(Resource):
|
||||
@setup_required
|
||||
@plugin_inner_api_only
|
||||
|
|
@ -291,6 +307,7 @@ class PluginFetchAppInfoApi(Resource):
|
|||
|
||||
|
||||
api.add_resource(PluginInvokeLLMApi, "/invoke/llm")
|
||||
api.add_resource(PluginInvokeLLMWithStructuredOutputApi, "/invoke/llm/structured-output")
|
||||
api.add_resource(PluginInvokeTextEmbeddingApi, "/invoke/text-embedding")
|
||||
api.add_resource(PluginInvokeRerankApi, "/invoke/rerank")
|
||||
api.add_resource(PluginInvokeTTSApi, "/invoke/tts")
|
||||
|
|
|
|||
|
|
@ -29,7 +29,19 @@ class EnterpriseWorkspace(Resource):
|
|||
|
||||
tenant_was_created.send(tenant)
|
||||
|
||||
return {"message": "enterprise workspace created."}
|
||||
resp = {
|
||||
"id": tenant.id,
|
||||
"name": tenant.name,
|
||||
"plan": tenant.plan,
|
||||
"status": tenant.status,
|
||||
"created_at": tenant.created_at.isoformat() + "Z" if tenant.created_at else None,
|
||||
"updated_at": tenant.updated_at.isoformat() + "Z" if tenant.updated_at else None,
|
||||
}
|
||||
|
||||
return {
|
||||
"message": "enterprise workspace created.",
|
||||
"tenant": resp,
|
||||
}
|
||||
|
||||
|
||||
class EnterpriseWorkspaceNoOwnerEmail(Resource):
|
||||
|
|
|
|||
|
|
@ -133,6 +133,22 @@ class DatasetListApi(DatasetApiResource):
|
|||
parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.get("embedding_model_provider"):
|
||||
DatasetService.check_embedding_model_setting(
|
||||
tenant_id, args.get("embedding_model_provider"), args.get("embedding_model")
|
||||
)
|
||||
if (
|
||||
args.get("retrieval_model")
|
||||
and args.get("retrieval_model").get("reranking_model")
|
||||
and args.get("retrieval_model").get("reranking_model").get("reranking_provider_name")
|
||||
):
|
||||
DatasetService.check_reranking_model_setting(
|
||||
tenant_id,
|
||||
args.get("retrieval_model").get("reranking_model").get("reranking_provider_name"),
|
||||
args.get("retrieval_model").get("reranking_model").get("reranking_model_name"),
|
||||
)
|
||||
|
||||
try:
|
||||
dataset = DatasetService.create_empty_dataset(
|
||||
tenant_id=tenant_id,
|
||||
|
|
@ -265,10 +281,20 @@ class DatasetApi(DatasetApiResource):
|
|||
data = request.get_json()
|
||||
|
||||
# check embedding model setting
|
||||
if data.get("indexing_technique") == "high_quality":
|
||||
if data.get("indexing_technique") == "high_quality" or data.get("embedding_model_provider"):
|
||||
DatasetService.check_embedding_model_setting(
|
||||
dataset.tenant_id, data.get("embedding_model_provider"), data.get("embedding_model")
|
||||
)
|
||||
if (
|
||||
data.get("retrieval_model")
|
||||
and data.get("retrieval_model").get("reranking_model")
|
||||
and data.get("retrieval_model").get("reranking_model").get("reranking_provider_name")
|
||||
):
|
||||
DatasetService.check_reranking_model_setting(
|
||||
dataset.tenant_id,
|
||||
data.get("retrieval_model").get("reranking_model").get("reranking_provider_name"),
|
||||
data.get("retrieval_model").get("reranking_model").get("reranking_model_name"),
|
||||
)
|
||||
|
||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||
DatasetPermissionService.check_permission(
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import json
|
|||
from flask import request
|
||||
from flask_restful import marshal, reqparse
|
||||
from sqlalchemy import desc, select
|
||||
from werkzeug.exceptions import NotFound
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
import services
|
||||
from controllers.common.errors import FilenameNotExistsError
|
||||
|
|
@ -18,6 +18,7 @@ from controllers.service_api.app.error import (
|
|||
from controllers.service_api.dataset.error import (
|
||||
ArchivedDocumentImmutableError,
|
||||
DocumentIndexingError,
|
||||
InvalidMetadataError,
|
||||
)
|
||||
from controllers.service_api.wraps import (
|
||||
DatasetApiResource,
|
||||
|
|
@ -29,7 +30,7 @@ from extensions.ext_database import db
|
|||
from fields.document_fields import document_fields, document_status_fields
|
||||
from libs.login import current_user
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
from services.dataset_service import DocumentService
|
||||
from services.dataset_service import DatasetService, DocumentService
|
||||
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig
|
||||
from services.file_service import FileService
|
||||
|
||||
|
|
@ -59,6 +60,7 @@ class DocumentAddByTextApi(DatasetApiResource):
|
|||
parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
dataset_id = str(dataset_id)
|
||||
tenant_id = str(tenant_id)
|
||||
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
||||
|
|
@ -74,6 +76,21 @@ class DocumentAddByTextApi(DatasetApiResource):
|
|||
if text is None or name is None:
|
||||
raise ValueError("Both 'text' and 'name' must be non-null values.")
|
||||
|
||||
if args.get("embedding_model_provider"):
|
||||
DatasetService.check_embedding_model_setting(
|
||||
tenant_id, args.get("embedding_model_provider"), args.get("embedding_model")
|
||||
)
|
||||
if (
|
||||
args.get("retrieval_model")
|
||||
and args.get("retrieval_model").get("reranking_model")
|
||||
and args.get("retrieval_model").get("reranking_model").get("reranking_provider_name")
|
||||
):
|
||||
DatasetService.check_reranking_model_setting(
|
||||
tenant_id,
|
||||
args.get("retrieval_model").get("reranking_model").get("reranking_provider_name"),
|
||||
args.get("retrieval_model").get("reranking_model").get("reranking_model_name"),
|
||||
)
|
||||
|
||||
upload_file = FileService.upload_text(text=str(text), text_name=str(name))
|
||||
data_source = {
|
||||
"type": "upload_file",
|
||||
|
|
@ -124,6 +141,17 @@ class DocumentUpdateByTextApi(DatasetApiResource):
|
|||
if not dataset:
|
||||
raise ValueError("Dataset does not exist.")
|
||||
|
||||
if (
|
||||
args.get("retrieval_model")
|
||||
and args.get("retrieval_model").get("reranking_model")
|
||||
and args.get("retrieval_model").get("reranking_model").get("reranking_provider_name")
|
||||
):
|
||||
DatasetService.check_reranking_model_setting(
|
||||
tenant_id,
|
||||
args.get("retrieval_model").get("reranking_model").get("reranking_provider_name"),
|
||||
args.get("retrieval_model").get("reranking_model").get("reranking_model_name"),
|
||||
)
|
||||
|
||||
# indexing_technique is already set in dataset since this is an update
|
||||
args["indexing_technique"] = dataset.indexing_technique
|
||||
|
||||
|
|
@ -188,6 +216,21 @@ class DocumentAddByFileApi(DatasetApiResource):
|
|||
raise ValueError("indexing_technique is required.")
|
||||
args["indexing_technique"] = indexing_technique
|
||||
|
||||
if "embedding_model_provider" in args:
|
||||
DatasetService.check_embedding_model_setting(
|
||||
tenant_id, args["embedding_model_provider"], args["embedding_model"]
|
||||
)
|
||||
if (
|
||||
"retrieval_model" in args
|
||||
and args["retrieval_model"].get("reranking_model")
|
||||
and args["retrieval_model"].get("reranking_model").get("reranking_provider_name")
|
||||
):
|
||||
DatasetService.check_reranking_model_setting(
|
||||
tenant_id,
|
||||
args["retrieval_model"].get("reranking_model").get("reranking_provider_name"),
|
||||
args["retrieval_model"].get("reranking_model").get("reranking_model_name"),
|
||||
)
|
||||
|
||||
# save file info
|
||||
file = request.files["file"]
|
||||
# check file
|
||||
|
|
@ -424,6 +467,101 @@ class DocumentIndexingStatusApi(DatasetApiResource):
|
|||
return data
|
||||
|
||||
|
||||
class DocumentDetailApi(DatasetApiResource):
|
||||
METADATA_CHOICES = {"all", "only", "without"}
|
||||
|
||||
def get(self, tenant_id, dataset_id, document_id):
|
||||
dataset_id = str(dataset_id)
|
||||
document_id = str(document_id)
|
||||
|
||||
dataset = self.get_dataset(dataset_id, tenant_id)
|
||||
|
||||
document = DocumentService.get_document(dataset.id, document_id)
|
||||
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
|
||||
if document.tenant_id != str(tenant_id):
|
||||
raise Forbidden("No permission.")
|
||||
|
||||
metadata = request.args.get("metadata", "all")
|
||||
if metadata not in self.METADATA_CHOICES:
|
||||
raise InvalidMetadataError(f"Invalid metadata value: {metadata}")
|
||||
|
||||
if metadata == "only":
|
||||
response = {"id": document.id, "doc_type": document.doc_type, "doc_metadata": document.doc_metadata_details}
|
||||
elif metadata == "without":
|
||||
dataset_process_rules = DatasetService.get_process_rules(dataset_id)
|
||||
document_process_rules = document.dataset_process_rule.to_dict()
|
||||
data_source_info = document.data_source_detail_dict
|
||||
response = {
|
||||
"id": document.id,
|
||||
"position": document.position,
|
||||
"data_source_type": document.data_source_type,
|
||||
"data_source_info": data_source_info,
|
||||
"dataset_process_rule_id": document.dataset_process_rule_id,
|
||||
"dataset_process_rule": dataset_process_rules,
|
||||
"document_process_rule": document_process_rules,
|
||||
"name": document.name,
|
||||
"created_from": document.created_from,
|
||||
"created_by": document.created_by,
|
||||
"created_at": document.created_at.timestamp(),
|
||||
"tokens": document.tokens,
|
||||
"indexing_status": document.indexing_status,
|
||||
"completed_at": int(document.completed_at.timestamp()) if document.completed_at else None,
|
||||
"updated_at": int(document.updated_at.timestamp()) if document.updated_at else None,
|
||||
"indexing_latency": document.indexing_latency,
|
||||
"error": document.error,
|
||||
"enabled": document.enabled,
|
||||
"disabled_at": int(document.disabled_at.timestamp()) if document.disabled_at else None,
|
||||
"disabled_by": document.disabled_by,
|
||||
"archived": document.archived,
|
||||
"segment_count": document.segment_count,
|
||||
"average_segment_length": document.average_segment_length,
|
||||
"hit_count": document.hit_count,
|
||||
"display_status": document.display_status,
|
||||
"doc_form": document.doc_form,
|
||||
"doc_language": document.doc_language,
|
||||
}
|
||||
else:
|
||||
dataset_process_rules = DatasetService.get_process_rules(dataset_id)
|
||||
document_process_rules = document.dataset_process_rule.to_dict()
|
||||
data_source_info = document.data_source_detail_dict
|
||||
response = {
|
||||
"id": document.id,
|
||||
"position": document.position,
|
||||
"data_source_type": document.data_source_type,
|
||||
"data_source_info": data_source_info,
|
||||
"dataset_process_rule_id": document.dataset_process_rule_id,
|
||||
"dataset_process_rule": dataset_process_rules,
|
||||
"document_process_rule": document_process_rules,
|
||||
"name": document.name,
|
||||
"created_from": document.created_from,
|
||||
"created_by": document.created_by,
|
||||
"created_at": document.created_at.timestamp(),
|
||||
"tokens": document.tokens,
|
||||
"indexing_status": document.indexing_status,
|
||||
"completed_at": int(document.completed_at.timestamp()) if document.completed_at else None,
|
||||
"updated_at": int(document.updated_at.timestamp()) if document.updated_at else None,
|
||||
"indexing_latency": document.indexing_latency,
|
||||
"error": document.error,
|
||||
"enabled": document.enabled,
|
||||
"disabled_at": int(document.disabled_at.timestamp()) if document.disabled_at else None,
|
||||
"disabled_by": document.disabled_by,
|
||||
"archived": document.archived,
|
||||
"doc_type": document.doc_type,
|
||||
"doc_metadata": document.doc_metadata_details,
|
||||
"segment_count": document.segment_count,
|
||||
"average_segment_length": document.average_segment_length,
|
||||
"hit_count": document.hit_count,
|
||||
"display_status": document.display_status,
|
||||
"doc_form": document.doc_form,
|
||||
"doc_language": document.doc_language,
|
||||
}
|
||||
|
||||
return response
|
||||
|
||||
|
||||
api.add_resource(
|
||||
DocumentAddByTextApi,
|
||||
"/datasets/<uuid:dataset_id>/document/create_by_text",
|
||||
|
|
@ -447,3 +585,4 @@ api.add_resource(
|
|||
api.add_resource(DocumentDeleteApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>")
|
||||
api.add_resource(DocumentListApi, "/datasets/<uuid:dataset_id>/documents")
|
||||
api.add_resource(DocumentIndexingStatusApi, "/datasets/<uuid:dataset_id>/documents/<string:batch>/indexing-status")
|
||||
api.add_resource(DocumentDetailApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>")
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ class IndexApi(Resource):
|
|||
return {
|
||||
"welcome": "Dify OpenAPI",
|
||||
"api_version": "v1",
|
||||
"server_version": dify_config.CURRENT_VERSION,
|
||||
"server_version": dify_config.project.version,
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -11,13 +11,13 @@ from flask_restful import Resource
|
|||
from pydantic import BaseModel
|
||||
from sqlalchemy import select, update
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden, Unauthorized
|
||||
from werkzeug.exceptions import Forbidden, NotFound, Unauthorized
|
||||
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.login import _get_user
|
||||
from models.account import Account, Tenant, TenantAccountJoin, TenantStatus
|
||||
from models.dataset import RateLimitLog
|
||||
from models.dataset import Dataset, RateLimitLog
|
||||
from models.model import ApiToken, App, EndUser
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
|
|
@ -317,3 +317,11 @@ def create_or_update_end_user_for_user_id(app_model: App, user_id: Optional[str]
|
|||
|
||||
class DatasetApiResource(Resource):
|
||||
method_decorators = [validate_dataset_token]
|
||||
|
||||
def get_dataset(self, dataset_id: str, tenant_id: str) -> Dataset:
|
||||
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id, Dataset.tenant_id == tenant_id).first()
|
||||
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
return dataset
|
||||
|
|
|
|||
|
|
@ -27,6 +27,9 @@ from core.ops.ops_trace_manager import TraceQueueManager
|
|||
from core.prompt.utils.get_thread_messages_length import get_thread_messages_length
|
||||
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository
|
||||
from core.workflow.repositories.draft_variable_repository import (
|
||||
DraftVariableSaverFactory,
|
||||
)
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader
|
||||
|
|
@ -36,8 +39,10 @@ from libs.flask_utils import preserve_flask_contexts
|
|||
from models import Account, App, Conversation, EndUser, Message, Workflow, WorkflowNodeExecutionTriggeredFrom
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
from services.conversation_service import ConversationService
|
||||
from services.errors.message import MessageNotExistsError
|
||||
from services.workflow_draft_variable_service import DraftVarLoader, WorkflowDraftVariableService
|
||||
from services.workflow_draft_variable_service import (
|
||||
DraftVarLoader,
|
||||
WorkflowDraftVariableService,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -451,6 +456,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
stream=stream,
|
||||
draft_var_saver_factory=self._get_draft_var_saver_factory(invoke_from),
|
||||
)
|
||||
|
||||
return AdvancedChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
|
||||
|
|
@ -480,8 +486,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||
# get conversation and message
|
||||
conversation = self._get_conversation(conversation_id)
|
||||
message = self._get_message(message_id)
|
||||
if message is None:
|
||||
raise MessageNotExistsError("Message not exists")
|
||||
|
||||
# chatbot app
|
||||
runner = AdvancedChatAppRunner(
|
||||
|
|
@ -524,6 +528,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||
user: Union[Account, EndUser],
|
||||
workflow_execution_repository: WorkflowExecutionRepository,
|
||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||
draft_var_saver_factory: DraftVariableSaverFactory,
|
||||
stream: bool = False,
|
||||
) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
|
||||
"""
|
||||
|
|
@ -550,6 +555,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
stream=stream,
|
||||
draft_var_saver_factory=draft_var_saver_factory,
|
||||
)
|
||||
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -64,6 +64,7 @@ from core.workflow.entities.workflow_execution import WorkflowExecutionStatus, W
|
|||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.nodes import NodeType
|
||||
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from core.workflow.workflow_cycle_manager import CycleManagerWorkflowInfo, WorkflowCycleManager
|
||||
|
|
@ -94,6 +95,7 @@ class AdvancedChatAppGenerateTaskPipeline:
|
|||
dialogue_count: int,
|
||||
workflow_execution_repository: WorkflowExecutionRepository,
|
||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||
draft_var_saver_factory: DraftVariableSaverFactory,
|
||||
) -> None:
|
||||
self._base_task_pipeline = BasedGenerateTaskPipeline(
|
||||
application_generate_entity=application_generate_entity,
|
||||
|
|
@ -153,6 +155,7 @@ class AdvancedChatAppGenerateTaskPipeline:
|
|||
self._conversation_name_generate_thread: Thread | None = None
|
||||
self._recorded_files: list[Mapping[str, Any]] = []
|
||||
self._workflow_run_id: str = ""
|
||||
self._draft_var_saver_factory = draft_var_saver_factory
|
||||
|
||||
def process(self) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
|
||||
"""
|
||||
|
|
@ -371,6 +374,7 @@ class AdvancedChatAppGenerateTaskPipeline:
|
|||
workflow_node_execution=workflow_node_execution,
|
||||
)
|
||||
session.commit()
|
||||
self._save_output_for_event(event, workflow_node_execution.id)
|
||||
|
||||
if node_finish_resp:
|
||||
yield node_finish_resp
|
||||
|
|
@ -390,6 +394,8 @@ class AdvancedChatAppGenerateTaskPipeline:
|
|||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
)
|
||||
if isinstance(event, QueueNodeExceptionEvent):
|
||||
self._save_output_for_event(event, workflow_node_execution.id)
|
||||
|
||||
if node_finish_resp:
|
||||
yield node_finish_resp
|
||||
|
|
@ -759,3 +765,15 @@ class AdvancedChatAppGenerateTaskPipeline:
|
|||
if not message:
|
||||
raise ValueError(f"Message not found: {self._message_id}")
|
||||
return message
|
||||
|
||||
def _save_output_for_event(self, event: QueueNodeSucceededEvent | QueueNodeExceptionEvent, node_execution_id: str):
|
||||
with Session(db.engine) as session, session.begin():
|
||||
saver = self._draft_var_saver_factory(
|
||||
session=session,
|
||||
app_id=self._application_generate_entity.app_config.app_id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_execution_id=node_execution_id,
|
||||
enclosing_node_id=event.in_loop_id or event.in_iteration_id,
|
||||
)
|
||||
saver.save(event.process_data, event.outputs)
|
||||
|
|
|
|||
|
|
@ -26,7 +26,6 @@ from factories import file_factory
|
|||
from libs.flask_utils import preserve_flask_contexts
|
||||
from models import Account, App, EndUser
|
||||
from services.conversation_service import ConversationService
|
||||
from services.errors.message import MessageNotExistsError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -238,8 +237,6 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
|||
# get conversation and message
|
||||
conversation = self._get_conversation(conversation_id)
|
||||
message = self._get_message(message_id)
|
||||
if message is None:
|
||||
raise MessageNotExistsError("Message not exists")
|
||||
|
||||
# chatbot app
|
||||
runner = AgentChatAppRunner()
|
||||
|
|
|
|||
|
|
@ -1,10 +1,20 @@
|
|||
import json
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union, final
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.app_config.entities import VariableEntityType
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.file import File, FileUploadConfig
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.repositories.draft_variable_repository import (
|
||||
DraftVariableSaver,
|
||||
DraftVariableSaverFactory,
|
||||
NoopDraftVariableSaver,
|
||||
)
|
||||
from factories import file_factory
|
||||
from services.workflow_draft_variable_service import DraftVariableSaver as DraftVariableSaverImpl
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.app.app_config.entities import VariableEntity
|
||||
|
|
@ -159,3 +169,38 @@ class BaseAppGenerator:
|
|||
yield f"event: {message}\n\n"
|
||||
|
||||
return gen()
|
||||
|
||||
@final
|
||||
@staticmethod
|
||||
def _get_draft_var_saver_factory(invoke_from: InvokeFrom) -> DraftVariableSaverFactory:
|
||||
if invoke_from == InvokeFrom.DEBUGGER:
|
||||
|
||||
def draft_var_saver_factory(
|
||||
session: Session,
|
||||
app_id: str,
|
||||
node_id: str,
|
||||
node_type: NodeType,
|
||||
node_execution_id: str,
|
||||
enclosing_node_id: str | None = None,
|
||||
) -> DraftVariableSaver:
|
||||
return DraftVariableSaverImpl(
|
||||
session=session,
|
||||
app_id=app_id,
|
||||
node_id=node_id,
|
||||
node_type=node_type,
|
||||
node_execution_id=node_execution_id,
|
||||
enclosing_node_id=enclosing_node_id,
|
||||
)
|
||||
else:
|
||||
|
||||
def draft_var_saver_factory(
|
||||
session: Session,
|
||||
app_id: str,
|
||||
node_id: str,
|
||||
node_type: NodeType,
|
||||
node_execution_id: str,
|
||||
enclosing_node_id: str | None = None,
|
||||
) -> DraftVariableSaver:
|
||||
return NoopDraftVariableSaver()
|
||||
|
||||
return draft_var_saver_factory
|
||||
|
|
|
|||
|
|
@ -25,7 +25,6 @@ from factories import file_factory
|
|||
from models.account import Account
|
||||
from models.model import App, EndUser
|
||||
from services.conversation_service import ConversationService
|
||||
from services.errors.message import MessageNotExistsError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -224,8 +223,6 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
|||
# get conversation and message
|
||||
conversation = self._get_conversation(conversation_id)
|
||||
message = self._get_message(message_id)
|
||||
if message is None:
|
||||
raise MessageNotExistsError("Message not exists")
|
||||
|
||||
# chatbot app
|
||||
runner = ChatAppRunner()
|
||||
|
|
|
|||
|
|
@ -45,6 +45,7 @@ from core.app.entities.task_entities import (
|
|||
from core.file import FILE_MODEL_IDENTITY, File
|
||||
from core.plugin.impl.datasource import PluginDatasourceManager
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.variables.segments import ArrayFileSegment, FileSegment, Segment
|
||||
from core.workflow.entities.workflow_execution import WorkflowExecution
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution, WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes import NodeType
|
||||
|
|
@ -516,7 +517,8 @@ class WorkflowResponseConverter:
|
|||
# Convert to tuple to match Sequence type
|
||||
return tuple(flattened_files)
|
||||
|
||||
def _fetch_files_from_variable_value(self, value: Union[dict, list]) -> Sequence[Mapping[str, Any]]:
|
||||
@classmethod
|
||||
def _fetch_files_from_variable_value(cls, value: Union[dict, list, Segment]) -> Sequence[Mapping[str, Any]]:
|
||||
"""
|
||||
Fetch files from variable value
|
||||
:param value: variable value
|
||||
|
|
@ -525,20 +527,30 @@ class WorkflowResponseConverter:
|
|||
if not value:
|
||||
return []
|
||||
|
||||
files = []
|
||||
if isinstance(value, list):
|
||||
files: list[Mapping[str, Any]] = []
|
||||
if isinstance(value, FileSegment):
|
||||
files.append(value.value.to_dict())
|
||||
elif isinstance(value, ArrayFileSegment):
|
||||
files.extend([i.to_dict() for i in value.value])
|
||||
elif isinstance(value, File):
|
||||
files.append(value.to_dict())
|
||||
elif isinstance(value, list):
|
||||
for item in value:
|
||||
file = self._get_file_var_from_value(item)
|
||||
file = cls._get_file_var_from_value(item)
|
||||
if file:
|
||||
files.append(file)
|
||||
elif isinstance(value, dict):
|
||||
file = self._get_file_var_from_value(value)
|
||||
elif isinstance(
|
||||
value,
|
||||
dict,
|
||||
):
|
||||
file = cls._get_file_var_from_value(value)
|
||||
if file:
|
||||
files.append(file)
|
||||
|
||||
return files
|
||||
|
||||
def _get_file_var_from_value(self, value: Union[dict, list]) -> Mapping[str, Any] | None:
|
||||
@classmethod
|
||||
def _get_file_var_from_value(cls, value: Union[dict, list]) -> Mapping[str, Any] | None:
|
||||
"""
|
||||
Get file var from value
|
||||
:param value: variable value
|
||||
|
|
|
|||
|
|
@ -201,8 +201,6 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
|||
try:
|
||||
# get message
|
||||
message = self._get_message(message_id)
|
||||
if message is None:
|
||||
raise MessageNotExistsError()
|
||||
|
||||
# chatbot app
|
||||
runner = CompletionAppRunner()
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ from models.enums import CreatorUserRole
|
|||
from models.model import App, AppMode, AppModelConfig, Conversation, EndUser, Message, MessageFile
|
||||
from services.errors.app_model_config import AppModelConfigBrokenError
|
||||
from services.errors.conversation import ConversationNotExistsError
|
||||
from services.errors.message import MessageNotExistsError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -251,7 +252,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
|||
|
||||
return introduction or ""
|
||||
|
||||
def _get_conversation(self, conversation_id: str):
|
||||
def _get_conversation(self, conversation_id: str) -> Conversation:
|
||||
"""
|
||||
Get conversation by conversation id
|
||||
:param conversation_id: conversation id
|
||||
|
|
@ -260,11 +261,11 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
|||
conversation = db.session.query(Conversation).filter(Conversation.id == conversation_id).first()
|
||||
|
||||
if not conversation:
|
||||
raise ConversationNotExistsError()
|
||||
raise ConversationNotExistsError("Conversation not exists")
|
||||
|
||||
return conversation
|
||||
|
||||
def _get_message(self, message_id: str) -> Optional[Message]:
|
||||
def _get_message(self, message_id: str) -> Message:
|
||||
"""
|
||||
Get message by message id
|
||||
:param message_id: message id
|
||||
|
|
@ -272,4 +273,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
|||
"""
|
||||
message = db.session.query(Message).filter(Message.id == message_id).first()
|
||||
|
||||
if message is None:
|
||||
raise MessageNotExistsError("Message not exists")
|
||||
|
||||
return message
|
||||
|
|
|
|||
|
|
@ -25,6 +25,7 @@ from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
|||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository
|
||||
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader
|
||||
|
|
@ -219,6 +220,9 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
# new thread with request context and contextvars
|
||||
context = contextvars.copy_context()
|
||||
|
||||
# release database connection, because the following new thread operations may take a long time
|
||||
db.session.close()
|
||||
|
||||
worker_thread = threading.Thread(
|
||||
target=self._generate_worker,
|
||||
kwargs={
|
||||
|
|
@ -233,6 +237,10 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
|
||||
worker_thread.start()
|
||||
|
||||
draft_var_saver_factory = self._get_draft_var_saver_factory(
|
||||
invoke_from,
|
||||
)
|
||||
|
||||
# return response or stream generator
|
||||
response = self._handle_response(
|
||||
application_generate_entity=application_generate_entity,
|
||||
|
|
@ -241,6 +249,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
user=user,
|
||||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
draft_var_saver_factory=draft_var_saver_factory,
|
||||
stream=streaming,
|
||||
)
|
||||
|
||||
|
|
@ -471,6 +480,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
user: Union[Account, EndUser],
|
||||
workflow_execution_repository: WorkflowExecutionRepository,
|
||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||
draft_var_saver_factory: DraftVariableSaverFactory,
|
||||
stream: bool = False,
|
||||
) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
|
||||
"""
|
||||
|
|
@ -491,6 +501,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
user=user,
|
||||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
draft_var_saver_factory=draft_var_saver_factory,
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -56,6 +56,7 @@ from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
|
|||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.workflow.entities.workflow_execution import WorkflowExecution, WorkflowExecutionStatus, WorkflowType
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from core.workflow.workflow_cycle_manager import CycleManagerWorkflowInfo, WorkflowCycleManager
|
||||
|
|
@ -87,6 +88,7 @@ class WorkflowAppGenerateTaskPipeline:
|
|||
stream: bool,
|
||||
workflow_execution_repository: WorkflowExecutionRepository,
|
||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||
draft_var_saver_factory: DraftVariableSaverFactory,
|
||||
) -> None:
|
||||
self._base_task_pipeline = BasedGenerateTaskPipeline(
|
||||
application_generate_entity=application_generate_entity,
|
||||
|
|
@ -131,6 +133,8 @@ class WorkflowAppGenerateTaskPipeline:
|
|||
self._application_generate_entity = application_generate_entity
|
||||
self._workflow_features_dict = workflow.features_dict
|
||||
self._workflow_run_id = ""
|
||||
self._invoke_from = queue_manager._invoke_from
|
||||
self._draft_var_saver_factory = draft_var_saver_factory
|
||||
|
||||
def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
|
||||
"""
|
||||
|
|
@ -322,6 +326,8 @@ class WorkflowAppGenerateTaskPipeline:
|
|||
workflow_node_execution=workflow_node_execution,
|
||||
)
|
||||
|
||||
self._save_output_for_event(event, workflow_node_execution.id)
|
||||
|
||||
if node_success_response:
|
||||
yield node_success_response
|
||||
elif isinstance(
|
||||
|
|
@ -339,6 +345,8 @@ class WorkflowAppGenerateTaskPipeline:
|
|||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
)
|
||||
if isinstance(event, QueueNodeExceptionEvent):
|
||||
self._save_output_for_event(event, workflow_node_execution.id)
|
||||
|
||||
if node_failed_response:
|
||||
yield node_failed_response
|
||||
|
|
@ -593,3 +601,15 @@ class WorkflowAppGenerateTaskPipeline:
|
|||
)
|
||||
|
||||
return response
|
||||
|
||||
def _save_output_for_event(self, event: QueueNodeSucceededEvent | QueueNodeExceptionEvent, node_execution_id: str):
|
||||
with Session(db.engine) as session, session.begin():
|
||||
saver = self._draft_var_saver_factory(
|
||||
session=session,
|
||||
app_id=self._application_generate_entity.app_config.app_id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_execution_id=node_execution_id,
|
||||
enclosing_node_id=event.in_loop_id or event.in_iteration_id,
|
||||
)
|
||||
saver.save(event.process_data, event.outputs)
|
||||
|
|
|
|||
|
|
@ -1,8 +1,6 @@
|
|||
from collections.abc import Mapping
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.apps.base_app_runner import AppRunner
|
||||
from core.app.entities.queue_entities import (
|
||||
|
|
@ -35,7 +33,6 @@ from core.workflow.entities.variable_pool import VariablePool
|
|||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
|
||||
from core.workflow.graph_engine.entities.event import (
|
||||
AgentLogEvent,
|
||||
BaseNodeEvent,
|
||||
GraphEngineEvent,
|
||||
GraphRunFailedEvent,
|
||||
GraphRunPartialSucceededEvent,
|
||||
|
|
@ -70,9 +67,6 @@ from core.workflow.workflow_entry import WorkflowEntry
|
|||
from extensions.ext_database import db
|
||||
from models.model import App
|
||||
from models.workflow import Workflow
|
||||
from services.workflow_draft_variable_service import (
|
||||
DraftVariableSaver,
|
||||
)
|
||||
|
||||
|
||||
class WorkflowBasedAppRunner(AppRunner):
|
||||
|
|
@ -400,7 +394,6 @@ class WorkflowBasedAppRunner(AppRunner):
|
|||
in_loop_id=event.in_loop_id,
|
||||
)
|
||||
)
|
||||
self._save_draft_var_for_event(event)
|
||||
|
||||
elif isinstance(event, NodeRunFailedEvent):
|
||||
self._publish_event(
|
||||
|
|
@ -464,7 +457,6 @@ class WorkflowBasedAppRunner(AppRunner):
|
|||
in_loop_id=event.in_loop_id,
|
||||
)
|
||||
)
|
||||
self._save_draft_var_for_event(event)
|
||||
|
||||
elif isinstance(event, NodeInIterationFailedEvent):
|
||||
self._publish_event(
|
||||
|
|
@ -718,30 +710,3 @@ class WorkflowBasedAppRunner(AppRunner):
|
|||
|
||||
def _publish_event(self, event: AppQueueEvent) -> None:
|
||||
self.queue_manager.publish(event, PublishFrom.APPLICATION_MANAGER)
|
||||
|
||||
def _save_draft_var_for_event(self, event: BaseNodeEvent):
|
||||
run_result = event.route_node_state.node_run_result
|
||||
if run_result is None:
|
||||
return
|
||||
process_data = run_result.process_data
|
||||
outputs = run_result.outputs
|
||||
with Session(bind=db.engine) as session, session.begin():
|
||||
draft_var_saver = DraftVariableSaver(
|
||||
session=session,
|
||||
app_id=self._get_app_id(),
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
# FIXME(QuantumGhost): rely on private state of queue_manager is not ideal.
|
||||
invoke_from=self.queue_manager._invoke_from,
|
||||
node_execution_id=event.id,
|
||||
enclosing_node_id=event.in_loop_id or event.in_iteration_id or None,
|
||||
)
|
||||
draft_var_saver.save(process_data=process_data, outputs=outputs)
|
||||
|
||||
|
||||
def _remove_first_element_from_variable_string(key: str) -> str:
|
||||
"""
|
||||
Remove the first element from the prefix.
|
||||
"""
|
||||
prefix, remaining = key.split(".", maxsplit=1)
|
||||
return remaining
|
||||
|
|
|
|||
|
|
@ -395,6 +395,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
|||
message.provider_response_latency = time.perf_counter() - self._start_at
|
||||
message.total_price = usage.total_price
|
||||
message.currency = usage.currency
|
||||
self._task_state.llm_result.usage.latency = message.provider_response_latency
|
||||
message.message_metadata = self._task_state.metadata.model_dump_json()
|
||||
|
||||
if trace_manager:
|
||||
|
|
|
|||
|
|
@ -15,6 +15,11 @@ class CommonParameterType(StrEnum):
|
|||
MODEL_SELECTOR = "model-selector"
|
||||
TOOLS_SELECTOR = "array[tools]"
|
||||
|
||||
# Dynamic select parameter
|
||||
# Once you are not sure about the available options until authorization is done
|
||||
# eg: Select a Slack channel from a Slack workspace
|
||||
DYNAMIC_SELECT = "dynamic-select"
|
||||
|
||||
# TOOL_SELECTOR = "tool-selector"
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -534,7 +534,7 @@ class IndexingRunner:
|
|||
# chunk nodes by chunk size
|
||||
indexing_start_at = time.perf_counter()
|
||||
tokens = 0
|
||||
if dataset_document.doc_form != IndexType.PARENT_CHILD_INDEX:
|
||||
if dataset_document.doc_form != IndexType.PARENT_CHILD_INDEX and dataset.indexing_technique == "economy":
|
||||
# create keyword index
|
||||
create_keyword_thread = threading.Thread(
|
||||
target=self._process_keyword_index,
|
||||
|
|
@ -572,7 +572,7 @@ class IndexingRunner:
|
|||
|
||||
for future in futures:
|
||||
tokens += future.result()
|
||||
if dataset_document.doc_form != IndexType.PARENT_CHILD_INDEX:
|
||||
if dataset_document.doc_form != IndexType.PARENT_CHILD_INDEX and dataset.indexing_technique == "economy":
|
||||
create_keyword_thread.join()
|
||||
indexing_end_at = time.perf_counter()
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,374 @@
|
|||
import json
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from copy import deepcopy
|
||||
from enum import StrEnum
|
||||
from typing import Any, Literal, Optional, cast, overload
|
||||
|
||||
import json_repair
|
||||
from pydantic import TypeAdapter, ValidationError
|
||||
|
||||
from core.llm_generator.output_parser.errors import OutputParserError
|
||||
from core.llm_generator.prompts import STRUCTURED_OUTPUT_PROMPT
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.callbacks.base_callback import Callback
|
||||
from core.model_runtime.entities.llm_entities import (
|
||||
LLMResult,
|
||||
LLMResultChunk,
|
||||
LLMResultChunkDelta,
|
||||
LLMResultChunkWithStructuredOutput,
|
||||
LLMResultWithStructuredOutput,
|
||||
)
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessage,
|
||||
PromptMessageTool,
|
||||
SystemPromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, ParameterRule
|
||||
|
||||
|
||||
class ResponseFormat(StrEnum):
|
||||
"""Constants for model response formats"""
|
||||
|
||||
JSON_SCHEMA = "json_schema" # model's structured output mode. some model like gemini, gpt-4o, support this mode.
|
||||
JSON = "JSON" # model's json mode. some model like claude support this mode.
|
||||
JSON_OBJECT = "json_object" # json mode's another alias. some model like deepseek-chat, qwen use this alias.
|
||||
|
||||
|
||||
class SpecialModelType(StrEnum):
|
||||
"""Constants for identifying model types"""
|
||||
|
||||
GEMINI = "gemini"
|
||||
OLLAMA = "ollama"
|
||||
|
||||
|
||||
@overload
|
||||
def invoke_llm_with_structured_output(
|
||||
provider: str,
|
||||
model_schema: AIModelEntity,
|
||||
model_instance: ModelInstance,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
json_schema: Mapping[str, Any],
|
||||
model_parameters: Optional[Mapping] = None,
|
||||
tools: Sequence[PromptMessageTool] | None = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stream: Literal[True] = True,
|
||||
user: Optional[str] = None,
|
||||
callbacks: Optional[list[Callback]] = None,
|
||||
) -> Generator[LLMResultChunkWithStructuredOutput, None, None]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def invoke_llm_with_structured_output(
|
||||
provider: str,
|
||||
model_schema: AIModelEntity,
|
||||
model_instance: ModelInstance,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
json_schema: Mapping[str, Any],
|
||||
model_parameters: Optional[Mapping] = None,
|
||||
tools: Sequence[PromptMessageTool] | None = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stream: Literal[False] = False,
|
||||
user: Optional[str] = None,
|
||||
callbacks: Optional[list[Callback]] = None,
|
||||
) -> LLMResultWithStructuredOutput: ...
|
||||
|
||||
|
||||
@overload
|
||||
def invoke_llm_with_structured_output(
|
||||
provider: str,
|
||||
model_schema: AIModelEntity,
|
||||
model_instance: ModelInstance,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
json_schema: Mapping[str, Any],
|
||||
model_parameters: Optional[Mapping] = None,
|
||||
tools: Sequence[PromptMessageTool] | None = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
callbacks: Optional[list[Callback]] = None,
|
||||
) -> LLMResultWithStructuredOutput | Generator[LLMResultChunkWithStructuredOutput, None, None]: ...
|
||||
|
||||
|
||||
def invoke_llm_with_structured_output(
|
||||
provider: str,
|
||||
model_schema: AIModelEntity,
|
||||
model_instance: ModelInstance,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
json_schema: Mapping[str, Any],
|
||||
model_parameters: Optional[Mapping] = None,
|
||||
tools: Sequence[PromptMessageTool] | None = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
callbacks: Optional[list[Callback]] = None,
|
||||
) -> LLMResultWithStructuredOutput | Generator[LLMResultChunkWithStructuredOutput, None, None]:
|
||||
"""
|
||||
Invoke large language model with structured output
|
||||
1. This method invokes model_instance.invoke_llm with json_schema
|
||||
2. Try to parse the result as structured output
|
||||
|
||||
:param prompt_messages: prompt messages
|
||||
:param json_schema: json schema
|
||||
:param model_parameters: model parameters
|
||||
:param tools: tools for tool calling
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
:param callbacks: callbacks
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
|
||||
# handle native json schema
|
||||
model_parameters_with_json_schema: dict[str, Any] = {
|
||||
**(model_parameters or {}),
|
||||
}
|
||||
|
||||
if model_schema.support_structure_output:
|
||||
model_parameters = _handle_native_json_schema(
|
||||
provider, model_schema, json_schema, model_parameters_with_json_schema, model_schema.parameter_rules
|
||||
)
|
||||
else:
|
||||
# Set appropriate response format based on model capabilities
|
||||
_set_response_format(model_parameters_with_json_schema, model_schema.parameter_rules)
|
||||
|
||||
# handle prompt based schema
|
||||
prompt_messages = _handle_prompt_based_schema(
|
||||
prompt_messages=prompt_messages,
|
||||
structured_output_schema=json_schema,
|
||||
)
|
||||
|
||||
llm_result = model_instance.invoke_llm(
|
||||
prompt_messages=list(prompt_messages),
|
||||
model_parameters=model_parameters_with_json_schema,
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user,
|
||||
callbacks=callbacks,
|
||||
)
|
||||
|
||||
if isinstance(llm_result, LLMResult):
|
||||
if not isinstance(llm_result.message.content, str):
|
||||
raise OutputParserError(
|
||||
f"Failed to parse structured output, LLM result is not a string: {llm_result.message.content}"
|
||||
)
|
||||
|
||||
return LLMResultWithStructuredOutput(
|
||||
structured_output=_parse_structured_output(llm_result.message.content),
|
||||
model=llm_result.model,
|
||||
message=llm_result.message,
|
||||
usage=llm_result.usage,
|
||||
system_fingerprint=llm_result.system_fingerprint,
|
||||
prompt_messages=llm_result.prompt_messages,
|
||||
)
|
||||
else:
|
||||
|
||||
def generator() -> Generator[LLMResultChunkWithStructuredOutput, None, None]:
|
||||
result_text: str = ""
|
||||
prompt_messages: Sequence[PromptMessage] = []
|
||||
system_fingerprint: Optional[str] = None
|
||||
for event in llm_result:
|
||||
if isinstance(event, LLMResultChunk):
|
||||
if isinstance(event.delta.message.content, str):
|
||||
result_text += event.delta.message.content
|
||||
prompt_messages = event.prompt_messages
|
||||
system_fingerprint = event.system_fingerprint
|
||||
|
||||
yield LLMResultChunkWithStructuredOutput(
|
||||
model=model_schema.model,
|
||||
prompt_messages=prompt_messages,
|
||||
system_fingerprint=system_fingerprint,
|
||||
delta=event.delta,
|
||||
)
|
||||
|
||||
yield LLMResultChunkWithStructuredOutput(
|
||||
structured_output=_parse_structured_output(result_text),
|
||||
model=model_schema.model,
|
||||
prompt_messages=prompt_messages,
|
||||
system_fingerprint=system_fingerprint,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(content=""),
|
||||
usage=None,
|
||||
finish_reason=None,
|
||||
),
|
||||
)
|
||||
|
||||
return generator()
|
||||
|
||||
|
||||
def _handle_native_json_schema(
|
||||
provider: str,
|
||||
model_schema: AIModelEntity,
|
||||
structured_output_schema: Mapping,
|
||||
model_parameters: dict,
|
||||
rules: list[ParameterRule],
|
||||
) -> dict:
|
||||
"""
|
||||
Handle structured output for models with native JSON schema support.
|
||||
|
||||
:param model_parameters: Model parameters to update
|
||||
:param rules: Model parameter rules
|
||||
:return: Updated model parameters with JSON schema configuration
|
||||
"""
|
||||
# Process schema according to model requirements
|
||||
schema_json = _prepare_schema_for_model(provider, model_schema, structured_output_schema)
|
||||
|
||||
# Set JSON schema in parameters
|
||||
model_parameters["json_schema"] = json.dumps(schema_json, ensure_ascii=False)
|
||||
|
||||
# Set appropriate response format if required by the model
|
||||
for rule in rules:
|
||||
if rule.name == "response_format" and ResponseFormat.JSON_SCHEMA.value in rule.options:
|
||||
model_parameters["response_format"] = ResponseFormat.JSON_SCHEMA.value
|
||||
|
||||
return model_parameters
|
||||
|
||||
|
||||
def _set_response_format(model_parameters: dict, rules: list) -> None:
|
||||
"""
|
||||
Set the appropriate response format parameter based on model rules.
|
||||
|
||||
:param model_parameters: Model parameters to update
|
||||
:param rules: Model parameter rules
|
||||
"""
|
||||
for rule in rules:
|
||||
if rule.name == "response_format":
|
||||
if ResponseFormat.JSON.value in rule.options:
|
||||
model_parameters["response_format"] = ResponseFormat.JSON.value
|
||||
elif ResponseFormat.JSON_OBJECT.value in rule.options:
|
||||
model_parameters["response_format"] = ResponseFormat.JSON_OBJECT.value
|
||||
|
||||
|
||||
def _handle_prompt_based_schema(
|
||||
prompt_messages: Sequence[PromptMessage], structured_output_schema: Mapping
|
||||
) -> list[PromptMessage]:
|
||||
"""
|
||||
Handle structured output for models without native JSON schema support.
|
||||
This function modifies the prompt messages to include schema-based output requirements.
|
||||
|
||||
Args:
|
||||
prompt_messages: Original sequence of prompt messages
|
||||
|
||||
Returns:
|
||||
list[PromptMessage]: Updated prompt messages with structured output requirements
|
||||
"""
|
||||
# Convert schema to string format
|
||||
schema_str = json.dumps(structured_output_schema, ensure_ascii=False)
|
||||
|
||||
# Find existing system prompt with schema placeholder
|
||||
system_prompt = next(
|
||||
(prompt for prompt in prompt_messages if isinstance(prompt, SystemPromptMessage)),
|
||||
None,
|
||||
)
|
||||
structured_output_prompt = STRUCTURED_OUTPUT_PROMPT.replace("{{schema}}", schema_str)
|
||||
# Prepare system prompt content
|
||||
system_prompt_content = (
|
||||
structured_output_prompt + "\n\n" + system_prompt.content
|
||||
if system_prompt and isinstance(system_prompt.content, str)
|
||||
else structured_output_prompt
|
||||
)
|
||||
system_prompt = SystemPromptMessage(content=system_prompt_content)
|
||||
|
||||
# Extract content from the last user message
|
||||
|
||||
filtered_prompts = [prompt for prompt in prompt_messages if not isinstance(prompt, SystemPromptMessage)]
|
||||
updated_prompt = [system_prompt] + filtered_prompts
|
||||
|
||||
return updated_prompt
|
||||
|
||||
|
||||
def _parse_structured_output(result_text: str) -> Mapping[str, Any]:
|
||||
structured_output: Mapping[str, Any] = {}
|
||||
parsed: Mapping[str, Any] = {}
|
||||
try:
|
||||
parsed = TypeAdapter(Mapping).validate_json(result_text)
|
||||
if not isinstance(parsed, dict):
|
||||
raise OutputParserError(f"Failed to parse structured output: {result_text}")
|
||||
structured_output = parsed
|
||||
except ValidationError:
|
||||
# if the result_text is not a valid json, try to repair it
|
||||
temp_parsed = json_repair.loads(result_text)
|
||||
if not isinstance(temp_parsed, dict):
|
||||
# handle reasoning model like deepseek-r1 got '<think>\n\n</think>\n' prefix
|
||||
if isinstance(temp_parsed, list):
|
||||
temp_parsed = next((item for item in temp_parsed if isinstance(item, dict)), {})
|
||||
else:
|
||||
raise OutputParserError(f"Failed to parse structured output: {result_text}")
|
||||
structured_output = cast(dict, temp_parsed)
|
||||
return structured_output
|
||||
|
||||
|
||||
def _prepare_schema_for_model(provider: str, model_schema: AIModelEntity, schema: Mapping) -> dict:
|
||||
"""
|
||||
Prepare JSON schema based on model requirements.
|
||||
|
||||
Different models have different requirements for JSON schema formatting.
|
||||
This function handles these differences.
|
||||
|
||||
:param schema: The original JSON schema
|
||||
:return: Processed schema compatible with the current model
|
||||
"""
|
||||
|
||||
# Deep copy to avoid modifying the original schema
|
||||
processed_schema = dict(deepcopy(schema))
|
||||
|
||||
# Convert boolean types to string types (common requirement)
|
||||
convert_boolean_to_string(processed_schema)
|
||||
|
||||
# Apply model-specific transformations
|
||||
if SpecialModelType.GEMINI in model_schema.model:
|
||||
remove_additional_properties(processed_schema)
|
||||
return processed_schema
|
||||
elif SpecialModelType.OLLAMA in provider:
|
||||
return processed_schema
|
||||
else:
|
||||
# Default format with name field
|
||||
return {"schema": processed_schema, "name": "llm_response"}
|
||||
|
||||
|
||||
def remove_additional_properties(schema: dict) -> None:
|
||||
"""
|
||||
Remove additionalProperties fields from JSON schema.
|
||||
Used for models like Gemini that don't support this property.
|
||||
|
||||
:param schema: JSON schema to modify in-place
|
||||
"""
|
||||
if not isinstance(schema, dict):
|
||||
return
|
||||
|
||||
# Remove additionalProperties at current level
|
||||
schema.pop("additionalProperties", None)
|
||||
|
||||
# Process nested structures recursively
|
||||
for value in schema.values():
|
||||
if isinstance(value, dict):
|
||||
remove_additional_properties(value)
|
||||
elif isinstance(value, list):
|
||||
for item in value:
|
||||
if isinstance(item, dict):
|
||||
remove_additional_properties(item)
|
||||
|
||||
|
||||
def convert_boolean_to_string(schema: dict) -> None:
|
||||
"""
|
||||
Convert boolean type specifications to string in JSON schema.
|
||||
|
||||
:param schema: JSON schema to modify in-place
|
||||
"""
|
||||
if not isinstance(schema, dict):
|
||||
return
|
||||
|
||||
# Check for boolean type at current level
|
||||
if schema.get("type") == "boolean":
|
||||
schema["type"] = "string"
|
||||
|
||||
# Process nested dictionaries and lists recursively
|
||||
for value in schema.values():
|
||||
if isinstance(value, dict):
|
||||
convert_boolean_to_string(value)
|
||||
elif isinstance(value, list):
|
||||
for item in value:
|
||||
if isinstance(item, dict):
|
||||
convert_boolean_to_string(item)
|
||||
|
|
@ -291,3 +291,21 @@ Your task is to convert simple user descriptions into properly formatted JSON Sc
|
|||
|
||||
Now, generate a JSON Schema based on my description
|
||||
""" # noqa: E501
|
||||
|
||||
STRUCTURED_OUTPUT_PROMPT = """You’re a helpful AI assistant. You could answer questions and output in JSON format.
|
||||
constraints:
|
||||
- You must output in JSON format.
|
||||
- Do not output boolean value, use string type instead.
|
||||
- Do not output integer or float value, use number type instead.
|
||||
eg:
|
||||
Here is the JSON schema:
|
||||
{"additionalProperties": false, "properties": {"age": {"type": "number"}, "name": {"type": "string"}}, "required": ["name", "age"], "type": "object"}
|
||||
|
||||
Here is the user's question:
|
||||
My name is John Doe and I am 30 years old.
|
||||
|
||||
output:
|
||||
{"name": "John Doe", "age": 30}
|
||||
Here is the JSON schema:
|
||||
{{schema}}
|
||||
""" # noqa: E501
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from collections.abc import Sequence
|
||||
from collections.abc import Mapping, Sequence
|
||||
from decimal import Decimal
|
||||
from enum import StrEnum
|
||||
from typing import Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
|
@ -101,6 +101,20 @@ class LLMResult(BaseModel):
|
|||
system_fingerprint: Optional[str] = None
|
||||
|
||||
|
||||
class LLMStructuredOutput(BaseModel):
|
||||
"""
|
||||
Model class for llm structured output.
|
||||
"""
|
||||
|
||||
structured_output: Optional[Mapping[str, Any]] = None
|
||||
|
||||
|
||||
class LLMResultWithStructuredOutput(LLMResult, LLMStructuredOutput):
|
||||
"""
|
||||
Model class for llm result with structured output.
|
||||
"""
|
||||
|
||||
|
||||
class LLMResultChunkDelta(BaseModel):
|
||||
"""
|
||||
Model class for llm result chunk delta.
|
||||
|
|
@ -123,6 +137,12 @@ class LLMResultChunk(BaseModel):
|
|||
delta: LLMResultChunkDelta
|
||||
|
||||
|
||||
class LLMResultChunkWithStructuredOutput(LLMResultChunk, LLMStructuredOutput):
|
||||
"""
|
||||
Model class for llm result chunk with structured output.
|
||||
"""
|
||||
|
||||
|
||||
class NumTokensResult(PriceInfo):
|
||||
"""
|
||||
Model class for number of tokens result.
|
||||
|
|
|
|||
|
|
@ -83,6 +83,7 @@ class LangFuseDataTrace(BaseTraceInstance):
|
|||
metadata=metadata,
|
||||
session_id=trace_info.conversation_id,
|
||||
tags=["message", "workflow"],
|
||||
version=trace_info.workflow_run_version,
|
||||
)
|
||||
self.add_trace(langfuse_trace_data=trace_data)
|
||||
workflow_span_data = LangfuseSpan(
|
||||
|
|
@ -108,6 +109,7 @@ class LangFuseDataTrace(BaseTraceInstance):
|
|||
metadata=metadata,
|
||||
session_id=trace_info.conversation_id,
|
||||
tags=["workflow"],
|
||||
version=trace_info.workflow_run_version,
|
||||
)
|
||||
self.add_trace(langfuse_trace_data=trace_data)
|
||||
|
||||
|
|
@ -172,37 +174,7 @@ class LangFuseDataTrace(BaseTraceInstance):
|
|||
}
|
||||
)
|
||||
|
||||
# add span
|
||||
if trace_info.message_id:
|
||||
span_data = LangfuseSpan(
|
||||
id=node_execution_id,
|
||||
name=node_type,
|
||||
input=inputs,
|
||||
output=outputs,
|
||||
trace_id=trace_id,
|
||||
start_time=created_at,
|
||||
end_time=finished_at,
|
||||
metadata=metadata,
|
||||
level=(LevelEnum.DEFAULT if status == "succeeded" else LevelEnum.ERROR),
|
||||
status_message=trace_info.error or "",
|
||||
parent_observation_id=trace_info.workflow_run_id,
|
||||
)
|
||||
else:
|
||||
span_data = LangfuseSpan(
|
||||
id=node_execution_id,
|
||||
name=node_type,
|
||||
input=inputs,
|
||||
output=outputs,
|
||||
trace_id=trace_id,
|
||||
start_time=created_at,
|
||||
end_time=finished_at,
|
||||
metadata=metadata,
|
||||
level=(LevelEnum.DEFAULT if status == "succeeded" else LevelEnum.ERROR),
|
||||
status_message=trace_info.error or "",
|
||||
)
|
||||
|
||||
self.add_span(langfuse_span_data=span_data)
|
||||
|
||||
# add generation span
|
||||
if process_data and process_data.get("model_mode") == "chat":
|
||||
total_token = metadata.get("total_tokens", 0)
|
||||
prompt_tokens = 0
|
||||
|
|
@ -226,10 +198,10 @@ class LangFuseDataTrace(BaseTraceInstance):
|
|||
)
|
||||
|
||||
node_generation_data = LangfuseGeneration(
|
||||
name="llm",
|
||||
id=node_execution_id,
|
||||
name=node_name,
|
||||
trace_id=trace_id,
|
||||
model=process_data.get("model_name"),
|
||||
parent_observation_id=node_execution_id,
|
||||
start_time=created_at,
|
||||
end_time=finished_at,
|
||||
input=inputs,
|
||||
|
|
@ -237,11 +209,30 @@ class LangFuseDataTrace(BaseTraceInstance):
|
|||
metadata=metadata,
|
||||
level=(LevelEnum.DEFAULT if status == "succeeded" else LevelEnum.ERROR),
|
||||
status_message=trace_info.error or "",
|
||||
parent_observation_id=trace_info.workflow_run_id if trace_info.message_id else None,
|
||||
usage=generation_usage,
|
||||
)
|
||||
|
||||
self.add_generation(langfuse_generation_data=node_generation_data)
|
||||
|
||||
# add normal span
|
||||
else:
|
||||
span_data = LangfuseSpan(
|
||||
id=node_execution_id,
|
||||
name=node_name,
|
||||
input=inputs,
|
||||
output=outputs,
|
||||
trace_id=trace_id,
|
||||
start_time=created_at,
|
||||
end_time=finished_at,
|
||||
metadata=metadata,
|
||||
level=(LevelEnum.DEFAULT if status == "succeeded" else LevelEnum.ERROR),
|
||||
status_message=trace_info.error or "",
|
||||
parent_observation_id=trace_info.workflow_run_id if trace_info.message_id else None,
|
||||
)
|
||||
|
||||
self.add_span(langfuse_span_data=span_data)
|
||||
|
||||
def message_trace(self, trace_info: MessageTraceInfo, **kwargs):
|
||||
# get message file data
|
||||
file_list = trace_info.file_list
|
||||
|
|
@ -284,7 +275,7 @@ class LangFuseDataTrace(BaseTraceInstance):
|
|||
)
|
||||
self.add_trace(langfuse_trace_data=trace_data)
|
||||
|
||||
# start add span
|
||||
# add generation
|
||||
generation_usage = GenerationUsage(
|
||||
input=trace_info.message_tokens,
|
||||
output=trace_info.answer_tokens,
|
||||
|
|
|
|||
|
|
@ -2,8 +2,15 @@ import tempfile
|
|||
from binascii import hexlify, unhexlify
|
||||
from collections.abc import Generator
|
||||
|
||||
from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.llm_entities import (
|
||||
LLMResult,
|
||||
LLMResultChunk,
|
||||
LLMResultChunkDelta,
|
||||
LLMResultChunkWithStructuredOutput,
|
||||
LLMResultWithStructuredOutput,
|
||||
)
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
PromptMessage,
|
||||
SystemPromptMessage,
|
||||
|
|
@ -12,6 +19,7 @@ from core.model_runtime.entities.message_entities import (
|
|||
from core.plugin.backwards_invocation.base import BaseBackwardsInvocation
|
||||
from core.plugin.entities.request import (
|
||||
RequestInvokeLLM,
|
||||
RequestInvokeLLMWithStructuredOutput,
|
||||
RequestInvokeModeration,
|
||||
RequestInvokeRerank,
|
||||
RequestInvokeSpeech2Text,
|
||||
|
|
@ -81,6 +89,72 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
|
|||
|
||||
return handle_non_streaming(response)
|
||||
|
||||
@classmethod
|
||||
def invoke_llm_with_structured_output(
|
||||
cls, user_id: str, tenant: Tenant, payload: RequestInvokeLLMWithStructuredOutput
|
||||
):
|
||||
"""
|
||||
invoke llm with structured output
|
||||
"""
|
||||
model_instance = ModelManager().get_model_instance(
|
||||
tenant_id=tenant.id,
|
||||
provider=payload.provider,
|
||||
model_type=payload.model_type,
|
||||
model=payload.model,
|
||||
)
|
||||
|
||||
model_schema = model_instance.model_type_instance.get_model_schema(payload.model, model_instance.credentials)
|
||||
|
||||
if not model_schema:
|
||||
raise ValueError(f"Model schema not found for {payload.model}")
|
||||
|
||||
response = invoke_llm_with_structured_output(
|
||||
provider=payload.provider,
|
||||
model_schema=model_schema,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=payload.prompt_messages,
|
||||
json_schema=payload.structured_output_schema,
|
||||
tools=payload.tools,
|
||||
stop=payload.stop,
|
||||
stream=True if payload.stream is None else payload.stream,
|
||||
user=user_id,
|
||||
model_parameters=payload.completion_params,
|
||||
)
|
||||
|
||||
if isinstance(response, Generator):
|
||||
|
||||
def handle() -> Generator[LLMResultChunkWithStructuredOutput, None, None]:
|
||||
for chunk in response:
|
||||
if chunk.delta.usage:
|
||||
llm_utils.deduct_llm_quota(
|
||||
tenant_id=tenant.id, model_instance=model_instance, usage=chunk.delta.usage
|
||||
)
|
||||
chunk.prompt_messages = []
|
||||
yield chunk
|
||||
|
||||
return handle()
|
||||
else:
|
||||
if response.usage:
|
||||
llm_utils.deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=response.usage)
|
||||
|
||||
def handle_non_streaming(
|
||||
response: LLMResultWithStructuredOutput,
|
||||
) -> Generator[LLMResultChunkWithStructuredOutput, None, None]:
|
||||
yield LLMResultChunkWithStructuredOutput(
|
||||
model=response.model,
|
||||
prompt_messages=[],
|
||||
system_fingerprint=response.system_fingerprint,
|
||||
structured_output=response.structured_output,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=response.message,
|
||||
usage=response.usage,
|
||||
finish_reason="",
|
||||
),
|
||||
)
|
||||
|
||||
return handle_non_streaming(response)
|
||||
|
||||
@classmethod
|
||||
def invoke_text_embedding(cls, user_id: str, tenant: Tenant, payload: RequestInvokeTextEmbedding):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -10,6 +10,9 @@ from core.tools.entities.common_entities import I18nObject
|
|||
class PluginParameterOption(BaseModel):
|
||||
value: str = Field(..., description="The value of the option")
|
||||
label: I18nObject = Field(..., description="The label of the option")
|
||||
icon: Optional[str] = Field(
|
||||
default=None, description="The icon of the option, can be a url or a base64 encoded image"
|
||||
)
|
||||
|
||||
@field_validator("value", mode="before")
|
||||
@classmethod
|
||||
|
|
@ -35,6 +38,7 @@ class PluginParameterType(enum.StrEnum):
|
|||
APP_SELECTOR = CommonParameterType.APP_SELECTOR.value
|
||||
MODEL_SELECTOR = CommonParameterType.MODEL_SELECTOR.value
|
||||
TOOLS_SELECTOR = CommonParameterType.TOOLS_SELECTOR.value
|
||||
DYNAMIC_SELECT = CommonParameterType.DYNAMIC_SELECT.value
|
||||
|
||||
# deprecated, should not use.
|
||||
SYSTEM_FILES = CommonParameterType.SYSTEM_FILES.value
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from collections.abc import Mapping
|
||||
from collections.abc import Mapping, Sequence
|
||||
from datetime import datetime
|
||||
from enum import StrEnum
|
||||
from typing import Any, Generic, Optional, TypeVar
|
||||
|
|
@ -10,6 +10,7 @@ from core.datasource.entities.datasource_entities import DatasourceProviderEntit
|
|||
from core.model_runtime.entities.model_entities import AIModelEntity
|
||||
from core.model_runtime.entities.provider_entities import ProviderEntity
|
||||
from core.plugin.entities.base import BasePluginEntity
|
||||
from core.plugin.entities.parameters import PluginParameterOption
|
||||
from core.plugin.entities.plugin import PluginDeclaration, PluginEntity
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolProviderEntityWithPlugin
|
||||
|
|
@ -195,3 +196,7 @@ class PluginOAuthCredentialsResponse(BaseModel):
|
|||
class PluginListResponse(BaseModel):
|
||||
list: list[PluginEntity]
|
||||
total: int
|
||||
|
||||
|
||||
class PluginDynamicSelectOptionsResponse(BaseModel):
|
||||
options: Sequence[PluginParameterOption] = Field(description="The options of the dynamic select.")
|
||||
|
|
|
|||
|
|
@ -82,6 +82,16 @@ class RequestInvokeLLM(BaseRequestInvokeModel):
|
|||
return v
|
||||
|
||||
|
||||
class RequestInvokeLLMWithStructuredOutput(RequestInvokeLLM):
|
||||
"""
|
||||
Request to invoke LLM with structured output
|
||||
"""
|
||||
|
||||
structured_output_schema: dict[str, Any] = Field(
|
||||
default_factory=dict, description="The schema of the structured output in JSON schema format"
|
||||
)
|
||||
|
||||
|
||||
class RequestInvokeTextEmbedding(BaseRequestInvokeModel):
|
||||
"""
|
||||
Request to invoke text embedding
|
||||
|
|
|
|||
|
|
@ -0,0 +1,45 @@
|
|||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.plugin.entities.plugin import GenericProviderID
|
||||
from core.plugin.entities.plugin_daemon import PluginDynamicSelectOptionsResponse
|
||||
from core.plugin.impl.base import BasePluginClient
|
||||
|
||||
|
||||
class DynamicSelectClient(BasePluginClient):
|
||||
def fetch_dynamic_select_options(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
action: str,
|
||||
credentials: Mapping[str, Any],
|
||||
parameter: str,
|
||||
) -> PluginDynamicSelectOptionsResponse:
|
||||
"""
|
||||
Fetch dynamic select options for a plugin parameter.
|
||||
"""
|
||||
response = self._request_with_plugin_daemon_response_stream(
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/dispatch/dynamic_select/fetch_parameter_options",
|
||||
PluginDynamicSelectOptionsResponse,
|
||||
data={
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
"provider": GenericProviderID(provider).provider_name,
|
||||
"credentials": credentials,
|
||||
"provider_action": action,
|
||||
"parameter": parameter,
|
||||
},
|
||||
},
|
||||
headers={
|
||||
"X-Plugin-ID": plugin_id,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
for options in response:
|
||||
return options
|
||||
|
||||
raise ValueError(f"Plugin service returned no options for parameter '{parameter}' in provider '{provider}'")
|
||||
|
|
@ -79,6 +79,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
|||
if dataset.indexing_technique == "high_quality":
|
||||
vector = Vector(dataset)
|
||||
vector.create(documents)
|
||||
with_keywords = False
|
||||
if with_keywords:
|
||||
keywords_list = kwargs.get("keywords_list")
|
||||
keyword = Keyword(dataset)
|
||||
|
|
@ -94,6 +95,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
|||
vector.delete_by_ids(node_ids)
|
||||
else:
|
||||
vector.delete()
|
||||
with_keywords = False
|
||||
if with_keywords:
|
||||
keyword = Keyword(dataset)
|
||||
if node_ids:
|
||||
|
|
|
|||
|
|
@ -1010,6 +1010,9 @@ class DatasetRetrieval:
|
|||
def _process_metadata_filter_func(
|
||||
self, sequence: int, condition: str, metadata_name: str, value: Optional[Any], filters: list
|
||||
):
|
||||
if value is None:
|
||||
return
|
||||
|
||||
key = f"{metadata_name}_{sequence}"
|
||||
key_value = f"{metadata_name}_{sequence}_value"
|
||||
match condition:
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ from typing import Any, Optional
|
|||
from core.helper.code_executor.code_executor import CodeExecutor, CodeLanguage
|
||||
from core.tools.builtin_tool.tool import BuiltinTool
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.errors import ToolInvokeError
|
||||
|
||||
|
||||
class SimpleCode(BuiltinTool):
|
||||
|
|
@ -25,6 +26,8 @@ class SimpleCode(BuiltinTool):
|
|||
if language not in {CodeLanguage.PYTHON3, CodeLanguage.JAVASCRIPT}:
|
||||
raise ValueError(f"Only python3 and javascript are supported, not {language}")
|
||||
|
||||
result = CodeExecutor.execute_code(language, "", code)
|
||||
|
||||
yield self.create_text_message(result)
|
||||
try:
|
||||
result = CodeExecutor.execute_code(language, "", code)
|
||||
yield self.create_text_message(result)
|
||||
except Exception as e:
|
||||
raise ToolInvokeError(str(e))
|
||||
|
|
|
|||
|
|
@ -240,6 +240,7 @@ class ToolParameter(PluginParameter):
|
|||
FILES = PluginParameterType.FILES.value
|
||||
APP_SELECTOR = PluginParameterType.APP_SELECTOR.value
|
||||
MODEL_SELECTOR = PluginParameterType.MODEL_SELECTOR.value
|
||||
DYNAMIC_SELECT = PluginParameterType.DYNAMIC_SELECT.value
|
||||
|
||||
# deprecated, should not use.
|
||||
SYSTEM_FILES = PluginParameterType.SYSTEM_FILES.value
|
||||
|
|
|
|||
|
|
@ -86,6 +86,7 @@ class ProviderConfigEncrypter(BaseModel):
|
|||
cached_credentials = cache.get()
|
||||
if cached_credentials:
|
||||
return cached_credentials
|
||||
|
||||
data = self._deep_copy(data)
|
||||
# get fields need to be decrypted
|
||||
fields = dict[str, BasicProviderConfig]()
|
||||
|
|
|
|||
|
|
@ -67,11 +67,21 @@ class WorkflowNodeExecution(BaseModel):
|
|||
but they are not stored in the model.
|
||||
"""
|
||||
|
||||
# Core identification fields
|
||||
id: str # Unique identifier for this execution record
|
||||
node_execution_id: Optional[str] = None # Optional secondary ID for cross-referencing
|
||||
# --------- Core identification fields ---------
|
||||
|
||||
# Unique identifier for this execution record, used when persisting to storage.
|
||||
# Value is a UUID string (e.g., '09b3e04c-f9ae-404c-ad82-290b8d7bd382').
|
||||
id: str
|
||||
|
||||
# Optional secondary ID for cross-referencing purposes.
|
||||
#
|
||||
# NOTE: For referencing the persisted record, use `id` rather than `node_execution_id`.
|
||||
# While `node_execution_id` may sometimes be a UUID string, this is not guaranteed.
|
||||
# In most scenarios, `id` should be used as the primary identifier.
|
||||
node_execution_id: Optional[str] = None
|
||||
workflow_id: str # ID of the workflow this node belongs to
|
||||
workflow_execution_id: Optional[str] = None # ID of the specific workflow run (null for single-step debugging)
|
||||
# --------- Core identification fields ends ---------
|
||||
|
||||
# Execution positioning and flow
|
||||
index: int # Sequence number for ordering in trace visualization
|
||||
|
|
|
|||
|
|
@ -158,7 +158,10 @@ class AgentNode(ToolNode):
|
|||
# variable_pool.convert_template expects a string template,
|
||||
# but if passing a dict, convert to JSON string first before rendering
|
||||
try:
|
||||
parameter_value = json.dumps(agent_input.value, ensure_ascii=False)
|
||||
if not isinstance(agent_input.value, str):
|
||||
parameter_value = json.dumps(agent_input.value, ensure_ascii=False)
|
||||
else:
|
||||
parameter_value = str(agent_input.value)
|
||||
except TypeError:
|
||||
parameter_value = str(agent_input.value)
|
||||
segment_group = variable_pool.convert_template(parameter_value)
|
||||
|
|
@ -166,7 +169,8 @@ class AgentNode(ToolNode):
|
|||
# variable_pool.convert_template returns a string,
|
||||
# so we need to convert it back to a dictionary
|
||||
try:
|
||||
parameter_value = json.loads(parameter_value)
|
||||
if not isinstance(agent_input.value, str):
|
||||
parameter_value = json.loads(parameter_value)
|
||||
except json.JSONDecodeError:
|
||||
parameter_value = parameter_value
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@ import logging
|
|||
from collections.abc import Generator
|
||||
from typing import cast
|
||||
|
||||
from core.file import FILE_MODEL_IDENTITY, File
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.graph_engine.entities.event import (
|
||||
GraphEngineEvent,
|
||||
|
|
@ -201,44 +200,3 @@ class AnswerStreamProcessor(StreamProcessor):
|
|||
stream_out_answer_node_ids.append(answer_node_id)
|
||||
|
||||
return stream_out_answer_node_ids
|
||||
|
||||
@classmethod
|
||||
def _fetch_files_from_variable_value(cls, value: dict | list) -> list[dict]:
|
||||
"""
|
||||
Fetch files from variable value
|
||||
:param value: variable value
|
||||
:return:
|
||||
"""
|
||||
if not value:
|
||||
return []
|
||||
|
||||
files = []
|
||||
if isinstance(value, list):
|
||||
for item in value:
|
||||
file_var = cls._get_file_var_from_value(item)
|
||||
if file_var:
|
||||
files.append(file_var)
|
||||
elif isinstance(value, dict):
|
||||
file_var = cls._get_file_var_from_value(value)
|
||||
if file_var:
|
||||
files.append(file_var)
|
||||
|
||||
return files
|
||||
|
||||
@classmethod
|
||||
def _get_file_var_from_value(cls, value: dict | list):
|
||||
"""
|
||||
Get file var from value
|
||||
:param value: variable value
|
||||
:return:
|
||||
"""
|
||||
if not value:
|
||||
return None
|
||||
|
||||
if isinstance(value, dict):
|
||||
if "dify_model_identity" in value and value["dify_model_identity"] == FILE_MODEL_IDENTITY:
|
||||
return value
|
||||
elif isinstance(value, File):
|
||||
return value.to_dict()
|
||||
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -333,7 +333,7 @@ class Executor:
|
|||
try:
|
||||
response = getattr(ssrf_proxy, self.method.lower())(**request_args)
|
||||
except (ssrf_proxy.MaxRetriesExceededError, httpx.RequestError) as e:
|
||||
raise HttpRequestNodeError(str(e))
|
||||
raise HttpRequestNodeError(str(e)) from e
|
||||
# FIXME: fix type ignore, this maybe httpx type issue
|
||||
return response # type: ignore
|
||||
|
||||
|
|
|
|||
|
|
@ -490,6 +490,9 @@ class KnowledgeRetrievalNode(LLMNode):
|
|||
def _process_metadata_filter_func(
|
||||
self, sequence: int, condition: str, metadata_name: str, value: Optional[Any], filters: list
|
||||
):
|
||||
if value is None:
|
||||
return
|
||||
|
||||
key = f"{metadata_name}_{sequence}"
|
||||
key_value = f"{metadata_name}_{sequence}_value"
|
||||
match condition:
|
||||
|
|
|
|||
|
|
@ -5,11 +5,11 @@ import logging
|
|||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Optional, cast
|
||||
|
||||
import json_repair
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.file import FileType, file_manager
|
||||
from core.helper.code_executor import CodeExecutor, CodeLanguage
|
||||
from core.llm_generator.output_parser.errors import OutputParserError
|
||||
from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.model_runtime.entities import (
|
||||
|
|
@ -18,7 +18,13 @@ from core.model_runtime.entities import (
|
|||
PromptMessageContentType,
|
||||
TextPromptMessageContent,
|
||||
)
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMUsage
|
||||
from core.model_runtime.entities.llm_entities import (
|
||||
LLMResult,
|
||||
LLMResultChunk,
|
||||
LLMResultChunkWithStructuredOutput,
|
||||
LLMStructuredOutput,
|
||||
LLMUsage,
|
||||
)
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessageContentUnionTypes,
|
||||
|
|
@ -31,7 +37,6 @@ from core.model_runtime.entities.model_entities import (
|
|||
ModelFeature,
|
||||
ModelPropertyKey,
|
||||
ModelType,
|
||||
ParameterRule,
|
||||
)
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
|
|
@ -62,11 +67,6 @@ from core.workflow.nodes.event import (
|
|||
RunRetrieverResourceEvent,
|
||||
RunStreamChunkEvent,
|
||||
)
|
||||
from core.workflow.utils.structured_output.entities import (
|
||||
ResponseFormat,
|
||||
SpecialModelType,
|
||||
)
|
||||
from core.workflow.utils.structured_output.prompt import STRUCTURED_OUTPUT_PROMPT
|
||||
from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
||||
|
||||
from . import llm_utils
|
||||
|
|
@ -143,12 +143,6 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||
return "1"
|
||||
|
||||
def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]:
|
||||
def process_structured_output(text: str) -> Optional[dict[str, Any]]:
|
||||
"""Process structured output if enabled"""
|
||||
if not self.node_data.structured_output_enabled or not self.node_data.structured_output:
|
||||
return None
|
||||
return self._parse_structured_output(text)
|
||||
|
||||
node_inputs: Optional[dict[str, Any]] = None
|
||||
process_data = None
|
||||
result_text = ""
|
||||
|
|
@ -244,6 +238,8 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||
stop=stop,
|
||||
)
|
||||
|
||||
structured_output: LLMStructuredOutput | None = None
|
||||
|
||||
for event in generator:
|
||||
if isinstance(event, RunStreamChunkEvent):
|
||||
yield event
|
||||
|
|
@ -254,10 +250,12 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||
# deduct quota
|
||||
llm_utils.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
|
||||
break
|
||||
elif isinstance(event, LLMStructuredOutput):
|
||||
structured_output = event
|
||||
|
||||
outputs = {"text": result_text, "usage": jsonable_encoder(usage), "finish_reason": finish_reason}
|
||||
structured_output = process_structured_output(result_text)
|
||||
if structured_output:
|
||||
outputs["structured_output"] = structured_output
|
||||
outputs["structured_output"] = structured_output.structured_output
|
||||
if self._file_outputs is not None:
|
||||
outputs["files"] = ArrayFileSegment(value=self._file_outputs)
|
||||
|
||||
|
|
@ -302,20 +300,40 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||
model_instance: ModelInstance,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
) -> Generator[NodeEvent, None, None]:
|
||||
invoke_result = model_instance.invoke_llm(
|
||||
prompt_messages=list(prompt_messages),
|
||||
model_parameters=node_data_model.completion_params,
|
||||
stop=list(stop or []),
|
||||
stream=True,
|
||||
user=self.user_id,
|
||||
) -> Generator[NodeEvent | LLMStructuredOutput, None, None]:
|
||||
model_schema = model_instance.model_type_instance.get_model_schema(
|
||||
node_data_model.name, model_instance.credentials
|
||||
)
|
||||
if not model_schema:
|
||||
raise ValueError(f"Model schema not found for {node_data_model.name}")
|
||||
|
||||
if self.node_data.structured_output_enabled:
|
||||
output_schema = self._fetch_structured_output_schema()
|
||||
invoke_result = invoke_llm_with_structured_output(
|
||||
provider=model_instance.provider,
|
||||
model_schema=model_schema,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
json_schema=output_schema,
|
||||
model_parameters=node_data_model.completion_params,
|
||||
stop=list(stop or []),
|
||||
stream=True,
|
||||
user=self.user_id,
|
||||
)
|
||||
else:
|
||||
invoke_result = model_instance.invoke_llm(
|
||||
prompt_messages=list(prompt_messages),
|
||||
model_parameters=node_data_model.completion_params,
|
||||
stop=list(stop or []),
|
||||
stream=True,
|
||||
user=self.user_id,
|
||||
)
|
||||
|
||||
return self._handle_invoke_result(invoke_result=invoke_result)
|
||||
|
||||
def _handle_invoke_result(
|
||||
self, invoke_result: LLMResult | Generator[LLMResultChunk, None, None]
|
||||
) -> Generator[NodeEvent, None, None]:
|
||||
self, invoke_result: LLMResult | Generator[LLMResultChunk | LLMStructuredOutput, None, None]
|
||||
) -> Generator[NodeEvent | LLMStructuredOutput, None, None]:
|
||||
# For blocking mode
|
||||
if isinstance(invoke_result, LLMResult):
|
||||
event = self._handle_blocking_result(invoke_result=invoke_result)
|
||||
|
|
@ -329,23 +347,32 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||
usage = LLMUsage.empty_usage()
|
||||
finish_reason = None
|
||||
full_text_buffer = io.StringIO()
|
||||
for result in invoke_result:
|
||||
contents = result.delta.message.content
|
||||
for text_part in self._save_multimodal_output_and_convert_result_to_markdown(contents):
|
||||
full_text_buffer.write(text_part)
|
||||
yield RunStreamChunkEvent(chunk_content=text_part, from_variable_selector=[self.node_id, "text"])
|
||||
# Consume the invoke result and handle generator exception
|
||||
try:
|
||||
for result in invoke_result:
|
||||
if isinstance(result, LLMResultChunkWithStructuredOutput):
|
||||
yield result
|
||||
if isinstance(result, LLMResultChunk):
|
||||
contents = result.delta.message.content
|
||||
for text_part in self._save_multimodal_output_and_convert_result_to_markdown(contents):
|
||||
full_text_buffer.write(text_part)
|
||||
yield RunStreamChunkEvent(
|
||||
chunk_content=text_part, from_variable_selector=[self.node_id, "text"]
|
||||
)
|
||||
|
||||
# Update the whole metadata
|
||||
if not model and result.model:
|
||||
model = result.model
|
||||
if len(prompt_messages) == 0:
|
||||
# TODO(QuantumGhost): it seems that this update has no visable effect.
|
||||
# What's the purpose of the line below?
|
||||
prompt_messages = list(result.prompt_messages)
|
||||
if usage.prompt_tokens == 0 and result.delta.usage:
|
||||
usage = result.delta.usage
|
||||
if finish_reason is None and result.delta.finish_reason:
|
||||
finish_reason = result.delta.finish_reason
|
||||
# Update the whole metadata
|
||||
if not model and result.model:
|
||||
model = result.model
|
||||
if len(prompt_messages) == 0:
|
||||
# TODO(QuantumGhost): it seems that this update has no visable effect.
|
||||
# What's the purpose of the line below?
|
||||
prompt_messages = list(result.prompt_messages)
|
||||
if usage.prompt_tokens == 0 and result.delta.usage:
|
||||
usage = result.delta.usage
|
||||
if finish_reason is None and result.delta.finish_reason:
|
||||
finish_reason = result.delta.finish_reason
|
||||
except OutputParserError as e:
|
||||
raise LLMNodeError(f"Failed to parse structured output: {e}")
|
||||
|
||||
yield ModelInvokeCompletedEvent(text=full_text_buffer.getvalue(), usage=usage, finish_reason=finish_reason)
|
||||
|
||||
|
|
@ -522,12 +549,6 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||
if not model_schema:
|
||||
raise ModelNotExistError(f"Model {node_data_model.name} not exist.")
|
||||
|
||||
if self.node_data.structured_output_enabled:
|
||||
if model_schema.support_structure_output:
|
||||
completion_params = self._handle_native_json_schema(completion_params, model_schema.parameter_rules)
|
||||
else:
|
||||
# Set appropriate response format based on model capabilities
|
||||
self._set_response_format(completion_params, model_schema.parameter_rules)
|
||||
model_config_with_cred.parameters = completion_params
|
||||
# NOTE(-LAN-): This line modify the `self.node_data.model`, which is used in `_invoke_llm()`.
|
||||
node_data_model.completion_params = completion_params
|
||||
|
|
@ -719,32 +740,8 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||
)
|
||||
if not model_schema:
|
||||
raise ModelNotExistError(f"Model {model_config.model} not exist.")
|
||||
if self.node_data.structured_output_enabled:
|
||||
if not model_schema.support_structure_output:
|
||||
filtered_prompt_messages = self._handle_prompt_based_schema(
|
||||
prompt_messages=filtered_prompt_messages,
|
||||
)
|
||||
return filtered_prompt_messages, model_config.stop
|
||||
|
||||
def _parse_structured_output(self, result_text: str) -> dict[str, Any]:
|
||||
structured_output: dict[str, Any] = {}
|
||||
try:
|
||||
parsed = json.loads(result_text)
|
||||
if not isinstance(parsed, dict):
|
||||
raise LLMNodeError(f"Failed to parse structured output: {result_text}")
|
||||
structured_output = parsed
|
||||
except json.JSONDecodeError as e:
|
||||
# if the result_text is not a valid json, try to repair it
|
||||
parsed = json_repair.loads(result_text)
|
||||
if not isinstance(parsed, dict):
|
||||
# handle reasoning model like deepseek-r1 got '<think>\n\n</think>\n' prefix
|
||||
if isinstance(parsed, list):
|
||||
parsed = next((item for item in parsed if isinstance(item, dict)), {})
|
||||
else:
|
||||
raise LLMNodeError(f"Failed to parse structured output: {result_text}")
|
||||
structured_output = parsed
|
||||
return structured_output
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
|
|
@ -934,104 +931,6 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||
self._file_outputs.append(saved_file)
|
||||
return saved_file
|
||||
|
||||
def _handle_native_json_schema(self, model_parameters: dict, rules: list[ParameterRule]) -> dict:
|
||||
"""
|
||||
Handle structured output for models with native JSON schema support.
|
||||
|
||||
:param model_parameters: Model parameters to update
|
||||
:param rules: Model parameter rules
|
||||
:return: Updated model parameters with JSON schema configuration
|
||||
"""
|
||||
# Process schema according to model requirements
|
||||
schema = self._fetch_structured_output_schema()
|
||||
schema_json = self._prepare_schema_for_model(schema)
|
||||
|
||||
# Set JSON schema in parameters
|
||||
model_parameters["json_schema"] = json.dumps(schema_json, ensure_ascii=False)
|
||||
|
||||
# Set appropriate response format if required by the model
|
||||
for rule in rules:
|
||||
if rule.name == "response_format" and ResponseFormat.JSON_SCHEMA.value in rule.options:
|
||||
model_parameters["response_format"] = ResponseFormat.JSON_SCHEMA.value
|
||||
|
||||
return model_parameters
|
||||
|
||||
def _handle_prompt_based_schema(self, prompt_messages: Sequence[PromptMessage]) -> list[PromptMessage]:
|
||||
"""
|
||||
Handle structured output for models without native JSON schema support.
|
||||
This function modifies the prompt messages to include schema-based output requirements.
|
||||
|
||||
Args:
|
||||
prompt_messages: Original sequence of prompt messages
|
||||
|
||||
Returns:
|
||||
list[PromptMessage]: Updated prompt messages with structured output requirements
|
||||
"""
|
||||
# Convert schema to string format
|
||||
schema_str = json.dumps(self._fetch_structured_output_schema(), ensure_ascii=False)
|
||||
|
||||
# Find existing system prompt with schema placeholder
|
||||
system_prompt = next(
|
||||
(prompt for prompt in prompt_messages if isinstance(prompt, SystemPromptMessage)),
|
||||
None,
|
||||
)
|
||||
structured_output_prompt = STRUCTURED_OUTPUT_PROMPT.replace("{{schema}}", schema_str)
|
||||
# Prepare system prompt content
|
||||
system_prompt_content = (
|
||||
structured_output_prompt + "\n\n" + system_prompt.content
|
||||
if system_prompt and isinstance(system_prompt.content, str)
|
||||
else structured_output_prompt
|
||||
)
|
||||
system_prompt = SystemPromptMessage(content=system_prompt_content)
|
||||
|
||||
# Extract content from the last user message
|
||||
|
||||
filtered_prompts = [prompt for prompt in prompt_messages if not isinstance(prompt, SystemPromptMessage)]
|
||||
updated_prompt = [system_prompt] + filtered_prompts
|
||||
|
||||
return updated_prompt
|
||||
|
||||
def _set_response_format(self, model_parameters: dict, rules: list) -> None:
|
||||
"""
|
||||
Set the appropriate response format parameter based on model rules.
|
||||
|
||||
:param model_parameters: Model parameters to update
|
||||
:param rules: Model parameter rules
|
||||
"""
|
||||
for rule in rules:
|
||||
if rule.name == "response_format":
|
||||
if ResponseFormat.JSON.value in rule.options:
|
||||
model_parameters["response_format"] = ResponseFormat.JSON.value
|
||||
elif ResponseFormat.JSON_OBJECT.value in rule.options:
|
||||
model_parameters["response_format"] = ResponseFormat.JSON_OBJECT.value
|
||||
|
||||
def _prepare_schema_for_model(self, schema: dict) -> dict:
|
||||
"""
|
||||
Prepare JSON schema based on model requirements.
|
||||
|
||||
Different models have different requirements for JSON schema formatting.
|
||||
This function handles these differences.
|
||||
|
||||
:param schema: The original JSON schema
|
||||
:return: Processed schema compatible with the current model
|
||||
"""
|
||||
|
||||
# Deep copy to avoid modifying the original schema
|
||||
processed_schema = schema.copy()
|
||||
|
||||
# Convert boolean types to string types (common requirement)
|
||||
convert_boolean_to_string(processed_schema)
|
||||
|
||||
# Apply model-specific transformations
|
||||
if SpecialModelType.GEMINI in self.node_data.model.name:
|
||||
remove_additional_properties(processed_schema)
|
||||
return processed_schema
|
||||
elif SpecialModelType.OLLAMA in self.node_data.model.provider:
|
||||
return processed_schema
|
||||
else:
|
||||
# Default format with name field
|
||||
return {"schema": processed_schema, "name": "llm_response"}
|
||||
|
||||
def _fetch_model_schema(self, provider: str) -> AIModelEntity | None:
|
||||
"""
|
||||
Fetch model schema
|
||||
|
|
@ -1243,49 +1142,3 @@ def _handle_completion_template(
|
|||
)
|
||||
prompt_messages.append(prompt_message)
|
||||
return prompt_messages
|
||||
|
||||
|
||||
def remove_additional_properties(schema: dict) -> None:
|
||||
"""
|
||||
Remove additionalProperties fields from JSON schema.
|
||||
Used for models like Gemini that don't support this property.
|
||||
|
||||
:param schema: JSON schema to modify in-place
|
||||
"""
|
||||
if not isinstance(schema, dict):
|
||||
return
|
||||
|
||||
# Remove additionalProperties at current level
|
||||
schema.pop("additionalProperties", None)
|
||||
|
||||
# Process nested structures recursively
|
||||
for value in schema.values():
|
||||
if isinstance(value, dict):
|
||||
remove_additional_properties(value)
|
||||
elif isinstance(value, list):
|
||||
for item in value:
|
||||
if isinstance(item, dict):
|
||||
remove_additional_properties(item)
|
||||
|
||||
|
||||
def convert_boolean_to_string(schema: dict) -> None:
|
||||
"""
|
||||
Convert boolean type specifications to string in JSON schema.
|
||||
|
||||
:param schema: JSON schema to modify in-place
|
||||
"""
|
||||
if not isinstance(schema, dict):
|
||||
return
|
||||
|
||||
# Check for boolean type at current level
|
||||
if schema.get("type") == "boolean":
|
||||
schema["type"] = "string"
|
||||
|
||||
# Process nested dictionaries and lists recursively
|
||||
for value in schema.values():
|
||||
if isinstance(value, dict):
|
||||
convert_boolean_to_string(value)
|
||||
elif isinstance(value, list):
|
||||
for item in value:
|
||||
if isinstance(item, dict):
|
||||
convert_boolean_to_string(item)
|
||||
|
|
|
|||
|
|
@ -167,7 +167,9 @@ class ToolNode(BaseNode[ToolNodeData]):
|
|||
if tool_input.type == "variable":
|
||||
variable = variable_pool.get(tool_input.value)
|
||||
if variable is None:
|
||||
raise ToolParameterError(f"Variable {tool_input.value} does not exist")
|
||||
if parameter.required:
|
||||
raise ToolParameterError(f"Variable {tool_input.value} does not exist")
|
||||
continue
|
||||
parameter_value = variable.value
|
||||
elif tool_input.type in {"mixed", "constant"}:
|
||||
segment_group = variable_pool.convert_template(str(tool_input.value))
|
||||
|
|
|
|||
|
|
@ -0,0 +1,32 @@
|
|||
import abc
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Protocol
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
|
||||
|
||||
class DraftVariableSaver(Protocol):
|
||||
@abc.abstractmethod
|
||||
def save(self, process_data: Mapping[str, Any] | None, outputs: Mapping[str, Any] | None):
|
||||
pass
|
||||
|
||||
|
||||
class DraftVariableSaverFactory(Protocol):
|
||||
@abc.abstractmethod
|
||||
def __call__(
|
||||
self,
|
||||
session: Session,
|
||||
app_id: str,
|
||||
node_id: str,
|
||||
node_type: NodeType,
|
||||
node_execution_id: str,
|
||||
enclosing_node_id: str | None = None,
|
||||
) -> "DraftVariableSaver":
|
||||
pass
|
||||
|
||||
|
||||
class NoopDraftVariableSaver(DraftVariableSaver):
|
||||
def save(self, process_data: Mapping[str, Any] | None, outputs: Mapping[str, Any] | None):
|
||||
pass
|
||||
|
|
@ -1,16 +0,0 @@
|
|||
from enum import StrEnum
|
||||
|
||||
|
||||
class ResponseFormat(StrEnum):
|
||||
"""Constants for model response formats"""
|
||||
|
||||
JSON_SCHEMA = "json_schema" # model's structured output mode. some model like gemini, gpt-4o, support this mode.
|
||||
JSON = "JSON" # model's json mode. some model like claude support this mode.
|
||||
JSON_OBJECT = "json_object" # json mode's another alias. some model like deepseek-chat, qwen use this alias.
|
||||
|
||||
|
||||
class SpecialModelType(StrEnum):
|
||||
"""Constants for identifying model types"""
|
||||
|
||||
GEMINI = "gemini"
|
||||
OLLAMA = "ollama"
|
||||
|
|
@ -1,17 +0,0 @@
|
|||
STRUCTURED_OUTPUT_PROMPT = """You’re a helpful AI assistant. You could answer questions and output in JSON format.
|
||||
constraints:
|
||||
- You must output in JSON format.
|
||||
- Do not output boolean value, use string type instead.
|
||||
- Do not output integer or float value, use number type instead.
|
||||
eg:
|
||||
Here is the JSON schema:
|
||||
{"additionalProperties": false, "properties": {"age": {"type": "number"}, "name": {"type": "string"}}, "required": ["name", "age"], "type": "object"}
|
||||
|
||||
Here is the user's question:
|
||||
My name is John Doe and I am 30 years old.
|
||||
|
||||
output:
|
||||
{"name": "John Doe", "age": 30}
|
||||
Here is the JSON schema:
|
||||
{{schema}}
|
||||
""" # noqa: E501
|
||||
|
|
@ -7,6 +7,7 @@ def append_variables_recursively(
|
|||
):
|
||||
"""
|
||||
Append variables recursively
|
||||
:param pool: variable pool to append variables to
|
||||
:param node_id: node id
|
||||
:param variable_key_list: variable key list
|
||||
:param variable_value: variable value
|
||||
|
|
|
|||
|
|
@ -27,6 +27,7 @@ from core.workflow.enums import SystemVariableKey
|
|||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -160,12 +161,13 @@ class WorkflowCycleManager:
|
|||
exceptions_count: int = 0,
|
||||
) -> WorkflowExecution:
|
||||
workflow_execution = self._get_workflow_execution_or_raise_error(workflow_run_id)
|
||||
now = naive_utc_now()
|
||||
|
||||
workflow_execution.status = WorkflowExecutionStatus(status.value)
|
||||
workflow_execution.error_message = error_message
|
||||
workflow_execution.total_tokens = total_tokens
|
||||
workflow_execution.total_steps = total_steps
|
||||
workflow_execution.finished_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
workflow_execution.finished_at = now
|
||||
workflow_execution.exceptions_count = exceptions_count
|
||||
|
||||
# Use the instance repository to find running executions for a workflow run
|
||||
|
|
@ -174,7 +176,6 @@ class WorkflowCycleManager:
|
|||
)
|
||||
|
||||
# Update the domain models
|
||||
now = datetime.now(UTC).replace(tzinfo=None)
|
||||
for node_execution in running_node_executions:
|
||||
if node_execution.node_execution_id:
|
||||
# Update the domain model
|
||||
|
|
|
|||
|
|
@ -300,7 +300,7 @@ class WorkflowEntry:
|
|||
return node_instance, generator
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"error while running node_instance, workflow_id=%s, node_id=%s, type=%s, version=%s",
|
||||
"error while running node_instance, node_id=%s, type=%s, version=%s",
|
||||
node_instance.id,
|
||||
node_instance.node_type,
|
||||
node_instance.version(),
|
||||
|
|
|
|||
|
|
@ -3,8 +3,10 @@ from .clean_when_document_deleted import handle
|
|||
from .create_document_index import handle
|
||||
from .create_installed_app_when_app_created import handle
|
||||
from .create_site_record_when_app_created import handle
|
||||
from .deduct_quota_when_message_created import handle
|
||||
from .delete_tool_parameters_cache_when_sync_draft_workflow import handle
|
||||
from .update_app_dataset_join_when_app_model_config_updated import handle
|
||||
from .update_app_dataset_join_when_app_published_workflow_updated import handle
|
||||
from .update_provider_last_used_at_when_message_created import handle
|
||||
|
||||
# Consolidated handler replaces both deduct_quota_when_message_created and
|
||||
# update_provider_last_used_at_when_message_created
|
||||
from .update_provider_when_message_created import handle
|
||||
|
|
|
|||
|
|
@ -1,65 +0,0 @@
|
|||
from datetime import UTC, datetime
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, ChatAppGenerateEntity
|
||||
from core.entities.provider_entities import QuotaUnit
|
||||
from core.plugin.entities.plugin import ModelProviderID
|
||||
from events.message_event import message_was_created
|
||||
from extensions.ext_database import db
|
||||
from models.provider import Provider, ProviderType
|
||||
|
||||
|
||||
@message_was_created.connect
|
||||
def handle(sender, **kwargs):
|
||||
message = sender
|
||||
application_generate_entity = kwargs.get("application_generate_entity")
|
||||
|
||||
if not isinstance(application_generate_entity, ChatAppGenerateEntity | AgentChatAppGenerateEntity):
|
||||
return
|
||||
|
||||
model_config = application_generate_entity.model_conf
|
||||
provider_model_bundle = model_config.provider_model_bundle
|
||||
provider_configuration = provider_model_bundle.configuration
|
||||
|
||||
if provider_configuration.using_provider_type != ProviderType.SYSTEM:
|
||||
return
|
||||
|
||||
system_configuration = provider_configuration.system_configuration
|
||||
|
||||
if not system_configuration.current_quota_type:
|
||||
return
|
||||
|
||||
quota_unit = None
|
||||
for quota_configuration in system_configuration.quota_configurations:
|
||||
if quota_configuration.quota_type == system_configuration.current_quota_type:
|
||||
quota_unit = quota_configuration.quota_unit
|
||||
|
||||
if quota_configuration.quota_limit == -1:
|
||||
return
|
||||
|
||||
break
|
||||
|
||||
used_quota = None
|
||||
if quota_unit:
|
||||
if quota_unit == QuotaUnit.TOKENS:
|
||||
used_quota = message.message_tokens + message.answer_tokens
|
||||
elif quota_unit == QuotaUnit.CREDITS:
|
||||
used_quota = dify_config.get_model_credits(model_config.model)
|
||||
else:
|
||||
used_quota = 1
|
||||
|
||||
if used_quota is not None and system_configuration.current_quota_type is not None:
|
||||
db.session.query(Provider).filter(
|
||||
Provider.tenant_id == application_generate_entity.app_config.tenant_id,
|
||||
# TODO: Use provider name with prefix after the data migration.
|
||||
Provider.provider_name == ModelProviderID(model_config.provider).provider_name,
|
||||
Provider.provider_type == ProviderType.SYSTEM.value,
|
||||
Provider.quota_type == system_configuration.current_quota_type.value,
|
||||
Provider.quota_limit > Provider.quota_used,
|
||||
).update(
|
||||
{
|
||||
"quota_used": Provider.quota_used + used_quota,
|
||||
"last_used": datetime.now(tz=UTC).replace(tzinfo=None),
|
||||
}
|
||||
)
|
||||
db.session.commit()
|
||||
|
|
@ -1,20 +0,0 @@
|
|||
from datetime import UTC, datetime
|
||||
|
||||
from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, ChatAppGenerateEntity
|
||||
from events.message_event import message_was_created
|
||||
from extensions.ext_database import db
|
||||
from models.provider import Provider
|
||||
|
||||
|
||||
@message_was_created.connect
|
||||
def handle(sender, **kwargs):
|
||||
application_generate_entity = kwargs.get("application_generate_entity")
|
||||
|
||||
if not isinstance(application_generate_entity, ChatAppGenerateEntity | AgentChatAppGenerateEntity):
|
||||
return
|
||||
|
||||
db.session.query(Provider).filter(
|
||||
Provider.tenant_id == application_generate_entity.app_config.tenant_id,
|
||||
Provider.provider_name == application_generate_entity.model_conf.provider,
|
||||
).update({"last_used": datetime.now(UTC).replace(tzinfo=None)})
|
||||
db.session.commit()
|
||||
|
|
@ -0,0 +1,234 @@
|
|||
import logging
|
||||
import time as time_module
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, ChatAppGenerateEntity
|
||||
from core.entities.provider_entities import QuotaUnit, SystemConfiguration
|
||||
from core.plugin.entities.plugin import ModelProviderID
|
||||
from events.message_event import message_was_created
|
||||
from extensions.ext_database import db
|
||||
from libs import datetime_utils
|
||||
from models.model import Message
|
||||
from models.provider import Provider, ProviderType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class _ProviderUpdateFilters(BaseModel):
|
||||
"""Filters for identifying Provider records to update."""
|
||||
|
||||
tenant_id: str
|
||||
provider_name: str
|
||||
provider_type: Optional[str] = None
|
||||
quota_type: Optional[str] = None
|
||||
|
||||
|
||||
class _ProviderUpdateAdditionalFilters(BaseModel):
|
||||
"""Additional filters for Provider updates."""
|
||||
|
||||
quota_limit_check: bool = False
|
||||
|
||||
|
||||
class _ProviderUpdateValues(BaseModel):
|
||||
"""Values to update in Provider records."""
|
||||
|
||||
last_used: Optional[datetime] = None
|
||||
quota_used: Optional[Any] = None # Can be Provider.quota_used + int expression
|
||||
|
||||
|
||||
class _ProviderUpdateOperation(BaseModel):
|
||||
"""A single Provider update operation."""
|
||||
|
||||
filters: _ProviderUpdateFilters
|
||||
values: _ProviderUpdateValues
|
||||
additional_filters: _ProviderUpdateAdditionalFilters = _ProviderUpdateAdditionalFilters()
|
||||
description: str = "unknown"
|
||||
|
||||
|
||||
@message_was_created.connect
|
||||
def handle(sender: Message, **kwargs):
|
||||
"""
|
||||
Consolidated handler for Provider updates when a message is created.
|
||||
|
||||
This handler replaces both:
|
||||
- update_provider_last_used_at_when_message_created
|
||||
- deduct_quota_when_message_created
|
||||
|
||||
By performing all Provider updates in a single transaction, we ensure
|
||||
consistency and efficiency when updating Provider records.
|
||||
"""
|
||||
message = sender
|
||||
application_generate_entity = kwargs.get("application_generate_entity")
|
||||
|
||||
if not isinstance(application_generate_entity, ChatAppGenerateEntity | AgentChatAppGenerateEntity):
|
||||
return
|
||||
|
||||
tenant_id = application_generate_entity.app_config.tenant_id
|
||||
provider_name = application_generate_entity.model_conf.provider
|
||||
current_time = datetime_utils.naive_utc_now()
|
||||
|
||||
# Prepare updates for both scenarios
|
||||
updates_to_perform: list[_ProviderUpdateOperation] = []
|
||||
|
||||
# 1. Always update last_used for the provider
|
||||
basic_update = _ProviderUpdateOperation(
|
||||
filters=_ProviderUpdateFilters(
|
||||
tenant_id=tenant_id,
|
||||
provider_name=provider_name,
|
||||
),
|
||||
values=_ProviderUpdateValues(last_used=current_time),
|
||||
description="basic_last_used_update",
|
||||
)
|
||||
updates_to_perform.append(basic_update)
|
||||
|
||||
# 2. Check if we need to deduct quota (system provider only)
|
||||
model_config = application_generate_entity.model_conf
|
||||
provider_model_bundle = model_config.provider_model_bundle
|
||||
provider_configuration = provider_model_bundle.configuration
|
||||
|
||||
if (
|
||||
provider_configuration.using_provider_type == ProviderType.SYSTEM
|
||||
and provider_configuration.system_configuration
|
||||
and provider_configuration.system_configuration.current_quota_type is not None
|
||||
):
|
||||
system_configuration = provider_configuration.system_configuration
|
||||
|
||||
# Calculate quota usage
|
||||
used_quota = _calculate_quota_usage(
|
||||
message=message,
|
||||
system_configuration=system_configuration,
|
||||
model_name=model_config.model,
|
||||
)
|
||||
|
||||
if used_quota is not None:
|
||||
quota_update = _ProviderUpdateOperation(
|
||||
filters=_ProviderUpdateFilters(
|
||||
tenant_id=tenant_id,
|
||||
provider_name=ModelProviderID(model_config.provider).provider_name,
|
||||
provider_type=ProviderType.SYSTEM.value,
|
||||
quota_type=provider_configuration.system_configuration.current_quota_type.value,
|
||||
),
|
||||
values=_ProviderUpdateValues(quota_used=Provider.quota_used + used_quota, last_used=current_time),
|
||||
additional_filters=_ProviderUpdateAdditionalFilters(
|
||||
quota_limit_check=True # Provider.quota_limit > Provider.quota_used
|
||||
),
|
||||
description="quota_deduction_update",
|
||||
)
|
||||
updates_to_perform.append(quota_update)
|
||||
|
||||
# Execute all updates
|
||||
start_time = time_module.perf_counter()
|
||||
try:
|
||||
_execute_provider_updates(updates_to_perform)
|
||||
|
||||
# Log successful completion with timing
|
||||
duration = time_module.perf_counter() - start_time
|
||||
|
||||
logger.info(
|
||||
f"Provider updates completed successfully. "
|
||||
f"Updates: {len(updates_to_perform)}, Duration: {duration:.3f}s, "
|
||||
f"Tenant: {tenant_id}, Provider: {provider_name}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
# Log failure with timing and context
|
||||
duration = time_module.perf_counter() - start_time
|
||||
|
||||
logger.exception(
|
||||
f"Provider updates failed after {duration:.3f}s. "
|
||||
f"Updates: {len(updates_to_perform)}, Tenant: {tenant_id}, "
|
||||
f"Provider: {provider_name}"
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
def _calculate_quota_usage(
|
||||
*, message: Message, system_configuration: SystemConfiguration, model_name: str
|
||||
) -> Optional[int]:
|
||||
"""Calculate quota usage based on message tokens and quota type."""
|
||||
quota_unit = None
|
||||
for quota_configuration in system_configuration.quota_configurations:
|
||||
if quota_configuration.quota_type == system_configuration.current_quota_type:
|
||||
quota_unit = quota_configuration.quota_unit
|
||||
if quota_configuration.quota_limit == -1:
|
||||
return None
|
||||
break
|
||||
if quota_unit is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
if quota_unit == QuotaUnit.TOKENS:
|
||||
tokens = message.message_tokens + message.answer_tokens
|
||||
return tokens
|
||||
if quota_unit == QuotaUnit.CREDITS:
|
||||
tokens = dify_config.get_model_credits(model_name)
|
||||
return tokens
|
||||
elif quota_unit == QuotaUnit.TIMES:
|
||||
return 1
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.exception("Failed to calculate quota usage")
|
||||
return None
|
||||
|
||||
|
||||
def _execute_provider_updates(updates_to_perform: list[_ProviderUpdateOperation]):
|
||||
"""Execute all Provider updates in a single transaction."""
|
||||
if not updates_to_perform:
|
||||
return
|
||||
|
||||
# Use SQLAlchemy's context manager for transaction management
|
||||
# This automatically handles commit/rollback
|
||||
with Session(db.engine) as session:
|
||||
# Use a single transaction for all updates
|
||||
for update_operation in updates_to_perform:
|
||||
filters = update_operation.filters
|
||||
values = update_operation.values
|
||||
additional_filters = update_operation.additional_filters
|
||||
description = update_operation.description
|
||||
|
||||
# Build the where conditions
|
||||
where_conditions = [
|
||||
Provider.tenant_id == filters.tenant_id,
|
||||
Provider.provider_name == filters.provider_name,
|
||||
]
|
||||
|
||||
# Add additional filters if specified
|
||||
if filters.provider_type is not None:
|
||||
where_conditions.append(Provider.provider_type == filters.provider_type)
|
||||
if filters.quota_type is not None:
|
||||
where_conditions.append(Provider.quota_type == filters.quota_type)
|
||||
if additional_filters.quota_limit_check:
|
||||
where_conditions.append(Provider.quota_limit > Provider.quota_used)
|
||||
|
||||
# Prepare values dict for SQLAlchemy update
|
||||
update_values = {}
|
||||
if values.last_used is not None:
|
||||
update_values["last_used"] = values.last_used
|
||||
if values.quota_used is not None:
|
||||
update_values["quota_used"] = values.quota_used
|
||||
|
||||
# Build and execute the update statement
|
||||
stmt = update(Provider).where(*where_conditions).values(**update_values)
|
||||
result = session.execute(stmt)
|
||||
rows_affected = result.rowcount
|
||||
|
||||
logger.debug(
|
||||
f"Provider update ({description}): {rows_affected} rows affected. "
|
||||
f"Filters: {filters.model_dump()}, Values: {update_values}"
|
||||
)
|
||||
|
||||
# If no rows were affected for quota updates, log a warning
|
||||
if rows_affected == 0 and description == "quota_deduction_update":
|
||||
logger.warning(
|
||||
f"No Provider rows updated for quota deduction. "
|
||||
f"This may indicate quota limit exceeded or provider not found. "
|
||||
f"Filters: {filters.model_dump()}"
|
||||
)
|
||||
|
||||
logger.debug(f"Successfully processed {len(updates_to_perform)} Provider updates")
|
||||
|
|
@ -12,14 +12,14 @@ def init_app(app: DifyApp):
|
|||
@app.after_request
|
||||
def after_request(response):
|
||||
"""Add Version headers to the response."""
|
||||
response.headers.add("X-Version", dify_config.CURRENT_VERSION)
|
||||
response.headers.add("X-Version", dify_config.project.version)
|
||||
response.headers.add("X-Env", dify_config.DEPLOY_ENV)
|
||||
return response
|
||||
|
||||
@app.route("/health")
|
||||
def health():
|
||||
return Response(
|
||||
json.dumps({"pid": os.getpid(), "status": "ok", "version": dify_config.CURRENT_VERSION}),
|
||||
json.dumps({"pid": os.getpid(), "status": "ok", "version": dify_config.project.version}),
|
||||
status=200,
|
||||
content_type="application/json",
|
||||
)
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ def init_app(app: DifyApp) -> Celery:
|
|||
"master_name": dify_config.CELERY_SENTINEL_MASTER_NAME,
|
||||
"sentinel_kwargs": {
|
||||
"socket_timeout": dify_config.CELERY_SENTINEL_SOCKET_TIMEOUT,
|
||||
"password": dify_config.CELERY_SENTINEL_PASSWORD,
|
||||
},
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -49,7 +49,7 @@ def init_app(app: DifyApp):
|
|||
logging.getLogger().addHandler(exception_handler)
|
||||
|
||||
def init_flask_instrumentor(app: DifyApp):
|
||||
meter = get_meter("http_metrics", version=dify_config.CURRENT_VERSION)
|
||||
meter = get_meter("http_metrics", version=dify_config.project.version)
|
||||
_http_response_counter = meter.create_counter(
|
||||
"http.server.response.count",
|
||||
description="Total number of HTTP responses by status code, method and target",
|
||||
|
|
@ -163,7 +163,7 @@ def init_app(app: DifyApp):
|
|||
resource = Resource(
|
||||
attributes={
|
||||
ResourceAttributes.SERVICE_NAME: dify_config.APPLICATION_NAME,
|
||||
ResourceAttributes.SERVICE_VERSION: f"dify-{dify_config.CURRENT_VERSION}-{dify_config.COMMIT_SHA}",
|
||||
ResourceAttributes.SERVICE_VERSION: f"dify-{dify_config.project.version}-{dify_config.COMMIT_SHA}",
|
||||
ResourceAttributes.PROCESS_PID: os.getpid(),
|
||||
ResourceAttributes.DEPLOYMENT_ENVIRONMENT: f"{dify_config.DEPLOY_ENV}-{dify_config.EDITION}",
|
||||
ResourceAttributes.HOST_NAME: socket.gethostname(),
|
||||
|
|
|
|||
|
|
@ -35,6 +35,6 @@ def init_app(app: DifyApp):
|
|||
traces_sample_rate=dify_config.SENTRY_TRACES_SAMPLE_RATE,
|
||||
profiles_sample_rate=dify_config.SENTRY_PROFILES_SAMPLE_RATE,
|
||||
environment=dify_config.DEPLOY_ENV,
|
||||
release=f"dify-{dify_config.CURRENT_VERSION}-{dify_config.COMMIT_SHA}",
|
||||
release=f"dify-{dify_config.project.version}-{dify_config.COMMIT_SHA}",
|
||||
before_send=before_send,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -432,7 +432,7 @@ def get_file_type_by_mime_type(mime_type: str) -> FileType:
|
|||
|
||||
class StorageKeyLoader:
|
||||
"""FileKeyLoader load the storage key from database for a list of files.
|
||||
This loader is batched, the
|
||||
This loader is batched, the database query count is constant regardless of the input size.
|
||||
"""
|
||||
|
||||
def __init__(self, session: Session, tenant_id: str) -> None:
|
||||
|
|
@ -493,10 +493,10 @@ class StorageKeyLoader:
|
|||
if file.transfer_method in (FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL):
|
||||
upload_file_row = upload_files.get(model_id)
|
||||
if upload_file_row is None:
|
||||
raise ValueError(...)
|
||||
raise ValueError(f"Upload file not found for id: {model_id}")
|
||||
file._storage_key = upload_file_row.key
|
||||
elif file.transfer_method == FileTransferMethod.TOOL_FILE:
|
||||
tool_file_row = tool_files.get(model_id)
|
||||
if tool_file_row is None:
|
||||
raise ValueError(...)
|
||||
raise ValueError(f"Tool file not found for id: {model_id}")
|
||||
file._storage_key = tool_file_row.file_key
|
||||
|
|
|
|||
|
|
@ -0,0 +1,30 @@
|
|||
from pathlib import Path
|
||||
|
||||
|
||||
def search_file_upwards(
|
||||
base_dir_path: Path,
|
||||
target_file_name: str,
|
||||
max_search_parent_depth: int,
|
||||
) -> Path:
|
||||
"""
|
||||
Find a target file in the current directory or its parent directories up to a specified depth.
|
||||
:param base_dir_path: Starting directory path to search from.
|
||||
:param target_file_name: Name of the file to search for.
|
||||
:param max_search_parent_depth: Maximum number of parent directories to search upwards.
|
||||
:return: Path of the file if found, otherwise None.
|
||||
"""
|
||||
current_path = base_dir_path.resolve()
|
||||
for _ in range(max_search_parent_depth):
|
||||
candidate_path = current_path / target_file_name
|
||||
if candidate_path.is_file():
|
||||
return candidate_path
|
||||
parent_path = current_path.parent
|
||||
if parent_path == current_path: # reached the root directory
|
||||
break
|
||||
else:
|
||||
current_path = parent_path
|
||||
|
||||
raise ValueError(
|
||||
f"File '{target_file_name}' not found in the directory '{base_dir_path.resolve()}' or its parent directories"
|
||||
f" in depth of {max_search_parent_depth}."
|
||||
)
|
||||
|
|
@ -162,7 +162,7 @@ class Dataset(Base):
|
|||
def word_count(self):
|
||||
return (
|
||||
db.session.query(Document)
|
||||
.with_entities(func.coalesce(func.sum(Document.word_count)))
|
||||
.with_entities(func.coalesce(func.sum(Document.word_count), 0))
|
||||
.filter(Document.dataset_id == self.id)
|
||||
.scalar()
|
||||
)
|
||||
|
|
@ -480,7 +480,7 @@ class Document(Base):
|
|||
def hit_count(self):
|
||||
return (
|
||||
db.session.query(DocumentSegment)
|
||||
.with_entities(func.coalesce(func.sum(DocumentSegment.hit_count)))
|
||||
.with_entities(func.coalesce(func.sum(DocumentSegment.hit_count), 0))
|
||||
.filter(DocumentSegment.document_id == self.id)
|
||||
.scalar()
|
||||
)
|
||||
|
|
|
|||
|
|
@ -677,7 +677,7 @@ class Conversation(Base):
|
|||
if isinstance(value, dict) and value.get("dify_model_identity") == FILE_MODEL_IDENTITY:
|
||||
if value["transfer_method"] == FileTransferMethod.TOOL_FILE:
|
||||
value["tool_file_id"] = value["related_id"]
|
||||
elif value["transfer_method"] == FileTransferMethod.LOCAL_FILE:
|
||||
elif value["transfer_method"] in [FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL]:
|
||||
value["upload_file_id"] = value["related_id"]
|
||||
inputs[key] = file_factory.build_from_mapping(mapping=value, tenant_id=value["tenant_id"])
|
||||
elif isinstance(value, list) and all(
|
||||
|
|
@ -687,7 +687,7 @@ class Conversation(Base):
|
|||
for item in value:
|
||||
if item["transfer_method"] == FileTransferMethod.TOOL_FILE:
|
||||
item["tool_file_id"] = item["related_id"]
|
||||
elif item["transfer_method"] == FileTransferMethod.LOCAL_FILE:
|
||||
elif item["transfer_method"] in [FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL]:
|
||||
item["upload_file_id"] = item["related_id"]
|
||||
inputs[key].append(file_factory.build_from_mapping(mapping=item, tenant_id=item["tenant_id"]))
|
||||
|
||||
|
|
@ -719,7 +719,6 @@ class Conversation(Base):
|
|||
if "model" in override_model_configs:
|
||||
app_model_config = AppModelConfig()
|
||||
app_model_config = app_model_config.from_model_config_dict(override_model_configs)
|
||||
assert app_model_config is not None, "app model config not found"
|
||||
model_config = app_model_config.to_dict()
|
||||
else:
|
||||
model_config["configs"] = override_model_configs
|
||||
|
|
@ -915,11 +914,11 @@ class Message(Base):
|
|||
_inputs: Mapped[dict] = mapped_column("inputs", db.JSON)
|
||||
query: Mapped[str] = db.Column(db.Text, nullable=False)
|
||||
message = db.Column(db.JSON, nullable=False)
|
||||
message_tokens = db.Column(db.Integer, nullable=False, server_default=db.text("0"))
|
||||
message_tokens: Mapped[int] = db.Column(db.Integer, nullable=False, server_default=db.text("0"))
|
||||
message_unit_price = db.Column(db.Numeric(10, 4), nullable=False)
|
||||
message_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001"))
|
||||
answer: Mapped[str] = db.Column(db.Text, nullable=False)
|
||||
answer_tokens = db.Column(db.Integer, nullable=False, server_default=db.text("0"))
|
||||
answer_tokens: Mapped[int] = db.Column(db.Integer, nullable=False, server_default=db.text("0"))
|
||||
answer_unit_price = db.Column(db.Numeric(10, 4), nullable=False)
|
||||
answer_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001"))
|
||||
parent_message_id = db.Column(StringUUID, nullable=True)
|
||||
|
|
@ -948,7 +947,7 @@ class Message(Base):
|
|||
if isinstance(value, dict) and value.get("dify_model_identity") == FILE_MODEL_IDENTITY:
|
||||
if value["transfer_method"] == FileTransferMethod.TOOL_FILE:
|
||||
value["tool_file_id"] = value["related_id"]
|
||||
elif value["transfer_method"] == FileTransferMethod.LOCAL_FILE:
|
||||
elif value["transfer_method"] in [FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL]:
|
||||
value["upload_file_id"] = value["related_id"]
|
||||
inputs[key] = file_factory.build_from_mapping(mapping=value, tenant_id=value["tenant_id"])
|
||||
elif isinstance(value, list) and all(
|
||||
|
|
@ -958,7 +957,7 @@ class Message(Base):
|
|||
for item in value:
|
||||
if item["transfer_method"] == FileTransferMethod.TOOL_FILE:
|
||||
item["tool_file_id"] = item["related_id"]
|
||||
elif item["transfer_method"] == FileTransferMethod.LOCAL_FILE:
|
||||
elif item["transfer_method"] in [FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL]:
|
||||
item["upload_file_id"] = item["related_id"]
|
||||
inputs[key].append(file_factory.build_from_mapping(mapping=item, tenant_id=item["tenant_id"]))
|
||||
return inputs
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
[project]
|
||||
name = "dify-api"
|
||||
dynamic = ["version"]
|
||||
version = "1.5.1"
|
||||
requires-python = ">=3.11,<3.13"
|
||||
|
||||
dependencies = [
|
||||
|
|
@ -155,6 +155,7 @@ dev = [
|
|||
"types_setuptools>=80.9.0",
|
||||
"pandas-stubs~=2.2.3",
|
||||
"scipy-stubs>=1.15.3.0",
|
||||
"types-python-http-client>=3.3.7.20240910",
|
||||
]
|
||||
|
||||
############################################################
|
||||
|
|
@ -197,7 +198,7 @@ vdb = [
|
|||
"pymochow==1.3.1",
|
||||
"pyobvector~=0.1.6",
|
||||
"qdrant-client==1.9.0",
|
||||
"tablestore==6.1.0",
|
||||
"tablestore==6.2.0",
|
||||
"tcvectordb~=1.6.4",
|
||||
"tidb-vector==0.0.9",
|
||||
"upstash-vector==0.6.0",
|
||||
|
|
|
|||
|
|
@ -889,7 +889,7 @@ class RegisterService:
|
|||
|
||||
TenantService.create_owner_tenant_if_not_exist(account=account, is_setup=True)
|
||||
|
||||
dify_setup = DifySetup(version=dify_config.CURRENT_VERSION)
|
||||
dify_setup = DifySetup(version=dify_config.project.version)
|
||||
db.session.add(dify_setup)
|
||||
db.session.commit()
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -323,6 +323,23 @@ class DatasetService:
|
|||
except ProviderTokenNotInitError as ex:
|
||||
raise ValueError(ex.description)
|
||||
|
||||
@staticmethod
|
||||
def check_reranking_model_setting(tenant_id: str, reranking_model_provider: str, reranking_model: str):
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
model_manager.get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
provider=reranking_model_provider,
|
||||
model_type=ModelType.RERANK,
|
||||
model=reranking_model,
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ValueError(
|
||||
"No Rerank Model available. Please configure a valid provider in the Settings -> Model Provider."
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ValueError(ex.description)
|
||||
|
||||
@staticmethod
|
||||
def update_dataset(dataset_id, data, user):
|
||||
"""
|
||||
|
|
@ -645,6 +662,10 @@ class DatasetService:
|
|||
)
|
||||
except ProviderTokenNotInitError:
|
||||
# If we can't get the embedding model, preserve existing settings
|
||||
logging.warning(
|
||||
f"Failed to initialize embedding model {data['embedding_model_provider']}/{data['embedding_model']}, "
|
||||
f"preserving existing settings"
|
||||
)
|
||||
if dataset.embedding_model_provider and dataset.embedding_model:
|
||||
filtered_data["embedding_model_provider"] = dataset.embedding_model_provider
|
||||
filtered_data["embedding_model"] = dataset.embedding_model
|
||||
|
|
@ -2661,6 +2682,7 @@ class SegmentService:
|
|||
|
||||
# calc embedding use tokens
|
||||
if document.doc_form == "qa_model":
|
||||
segment.answer = args.answer
|
||||
tokens = embedding_model.get_text_embedding_num_tokens(texts=[content + segment.answer])[0]
|
||||
else:
|
||||
tokens = embedding_model.get_text_embedding_num_tokens(texts=[content])[0]
|
||||
|
|
|
|||
|
|
@ -1,23 +0,0 @@
|
|||
from typing import Optional
|
||||
|
||||
from core.moderation.factory import ModerationFactory, ModerationOutputsResult
|
||||
from extensions.ext_database import db
|
||||
from models.model import App, AppModelConfig
|
||||
|
||||
|
||||
class ModerationService:
|
||||
def moderation_for_outputs(self, app_id: str, app_model: App, text: str) -> ModerationOutputsResult:
|
||||
app_model_config: Optional[AppModelConfig] = None
|
||||
|
||||
app_model_config = (
|
||||
db.session.query(AppModelConfig).filter(AppModelConfig.id == app_model.app_model_config_id).first()
|
||||
)
|
||||
|
||||
if not app_model_config:
|
||||
raise ValueError("app model config not found")
|
||||
|
||||
name = app_model_config.sensitive_word_avoidance_dict["type"]
|
||||
config = app_model_config.sensitive_word_avoidance_dict["config"]
|
||||
|
||||
moderation = ModerationFactory(name, app_id, app_model.tenant_id, config)
|
||||
return moderation.moderation_for_outputs(text)
|
||||
|
|
@ -8,9 +8,10 @@ from extensions.ext_redis import redis_client
|
|||
class OAuthProxyService(BasePluginClient):
|
||||
# Default max age for proxy context parameter in seconds
|
||||
__MAX_AGE__ = 5 * 60 # 5 minutes
|
||||
__KEY_PREFIX__ = "oauth_proxy_context:"
|
||||
|
||||
@staticmethod
|
||||
def create_proxy_context(user_id, tenant_id, plugin_id, provider):
|
||||
def create_proxy_context(user_id: str, tenant_id: str, plugin_id: str, provider: str):
|
||||
"""
|
||||
Create a proxy context for an OAuth 2.0 authorization request.
|
||||
|
||||
|
|
@ -23,26 +24,22 @@ class OAuthProxyService(BasePluginClient):
|
|||
is used to verify the state, ensuring the request's integrity and authenticity,
|
||||
and mitigating replay attacks.
|
||||
"""
|
||||
seconds, _ = redis_client.time()
|
||||
context_id = str(uuid.uuid4())
|
||||
data = {
|
||||
"user_id": user_id,
|
||||
"plugin_id": plugin_id,
|
||||
"tenant_id": tenant_id,
|
||||
"provider": provider,
|
||||
# encode redis time to avoid distribution time skew
|
||||
"timestamp": seconds,
|
||||
}
|
||||
# ignore nonce collision
|
||||
redis_client.setex(
|
||||
f"oauth_proxy_context:{context_id}",
|
||||
f"{OAuthProxyService.__KEY_PREFIX__}{context_id}",
|
||||
OAuthProxyService.__MAX_AGE__,
|
||||
json.dumps(data),
|
||||
)
|
||||
return context_id
|
||||
|
||||
@staticmethod
|
||||
def use_proxy_context(context_id, max_age=__MAX_AGE__):
|
||||
def use_proxy_context(context_id: str):
|
||||
"""
|
||||
Validate the proxy context parameter.
|
||||
This checks if the context_id is valid and not expired.
|
||||
|
|
@ -50,12 +47,7 @@ class OAuthProxyService(BasePluginClient):
|
|||
if not context_id:
|
||||
raise ValueError("context_id is required")
|
||||
# get data from redis
|
||||
data = redis_client.getdel(f"oauth_proxy_context:{context_id}")
|
||||
data = redis_client.getdel(f"{OAuthProxyService.__KEY_PREFIX__}{context_id}")
|
||||
if not data:
|
||||
raise ValueError("context_id is invalid")
|
||||
# check if data is expired
|
||||
seconds, _ = redis_client.time()
|
||||
state = json.loads(data)
|
||||
if state.get("timestamp") < seconds - max_age:
|
||||
raise ValueError("context_id is expired")
|
||||
return state
|
||||
return json.loads(data)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,74 @@
|
|||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Literal
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.plugin.entities.parameters import PluginParameterOption
|
||||
from core.plugin.impl.dynamic_select import DynamicSelectClient
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.tools.utils.configuration import ProviderConfigEncrypter
|
||||
from extensions.ext_database import db
|
||||
from models.tools import BuiltinToolProvider
|
||||
|
||||
|
||||
class PluginParameterService:
|
||||
@staticmethod
|
||||
def get_dynamic_select_options(
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
action: str,
|
||||
parameter: str,
|
||||
provider_type: Literal["tool"],
|
||||
) -> Sequence[PluginParameterOption]:
|
||||
"""
|
||||
Get dynamic select options for a plugin parameter.
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant ID.
|
||||
plugin_id: The plugin ID.
|
||||
provider: The provider name.
|
||||
action: The action name.
|
||||
parameter: The parameter name.
|
||||
"""
|
||||
credentials: Mapping[str, Any] = {}
|
||||
|
||||
match provider_type:
|
||||
case "tool":
|
||||
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
|
||||
# init tool configuration
|
||||
tool_configuration = ProviderConfigEncrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
|
||||
provider_type=provider_controller.provider_type.value,
|
||||
provider_identity=provider_controller.entity.identity.name,
|
||||
)
|
||||
|
||||
# check if credentials are required
|
||||
if not provider_controller.need_credentials:
|
||||
credentials = {}
|
||||
else:
|
||||
# fetch credentials from db
|
||||
with Session(db.engine) as session:
|
||||
db_record = (
|
||||
session.query(BuiltinToolProvider)
|
||||
.filter(
|
||||
BuiltinToolProvider.tenant_id == tenant_id,
|
||||
BuiltinToolProvider.provider == provider,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if db_record is None:
|
||||
raise ValueError(f"Builtin provider {provider} not found when fetching credentials")
|
||||
|
||||
credentials = tool_configuration.decrypt(db_record.credentials)
|
||||
case _:
|
||||
raise ValueError(f"Invalid provider type: {provider_type}")
|
||||
|
||||
return (
|
||||
DynamicSelectClient()
|
||||
.fetch_dynamic_select_options(tenant_id, user_id, plugin_id, provider, action, credentials, parameter)
|
||||
.options
|
||||
)
|
||||
|
|
@ -97,16 +97,16 @@ class VectorService:
|
|||
vector = Vector(dataset=dataset)
|
||||
vector.delete_by_ids([segment.index_node_id])
|
||||
vector.add_texts([document], duplicate_check=True)
|
||||
|
||||
# update keyword index
|
||||
keyword = Keyword(dataset)
|
||||
keyword.delete_by_ids([segment.index_node_id])
|
||||
|
||||
# save keyword index
|
||||
if keywords and len(keywords) > 0:
|
||||
keyword.add_texts([document], keywords_list=[keywords])
|
||||
else:
|
||||
keyword.add_texts([document])
|
||||
# update keyword index
|
||||
keyword = Keyword(dataset)
|
||||
keyword.delete_by_ids([segment.index_node_id])
|
||||
|
||||
# save keyword index
|
||||
if keywords and len(keywords) > 0:
|
||||
keyword.add_texts([document], keywords_list=[keywords])
|
||||
else:
|
||||
keyword.add_texts([document])
|
||||
|
||||
@classmethod
|
||||
def generate_child_chunks(
|
||||
|
|
|
|||
|
|
@ -154,7 +154,7 @@ class WorkflowDraftVariableService:
|
|||
variables = (
|
||||
# Do not load the `value` field.
|
||||
query.options(orm.defer(WorkflowDraftVariable.value))
|
||||
.order_by(WorkflowDraftVariable.id.desc())
|
||||
.order_by(WorkflowDraftVariable.created_at.desc())
|
||||
.limit(limit)
|
||||
.offset((page - 1) * limit)
|
||||
.all()
|
||||
|
|
@ -168,7 +168,7 @@ class WorkflowDraftVariableService:
|
|||
WorkflowDraftVariable.node_id == node_id,
|
||||
)
|
||||
query = self._session.query(WorkflowDraftVariable).filter(*criteria)
|
||||
variables = query.order_by(WorkflowDraftVariable.id.desc()).all()
|
||||
variables = query.order_by(WorkflowDraftVariable.created_at.desc()).all()
|
||||
return WorkflowDraftVariableList(variables=variables)
|
||||
|
||||
def list_node_variables(self, app_id: str, node_id: str) -> WorkflowDraftVariableList:
|
||||
|
|
@ -235,7 +235,9 @@ class WorkflowDraftVariableService:
|
|||
self._session.flush()
|
||||
return variable
|
||||
|
||||
def _reset_node_var(self, workflow: Workflow, variable: WorkflowDraftVariable) -> WorkflowDraftVariable | None:
|
||||
def _reset_node_var_or_sys_var(
|
||||
self, workflow: Workflow, variable: WorkflowDraftVariable
|
||||
) -> WorkflowDraftVariable | None:
|
||||
# If a variable does not allow updating, it makes no sence to resetting it.
|
||||
if not variable.editable:
|
||||
return variable
|
||||
|
|
@ -259,28 +261,35 @@ class WorkflowDraftVariableService:
|
|||
self._session.flush()
|
||||
return None
|
||||
|
||||
# Get node type for proper value extraction
|
||||
node_config = workflow.get_node_config_by_id(variable.node_id)
|
||||
node_type = workflow.get_node_type_from_node_config(node_config)
|
||||
|
||||
outputs_dict = node_exec.outputs_dict or {}
|
||||
# a sentinel value used to check the absent of the output variable key.
|
||||
absent = object()
|
||||
|
||||
# Note: Based on the implementation in `_build_from_variable_assigner_mapping`,
|
||||
# VariableAssignerNode (both v1 and v2) can only create conversation draft variables.
|
||||
# For consistency, we should simply return when processing VARIABLE_ASSIGNER nodes.
|
||||
#
|
||||
# This implementation must remain synchronized with the `_build_from_variable_assigner_mapping`
|
||||
# and `save` methods.
|
||||
if node_type == NodeType.VARIABLE_ASSIGNER:
|
||||
return variable
|
||||
if variable.get_variable_type() == DraftVariableType.NODE:
|
||||
# Get node type for proper value extraction
|
||||
node_config = workflow.get_node_config_by_id(variable.node_id)
|
||||
node_type = workflow.get_node_type_from_node_config(node_config)
|
||||
|
||||
if variable.name not in outputs_dict:
|
||||
# Note: Based on the implementation in `_build_from_variable_assigner_mapping`,
|
||||
# VariableAssignerNode (both v1 and v2) can only create conversation draft variables.
|
||||
# For consistency, we should simply return when processing VARIABLE_ASSIGNER nodes.
|
||||
#
|
||||
# This implementation must remain synchronized with the `_build_from_variable_assigner_mapping`
|
||||
# and `save` methods.
|
||||
if node_type == NodeType.VARIABLE_ASSIGNER:
|
||||
return variable
|
||||
output_value = outputs_dict.get(variable.name, absent)
|
||||
else:
|
||||
output_value = outputs_dict.get(f"sys.{variable.name}", absent)
|
||||
|
||||
# We cannot use `is None` to check the existence of an output variable here as
|
||||
# the value of the output may be `None`.
|
||||
if output_value is absent:
|
||||
# If variable not found in execution data, delete the variable
|
||||
self._session.delete(instance=variable)
|
||||
self._session.flush()
|
||||
return None
|
||||
value = outputs_dict[variable.name]
|
||||
value_seg = WorkflowDraftVariable.build_segment_with_type(variable.value_type, value)
|
||||
value_seg = WorkflowDraftVariable.build_segment_with_type(variable.value_type, output_value)
|
||||
# Extract variable value using unified logic
|
||||
variable.set_value(value_seg)
|
||||
variable.last_edited_at = None # Reset to indicate this is a reset operation
|
||||
|
|
@ -291,10 +300,8 @@ class WorkflowDraftVariableService:
|
|||
variable_type = variable.get_variable_type()
|
||||
if variable_type == DraftVariableType.CONVERSATION:
|
||||
return self._reset_conv_var(workflow, variable)
|
||||
elif variable_type == DraftVariableType.NODE:
|
||||
return self._reset_node_var(workflow, variable)
|
||||
else:
|
||||
raise VariableResetError(f"cannot reset system variable, variable_id={variable.id}")
|
||||
return self._reset_node_var_or_sys_var(workflow, variable)
|
||||
|
||||
def delete_variable(self, variable: WorkflowDraftVariable):
|
||||
self._session.delete(variable)
|
||||
|
|
@ -439,6 +446,9 @@ def _batch_upsert_draft_varaible(
|
|||
stmt = stmt.on_conflict_do_update(
|
||||
index_elements=WorkflowDraftVariable.unique_app_id_node_id_name(),
|
||||
set_={
|
||||
# Refresh creation timestamp to ensure updated variables
|
||||
# appear first in chronologically sorted result sets.
|
||||
"created_at": stmt.excluded.created_at,
|
||||
"updated_at": stmt.excluded.updated_at,
|
||||
"last_edited_at": stmt.excluded.last_edited_at,
|
||||
"description": stmt.excluded.description,
|
||||
|
|
@ -525,9 +535,6 @@ class DraftVariableSaver:
|
|||
# The type of the current node (see NodeType).
|
||||
_node_type: NodeType
|
||||
|
||||
# Indicates how the workflow execution was triggered (see InvokeFrom).
|
||||
_invoke_from: InvokeFrom
|
||||
|
||||
#
|
||||
_node_execution_id: str
|
||||
|
||||
|
|
@ -546,15 +553,16 @@ class DraftVariableSaver:
|
|||
app_id: str,
|
||||
node_id: str,
|
||||
node_type: NodeType,
|
||||
invoke_from: InvokeFrom,
|
||||
node_execution_id: str,
|
||||
enclosing_node_id: str | None = None,
|
||||
):
|
||||
# Important: `node_execution_id` parameter refers to the primary key (`id`) of the
|
||||
# WorkflowNodeExecutionModel/WorkflowNodeExecution, not their `node_execution_id`
|
||||
# field. These are distinct database fields with different purposes.
|
||||
self._session = session
|
||||
self._app_id = app_id
|
||||
self._node_id = node_id
|
||||
self._node_type = node_type
|
||||
self._invoke_from = invoke_from
|
||||
self._node_execution_id = node_execution_id
|
||||
self._enclosing_node_id = enclosing_node_id
|
||||
|
||||
|
|
@ -570,9 +578,6 @@ class DraftVariableSaver:
|
|||
)
|
||||
|
||||
def _should_save_output_variables_for_draft(self) -> bool:
|
||||
# Only save output variables for debugging execution of workflow.
|
||||
if self._invoke_from != InvokeFrom.DEBUGGER:
|
||||
return False
|
||||
if self._enclosing_node_id is not None and self._node_type != NodeType.VARIABLE_ASSIGNER:
|
||||
# Currently we do not save output variables for nodes inside loop or iteration.
|
||||
return False
|
||||
|
|
|
|||
|
|
@ -12,7 +12,6 @@ from sqlalchemy.orm import Session
|
|||
from core.app.app_config.entities import VariableEntityType
|
||||
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
|
||||
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.file import File
|
||||
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
from core.variables import Variable
|
||||
|
|
@ -413,7 +412,6 @@ class WorkflowService:
|
|||
app_id=app_model.id,
|
||||
node_id=workflow_node_execution.node_id,
|
||||
node_type=NodeType(workflow_node_execution.node_type),
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
enclosing_node_id=enclosing_node_id,
|
||||
node_execution_id=node_execution.id,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from unittest.mock import MagicMock, patch
|
|||
import pytest
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.llm_generator.output_parser.structured_output import _parse_structured_output
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
||||
from core.model_runtime.entities.message_entities import AssistantPromptMessage
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
|
|
@ -277,29 +278,6 @@ def test_execute_llm_with_jinja2(flask_req_ctx, setup_code_executor_mock):
|
|||
|
||||
|
||||
def test_extract_json():
|
||||
node = init_llm_node(
|
||||
config={
|
||||
"id": "llm",
|
||||
"data": {
|
||||
"title": "123",
|
||||
"type": "llm",
|
||||
"model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "chat", "completion_params": {}},
|
||||
"prompt_config": {
|
||||
"structured_output": {
|
||||
"enabled": True,
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {"name": {"type": "string"}, "age": {"type": "number"}},
|
||||
},
|
||||
}
|
||||
},
|
||||
"prompt_template": [{"role": "user", "text": "{{#sys.query#}}"}],
|
||||
"memory": None,
|
||||
"context": {"enabled": False},
|
||||
"vision": {"enabled": False},
|
||||
},
|
||||
},
|
||||
)
|
||||
llm_texts = [
|
||||
'<think>\n\n</think>{"name": "test", "age": 123', # resoning model (deepseek-r1)
|
||||
'{"name":"test","age":123}', # json schema model (gpt-4o)
|
||||
|
|
@ -308,4 +286,4 @@ def test_extract_json():
|
|||
'{"name":"test",age:123}', # without quotes (qwen-2.5-0.5b)
|
||||
]
|
||||
result = {"name": "test", "age": 123}
|
||||
assert all(node._parse_structured_output(item) == result for item in llm_texts)
|
||||
assert all(_parse_structured_output(item) == result for item in llm_texts)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,259 @@
|
|||
from collections.abc import Mapping, Sequence
|
||||
|
||||
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
|
||||
from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType
|
||||
from core.variables.segments import ArrayFileSegment, FileSegment
|
||||
|
||||
|
||||
class TestWorkflowResponseConverterFetchFilesFromVariableValue:
|
||||
"""Test class for WorkflowResponseConverter._fetch_files_from_variable_value method"""
|
||||
|
||||
def create_test_file(self, file_id: str = "test_file_1") -> File:
|
||||
"""Create a test File object"""
|
||||
return File(
|
||||
id=file_id,
|
||||
tenant_id="test_tenant",
|
||||
type=FileType.DOCUMENT,
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
related_id="related_123",
|
||||
filename=f"{file_id}.txt",
|
||||
extension=".txt",
|
||||
mime_type="text/plain",
|
||||
size=1024,
|
||||
storage_key="storage_key_123",
|
||||
)
|
||||
|
||||
def create_file_dict(self, file_id: str = "test_file_dict") -> dict:
|
||||
"""Create a file dictionary with correct dify_model_identity"""
|
||||
return {
|
||||
"dify_model_identity": FILE_MODEL_IDENTITY,
|
||||
"id": file_id,
|
||||
"tenant_id": "test_tenant",
|
||||
"type": "document",
|
||||
"transfer_method": "local_file",
|
||||
"related_id": "related_456",
|
||||
"filename": f"{file_id}.txt",
|
||||
"extension": ".txt",
|
||||
"mime_type": "text/plain",
|
||||
"size": 2048,
|
||||
"url": "http://example.com/file.txt",
|
||||
}
|
||||
|
||||
def test_fetch_files_from_variable_value_with_none(self):
|
||||
"""Test with None input"""
|
||||
# The method signature expects Union[dict, list, Segment], but implementation handles None
|
||||
# We'll test the actual behavior by passing an empty dict instead
|
||||
result = WorkflowResponseConverter._fetch_files_from_variable_value(None) # type: ignore
|
||||
assert result == []
|
||||
|
||||
def test_fetch_files_from_variable_value_with_empty_dict(self):
|
||||
"""Test with empty dictionary"""
|
||||
result = WorkflowResponseConverter._fetch_files_from_variable_value({})
|
||||
assert result == []
|
||||
|
||||
def test_fetch_files_from_variable_value_with_empty_list(self):
|
||||
"""Test with empty list"""
|
||||
result = WorkflowResponseConverter._fetch_files_from_variable_value([])
|
||||
assert result == []
|
||||
|
||||
def test_fetch_files_from_variable_value_with_file_segment(self):
|
||||
"""Test with valid FileSegment"""
|
||||
test_file = self.create_test_file("segment_file")
|
||||
file_segment = FileSegment(value=test_file)
|
||||
|
||||
result = WorkflowResponseConverter._fetch_files_from_variable_value(file_segment)
|
||||
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0], dict)
|
||||
assert result[0]["id"] == "segment_file"
|
||||
assert result[0]["dify_model_identity"] == FILE_MODEL_IDENTITY
|
||||
|
||||
def test_fetch_files_from_variable_value_with_array_file_segment_single(self):
|
||||
"""Test with ArrayFileSegment containing single file"""
|
||||
test_file = self.create_test_file("array_file_1")
|
||||
array_segment = ArrayFileSegment(value=[test_file])
|
||||
|
||||
result = WorkflowResponseConverter._fetch_files_from_variable_value(array_segment)
|
||||
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0], dict)
|
||||
assert result[0]["id"] == "array_file_1"
|
||||
|
||||
def test_fetch_files_from_variable_value_with_array_file_segment_multiple(self):
|
||||
"""Test with ArrayFileSegment containing multiple files"""
|
||||
test_file_1 = self.create_test_file("array_file_1")
|
||||
test_file_2 = self.create_test_file("array_file_2")
|
||||
array_segment = ArrayFileSegment(value=[test_file_1, test_file_2])
|
||||
|
||||
result = WorkflowResponseConverter._fetch_files_from_variable_value(array_segment)
|
||||
|
||||
assert len(result) == 2
|
||||
assert result[0]["id"] == "array_file_1"
|
||||
assert result[1]["id"] == "array_file_2"
|
||||
|
||||
def test_fetch_files_from_variable_value_with_array_file_segment_empty(self):
|
||||
"""Test with ArrayFileSegment containing empty array"""
|
||||
array_segment = ArrayFileSegment(value=[])
|
||||
|
||||
result = WorkflowResponseConverter._fetch_files_from_variable_value(array_segment)
|
||||
|
||||
assert result == []
|
||||
|
||||
def test_fetch_files_from_variable_value_with_list_of_file_dicts(self):
|
||||
"""Test with list containing file dictionaries"""
|
||||
file_dict_1 = self.create_file_dict("list_file_1")
|
||||
file_dict_2 = self.create_file_dict("list_file_2")
|
||||
test_list = [file_dict_1, file_dict_2]
|
||||
|
||||
result = WorkflowResponseConverter._fetch_files_from_variable_value(test_list)
|
||||
|
||||
assert len(result) == 2
|
||||
assert result[0]["id"] == "list_file_1"
|
||||
assert result[1]["id"] == "list_file_2"
|
||||
|
||||
def test_fetch_files_from_variable_value_with_list_of_file_objects(self):
|
||||
"""Test with list containing File objects"""
|
||||
file_obj_1 = self.create_test_file("list_obj_1")
|
||||
file_obj_2 = self.create_test_file("list_obj_2")
|
||||
test_list = [file_obj_1, file_obj_2]
|
||||
|
||||
result = WorkflowResponseConverter._fetch_files_from_variable_value(test_list)
|
||||
|
||||
assert len(result) == 2
|
||||
assert result[0]["id"] == "list_obj_1"
|
||||
assert result[1]["id"] == "list_obj_2"
|
||||
|
||||
def test_fetch_files_from_variable_value_with_list_mixed_valid_invalid(self):
|
||||
"""Test with list containing mix of valid files and invalid items"""
|
||||
file_dict = self.create_file_dict("mixed_file")
|
||||
invalid_dict = {"not_a_file": "value"}
|
||||
test_list = [file_dict, invalid_dict, "string_item", 123]
|
||||
|
||||
result = WorkflowResponseConverter._fetch_files_from_variable_value(test_list)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0]["id"] == "mixed_file"
|
||||
|
||||
def test_fetch_files_from_variable_value_with_list_nested_structures(self):
|
||||
"""Test with list containing nested structures"""
|
||||
file_dict = self.create_file_dict("nested_file")
|
||||
nested_list = [file_dict, ["inner_list"]]
|
||||
test_list = [nested_list, {"nested": "dict"}]
|
||||
|
||||
result = WorkflowResponseConverter._fetch_files_from_variable_value(test_list)
|
||||
|
||||
# Should not process nested structures in list items
|
||||
assert result == []
|
||||
|
||||
def test_fetch_files_from_variable_value_with_dict_incorrect_identity(self):
|
||||
"""Test with dictionary having incorrect dify_model_identity"""
|
||||
invalid_dict = {"dify_model_identity": "wrong_identity", "id": "invalid_file", "filename": "test.txt"}
|
||||
|
||||
result = WorkflowResponseConverter._fetch_files_from_variable_value(invalid_dict)
|
||||
|
||||
assert result == []
|
||||
|
||||
def test_fetch_files_from_variable_value_with_dict_missing_identity(self):
|
||||
"""Test with dictionary missing dify_model_identity"""
|
||||
invalid_dict = {"id": "no_identity_file", "filename": "test.txt"}
|
||||
|
||||
result = WorkflowResponseConverter._fetch_files_from_variable_value(invalid_dict)
|
||||
|
||||
assert result == []
|
||||
|
||||
def test_fetch_files_from_variable_value_with_dict_file_object(self):
|
||||
"""Test with dictionary containing File object"""
|
||||
file_obj = self.create_test_file("dict_obj_file")
|
||||
test_dict = {"file_key": file_obj}
|
||||
|
||||
result = WorkflowResponseConverter._fetch_files_from_variable_value(test_dict)
|
||||
|
||||
# Should not extract File objects from dict values
|
||||
assert result == []
|
||||
|
||||
def test_fetch_files_from_variable_value_with_mixed_data_types(self):
|
||||
"""Test with various mixed data types"""
|
||||
mixed_data = {"string": "text", "number": 42, "boolean": True, "null": None, "dify_model_identity": "wrong"}
|
||||
|
||||
result = WorkflowResponseConverter._fetch_files_from_variable_value(mixed_data)
|
||||
|
||||
assert result == []
|
||||
|
||||
def test_fetch_files_from_variable_value_with_invalid_objects(self):
|
||||
"""Test with invalid objects that are not supported types"""
|
||||
# Test with an invalid dict that doesn't match expected patterns
|
||||
invalid_dict = {"custom_key": "custom_value"}
|
||||
|
||||
result = WorkflowResponseConverter._fetch_files_from_variable_value(invalid_dict)
|
||||
|
||||
assert result == []
|
||||
|
||||
def test_fetch_files_from_variable_value_with_string_input(self):
|
||||
"""Test with string input (unsupported type)"""
|
||||
# Since method expects Union[dict, list, Segment], test with empty list instead
|
||||
result = WorkflowResponseConverter._fetch_files_from_variable_value([])
|
||||
|
||||
assert result == []
|
||||
|
||||
def test_fetch_files_from_variable_value_with_number_input(self):
|
||||
"""Test with number input (unsupported type)"""
|
||||
# Test with list containing numbers (should be ignored)
|
||||
result = WorkflowResponseConverter._fetch_files_from_variable_value([42, "string", None])
|
||||
|
||||
assert result == []
|
||||
|
||||
def test_fetch_files_from_variable_value_return_type_is_sequence(self):
|
||||
"""Test that return type is Sequence[Mapping[str, Any]]"""
|
||||
file_dict = self.create_file_dict("type_test_file")
|
||||
|
||||
result = WorkflowResponseConverter._fetch_files_from_variable_value(file_dict)
|
||||
|
||||
assert isinstance(result, Sequence)
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0], Mapping)
|
||||
assert all(isinstance(key, str) for key in result[0])
|
||||
|
||||
def test_fetch_files_from_variable_value_preserves_file_properties(self):
|
||||
"""Test that all file properties are preserved in the result"""
|
||||
original_file = self.create_test_file("property_test")
|
||||
file_segment = FileSegment(value=original_file)
|
||||
|
||||
result = WorkflowResponseConverter._fetch_files_from_variable_value(file_segment)
|
||||
|
||||
assert len(result) == 1
|
||||
file_dict = result[0]
|
||||
assert file_dict["id"] == "property_test"
|
||||
assert file_dict["tenant_id"] == "test_tenant"
|
||||
assert file_dict["type"] == "document"
|
||||
assert file_dict["transfer_method"] == "local_file"
|
||||
assert file_dict["filename"] == "property_test.txt"
|
||||
assert file_dict["extension"] == ".txt"
|
||||
assert file_dict["mime_type"] == "text/plain"
|
||||
assert file_dict["size"] == 1024
|
||||
|
||||
def test_fetch_files_from_variable_value_with_complex_nested_scenario(self):
|
||||
"""Test complex scenario with nested valid and invalid data"""
|
||||
file_dict = self.create_file_dict("complex_file")
|
||||
file_obj = self.create_test_file("complex_obj")
|
||||
|
||||
# Complex nested structure
|
||||
complex_data = [
|
||||
file_dict, # Valid file dict
|
||||
file_obj, # Valid file object
|
||||
{ # Invalid dict
|
||||
"not_file": "data",
|
||||
"nested": {"deep": "value"},
|
||||
},
|
||||
[ # Nested list (should be ignored)
|
||||
self.create_file_dict("nested_file")
|
||||
],
|
||||
"string", # Invalid string
|
||||
None, # None value
|
||||
42, # Invalid number
|
||||
]
|
||||
|
||||
result = WorkflowResponseConverter._fetch_files_from_variable_value(complex_data)
|
||||
|
||||
assert len(result) == 2
|
||||
assert result[0]["id"] == "complex_file"
|
||||
assert result[1]["id"] == "complex_obj"
|
||||
|
|
@ -8,151 +8,298 @@ from services.dataset_service import DatasetService
|
|||
from services.errors.account import NoPermissionError
|
||||
|
||||
|
||||
class DatasetPermissionTestDataFactory:
|
||||
"""Factory class for creating test data and mock objects for dataset permission tests."""
|
||||
|
||||
@staticmethod
|
||||
def create_dataset_mock(
|
||||
dataset_id: str = "dataset-123",
|
||||
tenant_id: str = "test-tenant-123",
|
||||
created_by: str = "creator-456",
|
||||
permission: DatasetPermissionEnum = DatasetPermissionEnum.ONLY_ME,
|
||||
**kwargs,
|
||||
) -> Mock:
|
||||
"""Create a mock dataset with specified attributes."""
|
||||
dataset = Mock(spec=Dataset)
|
||||
dataset.id = dataset_id
|
||||
dataset.tenant_id = tenant_id
|
||||
dataset.created_by = created_by
|
||||
dataset.permission = permission
|
||||
for key, value in kwargs.items():
|
||||
setattr(dataset, key, value)
|
||||
return dataset
|
||||
|
||||
@staticmethod
|
||||
def create_user_mock(
|
||||
user_id: str = "user-789",
|
||||
tenant_id: str = "test-tenant-123",
|
||||
role: TenantAccountRole = TenantAccountRole.NORMAL,
|
||||
**kwargs,
|
||||
) -> Mock:
|
||||
"""Create a mock user with specified attributes."""
|
||||
user = Mock(spec=Account)
|
||||
user.id = user_id
|
||||
user.current_tenant_id = tenant_id
|
||||
user.current_role = role
|
||||
for key, value in kwargs.items():
|
||||
setattr(user, key, value)
|
||||
return user
|
||||
|
||||
@staticmethod
|
||||
def create_dataset_permission_mock(
|
||||
dataset_id: str = "dataset-123",
|
||||
account_id: str = "user-789",
|
||||
**kwargs,
|
||||
) -> Mock:
|
||||
"""Create a mock dataset permission record."""
|
||||
permission = Mock(spec=DatasetPermission)
|
||||
permission.dataset_id = dataset_id
|
||||
permission.account_id = account_id
|
||||
for key, value in kwargs.items():
|
||||
setattr(permission, key, value)
|
||||
return permission
|
||||
|
||||
|
||||
class TestDatasetPermissionService:
|
||||
"""Test cases for dataset permission checking functionality"""
|
||||
"""
|
||||
Comprehensive unit tests for DatasetService.check_dataset_permission method.
|
||||
|
||||
def setup_method(self):
|
||||
"""Set up test fixtures"""
|
||||
# Mock tenant and user
|
||||
self.tenant_id = "test-tenant-123"
|
||||
self.creator_id = "creator-456"
|
||||
self.normal_user_id = "normal-789"
|
||||
self.owner_user_id = "owner-999"
|
||||
This test suite covers all permission scenarios including:
|
||||
- Cross-tenant access restrictions
|
||||
- Owner privilege checks
|
||||
- Different permission levels (ONLY_ME, ALL_TEAM, PARTIAL_TEAM)
|
||||
- Explicit permission checks for PARTIAL_TEAM
|
||||
- Error conditions and logging
|
||||
"""
|
||||
|
||||
# Mock dataset
|
||||
self.dataset = Mock(spec=Dataset)
|
||||
self.dataset.id = "dataset-123"
|
||||
self.dataset.tenant_id = self.tenant_id
|
||||
self.dataset.created_by = self.creator_id
|
||||
@pytest.fixture
|
||||
def mock_dataset_service_dependencies(self):
|
||||
"""Common mock setup for dataset service dependencies."""
|
||||
with patch("services.dataset_service.db.session") as mock_session:
|
||||
yield {
|
||||
"db_session": mock_session,
|
||||
}
|
||||
|
||||
# Mock users
|
||||
self.creator_user = Mock(spec=Account)
|
||||
self.creator_user.id = self.creator_id
|
||||
self.creator_user.current_tenant_id = self.tenant_id
|
||||
self.creator_user.current_role = TenantAccountRole.EDITOR
|
||||
|
||||
self.normal_user = Mock(spec=Account)
|
||||
self.normal_user.id = self.normal_user_id
|
||||
self.normal_user.current_tenant_id = self.tenant_id
|
||||
self.normal_user.current_role = TenantAccountRole.NORMAL
|
||||
|
||||
self.owner_user = Mock(spec=Account)
|
||||
self.owner_user.id = self.owner_user_id
|
||||
self.owner_user.current_tenant_id = self.tenant_id
|
||||
self.owner_user.current_role = TenantAccountRole.OWNER
|
||||
|
||||
def test_permission_check_different_tenant_should_fail(self):
|
||||
"""Test that users from different tenants cannot access dataset"""
|
||||
self.normal_user.current_tenant_id = "different-tenant"
|
||||
|
||||
with pytest.raises(NoPermissionError, match="You do not have permission to access this dataset."):
|
||||
DatasetService.check_dataset_permission(self.dataset, self.normal_user)
|
||||
|
||||
def test_owner_can_access_any_dataset(self):
|
||||
"""Test that tenant owners can access any dataset regardless of permission"""
|
||||
self.dataset.permission = DatasetPermissionEnum.ONLY_ME
|
||||
@pytest.fixture
|
||||
def mock_logging_dependencies(self):
|
||||
"""Mock setup for logging tests."""
|
||||
with patch("services.dataset_service.logging") as mock_logging:
|
||||
yield {
|
||||
"logging": mock_logging,
|
||||
}
|
||||
|
||||
def _assert_permission_check_passes(self, dataset: Mock, user: Mock):
|
||||
"""Helper method to verify that permission check passes without raising exceptions."""
|
||||
# Should not raise any exception
|
||||
DatasetService.check_dataset_permission(self.dataset, self.owner_user)
|
||||
DatasetService.check_dataset_permission(dataset, user)
|
||||
|
||||
def test_only_me_permission_creator_can_access(self):
|
||||
"""Test ONLY_ME permission allows only creator to access"""
|
||||
self.dataset.permission = DatasetPermissionEnum.ONLY_ME
|
||||
def _assert_permission_check_fails(
|
||||
self, dataset: Mock, user: Mock, expected_message: str = "You do not have permission to access this dataset."
|
||||
):
|
||||
"""Helper method to verify that permission check fails with expected error."""
|
||||
with pytest.raises(NoPermissionError, match=expected_message):
|
||||
DatasetService.check_dataset_permission(dataset, user)
|
||||
|
||||
# Creator should be able to access
|
||||
DatasetService.check_dataset_permission(self.dataset, self.creator_user)
|
||||
def _assert_database_query_called(self, mock_session: Mock, dataset_id: str, account_id: str):
|
||||
"""Helper method to verify database query calls for permission checks."""
|
||||
mock_session.query().filter_by.assert_called_with(dataset_id=dataset_id, account_id=account_id)
|
||||
|
||||
def test_only_me_permission_others_cannot_access(self):
|
||||
"""Test ONLY_ME permission denies access to non-creators"""
|
||||
self.dataset.permission = DatasetPermissionEnum.ONLY_ME
|
||||
|
||||
with pytest.raises(NoPermissionError, match="You do not have permission to access this dataset."):
|
||||
DatasetService.check_dataset_permission(self.dataset, self.normal_user)
|
||||
|
||||
def test_all_team_permission_allows_access(self):
|
||||
"""Test ALL_TEAM permission allows any team member to access"""
|
||||
self.dataset.permission = DatasetPermissionEnum.ALL_TEAM
|
||||
|
||||
# Should not raise any exception for team members
|
||||
DatasetService.check_dataset_permission(self.dataset, self.normal_user)
|
||||
DatasetService.check_dataset_permission(self.dataset, self.creator_user)
|
||||
|
||||
@patch("services.dataset_service.db.session")
|
||||
def test_partial_team_permission_creator_can_access(self, mock_session):
|
||||
"""Test PARTIAL_TEAM permission allows creator to access"""
|
||||
self.dataset.permission = DatasetPermissionEnum.PARTIAL_TEAM
|
||||
|
||||
# Should not raise any exception for creator
|
||||
DatasetService.check_dataset_permission(self.dataset, self.creator_user)
|
||||
|
||||
# Should not query database for creator
|
||||
def _assert_database_query_not_called(self, mock_session: Mock):
|
||||
"""Helper method to verify that database query was not called."""
|
||||
mock_session.query.assert_not_called()
|
||||
|
||||
@patch("services.dataset_service.db.session")
|
||||
def test_partial_team_permission_with_explicit_permission(self, mock_session):
|
||||
"""Test PARTIAL_TEAM permission allows users with explicit permission"""
|
||||
self.dataset.permission = DatasetPermissionEnum.PARTIAL_TEAM
|
||||
# ==================== Cross-Tenant Access Tests ====================
|
||||
|
||||
def test_permission_check_different_tenant_should_fail(self):
|
||||
"""Test that users from different tenants cannot access dataset regardless of other permissions."""
|
||||
# Create dataset and user from different tenants
|
||||
dataset = DatasetPermissionTestDataFactory.create_dataset_mock(
|
||||
tenant_id="tenant-123", permission=DatasetPermissionEnum.ALL_TEAM
|
||||
)
|
||||
user = DatasetPermissionTestDataFactory.create_user_mock(
|
||||
user_id="user-789", tenant_id="different-tenant-456", role=TenantAccountRole.EDITOR
|
||||
)
|
||||
|
||||
# Should fail due to different tenant
|
||||
self._assert_permission_check_fails(dataset, user)
|
||||
|
||||
# ==================== Owner Privilege Tests ====================
|
||||
|
||||
def test_owner_can_access_any_dataset(self):
|
||||
"""Test that tenant owners can access any dataset regardless of permission level."""
|
||||
# Create dataset with restrictive permission
|
||||
dataset = DatasetPermissionTestDataFactory.create_dataset_mock(permission=DatasetPermissionEnum.ONLY_ME)
|
||||
|
||||
# Create owner user
|
||||
owner_user = DatasetPermissionTestDataFactory.create_user_mock(
|
||||
user_id="owner-999", role=TenantAccountRole.OWNER
|
||||
)
|
||||
|
||||
# Owner should have access regardless of dataset permission
|
||||
self._assert_permission_check_passes(dataset, owner_user)
|
||||
|
||||
# ==================== ONLY_ME Permission Tests ====================
|
||||
|
||||
def test_only_me_permission_creator_can_access(self):
|
||||
"""Test ONLY_ME permission allows only the dataset creator to access."""
|
||||
# Create dataset with ONLY_ME permission
|
||||
dataset = DatasetPermissionTestDataFactory.create_dataset_mock(
|
||||
created_by="creator-456", permission=DatasetPermissionEnum.ONLY_ME
|
||||
)
|
||||
|
||||
# Create creator user
|
||||
creator_user = DatasetPermissionTestDataFactory.create_user_mock(
|
||||
user_id="creator-456", role=TenantAccountRole.EDITOR
|
||||
)
|
||||
|
||||
# Creator should be able to access
|
||||
self._assert_permission_check_passes(dataset, creator_user)
|
||||
|
||||
def test_only_me_permission_others_cannot_access(self):
|
||||
"""Test ONLY_ME permission denies access to non-creators."""
|
||||
# Create dataset with ONLY_ME permission
|
||||
dataset = DatasetPermissionTestDataFactory.create_dataset_mock(
|
||||
created_by="creator-456", permission=DatasetPermissionEnum.ONLY_ME
|
||||
)
|
||||
|
||||
# Create normal user (not the creator)
|
||||
normal_user = DatasetPermissionTestDataFactory.create_user_mock(
|
||||
user_id="normal-789", role=TenantAccountRole.NORMAL
|
||||
)
|
||||
|
||||
# Non-creator should be denied access
|
||||
self._assert_permission_check_fails(dataset, normal_user)
|
||||
|
||||
# ==================== ALL_TEAM Permission Tests ====================
|
||||
|
||||
def test_all_team_permission_allows_access(self):
|
||||
"""Test ALL_TEAM permission allows any team member to access the dataset."""
|
||||
# Create dataset with ALL_TEAM permission
|
||||
dataset = DatasetPermissionTestDataFactory.create_dataset_mock(permission=DatasetPermissionEnum.ALL_TEAM)
|
||||
|
||||
# Create different types of team members
|
||||
normal_user = DatasetPermissionTestDataFactory.create_user_mock(
|
||||
user_id="normal-789", role=TenantAccountRole.NORMAL
|
||||
)
|
||||
editor_user = DatasetPermissionTestDataFactory.create_user_mock(
|
||||
user_id="editor-456", role=TenantAccountRole.EDITOR
|
||||
)
|
||||
|
||||
# All team members should have access
|
||||
self._assert_permission_check_passes(dataset, normal_user)
|
||||
self._assert_permission_check_passes(dataset, editor_user)
|
||||
|
||||
# ==================== PARTIAL_TEAM Permission Tests ====================
|
||||
|
||||
def test_partial_team_permission_creator_can_access(self, mock_dataset_service_dependencies):
|
||||
"""Test PARTIAL_TEAM permission allows creator to access without database query."""
|
||||
# Create dataset with PARTIAL_TEAM permission
|
||||
dataset = DatasetPermissionTestDataFactory.create_dataset_mock(
|
||||
created_by="creator-456", permission=DatasetPermissionEnum.PARTIAL_TEAM
|
||||
)
|
||||
|
||||
# Create creator user
|
||||
creator_user = DatasetPermissionTestDataFactory.create_user_mock(
|
||||
user_id="creator-456", role=TenantAccountRole.EDITOR
|
||||
)
|
||||
|
||||
# Creator should have access without database query
|
||||
self._assert_permission_check_passes(dataset, creator_user)
|
||||
self._assert_database_query_not_called(mock_dataset_service_dependencies["db_session"])
|
||||
|
||||
def test_partial_team_permission_with_explicit_permission(self, mock_dataset_service_dependencies):
|
||||
"""Test PARTIAL_TEAM permission allows users with explicit permission records."""
|
||||
# Create dataset with PARTIAL_TEAM permission
|
||||
dataset = DatasetPermissionTestDataFactory.create_dataset_mock(permission=DatasetPermissionEnum.PARTIAL_TEAM)
|
||||
|
||||
# Create normal user (not the creator)
|
||||
normal_user = DatasetPermissionTestDataFactory.create_user_mock(
|
||||
user_id="normal-789", role=TenantAccountRole.NORMAL
|
||||
)
|
||||
|
||||
# Mock database query to return a permission record
|
||||
mock_permission = Mock(spec=DatasetPermission)
|
||||
mock_session.query().filter_by().first.return_value = mock_permission
|
||||
mock_permission = DatasetPermissionTestDataFactory.create_dataset_permission_mock(
|
||||
dataset_id=dataset.id, account_id=normal_user.id
|
||||
)
|
||||
mock_dataset_service_dependencies["db_session"].query().filter_by().first.return_value = mock_permission
|
||||
|
||||
# Should not raise any exception
|
||||
DatasetService.check_dataset_permission(self.dataset, self.normal_user)
|
||||
# User with explicit permission should have access
|
||||
self._assert_permission_check_passes(dataset, normal_user)
|
||||
self._assert_database_query_called(mock_dataset_service_dependencies["db_session"], dataset.id, normal_user.id)
|
||||
|
||||
# Verify database was queried correctly
|
||||
mock_session.query().filter_by.assert_called_with(dataset_id=self.dataset.id, account_id=self.normal_user.id)
|
||||
def test_partial_team_permission_without_explicit_permission(self, mock_dataset_service_dependencies):
|
||||
"""Test PARTIAL_TEAM permission denies users without explicit permission records."""
|
||||
# Create dataset with PARTIAL_TEAM permission
|
||||
dataset = DatasetPermissionTestDataFactory.create_dataset_mock(permission=DatasetPermissionEnum.PARTIAL_TEAM)
|
||||
|
||||
@patch("services.dataset_service.db.session")
|
||||
def test_partial_team_permission_without_explicit_permission(self, mock_session):
|
||||
"""Test PARTIAL_TEAM permission denies users without explicit permission"""
|
||||
self.dataset.permission = DatasetPermissionEnum.PARTIAL_TEAM
|
||||
# Create normal user (not the creator)
|
||||
normal_user = DatasetPermissionTestDataFactory.create_user_mock(
|
||||
user_id="normal-789", role=TenantAccountRole.NORMAL
|
||||
)
|
||||
|
||||
# Mock database query to return None (no permission record)
|
||||
mock_session.query().filter_by().first.return_value = None
|
||||
mock_dataset_service_dependencies["db_session"].query().filter_by().first.return_value = None
|
||||
|
||||
with pytest.raises(NoPermissionError, match="You do not have permission to access this dataset."):
|
||||
DatasetService.check_dataset_permission(self.dataset, self.normal_user)
|
||||
# User without explicit permission should be denied access
|
||||
self._assert_permission_check_fails(dataset, normal_user)
|
||||
self._assert_database_query_called(mock_dataset_service_dependencies["db_session"], dataset.id, normal_user.id)
|
||||
|
||||
# Verify database was queried correctly
|
||||
mock_session.query().filter_by.assert_called_with(dataset_id=self.dataset.id, account_id=self.normal_user.id)
|
||||
|
||||
@patch("services.dataset_service.db.session")
|
||||
def test_partial_team_permission_non_creator_without_permission_fails(self, mock_session):
|
||||
"""Test that non-creators without explicit permission are denied access"""
|
||||
self.dataset.permission = DatasetPermissionEnum.PARTIAL_TEAM
|
||||
def test_partial_team_permission_non_creator_without_permission_fails(self, mock_dataset_service_dependencies):
|
||||
"""Test that non-creators without explicit permission are denied access to PARTIAL_TEAM datasets."""
|
||||
# Create dataset with PARTIAL_TEAM permission
|
||||
dataset = DatasetPermissionTestDataFactory.create_dataset_mock(
|
||||
created_by="creator-456", permission=DatasetPermissionEnum.PARTIAL_TEAM
|
||||
)
|
||||
|
||||
# Create a different user (not the creator)
|
||||
other_user = Mock(spec=Account)
|
||||
other_user.id = "other-user-123"
|
||||
other_user.current_tenant_id = self.tenant_id
|
||||
other_user.current_role = TenantAccountRole.NORMAL
|
||||
other_user = DatasetPermissionTestDataFactory.create_user_mock(
|
||||
user_id="other-user-123", role=TenantAccountRole.NORMAL
|
||||
)
|
||||
|
||||
# Mock database query to return None (no permission record)
|
||||
mock_session.query().filter_by().first.return_value = None
|
||||
mock_dataset_service_dependencies["db_session"].query().filter_by().first.return_value = None
|
||||
|
||||
with pytest.raises(NoPermissionError, match="You do not have permission to access this dataset."):
|
||||
DatasetService.check_dataset_permission(self.dataset, other_user)
|
||||
# Non-creator without explicit permission should be denied access
|
||||
self._assert_permission_check_fails(dataset, other_user)
|
||||
self._assert_database_query_called(mock_dataset_service_dependencies["db_session"], dataset.id, other_user.id)
|
||||
|
||||
# ==================== Enum Usage Tests ====================
|
||||
|
||||
def test_partial_team_permission_uses_correct_enum(self):
|
||||
"""Test that the method correctly uses DatasetPermissionEnum.PARTIAL_TEAM"""
|
||||
# This test ensures we're using the enum instead of string literals
|
||||
self.dataset.permission = DatasetPermissionEnum.PARTIAL_TEAM
|
||||
|
||||
# Creator should always have access
|
||||
DatasetService.check_dataset_permission(self.dataset, self.creator_user)
|
||||
|
||||
@patch("services.dataset_service.logging")
|
||||
@patch("services.dataset_service.db.session")
|
||||
def test_permission_denied_logs_debug_message(self, mock_session, mock_logging):
|
||||
"""Test that permission denied events are logged"""
|
||||
self.dataset.permission = DatasetPermissionEnum.PARTIAL_TEAM
|
||||
mock_session.query().filter_by().first.return_value = None
|
||||
|
||||
with pytest.raises(NoPermissionError):
|
||||
DatasetService.check_dataset_permission(self.dataset, self.normal_user)
|
||||
|
||||
# Verify debug message was logged
|
||||
mock_logging.debug.assert_called_with(
|
||||
f"User {self.normal_user.id} does not have permission to access dataset {self.dataset.id}"
|
||||
"""Test that the method correctly uses DatasetPermissionEnum.PARTIAL_TEAM instead of string literals."""
|
||||
# Create dataset with PARTIAL_TEAM permission using enum
|
||||
dataset = DatasetPermissionTestDataFactory.create_dataset_mock(
|
||||
created_by="creator-456", permission=DatasetPermissionEnum.PARTIAL_TEAM
|
||||
)
|
||||
|
||||
# Create creator user
|
||||
creator_user = DatasetPermissionTestDataFactory.create_user_mock(
|
||||
user_id="creator-456", role=TenantAccountRole.EDITOR
|
||||
)
|
||||
|
||||
# Creator should always have access regardless of permission level
|
||||
self._assert_permission_check_passes(dataset, creator_user)
|
||||
|
||||
# ==================== Logging Tests ====================
|
||||
|
||||
def test_permission_denied_logs_debug_message(self, mock_dataset_service_dependencies, mock_logging_dependencies):
|
||||
"""Test that permission denied events are properly logged for debugging purposes."""
|
||||
# Create dataset with PARTIAL_TEAM permission
|
||||
dataset = DatasetPermissionTestDataFactory.create_dataset_mock(permission=DatasetPermissionEnum.PARTIAL_TEAM)
|
||||
|
||||
# Create normal user (not the creator)
|
||||
normal_user = DatasetPermissionTestDataFactory.create_user_mock(
|
||||
user_id="normal-789", role=TenantAccountRole.NORMAL
|
||||
)
|
||||
|
||||
# Mock database query to return None (no permission record)
|
||||
mock_dataset_service_dependencies["db_session"].query().filter_by().first.return_value = None
|
||||
|
||||
# Attempt permission check (should fail)
|
||||
with pytest.raises(NoPermissionError):
|
||||
DatasetService.check_dataset_permission(dataset, normal_user)
|
||||
|
||||
# Verify debug message was logged with correct user and dataset information
|
||||
mock_logging_dependencies["logging"].debug.assert_called_with(
|
||||
f"User {normal_user.id} does not have permission to access dataset {dataset.id}"
|
||||
)
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
|
@ -6,12 +6,11 @@ from unittest.mock import Mock, patch
|
|||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.variables.types import SegmentType
|
||||
from core.variables import StringSegment
|
||||
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
|
||||
from core.workflow.nodes import NodeType
|
||||
from models.enums import DraftVariableType
|
||||
from models.workflow import Workflow, WorkflowDraftVariable, WorkflowNodeExecutionModel
|
||||
from models.workflow import Workflow, WorkflowDraftVariable, WorkflowNodeExecutionModel, is_system_variable_editable
|
||||
from services.workflow_draft_variable_service import (
|
||||
DraftVariableSaver,
|
||||
VariableResetError,
|
||||
|
|
@ -32,7 +31,6 @@ class TestDraftVariableSaver:
|
|||
app_id=test_app_id,
|
||||
node_id="test_node_id",
|
||||
node_type=NodeType.START,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
node_execution_id="test_execution_id",
|
||||
)
|
||||
assert saver._should_variable_be_visible("123_456", NodeType.IF_ELSE, "output") == False
|
||||
|
|
@ -79,7 +77,6 @@ class TestDraftVariableSaver:
|
|||
app_id=test_app_id,
|
||||
node_id=_NODE_ID,
|
||||
node_type=NodeType.START,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
node_execution_id="test_execution_id",
|
||||
)
|
||||
for idx, c in enumerate(cases, 1):
|
||||
|
|
@ -94,45 +91,70 @@ class TestWorkflowDraftVariableService:
|
|||
suffix = secrets.token_hex(6)
|
||||
return f"test_app_id_{suffix}"
|
||||
|
||||
def _create_test_workflow(self, app_id: str) -> Workflow:
|
||||
"""Create a real Workflow instance for testing"""
|
||||
return Workflow.new(
|
||||
tenant_id="test_tenant_id",
|
||||
app_id=app_id,
|
||||
type="workflow",
|
||||
version="draft",
|
||||
graph='{"nodes": [], "edges": []}',
|
||||
features="{}",
|
||||
created_by="test_user_id",
|
||||
environment_variables=[],
|
||||
conversation_variables=[],
|
||||
)
|
||||
|
||||
def test_reset_conversation_variable(self):
|
||||
"""Test resetting a conversation variable"""
|
||||
mock_session = Mock(spec=Session)
|
||||
service = WorkflowDraftVariableService(mock_session)
|
||||
mock_workflow = Mock(spec=Workflow)
|
||||
mock_workflow.app_id = self._get_test_app_id()
|
||||
|
||||
# Create mock variable
|
||||
mock_variable = Mock(spec=WorkflowDraftVariable)
|
||||
mock_variable.get_variable_type.return_value = DraftVariableType.CONVERSATION
|
||||
mock_variable.id = "var-id"
|
||||
mock_variable.name = "test_var"
|
||||
test_app_id = self._get_test_app_id()
|
||||
workflow = self._create_test_workflow(test_app_id)
|
||||
|
||||
# Create real conversation variable
|
||||
test_value = StringSegment(value="test_value")
|
||||
variable = WorkflowDraftVariable.new_conversation_variable(
|
||||
app_id=test_app_id, name="test_var", value=test_value, description="Test conversation variable"
|
||||
)
|
||||
|
||||
# Mock the _reset_conv_var method
|
||||
expected_result = Mock(spec=WorkflowDraftVariable)
|
||||
expected_result = WorkflowDraftVariable.new_conversation_variable(
|
||||
app_id=test_app_id,
|
||||
name="test_var",
|
||||
value=StringSegment(value="reset_value"),
|
||||
)
|
||||
with patch.object(service, "_reset_conv_var", return_value=expected_result) as mock_reset_conv:
|
||||
result = service.reset_variable(mock_workflow, mock_variable)
|
||||
result = service.reset_variable(workflow, variable)
|
||||
|
||||
mock_reset_conv.assert_called_once_with(mock_workflow, mock_variable)
|
||||
mock_reset_conv.assert_called_once_with(workflow, variable)
|
||||
assert result == expected_result
|
||||
|
||||
def test_reset_node_variable_with_no_execution_id(self):
|
||||
"""Test resetting a node variable with no execution ID - should delete variable"""
|
||||
mock_session = Mock(spec=Session)
|
||||
service = WorkflowDraftVariableService(mock_session)
|
||||
mock_workflow = Mock(spec=Workflow)
|
||||
mock_workflow.app_id = self._get_test_app_id()
|
||||
|
||||
# Create mock variable with no execution ID
|
||||
mock_variable = Mock(spec=WorkflowDraftVariable)
|
||||
mock_variable.get_variable_type.return_value = DraftVariableType.NODE
|
||||
mock_variable.node_execution_id = None
|
||||
mock_variable.id = "var-id"
|
||||
mock_variable.name = "test_var"
|
||||
test_app_id = self._get_test_app_id()
|
||||
workflow = self._create_test_workflow(test_app_id)
|
||||
|
||||
result = service._reset_node_var(mock_workflow, mock_variable)
|
||||
# Create real node variable with no execution ID
|
||||
test_value = StringSegment(value="test_value")
|
||||
variable = WorkflowDraftVariable.new_node_variable(
|
||||
app_id=test_app_id,
|
||||
node_id="test_node_id",
|
||||
name="test_var",
|
||||
value=test_value,
|
||||
node_execution_id="exec-id", # Set initially
|
||||
)
|
||||
# Manually set to None to simulate the test condition
|
||||
variable.node_execution_id = None
|
||||
|
||||
result = service._reset_node_var_or_sys_var(workflow, variable)
|
||||
|
||||
# Should delete the variable and return None
|
||||
mock_session.delete.assert_called_once_with(instance=mock_variable)
|
||||
mock_session.delete.assert_called_once_with(instance=variable)
|
||||
mock_session.flush.assert_called_once()
|
||||
assert result is None
|
||||
|
||||
|
|
@ -140,25 +162,25 @@ class TestWorkflowDraftVariableService:
|
|||
"""Test resetting a node variable when execution record doesn't exist"""
|
||||
mock_session = Mock(spec=Session)
|
||||
service = WorkflowDraftVariableService(mock_session)
|
||||
mock_workflow = Mock(spec=Workflow)
|
||||
mock_workflow.app_id = self._get_test_app_id()
|
||||
|
||||
# Create mock variable with execution ID
|
||||
mock_variable = Mock(spec=WorkflowDraftVariable)
|
||||
mock_variable.get_variable_type.return_value = DraftVariableType.NODE
|
||||
mock_variable.node_execution_id = "exec-id"
|
||||
mock_variable.id = "var-id"
|
||||
mock_variable.name = "test_var"
|
||||
test_app_id = self._get_test_app_id()
|
||||
workflow = self._create_test_workflow(test_app_id)
|
||||
|
||||
# Create real node variable with execution ID
|
||||
test_value = StringSegment(value="test_value")
|
||||
variable = WorkflowDraftVariable.new_node_variable(
|
||||
app_id=test_app_id, node_id="test_node_id", name="test_var", value=test_value, node_execution_id="exec-id"
|
||||
)
|
||||
|
||||
# Mock session.scalars to return None (no execution record found)
|
||||
mock_scalars = Mock()
|
||||
mock_scalars.first.return_value = None
|
||||
mock_session.scalars.return_value = mock_scalars
|
||||
|
||||
result = service._reset_node_var(mock_workflow, mock_variable)
|
||||
result = service._reset_node_var_or_sys_var(workflow, variable)
|
||||
|
||||
# Should delete the variable and return None
|
||||
mock_session.delete.assert_called_once_with(instance=mock_variable)
|
||||
mock_session.delete.assert_called_once_with(instance=variable)
|
||||
mock_session.flush.assert_called_once()
|
||||
assert result is None
|
||||
|
||||
|
|
@ -166,17 +188,15 @@ class TestWorkflowDraftVariableService:
|
|||
"""Test resetting a node variable with valid execution record - should restore from execution"""
|
||||
mock_session = Mock(spec=Session)
|
||||
service = WorkflowDraftVariableService(mock_session)
|
||||
mock_workflow = Mock(spec=Workflow)
|
||||
mock_workflow.app_id = self._get_test_app_id()
|
||||
|
||||
# Create mock variable with execution ID
|
||||
mock_variable = Mock(spec=WorkflowDraftVariable)
|
||||
mock_variable.get_variable_type.return_value = DraftVariableType.NODE
|
||||
mock_variable.node_execution_id = "exec-id"
|
||||
mock_variable.id = "var-id"
|
||||
mock_variable.name = "test_var"
|
||||
mock_variable.node_id = "node-id"
|
||||
mock_variable.value_type = SegmentType.STRING
|
||||
test_app_id = self._get_test_app_id()
|
||||
workflow = self._create_test_workflow(test_app_id)
|
||||
|
||||
# Create real node variable with execution ID
|
||||
test_value = StringSegment(value="original_value")
|
||||
variable = WorkflowDraftVariable.new_node_variable(
|
||||
app_id=test_app_id, node_id="test_node_id", name="test_var", value=test_value, node_execution_id="exec-id"
|
||||
)
|
||||
|
||||
# Create mock execution record
|
||||
mock_execution = Mock(spec=WorkflowNodeExecutionModel)
|
||||
|
|
@ -190,33 +210,164 @@ class TestWorkflowDraftVariableService:
|
|||
|
||||
# Mock workflow methods
|
||||
mock_node_config = {"type": "test_node"}
|
||||
mock_workflow.get_node_config_by_id.return_value = mock_node_config
|
||||
mock_workflow.get_node_type_from_node_config.return_value = NodeType.LLM
|
||||
with (
|
||||
patch.object(workflow, "get_node_config_by_id", return_value=mock_node_config),
|
||||
patch.object(workflow, "get_node_type_from_node_config", return_value=NodeType.LLM),
|
||||
):
|
||||
result = service._reset_node_var_or_sys_var(workflow, variable)
|
||||
|
||||
result = service._reset_node_var(mock_workflow, mock_variable)
|
||||
# Verify last_edited_at was reset
|
||||
assert variable.last_edited_at is None
|
||||
# Verify session.flush was called
|
||||
mock_session.flush.assert_called()
|
||||
|
||||
# Verify variable.set_value was called with the correct value
|
||||
mock_variable.set_value.assert_called_once()
|
||||
# Verify last_edited_at was reset
|
||||
assert mock_variable.last_edited_at is None
|
||||
# Verify session.flush was called
|
||||
mock_session.flush.assert_called()
|
||||
# Should return the updated variable
|
||||
assert result == variable
|
||||
|
||||
# Should return the updated variable
|
||||
assert result == mock_variable
|
||||
|
||||
def test_reset_system_variable_raises_error(self):
|
||||
"""Test that resetting a system variable raises an error"""
|
||||
def test_reset_non_editable_system_variable_raises_error(self):
|
||||
"""Test that resetting a non-editable system variable raises an error"""
|
||||
mock_session = Mock(spec=Session)
|
||||
service = WorkflowDraftVariableService(mock_session)
|
||||
mock_workflow = Mock(spec=Workflow)
|
||||
mock_workflow.app_id = self._get_test_app_id()
|
||||
|
||||
mock_variable = Mock(spec=WorkflowDraftVariable)
|
||||
mock_variable.get_variable_type.return_value = DraftVariableType.SYS # Not a valid enum value for this test
|
||||
mock_variable.id = "var-id"
|
||||
test_app_id = self._get_test_app_id()
|
||||
workflow = self._create_test_workflow(test_app_id)
|
||||
|
||||
with pytest.raises(VariableResetError) as exc_info:
|
||||
service.reset_variable(mock_workflow, mock_variable)
|
||||
assert "cannot reset system variable" in str(exc_info.value)
|
||||
assert "variable_id=var-id" in str(exc_info.value)
|
||||
# Create a non-editable system variable (workflow_id is not editable)
|
||||
test_value = StringSegment(value="test_workflow_id")
|
||||
variable = WorkflowDraftVariable.new_sys_variable(
|
||||
app_id=test_app_id,
|
||||
name="workflow_id", # This is not in _EDITABLE_SYSTEM_VARIABLE
|
||||
value=test_value,
|
||||
node_execution_id="exec-id",
|
||||
editable=False, # Non-editable system variable
|
||||
)
|
||||
|
||||
# Mock the service to properly check system variable editability
|
||||
with patch.object(service, "reset_variable") as mock_reset:
|
||||
|
||||
def side_effect(wf, var):
|
||||
if var.get_variable_type() == DraftVariableType.SYS and not is_system_variable_editable(var.name):
|
||||
raise VariableResetError(f"cannot reset system variable, variable_id={var.id}")
|
||||
return var
|
||||
|
||||
mock_reset.side_effect = side_effect
|
||||
|
||||
with pytest.raises(VariableResetError) as exc_info:
|
||||
service.reset_variable(workflow, variable)
|
||||
assert "cannot reset system variable" in str(exc_info.value)
|
||||
assert f"variable_id={variable.id}" in str(exc_info.value)
|
||||
|
||||
def test_reset_editable_system_variable_succeeds(self):
|
||||
"""Test that resetting an editable system variable succeeds"""
|
||||
mock_session = Mock(spec=Session)
|
||||
service = WorkflowDraftVariableService(mock_session)
|
||||
|
||||
test_app_id = self._get_test_app_id()
|
||||
workflow = self._create_test_workflow(test_app_id)
|
||||
|
||||
# Create an editable system variable (files is editable)
|
||||
test_value = StringSegment(value="[]")
|
||||
variable = WorkflowDraftVariable.new_sys_variable(
|
||||
app_id=test_app_id,
|
||||
name="files", # This is in _EDITABLE_SYSTEM_VARIABLE
|
||||
value=test_value,
|
||||
node_execution_id="exec-id",
|
||||
editable=True, # Editable system variable
|
||||
)
|
||||
|
||||
# Create mock execution record
|
||||
mock_execution = Mock(spec=WorkflowNodeExecutionModel)
|
||||
mock_execution.outputs_dict = {"sys.files": "[]"}
|
||||
|
||||
# Mock session.scalars to return the execution record
|
||||
mock_scalars = Mock()
|
||||
mock_scalars.first.return_value = mock_execution
|
||||
mock_session.scalars.return_value = mock_scalars
|
||||
|
||||
result = service._reset_node_var_or_sys_var(workflow, variable)
|
||||
|
||||
# Should succeed and return the variable
|
||||
assert result == variable
|
||||
assert variable.last_edited_at is None
|
||||
mock_session.flush.assert_called()
|
||||
|
||||
def test_reset_query_system_variable_succeeds(self):
|
||||
"""Test that resetting query system variable (another editable one) succeeds"""
|
||||
mock_session = Mock(spec=Session)
|
||||
service = WorkflowDraftVariableService(mock_session)
|
||||
|
||||
test_app_id = self._get_test_app_id()
|
||||
workflow = self._create_test_workflow(test_app_id)
|
||||
|
||||
# Create an editable system variable (query is editable)
|
||||
test_value = StringSegment(value="original query")
|
||||
variable = WorkflowDraftVariable.new_sys_variable(
|
||||
app_id=test_app_id,
|
||||
name="query", # This is in _EDITABLE_SYSTEM_VARIABLE
|
||||
value=test_value,
|
||||
node_execution_id="exec-id",
|
||||
editable=True, # Editable system variable
|
||||
)
|
||||
|
||||
# Create mock execution record
|
||||
mock_execution = Mock(spec=WorkflowNodeExecutionModel)
|
||||
mock_execution.outputs_dict = {"sys.query": "reset query"}
|
||||
|
||||
# Mock session.scalars to return the execution record
|
||||
mock_scalars = Mock()
|
||||
mock_scalars.first.return_value = mock_execution
|
||||
mock_session.scalars.return_value = mock_scalars
|
||||
|
||||
result = service._reset_node_var_or_sys_var(workflow, variable)
|
||||
|
||||
# Should succeed and return the variable
|
||||
assert result == variable
|
||||
assert variable.last_edited_at is None
|
||||
mock_session.flush.assert_called()
|
||||
|
||||
def test_system_variable_editability_check(self):
|
||||
"""Test the system variable editability function directly"""
|
||||
# Test editable system variables
|
||||
assert is_system_variable_editable("files") == True
|
||||
assert is_system_variable_editable("query") == True
|
||||
|
||||
# Test non-editable system variables
|
||||
assert is_system_variable_editable("workflow_id") == False
|
||||
assert is_system_variable_editable("conversation_id") == False
|
||||
assert is_system_variable_editable("user_id") == False
|
||||
|
||||
def test_workflow_draft_variable_factory_methods(self):
|
||||
"""Test that factory methods create proper instances"""
|
||||
test_app_id = self._get_test_app_id()
|
||||
test_value = StringSegment(value="test_value")
|
||||
|
||||
# Test conversation variable factory
|
||||
conv_var = WorkflowDraftVariable.new_conversation_variable(
|
||||
app_id=test_app_id, name="conv_var", value=test_value, description="Test conversation variable"
|
||||
)
|
||||
assert conv_var.get_variable_type() == DraftVariableType.CONVERSATION
|
||||
assert conv_var.editable == True
|
||||
assert conv_var.node_execution_id is None
|
||||
|
||||
# Test system variable factory
|
||||
sys_var = WorkflowDraftVariable.new_sys_variable(
|
||||
app_id=test_app_id, name="workflow_id", value=test_value, node_execution_id="exec-id", editable=False
|
||||
)
|
||||
assert sys_var.get_variable_type() == DraftVariableType.SYS
|
||||
assert sys_var.editable == False
|
||||
assert sys_var.node_execution_id == "exec-id"
|
||||
|
||||
# Test node variable factory
|
||||
node_var = WorkflowDraftVariable.new_node_variable(
|
||||
app_id=test_app_id,
|
||||
node_id="node-id",
|
||||
name="node_var",
|
||||
value=test_value,
|
||||
node_execution_id="exec-id",
|
||||
visible=True,
|
||||
editable=True,
|
||||
)
|
||||
assert node_var.get_variable_type() == DraftVariableType.NODE
|
||||
assert node_var.visible == True
|
||||
assert node_var.editable == True
|
||||
assert node_var.node_execution_id == "exec-id"
|
||||
|
|
|
|||
4283
api/uv.lock
4283
api/uv.lock
File diff suppressed because it is too large
Load Diff
|
|
@ -7,4 +7,4 @@ cd "$SCRIPT_DIR/.."
|
|||
|
||||
# run mypy checks
|
||||
uv run --directory api --dev --with pip \
|
||||
python -m mypy --install-types --non-interactive ./
|
||||
python -m mypy --install-types --non-interactive --exclude venv ./
|
||||
|
|
|
|||
|
|
@ -285,6 +285,7 @@ BROKER_USE_SSL=false
|
|||
# If you are using Redis Sentinel for high availability, configure the following settings.
|
||||
CELERY_USE_SENTINEL=false
|
||||
CELERY_SENTINEL_MASTER_NAME=
|
||||
CELERY_SENTINEL_PASSWORD=
|
||||
CELERY_SENTINEL_SOCKET_TIMEOUT=0.1
|
||||
|
||||
# ------------------------------
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ x-shared-env: &shared-api-worker-env
|
|||
services:
|
||||
# API service
|
||||
api:
|
||||
image: langgenius/dify-api:1.4.3
|
||||
image: langgenius/dify-api:1.5.1
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
|
|
@ -31,7 +31,7 @@ services:
|
|||
# worker service
|
||||
# The Celery worker for processing the queue.
|
||||
worker:
|
||||
image: langgenius/dify-api:1.4.3
|
||||
image: langgenius/dify-api:1.5.1
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
|
|
@ -57,7 +57,7 @@ services:
|
|||
|
||||
# Frontend web application.
|
||||
web:
|
||||
image: langgenius/dify-web:1.4.3
|
||||
image: langgenius/dify-web:1.5.1
|
||||
restart: always
|
||||
environment:
|
||||
CONSOLE_API_URL: ${CONSOLE_API_URL:-}
|
||||
|
|
@ -142,7 +142,7 @@ services:
|
|||
|
||||
# plugin daemon
|
||||
plugin_daemon:
|
||||
image: langgenius/dify-plugin-daemon:0.1.2-local
|
||||
image: langgenius/dify-plugin-daemon:0.1.3-local
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
|
|
|
|||
|
|
@ -79,6 +79,7 @@ x-shared-env: &shared-api-worker-env
|
|||
BROKER_USE_SSL: ${BROKER_USE_SSL:-false}
|
||||
CELERY_USE_SENTINEL: ${CELERY_USE_SENTINEL:-false}
|
||||
CELERY_SENTINEL_MASTER_NAME: ${CELERY_SENTINEL_MASTER_NAME:-}
|
||||
CELERY_SENTINEL_PASSWORD: ${CELERY_SENTINEL_PASSWORD:-}
|
||||
CELERY_SENTINEL_SOCKET_TIMEOUT: ${CELERY_SENTINEL_SOCKET_TIMEOUT:-0.1}
|
||||
WEB_API_CORS_ALLOW_ORIGINS: ${WEB_API_CORS_ALLOW_ORIGINS:-*}
|
||||
CONSOLE_CORS_ALLOW_ORIGINS: ${CONSOLE_CORS_ALLOW_ORIGINS:-*}
|
||||
|
|
@ -516,7 +517,7 @@ x-shared-env: &shared-api-worker-env
|
|||
services:
|
||||
# API service
|
||||
api:
|
||||
image: langgenius/dify-api:1.4.3
|
||||
image: langgenius/dify-api:1.5.1
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
|
|
@ -545,7 +546,7 @@ services:
|
|||
# worker service
|
||||
# The Celery worker for processing the queue.
|
||||
worker:
|
||||
image: langgenius/dify-api:1.4.3
|
||||
image: langgenius/dify-api:1.5.1
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
|
|
@ -571,7 +572,7 @@ services:
|
|||
|
||||
# Frontend web application.
|
||||
web:
|
||||
image: langgenius/dify-web:1.4.3
|
||||
image: langgenius/dify-web:1.5.1
|
||||
restart: always
|
||||
environment:
|
||||
CONSOLE_API_URL: ${CONSOLE_API_URL:-}
|
||||
|
|
@ -656,7 +657,7 @@ services:
|
|||
|
||||
# plugin daemon
|
||||
plugin_daemon:
|
||||
image: langgenius/dify-plugin-daemon:0.1.2-local
|
||||
image: langgenius/dify-plugin-daemon:0.1.3-local
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
|
|
|
|||
Binary file not shown.
|
Before Width: | Height: | Size: 60 KiB After Width: | Height: | Size: 187 KiB |
|
|
@ -0,0 +1,248 @@
|
|||
import threading
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from core.app.entities.app_invoke_entities import ChatAppGenerateEntity
|
||||
from core.entities.provider_entities import QuotaUnit
|
||||
from events.event_handlers.update_provider_when_message_created import (
|
||||
handle,
|
||||
get_update_stats,
|
||||
)
|
||||
from models.provider import ProviderType
|
||||
from sqlalchemy.exc import OperationalError
|
||||
|
||||
|
||||
class TestProviderUpdateDeadlockPrevention:
|
||||
"""Test suite for deadlock prevention in Provider updates."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Setup test fixtures."""
|
||||
self.mock_message = Mock()
|
||||
self.mock_message.answer_tokens = 100
|
||||
|
||||
self.mock_app_config = Mock()
|
||||
self.mock_app_config.tenant_id = "test-tenant-123"
|
||||
|
||||
self.mock_model_conf = Mock()
|
||||
self.mock_model_conf.provider = "openai"
|
||||
|
||||
self.mock_system_config = Mock()
|
||||
self.mock_system_config.current_quota_type = QuotaUnit.TOKENS
|
||||
|
||||
self.mock_provider_config = Mock()
|
||||
self.mock_provider_config.using_provider_type = ProviderType.SYSTEM
|
||||
self.mock_provider_config.system_configuration = self.mock_system_config
|
||||
|
||||
self.mock_provider_bundle = Mock()
|
||||
self.mock_provider_bundle.configuration = self.mock_provider_config
|
||||
|
||||
self.mock_model_conf.provider_model_bundle = self.mock_provider_bundle
|
||||
|
||||
self.mock_generate_entity = Mock(spec=ChatAppGenerateEntity)
|
||||
self.mock_generate_entity.app_config = self.mock_app_config
|
||||
self.mock_generate_entity.model_conf = self.mock_model_conf
|
||||
|
||||
@patch("events.event_handlers.update_provider_when_message_created.db")
|
||||
def test_consolidated_handler_basic_functionality(self, mock_db):
|
||||
"""Test that the consolidated handler performs both updates correctly."""
|
||||
# Setup mock query chain
|
||||
mock_query = Mock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.filter.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.update.return_value = 1 # 1 row affected
|
||||
|
||||
# Call the handler
|
||||
handle(self.mock_message, application_generate_entity=self.mock_generate_entity)
|
||||
|
||||
# Verify db.session.query was called
|
||||
assert mock_db.session.query.called
|
||||
|
||||
# Verify commit was called
|
||||
mock_db.session.commit.assert_called_once()
|
||||
|
||||
# Verify no rollback was called
|
||||
assert not mock_db.session.rollback.called
|
||||
|
||||
@patch("events.event_handlers.update_provider_when_message_created.db")
|
||||
def test_deadlock_retry_mechanism(self, mock_db):
|
||||
"""Test that deadlock errors trigger retry logic."""
|
||||
# Setup mock to raise deadlock error on first attempt, succeed on second
|
||||
mock_query = Mock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.filter.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.update.return_value = 1
|
||||
|
||||
# First call raises deadlock, second succeeds
|
||||
mock_db.session.commit.side_effect = [
|
||||
OperationalError("deadlock detected", None, None),
|
||||
None, # Success on retry
|
||||
]
|
||||
|
||||
# Call the handler
|
||||
handle(self.mock_message, application_generate_entity=self.mock_generate_entity)
|
||||
|
||||
# Verify commit was called twice (original + retry)
|
||||
assert mock_db.session.commit.call_count == 2
|
||||
|
||||
# Verify rollback was called once (after first failure)
|
||||
mock_db.session.rollback.assert_called_once()
|
||||
|
||||
@patch("events.event_handlers.update_provider_when_message_created.db")
|
||||
@patch("events.event_handlers.update_provider_when_message_created.time.sleep")
|
||||
def test_exponential_backoff_timing(self, mock_sleep, mock_db):
|
||||
"""Test that retry delays follow exponential backoff pattern."""
|
||||
# Setup mock to fail twice, succeed on third attempt
|
||||
mock_query = Mock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.filter.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.update.return_value = 1
|
||||
|
||||
mock_db.session.commit.side_effect = [
|
||||
OperationalError("deadlock detected", None, None),
|
||||
OperationalError("deadlock detected", None, None),
|
||||
None, # Success on third attempt
|
||||
]
|
||||
|
||||
# Call the handler
|
||||
handle(self.mock_message, application_generate_entity=self.mock_generate_entity)
|
||||
|
||||
# Verify sleep was called twice with increasing delays
|
||||
assert mock_sleep.call_count == 2
|
||||
|
||||
# First delay should be around 0.1s + jitter
|
||||
first_delay = mock_sleep.call_args_list[0][0][0]
|
||||
assert 0.1 <= first_delay <= 0.3
|
||||
|
||||
# Second delay should be around 0.2s + jitter
|
||||
second_delay = mock_sleep.call_args_list[1][0][0]
|
||||
assert 0.2 <= second_delay <= 0.4
|
||||
|
||||
def test_concurrent_handler_execution(self):
|
||||
"""Test that multiple handlers can run concurrently without deadlock."""
|
||||
results = []
|
||||
errors = []
|
||||
|
||||
def run_handler():
|
||||
try:
|
||||
with patch(
|
||||
"events.event_handlers.update_provider_when_message_created.db"
|
||||
) as mock_db:
|
||||
mock_query = Mock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.filter.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.update.return_value = 1
|
||||
|
||||
handle(
|
||||
self.mock_message,
|
||||
application_generate_entity=self.mock_generate_entity,
|
||||
)
|
||||
results.append("success")
|
||||
except Exception as e:
|
||||
errors.append(str(e))
|
||||
|
||||
# Run multiple handlers concurrently
|
||||
threads = []
|
||||
for _ in range(5):
|
||||
thread = threading.Thread(target=run_handler)
|
||||
threads.append(thread)
|
||||
thread.start()
|
||||
|
||||
# Wait for all threads to complete
|
||||
for thread in threads:
|
||||
thread.join(timeout=5)
|
||||
|
||||
# Verify all handlers completed successfully
|
||||
assert len(results) == 5
|
||||
assert len(errors) == 0
|
||||
|
||||
def test_performance_stats_tracking(self):
|
||||
"""Test that performance statistics are tracked correctly."""
|
||||
# Reset stats
|
||||
stats = get_update_stats()
|
||||
initial_total = stats["total_updates"]
|
||||
|
||||
with patch(
|
||||
"events.event_handlers.update_provider_when_message_created.db"
|
||||
) as mock_db:
|
||||
mock_query = Mock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.filter.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.update.return_value = 1
|
||||
|
||||
# Call handler
|
||||
handle(
|
||||
self.mock_message, application_generate_entity=self.mock_generate_entity
|
||||
)
|
||||
|
||||
# Check that stats were updated
|
||||
updated_stats = get_update_stats()
|
||||
assert updated_stats["total_updates"] == initial_total + 1
|
||||
assert updated_stats["successful_updates"] >= initial_total + 1
|
||||
|
||||
def test_non_chat_entity_ignored(self):
|
||||
"""Test that non-chat entities are ignored by the handler."""
|
||||
# Create a non-chat entity
|
||||
mock_non_chat_entity = Mock()
|
||||
mock_non_chat_entity.__class__.__name__ = "NonChatEntity"
|
||||
|
||||
with patch(
|
||||
"events.event_handlers.update_provider_when_message_created.db"
|
||||
) as mock_db:
|
||||
# Call handler with non-chat entity
|
||||
handle(self.mock_message, application_generate_entity=mock_non_chat_entity)
|
||||
|
||||
# Verify no database operations were performed
|
||||
assert not mock_db.session.query.called
|
||||
assert not mock_db.session.commit.called
|
||||
|
||||
@patch("events.event_handlers.update_provider_when_message_created.db")
|
||||
def test_quota_calculation_tokens(self, mock_db):
|
||||
"""Test quota calculation for token-based quotas."""
|
||||
# Setup token-based quota
|
||||
self.mock_system_config.current_quota_type = QuotaUnit.TOKENS
|
||||
self.mock_message.answer_tokens = 150
|
||||
|
||||
mock_query = Mock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.filter.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.update.return_value = 1
|
||||
|
||||
# Call handler
|
||||
handle(self.mock_message, application_generate_entity=self.mock_generate_entity)
|
||||
|
||||
# Verify update was called with token count
|
||||
update_calls = mock_query.update.call_args_list
|
||||
|
||||
# Should have at least one call with quota_used update
|
||||
quota_update_found = False
|
||||
for call in update_calls:
|
||||
values = call[0][0] # First argument to update()
|
||||
if "quota_used" in values:
|
||||
quota_update_found = True
|
||||
break
|
||||
|
||||
assert quota_update_found
|
||||
|
||||
@patch("events.event_handlers.update_provider_when_message_created.db")
|
||||
def test_quota_calculation_times(self, mock_db):
|
||||
"""Test quota calculation for times-based quotas."""
|
||||
# Setup times-based quota
|
||||
self.mock_system_config.current_quota_type = QuotaUnit.TIMES
|
||||
|
||||
mock_query = Mock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.filter.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.update.return_value = 1
|
||||
|
||||
# Call handler
|
||||
handle(self.mock_message, application_generate_entity=self.mock_generate_entity)
|
||||
|
||||
# Verify update was called
|
||||
assert mock_query.update.called
|
||||
assert mock_db.session.commit.called
|
||||
|
|
@ -36,6 +36,7 @@ import AccessControl from '@/app/components/app/app-access-control'
|
|||
import { AccessMode } from '@/models/access-control'
|
||||
import { useGlobalPublicStore } from '@/context/global-public-context'
|
||||
import { formatTime } from '@/utils/time'
|
||||
import { useGetUserCanAccessApp } from '@/service/access-control'
|
||||
|
||||
export type AppCardProps = {
|
||||
app: App
|
||||
|
|
@ -190,6 +191,7 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => {
|
|||
}, [onRefresh, mutateApps, setShowAccessControl])
|
||||
|
||||
const Operations = (props: HtmlContentProps) => {
|
||||
const { data: userCanAccessApp, isLoading: isGettingUserCanAccessApp } = useGetUserCanAccessApp({ appId: app?.id, enabled: (!!props?.open && systemFeatures.webapp_auth.enabled) })
|
||||
const onMouseLeave = async () => {
|
||||
props.onClose?.()
|
||||
}
|
||||
|
|
@ -267,10 +269,14 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => {
|
|||
</button>
|
||||
</>
|
||||
)}
|
||||
<Divider className="my-1" />
|
||||
<button className='mx-1 flex h-8 cursor-pointer items-center gap-2 rounded-lg px-3 hover:bg-state-base-hover' onClick={onClickInstalledApp}>
|
||||
<span className='system-sm-regular text-text-secondary'>{t('app.openInExplore')}</span>
|
||||
</button>
|
||||
{
|
||||
(isGettingUserCanAccessApp || !userCanAccessApp?.result) ? null : <>
|
||||
<Divider className="my-1" />
|
||||
<button className='mx-1 flex h-8 cursor-pointer items-center gap-2 rounded-lg px-3 hover:bg-state-base-hover' onClick={onClickInstalledApp}>
|
||||
<span className='system-sm-regular text-text-secondary'>{t('app.openInExplore')}</span>
|
||||
</button>
|
||||
</>
|
||||
}
|
||||
<Divider className="my-1" />
|
||||
{
|
||||
systemFeatures.webapp_auth.enabled && isCurrentWorkspaceEditor && <>
|
||||
|
|
|
|||
|
|
@ -25,10 +25,13 @@ const Layout: FC<{
|
|||
}
|
||||
|
||||
let appCode: string | null = null
|
||||
if (redirectUrl)
|
||||
appCode = redirectUrl?.split('/').pop() || null
|
||||
else
|
||||
if (redirectUrl) {
|
||||
const url = new URL(`${window.location.origin}${decodeURIComponent(redirectUrl)}`)
|
||||
appCode = url.pathname.split('/').pop() || null
|
||||
}
|
||||
else {
|
||||
appCode = pathname.split('/').pop() || null
|
||||
}
|
||||
|
||||
if (!appCode)
|
||||
return
|
||||
|
|
|
|||
|
|
@ -25,7 +25,10 @@ export default function CheckCode() {
|
|||
const redirectUrl = searchParams.get('redirect_url')
|
||||
|
||||
const getAppCodeFromRedirectUrl = useCallback(() => {
|
||||
const appCode = redirectUrl?.split('/').pop()
|
||||
if (!redirectUrl)
|
||||
return null
|
||||
const url = new URL(`${window.location.origin}${decodeURIComponent(redirectUrl)}`)
|
||||
const appCode = url.pathname.split('/').pop()
|
||||
if (!appCode)
|
||||
return null
|
||||
|
||||
|
|
@ -62,7 +65,7 @@ export default function CheckCode() {
|
|||
localStorage.setItem('webapp_access_token', ret.data.access_token)
|
||||
const tokenResp = await fetchAccessToken({ appCode, webAppAccessToken: ret.data.access_token })
|
||||
await setAccessToken(appCode, tokenResp.access_token)
|
||||
router.replace(redirectUrl)
|
||||
router.replace(decodeURIComponent(redirectUrl))
|
||||
}
|
||||
}
|
||||
catch (error) { console.error(error) }
|
||||
|
|
|
|||
|
|
@ -23,7 +23,10 @@ const ExternalMemberSSOAuth = () => {
|
|||
}
|
||||
|
||||
const getAppCodeFromRedirectUrl = useCallback(() => {
|
||||
const appCode = redirectUrl?.split('/').pop()
|
||||
if (!redirectUrl)
|
||||
return null
|
||||
const url = new URL(`${window.location.origin}${decodeURIComponent(redirectUrl)}`)
|
||||
const appCode = url.pathname.split('/').pop()
|
||||
if (!appCode)
|
||||
return null
|
||||
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
'use client'
|
||||
import Link from 'next/link'
|
||||
import { useCallback, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
|
|
@ -33,7 +34,10 @@ export default function MailAndPasswordAuth({ isEmailSetup }: MailAndPasswordAut
|
|||
const redirectUrl = searchParams.get('redirect_url')
|
||||
|
||||
const getAppCodeFromRedirectUrl = useCallback(() => {
|
||||
const appCode = redirectUrl?.split('/').pop()
|
||||
if (!redirectUrl)
|
||||
return null
|
||||
const url = new URL(`${window.location.origin}${decodeURIComponent(redirectUrl)}`)
|
||||
const appCode = url.pathname.split('/').pop()
|
||||
if (!appCode)
|
||||
return null
|
||||
|
||||
|
|
@ -87,7 +91,7 @@ export default function MailAndPasswordAuth({ isEmailSetup }: MailAndPasswordAut
|
|||
localStorage.setItem('webapp_access_token', res.data.access_token)
|
||||
const tokenResp = await fetchAccessToken({ appCode, webAppAccessToken: res.data.access_token })
|
||||
await setAccessToken(appCode, tokenResp.access_token)
|
||||
router.replace(redirectUrl)
|
||||
router.replace(decodeURIComponent(redirectUrl))
|
||||
}
|
||||
else {
|
||||
Toast.notify({
|
||||
|
|
|
|||
|
|
@ -23,7 +23,10 @@ const SSOAuth: FC<SSOAuthProps> = ({
|
|||
|
||||
const redirectUrl = searchParams.get('redirect_url')
|
||||
const getAppCodeFromRedirectUrl = useCallback(() => {
|
||||
const appCode = redirectUrl?.split('/').pop()
|
||||
if (!redirectUrl)
|
||||
return null
|
||||
const url = new URL(`${window.location.origin}${decodeURIComponent(redirectUrl)}`)
|
||||
const appCode = url.pathname.split('/').pop()
|
||||
if (!appCode)
|
||||
return null
|
||||
|
||||
|
|
|
|||
|
|
@ -46,7 +46,10 @@ const WebSSOForm: FC = () => {
|
|||
}
|
||||
|
||||
const getAppCodeFromRedirectUrl = useCallback(() => {
|
||||
const appCode = redirectUrl?.split('/').pop()
|
||||
if (!redirectUrl)
|
||||
return null
|
||||
const url = new URL(`${window.location.origin}${decodeURIComponent(redirectUrl)}`)
|
||||
const appCode = url.pathname.split('/').pop()
|
||||
if (!appCode)
|
||||
return null
|
||||
|
||||
|
|
@ -63,20 +66,20 @@ const WebSSOForm: FC = () => {
|
|||
localStorage.setItem('webapp_access_token', tokenFromUrl)
|
||||
const tokenResp = await fetchAccessToken({ appCode, webAppAccessToken: tokenFromUrl })
|
||||
await setAccessToken(appCode, tokenResp.access_token)
|
||||
router.replace(redirectUrl)
|
||||
router.replace(decodeURIComponent(redirectUrl))
|
||||
return
|
||||
}
|
||||
if (appCode && redirectUrl && localStorage.getItem('webapp_access_token')) {
|
||||
const tokenResp = await fetchAccessToken({ appCode, webAppAccessToken: localStorage.getItem('webapp_access_token') })
|
||||
await setAccessToken(appCode, tokenResp.access_token)
|
||||
router.replace(redirectUrl)
|
||||
router.replace(decodeURIComponent(redirectUrl))
|
||||
}
|
||||
})()
|
||||
}, [getAppCodeFromRedirectUrl, redirectUrl, router, tokenFromUrl, message])
|
||||
|
||||
useEffect(() => {
|
||||
if (webAppAccessMode && webAppAccessMode === AccessMode.PUBLIC && redirectUrl)
|
||||
router.replace(redirectUrl)
|
||||
router.replace(decodeURIComponent(redirectUrl))
|
||||
}, [webAppAccessMode, router, redirectUrl])
|
||||
|
||||
if (tokenFromUrl) {
|
||||
|
|
|
|||
|
|
@ -256,7 +256,7 @@ const AppInfo = ({ expand, onlyShowDetail = false, openState = false, onDetailEx
|
|||
</div>
|
||||
{/* description */}
|
||||
{appDetail.description && (
|
||||
<div className='system-xs-regular overflow-wrap-anywhere w-full max-w-full whitespace-normal break-words text-text-tertiary'>{appDetail.description}</div>
|
||||
<div className='system-xs-regular overflow-wrap-anywhere max-h-[105px] w-full max-w-full overflow-y-auto whitespace-normal break-words text-text-tertiary'>{appDetail.description}</div>
|
||||
)}
|
||||
{/* operations */}
|
||||
<div className='flex flex-wrap items-center gap-1 self-stretch'>
|
||||
|
|
|
|||
|
|
@ -80,6 +80,8 @@ import {
|
|||
import PluginDependency from '@/app/components/workflow/plugin-dependency'
|
||||
import { supportFunctionCall } from '@/utils/tool-call'
|
||||
import { MittProvider } from '@/context/mitt-context'
|
||||
import { fetchAndMergeValidCompletionParams } from '@/utils/completion-params'
|
||||
import Toast from '@/app/components/base/toast'
|
||||
|
||||
type PublishConfig = {
|
||||
modelConfig: ModelConfig
|
||||
|
|
@ -453,7 +455,21 @@ const Configuration: FC = () => {
|
|||
...visionConfig,
|
||||
enabled: supportVision,
|
||||
}, true)
|
||||
setCompletionParams({})
|
||||
|
||||
try {
|
||||
const { params: filtered, removedDetails } = await fetchAndMergeValidCompletionParams(
|
||||
provider,
|
||||
modelId,
|
||||
completionParams,
|
||||
)
|
||||
if (Object.keys(removedDetails).length)
|
||||
Toast.notify({ type: 'warning', message: `${t('common.modelProvider.parametersInvalidRemoved')}: ${Object.entries(removedDetails).map(([k, reason]) => `${k} (${reason})`).join(', ')}` })
|
||||
setCompletionParams(filtered)
|
||||
}
|
||||
catch (e) {
|
||||
Toast.notify({ type: 'error', message: t('common.error') })
|
||||
setCompletionParams({})
|
||||
}
|
||||
}
|
||||
|
||||
const isShowVisionConfig = !!currModel?.features?.includes(ModelFeatureEnum.vision)
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue