fix user_id missed

This commit is contained in:
jyong 2025-09-11 17:36:49 +08:00
parent 914ae3c5d2
commit 32cccbbf88
5 changed files with 76 additions and 47 deletions

View File

@ -145,7 +145,7 @@ class DatasetDocumentListApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id):
def get(self, dataset_id: str):
dataset_id = str(dataset_id)
page = request.args.get("page", default=1, type=int)
limit = request.args.get("limit", default=20, type=int)
@ -153,7 +153,7 @@ class DatasetDocumentListApi(Resource):
sort = request.args.get("sort", default="-created_at", type=str)
# "yes", "true", "t", "y", "1" convert to True, while others convert to False.
try:
fetch_val = request.args.get("fetch", default="false")
fetch_val = request.args.get("fetch", default=False)
if isinstance(fetch_val, bool):
fetch = fetch_val
else:
@ -250,7 +250,7 @@ class DatasetDocumentListApi(Resource):
@marshal_with(dataset_and_document_fields)
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge")
def post(self, dataset_id):
def post(self, dataset_id: str):
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
@ -308,7 +308,7 @@ class DatasetDocumentListApi(Resource):
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def delete(self, dataset_id):
def delete(self, dataset_id: str):
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
if dataset is None:

View File

@ -119,9 +119,6 @@ class DraftRagPipelineApi(Resource):
else:
abort(415)
if not isinstance(current_user, Account):
raise Forbidden()
try:
environment_variables_list = args.get("environment_variables") or []
environment_variables = [
@ -164,9 +161,6 @@ class RagPipelineDraftRunIterationNodeApi(Resource):
if not isinstance(current_user, Account) or not current_user.is_editor:
raise Forbidden()
if not isinstance(current_user, Account):
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, location="json")
args = parser.parse_args()
@ -201,9 +195,6 @@ class RagPipelineDraftRunLoopNodeApi(Resource):
if not isinstance(current_user, Account) or not current_user.is_editor:
raise Forbidden()
if not isinstance(current_user, Account):
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, location="json")
args = parser.parse_args()
@ -238,9 +229,6 @@ class DraftRagPipelineRunApi(Resource):
if not isinstance(current_user, Account) or not current_user.is_editor:
raise Forbidden()
if not isinstance(current_user, Account):
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
parser.add_argument("datasource_type", type=str, required=True, location="json")
@ -275,9 +263,6 @@ class PublishedRagPipelineRunApi(Resource):
if not isinstance(current_user, Account) or not current_user.is_editor:
raise Forbidden()
if not isinstance(current_user, Account):
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
parser.add_argument("datasource_type", type=str, required=True, location="json")
@ -396,10 +381,7 @@ class RagPipelinePublishedDatasourceNodeRunApi(Resource):
# The role of the current user in the ta table must be admin, owner, or editor
if not isinstance(current_user, Account) or not current_user.is_editor:
raise Forbidden()
if not isinstance(current_user, Account):
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
parser.add_argument("datasource_type", type=str, required=True, location="json")

View File

@ -200,7 +200,7 @@ class DatasetListApi(DatasetApiResource):
401: "Unauthorized - invalid API token",
}
)
def get(self, tenant_id):
def get(self, tenant_id: str):
"""Resource for getting datasets."""
page = request.args.get("page", default=1, type=int)
limit = request.args.get("limit", default=20, type=int)

View File

@ -47,3 +47,8 @@ class DatasetInUseError(BaseHTTPException):
error_code = "dataset_in_use"
description = "The dataset is being used by some apps. Please remove the dataset from the apps before deleting it."
code = 409
class PipelineRunError(BaseHTTPException):
error_code = "pipeline_run_error"
description = "Pipeline run error."
code = 400

View File

@ -33,6 +33,7 @@ from core.rag.entities.event import (
DatasourceErrorEvent,
DatasourceProcessingEvent,
)
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.repositories.factory import DifyCoreRepositoryFactory
from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository
from core.variables.variables import Variable
@ -54,7 +55,13 @@ from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from models.account import Account
from models.dataset import Document, Pipeline, PipelineCustomizedTemplate, PipelineRecommendedPlugin # type: ignore
from models.dataset import ( # type: ignore
Dataset,
Document,
Pipeline,
PipelineCustomizedTemplate,
PipelineRecommendedPlugin,
)
from models.enums import WorkflowRunTriggeredFrom
from models.model import EndUser
from models.workflow import (
@ -480,7 +487,7 @@ class RagPipelineService:
self,
pipeline: Pipeline,
node_id: str,
user_inputs: dict,
user_inputs: Mapping[str, Any],
account: Account,
datasource_type: str,
is_published: bool,
@ -1312,7 +1319,7 @@ class RagPipelineService:
"uninstalled_recommended_plugins": uninstalled_plugin_list,
}
def get_datasource_plugins(self, dataset_id: str, is_published: bool) -> list[dict]:
def get_datasource_plugins(self, tenant_id: str, dataset_id: str, is_published: bool) -> list[dict]:
"""
Get datasource plugins
"""
@ -1325,9 +1332,9 @@ class RagPipelineService:
workflow: Workflow | None = None
if is_published:
workflow: Workflow | None = self.get_published_workflow(pipeline=pipeline)
workflow = self.get_published_workflow(pipeline=pipeline)
else:
workflow: Workflow | None = self.get_draft_workflow(pipeline=pipeline)
workflow = self.get_draft_workflow(pipeline=pipeline)
if not pipeline or not workflow:
raise ValueError("Pipeline or workflow not found")
@ -1338,33 +1345,68 @@ class RagPipelineService:
datasource_node_data = datasource_node.get("data", {})
if not datasource_node_data:
continue
datasource_parameters = datasource_node_data.get("datasource_parameters", {})
user_input_variables_keys = []
user_input_variables = []
for _, value in datasource_parameters.items():
if value.get("value") and isinstance(value.get("value"), str):
pattern = r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z0-9_][a-zA-Z0-9_]{0,29}){1,10})#\}\}"
match = re.match(pattern, value["value"])
if match:
full_path = match.group(1)
last_part = full_path.split(".")[-1]
variables = workflow.rag_pipeline_variables
if variables:
variables_map = {item["variable"]: item for item in variables}
else:
variables_map = {}
datasource_parameters = datasource_node_data.get("datasource_parameters", {})
user_input_variables_keys = []
user_input_variables = []
for _, value in datasource_parameters.items():
if value.get("value") and isinstance(value.get("value"), str):
pattern = r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z0-9_][a-zA-Z0-9_]{0,29}){1,10})#\}\}"
match = re.match(pattern, value["value"])
if match:
full_path = match.group(1)
last_part = full_path.split(".")[-1]
user_input_variables_keys.append(last_part)
elif value.get("value") and isinstance(value.get("value"), list):
last_part = value.get("value")[-1]
user_input_variables_keys.append(last_part)
elif value.get("value") and isinstance(value.get("value"), list):
last_part = value.get("value")[-1]
user_input_variables_keys.append(last_part)
for key, value in variables_map.items():
if key in user_input_variables_keys:
user_input_variables.append(value)
for key, value in variables_map.items():
if key in user_input_variables_keys:
user_input_variables.append(value)
# get credentials
datasource_provider_service: DatasourceProviderService = DatasourceProviderService()
credentials: list[dict[Any, Any]] = datasource_provider_service.list_datasource_credentials(
tenant_id=tenant_id,
provider=datasource_node_data.get("provider_name"),
plugin_id=datasource_node_data.get("plugin_id")
)
credential_info_list: list[Any] = []
for credential in credentials:
credential_info_list.append({
"id": credential.get("id"),
"name": credential.get("name"),
"type": credential.get("type"),
"is_default": credential.get("is_default"),
})
datasource_plugins.append({
"node_id": datasource_node.get("id"),
"plugin_id": datasource_node_data.get("plugin_id"),
"provider_name": datasource_node_data.get("provider_name"),
"provider_type": datasource_node_data.get("provider_type"),
"datasource_type": datasource_node_data.get("provider_type"),
"title": datasource_node_data.get("title"),
"user_input_variables": user_input_variables,
"credentials": credential_info_list,
})
return datasource_plugins
def get_pipeline(self, tenant_id: str, dataset_id: str) -> Pipeline:
"""
Get pipeline
"""
dataset: Dataset | None = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
if not dataset:
raise ValueError("Dataset not found")
pipeline: Pipeline | None = db.session.query(Pipeline).filter(Pipeline.id == dataset.pipeline_id).first()
if not pipeline:
raise ValueError("Pipeline not found")
return pipeline