diff --git a/api/README.md b/api/README.md index 5ecf92a4f0..e75ea3d354 100644 --- a/api/README.md +++ b/api/README.md @@ -80,10 +80,10 @@ 1. If you need to handle and debug the async tasks (e.g. dataset importing and documents indexing), please start the worker service. ```bash -uv run celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation +uv run celery -A app.celery worker -P gevent -c 2 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation ``` -Addition, if you want to debug the celery scheduled tasks, you can use the following command in another terminal: +Additionally, if you want to debug the celery scheduled tasks, you can run the following command in another terminal to start the beat service: ```bash uv run celery -A app.celery beat diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index 2affbd6a42..60eedd2197 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -1,4 +1,5 @@ -import flask_restx +from typing import Any, cast + from flask import request from flask_login import current_user from flask_restx import Resource, fields, marshal, marshal_with, reqparse @@ -31,12 +32,13 @@ from fields.dataset_fields import dataset_detail_fields, dataset_query_detail_fi from fields.document_fields import document_status_fields from libs.login import login_required from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile +from models.account import Account from models.dataset import DatasetPermissionEnum from models.provider_ids import ModelProviderID from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService -def _validate_name(name): +def _validate_name(name: str) -> str: if not name or len(name) < 1 or len(name) > 40: raise ValueError("Name must be between 1 to 40 characters.") return name @@ -92,7 +94,7 @@ class DatasetListApi(Resource): for embedding_model in embedding_models: model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}") - data = marshal(datasets, dataset_detail_fields) + data = cast(list[dict[str, Any]], marshal(datasets, dataset_detail_fields)) for item in data: # convert embedding_model_provider to plugin standard format if item["indexing_technique"] == "high_quality" and item["embedding_model_provider"]: @@ -192,7 +194,7 @@ class DatasetListApi(Resource): name=args["name"], description=args["description"], indexing_technique=args["indexing_technique"], - account=current_user, + account=cast(Account, current_user), permission=DatasetPermissionEnum.ONLY_ME, provider=args["provider"], external_knowledge_api_id=args["external_knowledge_api_id"], @@ -224,7 +226,7 @@ class DatasetApi(Resource): DatasetService.check_dataset_permission(dataset, current_user) except services.errors.account.NoPermissionError as e: raise Forbidden(str(e)) - data = marshal(dataset, dataset_detail_fields) + data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields)) if dataset.indexing_technique == "high_quality": if dataset.embedding_model_provider: provider_id = ModelProviderID(dataset.embedding_model_provider) @@ -369,7 +371,7 @@ class DatasetApi(Resource): if dataset is None: raise NotFound("Dataset not found.") - result_data = marshal(dataset, dataset_detail_fields) + result_data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields)) tenant_id = current_user.current_tenant_id if data.get("partial_member_list") and data.get("permission") == "partial_members": @@ -688,7 +690,7 @@ class DatasetApiKeyApi(Resource): ) if current_key_count >= self.max_keys: - flask_restx.abort( + api.abort( 400, message=f"Cannot create more than {self.max_keys} API keys for this resource type.", code="max_keys_exceeded", @@ -733,7 +735,7 @@ class DatasetApiDeleteApi(Resource): ) if key is None: - flask_restx.abort(404, message="API key not found") + api.abort(404, message="API key not found") db.session.query(ApiToken).where(ApiToken.id == api_key_id).delete() db.session.commit() diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index 6aaede0fb3..c5fa2061bf 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -55,6 +55,7 @@ from fields.document_fields import ( from libs.datetime_utils import naive_utc_now from libs.login import login_required from models import Dataset, DatasetProcessRule, Document, DocumentSegment, UploadFile +from models.account import Account from models.dataset import DocumentPipelineExecutionLog from services.dataset_service import DatasetService, DocumentService from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig @@ -418,7 +419,9 @@ class DatasetInitApi(Resource): try: dataset, documents, batch = DocumentService.save_document_without_dataset_id( - tenant_id=current_user.current_tenant_id, knowledge_config=knowledge_config, account=current_user + tenant_id=current_user.current_tenant_id, + knowledge_config=knowledge_config, + account=cast(Account, current_user), ) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) @@ -452,7 +455,7 @@ class DocumentIndexingEstimateApi(DocumentResource): raise DocumentAlreadyFinishedError() data_process_rule = document.dataset_process_rule - data_process_rule_dict = data_process_rule.to_dict() + data_process_rule_dict = data_process_rule.to_dict() if data_process_rule else {} response = {"tokens": 0, "total_price": 0, "currency": "USD", "total_segments": 0, "preview": []} @@ -514,7 +517,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource): if not documents: return {"tokens": 0, "total_price": 0, "currency": "USD", "total_segments": 0, "preview": []}, 200 data_process_rule = documents[0].dataset_process_rule - data_process_rule_dict = data_process_rule.to_dict() + data_process_rule_dict = data_process_rule.to_dict() if data_process_rule else {} extract_settings = [] for document in documents: if document.indexing_status in {"completed", "error"}: @@ -753,7 +756,7 @@ class DocumentApi(DocumentResource): } else: dataset_process_rules = DatasetService.get_process_rules(dataset_id) - document_process_rules = document.dataset_process_rule.to_dict() + document_process_rules = document.dataset_process_rule.to_dict() if document.dataset_process_rule else {} data_source_info = document.data_source_detail_dict response = { "id": document.id, @@ -1073,7 +1076,9 @@ class DocumentRenameApi(DocumentResource): if not current_user.is_dataset_editor: raise Forbidden() dataset = DatasetService.get_dataset(dataset_id) - DatasetService.check_dataset_operator_permission(current_user, dataset) + if not dataset: + raise NotFound("Dataset not found.") + DatasetService.check_dataset_operator_permission(cast(Account, current_user), dataset) parser = reqparse.RequestParser() parser.add_argument("name", type=str, required=True, nullable=False, location="json") args = parser.parse_args() diff --git a/api/controllers/console/datasets/datasets_segments.py b/api/controllers/console/datasets/datasets_segments.py index ba552821d2..9f2805e2c6 100644 --- a/api/controllers/console/datasets/datasets_segments.py +++ b/api/controllers/console/datasets/datasets_segments.py @@ -392,7 +392,12 @@ class DatasetDocumentSegmentBatchImportApi(Resource): # send batch add segments task redis_client.setnx(indexing_cache_key, "waiting") batch_create_segment_to_index_task.delay( - str(job_id), upload_file_id, dataset_id, document_id, current_user.current_tenant_id, current_user.id + str(job_id), + upload_file_id, + dataset_id, + document_id, + current_user.current_tenant_id, + current_user.id, ) except Exception as e: return {"error": str(e)}, 500 @@ -468,7 +473,8 @@ class ChildChunkAddApi(Resource): parser.add_argument("content", type=str, required=True, nullable=False, location="json") args = parser.parse_args() try: - child_chunk = SegmentService.create_child_chunk(args.get("content"), segment, document, dataset) + content = args["content"] + child_chunk = SegmentService.create_child_chunk(content, segment, document, dataset) except ChildChunkIndexingServiceError as e: raise ChildChunkIndexingError(str(e)) return {"data": marshal(child_chunk, child_chunk_fields)}, 200 @@ -557,7 +563,8 @@ class ChildChunkAddApi(Resource): parser.add_argument("chunks", type=list, required=True, nullable=False, location="json") args = parser.parse_args() try: - chunks = [ChildChunkUpdateArgs(**chunk) for chunk in args.get("chunks")] + chunks_data = args["chunks"] + chunks = [ChildChunkUpdateArgs(**chunk) for chunk in chunks_data] child_chunks = SegmentService.update_child_chunks(chunks, segment, document, dataset) except ChildChunkIndexingServiceError as e: raise ChildChunkIndexingError(str(e)) @@ -674,9 +681,8 @@ class ChildChunkUpdateApi(Resource): parser.add_argument("content", type=str, required=True, nullable=False, location="json") args = parser.parse_args() try: - child_chunk = SegmentService.update_child_chunk( - args.get("content"), child_chunk, segment, document, dataset - ) + content = args["content"] + child_chunk = SegmentService.update_child_chunk(content, child_chunk, segment, document, dataset) except ChildChunkIndexingServiceError as e: raise ChildChunkIndexingError(str(e)) return {"data": marshal(child_chunk, child_chunk_fields)}, 200 diff --git a/api/controllers/console/datasets/external.py b/api/controllers/console/datasets/external.py index e8f5a11b41..adf9f53523 100644 --- a/api/controllers/console/datasets/external.py +++ b/api/controllers/console/datasets/external.py @@ -1,3 +1,5 @@ +from typing import cast + from flask import request from flask_login import current_user from flask_restx import Resource, fields, marshal, reqparse @@ -9,13 +11,14 @@ from controllers.console.datasets.error import DatasetNameDuplicateError from controllers.console.wraps import account_initialization_required, setup_required from fields.dataset_fields import dataset_detail_fields from libs.login import login_required +from models.account import Account from services.dataset_service import DatasetService from services.external_knowledge_service import ExternalDatasetService from services.hit_testing_service import HitTestingService from services.knowledge_service import ExternalDatasetTestService -def _validate_name(name): +def _validate_name(name: str) -> str: if not name or len(name) < 1 or len(name) > 100: raise ValueError("Name must be between 1 to 100 characters.") return name @@ -274,7 +277,7 @@ class ExternalKnowledgeHitTestingApi(Resource): response = HitTestingService.external_retrieve( dataset=dataset, query=args["query"], - account=current_user, + account=cast(Account, current_user), external_retrieval_model=args["external_retrieval_model"], metadata_filtering_conditions=args["metadata_filtering_conditions"], ) diff --git a/api/controllers/console/datasets/hit_testing_base.py b/api/controllers/console/datasets/hit_testing_base.py index cfbfc50873..a68e337135 100644 --- a/api/controllers/console/datasets/hit_testing_base.py +++ b/api/controllers/console/datasets/hit_testing_base.py @@ -1,10 +1,11 @@ import logging +from typing import cast from flask_login import current_user from flask_restx import marshal, reqparse from werkzeug.exceptions import Forbidden, InternalServerError, NotFound -import services.dataset_service +import services from controllers.console.app.error import ( CompletionRequestError, ProviderModelCurrentlyNotSupportError, @@ -20,6 +21,7 @@ from core.errors.error import ( ) from core.model_runtime.errors.invoke import InvokeError from fields.hit_testing_fields import hit_testing_record_fields +from models.account import Account from services.dataset_service import DatasetService from services.hit_testing_service import HitTestingService @@ -59,7 +61,7 @@ class DatasetsHitTestingBase: response = HitTestingService.retrieve( dataset=dataset, query=args["query"], - account=current_user, + account=cast(Account, current_user), retrieval_model=args["retrieval_model"], external_retrieval_model=args["external_retrieval_model"], limit=10, diff --git a/api/controllers/console/datasets/metadata.py b/api/controllers/console/datasets/metadata.py index 53dc80eaa5..dc3cd3fce9 100644 --- a/api/controllers/console/datasets/metadata.py +++ b/api/controllers/console/datasets/metadata.py @@ -62,6 +62,7 @@ class DatasetMetadataApi(Resource): parser = reqparse.RequestParser() parser.add_argument("name", type=str, required=True, nullable=False, location="json") args = parser.parse_args() + name = args["name"] dataset_id_str = str(dataset_id) metadata_id_str = str(metadata_id) @@ -70,7 +71,7 @@ class DatasetMetadataApi(Resource): raise NotFound("Dataset not found.") DatasetService.check_dataset_permission(dataset, current_user) - metadata = MetadataService.update_metadata_name(dataset_id_str, metadata_id_str, args.get("name")) + metadata = MetadataService.update_metadata_name(dataset_id_str, metadata_id_str, name) return metadata, 200 @setup_required diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py index 6641911243..3af590afc8 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py @@ -20,13 +20,13 @@ from services.rag_pipeline.rag_pipeline import RagPipelineService logger = logging.getLogger(__name__) -def _validate_name(name): +def _validate_name(name: str) -> str: if not name or len(name) < 1 or len(name) > 40: raise ValueError("Name must be between 1 to 40 characters.") return name -def _validate_description_length(description): +def _validate_description_length(description: str) -> str: if len(description) > 400: raise ValueError("Description cannot exceed 400 characters.") return description @@ -76,7 +76,7 @@ class CustomizedPipelineTemplateApi(Resource): ) parser.add_argument( "description", - type=str, + type=_validate_description_length, nullable=True, required=False, default="", @@ -133,7 +133,7 @@ class PublishCustomizedPipelineTemplateApi(Resource): ) parser.add_argument( "description", - type=str, + type=_validate_description_length, nullable=True, required=False, default="", diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_datasets.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_datasets.py index c741bfbf82..404aa42073 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_datasets.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_datasets.py @@ -1,5 +1,5 @@ -from flask_login import current_user # type: ignore # type: ignore -from flask_restx import Resource, marshal, reqparse # type: ignore +from flask_login import current_user +from flask_restx import Resource, marshal, reqparse from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden @@ -20,18 +20,6 @@ from services.entities.knowledge_entities.rag_pipeline_entities import IconInfo, from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService -def _validate_name(name): - if not name or len(name) < 1 or len(name) > 40: - raise ValueError("Name must be between 1 to 40 characters.") - return name - - -def _validate_description_length(description): - if len(description) > 400: - raise ValueError("Description cannot exceed 400 characters.") - return description - - @console_ns.route("/rag/pipeline/dataset") class CreateRagPipelineDatasetApi(Resource): @setup_required diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py index 38f75402a8..bef6bfd13e 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py @@ -1,5 +1,5 @@ import logging -from typing import Any, NoReturn +from typing import NoReturn from flask import Response from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse @@ -11,14 +11,12 @@ from controllers.console.app.error import ( DraftWorkflowNotExist, ) from controllers.console.app.workflow_draft_variable import ( - _WORKFLOW_DRAFT_VARIABLE_FIELDS, - _WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS, + _WORKFLOW_DRAFT_VARIABLE_FIELDS, # type: ignore[private-usage] + _WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS, # type: ignore[private-usage] ) from controllers.console.datasets.wraps import get_rag_pipeline from controllers.console.wraps import account_initialization_required, setup_required from controllers.web.error import InvalidArgumentError, NotFoundError -from core.variables.segment_group import SegmentGroup -from core.variables.segments import ArrayFileSegment, FileSegment, Segment from core.variables.types import SegmentType from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID from extensions.ext_database import db @@ -34,32 +32,6 @@ from services.workflow_draft_variable_service import WorkflowDraftVariableList, logger = logging.getLogger(__name__) -def _convert_values_to_json_serializable_object(value: Segment) -> Any: - if isinstance(value, FileSegment): - return value.value.model_dump() - elif isinstance(value, ArrayFileSegment): - return [i.model_dump() for i in value.value] - elif isinstance(value, SegmentGroup): - return [_convert_values_to_json_serializable_object(i) for i in value.value] - else: - return value.value - - -def _serialize_var_value(variable: WorkflowDraftVariable) -> Any: - value = variable.get_value() - # create a copy of the value to avoid affecting the model cache. - value = value.model_copy(deep=True) - # Refresh the url signature before returning it to client. - if isinstance(value, FileSegment): - file = value.value - file.remote_url = file.generate_url() - elif isinstance(value, ArrayFileSegment): - files = value.value - for file in files: - file.remote_url = file.generate_url() - return _convert_values_to_json_serializable_object(value) - - def _create_pagination_parser(): parser = reqparse.RequestParser() parser.add_argument( @@ -104,7 +76,7 @@ def _api_prerequisite(f): @account_initialization_required @get_rag_pipeline def wrapper(*args, **kwargs): - if not isinstance(current_user, Account) or not current_user.is_editor: + if not isinstance(current_user, Account) or not current_user.has_edit_permission: raise Forbidden() return f(*args, **kwargs) diff --git a/api/controllers/console/tag/tags.py b/api/controllers/console/tag/tags.py index da236ee5af..3d29b3ee61 100644 --- a/api/controllers/console/tag/tags.py +++ b/api/controllers/console/tag/tags.py @@ -3,7 +3,7 @@ from flask_login import current_user from flask_restx import Resource, marshal_with, reqparse from werkzeug.exceptions import Forbidden -from controllers.console import api +from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, setup_required from fields.tag_fields import dataset_tag_fields from libs.login import login_required @@ -17,6 +17,7 @@ def _validate_name(name): return name +@console_ns.route("/tags") class TagListApi(Resource): @setup_required @login_required @@ -52,6 +53,7 @@ class TagListApi(Resource): return response, 200 +@console_ns.route("/tags/") class TagUpdateDeleteApi(Resource): @setup_required @login_required @@ -89,6 +91,7 @@ class TagUpdateDeleteApi(Resource): return 204 +@console_ns.route("/tag-bindings/create") class TagBindingCreateApi(Resource): @setup_required @login_required @@ -114,6 +117,7 @@ class TagBindingCreateApi(Resource): return {"result": "success"}, 200 +@console_ns.route("/tag-bindings/remove") class TagBindingDeleteApi(Resource): @setup_required @login_required @@ -133,9 +137,3 @@ class TagBindingDeleteApi(Resource): TagService.delete_tag_binding(args) return {"result": "success"}, 200 - - -api.add_resource(TagListApi, "/tags") -api.add_resource(TagUpdateDeleteApi, "/tags/") -api.add_resource(TagBindingCreateApi, "/tag-bindings/create") -api.add_resource(TagBindingDeleteApi, "/tag-bindings/remove") diff --git a/api/controllers/service_api/dataset/dataset.py b/api/controllers/service_api/dataset/dataset.py index 6a70345f7c..72ab05cec0 100644 --- a/api/controllers/service_api/dataset/dataset.py +++ b/api/controllers/service_api/dataset/dataset.py @@ -1,10 +1,10 @@ -from typing import Literal +from typing import Any, Literal, cast from flask import request from flask_restx import marshal, reqparse from werkzeug.exceptions import Forbidden, NotFound -import services.dataset_service +import services from controllers.service_api import service_api_ns from controllers.service_api.dataset.error import DatasetInUseError, DatasetNameDuplicateError, InvalidActionError from controllers.service_api.wraps import ( @@ -254,19 +254,21 @@ class DatasetListApi(DatasetApiResource): """Resource for creating datasets.""" args = dataset_create_parser.parse_args() - if args.get("embedding_model_provider"): - DatasetService.check_embedding_model_setting( - tenant_id, args.get("embedding_model_provider"), args.get("embedding_model") - ) + embedding_model_provider = args.get("embedding_model_provider") + embedding_model = args.get("embedding_model") + if embedding_model_provider and embedding_model: + DatasetService.check_embedding_model_setting(tenant_id, embedding_model_provider, embedding_model) + + retrieval_model = args.get("retrieval_model") if ( - args.get("retrieval_model") - and args.get("retrieval_model").get("reranking_model") - and args.get("retrieval_model").get("reranking_model").get("reranking_provider_name") + retrieval_model + and retrieval_model.get("reranking_model") + and retrieval_model.get("reranking_model").get("reranking_provider_name") ): DatasetService.check_reranking_model_setting( tenant_id, - args.get("retrieval_model").get("reranking_model").get("reranking_provider_name"), - args.get("retrieval_model").get("reranking_model").get("reranking_model_name"), + retrieval_model.get("reranking_model").get("reranking_provider_name"), + retrieval_model.get("reranking_model").get("reranking_model_name"), ) try: @@ -317,7 +319,7 @@ class DatasetApi(DatasetApiResource): DatasetService.check_dataset_permission(dataset, current_user) except services.errors.account.NoPermissionError as e: raise Forbidden(str(e)) - data = marshal(dataset, dataset_detail_fields) + data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields)) # check embedding setting provider_manager = ProviderManager() assert isinstance(current_user, Account) @@ -331,8 +333,8 @@ class DatasetApi(DatasetApiResource): for embedding_model in embedding_models: model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}") - if data["indexing_technique"] == "high_quality": - item_model = f"{data['embedding_model']}:{data['embedding_model_provider']}" + if data.get("indexing_technique") == "high_quality": + item_model = f"{data.get('embedding_model')}:{data.get('embedding_model_provider')}" if item_model in model_names: data["embedding_available"] = True else: @@ -341,7 +343,9 @@ class DatasetApi(DatasetApiResource): data["embedding_available"] = True # force update search method to keyword_search if indexing_technique is economic - data["retrieval_model_dict"]["search_method"] = "keyword_search" + retrieval_model_dict = data.get("retrieval_model_dict") + if retrieval_model_dict: + retrieval_model_dict["search_method"] = "keyword_search" if data.get("permission") == "partial_members": part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str) @@ -372,19 +376,24 @@ class DatasetApi(DatasetApiResource): data = request.get_json() # check embedding model setting - if data.get("indexing_technique") == "high_quality" or data.get("embedding_model_provider"): - DatasetService.check_embedding_model_setting( - dataset.tenant_id, data.get("embedding_model_provider"), data.get("embedding_model") - ) + embedding_model_provider = data.get("embedding_model_provider") + embedding_model = data.get("embedding_model") + if data.get("indexing_technique") == "high_quality" or embedding_model_provider: + if embedding_model_provider and embedding_model: + DatasetService.check_embedding_model_setting( + dataset.tenant_id, embedding_model_provider, embedding_model + ) + + retrieval_model = data.get("retrieval_model") if ( - data.get("retrieval_model") - and data.get("retrieval_model").get("reranking_model") - and data.get("retrieval_model").get("reranking_model").get("reranking_provider_name") + retrieval_model + and retrieval_model.get("reranking_model") + and retrieval_model.get("reranking_model").get("reranking_provider_name") ): DatasetService.check_reranking_model_setting( dataset.tenant_id, - data.get("retrieval_model").get("reranking_model").get("reranking_provider_name"), - data.get("retrieval_model").get("reranking_model").get("reranking_model_name"), + retrieval_model.get("reranking_model").get("reranking_provider_name"), + retrieval_model.get("reranking_model").get("reranking_model_name"), ) # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator @@ -397,7 +406,7 @@ class DatasetApi(DatasetApiResource): if dataset is None: raise NotFound("Dataset not found.") - result_data = marshal(dataset, dataset_detail_fields) + result_data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields)) assert isinstance(current_user, Account) tenant_id = current_user.current_tenant_id @@ -591,9 +600,10 @@ class DatasetTagsApi(DatasetApiResource): args = tag_update_parser.parse_args() args["type"] = "knowledge" - tag = TagService.update_tags(args, args.get("tag_id")) + tag_id = args["tag_id"] + tag = TagService.update_tags(args, tag_id) - binding_count = TagService.get_tag_binding_count(args.get("tag_id")) + binding_count = TagService.get_tag_binding_count(tag_id) response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count} @@ -616,7 +626,7 @@ class DatasetTagsApi(DatasetApiResource): if not current_user.has_edit_permission: raise Forbidden() args = tag_delete_parser.parse_args() - TagService.delete_tag(args.get("tag_id")) + TagService.delete_tag(args["tag_id"]) return 204 diff --git a/api/controllers/service_api/dataset/document.py b/api/controllers/service_api/dataset/document.py index e01bc8940c..c1122acd7b 100644 --- a/api/controllers/service_api/dataset/document.py +++ b/api/controllers/service_api/dataset/document.py @@ -108,19 +108,21 @@ class DocumentAddByTextApi(DatasetApiResource): if text is None or name is None: raise ValueError("Both 'text' and 'name' must be non-null values.") - if args.get("embedding_model_provider"): - DatasetService.check_embedding_model_setting( - tenant_id, args.get("embedding_model_provider"), args.get("embedding_model") - ) + embedding_model_provider = args.get("embedding_model_provider") + embedding_model = args.get("embedding_model") + if embedding_model_provider and embedding_model: + DatasetService.check_embedding_model_setting(tenant_id, embedding_model_provider, embedding_model) + + retrieval_model = args.get("retrieval_model") if ( - args.get("retrieval_model") - and args.get("retrieval_model").get("reranking_model") - and args.get("retrieval_model").get("reranking_model").get("reranking_provider_name") + retrieval_model + and retrieval_model.get("reranking_model") + and retrieval_model.get("reranking_model").get("reranking_provider_name") ): DatasetService.check_reranking_model_setting( tenant_id, - args.get("retrieval_model").get("reranking_model").get("reranking_provider_name"), - args.get("retrieval_model").get("reranking_model").get("reranking_model_name"), + retrieval_model.get("reranking_model").get("reranking_provider_name"), + retrieval_model.get("reranking_model").get("reranking_model_name"), ) if not current_user: @@ -187,15 +189,16 @@ class DocumentUpdateByTextApi(DatasetApiResource): if not dataset: raise ValueError("Dataset does not exist.") + retrieval_model = args.get("retrieval_model") if ( - args.get("retrieval_model") - and args.get("retrieval_model").get("reranking_model") - and args.get("retrieval_model").get("reranking_model").get("reranking_provider_name") + retrieval_model + and retrieval_model.get("reranking_model") + and retrieval_model.get("reranking_model").get("reranking_provider_name") ): DatasetService.check_reranking_model_setting( tenant_id, - args.get("retrieval_model").get("reranking_model").get("reranking_provider_name"), - args.get("retrieval_model").get("reranking_model").get("reranking_model_name"), + retrieval_model.get("reranking_model").get("reranking_provider_name"), + retrieval_model.get("reranking_model").get("reranking_model_name"), ) # indexing_technique is already set in dataset since this is an update diff --git a/api/controllers/service_api/dataset/metadata.py b/api/controllers/service_api/dataset/metadata.py index c6032048e6..e01659dc68 100644 --- a/api/controllers/service_api/dataset/metadata.py +++ b/api/controllers/service_api/dataset/metadata.py @@ -106,7 +106,7 @@ class DatasetMetadataServiceApi(DatasetApiResource): raise NotFound("Dataset not found.") DatasetService.check_dataset_permission(dataset, current_user) - metadata = MetadataService.update_metadata_name(dataset_id_str, metadata_id_str, args.get("name")) + metadata = MetadataService.update_metadata_name(dataset_id_str, metadata_id_str, args["name"]) return marshal(metadata, dataset_metadata_fields), 200 @service_api_ns.doc("delete_dataset_metadata") diff --git a/api/core/tools/utils/dataset_retriever/dataset_retriever_base_tool.py b/api/core/tools/utils/dataset_retriever/dataset_retriever_base_tool.py index ac2967d0c1..dd0b4bedcf 100644 --- a/api/core/tools/utils/dataset_retriever/dataset_retriever_base_tool.py +++ b/api/core/tools/utils/dataset_retriever/dataset_retriever_base_tool.py @@ -18,6 +18,10 @@ class DatasetRetrieverBaseTool(BaseModel, ABC): retriever_from: str model_config = ConfigDict(arbitrary_types_allowed=True) + def run(self, query: str) -> str: + """Use the tool.""" + return self._run(query) + @abstractmethod def _run(self, query: str) -> str: """Use the tool. diff --git a/api/core/tools/utils/dataset_retriever_tool.py b/api/core/tools/utils/dataset_retriever_tool.py index a62d419243..fca6e6f1c7 100644 --- a/api/core/tools/utils/dataset_retriever_tool.py +++ b/api/core/tools/utils/dataset_retriever_tool.py @@ -124,7 +124,7 @@ class DatasetRetrieverTool(Tool): yield self.create_text_message(text="please input query") else: # invoke dataset retriever tool - result = self.retrieval_tool._run(query=query) + result = self.retrieval_tool.run(query=query) yield self.create_text_message(text=result) def validate_credentials( diff --git a/api/core/tools/utils/parser.py b/api/core/tools/utils/parser.py index 2e306db6c7..fcb1d325af 100644 --- a/api/core/tools/utils/parser.py +++ b/api/core/tools/utils/parser.py @@ -2,6 +2,7 @@ import re from json import dumps as json_dumps from json import loads as json_loads from json.decoder import JSONDecodeError +from typing import Any from flask import request from requests import get @@ -127,34 +128,34 @@ class ApiBasedToolSchemaParser: if "allOf" in prop_dict: del prop_dict["allOf"] - # parse body parameters - if "schema" in interface["operation"]["requestBody"]["content"][content_type]: - body_schema = interface["operation"]["requestBody"]["content"][content_type]["schema"] - required = body_schema.get("required", []) - properties = body_schema.get("properties", {}) - for name, property in properties.items(): - tool = ToolParameter( - name=name, - label=I18nObject(en_US=name, zh_Hans=name), - human_description=I18nObject( - en_US=property.get("description", ""), zh_Hans=property.get("description", "") - ), - type=ToolParameter.ToolParameterType.STRING, - required=name in required, - form=ToolParameter.ToolParameterForm.LLM, - llm_description=property.get("description", ""), - default=property.get("default", None), - placeholder=I18nObject( - en_US=property.get("description", ""), zh_Hans=property.get("description", "") - ), - ) + # parse body parameters + if "schema" in interface["operation"]["requestBody"]["content"][content_type]: + body_schema = interface["operation"]["requestBody"]["content"][content_type]["schema"] + required = body_schema.get("required", []) + properties = body_schema.get("properties", {}) + for name, property in properties.items(): + tool = ToolParameter( + name=name, + label=I18nObject(en_US=name, zh_Hans=name), + human_description=I18nObject( + en_US=property.get("description", ""), zh_Hans=property.get("description", "") + ), + type=ToolParameter.ToolParameterType.STRING, + required=name in required, + form=ToolParameter.ToolParameterForm.LLM, + llm_description=property.get("description", ""), + default=property.get("default", None), + placeholder=I18nObject( + en_US=property.get("description", ""), zh_Hans=property.get("description", "") + ), + ) - # check if there is a type - typ = ApiBasedToolSchemaParser._get_tool_parameter_type(property) - if typ: - tool.type = typ + # check if there is a type + typ = ApiBasedToolSchemaParser._get_tool_parameter_type(property) + if typ: + tool.type = typ - parameters.append(tool) + parameters.append(tool) # check if parameters is duplicated parameters_count = {} @@ -241,7 +242,9 @@ class ApiBasedToolSchemaParser: return ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(openapi, extra_info=extra_info, warning=warning) @staticmethod - def parse_swagger_to_openapi(swagger: dict, extra_info: dict | None = None, warning: dict | None = None): + def parse_swagger_to_openapi( + swagger: dict, extra_info: dict | None = None, warning: dict | None = None + ) -> dict[str, Any]: warning = warning or {} """ parse swagger to openapi @@ -257,7 +260,7 @@ class ApiBasedToolSchemaParser: if len(servers) == 0: raise ToolApiSchemaError("No server found in the swagger yaml.") - openapi = { + converted_openapi: dict[str, Any] = { "openapi": "3.0.0", "info": { "title": info.get("title", "Swagger"), @@ -275,7 +278,7 @@ class ApiBasedToolSchemaParser: # convert paths for path, path_item in swagger["paths"].items(): - openapi["paths"][path] = {} + converted_openapi["paths"][path] = {} for method, operation in path_item.items(): if "operationId" not in operation: raise ToolApiSchemaError(f"No operationId found in operation {method} {path}.") @@ -286,7 +289,7 @@ class ApiBasedToolSchemaParser: if warning is not None: warning["missing_summary"] = f"No summary or description found in operation {method} {path}." - openapi["paths"][path][method] = { + converted_openapi["paths"][path][method] = { "operationId": operation["operationId"], "summary": operation.get("summary", ""), "description": operation.get("description", ""), @@ -295,13 +298,14 @@ class ApiBasedToolSchemaParser: } if "requestBody" in operation: - openapi["paths"][path][method]["requestBody"] = operation["requestBody"] + converted_openapi["paths"][path][method]["requestBody"] = operation["requestBody"] # convert definitions - for name, definition in swagger["definitions"].items(): - openapi["components"]["schemas"][name] = definition + if "definitions" in swagger: + for name, definition in swagger["definitions"].items(): + converted_openapi["components"]["schemas"][name] = definition - return openapi + return converted_openapi @staticmethod def parse_openai_plugin_json_to_tool_bundle( diff --git a/api/core/workflow/entities/variable_pool.py b/api/core/workflow/entities/variable_pool.py index 8ceabde7e6..2dc00fd70b 100644 --- a/api/core/workflow/entities/variable_pool.py +++ b/api/core/workflow/entities/variable_pool.py @@ -184,11 +184,22 @@ class VariablePool(BaseModel): """Extract the actual value from an ObjectSegment.""" return obj.value if isinstance(obj, ObjectSegment) else obj - def _get_nested_attribute(self, obj: Mapping[str, Any], attr: str): - """Get a nested attribute from a dictionary-like object.""" - if not isinstance(obj, dict): + def _get_nested_attribute(self, obj: Mapping[str, Any], attr: str) -> Segment | None: + """ + Get a nested attribute from a dictionary-like object. + + Args: + obj: The dictionary-like object to search. + attr: The key to look up. + + Returns: + Segment | None: + The corresponding Segment built from the attribute value if the key exists, + otherwise None. + """ + if not isinstance(obj, dict) or attr not in obj: return None - return obj.get(attr) + return variable_factory.build_segment(obj.get(attr)) def remove(self, selector: Sequence[str], /): """ diff --git a/api/core/workflow/nodes/iteration/iteration_node.py b/api/core/workflow/nodes/iteration/iteration_node.py index 1a417b5739..a05a6b1b96 100644 --- a/api/core/workflow/nodes/iteration/iteration_node.py +++ b/api/core/workflow/nodes/iteration/iteration_node.py @@ -10,6 +10,8 @@ from typing_extensions import TypeIs from core.variables import IntegerVariable, NoneSegment from core.variables.segments import ArrayAnySegment, ArraySegment +from core.variables.variables import VariableUnion +from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID from core.workflow.entities import VariablePool from core.workflow.enums import ( ErrorStrategy, @@ -217,6 +219,13 @@ class IterationNode(Node): graph_engine=graph_engine, ) + # Sync conversation variables after each iteration completes + self._sync_conversation_variables_from_snapshot( + self._extract_conversation_variable_snapshot( + variable_pool=graph_engine.graph_runtime_state.variable_pool + ) + ) + # Update the total tokens from this iteration self.graph_runtime_state.total_tokens += graph_engine.graph_runtime_state.total_tokens iter_run_map[str(index)] = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds() @@ -235,7 +244,10 @@ class IterationNode(Node): with ThreadPoolExecutor(max_workers=max_workers) as executor: # Submit all iteration tasks - future_to_index: dict[Future[tuple[datetime, list[GraphNodeEventBase], object | None, int]], int] = {} + future_to_index: dict[ + Future[tuple[datetime, list[GraphNodeEventBase], object | None, int, dict[str, VariableUnion]]], + int, + ] = {} for index, item in enumerate(iterator_list_value): yield IterationNextEvent(index=index) future = executor.submit( @@ -252,7 +264,7 @@ class IterationNode(Node): index = future_to_index[future] try: result = future.result() - iter_start_at, events, output_value, tokens_used = result + iter_start_at, events, output_value, tokens_used, conversation_snapshot = result # Update outputs at the correct index outputs[index] = output_value @@ -264,6 +276,9 @@ class IterationNode(Node): self.graph_runtime_state.total_tokens += tokens_used iter_run_map[str(index)] = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds() + # Sync conversation variables after iteration completion + self._sync_conversation_variables_from_snapshot(conversation_snapshot) + except Exception as e: # Handle errors based on error_handle_mode match self._node_data.error_handle_mode: @@ -288,7 +303,7 @@ class IterationNode(Node): item: object, flask_app: Flask, context_vars: contextvars.Context, - ) -> tuple[datetime, list[GraphNodeEventBase], object | None, int]: + ) -> tuple[datetime, list[GraphNodeEventBase], object | None, int, dict[str, VariableUnion]]: """Execute a single iteration in parallel mode and return results.""" with preserve_flask_contexts(flask_app=flask_app, context_vars=context_vars): iter_start_at = datetime.now(UTC).replace(tzinfo=None) @@ -307,8 +322,17 @@ class IterationNode(Node): # Get the output value from the temporary outputs list output_value = outputs_temp[0] if outputs_temp else None + conversation_snapshot = self._extract_conversation_variable_snapshot( + variable_pool=graph_engine.graph_runtime_state.variable_pool + ) - return iter_start_at, events, output_value, graph_engine.graph_runtime_state.total_tokens + return ( + iter_start_at, + events, + output_value, + graph_engine.graph_runtime_state.total_tokens, + conversation_snapshot, + ) def _handle_iteration_success( self, @@ -430,6 +454,23 @@ class IterationNode(Node): return variable_mapping + def _extract_conversation_variable_snapshot(self, *, variable_pool: VariablePool) -> dict[str, VariableUnion]: + conversation_variables = variable_pool.variable_dictionary.get(CONVERSATION_VARIABLE_NODE_ID, {}) + return {name: variable.model_copy(deep=True) for name, variable in conversation_variables.items()} + + def _sync_conversation_variables_from_snapshot(self, snapshot: dict[str, VariableUnion]) -> None: + parent_pool = self.graph_runtime_state.variable_pool + parent_conversations = parent_pool.variable_dictionary.get(CONVERSATION_VARIABLE_NODE_ID, {}) + + current_keys = set(parent_conversations.keys()) + snapshot_keys = set(snapshot.keys()) + + for removed_key in current_keys - snapshot_keys: + parent_pool.remove((CONVERSATION_VARIABLE_NODE_ID, removed_key)) + + for name, variable in snapshot.items(): + parent_pool.add((CONVERSATION_VARIABLE_NODE_ID, name), variable) + def _append_iteration_info_to_event( self, event: GraphNodeEventBase, diff --git a/api/extensions/ext_celery.py b/api/extensions/ext_celery.py index febf744369..ccb1c78f4e 100644 --- a/api/extensions/ext_celery.py +++ b/api/extensions/ext_celery.py @@ -147,6 +147,7 @@ def init_app(app: DifyApp) -> Celery: } if dify_config.ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK and dify_config.MARKETPLACE_ENABLED: imports.append("schedule.check_upgradable_plugin_task") + imports.append("tasks.process_tenant_plugin_autoupgrade_check_task") beat_schedule["check_upgradable_plugin_task"] = { "task": "schedule.check_upgradable_plugin_task.check_upgradable_plugin_task", "schedule": crontab(minute="*/15"), diff --git a/api/factories/variable_factory.py b/api/factories/variable_factory.py index 2104e66254..494194369a 100644 --- a/api/factories/variable_factory.py +++ b/api/factories/variable_factory.py @@ -142,6 +142,8 @@ def build_segment(value: Any, /) -> Segment: # below if value is None: return NoneSegment() + if isinstance(value, Segment): + return value if isinstance(value, str): return StringSegment(value=value) if isinstance(value, bool): diff --git a/api/libs/external_api.py b/api/libs/external_api.py index cf91b0117f..25a82f8a96 100644 --- a/api/libs/external_api.py +++ b/api/libs/external_api.py @@ -94,7 +94,7 @@ def register_external_error_handlers(api: Api): got_request_exception.send(current_app, exception=e) status_code = 500 - data = getattr(e, "data", {"message": http_status_message(status_code)}) + data: dict[str, Any] = getattr(e, "data", {"message": http_status_message(status_code)}) # 🔒 Normalize non-mapping data (e.g., if someone set e.data = Response) if not isinstance(data, dict): diff --git a/api/libs/gmpy2_pkcs10aep_cipher.py b/api/libs/gmpy2_pkcs10aep_cipher.py index 9759156c0f..fc38d51005 100644 --- a/api/libs/gmpy2_pkcs10aep_cipher.py +++ b/api/libs/gmpy2_pkcs10aep_cipher.py @@ -27,7 +27,7 @@ import gmpy2 # type: ignore from Crypto import Random from Crypto.Signature.pss import MGF1 from Crypto.Util.number import bytes_to_long, ceil_div, long_to_bytes -from Crypto.Util.py3compat import _copy_bytes, bord +from Crypto.Util.py3compat import bord from Crypto.Util.strxor import strxor @@ -72,7 +72,7 @@ class PKCS1OAepCipher: else: self._mgf = lambda x, y: MGF1(x, y, self._hashObj) - self._label = _copy_bytes(None, None, label) + self._label = bytes(label) self._randfunc = randfunc def can_encrypt(self): @@ -120,7 +120,7 @@ class PKCS1OAepCipher: # Step 2b ps = b"\x00" * ps_len # Step 2c - db = lHash + ps + b"\x01" + _copy_bytes(None, None, message) + db = lHash + ps + b"\x01" + bytes(message) # Step 2d ros = self._randfunc(hLen) # Step 2e diff --git a/api/libs/sendgrid.py b/api/libs/sendgrid.py index ecc4b3fb98..a270fa70fa 100644 --- a/api/libs/sendgrid.py +++ b/api/libs/sendgrid.py @@ -14,7 +14,7 @@ class SendGridClient: def send(self, mail: dict): logger.debug("Sending email with SendGrid") - + _to = "" try: _to = mail["to"] @@ -28,7 +28,7 @@ class SendGridClient: content = Content("text/html", mail["html"]) sg_mail = Mail(from_email, to_email, subject, content) mail_json = sg_mail.get() - response = sg.client.mail.send.post(request_body=mail_json) # ty: ignore [call-non-callable] + response = sg.client.mail.send.post(request_body=mail_json) # type: ignore logger.debug(response.status_code) logger.debug(response.body) logger.debug(response.headers) diff --git a/api/pyproject.toml b/api/pyproject.toml index 319c9a23a6..fb8a035118 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "dify-api" -version = "1.9.0" +version = "1.9.1" requires-python = ">=3.11,<3.13" dependencies = [ @@ -181,7 +181,7 @@ dev = [ storage = [ "azure-storage-blob==12.13.0", "bce-python-sdk~=0.9.23", - "cos-python-sdk-v5==1.9.30", + "cos-python-sdk-v5==1.9.38", "esdk-obs-python==3.24.6.1", "google-cloud-storage==2.16.0", "opendal~=0.46.0", @@ -208,7 +208,7 @@ vdb = [ "couchbase~=4.3.0", "elasticsearch==8.14.0", "opensearch-py==2.4.0", - "oracledb==3.0.0", + "oracledb==3.3.0", "pgvecto-rs[sqlalchemy]~=0.2.1", "pgvector==0.2.5", "pymilvus~=2.5.0", diff --git a/api/pyrightconfig.json b/api/pyrightconfig.json index 61ed3ac3b4..1e6cd501ad 100644 --- a/api/pyrightconfig.json +++ b/api/pyrightconfig.json @@ -6,11 +6,7 @@ "migrations/", "core/rag", "extensions", - "libs", - "controllers/console/datasets", - "controllers/service_api/dataset", "core/ops", - "core/tools", "core/model_runtime", "core/workflow/nodes", "core/app/app_config/easy_ui_based_app/dataset" diff --git a/api/schedule/check_upgradable_plugin_task.py b/api/schedule/check_upgradable_plugin_task.py index a9ad27b059..0712100c01 100644 --- a/api/schedule/check_upgradable_plugin_task.py +++ b/api/schedule/check_upgradable_plugin_task.py @@ -6,7 +6,7 @@ import click import app from extensions.ext_database import db from models.account import TenantPluginAutoUpgradeStrategy -from tasks.process_tenant_plugin_autoupgrade_check_task import process_tenant_plugin_autoupgrade_check_task +from tasks import process_tenant_plugin_autoupgrade_check_task as check_task AUTO_UPGRADE_MINIMAL_CHECKING_INTERVAL = 15 * 60 # 15 minutes MAX_CONCURRENT_CHECK_TASKS = 20 @@ -43,7 +43,7 @@ def check_upgradable_plugin_task(): for i in range(0, total_strategies, MAX_CONCURRENT_CHECK_TASKS): batch_strategies = strategies[i : i + MAX_CONCURRENT_CHECK_TASKS] for strategy in batch_strategies: - process_tenant_plugin_autoupgrade_check_task.delay( + check_task.process_tenant_plugin_autoupgrade_check_task.delay( strategy.tenant_id, strategy.strategy_setting, strategy.upgrade_time_of_day, diff --git a/api/tasks/process_tenant_plugin_autoupgrade_check_task.py b/api/tasks/process_tenant_plugin_autoupgrade_check_task.py index bae8f1c4db..124971e8e2 100644 --- a/api/tasks/process_tenant_plugin_autoupgrade_check_task.py +++ b/api/tasks/process_tenant_plugin_autoupgrade_check_task.py @@ -1,5 +1,5 @@ +import json import operator -import traceback import typing import click @@ -9,38 +9,106 @@ from core.helper import marketplace from core.helper.marketplace import MarketplacePluginDeclaration from core.plugin.entities.plugin import PluginInstallationSource from core.plugin.impl.plugin import PluginInstaller +from extensions.ext_redis import redis_client from models.account import TenantPluginAutoUpgradeStrategy RETRY_TIMES_OF_ONE_PLUGIN_IN_ONE_TENANT = 3 +CACHE_REDIS_KEY_PREFIX = "plugin_autoupgrade_check_task:cached_plugin_manifests:" +CACHE_REDIS_TTL = 60 * 15 # 15 minutes -cached_plugin_manifests: dict[str, typing.Union[MarketplacePluginDeclaration, None]] = {} +def _get_redis_cache_key(plugin_id: str) -> str: + """Generate Redis cache key for plugin manifest.""" + return f"{CACHE_REDIS_KEY_PREFIX}{plugin_id}" + + +def _get_cached_manifest(plugin_id: str) -> typing.Union[MarketplacePluginDeclaration, None, bool]: + """ + Get cached plugin manifest from Redis. + Returns: + - MarketplacePluginDeclaration: if found in cache + - None: if cached as not found (marketplace returned no result) + - False: if not in cache at all + """ + try: + key = _get_redis_cache_key(plugin_id) + cached_data = redis_client.get(key) + if cached_data is None: + return False + + cached_json = json.loads(cached_data) + if cached_json is None: + return None + + return MarketplacePluginDeclaration.model_validate(cached_json) + except Exception: + return False + + +def _set_cached_manifest(plugin_id: str, manifest: typing.Union[MarketplacePluginDeclaration, None]) -> None: + """ + Cache plugin manifest in Redis. + Args: + plugin_id: The plugin ID + manifest: The manifest to cache, or None if not found in marketplace + """ + try: + key = _get_redis_cache_key(plugin_id) + if manifest is None: + # Cache the fact that this plugin was not found + redis_client.setex(key, CACHE_REDIS_TTL, json.dumps(None)) + else: + # Cache the manifest data + redis_client.setex(key, CACHE_REDIS_TTL, manifest.model_dump_json()) + except Exception: + # If Redis fails, continue without caching + # traceback.print_exc() + pass def marketplace_batch_fetch_plugin_manifests( plugin_ids_plain_list: list[str], ) -> list[MarketplacePluginDeclaration]: - global cached_plugin_manifests - # return marketplace.batch_fetch_plugin_manifests(plugin_ids_plain_list) - not_included_plugin_ids = [ - plugin_id for plugin_id in plugin_ids_plain_list if plugin_id not in cached_plugin_manifests - ] - if not_included_plugin_ids: - manifests = marketplace.batch_fetch_plugin_manifests_ignore_deserialization_error(not_included_plugin_ids) + """Fetch plugin manifests with Redis caching support.""" + cached_manifests: dict[str, typing.Union[MarketplacePluginDeclaration, None]] = {} + not_cached_plugin_ids: list[str] = [] + + # Check Redis cache for each plugin + for plugin_id in plugin_ids_plain_list: + cached_result = _get_cached_manifest(plugin_id) + if cached_result is False: + # Not in cache, need to fetch + not_cached_plugin_ids.append(plugin_id) + else: + # Either found manifest or cached as None (not found in marketplace) + # At this point, cached_result is either MarketplacePluginDeclaration or None + if isinstance(cached_result, bool): + # This should never happen due to the if condition above, but for type safety + continue + cached_manifests[plugin_id] = cached_result + + # Fetch uncached plugins from marketplace + if not_cached_plugin_ids: + manifests = marketplace.batch_fetch_plugin_manifests_ignore_deserialization_error(not_cached_plugin_ids) + + # Cache the fetched manifests for manifest in manifests: - cached_plugin_manifests[manifest.plugin_id] = manifest + cached_manifests[manifest.plugin_id] = manifest + _set_cached_manifest(manifest.plugin_id, manifest) - if ( - len(manifests) == 0 - ): # this indicates that the plugin not found in marketplace, should set None in cache to prevent future check - for plugin_id in not_included_plugin_ids: - cached_plugin_manifests[plugin_id] = None + # Cache plugins that were not found in marketplace + fetched_plugin_ids = {manifest.plugin_id for manifest in manifests} + for plugin_id in not_cached_plugin_ids: + if plugin_id not in fetched_plugin_ids: + cached_manifests[plugin_id] = None + _set_cached_manifest(plugin_id, None) + # Build result list from cached manifests result: list[MarketplacePluginDeclaration] = [] for plugin_id in plugin_ids_plain_list: - final_manifest = cached_plugin_manifests.get(plugin_id) - if final_manifest is not None: - result.append(final_manifest) + cached_manifest: typing.Union[MarketplacePluginDeclaration, None] = cached_manifests.get(plugin_id) + if cached_manifest is not None: + result.append(cached_manifest) return result @@ -157,10 +225,10 @@ def process_tenant_plugin_autoupgrade_check_task( ) except Exception as e: click.echo(click.style(f"Error when upgrading plugin: {e}", fg="red")) - traceback.print_exc() + # traceback.print_exc() break except Exception as e: click.echo(click.style(f"Error when checking upgradable plugin: {e}", fg="red")) - traceback.print_exc() + # traceback.print_exc() return diff --git a/api/tests/fixtures/workflow/update-conversation-variable-in-iteration.yml b/api/tests/fixtures/workflow/update-conversation-variable-in-iteration.yml new file mode 100644 index 0000000000..ffc6eb9120 --- /dev/null +++ b/api/tests/fixtures/workflow/update-conversation-variable-in-iteration.yml @@ -0,0 +1,316 @@ +app: + description: 'This chatflow receives a sys.query, writes it into the `answer` variable, + and then outputs the `answer` variable. + + + `answer` is a conversation variable with a blank default value; it will be updated + in an iteration node. + + + if this chatflow works correctly, it will output the `sys.query` as the same.' + icon: 🤖 + icon_background: '#FFEAD5' + mode: advanced-chat + name: update-conversation-variable-in-iteration + use_icon_as_answer_icon: false +dependencies: [] +kind: app +version: 0.4.0 +workflow: + conversation_variables: + - description: '' + id: c30af82d-b2ec-417d-a861-4dd78584faa4 + name: answer + selector: + - conversation + - answer + value: '' + value_type: string + environment_variables: [] + features: + file_upload: + allowed_file_extensions: + - .JPG + - .JPEG + - .PNG + - .GIF + - .WEBP + - .SVG + allowed_file_types: + - image + allowed_file_upload_methods: + - local_file + - remote_url + enabled: false + fileUploadConfig: + audio_file_size_limit: 50 + batch_count_limit: 5 + file_size_limit: 15 + image_file_size_limit: 10 + video_file_size_limit: 100 + workflow_file_upload_limit: 10 + image: + enabled: false + number_limits: 3 + transfer_methods: + - local_file + - remote_url + number_limits: 3 + opening_statement: '' + retriever_resource: + enabled: true + sensitive_word_avoidance: + enabled: false + speech_to_text: + enabled: false + suggested_questions: [] + suggested_questions_after_answer: + enabled: false + text_to_speech: + enabled: false + language: '' + voice: '' + graph: + edges: + - data: + isInIteration: false + isInLoop: false + sourceType: start + targetType: code + id: 1759032354471-source-1759032363865-target + source: '1759032354471' + sourceHandle: source + target: '1759032363865' + targetHandle: target + type: custom + zIndex: 0 + - data: + isInIteration: false + isInLoop: false + sourceType: code + targetType: iteration + id: 1759032363865-source-1759032379989-target + source: '1759032363865' + sourceHandle: source + target: '1759032379989' + targetHandle: target + type: custom + zIndex: 0 + - data: + isInIteration: true + isInLoop: false + iteration_id: '1759032379989' + sourceType: iteration-start + targetType: assigner + id: 1759032379989start-source-1759032394460-target + source: 1759032379989start + sourceHandle: source + target: '1759032394460' + targetHandle: target + type: custom + zIndex: 1002 + - data: + isInIteration: false + isInLoop: false + sourceType: iteration + targetType: answer + id: 1759032379989-source-1759032410331-target + source: '1759032379989' + sourceHandle: source + target: '1759032410331' + targetHandle: target + type: custom + zIndex: 0 + - data: + isInIteration: true + isInLoop: false + iteration_id: '1759032379989' + sourceType: assigner + targetType: code + id: 1759032394460-source-1759032476318-target + source: '1759032394460' + sourceHandle: source + target: '1759032476318' + targetHandle: target + type: custom + zIndex: 1002 + nodes: + - data: + selected: false + title: Start + type: start + variables: [] + height: 52 + id: '1759032354471' + position: + x: 30 + y: 302 + positionAbsolute: + x: 30 + y: 302 + selected: false + sourcePosition: right + targetPosition: left + type: custom + width: 242 + - data: + code: "\ndef main():\n return {\n \"result\": [1],\n }\n" + code_language: python3 + outputs: + result: + children: null + type: array[number] + selected: false + title: Code + type: code + variables: [] + height: 52 + id: '1759032363865' + position: + x: 332 + y: 302 + positionAbsolute: + x: 332 + y: 302 + sourcePosition: right + targetPosition: left + type: custom + width: 242 + - data: + error_handle_mode: terminated + height: 204 + is_parallel: false + iterator_input_type: array[number] + iterator_selector: + - '1759032363865' + - result + output_selector: + - '1759032476318' + - result + output_type: array[string] + parallel_nums: 10 + selected: false + start_node_id: 1759032379989start + title: Iteration + type: iteration + width: 808 + height: 204 + id: '1759032379989' + position: + x: 634 + y: 302 + positionAbsolute: + x: 634 + y: 302 + selected: true + sourcePosition: right + targetPosition: left + type: custom + width: 808 + zIndex: 1 + - data: + desc: '' + isInIteration: true + selected: false + title: '' + type: iteration-start + draggable: false + height: 48 + id: 1759032379989start + parentId: '1759032379989' + position: + x: 60 + y: 78 + positionAbsolute: + x: 694 + y: 380 + selectable: false + sourcePosition: right + targetPosition: left + type: custom-iteration-start + width: 44 + zIndex: 1002 + - data: + isInIteration: true + isInLoop: false + items: + - input_type: variable + operation: over-write + value: + - sys + - query + variable_selector: + - conversation + - answer + write_mode: over-write + iteration_id: '1759032379989' + selected: false + title: Variable Assigner + type: assigner + version: '2' + height: 84 + id: '1759032394460' + parentId: '1759032379989' + position: + x: 204 + y: 60 + positionAbsolute: + x: 838 + y: 362 + sourcePosition: right + targetPosition: left + type: custom + width: 242 + zIndex: 1002 + - data: + answer: '{{#conversation.answer#}}' + selected: false + title: Answer + type: answer + variables: [] + height: 104 + id: '1759032410331' + position: + x: 1502 + y: 302 + positionAbsolute: + x: 1502 + y: 302 + selected: false + sourcePosition: right + targetPosition: left + type: custom + width: 242 + - data: + code: "\ndef main():\n return {\n \"result\": '',\n }\n" + code_language: python3 + isInIteration: true + isInLoop: false + iteration_id: '1759032379989' + outputs: + result: + children: null + type: string + selected: false + title: Code 2 + type: code + variables: [] + height: 52 + id: '1759032476318' + parentId: '1759032379989' + position: + x: 506 + y: 76 + positionAbsolute: + x: 1140 + y: 378 + sourcePosition: right + targetPosition: left + type: custom + width: 242 + zIndex: 1002 + viewport: + x: 120.39999999999998 + y: 85.20000000000005 + zoom: 0.7 + rag_pipeline_variables: [] diff --git a/api/tests/unit_tests/core/workflow/entities/test_variable_pool.py b/api/tests/unit_tests/core/workflow/entities/test_variable_pool.py new file mode 100644 index 0000000000..68fe82d05e --- /dev/null +++ b/api/tests/unit_tests/core/workflow/entities/test_variable_pool.py @@ -0,0 +1,113 @@ +from core.variables.segments import ( + BooleanSegment, + IntegerSegment, + NoneSegment, + StringSegment, +) +from core.workflow.entities.variable_pool import VariablePool + + +class TestVariablePoolGetAndNestedAttribute: + # + # _get_nested_attribute tests + # + def test__get_nested_attribute_existing_key(self): + pool = VariablePool.empty() + obj = {"a": 123} + segment = pool._get_nested_attribute(obj, "a") + assert segment is not None + assert segment.value == 123 + + def test__get_nested_attribute_missing_key(self): + pool = VariablePool.empty() + obj = {"a": 123} + segment = pool._get_nested_attribute(obj, "b") + assert segment is None + + def test__get_nested_attribute_non_dict(self): + pool = VariablePool.empty() + obj = ["not", "a", "dict"] + segment = pool._get_nested_attribute(obj, "a") + assert segment is None + + def test__get_nested_attribute_with_none_value(self): + pool = VariablePool.empty() + obj = {"a": None} + segment = pool._get_nested_attribute(obj, "a") + assert segment is not None + assert isinstance(segment, NoneSegment) + + def test__get_nested_attribute_with_empty_string(self): + pool = VariablePool.empty() + obj = {"a": ""} + segment = pool._get_nested_attribute(obj, "a") + assert segment is not None + assert isinstance(segment, StringSegment) + assert segment.value == "" + + # + # get tests + # + def test_get_simple_variable(self): + pool = VariablePool.empty() + pool.add(("node1", "var1"), "value1") + segment = pool.get(("node1", "var1")) + assert segment is not None + assert segment.value == "value1" + + def test_get_missing_variable(self): + pool = VariablePool.empty() + result = pool.get(("node1", "unknown")) + assert result is None + + def test_get_with_too_short_selector(self): + pool = VariablePool.empty() + result = pool.get(("only_node",)) + assert result is None + + def test_get_nested_object_attribute(self): + pool = VariablePool.empty() + obj_value = {"inner": "hello"} + pool.add(("node1", "obj"), obj_value) + + # simulate selector with nested attr + segment = pool.get(("node1", "obj", "inner")) + assert segment is not None + assert segment.value == "hello" + + def test_get_nested_object_missing_attribute(self): + pool = VariablePool.empty() + obj_value = {"inner": "hello"} + pool.add(("node1", "obj"), obj_value) + + result = pool.get(("node1", "obj", "not_exist")) + assert result is None + + def test_get_nested_object_attribute_with_falsy_values(self): + pool = VariablePool.empty() + obj_value = { + "inner_none": None, + "inner_empty": "", + "inner_zero": 0, + "inner_false": False, + } + pool.add(("node1", "obj"), obj_value) + + segment_none = pool.get(("node1", "obj", "inner_none")) + assert segment_none is not None + assert isinstance(segment_none, NoneSegment) + + segment_empty = pool.get(("node1", "obj", "inner_empty")) + assert segment_empty is not None + assert isinstance(segment_empty, StringSegment) + assert segment_empty.value == "" + + segment_zero = pool.get(("node1", "obj", "inner_zero")) + assert segment_zero is not None + assert isinstance(segment_zero, IntegerSegment) + assert segment_zero.value == 0 + + segment_false = pool.get(("node1", "obj", "inner_false")) + assert segment_false is not None + assert isinstance(segment_false, BooleanSegment) + assert segment_false.value is False diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_update_conversation_variable_iteration.py b/api/tests/unit_tests/core/workflow/graph_engine/test_update_conversation_variable_iteration.py new file mode 100644 index 0000000000..a7309f64de --- /dev/null +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_update_conversation_variable_iteration.py @@ -0,0 +1,41 @@ +"""Validate conversation variable updates inside an iteration workflow. + +This test uses the ``update-conversation-variable-in-iteration`` fixture, which +routes ``sys.query`` into the conversation variable ``answer`` from within an +iteration container. The workflow should surface that updated conversation +variable in the final answer output. + +Code nodes in the fixture are mocked because their concrete outputs are not +relevant to verifying variable propagation semantics. +""" + +from .test_mock_config import MockConfigBuilder +from .test_table_runner import TableTestRunner, WorkflowTestCase + + +def test_update_conversation_variable_in_iteration(): + fixture_name = "update-conversation-variable-in-iteration" + user_query = "ensure conversation variable syncs" + + mock_config = ( + MockConfigBuilder() + .with_node_output("1759032363865", {"result": [1]}) + .with_node_output("1759032476318", {"result": ""}) + .build() + ) + + case = WorkflowTestCase( + fixture_path=fixture_name, + use_auto_mock=True, + mock_config=mock_config, + query=user_query, + expected_outputs={"answer": user_query}, + description="Conversation variable updated within iteration should flow to answer output.", + ) + + runner = TableTestRunner() + result = runner.run_test_case(case) + + assert result.success, f"Workflow execution failed: {result.error}" + assert result.actual_outputs is not None + assert result.actual_outputs.get("answer") == user_query diff --git a/api/uv.lock b/api/uv.lock index 8b3564d139..f2abdee78f 100644 --- a/api/uv.lock +++ b/api/uv.lock @@ -1076,7 +1076,7 @@ wheels = [ [[package]] name = "cos-python-sdk-v5" -version = "1.9.30" +version = "1.9.38" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "crcmod" }, @@ -1085,7 +1085,10 @@ dependencies = [ { name = "six" }, { name = "xmltodict" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/c4/f2/be99b41433b33a76896680920fca621f191875ca410a66778015e47a501b/cos-python-sdk-v5-1.9.30.tar.gz", hash = "sha256:a23fd090211bf90883066d90cd74317860aa67c6d3aa80fe5e44b18c7e9b2a81", size = 108384, upload-time = "2024-06-14T08:02:37.063Z" } +sdist = { url = "https://files.pythonhosted.org/packages/24/3c/d208266fec7cc3221b449e236b87c3fc1999d5ac4379d4578480321cfecc/cos_python_sdk_v5-1.9.38.tar.gz", hash = "sha256:491a8689ae2f1a6f04dacba66a877b2c8d361456f9cfd788ed42170a1cbf7a9f", size = 98092, upload-time = "2025-07-22T07:56:20.34Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ab/c8/c9c156aa3bc7caba9b4f8a2b6abec3da6263215988f3fec0ea843f137a10/cos_python_sdk_v5-1.9.38-py3-none-any.whl", hash = "sha256:1d3dd3be2bd992b2e9c2dcd018e2596aa38eab022dbc86b4a5d14c8fc88370e6", size = 92601, upload-time = "2025-08-17T05:12:30.867Z" }, +] [[package]] name = "couchbase" @@ -1286,7 +1289,7 @@ wheels = [ [[package]] name = "dify-api" -version = "1.9.0" +version = "1.9.1" source = { virtual = "." } dependencies = [ { name = "arize-phoenix-otel" }, @@ -1639,7 +1642,7 @@ dev = [ storage = [ { name = "azure-storage-blob", specifier = "==12.13.0" }, { name = "bce-python-sdk", specifier = "~=0.9.23" }, - { name = "cos-python-sdk-v5", specifier = "==1.9.30" }, + { name = "cos-python-sdk-v5", specifier = "==1.9.38" }, { name = "esdk-obs-python", specifier = "==3.24.6.1" }, { name = "google-cloud-storage", specifier = "==2.16.0" }, { name = "opendal", specifier = "~=0.46.0" }, @@ -1661,7 +1664,7 @@ vdb = [ { name = "elasticsearch", specifier = "==8.14.0" }, { name = "mo-vector", specifier = "~=0.1.13" }, { name = "opensearch-py", specifier = "==2.4.0" }, - { name = "oracledb", specifier = "==3.0.0" }, + { name = "oracledb", specifier = "==3.3.0" }, { name = "pgvecto-rs", extras = ["sqlalchemy"], specifier = "~=0.2.1" }, { name = "pgvector", specifier = "==0.2.5" }, { name = "pymilvus", specifier = "~=2.5.0" }, @@ -4094,23 +4097,23 @@ numpy = [ [[package]] name = "oracledb" -version = "3.0.0" +version = "3.3.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "cryptography" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/bf/39/712f797b75705c21148fa1d98651f63c2e5cc6876e509a0a9e2f5b406572/oracledb-3.0.0.tar.gz", hash = "sha256:64dc86ee5c032febc556798b06e7b000ef6828bb0252084f6addacad3363db85", size = 840431, upload-time = "2025-03-03T19:36:12.223Z" } +sdist = { url = "https://files.pythonhosted.org/packages/51/c9/fae18fa5d803712d188486f8e86ad4f4e00316793ca19745d7c11092c360/oracledb-3.3.0.tar.gz", hash = "sha256:e830d3544a1578296bcaa54c6e8c8ae10a58c7db467c528c4b27adbf9c8b4cb0", size = 811776, upload-time = "2025-07-29T22:34:10.489Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/fa/bf/d872c4b3fc15cd3261fe0ea72b21d181700c92dbc050160e161654987062/oracledb-3.0.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:52daa9141c63dfa75c07d445e9bb7f69f43bfb3c5a173ecc48c798fe50288d26", size = 4312963, upload-time = "2025-03-03T19:36:32.576Z" }, - { url = "https://files.pythonhosted.org/packages/b1/ea/01ee29e76a610a53bb34fdc1030f04b7669c3f80b25f661e07850fc6160e/oracledb-3.0.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:af98941789df4c6aaaf4338f5b5f6b7f2c8c3fe6f8d6a9382f177f350868747a", size = 2661536, upload-time = "2025-03-03T19:36:34.904Z" }, - { url = "https://files.pythonhosted.org/packages/3d/8e/ad380e34a46819224423b4773e58c350bc6269643c8969604097ced8c3bc/oracledb-3.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:9812bb48865aaec35d73af54cd1746679f2a8a13cbd1412ab371aba2e39b3943", size = 2867461, upload-time = "2025-03-03T19:36:36.508Z" }, - { url = "https://files.pythonhosted.org/packages/96/09/ecc4384a27fd6e1e4de824ae9c160e4ad3aaebdaade5b4bdcf56a4d1ff63/oracledb-3.0.0-cp311-cp311-win32.whl", hash = "sha256:6c27fe0de64f2652e949eb05b3baa94df9b981a4a45fa7f8a991e1afb450c8e2", size = 1752046, upload-time = "2025-03-03T19:36:38.313Z" }, - { url = "https://files.pythonhosted.org/packages/62/e8/f34bde24050c6e55eeba46b23b2291f2dd7fd272fa8b322dcbe71be55778/oracledb-3.0.0-cp311-cp311-win_amd64.whl", hash = "sha256:f922709672002f0b40997456f03a95f03e5712a86c61159951c5ce09334325e0", size = 2101210, upload-time = "2025-03-03T19:36:40.669Z" }, - { url = "https://files.pythonhosted.org/packages/6f/fc/24590c3a3d41e58494bd3c3b447a62835138e5f9b243d9f8da0cfb5da8dc/oracledb-3.0.0-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:acd0e747227dea01bebe627b07e958bf36588a337539f24db629dc3431d3f7eb", size = 4351993, upload-time = "2025-03-03T19:36:42.577Z" }, - { url = "https://files.pythonhosted.org/packages/b7/b6/1f3b0b7bb94d53e8857d77b2e8dbdf6da091dd7e377523e24b79dac4fd71/oracledb-3.0.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:f8b402f77c22af031cd0051aea2472ecd0635c1b452998f511aa08b7350c90a4", size = 2532640, upload-time = "2025-03-03T19:36:45.066Z" }, - { url = "https://files.pythonhosted.org/packages/72/1a/1815f6c086ab49c00921cf155ff5eede5267fb29fcec37cb246339a5ce4d/oracledb-3.0.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:378a27782e9a37918bd07a5a1427a77cb6f777d0a5a8eac9c070d786f50120ef", size = 2765949, upload-time = "2025-03-03T19:36:47.47Z" }, - { url = "https://files.pythonhosted.org/packages/33/8d/208900f8d372909792ee70b2daad3f7361181e55f2217c45ed9dff658b54/oracledb-3.0.0-cp312-cp312-win32.whl", hash = "sha256:54a28c2cb08316a527cd1467740a63771cc1c1164697c932aa834c0967dc4efc", size = 1709373, upload-time = "2025-03-03T19:36:49.67Z" }, - { url = "https://files.pythonhosted.org/packages/0c/5e/c21754f19c896102793c3afec2277e2180aa7d505e4d7fcca24b52d14e4f/oracledb-3.0.0-cp312-cp312-win_amd64.whl", hash = "sha256:8289bad6d103ce42b140e40576cf0c81633e344d56e2d738b539341eacf65624", size = 2056452, upload-time = "2025-03-03T19:36:51.363Z" }, + { url = "https://files.pythonhosted.org/packages/3f/35/95d9a502fdc48ce1ef3a513ebd027488353441e15aa0448619abb3d09d32/oracledb-3.3.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:d9adb74f837838e21898d938e3a725cf73099c65f98b0b34d77146b453e945e0", size = 3963945, upload-time = "2025-07-29T22:34:28.633Z" }, + { url = "https://files.pythonhosted.org/packages/16/a7/8f1ef447d995bb51d9fdc36356697afeceb603932f16410c12d52b2df1a4/oracledb-3.3.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:4b063d1007882570f170ebde0f364e78d4a70c8f015735cc900663278b9ceef7", size = 2449385, upload-time = "2025-07-29T22:34:30.592Z" }, + { url = "https://files.pythonhosted.org/packages/b3/fa/6a78480450bc7d256808d0f38ade3385735fb5a90dab662167b4257dcf94/oracledb-3.3.0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:187728f0a2d161676b8c581a9d8f15d9631a8fea1e628f6d0e9fa2f01280cd22", size = 2634943, upload-time = "2025-07-29T22:34:33.142Z" }, + { url = "https://files.pythonhosted.org/packages/5b/90/ea32b569a45fb99fac30b96f1ac0fb38b029eeebb78357bc6db4be9dde41/oracledb-3.3.0-cp311-cp311-win32.whl", hash = "sha256:920f14314f3402c5ab98f2efc5932e0547e9c0a4ca9338641357f73844e3e2b1", size = 1483549, upload-time = "2025-07-29T22:34:35.015Z" }, + { url = "https://files.pythonhosted.org/packages/81/55/ae60f72836eb8531b630299f9ed68df3fe7868c6da16f820a108155a21f9/oracledb-3.3.0-cp311-cp311-win_amd64.whl", hash = "sha256:825edb97976468db1c7e52c78ba38d75ce7e2b71a2e88f8629bcf02be8e68a8a", size = 1834737, upload-time = "2025-07-29T22:34:36.824Z" }, + { url = "https://files.pythonhosted.org/packages/08/a8/f6b7809d70e98e113786d5a6f1294da81c046d2fa901ad656669fc5d7fae/oracledb-3.3.0-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:9d25e37d640872731ac9b73f83cbc5fc4743cd744766bdb250488caf0d7696a8", size = 3943512, upload-time = "2025-07-29T22:34:39.237Z" }, + { url = "https://files.pythonhosted.org/packages/df/b9/8145ad8991f4864d3de4a911d439e5bc6cdbf14af448f3ab1e846a54210c/oracledb-3.3.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b0bf7cdc2b668f939aa364f552861bc7a149d7cd3f3794730d43ef07613b2bf9", size = 2276258, upload-time = "2025-07-29T22:34:41.547Z" }, + { url = "https://files.pythonhosted.org/packages/56/bf/f65635ad5df17d6e4a2083182750bb136ac663ff0e9996ce59d77d200f60/oracledb-3.3.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:2fe20540fde64a6987046807ea47af93be918fd70b9766b3eb803c01e6d4202e", size = 2458811, upload-time = "2025-07-29T22:34:44.648Z" }, + { url = "https://files.pythonhosted.org/packages/7d/30/e0c130b6278c10b0e6cd77a3a1a29a785c083c549676cf701c5d180b8e63/oracledb-3.3.0-cp312-cp312-win32.whl", hash = "sha256:db080be9345cbf9506ffdaea3c13d5314605355e76d186ec4edfa49960ffb813", size = 1445525, upload-time = "2025-07-29T22:34:46.603Z" }, + { url = "https://files.pythonhosted.org/packages/1a/5c/7254f5e1a33a5d6b8bf6813d4f4fdcf5c4166ec8a7af932d987879d5595c/oracledb-3.3.0-cp312-cp312-win_amd64.whl", hash = "sha256:be81e3afe79f6c8ece79a86d6067ad1572d2992ce1c590a086f3755a09535eb4", size = 1789976, upload-time = "2025-07-29T22:34:48.5Z" }, ] [[package]] diff --git a/docker/.env.example b/docker/.env.example index ad99c3d448..dce5d5cc44 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -45,7 +45,7 @@ APP_WEB_URL= # Recommendation: use a dedicated domain (e.g., https://upload.example.com). # Alternatively, use http://:5001 or http://api:5001, # ensuring port 5001 is externally accessible (see docker-compose.yaml). -FILES_URL=http://api:5001 +FILES_URL= # INTERNAL_FILES_URL is used for plugin daemon communication within Docker network. # Set this to the internal Docker service URL for proper plugin file access. diff --git a/docker/docker-compose-template.yaml b/docker/docker-compose-template.yaml index e4071df821..ca2928679e 100644 --- a/docker/docker-compose-template.yaml +++ b/docker/docker-compose-template.yaml @@ -2,7 +2,7 @@ x-shared-env: &shared-api-worker-env services: # API service api: - image: langgenius/dify-api:1.9.0 + image: langgenius/dify-api:1.9.1 restart: always environment: # Use the shared environment variables. @@ -31,7 +31,7 @@ services: # worker service # The Celery worker for processing all queues (dataset, workflow, mail, etc.) worker: - image: langgenius/dify-api:1.9.0 + image: langgenius/dify-api:1.9.1 restart: always environment: # Use the shared environment variables. @@ -58,7 +58,7 @@ services: # worker_beat service # Celery beat for scheduling periodic tasks. worker_beat: - image: langgenius/dify-api:1.9.0 + image: langgenius/dify-api:1.9.1 restart: always environment: # Use the shared environment variables. @@ -76,7 +76,7 @@ services: # Frontend web application. web: - image: langgenius/dify-web:1.9.0 + image: langgenius/dify-web:1.9.1 restart: always environment: CONSOLE_API_URL: ${CONSOLE_API_URL:-} diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index 53f17f0389..d3f382b67f 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -10,7 +10,7 @@ x-shared-env: &shared-api-worker-env SERVICE_API_URL: ${SERVICE_API_URL:-} APP_API_URL: ${APP_API_URL:-} APP_WEB_URL: ${APP_WEB_URL:-} - FILES_URL: ${FILES_URL:-http://api:5001} + FILES_URL: ${FILES_URL:-} INTERNAL_FILES_URL: ${INTERNAL_FILES_URL:-} LANG: ${LANG:-en_US.UTF-8} LC_ALL: ${LC_ALL:-en_US.UTF-8} @@ -604,7 +604,7 @@ x-shared-env: &shared-api-worker-env services: # API service api: - image: langgenius/dify-api:1.9.0 + image: langgenius/dify-api:1.9.1 restart: always environment: # Use the shared environment variables. @@ -633,7 +633,7 @@ services: # worker service # The Celery worker for processing the queue. worker: - image: langgenius/dify-api:1.9.0 + image: langgenius/dify-api:1.9.1 restart: always environment: # Use the shared environment variables. @@ -660,7 +660,7 @@ services: # worker_beat service # Celery beat for scheduling periodic tasks. worker_beat: - image: langgenius/dify-api:1.9.0 + image: langgenius/dify-api:1.9.1 restart: always environment: # Use the shared environment variables. @@ -678,7 +678,7 @@ services: # Frontend web application. web: - image: langgenius/dify-web:1.9.0 + image: langgenius/dify-web:1.9.1 restart: always environment: CONSOLE_API_URL: ${CONSOLE_API_URL:-} diff --git a/sdks/python-client/dify_client/__init__.py b/sdks/python-client/dify_client/__init__.py index e866472f45..e252bc0472 100644 --- a/sdks/python-client/dify_client/__init__.py +++ b/sdks/python-client/dify_client/__init__.py @@ -4,6 +4,7 @@ from dify_client.client import ( DifyClient, KnowledgeBaseClient, WorkflowClient, + WorkspaceClient, ) __all__ = [ @@ -12,4 +13,5 @@ __all__ = [ "DifyClient", "KnowledgeBaseClient", "WorkflowClient", + "WorkspaceClient", ] diff --git a/sdks/python-client/dify_client/client.py b/sdks/python-client/dify_client/client.py index 201391eae9..fb42e3773d 100644 --- a/sdks/python-client/dify_client/client.py +++ b/sdks/python-client/dify_client/client.py @@ -1,5 +1,6 @@ import json -from typing import Literal +from typing import Literal, Union, Dict, List, Any, Optional, IO + import requests @@ -49,6 +50,18 @@ class DifyClient: params = {"user": user} return self._send_request("GET", "/meta", params=params) + def get_app_info(self): + """Get basic application information including name, description, tags, and mode.""" + return self._send_request("GET", "/info") + + def get_app_site_info(self): + """Get application site information.""" + return self._send_request("GET", "/site") + + def get_file_preview(self, file_id: str): + """Get file preview by file ID.""" + return self._send_request("GET", f"/files/{file_id}/preview") + class CompletionClient(DifyClient): def create_completion_message( @@ -144,6 +157,51 @@ class ChatClient(DifyClient): files = {"file": audio_file} return self._send_request_with_files("POST", "/audio-to-text", data, files) + # Annotation APIs + def annotation_reply_action( + self, + action: Literal["enable", "disable"], + score_threshold: float, + embedding_provider_name: str, + embedding_model_name: str, + ): + """Enable or disable annotation reply feature.""" + # Backend API requires these fields to be non-None values + if score_threshold is None or embedding_provider_name is None or embedding_model_name is None: + raise ValueError("score_threshold, embedding_provider_name, and embedding_model_name cannot be None") + + data = { + "score_threshold": score_threshold, + "embedding_provider_name": embedding_provider_name, + "embedding_model_name": embedding_model_name, + } + return self._send_request("POST", f"/apps/annotation-reply/{action}", json=data) + + def get_annotation_reply_status(self, action: Literal["enable", "disable"], job_id: str): + """Get the status of an annotation reply action job.""" + return self._send_request("GET", f"/apps/annotation-reply/{action}/status/{job_id}") + + def list_annotations(self, page: int = 1, limit: int = 20, keyword: str = ""): + """List annotations for the application.""" + params = {"page": page, "limit": limit} + if keyword: + params["keyword"] = keyword + return self._send_request("GET", "/apps/annotations", params=params) + + def create_annotation(self, question: str, answer: str): + """Create a new annotation.""" + data = {"question": question, "answer": answer} + return self._send_request("POST", "/apps/annotations", json=data) + + def update_annotation(self, annotation_id: str, question: str, answer: str): + """Update an existing annotation.""" + data = {"question": question, "answer": answer} + return self._send_request("PUT", f"/apps/annotations/{annotation_id}", json=data) + + def delete_annotation(self, annotation_id: str): + """Delete an annotation.""" + return self._send_request("DELETE", f"/apps/annotations/{annotation_id}") + class WorkflowClient(DifyClient): def run(self, inputs: dict, response_mode: Literal["blocking", "streaming"] = "streaming", user: str = "abc-123"): @@ -157,6 +215,55 @@ class WorkflowClient(DifyClient): def get_result(self, workflow_run_id): return self._send_request("GET", f"/workflows/run/{workflow_run_id}") + def get_workflow_logs( + self, + keyword: str = None, + status: Literal["succeeded", "failed", "stopped"] | None = None, + page: int = 1, + limit: int = 20, + created_at__before: str = None, + created_at__after: str = None, + created_by_end_user_session_id: str = None, + created_by_account: str = None, + ): + """Get workflow execution logs with optional filtering.""" + params = {"page": page, "limit": limit} + if keyword: + params["keyword"] = keyword + if status: + params["status"] = status + if created_at__before: + params["created_at__before"] = created_at__before + if created_at__after: + params["created_at__after"] = created_at__after + if created_by_end_user_session_id: + params["created_by_end_user_session_id"] = created_by_end_user_session_id + if created_by_account: + params["created_by_account"] = created_by_account + return self._send_request("GET", "/workflows/logs", params=params) + + def run_specific_workflow( + self, + workflow_id: str, + inputs: dict, + response_mode: Literal["blocking", "streaming"] = "streaming", + user: str = "abc-123", + ): + """Run a specific workflow by workflow ID.""" + data = {"inputs": inputs, "response_mode": response_mode, "user": user} + return self._send_request( + "POST", f"/workflows/{workflow_id}/run", data, stream=True if response_mode == "streaming" else False + ) + + +class WorkspaceClient(DifyClient): + """Client for workspace-related operations.""" + + def get_available_models(self, model_type: str): + """Get available models by model type.""" + url = f"/workspaces/current/models/model-types/{model_type}" + return self._send_request("GET", url) + class KnowledgeBaseClient(DifyClient): def __init__( @@ -443,3 +550,117 @@ class KnowledgeBaseClient(DifyClient): data = {"segment": segment_data} url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/segments/{segment_id}" return self._send_request("POST", url, json=data, **kwargs) + + # Advanced Knowledge Base APIs + def hit_testing( + self, query: str, retrieval_model: Dict[str, Any] = None, external_retrieval_model: Dict[str, Any] = None + ): + """Perform hit testing on the dataset.""" + data = {"query": query} + if retrieval_model: + data["retrieval_model"] = retrieval_model + if external_retrieval_model: + data["external_retrieval_model"] = external_retrieval_model + url = f"/datasets/{self._get_dataset_id()}/hit-testing" + return self._send_request("POST", url, json=data) + + def get_dataset_metadata(self): + """Get dataset metadata.""" + url = f"/datasets/{self._get_dataset_id()}/metadata" + return self._send_request("GET", url) + + def create_dataset_metadata(self, metadata_data: Dict[str, Any]): + """Create dataset metadata.""" + url = f"/datasets/{self._get_dataset_id()}/metadata" + return self._send_request("POST", url, json=metadata_data) + + def update_dataset_metadata(self, metadata_id: str, metadata_data: Dict[str, Any]): + """Update dataset metadata.""" + url = f"/datasets/{self._get_dataset_id()}/metadata/{metadata_id}" + return self._send_request("PATCH", url, json=metadata_data) + + def get_built_in_metadata(self): + """Get built-in metadata.""" + url = f"/datasets/{self._get_dataset_id()}/metadata/built-in" + return self._send_request("GET", url) + + def manage_built_in_metadata(self, action: str, metadata_data: Dict[str, Any] = None): + """Manage built-in metadata with specified action.""" + data = metadata_data or {} + url = f"/datasets/{self._get_dataset_id()}/metadata/built-in/{action}" + return self._send_request("POST", url, json=data) + + def update_documents_metadata(self, operation_data: List[Dict[str, Any]]): + """Update metadata for multiple documents.""" + url = f"/datasets/{self._get_dataset_id()}/documents/metadata" + data = {"operation_data": operation_data} + return self._send_request("POST", url, json=data) + + # Dataset Tags APIs + def list_dataset_tags(self): + """List all dataset tags.""" + return self._send_request("GET", "/datasets/tags") + + def bind_dataset_tags(self, tag_ids: List[str]): + """Bind tags to dataset.""" + data = {"tag_ids": tag_ids, "target_id": self._get_dataset_id()} + return self._send_request("POST", "/datasets/tags/binding", json=data) + + def unbind_dataset_tag(self, tag_id: str): + """Unbind a single tag from dataset.""" + data = {"tag_id": tag_id, "target_id": self._get_dataset_id()} + return self._send_request("POST", "/datasets/tags/unbinding", json=data) + + def get_dataset_tags(self): + """Get tags for current dataset.""" + url = f"/datasets/{self._get_dataset_id()}/tags" + return self._send_request("GET", url) + + # RAG Pipeline APIs + def get_datasource_plugins(self, is_published: bool = True): + """Get datasource plugins for RAG pipeline.""" + params = {"is_published": is_published} + url = f"/datasets/{self._get_dataset_id()}/pipeline/datasource-plugins" + return self._send_request("GET", url, params=params) + + def run_datasource_node( + self, + node_id: str, + inputs: Dict[str, Any], + datasource_type: str, + is_published: bool = True, + credential_id: str = None, + ): + """Run a datasource node in RAG pipeline.""" + data = {"inputs": inputs, "datasource_type": datasource_type, "is_published": is_published} + if credential_id: + data["credential_id"] = credential_id + url = f"/datasets/{self._get_dataset_id()}/pipeline/datasource/nodes/{node_id}/run" + return self._send_request("POST", url, json=data, stream=True) + + def run_rag_pipeline( + self, + inputs: Dict[str, Any], + datasource_type: str, + datasource_info_list: List[Dict[str, Any]], + start_node_id: str, + is_published: bool = True, + response_mode: Literal["streaming", "blocking"] = "blocking", + ): + """Run RAG pipeline.""" + data = { + "inputs": inputs, + "datasource_type": datasource_type, + "datasource_info_list": datasource_info_list, + "start_node_id": start_node_id, + "is_published": is_published, + "response_mode": response_mode, + } + url = f"/datasets/{self._get_dataset_id()}/pipeline/run" + return self._send_request("POST", url, json=data, stream=response_mode == "streaming") + + def upload_pipeline_file(self, file_path: str): + """Upload file for RAG pipeline.""" + with open(file_path, "rb") as f: + files = {"file": f} + return self._send_request_with_files("POST", "/datasets/pipeline/file-upload", {}, files) diff --git a/sdks/python-client/tests/test_new_apis.py b/sdks/python-client/tests/test_new_apis.py new file mode 100644 index 0000000000..09c62dfda7 --- /dev/null +++ b/sdks/python-client/tests/test_new_apis.py @@ -0,0 +1,416 @@ +#!/usr/bin/env python3 +""" +Test suite for the new Service API functionality in the Python SDK. + +This test validates the implementation of the missing Service API endpoints +that were added to the Python SDK to achieve complete coverage. +""" + +import unittest +from unittest.mock import Mock, patch, MagicMock +import json + +from dify_client import ( + DifyClient, + ChatClient, + WorkflowClient, + KnowledgeBaseClient, + WorkspaceClient, +) + + +class TestNewServiceAPIs(unittest.TestCase): + """Test cases for new Service API implementations.""" + + def setUp(self): + """Set up test fixtures.""" + self.api_key = "test-api-key" + self.base_url = "https://api.dify.ai/v1" + + @patch("dify_client.client.requests.request") + def test_app_info_apis(self, mock_request): + """Test application info APIs.""" + mock_response = Mock() + mock_response.json.return_value = { + "name": "Test App", + "description": "Test Description", + "tags": ["test", "api"], + "mode": "chat", + "author_name": "Test Author", + } + mock_request.return_value = mock_response + + client = DifyClient(self.api_key, self.base_url) + + # Test get_app_info + result = client.get_app_info() + mock_request.assert_called_with( + "GET", + f"{self.base_url}/info", + json=None, + params=None, + headers={ + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + }, + stream=False, + ) + + # Test get_app_site_info + client.get_app_site_info() + mock_request.assert_called_with( + "GET", + f"{self.base_url}/site", + json=None, + params=None, + headers={ + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + }, + stream=False, + ) + + # Test get_file_preview + file_id = "test-file-id" + client.get_file_preview(file_id) + mock_request.assert_called_with( + "GET", + f"{self.base_url}/files/{file_id}/preview", + json=None, + params=None, + headers={ + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + }, + stream=False, + ) + + @patch("dify_client.client.requests.request") + def test_annotation_apis(self, mock_request): + """Test annotation APIs.""" + mock_response = Mock() + mock_response.json.return_value = {"result": "success"} + mock_request.return_value = mock_response + + client = ChatClient(self.api_key, self.base_url) + + # Test annotation_reply_action - enable + client.annotation_reply_action( + action="enable", + score_threshold=0.8, + embedding_provider_name="openai", + embedding_model_name="text-embedding-ada-002", + ) + mock_request.assert_called_with( + "POST", + f"{self.base_url}/apps/annotation-reply/enable", + json={ + "score_threshold": 0.8, + "embedding_provider_name": "openai", + "embedding_model_name": "text-embedding-ada-002", + }, + params=None, + headers={ + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + }, + stream=False, + ) + + # Test annotation_reply_action - disable (now requires same fields as enable) + client.annotation_reply_action( + action="disable", + score_threshold=0.5, + embedding_provider_name="openai", + embedding_model_name="text-embedding-ada-002", + ) + + # Test annotation_reply_action with score_threshold=0 (edge case) + client.annotation_reply_action( + action="enable", + score_threshold=0.0, # This should work and not raise ValueError + embedding_provider_name="openai", + embedding_model_name="text-embedding-ada-002", + ) + + # Test get_annotation_reply_status + client.get_annotation_reply_status("enable", "job-123") + + # Test list_annotations + client.list_annotations(page=1, limit=20, keyword="test") + + # Test create_annotation + client.create_annotation("Test question?", "Test answer.") + + # Test update_annotation + client.update_annotation("annotation-123", "Updated question?", "Updated answer.") + + # Test delete_annotation + client.delete_annotation("annotation-123") + + # Verify all calls were made (8 calls: enable + disable + enable with 0.0 + 5 other operations) + self.assertEqual(mock_request.call_count, 8) + + @patch("dify_client.client.requests.request") + def test_knowledge_base_advanced_apis(self, mock_request): + """Test advanced knowledge base APIs.""" + mock_response = Mock() + mock_response.json.return_value = {"result": "success"} + mock_request.return_value = mock_response + + dataset_id = "test-dataset-id" + client = KnowledgeBaseClient(self.api_key, self.base_url, dataset_id) + + # Test hit_testing + client.hit_testing("test query", {"type": "vector"}) + mock_request.assert_called_with( + "POST", + f"{self.base_url}/datasets/{dataset_id}/hit-testing", + json={"query": "test query", "retrieval_model": {"type": "vector"}}, + params=None, + headers={ + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + }, + stream=False, + ) + + # Test metadata operations + client.get_dataset_metadata() + client.create_dataset_metadata({"key": "value"}) + client.update_dataset_metadata("meta-123", {"key": "new_value"}) + client.get_built_in_metadata() + client.manage_built_in_metadata("enable", {"type": "built_in"}) + client.update_documents_metadata([{"document_id": "doc1", "metadata": {"key": "value"}}]) + + # Test tag operations + client.list_dataset_tags() + client.bind_dataset_tags(["tag1", "tag2"]) + client.unbind_dataset_tag("tag1") + client.get_dataset_tags() + + # Verify multiple calls were made + self.assertGreater(mock_request.call_count, 5) + + @patch("dify_client.client.requests.request") + def test_rag_pipeline_apis(self, mock_request): + """Test RAG pipeline APIs.""" + mock_response = Mock() + mock_response.json.return_value = {"result": "success"} + mock_request.return_value = mock_response + + dataset_id = "test-dataset-id" + client = KnowledgeBaseClient(self.api_key, self.base_url, dataset_id) + + # Test get_datasource_plugins + client.get_datasource_plugins(is_published=True) + mock_request.assert_called_with( + "GET", + f"{self.base_url}/datasets/{dataset_id}/pipeline/datasource-plugins", + json=None, + params={"is_published": True}, + headers={ + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + }, + stream=False, + ) + + # Test run_datasource_node + client.run_datasource_node( + node_id="node-123", + inputs={"param": "value"}, + datasource_type="online_document", + is_published=True, + credential_id="cred-123", + ) + + # Test run_rag_pipeline with blocking mode + client.run_rag_pipeline( + inputs={"query": "test"}, + datasource_type="online_document", + datasource_info_list=[{"id": "ds1"}], + start_node_id="start-node", + is_published=True, + response_mode="blocking", + ) + + # Test run_rag_pipeline with streaming mode + client.run_rag_pipeline( + inputs={"query": "test"}, + datasource_type="online_document", + datasource_info_list=[{"id": "ds1"}], + start_node_id="start-node", + is_published=True, + response_mode="streaming", + ) + + self.assertEqual(mock_request.call_count, 4) + + @patch("dify_client.client.requests.request") + def test_workspace_apis(self, mock_request): + """Test workspace APIs.""" + mock_response = Mock() + mock_response.json.return_value = { + "data": [{"name": "gpt-3.5-turbo", "type": "llm"}, {"name": "gpt-4", "type": "llm"}] + } + mock_request.return_value = mock_response + + client = WorkspaceClient(self.api_key, self.base_url) + + # Test get_available_models + result = client.get_available_models("llm") + mock_request.assert_called_with( + "GET", + f"{self.base_url}/workspaces/current/models/model-types/llm", + json=None, + params=None, + headers={ + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + }, + stream=False, + ) + + @patch("dify_client.client.requests.request") + def test_workflow_advanced_apis(self, mock_request): + """Test advanced workflow APIs.""" + mock_response = Mock() + mock_response.json.return_value = {"result": "success"} + mock_request.return_value = mock_response + + client = WorkflowClient(self.api_key, self.base_url) + + # Test get_workflow_logs + client.get_workflow_logs(keyword="test", status="succeeded", page=1, limit=20) + mock_request.assert_called_with( + "GET", + f"{self.base_url}/workflows/logs", + json=None, + params={"page": 1, "limit": 20, "keyword": "test", "status": "succeeded"}, + headers={ + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + }, + stream=False, + ) + + # Test get_workflow_logs with additional filters + client.get_workflow_logs( + keyword="test", + status="succeeded", + page=1, + limit=20, + created_at__before="2024-01-01", + created_at__after="2023-01-01", + created_by_account="user123", + ) + + # Test run_specific_workflow + client.run_specific_workflow( + workflow_id="workflow-123", inputs={"param": "value"}, response_mode="streaming", user="user-123" + ) + + self.assertEqual(mock_request.call_count, 3) + + def test_error_handling(self): + """Test error handling for required parameters.""" + client = ChatClient(self.api_key, self.base_url) + + # Test annotation_reply_action with missing required parameters would be a TypeError now + # since parameters are required in method signature + with self.assertRaises(TypeError): + client.annotation_reply_action("enable") + + # Test annotation_reply_action with explicit None values should raise ValueError + with self.assertRaises(ValueError) as context: + client.annotation_reply_action("enable", None, "provider", "model") + + self.assertIn("cannot be None", str(context.exception)) + + # Test KnowledgeBaseClient without dataset_id + kb_client = KnowledgeBaseClient(self.api_key, self.base_url) + with self.assertRaises(ValueError) as context: + kb_client.hit_testing("test query") + + self.assertIn("dataset_id is not set", str(context.exception)) + + @patch("dify_client.client.open") + @patch("dify_client.client.requests.request") + def test_file_upload_apis(self, mock_request, mock_open): + """Test file upload APIs.""" + mock_response = Mock() + mock_response.json.return_value = {"result": "success"} + mock_request.return_value = mock_response + + mock_file = MagicMock() + mock_open.return_value.__enter__.return_value = mock_file + + dataset_id = "test-dataset-id" + client = KnowledgeBaseClient(self.api_key, self.base_url, dataset_id) + + # Test upload_pipeline_file + client.upload_pipeline_file("/path/to/test.pdf") + + mock_open.assert_called_with("/path/to/test.pdf", "rb") + mock_request.assert_called_once() + + def test_comprehensive_coverage(self): + """Test that all previously missing APIs are now implemented.""" + + # Test DifyClient methods + dify_methods = ["get_app_info", "get_app_site_info", "get_file_preview"] + client = DifyClient(self.api_key) + for method in dify_methods: + self.assertTrue(hasattr(client, method), f"DifyClient missing method: {method}") + + # Test ChatClient annotation methods + chat_methods = [ + "annotation_reply_action", + "get_annotation_reply_status", + "list_annotations", + "create_annotation", + "update_annotation", + "delete_annotation", + ] + chat_client = ChatClient(self.api_key) + for method in chat_methods: + self.assertTrue(hasattr(chat_client, method), f"ChatClient missing method: {method}") + + # Test WorkflowClient advanced methods + workflow_methods = ["get_workflow_logs", "run_specific_workflow"] + workflow_client = WorkflowClient(self.api_key) + for method in workflow_methods: + self.assertTrue(hasattr(workflow_client, method), f"WorkflowClient missing method: {method}") + + # Test KnowledgeBaseClient advanced methods + kb_methods = [ + "hit_testing", + "get_dataset_metadata", + "create_dataset_metadata", + "update_dataset_metadata", + "get_built_in_metadata", + "manage_built_in_metadata", + "update_documents_metadata", + "list_dataset_tags", + "bind_dataset_tags", + "unbind_dataset_tag", + "get_dataset_tags", + "get_datasource_plugins", + "run_datasource_node", + "run_rag_pipeline", + "upload_pipeline_file", + ] + kb_client = KnowledgeBaseClient(self.api_key) + for method in kb_methods: + self.assertTrue(hasattr(kb_client, method), f"KnowledgeBaseClient missing method: {method}") + + # Test WorkspaceClient methods + workspace_methods = ["get_available_models"] + workspace_client = WorkspaceClient(self.api_key) + for method in workspace_methods: + self.assertTrue(hasattr(workspace_client, method), f"WorkspaceClient missing method: {method}") + + +if __name__ == "__main__": + unittest.main() diff --git a/web/app/components/app/annotation/index.tsx b/web/app/components/app/annotation/index.tsx index afa8732701..264b1ac727 100644 --- a/web/app/components/app/annotation/index.tsx +++ b/web/app/components/app/annotation/index.tsx @@ -38,7 +38,7 @@ const Annotation: FC = (props) => { const [isShowEdit, setIsShowEdit] = useState(false) const [annotationConfig, setAnnotationConfig] = useState(null) const [isChatApp] = useState(appDetail.mode !== 'completion') - const [controlRefreshSwitch, setControlRefreshSwitch] = useState(Date.now()) + const [controlRefreshSwitch, setControlRefreshSwitch] = useState(() => Date.now()) const { plan, enableBilling } = useProviderContext() const isAnnotationFull = enableBilling && plan.usage.annotatedResponse >= plan.total.annotatedResponse const [isShowAnnotationFullModal, setIsShowAnnotationFullModal] = useState(false) @@ -48,7 +48,7 @@ const Annotation: FC = (props) => { const [list, setList] = useState([]) const [total, setTotal] = useState(0) const [isLoading, setIsLoading] = useState(false) - const [controlUpdateList, setControlUpdateList] = useState(Date.now()) + const [controlUpdateList, setControlUpdateList] = useState(() => Date.now()) const [currItem, setCurrItem] = useState(null) const [isShowViewModal, setIsShowViewModal] = useState(false) const [selectedIds, setSelectedIds] = useState([]) diff --git a/web/app/components/app/configuration/config-prompt/prompt-editor-height-resize-wrap.tsx b/web/app/components/app/configuration/config-prompt/prompt-editor-height-resize-wrap.tsx index 1457a298f2..9e10db93ae 100644 --- a/web/app/components/app/configuration/config-prompt/prompt-editor-height-resize-wrap.tsx +++ b/web/app/components/app/configuration/config-prompt/prompt-editor-height-resize-wrap.tsx @@ -25,7 +25,7 @@ const PromptEditorHeightResizeWrap: FC = ({ }) => { const [clientY, setClientY] = useState(0) const [isResizing, setIsResizing] = useState(false) - const [prevUserSelectStyle, setPrevUserSelectStyle] = useState(getComputedStyle(document.body).userSelect) + const [prevUserSelectStyle, setPrevUserSelectStyle] = useState(() => getComputedStyle(document.body).userSelect) const [oldHeight, setOldHeight] = useState(height) const handleStartResize = useCallback((e: React.MouseEvent) => { diff --git a/web/app/components/app/configuration/config-var/config-modal/index.tsx b/web/app/components/app/configuration/config-var/config-modal/index.tsx index cecc076fe7..b0f0ea8779 100644 --- a/web/app/components/app/configuration/config-var/config-modal/index.tsx +++ b/web/app/components/app/configuration/config-var/config-modal/index.tsx @@ -53,7 +53,7 @@ const ConfigModal: FC = ({ }) => { const { modelConfig } = useContext(ConfigContext) const { t } = useTranslation() - const [tempPayload, setTempPayload] = useState(payload || getNewVarInWorkflow('') as any) + const [tempPayload, setTempPayload] = useState(() => payload || getNewVarInWorkflow('') as any) const { type, label, variable, options, max_length } = tempPayload const modalRef = useRef(null) const appDetail = useAppStore(state => state.appDetail) diff --git a/web/app/components/app/configuration/dataset-config/index.tsx b/web/app/components/app/configuration/dataset-config/index.tsx index 6165cfdeec..65ef74bc27 100644 --- a/web/app/components/app/configuration/dataset-config/index.tsx +++ b/web/app/components/app/configuration/dataset-config/index.tsx @@ -65,13 +65,40 @@ const DatasetConfig: FC = () => { const onRemove = (id: string) => { const filteredDataSets = dataSet.filter(item => item.id !== id) setDataSet(filteredDataSets) - const retrievalConfig = getMultipleRetrievalConfig(datasetConfigs as any, filteredDataSets, dataSet, { + const { datasets, retrieval_model, score_threshold_enabled, ...restConfigs } = datasetConfigs + const { + top_k, + score_threshold, + reranking_model, + reranking_mode, + weights, + reranking_enable, + } = restConfigs + const oldRetrievalConfig = { + top_k, + score_threshold, + reranking_model: (reranking_model.reranking_provider_name && reranking_model.reranking_model_name) ? { + provider: reranking_model.reranking_provider_name, + model: reranking_model.reranking_model_name, + } : undefined, + reranking_mode, + weights, + reranking_enable, + } + const retrievalConfig = getMultipleRetrievalConfig(oldRetrievalConfig, filteredDataSets, dataSet, { provider: currentRerankProvider?.provider, model: currentRerankModel?.model, }) setDatasetConfigs({ - ...(datasetConfigs as any), + ...datasetConfigsRef.current, ...retrievalConfig, + reranking_model: { + reranking_provider_name: retrievalConfig?.reranking_model?.provider || '', + reranking_model_name: retrievalConfig?.reranking_model?.model || '', + }, + retrieval_model, + score_threshold_enabled, + datasets, }) const { allExternal, diff --git a/web/app/components/app/configuration/dataset-config/params-config/config-content.tsx b/web/app/components/app/configuration/dataset-config/params-config/config-content.tsx index cb61b927bc..1558d32fc6 100644 --- a/web/app/components/app/configuration/dataset-config/params-config/config-content.tsx +++ b/web/app/components/app/configuration/dataset-config/params-config/config-content.tsx @@ -30,11 +30,11 @@ import { noop } from 'lodash-es' type Props = { datasetConfigs: DatasetConfigs onChange: (configs: DatasetConfigs, isRetrievalModeChange?: boolean) => void + selectedDatasets?: DataSet[] isInWorkflow?: boolean singleRetrievalModelConfig?: ModelConfig onSingleRetrievalModelChange?: (config: ModelConfig) => void onSingleRetrievalModelParamsChange?: (config: ModelConfig) => void - selectedDatasets?: DataSet[] } const ConfigContent: FC = ({ @@ -61,22 +61,28 @@ const ConfigContent: FC = ({ const { modelList: rerankModelList, + currentModel: validDefaultRerankModel, + currentProvider: validDefaultRerankProvider, } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank) + /** + * If reranking model is set and is valid, use the reranking model + * Otherwise, check if the default reranking model is valid + */ const { currentModel: currentRerankModel, } = useCurrentProviderAndModel( rerankModelList, { - provider: datasetConfigs.reranking_model?.reranking_provider_name, - model: datasetConfigs.reranking_model?.reranking_model_name, + provider: datasetConfigs.reranking_model?.reranking_provider_name || validDefaultRerankProvider?.provider || '', + model: datasetConfigs.reranking_model?.reranking_model_name || validDefaultRerankModel?.model || '', }, ) const rerankModel = useMemo(() => { return { - provider_name: datasetConfigs?.reranking_model?.reranking_provider_name ?? '', - model_name: datasetConfigs?.reranking_model?.reranking_model_name ?? '', + provider_name: datasetConfigs.reranking_model?.reranking_provider_name ?? '', + model_name: datasetConfigs.reranking_model?.reranking_model_name ?? '', } }, [datasetConfigs.reranking_model]) @@ -135,7 +141,7 @@ const ConfigContent: FC = ({ }) } - const model = singleRetrievalConfig + const model = singleRetrievalConfig // Legacy code, for compatibility, have to keep it const rerankingModeOptions = [ { @@ -158,7 +164,7 @@ const ConfigContent: FC = ({ const canManuallyToggleRerank = useMemo(() => { return (selectedDatasetsMode.allInternal && selectedDatasetsMode.allEconomic) - || selectedDatasetsMode.allExternal + || selectedDatasetsMode.allExternal }, [selectedDatasetsMode.allEconomic, selectedDatasetsMode.allExternal, selectedDatasetsMode.allInternal]) const showRerankModel = useMemo(() => { @@ -168,7 +174,7 @@ const ConfigContent: FC = ({ return datasetConfigs.reranking_enable }, [datasetConfigs.reranking_enable, canManuallyToggleRerank]) - const handleDisabledSwitchClick = useCallback((enable: boolean) => { + const handleManuallyToggleRerank = useCallback((enable: boolean) => { if (!currentRerankModel && enable) Toast.notify({ type: 'error', message: t('workflow.errorMsg.rerankModelRequired') }) onChange({ @@ -255,12 +261,11 @@ const ConfigContent: FC = ({
{ - selectedDatasetsMode.allEconomic && !selectedDatasetsMode.mixtureInternalAndExternal && ( + canManuallyToggleRerank && ( ) } diff --git a/web/app/components/app/configuration/hooks/use-advanced-prompt-config.ts b/web/app/components/app/configuration/hooks/use-advanced-prompt-config.ts index 193ac87dd0..92958cc96d 100644 --- a/web/app/components/app/configuration/hooks/use-advanced-prompt-config.ts +++ b/web/app/components/app/configuration/hooks/use-advanced-prompt-config.ts @@ -35,8 +35,8 @@ const useAdvancedPromptConfig = ({ setStop, }: Param) => { const isAdvancedPrompt = promptMode === PromptMode.advanced - const [chatPromptConfig, setChatPromptConfig] = useState(clone(DEFAULT_CHAT_PROMPT_CONFIG)) - const [completionPromptConfig, setCompletionPromptConfig] = useState(clone(DEFAULT_COMPLETION_PROMPT_CONFIG)) + const [chatPromptConfig, setChatPromptConfig] = useState(() => clone(DEFAULT_CHAT_PROMPT_CONFIG)) + const [completionPromptConfig, setCompletionPromptConfig] = useState(() => clone(DEFAULT_COMPLETION_PROMPT_CONFIG)) const currentAdvancedPrompt = (() => { if (!isAdvancedPrompt) diff --git a/web/app/components/app/configuration/index.tsx b/web/app/components/app/configuration/index.tsx index 091900642a..f1f81ebf97 100644 --- a/web/app/components/app/configuration/index.tsx +++ b/web/app/components/app/configuration/index.tsx @@ -284,18 +284,28 @@ const Configuration: FC = () => { setRerankSettingModalOpen(true) const { datasets, retrieval_model, score_threshold_enabled, ...restConfigs } = datasetConfigs + const { + top_k, + score_threshold, + reranking_model, + reranking_mode, + weights, + reranking_enable, + } = restConfigs - const retrievalConfig = getMultipleRetrievalConfig({ - top_k: restConfigs.top_k, - score_threshold: restConfigs.score_threshold, - reranking_model: restConfigs.reranking_model && { - provider: restConfigs.reranking_model.reranking_provider_name, - model: restConfigs.reranking_model.reranking_model_name, - }, - reranking_mode: restConfigs.reranking_mode, - weights: restConfigs.weights, - reranking_enable: restConfigs.reranking_enable, - }, newDatasets, dataSets, { + const oldRetrievalConfig = { + top_k, + score_threshold, + reranking_model: (reranking_model.reranking_provider_name && reranking_model.reranking_model_name) ? { + provider: reranking_model.reranking_provider_name, + model: reranking_model.reranking_model_name, + } : undefined, + reranking_mode, + weights, + reranking_enable, + } + + const retrievalConfig = getMultipleRetrievalConfig(oldRetrievalConfig, newDatasets, dataSets, { provider: currentRerankProvider?.provider, model: currentRerankModel?.model, }) diff --git a/web/app/components/base/chat/chat/index.tsx b/web/app/components/base/chat/chat/index.tsx index bee37cf2cd..a362f4dc99 100644 --- a/web/app/components/base/chat/chat/index.tsx +++ b/web/app/components/base/chat/chat/index.tsx @@ -160,8 +160,13 @@ const Chat: FC = ({ }) useEffect(() => { - window.addEventListener('resize', debounce(handleWindowResize)) - return () => window.removeEventListener('resize', handleWindowResize) + const debouncedHandler = debounce(handleWindowResize, 200) + window.addEventListener('resize', debouncedHandler) + + return () => { + window.removeEventListener('resize', debouncedHandler) + debouncedHandler.cancel() + } }, [handleWindowResize]) useEffect(() => { diff --git a/web/app/components/base/chat/chat/loading-anim/style.module.css b/web/app/components/base/chat/chat/loading-anim/style.module.css index b1371ec82a..d5a373df6f 100644 --- a/web/app/components/base/chat/chat/loading-anim/style.module.css +++ b/web/app/components/base/chat/chat/loading-anim/style.module.css @@ -1,6 +1,6 @@ .dot-flashing { position: relative; - animation: 1s infinite linear alternate; + animation: dot-flashing 1s infinite linear alternate; animation-delay: 0.5s; } @@ -10,7 +10,7 @@ display: inline-block; position: absolute; top: 0; - animation: 1s infinite linear alternate; + animation: dot-flashing 1s infinite linear alternate; } .dot-flashing::before { @@ -51,15 +51,21 @@ border-radius: 50%; background-color: #667085; color: #667085; - animation-name: dot-flashing; + animation: dot-flashing 1s infinite linear alternate; +} + +.text { + animation-delay: 0.5s; } .text::before { left: -7px; + animation-delay: 0s; } .text::after { left: 7px; + animation-delay: 1s; } .avatar, @@ -70,13 +76,19 @@ border-radius: 50%; background-color: #155EEF; color: #155EEF; - animation-name: dot-flashing-avatar; + animation: dot-flashing-avatar 1s infinite linear alternate; +} + +.avatar { + animation-delay: 0.5s; } .avatar::before { left: -5px; + animation-delay: 0s; } .avatar::after { left: 5px; + animation-delay: 1s; } diff --git a/web/app/components/base/date-and-time-picker/date-picker/index.tsx b/web/app/components/base/date-and-time-picker/date-picker/index.tsx index 0957b673cd..3114e80d90 100644 --- a/web/app/components/base/date-and-time-picker/date-picker/index.tsx +++ b/web/app/components/base/date-and-time-picker/date-picker/index.tsx @@ -56,8 +56,8 @@ const DatePicker = ({ const [currentDate, setCurrentDate] = useState(inputValue || defaultValue) const [selectedDate, setSelectedDate] = useState(inputValue) - const [selectedMonth, setSelectedMonth] = useState((inputValue || defaultValue).month()) - const [selectedYear, setSelectedYear] = useState((inputValue || defaultValue).year()) + const [selectedMonth, setSelectedMonth] = useState(() => (inputValue || defaultValue).month()) + const [selectedYear, setSelectedYear] = useState(() => (inputValue || defaultValue).year()) useEffect(() => { const handleClickOutside = (event: MouseEvent) => { diff --git a/web/app/components/base/date-and-time-picker/time-picker/index.tsx b/web/app/components/base/date-and-time-picker/time-picker/index.tsx index 830ba4bf0b..eb21d739af 100644 --- a/web/app/components/base/date-and-time-picker/time-picker/index.tsx +++ b/web/app/components/base/date-and-time-picker/time-picker/index.tsx @@ -29,7 +29,7 @@ const TimePicker = ({ const [isOpen, setIsOpen] = useState(false) const containerRef = useRef(null) const isInitial = useRef(true) - const [selectedTime, setSelectedTime] = useState(value ? getDateWithTimezone({ timezone, date: value }) : undefined) + const [selectedTime, setSelectedTime] = useState(() => value ? getDateWithTimezone({ timezone, date: value }) : undefined) useEffect(() => { const handleClickOutside = (event: MouseEvent) => { diff --git a/web/app/components/base/markdown-blocks/think-block.tsx b/web/app/components/base/markdown-blocks/think-block.tsx index acceecd433..a3b0561677 100644 --- a/web/app/components/base/markdown-blocks/think-block.tsx +++ b/web/app/components/base/markdown-blocks/think-block.tsx @@ -37,7 +37,7 @@ const removeEndThink = (children: any): any => { const useThinkTimer = (children: any) => { const { isResponding } = useChatContext() - const [startTime] = useState(Date.now()) + const [startTime] = useState(() => Date.now()) const [elapsedTime, setElapsedTime] = useState(0) const [isComplete, setIsComplete] = useState(false) const timerRef = useRef() diff --git a/web/app/components/base/notion-page-selector/base.tsx b/web/app/components/base/notion-page-selector/base.tsx index 1c54b57a18..adf044c406 100644 --- a/web/app/components/base/notion-page-selector/base.tsx +++ b/web/app/components/base/notion-page-selector/base.tsx @@ -93,7 +93,7 @@ const NotionPageSelector = ({ const defaultSelectedPagesId = useMemo(() => { return [...Array.from(pagesMapAndSelectedPagesId[1]), ...(value || [])] }, [pagesMapAndSelectedPagesId, value]) - const [selectedPagesId, setSelectedPagesId] = useState>(new Set(defaultSelectedPagesId)) + const [selectedPagesId, setSelectedPagesId] = useState>(() => new Set(defaultSelectedPagesId)) useEffect(() => { setSelectedPagesId(new Set(defaultSelectedPagesId)) diff --git a/web/app/components/base/tab-slider/index.tsx b/web/app/components/base/tab-slider/index.tsx index 56cde52154..55c44d5ea8 100644 --- a/web/app/components/base/tab-slider/index.tsx +++ b/web/app/components/base/tab-slider/index.tsx @@ -21,7 +21,7 @@ const TabSlider: FC = ({ onChange, options, }) => { - const [activeIndex, setActiveIndex] = useState(options.findIndex(option => option.value === value)) + const [activeIndex, setActiveIndex] = useState(() => options.findIndex(option => option.value === value)) const [sliderStyle, setSliderStyle] = useState({}) const { data: pluginList } = useInstalledPluginList() diff --git a/web/app/components/custom/custom-web-app-brand/index.tsx b/web/app/components/custom/custom-web-app-brand/index.tsx index ea2f44caea..eb06265042 100644 --- a/web/app/components/custom/custom-web-app-brand/index.tsx +++ b/web/app/components/custom/custom-web-app-brand/index.tsx @@ -38,7 +38,7 @@ const CustomWebAppBrand = () => { isCurrentWorkspaceManager, } = useAppContext() const [fileId, setFileId] = useState('') - const [imgKey, setImgKey] = useState(Date.now()) + const [imgKey, setImgKey] = useState(() => Date.now()) const [uploadProgress, setUploadProgress] = useState(0) const systemFeatures = useGlobalPublicStore(s => s.systemFeatures) const isSandbox = enableBilling && plan.type === Plan.sandbox diff --git a/web/app/components/datasets/common/retrieval-method-config/index.tsx b/web/app/components/datasets/common/retrieval-method-config/index.tsx index 57d357442f..ed230c52ce 100644 --- a/web/app/components/datasets/common/retrieval-method-config/index.tsx +++ b/web/app/components/datasets/common/retrieval-method-config/index.tsx @@ -40,7 +40,7 @@ const RetrievalMethodConfig: FC = ({ onChange({ ...value, search_method: retrieveMethod, - ...(!value.reranking_model.reranking_model_name + ...((!value.reranking_model.reranking_model_name || !value.reranking_model.reranking_provider_name) ? { reranking_model: { reranking_provider_name: isRerankDefaultModelValid ? rerankDefaultModel?.provider?.provider ?? '' : '', @@ -57,7 +57,7 @@ const RetrievalMethodConfig: FC = ({ onChange({ ...value, search_method: retrieveMethod, - ...(!value.reranking_model.reranking_model_name + ...((!value.reranking_model.reranking_model_name || !value.reranking_model.reranking_provider_name) ? { reranking_model: { reranking_provider_name: isRerankDefaultModelValid ? rerankDefaultModel?.provider?.provider ?? '' : '', diff --git a/web/app/components/datasets/common/retrieval-param-config/index.tsx b/web/app/components/datasets/common/retrieval-param-config/index.tsx index 216a56ab16..0c28149d56 100644 --- a/web/app/components/datasets/common/retrieval-param-config/index.tsx +++ b/web/app/components/datasets/common/retrieval-param-config/index.tsx @@ -54,7 +54,7 @@ const RetrievalParamConfig: FC = ({ }, ) - const handleDisabledSwitchClick = useCallback((enable: boolean) => { + const handleToggleRerankEnable = useCallback((enable: boolean) => { if (enable && !currentModel) Toast.notify({ type: 'error', message: t('workflow.errorMsg.rerankModelRequired') }) onChange({ @@ -119,7 +119,7 @@ const RetrievalParamConfig: FC = ({ )}
diff --git a/web/app/components/header/account-setting/data-source-page-new/install-from-marketplace.tsx b/web/app/components/header/account-setting/data-source-page-new/install-from-marketplace.tsx index 4c0de924d1..f4f7749f7f 100644 --- a/web/app/components/header/account-setting/data-source-page-new/install-from-marketplace.tsx +++ b/web/app/components/header/account-setting/data-source-page-new/install-from-marketplace.tsx @@ -52,7 +52,7 @@ const InstallFromMarketplace = ({
setCollapse(!collapse)}> - {t('common.modelProvider.installProvider')} + {t('common.modelProvider.installDataSourceProvider')}
{t('common.modelProvider.discoverMore')} diff --git a/web/app/components/header/account-setting/model-provider-page/hooks.ts b/web/app/components/header/account-setting/model-provider-page/hooks.ts index b10aeeb47e..48dc609795 100644 --- a/web/app/components/header/account-setting/model-provider-page/hooks.ts +++ b/web/app/components/header/account-setting/model-provider-page/hooks.ts @@ -323,15 +323,18 @@ export const useRefreshModel = () => { const { eventEmitter } = useEventEmitterContextContext() const updateModelProviders = useUpdateModelProviders() const updateModelList = useUpdateModelList() - const handleRefreshModel = useCallback((provider: ModelProvider, configurationMethod: ConfigurationMethodEnum, CustomConfigurationModelFixedFields?: CustomConfigurationModelFixedFields) => { + const handleRefreshModel = useCallback(( + provider: ModelProvider, + CustomConfigurationModelFixedFields?: CustomConfigurationModelFixedFields, + refreshModelList?: boolean, + ) => { updateModelProviders() provider.supported_model_types.forEach((type) => { updateModelList(type) }) - if (configurationMethod === ConfigurationMethodEnum.customizableModel - && provider.custom_configuration.status === CustomConfigurationStatusEnum.active) { + if (refreshModelList && provider.custom_configuration.status === CustomConfigurationStatusEnum.active) { eventEmitter?.emit({ type: UPDATE_MODEL_PROVIDER_CUSTOM_MODEL_LIST, payload: provider.provider, diff --git a/web/app/components/header/account-setting/model-provider-page/model-auth/hooks/use-auth.ts b/web/app/components/header/account-setting/model-provider-page/model-auth/hooks/use-auth.ts index 14b21be7f7..3136a70563 100644 --- a/web/app/components/header/account-setting/model-provider-page/model-auth/hooks/use-auth.ts +++ b/web/app/components/header/account-setting/model-provider-page/model-auth/hooks/use-auth.ts @@ -90,7 +90,7 @@ export const useAuth = ( type: 'success', message: t('common.api.actionSuccess'), }) - handleRefreshModel(provider, configurationMethod, undefined) + handleRefreshModel(provider, undefined, true) } finally { handleSetDoingAction(false) @@ -125,7 +125,7 @@ export const useAuth = ( type: 'success', message: t('common.api.actionSuccess'), }) - handleRefreshModel(provider, configurationMethod, undefined) + handleRefreshModel(provider, undefined, true) onRemove?.(pendingOperationCredentialId.current ?? '') closeConfirmDelete() } @@ -147,7 +147,7 @@ export const useAuth = ( if (res.result === 'success') { notify({ type: 'success', message: t('common.actionMsg.modifiedSuccessfully') }) - handleRefreshModel(provider, configurationMethod, undefined) + handleRefreshModel(provider, undefined, !payload.credential_id) } } finally { diff --git a/web/app/components/header/account-setting/model-provider-page/provider-added-card/model-load-balancing-modal.tsx b/web/app/components/header/account-setting/model-provider-page/provider-added-card/model-load-balancing-modal.tsx index 070c2ee90f..090147897b 100644 --- a/web/app/components/header/account-setting/model-provider-page/provider-added-card/model-load-balancing-modal.tsx +++ b/web/app/components/header/account-setting/model-provider-page/provider-added-card/model-load-balancing-modal.tsx @@ -159,7 +159,7 @@ const ModelLoadBalancingModal = ({ ) if (res.result === 'success') { notify({ type: 'success', message: t('common.actionMsg.modifiedSuccessfully') }) - handleRefreshModel(provider, configurateMethod, currentCustomConfigurationModelFixedFields) + handleRefreshModel(provider, currentCustomConfigurationModelFixedFields, false) onSave?.(provider.provider) onClose?.() } diff --git a/web/app/components/header/maintenance-notice.tsx b/web/app/components/header/maintenance-notice.tsx index 4bb4ef7f7d..bcbb344b2c 100644 --- a/web/app/components/header/maintenance-notice.tsx +++ b/web/app/components/header/maintenance-notice.tsx @@ -6,7 +6,7 @@ import { useLanguage } from '@/app/components/header/account-setting/model-provi const MaintenanceNotice = () => { const locale = useLanguage() - const [showNotice, setShowNotice] = useState(localStorage.getItem('hide-maintenance-notice') !== '1') + const [showNotice, setShowNotice] = useState(() => localStorage.getItem('hide-maintenance-notice') !== '1') const handleJumpNotice = () => { window.open(NOTICE_I18N.href, '_blank') } diff --git a/web/app/components/signin/countdown.tsx b/web/app/components/signin/countdown.tsx index 5fd6a29712..c16bd46fe4 100644 --- a/web/app/components/signin/countdown.tsx +++ b/web/app/components/signin/countdown.tsx @@ -12,7 +12,7 @@ type CountdownProps = { export default function Countdown({ onResend }: CountdownProps) { const { t } = useTranslation() - const [leftTime, setLeftTime] = useState(Number(localStorage.getItem(COUNT_DOWN_KEY) || COUNT_DOWN_TIME_MS)) + const [leftTime, setLeftTime] = useState(() => Number(localStorage.getItem(COUNT_DOWN_KEY) || COUNT_DOWN_TIME_MS)) const [time] = useCountDown({ leftTime, onEnd: () => { diff --git a/web/app/components/tools/mcp/modal.tsx b/web/app/components/tools/mcp/modal.tsx index 211d594caf..1a12b3b3e9 100644 --- a/web/app/components/tools/mcp/modal.tsx +++ b/web/app/components/tools/mcp/modal.tsx @@ -65,7 +65,7 @@ const MCPModal = ({ const originalServerID = data?.server_identifier const [url, setUrl] = React.useState(data?.server_url || '') const [name, setName] = React.useState(data?.name || '') - const [appIcon, setAppIcon] = useState(getIcon(data)) + const [appIcon, setAppIcon] = useState(() => getIcon(data)) const [showAppIconPicker, setShowAppIconPicker] = useState(false) const [serverIdentifier, setServerIdentifier] = React.useState(data?.server_identifier || '') const [timeout, setMcpTimeout] = React.useState(data?.timeout || 30) diff --git a/web/app/components/tools/provider-list.tsx b/web/app/components/tools/provider-list.tsx index d267b49c79..08a4aa0b5d 100644 --- a/web/app/components/tools/provider-list.tsx +++ b/web/app/components/tools/provider-list.tsx @@ -17,7 +17,7 @@ import CardMoreInfo from '@/app/components/plugins/card/card-more-info' import PluginDetailPanel from '@/app/components/plugins/plugin-detail-panel' import MCPList from './mcp' import { useAllToolProviders } from '@/service/use-tools' -import { useInstalledPluginList, useInvalidateInstalledPluginList } from '@/service/use-plugins' +import { useCheckInstalled, useInvalidateInstalledPluginList } from '@/service/use-plugins' import { useGlobalPublicStore } from '@/context/global-public-context' import { ToolTypeEnum } from '../workflow/block-selector/types' import { useMarketplace } from './marketplace/hooks' @@ -77,12 +77,14 @@ const ProviderList = () => { const currentProvider = useMemo(() => { return filteredCollectionList.find(collection => collection.id === currentProviderId) }, [currentProviderId, filteredCollectionList]) - const { data: pluginList } = useInstalledPluginList() + const { data: checkedInstalledData } = useCheckInstalled({ + pluginIds: currentProvider?.plugin_id ? [currentProvider.plugin_id] : [], + enabled: !!currentProvider?.plugin_id, + }) const invalidateInstalledPluginList = useInvalidateInstalledPluginList() const currentPluginDetail = useMemo(() => { - const detail = pluginList?.plugins.find(plugin => plugin.plugin_id === currentProvider?.plugin_id) - return detail - }, [currentProvider?.plugin_id, pluginList?.plugins]) + return checkedInstalledData?.plugins?.[0] + }, [checkedInstalledData]) const toolListTailRef = useRef(null) const showMarketplacePanel = useCallback(() => { diff --git a/web/app/components/workflow/nodes/_base/hooks/use-resize-panel.ts b/web/app/components/workflow/nodes/_base/hooks/use-resize-panel.ts index f2259a02cf..336c440d58 100644 --- a/web/app/components/workflow/nodes/_base/hooks/use-resize-panel.ts +++ b/web/app/components/workflow/nodes/_base/hooks/use-resize-panel.ts @@ -33,7 +33,7 @@ export const useResizePanel = (params?: UseResizePanelParams) => { const initContainerWidthRef = useRef(0) const initContainerHeightRef = useRef(0) const isResizingRef = useRef(false) - const [prevUserSelectStyle, setPrevUserSelectStyle] = useState(getComputedStyle(document.body).userSelect) + const [prevUserSelectStyle, setPrevUserSelectStyle] = useState(() => getComputedStyle(document.body).userSelect) const handleStartResize = useCallback((e: MouseEvent) => { initXRef.current = e.clientX diff --git a/web/app/components/workflow/nodes/http/hooks/use-key-value-list.ts b/web/app/components/workflow/nodes/http/hooks/use-key-value-list.ts index a61cad646f..44774074dc 100644 --- a/web/app/components/workflow/nodes/http/hooks/use-key-value-list.ts +++ b/web/app/components/workflow/nodes/http/hooks/use-key-value-list.ts @@ -16,7 +16,7 @@ const strToKeyValueList = (value: string) => { } const useKeyValueList = (value: string, onChange: (value: string) => void, noFilter?: boolean) => { - const [list, doSetList] = useState(value ? strToKeyValueList(value) : []) + const [list, doSetList] = useState(() => value ? strToKeyValueList(value) : []) const setList = (l: KeyValue[]) => { doSetList(l.map((item) => { return { diff --git a/web/app/components/workflow/nodes/knowledge-base/components/option-card.tsx b/web/app/components/workflow/nodes/knowledge-base/components/option-card.tsx index c15157fc5c..789e24835f 100644 --- a/web/app/components/workflow/nodes/knowledge-base/components/option-card.tsx +++ b/web/app/components/workflow/nodes/knowledge-base/components/option-card.tsx @@ -86,7 +86,10 @@ const OptionCard = memo(({ readonly && 'cursor-not-allowed', wrapperClassName && (typeof wrapperClassName === 'function' ? wrapperClassName(isActive) : wrapperClassName), )} - onClick={() => !readonly && enableSelect && id && onClick?.(id)} + onClick={(e) => { + e.stopPropagation() + !readonly && enableSelect && id && onClick?.(id) + }} >
= { chunk_structure, indexing_technique, retrieval_model, + embedding_model, + embedding_model_provider, + index_chunk_variable_selector, } = payload + const { + search_method, + reranking_enable, + reranking_model, + } = retrieval_model || {} + if (!chunk_structure) { return { isValid: false, @@ -36,6 +46,13 @@ const nodeDefault: NodeDefault = { } } + if (index_chunk_variable_selector.length === 0) { + return { + isValid: false, + errorMessage: t('workflow.nodes.knowledgeBase.chunksVariableIsRequired'), + } + } + if (!indexing_technique) { return { isValid: false, @@ -43,13 +60,27 @@ const nodeDefault: NodeDefault = { } } - if (!retrieval_model || !retrieval_model.search_method) { + if (indexing_technique === IndexingType.QUALIFIED && (!embedding_model || !embedding_model_provider)) { + return { + isValid: false, + errorMessage: t('workflow.nodes.knowledgeBase.embeddingModelIsRequired'), + } + } + + if (!retrieval_model || !search_method) { return { isValid: false, errorMessage: t('workflow.nodes.knowledgeBase.retrievalSettingIsRequired'), } } + if (reranking_enable && (!reranking_model || !reranking_model.reranking_provider_name || !reranking_model.reranking_model_name)) { + return { + isValid: false, + errorMessage: t('workflow.nodes.knowledgeBase.rerankingModelIsRequired'), + } + } + return { isValid: true, errorMessage: '', diff --git a/web/app/components/workflow/nodes/knowledge-base/hooks/use-config.ts b/web/app/components/workflow/nodes/knowledge-base/hooks/use-config.ts index 365722feba..8b22704c5a 100644 --- a/web/app/components/workflow/nodes/knowledge-base/hooks/use-config.ts +++ b/web/app/components/workflow/nodes/knowledge-base/hooks/use-config.ts @@ -9,13 +9,17 @@ import { ChunkStructureEnum, IndexMethodEnum, RetrievalSearchMethodEnum, + WeightedScoreEnum, } from '../types' import type { - HybridSearchModeEnum, KnowledgeBaseNodeType, RerankingModel, } from '../types' +import { + HybridSearchModeEnum, +} from '../types' import { isHighQualitySearchMethod } from '../utils' +import { DEFAULT_WEIGHTED_SCORE, RerankingModeEnum } from '@/models/datasets' export const useConfig = (id: string) => { const store = useStoreApi() @@ -35,6 +39,25 @@ export const useConfig = (id: string) => { }) }, [id, handleNodeDataUpdateWithSyncDraft]) + const getDefaultWeights = useCallback(({ + embeddingModel, + embeddingModelProvider, + }: { + embeddingModel: string + embeddingModelProvider: string + }) => { + return { + vector_setting: { + vector_weight: DEFAULT_WEIGHTED_SCORE.other.semantic, + embedding_provider_name: embeddingModelProvider || '', + embedding_model_name: embeddingModel, + }, + keyword_setting: { + keyword_weight: DEFAULT_WEIGHTED_SCORE.other.keyword, + }, + } + }, []) + const handleChunkStructureChange = useCallback((chunkStructure: ChunkStructureEnum) => { const nodeData = getNodeData() const { @@ -80,39 +103,72 @@ export const useConfig = (id: string) => { embeddingModelProvider: string }) => { const nodeData = getNodeData() - handleNodeDataUpdate({ + const defaultWeights = getDefaultWeights({ + embeddingModel, + embeddingModelProvider, + }) + const changeData = { embedding_model: embeddingModel, embedding_model_provider: embeddingModelProvider, retrieval_model: { ...nodeData?.data.retrieval_model, - vector_setting: { - ...nodeData?.data.retrieval_model.vector_setting, - embedding_provider_name: embeddingModelProvider, - embedding_model_name: embeddingModel, - }, }, - }) - }, [getNodeData, handleNodeDataUpdate]) + } + if (changeData.retrieval_model.weights) { + changeData.retrieval_model = { + ...changeData.retrieval_model, + weights: { + ...changeData.retrieval_model.weights, + vector_setting: { + ...changeData.retrieval_model.weights.vector_setting, + embedding_provider_name: embeddingModelProvider, + embedding_model_name: embeddingModel, + }, + }, + } + } + else { + changeData.retrieval_model = { + ...changeData.retrieval_model, + weights: defaultWeights, + } + } + handleNodeDataUpdate(changeData) + }, [getNodeData, getDefaultWeights, handleNodeDataUpdate]) const handleRetrievalSearchMethodChange = useCallback((searchMethod: RetrievalSearchMethodEnum) => { const nodeData = getNodeData() - handleNodeDataUpdate({ + const changeData = { retrieval_model: { ...nodeData?.data.retrieval_model, search_method: searchMethod, + reranking_mode: nodeData?.data.retrieval_model.reranking_mode || RerankingModeEnum.RerankingModel, }, - }) + } + if (searchMethod === RetrievalSearchMethodEnum.hybrid) { + changeData.retrieval_model = { + ...changeData.retrieval_model, + reranking_enable: changeData.retrieval_model.reranking_mode === RerankingModeEnum.RerankingModel, + } + } + handleNodeDataUpdate(changeData) }, [getNodeData, handleNodeDataUpdate]) const handleHybridSearchModeChange = useCallback((hybridSearchMode: HybridSearchModeEnum) => { const nodeData = getNodeData() + const defaultWeights = getDefaultWeights({ + embeddingModel: nodeData?.data.embedding_model || '', + embeddingModelProvider: nodeData?.data.embedding_model_provider || '', + }) handleNodeDataUpdate({ retrieval_model: { ...nodeData?.data.retrieval_model, reranking_mode: hybridSearchMode, + reranking_enable: hybridSearchMode === HybridSearchModeEnum.RerankingModel, + weights: nodeData?.data.retrieval_model.weights || defaultWeights, }, }) - }, [getNodeData, handleNodeDataUpdate]) + }, [getNodeData, getDefaultWeights, handleNodeDataUpdate]) const handleRerankingModelEnabledChange = useCallback((rerankingModelEnabled: boolean) => { const nodeData = getNodeData() @@ -130,11 +186,10 @@ export const useConfig = (id: string) => { retrieval_model: { ...nodeData?.data.retrieval_model, weights: { - weight_type: 'weighted_score', + weight_type: WeightedScoreEnum.Customized, vector_setting: { + ...nodeData?.data.retrieval_model.weights?.vector_setting, vector_weight: weightedScore.value[0], - embedding_provider_name: '', - embedding_model_name: '', }, keyword_setting: { keyword_weight: weightedScore.value[1], diff --git a/web/app/components/workflow/nodes/knowledge-retrieval/components/retrieval-config.tsx b/web/app/components/workflow/nodes/knowledge-retrieval/components/retrieval-config.tsx index 8a3dc1efba..619216d672 100644 --- a/web/app/components/workflow/nodes/knowledge-retrieval/components/retrieval-config.tsx +++ b/web/app/components/workflow/nodes/knowledge-retrieval/components/retrieval-config.tsx @@ -1,6 +1,6 @@ 'use client' import type { FC } from 'react' -import React, { useCallback, useState } from 'react' +import React, { useCallback, useMemo } from 'react' import { RiEqualizer2Line } from '@remixicon/react' import { useTranslation } from 'react-i18next' import type { MultipleRetrievalConfig, SingleRetrievalConfig } from '../types' @@ -14,8 +14,6 @@ import { import ConfigRetrievalContent from '@/app/components/app/configuration/dataset-config/params-config/config-content' import { RETRIEVE_TYPE } from '@/types/app' import { DATASET_DEFAULT } from '@/config' -import { useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks' -import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' import Button from '@/app/components/base/button' import type { DatasetConfigs } from '@/models/debug' import type { DataSet } from '@/models/datasets' @@ -32,8 +30,8 @@ type Props = { onSingleRetrievalModelChange?: (config: ModelConfig) => void onSingleRetrievalModelParamsChange?: (config: ModelConfig) => void readonly?: boolean - openFromProps?: boolean - onOpenFromPropsChange?: (openFromProps: boolean) => void + rerankModalOpen: boolean + onRerankModelOpenChange: (open: boolean) => void selectedDatasets: DataSet[] } @@ -45,26 +43,52 @@ const RetrievalConfig: FC = ({ onSingleRetrievalModelChange, onSingleRetrievalModelParamsChange, readonly, - openFromProps, - onOpenFromPropsChange, + rerankModalOpen, + onRerankModelOpenChange, selectedDatasets, }) => { const { t } = useTranslation() - const [open, setOpen] = useState(false) - const mergedOpen = openFromProps !== undefined ? openFromProps : open + const { retrieval_mode, multiple_retrieval_config } = payload const handleOpen = useCallback((newOpen: boolean) => { - setOpen(newOpen) - onOpenFromPropsChange?.(newOpen) - }, [onOpenFromPropsChange]) + onRerankModelOpenChange(newOpen) + }, [onRerankModelOpenChange]) - const { - currentProvider: validRerankDefaultProvider, - currentModel: validRerankDefaultModel, - } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank) + const datasetConfigs = useMemo(() => { + const { + reranking_model, + top_k, + score_threshold, + reranking_mode, + weights, + reranking_enable, + } = multiple_retrieval_config || {} + + return { + retrieval_model: retrieval_mode, + reranking_model: (reranking_model?.provider && reranking_model?.model) + ? { + reranking_provider_name: reranking_model?.provider, + reranking_model_name: reranking_model?.model, + } + : { + reranking_provider_name: '', + reranking_model_name: '', + }, + top_k: top_k || DATASET_DEFAULT.top_k, + score_threshold_enabled: !(score_threshold === undefined || score_threshold === null), + score_threshold, + datasets: { + datasets: [], + }, + reranking_mode, + weights, + reranking_enable, + } + }, [retrieval_mode, multiple_retrieval_config]) - const { multiple_retrieval_config } = payload const handleChange = useCallback((configs: DatasetConfigs, isRetrievalModeChange?: boolean) => { + // Legacy code, for compatibility, have to keep it if (isRetrievalModeChange) { onRetrievalModeChange(configs.retrieval_model) return @@ -72,13 +96,11 @@ const RetrievalConfig: FC = ({ onMultipleRetrievalConfigChange({ top_k: configs.top_k, score_threshold: configs.score_threshold_enabled ? (configs.score_threshold ?? DATASET_DEFAULT.score_threshold) : null, - reranking_model: payload.retrieval_mode === RETRIEVE_TYPE.oneWay + reranking_model: retrieval_mode === RETRIEVE_TYPE.oneWay ? undefined + // eslint-disable-next-line sonarjs/no-nested-conditional : (!configs.reranking_model?.reranking_provider_name - ? { - provider: validRerankDefaultProvider?.provider || '', - model: validRerankDefaultModel?.model || '', - } + ? undefined : { provider: configs.reranking_model?.reranking_provider_name, model: configs.reranking_model?.reranking_model_name, @@ -87,11 +109,11 @@ const RetrievalConfig: FC = ({ weights: configs.weights, reranking_enable: configs.reranking_enable, }) - }, [onMultipleRetrievalConfigChange, payload.retrieval_mode, validRerankDefaultProvider, validRerankDefaultModel, onRetrievalModeChange]) + }, [onMultipleRetrievalConfigChange, retrieval_mode, onRetrievalModeChange]) return ( = ({ onClick={() => { if (readonly) return - handleOpen(!mergedOpen) + handleOpen(!rerankModalOpen) }} >