mirror of
https://github.com/langgenius/dify.git
synced 2026-03-10 03:00:20 +08:00
dev/reformat
This commit is contained in:
parent
05aec66424
commit
0ec037b803
@ -739,6 +739,8 @@ class DatasetApiDeleteApi(Resource):
|
||||
db.session.commit()
|
||||
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/api-keys/<str:status>")
|
||||
class DatasetEnableApiApi(Resource):
|
||||
@setup_required
|
||||
|
||||
@ -124,8 +124,9 @@ class DocumentAddByTextApi(DatasetApiResource):
|
||||
args.get("retrieval_model").get("reranking_model").get("reranking_model_name"),
|
||||
)
|
||||
|
||||
upload_file = FileService(db.engine).upload_text(text=str(text),
|
||||
text_name=str(name), user_id=current_user.id, tenant_id=tenant_id)
|
||||
upload_file = FileService(db.engine).upload_text(
|
||||
text=str(text), text_name=str(name), user_id=current_user.id, tenant_id=tenant_id
|
||||
)
|
||||
data_source = {
|
||||
"type": "upload_file",
|
||||
"info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}},
|
||||
@ -203,8 +204,9 @@ class DocumentUpdateByTextApi(DatasetApiResource):
|
||||
name = args.get("name")
|
||||
if text is None or name is None:
|
||||
raise ValueError("Both text and name must be strings.")
|
||||
upload_file = FileService(db.engine).upload_text(text=str(text),
|
||||
text_name=str(name), user_id=current_user.id, tenant_id=tenant_id)
|
||||
upload_file = FileService(db.engine).upload_text(
|
||||
text=str(text), text_name=str(name), user_id=current_user.id, tenant_id=tenant_id
|
||||
)
|
||||
data_source = {
|
||||
"type": "upload_file",
|
||||
"info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}},
|
||||
|
||||
@ -41,7 +41,7 @@ class DatasourcePluginsApi(DatasetApiResource):
|
||||
@service_api_ns.doc(
|
||||
params={
|
||||
"is_published": "Whether to get published or draft datasource plugins "
|
||||
"(true for published, false for draft, default: true)"
|
||||
"(true for published, false for draft, default: true)"
|
||||
}
|
||||
)
|
||||
@service_api_ns.doc(
|
||||
@ -54,15 +54,14 @@ class DatasourcePluginsApi(DatasetApiResource):
|
||||
"""Resource for getting datasource plugins."""
|
||||
# Get query parameter to determine published or draft
|
||||
is_published: bool = request.args.get("is_published", default=True, type=bool)
|
||||
|
||||
|
||||
rag_pipeline_service: RagPipelineService = RagPipelineService()
|
||||
datasource_plugins: list[dict[Any, Any]] = rag_pipeline_service.get_datasource_plugins(
|
||||
tenant_id=tenant_id,
|
||||
dataset_id=dataset_id,
|
||||
is_published=is_published
|
||||
tenant_id=tenant_id, dataset_id=dataset_id, is_published=is_published
|
||||
)
|
||||
return datasource_plugins, 200
|
||||
|
||||
|
||||
@service_api_ns.route(f"/datasets/{uuid:dataset_id}/pipeline/datasource/nodes/{string:node_id}/run")
|
||||
class DatasourceNodeRunApi(DatasetApiResource):
|
||||
"""Resource for datasource node run."""
|
||||
@ -80,7 +79,7 @@ class DatasourceNodeRunApi(DatasetApiResource):
|
||||
"datasource_type": "Datasource type, e.g. online_document",
|
||||
"credential_id": "Credential ID",
|
||||
"is_published": "Whether to get published or draft datasource plugins "
|
||||
"(true for published, false for draft, default: true)"
|
||||
"(true for published, false for draft, default: true)",
|
||||
}
|
||||
)
|
||||
@service_api_ns.doc(
|
||||
@ -136,8 +135,8 @@ class PipelineRunApi(DatasetApiResource):
|
||||
"datasource_info_list": "Datasource info list",
|
||||
"start_node_id": "Start node ID",
|
||||
"is_published": "Whether to get published or draft datasource plugins "
|
||||
"(true for published, false for draft, default: true)",
|
||||
"streaming": "Whether to stream the response(streaming or blocking), default: streaming"
|
||||
"(true for published, false for draft, default: true)",
|
||||
"streaming": "Whether to stream the response(streaming or blocking), default: streaming",
|
||||
}
|
||||
)
|
||||
@service_api_ns.doc(
|
||||
@ -154,9 +153,16 @@ class PipelineRunApi(DatasetApiResource):
|
||||
parser.add_argument("datasource_info_list", type=list, required=True, location="json")
|
||||
parser.add_argument("start_node_id", type=str, required=True, location="json")
|
||||
parser.add_argument("is_published", type=bool, required=True, default=True, location="json")
|
||||
parser.add_argument("response_mode", type=str, required=True, choices=["streaming", "blocking"], default="blocking", location="json")
|
||||
parser.add_argument(
|
||||
"response_mode",
|
||||
type=str,
|
||||
required=True,
|
||||
choices=["streaming", "blocking"],
|
||||
default="blocking",
|
||||
location="json",
|
||||
)
|
||||
args: ParseResult = parser.parse_args()
|
||||
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
|
||||
@ -173,7 +179,7 @@ class PipelineRunApi(DatasetApiResource):
|
||||
|
||||
return helper.compact_generate_response(response)
|
||||
except Exception as ex:
|
||||
raise PipelineRunError(description=str(ex))
|
||||
raise PipelineRunError(description=str(ex))
|
||||
|
||||
|
||||
@service_api_ns.route("/datasets/pipeline/file-upload")
|
||||
@ -189,7 +195,6 @@ class KnowledgebasePipelineFileUploadApi(DatasetApiResource):
|
||||
401: "Unauthorized - invalid API token",
|
||||
413: "File too large",
|
||||
415: "Unsupported file type",
|
||||
|
||||
}
|
||||
)
|
||||
def post(self, tenant_id: str):
|
||||
|
||||
@ -204,7 +204,7 @@ def validate_dataset_token(view: Callable[Concatenate[T, P], R] | None = None):
|
||||
if not dataset_id and args:
|
||||
# For class methods: args[0] is self, args[1] is dataset_id (if exists)
|
||||
# Check if first arg is likely a class instance (has __dict__ or __class__)
|
||||
if len(args) > 1 and hasattr(args[0], '__dict__'):
|
||||
if len(args) > 1 and hasattr(args[0], "__dict__"):
|
||||
# This is a class method, dataset_id should be in args[1]
|
||||
potential_id = args[1]
|
||||
# Validate it's a string-like UUID, not another object
|
||||
@ -212,7 +212,7 @@ def validate_dataset_token(view: Callable[Concatenate[T, P], R] | None = None):
|
||||
# Try to convert to string and check if it's a valid UUID format
|
||||
str_id = str(potential_id)
|
||||
# Basic check: UUIDs are 36 chars with hyphens
|
||||
if len(str_id) == 36 and str_id.count('-') == 4:
|
||||
if len(str_id) == 36 and str_id.count("-") == 4:
|
||||
dataset_id = str_id
|
||||
except:
|
||||
pass
|
||||
@ -221,7 +221,7 @@ def validate_dataset_token(view: Callable[Concatenate[T, P], R] | None = None):
|
||||
potential_id = args[0]
|
||||
try:
|
||||
str_id = str(potential_id)
|
||||
if len(str_id) == 36 and str_id.count('-') == 4:
|
||||
if len(str_id) == 36 and str_id.count("-") == 4:
|
||||
dataset_id = str_id
|
||||
except:
|
||||
pass
|
||||
|
||||
@ -137,6 +137,7 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
documents: list[Document] = []
|
||||
if invoke_from == InvokeFrom.PUBLISHED and not is_retry and not args.get("original_document_id"):
|
||||
from services.dataset_service import DocumentService
|
||||
|
||||
for datasource_info in datasource_info_list:
|
||||
position = DocumentService.get_documents_position(dataset.id)
|
||||
document = self._build_document(
|
||||
@ -234,16 +235,18 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
workflow_thread_pool_id=workflow_thread_pool_id,
|
||||
)
|
||||
else:
|
||||
rag_pipeline_invoke_entities.append(RagPipelineInvokeEntity(
|
||||
pipeline_id=pipeline.id,
|
||||
user_id=user.id,
|
||||
tenant_id=pipeline.tenant_id,
|
||||
workflow_id=workflow.id,
|
||||
streaming=streaming,
|
||||
workflow_execution_id=workflow_run_id,
|
||||
workflow_thread_pool_id=workflow_thread_pool_id,
|
||||
application_generate_entity=application_generate_entity.model_dump(),
|
||||
))
|
||||
rag_pipeline_invoke_entities.append(
|
||||
RagPipelineInvokeEntity(
|
||||
pipeline_id=pipeline.id,
|
||||
user_id=user.id,
|
||||
tenant_id=pipeline.tenant_id,
|
||||
workflow_id=workflow.id,
|
||||
streaming=streaming,
|
||||
workflow_execution_id=workflow_run_id,
|
||||
workflow_thread_pool_id=workflow_thread_pool_id,
|
||||
application_generate_entity=application_generate_entity.model_dump(),
|
||||
)
|
||||
)
|
||||
|
||||
if rag_pipeline_invoke_entities:
|
||||
# store the rag_pipeline_invoke_entities to object storage
|
||||
|
||||
@ -11,4 +11,4 @@ class RagPipelineInvokeEntity(BaseModel):
|
||||
workflow_id: str
|
||||
streaming: bool
|
||||
workflow_execution_id: str | None = None
|
||||
workflow_thread_pool_id: str | None = None
|
||||
workflow_thread_pool_id: str | None = None
|
||||
|
||||
@ -29,9 +29,7 @@ class Jieba(BaseKeyword):
|
||||
with redis_client.lock(lock_name, timeout=600):
|
||||
keyword_table_handler = JiebaKeywordTableHandler()
|
||||
keyword_table = self._get_dataset_keyword_table()
|
||||
keyword_number = (
|
||||
self.dataset.keyword_number or self._config.max_keywords_per_chunk
|
||||
)
|
||||
keyword_number = self.dataset.keyword_number or self._config.max_keywords_per_chunk
|
||||
|
||||
for text in texts:
|
||||
keywords = keyword_table_handler.extract_keywords(text.page_content, keyword_number)
|
||||
@ -52,9 +50,7 @@ class Jieba(BaseKeyword):
|
||||
|
||||
keyword_table = self._get_dataset_keyword_table()
|
||||
keywords_list = kwargs.get("keywords_list")
|
||||
keyword_number = (
|
||||
self.dataset.keyword_number or self._config.max_keywords_per_chunk
|
||||
)
|
||||
keyword_number = self.dataset.keyword_number or self._config.max_keywords_per_chunk
|
||||
for i in range(len(texts)):
|
||||
text = texts[i]
|
||||
if keywords_list:
|
||||
@ -239,9 +235,7 @@ class Jieba(BaseKeyword):
|
||||
keyword_table or {}, segment.index_node_id, pre_segment_data["keywords"]
|
||||
)
|
||||
else:
|
||||
keyword_number = (
|
||||
self.dataset.keyword_number or self._config.max_keywords_per_chunk
|
||||
)
|
||||
keyword_number = self.dataset.keyword_number or self._config.max_keywords_per_chunk
|
||||
|
||||
keywords = keyword_table_handler.extract_keywords(segment.content, keyword_number)
|
||||
segment.keywords = list(keywords)
|
||||
|
||||
@ -10,7 +10,6 @@ from collections.abc import Sequence
|
||||
from typing import Any, Literal
|
||||
|
||||
import sqlalchemy as sa
|
||||
import yaml
|
||||
from sqlalchemy import exists, func, select
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import NotFound
|
||||
@ -60,7 +59,6 @@ from services.entities.knowledge_entities.knowledge_entities import (
|
||||
from services.entities.knowledge_entities.rag_pipeline_entities import (
|
||||
KnowledgeConfiguration,
|
||||
RagPipelineDatasetCreateEntity,
|
||||
RetrievalSetting,
|
||||
)
|
||||
from services.errors.account import NoPermissionError
|
||||
from services.errors.chunk import ChildChunkDeleteIndexError, ChildChunkIndexingError
|
||||
@ -1020,7 +1018,6 @@ class DatasetService:
|
||||
dataset.updated_at = naive_utc_now()
|
||||
db.session.commit()
|
||||
|
||||
|
||||
@staticmethod
|
||||
def get_dataset_auto_disable_logs(dataset_id: str):
|
||||
assert isinstance(current_user, Account)
|
||||
|
||||
@ -345,7 +345,7 @@ class DatasourceProviderService:
|
||||
def is_tenant_oauth_params_enabled(self, tenant_id: str, datasource_provider_id: DatasourceProviderID) -> bool:
|
||||
"""
|
||||
check if tenant oauth params is enabled
|
||||
"""
|
||||
"""
|
||||
return (
|
||||
db.session.query(DatasourceOauthTenantParamConfig)
|
||||
.filter_by(
|
||||
|
||||
@ -19,7 +19,6 @@ from core.rag.extractor.extract_processor import ExtractProcessor
|
||||
from extensions.ext_storage import storage
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.helper import extract_tenant_id
|
||||
from libs.login import current_user
|
||||
from models.account import Account
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import EndUser, UploadFile
|
||||
@ -121,7 +120,6 @@ class FileService:
|
||||
return file_size <= file_size_limit
|
||||
|
||||
def upload_text(self, text: str, text_name: str, user_id: str, tenant_id: str) -> UploadFile:
|
||||
|
||||
if len(text_name) > 200:
|
||||
text_name = text_name[:200]
|
||||
# user uuid as file name
|
||||
@ -241,4 +239,4 @@ class FileService:
|
||||
return
|
||||
storage.delete(upload_file.key)
|
||||
session.delete(upload_file)
|
||||
session.commit()
|
||||
session.commit()
|
||||
|
||||
@ -1,4 +1,6 @@
|
||||
from typing import Any, Mapping, Optional
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
@ -10,10 +12,11 @@ class DatasourceNodeRunApiEntity(BaseModel):
|
||||
credential_id: Optional[str] = None
|
||||
is_published: bool
|
||||
|
||||
|
||||
class PipelineRunApiEntity(BaseModel):
|
||||
inputs: Mapping[str, Any]
|
||||
datasource_type: str
|
||||
datasource_info_list: list[Mapping[str, Any]]
|
||||
start_node_id: str
|
||||
is_published: bool
|
||||
response_mode: str
|
||||
response_mode: str
|
||||
|
||||
@ -7,7 +7,6 @@ from collections.abc import Callable, Generator, Mapping, Sequence
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any, Optional, Union, cast
|
||||
from uuid import uuid4
|
||||
import uuid
|
||||
|
||||
from flask_login import current_user
|
||||
from sqlalchemy import func, or_, select
|
||||
@ -15,7 +14,6 @@ from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
import contexts
|
||||
from configs import dify_config
|
||||
from core.app.apps.pipeline.pipeline_config_manager import PipelineConfigManager
|
||||
from core.app.apps.pipeline.pipeline_generator import PipelineGenerator
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.datasource.entities.datasource_entities import (
|
||||
@ -57,7 +55,14 @@ from core.workflow.workflow_entry import WorkflowEntry
|
||||
from extensions.ext_database import db
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
from models.account import Account
|
||||
from models.dataset import Dataset, Document, DocumentPipelineExecutionLog, Pipeline, PipelineCustomizedTemplate, PipelineRecommendedPlugin # type: ignore
|
||||
from models.dataset import ( # type: ignore
|
||||
Dataset,
|
||||
Document,
|
||||
DocumentPipelineExecutionLog,
|
||||
Pipeline,
|
||||
PipelineCustomizedTemplate,
|
||||
PipelineRecommendedPlugin,
|
||||
)
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
from models.model import EndUser
|
||||
from models.workflow import (
|
||||
@ -1320,8 +1325,11 @@ class RagPipelineService:
|
||||
"""
|
||||
Retry error document
|
||||
"""
|
||||
document_pipeline_excution_log = db.session.query(DocumentPipelineExecutionLog).filter(
|
||||
DocumentPipelineExecutionLog.document_id == document.id).first()
|
||||
document_pipeline_excution_log = (
|
||||
db.session.query(DocumentPipelineExecutionLog)
|
||||
.filter(DocumentPipelineExecutionLog.document_id == document.id)
|
||||
.first()
|
||||
)
|
||||
if not document_pipeline_excution_log:
|
||||
raise ValueError("Document pipeline execution log not found")
|
||||
pipeline = db.session.query(Pipeline).filter(Pipeline.id == document_pipeline_excution_log.pipeline_id).first()
|
||||
|
||||
@ -52,19 +52,21 @@ def priority_rag_pipeline_run_task(
|
||||
|
||||
try:
|
||||
start_at = time.perf_counter()
|
||||
rag_pipeline_invoke_entities_content = FileService(db.engine).get_file_content(rag_pipeline_invoke_entities_file_id)
|
||||
rag_pipeline_invoke_entities_content = FileService(db.engine).get_file_content(
|
||||
rag_pipeline_invoke_entities_file_id
|
||||
)
|
||||
rag_pipeline_invoke_entities = json.loads(rag_pipeline_invoke_entities_content)
|
||||
|
||||
|
||||
# Get Flask app object for thread context
|
||||
flask_app = current_app._get_current_object() # type: ignore
|
||||
|
||||
|
||||
with ThreadPoolExecutor(max_workers=10) as executor:
|
||||
futures = []
|
||||
for rag_pipeline_invoke_entity in rag_pipeline_invoke_entities:
|
||||
# Submit task to thread pool with Flask app
|
||||
future = executor.submit(run_single_rag_pipeline_task, rag_pipeline_invoke_entity, flask_app)
|
||||
futures.append(future)
|
||||
|
||||
|
||||
# Wait for all tasks to complete
|
||||
for future in futures:
|
||||
try:
|
||||
@ -73,7 +75,9 @@ def priority_rag_pipeline_run_task(
|
||||
logging.exception("Error in pipeline task")
|
||||
end_at = time.perf_counter()
|
||||
logging.info(
|
||||
click.style(f"tenant_id: {tenant_id} , Rag pipeline run completed. Latency: {end_at - start_at}s", fg="green")
|
||||
click.style(
|
||||
f"tenant_id: {tenant_id} , Rag pipeline run completed. Latency: {end_at - start_at}s", fg="green"
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
logging.exception(click.style(f"Error running rag pipeline, tenant_id: {tenant_id}", fg="red"))
|
||||
@ -83,6 +87,7 @@ def priority_rag_pipeline_run_task(
|
||||
file_service.delete_file(rag_pipeline_invoke_entities_file_id)
|
||||
db.session.close()
|
||||
|
||||
|
||||
def run_single_rag_pipeline_task(rag_pipeline_invoke_entity: Mapping[str, Any], flask_app):
|
||||
"""Run a single RAG pipeline task within Flask app context."""
|
||||
# Create Flask application context for this thread
|
||||
@ -97,13 +102,13 @@ def run_single_rag_pipeline_task(rag_pipeline_invoke_entity: Mapping[str, Any],
|
||||
workflow_execution_id = rag_pipeline_invoke_entity_model.workflow_execution_id
|
||||
workflow_thread_pool_id = rag_pipeline_invoke_entity_model.workflow_thread_pool_id
|
||||
application_generate_entity = rag_pipeline_invoke_entity_model.application_generate_entity
|
||||
|
||||
|
||||
with Session(db.engine) as session:
|
||||
# Load required entities
|
||||
account = session.query(Account).filter(Account.id == user_id).first()
|
||||
if not account:
|
||||
raise ValueError(f"Account {user_id} not found")
|
||||
|
||||
|
||||
tenant = session.query(Tenant).filter(Tenant.id == tenant_id).first()
|
||||
if not tenant:
|
||||
raise ValueError(f"Tenant {tenant_id} not found")
|
||||
|
||||
@ -54,7 +54,8 @@ def rag_pipeline_run_task(
|
||||
try:
|
||||
start_at = time.perf_counter()
|
||||
rag_pipeline_invoke_entities_content = FileService(db.engine).get_file_content(
|
||||
rag_pipeline_invoke_entities_file_id)
|
||||
rag_pipeline_invoke_entities_file_id
|
||||
)
|
||||
rag_pipeline_invoke_entities = json.loads(rag_pipeline_invoke_entities_content)
|
||||
|
||||
# Get Flask app object for thread context
|
||||
@ -75,8 +76,9 @@ def rag_pipeline_run_task(
|
||||
logging.exception("Error in pipeline task")
|
||||
end_at = time.perf_counter()
|
||||
logging.info(
|
||||
click.style(f"tenant_id: {tenant_id} , Rag pipeline run completed. Latency: {end_at - start_at}s",
|
||||
fg="green")
|
||||
click.style(
|
||||
f"tenant_id: {tenant_id} , Rag pipeline run completed. Latency: {end_at - start_at}s", fg="green"
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
logging.exception(click.style(f"Error running rag pipeline, tenant_id: {tenant_id}", fg="red"))
|
||||
@ -94,8 +96,9 @@ def rag_pipeline_run_task(
|
||||
# Keep the flag set to indicate a task is running
|
||||
redis_client.setex(tenant_pipeline_task_key, 60 * 60, 1)
|
||||
rag_pipeline_run_task.delay( # type: ignore
|
||||
rag_pipeline_invoke_entities_file_id=next_file_id.decode('utf-8') if isinstance(next_file_id,
|
||||
bytes) else next_file_id,
|
||||
rag_pipeline_invoke_entities_file_id=next_file_id.decode("utf-8")
|
||||
if isinstance(next_file_id, bytes)
|
||||
else next_file_id,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
else:
|
||||
@ -120,13 +123,13 @@ def run_single_rag_pipeline_task(rag_pipeline_invoke_entity: Mapping[str, Any],
|
||||
workflow_execution_id = rag_pipeline_invoke_entity_model.workflow_execution_id
|
||||
workflow_thread_pool_id = rag_pipeline_invoke_entity_model.workflow_thread_pool_id
|
||||
application_generate_entity = rag_pipeline_invoke_entity_model.application_generate_entity
|
||||
|
||||
|
||||
with Session(db.engine) as session:
|
||||
# Load required entities
|
||||
account = session.query(Account).filter(Account.id == user_id).first()
|
||||
if not account:
|
||||
raise ValueError(f"Account {user_id} not found")
|
||||
|
||||
|
||||
tenant = session.query(Tenant).filter(Tenant.id == tenant_id).first()
|
||||
if not tenant:
|
||||
raise ValueError(f"Tenant {tenant_id} not found")
|
||||
|
||||
Loading…
Reference in New Issue
Block a user