dev/reformat

This commit is contained in:
jyong 2025-09-16 16:08:04 +08:00
parent 05aec66424
commit 0ec037b803
14 changed files with 87 additions and 67 deletions

View File

@ -739,6 +739,8 @@ class DatasetApiDeleteApi(Resource):
db.session.commit()
return {"result": "success"}, 204
@console_ns.route("/datasets/<uuid:dataset_id>/api-keys/<str:status>")
class DatasetEnableApiApi(Resource):
@setup_required

View File

@ -124,8 +124,9 @@ class DocumentAddByTextApi(DatasetApiResource):
args.get("retrieval_model").get("reranking_model").get("reranking_model_name"),
)
upload_file = FileService(db.engine).upload_text(text=str(text),
text_name=str(name), user_id=current_user.id, tenant_id=tenant_id)
upload_file = FileService(db.engine).upload_text(
text=str(text), text_name=str(name), user_id=current_user.id, tenant_id=tenant_id
)
data_source = {
"type": "upload_file",
"info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}},
@ -203,8 +204,9 @@ class DocumentUpdateByTextApi(DatasetApiResource):
name = args.get("name")
if text is None or name is None:
raise ValueError("Both text and name must be strings.")
upload_file = FileService(db.engine).upload_text(text=str(text),
text_name=str(name), user_id=current_user.id, tenant_id=tenant_id)
upload_file = FileService(db.engine).upload_text(
text=str(text), text_name=str(name), user_id=current_user.id, tenant_id=tenant_id
)
data_source = {
"type": "upload_file",
"info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}},

View File

@ -41,7 +41,7 @@ class DatasourcePluginsApi(DatasetApiResource):
@service_api_ns.doc(
params={
"is_published": "Whether to get published or draft datasource plugins "
"(true for published, false for draft, default: true)"
"(true for published, false for draft, default: true)"
}
)
@service_api_ns.doc(
@ -54,15 +54,14 @@ class DatasourcePluginsApi(DatasetApiResource):
"""Resource for getting datasource plugins."""
# Get query parameter to determine published or draft
is_published: bool = request.args.get("is_published", default=True, type=bool)
rag_pipeline_service: RagPipelineService = RagPipelineService()
datasource_plugins: list[dict[Any, Any]] = rag_pipeline_service.get_datasource_plugins(
tenant_id=tenant_id,
dataset_id=dataset_id,
is_published=is_published
tenant_id=tenant_id, dataset_id=dataset_id, is_published=is_published
)
return datasource_plugins, 200
@service_api_ns.route(f"/datasets/{uuid:dataset_id}/pipeline/datasource/nodes/{string:node_id}/run")
class DatasourceNodeRunApi(DatasetApiResource):
"""Resource for datasource node run."""
@ -80,7 +79,7 @@ class DatasourceNodeRunApi(DatasetApiResource):
"datasource_type": "Datasource type, e.g. online_document",
"credential_id": "Credential ID",
"is_published": "Whether to get published or draft datasource plugins "
"(true for published, false for draft, default: true)"
"(true for published, false for draft, default: true)",
}
)
@service_api_ns.doc(
@ -136,8 +135,8 @@ class PipelineRunApi(DatasetApiResource):
"datasource_info_list": "Datasource info list",
"start_node_id": "Start node ID",
"is_published": "Whether to get published or draft datasource plugins "
"(true for published, false for draft, default: true)",
"streaming": "Whether to stream the response(streaming or blocking), default: streaming"
"(true for published, false for draft, default: true)",
"streaming": "Whether to stream the response(streaming or blocking), default: streaming",
}
)
@service_api_ns.doc(
@ -154,9 +153,16 @@ class PipelineRunApi(DatasetApiResource):
parser.add_argument("datasource_info_list", type=list, required=True, location="json")
parser.add_argument("start_node_id", type=str, required=True, location="json")
parser.add_argument("is_published", type=bool, required=True, default=True, location="json")
parser.add_argument("response_mode", type=str, required=True, choices=["streaming", "blocking"], default="blocking", location="json")
parser.add_argument(
"response_mode",
type=str,
required=True,
choices=["streaming", "blocking"],
default="blocking",
location="json",
)
args: ParseResult = parser.parse_args()
if not isinstance(current_user, Account):
raise Forbidden()
@ -173,7 +179,7 @@ class PipelineRunApi(DatasetApiResource):
return helper.compact_generate_response(response)
except Exception as ex:
raise PipelineRunError(description=str(ex))
raise PipelineRunError(description=str(ex))
@service_api_ns.route("/datasets/pipeline/file-upload")
@ -189,7 +195,6 @@ class KnowledgebasePipelineFileUploadApi(DatasetApiResource):
401: "Unauthorized - invalid API token",
413: "File too large",
415: "Unsupported file type",
}
)
def post(self, tenant_id: str):

View File

@ -204,7 +204,7 @@ def validate_dataset_token(view: Callable[Concatenate[T, P], R] | None = None):
if not dataset_id and args:
# For class methods: args[0] is self, args[1] is dataset_id (if exists)
# Check if first arg is likely a class instance (has __dict__ or __class__)
if len(args) > 1 and hasattr(args[0], '__dict__'):
if len(args) > 1 and hasattr(args[0], "__dict__"):
# This is a class method, dataset_id should be in args[1]
potential_id = args[1]
# Validate it's a string-like UUID, not another object
@ -212,7 +212,7 @@ def validate_dataset_token(view: Callable[Concatenate[T, P], R] | None = None):
# Try to convert to string and check if it's a valid UUID format
str_id = str(potential_id)
# Basic check: UUIDs are 36 chars with hyphens
if len(str_id) == 36 and str_id.count('-') == 4:
if len(str_id) == 36 and str_id.count("-") == 4:
dataset_id = str_id
except:
pass
@ -221,7 +221,7 @@ def validate_dataset_token(view: Callable[Concatenate[T, P], R] | None = None):
potential_id = args[0]
try:
str_id = str(potential_id)
if len(str_id) == 36 and str_id.count('-') == 4:
if len(str_id) == 36 and str_id.count("-") == 4:
dataset_id = str_id
except:
pass

View File

@ -137,6 +137,7 @@ class PipelineGenerator(BaseAppGenerator):
documents: list[Document] = []
if invoke_from == InvokeFrom.PUBLISHED and not is_retry and not args.get("original_document_id"):
from services.dataset_service import DocumentService
for datasource_info in datasource_info_list:
position = DocumentService.get_documents_position(dataset.id)
document = self._build_document(
@ -234,16 +235,18 @@ class PipelineGenerator(BaseAppGenerator):
workflow_thread_pool_id=workflow_thread_pool_id,
)
else:
rag_pipeline_invoke_entities.append(RagPipelineInvokeEntity(
pipeline_id=pipeline.id,
user_id=user.id,
tenant_id=pipeline.tenant_id,
workflow_id=workflow.id,
streaming=streaming,
workflow_execution_id=workflow_run_id,
workflow_thread_pool_id=workflow_thread_pool_id,
application_generate_entity=application_generate_entity.model_dump(),
))
rag_pipeline_invoke_entities.append(
RagPipelineInvokeEntity(
pipeline_id=pipeline.id,
user_id=user.id,
tenant_id=pipeline.tenant_id,
workflow_id=workflow.id,
streaming=streaming,
workflow_execution_id=workflow_run_id,
workflow_thread_pool_id=workflow_thread_pool_id,
application_generate_entity=application_generate_entity.model_dump(),
)
)
if rag_pipeline_invoke_entities:
# store the rag_pipeline_invoke_entities to object storage

View File

@ -11,4 +11,4 @@ class RagPipelineInvokeEntity(BaseModel):
workflow_id: str
streaming: bool
workflow_execution_id: str | None = None
workflow_thread_pool_id: str | None = None
workflow_thread_pool_id: str | None = None

View File

@ -29,9 +29,7 @@ class Jieba(BaseKeyword):
with redis_client.lock(lock_name, timeout=600):
keyword_table_handler = JiebaKeywordTableHandler()
keyword_table = self._get_dataset_keyword_table()
keyword_number = (
self.dataset.keyword_number or self._config.max_keywords_per_chunk
)
keyword_number = self.dataset.keyword_number or self._config.max_keywords_per_chunk
for text in texts:
keywords = keyword_table_handler.extract_keywords(text.page_content, keyword_number)
@ -52,9 +50,7 @@ class Jieba(BaseKeyword):
keyword_table = self._get_dataset_keyword_table()
keywords_list = kwargs.get("keywords_list")
keyword_number = (
self.dataset.keyword_number or self._config.max_keywords_per_chunk
)
keyword_number = self.dataset.keyword_number or self._config.max_keywords_per_chunk
for i in range(len(texts)):
text = texts[i]
if keywords_list:
@ -239,9 +235,7 @@ class Jieba(BaseKeyword):
keyword_table or {}, segment.index_node_id, pre_segment_data["keywords"]
)
else:
keyword_number = (
self.dataset.keyword_number or self._config.max_keywords_per_chunk
)
keyword_number = self.dataset.keyword_number or self._config.max_keywords_per_chunk
keywords = keyword_table_handler.extract_keywords(segment.content, keyword_number)
segment.keywords = list(keywords)

View File

@ -10,7 +10,6 @@ from collections.abc import Sequence
from typing import Any, Literal
import sqlalchemy as sa
import yaml
from sqlalchemy import exists, func, select
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound
@ -60,7 +59,6 @@ from services.entities.knowledge_entities.knowledge_entities import (
from services.entities.knowledge_entities.rag_pipeline_entities import (
KnowledgeConfiguration,
RagPipelineDatasetCreateEntity,
RetrievalSetting,
)
from services.errors.account import NoPermissionError
from services.errors.chunk import ChildChunkDeleteIndexError, ChildChunkIndexingError
@ -1020,7 +1018,6 @@ class DatasetService:
dataset.updated_at = naive_utc_now()
db.session.commit()
@staticmethod
def get_dataset_auto_disable_logs(dataset_id: str):
assert isinstance(current_user, Account)

View File

@ -345,7 +345,7 @@ class DatasourceProviderService:
def is_tenant_oauth_params_enabled(self, tenant_id: str, datasource_provider_id: DatasourceProviderID) -> bool:
"""
check if tenant oauth params is enabled
"""
"""
return (
db.session.query(DatasourceOauthTenantParamConfig)
.filter_by(

View File

@ -19,7 +19,6 @@ from core.rag.extractor.extract_processor import ExtractProcessor
from extensions.ext_storage import storage
from libs.datetime_utils import naive_utc_now
from libs.helper import extract_tenant_id
from libs.login import current_user
from models.account import Account
from models.enums import CreatorUserRole
from models.model import EndUser, UploadFile
@ -121,7 +120,6 @@ class FileService:
return file_size <= file_size_limit
def upload_text(self, text: str, text_name: str, user_id: str, tenant_id: str) -> UploadFile:
if len(text_name) > 200:
text_name = text_name[:200]
# user uuid as file name
@ -241,4 +239,4 @@ class FileService:
return
storage.delete(upload_file.key)
session.delete(upload_file)
session.commit()
session.commit()

View File

@ -1,4 +1,6 @@
from typing import Any, Mapping, Optional
from collections.abc import Mapping
from typing import Any, Optional
from pydantic import BaseModel
@ -10,10 +12,11 @@ class DatasourceNodeRunApiEntity(BaseModel):
credential_id: Optional[str] = None
is_published: bool
class PipelineRunApiEntity(BaseModel):
inputs: Mapping[str, Any]
datasource_type: str
datasource_info_list: list[Mapping[str, Any]]
start_node_id: str
is_published: bool
response_mode: str
response_mode: str

View File

@ -7,7 +7,6 @@ from collections.abc import Callable, Generator, Mapping, Sequence
from datetime import UTC, datetime
from typing import Any, Optional, Union, cast
from uuid import uuid4
import uuid
from flask_login import current_user
from sqlalchemy import func, or_, select
@ -15,7 +14,6 @@ from sqlalchemy.orm import Session, sessionmaker
import contexts
from configs import dify_config
from core.app.apps.pipeline.pipeline_config_manager import PipelineConfigManager
from core.app.apps.pipeline.pipeline_generator import PipelineGenerator
from core.app.entities.app_invoke_entities import InvokeFrom
from core.datasource.entities.datasource_entities import (
@ -57,7 +55,14 @@ from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from models.account import Account
from models.dataset import Dataset, Document, DocumentPipelineExecutionLog, Pipeline, PipelineCustomizedTemplate, PipelineRecommendedPlugin # type: ignore
from models.dataset import ( # type: ignore
Dataset,
Document,
DocumentPipelineExecutionLog,
Pipeline,
PipelineCustomizedTemplate,
PipelineRecommendedPlugin,
)
from models.enums import WorkflowRunTriggeredFrom
from models.model import EndUser
from models.workflow import (
@ -1320,8 +1325,11 @@ class RagPipelineService:
"""
Retry error document
"""
document_pipeline_excution_log = db.session.query(DocumentPipelineExecutionLog).filter(
DocumentPipelineExecutionLog.document_id == document.id).first()
document_pipeline_excution_log = (
db.session.query(DocumentPipelineExecutionLog)
.filter(DocumentPipelineExecutionLog.document_id == document.id)
.first()
)
if not document_pipeline_excution_log:
raise ValueError("Document pipeline execution log not found")
pipeline = db.session.query(Pipeline).filter(Pipeline.id == document_pipeline_excution_log.pipeline_id).first()

View File

@ -52,19 +52,21 @@ def priority_rag_pipeline_run_task(
try:
start_at = time.perf_counter()
rag_pipeline_invoke_entities_content = FileService(db.engine).get_file_content(rag_pipeline_invoke_entities_file_id)
rag_pipeline_invoke_entities_content = FileService(db.engine).get_file_content(
rag_pipeline_invoke_entities_file_id
)
rag_pipeline_invoke_entities = json.loads(rag_pipeline_invoke_entities_content)
# Get Flask app object for thread context
flask_app = current_app._get_current_object() # type: ignore
with ThreadPoolExecutor(max_workers=10) as executor:
futures = []
for rag_pipeline_invoke_entity in rag_pipeline_invoke_entities:
# Submit task to thread pool with Flask app
future = executor.submit(run_single_rag_pipeline_task, rag_pipeline_invoke_entity, flask_app)
futures.append(future)
# Wait for all tasks to complete
for future in futures:
try:
@ -73,7 +75,9 @@ def priority_rag_pipeline_run_task(
logging.exception("Error in pipeline task")
end_at = time.perf_counter()
logging.info(
click.style(f"tenant_id: {tenant_id} , Rag pipeline run completed. Latency: {end_at - start_at}s", fg="green")
click.style(
f"tenant_id: {tenant_id} , Rag pipeline run completed. Latency: {end_at - start_at}s", fg="green"
)
)
except Exception:
logging.exception(click.style(f"Error running rag pipeline, tenant_id: {tenant_id}", fg="red"))
@ -83,6 +87,7 @@ def priority_rag_pipeline_run_task(
file_service.delete_file(rag_pipeline_invoke_entities_file_id)
db.session.close()
def run_single_rag_pipeline_task(rag_pipeline_invoke_entity: Mapping[str, Any], flask_app):
"""Run a single RAG pipeline task within Flask app context."""
# Create Flask application context for this thread
@ -97,13 +102,13 @@ def run_single_rag_pipeline_task(rag_pipeline_invoke_entity: Mapping[str, Any],
workflow_execution_id = rag_pipeline_invoke_entity_model.workflow_execution_id
workflow_thread_pool_id = rag_pipeline_invoke_entity_model.workflow_thread_pool_id
application_generate_entity = rag_pipeline_invoke_entity_model.application_generate_entity
with Session(db.engine) as session:
# Load required entities
account = session.query(Account).filter(Account.id == user_id).first()
if not account:
raise ValueError(f"Account {user_id} not found")
tenant = session.query(Tenant).filter(Tenant.id == tenant_id).first()
if not tenant:
raise ValueError(f"Tenant {tenant_id} not found")

View File

@ -54,7 +54,8 @@ def rag_pipeline_run_task(
try:
start_at = time.perf_counter()
rag_pipeline_invoke_entities_content = FileService(db.engine).get_file_content(
rag_pipeline_invoke_entities_file_id)
rag_pipeline_invoke_entities_file_id
)
rag_pipeline_invoke_entities = json.loads(rag_pipeline_invoke_entities_content)
# Get Flask app object for thread context
@ -75,8 +76,9 @@ def rag_pipeline_run_task(
logging.exception("Error in pipeline task")
end_at = time.perf_counter()
logging.info(
click.style(f"tenant_id: {tenant_id} , Rag pipeline run completed. Latency: {end_at - start_at}s",
fg="green")
click.style(
f"tenant_id: {tenant_id} , Rag pipeline run completed. Latency: {end_at - start_at}s", fg="green"
)
)
except Exception:
logging.exception(click.style(f"Error running rag pipeline, tenant_id: {tenant_id}", fg="red"))
@ -94,8 +96,9 @@ def rag_pipeline_run_task(
# Keep the flag set to indicate a task is running
redis_client.setex(tenant_pipeline_task_key, 60 * 60, 1)
rag_pipeline_run_task.delay( # type: ignore
rag_pipeline_invoke_entities_file_id=next_file_id.decode('utf-8') if isinstance(next_file_id,
bytes) else next_file_id,
rag_pipeline_invoke_entities_file_id=next_file_id.decode("utf-8")
if isinstance(next_file_id, bytes)
else next_file_id,
tenant_id=tenant_id,
)
else:
@ -120,13 +123,13 @@ def run_single_rag_pipeline_task(rag_pipeline_invoke_entity: Mapping[str, Any],
workflow_execution_id = rag_pipeline_invoke_entity_model.workflow_execution_id
workflow_thread_pool_id = rag_pipeline_invoke_entity_model.workflow_thread_pool_id
application_generate_entity = rag_pipeline_invoke_entity_model.application_generate_entity
with Session(db.engine) as session:
# Load required entities
account = session.query(Account).filter(Account.id == user_id).first()
if not account:
raise ValueError(f"Account {user_id} not found")
tenant = session.query(Tenant).filter(Tenant.id == tenant_id).first()
if not tenant:
raise ValueError(f"Tenant {tenant_id} not found")