mirror of https://github.com/langgenius/dify.git
fix user_id missed
This commit is contained in:
parent
914ae3c5d2
commit
32cccbbf88
|
|
@ -145,7 +145,7 @@ class DatasetDocumentListApi(Resource):
|
|||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id):
|
||||
def get(self, dataset_id: str):
|
||||
dataset_id = str(dataset_id)
|
||||
page = request.args.get("page", default=1, type=int)
|
||||
limit = request.args.get("limit", default=20, type=int)
|
||||
|
|
@ -153,7 +153,7 @@ class DatasetDocumentListApi(Resource):
|
|||
sort = request.args.get("sort", default="-created_at", type=str)
|
||||
# "yes", "true", "t", "y", "1" convert to True, while others convert to False.
|
||||
try:
|
||||
fetch_val = request.args.get("fetch", default="false")
|
||||
fetch_val = request.args.get("fetch", default=False)
|
||||
if isinstance(fetch_val, bool):
|
||||
fetch = fetch_val
|
||||
else:
|
||||
|
|
@ -250,7 +250,7 @@ class DatasetDocumentListApi(Resource):
|
|||
@marshal_with(dataset_and_document_fields)
|
||||
@cloud_edition_billing_resource_check("vector_space")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def post(self, dataset_id):
|
||||
def post(self, dataset_id: str):
|
||||
dataset_id = str(dataset_id)
|
||||
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
|
|
@ -308,7 +308,7 @@ class DatasetDocumentListApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def delete(self, dataset_id):
|
||||
def delete(self, dataset_id: str):
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if dataset is None:
|
||||
|
|
|
|||
|
|
@ -119,9 +119,6 @@ class DraftRagPipelineApi(Resource):
|
|||
else:
|
||||
abort(415)
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
|
||||
try:
|
||||
environment_variables_list = args.get("environment_variables") or []
|
||||
environment_variables = [
|
||||
|
|
@ -164,9 +161,6 @@ class RagPipelineDraftRunIterationNodeApi(Resource):
|
|||
if not isinstance(current_user, Account) or not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, location="json")
|
||||
args = parser.parse_args()
|
||||
|
|
@ -201,9 +195,6 @@ class RagPipelineDraftRunLoopNodeApi(Resource):
|
|||
if not isinstance(current_user, Account) or not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, location="json")
|
||||
args = parser.parse_args()
|
||||
|
|
@ -238,9 +229,6 @@ class DraftRagPipelineRunApi(Resource):
|
|||
if not isinstance(current_user, Account) or not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||
parser.add_argument("datasource_type", type=str, required=True, location="json")
|
||||
|
|
@ -275,9 +263,6 @@ class PublishedRagPipelineRunApi(Resource):
|
|||
if not isinstance(current_user, Account) or not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||
parser.add_argument("datasource_type", type=str, required=True, location="json")
|
||||
|
|
@ -396,10 +381,7 @@ class RagPipelinePublishedDatasourceNodeRunApi(Resource):
|
|||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not isinstance(current_user, Account) or not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||
parser.add_argument("datasource_type", type=str, required=True, location="json")
|
||||
|
|
|
|||
|
|
@ -200,7 +200,7 @@ class DatasetListApi(DatasetApiResource):
|
|||
401: "Unauthorized - invalid API token",
|
||||
}
|
||||
)
|
||||
def get(self, tenant_id):
|
||||
def get(self, tenant_id: str):
|
||||
"""Resource for getting datasets."""
|
||||
page = request.args.get("page", default=1, type=int)
|
||||
limit = request.args.get("limit", default=20, type=int)
|
||||
|
|
|
|||
|
|
@ -47,3 +47,8 @@ class DatasetInUseError(BaseHTTPException):
|
|||
error_code = "dataset_in_use"
|
||||
description = "The dataset is being used by some apps. Please remove the dataset from the apps before deleting it."
|
||||
code = 409
|
||||
|
||||
class PipelineRunError(BaseHTTPException):
|
||||
error_code = "pipeline_run_error"
|
||||
description = "Pipeline run error."
|
||||
code = 400
|
||||
|
|
@ -33,6 +33,7 @@ from core.rag.entities.event import (
|
|||
DatasourceErrorEvent,
|
||||
DatasourceProcessingEvent,
|
||||
)
|
||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
from core.repositories.factory import DifyCoreRepositoryFactory
|
||||
from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
from core.variables.variables import Variable
|
||||
|
|
@ -54,7 +55,13 @@ from core.workflow.workflow_entry import WorkflowEntry
|
|||
from extensions.ext_database import db
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
from models.account import Account
|
||||
from models.dataset import Document, Pipeline, PipelineCustomizedTemplate, PipelineRecommendedPlugin # type: ignore
|
||||
from models.dataset import ( # type: ignore
|
||||
Dataset,
|
||||
Document,
|
||||
Pipeline,
|
||||
PipelineCustomizedTemplate,
|
||||
PipelineRecommendedPlugin,
|
||||
)
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
from models.model import EndUser
|
||||
from models.workflow import (
|
||||
|
|
@ -480,7 +487,7 @@ class RagPipelineService:
|
|||
self,
|
||||
pipeline: Pipeline,
|
||||
node_id: str,
|
||||
user_inputs: dict,
|
||||
user_inputs: Mapping[str, Any],
|
||||
account: Account,
|
||||
datasource_type: str,
|
||||
is_published: bool,
|
||||
|
|
@ -1312,7 +1319,7 @@ class RagPipelineService:
|
|||
"uninstalled_recommended_plugins": uninstalled_plugin_list,
|
||||
}
|
||||
|
||||
def get_datasource_plugins(self, dataset_id: str, is_published: bool) -> list[dict]:
|
||||
def get_datasource_plugins(self, tenant_id: str, dataset_id: str, is_published: bool) -> list[dict]:
|
||||
"""
|
||||
Get datasource plugins
|
||||
"""
|
||||
|
|
@ -1325,9 +1332,9 @@ class RagPipelineService:
|
|||
|
||||
workflow: Workflow | None = None
|
||||
if is_published:
|
||||
workflow: Workflow | None = self.get_published_workflow(pipeline=pipeline)
|
||||
workflow = self.get_published_workflow(pipeline=pipeline)
|
||||
else:
|
||||
workflow: Workflow | None = self.get_draft_workflow(pipeline=pipeline)
|
||||
workflow = self.get_draft_workflow(pipeline=pipeline)
|
||||
if not pipeline or not workflow:
|
||||
raise ValueError("Pipeline or workflow not found")
|
||||
|
||||
|
|
@ -1338,33 +1345,68 @@ class RagPipelineService:
|
|||
datasource_node_data = datasource_node.get("data", {})
|
||||
if not datasource_node_data:
|
||||
continue
|
||||
|
||||
datasource_parameters = datasource_node_data.get("datasource_parameters", {})
|
||||
user_input_variables_keys = []
|
||||
user_input_variables = []
|
||||
|
||||
for _, value in datasource_parameters.items():
|
||||
if value.get("value") and isinstance(value.get("value"), str):
|
||||
pattern = r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z0-9_][a-zA-Z0-9_]{0,29}){1,10})#\}\}"
|
||||
match = re.match(pattern, value["value"])
|
||||
if match:
|
||||
full_path = match.group(1)
|
||||
last_part = full_path.split(".")[-1]
|
||||
variables = workflow.rag_pipeline_variables
|
||||
if variables:
|
||||
variables_map = {item["variable"]: item for item in variables}
|
||||
else:
|
||||
variables_map = {}
|
||||
|
||||
datasource_parameters = datasource_node_data.get("datasource_parameters", {})
|
||||
user_input_variables_keys = []
|
||||
user_input_variables = []
|
||||
|
||||
for _, value in datasource_parameters.items():
|
||||
if value.get("value") and isinstance(value.get("value"), str):
|
||||
pattern = r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z0-9_][a-zA-Z0-9_]{0,29}){1,10})#\}\}"
|
||||
match = re.match(pattern, value["value"])
|
||||
if match:
|
||||
full_path = match.group(1)
|
||||
last_part = full_path.split(".")[-1]
|
||||
user_input_variables_keys.append(last_part)
|
||||
elif value.get("value") and isinstance(value.get("value"), list):
|
||||
last_part = value.get("value")[-1]
|
||||
user_input_variables_keys.append(last_part)
|
||||
elif value.get("value") and isinstance(value.get("value"), list):
|
||||
last_part = value.get("value")[-1]
|
||||
user_input_variables_keys.append(last_part)
|
||||
for key, value in variables_map.items():
|
||||
if key in user_input_variables_keys:
|
||||
user_input_variables.append(value)
|
||||
for key, value in variables_map.items():
|
||||
if key in user_input_variables_keys:
|
||||
user_input_variables.append(value)
|
||||
|
||||
# get credentials
|
||||
datasource_provider_service: DatasourceProviderService = DatasourceProviderService()
|
||||
credentials: list[dict[Any, Any]] = datasource_provider_service.list_datasource_credentials(
|
||||
tenant_id=tenant_id,
|
||||
provider=datasource_node_data.get("provider_name"),
|
||||
plugin_id=datasource_node_data.get("plugin_id")
|
||||
)
|
||||
credential_info_list: list[Any] = []
|
||||
for credential in credentials:
|
||||
credential_info_list.append({
|
||||
"id": credential.get("id"),
|
||||
"name": credential.get("name"),
|
||||
"type": credential.get("type"),
|
||||
"is_default": credential.get("is_default"),
|
||||
})
|
||||
|
||||
datasource_plugins.append({
|
||||
"node_id": datasource_node.get("id"),
|
||||
"plugin_id": datasource_node_data.get("plugin_id"),
|
||||
"provider_name": datasource_node_data.get("provider_name"),
|
||||
"provider_type": datasource_node_data.get("provider_type"),
|
||||
"datasource_type": datasource_node_data.get("provider_type"),
|
||||
"title": datasource_node_data.get("title"),
|
||||
"user_input_variables": user_input_variables,
|
||||
"credentials": credential_info_list,
|
||||
})
|
||||
|
||||
return datasource_plugins
|
||||
|
||||
def get_pipeline(self, tenant_id: str, dataset_id: str) -> Pipeline:
|
||||
"""
|
||||
Get pipeline
|
||||
"""
|
||||
dataset: Dataset | None = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
|
||||
if not dataset:
|
||||
raise ValueError("Dataset not found")
|
||||
pipeline: Pipeline | None = db.session.query(Pipeline).filter(Pipeline.id == dataset.pipeline_id).first()
|
||||
if not pipeline:
|
||||
raise ValueError("Pipeline not found")
|
||||
return pipeline
|
||||
|
|
|
|||
Loading…
Reference in New Issue