diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index 91e2a90e5e..3a13bb6b67 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -739,6 +739,8 @@ class DatasetApiDeleteApi(Resource): db.session.commit() return {"result": "success"}, 204 + + @console_ns.route("/datasets//api-keys/") class DatasetEnableApiApi(Resource): @setup_required diff --git a/api/controllers/service_api/dataset/document.py b/api/controllers/service_api/dataset/document.py index 6b635bcfbd..a9f7608733 100644 --- a/api/controllers/service_api/dataset/document.py +++ b/api/controllers/service_api/dataset/document.py @@ -124,8 +124,9 @@ class DocumentAddByTextApi(DatasetApiResource): args.get("retrieval_model").get("reranking_model").get("reranking_model_name"), ) - upload_file = FileService(db.engine).upload_text(text=str(text), - text_name=str(name), user_id=current_user.id, tenant_id=tenant_id) + upload_file = FileService(db.engine).upload_text( + text=str(text), text_name=str(name), user_id=current_user.id, tenant_id=tenant_id + ) data_source = { "type": "upload_file", "info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}}, @@ -203,8 +204,9 @@ class DocumentUpdateByTextApi(DatasetApiResource): name = args.get("name") if text is None or name is None: raise ValueError("Both text and name must be strings.") - upload_file = FileService(db.engine).upload_text(text=str(text), - text_name=str(name), user_id=current_user.id, tenant_id=tenant_id) + upload_file = FileService(db.engine).upload_text( + text=str(text), text_name=str(name), user_id=current_user.id, tenant_id=tenant_id + ) data_source = { "type": "upload_file", "info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}}, diff --git a/api/controllers/service_api/dataset/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/service_api/dataset/rag_pipeline/rag_pipeline_workflow.py index 75e2ba9f63..55bfdde009 100644 --- a/api/controllers/service_api/dataset/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/service_api/dataset/rag_pipeline/rag_pipeline_workflow.py @@ -41,7 +41,7 @@ class DatasourcePluginsApi(DatasetApiResource): @service_api_ns.doc( params={ "is_published": "Whether to get published or draft datasource plugins " - "(true for published, false for draft, default: true)" + "(true for published, false for draft, default: true)" } ) @service_api_ns.doc( @@ -54,15 +54,14 @@ class DatasourcePluginsApi(DatasetApiResource): """Resource for getting datasource plugins.""" # Get query parameter to determine published or draft is_published: bool = request.args.get("is_published", default=True, type=bool) - + rag_pipeline_service: RagPipelineService = RagPipelineService() datasource_plugins: list[dict[Any, Any]] = rag_pipeline_service.get_datasource_plugins( - tenant_id=tenant_id, - dataset_id=dataset_id, - is_published=is_published + tenant_id=tenant_id, dataset_id=dataset_id, is_published=is_published ) return datasource_plugins, 200 + @service_api_ns.route(f"/datasets/{uuid:dataset_id}/pipeline/datasource/nodes/{string:node_id}/run") class DatasourceNodeRunApi(DatasetApiResource): """Resource for datasource node run.""" @@ -80,7 +79,7 @@ class DatasourceNodeRunApi(DatasetApiResource): "datasource_type": "Datasource type, e.g. online_document", "credential_id": "Credential ID", "is_published": "Whether to get published or draft datasource plugins " - "(true for published, false for draft, default: true)" + "(true for published, false for draft, default: true)", } ) @service_api_ns.doc( @@ -136,8 +135,8 @@ class PipelineRunApi(DatasetApiResource): "datasource_info_list": "Datasource info list", "start_node_id": "Start node ID", "is_published": "Whether to get published or draft datasource plugins " - "(true for published, false for draft, default: true)", - "streaming": "Whether to stream the response(streaming or blocking), default: streaming" + "(true for published, false for draft, default: true)", + "streaming": "Whether to stream the response(streaming or blocking), default: streaming", } ) @service_api_ns.doc( @@ -154,9 +153,16 @@ class PipelineRunApi(DatasetApiResource): parser.add_argument("datasource_info_list", type=list, required=True, location="json") parser.add_argument("start_node_id", type=str, required=True, location="json") parser.add_argument("is_published", type=bool, required=True, default=True, location="json") - parser.add_argument("response_mode", type=str, required=True, choices=["streaming", "blocking"], default="blocking", location="json") + parser.add_argument( + "response_mode", + type=str, + required=True, + choices=["streaming", "blocking"], + default="blocking", + location="json", + ) args: ParseResult = parser.parse_args() - + if not isinstance(current_user, Account): raise Forbidden() @@ -173,7 +179,7 @@ class PipelineRunApi(DatasetApiResource): return helper.compact_generate_response(response) except Exception as ex: - raise PipelineRunError(description=str(ex)) + raise PipelineRunError(description=str(ex)) @service_api_ns.route("/datasets/pipeline/file-upload") @@ -189,7 +195,6 @@ class KnowledgebasePipelineFileUploadApi(DatasetApiResource): 401: "Unauthorized - invalid API token", 413: "File too large", 415: "Unsupported file type", - } ) def post(self, tenant_id: str): diff --git a/api/controllers/service_api/wraps.py b/api/controllers/service_api/wraps.py index 246d3750d1..ee8e1d105b 100644 --- a/api/controllers/service_api/wraps.py +++ b/api/controllers/service_api/wraps.py @@ -204,7 +204,7 @@ def validate_dataset_token(view: Callable[Concatenate[T, P], R] | None = None): if not dataset_id and args: # For class methods: args[0] is self, args[1] is dataset_id (if exists) # Check if first arg is likely a class instance (has __dict__ or __class__) - if len(args) > 1 and hasattr(args[0], '__dict__'): + if len(args) > 1 and hasattr(args[0], "__dict__"): # This is a class method, dataset_id should be in args[1] potential_id = args[1] # Validate it's a string-like UUID, not another object @@ -212,7 +212,7 @@ def validate_dataset_token(view: Callable[Concatenate[T, P], R] | None = None): # Try to convert to string and check if it's a valid UUID format str_id = str(potential_id) # Basic check: UUIDs are 36 chars with hyphens - if len(str_id) == 36 and str_id.count('-') == 4: + if len(str_id) == 36 and str_id.count("-") == 4: dataset_id = str_id except: pass @@ -221,7 +221,7 @@ def validate_dataset_token(view: Callable[Concatenate[T, P], R] | None = None): potential_id = args[0] try: str_id = str(potential_id) - if len(str_id) == 36 and str_id.count('-') == 4: + if len(str_id) == 36 and str_id.count("-") == 4: dataset_id = str_id except: pass diff --git a/api/core/app/apps/pipeline/pipeline_generator.py b/api/core/app/apps/pipeline/pipeline_generator.py index c9daead0ba..574d9c71bf 100644 --- a/api/core/app/apps/pipeline/pipeline_generator.py +++ b/api/core/app/apps/pipeline/pipeline_generator.py @@ -137,6 +137,7 @@ class PipelineGenerator(BaseAppGenerator): documents: list[Document] = [] if invoke_from == InvokeFrom.PUBLISHED and not is_retry and not args.get("original_document_id"): from services.dataset_service import DocumentService + for datasource_info in datasource_info_list: position = DocumentService.get_documents_position(dataset.id) document = self._build_document( @@ -234,16 +235,18 @@ class PipelineGenerator(BaseAppGenerator): workflow_thread_pool_id=workflow_thread_pool_id, ) else: - rag_pipeline_invoke_entities.append(RagPipelineInvokeEntity( - pipeline_id=pipeline.id, - user_id=user.id, - tenant_id=pipeline.tenant_id, - workflow_id=workflow.id, - streaming=streaming, - workflow_execution_id=workflow_run_id, - workflow_thread_pool_id=workflow_thread_pool_id, - application_generate_entity=application_generate_entity.model_dump(), - )) + rag_pipeline_invoke_entities.append( + RagPipelineInvokeEntity( + pipeline_id=pipeline.id, + user_id=user.id, + tenant_id=pipeline.tenant_id, + workflow_id=workflow.id, + streaming=streaming, + workflow_execution_id=workflow_run_id, + workflow_thread_pool_id=workflow_thread_pool_id, + application_generate_entity=application_generate_entity.model_dump(), + ) + ) if rag_pipeline_invoke_entities: # store the rag_pipeline_invoke_entities to object storage diff --git a/api/core/app/entities/rag_pipeline_invoke_entities.py b/api/core/app/entities/rag_pipeline_invoke_entities.py index b26f496c8a..992b8da893 100644 --- a/api/core/app/entities/rag_pipeline_invoke_entities.py +++ b/api/core/app/entities/rag_pipeline_invoke_entities.py @@ -11,4 +11,4 @@ class RagPipelineInvokeEntity(BaseModel): workflow_id: str streaming: bool workflow_execution_id: str | None = None - workflow_thread_pool_id: str | None = None \ No newline at end of file + workflow_thread_pool_id: str | None = None diff --git a/api/core/rag/datasource/keyword/jieba/jieba.py b/api/core/rag/datasource/keyword/jieba/jieba.py index 3d69d86b65..97052717db 100644 --- a/api/core/rag/datasource/keyword/jieba/jieba.py +++ b/api/core/rag/datasource/keyword/jieba/jieba.py @@ -29,9 +29,7 @@ class Jieba(BaseKeyword): with redis_client.lock(lock_name, timeout=600): keyword_table_handler = JiebaKeywordTableHandler() keyword_table = self._get_dataset_keyword_table() - keyword_number = ( - self.dataset.keyword_number or self._config.max_keywords_per_chunk - ) + keyword_number = self.dataset.keyword_number or self._config.max_keywords_per_chunk for text in texts: keywords = keyword_table_handler.extract_keywords(text.page_content, keyword_number) @@ -52,9 +50,7 @@ class Jieba(BaseKeyword): keyword_table = self._get_dataset_keyword_table() keywords_list = kwargs.get("keywords_list") - keyword_number = ( - self.dataset.keyword_number or self._config.max_keywords_per_chunk - ) + keyword_number = self.dataset.keyword_number or self._config.max_keywords_per_chunk for i in range(len(texts)): text = texts[i] if keywords_list: @@ -239,9 +235,7 @@ class Jieba(BaseKeyword): keyword_table or {}, segment.index_node_id, pre_segment_data["keywords"] ) else: - keyword_number = ( - self.dataset.keyword_number or self._config.max_keywords_per_chunk - ) + keyword_number = self.dataset.keyword_number or self._config.max_keywords_per_chunk keywords = keyword_table_handler.extract_keywords(segment.content, keyword_number) segment.keywords = list(keywords) diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index d0d3c2d426..ed2301e172 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -10,7 +10,6 @@ from collections.abc import Sequence from typing import Any, Literal import sqlalchemy as sa -import yaml from sqlalchemy import exists, func, select from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound @@ -60,7 +59,6 @@ from services.entities.knowledge_entities.knowledge_entities import ( from services.entities.knowledge_entities.rag_pipeline_entities import ( KnowledgeConfiguration, RagPipelineDatasetCreateEntity, - RetrievalSetting, ) from services.errors.account import NoPermissionError from services.errors.chunk import ChildChunkDeleteIndexError, ChildChunkIndexingError @@ -1020,7 +1018,6 @@ class DatasetService: dataset.updated_at = naive_utc_now() db.session.commit() - @staticmethod def get_dataset_auto_disable_logs(dataset_id: str): assert isinstance(current_user, Account) diff --git a/api/services/datasource_provider_service.py b/api/services/datasource_provider_service.py index 1b5077df7b..870360ceb6 100644 --- a/api/services/datasource_provider_service.py +++ b/api/services/datasource_provider_service.py @@ -345,7 +345,7 @@ class DatasourceProviderService: def is_tenant_oauth_params_enabled(self, tenant_id: str, datasource_provider_id: DatasourceProviderID) -> bool: """ check if tenant oauth params is enabled - """ + """ return ( db.session.query(DatasourceOauthTenantParamConfig) .filter_by( diff --git a/api/services/file_service.py b/api/services/file_service.py index 5708efba3c..f0bb68766d 100644 --- a/api/services/file_service.py +++ b/api/services/file_service.py @@ -19,7 +19,6 @@ from core.rag.extractor.extract_processor import ExtractProcessor from extensions.ext_storage import storage from libs.datetime_utils import naive_utc_now from libs.helper import extract_tenant_id -from libs.login import current_user from models.account import Account from models.enums import CreatorUserRole from models.model import EndUser, UploadFile @@ -121,7 +120,6 @@ class FileService: return file_size <= file_size_limit def upload_text(self, text: str, text_name: str, user_id: str, tenant_id: str) -> UploadFile: - if len(text_name) > 200: text_name = text_name[:200] # user uuid as file name @@ -241,4 +239,4 @@ class FileService: return storage.delete(upload_file.key) session.delete(upload_file) - session.commit() \ No newline at end of file + session.commit() diff --git a/api/services/rag_pipeline/entity/pipeline_service_api_entities.py b/api/services/rag_pipeline/entity/pipeline_service_api_entities.py index f1718e3cc8..35005fad71 100644 --- a/api/services/rag_pipeline/entity/pipeline_service_api_entities.py +++ b/api/services/rag_pipeline/entity/pipeline_service_api_entities.py @@ -1,4 +1,6 @@ -from typing import Any, Mapping, Optional +from collections.abc import Mapping +from typing import Any, Optional + from pydantic import BaseModel @@ -10,10 +12,11 @@ class DatasourceNodeRunApiEntity(BaseModel): credential_id: Optional[str] = None is_published: bool + class PipelineRunApiEntity(BaseModel): inputs: Mapping[str, Any] datasource_type: str datasource_info_list: list[Mapping[str, Any]] start_node_id: str is_published: bool - response_mode: str \ No newline at end of file + response_mode: str diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index f9ef050c52..4f97e0f9bc 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -7,7 +7,6 @@ from collections.abc import Callable, Generator, Mapping, Sequence from datetime import UTC, datetime from typing import Any, Optional, Union, cast from uuid import uuid4 -import uuid from flask_login import current_user from sqlalchemy import func, or_, select @@ -15,7 +14,6 @@ from sqlalchemy.orm import Session, sessionmaker import contexts from configs import dify_config -from core.app.apps.pipeline.pipeline_config_manager import PipelineConfigManager from core.app.apps.pipeline.pipeline_generator import PipelineGenerator from core.app.entities.app_invoke_entities import InvokeFrom from core.datasource.entities.datasource_entities import ( @@ -57,7 +55,14 @@ from core.workflow.workflow_entry import WorkflowEntry from extensions.ext_database import db from libs.infinite_scroll_pagination import InfiniteScrollPagination from models.account import Account -from models.dataset import Dataset, Document, DocumentPipelineExecutionLog, Pipeline, PipelineCustomizedTemplate, PipelineRecommendedPlugin # type: ignore +from models.dataset import ( # type: ignore + Dataset, + Document, + DocumentPipelineExecutionLog, + Pipeline, + PipelineCustomizedTemplate, + PipelineRecommendedPlugin, +) from models.enums import WorkflowRunTriggeredFrom from models.model import EndUser from models.workflow import ( @@ -1320,8 +1325,11 @@ class RagPipelineService: """ Retry error document """ - document_pipeline_excution_log = db.session.query(DocumentPipelineExecutionLog).filter( - DocumentPipelineExecutionLog.document_id == document.id).first() + document_pipeline_excution_log = ( + db.session.query(DocumentPipelineExecutionLog) + .filter(DocumentPipelineExecutionLog.document_id == document.id) + .first() + ) if not document_pipeline_excution_log: raise ValueError("Document pipeline execution log not found") pipeline = db.session.query(Pipeline).filter(Pipeline.id == document_pipeline_excution_log.pipeline_id).first() diff --git a/api/tasks/rag_pipeline/priority_rag_pipeline_run_task.py b/api/tasks/rag_pipeline/priority_rag_pipeline_run_task.py index 5ccc51a66a..7021ddab38 100644 --- a/api/tasks/rag_pipeline/priority_rag_pipeline_run_task.py +++ b/api/tasks/rag_pipeline/priority_rag_pipeline_run_task.py @@ -52,19 +52,21 @@ def priority_rag_pipeline_run_task( try: start_at = time.perf_counter() - rag_pipeline_invoke_entities_content = FileService(db.engine).get_file_content(rag_pipeline_invoke_entities_file_id) + rag_pipeline_invoke_entities_content = FileService(db.engine).get_file_content( + rag_pipeline_invoke_entities_file_id + ) rag_pipeline_invoke_entities = json.loads(rag_pipeline_invoke_entities_content) - + # Get Flask app object for thread context flask_app = current_app._get_current_object() # type: ignore - + with ThreadPoolExecutor(max_workers=10) as executor: futures = [] for rag_pipeline_invoke_entity in rag_pipeline_invoke_entities: # Submit task to thread pool with Flask app future = executor.submit(run_single_rag_pipeline_task, rag_pipeline_invoke_entity, flask_app) futures.append(future) - + # Wait for all tasks to complete for future in futures: try: @@ -73,7 +75,9 @@ def priority_rag_pipeline_run_task( logging.exception("Error in pipeline task") end_at = time.perf_counter() logging.info( - click.style(f"tenant_id: {tenant_id} , Rag pipeline run completed. Latency: {end_at - start_at}s", fg="green") + click.style( + f"tenant_id: {tenant_id} , Rag pipeline run completed. Latency: {end_at - start_at}s", fg="green" + ) ) except Exception: logging.exception(click.style(f"Error running rag pipeline, tenant_id: {tenant_id}", fg="red")) @@ -83,6 +87,7 @@ def priority_rag_pipeline_run_task( file_service.delete_file(rag_pipeline_invoke_entities_file_id) db.session.close() + def run_single_rag_pipeline_task(rag_pipeline_invoke_entity: Mapping[str, Any], flask_app): """Run a single RAG pipeline task within Flask app context.""" # Create Flask application context for this thread @@ -97,13 +102,13 @@ def run_single_rag_pipeline_task(rag_pipeline_invoke_entity: Mapping[str, Any], workflow_execution_id = rag_pipeline_invoke_entity_model.workflow_execution_id workflow_thread_pool_id = rag_pipeline_invoke_entity_model.workflow_thread_pool_id application_generate_entity = rag_pipeline_invoke_entity_model.application_generate_entity - + with Session(db.engine) as session: # Load required entities account = session.query(Account).filter(Account.id == user_id).first() if not account: raise ValueError(f"Account {user_id} not found") - + tenant = session.query(Tenant).filter(Tenant.id == tenant_id).first() if not tenant: raise ValueError(f"Tenant {tenant_id} not found") diff --git a/api/tasks/rag_pipeline/rag_pipeline_run_task.py b/api/tasks/rag_pipeline/rag_pipeline_run_task.py index 5d64177e7e..d71a305b14 100644 --- a/api/tasks/rag_pipeline/rag_pipeline_run_task.py +++ b/api/tasks/rag_pipeline/rag_pipeline_run_task.py @@ -54,7 +54,8 @@ def rag_pipeline_run_task( try: start_at = time.perf_counter() rag_pipeline_invoke_entities_content = FileService(db.engine).get_file_content( - rag_pipeline_invoke_entities_file_id) + rag_pipeline_invoke_entities_file_id + ) rag_pipeline_invoke_entities = json.loads(rag_pipeline_invoke_entities_content) # Get Flask app object for thread context @@ -75,8 +76,9 @@ def rag_pipeline_run_task( logging.exception("Error in pipeline task") end_at = time.perf_counter() logging.info( - click.style(f"tenant_id: {tenant_id} , Rag pipeline run completed. Latency: {end_at - start_at}s", - fg="green") + click.style( + f"tenant_id: {tenant_id} , Rag pipeline run completed. Latency: {end_at - start_at}s", fg="green" + ) ) except Exception: logging.exception(click.style(f"Error running rag pipeline, tenant_id: {tenant_id}", fg="red")) @@ -94,8 +96,9 @@ def rag_pipeline_run_task( # Keep the flag set to indicate a task is running redis_client.setex(tenant_pipeline_task_key, 60 * 60, 1) rag_pipeline_run_task.delay( # type: ignore - rag_pipeline_invoke_entities_file_id=next_file_id.decode('utf-8') if isinstance(next_file_id, - bytes) else next_file_id, + rag_pipeline_invoke_entities_file_id=next_file_id.decode("utf-8") + if isinstance(next_file_id, bytes) + else next_file_id, tenant_id=tenant_id, ) else: @@ -120,13 +123,13 @@ def run_single_rag_pipeline_task(rag_pipeline_invoke_entity: Mapping[str, Any], workflow_execution_id = rag_pipeline_invoke_entity_model.workflow_execution_id workflow_thread_pool_id = rag_pipeline_invoke_entity_model.workflow_thread_pool_id application_generate_entity = rag_pipeline_invoke_entity_model.application_generate_entity - + with Session(db.engine) as session: # Load required entities account = session.query(Account).filter(Account.id == user_id).first() if not account: raise ValueError(f"Account {user_id} not found") - + tenant = session.query(Tenant).filter(Tenant.id == tenant_id).first() if not tenant: raise ValueError(f"Tenant {tenant_id} not found")