diff --git a/.github/workflows/autofix.yml b/.github/workflows/autofix.yml index ef69e08da9..0cae2ef552 100644 --- a/.github/workflows/autofix.yml +++ b/.github/workflows/autofix.yml @@ -30,6 +30,8 @@ jobs: run: | uvx --from ast-grep-cli sg --pattern 'db.session.query($WHATEVER).filter($HERE)' --rewrite 'db.session.query($WHATEVER).where($HERE)' -l py --update-all uvx --from ast-grep-cli sg --pattern 'session.query($WHATEVER).filter($HERE)' --rewrite 'session.query($WHATEVER).where($HERE)' -l py --update-all + uvx --from ast-grep-cli sg -p '$A = db.Column($$$B)' -r '$A = mapped_column($$$B)' -l py --update-all + uvx --from ast-grep-cli sg -p '$A : $T = db.Column($$$B)' -r '$A : $T = mapped_column($$$B)' -l py --update-all # Convert Optional[T] to T | None (ignoring quoted types) cat > /tmp/optional-rule.yml << 'EOF' id: convert-optional-to-union diff --git a/.github/workflows/deploy-dev.yml b/.github/workflows/deploy-dev.yml index de732c3134..cd1c86e668 100644 --- a/.github/workflows/deploy-dev.yml +++ b/.github/workflows/deploy-dev.yml @@ -18,7 +18,7 @@ jobs: - name: Deploy to server uses: appleboy/ssh-action@v0.1.8 with: - host: ${{ secrets.RAG_SSH_HOST }} + host: ${{ secrets.SSH_HOST }} username: ${{ secrets.SSH_USER }} key: ${{ secrets.SSH_PRIVATE_KEY }} script: | diff --git a/api/.ruff.toml b/api/.ruff.toml index 643bc063a1..5a29e1d8fa 100644 --- a/api/.ruff.toml +++ b/api/.ruff.toml @@ -81,7 +81,6 @@ ignore = [ "SIM113", # enumerate-for-loop "SIM117", # multiple-with-statements "SIM210", # if-expr-with-true-false - "UP038", # deprecated and not recommended by Ruff, https://docs.astral.sh/ruff/rules/non-pep604-isinstance/ ] [lint.per-file-ignores] diff --git a/api/configs/middleware/vdb/opensearch_config.py b/api/configs/middleware/vdb/opensearch_config.py index ba015a6eb9..a7d712545e 100644 --- a/api/configs/middleware/vdb/opensearch_config.py +++ b/api/configs/middleware/vdb/opensearch_config.py @@ -1,23 +1,24 @@ -from enum import Enum +from enum import StrEnum from typing import Literal from pydantic import Field, PositiveInt from pydantic_settings import BaseSettings +class AuthMethod(StrEnum): + """ + Authentication method for OpenSearch + """ + + BASIC = "basic" + AWS_MANAGED_IAM = "aws_managed_iam" + + class OpenSearchConfig(BaseSettings): """ Configuration settings for OpenSearch """ - class AuthMethod(Enum): - """ - Authentication method for OpenSearch - """ - - BASIC = "basic" - AWS_MANAGED_IAM = "aws_managed_iam" - OPENSEARCH_HOST: str | None = Field( description="Hostname or IP address of the OpenSearch server (e.g., 'localhost' or 'opensearch.example.com')", default=None, diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index 23b8e2c5a2..3927685af3 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -304,7 +304,7 @@ class AppCopyApi(Resource): account = cast(Account, current_user) result = import_service.import_app( account=account, - import_mode=ImportMode.YAML_CONTENT.value, + import_mode=ImportMode.YAML_CONTENT, yaml_content=yaml_content, name=args.get("name"), description=args.get("description"), diff --git a/api/controllers/console/app/app_import.py b/api/controllers/console/app/app_import.py index c14f597c25..037561cfed 100644 --- a/api/controllers/console/app/app_import.py +++ b/api/controllers/console/app/app_import.py @@ -70,9 +70,9 @@ class AppImportApi(Resource): EnterpriseService.WebAppAuth.update_app_access_mode(result.app_id, "private") # Return appropriate status code based on result status = result.status - if status == ImportStatus.FAILED.value: + if status == ImportStatus.FAILED: return result.model_dump(mode="json"), 400 - elif status == ImportStatus.PENDING.value: + elif status == ImportStatus.PENDING: return result.model_dump(mode="json"), 202 return result.model_dump(mode="json"), 200 @@ -97,7 +97,7 @@ class AppImportConfirmApi(Resource): session.commit() # Return appropriate status code based on result - if result.status == ImportStatus.FAILED.value: + if result.status == ImportStatus.FAILED: return result.model_dump(mode="json"), 400 return result.model_dump(mode="json"), 200 diff --git a/api/controllers/console/app/conversation.py b/api/controllers/console/app/conversation.py index f104ab5dee..3b8dff613b 100644 --- a/api/controllers/console/app/conversation.py +++ b/api/controllers/console/app/conversation.py @@ -309,7 +309,7 @@ class ChatConversationApi(Resource): ) if app_model.mode == AppMode.ADVANCED_CHAT: - query = query.where(Conversation.invoke_from != InvokeFrom.DEBUGGER.value) + query = query.where(Conversation.invoke_from != InvokeFrom.DEBUGGER) match args["sort_by"]: case "created_at": diff --git a/api/controllers/console/app/model_config.py b/api/controllers/console/app/model_config.py index 11df511840..e71b774d3e 100644 --- a/api/controllers/console/app/model_config.py +++ b/api/controllers/console/app/model_config.py @@ -90,7 +90,7 @@ class ModelConfigResource(Resource): if not isinstance(tool, dict) or len(tool.keys()) <= 3: continue - agent_tool_entity = AgentToolEntity(**tool) + agent_tool_entity = AgentToolEntity.model_validate(tool) # get tool try: tool_runtime = ToolManager.get_agent_tool_runtime( @@ -124,7 +124,7 @@ class ModelConfigResource(Resource): # encrypt agent tool parameters if it's secret-input agent_mode = new_app_model_config.agent_mode_dict for tool in agent_mode.get("tools") or []: - agent_tool_entity = AgentToolEntity(**tool) + agent_tool_entity = AgentToolEntity.model_validate(tool) # get tool key = f"{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}" diff --git a/api/controllers/console/app/statistic.py b/api/controllers/console/app/statistic.py index 6471b843c6..5974395c6a 100644 --- a/api/controllers/console/app/statistic.py +++ b/api/controllers/console/app/statistic.py @@ -52,7 +52,7 @@ FROM WHERE app_id = :app_id AND invoke_from != :invoke_from""" - arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value} + arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER} timezone = pytz.timezone(account.timezone) utc_timezone = pytz.utc @@ -127,7 +127,7 @@ class DailyConversationStatistic(Resource): sa.func.count(sa.distinct(Message.conversation_id)).label("conversation_count"), ) .select_from(Message) - .where(Message.app_id == app_model.id, Message.invoke_from != InvokeFrom.DEBUGGER.value) + .where(Message.app_id == app_model.id, Message.invoke_from != InvokeFrom.DEBUGGER) ) if args["start"]: @@ -190,7 +190,7 @@ FROM WHERE app_id = :app_id AND invoke_from != :invoke_from""" - arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value} + arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER} timezone = pytz.timezone(account.timezone) utc_timezone = pytz.utc @@ -263,7 +263,7 @@ FROM WHERE app_id = :app_id AND invoke_from != :invoke_from""" - arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value} + arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER} timezone = pytz.timezone(account.timezone) utc_timezone = pytz.utc @@ -345,7 +345,7 @@ FROM WHERE c.app_id = :app_id AND m.invoke_from != :invoke_from""" - arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value} + arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER} timezone = pytz.timezone(account.timezone) utc_timezone = pytz.utc @@ -432,7 +432,7 @@ LEFT JOIN WHERE m.app_id = :app_id AND m.invoke_from != :invoke_from""" - arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value} + arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER} timezone = pytz.timezone(account.timezone) utc_timezone = pytz.utc @@ -509,7 +509,7 @@ FROM WHERE app_id = :app_id AND invoke_from != :invoke_from""" - arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value} + arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER} timezone = pytz.timezone(account.timezone) utc_timezone = pytz.utc @@ -584,7 +584,7 @@ FROM WHERE app_id = :app_id AND invoke_from != :invoke_from""" - arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value} + arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER} timezone = pytz.timezone(account.timezone) utc_timezone = pytz.utc diff --git a/api/controllers/console/app/workflow_statistic.py b/api/controllers/console/app/workflow_statistic.py index 535e7cadd6..b8904bf3d9 100644 --- a/api/controllers/console/app/workflow_statistic.py +++ b/api/controllers/console/app/workflow_statistic.py @@ -47,7 +47,7 @@ WHERE arg_dict = { "tz": account.timezone, "app_id": app_model.id, - "triggered_from": WorkflowRunTriggeredFrom.APP_RUN.value, + "triggered_from": WorkflowRunTriggeredFrom.APP_RUN, } timezone = pytz.timezone(account.timezone) @@ -115,7 +115,7 @@ WHERE arg_dict = { "tz": account.timezone, "app_id": app_model.id, - "triggered_from": WorkflowRunTriggeredFrom.APP_RUN.value, + "triggered_from": WorkflowRunTriggeredFrom.APP_RUN, } timezone = pytz.timezone(account.timezone) @@ -183,7 +183,7 @@ WHERE arg_dict = { "tz": account.timezone, "app_id": app_model.id, - "triggered_from": WorkflowRunTriggeredFrom.APP_RUN.value, + "triggered_from": WorkflowRunTriggeredFrom.APP_RUN, } timezone = pytz.timezone(account.timezone) @@ -269,7 +269,7 @@ GROUP BY arg_dict = { "tz": account.timezone, "app_id": app_model.id, - "triggered_from": WorkflowRunTriggeredFrom.APP_RUN.value, + "triggered_from": WorkflowRunTriggeredFrom.APP_RUN, } timezone = pytz.timezone(account.timezone) diff --git a/api/controllers/console/auth/activate.py b/api/controllers/console/auth/activate.py index 8cdadfb03c..76171e3f8a 100644 --- a/api/controllers/console/auth/activate.py +++ b/api/controllers/console/auth/activate.py @@ -103,7 +103,7 @@ class ActivateApi(Resource): account.interface_language = args["interface_language"] account.timezone = args["timezone"] account.interface_theme = "light" - account.status = AccountStatus.ACTIVE.value + account.status = AccountStatus.ACTIVE account.initialized_at = naive_utc_now() db.session.commit() diff --git a/api/controllers/console/auth/oauth.py b/api/controllers/console/auth/oauth.py index 5528dc0569..4efeceb676 100644 --- a/api/controllers/console/auth/oauth.py +++ b/api/controllers/console/auth/oauth.py @@ -130,11 +130,11 @@ class OAuthCallback(Resource): return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message={e.description}") # Check account status - if account.status == AccountStatus.BANNED.value: + if account.status == AccountStatus.BANNED: return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Account is banned.") - if account.status == AccountStatus.PENDING.value: - account.status = AccountStatus.ACTIVE.value + if account.status == AccountStatus.PENDING: + account.status = AccountStatus.ACTIVE account.initialized_at = naive_utc_now() db.session.commit() diff --git a/api/controllers/console/datasets/data_source.py b/api/controllers/console/datasets/data_source.py index 370e0c0d14..6d9d675e87 100644 --- a/api/controllers/console/datasets/data_source.py +++ b/api/controllers/console/datasets/data_source.py @@ -15,7 +15,7 @@ from core.datasource.entities.datasource_entities import DatasourceProviderType, from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin from core.indexing_runner import IndexingRunner from core.rag.extractor.entity.datasource_type import DatasourceType -from core.rag.extractor.entity.extract_setting import ExtractSetting +from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo from core.rag.extractor.notion_extractor import NotionExtractor from extensions.ext_database import db from fields.data_source_fields import integrate_list_fields, integrate_notion_info_list_fields @@ -256,14 +256,16 @@ class DataSourceNotionApi(Resource): credential_id = notion_info.get("credential_id") for page in notion_info["pages"]: extract_setting = ExtractSetting( - datasource_type=DatasourceType.NOTION.value, - notion_info={ - "credential_id": credential_id, - "notion_workspace_id": workspace_id, - "notion_obj_id": page["page_id"], - "notion_page_type": page["type"], - "tenant_id": current_user.current_tenant_id, - }, + datasource_type=DatasourceType.NOTION, + notion_info=NotionInfo.model_validate( + { + "credential_id": credential_id, + "notion_workspace_id": workspace_id, + "notion_obj_id": page["page_id"], + "notion_page_type": page["type"], + "tenant_id": current_user.current_tenant_id, + } + ), document_model=args["doc_form"], ) extract_settings.append(extract_setting) diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index ac088b790e..dda0125687 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -24,7 +24,7 @@ from core.model_runtime.entities.model_entities import ModelType from core.provider_manager import ProviderManager from core.rag.datasource.vdb.vector_type import VectorType from core.rag.extractor.entity.datasource_type import DatasourceType -from core.rag.extractor.entity.extract_setting import ExtractSetting +from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo, WebsiteInfo from core.rag.retrieval.retrieval_methods import RetrievalMethod from extensions.ext_database import db from fields.app_fields import related_app_list @@ -500,7 +500,7 @@ class DatasetIndexingEstimateApi(Resource): if file_details: for file_detail in file_details: extract_setting = ExtractSetting( - datasource_type=DatasourceType.FILE.value, + datasource_type=DatasourceType.FILE, upload_file=file_detail, document_model=args["doc_form"], ) @@ -512,14 +512,16 @@ class DatasetIndexingEstimateApi(Resource): credential_id = notion_info.get("credential_id") for page in notion_info["pages"]: extract_setting = ExtractSetting( - datasource_type=DatasourceType.NOTION.value, - notion_info={ - "credential_id": credential_id, - "notion_workspace_id": workspace_id, - "notion_obj_id": page["page_id"], - "notion_page_type": page["type"], - "tenant_id": current_user.current_tenant_id, - }, + datasource_type=DatasourceType.NOTION, + notion_info=NotionInfo.model_validate( + { + "credential_id": credential_id, + "notion_workspace_id": workspace_id, + "notion_obj_id": page["page_id"], + "notion_page_type": page["type"], + "tenant_id": current_user.current_tenant_id, + } + ), document_model=args["doc_form"], ) extract_settings.append(extract_setting) @@ -527,15 +529,17 @@ class DatasetIndexingEstimateApi(Resource): website_info_list = args["info_list"]["website_info_list"] for url in website_info_list["urls"]: extract_setting = ExtractSetting( - datasource_type=DatasourceType.WEBSITE.value, - website_info={ - "provider": website_info_list["provider"], - "job_id": website_info_list["job_id"], - "url": url, - "tenant_id": current_user.current_tenant_id, - "mode": "crawl", - "only_main_content": website_info_list["only_main_content"], - }, + datasource_type=DatasourceType.WEBSITE, + website_info=WebsiteInfo.model_validate( + { + "provider": website_info_list["provider"], + "job_id": website_info_list["job_id"], + "url": url, + "tenant_id": current_user.current_tenant_id, + "mode": "crawl", + "only_main_content": website_info_list["only_main_content"], + } + ), document_model=args["doc_form"], ) extract_settings.append(extract_setting) @@ -782,7 +786,7 @@ class DatasetRetrievalSettingApi(Resource): | VectorType.VIKINGDB | VectorType.UPSTASH ): - return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]} + return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH]} case ( VectorType.QDRANT | VectorType.WEAVIATE @@ -809,9 +813,9 @@ class DatasetRetrievalSettingApi(Resource): ): return { "retrieval_method": [ - RetrievalMethod.SEMANTIC_SEARCH.value, - RetrievalMethod.FULL_TEXT_SEARCH.value, - RetrievalMethod.HYBRID_SEARCH.value, + RetrievalMethod.SEMANTIC_SEARCH, + RetrievalMethod.FULL_TEXT_SEARCH, + RetrievalMethod.HYBRID_SEARCH, ] } case _: @@ -838,7 +842,7 @@ class DatasetRetrievalSettingMockApi(Resource): | VectorType.VIKINGDB | VectorType.UPSTASH ): - return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]} + return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH]} case ( VectorType.QDRANT | VectorType.WEAVIATE @@ -863,9 +867,9 @@ class DatasetRetrievalSettingMockApi(Resource): ): return { "retrieval_method": [ - RetrievalMethod.SEMANTIC_SEARCH.value, - RetrievalMethod.FULL_TEXT_SEARCH.value, - RetrievalMethod.HYBRID_SEARCH.value, + RetrievalMethod.SEMANTIC_SEARCH, + RetrievalMethod.FULL_TEXT_SEARCH, + RetrievalMethod.HYBRID_SEARCH, ] } case _: diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index c5fa2061bf..011dacde76 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -44,7 +44,7 @@ from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.errors.invoke import InvokeAuthorizationError from core.plugin.impl.exc import PluginDaemonClientSideError from core.rag.extractor.entity.datasource_type import DatasourceType -from core.rag.extractor.entity.extract_setting import ExtractSetting +from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo, WebsiteInfo from extensions.ext_database import db from fields.document_fields import ( dataset_and_document_fields, @@ -305,7 +305,7 @@ class DatasetDocumentListApi(Resource): "doc_language", type=str, default="English", required=False, nullable=False, location="json" ) args = parser.parse_args() - knowledge_config = KnowledgeConfig(**args) + knowledge_config = KnowledgeConfig.model_validate(args) if not dataset.indexing_technique and not knowledge_config.indexing_technique: raise ValueError("indexing_technique is required.") @@ -395,7 +395,7 @@ class DatasetInitApi(Resource): parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json") args = parser.parse_args() - knowledge_config = KnowledgeConfig(**args) + knowledge_config = KnowledgeConfig.model_validate(args) if knowledge_config.indexing_technique == "high_quality": if knowledge_config.embedding_model is None or knowledge_config.embedding_model_provider is None: raise ValueError("embedding model and embedding model provider are required for high quality indexing.") @@ -475,7 +475,7 @@ class DocumentIndexingEstimateApi(DocumentResource): raise NotFound("File not found.") extract_setting = ExtractSetting( - datasource_type=DatasourceType.FILE.value, upload_file=file, document_model=document.doc_form + datasource_type=DatasourceType.FILE, upload_file=file, document_model=document.doc_form ) indexing_runner = IndexingRunner() @@ -538,7 +538,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource): raise NotFound("File not found.") extract_setting = ExtractSetting( - datasource_type=DatasourceType.FILE.value, upload_file=file_detail, document_model=document.doc_form + datasource_type=DatasourceType.FILE, upload_file=file_detail, document_model=document.doc_form ) extract_settings.append(extract_setting) @@ -546,14 +546,16 @@ class DocumentBatchIndexingEstimateApi(DocumentResource): if not data_source_info: continue extract_setting = ExtractSetting( - datasource_type=DatasourceType.NOTION.value, - notion_info={ - "credential_id": data_source_info["credential_id"], - "notion_workspace_id": data_source_info["notion_workspace_id"], - "notion_obj_id": data_source_info["notion_page_id"], - "notion_page_type": data_source_info["type"], - "tenant_id": current_user.current_tenant_id, - }, + datasource_type=DatasourceType.NOTION, + notion_info=NotionInfo.model_validate( + { + "credential_id": data_source_info["credential_id"], + "notion_workspace_id": data_source_info["notion_workspace_id"], + "notion_obj_id": data_source_info["notion_page_id"], + "notion_page_type": data_source_info["type"], + "tenant_id": current_user.current_tenant_id, + } + ), document_model=document.doc_form, ) extract_settings.append(extract_setting) @@ -561,15 +563,17 @@ class DocumentBatchIndexingEstimateApi(DocumentResource): if not data_source_info: continue extract_setting = ExtractSetting( - datasource_type=DatasourceType.WEBSITE.value, - website_info={ - "provider": data_source_info["provider"], - "job_id": data_source_info["job_id"], - "url": data_source_info["url"], - "tenant_id": current_user.current_tenant_id, - "mode": data_source_info["mode"], - "only_main_content": data_source_info["only_main_content"], - }, + datasource_type=DatasourceType.WEBSITE, + website_info=WebsiteInfo.model_validate( + { + "provider": data_source_info["provider"], + "job_id": data_source_info["job_id"], + "url": data_source_info["url"], + "tenant_id": current_user.current_tenant_id, + "mode": data_source_info["mode"], + "only_main_content": data_source_info["only_main_content"], + } + ), document_model=document.doc_form, ) extract_settings.append(extract_setting) diff --git a/api/controllers/console/datasets/datasets_segments.py b/api/controllers/console/datasets/datasets_segments.py index 9f2805e2c6..d6bd02483d 100644 --- a/api/controllers/console/datasets/datasets_segments.py +++ b/api/controllers/console/datasets/datasets_segments.py @@ -309,7 +309,7 @@ class DatasetDocumentSegmentUpdateApi(Resource): ) args = parser.parse_args() SegmentService.segment_create_args_validate(args, document) - segment = SegmentService.update_segment(SegmentUpdateArgs(**args), segment, document, dataset) + segment = SegmentService.update_segment(SegmentUpdateArgs.model_validate(args), segment, document, dataset) return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200 @setup_required @@ -564,7 +564,7 @@ class ChildChunkAddApi(Resource): args = parser.parse_args() try: chunks_data = args["chunks"] - chunks = [ChildChunkUpdateArgs(**chunk) for chunk in chunks_data] + chunks = [ChildChunkUpdateArgs.model_validate(chunk) for chunk in chunks_data] child_chunks = SegmentService.update_child_chunks(chunks, segment, document, dataset) except ChildChunkIndexingServiceError as e: raise ChildChunkIndexingError(str(e)) diff --git a/api/controllers/console/datasets/metadata.py b/api/controllers/console/datasets/metadata.py index dc3cd3fce9..8438458617 100644 --- a/api/controllers/console/datasets/metadata.py +++ b/api/controllers/console/datasets/metadata.py @@ -28,7 +28,7 @@ class DatasetMetadataCreateApi(Resource): parser.add_argument("type", type=str, required=True, nullable=False, location="json") parser.add_argument("name", type=str, required=True, nullable=False, location="json") args = parser.parse_args() - metadata_args = MetadataArgs(**args) + metadata_args = MetadataArgs.model_validate(args) dataset_id_str = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id_str) @@ -137,7 +137,7 @@ class DocumentMetadataEditApi(Resource): parser = reqparse.RequestParser() parser.add_argument("operation_data", type=list, required=True, nullable=False, location="json") args = parser.parse_args() - metadata_args = MetadataOperationData(**args) + metadata_args = MetadataOperationData.model_validate(args) MetadataService.update_documents_metadata(dataset, metadata_args) diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py index 3af590afc8..e021f95283 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py @@ -88,7 +88,7 @@ class CustomizedPipelineTemplateApi(Resource): nullable=True, ) args = parser.parse_args() - pipeline_template_info = PipelineTemplateInfoEntity(**args) + pipeline_template_info = PipelineTemplateInfoEntity.model_validate(args) RagPipelineService.update_customized_pipeline_template(template_id, pipeline_template_info) return 200 diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_import.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_import.py index e0b918456b..a82872ba2b 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_import.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_import.py @@ -60,9 +60,9 @@ class RagPipelineImportApi(Resource): # Return appropriate status code based on result status = result.status - if status == ImportStatus.FAILED.value: + if status == ImportStatus.FAILED: return result.model_dump(mode="json"), 400 - elif status == ImportStatus.PENDING.value: + elif status == ImportStatus.PENDING: return result.model_dump(mode="json"), 202 return result.model_dump(mode="json"), 200 @@ -87,7 +87,7 @@ class RagPipelineImportConfirmApi(Resource): session.commit() # Return appropriate status code based on result - if result.status == ImportStatus.FAILED.value: + if result.status == ImportStatus.FAILED: return result.model_dump(mode="json"), 400 return result.model_dump(mode="json"), 200 diff --git a/api/controllers/console/explore/installed_app.py b/api/controllers/console/explore/installed_app.py index bdc3fb0dbd..c86c243c9b 100644 --- a/api/controllers/console/explore/installed_app.py +++ b/api/controllers/console/explore/installed_app.py @@ -6,7 +6,7 @@ from flask_restx import Resource, inputs, marshal_with, reqparse from sqlalchemy import and_, select from werkzeug.exceptions import BadRequest, Forbidden, NotFound -from controllers.console import api +from controllers.console import console_ns from controllers.console.explore.wraps import InstalledAppResource from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check from extensions.ext_database import db @@ -22,6 +22,7 @@ from services.feature_service import FeatureService logger = logging.getLogger(__name__) +@console_ns.route("/installed-apps") class InstalledAppsListApi(Resource): @login_required @account_initialization_required @@ -154,6 +155,7 @@ class InstalledAppsListApi(Resource): return {"message": "App installed successfully"} +@console_ns.route("/installed-apps/") class InstalledAppApi(InstalledAppResource): """ update and delete an installed app @@ -185,7 +187,3 @@ class InstalledAppApi(InstalledAppResource): db.session.commit() return {"result": "success", "message": "App info updated successfully"} - - -api.add_resource(InstalledAppsListApi, "/installed-apps") -api.add_resource(InstalledAppApi, "/installed-apps/") diff --git a/api/controllers/console/explore/parameter.py b/api/controllers/console/explore/parameter.py index 7742ea24a9..9c6b2aedfb 100644 --- a/api/controllers/console/explore/parameter.py +++ b/api/controllers/console/explore/parameter.py @@ -1,7 +1,7 @@ from flask_restx import marshal_with from controllers.common import fields -from controllers.console import api +from controllers.console import console_ns from controllers.console.app.error import AppUnavailableError from controllers.console.explore.wraps import InstalledAppResource from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict @@ -9,6 +9,7 @@ from models.model import AppMode, InstalledApp from services.app_service import AppService +@console_ns.route("/installed-apps//parameters", endpoint="installed_app_parameters") class AppParameterApi(InstalledAppResource): """Resource for app variables.""" @@ -39,6 +40,7 @@ class AppParameterApi(InstalledAppResource): return get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form) +@console_ns.route("/installed-apps//meta", endpoint="installed_app_meta") class ExploreAppMetaApi(InstalledAppResource): def get(self, installed_app: InstalledApp): """Get app meta""" @@ -46,9 +48,3 @@ class ExploreAppMetaApi(InstalledAppResource): if not app_model: raise ValueError("App not found") return AppService().get_app_meta(app_model) - - -api.add_resource( - AppParameterApi, "/installed-apps//parameters", endpoint="installed_app_parameters" -) -api.add_resource(ExploreAppMetaApi, "/installed-apps//meta", endpoint="installed_app_meta") diff --git a/api/controllers/console/explore/recommended_app.py b/api/controllers/console/explore/recommended_app.py index 974222ddf7..6d627a929a 100644 --- a/api/controllers/console/explore/recommended_app.py +++ b/api/controllers/console/explore/recommended_app.py @@ -1,7 +1,7 @@ from flask_restx import Resource, fields, marshal_with, reqparse from constants.languages import languages -from controllers.console import api +from controllers.console import console_ns from controllers.console.wraps import account_initialization_required from libs.helper import AppIconUrlField from libs.login import current_user, login_required @@ -35,6 +35,7 @@ recommended_app_list_fields = { } +@console_ns.route("/explore/apps") class RecommendedAppListApi(Resource): @login_required @account_initialization_required @@ -56,13 +57,10 @@ class RecommendedAppListApi(Resource): return RecommendedAppService.get_recommended_apps_and_categories(language_prefix) +@console_ns.route("/explore/apps/") class RecommendedAppApi(Resource): @login_required @account_initialization_required def get(self, app_id): app_id = str(app_id) return RecommendedAppService.get_recommend_app_detail(app_id) - - -api.add_resource(RecommendedAppListApi, "/explore/apps") -api.add_resource(RecommendedAppApi, "/explore/apps/") diff --git a/api/controllers/console/explore/saved_message.py b/api/controllers/console/explore/saved_message.py index 6f05f898f9..79e4a4339e 100644 --- a/api/controllers/console/explore/saved_message.py +++ b/api/controllers/console/explore/saved_message.py @@ -2,7 +2,7 @@ from flask_restx import fields, marshal_with, reqparse from flask_restx.inputs import int_range from werkzeug.exceptions import NotFound -from controllers.console import api +from controllers.console import console_ns from controllers.console.explore.error import NotCompletionAppError from controllers.console.explore.wraps import InstalledAppResource from fields.conversation_fields import message_file_fields @@ -25,6 +25,7 @@ message_fields = { } +@console_ns.route("/installed-apps//saved-messages", endpoint="installed_app_saved_messages") class SavedMessageListApi(InstalledAppResource): saved_message_infinite_scroll_pagination_fields = { "limit": fields.Integer, @@ -66,6 +67,9 @@ class SavedMessageListApi(InstalledAppResource): return {"result": "success"} +@console_ns.route( + "/installed-apps//saved-messages/", endpoint="installed_app_saved_message" +) class SavedMessageApi(InstalledAppResource): def delete(self, installed_app, message_id): app_model = installed_app.app @@ -80,15 +84,3 @@ class SavedMessageApi(InstalledAppResource): SavedMessageService.delete(app_model, current_user, message_id) return {"result": "success"}, 204 - - -api.add_resource( - SavedMessageListApi, - "/installed-apps//saved-messages", - endpoint="installed_app_saved_messages", -) -api.add_resource( - SavedMessageApi, - "/installed-apps//saved-messages/", - endpoint="installed_app_saved_message", -) diff --git a/api/controllers/console/workspace/account.py b/api/controllers/console/workspace/account.py index 7a41a8a5cc..e2b0e3f84d 100644 --- a/api/controllers/console/workspace/account.py +++ b/api/controllers/console/workspace/account.py @@ -9,7 +9,7 @@ from sqlalchemy.orm import Session from configs import dify_config from constants.languages import supported_language -from controllers.console import api +from controllers.console import console_ns from controllers.console.auth.error import ( EmailAlreadyInUseError, EmailChangeLimitError, @@ -45,6 +45,7 @@ from services.billing_service import BillingService from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError +@console_ns.route("/account/init") class AccountInitApi(Resource): @setup_required @login_required @@ -97,6 +98,7 @@ class AccountInitApi(Resource): return {"result": "success"} +@console_ns.route("/account/profile") class AccountProfileApi(Resource): @setup_required @login_required @@ -109,6 +111,7 @@ class AccountProfileApi(Resource): return current_user +@console_ns.route("/account/name") class AccountNameApi(Resource): @setup_required @login_required @@ -130,6 +133,7 @@ class AccountNameApi(Resource): return updated_account +@console_ns.route("/account/avatar") class AccountAvatarApi(Resource): @setup_required @login_required @@ -147,6 +151,7 @@ class AccountAvatarApi(Resource): return updated_account +@console_ns.route("/account/interface-language") class AccountInterfaceLanguageApi(Resource): @setup_required @login_required @@ -164,6 +169,7 @@ class AccountInterfaceLanguageApi(Resource): return updated_account +@console_ns.route("/account/interface-theme") class AccountInterfaceThemeApi(Resource): @setup_required @login_required @@ -181,6 +187,7 @@ class AccountInterfaceThemeApi(Resource): return updated_account +@console_ns.route("/account/timezone") class AccountTimezoneApi(Resource): @setup_required @login_required @@ -202,6 +209,7 @@ class AccountTimezoneApi(Resource): return updated_account +@console_ns.route("/account/password") class AccountPasswordApi(Resource): @setup_required @login_required @@ -227,6 +235,7 @@ class AccountPasswordApi(Resource): return {"result": "success"} +@console_ns.route("/account/integrates") class AccountIntegrateApi(Resource): integrate_fields = { "provider": fields.String, @@ -283,6 +292,7 @@ class AccountIntegrateApi(Resource): return {"data": integrate_data} +@console_ns.route("/account/delete/verify") class AccountDeleteVerifyApi(Resource): @setup_required @login_required @@ -298,6 +308,7 @@ class AccountDeleteVerifyApi(Resource): return {"result": "success", "data": token} +@console_ns.route("/account/delete") class AccountDeleteApi(Resource): @setup_required @login_required @@ -320,6 +331,7 @@ class AccountDeleteApi(Resource): return {"result": "success"} +@console_ns.route("/account/delete/feedback") class AccountDeleteUpdateFeedbackApi(Resource): @setup_required def post(self): @@ -333,6 +345,7 @@ class AccountDeleteUpdateFeedbackApi(Resource): return {"result": "success"} +@console_ns.route("/account/education/verify") class EducationVerifyApi(Resource): verify_fields = { "token": fields.String, @@ -352,6 +365,7 @@ class EducationVerifyApi(Resource): return BillingService.EducationIdentity.verify(account.id, account.email) +@console_ns.route("/account/education") class EducationApi(Resource): status_fields = { "result": fields.Boolean, @@ -396,6 +410,7 @@ class EducationApi(Resource): return res +@console_ns.route("/account/education/autocomplete") class EducationAutoCompleteApi(Resource): data_fields = { "data": fields.List(fields.String), @@ -419,6 +434,7 @@ class EducationAutoCompleteApi(Resource): return BillingService.EducationIdentity.autocomplete(args["keywords"], args["page"], args["limit"]) +@console_ns.route("/account/change-email") class ChangeEmailSendEmailApi(Resource): @enable_change_email @setup_required @@ -467,6 +483,7 @@ class ChangeEmailSendEmailApi(Resource): return {"result": "success", "data": token} +@console_ns.route("/account/change-email/validity") class ChangeEmailCheckApi(Resource): @enable_change_email @setup_required @@ -508,6 +525,7 @@ class ChangeEmailCheckApi(Resource): return {"is_valid": True, "email": token_data.get("email"), "token": new_token} +@console_ns.route("/account/change-email/reset") class ChangeEmailResetApi(Resource): @enable_change_email @setup_required @@ -547,6 +565,7 @@ class ChangeEmailResetApi(Resource): return updated_account +@console_ns.route("/account/change-email/check-email-unique") class CheckEmailUnique(Resource): @setup_required def post(self): @@ -558,28 +577,3 @@ class CheckEmailUnique(Resource): if not AccountService.check_email_unique(args["email"]): raise EmailAlreadyInUseError() return {"result": "success"} - - -# Register API resources -api.add_resource(AccountInitApi, "/account/init") -api.add_resource(AccountProfileApi, "/account/profile") -api.add_resource(AccountNameApi, "/account/name") -api.add_resource(AccountAvatarApi, "/account/avatar") -api.add_resource(AccountInterfaceLanguageApi, "/account/interface-language") -api.add_resource(AccountInterfaceThemeApi, "/account/interface-theme") -api.add_resource(AccountTimezoneApi, "/account/timezone") -api.add_resource(AccountPasswordApi, "/account/password") -api.add_resource(AccountIntegrateApi, "/account/integrates") -api.add_resource(AccountDeleteVerifyApi, "/account/delete/verify") -api.add_resource(AccountDeleteApi, "/account/delete") -api.add_resource(AccountDeleteUpdateFeedbackApi, "/account/delete/feedback") -api.add_resource(EducationVerifyApi, "/account/education/verify") -api.add_resource(EducationApi, "/account/education") -api.add_resource(EducationAutoCompleteApi, "/account/education/autocomplete") -# Change email -api.add_resource(ChangeEmailSendEmailApi, "/account/change-email") -api.add_resource(ChangeEmailCheckApi, "/account/change-email/validity") -api.add_resource(ChangeEmailResetApi, "/account/change-email/reset") -api.add_resource(CheckEmailUnique, "/account/change-email/check-email-unique") -# api.add_resource(AccountEmailApi, '/account/email') -# api.add_resource(AccountEmailVerifyApi, '/account/email-verify') diff --git a/api/controllers/console/workspace/load_balancing_config.py b/api/controllers/console/workspace/load_balancing_config.py index 7c1bc7c075..99a1c1f032 100644 --- a/api/controllers/console/workspace/load_balancing_config.py +++ b/api/controllers/console/workspace/load_balancing_config.py @@ -1,7 +1,7 @@ from flask_restx import Resource, reqparse from werkzeug.exceptions import Forbidden -from controllers.console import api +from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, setup_required from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.errors.validate import CredentialsValidateFailedError @@ -10,6 +10,9 @@ from models.account import Account, TenantAccountRole from services.model_load_balancing_service import ModelLoadBalancingService +@console_ns.route( + "/workspaces/current/model-providers//models/load-balancing-configs/credentials-validate" +) class LoadBalancingCredentialsValidateApi(Resource): @setup_required @login_required @@ -61,6 +64,9 @@ class LoadBalancingCredentialsValidateApi(Resource): return response +@console_ns.route( + "/workspaces/current/model-providers//models/load-balancing-configs//credentials-validate" +) class LoadBalancingConfigCredentialsValidateApi(Resource): @setup_required @login_required @@ -111,15 +117,3 @@ class LoadBalancingConfigCredentialsValidateApi(Resource): response["error"] = error return response - - -# Load Balancing Config -api.add_resource( - LoadBalancingCredentialsValidateApi, - "/workspaces/current/model-providers//models/load-balancing-configs/credentials-validate", -) - -api.add_resource( - LoadBalancingConfigCredentialsValidateApi, - "/workspaces/current/model-providers//models/load-balancing-configs//credentials-validate", -) diff --git a/api/controllers/console/workspace/members.py b/api/controllers/console/workspace/members.py index 77f0c9a735..8b89853bd9 100644 --- a/api/controllers/console/workspace/members.py +++ b/api/controllers/console/workspace/members.py @@ -6,7 +6,7 @@ from flask_restx import Resource, marshal_with, reqparse import services from configs import dify_config -from controllers.console import api +from controllers.console import console_ns from controllers.console.auth.error import ( CannotTransferOwnerToSelfError, EmailCodeError, @@ -33,6 +33,7 @@ from services.errors.account import AccountAlreadyInTenantError from services.feature_service import FeatureService +@console_ns.route("/workspaces/current/members") class MemberListApi(Resource): """List all members of current tenant.""" @@ -49,6 +50,7 @@ class MemberListApi(Resource): return {"result": "success", "accounts": members}, 200 +@console_ns.route("/workspaces/current/members/invite-email") class MemberInviteEmailApi(Resource): """Invite a new member by email.""" @@ -111,6 +113,7 @@ class MemberInviteEmailApi(Resource): }, 201 +@console_ns.route("/workspaces/current/members/") class MemberCancelInviteApi(Resource): """Cancel an invitation by member id.""" @@ -143,6 +146,7 @@ class MemberCancelInviteApi(Resource): }, 200 +@console_ns.route("/workspaces/current/members//update-role") class MemberUpdateRoleApi(Resource): """Update member role.""" @@ -177,6 +181,7 @@ class MemberUpdateRoleApi(Resource): return {"result": "success"} +@console_ns.route("/workspaces/current/dataset-operators") class DatasetOperatorMemberListApi(Resource): """List all members of current tenant.""" @@ -193,6 +198,7 @@ class DatasetOperatorMemberListApi(Resource): return {"result": "success", "accounts": members}, 200 +@console_ns.route("/workspaces/current/members/send-owner-transfer-confirm-email") class SendOwnerTransferEmailApi(Resource): """Send owner transfer email.""" @@ -233,6 +239,7 @@ class SendOwnerTransferEmailApi(Resource): return {"result": "success", "data": token} +@console_ns.route("/workspaces/current/members/owner-transfer-check") class OwnerTransferCheckApi(Resource): @setup_required @login_required @@ -278,6 +285,7 @@ class OwnerTransferCheckApi(Resource): return {"is_valid": True, "email": token_data.get("email"), "token": new_token} +@console_ns.route("/workspaces/current/members//owner-transfer") class OwnerTransfer(Resource): @setup_required @login_required @@ -339,14 +347,3 @@ class OwnerTransfer(Resource): raise ValueError(str(e)) return {"result": "success"} - - -api.add_resource(MemberListApi, "/workspaces/current/members") -api.add_resource(MemberInviteEmailApi, "/workspaces/current/members/invite-email") -api.add_resource(MemberCancelInviteApi, "/workspaces/current/members/") -api.add_resource(MemberUpdateRoleApi, "/workspaces/current/members//update-role") -api.add_resource(DatasetOperatorMemberListApi, "/workspaces/current/dataset-operators") -# owner transfer -api.add_resource(SendOwnerTransferEmailApi, "/workspaces/current/members/send-owner-transfer-confirm-email") -api.add_resource(OwnerTransferCheckApi, "/workspaces/current/members/owner-transfer-check") -api.add_resource(OwnerTransfer, "/workspaces/current/members//owner-transfer") diff --git a/api/controllers/console/workspace/model_providers.py b/api/controllers/console/workspace/model_providers.py index 0c9db660aa..7012580362 100644 --- a/api/controllers/console/workspace/model_providers.py +++ b/api/controllers/console/workspace/model_providers.py @@ -5,7 +5,7 @@ from flask_login import current_user from flask_restx import Resource, reqparse from werkzeug.exceptions import Forbidden -from controllers.console import api +from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, setup_required from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.errors.validate import CredentialsValidateFailedError @@ -17,6 +17,7 @@ from services.billing_service import BillingService from services.model_provider_service import ModelProviderService +@console_ns.route("/workspaces/current/model-providers") class ModelProviderListApi(Resource): @setup_required @login_required @@ -45,6 +46,7 @@ class ModelProviderListApi(Resource): return jsonable_encoder({"data": provider_list}) +@console_ns.route("/workspaces/current/model-providers//credentials") class ModelProviderCredentialApi(Resource): @setup_required @login_required @@ -151,6 +153,7 @@ class ModelProviderCredentialApi(Resource): return {"result": "success"}, 204 +@console_ns.route("/workspaces/current/model-providers//credentials/switch") class ModelProviderCredentialSwitchApi(Resource): @setup_required @login_required @@ -175,6 +178,7 @@ class ModelProviderCredentialSwitchApi(Resource): return {"result": "success"} +@console_ns.route("/workspaces/current/model-providers//credentials/validate") class ModelProviderValidateApi(Resource): @setup_required @login_required @@ -211,6 +215,7 @@ class ModelProviderValidateApi(Resource): return response +@console_ns.route("/workspaces//model-providers///") class ModelProviderIconApi(Resource): """ Get model provider icon @@ -229,6 +234,7 @@ class ModelProviderIconApi(Resource): return send_file(io.BytesIO(icon), mimetype=mimetype) +@console_ns.route("/workspaces/current/model-providers//preferred-provider-type") class PreferredProviderTypeUpdateApi(Resource): @setup_required @login_required @@ -262,6 +268,7 @@ class PreferredProviderTypeUpdateApi(Resource): return {"result": "success"} +@console_ns.route("/workspaces/current/model-providers//checkout-url") class ModelProviderPaymentCheckoutUrlApi(Resource): @setup_required @login_required @@ -281,21 +288,3 @@ class ModelProviderPaymentCheckoutUrlApi(Resource): prefilled_email=current_user.email, ) return data - - -api.add_resource(ModelProviderListApi, "/workspaces/current/model-providers") - -api.add_resource(ModelProviderCredentialApi, "/workspaces/current/model-providers//credentials") -api.add_resource( - ModelProviderCredentialSwitchApi, "/workspaces/current/model-providers//credentials/switch" -) -api.add_resource(ModelProviderValidateApi, "/workspaces/current/model-providers//credentials/validate") - -api.add_resource( - PreferredProviderTypeUpdateApi, "/workspaces/current/model-providers//preferred-provider-type" -) -api.add_resource(ModelProviderPaymentCheckoutUrlApi, "/workspaces/current/model-providers//checkout-url") -api.add_resource( - ModelProviderIconApi, - "/workspaces//model-providers///", -) diff --git a/api/controllers/console/workspace/models.py b/api/controllers/console/workspace/models.py index f174fcc5d3..d38bb16ea7 100644 --- a/api/controllers/console/workspace/models.py +++ b/api/controllers/console/workspace/models.py @@ -4,7 +4,7 @@ from flask_login import current_user from flask_restx import Resource, reqparse from werkzeug.exceptions import Forbidden -from controllers.console import api +from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, setup_required from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.errors.validate import CredentialsValidateFailedError @@ -17,6 +17,7 @@ from services.model_provider_service import ModelProviderService logger = logging.getLogger(__name__) +@console_ns.route("/workspaces/current/default-model") class DefaultModelApi(Resource): @setup_required @login_required @@ -85,6 +86,7 @@ class DefaultModelApi(Resource): return {"result": "success"} +@console_ns.route("/workspaces/current/model-providers//models") class ModelProviderModelApi(Resource): @setup_required @login_required @@ -187,6 +189,7 @@ class ModelProviderModelApi(Resource): return {"result": "success"}, 204 +@console_ns.route("/workspaces/current/model-providers//models/credentials") class ModelProviderModelCredentialApi(Resource): @setup_required @login_required @@ -364,6 +367,7 @@ class ModelProviderModelCredentialApi(Resource): return {"result": "success"}, 204 +@console_ns.route("/workspaces/current/model-providers//models/credentials/switch") class ModelProviderModelCredentialSwitchApi(Resource): @setup_required @login_required @@ -395,6 +399,9 @@ class ModelProviderModelCredentialSwitchApi(Resource): return {"result": "success"} +@console_ns.route( + "/workspaces/current/model-providers//models/enable", endpoint="model-provider-model-enable" +) class ModelProviderModelEnableApi(Resource): @setup_required @login_required @@ -422,6 +429,9 @@ class ModelProviderModelEnableApi(Resource): return {"result": "success"} +@console_ns.route( + "/workspaces/current/model-providers//models/disable", endpoint="model-provider-model-disable" +) class ModelProviderModelDisableApi(Resource): @setup_required @login_required @@ -449,6 +459,7 @@ class ModelProviderModelDisableApi(Resource): return {"result": "success"} +@console_ns.route("/workspaces/current/model-providers//models/credentials/validate") class ModelProviderModelValidateApi(Resource): @setup_required @login_required @@ -494,6 +505,7 @@ class ModelProviderModelValidateApi(Resource): return response +@console_ns.route("/workspaces/current/model-providers//models/parameter-rules") class ModelProviderModelParameterRuleApi(Resource): @setup_required @login_required @@ -513,6 +525,7 @@ class ModelProviderModelParameterRuleApi(Resource): return jsonable_encoder({"data": parameter_rules}) +@console_ns.route("/workspaces/current/models/model-types/") class ModelProviderAvailableModelApi(Resource): @setup_required @login_required @@ -524,32 +537,3 @@ class ModelProviderAvailableModelApi(Resource): models = model_provider_service.get_models_by_model_type(tenant_id=tenant_id, model_type=model_type) return jsonable_encoder({"data": models}) - - -api.add_resource(ModelProviderModelApi, "/workspaces/current/model-providers//models") -api.add_resource( - ModelProviderModelEnableApi, - "/workspaces/current/model-providers//models/enable", - endpoint="model-provider-model-enable", -) -api.add_resource( - ModelProviderModelDisableApi, - "/workspaces/current/model-providers//models/disable", - endpoint="model-provider-model-disable", -) -api.add_resource( - ModelProviderModelCredentialApi, "/workspaces/current/model-providers//models/credentials" -) -api.add_resource( - ModelProviderModelCredentialSwitchApi, - "/workspaces/current/model-providers//models/credentials/switch", -) -api.add_resource( - ModelProviderModelValidateApi, "/workspaces/current/model-providers//models/credentials/validate" -) - -api.add_resource( - ModelProviderModelParameterRuleApi, "/workspaces/current/model-providers//models/parameter-rules" -) -api.add_resource(ModelProviderAvailableModelApi, "/workspaces/current/models/model-types/") -api.add_resource(DefaultModelApi, "/workspaces/current/default-model") diff --git a/api/controllers/console/workspace/plugin.py b/api/controllers/console/workspace/plugin.py index fd5421fa64..7c70fb8aa0 100644 --- a/api/controllers/console/workspace/plugin.py +++ b/api/controllers/console/workspace/plugin.py @@ -6,7 +6,7 @@ from flask_restx import Resource, reqparse from werkzeug.exceptions import Forbidden from configs import dify_config -from controllers.console import api +from controllers.console import console_ns from controllers.console.workspace import plugin_permission_required from controllers.console.wraps import account_initialization_required, setup_required from core.model_runtime.utils.encoders import jsonable_encoder @@ -19,6 +19,7 @@ from services.plugin.plugin_permission_service import PluginPermissionService from services.plugin.plugin_service import PluginService +@console_ns.route("/workspaces/current/plugin/debugging-key") class PluginDebuggingKeyApi(Resource): @setup_required @login_required @@ -37,6 +38,7 @@ class PluginDebuggingKeyApi(Resource): raise ValueError(e) +@console_ns.route("/workspaces/current/plugin/list") class PluginListApi(Resource): @setup_required @login_required @@ -55,6 +57,7 @@ class PluginListApi(Resource): return jsonable_encoder({"plugins": plugins_with_total.list, "total": plugins_with_total.total}) +@console_ns.route("/workspaces/current/plugin/list/latest-versions") class PluginListLatestVersionsApi(Resource): @setup_required @login_required @@ -72,6 +75,7 @@ class PluginListLatestVersionsApi(Resource): return jsonable_encoder({"versions": versions}) +@console_ns.route("/workspaces/current/plugin/list/installations/ids") class PluginListInstallationsFromIdsApi(Resource): @setup_required @login_required @@ -91,6 +95,7 @@ class PluginListInstallationsFromIdsApi(Resource): return jsonable_encoder({"plugins": plugins}) +@console_ns.route("/workspaces/current/plugin/icon") class PluginIconApi(Resource): @setup_required def get(self): @@ -108,6 +113,7 @@ class PluginIconApi(Resource): return send_file(io.BytesIO(icon_bytes), mimetype=mimetype, max_age=icon_cache_max_age) +@console_ns.route("/workspaces/current/plugin/upload/pkg") class PluginUploadFromPkgApi(Resource): @setup_required @login_required @@ -131,6 +137,7 @@ class PluginUploadFromPkgApi(Resource): return jsonable_encoder(response) +@console_ns.route("/workspaces/current/plugin/upload/github") class PluginUploadFromGithubApi(Resource): @setup_required @login_required @@ -153,6 +160,7 @@ class PluginUploadFromGithubApi(Resource): return jsonable_encoder(response) +@console_ns.route("/workspaces/current/plugin/upload/bundle") class PluginUploadFromBundleApi(Resource): @setup_required @login_required @@ -176,6 +184,7 @@ class PluginUploadFromBundleApi(Resource): return jsonable_encoder(response) +@console_ns.route("/workspaces/current/plugin/install/pkg") class PluginInstallFromPkgApi(Resource): @setup_required @login_required @@ -201,6 +210,7 @@ class PluginInstallFromPkgApi(Resource): return jsonable_encoder(response) +@console_ns.route("/workspaces/current/plugin/install/github") class PluginInstallFromGithubApi(Resource): @setup_required @login_required @@ -230,6 +240,7 @@ class PluginInstallFromGithubApi(Resource): return jsonable_encoder(response) +@console_ns.route("/workspaces/current/plugin/install/marketplace") class PluginInstallFromMarketplaceApi(Resource): @setup_required @login_required @@ -255,6 +266,7 @@ class PluginInstallFromMarketplaceApi(Resource): return jsonable_encoder(response) +@console_ns.route("/workspaces/current/plugin/marketplace/pkg") class PluginFetchMarketplacePkgApi(Resource): @setup_required @login_required @@ -280,6 +292,7 @@ class PluginFetchMarketplacePkgApi(Resource): raise ValueError(e) +@console_ns.route("/workspaces/current/plugin/fetch-manifest") class PluginFetchManifestApi(Resource): @setup_required @login_required @@ -304,6 +317,7 @@ class PluginFetchManifestApi(Resource): raise ValueError(e) +@console_ns.route("/workspaces/current/plugin/tasks") class PluginFetchInstallTasksApi(Resource): @setup_required @login_required @@ -325,6 +339,7 @@ class PluginFetchInstallTasksApi(Resource): raise ValueError(e) +@console_ns.route("/workspaces/current/plugin/tasks/") class PluginFetchInstallTaskApi(Resource): @setup_required @login_required @@ -339,6 +354,7 @@ class PluginFetchInstallTaskApi(Resource): raise ValueError(e) +@console_ns.route("/workspaces/current/plugin/tasks//delete") class PluginDeleteInstallTaskApi(Resource): @setup_required @login_required @@ -353,6 +369,7 @@ class PluginDeleteInstallTaskApi(Resource): raise ValueError(e) +@console_ns.route("/workspaces/current/plugin/tasks/delete_all") class PluginDeleteAllInstallTaskItemsApi(Resource): @setup_required @login_required @@ -367,6 +384,7 @@ class PluginDeleteAllInstallTaskItemsApi(Resource): raise ValueError(e) +@console_ns.route("/workspaces/current/plugin/tasks//delete/") class PluginDeleteInstallTaskItemApi(Resource): @setup_required @login_required @@ -381,6 +399,7 @@ class PluginDeleteInstallTaskItemApi(Resource): raise ValueError(e) +@console_ns.route("/workspaces/current/plugin/upgrade/marketplace") class PluginUpgradeFromMarketplaceApi(Resource): @setup_required @login_required @@ -404,6 +423,7 @@ class PluginUpgradeFromMarketplaceApi(Resource): raise ValueError(e) +@console_ns.route("/workspaces/current/plugin/upgrade/github") class PluginUpgradeFromGithubApi(Resource): @setup_required @login_required @@ -435,6 +455,7 @@ class PluginUpgradeFromGithubApi(Resource): raise ValueError(e) +@console_ns.route("/workspaces/current/plugin/uninstall") class PluginUninstallApi(Resource): @setup_required @login_required @@ -453,6 +474,7 @@ class PluginUninstallApi(Resource): raise ValueError(e) +@console_ns.route("/workspaces/current/plugin/permission/change") class PluginChangePermissionApi(Resource): @setup_required @login_required @@ -475,6 +497,7 @@ class PluginChangePermissionApi(Resource): return {"success": PluginPermissionService.change_permission(tenant_id, install_permission, debug_permission)} +@console_ns.route("/workspaces/current/plugin/permission/fetch") class PluginFetchPermissionApi(Resource): @setup_required @login_required @@ -499,6 +522,7 @@ class PluginFetchPermissionApi(Resource): ) +@console_ns.route("/workspaces/current/plugin/parameters/dynamic-options") class PluginFetchDynamicSelectOptionsApi(Resource): @setup_required @login_required @@ -535,6 +559,7 @@ class PluginFetchDynamicSelectOptionsApi(Resource): return jsonable_encoder({"options": options}) +@console_ns.route("/workspaces/current/plugin/preferences/change") class PluginChangePreferencesApi(Resource): @setup_required @login_required @@ -590,6 +615,7 @@ class PluginChangePreferencesApi(Resource): return jsonable_encoder({"success": True}) +@console_ns.route("/workspaces/current/plugin/preferences/fetch") class PluginFetchPreferencesApi(Resource): @setup_required @login_required @@ -628,6 +654,7 @@ class PluginFetchPreferencesApi(Resource): return jsonable_encoder({"permission": permission_dict, "auto_upgrade": auto_upgrade_dict}) +@console_ns.route("/workspaces/current/plugin/preferences/autoupgrade/exclude") class PluginAutoUpgradeExcludePluginApi(Resource): @setup_required @login_required @@ -641,35 +668,3 @@ class PluginAutoUpgradeExcludePluginApi(Resource): args = req.parse_args() return jsonable_encoder({"success": PluginAutoUpgradeService.exclude_plugin(tenant_id, args["plugin_id"])}) - - -api.add_resource(PluginDebuggingKeyApi, "/workspaces/current/plugin/debugging-key") -api.add_resource(PluginListApi, "/workspaces/current/plugin/list") -api.add_resource(PluginListLatestVersionsApi, "/workspaces/current/plugin/list/latest-versions") -api.add_resource(PluginListInstallationsFromIdsApi, "/workspaces/current/plugin/list/installations/ids") -api.add_resource(PluginIconApi, "/workspaces/current/plugin/icon") -api.add_resource(PluginUploadFromPkgApi, "/workspaces/current/plugin/upload/pkg") -api.add_resource(PluginUploadFromGithubApi, "/workspaces/current/plugin/upload/github") -api.add_resource(PluginUploadFromBundleApi, "/workspaces/current/plugin/upload/bundle") -api.add_resource(PluginInstallFromPkgApi, "/workspaces/current/plugin/install/pkg") -api.add_resource(PluginInstallFromGithubApi, "/workspaces/current/plugin/install/github") -api.add_resource(PluginUpgradeFromMarketplaceApi, "/workspaces/current/plugin/upgrade/marketplace") -api.add_resource(PluginUpgradeFromGithubApi, "/workspaces/current/plugin/upgrade/github") -api.add_resource(PluginInstallFromMarketplaceApi, "/workspaces/current/plugin/install/marketplace") -api.add_resource(PluginFetchManifestApi, "/workspaces/current/plugin/fetch-manifest") -api.add_resource(PluginFetchInstallTasksApi, "/workspaces/current/plugin/tasks") -api.add_resource(PluginFetchInstallTaskApi, "/workspaces/current/plugin/tasks/") -api.add_resource(PluginDeleteInstallTaskApi, "/workspaces/current/plugin/tasks//delete") -api.add_resource(PluginDeleteAllInstallTaskItemsApi, "/workspaces/current/plugin/tasks/delete_all") -api.add_resource(PluginDeleteInstallTaskItemApi, "/workspaces/current/plugin/tasks//delete/") -api.add_resource(PluginUninstallApi, "/workspaces/current/plugin/uninstall") -api.add_resource(PluginFetchMarketplacePkgApi, "/workspaces/current/plugin/marketplace/pkg") - -api.add_resource(PluginChangePermissionApi, "/workspaces/current/plugin/permission/change") -api.add_resource(PluginFetchPermissionApi, "/workspaces/current/plugin/permission/fetch") - -api.add_resource(PluginFetchDynamicSelectOptionsApi, "/workspaces/current/plugin/parameters/dynamic-options") - -api.add_resource(PluginFetchPreferencesApi, "/workspaces/current/plugin/preferences/fetch") -api.add_resource(PluginChangePreferencesApi, "/workspaces/current/plugin/preferences/change") -api.add_resource(PluginAutoUpgradeExcludePluginApi, "/workspaces/current/plugin/preferences/autoupgrade/exclude") diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index 8693d99e23..9285577f72 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -10,7 +10,7 @@ from flask_restx import ( from werkzeug.exceptions import Forbidden from configs import dify_config -from controllers.console import api +from controllers.console import console_ns from controllers.console.wraps import ( account_initialization_required, enterprise_license_required, @@ -47,6 +47,7 @@ def is_valid_url(url: str) -> bool: return False +@console_ns.route("/workspaces/current/tool-providers") class ToolProviderListApi(Resource): @setup_required @login_required @@ -71,6 +72,7 @@ class ToolProviderListApi(Resource): return ToolCommonService.list_tool_providers(user_id, tenant_id, args.get("type", None)) +@console_ns.route("/workspaces/current/tool-provider/builtin//tools") class ToolBuiltinProviderListToolsApi(Resource): @setup_required @login_required @@ -88,6 +90,7 @@ class ToolBuiltinProviderListToolsApi(Resource): ) +@console_ns.route("/workspaces/current/tool-provider/builtin//info") class ToolBuiltinProviderInfoApi(Resource): @setup_required @login_required @@ -100,6 +103,7 @@ class ToolBuiltinProviderInfoApi(Resource): return jsonable_encoder(BuiltinToolManageService.get_builtin_tool_provider_info(tenant_id, provider)) +@console_ns.route("/workspaces/current/tool-provider/builtin//delete") class ToolBuiltinProviderDeleteApi(Resource): @setup_required @login_required @@ -121,6 +125,7 @@ class ToolBuiltinProviderDeleteApi(Resource): ) +@console_ns.route("/workspaces/current/tool-provider/builtin//add") class ToolBuiltinProviderAddApi(Resource): @setup_required @login_required @@ -150,6 +155,7 @@ class ToolBuiltinProviderAddApi(Resource): ) +@console_ns.route("/workspaces/current/tool-provider/builtin//update") class ToolBuiltinProviderUpdateApi(Resource): @setup_required @login_required @@ -181,6 +187,7 @@ class ToolBuiltinProviderUpdateApi(Resource): return result +@console_ns.route("/workspaces/current/tool-provider/builtin//credentials") class ToolBuiltinProviderGetCredentialsApi(Resource): @setup_required @login_required @@ -196,6 +203,7 @@ class ToolBuiltinProviderGetCredentialsApi(Resource): ) +@console_ns.route("/workspaces/current/tool-provider/builtin//icon") class ToolBuiltinProviderIconApi(Resource): @setup_required def get(self, provider): @@ -204,6 +212,7 @@ class ToolBuiltinProviderIconApi(Resource): return send_file(io.BytesIO(icon_bytes), mimetype=mimetype, max_age=icon_cache_max_age) +@console_ns.route("/workspaces/current/tool-provider/api/add") class ToolApiProviderAddApi(Resource): @setup_required @login_required @@ -243,6 +252,7 @@ class ToolApiProviderAddApi(Resource): ) +@console_ns.route("/workspaces/current/tool-provider/api/remote") class ToolApiProviderGetRemoteSchemaApi(Resource): @setup_required @login_required @@ -266,6 +276,7 @@ class ToolApiProviderGetRemoteSchemaApi(Resource): ) +@console_ns.route("/workspaces/current/tool-provider/api/tools") class ToolApiProviderListToolsApi(Resource): @setup_required @login_required @@ -291,6 +302,7 @@ class ToolApiProviderListToolsApi(Resource): ) +@console_ns.route("/workspaces/current/tool-provider/api/update") class ToolApiProviderUpdateApi(Resource): @setup_required @login_required @@ -332,6 +344,7 @@ class ToolApiProviderUpdateApi(Resource): ) +@console_ns.route("/workspaces/current/tool-provider/api/delete") class ToolApiProviderDeleteApi(Resource): @setup_required @login_required @@ -358,6 +371,7 @@ class ToolApiProviderDeleteApi(Resource): ) +@console_ns.route("/workspaces/current/tool-provider/api/get") class ToolApiProviderGetApi(Resource): @setup_required @login_required @@ -381,6 +395,7 @@ class ToolApiProviderGetApi(Resource): ) +@console_ns.route("/workspaces/current/tool-provider/builtin//credential/schema/") class ToolBuiltinProviderCredentialsSchemaApi(Resource): @setup_required @login_required @@ -396,6 +411,7 @@ class ToolBuiltinProviderCredentialsSchemaApi(Resource): ) +@console_ns.route("/workspaces/current/tool-provider/api/schema") class ToolApiProviderSchemaApi(Resource): @setup_required @login_required @@ -412,6 +428,7 @@ class ToolApiProviderSchemaApi(Resource): ) +@console_ns.route("/workspaces/current/tool-provider/api/test/pre") class ToolApiProviderPreviousTestApi(Resource): @setup_required @login_required @@ -439,6 +456,7 @@ class ToolApiProviderPreviousTestApi(Resource): ) +@console_ns.route("/workspaces/current/tool-provider/workflow/create") class ToolWorkflowProviderCreateApi(Resource): @setup_required @login_required @@ -478,6 +496,7 @@ class ToolWorkflowProviderCreateApi(Resource): ) +@console_ns.route("/workspaces/current/tool-provider/workflow/update") class ToolWorkflowProviderUpdateApi(Resource): @setup_required @login_required @@ -520,6 +539,7 @@ class ToolWorkflowProviderUpdateApi(Resource): ) +@console_ns.route("/workspaces/current/tool-provider/workflow/delete") class ToolWorkflowProviderDeleteApi(Resource): @setup_required @login_required @@ -545,6 +565,7 @@ class ToolWorkflowProviderDeleteApi(Resource): ) +@console_ns.route("/workspaces/current/tool-provider/workflow/get") class ToolWorkflowProviderGetApi(Resource): @setup_required @login_required @@ -579,6 +600,7 @@ class ToolWorkflowProviderGetApi(Resource): return jsonable_encoder(tool) +@console_ns.route("/workspaces/current/tool-provider/workflow/tools") class ToolWorkflowProviderListToolApi(Resource): @setup_required @login_required @@ -603,6 +625,7 @@ class ToolWorkflowProviderListToolApi(Resource): ) +@console_ns.route("/workspaces/current/tools/builtin") class ToolBuiltinListApi(Resource): @setup_required @login_required @@ -624,6 +647,7 @@ class ToolBuiltinListApi(Resource): ) +@console_ns.route("/workspaces/current/tools/api") class ToolApiListApi(Resource): @setup_required @login_required @@ -642,6 +666,7 @@ class ToolApiListApi(Resource): ) +@console_ns.route("/workspaces/current/tools/workflow") class ToolWorkflowListApi(Resource): @setup_required @login_required @@ -663,6 +688,7 @@ class ToolWorkflowListApi(Resource): ) +@console_ns.route("/workspaces/current/tool-labels") class ToolLabelsApi(Resource): @setup_required @login_required @@ -672,6 +698,7 @@ class ToolLabelsApi(Resource): return jsonable_encoder(ToolLabelsService.list_tool_labels()) +@console_ns.route("/oauth/plugin//tool/authorization-url") class ToolPluginOAuthApi(Resource): @setup_required @login_required @@ -716,6 +743,7 @@ class ToolPluginOAuthApi(Resource): return response +@console_ns.route("/oauth/plugin//tool/callback") class ToolOAuthCallback(Resource): @setup_required def get(self, provider): @@ -766,6 +794,7 @@ class ToolOAuthCallback(Resource): return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback") +@console_ns.route("/workspaces/current/tool-provider/builtin//default-credential") class ToolBuiltinProviderSetDefaultApi(Resource): @setup_required @login_required @@ -779,6 +808,7 @@ class ToolBuiltinProviderSetDefaultApi(Resource): ) +@console_ns.route("/workspaces/current/tool-provider/builtin//oauth/custom-client") class ToolOAuthCustomClient(Resource): @setup_required @login_required @@ -822,6 +852,7 @@ class ToolOAuthCustomClient(Resource): ) +@console_ns.route("/workspaces/current/tool-provider/builtin//oauth/client-schema") class ToolBuiltinProviderGetOauthClientSchemaApi(Resource): @setup_required @login_required @@ -834,6 +865,7 @@ class ToolBuiltinProviderGetOauthClientSchemaApi(Resource): ) +@console_ns.route("/workspaces/current/tool-provider/builtin//credential/info") class ToolBuiltinProviderGetCredentialInfoApi(Resource): @setup_required @login_required @@ -849,6 +881,7 @@ class ToolBuiltinProviderGetCredentialInfoApi(Resource): ) +@console_ns.route("/workspaces/current/tool-provider/mcp") class ToolProviderMCPApi(Resource): @setup_required @login_required @@ -933,6 +966,7 @@ class ToolProviderMCPApi(Resource): return {"result": "success"} +@console_ns.route("/workspaces/current/tool-provider/mcp/auth") class ToolMCPAuthApi(Resource): @setup_required @login_required @@ -978,6 +1012,7 @@ class ToolMCPAuthApi(Resource): raise ValueError(f"Failed to connect to MCP server: {e}") from e +@console_ns.route("/workspaces/current/tool-provider/mcp/tools/") class ToolMCPDetailApi(Resource): @setup_required @login_required @@ -988,6 +1023,7 @@ class ToolMCPDetailApi(Resource): return jsonable_encoder(ToolTransformService.mcp_provider_to_user_provider(provider, for_list=True)) +@console_ns.route("/workspaces/current/tools/mcp") class ToolMCPListAllApi(Resource): @setup_required @login_required @@ -1001,6 +1037,7 @@ class ToolMCPListAllApi(Resource): return [tool.to_dict() for tool in tools] +@console_ns.route("/workspaces/current/tool-provider/mcp/update/") class ToolMCPUpdateApi(Resource): @setup_required @login_required @@ -1014,6 +1051,7 @@ class ToolMCPUpdateApi(Resource): return jsonable_encoder(tools) +@console_ns.route("/mcp/oauth/callback") class ToolMCPCallbackApi(Resource): def get(self): parser = reqparse.RequestParser() @@ -1024,67 +1062,3 @@ class ToolMCPCallbackApi(Resource): authorization_code = args["code"] handle_callback(state_key, authorization_code) return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback") - - -# tool provider -api.add_resource(ToolProviderListApi, "/workspaces/current/tool-providers") - -# tool oauth -api.add_resource(ToolPluginOAuthApi, "/oauth/plugin//tool/authorization-url") -api.add_resource(ToolOAuthCallback, "/oauth/plugin//tool/callback") -api.add_resource(ToolOAuthCustomClient, "/workspaces/current/tool-provider/builtin//oauth/custom-client") - -# builtin tool provider -api.add_resource(ToolBuiltinProviderListToolsApi, "/workspaces/current/tool-provider/builtin//tools") -api.add_resource(ToolBuiltinProviderInfoApi, "/workspaces/current/tool-provider/builtin//info") -api.add_resource(ToolBuiltinProviderAddApi, "/workspaces/current/tool-provider/builtin//add") -api.add_resource(ToolBuiltinProviderDeleteApi, "/workspaces/current/tool-provider/builtin//delete") -api.add_resource(ToolBuiltinProviderUpdateApi, "/workspaces/current/tool-provider/builtin//update") -api.add_resource( - ToolBuiltinProviderSetDefaultApi, "/workspaces/current/tool-provider/builtin//default-credential" -) -api.add_resource( - ToolBuiltinProviderGetCredentialInfoApi, "/workspaces/current/tool-provider/builtin//credential/info" -) -api.add_resource( - ToolBuiltinProviderGetCredentialsApi, "/workspaces/current/tool-provider/builtin//credentials" -) -api.add_resource( - ToolBuiltinProviderCredentialsSchemaApi, - "/workspaces/current/tool-provider/builtin//credential/schema/", -) -api.add_resource( - ToolBuiltinProviderGetOauthClientSchemaApi, - "/workspaces/current/tool-provider/builtin//oauth/client-schema", -) -api.add_resource(ToolBuiltinProviderIconApi, "/workspaces/current/tool-provider/builtin//icon") - -# api tool provider -api.add_resource(ToolApiProviderAddApi, "/workspaces/current/tool-provider/api/add") -api.add_resource(ToolApiProviderGetRemoteSchemaApi, "/workspaces/current/tool-provider/api/remote") -api.add_resource(ToolApiProviderListToolsApi, "/workspaces/current/tool-provider/api/tools") -api.add_resource(ToolApiProviderUpdateApi, "/workspaces/current/tool-provider/api/update") -api.add_resource(ToolApiProviderDeleteApi, "/workspaces/current/tool-provider/api/delete") -api.add_resource(ToolApiProviderGetApi, "/workspaces/current/tool-provider/api/get") -api.add_resource(ToolApiProviderSchemaApi, "/workspaces/current/tool-provider/api/schema") -api.add_resource(ToolApiProviderPreviousTestApi, "/workspaces/current/tool-provider/api/test/pre") - -# workflow tool provider -api.add_resource(ToolWorkflowProviderCreateApi, "/workspaces/current/tool-provider/workflow/create") -api.add_resource(ToolWorkflowProviderUpdateApi, "/workspaces/current/tool-provider/workflow/update") -api.add_resource(ToolWorkflowProviderDeleteApi, "/workspaces/current/tool-provider/workflow/delete") -api.add_resource(ToolWorkflowProviderGetApi, "/workspaces/current/tool-provider/workflow/get") -api.add_resource(ToolWorkflowProviderListToolApi, "/workspaces/current/tool-provider/workflow/tools") - -# mcp tool provider -api.add_resource(ToolMCPDetailApi, "/workspaces/current/tool-provider/mcp/tools/") -api.add_resource(ToolProviderMCPApi, "/workspaces/current/tool-provider/mcp") -api.add_resource(ToolMCPUpdateApi, "/workspaces/current/tool-provider/mcp/update/") -api.add_resource(ToolMCPAuthApi, "/workspaces/current/tool-provider/mcp/auth") -api.add_resource(ToolMCPCallbackApi, "/mcp/oauth/callback") - -api.add_resource(ToolBuiltinListApi, "/workspaces/current/tools/builtin") -api.add_resource(ToolApiListApi, "/workspaces/current/tools/api") -api.add_resource(ToolMCPListAllApi, "/workspaces/current/tools/mcp") -api.add_resource(ToolWorkflowListApi, "/workspaces/current/tools/workflow") -api.add_resource(ToolLabelsApi, "/workspaces/current/tool-labels") diff --git a/api/controllers/console/workspace/workspace.py b/api/controllers/console/workspace/workspace.py index 6bec70b5da..bc748ac3d2 100644 --- a/api/controllers/console/workspace/workspace.py +++ b/api/controllers/console/workspace/workspace.py @@ -14,7 +14,7 @@ from controllers.common.errors import ( TooManyFilesError, UnsupportedFileTypeError, ) -from controllers.console import api +from controllers.console import console_ns from controllers.console.admin import admin_required from controllers.console.error import AccountNotLinkTenantError from controllers.console.wraps import ( @@ -65,6 +65,7 @@ tenants_fields = { workspace_fields = {"id": fields.String, "name": fields.String, "status": fields.String, "created_at": TimestampField} +@console_ns.route("/workspaces") class TenantListApi(Resource): @setup_required @login_required @@ -93,6 +94,7 @@ class TenantListApi(Resource): return {"workspaces": marshal(tenant_dicts, tenants_fields)}, 200 +@console_ns.route("/all-workspaces") class WorkspaceListApi(Resource): @setup_required @admin_required @@ -118,6 +120,8 @@ class WorkspaceListApi(Resource): }, 200 +@console_ns.route("/workspaces/current", endpoint="workspaces_current") +@console_ns.route("/info", endpoint="info") # Deprecated class TenantApi(Resource): @setup_required @login_required @@ -143,11 +147,10 @@ class TenantApi(Resource): else: raise Unauthorized("workspace is archived") - if not tenant: - raise ValueError("No tenant available") return WorkspaceService.get_tenant_info(tenant), 200 +@console_ns.route("/workspaces/switch") class SwitchWorkspaceApi(Resource): @setup_required @login_required @@ -172,6 +175,7 @@ class SwitchWorkspaceApi(Resource): return {"result": "success", "new_tenant": marshal(WorkspaceService.get_tenant_info(new_tenant), tenant_fields)} +@console_ns.route("/workspaces/custom-config") class CustomConfigWorkspaceApi(Resource): @setup_required @login_required @@ -202,6 +206,7 @@ class CustomConfigWorkspaceApi(Resource): return {"result": "success", "tenant": marshal(WorkspaceService.get_tenant_info(tenant), tenant_fields)} +@console_ns.route("/workspaces/custom-config/webapp-logo/upload") class WebappLogoWorkspaceApi(Resource): @setup_required @login_required @@ -242,6 +247,7 @@ class WebappLogoWorkspaceApi(Resource): return {"id": upload_file.id}, 201 +@console_ns.route("/workspaces/info") class WorkspaceInfoApi(Resource): @setup_required @login_required @@ -261,13 +267,3 @@ class WorkspaceInfoApi(Resource): db.session.commit() return {"result": "success", "tenant": marshal(WorkspaceService.get_tenant_info(tenant), tenant_fields)} - - -api.add_resource(TenantListApi, "/workspaces") # GET for getting all tenants -api.add_resource(WorkspaceListApi, "/all-workspaces") # GET for getting all tenants -api.add_resource(TenantApi, "/workspaces/current", endpoint="workspaces_current") # GET for getting current tenant info -api.add_resource(TenantApi, "/info", endpoint="info") # Deprecated -api.add_resource(SwitchWorkspaceApi, "/workspaces/switch") # POST for switching tenant -api.add_resource(CustomConfigWorkspaceApi, "/workspaces/custom-config") -api.add_resource(WebappLogoWorkspaceApi, "/workspaces/custom-config/webapp-logo/upload") -api.add_resource(WorkspaceInfoApi, "/workspaces/info") # POST for changing workspace info diff --git a/api/controllers/inner_api/plugin/wraps.py b/api/controllers/inner_api/plugin/wraps.py index b683aa3160..1f588bedce 100644 --- a/api/controllers/inner_api/plugin/wraps.py +++ b/api/controllers/inner_api/plugin/wraps.py @@ -25,8 +25,8 @@ def get_user(tenant_id: str, user_id: str | None) -> EndUser: As a result, it could only be considered as an end user id. """ if not user_id: - user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID.value - is_anonymous = user_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID.value + user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID + is_anonymous = user_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID try: with Session(db.engine) as session: user_model = None @@ -85,7 +85,7 @@ def get_user_tenant(view: Callable[P, R] | None = None): raise ValueError("tenant_id is required") if not user_id: - user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID.value + user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID try: tenant_model = ( @@ -128,7 +128,7 @@ def plugin_data(view: Callable[P, R] | None = None, *, payload_type: type[BaseMo raise ValueError("invalid json") try: - payload = payload_type(**data) + payload = payload_type.model_validate(data) except Exception as e: raise ValueError(f"invalid payload: {str(e)}") diff --git a/api/controllers/service_api/dataset/dataset.py b/api/controllers/service_api/dataset/dataset.py index 961b96db91..92bbb76f0f 100644 --- a/api/controllers/service_api/dataset/dataset.py +++ b/api/controllers/service_api/dataset/dataset.py @@ -280,7 +280,7 @@ class DatasetListApi(DatasetApiResource): external_knowledge_id=args["external_knowledge_id"], embedding_model_provider=args["embedding_model_provider"], embedding_model_name=args["embedding_model"], - retrieval_model=RetrievalModel(**args["retrieval_model"]) + retrieval_model=RetrievalModel.model_validate(args["retrieval_model"]) if args["retrieval_model"] is not None else None, ) diff --git a/api/controllers/service_api/dataset/document.py b/api/controllers/service_api/dataset/document.py index c1122acd7b..961a338bc5 100644 --- a/api/controllers/service_api/dataset/document.py +++ b/api/controllers/service_api/dataset/document.py @@ -136,7 +136,7 @@ class DocumentAddByTextApi(DatasetApiResource): "info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}}, } args["data_source"] = data_source - knowledge_config = KnowledgeConfig(**args) + knowledge_config = KnowledgeConfig.model_validate(args) # validate args DocumentService.document_create_args_validate(knowledge_config) @@ -221,7 +221,7 @@ class DocumentUpdateByTextApi(DatasetApiResource): args["data_source"] = data_source # validate args args["original_document_id"] = str(document_id) - knowledge_config = KnowledgeConfig(**args) + knowledge_config = KnowledgeConfig.model_validate(args) DocumentService.document_create_args_validate(knowledge_config) try: @@ -328,7 +328,7 @@ class DocumentAddByFileApi(DatasetApiResource): } args["data_source"] = data_source # validate args - knowledge_config = KnowledgeConfig(**args) + knowledge_config = KnowledgeConfig.model_validate(args) DocumentService.document_create_args_validate(knowledge_config) dataset_process_rule = dataset.latest_process_rule if "process_rule" not in args else None @@ -426,7 +426,7 @@ class DocumentUpdateByFileApi(DatasetApiResource): # validate args args["original_document_id"] = str(document_id) - knowledge_config = KnowledgeConfig(**args) + knowledge_config = KnowledgeConfig.model_validate(args) DocumentService.document_create_args_validate(knowledge_config) try: diff --git a/api/controllers/service_api/dataset/metadata.py b/api/controllers/service_api/dataset/metadata.py index e01659dc68..51420fdd5f 100644 --- a/api/controllers/service_api/dataset/metadata.py +++ b/api/controllers/service_api/dataset/metadata.py @@ -51,7 +51,7 @@ class DatasetMetadataCreateServiceApi(DatasetApiResource): def post(self, tenant_id, dataset_id): """Create metadata for a dataset.""" args = metadata_create_parser.parse_args() - metadata_args = MetadataArgs(**args) + metadata_args = MetadataArgs.model_validate(args) dataset_id_str = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id_str) @@ -200,7 +200,7 @@ class DocumentMetadataEditServiceApi(DatasetApiResource): DatasetService.check_dataset_permission(dataset, current_user) args = document_metadata_parser.parse_args() - metadata_args = MetadataOperationData(**args) + metadata_args = MetadataOperationData.model_validate(args) MetadataService.update_documents_metadata(dataset, metadata_args) diff --git a/api/controllers/service_api/dataset/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/service_api/dataset/rag_pipeline/rag_pipeline_workflow.py index f05325d711..13ef8abc2d 100644 --- a/api/controllers/service_api/dataset/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/service_api/dataset/rag_pipeline/rag_pipeline_workflow.py @@ -98,7 +98,7 @@ class DatasourceNodeRunApi(DatasetApiResource): parser.add_argument("is_published", type=bool, required=True, location="json") args: ParseResult = parser.parse_args() - datasource_node_run_api_entity: DatasourceNodeRunApiEntity = DatasourceNodeRunApiEntity(**args) + datasource_node_run_api_entity = DatasourceNodeRunApiEntity.model_validate(args) assert isinstance(current_user, Account) rag_pipeline_service: RagPipelineService = RagPipelineService() pipeline: Pipeline = rag_pipeline_service.get_pipeline(tenant_id=tenant_id, dataset_id=dataset_id) diff --git a/api/controllers/service_api/dataset/segment.py b/api/controllers/service_api/dataset/segment.py index a22155b07a..d674c7467d 100644 --- a/api/controllers/service_api/dataset/segment.py +++ b/api/controllers/service_api/dataset/segment.py @@ -252,7 +252,7 @@ class DatasetSegmentApi(DatasetApiResource): args = segment_update_parser.parse_args() updated_segment = SegmentService.update_segment( - SegmentUpdateArgs(**args["segment"]), segment, document, dataset + SegmentUpdateArgs.model_validate(args["segment"]), segment, document, dataset ) return {"data": marshal(updated_segment, segment_fields), "doc_form": document.doc_form}, 200 diff --git a/api/controllers/service_api/wraps.py b/api/controllers/service_api/wraps.py index ee8e1d105b..2c9be4e887 100644 --- a/api/controllers/service_api/wraps.py +++ b/api/controllers/service_api/wraps.py @@ -313,7 +313,7 @@ def create_or_update_end_user_for_user_id(app_model: App, user_id: str | None = Create or update session terminal based on user ID. """ if not user_id: - user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID.value + user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID with Session(db.engine, expire_on_commit=False) as session: end_user = ( @@ -332,7 +332,7 @@ def create_or_update_end_user_for_user_id(app_model: App, user_id: str | None = tenant_id=app_model.tenant_id, app_id=app_model.id, type="service_api", - is_anonymous=user_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID.value, + is_anonymous=user_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID, session_id=user_id, ) session.add(end_user) diff --git a/api/controllers/web/passport.py b/api/controllers/web/passport.py index 6f7105a724..7190f06426 100644 --- a/api/controllers/web/passport.py +++ b/api/controllers/web/passport.py @@ -126,6 +126,8 @@ def exchange_token_for_existing_web_user(app_code: str, enterprise_user_decoded: end_user_id = enterprise_user_decoded.get("end_user_id") session_id = enterprise_user_decoded.get("session_id") user_auth_type = enterprise_user_decoded.get("auth_type") + exchanged_token_expires_unix = enterprise_user_decoded.get("exp") + if not user_auth_type: raise Unauthorized("Missing auth_type in the token.") @@ -169,8 +171,11 @@ def exchange_token_for_existing_web_user(app_code: str, enterprise_user_decoded: ) db.session.add(end_user) db.session.commit() - exp_dt = datetime.now(UTC) + timedelta(minutes=dify_config.ACCESS_TOKEN_EXPIRE_MINUTES) - exp = int(exp_dt.timestamp()) + + exp = int((datetime.now(UTC) + timedelta(minutes=dify_config.ACCESS_TOKEN_EXPIRE_MINUTES)).timestamp()) + if exchanged_token_expires_unix: + exp = int(exchanged_token_expires_unix) + payload = { "iss": site.id, "sub": "Web API Passport", diff --git a/api/core/app/app_config/easy_ui_based_app/agent/manager.py b/api/core/app/app_config/easy_ui_based_app/agent/manager.py index eab26e5af9..c1f336fdde 100644 --- a/api/core/app/app_config/easy_ui_based_app/agent/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/agent/manager.py @@ -40,7 +40,7 @@ class AgentConfigManager: "credential_id": tool.get("credential_id", None), } - agent_tools.append(AgentToolEntity(**agent_tool_properties)) + agent_tools.append(AgentToolEntity.model_validate(agent_tool_properties)) if "strategy" in config["agent_mode"] and config["agent_mode"]["strategy"] not in { "react_router", diff --git a/api/core/app/app_config/easy_ui_based_app/dataset/manager.py b/api/core/app/app_config/easy_ui_based_app/dataset/manager.py index 3564cc175b..aacafb2dad 100644 --- a/api/core/app/app_config/easy_ui_based_app/dataset/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/dataset/manager.py @@ -197,12 +197,12 @@ class DatasetConfigManager: # strategy if "strategy" not in config["agent_mode"] or not config["agent_mode"].get("strategy"): - config["agent_mode"]["strategy"] = PlanningStrategy.ROUTER.value + config["agent_mode"]["strategy"] = PlanningStrategy.ROUTER has_datasets = False if config.get("agent_mode", {}).get("strategy") in { - PlanningStrategy.ROUTER.value, - PlanningStrategy.REACT_ROUTER.value, + PlanningStrategy.ROUTER, + PlanningStrategy.REACT_ROUTER, }: for tool in config.get("agent_mode", {}).get("tools", []): key = list(tool.keys())[0] diff --git a/api/core/app/app_config/easy_ui_based_app/model_config/converter.py b/api/core/app/app_config/easy_ui_based_app/model_config/converter.py index 5b5eefe315..7cd5fe75d5 100644 --- a/api/core/app/app_config/easy_ui_based_app/model_config/converter.py +++ b/api/core/app/app_config/easy_ui_based_app/model_config/converter.py @@ -68,7 +68,7 @@ class ModelConfigConverter: # get model mode model_mode = model_config.mode if not model_mode: - model_mode = LLMMode.CHAT.value + model_mode = LLMMode.CHAT if model_schema and model_schema.model_properties.get(ModelPropertyKey.MODE): model_mode = LLMMode(model_schema.model_properties[ModelPropertyKey.MODE]).value diff --git a/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py b/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py index ec4f6074ab..21614c010c 100644 --- a/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py @@ -100,7 +100,7 @@ class PromptTemplateConfigManager: if config["model"]["mode"] not in model_mode_vals: raise ValueError(f"model.mode must be in {model_mode_vals} when prompt_type is advanced") - if app_mode == AppMode.CHAT and config["model"]["mode"] == ModelMode.COMPLETION.value: + if app_mode == AppMode.CHAT and config["model"]["mode"] == ModelMode.COMPLETION: user_prefix = config["completion_prompt_config"]["conversation_histories_role"]["user_prefix"] assistant_prefix = config["completion_prompt_config"]["conversation_histories_role"]["assistant_prefix"] @@ -110,7 +110,7 @@ class PromptTemplateConfigManager: if not assistant_prefix: config["completion_prompt_config"]["conversation_histories_role"]["assistant_prefix"] = "Assistant" - if config["model"]["mode"] == ModelMode.CHAT.value: + if config["model"]["mode"] == ModelMode.CHAT: prompt_list = config["chat_prompt_config"]["prompt"] if len(prompt_list) > 10: diff --git a/api/core/app/apps/agent_chat/app_config_manager.py b/api/core/app/apps/agent_chat/app_config_manager.py index 9ce841f432..801619ddbc 100644 --- a/api/core/app/apps/agent_chat/app_config_manager.py +++ b/api/core/app/apps/agent_chat/app_config_manager.py @@ -186,7 +186,7 @@ class AgentChatAppConfigManager(BaseAppConfigManager): raise ValueError("enabled in agent_mode must be of boolean type") if not agent_mode.get("strategy"): - agent_mode["strategy"] = PlanningStrategy.ROUTER.value + agent_mode["strategy"] = PlanningStrategy.ROUTER if agent_mode["strategy"] not in [member.value for member in list(PlanningStrategy.__members__.values())]: raise ValueError("strategy in agent_mode must be in the specified strategy list") diff --git a/api/core/app/apps/agent_chat/app_runner.py b/api/core/app/apps/agent_chat/app_runner.py index 388bed5255..759398b556 100644 --- a/api/core/app/apps/agent_chat/app_runner.py +++ b/api/core/app/apps/agent_chat/app_runner.py @@ -198,9 +198,9 @@ class AgentChatAppRunner(AppRunner): # start agent runner if agent_entity.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT: # check LLM mode - if model_schema.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT.value: + if model_schema.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT: runner_cls = CotChatAgentRunner - elif model_schema.model_properties.get(ModelPropertyKey.MODE) == LLMMode.COMPLETION.value: + elif model_schema.model_properties.get(ModelPropertyKey.MODE) == LLMMode.COMPLETION: runner_cls = CotCompletionAgentRunner else: raise ValueError(f"Invalid LLM mode: {model_schema.model_properties.get(ModelPropertyKey.MODE)}") diff --git a/api/core/app/apps/base_app_runner.py b/api/core/app/apps/base_app_runner.py index e7db3bc41b..61ac040c05 100644 --- a/api/core/app/apps/base_app_runner.py +++ b/api/core/app/apps/base_app_runner.py @@ -61,9 +61,6 @@ class AppRunner: if model_context_tokens is None: return -1 - if max_tokens is None: - max_tokens = 0 - prompt_tokens = model_instance.get_llm_num_tokens(prompt_messages) if prompt_tokens + max_tokens > model_context_tokens: diff --git a/api/core/app/apps/pipeline/pipeline_runner.py b/api/core/app/apps/pipeline/pipeline_runner.py index 145f629c4d..a8a7dde2b4 100644 --- a/api/core/app/apps/pipeline/pipeline_runner.py +++ b/api/core/app/apps/pipeline/pipeline_runner.py @@ -116,7 +116,7 @@ class PipelineRunner(WorkflowBasedAppRunner): rag_pipeline_variables = [] if workflow.rag_pipeline_variables: for v in workflow.rag_pipeline_variables: - rag_pipeline_variable = RAGPipelineVariable(**v) + rag_pipeline_variable = RAGPipelineVariable.model_validate(v) if ( rag_pipeline_variable.belong_to_node_id in (self.application_generate_entity.start_node_id, "shared") @@ -229,8 +229,8 @@ class PipelineRunner(WorkflowBasedAppRunner): workflow_id=workflow.id, graph_config=graph_config, user_id=self.application_generate_entity.user_id, - user_from=UserFrom.ACCOUNT.value, - invoke_from=InvokeFrom.SERVICE_API.value, + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.SERVICE_API, call_depth=0, ) diff --git a/api/core/app/apps/workflow_app_runner.py b/api/core/app/apps/workflow_app_runner.py index 564daba86d..68eb455d26 100644 --- a/api/core/app/apps/workflow_app_runner.py +++ b/api/core/app/apps/workflow_app_runner.py @@ -100,8 +100,8 @@ class WorkflowBasedAppRunner: workflow_id=workflow_id, graph_config=graph_config, user_id=user_id, - user_from=UserFrom.ACCOUNT.value, - invoke_from=InvokeFrom.SERVICE_API.value, + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.SERVICE_API, call_depth=0, ) @@ -244,8 +244,8 @@ class WorkflowBasedAppRunner: workflow_id=workflow.id, graph_config=graph_config, user_id="", - user_from=UserFrom.ACCOUNT.value, - invoke_from=InvokeFrom.SERVICE_API.value, + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.SERVICE_API, call_depth=0, ) diff --git a/api/core/datasource/entities/api_entities.py b/api/core/datasource/entities/api_entities.py index cdefcc4506..1179537570 100644 --- a/api/core/datasource/entities/api_entities.py +++ b/api/core/datasource/entities/api_entities.py @@ -49,7 +49,7 @@ class DatasourceProviderApiEntity(BaseModel): for datasource in datasources: if datasource.get("parameters"): for parameter in datasource.get("parameters"): - if parameter.get("type") == DatasourceParameter.DatasourceParameterType.SYSTEM_FILES.value: + if parameter.get("type") == DatasourceParameter.DatasourceParameterType.SYSTEM_FILES: parameter["type"] = "files" # ------------- diff --git a/api/core/datasource/entities/common_entities.py b/api/core/datasource/entities/common_entities.py index ac36d83ae3..3c64632dbb 100644 --- a/api/core/datasource/entities/common_entities.py +++ b/api/core/datasource/entities/common_entities.py @@ -1,4 +1,4 @@ -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, model_validator class I18nObject(BaseModel): @@ -11,11 +11,12 @@ class I18nObject(BaseModel): pt_BR: str | None = Field(default=None) ja_JP: str | None = Field(default=None) - def __init__(self, **data): - super().__init__(**data) + @model_validator(mode="after") + def _(self): self.zh_Hans = self.zh_Hans or self.en_US self.pt_BR = self.pt_BR or self.en_US self.ja_JP = self.ja_JP or self.en_US + return self def to_dict(self) -> dict: return {"zh_Hans": self.zh_Hans, "en_US": self.en_US, "pt_BR": self.pt_BR, "ja_JP": self.ja_JP} diff --git a/api/core/datasource/entities/datasource_entities.py b/api/core/datasource/entities/datasource_entities.py index ac4f51ac75..7f503b963f 100644 --- a/api/core/datasource/entities/datasource_entities.py +++ b/api/core/datasource/entities/datasource_entities.py @@ -54,16 +54,16 @@ class DatasourceParameter(PluginParameter): removes TOOLS_SELECTOR from PluginParameterType """ - STRING = PluginParameterType.STRING.value - NUMBER = PluginParameterType.NUMBER.value - BOOLEAN = PluginParameterType.BOOLEAN.value - SELECT = PluginParameterType.SELECT.value - SECRET_INPUT = PluginParameterType.SECRET_INPUT.value - FILE = PluginParameterType.FILE.value - FILES = PluginParameterType.FILES.value + STRING = PluginParameterType.STRING + NUMBER = PluginParameterType.NUMBER + BOOLEAN = PluginParameterType.BOOLEAN + SELECT = PluginParameterType.SELECT + SECRET_INPUT = PluginParameterType.SECRET_INPUT + FILE = PluginParameterType.FILE + FILES = PluginParameterType.FILES # deprecated, should not use. - SYSTEM_FILES = PluginParameterType.SYSTEM_FILES.value + SYSTEM_FILES = PluginParameterType.SYSTEM_FILES def as_normal_type(self): return as_normal_type(self) diff --git a/api/core/entities/provider_configuration.py b/api/core/entities/provider_configuration.py index 111de89178..bc19afb52a 100644 --- a/api/core/entities/provider_configuration.py +++ b/api/core/entities/provider_configuration.py @@ -5,7 +5,7 @@ from collections import defaultdict from collections.abc import Iterator, Sequence from json import JSONDecodeError -from pydantic import BaseModel, ConfigDict, Field +from pydantic import BaseModel, ConfigDict, Field, model_validator from sqlalchemy import func, select from sqlalchemy.orm import Session @@ -73,9 +73,8 @@ class ProviderConfiguration(BaseModel): # pydantic configs model_config = ConfigDict(protected_namespaces=()) - def __init__(self, **data): - super().__init__(**data) - + @model_validator(mode="after") + def _(self): if self.provider.provider not in original_provider_configurate_methods: original_provider_configurate_methods[self.provider.provider] = [] for configurate_method in self.provider.configurate_methods: @@ -90,6 +89,7 @@ class ProviderConfiguration(BaseModel): and ConfigurateMethod.PREDEFINED_MODEL not in self.provider.configurate_methods ): self.provider.configurate_methods.append(ConfigurateMethod.PREDEFINED_MODEL) + return self def get_current_credentials(self, model_type: ModelType, model: str) -> dict | None: """ @@ -207,7 +207,7 @@ class ProviderConfiguration(BaseModel): """ stmt = select(Provider).where( Provider.tenant_id == self.tenant_id, - Provider.provider_type == ProviderType.CUSTOM.value, + Provider.provider_type == ProviderType.CUSTOM, Provider.provider_name.in_(self._get_provider_names()), ) @@ -458,7 +458,7 @@ class ProviderConfiguration(BaseModel): provider_record = Provider( tenant_id=self.tenant_id, provider_name=self.provider.provider, - provider_type=ProviderType.CUSTOM.value, + provider_type=ProviderType.CUSTOM, is_valid=True, credential_id=new_record.id, ) @@ -1414,7 +1414,7 @@ class ProviderConfiguration(BaseModel): """ secret_input_form_variables = [] for credential_form_schema in credential_form_schemas: - if credential_form_schema.type.value == FormType.SECRET_INPUT.value: + if credential_form_schema.type.value == FormType.SECRET_INPUT: secret_input_form_variables.append(credential_form_schema.variable) return secret_input_form_variables diff --git a/api/core/extension/api_based_extension_requestor.py b/api/core/extension/api_based_extension_requestor.py index fab9ae44e9..f9e6099049 100644 --- a/api/core/extension/api_based_extension_requestor.py +++ b/api/core/extension/api_based_extension_requestor.py @@ -1,13 +1,13 @@ from typing import cast -import requests +import httpx from configs import dify_config from models.api_based_extension import APIBasedExtensionPoint class APIBasedExtensionRequestor: - timeout: tuple[int, int] = (5, 60) + timeout: httpx.Timeout = httpx.Timeout(60.0, connect=5.0) """timeout for request connect and read""" def __init__(self, api_endpoint: str, api_key: str): @@ -27,25 +27,23 @@ class APIBasedExtensionRequestor: url = self.api_endpoint try: - # proxy support for security - proxies = None + mounts: dict[str, httpx.BaseTransport] | None = None if dify_config.SSRF_PROXY_HTTP_URL and dify_config.SSRF_PROXY_HTTPS_URL: - proxies = { - "http": dify_config.SSRF_PROXY_HTTP_URL, - "https": dify_config.SSRF_PROXY_HTTPS_URL, + mounts = { + "http://": httpx.HTTPTransport(proxy=dify_config.SSRF_PROXY_HTTP_URL), + "https://": httpx.HTTPTransport(proxy=dify_config.SSRF_PROXY_HTTPS_URL), } - response = requests.request( - method="POST", - url=url, - json={"point": point.value, "params": params}, - headers=headers, - timeout=self.timeout, - proxies=proxies, - ) - except requests.Timeout: + with httpx.Client(mounts=mounts, timeout=self.timeout) as client: + response = client.request( + method="POST", + url=url, + json={"point": point.value, "params": params}, + headers=headers, + ) + except httpx.TimeoutException: raise ValueError("request timeout") - except requests.ConnectionError: + except httpx.RequestError: raise ValueError("request connection error") if response.status_code != 200: diff --git a/api/core/helper/code_executor/code_executor.py b/api/core/helper/code_executor/code_executor.py index 0c1d03dc13..f92278f9e2 100644 --- a/api/core/helper/code_executor/code_executor.py +++ b/api/core/helper/code_executor/code_executor.py @@ -131,7 +131,7 @@ class CodeExecutor: if (code := response_data.get("code")) != 0: raise CodeExecutionError(f"Got error code: {code}. Got error msg: {response_data.get('message')}") - response_code = CodeExecutionResponse(**response_data) + response_code = CodeExecutionResponse.model_validate(response_data) if response_code.data.error: raise CodeExecutionError(response_code.data.error) diff --git a/api/core/helper/marketplace.py b/api/core/helper/marketplace.py index 10f304c087..bddb864a95 100644 --- a/api/core/helper/marketplace.py +++ b/api/core/helper/marketplace.py @@ -26,7 +26,7 @@ def batch_fetch_plugin_manifests(plugin_ids: list[str]) -> Sequence[MarketplaceP response = httpx.post(url, json={"plugin_ids": plugin_ids}, headers={"X-Dify-Version": dify_config.project.version}) response.raise_for_status() - return [MarketplacePluginDeclaration(**plugin) for plugin in response.json()["data"]["plugins"]] + return [MarketplacePluginDeclaration.model_validate(plugin) for plugin in response.json()["data"]["plugins"]] def batch_fetch_plugin_manifests_ignore_deserialization_error( @@ -41,7 +41,7 @@ def batch_fetch_plugin_manifests_ignore_deserialization_error( result: list[MarketplacePluginDeclaration] = [] for plugin in response.json()["data"]["plugins"]: try: - result.append(MarketplacePluginDeclaration(**plugin)) + result.append(MarketplacePluginDeclaration.model_validate(plugin)) except Exception: pass diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index ee37024260..7822ed4268 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -20,7 +20,7 @@ from core.rag.cleaner.clean_processor import CleanProcessor from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.docstore.dataset_docstore import DatasetDocumentStore from core.rag.extractor.entity.datasource_type import DatasourceType -from core.rag.extractor.entity.extract_setting import ExtractSetting +from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo, WebsiteInfo from core.rag.index_processor.constant.index_type import IndexType from core.rag.index_processor.index_processor_base import BaseIndexProcessor from core.rag.index_processor.index_processor_factory import IndexProcessorFactory @@ -343,7 +343,7 @@ class IndexingRunner: if file_detail: extract_setting = ExtractSetting( - datasource_type=DatasourceType.FILE.value, + datasource_type=DatasourceType.FILE, upload_file=file_detail, document_model=dataset_document.doc_form, ) @@ -356,15 +356,17 @@ class IndexingRunner: ): raise ValueError("no notion import info found") extract_setting = ExtractSetting( - datasource_type=DatasourceType.NOTION.value, - notion_info={ - "credential_id": data_source_info["credential_id"], - "notion_workspace_id": data_source_info["notion_workspace_id"], - "notion_obj_id": data_source_info["notion_page_id"], - "notion_page_type": data_source_info["type"], - "document": dataset_document, - "tenant_id": dataset_document.tenant_id, - }, + datasource_type=DatasourceType.NOTION, + notion_info=NotionInfo.model_validate( + { + "credential_id": data_source_info["credential_id"], + "notion_workspace_id": data_source_info["notion_workspace_id"], + "notion_obj_id": data_source_info["notion_page_id"], + "notion_page_type": data_source_info["type"], + "document": dataset_document, + "tenant_id": dataset_document.tenant_id, + } + ), document_model=dataset_document.doc_form, ) text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"]) @@ -377,15 +379,17 @@ class IndexingRunner: ): raise ValueError("no website import info found") extract_setting = ExtractSetting( - datasource_type=DatasourceType.WEBSITE.value, - website_info={ - "provider": data_source_info["provider"], - "job_id": data_source_info["job_id"], - "tenant_id": dataset_document.tenant_id, - "url": data_source_info["url"], - "mode": data_source_info["mode"], - "only_main_content": data_source_info["only_main_content"], - }, + datasource_type=DatasourceType.WEBSITE, + website_info=WebsiteInfo.model_validate( + { + "provider": data_source_info["provider"], + "job_id": data_source_info["job_id"], + "tenant_id": dataset_document.tenant_id, + "url": data_source_info["url"], + "mode": data_source_info["mode"], + "only_main_content": data_source_info["only_main_content"], + } + ), document_model=dataset_document.doc_form, ) text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"]) diff --git a/api/core/llm_generator/output_parser/structured_output.py b/api/core/llm_generator/output_parser/structured_output.py index 1e302b7668..686529c3ca 100644 --- a/api/core/llm_generator/output_parser/structured_output.py +++ b/api/core/llm_generator/output_parser/structured_output.py @@ -224,8 +224,8 @@ def _handle_native_json_schema( # Set appropriate response format if required by the model for rule in rules: - if rule.name == "response_format" and ResponseFormat.JSON_SCHEMA.value in rule.options: - model_parameters["response_format"] = ResponseFormat.JSON_SCHEMA.value + if rule.name == "response_format" and ResponseFormat.JSON_SCHEMA in rule.options: + model_parameters["response_format"] = ResponseFormat.JSON_SCHEMA return model_parameters @@ -239,10 +239,10 @@ def _set_response_format(model_parameters: dict, rules: list): """ for rule in rules: if rule.name == "response_format": - if ResponseFormat.JSON.value in rule.options: - model_parameters["response_format"] = ResponseFormat.JSON.value - elif ResponseFormat.JSON_OBJECT.value in rule.options: - model_parameters["response_format"] = ResponseFormat.JSON_OBJECT.value + if ResponseFormat.JSON in rule.options: + model_parameters["response_format"] = ResponseFormat.JSON + elif ResponseFormat.JSON_OBJECT in rule.options: + model_parameters["response_format"] = ResponseFormat.JSON_OBJECT def _handle_prompt_based_schema( diff --git a/api/core/mcp/session/client_session.py b/api/core/mcp/session/client_session.py index 5817416ba4..fa1d309134 100644 --- a/api/core/mcp/session/client_session.py +++ b/api/core/mcp/session/client_session.py @@ -294,7 +294,7 @@ class ClientSession( method="completion/complete", params=types.CompleteRequestParams( ref=ref, - argument=types.CompletionArgument(**argument), + argument=types.CompletionArgument.model_validate(argument), ), ) ), diff --git a/api/core/model_runtime/entities/common_entities.py b/api/core/model_runtime/entities/common_entities.py index c7353de5af..b673efae22 100644 --- a/api/core/model_runtime/entities/common_entities.py +++ b/api/core/model_runtime/entities/common_entities.py @@ -1,4 +1,4 @@ -from pydantic import BaseModel +from pydantic import BaseModel, model_validator class I18nObject(BaseModel): @@ -9,7 +9,8 @@ class I18nObject(BaseModel): zh_Hans: str | None = None en_US: str - def __init__(self, **data): - super().__init__(**data) + @model_validator(mode="after") + def _(self): if not self.zh_Hans: self.zh_Hans = self.en_US + return self diff --git a/api/core/model_runtime/entities/provider_entities.py b/api/core/model_runtime/entities/provider_entities.py index 2ccc9e0eae..831fb9d4db 100644 --- a/api/core/model_runtime/entities/provider_entities.py +++ b/api/core/model_runtime/entities/provider_entities.py @@ -1,7 +1,7 @@ from collections.abc import Sequence from enum import Enum, StrEnum, auto -from pydantic import BaseModel, ConfigDict, Field, field_validator +from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.model_entities import AIModelEntity, ModelType @@ -46,10 +46,11 @@ class FormOption(BaseModel): value: str show_on: list[FormShowOnObject] = [] - def __init__(self, **data): - super().__init__(**data) + @model_validator(mode="after") + def _(self): if not self.label: self.label = I18nObject(en_US=self.value) + return self class CredentialFormSchema(BaseModel): diff --git a/api/core/model_runtime/model_providers/model_provider_factory.py b/api/core/model_runtime/model_providers/model_provider_factory.py index e070c17abd..e1afc41bee 100644 --- a/api/core/model_runtime/model_providers/model_provider_factory.py +++ b/api/core/model_runtime/model_providers/model_provider_factory.py @@ -269,17 +269,17 @@ class ModelProviderFactory: } if model_type == ModelType.LLM: - return LargeLanguageModel(**init_params) # type: ignore + return LargeLanguageModel.model_validate(init_params) elif model_type == ModelType.TEXT_EMBEDDING: - return TextEmbeddingModel(**init_params) # type: ignore + return TextEmbeddingModel.model_validate(init_params) elif model_type == ModelType.RERANK: - return RerankModel(**init_params) # type: ignore + return RerankModel.model_validate(init_params) elif model_type == ModelType.SPEECH2TEXT: - return Speech2TextModel(**init_params) # type: ignore + return Speech2TextModel.model_validate(init_params) elif model_type == ModelType.MODERATION: - return ModerationModel(**init_params) # type: ignore + return ModerationModel.model_validate(init_params) elif model_type == ModelType.TTS: - return TTSModel(**init_params) # type: ignore + return TTSModel.model_validate(init_params) def get_provider_icon(self, provider: str, icon_type: str, lang: str) -> tuple[bytes, str]: """ diff --git a/api/core/moderation/api/api.py b/api/core/moderation/api/api.py index 573f4ec2a7..2d72b17a04 100644 --- a/api/core/moderation/api/api.py +++ b/api/core/moderation/api/api.py @@ -51,7 +51,7 @@ class ApiModeration(Moderation): params = ModerationInputParams(app_id=self.app_id, inputs=inputs, query=query) result = self._get_config_by_requestor(APIBasedExtensionPoint.APP_MODERATION_INPUT, params.model_dump()) - return ModerationInputsResult(**result) + return ModerationInputsResult.model_validate(result) return ModerationInputsResult( flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response @@ -67,7 +67,7 @@ class ApiModeration(Moderation): params = ModerationOutputParams(app_id=self.app_id, text=text) result = self._get_config_by_requestor(APIBasedExtensionPoint.APP_MODERATION_OUTPUT, params.model_dump()) - return ModerationOutputsResult(**result) + return ModerationOutputsResult.model_validate(result) return ModerationOutputsResult( flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response diff --git a/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py b/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py index 1497bc1863..03d2d75372 100644 --- a/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py +++ b/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py @@ -213,9 +213,9 @@ class ArizePhoenixDataTrace(BaseTraceInstance): node_metadata.update(json.loads(node_execution.execution_metadata)) # Determine the correct span kind based on node type - span_kind = OpenInferenceSpanKindValues.CHAIN.value + span_kind = OpenInferenceSpanKindValues.CHAIN if node_execution.node_type == "llm": - span_kind = OpenInferenceSpanKindValues.LLM.value + span_kind = OpenInferenceSpanKindValues.LLM provider = process_data.get("model_provider") model = process_data.get("model_name") if provider: @@ -230,18 +230,18 @@ class ArizePhoenixDataTrace(BaseTraceInstance): node_metadata["prompt_tokens"] = usage_data.get("prompt_tokens", 0) node_metadata["completion_tokens"] = usage_data.get("completion_tokens", 0) elif node_execution.node_type == "dataset_retrieval": - span_kind = OpenInferenceSpanKindValues.RETRIEVER.value + span_kind = OpenInferenceSpanKindValues.RETRIEVER elif node_execution.node_type == "tool": - span_kind = OpenInferenceSpanKindValues.TOOL.value + span_kind = OpenInferenceSpanKindValues.TOOL else: - span_kind = OpenInferenceSpanKindValues.CHAIN.value + span_kind = OpenInferenceSpanKindValues.CHAIN node_span = self.tracer.start_span( name=node_execution.node_type, attributes={ SpanAttributes.INPUT_VALUE: node_execution.inputs or "{}", SpanAttributes.OUTPUT_VALUE: node_execution.outputs or "{}", - SpanAttributes.OPENINFERENCE_SPAN_KIND: span_kind, + SpanAttributes.OPENINFERENCE_SPAN_KIND: span_kind.value, SpanAttributes.METADATA: json.dumps(node_metadata, ensure_ascii=False), SpanAttributes.SESSION_ID: trace_info.conversation_id or "", }, diff --git a/api/core/ops/langfuse_trace/langfuse_trace.py b/api/core/ops/langfuse_trace/langfuse_trace.py index 931bed78d4..92e6b8ea60 100644 --- a/api/core/ops/langfuse_trace/langfuse_trace.py +++ b/api/core/ops/langfuse_trace/langfuse_trace.py @@ -73,7 +73,7 @@ class LangFuseDataTrace(BaseTraceInstance): if trace_info.message_id: trace_id = trace_info.trace_id or trace_info.message_id - name = TraceTaskName.MESSAGE_TRACE.value + name = TraceTaskName.MESSAGE_TRACE trace_data = LangfuseTrace( id=trace_id, user_id=user_id, @@ -88,7 +88,7 @@ class LangFuseDataTrace(BaseTraceInstance): self.add_trace(langfuse_trace_data=trace_data) workflow_span_data = LangfuseSpan( id=trace_info.workflow_run_id, - name=TraceTaskName.WORKFLOW_TRACE.value, + name=TraceTaskName.WORKFLOW_TRACE, input=dict(trace_info.workflow_run_inputs), output=dict(trace_info.workflow_run_outputs), trace_id=trace_id, @@ -103,7 +103,7 @@ class LangFuseDataTrace(BaseTraceInstance): trace_data = LangfuseTrace( id=trace_id, user_id=user_id, - name=TraceTaskName.WORKFLOW_TRACE.value, + name=TraceTaskName.WORKFLOW_TRACE, input=dict(trace_info.workflow_run_inputs), output=dict(trace_info.workflow_run_outputs), metadata=metadata, @@ -253,7 +253,7 @@ class LangFuseDataTrace(BaseTraceInstance): trace_data = LangfuseTrace( id=trace_id, user_id=user_id, - name=TraceTaskName.MESSAGE_TRACE.value, + name=TraceTaskName.MESSAGE_TRACE, input={ "message": trace_info.inputs, "files": file_list, @@ -303,7 +303,7 @@ class LangFuseDataTrace(BaseTraceInstance): if trace_info.message_data is None: return span_data = LangfuseSpan( - name=TraceTaskName.MODERATION_TRACE.value, + name=TraceTaskName.MODERATION_TRACE, input=trace_info.inputs, output={ "action": trace_info.action, @@ -331,7 +331,7 @@ class LangFuseDataTrace(BaseTraceInstance): ) generation_data = LangfuseGeneration( - name=TraceTaskName.SUGGESTED_QUESTION_TRACE.value, + name=TraceTaskName.SUGGESTED_QUESTION_TRACE, input=trace_info.inputs, output=str(trace_info.suggested_question), trace_id=trace_info.trace_id or trace_info.message_id, @@ -349,7 +349,7 @@ class LangFuseDataTrace(BaseTraceInstance): if trace_info.message_data is None: return dataset_retrieval_span_data = LangfuseSpan( - name=TraceTaskName.DATASET_RETRIEVAL_TRACE.value, + name=TraceTaskName.DATASET_RETRIEVAL_TRACE, input=trace_info.inputs, output={"documents": trace_info.documents}, trace_id=trace_info.trace_id or trace_info.message_id, @@ -377,7 +377,7 @@ class LangFuseDataTrace(BaseTraceInstance): def generate_name_trace(self, trace_info: GenerateNameTraceInfo): name_generation_trace_data = LangfuseTrace( - name=TraceTaskName.GENERATE_NAME_TRACE.value, + name=TraceTaskName.GENERATE_NAME_TRACE, input=trace_info.inputs, output=trace_info.outputs, user_id=trace_info.tenant_id, @@ -388,7 +388,7 @@ class LangFuseDataTrace(BaseTraceInstance): self.add_trace(langfuse_trace_data=name_generation_trace_data) name_generation_span_data = LangfuseSpan( - name=TraceTaskName.GENERATE_NAME_TRACE.value, + name=TraceTaskName.GENERATE_NAME_TRACE, input=trace_info.inputs, output=trace_info.outputs, trace_id=trace_info.conversation_id, diff --git a/api/core/ops/langsmith_trace/langsmith_trace.py b/api/core/ops/langsmith_trace/langsmith_trace.py index 24a43e1cd8..8b8117b24c 100644 --- a/api/core/ops/langsmith_trace/langsmith_trace.py +++ b/api/core/ops/langsmith_trace/langsmith_trace.py @@ -81,7 +81,7 @@ class LangSmithDataTrace(BaseTraceInstance): if trace_info.message_id: message_run = LangSmithRunModel( id=trace_info.message_id, - name=TraceTaskName.MESSAGE_TRACE.value, + name=TraceTaskName.MESSAGE_TRACE, inputs=dict(trace_info.workflow_run_inputs), outputs=dict(trace_info.workflow_run_outputs), run_type=LangSmithRunType.chain, @@ -110,7 +110,7 @@ class LangSmithDataTrace(BaseTraceInstance): file_list=trace_info.file_list, total_tokens=trace_info.total_tokens, id=trace_info.workflow_run_id, - name=TraceTaskName.WORKFLOW_TRACE.value, + name=TraceTaskName.WORKFLOW_TRACE, inputs=dict(trace_info.workflow_run_inputs), run_type=LangSmithRunType.tool, start_time=trace_info.workflow_data.created_at, @@ -271,7 +271,7 @@ class LangSmithDataTrace(BaseTraceInstance): output_tokens=trace_info.answer_tokens, total_tokens=trace_info.total_tokens, id=message_id, - name=TraceTaskName.MESSAGE_TRACE.value, + name=TraceTaskName.MESSAGE_TRACE, inputs=trace_info.inputs, run_type=LangSmithRunType.chain, start_time=trace_info.start_time, @@ -327,7 +327,7 @@ class LangSmithDataTrace(BaseTraceInstance): if trace_info.message_data is None: return langsmith_run = LangSmithRunModel( - name=TraceTaskName.MODERATION_TRACE.value, + name=TraceTaskName.MODERATION_TRACE, inputs=trace_info.inputs, outputs={ "action": trace_info.action, @@ -362,7 +362,7 @@ class LangSmithDataTrace(BaseTraceInstance): if message_data is None: return suggested_question_run = LangSmithRunModel( - name=TraceTaskName.SUGGESTED_QUESTION_TRACE.value, + name=TraceTaskName.SUGGESTED_QUESTION_TRACE, inputs=trace_info.inputs, outputs=trace_info.suggested_question, run_type=LangSmithRunType.tool, @@ -391,7 +391,7 @@ class LangSmithDataTrace(BaseTraceInstance): if trace_info.message_data is None: return dataset_retrieval_run = LangSmithRunModel( - name=TraceTaskName.DATASET_RETRIEVAL_TRACE.value, + name=TraceTaskName.DATASET_RETRIEVAL_TRACE, inputs=trace_info.inputs, outputs={"documents": trace_info.documents}, run_type=LangSmithRunType.retriever, @@ -447,7 +447,7 @@ class LangSmithDataTrace(BaseTraceInstance): def generate_name_trace(self, trace_info: GenerateNameTraceInfo): name_run = LangSmithRunModel( - name=TraceTaskName.GENERATE_NAME_TRACE.value, + name=TraceTaskName.GENERATE_NAME_TRACE, inputs=trace_info.inputs, outputs=trace_info.outputs, run_type=LangSmithRunType.tool, diff --git a/api/core/ops/opik_trace/opik_trace.py b/api/core/ops/opik_trace/opik_trace.py index 8fa92f9fcd..8050c59db9 100644 --- a/api/core/ops/opik_trace/opik_trace.py +++ b/api/core/ops/opik_trace/opik_trace.py @@ -108,7 +108,7 @@ class OpikDataTrace(BaseTraceInstance): trace_data = { "id": opik_trace_id, - "name": TraceTaskName.MESSAGE_TRACE.value, + "name": TraceTaskName.MESSAGE_TRACE, "start_time": trace_info.start_time, "end_time": trace_info.end_time, "metadata": workflow_metadata, @@ -125,7 +125,7 @@ class OpikDataTrace(BaseTraceInstance): "id": root_span_id, "parent_span_id": None, "trace_id": opik_trace_id, - "name": TraceTaskName.WORKFLOW_TRACE.value, + "name": TraceTaskName.WORKFLOW_TRACE, "input": wrap_dict("input", trace_info.workflow_run_inputs), "output": wrap_dict("output", trace_info.workflow_run_outputs), "start_time": trace_info.start_time, @@ -138,7 +138,7 @@ class OpikDataTrace(BaseTraceInstance): else: trace_data = { "id": opik_trace_id, - "name": TraceTaskName.MESSAGE_TRACE.value, + "name": TraceTaskName.MESSAGE_TRACE, "start_time": trace_info.start_time, "end_time": trace_info.end_time, "metadata": workflow_metadata, @@ -290,7 +290,7 @@ class OpikDataTrace(BaseTraceInstance): trace_data = { "id": prepare_opik_uuid(trace_info.start_time, dify_trace_id), - "name": TraceTaskName.MESSAGE_TRACE.value, + "name": TraceTaskName.MESSAGE_TRACE, "start_time": trace_info.start_time, "end_time": trace_info.end_time, "metadata": wrap_metadata(metadata), @@ -329,7 +329,7 @@ class OpikDataTrace(BaseTraceInstance): span_data = { "trace_id": prepare_opik_uuid(start_time, trace_info.trace_id or trace_info.message_id), - "name": TraceTaskName.MODERATION_TRACE.value, + "name": TraceTaskName.MODERATION_TRACE, "type": "tool", "start_time": start_time, "end_time": trace_info.end_time or trace_info.message_data.updated_at, @@ -355,7 +355,7 @@ class OpikDataTrace(BaseTraceInstance): span_data = { "trace_id": prepare_opik_uuid(start_time, trace_info.trace_id or trace_info.message_id), - "name": TraceTaskName.SUGGESTED_QUESTION_TRACE.value, + "name": TraceTaskName.SUGGESTED_QUESTION_TRACE, "type": "tool", "start_time": start_time, "end_time": trace_info.end_time or message_data.updated_at, @@ -375,7 +375,7 @@ class OpikDataTrace(BaseTraceInstance): span_data = { "trace_id": prepare_opik_uuid(start_time, trace_info.trace_id or trace_info.message_id), - "name": TraceTaskName.DATASET_RETRIEVAL_TRACE.value, + "name": TraceTaskName.DATASET_RETRIEVAL_TRACE, "type": "tool", "start_time": start_time, "end_time": trace_info.end_time or trace_info.message_data.updated_at, @@ -405,7 +405,7 @@ class OpikDataTrace(BaseTraceInstance): def generate_name_trace(self, trace_info: GenerateNameTraceInfo): trace_data = { "id": prepare_opik_uuid(trace_info.start_time, trace_info.trace_id or trace_info.message_id), - "name": TraceTaskName.GENERATE_NAME_TRACE.value, + "name": TraceTaskName.GENERATE_NAME_TRACE, "start_time": trace_info.start_time, "end_time": trace_info.end_time, "metadata": wrap_metadata(trace_info.metadata), @@ -420,7 +420,7 @@ class OpikDataTrace(BaseTraceInstance): span_data = { "trace_id": trace.id, - "name": TraceTaskName.GENERATE_NAME_TRACE.value, + "name": TraceTaskName.GENERATE_NAME_TRACE, "start_time": trace_info.start_time, "end_time": trace_info.end_time, "metadata": wrap_metadata(trace_info.metadata), diff --git a/api/core/ops/weave_trace/weave_trace.py b/api/core/ops/weave_trace/weave_trace.py index 185bdd8179..9b3d7a8192 100644 --- a/api/core/ops/weave_trace/weave_trace.py +++ b/api/core/ops/weave_trace/weave_trace.py @@ -104,7 +104,7 @@ class WeaveDataTrace(BaseTraceInstance): message_run = WeaveTraceModel( id=trace_info.message_id, - op=str(TraceTaskName.MESSAGE_TRACE.value), + op=str(TraceTaskName.MESSAGE_TRACE), inputs=dict(trace_info.workflow_run_inputs), outputs=dict(trace_info.workflow_run_outputs), total_tokens=trace_info.total_tokens, @@ -126,7 +126,7 @@ class WeaveDataTrace(BaseTraceInstance): file_list=trace_info.file_list, total_tokens=trace_info.total_tokens, id=trace_info.workflow_run_id, - op=str(TraceTaskName.WORKFLOW_TRACE.value), + op=str(TraceTaskName.WORKFLOW_TRACE), inputs=dict(trace_info.workflow_run_inputs), outputs=dict(trace_info.workflow_run_outputs), attributes=workflow_attributes, @@ -253,7 +253,7 @@ class WeaveDataTrace(BaseTraceInstance): message_run = WeaveTraceModel( id=trace_id, - op=str(TraceTaskName.MESSAGE_TRACE.value), + op=str(TraceTaskName.MESSAGE_TRACE), input_tokens=trace_info.message_tokens, output_tokens=trace_info.answer_tokens, total_tokens=trace_info.total_tokens, @@ -300,7 +300,7 @@ class WeaveDataTrace(BaseTraceInstance): moderation_run = WeaveTraceModel( id=str(uuid.uuid4()), - op=str(TraceTaskName.MODERATION_TRACE.value), + op=str(TraceTaskName.MODERATION_TRACE), inputs=trace_info.inputs, outputs={ "action": trace_info.action, @@ -330,7 +330,7 @@ class WeaveDataTrace(BaseTraceInstance): suggested_question_run = WeaveTraceModel( id=str(uuid.uuid4()), - op=str(TraceTaskName.SUGGESTED_QUESTION_TRACE.value), + op=str(TraceTaskName.SUGGESTED_QUESTION_TRACE), inputs=trace_info.inputs, outputs=trace_info.suggested_question, attributes=attributes, @@ -355,7 +355,7 @@ class WeaveDataTrace(BaseTraceInstance): dataset_retrieval_run = WeaveTraceModel( id=str(uuid.uuid4()), - op=str(TraceTaskName.DATASET_RETRIEVAL_TRACE.value), + op=str(TraceTaskName.DATASET_RETRIEVAL_TRACE), inputs=trace_info.inputs, outputs={"documents": trace_info.documents}, attributes=attributes, @@ -397,7 +397,7 @@ class WeaveDataTrace(BaseTraceInstance): name_run = WeaveTraceModel( id=str(uuid.uuid4()), - op=str(TraceTaskName.GENERATE_NAME_TRACE.value), + op=str(TraceTaskName.GENERATE_NAME_TRACE), inputs=trace_info.inputs, outputs=trace_info.outputs, attributes=attributes, diff --git a/api/core/plugin/backwards_invocation/node.py b/api/core/plugin/backwards_invocation/node.py index 1d6d21cff7..9fbcbf55b4 100644 --- a/api/core/plugin/backwards_invocation/node.py +++ b/api/core/plugin/backwards_invocation/node.py @@ -52,7 +52,7 @@ class PluginNodeBackwardsInvocation(BaseBackwardsInvocation): instruction=instruction, # instruct with variables are not supported ) node_data_dict = node_data.model_dump() - node_data_dict["type"] = NodeType.PARAMETER_EXTRACTOR.value + node_data_dict["type"] = NodeType.PARAMETER_EXTRACTOR execution = workflow_service.run_free_workflow_node( node_data_dict, tenant_id=tenant_id, diff --git a/api/core/plugin/entities/request.py b/api/core/plugin/entities/request.py index 10f37f75f8..d5df85730b 100644 --- a/api/core/plugin/entities/request.py +++ b/api/core/plugin/entities/request.py @@ -83,16 +83,16 @@ class RequestInvokeLLM(BaseRequestInvokeModel): raise ValueError("prompt_messages must be a list") for i in range(len(v)): - if v[i]["role"] == PromptMessageRole.USER.value: - v[i] = UserPromptMessage(**v[i]) - elif v[i]["role"] == PromptMessageRole.ASSISTANT.value: - v[i] = AssistantPromptMessage(**v[i]) - elif v[i]["role"] == PromptMessageRole.SYSTEM.value: - v[i] = SystemPromptMessage(**v[i]) - elif v[i]["role"] == PromptMessageRole.TOOL.value: - v[i] = ToolPromptMessage(**v[i]) + if v[i]["role"] == PromptMessageRole.USER: + v[i] = UserPromptMessage.model_validate(v[i]) + elif v[i]["role"] == PromptMessageRole.ASSISTANT: + v[i] = AssistantPromptMessage.model_validate(v[i]) + elif v[i]["role"] == PromptMessageRole.SYSTEM: + v[i] = SystemPromptMessage.model_validate(v[i]) + elif v[i]["role"] == PromptMessageRole.TOOL: + v[i] = ToolPromptMessage.model_validate(v[i]) else: - v[i] = PromptMessage(**v[i]) + v[i] = PromptMessage.model_validate(v[i]) return v diff --git a/api/core/plugin/impl/base.py b/api/core/plugin/impl/base.py index 8e3df4da2c..c791b35161 100644 --- a/api/core/plugin/impl/base.py +++ b/api/core/plugin/impl/base.py @@ -2,11 +2,10 @@ import inspect import json import logging from collections.abc import Callable, Generator -from typing import TypeVar +from typing import Any, TypeVar -import requests +import httpx from pydantic import BaseModel -from requests.exceptions import HTTPError from yarl import URL from configs import dify_config @@ -47,29 +46,56 @@ class BasePluginClient: data: bytes | dict | str | None = None, params: dict | None = None, files: dict | None = None, - stream: bool = False, - ) -> requests.Response: + ) -> httpx.Response: """ Make a request to the plugin daemon inner API. """ - url = plugin_daemon_inner_api_baseurl / path - headers = headers or {} - headers["X-Api-Key"] = dify_config.PLUGIN_DAEMON_KEY - headers["Accept-Encoding"] = "gzip, deflate, br" + url, headers, prepared_data, params, files = self._prepare_request(path, headers, data, params, files) - if headers.get("Content-Type") == "application/json" and isinstance(data, dict): - data = json.dumps(data) + request_kwargs: dict[str, Any] = { + "method": method, + "url": url, + "headers": headers, + "params": params, + "files": files, + } + if isinstance(prepared_data, dict): + request_kwargs["data"] = prepared_data + elif prepared_data is not None: + request_kwargs["content"] = prepared_data try: - response = requests.request( - method=method, url=str(url), headers=headers, data=data, params=params, stream=stream, files=files - ) - except requests.ConnectionError: + response = httpx.request(**request_kwargs) + except httpx.RequestError: logger.exception("Request to Plugin Daemon Service failed") raise PluginDaemonInnerError(code=-500, message="Request to Plugin Daemon Service failed") return response + def _prepare_request( + self, + path: str, + headers: dict | None, + data: bytes | dict | str | None, + params: dict | None, + files: dict | None, + ) -> tuple[str, dict, bytes | dict | str | None, dict | None, dict | None]: + url = plugin_daemon_inner_api_baseurl / path + prepared_headers = dict(headers or {}) + prepared_headers["X-Api-Key"] = dify_config.PLUGIN_DAEMON_KEY + prepared_headers.setdefault("Accept-Encoding", "gzip, deflate, br") + + prepared_data: bytes | dict | str | None = ( + data if isinstance(data, (bytes, str, dict)) or data is None else None + ) + if isinstance(data, dict): + if prepared_headers.get("Content-Type") == "application/json": + prepared_data = json.dumps(data) + else: + prepared_data = data + + return str(url), prepared_headers, prepared_data, params, files + def _stream_request( self, method: str, @@ -78,23 +104,44 @@ class BasePluginClient: headers: dict | None = None, data: bytes | dict | None = None, files: dict | None = None, - ) -> Generator[bytes, None, None]: + ) -> Generator[str, None, None]: """ Make a stream request to the plugin daemon inner API """ - response = self._request(method, path, headers, data, params, files, stream=True) - for line in response.iter_lines(chunk_size=1024 * 8): - line = line.decode("utf-8").strip() - if line.startswith("data:"): - line = line[5:].strip() - if line: - yield line + url, headers, prepared_data, params, files = self._prepare_request(path, headers, data, params, files) + + stream_kwargs: dict[str, Any] = { + "method": method, + "url": url, + "headers": headers, + "params": params, + "files": files, + } + if isinstance(prepared_data, dict): + stream_kwargs["data"] = prepared_data + elif prepared_data is not None: + stream_kwargs["content"] = prepared_data + + try: + with httpx.stream(**stream_kwargs) as response: + for raw_line in response.iter_lines(): + if raw_line is None: + continue + line = raw_line.decode("utf-8") if isinstance(raw_line, bytes) else raw_line + line = line.strip() + if line.startswith("data:"): + line = line[5:].strip() + if line: + yield line + except httpx.RequestError: + logger.exception("Stream request to Plugin Daemon Service failed") + raise PluginDaemonInnerError(code=-500, message="Request to Plugin Daemon Service failed") def _stream_request_with_model( self, method: str, path: str, - type: type[T], + type_: type[T], headers: dict | None = None, data: bytes | dict | None = None, params: dict | None = None, @@ -104,13 +151,13 @@ class BasePluginClient: Make a stream request to the plugin daemon inner API and yield the response as a model. """ for line in self._stream_request(method, path, params, headers, data, files): - yield type(**json.loads(line)) # type: ignore + yield type_(**json.loads(line)) # type: ignore def _request_with_model( self, method: str, path: str, - type: type[T], + type_: type[T], headers: dict | None = None, data: bytes | None = None, params: dict | None = None, @@ -120,13 +167,13 @@ class BasePluginClient: Make a request to the plugin daemon inner API and return the response as a model. """ response = self._request(method, path, headers, data, params, files) - return type(**response.json()) # type: ignore + return type_(**response.json()) # type: ignore def _request_with_plugin_daemon_response( self, method: str, path: str, - type: type[T], + type_: type[T], headers: dict | None = None, data: bytes | dict | None = None, params: dict | None = None, @@ -139,23 +186,23 @@ class BasePluginClient: try: response = self._request(method, path, headers, data, params, files) response.raise_for_status() - except HTTPError as e: - msg = f"Failed to request plugin daemon, status: {e.response.status_code}, url: {path}" - logger.exception(msg) + except httpx.HTTPStatusError as e: + logger.exception("Failed to request plugin daemon, status: %s, url: %s", e.response.status_code, path) raise e except Exception as e: msg = f"Failed to request plugin daemon, url: {path}" - logger.exception(msg) + logger.exception("Failed to request plugin daemon, url: %s", path) raise ValueError(msg) from e try: json_response = response.json() if transformer: json_response = transformer(json_response) - rep = PluginDaemonBasicResponse[type](**json_response) # type: ignore + # https://stackoverflow.com/questions/59634937/variable-foo-class-is-not-valid-as-type-but-why + rep = PluginDaemonBasicResponse[type_].model_validate(json_response) # type: ignore except Exception: msg = ( - f"Failed to parse response from plugin daemon to PluginDaemonBasicResponse [{str(type.__name__)}]," + f"Failed to parse response from plugin daemon to PluginDaemonBasicResponse [{str(type_.__name__)}]," f" url: {path}" ) logger.exception(msg) @@ -163,7 +210,7 @@ class BasePluginClient: if rep.code != 0: try: - error = PluginDaemonError(**json.loads(rep.message)) + error = PluginDaemonError.model_validate(json.loads(rep.message)) except Exception: raise ValueError(f"{rep.message}, code: {rep.code}") @@ -178,7 +225,7 @@ class BasePluginClient: self, method: str, path: str, - type: type[T], + type_: type[T], headers: dict | None = None, data: bytes | dict | None = None, params: dict | None = None, @@ -189,7 +236,7 @@ class BasePluginClient: """ for line in self._stream_request(method, path, params, headers, data, files): try: - rep = PluginDaemonBasicResponse[type].model_validate_json(line) # type: ignore + rep = PluginDaemonBasicResponse[type_].model_validate_json(line) # type: ignore except (ValueError, TypeError): # TODO modify this when line_data has code and message try: @@ -204,7 +251,7 @@ class BasePluginClient: if rep.code != 0: if rep.code == -500: try: - error = PluginDaemonError(**json.loads(rep.message)) + error = PluginDaemonError.model_validate(json.loads(rep.message)) except Exception: raise PluginDaemonInnerError(code=rep.code, message=rep.message) diff --git a/api/core/plugin/impl/datasource.py b/api/core/plugin/impl/datasource.py index 84087f8104..ce1ef71494 100644 --- a/api/core/plugin/impl/datasource.py +++ b/api/core/plugin/impl/datasource.py @@ -46,7 +46,9 @@ class PluginDatasourceManager(BasePluginClient): params={"page": 1, "page_size": 256}, transformer=transformer, ) - local_file_datasource_provider = PluginDatasourceProviderEntity(**self._get_local_file_datasource_provider()) + local_file_datasource_provider = PluginDatasourceProviderEntity.model_validate( + self._get_local_file_datasource_provider() + ) for provider in response: ToolTransformService.repack_provider(tenant_id=tenant_id, provider=provider) @@ -104,7 +106,7 @@ class PluginDatasourceManager(BasePluginClient): Fetch datasource provider for the given tenant and plugin. """ if provider_id == "langgenius/file/file": - return PluginDatasourceProviderEntity(**self._get_local_file_datasource_provider()) + return PluginDatasourceProviderEntity.model_validate(self._get_local_file_datasource_provider()) tool_provider_id = DatasourceProviderID(provider_id) diff --git a/api/core/plugin/impl/model.py b/api/core/plugin/impl/model.py index 153da142f4..5dfc3c212e 100644 --- a/api/core/plugin/impl/model.py +++ b/api/core/plugin/impl/model.py @@ -162,7 +162,7 @@ class PluginModelClient(BasePluginClient): response = self._request_with_plugin_daemon_response_stream( method="POST", path=f"plugin/{tenant_id}/dispatch/llm/invoke", - type=LLMResultChunk, + type_=LLMResultChunk, data=jsonable_encoder( { "user_id": user_id, @@ -208,7 +208,7 @@ class PluginModelClient(BasePluginClient): response = self._request_with_plugin_daemon_response_stream( method="POST", path=f"plugin/{tenant_id}/dispatch/llm/num_tokens", - type=PluginLLMNumTokensResponse, + type_=PluginLLMNumTokensResponse, data=jsonable_encoder( { "user_id": user_id, @@ -250,7 +250,7 @@ class PluginModelClient(BasePluginClient): response = self._request_with_plugin_daemon_response_stream( method="POST", path=f"plugin/{tenant_id}/dispatch/text_embedding/invoke", - type=TextEmbeddingResult, + type_=TextEmbeddingResult, data=jsonable_encoder( { "user_id": user_id, @@ -291,7 +291,7 @@ class PluginModelClient(BasePluginClient): response = self._request_with_plugin_daemon_response_stream( method="POST", path=f"plugin/{tenant_id}/dispatch/text_embedding/num_tokens", - type=PluginTextEmbeddingNumTokensResponse, + type_=PluginTextEmbeddingNumTokensResponse, data=jsonable_encoder( { "user_id": user_id, @@ -334,7 +334,7 @@ class PluginModelClient(BasePluginClient): response = self._request_with_plugin_daemon_response_stream( method="POST", path=f"plugin/{tenant_id}/dispatch/rerank/invoke", - type=RerankResult, + type_=RerankResult, data=jsonable_encoder( { "user_id": user_id, @@ -378,7 +378,7 @@ class PluginModelClient(BasePluginClient): response = self._request_with_plugin_daemon_response_stream( method="POST", path=f"plugin/{tenant_id}/dispatch/tts/invoke", - type=PluginStringResultResponse, + type_=PluginStringResultResponse, data=jsonable_encoder( { "user_id": user_id, @@ -422,7 +422,7 @@ class PluginModelClient(BasePluginClient): response = self._request_with_plugin_daemon_response_stream( method="POST", path=f"plugin/{tenant_id}/dispatch/tts/model/voices", - type=PluginVoicesResponse, + type_=PluginVoicesResponse, data=jsonable_encoder( { "user_id": user_id, @@ -466,7 +466,7 @@ class PluginModelClient(BasePluginClient): response = self._request_with_plugin_daemon_response_stream( method="POST", path=f"plugin/{tenant_id}/dispatch/speech2text/invoke", - type=PluginStringResultResponse, + type_=PluginStringResultResponse, data=jsonable_encoder( { "user_id": user_id, @@ -506,7 +506,7 @@ class PluginModelClient(BasePluginClient): response = self._request_with_plugin_daemon_response_stream( method="POST", path=f"plugin/{tenant_id}/dispatch/moderation/invoke", - type=PluginBasicBooleanResponse, + type_=PluginBasicBooleanResponse, data=jsonable_encoder( { "user_id": user_id, diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py index 499d39bd5d..7bc9830ac3 100644 --- a/api/core/provider_manager.py +++ b/api/core/provider_manager.py @@ -610,7 +610,7 @@ class ProviderManager: provider_quota_to_provider_record_dict = {} for provider_record in provider_records: - if provider_record.provider_type != ProviderType.SYSTEM.value: + if provider_record.provider_type != ProviderType.SYSTEM: continue provider_quota_to_provider_record_dict[ProviderQuotaType.value_of(provider_record.quota_type)] = ( @@ -627,8 +627,8 @@ class ProviderManager: tenant_id=tenant_id, # TODO: Use provider name with prefix after the data migration. provider_name=ModelProviderID(provider_name).provider_name, - provider_type=ProviderType.SYSTEM.value, - quota_type=ProviderQuotaType.TRIAL.value, + provider_type=ProviderType.SYSTEM, + quota_type=ProviderQuotaType.TRIAL, quota_limit=quota.quota_limit, # type: ignore quota_used=0, is_valid=True, @@ -641,8 +641,8 @@ class ProviderManager: stmt = select(Provider).where( Provider.tenant_id == tenant_id, Provider.provider_name == ModelProviderID(provider_name).provider_name, - Provider.provider_type == ProviderType.SYSTEM.value, - Provider.quota_type == ProviderQuotaType.TRIAL.value, + Provider.provider_type == ProviderType.SYSTEM, + Provider.quota_type == ProviderQuotaType.TRIAL, ) existed_provider_record = db.session.scalar(stmt) if not existed_provider_record: @@ -702,7 +702,7 @@ class ProviderManager: """Get custom provider configuration.""" # Find custom provider record (non-system) custom_provider_record = next( - (record for record in provider_records if record.provider_type != ProviderType.SYSTEM.value), None + (record for record in provider_records if record.provider_type != ProviderType.SYSTEM), None ) if not custom_provider_record: @@ -905,7 +905,7 @@ class ProviderManager: # Convert provider_records to dict quota_type_to_provider_records_dict: dict[ProviderQuotaType, Provider] = {} for provider_record in provider_records: - if provider_record.provider_type != ProviderType.SYSTEM.value: + if provider_record.provider_type != ProviderType.SYSTEM: continue quota_type_to_provider_records_dict[ProviderQuotaType.value_of(provider_record.quota_type)] = ( @@ -1046,7 +1046,7 @@ class ProviderManager: """ secret_input_form_variables = [] for credential_form_schema in credential_form_schemas: - if credential_form_schema.type.value == FormType.SECRET_INPUT.value: + if credential_form_schema.type.value == FormType.SECRET_INPUT: secret_input_form_variables.append(credential_form_schema.variable) return secret_input_form_variables diff --git a/api/core/rag/data_post_processor/data_post_processor.py b/api/core/rag/data_post_processor/data_post_processor.py index 696e3e967f..cc946a72c3 100644 --- a/api/core/rag/data_post_processor/data_post_processor.py +++ b/api/core/rag/data_post_processor/data_post_processor.py @@ -46,7 +46,7 @@ class DataPostProcessor: reranking_model: dict | None = None, weights: dict | None = None, ) -> BaseRerankRunner | None: - if reranking_mode == RerankMode.WEIGHTED_SCORE.value and weights: + if reranking_mode == RerankMode.WEIGHTED_SCORE and weights: runner = RerankRunnerFactory.create_rerank_runner( runner_type=reranking_mode, tenant_id=tenant_id, @@ -62,7 +62,7 @@ class DataPostProcessor: ), ) return runner - elif reranking_mode == RerankMode.RERANKING_MODEL.value: + elif reranking_mode == RerankMode.RERANKING_MODEL: rerank_model_instance = self._get_rerank_model_instance(tenant_id, reranking_model) if rerank_model_instance is None: return None diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py index 63a1d911ca..6e9e2b4527 100644 --- a/api/core/rag/datasource/retrieval_service.py +++ b/api/core/rag/datasource/retrieval_service.py @@ -21,7 +21,7 @@ from models.dataset import Document as DatasetDocument from services.external_knowledge_service import ExternalDatasetService default_retrieval_model = { - "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, + "search_method": RetrievalMethod.SEMANTIC_SEARCH, "reranking_enable": False, "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, "top_k": 4, @@ -107,7 +107,7 @@ class RetrievalService: raise ValueError(";\n".join(exceptions)) # Deduplicate documents for hybrid search to avoid duplicate chunks - if retrieval_method == RetrievalMethod.HYBRID_SEARCH.value: + if retrieval_method == RetrievalMethod.HYBRID_SEARCH: all_documents = cls._deduplicate_documents(all_documents) data_post_processor = DataPostProcessor( str(dataset.tenant_id), reranking_mode, reranking_model, weights, False @@ -134,7 +134,7 @@ class RetrievalService: if not dataset: return [] metadata_condition = ( - MetadataCondition(**metadata_filtering_conditions) if metadata_filtering_conditions else None + MetadataCondition.model_validate(metadata_filtering_conditions) if metadata_filtering_conditions else None ) all_documents = ExternalDatasetService.fetch_external_knowledge_retrieval( dataset.tenant_id, @@ -245,10 +245,10 @@ class RetrievalService: reranking_model and reranking_model.get("reranking_model_name") and reranking_model.get("reranking_provider_name") - and retrieval_method == RetrievalMethod.SEMANTIC_SEARCH.value + and retrieval_method == RetrievalMethod.SEMANTIC_SEARCH ): data_post_processor = DataPostProcessor( - str(dataset.tenant_id), str(RerankMode.RERANKING_MODEL.value), reranking_model, None, False + str(dataset.tenant_id), str(RerankMode.RERANKING_MODEL), reranking_model, None, False ) all_documents.extend( data_post_processor.invoke( @@ -293,10 +293,10 @@ class RetrievalService: reranking_model and reranking_model.get("reranking_model_name") and reranking_model.get("reranking_provider_name") - and retrieval_method == RetrievalMethod.FULL_TEXT_SEARCH.value + and retrieval_method == RetrievalMethod.FULL_TEXT_SEARCH ): data_post_processor = DataPostProcessor( - str(dataset.tenant_id), str(RerankMode.RERANKING_MODEL.value), reranking_model, None, False + str(dataset.tenant_id), str(RerankMode.RERANKING_MODEL), reranking_model, None, False ) all_documents.extend( data_post_processor.invoke( diff --git a/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py b/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py index e55e5f3101..a306f9ba0c 100644 --- a/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py +++ b/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py @@ -488,9 +488,9 @@ class ClickzettaVector(BaseVector): create_table_sql = f""" CREATE TABLE IF NOT EXISTS {self._config.schema_name}.{self._table_name} ( id STRING NOT NULL COMMENT 'Unique document identifier', - {Field.CONTENT_KEY.value} STRING NOT NULL COMMENT 'Document text content for search and retrieval', - {Field.METADATA_KEY.value} JSON COMMENT 'Document metadata including source, type, and other attributes', - {Field.VECTOR.value} VECTOR(FLOAT, {dimension}) NOT NULL COMMENT + {Field.CONTENT_KEY} STRING NOT NULL COMMENT 'Document text content for search and retrieval', + {Field.METADATA_KEY} JSON COMMENT 'Document metadata including source, type, and other attributes', + {Field.VECTOR} VECTOR(FLOAT, {dimension}) NOT NULL COMMENT 'High-dimensional embedding vector for semantic similarity search', PRIMARY KEY (id) ) COMMENT 'Dify RAG knowledge base vector storage table for document embeddings and content' @@ -519,15 +519,15 @@ class ClickzettaVector(BaseVector): existing_indexes = cursor.fetchall() for idx in existing_indexes: # Check if vector index already exists on the embedding column - if Field.VECTOR.value in str(idx).lower(): - logger.info("Vector index already exists on column %s", Field.VECTOR.value) + if Field.VECTOR in str(idx).lower(): + logger.info("Vector index already exists on column %s", Field.VECTOR) return except (RuntimeError, ValueError) as e: logger.warning("Failed to check existing indexes: %s", e) index_sql = f""" CREATE VECTOR INDEX IF NOT EXISTS {index_name} - ON TABLE {self._config.schema_name}.{self._table_name}({Field.VECTOR.value}) + ON TABLE {self._config.schema_name}.{self._table_name}({Field.VECTOR}) PROPERTIES ( "distance.function" = "{self._config.vector_distance_function}", "scalar.type" = "f32", @@ -560,17 +560,17 @@ class ClickzettaVector(BaseVector): # More precise check: look for inverted index specifically on the content column if ( "inverted" in idx_str - and Field.CONTENT_KEY.value.lower() in idx_str + and Field.CONTENT_KEY.lower() in idx_str and (index_name.lower() in idx_str or f"idx_{self._table_name}_text" in idx_str) ): - logger.info("Inverted index already exists on column %s: %s", Field.CONTENT_KEY.value, idx) + logger.info("Inverted index already exists on column %s: %s", Field.CONTENT_KEY, idx) return except (RuntimeError, ValueError) as e: logger.warning("Failed to check existing indexes: %s", e) index_sql = f""" CREATE INVERTED INDEX IF NOT EXISTS {index_name} - ON TABLE {self._config.schema_name}.{self._table_name} ({Field.CONTENT_KEY.value}) + ON TABLE {self._config.schema_name}.{self._table_name} ({Field.CONTENT_KEY}) PROPERTIES ( "analyzer" = "{self._config.analyzer_type}", "mode" = "{self._config.analyzer_mode}" @@ -588,13 +588,13 @@ class ClickzettaVector(BaseVector): or "with the same type" in error_msg or "cannot create inverted index" in error_msg ) and "already has index" in error_msg: - logger.info("Inverted index already exists on column %s", Field.CONTENT_KEY.value) + logger.info("Inverted index already exists on column %s", Field.CONTENT_KEY) # Try to get the existing index name for logging try: cursor.execute(f"SHOW INDEX FROM {self._config.schema_name}.{self._table_name}") existing_indexes = cursor.fetchall() for idx in existing_indexes: - if "inverted" in str(idx).lower() and Field.CONTENT_KEY.value.lower() in str(idx).lower(): + if "inverted" in str(idx).lower() and Field.CONTENT_KEY.lower() in str(idx).lower(): logger.info("Found existing inverted index: %s", idx) break except (RuntimeError, ValueError): @@ -669,7 +669,7 @@ class ClickzettaVector(BaseVector): # Use parameterized INSERT with executemany for better performance and security # Cast JSON and VECTOR in SQL, pass raw data as parameters - columns = f"id, {Field.CONTENT_KEY.value}, {Field.METADATA_KEY.value}, {Field.VECTOR.value}" + columns = f"id, {Field.CONTENT_KEY}, {Field.METADATA_KEY}, {Field.VECTOR}" insert_sql = ( f"INSERT INTO {self._config.schema_name}.{self._table_name} ({columns}) " f"VALUES (?, ?, CAST(? AS JSON), CAST(? AS VECTOR({vector_dimension})))" @@ -767,7 +767,7 @@ class ClickzettaVector(BaseVector): # Use json_extract_string function for ClickZetta compatibility sql = ( f"DELETE FROM {self._config.schema_name}.{self._table_name} " - f"WHERE json_extract_string({Field.METADATA_KEY.value}, '$.{key}') = ?" + f"WHERE json_extract_string({Field.METADATA_KEY}, '$.{key}') = ?" ) cursor.execute(sql, binding_params=[value]) @@ -795,9 +795,7 @@ class ClickzettaVector(BaseVector): safe_doc_ids = [str(id).replace("'", "''") for id in document_ids_filter] doc_ids_str = ",".join(f"'{id}'" for id in safe_doc_ids) # Use json_extract_string function for ClickZetta compatibility - filter_clauses.append( - f"json_extract_string({Field.METADATA_KEY.value}, '$.document_id') IN ({doc_ids_str})" - ) + filter_clauses.append(f"json_extract_string({Field.METADATA_KEY}, '$.document_id') IN ({doc_ids_str})") # No need for dataset_id filter since each dataset has its own table @@ -808,23 +806,21 @@ class ClickzettaVector(BaseVector): distance_func = "COSINE_DISTANCE" if score_threshold > 0: query_vector_str = f"CAST('[{self._format_vector_simple(query_vector)}]' AS VECTOR({vector_dimension}))" - filter_clauses.append( - f"{distance_func}({Field.VECTOR.value}, {query_vector_str}) < {2 - score_threshold}" - ) + filter_clauses.append(f"{distance_func}({Field.VECTOR}, {query_vector_str}) < {2 - score_threshold}") else: # For L2 distance, smaller is better distance_func = "L2_DISTANCE" if score_threshold > 0: query_vector_str = f"CAST('[{self._format_vector_simple(query_vector)}]' AS VECTOR({vector_dimension}))" - filter_clauses.append(f"{distance_func}({Field.VECTOR.value}, {query_vector_str}) < {score_threshold}") + filter_clauses.append(f"{distance_func}({Field.VECTOR}, {query_vector_str}) < {score_threshold}") where_clause = " AND ".join(filter_clauses) if filter_clauses else "1=1" # Execute vector search query query_vector_str = f"CAST('[{self._format_vector_simple(query_vector)}]' AS VECTOR({vector_dimension}))" search_sql = f""" - SELECT id, {Field.CONTENT_KEY.value}, {Field.METADATA_KEY.value}, - {distance_func}({Field.VECTOR.value}, {query_vector_str}) AS distance + SELECT id, {Field.CONTENT_KEY}, {Field.METADATA_KEY}, + {distance_func}({Field.VECTOR}, {query_vector_str}) AS distance FROM {self._config.schema_name}.{self._table_name} WHERE {where_clause} ORDER BY distance @@ -887,9 +883,7 @@ class ClickzettaVector(BaseVector): safe_doc_ids = [str(id).replace("'", "''") for id in document_ids_filter] doc_ids_str = ",".join(f"'{id}'" for id in safe_doc_ids) # Use json_extract_string function for ClickZetta compatibility - filter_clauses.append( - f"json_extract_string({Field.METADATA_KEY.value}, '$.document_id') IN ({doc_ids_str})" - ) + filter_clauses.append(f"json_extract_string({Field.METADATA_KEY}, '$.document_id') IN ({doc_ids_str})") # No need for dataset_id filter since each dataset has its own table @@ -897,13 +891,13 @@ class ClickzettaVector(BaseVector): # match_all requires all terms to be present # Use simple quote escaping for MATCH_ALL since it needs to be in the WHERE clause escaped_query = query.replace("'", "''") - filter_clauses.append(f"MATCH_ALL({Field.CONTENT_KEY.value}, '{escaped_query}')") + filter_clauses.append(f"MATCH_ALL({Field.CONTENT_KEY}, '{escaped_query}')") where_clause = " AND ".join(filter_clauses) # Execute full-text search query search_sql = f""" - SELECT id, {Field.CONTENT_KEY.value}, {Field.METADATA_KEY.value} + SELECT id, {Field.CONTENT_KEY}, {Field.METADATA_KEY} FROM {self._config.schema_name}.{self._table_name} WHERE {where_clause} LIMIT {top_k} @@ -986,19 +980,17 @@ class ClickzettaVector(BaseVector): safe_doc_ids = [str(id).replace("'", "''") for id in document_ids_filter] doc_ids_str = ",".join(f"'{id}'" for id in safe_doc_ids) # Use json_extract_string function for ClickZetta compatibility - filter_clauses.append( - f"json_extract_string({Field.METADATA_KEY.value}, '$.document_id') IN ({doc_ids_str})" - ) + filter_clauses.append(f"json_extract_string({Field.METADATA_KEY}, '$.document_id') IN ({doc_ids_str})") # No need for dataset_id filter since each dataset has its own table # Use simple quote escaping for LIKE clause escaped_query = query.replace("'", "''") - filter_clauses.append(f"{Field.CONTENT_KEY.value} LIKE '%{escaped_query}%'") + filter_clauses.append(f"{Field.CONTENT_KEY} LIKE '%{escaped_query}%'") where_clause = " AND ".join(filter_clauses) search_sql = f""" - SELECT id, {Field.CONTENT_KEY.value}, {Field.METADATA_KEY.value} + SELECT id, {Field.CONTENT_KEY}, {Field.METADATA_KEY} FROM {self._config.schema_name}.{self._table_name} WHERE {where_clause} LIMIT {top_k} diff --git a/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_ja_vector.py b/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_ja_vector.py index 7b00928b7b..1e7fe52666 100644 --- a/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_ja_vector.py +++ b/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_ja_vector.py @@ -57,18 +57,18 @@ class ElasticSearchJaVector(ElasticSearchVector): } mappings = { "properties": { - Field.CONTENT_KEY.value: { + Field.CONTENT_KEY: { "type": "text", "analyzer": "ja_analyzer", "search_analyzer": "ja_analyzer", }, - Field.VECTOR.value: { # Make sure the dimension is correct here + Field.VECTOR: { # Make sure the dimension is correct here "type": "dense_vector", "dims": dim, "index": True, "similarity": "cosine", }, - Field.METADATA_KEY.value: { + Field.METADATA_KEY: { "type": "object", "properties": { "doc_id": {"type": "keyword"} # Map doc_id to keyword type diff --git a/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py b/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py index 2c147fa7ca..0ff8c915e6 100644 --- a/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py +++ b/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py @@ -4,7 +4,7 @@ import math from typing import Any, cast from urllib.parse import urlparse -import requests +from elasticsearch import ConnectionError as ElasticsearchConnectionError from elasticsearch import Elasticsearch from flask import current_app from packaging.version import parse as parse_version @@ -138,7 +138,7 @@ class ElasticSearchVector(BaseVector): if not client.ping(): raise ConnectionError("Failed to connect to Elasticsearch") - except requests.ConnectionError as e: + except ElasticsearchConnectionError as e: raise ConnectionError(f"Vector database connection error: {str(e)}") except Exception as e: raise ConnectionError(f"Elasticsearch client initialization failed: {str(e)}") @@ -163,9 +163,9 @@ class ElasticSearchVector(BaseVector): index=self._collection_name, id=uuids[i], document={ - Field.CONTENT_KEY.value: documents[i].page_content, - Field.VECTOR.value: embeddings[i] or None, - Field.METADATA_KEY.value: documents[i].metadata or {}, + Field.CONTENT_KEY: documents[i].page_content, + Field.VECTOR: embeddings[i] or None, + Field.METADATA_KEY: documents[i].metadata or {}, }, ) self._client.indices.refresh(index=self._collection_name) @@ -193,7 +193,7 @@ class ElasticSearchVector(BaseVector): def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: top_k = kwargs.get("top_k", 4) num_candidates = math.ceil(top_k * 1.5) - knn = {"field": Field.VECTOR.value, "query_vector": query_vector, "k": top_k, "num_candidates": num_candidates} + knn = {"field": Field.VECTOR, "query_vector": query_vector, "k": top_k, "num_candidates": num_candidates} document_ids_filter = kwargs.get("document_ids_filter") if document_ids_filter: knn["filter"] = {"terms": {"metadata.document_id": document_ids_filter}} @@ -205,9 +205,9 @@ class ElasticSearchVector(BaseVector): docs_and_scores.append( ( Document( - page_content=hit["_source"][Field.CONTENT_KEY.value], - vector=hit["_source"][Field.VECTOR.value], - metadata=hit["_source"][Field.METADATA_KEY.value], + page_content=hit["_source"][Field.CONTENT_KEY], + vector=hit["_source"][Field.VECTOR], + metadata=hit["_source"][Field.METADATA_KEY], ), hit["_score"], ) @@ -224,13 +224,13 @@ class ElasticSearchVector(BaseVector): return docs def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: - query_str: dict[str, Any] = {"match": {Field.CONTENT_KEY.value: query}} + query_str: dict[str, Any] = {"match": {Field.CONTENT_KEY: query}} document_ids_filter = kwargs.get("document_ids_filter") if document_ids_filter: query_str = { "bool": { - "must": {"match": {Field.CONTENT_KEY.value: query}}, + "must": {"match": {Field.CONTENT_KEY: query}}, "filter": {"terms": {"metadata.document_id": document_ids_filter}}, } } @@ -240,9 +240,9 @@ class ElasticSearchVector(BaseVector): for hit in results["hits"]["hits"]: docs.append( Document( - page_content=hit["_source"][Field.CONTENT_KEY.value], - vector=hit["_source"][Field.VECTOR.value], - metadata=hit["_source"][Field.METADATA_KEY.value], + page_content=hit["_source"][Field.CONTENT_KEY], + vector=hit["_source"][Field.VECTOR], + metadata=hit["_source"][Field.METADATA_KEY], ) ) @@ -270,14 +270,14 @@ class ElasticSearchVector(BaseVector): dim = len(embeddings[0]) mappings = { "properties": { - Field.CONTENT_KEY.value: {"type": "text"}, - Field.VECTOR.value: { # Make sure the dimension is correct here + Field.CONTENT_KEY: {"type": "text"}, + Field.VECTOR: { # Make sure the dimension is correct here "type": "dense_vector", "dims": dim, "index": True, "similarity": "cosine", }, - Field.METADATA_KEY.value: { + Field.METADATA_KEY: { "type": "object", "properties": { "doc_id": {"type": "keyword"}, # Map doc_id to keyword type diff --git a/api/core/rag/datasource/vdb/huawei/huawei_cloud_vector.py b/api/core/rag/datasource/vdb/huawei/huawei_cloud_vector.py index cfee090768..c7b6593a8f 100644 --- a/api/core/rag/datasource/vdb/huawei/huawei_cloud_vector.py +++ b/api/core/rag/datasource/vdb/huawei/huawei_cloud_vector.py @@ -67,9 +67,9 @@ class HuaweiCloudVector(BaseVector): index=self._collection_name, id=uuids[i], document={ - Field.CONTENT_KEY.value: documents[i].page_content, - Field.VECTOR.value: embeddings[i] or None, - Field.METADATA_KEY.value: documents[i].metadata or {}, + Field.CONTENT_KEY: documents[i].page_content, + Field.VECTOR: embeddings[i] or None, + Field.METADATA_KEY: documents[i].metadata or {}, }, ) self._client.indices.refresh(index=self._collection_name) @@ -101,7 +101,7 @@ class HuaweiCloudVector(BaseVector): "size": top_k, "query": { "vector": { - Field.VECTOR.value: { + Field.VECTOR: { "vector": query_vector, "topk": top_k, } @@ -116,9 +116,9 @@ class HuaweiCloudVector(BaseVector): docs_and_scores.append( ( Document( - page_content=hit["_source"][Field.CONTENT_KEY.value], - vector=hit["_source"][Field.VECTOR.value], - metadata=hit["_source"][Field.METADATA_KEY.value], + page_content=hit["_source"][Field.CONTENT_KEY], + vector=hit["_source"][Field.VECTOR], + metadata=hit["_source"][Field.METADATA_KEY], ), hit["_score"], ) @@ -135,15 +135,15 @@ class HuaweiCloudVector(BaseVector): return docs def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: - query_str = {"match": {Field.CONTENT_KEY.value: query}} + query_str = {"match": {Field.CONTENT_KEY: query}} results = self._client.search(index=self._collection_name, query=query_str, size=kwargs.get("top_k", 4)) docs = [] for hit in results["hits"]["hits"]: docs.append( Document( - page_content=hit["_source"][Field.CONTENT_KEY.value], - vector=hit["_source"][Field.VECTOR.value], - metadata=hit["_source"][Field.METADATA_KEY.value], + page_content=hit["_source"][Field.CONTENT_KEY], + vector=hit["_source"][Field.VECTOR], + metadata=hit["_source"][Field.METADATA_KEY], ) ) @@ -171,8 +171,8 @@ class HuaweiCloudVector(BaseVector): dim = len(embeddings[0]) mappings = { "properties": { - Field.CONTENT_KEY.value: {"type": "text"}, - Field.VECTOR.value: { # Make sure the dimension is correct here + Field.CONTENT_KEY: {"type": "text"}, + Field.VECTOR: { # Make sure the dimension is correct here "type": "vector", "dimension": dim, "indexing": True, @@ -181,7 +181,7 @@ class HuaweiCloudVector(BaseVector): "neighbors": 32, "efc": 128, }, - Field.METADATA_KEY.value: { + Field.METADATA_KEY: { "type": "object", "properties": { "doc_id": {"type": "keyword"} # Map doc_id to keyword type diff --git a/api/core/rag/datasource/vdb/lindorm/lindorm_vector.py b/api/core/rag/datasource/vdb/lindorm/lindorm_vector.py index 8824e1c67b..bfcb620618 100644 --- a/api/core/rag/datasource/vdb/lindorm/lindorm_vector.py +++ b/api/core/rag/datasource/vdb/lindorm/lindorm_vector.py @@ -125,9 +125,9 @@ class LindormVectorStore(BaseVector): } } action_values: dict[str, Any] = { - Field.CONTENT_KEY.value: documents[i].page_content, - Field.VECTOR.value: embeddings[i], - Field.METADATA_KEY.value: documents[i].metadata, + Field.CONTENT_KEY: documents[i].page_content, + Field.VECTOR: embeddings[i], + Field.METADATA_KEY: documents[i].metadata, } if self._using_ugc: action_header["index"]["routing"] = self._routing @@ -149,7 +149,7 @@ class LindormVectorStore(BaseVector): def get_ids_by_metadata_field(self, key: str, value: str): query: dict[str, Any] = { - "query": {"bool": {"must": [{"term": {f"{Field.METADATA_KEY.value}.{key}.keyword": value}}]}} + "query": {"bool": {"must": [{"term": {f"{Field.METADATA_KEY}.{key}.keyword": value}}]}} } if self._using_ugc: query["query"]["bool"]["must"].append({"term": {f"{ROUTING_FIELD}.keyword": self._routing}}) @@ -252,14 +252,14 @@ class LindormVectorStore(BaseVector): search_query: dict[str, Any] = { "size": top_k, "_source": True, - "query": {"knn": {Field.VECTOR.value: {"vector": query_vector, "k": top_k}}}, + "query": {"knn": {Field.VECTOR: {"vector": query_vector, "k": top_k}}}, } final_ext: dict[str, Any] = {"lvector": {}} if filters is not None and len(filters) > 0: # when using filter, transform filter from List[Dict] to Dict as valid format filter_dict = {"bool": {"must": filters}} if len(filters) > 1 else filters[0] - search_query["query"]["knn"][Field.VECTOR.value]["filter"] = filter_dict # filter should be Dict + search_query["query"]["knn"][Field.VECTOR]["filter"] = filter_dict # filter should be Dict final_ext["lvector"]["filter_type"] = "pre_filter" if final_ext != {"lvector": {}}: @@ -279,9 +279,9 @@ class LindormVectorStore(BaseVector): docs_and_scores.append( ( Document( - page_content=hit["_source"][Field.CONTENT_KEY.value], - vector=hit["_source"][Field.VECTOR.value], - metadata=hit["_source"][Field.METADATA_KEY.value], + page_content=hit["_source"][Field.CONTENT_KEY], + vector=hit["_source"][Field.VECTOR], + metadata=hit["_source"][Field.METADATA_KEY], ), hit["_score"], ) @@ -318,9 +318,9 @@ class LindormVectorStore(BaseVector): docs = [] for hit in response["hits"]["hits"]: - metadata = hit["_source"].get(Field.METADATA_KEY.value) - vector = hit["_source"].get(Field.VECTOR.value) - page_content = hit["_source"].get(Field.CONTENT_KEY.value) + metadata = hit["_source"].get(Field.METADATA_KEY) + vector = hit["_source"].get(Field.VECTOR) + page_content = hit["_source"].get(Field.CONTENT_KEY) doc = Document(page_content=page_content, vector=vector, metadata=metadata) docs.append(doc) @@ -342,8 +342,8 @@ class LindormVectorStore(BaseVector): "settings": {"index": {"knn": True, "knn_routing": self._using_ugc}}, "mappings": { "properties": { - Field.CONTENT_KEY.value: {"type": "text"}, - Field.VECTOR.value: { + Field.CONTENT_KEY: {"type": "text"}, + Field.VECTOR: { "type": "knn_vector", "dimension": len(embeddings[0]), # Make sure the dimension is correct here "method": { diff --git a/api/core/rag/datasource/vdb/milvus/milvus_vector.py b/api/core/rag/datasource/vdb/milvus/milvus_vector.py index 5f32feb709..96eb465401 100644 --- a/api/core/rag/datasource/vdb/milvus/milvus_vector.py +++ b/api/core/rag/datasource/vdb/milvus/milvus_vector.py @@ -85,7 +85,7 @@ class MilvusVector(BaseVector): collection_info = self._client.describe_collection(self._collection_name) fields = [field["name"] for field in collection_info["fields"]] # Since primary field is auto-id, no need to track it - self._fields = [f for f in fields if f != Field.PRIMARY_KEY.value] + self._fields = [f for f in fields if f != Field.PRIMARY_KEY] def _check_hybrid_search_support(self) -> bool: """ @@ -130,9 +130,9 @@ class MilvusVector(BaseVector): insert_dict = { # Do not need to insert the sparse_vector field separately, as the text_bm25_emb # function will automatically convert the native text into a sparse vector for us. - Field.CONTENT_KEY.value: documents[i].page_content, - Field.VECTOR.value: embeddings[i], - Field.METADATA_KEY.value: documents[i].metadata, + Field.CONTENT_KEY: documents[i].page_content, + Field.VECTOR: embeddings[i], + Field.METADATA_KEY: documents[i].metadata, } insert_dict_list.append(insert_dict) # Total insert count @@ -243,15 +243,15 @@ class MilvusVector(BaseVector): results = self._client.search( collection_name=self._collection_name, data=[query_vector], - anns_field=Field.VECTOR.value, + anns_field=Field.VECTOR, limit=kwargs.get("top_k", 4), - output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value], + output_fields=[Field.CONTENT_KEY, Field.METADATA_KEY], filter=filter, ) return self._process_search_results( results, - output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value], + output_fields=[Field.CONTENT_KEY, Field.METADATA_KEY], score_threshold=float(kwargs.get("score_threshold") or 0.0), ) @@ -264,7 +264,7 @@ class MilvusVector(BaseVector): "Full-text search is disabled: set MILVUS_ENABLE_HYBRID_SEARCH=true (requires Milvus >= 2.5.0)." ) return [] - if not self.field_exists(Field.SPARSE_VECTOR.value): + if not self.field_exists(Field.SPARSE_VECTOR): logger.warning( "Full-text search unavailable: collection missing 'sparse_vector' field; " "recreate the collection after enabling MILVUS_ENABLE_HYBRID_SEARCH to add BM25 sparse index." @@ -279,15 +279,15 @@ class MilvusVector(BaseVector): results = self._client.search( collection_name=self._collection_name, data=[query], - anns_field=Field.SPARSE_VECTOR.value, + anns_field=Field.SPARSE_VECTOR, limit=kwargs.get("top_k", 4), - output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value], + output_fields=[Field.CONTENT_KEY, Field.METADATA_KEY], filter=filter, ) return self._process_search_results( results, - output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value], + output_fields=[Field.CONTENT_KEY, Field.METADATA_KEY], score_threshold=float(kwargs.get("score_threshold") or 0.0), ) @@ -311,7 +311,7 @@ class MilvusVector(BaseVector): dim = len(embeddings[0]) fields = [] if metadatas: - fields.append(FieldSchema(Field.METADATA_KEY.value, DataType.JSON, max_length=65_535)) + fields.append(FieldSchema(Field.METADATA_KEY, DataType.JSON, max_length=65_535)) # Create the text field, enable_analyzer will be set True to support milvus automatically # transfer text to sparse_vector, reference: https://milvus.io/docs/full-text-search.md @@ -326,15 +326,15 @@ class MilvusVector(BaseVector): ): content_field_kwargs["analyzer_params"] = self._client_config.analyzer_params - fields.append(FieldSchema(Field.CONTENT_KEY.value, DataType.VARCHAR, **content_field_kwargs)) + fields.append(FieldSchema(Field.CONTENT_KEY, DataType.VARCHAR, **content_field_kwargs)) # Create the primary key field - fields.append(FieldSchema(Field.PRIMARY_KEY.value, DataType.INT64, is_primary=True, auto_id=True)) + fields.append(FieldSchema(Field.PRIMARY_KEY, DataType.INT64, is_primary=True, auto_id=True)) # Create the vector field, supports binary or float vectors - fields.append(FieldSchema(Field.VECTOR.value, infer_dtype_bydata(embeddings[0]), dim=dim)) + fields.append(FieldSchema(Field.VECTOR, infer_dtype_bydata(embeddings[0]), dim=dim)) # Create Sparse Vector Index for the collection if self._hybrid_search_enabled: - fields.append(FieldSchema(Field.SPARSE_VECTOR.value, DataType.SPARSE_FLOAT_VECTOR)) + fields.append(FieldSchema(Field.SPARSE_VECTOR, DataType.SPARSE_FLOAT_VECTOR)) schema = CollectionSchema(fields) @@ -342,8 +342,8 @@ class MilvusVector(BaseVector): if self._hybrid_search_enabled: bm25_function = Function( name="text_bm25_emb", - input_field_names=[Field.CONTENT_KEY.value], - output_field_names=[Field.SPARSE_VECTOR.value], + input_field_names=[Field.CONTENT_KEY], + output_field_names=[Field.SPARSE_VECTOR], function_type=FunctionType.BM25, ) schema.add_function(bm25_function) @@ -352,12 +352,12 @@ class MilvusVector(BaseVector): # Create Index params for the collection index_params_obj = IndexParams() - index_params_obj.add_index(field_name=Field.VECTOR.value, **index_params) + index_params_obj.add_index(field_name=Field.VECTOR, **index_params) # Create Sparse Vector Index for the collection if self._hybrid_search_enabled: index_params_obj.add_index( - field_name=Field.SPARSE_VECTOR.value, index_type="AUTOINDEX", metric_type="BM25" + field_name=Field.SPARSE_VECTOR, index_type="AUTOINDEX", metric_type="BM25" ) # Create the collection diff --git a/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py b/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py index 3eb1df027e..80ffdadd96 100644 --- a/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py +++ b/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py @@ -1,6 +1,6 @@ import json import logging -from typing import Any, Literal +from typing import Any from uuid import uuid4 from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnection, helpers @@ -8,6 +8,7 @@ from opensearchpy.helpers import BulkIndexError from pydantic import BaseModel, model_validator from configs import dify_config +from configs.middleware.vdb.opensearch_config import AuthMethod from core.rag.datasource.vdb.field import Field from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory @@ -25,7 +26,7 @@ class OpenSearchConfig(BaseModel): port: int secure: bool = False # use_ssl verify_certs: bool = True - auth_method: Literal["basic", "aws_managed_iam"] = "basic" + auth_method: AuthMethod = AuthMethod.BASIC user: str | None = None password: str | None = None aws_region: str | None = None @@ -98,9 +99,9 @@ class OpenSearchVector(BaseVector): "_op_type": "index", "_index": self._collection_name.lower(), "_source": { - Field.CONTENT_KEY.value: documents[i].page_content, - Field.VECTOR.value: embeddings[i], # Make sure you pass an array here - Field.METADATA_KEY.value: documents[i].metadata, + Field.CONTENT_KEY: documents[i].page_content, + Field.VECTOR: embeddings[i], # Make sure you pass an array here + Field.METADATA_KEY: documents[i].metadata, }, } # See https://github.com/langchain-ai/langchainjs/issues/4346#issuecomment-1935123377 @@ -116,7 +117,7 @@ class OpenSearchVector(BaseVector): ) def get_ids_by_metadata_field(self, key: str, value: str): - query = {"query": {"term": {f"{Field.METADATA_KEY.value}.{key}": value}}} + query = {"query": {"term": {f"{Field.METADATA_KEY}.{key}": value}}} response = self._client.search(index=self._collection_name.lower(), body=query) if response["hits"]["hits"]: return [hit["_id"] for hit in response["hits"]["hits"]] @@ -180,17 +181,17 @@ class OpenSearchVector(BaseVector): query = { "size": kwargs.get("top_k", 4), - "query": {"knn": {Field.VECTOR.value: {Field.VECTOR.value: query_vector, "k": kwargs.get("top_k", 4)}}}, + "query": {"knn": {Field.VECTOR: {Field.VECTOR: query_vector, "k": kwargs.get("top_k", 4)}}}, } document_ids_filter = kwargs.get("document_ids_filter") if document_ids_filter: query["query"] = { "script_score": { - "query": {"bool": {"filter": [{"terms": {Field.DOCUMENT_ID.value: document_ids_filter}}]}}, + "query": {"bool": {"filter": [{"terms": {Field.DOCUMENT_ID: document_ids_filter}}]}}, "script": { "source": "knn_score", "lang": "knn", - "params": {"field": Field.VECTOR.value, "query_value": query_vector, "space_type": "l2"}, + "params": {"field": Field.VECTOR, "query_value": query_vector, "space_type": "l2"}, }, } } @@ -203,7 +204,7 @@ class OpenSearchVector(BaseVector): docs = [] for hit in response["hits"]["hits"]: - metadata = hit["_source"].get(Field.METADATA_KEY.value, {}) + metadata = hit["_source"].get(Field.METADATA_KEY, {}) # Make sure metadata is a dictionary if metadata is None: @@ -212,7 +213,7 @@ class OpenSearchVector(BaseVector): metadata["score"] = hit["_score"] score_threshold = float(kwargs.get("score_threshold") or 0.0) if hit["_score"] >= score_threshold: - doc = Document(page_content=hit["_source"].get(Field.CONTENT_KEY.value), metadata=metadata) + doc = Document(page_content=hit["_source"].get(Field.CONTENT_KEY), metadata=metadata) docs.append(doc) return docs @@ -227,9 +228,9 @@ class OpenSearchVector(BaseVector): docs = [] for hit in response["hits"]["hits"]: - metadata = hit["_source"].get(Field.METADATA_KEY.value) - vector = hit["_source"].get(Field.VECTOR.value) - page_content = hit["_source"].get(Field.CONTENT_KEY.value) + metadata = hit["_source"].get(Field.METADATA_KEY) + vector = hit["_source"].get(Field.VECTOR) + page_content = hit["_source"].get(Field.CONTENT_KEY) doc = Document(page_content=page_content, vector=vector, metadata=metadata) docs.append(doc) @@ -250,8 +251,8 @@ class OpenSearchVector(BaseVector): "settings": {"index": {"knn": True}}, "mappings": { "properties": { - Field.CONTENT_KEY.value: {"type": "text"}, - Field.VECTOR.value: { + Field.CONTENT_KEY: {"type": "text"}, + Field.VECTOR: { "type": "knn_vector", "dimension": len(embeddings[0]), # Make sure the dimension is correct here "method": { @@ -261,7 +262,7 @@ class OpenSearchVector(BaseVector): "parameters": {"ef_construction": 64, "m": 8}, }, }, - Field.METADATA_KEY.value: { + Field.METADATA_KEY: { "type": "object", "properties": { "doc_id": {"type": "keyword"}, # Map doc_id to keyword type @@ -293,7 +294,7 @@ class OpenSearchVectorFactory(AbstractVectorFactory): port=dify_config.OPENSEARCH_PORT, secure=dify_config.OPENSEARCH_SECURE, verify_certs=dify_config.OPENSEARCH_VERIFY_CERTS, - auth_method=dify_config.OPENSEARCH_AUTH_METHOD.value, + auth_method=dify_config.OPENSEARCH_AUTH_METHOD, user=dify_config.OPENSEARCH_USER, password=dify_config.OPENSEARCH_PASSWORD, aws_region=dify_config.OPENSEARCH_AWS_REGION, diff --git a/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py b/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py index d46f29bd64..f8c62b908a 100644 --- a/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py +++ b/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py @@ -147,15 +147,13 @@ class QdrantVector(BaseVector): # create group_id payload index self._client.create_payload_index( - collection_name, Field.GROUP_KEY.value, field_schema=PayloadSchemaType.KEYWORD + collection_name, Field.GROUP_KEY, field_schema=PayloadSchemaType.KEYWORD ) # create doc_id payload index - self._client.create_payload_index( - collection_name, Field.DOC_ID.value, field_schema=PayloadSchemaType.KEYWORD - ) + self._client.create_payload_index(collection_name, Field.DOC_ID, field_schema=PayloadSchemaType.KEYWORD) # create document_id payload index self._client.create_payload_index( - collection_name, Field.DOCUMENT_ID.value, field_schema=PayloadSchemaType.KEYWORD + collection_name, Field.DOCUMENT_ID, field_schema=PayloadSchemaType.KEYWORD ) # create full text index text_index_params = TextIndexParams( @@ -165,9 +163,7 @@ class QdrantVector(BaseVector): max_token_len=20, lowercase=True, ) - self._client.create_payload_index( - collection_name, Field.CONTENT_KEY.value, field_schema=text_index_params - ) + self._client.create_payload_index(collection_name, Field.CONTENT_KEY, field_schema=text_index_params) redis_client.set(collection_exist_cache_key, 1, ex=3600) def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): @@ -220,10 +216,10 @@ class QdrantVector(BaseVector): self._build_payloads( batch_texts, batch_metadatas, - Field.CONTENT_KEY.value, - Field.METADATA_KEY.value, + Field.CONTENT_KEY, + Field.METADATA_KEY, group_id or "", # Ensure group_id is never None - Field.GROUP_KEY.value, + Field.GROUP_KEY, ), ) ] @@ -381,12 +377,12 @@ class QdrantVector(BaseVector): for result in results: if result.payload is None: continue - metadata = result.payload.get(Field.METADATA_KEY.value) or {} + metadata = result.payload.get(Field.METADATA_KEY) or {} # duplicate check score threshold if result.score >= score_threshold: metadata["score"] = result.score doc = Document( - page_content=result.payload.get(Field.CONTENT_KEY.value, ""), + page_content=result.payload.get(Field.CONTENT_KEY, ""), metadata=metadata, ) docs.append(doc) @@ -433,7 +429,7 @@ class QdrantVector(BaseVector): documents = [] for result in results: if result: - document = self._document_from_scored_point(result, Field.CONTENT_KEY.value, Field.METADATA_KEY.value) + document = self._document_from_scored_point(result, Field.CONTENT_KEY, Field.METADATA_KEY) documents.append(document) return documents diff --git a/api/core/rag/datasource/vdb/tablestore/tablestore_vector.py b/api/core/rag/datasource/vdb/tablestore/tablestore_vector.py index e91d9bb0d6..f2156afa59 100644 --- a/api/core/rag/datasource/vdb/tablestore/tablestore_vector.py +++ b/api/core/rag/datasource/vdb/tablestore/tablestore_vector.py @@ -55,7 +55,7 @@ class TableStoreVector(BaseVector): self._normalize_full_text_bm25_score = config.normalize_full_text_bm25_score self._table_name = f"{collection_name}" self._index_name = f"{collection_name}_idx" - self._tags_field = f"{Field.METADATA_KEY.value}_tags" + self._tags_field = f"{Field.METADATA_KEY}_tags" def create_collection(self, embeddings: list[list[float]], **kwargs): dimension = len(embeddings[0]) @@ -64,7 +64,7 @@ class TableStoreVector(BaseVector): def get_by_ids(self, ids: list[str]) -> list[Document]: docs = [] request = BatchGetRowRequest() - columns_to_get = [Field.METADATA_KEY.value, Field.CONTENT_KEY.value] + columns_to_get = [Field.METADATA_KEY, Field.CONTENT_KEY] rows_to_get = [[("id", _id)] for _id in ids] request.add(TableInBatchGetRowItem(self._table_name, rows_to_get, columns_to_get, None, 1)) @@ -73,11 +73,7 @@ class TableStoreVector(BaseVector): for item in table_result: if item.is_ok and item.row: kv = {k: v for k, v, _ in item.row.attribute_columns} - docs.append( - Document( - page_content=kv[Field.CONTENT_KEY.value], metadata=json.loads(kv[Field.METADATA_KEY.value]) - ) - ) + docs.append(Document(page_content=kv[Field.CONTENT_KEY], metadata=json.loads(kv[Field.METADATA_KEY]))) return docs def get_type(self) -> str: @@ -95,9 +91,9 @@ class TableStoreVector(BaseVector): self._write_row( primary_key=uuids[i], attributes={ - Field.CONTENT_KEY.value: documents[i].page_content, - Field.VECTOR.value: embeddings[i], - Field.METADATA_KEY.value: documents[i].metadata, + Field.CONTENT_KEY: documents[i].page_content, + Field.VECTOR: embeddings[i], + Field.METADATA_KEY: documents[i].metadata, }, ) return uuids @@ -180,7 +176,7 @@ class TableStoreVector(BaseVector): field_schemas = [ tablestore.FieldSchema( - Field.CONTENT_KEY.value, + Field.CONTENT_KEY, tablestore.FieldType.TEXT, analyzer=tablestore.AnalyzerType.MAXWORD, index=True, @@ -188,7 +184,7 @@ class TableStoreVector(BaseVector): store=False, ), tablestore.FieldSchema( - Field.VECTOR.value, + Field.VECTOR, tablestore.FieldType.VECTOR, vector_options=tablestore.VectorOptions( data_type=tablestore.VectorDataType.VD_FLOAT_32, @@ -197,7 +193,7 @@ class TableStoreVector(BaseVector): ), ), tablestore.FieldSchema( - Field.METADATA_KEY.value, + Field.METADATA_KEY, tablestore.FieldType.KEYWORD, index=True, store=False, @@ -233,15 +229,15 @@ class TableStoreVector(BaseVector): pk = [("id", primary_key)] tags = [] - for key, value in attributes[Field.METADATA_KEY.value].items(): + for key, value in attributes[Field.METADATA_KEY].items(): tags.append(str(key) + "=" + str(value)) attribute_columns = [ - (Field.CONTENT_KEY.value, attributes[Field.CONTENT_KEY.value]), - (Field.VECTOR.value, json.dumps(attributes[Field.VECTOR.value])), + (Field.CONTENT_KEY, attributes[Field.CONTENT_KEY]), + (Field.VECTOR, json.dumps(attributes[Field.VECTOR])), ( - Field.METADATA_KEY.value, - json.dumps(attributes[Field.METADATA_KEY.value]), + Field.METADATA_KEY, + json.dumps(attributes[Field.METADATA_KEY]), ), (self._tags_field, json.dumps(tags)), ] @@ -270,7 +266,7 @@ class TableStoreVector(BaseVector): index_name=self._index_name, search_query=query, columns_to_get=tablestore.ColumnsToGet( - column_names=[Field.PRIMARY_KEY.value], return_type=tablestore.ColumnReturnType.SPECIFIED + column_names=[Field.PRIMARY_KEY], return_type=tablestore.ColumnReturnType.SPECIFIED ), ) @@ -288,7 +284,7 @@ class TableStoreVector(BaseVector): self, query_vector: list[float], document_ids_filter: list[str] | None, top_k: int, score_threshold: float ) -> list[Document]: knn_vector_query = tablestore.KnnVectorQuery( - field_name=Field.VECTOR.value, + field_name=Field.VECTOR, top_k=top_k, float32_query_vector=query_vector, ) @@ -311,8 +307,8 @@ class TableStoreVector(BaseVector): for col in search_hit.row[1]: ots_column_map[col[0]] = col[1] - vector_str = ots_column_map.get(Field.VECTOR.value) - metadata_str = ots_column_map.get(Field.METADATA_KEY.value) + vector_str = ots_column_map.get(Field.VECTOR) + metadata_str = ots_column_map.get(Field.METADATA_KEY) vector = json.loads(vector_str) if vector_str else None metadata = json.loads(metadata_str) if metadata_str else {} @@ -321,7 +317,7 @@ class TableStoreVector(BaseVector): documents.append( Document( - page_content=ots_column_map.get(Field.CONTENT_KEY.value) or "", + page_content=ots_column_map.get(Field.CONTENT_KEY) or "", vector=vector, metadata=metadata, ) @@ -343,7 +339,7 @@ class TableStoreVector(BaseVector): self, query: str, document_ids_filter: list[str] | None, top_k: int, score_threshold: float ) -> list[Document]: bool_query = tablestore.BoolQuery(must_queries=[], filter_queries=[], should_queries=[], must_not_queries=[]) - bool_query.must_queries.append(tablestore.MatchQuery(text=query, field_name=Field.CONTENT_KEY.value)) + bool_query.must_queries.append(tablestore.MatchQuery(text=query, field_name=Field.CONTENT_KEY)) if document_ids_filter: bool_query.filter_queries.append(tablestore.TermsQuery(self._tags_field, document_ids_filter)) @@ -374,10 +370,10 @@ class TableStoreVector(BaseVector): for col in search_hit.row[1]: ots_column_map[col[0]] = col[1] - metadata_str = ots_column_map.get(Field.METADATA_KEY.value) + metadata_str = ots_column_map.get(Field.METADATA_KEY) metadata = json.loads(metadata_str) if metadata_str else {} - vector_str = ots_column_map.get(Field.VECTOR.value) + vector_str = ots_column_map.get(Field.VECTOR) vector = json.loads(vector_str) if vector_str else None if score: @@ -385,7 +381,7 @@ class TableStoreVector(BaseVector): documents.append( Document( - page_content=ots_column_map.get(Field.CONTENT_KEY.value) or "", + page_content=ots_column_map.get(Field.CONTENT_KEY) or "", vector=vector, metadata=metadata, ) diff --git a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py index f90a311df4..56ffb36a2b 100644 --- a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py +++ b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py @@ -5,9 +5,10 @@ from collections.abc import Generator, Iterable, Sequence from itertools import islice from typing import TYPE_CHECKING, Any, Union +import httpx import qdrant_client -import requests from flask import current_app +from httpx import DigestAuth from pydantic import BaseModel from qdrant_client.http import models as rest from qdrant_client.http.models import ( @@ -19,7 +20,6 @@ from qdrant_client.http.models import ( TokenizerType, ) from qdrant_client.local.qdrant_local import QdrantLocal -from requests.auth import HTTPDigestAuth from sqlalchemy import select from configs import dify_config @@ -141,15 +141,13 @@ class TidbOnQdrantVector(BaseVector): # create group_id payload index self._client.create_payload_index( - collection_name, Field.GROUP_KEY.value, field_schema=PayloadSchemaType.KEYWORD + collection_name, Field.GROUP_KEY, field_schema=PayloadSchemaType.KEYWORD ) # create doc_id payload index - self._client.create_payload_index( - collection_name, Field.DOC_ID.value, field_schema=PayloadSchemaType.KEYWORD - ) + self._client.create_payload_index(collection_name, Field.DOC_ID, field_schema=PayloadSchemaType.KEYWORD) # create document_id payload index self._client.create_payload_index( - collection_name, Field.DOCUMENT_ID.value, field_schema=PayloadSchemaType.KEYWORD + collection_name, Field.DOCUMENT_ID, field_schema=PayloadSchemaType.KEYWORD ) # create full text index text_index_params = TextIndexParams( @@ -159,9 +157,7 @@ class TidbOnQdrantVector(BaseVector): max_token_len=20, lowercase=True, ) - self._client.create_payload_index( - collection_name, Field.CONTENT_KEY.value, field_schema=text_index_params - ) + self._client.create_payload_index(collection_name, Field.CONTENT_KEY, field_schema=text_index_params) redis_client.set(collection_exist_cache_key, 1, ex=3600) def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): @@ -211,10 +207,10 @@ class TidbOnQdrantVector(BaseVector): self._build_payloads( batch_texts, batch_metadatas, - Field.CONTENT_KEY.value, - Field.METADATA_KEY.value, + Field.CONTENT_KEY, + Field.METADATA_KEY, group_id or "", - Field.GROUP_KEY.value, + Field.GROUP_KEY, ), ) ] @@ -349,13 +345,13 @@ class TidbOnQdrantVector(BaseVector): for result in results: if result.payload is None: continue - metadata = result.payload.get(Field.METADATA_KEY.value) or {} + metadata = result.payload.get(Field.METADATA_KEY) or {} # duplicate check score threshold score_threshold = kwargs.get("score_threshold") or 0.0 if result.score >= score_threshold: metadata["score"] = result.score doc = Document( - page_content=result.payload.get(Field.CONTENT_KEY.value, ""), + page_content=result.payload.get(Field.CONTENT_KEY, ""), metadata=metadata, ) docs.append(doc) @@ -392,7 +388,7 @@ class TidbOnQdrantVector(BaseVector): documents = [] for result in results: if result: - document = self._document_from_scored_point(result, Field.CONTENT_KEY.value, Field.METADATA_KEY.value) + document = self._document_from_scored_point(result, Field.CONTENT_KEY, Field.METADATA_KEY) documents.append(document) return documents @@ -504,10 +500,10 @@ class TidbOnQdrantVectorFactory(AbstractVectorFactory): } cluster_data = {"displayName": display_name, "region": region_object, "labels": labels} - response = requests.post( + response = httpx.post( f"{tidb_config.api_url}/clusters", json=cluster_data, - auth=HTTPDigestAuth(tidb_config.public_key, tidb_config.private_key), + auth=DigestAuth(tidb_config.public_key, tidb_config.private_key), ) if response.status_code == 200: @@ -527,10 +523,10 @@ class TidbOnQdrantVectorFactory(AbstractVectorFactory): body = {"password": new_password} - response = requests.put( + response = httpx.put( f"{tidb_config.api_url}/clusters/{cluster_id}/password", json=body, - auth=HTTPDigestAuth(tidb_config.public_key, tidb_config.private_key), + auth=DigestAuth(tidb_config.public_key, tidb_config.private_key), ) if response.status_code == 200: diff --git a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py index e1d4422144..754c149241 100644 --- a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py +++ b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py @@ -2,8 +2,8 @@ import time import uuid from collections.abc import Sequence -import requests -from requests.auth import HTTPDigestAuth +import httpx +from httpx import DigestAuth from configs import dify_config from extensions.ext_database import db @@ -49,7 +49,7 @@ class TidbService: "rootPassword": password, } - response = requests.post(f"{api_url}/clusters", json=cluster_data, auth=HTTPDigestAuth(public_key, private_key)) + response = httpx.post(f"{api_url}/clusters", json=cluster_data, auth=DigestAuth(public_key, private_key)) if response.status_code == 200: response_data = response.json() @@ -83,7 +83,7 @@ class TidbService: :return: The response from the API. """ - response = requests.delete(f"{api_url}/clusters/{cluster_id}", auth=HTTPDigestAuth(public_key, private_key)) + response = httpx.delete(f"{api_url}/clusters/{cluster_id}", auth=DigestAuth(public_key, private_key)) if response.status_code == 200: return response.json() @@ -102,7 +102,7 @@ class TidbService: :return: The response from the API. """ - response = requests.get(f"{api_url}/clusters/{cluster_id}", auth=HTTPDigestAuth(public_key, private_key)) + response = httpx.get(f"{api_url}/clusters/{cluster_id}", auth=DigestAuth(public_key, private_key)) if response.status_code == 200: return response.json() @@ -127,10 +127,10 @@ class TidbService: body = {"password": new_password, "builtinRole": "role_admin", "customRoles": []} - response = requests.patch( + response = httpx.patch( f"{api_url}/clusters/{cluster_id}/sqlUsers/{account}", json=body, - auth=HTTPDigestAuth(public_key, private_key), + auth=DigestAuth(public_key, private_key), ) if response.status_code == 200: @@ -161,9 +161,7 @@ class TidbService: tidb_serverless_list_map = {item.cluster_id: item for item in tidb_serverless_list} cluster_ids = [item.cluster_id for item in tidb_serverless_list] params = {"clusterIds": cluster_ids, "view": "BASIC"} - response = requests.get( - f"{api_url}/clusters:batchGet", params=params, auth=HTTPDigestAuth(public_key, private_key) - ) + response = httpx.get(f"{api_url}/clusters:batchGet", params=params, auth=DigestAuth(public_key, private_key)) if response.status_code == 200: response_data = response.json() @@ -224,8 +222,8 @@ class TidbService: clusters.append(cluster_data) request_body = {"requests": clusters} - response = requests.post( - f"{api_url}/clusters:batchCreate", json=request_body, auth=HTTPDigestAuth(public_key, private_key) + response = httpx.post( + f"{api_url}/clusters:batchCreate", json=request_body, auth=DigestAuth(public_key, private_key) ) if response.status_code == 200: diff --git a/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py b/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py index b8897c4165..27ae038a06 100644 --- a/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py +++ b/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py @@ -55,13 +55,13 @@ class TiDBVector(BaseVector): return Table( self._collection_name, self._orm_base.metadata, - Column(Field.PRIMARY_KEY.value, String(36), primary_key=True, nullable=False), + Column(Field.PRIMARY_KEY, String(36), primary_key=True, nullable=False), Column( - Field.VECTOR.value, + Field.VECTOR, VectorType(dim), nullable=False, ), - Column(Field.TEXT_KEY.value, TEXT, nullable=False), + Column(Field.TEXT_KEY, TEXT, nullable=False), Column("meta", JSON, nullable=False), Column("create_time", DateTime, server_default=sqlalchemy.text("CURRENT_TIMESTAMP")), Column( diff --git a/api/core/rag/datasource/vdb/vikingdb/vikingdb_vector.py b/api/core/rag/datasource/vdb/vikingdb/vikingdb_vector.py index d1bdd3baef..e5feecf2bc 100644 --- a/api/core/rag/datasource/vdb/vikingdb/vikingdb_vector.py +++ b/api/core/rag/datasource/vdb/vikingdb/vikingdb_vector.py @@ -76,11 +76,11 @@ class VikingDBVector(BaseVector): if not self._has_collection(): fields = [ - Field(field_name=vdb_Field.PRIMARY_KEY.value, field_type=FieldType.String, is_primary_key=True), - Field(field_name=vdb_Field.METADATA_KEY.value, field_type=FieldType.String), - Field(field_name=vdb_Field.GROUP_KEY.value, field_type=FieldType.String), - Field(field_name=vdb_Field.CONTENT_KEY.value, field_type=FieldType.Text), - Field(field_name=vdb_Field.VECTOR.value, field_type=FieldType.Vector, dim=dimension), + Field(field_name=vdb_Field.PRIMARY_KEY, field_type=FieldType.String, is_primary_key=True), + Field(field_name=vdb_Field.METADATA_KEY, field_type=FieldType.String), + Field(field_name=vdb_Field.GROUP_KEY, field_type=FieldType.String), + Field(field_name=vdb_Field.CONTENT_KEY, field_type=FieldType.Text), + Field(field_name=vdb_Field.VECTOR, field_type=FieldType.Vector, dim=dimension), ] self._client.create_collection( @@ -100,7 +100,7 @@ class VikingDBVector(BaseVector): collection_name=self._collection_name, index_name=self._index_name, vector_index=vector_index, - partition_by=vdb_Field.GROUP_KEY.value, + partition_by=vdb_Field.GROUP_KEY, description="Index For Dify", ) redis_client.set(collection_exist_cache_key, 1, ex=3600) @@ -126,11 +126,11 @@ class VikingDBVector(BaseVector): # FIXME: fix the type of metadata later doc = Data( { - vdb_Field.PRIMARY_KEY.value: metadatas[i]["doc_id"], # type: ignore - vdb_Field.VECTOR.value: embeddings[i] if embeddings else None, - vdb_Field.CONTENT_KEY.value: page_content, - vdb_Field.METADATA_KEY.value: json.dumps(metadata), - vdb_Field.GROUP_KEY.value: self._group_id, + vdb_Field.PRIMARY_KEY: metadatas[i]["doc_id"], # type: ignore + vdb_Field.VECTOR: embeddings[i] if embeddings else None, + vdb_Field.CONTENT_KEY: page_content, + vdb_Field.METADATA_KEY: json.dumps(metadata), + vdb_Field.GROUP_KEY: self._group_id, } ) docs.append(doc) @@ -151,7 +151,7 @@ class VikingDBVector(BaseVector): # Note: Metadata field value is an dict, but vikingdb field # not support json type results = self._client.get_index(self._collection_name, self._index_name).search( - filter={"op": "must", "field": vdb_Field.GROUP_KEY.value, "conds": [self._group_id]}, + filter={"op": "must", "field": vdb_Field.GROUP_KEY, "conds": [self._group_id]}, # max value is 5000 limit=5000, ) @@ -161,7 +161,7 @@ class VikingDBVector(BaseVector): ids = [] for result in results: - metadata = result.fields.get(vdb_Field.METADATA_KEY.value) + metadata = result.fields.get(vdb_Field.METADATA_KEY) if metadata is not None: metadata = json.loads(metadata) if metadata.get(key) == value: @@ -189,12 +189,12 @@ class VikingDBVector(BaseVector): docs = [] for result in results: - metadata = result.fields.get(vdb_Field.METADATA_KEY.value) + metadata = result.fields.get(vdb_Field.METADATA_KEY) if metadata is not None: metadata = json.loads(metadata) if result.score >= score_threshold: metadata["score"] = result.score - doc = Document(page_content=result.fields.get(vdb_Field.CONTENT_KEY.value), metadata=metadata) + doc = Document(page_content=result.fields.get(vdb_Field.CONTENT_KEY), metadata=metadata) docs.append(doc) docs = sorted(docs, key=lambda x: x.metadata.get("score", 0) if x.metadata else 0, reverse=True) return docs diff --git a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py b/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py index 3ec08b93ed..8820c0a846 100644 --- a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py +++ b/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py @@ -2,7 +2,6 @@ import datetime import json from typing import Any -import requests import weaviate # type: ignore from pydantic import BaseModel, model_validator @@ -45,8 +44,8 @@ class WeaviateVector(BaseVector): client = weaviate.Client( url=config.endpoint, auth_client_secret=auth_config, timeout_config=(5, 60), startup_period=None ) - except requests.ConnectionError: - raise ConnectionError("Vector database connection error") + except Exception as exc: + raise ConnectionError("Vector database connection error") from exc client.batch.configure( # `batch_size` takes an `int` value to enable auto-batching @@ -105,7 +104,7 @@ class WeaviateVector(BaseVector): with self._client.batch as batch: for i, text in enumerate(texts): - data_properties = {Field.TEXT_KEY.value: text} + data_properties = {Field.TEXT_KEY: text} if metadatas is not None: # metadata maybe None for key, val in (metadatas[i] or {}).items(): @@ -183,7 +182,7 @@ class WeaviateVector(BaseVector): """Look up similar documents by embedding vector in Weaviate.""" collection_name = self._collection_name properties = self._attributes - properties.append(Field.TEXT_KEY.value) + properties.append(Field.TEXT_KEY) query_obj = self._client.query.get(collection_name, properties) vector = {"vector": query_vector} @@ -205,7 +204,7 @@ class WeaviateVector(BaseVector): docs_and_scores = [] for res in result["data"]["Get"][collection_name]: - text = res.pop(Field.TEXT_KEY.value) + text = res.pop(Field.TEXT_KEY) score = 1 - res["_additional"]["distance"] docs_and_scores.append((Document(page_content=text, metadata=res), score)) @@ -233,7 +232,7 @@ class WeaviateVector(BaseVector): collection_name = self._collection_name content: dict[str, Any] = {"concepts": [query]} properties = self._attributes - properties.append(Field.TEXT_KEY.value) + properties.append(Field.TEXT_KEY) if kwargs.get("search_distance"): content["certainty"] = kwargs.get("search_distance") query_obj = self._client.query.get(collection_name, properties) @@ -251,7 +250,7 @@ class WeaviateVector(BaseVector): raise ValueError(f"Error during query: {result['errors']}") docs = [] for res in result["data"]["Get"][collection_name]: - text = res.pop(Field.TEXT_KEY.value) + text = res.pop(Field.TEXT_KEY) additional = res.pop("_additional") docs.append(Document(page_content=text, vector=additional["vector"], metadata=res)) return docs diff --git a/api/core/rag/entities/event.py b/api/core/rag/entities/event.py index 24db5d77be..a61b17ddb8 100644 --- a/api/core/rag/entities/event.py +++ b/api/core/rag/entities/event.py @@ -20,12 +20,12 @@ class BaseDatasourceEvent(BaseModel): class DatasourceErrorEvent(BaseDatasourceEvent): - event: str = DatasourceStreamEvent.ERROR.value + event: DatasourceStreamEvent = DatasourceStreamEvent.ERROR error: str = Field(..., description="error message") class DatasourceCompletedEvent(BaseDatasourceEvent): - event: str = DatasourceStreamEvent.COMPLETED.value + event: DatasourceStreamEvent = DatasourceStreamEvent.COMPLETED data: Mapping[str, Any] | list = Field(..., description="result") total: int | None = Field(default=0, description="total") completed: int | None = Field(default=0, description="completed") @@ -33,6 +33,6 @@ class DatasourceCompletedEvent(BaseDatasourceEvent): class DatasourceProcessingEvent(BaseDatasourceEvent): - event: str = DatasourceStreamEvent.PROCESSING.value + event: DatasourceStreamEvent = DatasourceStreamEvent.PROCESSING total: int | None = Field(..., description="total") completed: int | None = Field(..., description="completed") diff --git a/api/core/rag/extractor/entity/extract_setting.py b/api/core/rag/extractor/entity/extract_setting.py index b9bf9d0d8c..c3bfbce98f 100644 --- a/api/core/rag/extractor/entity/extract_setting.py +++ b/api/core/rag/extractor/entity/extract_setting.py @@ -17,9 +17,6 @@ class NotionInfo(BaseModel): tenant_id: str model_config = ConfigDict(arbitrary_types_allowed=True) - def __init__(self, **data): - super().__init__(**data) - class WebsiteInfo(BaseModel): """ @@ -47,6 +44,3 @@ class ExtractSetting(BaseModel): website_info: WebsiteInfo | None = None document_model: str | None = None model_config = ConfigDict(arbitrary_types_allowed=True) - - def __init__(self, **data): - super().__init__(**data) diff --git a/api/core/rag/extractor/extract_processor.py b/api/core/rag/extractor/extract_processor.py index 3dc08e1832..0f62f9c4b6 100644 --- a/api/core/rag/extractor/extract_processor.py +++ b/api/core/rag/extractor/extract_processor.py @@ -45,7 +45,7 @@ class ExtractProcessor: cls, upload_file: UploadFile, return_text: bool = False, is_automatic: bool = False ) -> Union[list[Document], str]: extract_setting = ExtractSetting( - datasource_type=DatasourceType.FILE.value, upload_file=upload_file, document_model="text_model" + datasource_type=DatasourceType.FILE, upload_file=upload_file, document_model="text_model" ) if return_text: delimiter = "\n" @@ -76,7 +76,7 @@ class ExtractProcessor: # https://stackoverflow.com/questions/26541416/generate-temporary-file-names-without-creating-actual-file-in-python#comment90414256_26541521 file_path = f"{temp_dir}/{tempfile.gettempdir()}{suffix}" Path(file_path).write_bytes(response.content) - extract_setting = ExtractSetting(datasource_type=DatasourceType.FILE.value, document_model="text_model") + extract_setting = ExtractSetting(datasource_type=DatasourceType.FILE, document_model="text_model") if return_text: delimiter = "\n" return delimiter.join( @@ -92,7 +92,7 @@ class ExtractProcessor: def extract( cls, extract_setting: ExtractSetting, is_automatic: bool = False, file_path: str | None = None ) -> list[Document]: - if extract_setting.datasource_type == DatasourceType.FILE.value: + if extract_setting.datasource_type == DatasourceType.FILE: with tempfile.TemporaryDirectory() as temp_dir: if not file_path: assert extract_setting.upload_file is not None, "upload_file is required" @@ -163,7 +163,7 @@ class ExtractProcessor: # txt extractor = TextExtractor(file_path, autodetect_encoding=True) return extractor.extract() - elif extract_setting.datasource_type == DatasourceType.NOTION.value: + elif extract_setting.datasource_type == DatasourceType.NOTION: assert extract_setting.notion_info is not None, "notion_info is required" extractor = NotionExtractor( notion_workspace_id=extract_setting.notion_info.notion_workspace_id, @@ -174,7 +174,7 @@ class ExtractProcessor: credential_id=extract_setting.notion_info.credential_id, ) return extractor.extract() - elif extract_setting.datasource_type == DatasourceType.WEBSITE.value: + elif extract_setting.datasource_type == DatasourceType.WEBSITE: assert extract_setting.website_info is not None, "website_info is required" if extract_setting.website_info.provider == "firecrawl": extractor = FirecrawlWebExtractor( diff --git a/api/core/rag/extractor/firecrawl/firecrawl_app.py b/api/core/rag/extractor/firecrawl/firecrawl_app.py index e1ba6ef243..c20ecd2b89 100644 --- a/api/core/rag/extractor/firecrawl/firecrawl_app.py +++ b/api/core/rag/extractor/firecrawl/firecrawl_app.py @@ -2,7 +2,7 @@ import json import time from typing import Any, cast -import requests +import httpx from extensions.ext_storage import storage @@ -104,18 +104,18 @@ class FirecrawlApp: def _prepare_headers(self) -> dict[str, Any]: return {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"} - def _post_request(self, url, data, headers, retries=3, backoff_factor=0.5) -> requests.Response: + def _post_request(self, url, data, headers, retries=3, backoff_factor=0.5) -> httpx.Response: for attempt in range(retries): - response = requests.post(url, headers=headers, json=data) + response = httpx.post(url, headers=headers, json=data) if response.status_code == 502: time.sleep(backoff_factor * (2**attempt)) else: return response return response - def _get_request(self, url, headers, retries=3, backoff_factor=0.5) -> requests.Response: + def _get_request(self, url, headers, retries=3, backoff_factor=0.5) -> httpx.Response: for attempt in range(retries): - response = requests.get(url, headers=headers) + response = httpx.get(url, headers=headers) if response.status_code == 502: time.sleep(backoff_factor * (2**attempt)) else: diff --git a/api/core/rag/extractor/notion_extractor.py b/api/core/rag/extractor/notion_extractor.py index bddf41af43..e87ab38349 100644 --- a/api/core/rag/extractor/notion_extractor.py +++ b/api/core/rag/extractor/notion_extractor.py @@ -3,7 +3,7 @@ import logging import operator from typing import Any, cast -import requests +import httpx from configs import dify_config from core.rag.extractor.extractor_base import BaseExtractor @@ -92,7 +92,7 @@ class NotionExtractor(BaseExtractor): if next_cursor: current_query["start_cursor"] = next_cursor - res = requests.post( + res = httpx.post( DATABASE_URL_TMPL.format(database_id=database_id), headers={ "Authorization": "Bearer " + self._notion_access_token, @@ -160,7 +160,7 @@ class NotionExtractor(BaseExtractor): while True: query_dict: dict[str, Any] = {} if not start_cursor else {"start_cursor": start_cursor} try: - res = requests.request( + res = httpx.request( "GET", block_url, headers={ @@ -173,7 +173,7 @@ class NotionExtractor(BaseExtractor): if res.status_code != 200: raise ValueError(f"Error fetching Notion block data: {res.text}") data = res.json() - except requests.RequestException as e: + except httpx.HTTPError as e: raise ValueError("Error fetching Notion block data") from e if "results" not in data or not isinstance(data["results"], list): raise ValueError("Error fetching Notion block data") @@ -222,7 +222,7 @@ class NotionExtractor(BaseExtractor): while True: query_dict: dict[str, Any] = {} if not start_cursor else {"start_cursor": start_cursor} - res = requests.request( + res = httpx.request( "GET", block_url, headers={ @@ -282,7 +282,7 @@ class NotionExtractor(BaseExtractor): while not done: query_dict: dict[str, Any] = {} if not start_cursor else {"start_cursor": start_cursor} - res = requests.request( + res = httpx.request( "GET", block_url, headers={ @@ -354,7 +354,7 @@ class NotionExtractor(BaseExtractor): query_dict: dict[str, Any] = {} - res = requests.request( + res = httpx.request( "GET", retrieve_page_url, headers={ diff --git a/api/core/rag/extractor/watercrawl/client.py b/api/core/rag/extractor/watercrawl/client.py index 6d596e07d8..7cf6c4d289 100644 --- a/api/core/rag/extractor/watercrawl/client.py +++ b/api/core/rag/extractor/watercrawl/client.py @@ -3,8 +3,8 @@ from collections.abc import Generator from typing import Union from urllib.parse import urljoin -import requests -from requests import Response +import httpx +from httpx import Response from core.rag.extractor.watercrawl.exceptions import ( WaterCrawlAuthenticationError, @@ -20,28 +20,45 @@ class BaseAPIClient: self.session = self.init_session() def init_session(self): - session = requests.Session() - session.headers.update({"X-API-Key": self.api_key}) - session.headers.update({"Content-Type": "application/json"}) - session.headers.update({"Accept": "application/json"}) - session.headers.update({"User-Agent": "WaterCrawl-Plugin"}) - session.headers.update({"Accept-Language": "en-US"}) - return session + headers = { + "X-API-Key": self.api_key, + "Content-Type": "application/json", + "Accept": "application/json", + "User-Agent": "WaterCrawl-Plugin", + "Accept-Language": "en-US", + } + return httpx.Client(headers=headers, timeout=None) + + def _request( + self, + method: str, + endpoint: str, + query_params: dict | None = None, + data: dict | None = None, + **kwargs, + ) -> Response: + stream = kwargs.pop("stream", False) + url = urljoin(self.base_url, endpoint) + if stream: + request = self.session.build_request(method, url, params=query_params, json=data) + return self.session.send(request, stream=True, **kwargs) + + return self.session.request(method, url, params=query_params, json=data, **kwargs) def _get(self, endpoint: str, query_params: dict | None = None, **kwargs): - return self.session.get(urljoin(self.base_url, endpoint), params=query_params, **kwargs) + return self._request("GET", endpoint, query_params=query_params, **kwargs) def _post(self, endpoint: str, query_params: dict | None = None, data: dict | None = None, **kwargs): - return self.session.post(urljoin(self.base_url, endpoint), params=query_params, json=data, **kwargs) + return self._request("POST", endpoint, query_params=query_params, data=data, **kwargs) def _put(self, endpoint: str, query_params: dict | None = None, data: dict | None = None, **kwargs): - return self.session.put(urljoin(self.base_url, endpoint), params=query_params, json=data, **kwargs) + return self._request("PUT", endpoint, query_params=query_params, data=data, **kwargs) def _delete(self, endpoint: str, query_params: dict | None = None, **kwargs): - return self.session.delete(urljoin(self.base_url, endpoint), params=query_params, **kwargs) + return self._request("DELETE", endpoint, query_params=query_params, **kwargs) def _patch(self, endpoint: str, query_params: dict | None = None, data: dict | None = None, **kwargs): - return self.session.patch(urljoin(self.base_url, endpoint), params=query_params, json=data, **kwargs) + return self._request("PATCH", endpoint, query_params=query_params, data=data, **kwargs) class WaterCrawlAPIClient(BaseAPIClient): @@ -49,14 +66,17 @@ class WaterCrawlAPIClient(BaseAPIClient): super().__init__(api_key, base_url) def process_eventstream(self, response: Response, download: bool = False) -> Generator: - for line in response.iter_lines(): - line = line.decode("utf-8") - if line.startswith("data:"): - line = line[5:].strip() - data = json.loads(line) - if data["type"] == "result" and download: - data["data"] = self.download_result(data["data"]) - yield data + try: + for raw_line in response.iter_lines(): + line = raw_line.decode("utf-8") if isinstance(raw_line, bytes) else raw_line + if line.startswith("data:"): + line = line[5:].strip() + data = json.loads(line) + if data["type"] == "result" and download: + data["data"] = self.download_result(data["data"]) + yield data + finally: + response.close() def process_response(self, response: Response) -> dict | bytes | list | None | Generator: if response.status_code == 401: @@ -170,7 +190,10 @@ class WaterCrawlAPIClient(BaseAPIClient): return event_data["data"] def download_result(self, result_object: dict): - response = requests.get(result_object["result"]) - response.raise_for_status() - result_object["result"] = response.json() + response = httpx.get(result_object["result"], timeout=None) + try: + response.raise_for_status() + result_object["result"] = response.json() + finally: + response.close() return result_object diff --git a/api/core/rag/extractor/word_extractor.py b/api/core/rag/extractor/word_extractor.py index f25f92cf81..1a9704688a 100644 --- a/api/core/rag/extractor/word_extractor.py +++ b/api/core/rag/extractor/word_extractor.py @@ -9,7 +9,7 @@ import uuid from urllib.parse import urlparse from xml.etree import ElementTree -import requests +import httpx from docx import Document as DocxDocument from configs import dify_config @@ -43,15 +43,19 @@ class WordExtractor(BaseExtractor): # If the file is a web path, download it to a temporary file, and use that if not os.path.isfile(self.file_path) and self._is_valid_url(self.file_path): - r = requests.get(self.file_path) + response = httpx.get(self.file_path, timeout=None) - if r.status_code != 200: - raise ValueError(f"Check the url of your file; returned status code {r.status_code}") + if response.status_code != 200: + response.close() + raise ValueError(f"Check the url of your file; returned status code {response.status_code}") self.web_path = self.file_path # TODO: use a better way to handle the file self.temp_file = tempfile.NamedTemporaryFile() # noqa SIM115 - self.temp_file.write(r.content) + try: + self.temp_file.write(response.content) + finally: + response.close() self.file_path = self.temp_file.name elif not os.path.isfile(self.file_path): raise ValueError(f"File path {self.file_path} is not a valid file or url") diff --git a/api/core/rag/index_processor/processor/paragraph_index_processor.py b/api/core/rag/index_processor/processor/paragraph_index_processor.py index 755aa88d08..4fcffbcc77 100644 --- a/api/core/rag/index_processor/processor/paragraph_index_processor.py +++ b/api/core/rag/index_processor/processor/paragraph_index_processor.py @@ -38,11 +38,11 @@ class ParagraphIndexProcessor(BaseIndexProcessor): raise ValueError("No process rule found.") if process_rule.get("mode") == "automatic": automatic_rule = DatasetProcessRule.AUTOMATIC_RULES - rules = Rule(**automatic_rule) + rules = Rule.model_validate(automatic_rule) else: if not process_rule.get("rules"): raise ValueError("No rules found in process rule.") - rules = Rule(**process_rule.get("rules")) + rules = Rule.model_validate(process_rule.get("rules")) # Split the text documents into nodes. if not rules.segmentation: raise ValueError("No segmentation found in rules.") diff --git a/api/core/rag/index_processor/processor/parent_child_index_processor.py b/api/core/rag/index_processor/processor/parent_child_index_processor.py index e0ccd8b567..7bdde286f5 100644 --- a/api/core/rag/index_processor/processor/parent_child_index_processor.py +++ b/api/core/rag/index_processor/processor/parent_child_index_processor.py @@ -40,7 +40,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor): raise ValueError("No process rule found.") if not process_rule.get("rules"): raise ValueError("No rules found in process rule.") - rules = Rule(**process_rule.get("rules")) + rules = Rule.model_validate(process_rule.get("rules")) all_documents: list[Document] = [] if rules.parent_mode == ParentMode.PARAGRAPH: # Split the text documents into nodes. @@ -110,7 +110,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor): child_documents = document.children if child_documents: formatted_child_documents = [ - Document(**child_document.model_dump()) for child_document in child_documents + Document.model_validate(child_document.model_dump()) for child_document in child_documents ] vector.create(formatted_child_documents) @@ -224,7 +224,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor): return child_nodes def index(self, dataset: Dataset, document: DatasetDocument, chunks: Any): - parent_childs = ParentChildStructureChunk(**chunks) + parent_childs = ParentChildStructureChunk.model_validate(chunks) documents = [] for parent_child in parent_childs.parent_child_chunks: metadata = { @@ -274,7 +274,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor): vector.create(all_child_documents) def format_preview(self, chunks: Any) -> Mapping[str, Any]: - parent_childs = ParentChildStructureChunk(**chunks) + parent_childs = ParentChildStructureChunk.model_validate(chunks) preview = [] for parent_child in parent_childs.parent_child_chunks: preview.append({"content": parent_child.parent_content, "child_chunks": parent_child.child_contents}) diff --git a/api/core/rag/index_processor/processor/qa_index_processor.py b/api/core/rag/index_processor/processor/qa_index_processor.py index 2054031643..9c8f70dba8 100644 --- a/api/core/rag/index_processor/processor/qa_index_processor.py +++ b/api/core/rag/index_processor/processor/qa_index_processor.py @@ -47,7 +47,7 @@ class QAIndexProcessor(BaseIndexProcessor): raise ValueError("No process rule found.") if not process_rule.get("rules"): raise ValueError("No rules found in process rule.") - rules = Rule(**process_rule.get("rules")) + rules = Rule.model_validate(process_rule.get("rules")) splitter = self._get_splitter( processing_rule_mode=process_rule.get("mode"), max_tokens=rules.segmentation.max_tokens if rules.segmentation else 0, @@ -168,7 +168,7 @@ class QAIndexProcessor(BaseIndexProcessor): return docs def index(self, dataset: Dataset, document: DatasetDocument, chunks: Any): - qa_chunks = QAStructureChunk(**chunks) + qa_chunks = QAStructureChunk.model_validate(chunks) documents = [] for qa_chunk in qa_chunks.qa_chunks: metadata = { @@ -191,7 +191,7 @@ class QAIndexProcessor(BaseIndexProcessor): raise ValueError("Indexing technique must be high quality.") def format_preview(self, chunks: Any) -> Mapping[str, Any]: - qa_chunks = QAStructureChunk(**chunks) + qa_chunks = QAStructureChunk.model_validate(chunks) preview = [] for qa_chunk in qa_chunks.qa_chunks: preview.append({"question": qa_chunk.question, "answer": qa_chunk.answer}) diff --git a/api/core/rag/rerank/rerank_factory.py b/api/core/rag/rerank/rerank_factory.py index 1a3cf85736..524e83824c 100644 --- a/api/core/rag/rerank/rerank_factory.py +++ b/api/core/rag/rerank/rerank_factory.py @@ -8,9 +8,9 @@ class RerankRunnerFactory: @staticmethod def create_rerank_runner(runner_type: str, *args, **kwargs) -> BaseRerankRunner: match runner_type: - case RerankMode.RERANKING_MODEL.value: + case RerankMode.RERANKING_MODEL: return RerankModelRunner(*args, **kwargs) - case RerankMode.WEIGHTED_SCORE.value: + case RerankMode.WEIGHTED_SCORE: return WeightRerankRunner(*args, **kwargs) case _: raise ValueError(f"Unknown runner type: {runner_type}") diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index b08f80da49..0a702d2902 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -61,7 +61,7 @@ from models.dataset import Document as DatasetDocument from services.external_knowledge_service import ExternalDatasetService default_retrieval_model: dict[str, Any] = { - "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, + "search_method": RetrievalMethod.SEMANTIC_SEARCH, "reranking_enable": False, "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, "top_k": 4, @@ -692,7 +692,7 @@ class DatasetRetrieval: if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE: # get retrieval model config default_retrieval_model = { - "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, + "search_method": RetrievalMethod.SEMANTIC_SEARCH, "reranking_enable": False, "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, "top_k": 2, diff --git a/api/core/rag/retrieval/retrieval_methods.py b/api/core/rag/retrieval/retrieval_methods.py index c7c6e60c8d..5f0f2a9d33 100644 --- a/api/core/rag/retrieval/retrieval_methods.py +++ b/api/core/rag/retrieval/retrieval_methods.py @@ -9,8 +9,8 @@ class RetrievalMethod(Enum): @staticmethod def is_support_semantic_search(retrieval_method: str) -> bool: - return retrieval_method in {RetrievalMethod.SEMANTIC_SEARCH.value, RetrievalMethod.HYBRID_SEARCH.value} + return retrieval_method in {RetrievalMethod.SEMANTIC_SEARCH, RetrievalMethod.HYBRID_SEARCH} @staticmethod def is_support_fulltext_search(retrieval_method: str) -> bool: - return retrieval_method in {RetrievalMethod.FULL_TEXT_SEARCH.value, RetrievalMethod.HYBRID_SEARCH.value} + return retrieval_method in {RetrievalMethod.FULL_TEXT_SEARCH, RetrievalMethod.HYBRID_SEARCH} diff --git a/api/core/rag/splitter/fixed_text_splitter.py b/api/core/rag/splitter/fixed_text_splitter.py index 8356861242..801d2a2a52 100644 --- a/api/core/rag/splitter/fixed_text_splitter.py +++ b/api/core/rag/splitter/fixed_text_splitter.py @@ -2,6 +2,7 @@ from __future__ import annotations +import re from typing import Any from core.model_manager import ModelInstance @@ -52,7 +53,7 @@ class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter) """Create a new TextSplitter.""" super().__init__(**kwargs) self._fixed_separator = fixed_separator - self._separators = separators or ["\n\n", "\n", " ", ""] + self._separators = separators or ["\n\n", "\n", "。", ". ", " ", ""] def split_text(self, text: str) -> list[str]: """Split incoming text and return chunks.""" @@ -90,16 +91,19 @@ class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter) # Now that we have the separator, split the text if separator: if separator == " ": - splits = text.split() + splits = re.split(r" +", text) else: splits = text.split(separator) splits = [item + separator if i < len(splits) else item for i, item in enumerate(splits)] else: splits = list(text) - splits = [s for s in splits if (s not in {"", "\n"})] + if separator == "\n": + splits = [s for s in splits if s != ""] + else: + splits = [s for s in splits if (s not in {"", "\n"})] _good_splits = [] _good_splits_lengths = [] # cache the lengths of the splits - _separator = "" if self._keep_separator else separator + _separator = separator if self._keep_separator else "" s_lens = self._length_function(splits) if separator != "": for s, s_len in zip(splits, s_lens): diff --git a/api/core/tools/builtin_tool/provider.py b/api/core/tools/builtin_tool/provider.py index 45fd16d684..2e94907f30 100644 --- a/api/core/tools/builtin_tool/provider.py +++ b/api/core/tools/builtin_tool/provider.py @@ -90,7 +90,7 @@ class BuiltinToolProviderController(ToolProviderController): tools.append( assistant_tool_class( provider=provider, - entity=ToolEntity(**tool), + entity=ToolEntity.model_validate(tool), runtime=ToolRuntime(tenant_id=""), ) ) @@ -111,7 +111,7 @@ class BuiltinToolProviderController(ToolProviderController): :return: the credentials schema """ - return self.get_credentials_schema_by_type(CredentialType.API_KEY.value) + return self.get_credentials_schema_by_type(CredentialType.API_KEY) def get_credentials_schema_by_type(self, credential_type: str) -> list[ProviderConfig]: """ @@ -122,7 +122,7 @@ class BuiltinToolProviderController(ToolProviderController): """ if credential_type == CredentialType.OAUTH2.value: return self.entity.oauth_schema.credentials_schema.copy() if self.entity.oauth_schema else [] - if credential_type == CredentialType.API_KEY.value: + if credential_type == CredentialType.API_KEY: return self.entity.credentials_schema.copy() if self.entity.credentials_schema else [] raise ValueError(f"Invalid credential type: {credential_type}") @@ -134,15 +134,15 @@ class BuiltinToolProviderController(ToolProviderController): """ return self.entity.oauth_schema.client_schema.copy() if self.entity.oauth_schema else [] - def get_supported_credential_types(self) -> list[str]: + def get_supported_credential_types(self) -> list[CredentialType]: """ returns the credential support type of the provider """ types = [] if self.entity.credentials_schema is not None and len(self.entity.credentials_schema) > 0: - types.append(CredentialType.API_KEY.value) + types.append(CredentialType.API_KEY) if self.entity.oauth_schema is not None and len(self.entity.oauth_schema.credentials_schema) > 0: - types.append(CredentialType.OAUTH2.value) + types.append(CredentialType.OAUTH2) return types def get_tools(self) -> list[BuiltinTool]: diff --git a/api/core/tools/custom_tool/tool.py b/api/core/tools/custom_tool/tool.py index 34d0f5c622..f18f638f2d 100644 --- a/api/core/tools/custom_tool/tool.py +++ b/api/core/tools/custom_tool/tool.py @@ -290,6 +290,7 @@ class ApiTool(Tool): method_lc ]( # https://discuss.python.org/t/type-inference-for-function-return-types/42926 url, + max_retries=0, params=params, headers=headers, cookies=cookies, diff --git a/api/core/tools/entities/api_entities.py b/api/core/tools/entities/api_entities.py index 00c4ab9dd7..de6bf01ae9 100644 --- a/api/core/tools/entities/api_entities.py +++ b/api/core/tools/entities/api_entities.py @@ -61,7 +61,7 @@ class ToolProviderApiEntity(BaseModel): for tool in tools: if tool.get("parameters"): for parameter in tool.get("parameters"): - if parameter.get("type") == ToolParameter.ToolParameterType.SYSTEM_FILES.value: + if parameter.get("type") == ToolParameter.ToolParameterType.SYSTEM_FILES: parameter["type"] = "files" if parameter.get("input_schema") is None: parameter.pop("input_schema", None) @@ -110,7 +110,9 @@ class ToolProviderCredentialApiEntity(BaseModel): class ToolProviderCredentialInfoApiEntity(BaseModel): - supported_credential_types: list[str] = Field(description="The supported credential types of the provider") + supported_credential_types: list[CredentialType] = Field( + description="The supported credential types of the provider" + ) is_oauth_custom_client_enabled: bool = Field( default=False, description="Whether the OAuth custom client is enabled for the provider" ) diff --git a/api/core/tools/entities/common_entities.py b/api/core/tools/entities/common_entities.py index 2c6d9c1964..21d310bbb9 100644 --- a/api/core/tools/entities/common_entities.py +++ b/api/core/tools/entities/common_entities.py @@ -1,4 +1,4 @@ -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, model_validator class I18nObject(BaseModel): @@ -11,11 +11,12 @@ class I18nObject(BaseModel): pt_BR: str | None = Field(default=None) ja_JP: str | None = Field(default=None) - def __init__(self, **data): - super().__init__(**data) + @model_validator(mode="after") + def _populate_missing_locales(self): self.zh_Hans = self.zh_Hans or self.en_US self.pt_BR = self.pt_BR or self.en_US self.ja_JP = self.ja_JP or self.en_US + return self def to_dict(self): return {"zh_Hans": self.zh_Hans, "en_US": self.en_US, "pt_BR": self.pt_BR, "ja_JP": self.ja_JP} diff --git a/api/core/tools/entities/tool_entities.py b/api/core/tools/entities/tool_entities.py index a59b54216f..62e3aa8b5d 100644 --- a/api/core/tools/entities/tool_entities.py +++ b/api/core/tools/entities/tool_entities.py @@ -113,7 +113,7 @@ class ApiProviderAuthType(StrEnum): # normalize & tiny alias for backward compatibility v = (value or "").strip().lower() if v == "api_key": - v = cls.API_KEY_HEADER.value + v = cls.API_KEY_HEADER for mode in cls: if mode.value == v: diff --git a/api/core/tools/mcp_tool/provider.py b/api/core/tools/mcp_tool/provider.py index 5b04f0edbe..f269b8db9b 100644 --- a/api/core/tools/mcp_tool/provider.py +++ b/api/core/tools/mcp_tool/provider.py @@ -54,7 +54,7 @@ class MCPToolProviderController(ToolProviderController): """ tools = [] tools_data = json.loads(db_provider.tools) - remote_mcp_tools = [RemoteMCPTool(**tool) for tool in tools_data] + remote_mcp_tools = [RemoteMCPTool.model_validate(tool) for tool in tools_data] user = db_provider.load_user() tools = [ ToolEntity( diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index 9e5f5a7c23..af68971ca7 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -1008,7 +1008,7 @@ class ToolManager: config = tool_configurations.get(parameter.name, {}) if not (config and isinstance(config, dict) and config.get("value") is not None): continue - tool_input = ToolNodeData.ToolInput(**tool_configurations.get(parameter.name, {})) + tool_input = ToolNodeData.ToolInput.model_validate(tool_configurations.get(parameter.name, {})) if tool_input.type == "variable": variable = variable_pool.get(tool_input.value) if variable is None: diff --git a/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py b/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py index 75c0c6738e..b5bc4d3c00 100644 --- a/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py +++ b/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py @@ -18,7 +18,7 @@ from extensions.ext_database import db from models.dataset import Dataset, Document, DocumentSegment default_retrieval_model: dict[str, Any] = { - "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, + "search_method": RetrievalMethod.SEMANTIC_SEARCH, "reranking_enable": False, "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, "top_k": 2, @@ -126,7 +126,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool): data_source_type=document.data_source_type, segment_id=segment.id, retriever_from=self.retriever_from, - score=document_score_list.get(segment.index_node_id, None), + score=document_score_list.get(segment.index_node_id), doc_metadata=document.doc_metadata, ) diff --git a/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py b/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py index 0e2237befd..1eae582f67 100644 --- a/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py +++ b/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py @@ -17,7 +17,7 @@ from models.dataset import Document as DatasetDocument from services.external_knowledge_service import ExternalDatasetService default_retrieval_model: dict[str, Any] = { - "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, + "search_method": RetrievalMethod.SEMANTIC_SEARCH, "reranking_enable": False, "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, "reranking_mode": "reranking_model", diff --git a/api/core/tools/utils/parser.py b/api/core/tools/utils/parser.py index fcb1d325af..c7ac3387e5 100644 --- a/api/core/tools/utils/parser.py +++ b/api/core/tools/utils/parser.py @@ -4,8 +4,8 @@ from json import loads as json_loads from json.decoder import JSONDecodeError from typing import Any +import httpx from flask import request -from requests import get from yaml import YAMLError, safe_load from core.tools.entities.common_entities import I18nObject @@ -334,15 +334,20 @@ class ApiBasedToolSchemaParser: raise ToolNotSupportedError("Only openapi is supported now.") # get openapi yaml - response = get(api_url, headers={"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) "}, timeout=5) - - if response.status_code != 200: - raise ToolProviderNotFoundError("cannot get openapi yaml from url.") - - return ApiBasedToolSchemaParser.parse_openapi_yaml_to_tool_bundle( - response.text, extra_info=extra_info, warning=warning + response = httpx.get( + api_url, headers={"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) "}, timeout=5 ) + try: + if response.status_code != 200: + raise ToolProviderNotFoundError("cannot get openapi yaml from url.") + + return ApiBasedToolSchemaParser.parse_openapi_yaml_to_tool_bundle( + response.text, extra_info=extra_info, warning=warning + ) + finally: + response.close() + @staticmethod def auto_parse_to_tool_bundle( content: str, extra_info: dict | None = None, warning: dict | None = None @@ -388,7 +393,7 @@ class ApiBasedToolSchemaParser: openapi = ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle( loaded_content, extra_info=extra_info, warning=warning ) - schema_type = ApiProviderSchemaType.OPENAPI.value + schema_type = ApiProviderSchemaType.OPENAPI return openapi, schema_type except ToolApiSchemaError as e: openapi_error = e @@ -398,7 +403,7 @@ class ApiBasedToolSchemaParser: converted_swagger = ApiBasedToolSchemaParser.parse_swagger_to_openapi( loaded_content, extra_info=extra_info, warning=warning ) - schema_type = ApiProviderSchemaType.SWAGGER.value + schema_type = ApiProviderSchemaType.SWAGGER return ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle( converted_swagger, extra_info=extra_info, warning=warning ), schema_type @@ -410,7 +415,7 @@ class ApiBasedToolSchemaParser: openapi_plugin = ApiBasedToolSchemaParser.parse_openai_plugin_json_to_tool_bundle( json_dumps(loaded_content), extra_info=extra_info, warning=warning ) - return openapi_plugin, ApiProviderSchemaType.OPENAI_PLUGIN.value + return openapi_plugin, ApiProviderSchemaType.OPENAI_PLUGIN except ToolNotSupportedError as e: # maybe it's not plugin at all openapi_plugin_error = e diff --git a/api/core/workflow/graph_engine/command_channels/redis_channel.py b/api/core/workflow/graph_engine/command_channels/redis_channel.py index 056e17bf5d..c841459170 100644 --- a/api/core/workflow/graph_engine/command_channels/redis_channel.py +++ b/api/core/workflow/graph_engine/command_channels/redis_channel.py @@ -105,10 +105,10 @@ class RedisChannel: command_type = CommandType(command_type_value) if command_type == CommandType.ABORT: - return AbortCommand(**data) + return AbortCommand.model_validate(data) else: # For other command types, use base class - return GraphEngineCommand(**data) + return GraphEngineCommand.model_validate(data) except (ValueError, TypeError): return None diff --git a/api/core/workflow/nodes/agent/agent_node.py b/api/core/workflow/nodes/agent/agent_node.py index a01686a4b8..4a24b18465 100644 --- a/api/core/workflow/nodes/agent/agent_node.py +++ b/api/core/workflow/nodes/agent/agent_node.py @@ -252,7 +252,10 @@ class AgentNode(Node): if all(isinstance(v, dict) for _, v in parameters.items()): params = {} for key, param in parameters.items(): - if param.get("auto", ParamsAutoGenerated.OPEN.value) == ParamsAutoGenerated.CLOSE.value: + if param.get("auto", ParamsAutoGenerated.OPEN) in ( + ParamsAutoGenerated.CLOSE, + 0, + ): value_param = param.get("value", {}) params[key] = value_param.get("value", "") if value_param is not None else None else: @@ -266,7 +269,7 @@ class AgentNode(Node): value = cast(list[dict[str, Any]], value) tool_value = [] for tool in value: - provider_type = ToolProviderType(tool.get("type", ToolProviderType.BUILT_IN.value)) + provider_type = ToolProviderType(tool.get("type", ToolProviderType.BUILT_IN)) setting_params = tool.get("settings", {}) parameters = tool.get("parameters", {}) manual_input_params = [key for key, value in parameters.items() if value is not None] @@ -417,7 +420,7 @@ class AgentNode(Node): def _fetch_memory(self, model_instance: ModelInstance) -> TokenBufferMemory | None: # get conversation id conversation_id_variable = self.graph_runtime_state.variable_pool.get( - ["sys", SystemVariableKey.CONVERSATION_ID.value] + ["sys", SystemVariableKey.CONVERSATION_ID] ) if not isinstance(conversation_id_variable, StringSegment): return None @@ -476,7 +479,7 @@ class AgentNode(Node): if meta_version and Version(meta_version) > Version("0.0.1"): return tools else: - return [tool for tool in tools if tool.get("type") != ToolProviderType.MCP.value] + return [tool for tool in tools if tool.get("type") != ToolProviderType.MCP] def _transform_message( self, diff --git a/api/core/workflow/nodes/datasource/datasource_node.py b/api/core/workflow/nodes/datasource/datasource_node.py index 937f4c944f..e392cb5f5c 100644 --- a/api/core/workflow/nodes/datasource/datasource_node.py +++ b/api/core/workflow/nodes/datasource/datasource_node.py @@ -75,11 +75,11 @@ class DatasourceNode(Node): node_data = self._node_data variable_pool = self.graph_runtime_state.variable_pool - datasource_type_segement = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_TYPE.value]) + datasource_type_segement = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_TYPE]) if not datasource_type_segement: raise DatasourceNodeError("Datasource type is not set") datasource_type = str(datasource_type_segement.value) if datasource_type_segement.value else None - datasource_info_segement = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_INFO.value]) + datasource_info_segement = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_INFO]) if not datasource_info_segement: raise DatasourceNodeError("Datasource info is not set") datasource_info_value = datasource_info_segement.value @@ -267,7 +267,7 @@ class DatasourceNode(Node): return result def _fetch_files(self, variable_pool: VariablePool) -> list[File]: - variable = variable_pool.get(["sys", SystemVariableKey.FILES.value]) + variable = variable_pool.get(["sys", SystemVariableKey.FILES]) assert isinstance(variable, ArrayAnyVariable | ArrayAnySegment) return list(variable.value) if variable else [] diff --git a/api/core/workflow/nodes/end/end_node.py b/api/core/workflow/nodes/end/end_node.py index 2bdfe4efce..7ec74084d0 100644 --- a/api/core/workflow/nodes/end/end_node.py +++ b/api/core/workflow/nodes/end/end_node.py @@ -16,7 +16,7 @@ class EndNode(Node): _node_data: EndNodeData def init_node_data(self, data: Mapping[str, Any]): - self._node_data = EndNodeData(**data) + self._node_data = EndNodeData.model_validate(data) def _get_error_strategy(self) -> ErrorStrategy | None: return self._node_data.error_strategy diff --git a/api/core/workflow/nodes/http_request/node.py b/api/core/workflow/nodes/http_request/node.py index 20e1337ea7..55dec3fb08 100644 --- a/api/core/workflow/nodes/http_request/node.py +++ b/api/core/workflow/nodes/http_request/node.py @@ -234,7 +234,7 @@ class HttpRequestNode(Node): mapping = { "tool_file_id": tool_file.id, - "transfer_method": FileTransferMethod.TOOL_FILE.value, + "transfer_method": FileTransferMethod.TOOL_FILE, } file = file_factory.build_from_mapping( mapping=mapping, diff --git a/api/core/workflow/nodes/iteration/iteration_node.py b/api/core/workflow/nodes/iteration/iteration_node.py index a05a6b1b96..c089a68bd4 100644 --- a/api/core/workflow/nodes/iteration/iteration_node.py +++ b/api/core/workflow/nodes/iteration/iteration_node.py @@ -95,7 +95,7 @@ class IterationNode(Node): "config": { "is_parallel": False, "parallel_nums": 10, - "error_handle_mode": ErrorHandleMode.TERMINATED.value, + "error_handle_mode": ErrorHandleMode.TERMINATED, }, } @@ -342,10 +342,13 @@ class IterationNode(Node): iterator_list_value: Sequence[object], iter_run_map: dict[str, float], ) -> Generator[NodeEventBase, None, None]: + # Flatten the list of lists if all outputs are lists + flattened_outputs = self._flatten_outputs_if_needed(outputs) + yield IterationSucceededEvent( start_at=started_at, inputs=inputs, - outputs={"output": outputs}, + outputs={"output": flattened_outputs}, steps=len(iterator_list_value), metadata={ WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens, @@ -357,13 +360,39 @@ class IterationNode(Node): yield StreamCompletedEvent( node_run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, - outputs={"output": outputs}, + outputs={"output": flattened_outputs}, metadata={ WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens, }, ) ) + def _flatten_outputs_if_needed(self, outputs: list[object]) -> list[object]: + """ + Flatten the outputs list if all elements are lists. + This maintains backward compatibility with version 1.8.1 behavior. + """ + if not outputs: + return outputs + + # Check if all non-None outputs are lists + non_none_outputs = [output for output in outputs if output is not None] + if not non_none_outputs: + return outputs + + if all(isinstance(output, list) for output in non_none_outputs): + # Flatten the list of lists + flattened: list[Any] = [] + for output in outputs: + if isinstance(output, list): + flattened.extend(output) + elif output is not None: + # This shouldn't happen based on our check, but handle it gracefully + flattened.append(output) + return flattened + + return outputs + def _handle_iteration_failure( self, started_at: datetime, @@ -373,10 +402,13 @@ class IterationNode(Node): iter_run_map: dict[str, float], error: IterationNodeError, ) -> Generator[NodeEventBase, None, None]: + # Flatten the list of lists if all outputs are lists (even in failure case) + flattened_outputs = self._flatten_outputs_if_needed(outputs) + yield IterationFailedEvent( start_at=started_at, inputs=inputs, - outputs={"output": outputs}, + outputs={"output": flattened_outputs}, steps=len(iterator_list_value), metadata={ WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens, diff --git a/api/core/workflow/nodes/iteration/iteration_start_node.py b/api/core/workflow/nodes/iteration/iteration_start_node.py index 80f39ccebc..90b7f4539b 100644 --- a/api/core/workflow/nodes/iteration/iteration_start_node.py +++ b/api/core/workflow/nodes/iteration/iteration_start_node.py @@ -18,7 +18,7 @@ class IterationStartNode(Node): _node_data: IterationStartNodeData def init_node_data(self, data: Mapping[str, Any]): - self._node_data = IterationStartNodeData(**data) + self._node_data = IterationStartNodeData.model_validate(data) def _get_error_strategy(self) -> ErrorStrategy | None: return self._node_data.error_strategy diff --git a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py index 05e0c7707a..2751f24048 100644 --- a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py +++ b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py @@ -27,7 +27,7 @@ from .exc import ( logger = logging.getLogger(__name__) default_retrieval_model = { - "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, + "search_method": RetrievalMethod.SEMANTIC_SEARCH, "reranking_enable": False, "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, "top_k": 2, @@ -77,7 +77,7 @@ class KnowledgeIndexNode(Node): raise KnowledgeIndexNodeError("Index chunk variable is required.") invoke_from = variable_pool.get(["sys", SystemVariableKey.INVOKE_FROM]) if invoke_from: - is_preview = invoke_from.value == InvokeFrom.DEBUGGER.value + is_preview = invoke_from.value == InvokeFrom.DEBUGGER else: is_preview = False chunks = variable.value diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index b6128d3eab..7091b62463 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -72,7 +72,7 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) default_retrieval_model = { - "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, + "search_method": RetrievalMethod.SEMANTIC_SEARCH, "reranking_enable": False, "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, "top_k": 4, diff --git a/api/core/workflow/nodes/list_operator/node.py b/api/core/workflow/nodes/list_operator/node.py index 3243b22d44..180eb2ad90 100644 --- a/api/core/workflow/nodes/list_operator/node.py +++ b/api/core/workflow/nodes/list_operator/node.py @@ -41,7 +41,7 @@ class ListOperatorNode(Node): _node_data: ListOperatorNodeData def init_node_data(self, data: Mapping[str, Any]): - self._node_data = ListOperatorNodeData(**data) + self._node_data = ListOperatorNodeData.model_validate(data) def _get_error_strategy(self) -> ErrorStrategy | None: return self._node_data.error_strategy diff --git a/api/core/workflow/nodes/llm/llm_utils.py b/api/core/workflow/nodes/llm/llm_utils.py index ad969cdad1..aff84433b2 100644 --- a/api/core/workflow/nodes/llm/llm_utils.py +++ b/api/core/workflow/nodes/llm/llm_utils.py @@ -92,7 +92,7 @@ def fetch_memory( return None # get conversation id - conversation_id_variable = variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID.value]) + conversation_id_variable = variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID]) if not isinstance(conversation_id_variable, StringSegment): return None conversation_id = conversation_id_variable.value @@ -143,7 +143,7 @@ def deduct_llm_quota(tenant_id: str, model_instance: ModelInstance, usage: LLMUs Provider.tenant_id == tenant_id, # TODO: Use provider name with prefix after the data migration. Provider.provider_name == ModelProviderID(model_instance.provider).provider_name, - Provider.provider_type == ProviderType.SYSTEM.value, + Provider.provider_type == ProviderType.SYSTEM, Provider.quota_type == system_configuration.current_quota_type.value, Provider.quota_limit > Provider.quota_used, ) diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index 4742476352..13f6d904e6 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -945,7 +945,7 @@ class LLMNode(Node): variable_mapping["#files#"] = typed_node_data.vision.configs.variable_selector if typed_node_data.memory: - variable_mapping["#sys.query#"] = ["sys", SystemVariableKey.QUERY.value] + variable_mapping["#sys.query#"] = ["sys", SystemVariableKey.QUERY] if typed_node_data.prompt_config: enable_jinja = False diff --git a/api/core/workflow/nodes/loop/loop_end_node.py b/api/core/workflow/nodes/loop/loop_end_node.py index 38aef06d24..e5bce1230c 100644 --- a/api/core/workflow/nodes/loop/loop_end_node.py +++ b/api/core/workflow/nodes/loop/loop_end_node.py @@ -18,7 +18,7 @@ class LoopEndNode(Node): _node_data: LoopEndNodeData def init_node_data(self, data: Mapping[str, Any]): - self._node_data = LoopEndNodeData(**data) + self._node_data = LoopEndNodeData.model_validate(data) def _get_error_strategy(self) -> ErrorStrategy | None: return self._node_data.error_strategy diff --git a/api/core/workflow/nodes/loop/loop_start_node.py b/api/core/workflow/nodes/loop/loop_start_node.py index e777a8cbe9..e065dc90a0 100644 --- a/api/core/workflow/nodes/loop/loop_start_node.py +++ b/api/core/workflow/nodes/loop/loop_start_node.py @@ -18,7 +18,7 @@ class LoopStartNode(Node): _node_data: LoopStartNodeData def init_node_data(self, data: Mapping[str, Any]): - self._node_data = LoopStartNodeData(**data) + self._node_data = LoopStartNodeData.model_validate(data) def _get_error_strategy(self) -> ErrorStrategy | None: return self._node_data.error_strategy diff --git a/api/core/workflow/nodes/start/start_node.py b/api/core/workflow/nodes/start/start_node.py index 2f33c54128..3b134be1a1 100644 --- a/api/core/workflow/nodes/start/start_node.py +++ b/api/core/workflow/nodes/start/start_node.py @@ -16,7 +16,7 @@ class StartNode(Node): _node_data: StartNodeData def init_node_data(self, data: Mapping[str, Any]): - self._node_data = StartNodeData(**data) + self._node_data = StartNodeData.model_validate(data) def _get_error_strategy(self) -> ErrorStrategy | None: return self._node_data.error_strategy diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index ce1a879ff1..cd0094f531 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -224,7 +224,7 @@ class ToolNode(Node): return result def _fetch_files(self, variable_pool: "VariablePool") -> list[File]: - variable = variable_pool.get(["sys", SystemVariableKey.FILES.value]) + variable = variable_pool.get(["sys", SystemVariableKey.FILES]) assert isinstance(variable, ArrayAnyVariable | ArrayAnySegment) return list(variable.value) if variable else [] diff --git a/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py b/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py index be00d55937..0ac0d3d858 100644 --- a/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py +++ b/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py @@ -15,7 +15,7 @@ class VariableAggregatorNode(Node): _node_data: VariableAssignerNodeData def init_node_data(self, data: Mapping[str, Any]): - self._node_data = VariableAssignerNodeData(**data) + self._node_data = VariableAssignerNodeData.model_validate(data) def _get_error_strategy(self) -> ErrorStrategy | None: return self._node_data.error_strategy diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py index 3801dfe15d..4cd885cfa5 100644 --- a/api/core/workflow/workflow_entry.py +++ b/api/core/workflow/workflow_entry.py @@ -227,7 +227,7 @@ class WorkflowEntry: "height": node_height, "type": "custom", "data": { - "type": NodeType.START.value, + "type": NodeType.START, "title": "Start", "desc": "Start", }, diff --git a/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py b/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py index 6c9fc0bf1d..1b44d8a1e2 100644 --- a/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py +++ b/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py @@ -12,9 +12,9 @@ def handle(sender, **kwargs): if synced_draft_workflow is None: return for node_data in synced_draft_workflow.graph_dict.get("nodes", []): - if node_data.get("data", {}).get("type") == NodeType.TOOL.value: + if node_data.get("data", {}).get("type") == NodeType.TOOL: try: - tool_entity = ToolEntity(**node_data["data"]) + tool_entity = ToolEntity.model_validate(node_data["data"]) tool_runtime = ToolManager.get_tool_runtime( provider_type=tool_entity.provider_type, provider_id=tool_entity.provider_id, diff --git a/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py b/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py index 898ec1f153..53e0065f6e 100644 --- a/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py +++ b/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py @@ -53,7 +53,7 @@ def get_dataset_ids_from_workflow(published_workflow: Workflow) -> set[str]: # fetch all knowledge retrieval nodes knowledge_retrieval_nodes = [ - node for node in nodes if node.get("data", {}).get("type") == NodeType.KNOWLEDGE_RETRIEVAL.value + node for node in nodes if node.get("data", {}).get("type") == NodeType.KNOWLEDGE_RETRIEVAL ] if not knowledge_retrieval_nodes: @@ -61,7 +61,7 @@ def get_dataset_ids_from_workflow(published_workflow: Workflow) -> set[str]: for node in knowledge_retrieval_nodes: try: - node_data = KnowledgeRetrievalNodeData(**node.get("data", {})) + node_data = KnowledgeRetrievalNodeData.model_validate(node.get("data", {})) dataset_ids.update(dataset_id for dataset_id in node_data.dataset_ids) except Exception: continue diff --git a/api/events/event_handlers/update_provider_when_message_created.py b/api/events/event_handlers/update_provider_when_message_created.py index 27efa539dc..c0694d4efe 100644 --- a/api/events/event_handlers/update_provider_when_message_created.py +++ b/api/events/event_handlers/update_provider_when_message_created.py @@ -139,7 +139,7 @@ def handle(sender: Message, **kwargs): filters=_ProviderUpdateFilters( tenant_id=tenant_id, provider_name=ModelProviderID(model_config.provider).provider_name, - provider_type=ProviderType.SYSTEM.value, + provider_type=ProviderType.SYSTEM, quota_type=provider_configuration.system_configuration.current_quota_type.value, ), values=_ProviderUpdateValues(quota_used=Provider.quota_used + used_quota, last_used=current_time), diff --git a/api/extensions/ext_otel.py b/api/extensions/ext_otel.py index 19c6e68c6b..cb6e4849a9 100644 --- a/api/extensions/ext_otel.py +++ b/api/extensions/ext_otel.py @@ -138,7 +138,6 @@ def init_app(app: DifyApp): from opentelemetry.instrumentation.flask import FlaskInstrumentor from opentelemetry.instrumentation.httpx import HTTPXClientInstrumentor from opentelemetry.instrumentation.redis import RedisInstrumentor - from opentelemetry.instrumentation.requests import RequestsInstrumentor from opentelemetry.instrumentation.sqlalchemy import SQLAlchemyInstrumentor from opentelemetry.metrics import get_meter, get_meter_provider, set_meter_provider from opentelemetry.propagate import set_global_textmap @@ -238,7 +237,6 @@ def init_app(app: DifyApp): instrument_exception_logging() init_sqlalchemy_instrumentor(app) RedisInstrumentor().instrument() - RequestsInstrumentor().instrument() HTTPXClientInstrumentor().instrument() atexit.register(shutdown_tracer) diff --git a/api/extensions/storage/clickzetta_volume/file_lifecycle.py b/api/extensions/storage/clickzetta_volume/file_lifecycle.py index 6ab02ad8cc..dc5aa8e39c 100644 --- a/api/extensions/storage/clickzetta_volume/file_lifecycle.py +++ b/api/extensions/storage/clickzetta_volume/file_lifecycle.py @@ -264,7 +264,7 @@ class FileLifecycleManager: logger.warning("File %s not found in metadata", filename) return False - metadata_dict[filename]["status"] = FileStatus.ARCHIVED.value + metadata_dict[filename]["status"] = FileStatus.ARCHIVED metadata_dict[filename]["modified_at"] = datetime.now().isoformat() self._save_metadata(metadata_dict) @@ -309,7 +309,7 @@ class FileLifecycleManager: # Update metadata metadata_dict = self._load_metadata() if filename in metadata_dict: - metadata_dict[filename]["status"] = FileStatus.DELETED.value + metadata_dict[filename]["status"] = FileStatus.DELETED metadata_dict[filename]["modified_at"] = datetime.now().isoformat() self._save_metadata(metadata_dict) diff --git a/api/factories/file_factory.py b/api/factories/file_factory.py index d66c757249..69fd1a6da3 100644 --- a/api/factories/file_factory.py +++ b/api/factories/file_factory.py @@ -45,7 +45,7 @@ def build_from_message_file( } # Set the correct ID field based on transfer method - if message_file.transfer_method == FileTransferMethod.TOOL_FILE.value: + if message_file.transfer_method == FileTransferMethod.TOOL_FILE: mapping["tool_file_id"] = message_file.upload_file_id else: mapping["upload_file_id"] = message_file.upload_file_id @@ -368,9 +368,7 @@ def _build_from_datasource_file( if strict_type_validation and specified_type and detected_file_type.value != specified_type: raise ValueError("Detected file type does not match the specified type. Please verify the file.") - file_type = ( - FileType(specified_type) if specified_type and specified_type != FileType.CUSTOM.value else detected_file_type - ) + file_type = FileType(specified_type) if specified_type and specified_type != FileType.CUSTOM else detected_file_type return File( id=mapping.get("datasource_file_id"), diff --git a/api/models/account.py b/api/models/account.py index 8c1f990aa2..86cd9e41b5 100644 --- a/api/models/account.py +++ b/api/models/account.py @@ -1,15 +1,16 @@ import enum import json +from dataclasses import field from datetime import datetime from typing import Any, Optional import sqlalchemy as sa from flask_login import UserMixin # type: ignore[import-untyped] from sqlalchemy import DateTime, String, func, select -from sqlalchemy.orm import Mapped, Session, mapped_column, reconstructor +from sqlalchemy.orm import Mapped, Session, mapped_column from typing_extensions import deprecated -from models.base import Base +from models.base import TypeBase from .engine import db from .types import StringUUID @@ -83,31 +84,37 @@ class AccountStatus(enum.StrEnum): CLOSED = "closed" -class Account(UserMixin, Base): +class Account(UserMixin, TypeBase): __tablename__ = "accounts" __table_args__ = (sa.PrimaryKeyConstraint("id", name="account_pkey"), sa.Index("account_email_idx", "email")) - id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) + id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"), init=False) name: Mapped[str] = mapped_column(String(255)) email: Mapped[str] = mapped_column(String(255)) - password: Mapped[str | None] = mapped_column(String(255)) - password_salt: Mapped[str | None] = mapped_column(String(255)) - avatar: Mapped[str | None] = mapped_column(String(255), nullable=True) - interface_language: Mapped[str | None] = mapped_column(String(255)) - interface_theme: Mapped[str | None] = mapped_column(String(255), nullable=True) - timezone: Mapped[str | None] = mapped_column(String(255)) - last_login_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) - last_login_ip: Mapped[str | None] = mapped_column(String(255), nullable=True) - last_active_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp(), nullable=False) - status: Mapped[str] = mapped_column(String(16), server_default=sa.text("'active'::character varying")) - initialized_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) - created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp(), nullable=False) - updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp(), nullable=False) + password: Mapped[str | None] = mapped_column(String(255), default=None) + password_salt: Mapped[str | None] = mapped_column(String(255), default=None) + avatar: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None) + interface_language: Mapped[str | None] = mapped_column(String(255), default=None) + interface_theme: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None) + timezone: Mapped[str | None] = mapped_column(String(255), default=None) + last_login_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True, default=None) + last_login_ip: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None) + last_active_at: Mapped[datetime] = mapped_column( + DateTime, server_default=func.current_timestamp(), nullable=False, init=False + ) + status: Mapped[str] = mapped_column( + String(16), server_default=sa.text("'active'::character varying"), default="active" + ) + initialized_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True, default=None) + created_at: Mapped[datetime] = mapped_column( + DateTime, server_default=func.current_timestamp(), nullable=False, init=False + ) + updated_at: Mapped[datetime] = mapped_column( + DateTime, server_default=func.current_timestamp(), nullable=False, init=False + ) - @reconstructor - def init_on_load(self): - self.role: TenantAccountRole | None = None - self._current_tenant: Tenant | None = None + role: TenantAccountRole | None = field(default=None, init=False) + _current_tenant: "Tenant | None" = field(default=None, init=False) @property def is_password_set(self): @@ -226,18 +233,24 @@ class TenantStatus(enum.StrEnum): ARCHIVE = "archive" -class Tenant(Base): +class Tenant(TypeBase): __tablename__ = "tenants" __table_args__ = (sa.PrimaryKeyConstraint("id", name="tenant_pkey"),) - id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) + id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"), init=False) name: Mapped[str] = mapped_column(String(255)) - encrypt_public_key: Mapped[str | None] = mapped_column(sa.Text) - plan: Mapped[str] = mapped_column(String(255), server_default=sa.text("'basic'::character varying")) - status: Mapped[str] = mapped_column(String(255), server_default=sa.text("'normal'::character varying")) - custom_config: Mapped[str | None] = mapped_column(sa.Text) - created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp(), nullable=False) - updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp()) + encrypt_public_key: Mapped[str | None] = mapped_column(sa.Text, default=None) + plan: Mapped[str] = mapped_column( + String(255), server_default=sa.text("'basic'::character varying"), default="basic" + ) + status: Mapped[str] = mapped_column( + String(255), server_default=sa.text("'normal'::character varying"), default="normal" + ) + custom_config: Mapped[str | None] = mapped_column(sa.Text, default=None) + created_at: Mapped[datetime] = mapped_column( + DateTime, server_default=func.current_timestamp(), nullable=False, init=False + ) + updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp(), init=False) def get_accounts(self) -> list[Account]: return list( @@ -257,7 +270,7 @@ class Tenant(Base): self.custom_config = json.dumps(value) -class TenantAccountJoin(Base): +class TenantAccountJoin(TypeBase): __tablename__ = "tenant_account_joins" __table_args__ = ( sa.PrimaryKeyConstraint("id", name="tenant_account_join_pkey"), @@ -266,17 +279,21 @@ class TenantAccountJoin(Base): sa.UniqueConstraint("tenant_id", "account_id", name="unique_tenant_account_join"), ) - id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) + id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"), init=False) tenant_id: Mapped[str] = mapped_column(StringUUID) account_id: Mapped[str] = mapped_column(StringUUID) - current: Mapped[bool] = mapped_column(sa.Boolean, server_default=sa.text("false")) - role: Mapped[str] = mapped_column(String(16), server_default="normal") - invited_by: Mapped[str | None] = mapped_column(StringUUID) - created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp()) - updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp()) + current: Mapped[bool] = mapped_column(sa.Boolean, server_default=sa.text("false"), default=False) + role: Mapped[str] = mapped_column(String(16), server_default="normal", default="normal") + invited_by: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None) + created_at: Mapped[datetime] = mapped_column( + DateTime, server_default=func.current_timestamp(), nullable=False, init=False + ) + updated_at: Mapped[datetime] = mapped_column( + DateTime, server_default=func.current_timestamp(), nullable=False, init=False + ) -class AccountIntegrate(Base): +class AccountIntegrate(TypeBase): __tablename__ = "account_integrates" __table_args__ = ( sa.PrimaryKeyConstraint("id", name="account_integrate_pkey"), @@ -284,16 +301,20 @@ class AccountIntegrate(Base): sa.UniqueConstraint("provider", "open_id", name="unique_provider_open_id"), ) - id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) + id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"), init=False) account_id: Mapped[str] = mapped_column(StringUUID) provider: Mapped[str] = mapped_column(String(16)) open_id: Mapped[str] = mapped_column(String(255)) encrypted_token: Mapped[str] = mapped_column(String(255)) - created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp()) - updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp()) + created_at: Mapped[datetime] = mapped_column( + DateTime, server_default=func.current_timestamp(), nullable=False, init=False + ) + updated_at: Mapped[datetime] = mapped_column( + DateTime, server_default=func.current_timestamp(), nullable=False, init=False + ) -class InvitationCode(Base): +class InvitationCode(TypeBase): __tablename__ = "invitation_codes" __table_args__ = ( sa.PrimaryKeyConstraint("id", name="invitation_code_pkey"), @@ -301,18 +322,22 @@ class InvitationCode(Base): sa.Index("invitation_codes_code_idx", "code", "status"), ) - id: Mapped[int] = mapped_column(sa.Integer) + id: Mapped[int] = mapped_column(sa.Integer, init=False) batch: Mapped[str] = mapped_column(String(255)) code: Mapped[str] = mapped_column(String(32)) - status: Mapped[str] = mapped_column(String(16), server_default=sa.text("'unused'::character varying")) - used_at: Mapped[datetime | None] = mapped_column(DateTime) - used_by_tenant_id: Mapped[str | None] = mapped_column(StringUUID) - used_by_account_id: Mapped[str | None] = mapped_column(StringUUID) - deprecated_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) - created_at: Mapped[datetime] = mapped_column(DateTime, server_default=sa.text("CURRENT_TIMESTAMP(0)")) + status: Mapped[str] = mapped_column( + String(16), server_default=sa.text("'unused'::character varying"), default="unused" + ) + used_at: Mapped[datetime | None] = mapped_column(DateTime, default=None) + used_by_tenant_id: Mapped[str | None] = mapped_column(StringUUID, default=None) + used_by_account_id: Mapped[str | None] = mapped_column(StringUUID, default=None) + deprecated_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True, default=None) + created_at: Mapped[datetime] = mapped_column( + DateTime, server_default=sa.text("CURRENT_TIMESTAMP(0)"), nullable=False, init=False + ) -class TenantPluginPermission(Base): +class TenantPluginPermission(TypeBase): class InstallPermission(enum.StrEnum): EVERYONE = "everyone" ADMINS = "admins" @@ -329,13 +354,17 @@ class TenantPluginPermission(Base): sa.UniqueConstraint("tenant_id", name="unique_tenant_plugin"), ) - id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) + id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"), init=False) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - install_permission: Mapped[InstallPermission] = mapped_column(String(16), nullable=False, server_default="everyone") - debug_permission: Mapped[DebugPermission] = mapped_column(String(16), nullable=False, server_default="noone") + install_permission: Mapped[InstallPermission] = mapped_column( + String(16), nullable=False, server_default="everyone", default=InstallPermission.EVERYONE + ) + debug_permission: Mapped[DebugPermission] = mapped_column( + String(16), nullable=False, server_default="noone", default=DebugPermission.NOBODY + ) -class TenantPluginAutoUpgradeStrategy(Base): +class TenantPluginAutoUpgradeStrategy(TypeBase): class StrategySetting(enum.StrEnum): DISABLED = "disabled" FIX_ONLY = "fix_only" @@ -352,12 +381,20 @@ class TenantPluginAutoUpgradeStrategy(Base): sa.UniqueConstraint("tenant_id", name="unique_tenant_plugin_auto_upgrade_strategy"), ) - id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) + id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"), init=False) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - strategy_setting: Mapped[StrategySetting] = mapped_column(String(16), nullable=False, server_default="fix_only") - upgrade_time_of_day: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0) # seconds of the day - upgrade_mode: Mapped[UpgradeMode] = mapped_column(String(16), nullable=False, server_default="exclude") - exclude_plugins: Mapped[list[str]] = mapped_column(sa.ARRAY(String(255)), nullable=False) # plugin_id (author/name) - include_plugins: Mapped[list[str]] = mapped_column(sa.ARRAY(String(255)), nullable=False) # plugin_id (author/name) - created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) + strategy_setting: Mapped[StrategySetting] = mapped_column( + String(16), nullable=False, server_default="fix_only", default=StrategySetting.FIX_ONLY + ) + upgrade_mode: Mapped[UpgradeMode] = mapped_column( + String(16), nullable=False, server_default="exclude", default=UpgradeMode.EXCLUDE + ) + exclude_plugins: Mapped[list[str]] = mapped_column(sa.ARRAY(String(255)), nullable=False, default_factory=list) + include_plugins: Mapped[list[str]] = mapped_column(sa.ARRAY(String(255)), nullable=False, default_factory=list) + upgrade_time_of_day: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0) + created_at: Mapped[datetime] = mapped_column( + DateTime, nullable=False, server_default=func.current_timestamp(), init=False + ) + updated_at: Mapped[datetime] = mapped_column( + DateTime, nullable=False, server_default=func.current_timestamp(), init=False + ) diff --git a/api/models/api_based_extension.py b/api/models/api_based_extension.py index 60167d9069..e86826fc3d 100644 --- a/api/models/api_based_extension.py +++ b/api/models/api_based_extension.py @@ -9,7 +9,7 @@ from .base import Base from .types import StringUUID -class APIBasedExtensionPoint(enum.Enum): +class APIBasedExtensionPoint(enum.StrEnum): APP_EXTERNAL_DATA_TOOL_QUERY = "app.external_data_tool.query" PING = "ping" APP_MODERATION_INPUT = "app.moderation.input" diff --git a/api/models/dataset.py b/api/models/dataset.py index 25ebe14738..5653445f2b 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -61,18 +61,18 @@ class Dataset(Base): created_by = mapped_column(StringUUID, nullable=False) created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) updated_by = mapped_column(StringUUID, nullable=True) - updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) embedding_model = mapped_column(db.String(255), nullable=True) embedding_model_provider = mapped_column(db.String(255), nullable=True) - keyword_number = db.Column(db.Integer, nullable=True, server_default=db.text("10")) + keyword_number = mapped_column(sa.Integer, nullable=True, server_default=db.text("10")) collection_binding_id = mapped_column(StringUUID, nullable=True) retrieval_model = mapped_column(JSONB, nullable=True) - built_in_field_enabled = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) - icon_info = db.Column(JSONB, nullable=True) - runtime_mode = db.Column(db.String(255), nullable=True, server_default=db.text("'general'::character varying")) - pipeline_id = db.Column(StringUUID, nullable=True) - chunk_structure = db.Column(db.String(255), nullable=True) - enable_api = db.Column(db.Boolean, nullable=False, server_default=db.text("true")) + built_in_field_enabled = mapped_column(sa.Boolean, nullable=False, server_default=db.text("false")) + icon_info = mapped_column(JSONB, nullable=True) + runtime_mode = mapped_column(db.String(255), nullable=True, server_default=db.text("'general'::character varying")) + pipeline_id = mapped_column(StringUUID, nullable=True) + chunk_structure = mapped_column(db.String(255), nullable=True) + enable_api = mapped_column(sa.Boolean, nullable=False, server_default=db.text("true")) @property def total_documents(self): @@ -184,7 +184,7 @@ class Dataset(Base): @property def retrieval_model_dict(self): default_retrieval_model = { - "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, + "search_method": RetrievalMethod.SEMANTIC_SEARCH, "reranking_enable": False, "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, "top_k": 2, @@ -754,7 +754,7 @@ class DocumentSegment(Base): if process_rule and process_rule.mode == "hierarchical": rules_dict = process_rule.rules_dict if rules_dict: - rules = Rule(**rules_dict) + rules = Rule.model_validate(rules_dict) if rules.parent_mode and rules.parent_mode != ParentMode.FULL_DOC: child_chunks = ( db.session.query(ChildChunk) @@ -772,7 +772,7 @@ class DocumentSegment(Base): if process_rule and process_rule.mode == "hierarchical": rules_dict = process_rule.rules_dict if rules_dict: - rules = Rule(**rules_dict) + rules = Rule.model_validate(rules_dict) if rules.parent_mode: child_chunks = ( db.session.query(ChildChunk) @@ -1226,21 +1226,21 @@ class PipelineBuiltInTemplate(Base): # type: ignore[name-defined] __tablename__ = "pipeline_built_in_templates" __table_args__ = (db.PrimaryKeyConstraint("id", name="pipeline_built_in_template_pkey"),) - id = db.Column(StringUUID, server_default=db.text("uuidv7()")) - name = db.Column(db.String(255), nullable=False) - description = db.Column(db.Text, nullable=False) - chunk_structure = db.Column(db.String(255), nullable=False) - icon = db.Column(db.JSON, nullable=False) - yaml_content = db.Column(db.Text, nullable=False) - copyright = db.Column(db.String(255), nullable=False) - privacy_policy = db.Column(db.String(255), nullable=False) - position = db.Column(db.Integer, nullable=False) - install_count = db.Column(db.Integer, nullable=False, default=0) - language = db.Column(db.String(255), nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - created_by = db.Column(StringUUID, nullable=False) - updated_by = db.Column(StringUUID, nullable=True) + id = mapped_column(StringUUID, server_default=db.text("uuidv7()")) + name = mapped_column(db.String(255), nullable=False) + description = mapped_column(sa.Text, nullable=False) + chunk_structure = mapped_column(db.String(255), nullable=False) + icon = mapped_column(sa.JSON, nullable=False) + yaml_content = mapped_column(sa.Text, nullable=False) + copyright = mapped_column(db.String(255), nullable=False) + privacy_policy = mapped_column(db.String(255), nullable=False) + position = mapped_column(sa.Integer, nullable=False) + install_count = mapped_column(sa.Integer, nullable=False, default=0) + language = mapped_column(db.String(255), nullable=False) + created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) + created_by = mapped_column(StringUUID, nullable=False) + updated_by = mapped_column(StringUUID, nullable=True) @property def created_user_name(self): @@ -1257,20 +1257,20 @@ class PipelineCustomizedTemplate(Base): # type: ignore[name-defined] db.Index("pipeline_customized_template_tenant_idx", "tenant_id"), ) - id = db.Column(StringUUID, server_default=db.text("uuidv7()")) - tenant_id = db.Column(StringUUID, nullable=False) - name = db.Column(db.String(255), nullable=False) - description = db.Column(db.Text, nullable=False) - chunk_structure = db.Column(db.String(255), nullable=False) - icon = db.Column(db.JSON, nullable=False) - position = db.Column(db.Integer, nullable=False) - yaml_content = db.Column(db.Text, nullable=False) - install_count = db.Column(db.Integer, nullable=False, default=0) - language = db.Column(db.String(255), nullable=False) - created_by = db.Column(StringUUID, nullable=False) - updated_by = db.Column(StringUUID, nullable=True) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + id = mapped_column(StringUUID, server_default=db.text("uuidv7()")) + tenant_id = mapped_column(StringUUID, nullable=False) + name = mapped_column(db.String(255), nullable=False) + description = mapped_column(sa.Text, nullable=False) + chunk_structure = mapped_column(db.String(255), nullable=False) + icon = mapped_column(sa.JSON, nullable=False) + position = mapped_column(sa.Integer, nullable=False) + yaml_content = mapped_column(sa.Text, nullable=False) + install_count = mapped_column(sa.Integer, nullable=False, default=0) + language = mapped_column(db.String(255), nullable=False) + created_by = mapped_column(StringUUID, nullable=False) + updated_by = mapped_column(StringUUID, nullable=True) + created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) @property def created_user_name(self): @@ -1284,17 +1284,17 @@ class Pipeline(Base): # type: ignore[name-defined] __tablename__ = "pipelines" __table_args__ = (db.PrimaryKeyConstraint("id", name="pipeline_pkey"),) - id = db.Column(StringUUID, server_default=db.text("uuidv7()")) - tenant_id: Mapped[str] = db.Column(StringUUID, nullable=False) - name = db.Column(db.String(255), nullable=False) - description = db.Column(db.Text, nullable=False, server_default=db.text("''::character varying")) - workflow_id = db.Column(StringUUID, nullable=True) - is_public = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) - is_published = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) - created_by = db.Column(StringUUID, nullable=True) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_by = db.Column(StringUUID, nullable=True) - updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + id = mapped_column(StringUUID, server_default=db.text("uuidv7()")) + tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + name = mapped_column(db.String(255), nullable=False) + description = mapped_column(sa.Text, nullable=False, server_default=db.text("''::character varying")) + workflow_id = mapped_column(StringUUID, nullable=True) + is_public = mapped_column(sa.Boolean, nullable=False, server_default=db.text("false")) + is_published = mapped_column(sa.Boolean, nullable=False, server_default=db.text("false")) + created_by = mapped_column(StringUUID, nullable=True) + created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_by = mapped_column(StringUUID, nullable=True) + updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) def retrieve_dataset(self, session: Session): return session.query(Dataset).where(Dataset.pipeline_id == self.id).first() @@ -1307,25 +1307,25 @@ class DocumentPipelineExecutionLog(Base): db.Index("document_pipeline_execution_logs_document_id_idx", "document_id"), ) - id = db.Column(StringUUID, server_default=db.text("uuidv7()")) - pipeline_id = db.Column(StringUUID, nullable=False) - document_id = db.Column(StringUUID, nullable=False) - datasource_type = db.Column(db.String(255), nullable=False) - datasource_info = db.Column(db.Text, nullable=False) - datasource_node_id = db.Column(db.String(255), nullable=False) - input_data = db.Column(db.JSON, nullable=False) - created_by = db.Column(StringUUID, nullable=True) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + id = mapped_column(StringUUID, server_default=db.text("uuidv7()")) + pipeline_id = mapped_column(StringUUID, nullable=False) + document_id = mapped_column(StringUUID, nullable=False) + datasource_type = mapped_column(db.String(255), nullable=False) + datasource_info = mapped_column(sa.Text, nullable=False) + datasource_node_id = mapped_column(db.String(255), nullable=False) + input_data = mapped_column(sa.JSON, nullable=False) + created_by = mapped_column(StringUUID, nullable=True) + created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) class PipelineRecommendedPlugin(Base): __tablename__ = "pipeline_recommended_plugins" __table_args__ = (db.PrimaryKeyConstraint("id", name="pipeline_recommended_plugin_pkey"),) - id = db.Column(StringUUID, server_default=db.text("uuidv7()")) - plugin_id = db.Column(db.Text, nullable=False) - provider_name = db.Column(db.Text, nullable=False) - position = db.Column(db.Integer, nullable=False, default=0) - active = db.Column(db.Boolean, nullable=False, default=True) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + id = mapped_column(StringUUID, server_default=db.text("uuidv7()")) + plugin_id = mapped_column(sa.Text, nullable=False) + provider_name = mapped_column(sa.Text, nullable=False) + position = mapped_column(sa.Integer, nullable=False, default=0) + active = mapped_column(sa.Boolean, nullable=False, default=True) + created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) diff --git a/api/models/model.py b/api/models/model.py index a8218c3a4e..18958c8253 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -186,13 +186,13 @@ class App(Base): if len(keys) >= 4: provider_type = tool.get("provider_type", "") provider_id = tool.get("provider_id", "") - if provider_type == ToolProviderType.API.value: + if provider_type == ToolProviderType.API: try: uuid.UUID(provider_id) except Exception: continue api_provider_ids.append(provider_id) - if provider_type == ToolProviderType.BUILT_IN.value: + if provider_type == ToolProviderType.BUILT_IN: try: # check if it's hardcoded try: @@ -251,23 +251,23 @@ class App(Base): provider_type = tool.get("provider_type", "") provider_id = tool.get("provider_id", "") - if provider_type == ToolProviderType.API.value: + if provider_type == ToolProviderType.API: if uuid.UUID(provider_id) not in existing_api_providers: deleted_tools.append( { - "type": ToolProviderType.API.value, + "type": ToolProviderType.API, "tool_name": tool["tool_name"], "provider_id": provider_id, } ) - if provider_type == ToolProviderType.BUILT_IN.value: + if provider_type == ToolProviderType.BUILT_IN: generic_provider_id = GenericProviderID(provider_id) if not existing_builtin_providers[generic_provider_id.provider_name]: deleted_tools.append( { - "type": ToolProviderType.BUILT_IN.value, + "type": ToolProviderType.BUILT_IN, "tool_name": tool["tool_name"], "provider_id": provider_id, # use the original one } @@ -1154,7 +1154,7 @@ class Message(Base): files: list[File] = [] for message_file in message_files: - if message_file.transfer_method == FileTransferMethod.LOCAL_FILE.value: + if message_file.transfer_method == FileTransferMethod.LOCAL_FILE: if message_file.upload_file_id is None: raise ValueError(f"MessageFile {message_file.id} is a local file but has no upload_file_id") file = file_factory.build_from_mapping( @@ -1166,7 +1166,7 @@ class Message(Base): }, tenant_id=current_app.tenant_id, ) - elif message_file.transfer_method == FileTransferMethod.REMOTE_URL.value: + elif message_file.transfer_method == FileTransferMethod.REMOTE_URL: if message_file.url is None: raise ValueError(f"MessageFile {message_file.id} is a remote url but has no url") file = file_factory.build_from_mapping( @@ -1179,7 +1179,7 @@ class Message(Base): }, tenant_id=current_app.tenant_id, ) - elif message_file.transfer_method == FileTransferMethod.TOOL_FILE.value: + elif message_file.transfer_method == FileTransferMethod.TOOL_FILE: if message_file.upload_file_id is None: assert message_file.url is not None message_file.upload_file_id = message_file.url.split("/")[-1].split(".")[0] diff --git a/api/models/oauth.py b/api/models/oauth.py index 1d5d37e3e1..ef23780dc8 100644 --- a/api/models/oauth.py +++ b/api/models/oauth.py @@ -1,7 +1,8 @@ from datetime import datetime +import sqlalchemy as sa from sqlalchemy.dialects.postgresql import JSONB -from sqlalchemy.orm import Mapped +from sqlalchemy.orm import Mapped, mapped_column from .base import Base from .engine import db @@ -15,10 +16,10 @@ class DatasourceOauthParamConfig(Base): # type: ignore[name-defined] db.UniqueConstraint("plugin_id", "provider", name="datasource_oauth_config_datasource_id_provider_idx"), ) - id = db.Column(StringUUID, server_default=db.text("uuidv7()")) - plugin_id: Mapped[str] = db.Column(db.String(255), nullable=False) - provider: Mapped[str] = db.Column(db.String(255), nullable=False) - system_credentials: Mapped[dict] = db.Column(JSONB, nullable=False) + id = mapped_column(StringUUID, server_default=db.text("uuidv7()")) + plugin_id: Mapped[str] = mapped_column(db.String(255), nullable=False) + provider: Mapped[str] = mapped_column(db.String(255), nullable=False) + system_credentials: Mapped[dict] = mapped_column(JSONB, nullable=False) class DatasourceProvider(Base): @@ -28,19 +29,19 @@ class DatasourceProvider(Base): db.UniqueConstraint("tenant_id", "plugin_id", "provider", "name", name="datasource_provider_unique_name"), db.Index("datasource_provider_auth_type_provider_idx", "tenant_id", "plugin_id", "provider"), ) - id = db.Column(StringUUID, server_default=db.text("uuidv7()")) - tenant_id = db.Column(StringUUID, nullable=False) - name: Mapped[str] = db.Column(db.String(255), nullable=False) - provider: Mapped[str] = db.Column(db.String(255), nullable=False) - plugin_id: Mapped[str] = db.Column(db.String(255), nullable=False) - auth_type: Mapped[str] = db.Column(db.String(255), nullable=False) - encrypted_credentials: Mapped[dict] = db.Column(JSONB, nullable=False) - avatar_url: Mapped[str] = db.Column(db.Text, nullable=True, default="default") - is_default: Mapped[bool] = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) - expires_at: Mapped[int] = db.Column(db.Integer, nullable=False, server_default="-1") + id = mapped_column(StringUUID, server_default=db.text("uuidv7()")) + tenant_id = mapped_column(StringUUID, nullable=False) + name: Mapped[str] = mapped_column(db.String(255), nullable=False) + provider: Mapped[str] = mapped_column(db.String(255), nullable=False) + plugin_id: Mapped[str] = mapped_column(db.String(255), nullable=False) + auth_type: Mapped[str] = mapped_column(db.String(255), nullable=False) + encrypted_credentials: Mapped[dict] = mapped_column(JSONB, nullable=False) + avatar_url: Mapped[str] = mapped_column(sa.Text, nullable=True, default="default") + is_default: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=db.text("false")) + expires_at: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default="-1") - created_at: Mapped[datetime] = db.Column(db.DateTime, nullable=False, default=datetime.now) - updated_at: Mapped[datetime] = db.Column(db.DateTime, nullable=False, default=datetime.now) + created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, default=datetime.now) + updated_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, default=datetime.now) class DatasourceOauthTenantParamConfig(Base): @@ -50,12 +51,12 @@ class DatasourceOauthTenantParamConfig(Base): db.UniqueConstraint("tenant_id", "plugin_id", "provider", name="datasource_oauth_tenant_config_unique"), ) - id = db.Column(StringUUID, server_default=db.text("uuidv7()")) - tenant_id = db.Column(StringUUID, nullable=False) - provider: Mapped[str] = db.Column(db.String(255), nullable=False) - plugin_id: Mapped[str] = db.Column(db.String(255), nullable=False) - client_params: Mapped[dict] = db.Column(JSONB, nullable=False, default={}) - enabled: Mapped[bool] = db.Column(db.Boolean, nullable=False, default=False) + id = mapped_column(StringUUID, server_default=db.text("uuidv7()")) + tenant_id = mapped_column(StringUUID, nullable=False) + provider: Mapped[str] = mapped_column(db.String(255), nullable=False) + plugin_id: Mapped[str] = mapped_column(db.String(255), nullable=False) + client_params: Mapped[dict] = mapped_column(JSONB, nullable=False, default={}) + enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, default=False) - created_at: Mapped[datetime] = db.Column(db.DateTime, nullable=False, default=datetime.now) - updated_at: Mapped[datetime] = db.Column(db.DateTime, nullable=False, default=datetime.now) + created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, default=datetime.now) + updated_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, default=datetime.now) diff --git a/api/models/provider.py b/api/models/provider.py index aacc6e505a..f6852d49f4 100644 --- a/api/models/provider.py +++ b/api/models/provider.py @@ -107,7 +107,7 @@ class Provider(Base): """ Returns True if the provider is enabled. """ - if self.provider_type == ProviderType.SYSTEM.value: + if self.provider_type == ProviderType.SYSTEM: return self.is_valid else: return self.is_valid and self.token_is_set diff --git a/api/models/task.py b/api/models/task.py index 3da1674536..4e49254dbd 100644 --- a/api/models/task.py +++ b/api/models/task.py @@ -8,8 +8,6 @@ from sqlalchemy.orm import Mapped, mapped_column from libs.datetime_utils import naive_utc_now from models.base import Base -from .engine import db - class CeleryTask(Base): """Task result/status.""" @@ -19,7 +17,7 @@ class CeleryTask(Base): id = mapped_column(sa.Integer, sa.Sequence("task_id_sequence"), primary_key=True, autoincrement=True) task_id = mapped_column(String(155), unique=True) status = mapped_column(String(50), default=states.PENDING) - result = mapped_column(db.PickleType, nullable=True) + result = mapped_column(sa.PickleType, nullable=True) date_done = mapped_column( DateTime, default=lambda: naive_utc_now(), @@ -44,5 +42,5 @@ class CeleryTaskSet(Base): sa.Integer, sa.Sequence("taskset_id_sequence"), autoincrement=True, primary_key=True ) taskset_id = mapped_column(String(155), unique=True) - result = mapped_column(db.PickleType, nullable=True) + result = mapped_column(sa.PickleType, nullable=True) date_done: Mapped[datetime | None] = mapped_column(DateTime, default=lambda: naive_utc_now(), nullable=True) diff --git a/api/models/tools.py b/api/models/tools.py index 7211d7aa3a..d581d588a4 100644 --- a/api/models/tools.py +++ b/api/models/tools.py @@ -152,7 +152,7 @@ class ApiToolProvider(Base): def tools(self) -> list["ApiToolBundle"]: from core.tools.entities.tool_bundle import ApiToolBundle - return [ApiToolBundle(**tool) for tool in json.loads(self.tools_str)] + return [ApiToolBundle.model_validate(tool) for tool in json.loads(self.tools_str)] @property def credentials(self) -> dict[str, Any]: @@ -242,7 +242,10 @@ class WorkflowToolProvider(Base): def parameter_configurations(self) -> list["WorkflowToolParameterConfiguration"]: from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration - return [WorkflowToolParameterConfiguration(**config) for config in json.loads(self.parameter_configuration)] + return [ + WorkflowToolParameterConfiguration.model_validate(config) + for config in json.loads(self.parameter_configuration) + ] @property def app(self) -> App | None: @@ -312,7 +315,7 @@ class MCPToolProvider(Base): def mcp_tools(self) -> list["MCPTool"]: from core.mcp.types import Tool as MCPTool - return [MCPTool(**tool) for tool in json.loads(self.tools)] + return [MCPTool.model_validate(tool) for tool in json.loads(self.tools)] @property def provider_icon(self) -> Mapping[str, str] | str: @@ -552,4 +555,4 @@ class DeprecatedPublishedAppTool(Base): def description_i18n(self) -> "I18nObject": from core.tools.entities.common_entities import I18nObject - return I18nObject(**json.loads(self.description)) + return I18nObject.model_validate(json.loads(self.description)) diff --git a/api/models/workflow.py b/api/models/workflow.py index e61005953e..b898f02612 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -360,7 +360,9 @@ class Workflow(Base): @property def environment_variables(self) -> Sequence[StringVariable | IntegerVariable | FloatVariable | SecretVariable]: - # _environment_variables is guaranteed to be non-None due to server_default="{}" + # TODO: find some way to init `self._environment_variables` when instance created. + if self._environment_variables is None: + self._environment_variables = "{}" # Use workflow.tenant_id to avoid relying on request user in background threads tenant_id = self.tenant_id @@ -444,7 +446,9 @@ class Workflow(Base): @property def conversation_variables(self) -> Sequence[Variable]: - # _conversation_variables is guaranteed to be non-None due to server_default="{}" + # TODO: find some way to init `self._conversation_variables` when instance created. + if self._conversation_variables is None: + self._conversation_variables = "{}" variables_dict: dict[str, Any] = json.loads(self._conversation_variables) results = [variable_factory.build_conversation_variable_from_mapping(v) for v in variables_dict.values()] @@ -825,14 +829,14 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo if self.execution_metadata_dict: from core.workflow.nodes import NodeType - if self.node_type == NodeType.TOOL.value and "tool_info" in self.execution_metadata_dict: + if self.node_type == NodeType.TOOL and "tool_info" in self.execution_metadata_dict: tool_info: dict[str, Any] = self.execution_metadata_dict["tool_info"] extras["icon"] = ToolManager.get_tool_icon( tenant_id=self.tenant_id, provider_type=tool_info["provider_type"], provider_id=tool_info["provider_id"], ) - elif self.node_type == NodeType.DATASOURCE.value and "datasource_info" in self.execution_metadata_dict: + elif self.node_type == NodeType.DATASOURCE and "datasource_info" in self.execution_metadata_dict: datasource_info = self.execution_metadata_dict["datasource_info"] extras["icon"] = datasource_info.get("icon") return extras diff --git a/api/pyproject.toml b/api/pyproject.toml index 1f51d60098..22eedf7b8b 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -48,7 +48,7 @@ dependencies = [ "opentelemetry-instrumentation-flask==0.48b0", "opentelemetry-instrumentation-httpx==0.48b0", "opentelemetry-instrumentation-redis==0.48b0", - "opentelemetry-instrumentation-requests==0.48b0", + "opentelemetry-instrumentation-httpx==0.48b0", "opentelemetry-instrumentation-sqlalchemy==0.48b0", "opentelemetry-propagator-b3==1.27.0", # opentelemetry-proto1.28.0 depends on protobuf (>=5.0,<6.0), @@ -110,7 +110,7 @@ dev = [ "lxml-stubs~=0.5.1", "ty~=0.0.1a19", "basedpyright~=1.31.0", - "ruff~=0.12.3", + "ruff~=0.14.0", "pytest~=8.3.2", "pytest-benchmark~=4.0.0", "pytest-cov~=4.1.0", @@ -145,8 +145,6 @@ dev = [ "types-pywin32~=310.0.0", "types-pyyaml~=6.0.12", "types-regex~=2024.11.6", - "types-requests~=2.32.0", - "types-requests-oauthlib~=2.0.0", "types-shapely~=2.0.0", "types-simplejson>=3.20.0", "types-six>=1.17.0", diff --git a/api/pyrightconfig.json b/api/pyrightconfig.json index 67571316a9..bf4ec2314e 100644 --- a/api/pyrightconfig.json +++ b/api/pyrightconfig.json @@ -15,7 +15,8 @@ "opentelemetry.instrumentation.httpx", "opentelemetry.instrumentation.requests", "opentelemetry.instrumentation.sqlalchemy", - "opentelemetry.instrumentation.redis" + "opentelemetry.instrumentation.redis", + "opentelemetry.instrumentation.httpx" ], "reportUnknownMemberType": "hint", "reportUnknownParameterType": "hint", diff --git a/api/services/account_service.py b/api/services/account_service.py index 0e699d16da..106bc0e77e 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -127,7 +127,7 @@ class AccountService: if not account: return None - if account.status == AccountStatus.BANNED.value: + if account.status == AccountStatus.BANNED: raise Unauthorized("Account is banned.") current_tenant = db.session.query(TenantAccountJoin).filter_by(account_id=account.id, current=True).first() @@ -178,7 +178,7 @@ class AccountService: if not account: raise AccountPasswordError("Invalid email or password.") - if account.status == AccountStatus.BANNED.value: + if account.status == AccountStatus.BANNED: raise AccountLoginError("Account is banned.") if password and invite_token and account.password is None: @@ -193,8 +193,8 @@ class AccountService: if account.password is None or not compare_password(password, account.password, account.password_salt): raise AccountPasswordError("Invalid email or password.") - if account.status == AccountStatus.PENDING.value: - account.status = AccountStatus.ACTIVE.value + if account.status == AccountStatus.PENDING: + account.status = AccountStatus.ACTIVE account.initialized_at = naive_utc_now() db.session.commit() @@ -246,10 +246,8 @@ class AccountService: ) ) - account = Account() - account.email = email - account.name = name - + password_to_set = None + salt_to_set = None if password: valid_password(password) @@ -261,14 +259,18 @@ class AccountService: password_hashed = hash_password(password, salt) base64_password_hashed = base64.b64encode(password_hashed).decode() - account.password = base64_password_hashed - account.password_salt = base64_salt + password_to_set = base64_password_hashed + salt_to_set = base64_salt - account.interface_language = interface_language - account.interface_theme = interface_theme - - # Set timezone based on language - account.timezone = language_timezone_mapping.get(interface_language, "UTC") + account = Account( + name=name, + email=email, + password=password_to_set, + password_salt=salt_to_set, + interface_language=interface_language, + interface_theme=interface_theme, + timezone=language_timezone_mapping.get(interface_language, "UTC"), + ) db.session.add(account) db.session.commit() @@ -355,7 +357,7 @@ class AccountService: @staticmethod def close_account(account: Account): """Close account""" - account.status = AccountStatus.CLOSED.value + account.status = AccountStatus.CLOSED db.session.commit() @staticmethod @@ -395,8 +397,8 @@ class AccountService: if ip_address: AccountService.update_login_info(account=account, ip_address=ip_address) - if account.status == AccountStatus.PENDING.value: - account.status = AccountStatus.ACTIVE.value + if account.status == AccountStatus.PENDING: + account.status = AccountStatus.ACTIVE db.session.commit() access_token = AccountService.get_account_jwt_token(account=account) @@ -764,7 +766,7 @@ class AccountService: if not account: return None - if account.status == AccountStatus.BANNED.value: + if account.status == AccountStatus.BANNED: raise Unauthorized("Account is banned.") return account @@ -1028,7 +1030,7 @@ class TenantService: @staticmethod def create_tenant_member(tenant: Tenant, account: Account, role: str = "normal") -> TenantAccountJoin: """Create tenant member""" - if role == TenantAccountRole.OWNER.value: + if role == TenantAccountRole.OWNER: if TenantService.has_roles(tenant, [TenantAccountRole.OWNER]): logger.error("Tenant %s has already an owner.", tenant.id) raise Exception("Tenant already has an owner.") @@ -1313,7 +1315,7 @@ class RegisterService: password=password, is_setup=is_setup, ) - account.status = AccountStatus.ACTIVE.value if not status else status.value + account.status = status or AccountStatus.ACTIVE account.initialized_at = naive_utc_now() if open_id is not None and provider is not None: @@ -1374,7 +1376,7 @@ class RegisterService: TenantService.create_tenant_member(tenant, account, role) # Support resend invitation email when the account is pending status - if account.status != AccountStatus.PENDING.value: + if account.status != AccountStatus.PENDING: raise AccountAlreadyInTenantError("Account already in tenant.") token = cls.generate_invite_token(tenant, account) diff --git a/api/services/app_dsl_service.py b/api/services/app_dsl_service.py index 8701fe4f4e..311f80bef6 100644 --- a/api/services/app_dsl_service.py +++ b/api/services/app_dsl_service.py @@ -494,7 +494,7 @@ class AppDslService: unique_hash = None graph = workflow_data.get("graph", {}) for node in graph.get("nodes", []): - if node.get("data", {}).get("type", "") == NodeType.KNOWLEDGE_RETRIEVAL.value: + if node.get("data", {}).get("type", "") == NodeType.KNOWLEDGE_RETRIEVAL: dataset_ids = node["data"].get("dataset_ids", []) node["data"]["dataset_ids"] = [ decrypted_id @@ -584,17 +584,17 @@ class AppDslService: if not node_data: continue data_type = node_data.get("type", "") - if data_type == NodeType.KNOWLEDGE_RETRIEVAL.value: + if data_type == NodeType.KNOWLEDGE_RETRIEVAL: dataset_ids = node_data.get("dataset_ids", []) node_data["dataset_ids"] = [ cls.encrypt_dataset_id(dataset_id=dataset_id, tenant_id=app_model.tenant_id) for dataset_id in dataset_ids ] # filter credential id from tool node - if not include_secret and data_type == NodeType.TOOL.value: + if not include_secret and data_type == NodeType.TOOL: node_data.pop("credential_id", None) # filter credential id from agent node - if not include_secret and data_type == NodeType.AGENT.value: + if not include_secret and data_type == NodeType.AGENT: for tool in node_data.get("agent_parameters", {}).get("tools", {}).get("value", []): tool.pop("credential_id", None) @@ -658,32 +658,32 @@ class AppDslService: try: typ = node.get("data", {}).get("type") match typ: - case NodeType.TOOL.value: - tool_entity = ToolNodeData(**node["data"]) + case NodeType.TOOL: + tool_entity = ToolNodeData.model_validate(node["data"]) dependencies.append( DependenciesAnalysisService.analyze_tool_dependency(tool_entity.provider_id), ) - case NodeType.LLM.value: - llm_entity = LLMNodeData(**node["data"]) + case NodeType.LLM: + llm_entity = LLMNodeData.model_validate(node["data"]) dependencies.append( DependenciesAnalysisService.analyze_model_provider_dependency(llm_entity.model.provider), ) - case NodeType.QUESTION_CLASSIFIER.value: - question_classifier_entity = QuestionClassifierNodeData(**node["data"]) + case NodeType.QUESTION_CLASSIFIER: + question_classifier_entity = QuestionClassifierNodeData.model_validate(node["data"]) dependencies.append( DependenciesAnalysisService.analyze_model_provider_dependency( question_classifier_entity.model.provider ), ) - case NodeType.PARAMETER_EXTRACTOR.value: - parameter_extractor_entity = ParameterExtractorNodeData(**node["data"]) + case NodeType.PARAMETER_EXTRACTOR: + parameter_extractor_entity = ParameterExtractorNodeData.model_validate(node["data"]) dependencies.append( DependenciesAnalysisService.analyze_model_provider_dependency( parameter_extractor_entity.model.provider ), ) - case NodeType.KNOWLEDGE_RETRIEVAL.value: - knowledge_retrieval_entity = KnowledgeRetrievalNodeData(**node["data"]) + case NodeType.KNOWLEDGE_RETRIEVAL: + knowledge_retrieval_entity = KnowledgeRetrievalNodeData.model_validate(node["data"]) if knowledge_retrieval_entity.retrieval_mode == "multiple": if knowledge_retrieval_entity.multiple_retrieval_config: if ( @@ -773,7 +773,7 @@ class AppDslService: """ Returns the leaked dependencies in current workspace """ - dependencies = [PluginDependency(**dep) for dep in dsl_dependencies] + dependencies = [PluginDependency.model_validate(dep) for dep in dsl_dependencies] if not dependencies: return [] diff --git a/api/services/datasource_provider_service.py b/api/services/datasource_provider_service.py index 89a5d89f61..36b7084973 100644 --- a/api/services/datasource_provider_service.py +++ b/api/services/datasource_provider_service.py @@ -646,7 +646,7 @@ class DatasourceProviderService: name=db_provider_name, provider=provider_name, plugin_id=plugin_id, - auth_type=CredentialType.API_KEY.value, + auth_type=CredentialType.API_KEY, encrypted_credentials=credentials, ) session.add(datasource_provider) @@ -674,7 +674,7 @@ class DatasourceProviderService: secret_input_form_variables = [] for credential_form_schema in credential_form_schemas: - if credential_form_schema.type.value == FormType.SECRET_INPUT.value: + if credential_form_schema.type.value == FormType.SECRET_INPUT: secret_input_form_variables.append(credential_form_schema.name) return secret_input_form_variables diff --git a/api/services/enterprise/base.py b/api/services/enterprise/base.py index edb76408e8..bdc960aa2d 100644 --- a/api/services/enterprise/base.py +++ b/api/services/enterprise/base.py @@ -1,10 +1,12 @@ import os +from collections.abc import Mapping +from typing import Any -import requests +import httpx class BaseRequest: - proxies = { + proxies: Mapping[str, str] | None = { "http": "", "https": "", } @@ -13,10 +15,31 @@ class BaseRequest: secret_key_header = "" @classmethod - def send_request(cls, method, endpoint, json=None, params=None): + def _build_mounts(cls) -> dict[str, httpx.BaseTransport] | None: + if not cls.proxies: + return None + + mounts: dict[str, httpx.BaseTransport] = {} + for scheme, value in cls.proxies.items(): + if not value: + continue + key = f"{scheme}://" if not scheme.endswith("://") else scheme + mounts[key] = httpx.HTTPTransport(proxy=value) + return mounts or None + + @classmethod + def send_request( + cls, + method: str, + endpoint: str, + json: Any | None = None, + params: Mapping[str, Any] | None = None, + ) -> Any: headers = {"Content-Type": "application/json", cls.secret_key_header: cls.secret_key} url = f"{cls.base_url}{endpoint}" - response = requests.request(method, url, json=json, params=params, headers=headers, proxies=cls.proxies) + mounts = cls._build_mounts() + with httpx.Client(mounts=mounts) as client: + response = client.request(method, url, json=json, params=params, headers=headers) return response.json() diff --git a/api/services/enterprise/enterprise_service.py b/api/services/enterprise/enterprise_service.py index f8612456d6..4fbf33fd6f 100644 --- a/api/services/enterprise/enterprise_service.py +++ b/api/services/enterprise/enterprise_service.py @@ -70,7 +70,7 @@ class EnterpriseService: data = EnterpriseRequest.send_request("GET", "/webapp/access-mode/id", params=params) if not data: raise ValueError("No data found.") - return WebAppSettings(**data) + return WebAppSettings.model_validate(data) @classmethod def batch_get_app_access_mode_by_id(cls, app_ids: list[str]) -> dict[str, WebAppSettings]: @@ -100,7 +100,7 @@ class EnterpriseService: data = EnterpriseRequest.send_request("GET", "/webapp/access-mode/code", params=params) if not data: raise ValueError("No data found.") - return WebAppSettings(**data) + return WebAppSettings.model_validate(data) @classmethod def update_app_access_mode(cls, app_id: str, access_mode: str): diff --git a/api/services/entities/model_provider_entities.py b/api/services/entities/model_provider_entities.py index 49d48f044c..0f5151919f 100644 --- a/api/services/entities/model_provider_entities.py +++ b/api/services/entities/model_provider_entities.py @@ -1,6 +1,7 @@ +from collections.abc import Sequence from enum import Enum -from pydantic import BaseModel, ConfigDict +from pydantic import BaseModel, ConfigDict, model_validator from configs import dify_config from core.entities.model_entities import ( @@ -71,7 +72,7 @@ class ProviderResponse(BaseModel): icon_large: I18nObject | None = None background: str | None = None help: ProviderHelpEntity | None = None - supported_model_types: list[ModelType] + supported_model_types: Sequence[ModelType] configurate_methods: list[ConfigurateMethod] provider_credential_schema: ProviderCredentialSchema | None = None model_credential_schema: ModelCredentialSchema | None = None @@ -82,9 +83,8 @@ class ProviderResponse(BaseModel): # pydantic configs model_config = ConfigDict(protected_namespaces=()) - def __init__(self, **data): - super().__init__(**data) - + @model_validator(mode="after") + def _(self): url_prefix = ( dify_config.CONSOLE_API_URL + f"/console/api/workspaces/{self.tenant_id}/model-providers/{self.provider}" ) @@ -97,6 +97,7 @@ class ProviderResponse(BaseModel): self.icon_large = I18nObject( en_US=f"{url_prefix}/icon_large/en_US", zh_Hans=f"{url_prefix}/icon_large/zh_Hans" ) + return self class ProviderWithModelsResponse(BaseModel): @@ -112,9 +113,8 @@ class ProviderWithModelsResponse(BaseModel): status: CustomConfigurationStatus models: list[ProviderModelWithStatusEntity] - def __init__(self, **data): - super().__init__(**data) - + @model_validator(mode="after") + def _(self): url_prefix = ( dify_config.CONSOLE_API_URL + f"/console/api/workspaces/{self.tenant_id}/model-providers/{self.provider}" ) @@ -127,6 +127,7 @@ class ProviderWithModelsResponse(BaseModel): self.icon_large = I18nObject( en_US=f"{url_prefix}/icon_large/en_US", zh_Hans=f"{url_prefix}/icon_large/zh_Hans" ) + return self class SimpleProviderEntityResponse(SimpleProviderEntity): @@ -136,9 +137,8 @@ class SimpleProviderEntityResponse(SimpleProviderEntity): tenant_id: str - def __init__(self, **data): - super().__init__(**data) - + @model_validator(mode="after") + def _(self): url_prefix = ( dify_config.CONSOLE_API_URL + f"/console/api/workspaces/{self.tenant_id}/model-providers/{self.provider}" ) @@ -151,6 +151,7 @@ class SimpleProviderEntityResponse(SimpleProviderEntity): self.icon_large = I18nObject( en_US=f"{url_prefix}/icon_large/en_US", zh_Hans=f"{url_prefix}/icon_large/zh_Hans" ) + return self class DefaultModelResponse(BaseModel): diff --git a/api/services/hit_testing_service.py b/api/services/hit_testing_service.py index 00ec3babf3..aa29354a6e 100644 --- a/api/services/hit_testing_service.py +++ b/api/services/hit_testing_service.py @@ -15,7 +15,7 @@ from models.dataset import Dataset, DatasetQuery logger = logging.getLogger(__name__) default_retrieval_model = { - "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, + "search_method": RetrievalMethod.SEMANTIC_SEARCH, "reranking_enable": False, "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, "top_k": 4, @@ -46,7 +46,7 @@ class HitTestingService: from core.app.app_config.entities import MetadataFilteringCondition - metadata_filtering_conditions = MetadataFilteringCondition(**metadata_filtering_conditions) + metadata_filtering_conditions = MetadataFilteringCondition.model_validate(metadata_filtering_conditions) metadata_filter_document_ids, metadata_condition = dataset_retrieval.get_metadata_filter_condition( dataset_ids=[dataset.id], diff --git a/api/services/ops_service.py b/api/services/ops_service.py index c214640653..b4b23b8360 100644 --- a/api/services/ops_service.py +++ b/api/services/ops_service.py @@ -123,7 +123,7 @@ class OpsService: config_class: type[BaseTracingConfig] = provider_config["config_class"] other_keys: list[str] = provider_config["other_keys"] - default_config_instance: BaseTracingConfig = config_class(**tracing_config) + default_config_instance = config_class.model_validate(tracing_config) for key in other_keys: if key in tracing_config and tracing_config[key] == "": tracing_config[key] = getattr(default_config_instance, key, None) diff --git a/api/services/plugin/plugin_migration.py b/api/services/plugin/plugin_migration.py index 99946d8fa9..dec92a6faa 100644 --- a/api/services/plugin/plugin_migration.py +++ b/api/services/plugin/plugin_migration.py @@ -242,7 +242,7 @@ class PluginMigration: if data.get("type") == "tool": provider_name = data.get("provider_name") provider_type = data.get("provider_type") - if provider_name not in excluded_providers and provider_type == ToolProviderType.BUILT_IN.value: + if provider_name not in excluded_providers and provider_type == ToolProviderType.BUILT_IN: result.append(ToolProviderID(provider_name).plugin_id) return result @@ -269,9 +269,9 @@ class PluginMigration: for tool in agent_config["tools"]: if isinstance(tool, dict): try: - tool_entity = AgentToolEntity(**tool) + tool_entity = AgentToolEntity.model_validate(tool) if ( - tool_entity.provider_type == ToolProviderType.BUILT_IN.value + tool_entity.provider_type == ToolProviderType.BUILT_IN and tool_entity.provider_id not in excluded_providers ): result.append(ToolProviderID(tool_entity.provider_id).plugin_id) diff --git a/api/services/rag_pipeline/pipeline_template/remote/remote_retrieval.py b/api/services/rag_pipeline/pipeline_template/remote/remote_retrieval.py index 8f96842337..571ca6c7a6 100644 --- a/api/services/rag_pipeline/pipeline_template/remote/remote_retrieval.py +++ b/api/services/rag_pipeline/pipeline_template/remote/remote_retrieval.py @@ -1,6 +1,6 @@ import logging -import requests +import httpx from configs import dify_config from services.rag_pipeline.pipeline_template.database.database_retrieval import DatabasePipelineTemplateRetrieval @@ -43,7 +43,7 @@ class RemotePipelineTemplateRetrieval(PipelineTemplateRetrievalBase): """ domain = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_REMOTE_DOMAIN url = f"{domain}/pipeline-templates/{template_id}" - response = requests.get(url, timeout=(3, 10)) + response = httpx.get(url, timeout=httpx.Timeout(10.0, connect=3.0)) if response.status_code != 200: return None data: dict = response.json() @@ -58,7 +58,7 @@ class RemotePipelineTemplateRetrieval(PipelineTemplateRetrievalBase): """ domain = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_REMOTE_DOMAIN url = f"{domain}/pipeline-templates?language={language}" - response = requests.get(url, timeout=(3, 10)) + response = httpx.get(url, timeout=httpx.Timeout(10.0, connect=3.0)) if response.status_code != 200: raise ValueError(f"fetch pipeline templates failed, status code: {response.status_code}") diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index fdaaa73bcc..13c0ca7392 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -358,7 +358,7 @@ class RagPipelineService: for node in nodes: if node.get("data", {}).get("type") == "knowledge-index": knowledge_configuration = node.get("data", {}) - knowledge_configuration = KnowledgeConfiguration(**knowledge_configuration) + knowledge_configuration = KnowledgeConfiguration.model_validate(knowledge_configuration) # update dataset dataset = pipeline.retrieve_dataset(session=session) @@ -873,7 +873,7 @@ class RagPipelineService: variable_pool = node_instance.graph_runtime_state.variable_pool invoke_from = variable_pool.get(["sys", SystemVariableKey.INVOKE_FROM]) if invoke_from: - if invoke_from.value == InvokeFrom.PUBLISHED.value: + if invoke_from.value == InvokeFrom.PUBLISHED: document_id = variable_pool.get(["sys", SystemVariableKey.DOCUMENT_ID]) if document_id: document = db.session.query(Document).where(Document.id == document_id.value).first() diff --git a/api/services/rag_pipeline/rag_pipeline_dsl_service.py b/api/services/rag_pipeline/rag_pipeline_dsl_service.py index f74de1bcab..c02fad4dc6 100644 --- a/api/services/rag_pipeline/rag_pipeline_dsl_service.py +++ b/api/services/rag_pipeline/rag_pipeline_dsl_service.py @@ -288,7 +288,7 @@ class RagPipelineDslService: dataset_id = None for node in nodes: if node.get("data", {}).get("type") == "knowledge-index": - knowledge_configuration = KnowledgeConfiguration(**node.get("data", {})) + knowledge_configuration = KnowledgeConfiguration.model_validate(node.get("data", {})) if ( dataset and pipeline.is_published @@ -426,7 +426,7 @@ class RagPipelineDslService: dataset_id = None for node in nodes: if node.get("data", {}).get("type") == "knowledge-index": - knowledge_configuration = KnowledgeConfiguration(**node.get("data", {})) + knowledge_configuration = KnowledgeConfiguration.model_validate(node.get("data", {})) if not dataset: dataset = Dataset( tenant_id=account.current_tenant_id, @@ -556,7 +556,7 @@ class RagPipelineDslService: graph = workflow_data.get("graph", {}) for node in graph.get("nodes", []): - if node.get("data", {}).get("type", "") == NodeType.KNOWLEDGE_RETRIEVAL.value: + if node.get("data", {}).get("type", "") == NodeType.KNOWLEDGE_RETRIEVAL: dataset_ids = node["data"].get("dataset_ids", []) node["data"]["dataset_ids"] = [ decrypted_id @@ -613,7 +613,7 @@ class RagPipelineDslService: tenant_id=pipeline.tenant_id, app_id=pipeline.id, features="{}", - type=WorkflowType.RAG_PIPELINE.value, + type=WorkflowType.RAG_PIPELINE, version="draft", graph=json.dumps(graph), created_by=account.id, @@ -689,17 +689,17 @@ class RagPipelineDslService: if not node_data: continue data_type = node_data.get("type", "") - if data_type == NodeType.KNOWLEDGE_RETRIEVAL.value: + if data_type == NodeType.KNOWLEDGE_RETRIEVAL: dataset_ids = node_data.get("dataset_ids", []) node["data"]["dataset_ids"] = [ self.encrypt_dataset_id(dataset_id=dataset_id, tenant_id=pipeline.tenant_id) for dataset_id in dataset_ids ] # filter credential id from tool node - if not include_secret and data_type == NodeType.TOOL.value: + if not include_secret and data_type == NodeType.TOOL: node_data.pop("credential_id", None) # filter credential id from agent node - if not include_secret and data_type == NodeType.AGENT.value: + if not include_secret and data_type == NodeType.AGENT: for tool in node_data.get("agent_parameters", {}).get("tools", {}).get("value", []): tool.pop("credential_id", None) @@ -733,36 +733,36 @@ class RagPipelineDslService: try: typ = node.get("data", {}).get("type") match typ: - case NodeType.TOOL.value: - tool_entity = ToolNodeData(**node["data"]) + case NodeType.TOOL: + tool_entity = ToolNodeData.model_validate(node["data"]) dependencies.append( DependenciesAnalysisService.analyze_tool_dependency(tool_entity.provider_id), ) - case NodeType.DATASOURCE.value: - datasource_entity = DatasourceNodeData(**node["data"]) + case NodeType.DATASOURCE: + datasource_entity = DatasourceNodeData.model_validate(node["data"]) if datasource_entity.provider_type != "local_file": dependencies.append(datasource_entity.plugin_id) - case NodeType.LLM.value: - llm_entity = LLMNodeData(**node["data"]) + case NodeType.LLM: + llm_entity = LLMNodeData.model_validate(node["data"]) dependencies.append( DependenciesAnalysisService.analyze_model_provider_dependency(llm_entity.model.provider), ) - case NodeType.QUESTION_CLASSIFIER.value: - question_classifier_entity = QuestionClassifierNodeData(**node["data"]) + case NodeType.QUESTION_CLASSIFIER: + question_classifier_entity = QuestionClassifierNodeData.model_validate(node["data"]) dependencies.append( DependenciesAnalysisService.analyze_model_provider_dependency( question_classifier_entity.model.provider ), ) - case NodeType.PARAMETER_EXTRACTOR.value: - parameter_extractor_entity = ParameterExtractorNodeData(**node["data"]) + case NodeType.PARAMETER_EXTRACTOR: + parameter_extractor_entity = ParameterExtractorNodeData.model_validate(node["data"]) dependencies.append( DependenciesAnalysisService.analyze_model_provider_dependency( parameter_extractor_entity.model.provider ), ) - case NodeType.KNOWLEDGE_INDEX.value: - knowledge_index_entity = KnowledgeConfiguration(**node["data"]) + case NodeType.KNOWLEDGE_INDEX: + knowledge_index_entity = KnowledgeConfiguration.model_validate(node["data"]) if knowledge_index_entity.indexing_technique == "high_quality": if knowledge_index_entity.embedding_model_provider: dependencies.append( @@ -782,8 +782,8 @@ class RagPipelineDslService: knowledge_index_entity.retrieval_model.reranking_model.reranking_provider_name ), ) - case NodeType.KNOWLEDGE_RETRIEVAL.value: - knowledge_retrieval_entity = KnowledgeRetrievalNodeData(**node["data"]) + case NodeType.KNOWLEDGE_RETRIEVAL: + knowledge_retrieval_entity = KnowledgeRetrievalNodeData.model_validate(node["data"]) if knowledge_retrieval_entity.retrieval_mode == "multiple": if knowledge_retrieval_entity.multiple_retrieval_config: if ( @@ -873,7 +873,7 @@ class RagPipelineDslService: """ Returns the leaked dependencies in current workspace """ - dependencies = [PluginDependency(**dep) for dep in dsl_dependencies] + dependencies = [PluginDependency.model_validate(dep) for dep in dsl_dependencies] if not dependencies: return [] @@ -927,7 +927,7 @@ class RagPipelineDslService: account = cast(Account, current_user) rag_pipeline_import_info: RagPipelineImportInfo = self.import_rag_pipeline( account=account, - import_mode=ImportMode.YAML_CONTENT.value, + import_mode=ImportMode.YAML_CONTENT, yaml_content=rag_pipeline_dataset_create_entity.yaml_content, dataset=None, dataset_name=rag_pipeline_dataset_create_entity.name, diff --git a/api/services/rag_pipeline/rag_pipeline_transform_service.py b/api/services/rag_pipeline/rag_pipeline_transform_service.py index 3d5a85b57f..39f426a2b0 100644 --- a/api/services/rag_pipeline/rag_pipeline_transform_service.py +++ b/api/services/rag_pipeline/rag_pipeline_transform_service.py @@ -156,13 +156,13 @@ class RagPipelineTransformService: self, dataset: Dataset, doc_form: str, indexing_technique: str | None, retrieval_model: dict, node: dict ): knowledge_configuration_dict = node.get("data", {}) - knowledge_configuration = KnowledgeConfiguration(**knowledge_configuration_dict) + knowledge_configuration = KnowledgeConfiguration.model_validate(knowledge_configuration_dict) if indexing_technique == "high_quality": knowledge_configuration.embedding_model = dataset.embedding_model knowledge_configuration.embedding_model_provider = dataset.embedding_model_provider if retrieval_model: - retrieval_setting = RetrievalSetting(**retrieval_model) + retrieval_setting = RetrievalSetting.model_validate(retrieval_model) if indexing_technique == "economy": retrieval_setting.search_method = "keyword_search" knowledge_configuration.retrieval_model = retrieval_setting @@ -214,7 +214,7 @@ class RagPipelineTransformService: tenant_id=pipeline.tenant_id, app_id=pipeline.id, features="{}", - type=WorkflowType.RAG_PIPELINE.value, + type=WorkflowType.RAG_PIPELINE, version="draft", graph=json.dumps(graph), created_by=current_user.id, @@ -226,7 +226,7 @@ class RagPipelineTransformService: tenant_id=pipeline.tenant_id, app_id=pipeline.id, features="{}", - type=WorkflowType.RAG_PIPELINE.value, + type=WorkflowType.RAG_PIPELINE, version=str(datetime.now(UTC).replace(tzinfo=None)), graph=json.dumps(graph), created_by=current_user.id, diff --git a/api/services/recommend_app/remote/remote_retrieval.py b/api/services/recommend_app/remote/remote_retrieval.py index 2d57769f63..b217c9026a 100644 --- a/api/services/recommend_app/remote/remote_retrieval.py +++ b/api/services/recommend_app/remote/remote_retrieval.py @@ -1,6 +1,6 @@ import logging -import requests +import httpx from configs import dify_config from services.recommend_app.buildin.buildin_retrieval import BuildInRecommendAppRetrieval @@ -43,7 +43,7 @@ class RemoteRecommendAppRetrieval(RecommendAppRetrievalBase): """ domain = dify_config.HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN url = f"{domain}/apps/{app_id}" - response = requests.get(url, timeout=(3, 10)) + response = httpx.get(url, timeout=httpx.Timeout(10.0, connect=3.0)) if response.status_code != 200: return None data: dict = response.json() @@ -58,7 +58,7 @@ class RemoteRecommendAppRetrieval(RecommendAppRetrievalBase): """ domain = dify_config.HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN url = f"{domain}/apps?language={language}" - response = requests.get(url, timeout=(3, 10)) + response = httpx.get(url, timeout=httpx.Timeout(10.0, connect=3.0)) if response.status_code != 200: raise ValueError(f"fetch recommended apps failed, status code: {response.status_code}") diff --git a/api/services/tools/api_tools_manage_service.py b/api/services/tools/api_tools_manage_service.py index f86d7e51bf..2c0c63f634 100644 --- a/api/services/tools/api_tools_manage_service.py +++ b/api/services/tools/api_tools_manage_service.py @@ -277,7 +277,7 @@ class ApiToolManageService: provider.icon = json.dumps(icon) provider.schema = schema provider.description = extra_info.get("description", "") - provider.schema_type_str = ApiProviderSchemaType.OPENAPI.value + provider.schema_type_str = ApiProviderSchemaType.OPENAPI provider.tools_str = json.dumps(jsonable_encoder(tool_bundles)) provider.privacy_policy = privacy_policy provider.custom_disclaimer = custom_disclaimer @@ -393,7 +393,7 @@ class ApiToolManageService: icon="", schema=schema, description="", - schema_type_str=ApiProviderSchemaType.OPENAPI.value, + schema_type_str=ApiProviderSchemaType.OPENAPI, tools_str=json.dumps(jsonable_encoder(tool_bundles)), credentials_str=json.dumps(credentials), ) diff --git a/api/services/tools/tools_transform_service.py b/api/services/tools/tools_transform_service.py index 6b36ed0eb7..81b4d6993a 100644 --- a/api/services/tools/tools_transform_service.py +++ b/api/services/tools/tools_transform_service.py @@ -50,16 +50,16 @@ class ToolTransformService: URL(dify_config.CONSOLE_API_URL or "/") / "console" / "api" / "workspaces" / "current" / "tool-provider" ) - if provider_type == ToolProviderType.BUILT_IN.value: + if provider_type == ToolProviderType.BUILT_IN: return str(url_prefix / "builtin" / provider_name / "icon") - elif provider_type in {ToolProviderType.API.value, ToolProviderType.WORKFLOW.value}: + elif provider_type in {ToolProviderType.API, ToolProviderType.WORKFLOW}: try: if isinstance(icon, str): return json.loads(icon) return icon except Exception: return {"background": "#252525", "content": "\ud83d\ude01"} - elif provider_type == ToolProviderType.MCP.value: + elif provider_type == ToolProviderType.MCP: return icon return "" @@ -242,7 +242,7 @@ class ToolTransformService: is_team_authorization=db_provider.authed, server_url=db_provider.masked_server_url, tools=ToolTransformService.mcp_tool_to_user_tool( - db_provider, [MCPTool(**tool) for tool in json.loads(db_provider.tools)] + db_provider, [MCPTool.model_validate(tool) for tool in json.loads(db_provider.tools)] ), updated_at=int(db_provider.updated_at.timestamp()), label=I18nObject(en_US=db_provider.name, zh_Hans=db_provider.name), @@ -387,6 +387,7 @@ class ToolTransformService: labels=labels or [], ) else: + assert tool.operation_id return ToolApiEntity( author=tool.author, name=tool.operation_id or "", diff --git a/api/services/vector_service.py b/api/services/vector_service.py index 1c559f2c2b..abc92a0181 100644 --- a/api/services/vector_service.py +++ b/api/services/vector_service.py @@ -134,7 +134,7 @@ class VectorService: ) # use full doc mode to generate segment's child chunk processing_rule_dict = processing_rule.to_dict() - processing_rule_dict["rules"]["parent_mode"] = ParentMode.FULL_DOC.value + processing_rule_dict["rules"]["parent_mode"] = ParentMode.FULL_DOC documents = index_processor.transform( [document], embedding_model_instance=embedding_model_instance, diff --git a/api/services/webapp_auth_service.py b/api/services/webapp_auth_service.py index 066dc9d741..d30e14f7a1 100644 --- a/api/services/webapp_auth_service.py +++ b/api/services/webapp_auth_service.py @@ -36,7 +36,7 @@ class WebAppAuthService: if not account: raise AccountNotFoundError() - if account.status == AccountStatus.BANNED.value: + if account.status == AccountStatus.BANNED: raise AccountLoginError("Account is banned.") if account.password is None or not compare_password(password, account.password, account.password_salt): @@ -56,7 +56,7 @@ class WebAppAuthService: if not account: return None - if account.status == AccountStatus.BANNED.value: + if account.status == AccountStatus.BANNED: raise Unauthorized("Account is banned.") return account diff --git a/api/services/workflow/workflow_converter.py b/api/services/workflow/workflow_converter.py index ce7d16b3bd..9c09f54bf5 100644 --- a/api/services/workflow/workflow_converter.py +++ b/api/services/workflow/workflow_converter.py @@ -228,7 +228,7 @@ class WorkflowConverter: "position": None, "data": { "title": "START", - "type": NodeType.START.value, + "type": NodeType.START, "variables": [jsonable_encoder(v) for v in variables], }, } @@ -273,7 +273,7 @@ class WorkflowConverter: inputs[v.variable] = "{{#start." + v.variable + "#}}" request_body = { - "point": APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY.value, + "point": APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY, "params": { "app_id": app_model.id, "tool_variable": tool_variable, @@ -290,7 +290,7 @@ class WorkflowConverter: "position": None, "data": { "title": f"HTTP REQUEST {api_based_extension.name}", - "type": NodeType.HTTP_REQUEST.value, + "type": NodeType.HTTP_REQUEST, "method": "post", "url": api_based_extension.api_endpoint, "authorization": {"type": "api-key", "config": {"type": "bearer", "api_key": api_key}}, @@ -308,7 +308,7 @@ class WorkflowConverter: "position": None, "data": { "title": f"Parse {api_based_extension.name} Response", - "type": NodeType.CODE.value, + "type": NodeType.CODE, "variables": [{"variable": "response_json", "value_selector": [http_request_node["id"], "body"]}], "code_language": "python3", "code": "import json\n\ndef main(response_json: str) -> str:\n response_body = json.loads(" @@ -348,7 +348,7 @@ class WorkflowConverter: "position": None, "data": { "title": "KNOWLEDGE RETRIEVAL", - "type": NodeType.KNOWLEDGE_RETRIEVAL.value, + "type": NodeType.KNOWLEDGE_RETRIEVAL, "query_variable_selector": query_variable_selector, "dataset_ids": dataset_config.dataset_ids, "retrieval_mode": retrieve_config.retrieve_strategy.value, @@ -396,16 +396,16 @@ class WorkflowConverter: :param external_data_variable_node_mapping: external data variable node mapping """ # fetch start and knowledge retrieval node - start_node = next(filter(lambda n: n["data"]["type"] == NodeType.START.value, graph["nodes"])) + start_node = next(filter(lambda n: n["data"]["type"] == NodeType.START, graph["nodes"])) knowledge_retrieval_node = next( - filter(lambda n: n["data"]["type"] == NodeType.KNOWLEDGE_RETRIEVAL.value, graph["nodes"]), None + filter(lambda n: n["data"]["type"] == NodeType.KNOWLEDGE_RETRIEVAL, graph["nodes"]), None ) role_prefix = None prompts: Any | None = None # Chat Model - if model_config.mode == LLMMode.CHAT.value: + if model_config.mode == LLMMode.CHAT: if prompt_template.prompt_type == PromptTemplateEntity.PromptType.SIMPLE: if not prompt_template.simple_prompt_template: raise ValueError("Simple prompt template is required") @@ -517,7 +517,7 @@ class WorkflowConverter: "position": None, "data": { "title": "LLM", - "type": NodeType.LLM.value, + "type": NodeType.LLM, "model": { "provider": model_config.provider, "name": model_config.model, @@ -572,7 +572,7 @@ class WorkflowConverter: "position": None, "data": { "title": "END", - "type": NodeType.END.value, + "type": NodeType.END, "outputs": [{"variable": "result", "value_selector": ["llm", "text"]}], }, } @@ -586,7 +586,7 @@ class WorkflowConverter: return { "id": "answer", "position": None, - "data": {"title": "ANSWER", "type": NodeType.ANSWER.value, "answer": "{{#llm.text#}}"}, + "data": {"title": "ANSWER", "type": NodeType.ANSWER, "answer": "{{#llm.text#}}"}, } def _create_edge(self, source: str, target: str): diff --git a/api/services/workflow_draft_variable_service.py b/api/services/workflow_draft_variable_service.py index 1378c20128..344b7486ee 100644 --- a/api/services/workflow_draft_variable_service.py +++ b/api/services/workflow_draft_variable_service.py @@ -569,7 +569,7 @@ class WorkflowDraftVariableService: system_instruction="", system_instruction_tokens=0, status="normal", - invoke_from=InvokeFrom.DEBUGGER.value, + invoke_from=InvokeFrom.DEBUGGER, from_source="console", from_end_user_id=None, from_account_id=account_id, diff --git a/api/services/workflow_run_service.py b/api/services/workflow_run_service.py index 79d91cab4c..6a2edd912a 100644 --- a/api/services/workflow_run_service.py +++ b/api/services/workflow_run_service.py @@ -74,7 +74,7 @@ class WorkflowRunService: return self._workflow_run_repo.get_paginated_workflow_runs( tenant_id=app_model.tenant_id, app_id=app_model.id, - triggered_from=WorkflowRunTriggeredFrom.DEBUGGING.value, + triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, limit=limit, last_id=last_id, ) diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 359fdb85fd..dea6a657a4 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -1006,7 +1006,7 @@ def _setup_variable_pool( ) # Only add chatflow-specific variables for non-workflow types - if workflow.type != WorkflowType.WORKFLOW.value: + if workflow.type != WorkflowType.WORKFLOW: system_variable.query = query system_variable.conversation_id = conversation_id system_variable.dialogue_count = 1 diff --git a/api/tasks/ops_trace_task.py b/api/tasks/ops_trace_task.py index 7b254ac3b5..72e3b42ca7 100644 --- a/api/tasks/ops_trace_task.py +++ b/api/tasks/ops_trace_task.py @@ -36,7 +36,7 @@ def process_trace_tasks(file_info): if trace_info.get("workflow_data"): trace_info["workflow_data"] = WorkflowRun.from_dict(data=trace_info["workflow_data"]) if trace_info.get("documents"): - trace_info["documents"] = [Document(**doc) for doc in trace_info["documents"]] + trace_info["documents"] = [Document.model_validate(doc) for doc in trace_info["documents"]] try: if trace_instance: diff --git a/api/tasks/rag_pipeline/priority_rag_pipeline_run_task.py b/api/tasks/rag_pipeline/priority_rag_pipeline_run_task.py index a2c99554f1..4171656131 100644 --- a/api/tasks/rag_pipeline/priority_rag_pipeline_run_task.py +++ b/api/tasks/rag_pipeline/priority_rag_pipeline_run_task.py @@ -79,7 +79,7 @@ def run_single_rag_pipeline_task(rag_pipeline_invoke_entity: Mapping[str, Any], # Create Flask application context for this thread with flask_app.app_context(): try: - rag_pipeline_invoke_entity_model = RagPipelineInvokeEntity(**rag_pipeline_invoke_entity) + rag_pipeline_invoke_entity_model = RagPipelineInvokeEntity.model_validate(rag_pipeline_invoke_entity) user_id = rag_pipeline_invoke_entity_model.user_id tenant_id = rag_pipeline_invoke_entity_model.tenant_id pipeline_id = rag_pipeline_invoke_entity_model.pipeline_id @@ -112,7 +112,7 @@ def run_single_rag_pipeline_task(rag_pipeline_invoke_entity: Mapping[str, Any], workflow_execution_id = str(uuid.uuid4()) # Create application generate entity from dict - entity = RagPipelineGenerateEntity(**application_generate_entity) + entity = RagPipelineGenerateEntity.model_validate(application_generate_entity) # Create workflow repositories session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) diff --git a/api/tasks/rag_pipeline/rag_pipeline_run_task.py b/api/tasks/rag_pipeline/rag_pipeline_run_task.py index 4e00f072bf..90ebe80daf 100644 --- a/api/tasks/rag_pipeline/rag_pipeline_run_task.py +++ b/api/tasks/rag_pipeline/rag_pipeline_run_task.py @@ -100,7 +100,7 @@ def run_single_rag_pipeline_task(rag_pipeline_invoke_entity: Mapping[str, Any], # Create Flask application context for this thread with flask_app.app_context(): try: - rag_pipeline_invoke_entity_model = RagPipelineInvokeEntity(**rag_pipeline_invoke_entity) + rag_pipeline_invoke_entity_model = RagPipelineInvokeEntity.model_validate(rag_pipeline_invoke_entity) user_id = rag_pipeline_invoke_entity_model.user_id tenant_id = rag_pipeline_invoke_entity_model.tenant_id pipeline_id = rag_pipeline_invoke_entity_model.pipeline_id @@ -133,7 +133,7 @@ def run_single_rag_pipeline_task(rag_pipeline_invoke_entity: Mapping[str, Any], workflow_execution_id = str(uuid.uuid4()) # Create application generate entity from dict - entity = RagPipelineGenerateEntity(**application_generate_entity) + entity = RagPipelineGenerateEntity.model_validate(application_generate_entity) # Create workflow repositories session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) diff --git a/api/tests/integration_tests/controllers/console/app/test_chat_message_permissions.py b/api/tests/integration_tests/controllers/console/app/test_chat_message_permissions.py index da1524ff2e..498ac56d5d 100644 --- a/api/tests/integration_tests/controllers/console/app/test_chat_message_permissions.py +++ b/api/tests/integration_tests/controllers/console/app/test_chat_message_permissions.py @@ -25,7 +25,7 @@ class TestChatMessageApiPermissions: """Create a mock App model for testing.""" app = App() app.id = str(uuid.uuid4()) - app.mode = AppMode.CHAT.value + app.mode = AppMode.CHAT app.tenant_id = str(uuid.uuid4()) app.status = "normal" return app @@ -33,17 +33,19 @@ class TestChatMessageApiPermissions: @pytest.fixture def mock_account(self, monkeypatch: pytest.MonkeyPatch): """Create a mock Account for testing.""" - account = Account() - account.id = str(uuid.uuid4()) - account.name = "Test User" - account.email = "test@example.com" + + account = Account( + name="Test User", + email="test@example.com", + ) account.last_active_at = naive_utc_now() account.created_at = naive_utc_now() account.updated_at = naive_utc_now() + account.id = str(uuid.uuid4()) - tenant = Tenant() + # Create mock tenant + tenant = Tenant(name="Test Tenant") tenant.id = str(uuid.uuid4()) - tenant.name = "Test Tenant" mock_session_instance = mock.Mock() diff --git a/api/tests/integration_tests/controllers/console/app/test_model_config_permissions.py b/api/tests/integration_tests/controllers/console/app/test_model_config_permissions.py index c0fd56ef63..04945e57a0 100644 --- a/api/tests/integration_tests/controllers/console/app/test_model_config_permissions.py +++ b/api/tests/integration_tests/controllers/console/app/test_model_config_permissions.py @@ -23,7 +23,7 @@ class TestModelConfigResourcePermissions: """Create a mock App model for testing.""" app = App() app.id = str(uuid.uuid4()) - app.mode = AppMode.CHAT.value + app.mode = AppMode.CHAT app.tenant_id = str(uuid.uuid4()) app.status = "normal" app.app_model_config_id = str(uuid.uuid4()) @@ -32,17 +32,16 @@ class TestModelConfigResourcePermissions: @pytest.fixture def mock_account(self, monkeypatch: pytest.MonkeyPatch): """Create a mock Account for testing.""" - account = Account() + + account = Account(name="Test User", email="test@example.com") account.id = str(uuid.uuid4()) - account.name = "Test User" - account.email = "test@example.com" account.last_active_at = naive_utc_now() account.created_at = naive_utc_now() account.updated_at = naive_utc_now() - tenant = Tenant() + # Create mock tenant + tenant = Tenant(name="Test Tenant") tenant.id = str(uuid.uuid4()) - tenant.name = "Test Tenant" mock_session_instance = mock.Mock() diff --git a/api/tests/integration_tests/services/test_workflow_draft_variable_service.py b/api/tests/integration_tests/services/test_workflow_draft_variable_service.py index aeee882750..f3a5ba0d11 100644 --- a/api/tests/integration_tests/services/test_workflow_draft_variable_service.py +++ b/api/tests/integration_tests/services/test_workflow_draft_variable_service.py @@ -542,7 +542,7 @@ class TestWorkflowDraftVariableServiceResetVariable(unittest.TestCase): index=1, node_execution_id=str(uuid.uuid4()), node_id=self._node_id, - node_type=NodeType.LLM.value, + node_type=NodeType.LLM, title="Test Node", inputs='{"input": "test input"}', process_data='{"test_var": "process_value", "other_var": "other_process"}', diff --git a/api/tests/integration_tests/tools/api_tool/test_api_tool.py b/api/tests/integration_tests/tools/api_tool/test_api_tool.py index 7c1a200c8f..e637530265 100644 --- a/api/tests/integration_tests/tools/api_tool/test_api_tool.py +++ b/api/tests/integration_tests/tools/api_tool/test_api_tool.py @@ -36,7 +36,7 @@ def test_api_tool(setup_http_mock): entity=ToolEntity( identity=ToolIdentity(provider="", author="", name="", label=I18nObject(en_US="test tool")), ), - api_bundle=ApiToolBundle(**tool_bundle), + api_bundle=ApiToolBundle.model_validate(tool_bundle), runtime=ToolRuntime(tenant_id="", credentials={"auth_type": "none"}), provider_id="test_tool", ) diff --git a/api/tests/integration_tests/vdb/__mock/baiduvectordb.py b/api/tests/integration_tests/vdb/__mock/baiduvectordb.py index 6d2aff5197..8a43d03a43 100644 --- a/api/tests/integration_tests/vdb/__mock/baiduvectordb.py +++ b/api/tests/integration_tests/vdb/__mock/baiduvectordb.py @@ -1,5 +1,6 @@ import os from collections import UserDict +from typing import Any from unittest.mock import MagicMock import pytest @@ -9,7 +10,6 @@ from pymochow.model.database import Database # type: ignore from pymochow.model.enum import IndexState, IndexType, MetricType, ReadConsistency, TableState # type: ignore from pymochow.model.schema import HNSWParams, VectorIndex # type: ignore from pymochow.model.table import Table # type: ignore -from requests.adapters import HTTPAdapter class AttrDict(UserDict): @@ -21,7 +21,7 @@ class MockBaiduVectorDBClass: def mock_vector_db_client( self, config=None, - adapter: HTTPAdapter | None = None, + adapter: Any | None = None, ): self.conn = MagicMock() self._config = MagicMock() diff --git a/api/tests/integration_tests/vdb/__mock/huaweicloudvectordb.py b/api/tests/integration_tests/vdb/__mock/huaweicloudvectordb.py index 9706c52455..9e24672317 100644 --- a/api/tests/integration_tests/vdb/__mock/huaweicloudvectordb.py +++ b/api/tests/integration_tests/vdb/__mock/huaweicloudvectordb.py @@ -44,25 +44,25 @@ class MockClient: "hits": [ { "_source": { - Field.CONTENT_KEY.value: "abcdef", - Field.VECTOR.value: [1, 2], - Field.METADATA_KEY.value: {}, + Field.CONTENT_KEY: "abcdef", + Field.VECTOR: [1, 2], + Field.METADATA_KEY: {}, }, "_score": 1.0, }, { "_source": { - Field.CONTENT_KEY.value: "123456", - Field.VECTOR.value: [2, 2], - Field.METADATA_KEY.value: {}, + Field.CONTENT_KEY: "123456", + Field.VECTOR: [2, 2], + Field.METADATA_KEY: {}, }, "_score": 0.9, }, { "_source": { - Field.CONTENT_KEY.value: "a1b2c3", - Field.VECTOR.value: [3, 2], - Field.METADATA_KEY.value: {}, + Field.CONTENT_KEY: "a1b2c3", + Field.VECTOR: [3, 2], + Field.METADATA_KEY: {}, }, "_score": 0.8, }, diff --git a/api/tests/integration_tests/vdb/__mock/tcvectordb.py b/api/tests/integration_tests/vdb/__mock/tcvectordb.py index e0b908cece..5130fcfe17 100644 --- a/api/tests/integration_tests/vdb/__mock/tcvectordb.py +++ b/api/tests/integration_tests/vdb/__mock/tcvectordb.py @@ -1,9 +1,8 @@ import os -from typing import Union +from typing import Any, Union import pytest from _pytest.monkeypatch import MonkeyPatch -from requests.adapters import HTTPAdapter from tcvectordb import RPCVectorDBClient # type: ignore from tcvectordb.model import enum from tcvectordb.model.collection import FilterIndexConfig @@ -23,7 +22,7 @@ class MockTcvectordbClass: key="", read_consistency: ReadConsistency = ReadConsistency.EVENTUAL_CONSISTENCY, timeout=10, - adapter: HTTPAdapter | None = None, + adapter: Any | None = None, pool_size: int = 2, proxies: dict | None = None, password: str | None = None, diff --git a/api/tests/integration_tests/vdb/__mock/vikingdb.py b/api/tests/integration_tests/vdb/__mock/vikingdb.py index 3ad72e5550..f351df8d5b 100644 --- a/api/tests/integration_tests/vdb/__mock/vikingdb.py +++ b/api/tests/integration_tests/vdb/__mock/vikingdb.py @@ -40,13 +40,13 @@ class MockVikingDBClass: collection_name=collection_name, description="Collection For Dify", viking_db_service=self._viking_db_service, - primary_key=vdb_Field.PRIMARY_KEY.value, + primary_key=vdb_Field.PRIMARY_KEY, fields=[ - Field(field_name=vdb_Field.PRIMARY_KEY.value, field_type=FieldType.String, is_primary_key=True), - Field(field_name=vdb_Field.METADATA_KEY.value, field_type=FieldType.String), - Field(field_name=vdb_Field.GROUP_KEY.value, field_type=FieldType.String), - Field(field_name=vdb_Field.CONTENT_KEY.value, field_type=FieldType.Text), - Field(field_name=vdb_Field.VECTOR.value, field_type=FieldType.Vector, dim=768), + Field(field_name=vdb_Field.PRIMARY_KEY, field_type=FieldType.String, is_primary_key=True), + Field(field_name=vdb_Field.METADATA_KEY, field_type=FieldType.String), + Field(field_name=vdb_Field.GROUP_KEY, field_type=FieldType.String), + Field(field_name=vdb_Field.CONTENT_KEY, field_type=FieldType.Text), + Field(field_name=vdb_Field.VECTOR, field_type=FieldType.Vector, dim=768), ], indexes=[ Index( @@ -71,7 +71,7 @@ class MockVikingDBClass: return Collection( collection_name=collection_name, description=description, - primary_key=vdb_Field.PRIMARY_KEY.value, + primary_key=vdb_Field.PRIMARY_KEY, viking_db_service=self._viking_db_service, fields=fields, ) @@ -126,11 +126,11 @@ class MockVikingDBClass: def fetch_data(self, id: Union[str, list[str], int, list[int]]): return Data( fields={ - vdb_Field.GROUP_KEY.value: "test_group", - vdb_Field.METADATA_KEY.value: "{}", - vdb_Field.CONTENT_KEY.value: "content", - vdb_Field.PRIMARY_KEY.value: id, - vdb_Field.VECTOR.value: [-0.00762577411336441, -0.01949881482151406, 0.008832383941428398], + vdb_Field.GROUP_KEY: "test_group", + vdb_Field.METADATA_KEY: "{}", + vdb_Field.CONTENT_KEY: "content", + vdb_Field.PRIMARY_KEY: id, + vdb_Field.VECTOR: [-0.00762577411336441, -0.01949881482151406, 0.008832383941428398], }, id=id, ) @@ -151,16 +151,16 @@ class MockVikingDBClass: return [ Data( fields={ - vdb_Field.GROUP_KEY.value: "test_group", - vdb_Field.METADATA_KEY.value: '\ + vdb_Field.GROUP_KEY: "test_group", + vdb_Field.METADATA_KEY: '\ {"source": "/var/folders/ml/xxx/xxx.txt", \ "document_id": "test_document_id", \ "dataset_id": "test_dataset_id", \ "doc_id": "test_id", \ "doc_hash": "test_hash"}', - vdb_Field.CONTENT_KEY.value: "content", - vdb_Field.PRIMARY_KEY.value: "test_id", - vdb_Field.VECTOR.value: vector, + vdb_Field.CONTENT_KEY: "content", + vdb_Field.PRIMARY_KEY: "test_id", + vdb_Field.VECTOR: vector, }, id="test_id", score=0.10, @@ -173,16 +173,16 @@ class MockVikingDBClass: return [ Data( fields={ - vdb_Field.GROUP_KEY.value: "test_group", - vdb_Field.METADATA_KEY.value: '\ + vdb_Field.GROUP_KEY: "test_group", + vdb_Field.METADATA_KEY: '\ {"source": "/var/folders/ml/xxx/xxx.txt", \ "document_id": "test_document_id", \ "dataset_id": "test_dataset_id", \ "doc_id": "test_id", \ "doc_hash": "test_hash"}', - vdb_Field.CONTENT_KEY.value: "content", - vdb_Field.PRIMARY_KEY.value: "test_id", - vdb_Field.VECTOR.value: [-0.00762577411336441, -0.01949881482151406, 0.008832383941428398], + vdb_Field.CONTENT_KEY: "content", + vdb_Field.PRIMARY_KEY: "test_id", + vdb_Field.VECTOR: [-0.00762577411336441, -0.01949881482151406, 0.008832383941428398], }, id="test_id", score=0.10, diff --git a/api/tests/integration_tests/vdb/opensearch/test_opensearch.py b/api/tests/integration_tests/vdb/opensearch/test_opensearch.py index 2d44dd2924..192c995ce5 100644 --- a/api/tests/integration_tests/vdb/opensearch/test_opensearch.py +++ b/api/tests/integration_tests/vdb/opensearch/test_opensearch.py @@ -129,8 +129,8 @@ class TestOpenSearchVector: "hits": [ { "_source": { - Field.CONTENT_KEY.value: get_example_text(), - Field.METADATA_KEY.value: {"document_id": self.example_doc_id}, + Field.CONTENT_KEY: get_example_text(), + Field.METADATA_KEY: {"document_id": self.example_doc_id}, }, "_score": 1.0, } diff --git a/api/tests/test_containers_integration_tests/services/test_account_service.py b/api/tests/test_containers_integration_tests/services/test_account_service.py index c98406d845..6eff73a8f3 100644 --- a/api/tests/test_containers_integration_tests/services/test_account_service.py +++ b/api/tests/test_containers_integration_tests/services/test_account_service.py @@ -16,6 +16,7 @@ from services.errors.account import ( AccountPasswordError, AccountRegisterError, CurrentPasswordIncorrectError, + TenantNotFoundError, ) from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkspacesLimitExceededError @@ -63,7 +64,7 @@ class TestAccountService: password=password, ) assert account.email == email - assert account.status == AccountStatus.ACTIVE.value + assert account.status == AccountStatus.ACTIVE # Login with correct password logged_in = AccountService.authenticate(email, password) @@ -184,7 +185,7 @@ class TestAccountService: ) # Ban the account - account.status = AccountStatus.BANNED.value + account.status = AccountStatus.BANNED from extensions.ext_database import db db.session.commit() @@ -268,14 +269,14 @@ class TestAccountService: interface_language="en-US", password=password, ) - account.status = AccountStatus.PENDING.value + account.status = AccountStatus.PENDING from extensions.ext_database import db db.session.commit() # Authenticate should activate the account authenticated_account = AccountService.authenticate(email, password) - assert authenticated_account.status == AccountStatus.ACTIVE.value + assert authenticated_account.status == AccountStatus.ACTIVE assert authenticated_account.initialized_at is not None def test_update_account_password_success(self, db_session_with_containers, mock_external_service_dependencies): @@ -538,7 +539,7 @@ class TestAccountService: from extensions.ext_database import db db.session.refresh(account) - assert account.status == AccountStatus.CLOSED.value + assert account.status == AccountStatus.CLOSED def test_update_account_fields(self, db_session_with_containers, mock_external_service_dependencies): """ @@ -678,7 +679,7 @@ class TestAccountService: interface_language="en-US", password=password, ) - account.status = AccountStatus.PENDING.value + account.status = AccountStatus.PENDING from extensions.ext_database import db db.session.commit() @@ -687,7 +688,7 @@ class TestAccountService: token_pair = AccountService.login(account) db.session.refresh(account) - assert account.status == AccountStatus.ACTIVE.value + assert account.status == AccountStatus.ACTIVE def test_logout(self, db_session_with_containers, mock_external_service_dependencies): """ @@ -859,7 +860,7 @@ class TestAccountService: ) # Ban the account - account.status = AccountStatus.BANNED.value + account.status = AccountStatus.BANNED from extensions.ext_database import db db.session.commit() @@ -989,7 +990,7 @@ class TestAccountService: ) # Ban the account - account.status = AccountStatus.BANNED.value + account.status = AccountStatus.BANNED from extensions.ext_database import db db.session.commit() @@ -1414,7 +1415,7 @@ class TestTenantService: ) # Try to get current tenant (should fail) - with pytest.raises(AttributeError): + with pytest.raises((AttributeError, TenantNotFoundError)): TenantService.get_current_tenant_by_account(account) def test_switch_tenant_success(self, db_session_with_containers, mock_external_service_dependencies): diff --git a/api/tests/test_containers_integration_tests/services/test_file_service.py b/api/tests/test_containers_integration_tests/services/test_file_service.py index 5598c5bc0c..e6bfc157c7 100644 --- a/api/tests/test_containers_integration_tests/services/test_file_service.py +++ b/api/tests/test_containers_integration_tests/services/test_file_service.py @@ -86,7 +86,7 @@ class TestFileService: join = TenantAccountJoin( tenant_id=tenant.id, account_id=account.id, - role=TenantAccountRole.OWNER.value, + role=TenantAccountRole.OWNER, current=True, ) db.session.add(join) @@ -187,7 +187,7 @@ class TestFileService: assert upload_file.extension == "pdf" assert upload_file.mime_type == mimetype assert upload_file.created_by == account.id - assert upload_file.created_by_role == CreatorUserRole.ACCOUNT.value + assert upload_file.created_by_role == CreatorUserRole.ACCOUNT assert upload_file.used is False assert upload_file.hash == hashlib.sha3_256(content).hexdigest() @@ -216,7 +216,7 @@ class TestFileService: assert upload_file is not None assert upload_file.created_by == end_user.id - assert upload_file.created_by_role == CreatorUserRole.END_USER.value + assert upload_file.created_by_role == CreatorUserRole.END_USER def test_upload_file_with_datasets_source( self, db_session_with_containers, engine, mock_external_service_dependencies diff --git a/api/tests/test_containers_integration_tests/services/test_metadata_service.py b/api/tests/test_containers_integration_tests/services/test_metadata_service.py index d0f7e945f1..253791cc2d 100644 --- a/api/tests/test_containers_integration_tests/services/test_metadata_service.py +++ b/api/tests/test_containers_integration_tests/services/test_metadata_service.py @@ -72,7 +72,7 @@ class TestMetadataService: join = TenantAccountJoin( tenant_id=tenant.id, account_id=account.id, - role=TenantAccountRole.OWNER.value, + role=TenantAccountRole.OWNER, current=True, ) db.session.add(join) diff --git a/api/tests/test_containers_integration_tests/services/test_model_load_balancing_service.py b/api/tests/test_containers_integration_tests/services/test_model_load_balancing_service.py index 66527dd506..8a72331425 100644 --- a/api/tests/test_containers_integration_tests/services/test_model_load_balancing_service.py +++ b/api/tests/test_containers_integration_tests/services/test_model_load_balancing_service.py @@ -103,7 +103,7 @@ class TestModelLoadBalancingService: join = TenantAccountJoin( tenant_id=tenant.id, account_id=account.id, - role=TenantAccountRole.OWNER.value, + role=TenantAccountRole.OWNER, current=True, ) db.session.add(join) diff --git a/api/tests/test_containers_integration_tests/services/test_model_provider_service.py b/api/tests/test_containers_integration_tests/services/test_model_provider_service.py index 2196da8b3e..fb319a4963 100644 --- a/api/tests/test_containers_integration_tests/services/test_model_provider_service.py +++ b/api/tests/test_containers_integration_tests/services/test_model_provider_service.py @@ -67,7 +67,7 @@ class TestModelProviderService: join = TenantAccountJoin( tenant_id=tenant.id, account_id=account.id, - role=TenantAccountRole.OWNER.value, + role=TenantAccountRole.OWNER, current=True, ) db.session.add(join) diff --git a/api/tests/test_containers_integration_tests/services/test_tag_service.py b/api/tests/test_containers_integration_tests/services/test_tag_service.py index 04cff397b2..3d1226019b 100644 --- a/api/tests/test_containers_integration_tests/services/test_tag_service.py +++ b/api/tests/test_containers_integration_tests/services/test_tag_service.py @@ -66,7 +66,7 @@ class TestTagService: join = TenantAccountJoin( tenant_id=tenant.id, account_id=account.id, - role=TenantAccountRole.OWNER.value, + role=TenantAccountRole.OWNER, current=True, ) db.session.add(join) diff --git a/api/tests/test_containers_integration_tests/services/test_web_conversation_service.py b/api/tests/test_containers_integration_tests/services/test_web_conversation_service.py index c9ace46c55..5db7901cbc 100644 --- a/api/tests/test_containers_integration_tests/services/test_web_conversation_service.py +++ b/api/tests/test_containers_integration_tests/services/test_web_conversation_service.py @@ -144,7 +144,7 @@ class TestWebConversationService: system_instruction=fake.text(max_nb_chars=300), system_instruction_tokens=50, status="normal", - invoke_from=InvokeFrom.WEB_APP.value, + invoke_from=InvokeFrom.WEB_APP, from_source="console" if isinstance(user, Account) else "api", from_end_user_id=user.id if isinstance(user, EndUser) else None, from_account_id=user.id if isinstance(user, Account) else None, diff --git a/api/tests/test_containers_integration_tests/services/test_webapp_auth_service.py b/api/tests/test_containers_integration_tests/services/test_webapp_auth_service.py index 316cfe1674..059767458a 100644 --- a/api/tests/test_containers_integration_tests/services/test_webapp_auth_service.py +++ b/api/tests/test_containers_integration_tests/services/test_webapp_auth_service.py @@ -87,7 +87,7 @@ class TestWebAppAuthService: join = TenantAccountJoin( tenant_id=tenant.id, account_id=account.id, - role=TenantAccountRole.OWNER.value, + role=TenantAccountRole.OWNER, current=True, ) db.session.add(join) @@ -150,7 +150,7 @@ class TestWebAppAuthService: join = TenantAccountJoin( tenant_id=tenant.id, account_id=account.id, - role=TenantAccountRole.OWNER.value, + role=TenantAccountRole.OWNER, current=True, ) db.session.add(join) @@ -232,7 +232,7 @@ class TestWebAppAuthService: assert result.id == account.id assert result.email == account.email assert result.name == account.name - assert result.status == AccountStatus.ACTIVE.value + assert result.status == AccountStatus.ACTIVE # Verify database state from extensions.ext_database import db @@ -280,7 +280,7 @@ class TestWebAppAuthService: email=fake.email(), name=fake.name(), interface_language="en-US", - status=AccountStatus.BANNED.value, + status=AccountStatus.BANNED, ) # Hash password @@ -411,7 +411,7 @@ class TestWebAppAuthService: assert result.id == account.id assert result.email == account.email assert result.name == account.name - assert result.status == AccountStatus.ACTIVE.value + assert result.status == AccountStatus.ACTIVE # Verify database state from extensions.ext_database import db @@ -455,7 +455,7 @@ class TestWebAppAuthService: email=unique_email, name=fake.name(), interface_language="en-US", - status=AccountStatus.BANNED.value, + status=AccountStatus.BANNED, ) from extensions.ext_database import db diff --git a/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py b/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py index 2e18184aea..62c9bead86 100644 --- a/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py +++ b/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py @@ -199,7 +199,7 @@ class TestWorkflowAppService: elapsed_time=1.5, total_tokens=100, total_steps=3, - created_by_role=CreatorUserRole.ACCOUNT.value, + created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, created_at=datetime.now(UTC), finished_at=datetime.now(UTC), @@ -215,7 +215,7 @@ class TestWorkflowAppService: workflow_id=workflow.id, workflow_run_id=workflow_run.id, created_from="service-api", - created_by_role=CreatorUserRole.ACCOUNT.value, + created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, created_at=datetime.now(UTC), ) @@ -356,7 +356,7 @@ class TestWorkflowAppService: elapsed_time=1.0 + i, total_tokens=100 + i * 10, total_steps=3, - created_by_role=CreatorUserRole.ACCOUNT.value, + created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, created_at=datetime.now(UTC) + timedelta(minutes=i), finished_at=datetime.now(UTC) + timedelta(minutes=i + 1) if status != "running" else None, @@ -371,7 +371,7 @@ class TestWorkflowAppService: workflow_id=workflow.id, workflow_run_id=workflow_run.id, created_from="service-api", - created_by_role=CreatorUserRole.ACCOUNT.value, + created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, created_at=datetime.now(UTC) + timedelta(minutes=i), ) @@ -464,7 +464,7 @@ class TestWorkflowAppService: elapsed_time=1.0, total_tokens=100, total_steps=3, - created_by_role=CreatorUserRole.ACCOUNT.value, + created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, created_at=timestamp, finished_at=timestamp + timedelta(minutes=1), @@ -479,7 +479,7 @@ class TestWorkflowAppService: workflow_id=workflow.id, workflow_run_id=workflow_run.id, created_from="service-api", - created_by_role=CreatorUserRole.ACCOUNT.value, + created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, created_at=timestamp, ) @@ -571,7 +571,7 @@ class TestWorkflowAppService: elapsed_time=1.0, total_tokens=100, total_steps=3, - created_by_role=CreatorUserRole.ACCOUNT.value, + created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, created_at=datetime.now(UTC) + timedelta(minutes=i), finished_at=datetime.now(UTC) + timedelta(minutes=i + 1), @@ -586,7 +586,7 @@ class TestWorkflowAppService: workflow_id=workflow.id, workflow_run_id=workflow_run.id, created_from="service-api", - created_by_role=CreatorUserRole.ACCOUNT.value, + created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, created_at=datetime.now(UTC) + timedelta(minutes=i), ) @@ -701,7 +701,7 @@ class TestWorkflowAppService: elapsed_time=1.0, total_tokens=100, total_steps=3, - created_by_role=CreatorUserRole.ACCOUNT.value, + created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, created_at=datetime.now(UTC) + timedelta(minutes=i), finished_at=datetime.now(UTC) + timedelta(minutes=i + 1), @@ -716,7 +716,7 @@ class TestWorkflowAppService: workflow_id=workflow.id, workflow_run_id=workflow_run.id, created_from="service-api", - created_by_role=CreatorUserRole.ACCOUNT.value, + created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, created_at=datetime.now(UTC) + timedelta(minutes=i), ) @@ -743,7 +743,7 @@ class TestWorkflowAppService: elapsed_time=1.0, total_tokens=100, total_steps=3, - created_by_role=CreatorUserRole.END_USER.value, + created_by_role=CreatorUserRole.END_USER, created_by=end_user.id, created_at=datetime.now(UTC) + timedelta(minutes=i + 10), finished_at=datetime.now(UTC) + timedelta(minutes=i + 11), @@ -758,7 +758,7 @@ class TestWorkflowAppService: workflow_id=workflow.id, workflow_run_id=workflow_run.id, created_from="web-app", - created_by_role=CreatorUserRole.END_USER.value, + created_by_role=CreatorUserRole.END_USER, created_by=end_user.id, created_at=datetime.now(UTC) + timedelta(minutes=i + 10), ) @@ -780,14 +780,14 @@ class TestWorkflowAppService: limit=20, ) assert result_session_filter["total"] == 2 - assert all(log.created_by_role == CreatorUserRole.END_USER.value for log in result_session_filter["data"]) + assert all(log.created_by_role == CreatorUserRole.END_USER for log in result_session_filter["data"]) # Test filtering by account email result_account_filter = service.get_paginate_workflow_app_logs( session=db_session_with_containers, app_model=app, created_by_account=account.email, page=1, limit=20 ) assert result_account_filter["total"] == 3 - assert all(log.created_by_role == CreatorUserRole.ACCOUNT.value for log in result_account_filter["data"]) + assert all(log.created_by_role == CreatorUserRole.ACCOUNT for log in result_account_filter["data"]) # Test filtering by non-existent session ID result_no_session = service.get_paginate_workflow_app_logs( @@ -853,7 +853,7 @@ class TestWorkflowAppService: elapsed_time=1.0, total_tokens=100, total_steps=3, - created_by_role=CreatorUserRole.ACCOUNT.value, + created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, created_at=datetime.now(UTC), finished_at=datetime.now(UTC) + timedelta(minutes=1), @@ -869,7 +869,7 @@ class TestWorkflowAppService: workflow_id=workflow.id, workflow_run_id=workflow_run.id, created_from="service-api", - created_by_role=CreatorUserRole.ACCOUNT.value, + created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, created_at=datetime.now(UTC), ) @@ -943,7 +943,7 @@ class TestWorkflowAppService: elapsed_time=0.0, # Edge case: 0 elapsed time total_tokens=0, # Edge case: 0 tokens total_steps=0, # Edge case: 0 steps - created_by_role=CreatorUserRole.ACCOUNT.value, + created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, created_at=datetime.now(UTC), finished_at=datetime.now(UTC), @@ -959,7 +959,7 @@ class TestWorkflowAppService: workflow_id=workflow.id, workflow_run_id=workflow_run.id, created_from="service-api", - created_by_role=CreatorUserRole.ACCOUNT.value, + created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, created_at=datetime.now(UTC), ) @@ -1098,7 +1098,7 @@ class TestWorkflowAppService: elapsed_time=1.5, total_tokens=100, total_steps=3, - created_by_role=CreatorUserRole.ACCOUNT.value, + created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, created_at=datetime.now(UTC) + timedelta(minutes=i), finished_at=datetime.now(UTC) + timedelta(minutes=i + 1) if status == "succeeded" else None, @@ -1113,7 +1113,7 @@ class TestWorkflowAppService: workflow_id=workflow.id, workflow_run_id=workflow_run.id, created_from="service-api", - created_by_role=CreatorUserRole.ACCOUNT.value, + created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, created_at=datetime.now(UTC) + timedelta(minutes=i), ) @@ -1198,7 +1198,7 @@ class TestWorkflowAppService: elapsed_time=1.5, total_tokens=100, total_steps=3, - created_by_role=CreatorUserRole.ACCOUNT.value, + created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, created_at=datetime.now(UTC) + timedelta(minutes=i), finished_at=datetime.now(UTC) + timedelta(minutes=i + 1) if status != "running" else None, @@ -1213,7 +1213,7 @@ class TestWorkflowAppService: workflow_id=workflow.id, workflow_run_id=workflow_run.id, created_from="service-api", - created_by_role=CreatorUserRole.ACCOUNT.value, + created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, created_at=datetime.now(UTC) + timedelta(minutes=i), ) @@ -1300,7 +1300,7 @@ class TestWorkflowAppService: elapsed_time=1.5, total_tokens=100, total_steps=3, - created_by_role=CreatorUserRole.ACCOUNT.value, + created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, created_at=datetime.now(UTC) + timedelta(minutes=i * 10 + j), finished_at=datetime.now(UTC) + timedelta(minutes=i * 10 + j + 1), @@ -1315,7 +1315,7 @@ class TestWorkflowAppService: workflow_id=workflow.id, workflow_run_id=workflow_run.id, created_from="service-api", - created_by_role=CreatorUserRole.ACCOUNT.value, + created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, created_at=datetime.now(UTC) + timedelta(minutes=i * 10 + j), ) diff --git a/api/tests/test_containers_integration_tests/services/test_workflow_run_service.py b/api/tests/test_containers_integration_tests/services/test_workflow_run_service.py index 4cb21ef6bd..23c4eeb82f 100644 --- a/api/tests/test_containers_integration_tests/services/test_workflow_run_service.py +++ b/api/tests/test_containers_integration_tests/services/test_workflow_run_service.py @@ -130,7 +130,7 @@ class TestWorkflowRunService: elapsed_time=1.5, total_tokens=100, total_steps=3, - created_by_role=CreatorUserRole.ACCOUNT.value, + created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, created_at=created_time, finished_at=created_time, @@ -167,7 +167,7 @@ class TestWorkflowRunService: inputs={}, status="normal", mode="chat", - from_source=CreatorUserRole.ACCOUNT.value, + from_source=CreatorUserRole.ACCOUNT, from_account_id=account.id, ) db.session.add(conversation) @@ -188,7 +188,7 @@ class TestWorkflowRunService: message.answer_price_unit = 0.001 message.currency = "USD" message.status = "normal" - message.from_source = CreatorUserRole.ACCOUNT.value + message.from_source = CreatorUserRole.ACCOUNT message.from_account_id = account.id message.workflow_run_id = workflow_run.id message.inputs = {"input": "test input"} @@ -458,7 +458,7 @@ class TestWorkflowRunService: status="succeeded", elapsed_time=0.5, execution_metadata=json.dumps({"tokens": 50}), - created_by_role=CreatorUserRole.ACCOUNT.value, + created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, created_at=datetime.now(UTC), ) @@ -689,7 +689,7 @@ class TestWorkflowRunService: status="succeeded", elapsed_time=0.5, execution_metadata=json.dumps({"tokens": 50}), - created_by_role=CreatorUserRole.END_USER.value, + created_by_role=CreatorUserRole.END_USER, created_by=end_user.id, created_at=datetime.now(UTC), ) @@ -710,4 +710,4 @@ class TestWorkflowRunService: assert node_exec.app_id == app.id assert node_exec.workflow_run_id == workflow_run.id assert node_exec.created_by == end_user.id - assert node_exec.created_by_role == CreatorUserRole.END_USER.value + assert node_exec.created_by_role == CreatorUserRole.END_USER diff --git a/api/tests/test_containers_integration_tests/services/test_workflow_service.py b/api/tests/test_containers_integration_tests/services/test_workflow_service.py index 60150667ed..4741eba1f5 100644 --- a/api/tests/test_containers_integration_tests/services/test_workflow_service.py +++ b/api/tests/test_containers_integration_tests/services/test_workflow_service.py @@ -44,27 +44,26 @@ class TestWorkflowService: Account: Created test account instance """ fake = fake or Faker() - account = Account() - account.id = fake.uuid4() - account.email = fake.email() - account.name = fake.name() - account.avatar_url = fake.url() - account.tenant_id = fake.uuid4() - account.status = "active" - account.type = "normal" - account.role = "owner" - account.interface_language = "en-US" # Set interface language for Site creation + account = Account( + email=fake.email(), + name=fake.name(), + avatar=fake.url(), + status="active", + interface_language="en-US", # Set interface language for Site creation + ) account.created_at = fake.date_time_this_year() + account.id = fake.uuid4() account.updated_at = account.created_at # Create a tenant for the account from models.account import Tenant - tenant = Tenant() - tenant.id = account.tenant_id - tenant.name = f"Test Tenant {fake.company()}" - tenant.plan = "basic" - tenant.status = "active" + tenant = Tenant( + name=f"Test Tenant {fake.company()}", + plan="basic", + status="active", + ) + tenant.id = account.current_tenant_id tenant.created_at = fake.date_time_this_year() tenant.updated_at = tenant.created_at @@ -91,20 +90,21 @@ class TestWorkflowService: App: Created test app instance """ fake = fake or Faker() - app = App() - app.id = fake.uuid4() - app.tenant_id = fake.uuid4() - app.name = fake.company() - app.description = fake.text() - app.mode = AppMode.WORKFLOW - app.icon_type = "emoji" - app.icon = "🤖" - app.icon_background = "#FFEAD5" - app.enable_site = True - app.enable_api = True - app.created_by = fake.uuid4() + app = App( + id=fake.uuid4(), + tenant_id=fake.uuid4(), + name=fake.company(), + description=fake.text(), + mode=AppMode.WORKFLOW, + icon_type="emoji", + icon="🤖", + icon_background="#FFEAD5", + enable_site=True, + enable_api=True, + created_by=fake.uuid4(), + workflow_id=None, # Will be set when workflow is created + ) app.updated_by = app.created_by - app.workflow_id = None # Will be set when workflow is created from extensions.ext_database import db @@ -126,19 +126,20 @@ class TestWorkflowService: Workflow: Created test workflow instance """ fake = fake or Faker() - workflow = Workflow() - workflow.id = fake.uuid4() - workflow.tenant_id = app.tenant_id - workflow.app_id = app.id - workflow.type = WorkflowType.WORKFLOW.value - workflow.version = Workflow.VERSION_DRAFT - workflow.graph = json.dumps({"nodes": [], "edges": []}) - workflow.features = json.dumps({"features": []}) - # unique_hash is a computed property based on graph and features - workflow.created_by = account.id - workflow.updated_by = account.id - workflow.environment_variables = [] - workflow.conversation_variables = [] + workflow = Workflow( + id=fake.uuid4(), + tenant_id=app.tenant_id, + app_id=app.id, + type=WorkflowType.WORKFLOW, + version=Workflow.VERSION_DRAFT, + graph=json.dumps({"nodes": [], "edges": []}), + features=json.dumps({"features": []}), + # unique_hash is a computed property based on graph and features + created_by=account.id, + updated_by=account.id, + environment_variables=[], + conversation_variables=[], + ) from extensions.ext_database import db @@ -175,7 +176,7 @@ class TestWorkflowService: node_execution.node_type = "test_node" node_execution.title = "Test Node" # Required field node_execution.status = "succeeded" - node_execution.created_by_role = CreatorUserRole.ACCOUNT.value # Required field + node_execution.created_by_role = CreatorUserRole.ACCOUNT # Required field node_execution.created_by = account.id # Required field node_execution.created_at = fake.date_time_this_year() diff --git a/api/tests/test_containers_integration_tests/services/test_workspace_service.py b/api/tests/test_containers_integration_tests/services/test_workspace_service.py index 3fd439256d..814d1908bd 100644 --- a/api/tests/test_containers_integration_tests/services/test_workspace_service.py +++ b/api/tests/test_containers_integration_tests/services/test_workspace_service.py @@ -69,7 +69,7 @@ class TestWorkspaceService: join = TenantAccountJoin( tenant_id=tenant.id, account_id=account.id, - role=TenantAccountRole.OWNER.value, + role=TenantAccountRole.OWNER, current=True, ) db.session.add(join) @@ -111,7 +111,7 @@ class TestWorkspaceService: assert result["name"] == tenant.name assert result["plan"] == tenant.plan assert result["status"] == tenant.status - assert result["role"] == TenantAccountRole.OWNER.value + assert result["role"] == TenantAccountRole.OWNER assert result["created_at"] == tenant.created_at assert result["trial_end_reason"] is None @@ -159,7 +159,7 @@ class TestWorkspaceService: assert result["name"] == tenant.name assert result["plan"] == tenant.plan assert result["status"] == tenant.status - assert result["role"] == TenantAccountRole.OWNER.value + assert result["role"] == TenantAccountRole.OWNER assert result["created_at"] == tenant.created_at assert result["trial_end_reason"] is None @@ -194,7 +194,7 @@ class TestWorkspaceService: from extensions.ext_database import db join = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first() - join.role = TenantAccountRole.NORMAL.value + join.role = TenantAccountRole.NORMAL db.session.commit() # Setup mocks for feature service @@ -212,7 +212,7 @@ class TestWorkspaceService: assert result["name"] == tenant.name assert result["plan"] == tenant.plan assert result["status"] == tenant.status - assert result["role"] == TenantAccountRole.NORMAL.value + assert result["role"] == TenantAccountRole.NORMAL assert result["created_at"] == tenant.created_at assert result["trial_end_reason"] is None @@ -245,7 +245,7 @@ class TestWorkspaceService: from extensions.ext_database import db join = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first() - join.role = TenantAccountRole.ADMIN.value + join.role = TenantAccountRole.ADMIN db.session.commit() # Setup mocks for feature service and tenant service @@ -260,7 +260,7 @@ class TestWorkspaceService: # Assert: Verify the expected outcomes assert result is not None - assert result["role"] == TenantAccountRole.ADMIN.value + assert result["role"] == TenantAccountRole.ADMIN # Verify custom config is included for admin users assert "custom_config" in result @@ -378,7 +378,7 @@ class TestWorkspaceService: from extensions.ext_database import db join = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first() - join.role = TenantAccountRole.EDITOR.value + join.role = TenantAccountRole.EDITOR db.session.commit() # Setup mocks for feature service and tenant service @@ -394,7 +394,7 @@ class TestWorkspaceService: # Assert: Verify the expected outcomes assert result is not None - assert result["role"] == TenantAccountRole.EDITOR.value + assert result["role"] == TenantAccountRole.EDITOR # Verify custom config is not included for editor users without admin privileges assert "custom_config" not in result @@ -425,7 +425,7 @@ class TestWorkspaceService: from extensions.ext_database import db join = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first() - join.role = TenantAccountRole.DATASET_OPERATOR.value + join.role = TenantAccountRole.DATASET_OPERATOR db.session.commit() # Setup mocks for feature service and tenant service @@ -441,7 +441,7 @@ class TestWorkspaceService: # Assert: Verify the expected outcomes assert result is not None - assert result["role"] == TenantAccountRole.DATASET_OPERATOR.value + assert result["role"] == TenantAccountRole.DATASET_OPERATOR # Verify custom config is not included for dataset operators without admin privileges assert "custom_config" not in result diff --git a/api/tests/test_containers_integration_tests/services/tools/test_api_tools_manage_service.py b/api/tests/test_containers_integration_tests/services/tools/test_api_tools_manage_service.py index a412bdccf8..7366b08439 100644 --- a/api/tests/test_containers_integration_tests/services/tools/test_api_tools_manage_service.py +++ b/api/tests/test_containers_integration_tests/services/tools/test_api_tools_manage_service.py @@ -72,7 +72,7 @@ class TestApiToolManageService: join = TenantAccountJoin( tenant_id=tenant.id, account_id=account.id, - role=TenantAccountRole.OWNER.value, + role=TenantAccountRole.OWNER, current=True, ) db.session.add(join) diff --git a/api/tests/test_containers_integration_tests/services/tools/test_mcp_tools_manage_service.py b/api/tests/test_containers_integration_tests/services/tools/test_mcp_tools_manage_service.py index dd22dcbfd1..f7a4c53318 100644 --- a/api/tests/test_containers_integration_tests/services/tools/test_mcp_tools_manage_service.py +++ b/api/tests/test_containers_integration_tests/services/tools/test_mcp_tools_manage_service.py @@ -72,7 +72,7 @@ class TestMCPToolManageService: join = TenantAccountJoin( tenant_id=tenant.id, account_id=account.id, - role=TenantAccountRole.OWNER.value, + role=TenantAccountRole.OWNER, current=True, ) db.session.add(join) diff --git a/api/tests/test_containers_integration_tests/services/tools/test_tools_transform_service.py b/api/tests/test_containers_integration_tests/services/tools/test_tools_transform_service.py index 827f9c010e..ae0c7b7a6b 100644 --- a/api/tests/test_containers_integration_tests/services/tools/test_tools_transform_service.py +++ b/api/tests/test_containers_integration_tests/services/tools/test_tools_transform_service.py @@ -168,7 +168,7 @@ class TestToolTransformService: """ # Arrange: Setup test data fake = Faker() - provider_type = ToolProviderType.BUILT_IN.value + provider_type = ToolProviderType.BUILT_IN provider_name = fake.company() icon = "🔧" @@ -206,7 +206,7 @@ class TestToolTransformService: """ # Arrange: Setup test data fake = Faker() - provider_type = ToolProviderType.API.value + provider_type = ToolProviderType.API provider_name = fake.company() icon = '{"background": "#FF6B6B", "content": "🔧"}' @@ -231,7 +231,7 @@ class TestToolTransformService: """ # Arrange: Setup test data with invalid JSON fake = Faker() - provider_type = ToolProviderType.API.value + provider_type = ToolProviderType.API provider_name = fake.company() icon = '{"invalid": json}' @@ -257,7 +257,7 @@ class TestToolTransformService: """ # Arrange: Setup test data fake = Faker() - provider_type = ToolProviderType.WORKFLOW.value + provider_type = ToolProviderType.WORKFLOW provider_name = fake.company() icon = {"background": "#FF6B6B", "content": "🔧"} @@ -282,7 +282,7 @@ class TestToolTransformService: """ # Arrange: Setup test data fake = Faker() - provider_type = ToolProviderType.MCP.value + provider_type = ToolProviderType.MCP provider_name = fake.company() icon = {"background": "#FF6B6B", "content": "🔧"} @@ -329,7 +329,7 @@ class TestToolTransformService: # Arrange: Setup test data fake = Faker() tenant_id = fake.uuid4() - provider = {"type": ToolProviderType.BUILT_IN.value, "name": fake.company(), "icon": "🔧"} + provider = {"type": ToolProviderType.BUILT_IN, "name": fake.company(), "icon": "🔧"} # Act: Execute the method under test ToolTransformService.repack_provider(tenant_id, provider) diff --git a/api/tests/test_containers_integration_tests/services/workflow/test_workflow_converter.py b/api/tests/test_containers_integration_tests/services/workflow/test_workflow_converter.py index 18ab4bb73c..88aa0b6e72 100644 --- a/api/tests/test_containers_integration_tests/services/workflow/test_workflow_converter.py +++ b/api/tests/test_containers_integration_tests/services/workflow/test_workflow_converter.py @@ -66,7 +66,7 @@ class TestWorkflowConverter: mock_config.model = ModelConfigEntity( provider="openai", model="gpt-4", - mode=LLMMode.CHAT.value, + mode=LLMMode.CHAT, parameters={}, stop=[], ) @@ -120,7 +120,7 @@ class TestWorkflowConverter: join = TenantAccountJoin( tenant_id=tenant.id, account_id=account.id, - role=TenantAccountRole.OWNER.value, + role=TenantAccountRole.OWNER, current=True, ) db.session.add(join) @@ -150,7 +150,7 @@ class TestWorkflowConverter: app = App( tenant_id=tenant.id, name=fake.company(), - mode=AppMode.CHAT.value, + mode=AppMode.CHAT, icon_type="emoji", icon="🤖", icon_background="#FF6B6B", @@ -218,7 +218,7 @@ class TestWorkflowConverter: # Assert: Verify the expected outcomes assert new_app is not None assert new_app.name == "Test Workflow App" - assert new_app.mode == AppMode.ADVANCED_CHAT.value + assert new_app.mode == AppMode.ADVANCED_CHAT assert new_app.icon_type == "emoji" assert new_app.icon == "🚀" assert new_app.icon_background == "#4CAF50" @@ -257,7 +257,7 @@ class TestWorkflowConverter: app = App( tenant_id=tenant.id, name=fake.company(), - mode=AppMode.CHAT.value, + mode=AppMode.CHAT, icon_type="emoji", icon="🤖", icon_background="#FF6B6B", @@ -522,7 +522,7 @@ class TestWorkflowConverter: model_config = ModelConfigEntity( provider="openai", model="gpt-4", - mode=LLMMode.CHAT.value, + mode=LLMMode.CHAT, parameters={"temperature": 0.7}, stop=[], ) diff --git a/api/tests/test_containers_integration_tests/tasks/test_add_document_to_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_add_document_to_index_task.py index 4600f2addb..96e673d855 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_add_document_to_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_add_document_to_index_task.py @@ -63,7 +63,7 @@ class TestAddDocumentToIndexTask: join = TenantAccountJoin( tenant_id=tenant.id, account_id=account.id, - role=TenantAccountRole.OWNER.value, + role=TenantAccountRole.OWNER, current=True, ) db.session.add(join) diff --git a/api/tests/test_containers_integration_tests/tasks/test_batch_clean_document_task.py b/api/tests/test_containers_integration_tests/tasks/test_batch_clean_document_task.py index 3d17a8ac9d..8628e2af7f 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_batch_clean_document_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_batch_clean_document_task.py @@ -84,7 +84,7 @@ class TestBatchCleanDocumentTask: join = TenantAccountJoin( tenant_id=tenant.id, account_id=account.id, - role=TenantAccountRole.OWNER.value, + role=TenantAccountRole.OWNER, current=True, ) db.session.add(join) diff --git a/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py index fcae93c669..a9cfb6ffd4 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py @@ -112,7 +112,7 @@ class TestBatchCreateSegmentToIndexTask: join = TenantAccountJoin( tenant_id=tenant.id, account_id=account.id, - role=TenantAccountRole.OWNER.value, + role=TenantAccountRole.OWNER, current=True, ) db.session.add(join) diff --git a/api/tests/test_containers_integration_tests/tasks/test_create_segment_to_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_create_segment_to_index_task.py index de81295100..987ebf8aca 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_create_segment_to_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_create_segment_to_index_task.py @@ -91,7 +91,7 @@ class TestCreateSegmentToIndexTask: join = TenantAccountJoin( tenant_id=tenant.id, account_id=account.id, - role=TenantAccountRole.OWNER.value, + role=TenantAccountRole.OWNER, current=True, ) db.session.add(join) diff --git a/api/tests/test_containers_integration_tests/tasks/test_delete_segment_from_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_delete_segment_from_index_task.py index 7af4f238be..94e9b76965 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_delete_segment_from_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_delete_segment_from_index_task.py @@ -48,11 +48,8 @@ class TestDeleteSegmentFromIndexTask: Tenant: Created test tenant instance """ fake = fake or Faker() - tenant = Tenant() + tenant = Tenant(name=f"Test Tenant {fake.company()}", plan="basic", status="active") tenant.id = fake.uuid4() - tenant.name = f"Test Tenant {fake.company()}" - tenant.plan = "basic" - tenant.status = "active" tenant.created_at = fake.date_time_this_year() tenant.updated_at = tenant.created_at @@ -73,16 +70,14 @@ class TestDeleteSegmentFromIndexTask: Account: Created test account instance """ fake = fake or Faker() - account = Account() + account = Account( + name=fake.name(), + email=fake.email(), + avatar=fake.url(), + status="active", + interface_language="en-US", + ) account.id = fake.uuid4() - account.email = fake.email() - account.name = fake.name() - account.avatar_url = fake.url() - account.tenant_id = tenant.id - account.status = "active" - account.type = "normal" - account.role = "owner" - account.interface_language = "en-US" account.created_at = fake.date_time_this_year() account.updated_at = account.created_at diff --git a/api/tests/test_containers_integration_tests/tasks/test_disable_segment_from_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_disable_segment_from_index_task.py index e1d63e993b..bc3701d098 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_disable_segment_from_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_disable_segment_from_index_task.py @@ -69,7 +69,7 @@ class TestDisableSegmentFromIndexTask: join = TenantAccountJoin( tenant_id=tenant.id, account_id=account.id, - role=TenantAccountRole.OWNER.value, + role=TenantAccountRole.OWNER, current=True, ) db.session.add(join) diff --git a/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py index 5fdb8c617c..0b36e0914a 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py @@ -43,27 +43,30 @@ class TestDisableSegmentsFromIndexTask: Account: Created test account instance """ fake = fake or Faker() - account = Account() + account = Account( + email=fake.email(), + name=fake.name(), + avatar=fake.url(), + status="active", + interface_language="en-US", + ) account.id = fake.uuid4() - account.email = fake.email() - account.name = fake.name() - account.avatar_url = fake.url() + # monkey-patch attributes for test setup account.tenant_id = fake.uuid4() - account.status = "active" account.type = "normal" account.role = "owner" - account.interface_language = "en-US" account.created_at = fake.date_time_this_year() account.updated_at = account.created_at # Create a tenant for the account from models.account import Tenant - tenant = Tenant() + tenant = Tenant( + name=f"Test Tenant {fake.company()}", + plan="basic", + status="active", + ) tenant.id = account.tenant_id - tenant.name = f"Test Tenant {fake.company()}" - tenant.plan = "basic" - tenant.status = "active" tenant.created_at = fake.date_time_this_year() tenant.updated_at = tenant.created_at @@ -91,20 +94,21 @@ class TestDisableSegmentsFromIndexTask: Dataset: Created test dataset instance """ fake = fake or Faker() - dataset = Dataset() - dataset.id = fake.uuid4() - dataset.tenant_id = account.tenant_id - dataset.name = f"Test Dataset {fake.word()}" - dataset.description = fake.text(max_nb_chars=200) - dataset.provider = "vendor" - dataset.permission = "only_me" - dataset.data_source_type = "upload_file" - dataset.indexing_technique = "high_quality" - dataset.created_by = account.id - dataset.updated_by = account.id - dataset.embedding_model = "text-embedding-ada-002" - dataset.embedding_model_provider = "openai" - dataset.built_in_field_enabled = False + dataset = Dataset( + id=fake.uuid4(), + tenant_id=account.tenant_id, + name=f"Test Dataset {fake.word()}", + description=fake.text(max_nb_chars=200), + provider="vendor", + permission="only_me", + data_source_type="upload_file", + indexing_technique="high_quality", + created_by=account.id, + updated_by=account.id, + embedding_model="text-embedding-ada-002", + embedding_model_provider="openai", + built_in_field_enabled=False, + ) from extensions.ext_database import db @@ -128,6 +132,7 @@ class TestDisableSegmentsFromIndexTask: """ fake = fake or Faker() document = DatasetDocument() + document.id = fake.uuid4() document.tenant_id = dataset.tenant_id document.dataset_id = dataset.id @@ -153,7 +158,6 @@ class TestDisableSegmentsFromIndexTask: document.archived = False document.doc_form = "text_model" # Use text_model form for testing document.doc_language = "en" - from extensions.ext_database import db db.session.add(document) diff --git a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_task.py b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_task.py index f75dcf06e1..a315577b78 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_task.py @@ -72,7 +72,7 @@ class TestDocumentIndexingTask: join = TenantAccountJoin( tenant_id=tenant.id, account_id=account.id, - role=TenantAccountRole.OWNER.value, + role=TenantAccountRole.OWNER, current=True, ) db.session.add(join) @@ -154,7 +154,7 @@ class TestDocumentIndexingTask: join = TenantAccountJoin( tenant_id=tenant.id, account_id=account.id, - role=TenantAccountRole.OWNER.value, + role=TenantAccountRole.OWNER, current=True, ) db.session.add(join) diff --git a/api/tests/test_containers_integration_tests/tasks/test_enable_segments_to_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_enable_segments_to_index_task.py index 38056496e7..798fe091ab 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_enable_segments_to_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_enable_segments_to_index_task.py @@ -63,7 +63,7 @@ class TestEnableSegmentsToIndexTask: join = TenantAccountJoin( tenant_id=tenant.id, account_id=account.id, - role=TenantAccountRole.OWNER.value, + role=TenantAccountRole.OWNER, current=True, ) db.session.add(join) diff --git a/api/tests/test_containers_integration_tests/tasks/test_mail_account_deletion_task.py b/api/tests/test_containers_integration_tests/tasks/test_mail_account_deletion_task.py index 2f38246787..31e9b67421 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_mail_account_deletion_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_mail_account_deletion_task.py @@ -66,7 +66,7 @@ class TestMailAccountDeletionTask: join = TenantAccountJoin( tenant_id=tenant.id, account_id=account.id, - role=TenantAccountRole.OWNER.value, + role=TenantAccountRole.OWNER, current=True, ) db.session.add(join) diff --git a/api/tests/test_containers_integration_tests/tasks/test_mail_change_mail_task.py b/api/tests/test_containers_integration_tests/tasks/test_mail_change_mail_task.py index 9cf348d989..1aed7dc7cc 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_mail_change_mail_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_mail_change_mail_task.py @@ -65,7 +65,7 @@ class TestMailChangeMailTask: join = TenantAccountJoin( tenant_id=tenant.id, account_id=account.id, - role=TenantAccountRole.OWNER.value, + role=TenantAccountRole.OWNER, current=True, ) db_session_with_containers.add(join) diff --git a/api/tests/test_containers_integration_tests/tasks/test_mail_invite_member_task.py b/api/tests/test_containers_integration_tests/tasks/test_mail_invite_member_task.py index 8fef87b317..c083861004 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_mail_invite_member_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_mail_invite_member_task.py @@ -95,10 +95,10 @@ class TestMailInviteMemberTask: name=fake.name(), password=fake.password(), interface_language="en-US", - status=AccountStatus.ACTIVE.value, - created_at=datetime.now(UTC), - updated_at=datetime.now(UTC), + status=AccountStatus.ACTIVE, ) + account.created_at = datetime.now(UTC) + account.updated_at = datetime.now(UTC) db_session_with_containers.add(account) db_session_with_containers.commit() db_session_with_containers.refresh(account) @@ -106,9 +106,9 @@ class TestMailInviteMemberTask: # Create tenant tenant = Tenant( name=fake.company(), - created_at=datetime.now(UTC), - updated_at=datetime.now(UTC), ) + tenant.created_at = datetime.now(UTC) + tenant.updated_at = datetime.now(UTC) db_session_with_containers.add(tenant) db_session_with_containers.commit() db_session_with_containers.refresh(tenant) @@ -117,9 +117,9 @@ class TestMailInviteMemberTask: tenant_join = TenantAccountJoin( tenant_id=tenant.id, account_id=account.id, - role=TenantAccountRole.OWNER.value, - created_at=datetime.now(UTC), + role=TenantAccountRole.OWNER, ) + tenant_join.created_at = datetime.now(UTC) db_session_with_containers.add(tenant_join) db_session_with_containers.commit() @@ -163,10 +163,11 @@ class TestMailInviteMemberTask: name=email.split("@")[0], password="", interface_language="en-US", - status=AccountStatus.PENDING.value, - created_at=datetime.now(UTC), - updated_at=datetime.now(UTC), + status=AccountStatus.PENDING, ) + + account.created_at = datetime.now(UTC) + account.updated_at = datetime.now(UTC) db_session_with_containers.add(account) db_session_with_containers.commit() db_session_with_containers.refresh(account) @@ -175,9 +176,9 @@ class TestMailInviteMemberTask: tenant_join = TenantAccountJoin( tenant_id=tenant.id, account_id=account.id, - role=TenantAccountRole.NORMAL.value, - created_at=datetime.now(UTC), + role=TenantAccountRole.NORMAL, ) + tenant_join.created_at = datetime.now(UTC) db_session_with_containers.add(tenant_join) db_session_with_containers.commit() @@ -485,7 +486,7 @@ class TestMailInviteMemberTask: db_session_with_containers.refresh(pending_account) db_session_with_containers.refresh(tenant) - assert pending_account.status == AccountStatus.PENDING.value + assert pending_account.status == AccountStatus.PENDING assert pending_account.email == invitee_email assert tenant.name is not None @@ -496,7 +497,7 @@ class TestMailInviteMemberTask: .first() ) assert tenant_join is not None - assert tenant_join.role == TenantAccountRole.NORMAL.value + assert tenant_join.role == TenantAccountRole.NORMAL def test_send_invite_member_mail_token_lifecycle_management( self, db_session_with_containers, mock_external_service_dependencies diff --git a/api/tests/unit_tests/controllers/console/auth/test_oauth.py b/api/tests/unit_tests/controllers/console/auth/test_oauth.py index 1a2e27e8fe..67f4b85413 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_oauth.py +++ b/api/tests/unit_tests/controllers/console/auth/test_oauth.py @@ -143,7 +143,7 @@ class TestOAuthCallback: oauth_provider.get_user_info.return_value = OAuthUserInfo(id="123", name="Test User", email="test@example.com") account = MagicMock() - account.status = AccountStatus.ACTIVE.value + account.status = AccountStatus.ACTIVE token_pair = MagicMock() token_pair.access_token = "jwt_access_token" @@ -220,11 +220,11 @@ class TestOAuthCallback: @pytest.mark.parametrize( ("account_status", "expected_redirect"), [ - (AccountStatus.BANNED.value, "http://localhost:3000/signin?message=Account is banned."), + (AccountStatus.BANNED, "http://localhost:3000/signin?message=Account is banned."), # CLOSED status: Currently NOT handled, will proceed to login (security issue) # This documents actual behavior. See test_defensive_check_for_closed_account_status for details ( - AccountStatus.CLOSED.value, + AccountStatus.CLOSED, "http://localhost:3000?access_token=jwt_access_token&refresh_token=jwt_refresh_token", ), ], @@ -296,13 +296,13 @@ class TestOAuthCallback: mock_get_providers.return_value = {"github": oauth_setup["provider"]} mock_account = MagicMock() - mock_account.status = AccountStatus.PENDING.value + mock_account.status = AccountStatus.PENDING mock_generate_account.return_value = mock_account with app.test_request_context("/auth/oauth/github/callback?code=test_code"): resource.get("github") - assert mock_account.status == AccountStatus.ACTIVE.value + assert mock_account.status == AccountStatus.ACTIVE assert mock_account.initialized_at is not None mock_db.session.commit.assert_called_once() @@ -352,7 +352,7 @@ class TestOAuthCallback: # Create account with CLOSED status closed_account = MagicMock() - closed_account.status = AccountStatus.CLOSED.value + closed_account.status = AccountStatus.CLOSED closed_account.id = "123" closed_account.name = "Closed Account" mock_generate_account.return_value = closed_account diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/milvus/test_milvus.py b/api/tests/unit_tests/core/rag/datasource/vdb/milvus/test_milvus.py index 48cc8a7e1c..fb2ddfe162 100644 --- a/api/tests/unit_tests/core/rag/datasource/vdb/milvus/test_milvus.py +++ b/api/tests/unit_tests/core/rag/datasource/vdb/milvus/test_milvus.py @@ -11,8 +11,8 @@ def test_default_value(): config = valid_config.copy() del config[key] with pytest.raises(ValidationError) as e: - MilvusConfig(**config) + MilvusConfig.model_validate(config) assert e.value.errors()[0]["msg"] == f"Value error, config MILVUS_{key.upper()} is required" - config = MilvusConfig(**valid_config) + config = MilvusConfig.model_validate(valid_config) assert config.database == "default" diff --git a/api/tests/unit_tests/core/rag/extractor/firecrawl/test_firecrawl.py b/api/tests/unit_tests/core/rag/extractor/firecrawl/test_firecrawl.py index 6689e13b96..e5ead6ff66 100644 --- a/api/tests/unit_tests/core/rag/extractor/firecrawl/test_firecrawl.py +++ b/api/tests/unit_tests/core/rag/extractor/firecrawl/test_firecrawl.py @@ -18,7 +18,7 @@ def test_firecrawl_web_extractor_crawl_mode(mocker): mocked_firecrawl = { "id": "test", } - mocker.patch("requests.post", return_value=_mock_response(mocked_firecrawl)) + mocker.patch("httpx.post", return_value=_mock_response(mocked_firecrawl)) job_id = firecrawl_app.crawl_url(url, params) assert job_id is not None diff --git a/api/tests/unit_tests/core/rag/extractor/test_notion_extractor.py b/api/tests/unit_tests/core/rag/extractor/test_notion_extractor.py index eea584a2f8..f1e1820acc 100644 --- a/api/tests/unit_tests/core/rag/extractor/test_notion_extractor.py +++ b/api/tests/unit_tests/core/rag/extractor/test_notion_extractor.py @@ -69,7 +69,7 @@ def test_notion_page(mocker): ], "next_cursor": None, } - mocker.patch("requests.request", return_value=_mock_response(mocked_notion_page)) + mocker.patch("httpx.request", return_value=_mock_response(mocked_notion_page)) page_docs = extractor._load_data_as_documents(page_id, "page") assert len(page_docs) == 1 @@ -84,7 +84,7 @@ def test_notion_database(mocker): "results": [_generate_page(i) for i in page_title_list], "next_cursor": None, } - mocker.patch("requests.post", return_value=_mock_response(mocked_notion_database)) + mocker.patch("httpx.post", return_value=_mock_response(mocked_notion_database)) database_docs = extractor._load_data_as_documents(database_id, "database") assert len(database_docs) == 1 content = _remove_multiple_new_lines(database_docs[0].page_content) diff --git a/api/tests/unit_tests/core/repositories/test_celery_workflow_execution_repository.py b/api/tests/unit_tests/core/repositories/test_celery_workflow_execution_repository.py index e7733b2317..e6d0371cd5 100644 --- a/api/tests/unit_tests/core/repositories/test_celery_workflow_execution_repository.py +++ b/api/tests/unit_tests/core/repositories/test_celery_workflow_execution_repository.py @@ -140,7 +140,7 @@ class TestCeleryWorkflowExecutionRepository: assert call_args["execution_data"] == sample_workflow_execution.model_dump() assert call_args["tenant_id"] == mock_account.current_tenant_id assert call_args["app_id"] == "test-app" - assert call_args["triggered_from"] == WorkflowRunTriggeredFrom.APP_RUN.value + assert call_args["triggered_from"] == WorkflowRunTriggeredFrom.APP_RUN assert call_args["creator_user_id"] == mock_account.id # Verify no task tracking occurs (no _pending_saves attribute) diff --git a/api/tests/unit_tests/core/repositories/test_celery_workflow_node_execution_repository.py b/api/tests/unit_tests/core/repositories/test_celery_workflow_node_execution_repository.py index 3abe20fca1..f6211f4cca 100644 --- a/api/tests/unit_tests/core/repositories/test_celery_workflow_node_execution_repository.py +++ b/api/tests/unit_tests/core/repositories/test_celery_workflow_node_execution_repository.py @@ -149,7 +149,7 @@ class TestCeleryWorkflowNodeExecutionRepository: assert call_args["execution_data"] == sample_workflow_node_execution.model_dump() assert call_args["tenant_id"] == mock_account.current_tenant_id assert call_args["app_id"] == "test-app" - assert call_args["triggered_from"] == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value + assert call_args["triggered_from"] == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN assert call_args["creator_user_id"] == mock_account.id # Verify execution is cached diff --git a/api/tests/unit_tests/core/repositories/test_workflow_node_execution_truncation.py b/api/tests/unit_tests/core/repositories/test_workflow_node_execution_truncation.py index 36f7d3ef55..485be90eae 100644 --- a/api/tests/unit_tests/core/repositories/test_workflow_node_execution_truncation.py +++ b/api/tests/unit_tests/core/repositories/test_workflow_node_execution_truncation.py @@ -145,12 +145,12 @@ class TestSQLAlchemyWorkflowNodeExecutionRepositoryTruncation: db_model.index = 1 db_model.predecessor_node_id = None db_model.node_id = "node-id" - db_model.node_type = NodeType.LLM.value + db_model.node_type = NodeType.LLM db_model.title = "Test Node" db_model.inputs = json.dumps({"value": "inputs"}) db_model.process_data = json.dumps({"value": "process_data"}) db_model.outputs = json.dumps({"value": "outputs"}) - db_model.status = WorkflowNodeExecutionStatus.SUCCEEDED.value + db_model.status = WorkflowNodeExecutionStatus.SUCCEEDED db_model.error = None db_model.elapsed_time = 1.0 db_model.execution_metadata = "{}" diff --git a/api/tests/unit_tests/core/workflow/graph_engine/command_channels/test_redis_channel.py b/api/tests/unit_tests/core/workflow/graph_engine/command_channels/test_redis_channel.py index 2c08fff27b..7ebccf83a7 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/command_channels/test_redis_channel.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/command_channels/test_redis_channel.py @@ -147,7 +147,7 @@ class TestRedisChannel: """Test deserializing an abort command.""" channel = RedisChannel(MagicMock(), "test:key") - abort_data = {"command_type": CommandType.ABORT.value} + abort_data = {"command_type": CommandType.ABORT} command = channel._deserialize_command(abort_data) assert isinstance(command, AbortCommand) @@ -158,7 +158,7 @@ class TestRedisChannel: channel = RedisChannel(MagicMock(), "test:key") # For now, only ABORT is supported, but test generic handling - generic_data = {"command_type": CommandType.ABORT.value} + generic_data = {"command_type": CommandType.ABORT} command = channel._deserialize_command(generic_data) assert command is not None diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_iteration_simple.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_iteration_simple.py index 6a9bfbdcc3..c39c12925f 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_iteration_simple.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_iteration_simple.py @@ -56,8 +56,8 @@ def test_mock_iteration_node_preserves_config(): workflow_id="test", graph_config={"nodes": [], "edges": []}, user_id="test", - user_from=UserFrom.ACCOUNT.value, - invoke_from=InvokeFrom.SERVICE_API.value, + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.SERVICE_API, call_depth=0, ) @@ -117,8 +117,8 @@ def test_mock_loop_node_preserves_config(): workflow_id="test", graph_config={"nodes": [], "edges": []}, user_id="test", - user_from=UserFrom.ACCOUNT.value, - invoke_from=InvokeFrom.SERVICE_API.value, + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.SERVICE_API, call_depth=0, ) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_redis_stop_integration.py b/api/tests/unit_tests/core/workflow/graph_engine/test_redis_stop_integration.py index b286d99f70..bd41fdeee5 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_redis_stop_integration.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_redis_stop_integration.py @@ -49,7 +49,7 @@ class TestRedisStopIntegration: # Verify the command data command_json = calls[0][0][1] command_data = json.loads(command_json) - assert command_data["command_type"] == CommandType.ABORT.value + assert command_data["command_type"] == CommandType.ABORT assert command_data["reason"] == "Test stop" def test_graph_engine_manager_handles_redis_failure_gracefully(self): @@ -122,7 +122,7 @@ class TestRedisStopIntegration: # Verify serialized command command_json = calls[0][0][1] command_data = json.loads(command_json) - assert command_data["command_type"] == CommandType.ABORT.value + assert command_data["command_type"] == CommandType.ABORT assert command_data["reason"] == "User requested stop" # Check expire was set @@ -137,9 +137,7 @@ class TestRedisStopIntegration: mock_redis.pipeline.return_value.__exit__ = Mock(return_value=None) # Mock command data - abort_command_json = json.dumps( - {"command_type": CommandType.ABORT.value, "reason": "Test abort", "payload": None} - ) + abort_command_json = json.dumps({"command_type": CommandType.ABORT, "reason": "Test abort", "payload": None}) # Mock pipeline execute to return commands mock_pipeline.execute.return_value = [ diff --git a/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py b/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py index b942614232..55fe62ca43 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py @@ -35,7 +35,7 @@ def list_operator_node(): "extract_by": ExtractConfig(enabled=False, serial="1"), "title": "Test Title", } - node_data = ListOperatorNodeData(**config) + node_data = ListOperatorNodeData.model_validate(config) node_config = { "id": "test_node_id", "data": node_data.model_dump(), diff --git a/api/tests/unit_tests/core/workflow/nodes/test_question_classifier_node.py b/api/tests/unit_tests/core/workflow/nodes/test_question_classifier_node.py index f990280c5f..47ef289ef3 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_question_classifier_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_question_classifier_node.py @@ -17,7 +17,7 @@ def test_init_question_classifier_node_data(): "vision": {"enabled": True, "configs": {"variable_selector": ["image"], "detail": "low"}}, } - node_data = QuestionClassifierNodeData(**data) + node_data = QuestionClassifierNodeData.model_validate(data) assert node_data.query_variable_selector == ["id", "name"] assert node_data.model.provider == "openai" @@ -49,7 +49,7 @@ def test_init_question_classifier_node_data_without_vision_config(): }, } - node_data = QuestionClassifierNodeData(**data) + node_data = QuestionClassifierNodeData.model_validate(data) assert node_data.query_variable_selector == ["id", "name"] assert node_data.model.provider == "openai" diff --git a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py index 3e50d5522a..6189febdf5 100644 --- a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py +++ b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py @@ -87,7 +87,7 @@ def test_overwrite_string_variable(): "data": { "title": "test", "assigned_variable_selector": ["conversation", conversation_variable.name], - "write_mode": WriteMode.OVER_WRITE.value, + "write_mode": WriteMode.OVER_WRITE, "input_variable_selector": [DEFAULT_NODE_ID, input_variable.name], }, } @@ -189,7 +189,7 @@ def test_append_variable_to_array(): "data": { "title": "test", "assigned_variable_selector": ["conversation", conversation_variable.name], - "write_mode": WriteMode.APPEND.value, + "write_mode": WriteMode.APPEND, "input_variable_selector": [DEFAULT_NODE_ID, input_variable.name], }, } @@ -282,7 +282,7 @@ def test_clear_array(): "data": { "title": "test", "assigned_variable_selector": ["conversation", conversation_variable.name], - "write_mode": WriteMode.CLEAR.value, + "write_mode": WriteMode.CLEAR, "input_variable_selector": [], }, } diff --git a/api/tests/unit_tests/core/workflow/test_system_variable.py b/api/tests/unit_tests/core/workflow/test_system_variable.py index 11d788ed79..3ae5edb383 100644 --- a/api/tests/unit_tests/core/workflow/test_system_variable.py +++ b/api/tests/unit_tests/core/workflow/test_system_variable.py @@ -46,7 +46,7 @@ class TestSystemVariableSerialization: def test_basic_deserialization(self): """Test successful deserialization from JSON structure with all fields correctly mapped.""" # Test with complete data - system_var = SystemVariable(**COMPLETE_VALID_DATA) + system_var = SystemVariable.model_validate(COMPLETE_VALID_DATA) # Verify all fields are correctly mapped assert system_var.user_id == COMPLETE_VALID_DATA["user_id"] @@ -59,7 +59,7 @@ class TestSystemVariableSerialization: assert system_var.files == [] # Test with minimal data (only required fields) - minimal_var = SystemVariable(**VALID_BASE_DATA) + minimal_var = SystemVariable.model_validate(VALID_BASE_DATA) assert minimal_var.user_id == VALID_BASE_DATA["user_id"] assert minimal_var.app_id == VALID_BASE_DATA["app_id"] assert minimal_var.workflow_id == VALID_BASE_DATA["workflow_id"] @@ -75,12 +75,12 @@ class TestSystemVariableSerialization: # Test workflow_run_id only (preferred alias) data_run_id = {**VALID_BASE_DATA, "workflow_run_id": workflow_id} - system_var1 = SystemVariable(**data_run_id) + system_var1 = SystemVariable.model_validate(data_run_id) assert system_var1.workflow_execution_id == workflow_id # Test workflow_execution_id only (direct field name) data_execution_id = {**VALID_BASE_DATA, "workflow_execution_id": workflow_id} - system_var2 = SystemVariable(**data_execution_id) + system_var2 = SystemVariable.model_validate(data_execution_id) assert system_var2.workflow_execution_id == workflow_id # Test both present - workflow_run_id should take precedence @@ -89,17 +89,17 @@ class TestSystemVariableSerialization: "workflow_execution_id": "should-be-ignored", "workflow_run_id": workflow_id, } - system_var3 = SystemVariable(**data_both) + system_var3 = SystemVariable.model_validate(data_both) assert system_var3.workflow_execution_id == workflow_id # Test neither present - should be None - system_var4 = SystemVariable(**VALID_BASE_DATA) + system_var4 = SystemVariable.model_validate(VALID_BASE_DATA) assert system_var4.workflow_execution_id is None def test_serialization_round_trip(self): """Test that serialize → deserialize produces the same result with alias handling.""" # Create original SystemVariable - original = SystemVariable(**COMPLETE_VALID_DATA) + original = SystemVariable.model_validate(COMPLETE_VALID_DATA) # Serialize to dict serialized = original.model_dump(mode="json") @@ -110,7 +110,7 @@ class TestSystemVariableSerialization: assert serialized["workflow_run_id"] == COMPLETE_VALID_DATA["workflow_run_id"] # Deserialize back - deserialized = SystemVariable(**serialized) + deserialized = SystemVariable.model_validate(serialized) # Verify all fields match after round-trip assert deserialized.user_id == original.user_id @@ -125,7 +125,7 @@ class TestSystemVariableSerialization: def test_json_round_trip(self): """Test JSON serialization/deserialization consistency with proper structure.""" # Create original SystemVariable - original = SystemVariable(**COMPLETE_VALID_DATA) + original = SystemVariable.model_validate(COMPLETE_VALID_DATA) # Serialize to JSON string json_str = original.model_dump_json() @@ -137,7 +137,7 @@ class TestSystemVariableSerialization: assert json_data["workflow_run_id"] == COMPLETE_VALID_DATA["workflow_run_id"] # Deserialize from JSON data - deserialized = SystemVariable(**json_data) + deserialized = SystemVariable.model_validate(json_data) # Verify key fields match after JSON round-trip assert deserialized.workflow_execution_id == original.workflow_execution_id @@ -149,13 +149,13 @@ class TestSystemVariableSerialization: """Test deserialization with File objects in the files field - SystemVariable specific logic.""" # Test with empty files list data_empty = {**VALID_BASE_DATA, "files": []} - system_var_empty = SystemVariable(**data_empty) + system_var_empty = SystemVariable.model_validate(data_empty) assert system_var_empty.files == [] # Test with single File object test_file = create_test_file() data_single = {**VALID_BASE_DATA, "files": [test_file]} - system_var_single = SystemVariable(**data_single) + system_var_single = SystemVariable.model_validate(data_single) assert len(system_var_single.files) == 1 assert system_var_single.files[0].filename == "test.txt" assert system_var_single.files[0].tenant_id == "test-tenant-id" @@ -179,14 +179,14 @@ class TestSystemVariableSerialization: ) data_multiple = {**VALID_BASE_DATA, "files": [file1, file2]} - system_var_multiple = SystemVariable(**data_multiple) + system_var_multiple = SystemVariable.model_validate(data_multiple) assert len(system_var_multiple.files) == 2 assert system_var_multiple.files[0].filename == "doc1.txt" assert system_var_multiple.files[1].filename == "image.jpg" # Verify files field serialization/deserialization serialized = system_var_multiple.model_dump(mode="json") - deserialized = SystemVariable(**serialized) + deserialized = SystemVariable.model_validate(serialized) assert len(deserialized.files) == 2 assert deserialized.files[0].filename == "doc1.txt" assert deserialized.files[1].filename == "image.jpg" @@ -197,7 +197,7 @@ class TestSystemVariableSerialization: # Create with workflow_run_id (alias) data_with_alias = {**VALID_BASE_DATA, "workflow_run_id": workflow_id} - system_var = SystemVariable(**data_with_alias) + system_var = SystemVariable.model_validate(data_with_alias) # Serialize and verify alias is used serialized = system_var.model_dump() @@ -205,7 +205,7 @@ class TestSystemVariableSerialization: assert "workflow_execution_id" not in serialized # Deserialize and verify field mapping - deserialized = SystemVariable(**serialized) + deserialized = SystemVariable.model_validate(serialized) assert deserialized.workflow_execution_id == workflow_id # Test JSON serialization path @@ -213,7 +213,7 @@ class TestSystemVariableSerialization: assert json_serialized["workflow_run_id"] == workflow_id assert "workflow_execution_id" not in json_serialized - json_deserialized = SystemVariable(**json_serialized) + json_deserialized = SystemVariable.model_validate(json_serialized) assert json_deserialized.workflow_execution_id == workflow_id def test_model_validator_serialization_logic(self): @@ -222,7 +222,7 @@ class TestSystemVariableSerialization: # Test direct instantiation with workflow_execution_id (should work) data1 = {**VALID_BASE_DATA, "workflow_execution_id": workflow_id} - system_var1 = SystemVariable(**data1) + system_var1 = SystemVariable.model_validate(data1) assert system_var1.workflow_execution_id == workflow_id # Test serialization of the above (should use alias) @@ -236,7 +236,7 @@ class TestSystemVariableSerialization: "workflow_execution_id": "should-be-removed", "workflow_run_id": workflow_id, } - system_var2 = SystemVariable(**data2) + system_var2 = SystemVariable.model_validate(data2) assert system_var2.workflow_execution_id == workflow_id # Verify serialization consistency diff --git a/api/tests/unit_tests/libs/test_helper.py b/api/tests/unit_tests/libs/test_helper.py index b7701055f5..85789bfa7e 100644 --- a/api/tests/unit_tests/libs/test_helper.py +++ b/api/tests/unit_tests/libs/test_helper.py @@ -11,7 +11,7 @@ class TestExtractTenantId: def test_extract_tenant_id_from_account_with_tenant(self): """Test extracting tenant_id from Account with current_tenant_id.""" # Create a mock Account object - account = Account() + account = Account(name="test", email="test@example.com") # Mock the current_tenant_id property account._current_tenant = type("MockTenant", (), {"id": "account-tenant-123"})() @@ -21,7 +21,7 @@ class TestExtractTenantId: def test_extract_tenant_id_from_account_without_tenant(self): """Test extracting tenant_id from Account without current_tenant_id.""" # Create a mock Account object - account = Account() + account = Account(name="test", email="test@example.com") account._current_tenant = None tenant_id = extract_tenant_id(account) diff --git a/api/tests/unit_tests/repositories/workflow_node_execution/test_sqlalchemy_repository.py b/api/tests/unit_tests/repositories/workflow_node_execution/test_sqlalchemy_repository.py index fadd1ee88f..5cba43714a 100644 --- a/api/tests/unit_tests/repositories/workflow_node_execution/test_sqlalchemy_repository.py +++ b/api/tests/unit_tests/repositories/workflow_node_execution/test_sqlalchemy_repository.py @@ -59,12 +59,11 @@ def session(): @pytest.fixture def mock_user(): """Create a user instance for testing.""" - user = Account() + user = Account(name="test", email="test@example.com") user.id = "test-user-id" - tenant = Tenant() + tenant = Tenant(name="Test Workspace") tenant.id = "test-tenant" - tenant.name = "Test Workspace" user._current_tenant = MagicMock() user._current_tenant.id = "test-tenant" @@ -299,7 +298,7 @@ def test_to_domain_model(repository): db_model.predecessor_node_id = "test-predecessor-id" db_model.node_execution_id = "test-node-execution-id" db_model.node_id = "test-node-id" - db_model.node_type = NodeType.START.value + db_model.node_type = NodeType.START db_model.title = "Test Node" db_model.inputs = json.dumps(inputs_dict) db_model.process_data = json.dumps(process_data_dict) diff --git a/api/tests/unit_tests/services/test_metadata_bug_complete.py b/api/tests/unit_tests/services/test_metadata_bug_complete.py index 0ff1edc950..31fe9b2868 100644 --- a/api/tests/unit_tests/services/test_metadata_bug_complete.py +++ b/api/tests/unit_tests/services/test_metadata_bug_complete.py @@ -118,7 +118,7 @@ class TestMetadataBugCompleteValidation: # But would crash when trying to create MetadataArgs with pytest.raises((ValueError, TypeError)): - MetadataArgs(**args) + MetadataArgs.model_validate(args) def test_7_end_to_end_validation_layers(self): """Test all validation layers work together correctly.""" @@ -131,7 +131,7 @@ class TestMetadataBugCompleteValidation: valid_data = {"type": "string", "name": "test_metadata"} # Should create valid Pydantic object - metadata_args = MetadataArgs(**valid_data) + metadata_args = MetadataArgs.model_validate(valid_data) assert metadata_args.type == "string" assert metadata_args.name == "test_metadata" diff --git a/api/tests/unit_tests/services/test_metadata_nullable_bug.py b/api/tests/unit_tests/services/test_metadata_nullable_bug.py index d151100cf3..c8cd7025c2 100644 --- a/api/tests/unit_tests/services/test_metadata_nullable_bug.py +++ b/api/tests/unit_tests/services/test_metadata_nullable_bug.py @@ -76,7 +76,7 @@ class TestMetadataNullableBug: # Step 2: Try to create MetadataArgs with None values # This should fail at Pydantic validation level with pytest.raises((ValueError, TypeError)): - metadata_args = MetadataArgs(**args) + metadata_args = MetadataArgs.model_validate(args) # Step 3: If we bypass Pydantic (simulating the bug scenario) # Move this outside the request context to avoid Flask-Login issues diff --git a/api/tests/unit_tests/services/workflow/test_workflow_converter.py b/api/tests/unit_tests/services/workflow/test_workflow_converter.py index 2ca781bae5..63ce4c0c3c 100644 --- a/api/tests/unit_tests/services/workflow/test_workflow_converter.py +++ b/api/tests/unit_tests/services/workflow/test_workflow_converter.py @@ -107,7 +107,7 @@ def test__convert_to_http_request_node_for_chatbot(default_variables): assert body_data body_data_json = json.loads(body_data) - assert body_data_json["point"] == APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY.value + assert body_data_json["point"] == APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY body_params = body_data_json["params"] assert body_params["app_id"] == app_model.id @@ -168,7 +168,7 @@ def test__convert_to_http_request_node_for_workflow_app(default_variables): assert body_data body_data_json = json.loads(body_data) - assert body_data_json["point"] == APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY.value + assert body_data_json["point"] == APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY body_params = body_data_json["params"] assert body_params["app_id"] == app_model.id diff --git a/api/tests/unit_tests/services/workflow/test_workflow_draft_variable_service.py b/api/tests/unit_tests/services/workflow/test_workflow_draft_variable_service.py index 7e324ca4db..66361f26e0 100644 --- a/api/tests/unit_tests/services/workflow/test_workflow_draft_variable_service.py +++ b/api/tests/unit_tests/services/workflow/test_workflow_draft_variable_service.py @@ -47,7 +47,8 @@ class TestDraftVariableSaver: def test__should_variable_be_visible(self): mock_session = MagicMock(spec=Session) - mock_user = Account(id=str(uuid.uuid4())) + mock_user = Account(name="test", email="test@example.com") + mock_user.id = str(uuid.uuid4()) test_app_id = self._get_test_app_id() saver = DraftVariableSaver( session=mock_session, diff --git a/api/uv.lock b/api/uv.lock index 21d1f17bad..af368199b7 100644 --- a/api/uv.lock +++ b/api/uv.lock @@ -1325,7 +1325,6 @@ dependencies = [ { name = "opentelemetry-instrumentation-flask" }, { name = "opentelemetry-instrumentation-httpx" }, { name = "opentelemetry-instrumentation-redis" }, - { name = "opentelemetry-instrumentation-requests" }, { name = "opentelemetry-instrumentation-sqlalchemy" }, { name = "opentelemetry-propagator-b3" }, { name = "opentelemetry-proto" }, @@ -1418,8 +1417,6 @@ dev = [ { name = "types-pyyaml" }, { name = "types-redis" }, { name = "types-regex" }, - { name = "types-requests" }, - { name = "types-requests-oauthlib" }, { name = "types-setuptools" }, { name = "types-shapely" }, { name = "types-simplejson" }, @@ -1516,7 +1513,6 @@ requires-dist = [ { name = "opentelemetry-instrumentation-flask", specifier = "==0.48b0" }, { name = "opentelemetry-instrumentation-httpx", specifier = "==0.48b0" }, { name = "opentelemetry-instrumentation-redis", specifier = "==0.48b0" }, - { name = "opentelemetry-instrumentation-requests", specifier = "==0.48b0" }, { name = "opentelemetry-instrumentation-sqlalchemy", specifier = "==0.48b0" }, { name = "opentelemetry-propagator-b3", specifier = "==1.27.0" }, { name = "opentelemetry-proto", specifier = "==1.27.0" }, @@ -1571,7 +1567,7 @@ dev = [ { name = "pytest-cov", specifier = "~=4.1.0" }, { name = "pytest-env", specifier = "~=1.1.3" }, { name = "pytest-mock", specifier = "~=3.14.0" }, - { name = "ruff", specifier = "~=0.12.3" }, + { name = "ruff", specifier = "~=0.14.0" }, { name = "scipy-stubs", specifier = ">=1.15.3.0" }, { name = "sseclient-py", specifier = ">=1.8.0" }, { name = "testcontainers", specifier = "~=4.10.0" }, @@ -1609,8 +1605,6 @@ dev = [ { name = "types-pyyaml", specifier = "~=6.0.12" }, { name = "types-redis", specifier = ">=4.6.0.20241004" }, { name = "types-regex", specifier = "~=2024.11.6" }, - { name = "types-requests", specifier = "~=2.32.0" }, - { name = "types-requests-oauthlib", specifier = "~=2.0.0" }, { name = "types-setuptools", specifier = ">=80.9.0" }, { name = "types-shapely", specifier = "~=2.0.0" }, { name = "types-simplejson", specifier = ">=3.20.0" }, @@ -3910,21 +3904,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/94/40/892f30d400091106309cc047fd3f6d76a828fedd984a953fd5386b78a2fb/opentelemetry_instrumentation_redis-0.48b0-py3-none-any.whl", hash = "sha256:48c7f2e25cbb30bde749dc0d8b9c74c404c851f554af832956b9630b27f5bcb7", size = 11610, upload-time = "2024-08-28T21:27:18.759Z" }, ] -[[package]] -name = "opentelemetry-instrumentation-requests" -version = "0.48b0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "opentelemetry-api" }, - { name = "opentelemetry-instrumentation" }, - { name = "opentelemetry-semantic-conventions" }, - { name = "opentelemetry-util-http" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/52/ac/5eb78efde21ff21d0ad5dc8c6cc6a0f8ae482ce8a46293c2f45a628b6166/opentelemetry_instrumentation_requests-0.48b0.tar.gz", hash = "sha256:67ab9bd877a0352ee0db4616c8b4ae59736ddd700c598ed907482d44f4c9a2b3", size = 14120, upload-time = "2024-08-28T21:28:16.933Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/43/df/0df9226d1b14f29d23c07e6194b9fd5ad50e7d987b7fd13df7dcf718aeb1/opentelemetry_instrumentation_requests-0.48b0-py3-none-any.whl", hash = "sha256:d4f01852121d0bd4c22f14f429654a735611d4f7bf3cf93f244bdf1489b2233d", size = 12366, upload-time = "2024-08-28T21:27:20.771Z" }, -] - [[package]] name = "opentelemetry-instrumentation-sqlalchemy" version = "0.48b0" @@ -5461,28 +5440,28 @@ wheels = [ [[package]] name = "ruff" -version = "0.12.12" +version = "0.14.0" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/a8/f0/e0965dd709b8cabe6356811c0ee8c096806bb57d20b5019eb4e48a117410/ruff-0.12.12.tar.gz", hash = "sha256:b86cd3415dbe31b3b46a71c598f4c4b2f550346d1ccf6326b347cc0c8fd063d6", size = 5359915, upload-time = "2025-09-04T16:50:18.273Z" } +sdist = { url = "https://files.pythonhosted.org/packages/41/b9/9bd84453ed6dd04688de9b3f3a4146a1698e8faae2ceeccce4e14c67ae17/ruff-0.14.0.tar.gz", hash = "sha256:62ec8969b7510f77945df916de15da55311fade8d6050995ff7f680afe582c57", size = 5452071, upload-time = "2025-10-07T18:21:55.763Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/09/79/8d3d687224d88367b51c7974cec1040c4b015772bfbeffac95face14c04a/ruff-0.12.12-py3-none-linux_armv6l.whl", hash = "sha256:de1c4b916d98ab289818e55ce481e2cacfaad7710b01d1f990c497edf217dafc", size = 12116602, upload-time = "2025-09-04T16:49:18.892Z" }, - { url = "https://files.pythonhosted.org/packages/c3/c3/6e599657fe192462f94861a09aae935b869aea8a1da07f47d6eae471397c/ruff-0.12.12-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:7acd6045e87fac75a0b0cdedacf9ab3e1ad9d929d149785903cff9bb69ad9727", size = 12868393, upload-time = "2025-09-04T16:49:23.043Z" }, - { url = "https://files.pythonhosted.org/packages/e8/d2/9e3e40d399abc95336b1843f52fc0daaceb672d0e3c9290a28ff1a96f79d/ruff-0.12.12-py3-none-macosx_11_0_arm64.whl", hash = "sha256:abf4073688d7d6da16611f2f126be86523a8ec4343d15d276c614bda8ec44edb", size = 12036967, upload-time = "2025-09-04T16:49:26.04Z" }, - { url = "https://files.pythonhosted.org/packages/e9/03/6816b2ed08836be272e87107d905f0908be5b4a40c14bfc91043e76631b8/ruff-0.12.12-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:968e77094b1d7a576992ac078557d1439df678a34c6fe02fd979f973af167577", size = 12276038, upload-time = "2025-09-04T16:49:29.056Z" }, - { url = "https://files.pythonhosted.org/packages/9f/d5/707b92a61310edf358a389477eabd8af68f375c0ef858194be97ca5b6069/ruff-0.12.12-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:42a67d16e5b1ffc6d21c5f67851e0e769517fb57a8ebad1d0781b30888aa704e", size = 11901110, upload-time = "2025-09-04T16:49:32.07Z" }, - { url = "https://files.pythonhosted.org/packages/9d/3d/f8b1038f4b9822e26ec3d5b49cf2bc313e3c1564cceb4c1a42820bf74853/ruff-0.12.12-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b216ec0a0674e4b1214dcc998a5088e54eaf39417327b19ffefba1c4a1e4971e", size = 13668352, upload-time = "2025-09-04T16:49:35.148Z" }, - { url = "https://files.pythonhosted.org/packages/98/0e/91421368ae6c4f3765dd41a150f760c5f725516028a6be30e58255e3c668/ruff-0.12.12-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:59f909c0fdd8f1dcdbfed0b9569b8bf428cf144bec87d9de298dcd4723f5bee8", size = 14638365, upload-time = "2025-09-04T16:49:38.892Z" }, - { url = "https://files.pythonhosted.org/packages/74/5d/88f3f06a142f58ecc8ecb0c2fe0b82343e2a2b04dcd098809f717cf74b6c/ruff-0.12.12-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9ac93d87047e765336f0c18eacad51dad0c1c33c9df7484c40f98e1d773876f5", size = 14060812, upload-time = "2025-09-04T16:49:42.732Z" }, - { url = "https://files.pythonhosted.org/packages/13/fc/8962e7ddd2e81863d5c92400820f650b86f97ff919c59836fbc4c1a6d84c/ruff-0.12.12-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:01543c137fd3650d322922e8b14cc133b8ea734617c4891c5a9fccf4bfc9aa92", size = 13050208, upload-time = "2025-09-04T16:49:46.434Z" }, - { url = "https://files.pythonhosted.org/packages/53/06/8deb52d48a9a624fd37390555d9589e719eac568c020b27e96eed671f25f/ruff-0.12.12-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2afc2fa864197634e549d87fb1e7b6feb01df0a80fd510d6489e1ce8c0b1cc45", size = 13311444, upload-time = "2025-09-04T16:49:49.931Z" }, - { url = "https://files.pythonhosted.org/packages/2a/81/de5a29af7eb8f341f8140867ffb93f82e4fde7256dadee79016ac87c2716/ruff-0.12.12-py3-none-manylinux_2_31_riscv64.whl", hash = "sha256:0c0945246f5ad776cb8925e36af2438e66188d2b57d9cf2eed2c382c58b371e5", size = 13279474, upload-time = "2025-09-04T16:49:53.465Z" }, - { url = "https://files.pythonhosted.org/packages/7f/14/d9577fdeaf791737ada1b4f5c6b59c21c3326f3f683229096cccd7674e0c/ruff-0.12.12-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:a0fbafe8c58e37aae28b84a80ba1817f2ea552e9450156018a478bf1fa80f4e4", size = 12070204, upload-time = "2025-09-04T16:49:56.882Z" }, - { url = "https://files.pythonhosted.org/packages/77/04/a910078284b47fad54506dc0af13839c418ff704e341c176f64e1127e461/ruff-0.12.12-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:b9c456fb2fc8e1282affa932c9e40f5ec31ec9cbb66751a316bd131273b57c23", size = 11880347, upload-time = "2025-09-04T16:49:59.729Z" }, - { url = "https://files.pythonhosted.org/packages/df/58/30185fcb0e89f05e7ea82e5817b47798f7fa7179863f9d9ba6fd4fe1b098/ruff-0.12.12-py3-none-musllinux_1_2_i686.whl", hash = "sha256:5f12856123b0ad0147d90b3961f5c90e7427f9acd4b40050705499c98983f489", size = 12891844, upload-time = "2025-09-04T16:50:02.591Z" }, - { url = "https://files.pythonhosted.org/packages/21/9c/28a8dacce4855e6703dcb8cdf6c1705d0b23dd01d60150786cd55aa93b16/ruff-0.12.12-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:26a1b5a2bf7dd2c47e3b46d077cd9c0fc3b93e6c6cc9ed750bd312ae9dc302ee", size = 13360687, upload-time = "2025-09-04T16:50:05.8Z" }, - { url = "https://files.pythonhosted.org/packages/c8/fa/05b6428a008e60f79546c943e54068316f32ec8ab5c4f73e4563934fbdc7/ruff-0.12.12-py3-none-win32.whl", hash = "sha256:173be2bfc142af07a01e3a759aba6f7791aa47acf3604f610b1c36db888df7b1", size = 12052870, upload-time = "2025-09-04T16:50:09.121Z" }, - { url = "https://files.pythonhosted.org/packages/85/60/d1e335417804df452589271818749d061b22772b87efda88354cf35cdb7a/ruff-0.12.12-py3-none-win_amd64.whl", hash = "sha256:e99620bf01884e5f38611934c09dd194eb665b0109104acae3ba6102b600fd0d", size = 13178016, upload-time = "2025-09-04T16:50:12.559Z" }, - { url = "https://files.pythonhosted.org/packages/28/7e/61c42657f6e4614a4258f1c3b0c5b93adc4d1f8575f5229d1906b483099b/ruff-0.12.12-py3-none-win_arm64.whl", hash = "sha256:2a8199cab4ce4d72d158319b63370abf60991495fb733db96cd923a34c52d093", size = 12256762, upload-time = "2025-09-04T16:50:15.737Z" }, + { url = "https://files.pythonhosted.org/packages/3a/4e/79d463a5f80654e93fa653ebfb98e0becc3f0e7cf6219c9ddedf1e197072/ruff-0.14.0-py3-none-linux_armv6l.whl", hash = "sha256:58e15bffa7054299becf4bab8a1187062c6f8cafbe9f6e39e0d5aface455d6b3", size = 12494532, upload-time = "2025-10-07T18:21:00.373Z" }, + { url = "https://files.pythonhosted.org/packages/ee/40/e2392f445ed8e02aa6105d49db4bfff01957379064c30f4811c3bf38aece/ruff-0.14.0-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:838d1b065f4df676b7c9957992f2304e41ead7a50a568185efd404297d5701e8", size = 13160768, upload-time = "2025-10-07T18:21:04.73Z" }, + { url = "https://files.pythonhosted.org/packages/75/da/2a656ea7c6b9bd14c7209918268dd40e1e6cea65f4bb9880eaaa43b055cd/ruff-0.14.0-py3-none-macosx_11_0_arm64.whl", hash = "sha256:703799d059ba50f745605b04638fa7e9682cc3da084b2092feee63500ff3d9b8", size = 12363376, upload-time = "2025-10-07T18:21:07.833Z" }, + { url = "https://files.pythonhosted.org/packages/42/e2/1ffef5a1875add82416ff388fcb7ea8b22a53be67a638487937aea81af27/ruff-0.14.0-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3ba9a8925e90f861502f7d974cc60e18ca29c72bb0ee8bfeabb6ade35a3abde7", size = 12608055, upload-time = "2025-10-07T18:21:10.72Z" }, + { url = "https://files.pythonhosted.org/packages/4a/32/986725199d7cee510d9f1dfdf95bf1efc5fa9dd714d0d85c1fb1f6be3bc3/ruff-0.14.0-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:e41f785498bd200ffc276eb9e1570c019c1d907b07cfb081092c8ad51975bbe7", size = 12318544, upload-time = "2025-10-07T18:21:13.741Z" }, + { url = "https://files.pythonhosted.org/packages/9a/ed/4969cefd53315164c94eaf4da7cfba1f267dc275b0abdd593d11c90829a3/ruff-0.14.0-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:30a58c087aef4584c193aebf2700f0fbcfc1e77b89c7385e3139956fa90434e2", size = 14001280, upload-time = "2025-10-07T18:21:16.411Z" }, + { url = "https://files.pythonhosted.org/packages/ab/ad/96c1fc9f8854c37681c9613d825925c7f24ca1acfc62a4eb3896b50bacd2/ruff-0.14.0-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:f8d07350bc7af0a5ce8812b7d5c1a7293cf02476752f23fdfc500d24b79b783c", size = 15027286, upload-time = "2025-10-07T18:21:19.577Z" }, + { url = "https://files.pythonhosted.org/packages/b3/00/1426978f97df4fe331074baf69615f579dc4e7c37bb4c6f57c2aad80c87f/ruff-0.14.0-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:eec3bbbf3a7d5482b5c1f42d5fc972774d71d107d447919fca620b0be3e3b75e", size = 14451506, upload-time = "2025-10-07T18:21:22.779Z" }, + { url = "https://files.pythonhosted.org/packages/58/d5/9c1cea6e493c0cf0647674cca26b579ea9d2a213b74b5c195fbeb9678e15/ruff-0.14.0-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:16b68e183a0e28e5c176d51004aaa40559e8f90065a10a559176713fcf435206", size = 13437384, upload-time = "2025-10-07T18:21:25.758Z" }, + { url = "https://files.pythonhosted.org/packages/29/b4/4cd6a4331e999fc05d9d77729c95503f99eae3ba1160469f2b64866964e3/ruff-0.14.0-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:eb732d17db2e945cfcbbc52af0143eda1da36ca8ae25083dd4f66f1542fdf82e", size = 13447976, upload-time = "2025-10-07T18:21:28.83Z" }, + { url = "https://files.pythonhosted.org/packages/3b/c0/ac42f546d07e4f49f62332576cb845d45c67cf5610d1851254e341d563b6/ruff-0.14.0-py3-none-manylinux_2_31_riscv64.whl", hash = "sha256:c958f66ab884b7873e72df38dcabee03d556a8f2ee1b8538ee1c2bbd619883dd", size = 13682850, upload-time = "2025-10-07T18:21:31.842Z" }, + { url = "https://files.pythonhosted.org/packages/5f/c4/4b0c9bcadd45b4c29fe1af9c5d1dc0ca87b4021665dfbe1c4688d407aa20/ruff-0.14.0-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:7eb0499a2e01f6e0c285afc5bac43ab380cbfc17cd43a2e1dd10ec97d6f2c42d", size = 12449825, upload-time = "2025-10-07T18:21:35.074Z" }, + { url = "https://files.pythonhosted.org/packages/4b/a8/e2e76288e6c16540fa820d148d83e55f15e994d852485f221b9524514730/ruff-0.14.0-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:4c63b2d99fafa05efca0ab198fd48fa6030d57e4423df3f18e03aa62518c565f", size = 12272599, upload-time = "2025-10-07T18:21:38.08Z" }, + { url = "https://files.pythonhosted.org/packages/18/14/e2815d8eff847391af632b22422b8207704222ff575dec8d044f9ab779b2/ruff-0.14.0-py3-none-musllinux_1_2_i686.whl", hash = "sha256:668fce701b7a222f3f5327f86909db2bbe99c30877c8001ff934c5413812ac02", size = 13193828, upload-time = "2025-10-07T18:21:41.216Z" }, + { url = "https://files.pythonhosted.org/packages/44/c6/61ccc2987cf0aecc588ff8f3212dea64840770e60d78f5606cd7dc34de32/ruff-0.14.0-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:a86bf575e05cb68dcb34e4c7dfe1064d44d3f0c04bbc0491949092192b515296", size = 13628617, upload-time = "2025-10-07T18:21:44.04Z" }, + { url = "https://files.pythonhosted.org/packages/73/e6/03b882225a1b0627e75339b420883dc3c90707a8917d2284abef7a58d317/ruff-0.14.0-py3-none-win32.whl", hash = "sha256:7450a243d7125d1c032cb4b93d9625dea46c8c42b4f06c6b709baac168e10543", size = 12367872, upload-time = "2025-10-07T18:21:46.67Z" }, + { url = "https://files.pythonhosted.org/packages/41/77/56cf9cf01ea0bfcc662de72540812e5ba8e9563f33ef3d37ab2174892c47/ruff-0.14.0-py3-none-win_amd64.whl", hash = "sha256:ea95da28cd874c4d9c922b39381cbd69cb7e7b49c21b8152b014bd4f52acddc2", size = 13464628, upload-time = "2025-10-07T18:21:50.318Z" }, + { url = "https://files.pythonhosted.org/packages/c6/2a/65880dfd0e13f7f13a775998f34703674a4554906167dce02daf7865b954/ruff-0.14.0-py3-none-win_arm64.whl", hash = "sha256:f42c9495f5c13ff841b1da4cb3c2a42075409592825dada7c5885c2c844ac730", size = 12565142, upload-time = "2025-10-07T18:21:53.577Z" }, ] [[package]] @@ -6440,19 +6419,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/2b/6f/ec0012be842b1d888d46884ac5558fd62aeae1f0ec4f7a581433d890d4b5/types_requests-2.32.4.20250809-py3-none-any.whl", hash = "sha256:f73d1832fb519ece02c85b1f09d5f0dd3108938e7d47e7f94bbfa18a6782b163", size = 20644, upload-time = "2025-08-09T03:17:09.716Z" }, ] -[[package]] -name = "types-requests-oauthlib" -version = "2.0.0.20250809" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "types-oauthlib" }, - { name = "types-requests" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/ed/40/5eca857a2dbda0fedd69b7fd3f51cb0b6ece8d448327d29f0ae54612ec98/types_requests_oauthlib-2.0.0.20250809.tar.gz", hash = "sha256:f3b9b31e0394fe2c362f0d44bc9ef6d5c150a298d01089513cd54a51daec37a2", size = 11008, upload-time = "2025-08-09T03:17:50.705Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/f3/38/8777f0ab409a7249777f230f6aefe0e9ba98355dc8b05fb31391fa30f312/types_requests_oauthlib-2.0.0.20250809-py3-none-any.whl", hash = "sha256:0d1af4907faf9f4a1b0f0afbc7ec488f1dd5561a2b5b6dad70f78091a1acfb76", size = 14319, upload-time = "2025-08-09T03:17:49.786Z" }, -] - [[package]] name = "types-s3transfer" version = "0.13.1" diff --git a/scripts/stress-test/setup/import_workflow_app.py b/scripts/stress-test/setup/import_workflow_app.py index 86d0239e35..41a76bd29b 100755 --- a/scripts/stress-test/setup/import_workflow_app.py +++ b/scripts/stress-test/setup/import_workflow_app.py @@ -8,7 +8,7 @@ sys.path.append(str(Path(__file__).parent.parent)) import json import httpx -from common import Logger, config_helper +from common import Logger, config_helper # type: ignore[import] def import_workflow_app() -> None: diff --git a/web/app/components/base/date-and-time-picker/time-picker/index.spec.tsx b/web/app/components/base/date-and-time-picker/time-picker/index.spec.tsx new file mode 100644 index 0000000000..40bc2928c8 --- /dev/null +++ b/web/app/components/base/date-and-time-picker/time-picker/index.spec.tsx @@ -0,0 +1,95 @@ +import React from 'react' +import { fireEvent, render, screen } from '@testing-library/react' +import TimePicker from './index' +import dayjs from '../utils/dayjs' +import { isDayjsObject } from '../utils/dayjs' + +jest.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string) => { + if (key === 'time.defaultPlaceholder') return 'Pick a time...' + if (key === 'time.operation.now') return 'Now' + if (key === 'time.operation.ok') return 'OK' + if (key === 'common.operation.clear') return 'Clear' + return key + }, + }), +})) + +jest.mock('@/app/components/base/portal-to-follow-elem', () => ({ + PortalToFollowElem: ({ children }: { children: React.ReactNode }) =>
{children}
, + PortalToFollowElemTrigger: ({ children, onClick }: { children: React.ReactNode, onClick: (e: React.MouseEvent) => void }) => ( +
{children}
+ ), + PortalToFollowElemContent: ({ children }: { children: React.ReactNode }) => ( +
{children}
+ ), +})) + +jest.mock('./options', () => () =>
) +jest.mock('./header', () => () =>
) + +describe('TimePicker', () => { + const baseProps = { + onChange: jest.fn(), + onClear: jest.fn(), + } + + beforeEach(() => { + jest.clearAllMocks() + }) + + test('renders formatted value for string input (Issue #26692 regression)', () => { + render( + , + ) + + expect(screen.getByDisplayValue('06:45 PM')).toBeInTheDocument() + }) + + test('confirms cleared value when confirming without selection', () => { + render( + , + ) + + const input = screen.getByRole('textbox') + fireEvent.click(input) + + const clearButton = screen.getByRole('button', { name: /clear/i }) + fireEvent.click(clearButton) + + const confirmButton = screen.getByRole('button', { name: 'OK' }) + fireEvent.click(confirmButton) + + expect(baseProps.onChange).toHaveBeenCalledTimes(1) + expect(baseProps.onChange).toHaveBeenCalledWith(undefined) + expect(baseProps.onClear).not.toHaveBeenCalled() + }) + + test('selecting current time emits timezone-aware value', () => { + const onChange = jest.fn() + render( + , + ) + + const nowButton = screen.getByRole('button', { name: 'Now' }) + fireEvent.click(nowButton) + + expect(onChange).toHaveBeenCalledTimes(1) + const emitted = onChange.mock.calls[0][0] + expect(isDayjsObject(emitted)).toBe(true) + expect(emitted?.utcOffset()).toBe(dayjs().tz('America/New_York').utcOffset()) + }) +}) diff --git a/web/app/components/base/date-and-time-picker/time-picker/index.tsx b/web/app/components/base/date-and-time-picker/time-picker/index.tsx index 1fb2cfed11..f23fcf8f4e 100644 --- a/web/app/components/base/date-and-time-picker/time-picker/index.tsx +++ b/web/app/components/base/date-and-time-picker/time-picker/index.tsx @@ -1,6 +1,13 @@ import React, { useCallback, useEffect, useRef, useState } from 'react' -import type { Period, TimePickerProps } from '../types' -import dayjs, { cloneTime, getDateWithTimezone, getHourIn12Hour } from '../utils/dayjs' +import type { Dayjs } from 'dayjs' +import { Period } from '../types' +import type { TimePickerProps } from '../types' +import dayjs, { + getDateWithTimezone, + getHourIn12Hour, + isDayjsObject, + toDayjs, +} from '../utils/dayjs' import { PortalToFollowElem, PortalToFollowElemContent, @@ -13,6 +20,11 @@ import { useTranslation } from 'react-i18next' import { RiCloseCircleFill, RiTimeLine } from '@remixicon/react' import cn from '@/utils/classnames' +const to24Hour = (hour12: string, period: Period) => { + const normalized = Number.parseInt(hour12, 10) % 12 + return period === Period.PM ? normalized + 12 : normalized +} + const TimePicker = ({ value, timezone, @@ -28,7 +40,11 @@ const TimePicker = ({ const [isOpen, setIsOpen] = useState(false) const containerRef = useRef(null) const isInitial = useRef(true) - const [selectedTime, setSelectedTime] = useState(() => value ? getDateWithTimezone({ timezone, date: value }) : undefined) + + // Initialize selectedTime + const [selectedTime, setSelectedTime] = useState(() => { + return toDayjs(value, { timezone }) + }) useEffect(() => { const handleClickOutside = (event: MouseEvent) => { @@ -39,20 +55,47 @@ const TimePicker = ({ return () => document.removeEventListener('mousedown', handleClickOutside) }, []) + // Track previous values to avoid unnecessary updates + const prevValueRef = useRef(value) + const prevTimezoneRef = useRef(timezone) + useEffect(() => { if (isInitial.current) { isInitial.current = false + // Save initial values on first render + prevValueRef.current = value + prevTimezoneRef.current = timezone return } - if (value) { - const newValue = getDateWithTimezone({ date: value, timezone }) - setSelectedTime(newValue) - onChange(newValue) + + // Only update when timezone changes but value doesn't + const valueChanged = prevValueRef.current !== value + const timezoneChanged = prevTimezoneRef.current !== timezone + + // Update reference values + prevValueRef.current = value + prevTimezoneRef.current = timezone + + // Skip if neither timezone changed nor value changed + if (!timezoneChanged && !valueChanged) return + + if (value !== undefined && value !== null) { + const dayjsValue = toDayjs(value, { timezone }) + if (!dayjsValue) return + + setSelectedTime(dayjsValue) + + if (timezoneChanged && !valueChanged) + onChange(dayjsValue) + return } - else { - setSelectedTime(prev => prev ? getDateWithTimezone({ date: prev, timezone }) : undefined) - } - }, [timezone]) + + setSelectedTime((prev) => { + if (!isDayjsObject(prev)) + return undefined + return timezone ? getDateWithTimezone({ date: prev, timezone }) : prev + }) + }, [timezone, value, onChange]) const handleClickTrigger = (e: React.MouseEvent) => { e.stopPropagation() @@ -61,8 +104,16 @@ const TimePicker = ({ return } setIsOpen(true) - if (value) - setSelectedTime(value) + + if (value) { + const dayjsValue = toDayjs(value, { timezone }) + const needsUpdate = dayjsValue && ( + !selectedTime + || !isDayjsObject(selectedTime) + || !dayjsValue.isSame(selectedTime, 'minute') + ) + if (needsUpdate) setSelectedTime(dayjsValue) + } } const handleClear = (e: React.MouseEvent) => { @@ -73,42 +124,68 @@ const TimePicker = ({ } const handleTimeSelect = (hour: string, minute: string, period: Period) => { - const newTime = cloneTime(dayjs(), dayjs(`1/1/2000 ${hour}:${minute} ${period}`)) + const periodAdjustedHour = to24Hour(hour, period) + const nextMinute = Number.parseInt(minute, 10) setSelectedTime((prev) => { - return prev ? cloneTime(prev, newTime) : newTime + const reference = isDayjsObject(prev) + ? prev + : (timezone ? getDateWithTimezone({ timezone }) : dayjs()).startOf('minute') + return reference + .set('hour', periodAdjustedHour) + .set('minute', nextMinute) + .set('second', 0) + .set('millisecond', 0) }) } + const getSafeTimeObject = useCallback(() => { + if (isDayjsObject(selectedTime)) + return selectedTime + return (timezone ? getDateWithTimezone({ timezone }) : dayjs()).startOf('day') + }, [selectedTime, timezone]) + const handleSelectHour = useCallback((hour: string) => { - const time = selectedTime || dayjs().startOf('day') + const time = getSafeTimeObject() handleTimeSelect(hour, time.minute().toString().padStart(2, '0'), time.format('A') as Period) - }, [selectedTime]) + }, [getSafeTimeObject]) const handleSelectMinute = useCallback((minute: string) => { - const time = selectedTime || dayjs().startOf('day') + const time = getSafeTimeObject() handleTimeSelect(getHourIn12Hour(time).toString().padStart(2, '0'), minute, time.format('A') as Period) - }, [selectedTime]) + }, [getSafeTimeObject]) const handleSelectPeriod = useCallback((period: Period) => { - const time = selectedTime || dayjs().startOf('day') + const time = getSafeTimeObject() handleTimeSelect(getHourIn12Hour(time).toString().padStart(2, '0'), time.minute().toString().padStart(2, '0'), period) - }, [selectedTime]) + }, [getSafeTimeObject]) const handleSelectCurrentTime = useCallback(() => { const newDate = getDateWithTimezone({ timezone }) setSelectedTime(newDate) onChange(newDate) setIsOpen(false) - }, [onChange, timezone]) + }, [timezone, onChange]) const handleConfirm = useCallback(() => { - onChange(selectedTime) + const valueToEmit = isDayjsObject(selectedTime) ? selectedTime : undefined + onChange(valueToEmit) setIsOpen(false) - }, [onChange, selectedTime]) + }, [selectedTime, onChange]) const timeFormat = 'hh:mm A' - const displayValue = value?.format(timeFormat) || '' - const placeholderDate = isOpen && selectedTime ? selectedTime.format(timeFormat) : (placeholder || t('time.defaultPlaceholder')) + + const formatTimeValue = useCallback((timeValue: string | Dayjs | undefined): string => { + if (!timeValue) return '' + + const dayjsValue = toDayjs(timeValue, { timezone }) + return dayjsValue?.format(timeFormat) || '' + }, [timezone]) + + const displayValue = formatTimeValue(value) + + const placeholderDate = isOpen && isDayjsObject(selectedTime) + ? selectedTime.format(timeFormat) + : (placeholder || t('time.defaultPlaceholder')) const inputElem = (
diff --git a/web/app/components/base/date-and-time-picker/types.ts b/web/app/components/base/date-and-time-picker/types.ts index 4ac01c142a..b51c2ebb01 100644 --- a/web/app/components/base/date-and-time-picker/types.ts +++ b/web/app/components/base/date-and-time-picker/types.ts @@ -54,7 +54,7 @@ export type TriggerParams = { onClick: (e: React.MouseEvent) => void } export type TimePickerProps = { - value: Dayjs | undefined + value: Dayjs | string | undefined timezone?: string placeholder?: string onChange: (date: Dayjs | undefined) => void diff --git a/web/app/components/base/date-and-time-picker/utils/dayjs.spec.ts b/web/app/components/base/date-and-time-picker/utils/dayjs.spec.ts new file mode 100644 index 0000000000..549ab01029 --- /dev/null +++ b/web/app/components/base/date-and-time-picker/utils/dayjs.spec.ts @@ -0,0 +1,67 @@ +import dayjs from './dayjs' +import { + getDateWithTimezone, + isDayjsObject, + toDayjs, +} from './dayjs' + +describe('dayjs utilities', () => { + const timezone = 'UTC' + + test('toDayjs parses time-only strings with timezone support', () => { + const result = toDayjs('18:45', { timezone }) + expect(result).toBeDefined() + expect(result?.format('HH:mm')).toBe('18:45') + expect(result?.utcOffset()).toBe(getDateWithTimezone({ timezone }).utcOffset()) + }) + + test('toDayjs parses 12-hour time strings', () => { + const tz = 'America/New_York' + const result = toDayjs('07:15 PM', { timezone: tz }) + expect(result).toBeDefined() + expect(result?.format('HH:mm')).toBe('19:15') + expect(result?.utcOffset()).toBe(getDateWithTimezone({ timezone: tz }).utcOffset()) + }) + + test('isDayjsObject detects dayjs instances', () => { + const date = dayjs() + expect(isDayjsObject(date)).toBe(true) + expect(isDayjsObject(getDateWithTimezone({ timezone }))).toBe(true) + expect(isDayjsObject('2024-01-01')).toBe(false) + expect(isDayjsObject({})).toBe(false) + }) + + test('toDayjs parses datetime strings in target timezone', () => { + const value = '2024-05-01 12:00:00' + const tz = 'America/New_York' + + const result = toDayjs(value, { timezone: tz }) + + expect(result).toBeDefined() + expect(result?.hour()).toBe(12) + expect(result?.format('YYYY-MM-DD HH:mm')).toBe('2024-05-01 12:00') + }) + + test('toDayjs parses ISO datetime strings in target timezone', () => { + const value = '2024-05-01T14:30:00' + const tz = 'Europe/London' + + const result = toDayjs(value, { timezone: tz }) + + expect(result).toBeDefined() + expect(result?.hour()).toBe(14) + expect(result?.minute()).toBe(30) + }) + + test('toDayjs handles dates without time component', () => { + const value = '2024-05-01' + const tz = 'America/Los_Angeles' + + const result = toDayjs(value, { timezone: tz }) + + expect(result).toBeDefined() + expect(result?.format('YYYY-MM-DD')).toBe('2024-05-01') + expect(result?.hour()).toBe(0) + expect(result?.minute()).toBe(0) + }) +}) diff --git a/web/app/components/base/date-and-time-picker/utils/dayjs.ts b/web/app/components/base/date-and-time-picker/utils/dayjs.ts index fef35bf6ca..808b50247a 100644 --- a/web/app/components/base/date-and-time-picker/utils/dayjs.ts +++ b/web/app/components/base/date-and-time-picker/utils/dayjs.ts @@ -10,6 +10,25 @@ dayjs.extend(timezone) export default dayjs const monthMaps: Record = {} +const DEFAULT_OFFSET_STR = 'UTC+0' +const TIME_ONLY_REGEX = /^(\d{1,2}):(\d{2})(?::(\d{2})(?:\.(\d{1,3}))?)?$/ +const TIME_ONLY_12H_REGEX = /^(\d{1,2}):(\d{2})(?::(\d{2}))?\s?(AM|PM)$/i + +const COMMON_PARSE_FORMATS = [ + 'YYYY-MM-DD', + 'YYYY/MM/DD', + 'DD-MM-YYYY', + 'DD/MM/YYYY', + 'MM-DD-YYYY', + 'MM/DD/YYYY', + 'YYYY-MM-DDTHH:mm:ss.SSSZ', + 'YYYY-MM-DDTHH:mm:ssZ', + 'YYYY-MM-DD HH:mm:ss', + 'YYYY-MM-DDTHH:mm', + 'YYYY-MM-DDTHH:mmZ', + 'YYYY-MM-DDTHH:mm:ss', + 'YYYY-MM-DDTHH:mm:ss.SSS', +] export const cloneTime = (targetDate: Dayjs, sourceDate: Dayjs) => { return targetDate.clone() @@ -76,21 +95,116 @@ export const getHourIn12Hour = (date: Dayjs) => { return hour === 0 ? 12 : hour >= 12 ? hour - 12 : hour } -export const getDateWithTimezone = (props: { date?: Dayjs, timezone?: string }) => { - return props.date ? dayjs.tz(props.date, props.timezone) : dayjs().tz(props.timezone) +export const getDateWithTimezone = ({ date, timezone }: { date?: Dayjs, timezone?: string }) => { + if (!timezone) + return (date ?? dayjs()).clone() + return date ? dayjs.tz(date, timezone) : dayjs().tz(timezone) } -// Asia/Shanghai -> UTC+8 -const DEFAULT_OFFSET_STR = 'UTC+0' export const convertTimezoneToOffsetStr = (timezone?: string) => { if (!timezone) return DEFAULT_OFFSET_STR const tzItem = tz.find(item => item.value === timezone) - if(!tzItem) + if (!tzItem) return DEFAULT_OFFSET_STR return `UTC${tzItem.name.charAt(0)}${tzItem.name.charAt(2)}` } +export const isDayjsObject = (value: unknown): value is Dayjs => dayjs.isDayjs(value) + +export type ToDayjsOptions = { + timezone?: string + format?: string + formats?: string[] +} + +const warnParseFailure = (value: string) => { + if (process.env.NODE_ENV !== 'production') + console.warn('[TimePicker] Failed to parse time value', value) +} + +const normalizeMillisecond = (value: string | undefined) => { + if (!value) return 0 + if (value.length === 3) return Number(value) + if (value.length > 3) return Number(value.slice(0, 3)) + return Number(value.padEnd(3, '0')) +} + +const applyTimezone = (date: Dayjs, timezone?: string) => { + return timezone ? getDateWithTimezone({ date, timezone }) : date +} + +export const toDayjs = (value: string | Dayjs | undefined, options: ToDayjsOptions = {}): Dayjs | undefined => { + if (!value) + return undefined + + const { timezone: tzName, format, formats } = options + + if (isDayjsObject(value)) + return applyTimezone(value, tzName) + + if (typeof value !== 'string') + return undefined + + const trimmed = value.trim() + + if (format) { + const parsedWithFormat = tzName + ? dayjs.tz(trimmed, format, tzName, true) + : dayjs(trimmed, format, true) + if (parsedWithFormat.isValid()) + return parsedWithFormat + } + + const timeMatch = TIME_ONLY_REGEX.exec(trimmed) + if (timeMatch) { + const base = applyTimezone(dayjs(), tzName).startOf('day') + const rawHour = Number(timeMatch[1]) + const minute = Number(timeMatch[2]) + const second = timeMatch[3] ? Number(timeMatch[3]) : 0 + const millisecond = normalizeMillisecond(timeMatch[4]) + + return base + .set('hour', rawHour) + .set('minute', minute) + .set('second', second) + .set('millisecond', millisecond) + } + + const timeMatch12h = TIME_ONLY_12H_REGEX.exec(trimmed) + if (timeMatch12h) { + const base = applyTimezone(dayjs(), tzName).startOf('day') + let hour = Number(timeMatch12h[1]) % 12 + const isPM = timeMatch12h[4]?.toUpperCase() === 'PM' + if (isPM) + hour += 12 + const minute = Number(timeMatch12h[2]) + const second = timeMatch12h[3] ? Number(timeMatch12h[3]) : 0 + + return base + .set('hour', hour) + .set('minute', minute) + .set('second', second) + .set('millisecond', 0) + } + + const candidateFormats = formats ?? COMMON_PARSE_FORMATS + for (const fmt of candidateFormats) { + const parsed = tzName + ? dayjs.tz(trimmed, fmt, tzName, true) + : dayjs(trimmed, fmt, true) + if (parsed.isValid()) + return parsed + } + + const fallbackParsed = tzName ? dayjs.tz(trimmed, tzName) : dayjs(trimmed) + if (fallbackParsed.isValid()) + return fallbackParsed + + warnParseFailure(value) + return undefined +} + // Parse date with multiple format support export const parseDateWithFormat = (dateString: string, format?: string): Dayjs | null => { if (!dateString) return null @@ -103,15 +217,7 @@ export const parseDateWithFormat = (dateString: string, format?: string): Dayjs // Try common date formats const formats = [ - 'YYYY-MM-DD', // Standard format - 'YYYY/MM/DD', // Slash format - 'DD-MM-YYYY', // European format - 'DD/MM/YYYY', // European slash format - 'MM-DD-YYYY', // US format - 'MM/DD/YYYY', // US slash format - 'YYYY-MM-DDTHH:mm:ss.SSSZ', // ISO format - 'YYYY-MM-DDTHH:mm:ssZ', // ISO format (no milliseconds) - 'YYYY-MM-DD HH:mm:ss', // Standard datetime format + ...COMMON_PARSE_FORMATS, ] for (const fmt of formats) { diff --git a/web/app/components/base/icons/assets/public/tracing/aliyun-icon-big.svg b/web/app/components/base/icons/assets/public/tracing/aliyun-icon-big.svg index 210a1cd00b..d82b9bc1e4 100644 --- a/web/app/components/base/icons/assets/public/tracing/aliyun-icon-big.svg +++ b/web/app/components/base/icons/assets/public/tracing/aliyun-icon-big.svg @@ -1 +1 @@ - + \ No newline at end of file diff --git a/web/app/components/base/icons/assets/public/tracing/aliyun-icon.svg b/web/app/components/base/icons/assets/public/tracing/aliyun-icon.svg index 6f7645301c..cee8858471 100644 --- a/web/app/components/base/icons/assets/public/tracing/aliyun-icon.svg +++ b/web/app/components/base/icons/assets/public/tracing/aliyun-icon.svg @@ -1 +1 @@ - + \ No newline at end of file diff --git a/web/app/components/base/icons/src/public/tracing/AliyunIcon.json b/web/app/components/base/icons/src/public/tracing/AliyunIcon.json index 5cbb52c237..154aeff8c6 100644 --- a/web/app/components/base/icons/src/public/tracing/AliyunIcon.json +++ b/web/app/components/base/icons/src/public/tracing/AliyunIcon.json @@ -1,118 +1,129 @@ { - "icon": { - "type": "element", - "isRootNode": true, - "name": "svg", - "attributes": { - "xmlns": "http://www.w3.org/2000/svg", - "xmlns:xlink": "http://www.w3.org/1999/xlink", - "fill": "none", - "version": "1.1", - "width": "65", - "height": "16", - "viewBox": "0 0 65 16" - }, - "children": [ - { - "type": "element", - "name": "defs", - "children": [ - { - "type": "element", - "name": "clipPath", - "attributes": { - "id": "master_svg0_42_34281" - }, - "children": [ - { - "type": "element", - "name": "rect", - "attributes": { - "x": "0", - "y": "0", - "width": "19", - "height": "16", - "rx": "0" - } - } - ] - } - ] - }, - { - "type": "element", - "name": "g", - "children": [ - { - "type": "element", - "name": "g", - "attributes": { - "clip-path": "url(#master_svg0_42_34281)" - }, - "children": [ - { - "type": "element", - "name": "g", - "children": [ - { - "type": "element", - "name": "g", - "children": [ - { - "type": "element", - "name": "path", - "attributes": { - "d": "M4.06862,14.6667C3.79213,14.6667,3.45463,14.5688,3.05614,14.373C2.97908,14.3351,2.92692,14.3105,2.89968,14.2992C2.33193,14.0628,1.82911,13.7294,1.39123,13.2989C0.463742,12.3871,0,11.2874,0,10C0,8.71258,0.463742,7.61293,1.39123,6.70107C2.16172,5.94358,3.06404,5.50073,4.09819,5.37252C4.23172,3.98276,4.81755,2.77756,5.85569,1.75693C7.04708,0.585642,8.4857,0,10.1716,0C11.5256,0,12.743,0.396982,13.8239,1.19095C14.8847,1.97019,15.61,2.97855,16,4.21604L14.7045,4.61063C14.4016,3.64918,13.8374,2.86532,13.0121,2.25905C12.1719,1.64191,11.2251,1.33333,10.1716,1.33333C8.8602,1.33333,7.74124,1.7888,6.81467,2.69974C5.88811,3.61067,5.42483,4.71076,5.42483,6L5.42483,6.66667L4.74673,6.66667C3.81172,6.66667,3.01288,6.99242,2.35021,7.64393C1.68754,8.2954,1.35621,9.08076,1.35621,10C1.35621,10.9192,1.68754,11.7046,2.35021,12.3561C2.66354,12.6641,3.02298,12.9026,3.42852,13.0714C3.48193,13.0937,3.55988,13.13,3.66237,13.1803C3.87004,13.2823,4.00545,13.3333,4.06862,13.3333L4.06862,14.6667Z", - "fill-rule": "evenodd", - "fill": "#FF6A00", - "fill-opacity": "1" - } - } - ] - }, - { - "type": "element", - "name": "g", - "children": [ - { - "type": "element", - "name": "path", - "attributes": { - "d": "M13.458613505859375,7.779393492279053C12.975613505859375,7.717463492279053,12.484813505859375,7.686503492279053,11.993983505859376,7.686503492279053C11.152583505859376,7.686503492279053,10.303403505859375,7.779393492279053,9.493183505859374,7.941943492279053C8.682953505859375,8.104503492279052,7.903893505859375,8.359943492279053,7.155983505859375,8.654083492279053C6.657383505859375,8.870823492279053,6.158783505859375,9.128843492279053,5.660181505859375,9.428153492279053C5.332974751859375,9.621673492279053,5.239486705859375,10.070633492279054,5.434253505859375,10.395743492279053L7.413073505859375,13.298533492279052C7.639003505859375,13.623603492279052,8.090863505859375,13.716463492279052,8.418073505859375,13.523003492279052C8.547913505859375,13.435263492279052,8.763453505859374,13.326893492279053,9.064693505859374,13.197863492279053C9.516553505859374,13.004333492279052,9.976203505859374,12.872733492279053,10.459223505859375,12.779863492279052C10.942243505859375,12.679263492279052,11.433053505859375,12.617333492279052,11.955023505859375,12.617333492279052L13.380683505859375,7.810353492279052L13.458613505859375,7.779393492279053ZM15.273813505859374,8.135463492279053L15.016753505859375,5.333333492279053L13.458613505859375,7.787133492279053C13.817013505859375,7.818093492279052,14.144213505859375,7.880023492279053,14.494743505859375,7.949683492279053C14.494743505859375,7.944523492279053,14.754433505859375,8.006453492279054,15.273813505859374,8.135463492279053ZM12.064083505859376,12.648273492279053L11.378523505859375,14.970463492279054L12.515943505859376,16.00003349227905L14.074083505859376,15.643933492279054L14.525943505859376,13.027603492279052C14.198743505859374,12.934663492279054,13.879283505859375,12.834063492279054,13.552083505859375,12.772133492279053C13.069083505859375,12.717933492279052,12.578283505859375,12.648273492279053,12.064083505859376,12.648273492279053ZM18.327743505859374,9.428153492279053C17.829143505859374,9.128843492279053,17.330543505859374,8.870823492279053,16.831943505859375,8.654083492279053C16.348943505859374,8.460573492279053,15.826943505859376,8.267053492279054,15.305013505859375,8.135463492279053L15.305013505859375,8.267053492279054L14.463613505859374,13.043063492279053C14.596083505859376,13.105003492279053,14.759683505859375,13.135933492279053,14.884283505859376,13.205603492279053C15.185523505859376,13.334623492279052,15.401043505859375,13.443003492279052,15.530943505859375,13.530733492279053C15.858143505859376,13.724263492279054,16.341143505859375,13.623603492279052,16.535943505859375,13.306263492279053L18.514743505859375,10.403483492279053C18.779643505859376,10.039673492279054,18.686143505859377,9.621673492279053,18.327743505859374,9.428153492279053Z", - "fill": "#FF6A00", - "fill-opacity": "1" - } - } - ] - } - ] - } - ] - }, - { - "type": "element", - "name": "g", - "children": [ - { - "type": "element", - "name": "g", - "children": [ - { - "type": "element", - "name": "path", - "attributes": { - "d": "M25.044,2.668L34.676,2.668L34.676,4.04L25.044,4.04L25.044,2.668ZM29.958,7.82Q29.258,9.066,28.355,10.41Q27.451999999999998,11.754,26.92,12.3L32.506,11.782Q31.442,10.158,30.84,9.346L32.058,8.562000000000001Q32.786,9.5,33.843,11.012Q34.9,12.524,35.516,13.546L34.214,14.526Q33.891999999999996,13.966,33.346000000000004,13.098Q32.016,13.182,29.734,13.378Q27.451999999999998,13.574,25.87,13.742L25.31,13.812L24.834,13.882L24.414,12.468Q24.708,12.37,24.862000000000002,12.265Q25.016,12.16,25.121,12.069Q25.226,11.978,25.268,11.936Q25.912,11.32,26.724,10.165Q27.536,9.01,28.208,7.82L23.854,7.82L23.854,6.434L35.866,6.434L35.866,7.82L29.958,7.82ZM42.656,7.414L42.656,8.576L41.354,8.576L41.354,1.814L42.656,1.87L42.656,7.036Q43.314,5.846,43.888000000000005,4.369Q44.462,2.892,44.714,1.6600000000000001L46.086,1.981999Q45.96,2.612,45.722,3.41L49.6,3.41L49.6,4.74L45.274,4.74Q44.616,6.56,43.706,8.128L42.656,7.414ZM38.596000000000004,2.346L39.884,2.402L39.884,8.212L38.596000000000004,8.212L38.596000000000004,2.346ZM46.184,4.964Q46.688,5.356,47.5,6.175Q48.312,6.994,48.788,7.582L47.751999999999995,8.59Q47.346000000000004,8.072,46.576,7.274Q45.806,6.476,45.204,5.902L46.184,4.964ZM48.41,9.01L48.41,12.706L49.894,12.706L49.894,13.966L37.391999999999996,13.966L37.391999999999996,12.706L38.848,12.706L38.848,9.01L48.41,9.01ZM41.676,10.256L40.164,10.256L40.164,12.706L41.676,12.706L41.676,10.256ZM42.908,12.706L44.364000000000004,12.706L44.364000000000004,10.256L42.908,10.256L42.908,12.706ZM45.582,12.706L47.108000000000004,12.706L47.108000000000004,10.256L45.582,10.256L45.582,12.706ZM54.906,7.456L55.116,8.394L54.178,8.814L54.178,12.818Q54.178,13.434,54.031,13.735Q53.884,14.036,53.534,14.162Q53.184,14.288,52.456,14.358L51.867999999999995,14.414L51.476,13.084L52.162,13.028Q52.512,13,52.652,12.958Q52.792,12.916,52.841,12.797Q52.89,12.678,52.89,12.384L52.89,9.36Q51.980000000000004,9.724,51.322,9.948L51.013999999999996,8.576Q51.798,8.324,52.89,7.876L52.89,5.524L51.42,5.524L51.42,4.166L52.89,4.166L52.89,1.7579989999999999L54.178,1.814L54.178,4.166L55.214,4.166L55.214,5.524L54.178,5.524L54.178,7.316L54.808,7.022L54.906,7.456ZM56.894,4.5440000000000005L56.894,6.098L55.564,6.098L55.564,3.256L58.686,3.256Q58.42,2.346,58.266,1.9260000000000002L59.624,1.7579989999999999Q59.848,2.276,60.142,3.256L63.25,3.256L63.25,6.098L61.962,6.098L61.962,4.5440000000000005L56.894,4.5440000000000005ZM59.008,6.322Q58.392,6.938,57.685,7.512Q56.978,8.086,55.956,8.841999999999999L55.242,7.764Q56.824,6.728,58.126,5.37L59.008,6.322ZM60.422,5.37Q61.024,5.776,62.095,6.581Q63.166,7.386,63.656,7.806L62.942,8.982Q62.368,8.45,61.332,7.652Q60.296,6.854,59.666,6.434L60.422,5.37ZM62.592,10.256L60.044,10.256L60.044,12.566L63.572,12.566L63.572,13.826L55.144,13.826L55.144,12.566L58.63,12.566L58.63,10.256L56.054,10.256L56.054,8.982L62.592,8.982L62.592,10.256Z", - "fill": "#FF6A00", - "fill-opacity": "1" - } - } - ] - } - ] - } - ] - } - ] - }, - "name": "AliyunIcon" + "icon": { + "type": "element", + "isRootNode": true, + "name": "svg", + "attributes": { + "xmlns": "http://www.w3.org/2000/svg", + "xmlns:xlink": "http://www.w3.org/1999/xlink", + "fill": "none", + "version": "1.1", + "width": "65", + "height": "16", + "viewBox": "0 0 65 16" + }, + "children": [ + { + "type": "element", + "name": "defs", + "attributes": {}, + "children": [ + { + "type": "element", + "name": "clipPath", + "attributes": { + "id": "master_svg0_42_34281" + }, + "children": [ + { + "type": "element", + "name": "rect", + "attributes": { + "x": "0", + "y": "0", + "width": "19", + "height": "16", + "rx": "0" + }, + "children": [] + } + ] + } + ] + }, + { + "type": "element", + "name": "g", + "attributes": {}, + "children": [ + { + "type": "element", + "name": "g", + "attributes": { + "clip-path": "url(#master_svg0_42_34281)" + }, + "children": [ + { + "type": "element", + "name": "g", + "attributes": {}, + "children": [ + { + "type": "element", + "name": "g", + "attributes": {}, + "children": [ + { + "type": "element", + "name": "path", + "attributes": { + "d": "M4.06862,14.6667C3.79213,14.6667,3.45463,14.5688,3.05614,14.373C2.97908,14.3351,2.92692,14.3105,2.89968,14.2992C2.33193,14.0628,1.82911,13.7294,1.39123,13.2989C0.463742,12.3871,0,11.2874,0,10C0,8.71258,0.463742,7.61293,1.39123,6.70107C2.16172,5.94358,3.06404,5.50073,4.09819,5.37252C4.23172,3.98276,4.81755,2.77756,5.85569,1.75693C7.04708,0.585642,8.4857,0,10.1716,0C11.5256,0,12.743,0.396982,13.8239,1.19095C14.8847,1.97019,15.61,2.97855,16,4.21604L14.7045,4.61063C14.4016,3.64918,13.8374,2.86532,13.0121,2.25905C12.1719,1.64191,11.2251,1.33333,10.1716,1.33333C8.8602,1.33333,7.74124,1.7888,6.81467,2.69974C5.88811,3.61067,5.42483,4.71076,5.42483,6L5.42483,6.66667L4.74673,6.66667C3.81172,6.66667,3.01288,6.99242,2.35021,7.64393C1.68754,8.2954,1.35621,9.08076,1.35621,10C1.35621,10.9192,1.68754,11.7046,2.35021,12.3561C2.66354,12.6641,3.02298,12.9026,3.42852,13.0714C3.48193,13.0937,3.55988,13.13,3.66237,13.1803C3.87004,13.2823,4.00545,13.3333,4.06862,13.3333L4.06862,14.6667Z", + "fill-rule": "evenodd", + "fill": "#FF6A00", + "fill-opacity": "1" + }, + "children": [] + } + ] + }, + { + "type": "element", + "name": "g", + "attributes": {}, + "children": [ + { + "type": "element", + "name": "path", + "attributes": { + "d": "M13.458613505859375,7.779393492279053C12.975613505859375,7.717463492279053,12.484813505859375,7.686503492279053,11.993983505859376,7.686503492279053C11.152583505859376,7.686503492279053,10.303403505859375,7.779393492279053,9.493183505859374,7.941943492279053C8.682953505859375,8.104503492279052,7.903893505859375,8.359943492279053,7.155983505859375,8.654083492279053C6.657383505859375,8.870823492279053,6.158783505859375,9.128843492279053,5.660181505859375,9.428153492279053C5.332974751859375,9.621673492279053,5.239486705859375,10.070633492279054,5.434253505859375,10.395743492279053L7.413073505859375,13.298533492279052C7.639003505859375,13.623603492279052,8.090863505859375,13.716463492279052,8.418073505859375,13.523003492279052C8.547913505859375,13.435263492279052,8.763453505859374,13.326893492279053,9.064693505859374,13.197863492279053C9.516553505859374,13.004333492279052,9.976203505859374,12.872733492279053,10.459223505859375,12.779863492279052C10.942243505859375,12.679263492279052,11.433053505859375,12.617333492279052,11.955023505859375,12.617333492279052L13.380683505859375,7.810353492279052L13.458613505859375,7.779393492279053ZM15.273813505859374,8.135463492279053L15.016753505859375,5.333333492279053L13.458613505859375,7.787133492279053C13.817013505859375,7.818093492279052,14.144213505859375,7.880023492279053,14.494743505859375,7.949683492279053C14.494743505859375,7.944523492279053,14.754433505859375,8.006453492279054,15.273813505859374,8.135463492279053ZM12.064083505859376,12.648273492279053L11.378523505859375,14.970463492279054L12.515943505859376,16.00003349227905L14.074083505859376,15.643933492279054L14.525943505859376,13.027603492279052C14.198743505859374,12.934663492279054,13.879283505859375,12.834063492279054,13.552083505859375,12.772133492279053C13.069083505859375,12.717933492279052,12.578283505859375,12.648273492279053,12.064083505859376,12.648273492279053ZM18.327743505859374,9.428153492279053C17.829143505859374,9.128843492279053,17.330543505859374,8.870823492279053,16.831943505859375,8.654083492279053C16.348943505859374,8.460573492279053,15.826943505859376,8.267053492279054,15.305013505859375,8.135463492279053L15.305013505859375,8.267053492279054L14.463613505859374,13.043063492279053C14.596083505859376,13.105003492279053,14.759683505859375,13.135933492279053,14.884283505859376,13.205603492279053C15.185523505859376,13.334623492279052,15.401043505859375,13.443003492279052,15.530943505859375,13.530733492279053C15.858143505859376,13.724263492279054,16.341143505859375,13.623603492279052,16.535943505859375,13.306263492279053L18.514743505859375,10.403483492279053C18.779643505859376,10.039673492279054,18.686143505859377,9.621673492279053,18.327743505859374,9.428153492279053Z", + "fill": "#FF6A00", + "fill-opacity": "1" + }, + "children": [] + } + ] + } + ] + } + ] + }, + { + "type": "element", + "name": "g", + "attributes": {}, + "children": [ + { + "type": "element", + "name": "g", + "attributes": {}, + "children": [ + { + "type": "element", + "name": "path", + "attributes": { + "d": "M25.044,2.668L34.676,2.668L34.676,4.04L25.044,4.04L25.044,2.668ZM29.958,7.82Q29.258,9.066,28.355,10.41Q27.451999999999998,11.754,26.92,12.3L32.506,11.782Q31.442,10.158,30.84,9.346L32.058,8.562000000000001Q32.786,9.5,33.843,11.012Q34.9,12.524,35.516,13.546L34.214,14.526Q33.891999999999996,13.966,33.346000000000004,13.098Q32.016,13.182,29.734,13.378Q27.451999999999998,13.574,25.87,13.742L25.31,13.812L24.834,13.882L24.414,12.468Q24.708,12.37,24.862000000000002,12.265Q25.016,12.16,25.121,12.069Q25.226,11.978,25.268,11.936Q25.912,11.32,26.724,10.165Q27.536,9.01,28.208,7.82L23.854,7.82L23.854,6.434L35.866,6.434L35.866,7.82L29.958,7.82ZM42.656,7.414L42.656,8.576L41.354,8.576L41.354,1.814L42.656,1.87L42.656,7.036Q43.314,5.846,43.888000000000005,4.369Q44.462,2.892,44.714,1.6600000000000001L46.086,1.981999Q45.96,2.612,45.722,3.41L49.6,3.41L49.6,4.74L45.274,4.74Q44.616,6.56,43.706,8.128L42.656,7.414ZM38.596000000000004,2.346L39.884,2.402L39.884,8.212L38.596000000000004,8.212L38.596000000000004,2.346ZM46.184,4.964Q46.688,5.356,47.5,6.175Q48.312,6.994,48.788,7.582L47.751999999999995,8.59Q47.346000000000004,8.072,46.576,7.274Q45.806,6.476,45.204,5.902L46.184,4.964ZM48.41,9.01L48.41,12.706L49.894,12.706L49.894,13.966L37.391999999999996,13.966L37.391999999999996,12.706L38.848,12.706L38.848,9.01L48.41,9.01ZM41.676,10.256L40.164,10.256L40.164,12.706L41.676,12.706L41.676,10.256ZM42.908,12.706L44.364000000000004,12.706L44.364000000000004,10.256L42.908,10.256L42.908,12.706ZM45.582,12.706L47.108000000000004,12.706L47.108000000000004,10.256L45.582,10.256L45.582,12.706ZM54.906,7.456L55.116,8.394L54.178,8.814L54.178,12.818Q54.178,13.434,54.031,13.735Q53.884,14.036,53.534,14.162Q53.184,14.288,52.456,14.358L51.867999999999995,14.414L51.476,13.084L52.162,13.028Q52.512,13,52.652,12.958Q52.792,12.916,52.841,12.797Q52.89,12.678,52.89,12.384L52.89,9.36Q51.980000000000004,9.724,51.322,9.948L51.013999999999996,8.576Q51.798,8.324,52.89,7.876L52.89,5.524L51.42,5.524L51.42,4.166L52.89,4.166L52.89,1.7579989999999999L54.178,1.814L54.178,4.166L55.214,4.166L55.214,5.524L54.178,5.524L54.178,7.316L54.808,7.022L54.906,7.456ZM56.894,4.5440000000000005L56.894,6.098L55.564,6.098L55.564,3.256L58.686,3.256Q58.42,2.346,58.266,1.9260000000000002L59.624,1.7579989999999999Q59.848,2.276,60.142,3.256L63.25,3.256L63.25,6.098L61.962,6.098L61.962,4.5440000000000005L56.894,4.5440000000000005ZM59.008,6.322Q58.392,6.938,57.685,7.512Q56.978,8.086,55.956,8.841999999999999L55.242,7.764Q56.824,6.728,58.126,5.37L59.008,6.322ZM60.422,5.37Q61.024,5.776,62.095,6.581Q63.166,7.386,63.656,7.806L62.942,8.982Q62.368,8.45,61.332,7.652Q60.296,6.854,59.666,6.434L60.422,5.37ZM62.592,10.256L60.044,10.256L60.044,12.566L63.572,12.566L63.572,13.826L55.144,13.826L55.144,12.566L58.63,12.566L58.63,10.256L56.054,10.256L56.054,8.982L62.592,8.982L62.592,10.256Z", + "fill": "#FF6A00", + "fill-opacity": "1" + }, + "children": [] + } + ] + } + ] + } + ] + } + ] + }, + "name": "AliyunIcon" } diff --git a/web/app/components/base/icons/src/public/tracing/AliyunIconBig.json b/web/app/components/base/icons/src/public/tracing/AliyunIconBig.json index ea60744daf..7ed5166461 100644 --- a/web/app/components/base/icons/src/public/tracing/AliyunIconBig.json +++ b/web/app/components/base/icons/src/public/tracing/AliyunIconBig.json @@ -1,71 +1,78 @@ { - "icon": { - "type": "element", - "isRootNode": true, - "name": "svg", - "attributes": { - "xmlns": "http://www.w3.org/2000/svg", - "xmlns:xlink": "http://www.w3.org/1999/xlink", - "fill": "none", - "version": "1.1", - "width": "96", - "height": "24", - "viewBox": "0 0 96 24" - }, - "children": [ - { - "type": "element", - "name": "g", - "children": [ - { - "type": "element", - "name": "g", - "children": [ - { - "type": "element", - "name": "path", - "attributes": { - "d": "M6.10294,22C5.68819,22,5.18195,21.8532,4.58421,21.5595C4.46861,21.5027,4.39038,21.4658,4.34951,21.4488C3.49789,21.0943,2.74367,20.5941,2.08684,19.9484C0.695613,18.5806,0,16.9311,0,15C0,13.0689,0.695612,11.4194,2.08684,10.0516C3.24259,8.91537,4.59607,8.2511,6.14728,8.05878C6.34758,5.97414,7.22633,4.16634,8.78354,2.63539C10.5706,0.878463,12.7286,0,15.2573,0C17.2884,0,19.1146,0.595472,20.7358,1.78642C22.327,2.95528,23.4151,4.46783,24,6.32406L22.0568,6.91594C21.6024,5.47377,20.7561,4.29798,19.5181,3.38858C18.2579,2.46286,16.8377,2,15.2573,2C13.2903,2,11.6119,2.6832,10.222,4.04961C8.83217,5.41601,8.13725,7.06614,8.13725,9L8.13725,10L7.12009,10C5.71758,10,4.51932,10.4886,3.52532,11.4659C2.53132,12.4431,2.03431,13.6211,2.03431,15C2.03431,16.3789,2.53132,17.5569,3.52532,18.5341C3.99531,18.9962,4.53447,19.3538,5.14278,19.6071C5.2229,19.6405,5.33983,19.695,5.49356,19.7705C5.80505,19.9235,6.00818,20,6.10294,20L6.10294,22Z", - "fill-rule": "evenodd", - "fill": "#FF6A00", - "fill-opacity": "1" - } - } - ] - }, - { - "type": "element", - "name": "g", - "children": [ - { - "type": "element", - "name": "path", - "attributes": { - "d": "M20.18796103515625,11.66909C19.46346103515625,11.5762,18.72726103515625,11.52975,17.991011035156248,11.52975C16.728921035156247,11.52975,15.45515103515625,11.66909,14.23981103515625,11.91292C13.02447103515625,12.156749999999999,11.85588103515625,12.539909999999999,10.73402103515625,12.98113C9.98612103515625,13.306239999999999,9.23822103515625,13.69327,8.49031803515625,14.14223C7.99950790415625,14.43251,7.85927603515625,15.10595,8.15142503515625,15.59361L11.11966103515625,19.9478C11.45855103515625,20.4354,12.13634103515625,20.5747,12.627151035156249,20.2845C12.821921035156251,20.152900000000002,13.14523103515625,19.990299999999998,13.59708103515625,19.796799999999998C14.27487103515625,19.506500000000003,14.964341035156249,19.3091,15.68887103515625,19.169800000000002C16.413401035156248,19.018900000000002,17.14962103515625,18.926000000000002,17.93258103515625,18.926000000000002L20.071061035156248,11.715530000000001L20.18796103515625,11.66909ZM22.91076103515625,12.20319L22.525161035156252,8L20.18796103515625,11.6807C20.72556103515625,11.72714,21.21636103515625,11.82003,21.74216103515625,11.92453C21.74216103515625,11.91679,22.13166103515625,12.00968,22.91076103515625,12.20319ZM18.09616103515625,18.9724L17.06782103515625,22.4557L18.773961035156248,24L21.11116103515625,23.465899999999998L21.788961035156248,19.5414C21.298161035156248,19.402,20.81896103515625,19.2511,20.32816103515625,19.1582C19.60366103515625,19.076900000000002,18.86746103515625,18.9724,18.09616103515625,18.9724ZM27.49166103515625,14.14223C26.74376103515625,13.69327,25.99586103515625,13.306239999999999,25.24796103515625,12.98113C24.52346103515625,12.69086,23.74046103515625,12.40058,22.95756103515625,12.20319L22.95756103515625,12.40058L21.69546103515625,19.5646C21.89416103515625,19.6575,22.139561035156248,19.7039,22.32646103515625,19.8084C22.77836103515625,20.0019,23.101661035156248,20.1645,23.29646103515625,20.2961C23.78726103515625,20.586399999999998,24.51176103515625,20.4354,24.80396103515625,19.959400000000002L27.77216103515625,15.605229999999999C28.16946103515625,15.05951,28.02926103515625,14.43251,27.49166103515625,14.14223Z", - "fill": "#FF6A00", - "fill-opacity": "1" - } - } - ] - }, - { - "type": "element", - "name": "g", - "children": [ - { - "type": "element", - "name": "path", - "attributes": { - "d": "M35.785,3.8624638671875L50.233000000000004,3.8624638671875L50.233000000000004,5.9204638671875L35.785,5.9204638671875L35.785,3.8624638671875ZM43.156,11.5904638671875Q42.106,13.4594638671875,40.7515,15.4754638671875Q39.397,17.4914638671875,38.599000000000004,18.3104638671875L46.978,17.5334638671875Q45.382,15.0974638671875,44.479,13.8794638671875L46.306,12.7034638671875Q47.397999999999996,14.1104638671875,48.9835,16.3784638671875Q50.569,18.6464638671875,51.492999999999995,20.1794638671875L49.54,21.6494638671875Q49.057,20.8094638671875,48.238,19.5074638671875Q46.243,19.6334638671875,42.82,19.9274638671875Q39.397,20.2214638671875,37.024,20.4734638671875L36.184,20.5784638671875L35.47,20.6834638671875L34.84,18.5624638671875Q35.281,18.4154638671875,35.512,18.2579638671875Q35.743,18.1004638671875,35.9005,17.963963867187502Q36.058,17.8274638671875,36.121,17.7644638671875Q37.087,16.840463867187502,38.305,15.1079638671875Q39.522999999999996,13.3754638671875,40.531,11.5904638671875L34,11.5904638671875L34,9.5114638671875L52.018,9.5114638671875L52.018,11.5904638671875L43.156,11.5904638671875ZM62.203,10.9814638671875L62.203,12.7244638671875L60.25,12.7244638671875L60.25,2.5814638671875L62.203,2.6654638671875L62.203,10.4144638671875Q63.19,8.6294638671875,64.051,6.4139638671875Q64.912,4.1984638671875,65.28999999999999,2.3504638671875L67.348,2.8334628671875Q67.15899999999999,3.7784638671875,66.80199999999999,4.9754638671875L72.619,4.9754638671875L72.619,6.9704638671875L66.13,6.9704638671875Q65.143,9.7004638671875,63.778,12.0524638671875L62.203,10.9814638671875ZM56.113,3.3794638671875L58.045,3.4634638671875L58.045,12.1784638671875L56.113,12.1784638671875L56.113,3.3794638671875ZM67.495,7.3064638671875Q68.251,7.8944638671875,69.469,9.1229638671875Q70.687,10.3514638671875,71.40100000000001,11.2334638671875L69.84700000000001,12.7454638671875Q69.238,11.9684638671875,68.083,10.7714638671875Q66.928,9.5744638671875,66.025,8.7134638671875L67.495,7.3064638671875ZM70.834,13.3754638671875L70.834,18.9194638671875L73.06,18.9194638671875L73.06,20.8094638671875L54.307,20.8094638671875L54.307,18.9194638671875L56.491,18.9194638671875L56.491,13.3754638671875L70.834,13.3754638671875ZM60.733000000000004,15.2444638671875L58.465,15.2444638671875L58.465,18.9194638671875L60.733000000000004,18.9194638671875L60.733000000000004,15.2444638671875ZM62.581,18.9194638671875L64.765,18.9194638671875L64.765,15.2444638671875L62.581,15.2444638671875L62.581,18.9194638671875ZM66.592,18.9194638671875L68.881,18.9194638671875L68.881,15.2444638671875L66.592,15.2444638671875L66.592,18.9194638671875ZM80.578,11.0444638671875L80.893,12.4514638671875L79.48599999999999,13.0814638671875L79.48599999999999,19.0874638671875Q79.48599999999999,20.0114638671875,79.2655,20.4629638671875Q79.045,20.9144638671875,78.52000000000001,21.1034638671875Q77.995,21.2924638671875,76.90299999999999,21.3974638671875L76.021,21.4814638671875L75.43299999999999,19.4864638671875L76.462,19.4024638671875Q76.987,19.3604638671875,77.197,19.2974638671875Q77.407,19.2344638671875,77.4805,19.0559638671875Q77.554,18.8774638671875,77.554,18.4364638671875L77.554,13.9004638671875Q76.189,14.4464638671875,75.202,14.7824638671875L74.74000000000001,12.7244638671875Q75.916,12.3464638671875,77.554,11.6744638671875L77.554,8.1464638671875L75.34899999999999,8.1464638671875L75.34899999999999,6.1094638671875L77.554,6.1094638671875L77.554,2.4974628671875L79.48599999999999,2.5814638671875L79.48599999999999,6.1094638671875L81.03999999999999,6.1094638671875L81.03999999999999,8.1464638671875L79.48599999999999,8.1464638671875L79.48599999999999,10.8344638671875L80.431,10.3934638671875L80.578,11.0444638671875ZM83.56,6.6764638671875L83.56,9.0074638671875L81.565,9.0074638671875L81.565,4.7444638671875L86.24799999999999,4.7444638671875Q85.84899999999999,3.3794638671875,85.618,2.7494638671875L87.655,2.4974628671875Q87.991,3.2744638671875,88.432,4.7444638671875L93.094,4.7444638671875L93.094,9.0074638671875L91.162,9.0074638671875L91.162,6.6764638671875L83.56,6.6764638671875ZM86.731,9.3434638671875Q85.807,10.2674638671875,84.7465,11.1284638671875Q83.686,11.9894638671875,82.15299999999999,13.1234638671875L81.082,11.5064638671875Q83.455,9.9524638671875,85.408,7.9154638671875L86.731,9.3434638671875ZM88.852,7.9154638671875Q89.755,8.5244638671875,91.3615,9.731963867187499Q92.968,10.9394638671875,93.703,11.5694638671875L92.632,13.3334638671875Q91.771,12.5354638671875,90.217,11.3384638671875Q88.663,10.1414638671875,87.718,9.5114638671875L88.852,7.9154638671875ZM92.107,15.2444638671875L88.285,15.2444638671875L88.285,18.7094638671875L93.577,18.7094638671875L93.577,20.5994638671875L80.935,20.5994638671875L80.935,18.7094638671875L86.164,18.7094638671875L86.164,15.2444638671875L82.3,15.2444638671875L82.3,13.3334638671875L92.107,13.3334638671875L92.107,15.2444638671875Z", - "fill": "#FF6A00", - "fill-opacity": "1" - } - } - ] - } - ] - } - ] - }, - "name": "AliyunBigIcon" + "icon": { + "type": "element", + "isRootNode": true, + "name": "svg", + "attributes": { + "xmlns": "http://www.w3.org/2000/svg", + "xmlns:xlink": "http://www.w3.org/1999/xlink", + "fill": "none", + "version": "1.1", + "width": "96", + "height": "24", + "viewBox": "0 0 96 24" + }, + "children": [ + { + "type": "element", + "name": "g", + "attributes": {}, + "children": [ + { + "type": "element", + "name": "g", + "attributes": {}, + "children": [ + { + "type": "element", + "name": "path", + "attributes": { + "d": "M6.10294,22C5.68819,22,5.18195,21.8532,4.58421,21.5595C4.46861,21.5027,4.39038,21.4658,4.34951,21.4488C3.49789,21.0943,2.74367,20.5941,2.08684,19.9484C0.695613,18.5806,0,16.9311,0,15C0,13.0689,0.695612,11.4194,2.08684,10.0516C3.24259,8.91537,4.59607,8.2511,6.14728,8.05878C6.34758,5.97414,7.22633,4.16634,8.78354,2.63539C10.5706,0.878463,12.7286,0,15.2573,0C17.2884,0,19.1146,0.595472,20.7358,1.78642C22.327,2.95528,23.4151,4.46783,24,6.32406L22.0568,6.91594C21.6024,5.47377,20.7561,4.29798,19.5181,3.38858C18.2579,2.46286,16.8377,2,15.2573,2C13.2903,2,11.6119,2.6832,10.222,4.04961C8.83217,5.41601,8.13725,7.06614,8.13725,9L8.13725,10L7.12009,10C5.71758,10,4.51932,10.4886,3.52532,11.4659C2.53132,12.4431,2.03431,13.6211,2.03431,15C2.03431,16.3789,2.53132,17.5569,3.52532,18.5341C3.99531,18.9962,4.53447,19.3538,5.14278,19.6071C5.2229,19.6405,5.33983,19.695,5.49356,19.7705C5.80505,19.9235,6.00818,20,6.10294,20L6.10294,22Z", + "fill-rule": "evenodd", + "fill": "#FF6A00", + "fill-opacity": "1" + }, + "children": [] + } + ] + }, + { + "type": "element", + "name": "g", + "attributes": {}, + "children": [ + { + "type": "element", + "name": "path", + "attributes": { + "d": "M20.18796103515625,11.66909C19.46346103515625,11.5762,18.72726103515625,11.52975,17.991011035156248,11.52975C16.728921035156247,11.52975,15.45515103515625,11.66909,14.23981103515625,11.91292C13.02447103515625,12.156749999999999,11.85588103515625,12.539909999999999,10.73402103515625,12.98113C9.98612103515625,13.306239999999999,9.23822103515625,13.69327,8.49031803515625,14.14223C7.99950790415625,14.43251,7.85927603515625,15.10595,8.15142503515625,15.59361L11.11966103515625,19.9478C11.45855103515625,20.4354,12.13634103515625,20.5747,12.627151035156249,20.2845C12.821921035156251,20.152900000000002,13.14523103515625,19.990299999999998,13.59708103515625,19.796799999999998C14.27487103515625,19.506500000000003,14.964341035156249,19.3091,15.68887103515625,19.169800000000002C16.413401035156248,19.018900000000002,17.14962103515625,18.926000000000002,17.93258103515625,18.926000000000002L20.071061035156248,11.715530000000001L20.18796103515625,11.66909ZM22.91076103515625,12.20319L22.525161035156252,8L20.18796103515625,11.6807C20.72556103515625,11.72714,21.21636103515625,11.82003,21.74216103515625,11.92453C21.74216103515625,11.91679,22.13166103515625,12.00968,22.91076103515625,12.20319ZM18.09616103515625,18.9724L17.06782103515625,22.4557L18.773961035156248,24L21.11116103515625,23.465899999999998L21.788961035156248,19.5414C21.298161035156248,19.402,20.81896103515625,19.2511,20.32816103515625,19.1582C19.60366103515625,19.076900000000002,18.86746103515625,18.9724,18.09616103515625,18.9724ZM27.49166103515625,14.14223C26.74376103515625,13.69327,25.99586103515625,13.306239999999999,25.24796103515625,12.98113C24.52346103515625,12.69086,23.74046103515625,12.40058,22.95756103515625,12.20319L22.95756103515625,12.40058L21.69546103515625,19.5646C21.89416103515625,19.6575,22.139561035156248,19.7039,22.32646103515625,19.8084C22.77836103515625,20.0019,23.101661035156248,20.1645,23.29646103515625,20.2961C23.78726103515625,20.586399999999998,24.51176103515625,20.4354,24.80396103515625,19.959400000000002L27.77216103515625,15.605229999999999C28.16946103515625,15.05951,28.02926103515625,14.43251,27.49166103515625,14.14223Z", + "fill": "#FF6A00", + "fill-opacity": "1" + }, + "children": [] + } + ] + }, + { + "type": "element", + "name": "g", + "attributes": {}, + "children": [ + { + "type": "element", + "name": "path", + "attributes": { + "d": "M35.785,3.8624638671875L50.233000000000004,3.8624638671875L50.233000000000004,5.9204638671875L35.785,5.9204638671875L35.785,3.8624638671875ZM43.156,11.5904638671875Q42.106,13.4594638671875,40.7515,15.4754638671875Q39.397,17.4914638671875,38.599000000000004,18.3104638671875L46.978,17.5334638671875Q45.382,15.0974638671875,44.479,13.8794638671875L46.306,12.7034638671875Q47.397999999999996,14.1104638671875,48.9835,16.3784638671875Q50.569,18.6464638671875,51.492999999999995,20.1794638671875L49.54,21.6494638671875Q49.057,20.8094638671875,48.238,19.5074638671875Q46.243,19.6334638671875,42.82,19.9274638671875Q39.397,20.2214638671875,37.024,20.4734638671875L36.184,20.5784638671875L35.47,20.6834638671875L34.84,18.5624638671875Q35.281,18.4154638671875,35.512,18.2579638671875Q35.743,18.1004638671875,35.9005,17.963963867187502Q36.058,17.8274638671875,36.121,17.7644638671875Q37.087,16.840463867187502,38.305,15.1079638671875Q39.522999999999996,13.3754638671875,40.531,11.5904638671875L34,11.5904638671875L34,9.5114638671875L52.018,9.5114638671875L52.018,11.5904638671875L43.156,11.5904638671875ZM62.203,10.9814638671875L62.203,12.7244638671875L60.25,12.7244638671875L60.25,2.5814638671875L62.203,2.6654638671875L62.203,10.4144638671875Q63.19,8.6294638671875,64.051,6.4139638671875Q64.912,4.1984638671875,65.28999999999999,2.3504638671875L67.348,2.8334628671875Q67.15899999999999,3.7784638671875,66.80199999999999,4.9754638671875L72.619,4.9754638671875L72.619,6.9704638671875L66.13,6.9704638671875Q65.143,9.7004638671875,63.778,12.0524638671875L62.203,10.9814638671875ZM56.113,3.3794638671875L58.045,3.4634638671875L58.045,12.1784638671875L56.113,12.1784638671875L56.113,3.3794638671875ZM67.495,7.3064638671875Q68.251,7.8944638671875,69.469,9.1229638671875Q70.687,10.3514638671875,71.40100000000001,11.2334638671875L69.84700000000001,12.7454638671875Q69.238,11.9684638671875,68.083,10.7714638671875Q66.928,9.5744638671875,66.025,8.7134638671875L67.495,7.3064638671875ZM70.834,13.3754638671875L70.834,18.9194638671875L73.06,18.9194638671875L73.06,20.8094638671875L54.307,20.8094638671875L54.307,18.9194638671875L56.491,18.9194638671875L56.491,13.3754638671875L70.834,13.3754638671875ZM60.733000000000004,15.2444638671875L58.465,15.2444638671875L58.465,18.9194638671875L60.733000000000004,18.9194638671875L60.733000000000004,15.2444638671875ZM62.581,18.9194638671875L64.765,18.9194638671875L64.765,15.2444638671875L62.581,15.2444638671875L62.581,18.9194638671875ZM66.592,18.9194638671875L68.881,18.9194638671875L68.881,15.2444638671875L66.592,15.2444638671875L66.592,18.9194638671875ZM80.578,11.0444638671875L80.893,12.4514638671875L79.48599999999999,13.0814638671875L79.48599999999999,19.0874638671875Q79.48599999999999,20.0114638671875,79.2655,20.4629638671875Q79.045,20.9144638671875,78.52000000000001,21.1034638671875Q77.995,21.2924638671875,76.90299999999999,21.3974638671875L76.021,21.4814638671875L75.43299999999999,19.4864638671875L76.462,19.4024638671875Q76.987,19.3604638671875,77.197,19.2974638671875Q77.407,19.2344638671875,77.4805,19.0559638671875Q77.554,18.8774638671875,77.554,18.4364638671875L77.554,13.9004638671875Q76.189,14.4464638671875,75.202,14.7824638671875L74.74000000000001,12.7244638671875Q75.916,12.3464638671875,77.554,11.6744638671875L77.554,8.1464638671875L75.34899999999999,8.1464638671875L75.34899999999999,6.1094638671875L77.554,6.1094638671875L77.554,2.4974628671875L79.48599999999999,2.5814638671875L79.48599999999999,6.1094638671875L81.03999999999999,6.1094638671875L81.03999999999999,8.1464638671875L79.48599999999999,8.1464638671875L79.48599999999999,10.8344638671875L80.431,10.3934638671875L80.578,11.0444638671875ZM83.56,6.6764638671875L83.56,9.0074638671875L81.565,9.0074638671875L81.565,4.7444638671875L86.24799999999999,4.7444638671875Q85.84899999999999,3.3794638671875,85.618,2.7494638671875L87.655,2.4974628671875Q87.991,3.2744638671875,88.432,4.7444638671875L93.094,4.7444638671875L93.094,9.0074638671875L91.162,9.0074638671875L91.162,6.6764638671875L83.56,6.6764638671875ZM86.731,9.3434638671875Q85.807,10.2674638671875,84.7465,11.1284638671875Q83.686,11.9894638671875,82.15299999999999,13.1234638671875L81.082,11.5064638671875Q83.455,9.9524638671875,85.408,7.9154638671875L86.731,9.3434638671875ZM88.852,7.9154638671875Q89.755,8.5244638671875,91.3615,9.731963867187499Q92.968,10.9394638671875,93.703,11.5694638671875L92.632,13.3334638671875Q91.771,12.5354638671875,90.217,11.3384638671875Q88.663,10.1414638671875,87.718,9.5114638671875L88.852,7.9154638671875ZM92.107,15.2444638671875L88.285,15.2444638671875L88.285,18.7094638671875L93.577,18.7094638671875L93.577,20.5994638671875L80.935,20.5994638671875L80.935,18.7094638671875L86.164,18.7094638671875L86.164,15.2444638671875L82.3,15.2444638671875L82.3,13.3334638671875L92.107,13.3334638671875L92.107,15.2444638671875Z", + "fill": "#FF6A00", + "fill-opacity": "1" + }, + "children": [] + } + ] + } + ] + } + ] + }, + "name": "AliyunIconBig" } diff --git a/web/app/components/billing/pricing/footer.tsx b/web/app/components/billing/pricing/footer.tsx index 4e3cdfee3d..fd713eb3da 100644 --- a/web/app/components/billing/pricing/footer.tsx +++ b/web/app/components/billing/pricing/footer.tsx @@ -2,19 +2,29 @@ import React from 'react' import Link from 'next/link' import { useTranslation } from 'react-i18next' import { RiArrowRightUpLine } from '@remixicon/react' +import { type Category, CategoryEnum } from '.' +import cn from '@/utils/classnames' type FooterProps = { pricingPageURL: string + currentCategory: Category } const Footer = ({ pricingPageURL, + currentCategory, }: FooterProps) => { const { t } = useTranslation() return (
-
+
+ {currentCategory === CategoryEnum.CLOUD && ( +
+ {t('billing.plansCommon.taxTip')} + {t('billing.plansCommon.taxTipSecond')} +
+ )} void @@ -25,7 +30,7 @@ const Pricing: FC = ({ const { plan } = useProviderContext() const { isCurrentWorkspaceManager } = useAppContext() const [planRange, setPlanRange] = React.useState(PlanRange.monthly) - const [currentCategory, setCurrentCategory] = useState('cloud') + const [currentCategory, setCurrentCategory] = useState(CategoryEnum.CLOUD) const canPay = isCurrentWorkspaceManager useKeyPress(['esc'], onCancel) @@ -57,7 +62,7 @@ const Pricing: FC = ({ planRange={planRange} canPay={canPay} /> -
+
diff --git a/web/app/components/datasets/extra-info/service-api/index.tsx b/web/app/components/datasets/extra-info/service-api/index.tsx index b1843682ee..af7ce946ad 100644 --- a/web/app/components/datasets/extra-info/service-api/index.tsx +++ b/web/app/components/datasets/extra-info/service-api/index.tsx @@ -52,7 +52,7 @@ const ServiceApi = ({ />
- + : undefined} />
) diff --git a/web/app/components/plugins/hooks.ts b/web/app/components/plugins/hooks.ts index 0af7c1a170..f22b2c4d69 100644 --- a/web/app/components/plugins/hooks.ts +++ b/web/app/components/plugins/hooks.ts @@ -1,3 +1,4 @@ +import { useMemo } from 'react' import { useTranslation } from 'react-i18next' import type { TFunction } from 'i18next' import { @@ -14,23 +15,29 @@ export const useTags = (translateFromOut?: TFunction) => { const { t: translation } = useTranslation() const t = translateFromOut || translation - const tags = tagKeys.map((tag) => { - return { - name: tag, - label: t(`pluginTags.tags.${tag}`), + const tags = useMemo(() => { + return tagKeys.map((tag) => { + return { + name: tag, + label: t(`pluginTags.tags.${tag}`), + } + }) + }, [t]) + + const tagsMap = useMemo(() => { + return tags.reduce((acc, tag) => { + acc[tag.name] = tag + return acc + }, {} as Record) + }, [tags]) + + const getTagLabel = useMemo(() => { + return (name: string) => { + if (!tagsMap[name]) + return name + return tagsMap[name].label } - }) - - const tagsMap = tags.reduce((acc, tag) => { - acc[tag.name] = tag - return acc - }, {} as Record) - - const getTagLabel = (name: string) => { - if (!tagsMap[name]) - return name - return tagsMap[name].label - } + }, [tagsMap]) return { tags, @@ -48,23 +55,27 @@ export const useCategories = (translateFromOut?: TFunction) => { const { t: translation } = useTranslation() const t = translateFromOut || translation - const categories = categoryKeys.map((category) => { - if (category === 'agent-strategy') { - return { - name: 'agent-strategy', - label: t('plugin.category.agents'), + const categories = useMemo(() => { + return categoryKeys.map((category) => { + if (category === 'agent-strategy') { + return { + name: 'agent-strategy', + label: t('plugin.category.agents'), + } } - } - return { - name: category, - label: t(`plugin.category.${category}s`), - } - }) + return { + name: category, + label: t(`plugin.category.${category}s`), + } + }) + }, [t]) - const categoriesMap = categories.reduce((acc, category) => { - acc[category.name] = category - return acc - }, {} as Record) + const categoriesMap = useMemo(() => { + return categories.reduce((acc, category) => { + acc[category.name] = category + return acc + }, {} as Record) + }, [categories]) return { categories, @@ -76,23 +87,27 @@ export const useSingleCategories = (translateFromOut?: TFunction) => { const { t: translation } = useTranslation() const t = translateFromOut || translation - const categories = categoryKeys.map((category) => { - if (category === 'agent-strategy') { - return { - name: 'agent-strategy', - label: t('plugin.categorySingle.agent'), + const categories = useMemo(() => { + return categoryKeys.map((category) => { + if (category === 'agent-strategy') { + return { + name: 'agent-strategy', + label: t('plugin.categorySingle.agent'), + } } - } - return { - name: category, - label: t(`plugin.categorySingle.${category}`), - } - }) + return { + name: category, + label: t(`plugin.categorySingle.${category}`), + } + }) + }, [t]) - const categoriesMap = categories.reduce((acc, category) => { - acc[category.name] = category - return acc - }, {} as Record) + const categoriesMap = useMemo(() => { + return categories.reduce((acc, category) => { + acc[category.name] = category + return acc + }, {} as Record) + }, [categories]) return { categories, diff --git a/web/app/components/tools/mcp/modal.tsx b/web/app/components/tools/mcp/modal.tsx index 1a12b3b3e9..1d888c57e8 100644 --- a/web/app/components/tools/mcp/modal.tsx +++ b/web/app/components/tools/mcp/modal.tsx @@ -3,6 +3,7 @@ import React, { useRef, useState } from 'react' import { useTranslation } from 'react-i18next' import { getDomain } from 'tldts' import { RiCloseLine, RiEditLine } from '@remixicon/react' +import { Mcp } from '@/app/components/base/icons/src/vender/other' import AppIconPicker from '@/app/components/base/app-icon-picker' import type { AppIconSelection } from '@/app/components/base/app-icon-picker' import AppIcon from '@/app/components/base/app-icon' @@ -17,6 +18,7 @@ import Toast from '@/app/components/base/toast' import { uploadRemoteFileInfo } from '@/service/common' import cn from '@/utils/classnames' import { useHover } from 'ahooks' +import { shouldUseMcpIconForAppIcon } from '@/utils/mcp' export type DuplicateAppModalProps = { data?: ToolWithProvider @@ -35,7 +37,7 @@ export type DuplicateAppModalProps = { onHide: () => void } -const DEFAULT_ICON = { type: 'emoji', icon: '🧿', background: '#EFF1F5' } +const DEFAULT_ICON = { type: 'emoji', icon: '🔗', background: '#6366F1' } const extractFileId = (url: string) => { const match = url.match(/files\/(.+?)\/file-preview/) return match ? match[1] : null @@ -208,6 +210,7 @@ const MCPModal = ({ icon={appIcon.type === 'emoji' ? appIcon.icon : appIcon.fileId} background={appIcon.type === 'emoji' ? appIcon.background : undefined} imageUrl={appIcon.type === 'image' ? appIcon.url : undefined} + innerIcon={shouldUseMcpIconForAppIcon(appIcon.type, appIcon.type === 'emoji' ? appIcon.icon : '') ? : undefined} size='xxl' className='relative cursor-pointer rounded-2xl' coverElement={ diff --git a/web/app/components/tools/provider-list.tsx b/web/app/components/tools/provider-list.tsx index 08a4aa0b5d..1679b4469b 100644 --- a/web/app/components/tools/provider-list.tsx +++ b/web/app/components/tools/provider-list.tsx @@ -21,6 +21,7 @@ import { useCheckInstalled, useInvalidateInstalledPluginList } from '@/service/u import { useGlobalPublicStore } from '@/context/global-public-context' import { ToolTypeEnum } from '../workflow/block-selector/types' import { useMarketplace } from './marketplace/hooks' +import { useTags } from '@/app/components/plugins/hooks' const getToolType = (type: string) => { switch (type) { @@ -40,6 +41,7 @@ const ProviderList = () => { // const searchParams = useSearchParams() // searchParams.get('category') === 'workflow' const { t } = useTranslation() + const { getTagLabel } = useTags() const { enable_marketplace } = useGlobalPublicStore(s => s.systemFeatures) const containerRef = useRef(null) @@ -180,7 +182,7 @@ const ProviderList = () => { } as any} footer={ getTagLabel(label)) || []} /> } /> diff --git a/web/app/components/workflow/index.tsx b/web/app/components/workflow/index.tsx index 75c4d51390..b289cafefd 100644 --- a/web/app/components/workflow/index.tsx +++ b/web/app/components/workflow/index.tsx @@ -412,10 +412,10 @@ export const Workflow: FC = memo(({ nodesFocusable={!nodesReadOnly} edgesFocusable={!nodesReadOnly} panOnScroll={false} - panOnDrag={controlMode === ControlMode.Hand && !workflowReadOnly} - zoomOnPinch={!workflowReadOnly} - zoomOnScroll={!workflowReadOnly} - zoomOnDoubleClick={!workflowReadOnly} + panOnDrag={controlMode === ControlMode.Hand} + zoomOnPinch={true} + zoomOnScroll={true} + zoomOnDoubleClick={true} isValidConnection={isValidConnection} selectionKeyCode={null} selectionMode={SelectionMode.Partial} diff --git a/web/i18n/de-DE/billing.ts b/web/i18n/de-DE/billing.ts index 98d4488fab..fc45f3889c 100644 --- a/web/i18n/de-DE/billing.ts +++ b/web/i18n/de-DE/billing.ts @@ -94,6 +94,8 @@ const translation = { teamMember_one: '{{count,number}} Teammitglied', documentsRequestQuotaTooltip: 'Gibt die Gesamtzahl der Aktionen an, die ein Arbeitsbereich pro Minute innerhalb der Wissensbasis ausführen kann, einschließlich der Erstellung, Löschung, Aktualisierung von Datensätzen, des Hochladens von Dokumenten, von Änderungen, der Archivierung und von Abfragen in der Wissensbasis. Diese Kennzahl wird verwendet, um die Leistung von Anfragen an die Wissensbasis zu bewerten. Wenn ein Sandbox-Nutzer beispielsweise in einer Minute 10 aufeinanderfolgende Testdurchläufe durchführt, wird sein Arbeitsbereich für die nächste Minute vorübergehend daran gehindert, die folgenden Aktionen auszuführen: Erstellung, Löschung, Aktualisierung von Datensätzen sowie das Hochladen oder Ändern von Dokumenten.', startBuilding: 'Beginnen Sie mit der Entwicklung', + taxTipSecond: 'Wenn in Ihrer Region keine relevanten Steuervorschriften gelten, wird an der Kasse keine Steuer angezeigt und Ihnen werden während der gesamten Abonnementlaufzeit keine zusätzlichen Gebühren berechnet.', + taxTip: 'Alle Abonnementspreise (monatlich/jährlich) verstehen sich zuzüglich der geltenden Steuern (z. B. MwSt., Umsatzsteuer).', }, plans: { sandbox: { diff --git a/web/i18n/en-US/billing.ts b/web/i18n/en-US/billing.ts index d85dccf6ac..9d68b8e934 100644 --- a/web/i18n/en-US/billing.ts +++ b/web/i18n/en-US/billing.ts @@ -37,6 +37,8 @@ const translation = { save: 'Save ', free: 'Free', annualBilling: 'Bill Annually Save {{percent}}%', + taxTip: 'All subscription prices (monthly/annual) exclude applicable taxes (e.g., VAT, sales tax).', + taxTipSecond: 'If your region has no applicable tax requirements, no tax will appear in your checkout, and you won’t be charged any additional fees for the entire subscription term.', comparePlanAndFeatures: 'Compare plans & features', priceTip: 'per workspace/', currentPlan: 'Current Plan', diff --git a/web/i18n/es-ES/billing.ts b/web/i18n/es-ES/billing.ts index c5d4ef95b9..a8180e2d07 100644 --- a/web/i18n/es-ES/billing.ts +++ b/web/i18n/es-ES/billing.ts @@ -94,6 +94,8 @@ const translation = { apiRateLimitTooltip: 'El límite de tasa de la API se aplica a todas las solicitudes realizadas a través de la API de Dify, incluidos la generación de texto, las conversaciones de chat, las ejecuciones de flujo de trabajo y el procesamiento de documentos.', documentsRequestQuotaTooltip: 'Especifica el número total de acciones que un espacio de trabajo puede realizar por minuto dentro de la base de conocimientos, incluyendo la creación, eliminación, actualización de conjuntos de datos, carga de documentos, modificaciones, archivo y consultas a la base de conocimientos. Esta métrica se utiliza para evaluar el rendimiento de las solicitudes a la base de conocimientos. Por ejemplo, si un usuario de Sandbox realiza 10 pruebas consecutivas en un minuto, su espacio de trabajo será temporalmente restringido de realizar las siguientes acciones durante el siguiente minuto: creación de conjuntos de datos, eliminación, actualizaciones y carga o modificaciones de documentos.', startBuilding: 'Empezar a construir', + taxTip: 'Todos los precios de suscripción (mensuales/anuales) excluyen los impuestos aplicables (por ejemplo, IVA, impuesto sobre ventas).', + taxTipSecond: 'Si su región no tiene requisitos fiscales aplicables, no se mostrará ningún impuesto en su pago y no se le cobrará ninguna tarifa adicional durante todo el período de suscripción.', }, plans: { sandbox: { diff --git a/web/i18n/fa-IR/billing.ts b/web/i18n/fa-IR/billing.ts index 5634692dc2..3749036f3c 100644 --- a/web/i18n/fa-IR/billing.ts +++ b/web/i18n/fa-IR/billing.ts @@ -94,6 +94,8 @@ const translation = { apiRateLimitTooltip: 'محدودیت نرخ API برای همه درخواست‌های انجام شده از طریق API Dify اعمال می‌شود، از جمله تولید متن، محاوره‌های چت، اجرای گردش‌های کار و پردازش اسناد.', documentsRequestQuotaTooltip: 'تعیین می‌کند که تعداد کلی اقداماتی که یک فضای کاری می‌تواند در هر دقیقه در داخل پایگاه دانش انجام دهد، شامل ایجاد مجموعه داده، حذف، به‌روزرسانی، بارگذاری مستندات، تغییرات، بایگانی و پرسش از پایگاه دانش است. این معیار برای ارزیابی عملکرد درخواست‌های پایگاه دانش استفاده می‌شود. به عنوان مثال، اگر یک کاربر Sandbox در طی یک دقیقه 10 آزمایش متوالی انجام دهد، فضای کاری او به طور موقت از انجام اقدامات زیر در دقیقه بعدی محدود خواهد شد: ایجاد مجموعه داده، حذف، به‌روزرسانی و بارگذاری یا تغییر مستندات.', startBuilding: 'شروع به ساخت کنید', + taxTip: 'تمام قیمت‌های اشتراک (ماهانه/سالانه) شامل مالیات‌های مربوطه (مثلاً مالیات بر ارزش افزوده، مالیات فروش) نمی‌شوند.', + taxTipSecond: 'اگر منطقه شما هیچ الزامات مالیاتی قابل اجرا نداشته باشد، هیچ مالیاتی در هنگام پرداخت نشان داده نمی‌شود و برای کل مدت اشتراک هیچ هزینه اضافی از شما دریافت نخواهد شد.', }, plans: { sandbox: { diff --git a/web/i18n/fr-FR/billing.ts b/web/i18n/fr-FR/billing.ts index 117d1c6654..a41eed7e23 100644 --- a/web/i18n/fr-FR/billing.ts +++ b/web/i18n/fr-FR/billing.ts @@ -94,6 +94,8 @@ const translation = { documents: '{{count,number}} Documents de connaissance', documentsRequestQuotaTooltip: 'Spécifie le nombre total d\'actions qu\'un espace de travail peut effectuer par minute dans la base de connaissances, y compris la création, la suppression, les mises à jour de jeux de données, le téléchargement de documents, les modifications, l\'archivage et les requêtes de la base de connaissances. Ce paramètre est utilisé pour évaluer les performances des requêtes de la base de connaissances. Par exemple, si un utilisateur de Sandbox effectue 10 tests de validité consécutifs en une minute, son espace de travail sera temporairement restreint dans l\'exécution des actions suivantes pendant la minute suivante : création, suppression, mises à jour de jeux de données, et téléchargements ou modifications de documents.', startBuilding: 'Commencez à construire', + taxTip: 'Tous les prix des abonnements (mensuels/annuels) s\'entendent hors taxes applicables (par exemple, TVA, taxe de vente).', + taxTipSecond: 'Si votre région n\'a pas de exigences fiscales applicables, aucune taxe n\'apparaîtra lors de votre paiement et vous ne serez pas facturé de frais supplémentaires pendant toute la durée de l\'abonnement.', }, plans: { sandbox: { diff --git a/web/i18n/hi-IN/billing.ts b/web/i18n/hi-IN/billing.ts index 749ab804ab..fbc6dffc7c 100644 --- a/web/i18n/hi-IN/billing.ts +++ b/web/i18n/hi-IN/billing.ts @@ -102,6 +102,8 @@ const translation = { teamMember_one: '{{count,number}} टीम सदस्य', documentsRequestQuotaTooltip: 'यह ज्ञान आधार में एक कार्यक्षेत्र द्वारा प्रति मिनट किए जा सकने वाले कुल कार्यों की संख्या को निर्दिष्ट करता है, जिसमें डेटासेट बनाना, हटाना, अपडेट करना, दस्तावेज़ अपलोड करना, संशोधन करना, संग्रहित करना और ज्ञान आधार अनुरोध शामिल हैं। इस मीट्रिक का उपयोग ज्ञान आधार अनुरोधों के प्रदर्शन का मूल्यांकन करने के लिए किया जाता है। उदाहरण के लिए, यदि एक सैंडबॉक्स उपयोगकर्ता एक मिनट के भीतर 10 लगातार हिट परीक्षण करता है, तो उनके कार्यक्षेत्र को अगले मिनट के लिए निम्नलिखित कार्यों को करने से अस्थायी रूप से प्रतिबंधित किया जाएगा: डेटासेट बनाना, हटाना, अपडेट करना और दस्तावेज़ अपलोड या संशोधन करना।', startBuilding: 'बनाना शुरू करें', + taxTip: 'सभी सदस्यता मूल्य (मासिक/वार्षिक) लागू करों (जैसे, VAT, बिक्री कर) को शामिल नहीं करते हैं।', + taxTipSecond: 'यदि आपके क्षेत्र में कोई लागू कर आवश्यकताएँ नहीं हैं, तो आपकी चेकआउट में कोई कर नहीं दिखाई देगा, और पूरे सदस्यता अवधि के लिए आपसे कोई अतिरिक्त शुल्क नहीं लिया जाएगा।', }, plans: { sandbox: { diff --git a/web/i18n/id-ID/billing.ts b/web/i18n/id-ID/billing.ts index 11419c3b16..c6c718d15b 100644 --- a/web/i18n/id-ID/billing.ts +++ b/web/i18n/id-ID/billing.ts @@ -87,6 +87,8 @@ const translation = { modelProviders: 'Mendukung OpenAI/Anthropic/Llama2/Azure OpenAI/Hugging Face/Replite', member: 'Anggota', startBuilding: 'Mulai Membangun', + taxTip: 'Semua harga langganan (bulanan/tahunan) belum termasuk pajak yang berlaku (misalnya, PPN, pajak penjualan).', + taxTipSecond: 'Jika wilayah Anda tidak memiliki persyaratan pajak yang berlaku, tidak akan ada pajak yang muncul saat checkout, dan Anda tidak akan dikenakan biaya tambahan apa pun selama masa langganan.', }, plans: { sandbox: { diff --git a/web/i18n/it-IT/billing.ts b/web/i18n/it-IT/billing.ts index f89502ee5b..ef6b1943e3 100644 --- a/web/i18n/it-IT/billing.ts +++ b/web/i18n/it-IT/billing.ts @@ -102,6 +102,8 @@ const translation = { annualBilling: 'Fatturazione annuale', documentsRequestQuotaTooltip: 'Specifica il numero totale di azioni che un\'area di lavoro può eseguire al minuto all\'interno della base di conoscenza, compresi la creazione, l\'eliminazione, gli aggiornamenti dei dataset, il caricamento di documenti, le modifiche, l\'archiviazione e le query sulla base di conoscenza. Questa metrica viene utilizzata per valutare le prestazioni delle richieste alla base di conoscenza. Ad esempio, se un utente di Sandbox esegue 10 test consecutivi in un minuto, la sua area di lavoro sarà temporaneamente limitata dall\'eseguire le seguenti azioni per il minuto successivo: creazione, eliminazione, aggiornamenti dei dataset e caricamento o modifica di documenti.', startBuilding: 'Inizia a costruire', + taxTip: 'Tutti i prezzi degli abbonamenti (mensili/annuali) non includono le tasse applicabili (ad esempio, IVA, imposta sulle vendite).', + taxTipSecond: 'Se nella tua regione non ci sono requisiti fiscali applicabili, nessuna tassa apparirà al momento del pagamento e non ti verranno addebitate spese aggiuntive per l\'intera durata dell\'abbonamento.', }, plans: { sandbox: { diff --git a/web/i18n/ja-JP/billing.ts b/web/i18n/ja-JP/billing.ts index 428ce29ed2..2a5f74c15f 100644 --- a/web/i18n/ja-JP/billing.ts +++ b/web/i18n/ja-JP/billing.ts @@ -36,6 +36,8 @@ const translation = { save: '節約 ', free: '無料', annualBilling: '年次請求', + taxTip: 'すべてのサブスクリプション料金(月額/年額)は、適用される税金(例:消費税、付加価値税)を含みません。', + taxTipSecond: 'お客様の地域に適用税がない場合、チェックアウト時に税金は表示されず、サブスクリプション期間中に追加料金が請求されることもありません。', comparePlanAndFeatures: 'プランと機能を比較する', priceTip: 'ワークスペース/', currentPlan: '現在のプラン', diff --git a/web/i18n/ko-KR/billing.ts b/web/i18n/ko-KR/billing.ts index ff0dd189e4..c5f081d41b 100644 --- a/web/i18n/ko-KR/billing.ts +++ b/web/i18n/ko-KR/billing.ts @@ -103,6 +103,8 @@ const translation = { documentsRequestQuotaTooltip: '지식 기반 내에서 작업 공간이 분당 수행할 수 있는 총 작업 수를 지정합니다. 여기에는 데이터 세트 생성, 삭제, 업데이트, 문서 업로드, 수정, 보관 및 지식 기반 쿼리가 포함됩니다. 이 지표는 지식 기반 요청의 성능을 평가하는 데 사용됩니다. 예를 들어, 샌드박스 사용자가 1 분 이내에 10 회의 연속 히트 테스트를 수행하면, 해당 작업 공간은 다음 1 분 동안 데이터 세트 생성, 삭제, 업데이트 및 문서 업로드 또는 수정과 같은 작업을 수행하는 것이 일시적으로 제한됩니다.', startBuilding: '구축 시작', + taxTip: '모든 구독 요금(월간/연간)에는 해당 세금(예: 부가가치세, 판매세)이 포함되어 있지 않습니다.', + taxTipSecond: '귀하의 지역에 적용 가능한 세금 요구 사항이 없는 경우, 결제 시 세금이 표시되지 않으며 전체 구독 기간 동안 추가 요금이 부과되지 않습니다.', }, plans: { sandbox: { diff --git a/web/i18n/pl-PL/billing.ts b/web/i18n/pl-PL/billing.ts index 3bf0867877..cf0859468b 100644 --- a/web/i18n/pl-PL/billing.ts +++ b/web/i18n/pl-PL/billing.ts @@ -101,6 +101,8 @@ const translation = { documentsRequestQuota: '{{count,number}}/min Limit wiedzy na żądanie', documentsRequestQuotaTooltip: 'Określa całkowitą liczbę działań, jakie przestrzeń robocza może wykonać na minutę w ramach bazy wiedzy, w tym tworzenie zbiorów danych, usuwanie, aktualizacje, przesyłanie dokumentów, modyfikacje, archiwizowanie i zapytania do bazy wiedzy. Ta metryka jest używana do oceny wydajności zapytań do bazy wiedzy. Na przykład, jeśli użytkownik Sandbox wykona 10 kolejnych testów w ciągu jednej minuty, jego przestrzeń robocza zostanie tymczasowo ograniczona w wykonywaniu następujących działań przez następną minutę: tworzenie zbiorów danych, usuwanie, aktualizacje oraz przesyłanie lub modyfikacje dokumentów.', startBuilding: 'Zacznij budować', + taxTip: 'Wszystkie ceny subskrypcji (miesięczne/roczne) nie obejmują obowiązujących podatków (np. VAT, podatek od sprzedaży).', + taxTipSecond: 'Jeśli w Twoim regionie nie ma obowiązujących przepisów podatkowych, podatek nie pojawi się podczas realizacji zamówienia i nie zostaną naliczone żadne dodatkowe opłaty przez cały okres subskrypcji.', }, plans: { sandbox: { diff --git a/web/i18n/pt-BR/billing.ts b/web/i18n/pt-BR/billing.ts index 91ccaa7794..e4ca0a064a 100644 --- a/web/i18n/pt-BR/billing.ts +++ b/web/i18n/pt-BR/billing.ts @@ -94,6 +94,8 @@ const translation = { apiRateLimitTooltip: 'O limite da taxa da API se aplica a todas as solicitações feitas através da API Dify, incluindo geração de texto, conversas de chat, execuções de fluxo de trabalho e processamento de documentos.', documentsRequestQuotaTooltip: 'Especifica o número total de ações que um espaço de trabalho pode realizar por minuto dentro da base de conhecimento, incluindo criação, exclusão, atualizações de conjuntos de dados, uploads de documentos, modificações, arquivamento e consultas à base de conhecimento. Esse métrica é utilizada para avaliar o desempenho das solicitações à base de conhecimento. Por exemplo, se um usuário do Sandbox realizar 10 testes de impacto consecutivos dentro de um minuto, seu espaço de trabalho ficará temporariamente restrito de realizar as seguintes ações no minuto seguinte: criação, exclusão, atualizações de conjuntos de dados e uploads ou modificações de documentos.', startBuilding: 'Comece a construir', + taxTip: 'Todos os preços de assinatura (mensal/anual) não incluem os impostos aplicáveis (por exemplo, IVA, imposto sobre vendas).', + taxTipSecond: 'Se a sua região não tiver requisitos fiscais aplicáveis, nenhum imposto aparecerá no seu checkout e você não será cobrado por taxas adicionais durante todo o período da assinatura.', }, plans: { sandbox: { diff --git a/web/i18n/ro-RO/billing.ts b/web/i18n/ro-RO/billing.ts index 550ff3e677..3f5577dc32 100644 --- a/web/i18n/ro-RO/billing.ts +++ b/web/i18n/ro-RO/billing.ts @@ -94,6 +94,8 @@ const translation = { documentsRequestQuotaTooltip: 'Specificați numărul total de acțiuni pe care un spațiu de lucru le poate efectua pe minut în cadrul bazei de cunoștințe, inclusiv crearea, ștergerea, actualizările setului de date, încărcările de documente, modificările, arhivarea și interogările bazei de cunoștințe. Acest metric este utilizat pentru a evalua performanța cererilor din baza de cunoștințe. De exemplu, dacă un utilizator Sandbox efectuează 10 teste consecutive de hituri într-un minut, spațiul său de lucru va fi restricționat temporar de la efectuarea următoarelor acțiuni pentru minutul următor: crearea setului de date, ștergerea, actualizările și încărcările sau modificările documentelor.', apiRateLimitTooltip: 'Limita de rată API se aplică tuturor cererilor efectuate prin API-ul Dify, inclusiv generarea de texte, conversațiile de chat, execuțiile fluxului de lucru și procesarea documentelor.', startBuilding: 'Începeți să construiți', + taxTip: 'Toate prețurile abonamentelor (lunare/anuale) nu includ taxele aplicabile (de exemplu, TVA, taxa pe vânzări).', + taxTipSecond: 'Dacă regiunea dumneavoastră nu are cerințe fiscale aplicabile, niciun impozit nu va apărea la finalizarea comenzii și nu vi se vor percepe taxe suplimentare pe întreaga durată a abonamentului.', }, plans: { sandbox: { diff --git a/web/i18n/ru-RU/billing.ts b/web/i18n/ru-RU/billing.ts index 27f5c71685..7017f90cc2 100644 --- a/web/i18n/ru-RU/billing.ts +++ b/web/i18n/ru-RU/billing.ts @@ -94,6 +94,8 @@ const translation = { priceTip: 'по рабочему месту/', documentsTooltip: 'Квота на количество документов, импортируемых из источника знаний.', startBuilding: 'Начать строительство', + taxTip: 'Все цены на подписку (ежемесячную/годовую) не включают применимые налоги (например, НДС, налог с продаж).', + taxTipSecond: 'Если в вашем регионе нет применимых налоговых требований, налоги не будут отображаться при оформлении заказа, и с вас не будут взиматься дополнительные сборы за весь срок подписки.', }, plans: { sandbox: { diff --git a/web/i18n/sl-SI/billing.ts b/web/i18n/sl-SI/billing.ts index 4481100dd8..fb9d9ec435 100644 --- a/web/i18n/sl-SI/billing.ts +++ b/web/i18n/sl-SI/billing.ts @@ -94,6 +94,8 @@ const translation = { getStarted: 'Začnite', documentsRequestQuotaTooltip: 'Določa skupno število dejanj, ki jih lahko delovno mesto opravi na minuto znotraj znanja baze, vključno s kreiranjem, brisanjem, posodobitvami, nalaganjem dokumentov, spremembami, arhiviranjem in poizvedbami po znanju bazi. Ta meritev se uporablja za ocenjevanje uspešnosti poizvedb v bazi znanja. Na primer, če uporabnik Sandbox izvede 10 zaporednih testov udarca v eni minuti, bo njegovo delovno mesto začasno omejeno pri izvajanju naslednjih dejanj v naslednji minuti: kreiranje podatkovnih nizov, brisanje, posodobitve in nalaganje ali spremembe dokumentov.', startBuilding: 'Začnite graditi', + taxTip: 'Vse cene naročnin (mesečne/letne) ne vključujejo veljavnih davkov (npr. DDV, davek na promet).', + taxTipSecond: 'Če vaša regija nima veljavnih davčnih zahtev, se v vaši košarici ne bo prikazal noben davek in za celotno obdobje naročnine vam ne bodo zaračunani nobeni dodatni stroški.', }, plans: { sandbox: { diff --git a/web/i18n/th-TH/billing.ts b/web/i18n/th-TH/billing.ts index 55a01449eb..461e4a8240 100644 --- a/web/i18n/th-TH/billing.ts +++ b/web/i18n/th-TH/billing.ts @@ -94,6 +94,8 @@ const translation = { annualBilling: 'การเรียกเก็บเงินประจำปี', documentsRequestQuotaTooltip: 'ระบุจำนวนรวมของการกระทำที่เวิร์กสเปซสามารถดำเนินการต่อหนึ่งนาทีภายในฐานความรู้ รวมถึงการสร้างชุดข้อมูล การลบ การอัปเดต การอัปโหลดเอกสาร การปรับเปลี่ยน การเก็บถาวร และการสอบถามฐานความรู้ เมตริกนี้ถูกใช้ในการประเมินประสิทธิภาพของคำขอฐานความรู้ ตัวอย่างเช่น หากผู้ใช้ Sandbox ทำการทดสอบการตี 10 ครั้งต่อเนื่องภายในหนึ่งนาที เวิร์กสเปซของพวกเขาจะถูกจำกัดชั่วคราวในการดำเนินการต่อไปนี้ในนาทีถัดไป: การสร้างชุดข้อมูล การลบ การอัปเดต หรือการอัปโหลดหรือปรับเปลี่ยนเอกสาร.', startBuilding: 'เริ่มสร้าง', + taxTip: 'ราคาการสมัครสมาชิกทั้งหมด (รายเดือน/รายปี) ไม่รวมภาษีที่ใช้บังคับ (เช่น ภาษีมูลค่าเพิ่ม, ภาษีการขาย)', + taxTipSecond: 'หากภูมิภาคของคุณไม่มีข้อกำหนดเกี่ยวกับภาษีที่ใช้ได้ จะไม่มีการคิดภาษีในขั้นตอนการชำระเงินของคุณ และคุณจะไม่ถูกเรียกเก็บค่าธรรมเนียมเพิ่มเติมใด ๆ ตลอดระยะเวลาสมาชิกทั้งหมด', }, plans: { sandbox: { diff --git a/web/i18n/tr-TR/billing.ts b/web/i18n/tr-TR/billing.ts index 62d6e0a07e..6d01d9dd32 100644 --- a/web/i18n/tr-TR/billing.ts +++ b/web/i18n/tr-TR/billing.ts @@ -94,6 +94,8 @@ const translation = { teamWorkspace: '{{count,number}} Takım Çalışma Alanı', documentsRequestQuotaTooltip: 'Bir çalışma alanının bilgi tabanında, veri seti oluşturma, silme, güncellemeler, belge yüklemeleri, değişiklikler, arşivleme ve bilgi tabanı sorguları dahil olmak üzere, dakikada gerçekleştirebileceği toplam işlem sayısını belirtir. Bu ölçüt, bilgi tabanı taleplerinin performansını değerlendirmek için kullanılır. Örneğin, bir Sandbox kullanıcısı bir dakika içinde ardışık 10 vurma testi gerçekleştirirse, çalışma alanı bir sonraki dakika için aşağıdaki işlemleri gerçekleştirmesi geçici olarak kısıtlanacaktır: veri seti oluşturma, silme, güncellemeler ve belge yüklemeleri veya değişiklikler.', startBuilding: 'İnşa Etmeye Başlayın', + taxTip: 'Tüm abonelik fiyatları (aylık/yıllık) geçerli vergiler (ör. KDV, satış vergisi) hariçtir.', + taxTipSecond: 'Bölgenizde geçerli vergi gereksinimleri yoksa, ödeme sayfanızda herhangi bir vergi görünmeyecek ve tüm abonelik süresi boyunca ek bir ücret tahsil edilmeyecektir.', }, plans: { sandbox: { diff --git a/web/i18n/uk-UA/billing.ts b/web/i18n/uk-UA/billing.ts index 10dafedb24..03b743e4fe 100644 --- a/web/i18n/uk-UA/billing.ts +++ b/web/i18n/uk-UA/billing.ts @@ -94,6 +94,8 @@ const translation = { apiRateLimitTooltip: 'Обмеження частоти запитів застосовується до всіх запитів, зроблених через API Dify, включаючи генерацію тексту, чат-розмови, виконання робочих процесів та обробку документів.', documentsRequestQuotaTooltip: 'Вказує загальну кількість дій, які робоча область може виконувати за хвилину в межах бази знань, включаючи створення, видалення, оновлення наборів даних, завантаження документів, модифікації, архівування та запити до бази знань. Цей показник використовується для оцінки ефективності запитів до бази знань. Наприклад, якщо користувач Sandbox виконує 10 послідовних тестів за один хвилину, його робочій області буде тимчасово заборонено виконувати наступні дії протягом наступної хвилини: створення наборів даних, видалення, оновлення, а також завантаження чи модифікацію документів.', startBuilding: 'Почніть будувати', + taxTip: 'Всі ціни на підписку (щомісячна/щорічна) не включають відповідні податки (наприклад, ПДВ, податок з продажу).', + taxTipSecond: 'Якщо для вашого регіону немає відповідних податкових вимог, податок не відображатиметься на вашому чек-ауті, і з вас не стягуватимуть додаткові збори протягом усього терміну підписки.', }, plans: { sandbox: { diff --git a/web/i18n/vi-VN/billing.ts b/web/i18n/vi-VN/billing.ts index 68e662425f..0166185e45 100644 --- a/web/i18n/vi-VN/billing.ts +++ b/web/i18n/vi-VN/billing.ts @@ -94,6 +94,8 @@ const translation = { freeTrialTipSuffix: 'Không cần thẻ tín dụng', documentsRequestQuotaTooltip: 'Chỉ định tổng số hành động mà một không gian làm việc có thể thực hiện mỗi phút trong cơ sở tri thức, bao gồm tạo mới tập dữ liệu, xóa, cập nhật, tải tài liệu lên, thay đổi, lưu trữ và truy vấn cơ sở tri thức. Chỉ số này được sử dụng để đánh giá hiệu suất của các yêu cầu cơ sở tri thức. Ví dụ, nếu một người dùng Sandbox thực hiện 10 lần kiểm tra liên tiếp trong một phút, không gian làm việc của họ sẽ bị hạn chế tạm thời không thực hiện các hành động sau trong phút tiếp theo: tạo mới tập dữ liệu, xóa, cập nhật và tải tài liệu lên hoặc thay đổi.', startBuilding: 'Bắt đầu xây dựng', + taxTipSecond: 'Nếu khu vực của bạn không có yêu cầu thuế áp dụng, sẽ không có thuế xuất hiện trong quá trình thanh toán của bạn và bạn sẽ không bị tính bất kỳ khoản phí bổ sung nào trong suốt thời gian đăng ký.', + taxTip: 'Tất cả giá đăng ký (hàng tháng/hàng năm) chưa bao gồm các loại thuế áp dụng (ví dụ: VAT, thuế bán hàng).', }, plans: { sandbox: { diff --git a/web/i18n/zh-Hans/billing.ts b/web/i18n/zh-Hans/billing.ts index 6d21693f88..f200b94e33 100644 --- a/web/i18n/zh-Hans/billing.ts +++ b/web/i18n/zh-Hans/billing.ts @@ -36,6 +36,8 @@ const translation = { save: '节省', free: '免费', annualBilling: '按年计费节省 {{percent}}%', + taxTip: '所有订阅价格(按月/按年)均不含适用税费(如增值税、销售税)。', + taxTipSecond: '如果您所在地区无适用税费要求,结账时将不会显示税费,且在整个订阅周期内您都无需支付任何额外费用。', comparePlanAndFeatures: '对比套餐 & 功能特性', priceTip: '每个团队空间/', currentPlan: '当前计划', diff --git a/web/i18n/zh-Hant/billing.ts b/web/i18n/zh-Hant/billing.ts index f99b1ef2cf..1b0b1f5e1f 100644 --- a/web/i18n/zh-Hant/billing.ts +++ b/web/i18n/zh-Hant/billing.ts @@ -94,6 +94,8 @@ const translation = { documentsTooltip: '從知識數據來源導入的文件數量配額。', documentsRequestQuotaTooltip: '指定工作區在知識基礎中每分鐘可以執行的總操作次數,包括數據集的創建、刪除、更新、文檔上傳、修改、歸檔和知識基礎查詢。這個指標用於評估知識基礎請求的性能。例如,如果一個沙箱用戶在一分鐘內連續執行 10 次命中測試,他們的工作區將在接下來的一分鐘內暫時禁止執行以下操作:數據集的創建、刪除、更新以及文檔上傳或修改。', startBuilding: '開始建造', + taxTip: '所有訂閱價格(月費/年費)不包含適用的稅費(例如增值稅、銷售稅)。', + taxTipSecond: '如果您的地區沒有適用的稅務要求,結帳時將不會顯示任何稅款,且在整個訂閱期間您也不會被收取任何額外費用。', }, plans: { sandbox: { diff --git a/web/package.json b/web/package.json index 62cccf0610..2a8972ee80 100644 --- a/web/package.json +++ b/web/package.json @@ -2,7 +2,7 @@ "name": "dify-web", "version": "1.9.1", "private": true, - "packageManager": "pnpm@10.17.1", + "packageManager": "pnpm@10.18.2", "engines": { "node": ">=v22.11.0" }, diff --git a/web/utils/mcp.ts b/web/utils/mcp.ts new file mode 100644 index 0000000000..dcbb63ee8a --- /dev/null +++ b/web/utils/mcp.ts @@ -0,0 +1,22 @@ +/** + * MCP (Model Context Protocol) utility functions + */ + +/** + * Determines if the MCP icon should be used based on the icon source + * @param src - The icon source, can be a string URL or an object with content and background + * @returns true if the MCP icon should be used (when it's an emoji object with 🔗 content) + */ +export const shouldUseMcpIcon = (src: any): boolean => { + return typeof src === 'object' && src?.content === '🔗' +} + +/** + * Checks if an app icon should use the MCP icon + * @param iconType - The type of icon ('emoji' | 'image') + * @param icon - The icon content (emoji or file ID) + * @returns true if the MCP icon should be used + */ +export const shouldUseMcpIconForAppIcon = (iconType: string, icon: string): boolean => { + return iconType === 'emoji' && icon === '🔗' +}