mirror of https://github.com/langgenius/dify.git
r2 transform
This commit is contained in:
parent
2012ea3213
commit
384073f025
|
|
@ -947,7 +947,8 @@ class RagPipelineWorkflowLastRunApi(Resource):
|
|||
if node_exec is None:
|
||||
raise NotFound("last run not found")
|
||||
return node_exec
|
||||
|
||||
|
||||
|
||||
class RagPipelineTransformApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -955,8 +956,8 @@ class RagPipelineTransformApi(Resource):
|
|||
def post(self, dataset_id):
|
||||
dataset_id = str(dataset_id)
|
||||
rag_pipeline_transform_service = RagPipelineTransformService()
|
||||
rag_pipeline_transform_service.transform_dataset(dataset_id)
|
||||
return {"message": "success"}
|
||||
result = rag_pipeline_transform_service.transform_dataset(dataset_id)
|
||||
return result
|
||||
|
||||
|
||||
api.add_resource(
|
||||
|
|
@ -1070,4 +1071,4 @@ api.add_resource(
|
|||
api.add_resource(
|
||||
RagPipelineTransformApi,
|
||||
"/rag/pipelines/transform/datasets/<uuid:dataset_id>",
|
||||
)
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import re
|
||||
|
||||
from core.app.app_config.entities import RagPipelineVariableEntity, VariableEntity
|
||||
from models.workflow import Workflow
|
||||
|
||||
|
|
@ -56,7 +57,7 @@ class WorkflowVariablesConfigManager:
|
|||
last_part = full_path.split(".")[-1]
|
||||
variables_map.pop(last_part)
|
||||
all_second_step_variables = list(variables_map.values())
|
||||
|
||||
|
||||
for item in all_second_step_variables:
|
||||
if item.get("belong_to_node_id") == start_node_id or item.get("belong_to_node_id") == "shared":
|
||||
variables.append(RagPipelineVariableEntity.model_validate(item))
|
||||
|
|
|
|||
|
|
@ -171,43 +171,45 @@ class DatasourceProviderService:
|
|||
}
|
||||
for option in credential.options or []
|
||||
],
|
||||
} for credential in datasource.declaration.credentials_schema
|
||||
}
|
||||
for credential in datasource.declaration.credentials_schema
|
||||
],
|
||||
"oauth_schema":
|
||||
{
|
||||
"client_schema": [
|
||||
{
|
||||
"type": client_schema.type.value,
|
||||
"name": client_schema.name,
|
||||
"required": client_schema.required,
|
||||
"default": client_schema.default,
|
||||
"options": [
|
||||
{
|
||||
"value": option.value,
|
||||
"label": option.label.model_dump(),
|
||||
}
|
||||
for option in client_schema.options or []
|
||||
],
|
||||
}
|
||||
for client_schema in datasource.declaration.oauth_schema.client_schema or []
|
||||
],
|
||||
"credentials_schema": [
|
||||
{
|
||||
"type": credential.type.value,
|
||||
"name": credential.name,
|
||||
"required": credential.required,
|
||||
"default": credential.default,
|
||||
"options": [
|
||||
{
|
||||
"value": option.value,
|
||||
"label": option.label.model_dump(),
|
||||
}
|
||||
for option in credential.options or []
|
||||
],
|
||||
}
|
||||
for credential in datasource.declaration.oauth_schema.credentials_schema or []
|
||||
],
|
||||
} if datasource.declaration.oauth_schema else None,
|
||||
"oauth_schema": {
|
||||
"client_schema": [
|
||||
{
|
||||
"type": client_schema.type.value,
|
||||
"name": client_schema.name,
|
||||
"required": client_schema.required,
|
||||
"default": client_schema.default,
|
||||
"options": [
|
||||
{
|
||||
"value": option.value,
|
||||
"label": option.label.model_dump(),
|
||||
}
|
||||
for option in client_schema.options or []
|
||||
],
|
||||
}
|
||||
for client_schema in datasource.declaration.oauth_schema.client_schema or []
|
||||
],
|
||||
"credentials_schema": [
|
||||
{
|
||||
"type": credential.type.value,
|
||||
"name": credential.name,
|
||||
"required": credential.required,
|
||||
"default": credential.default,
|
||||
"options": [
|
||||
{
|
||||
"value": option.value,
|
||||
"label": option.label.model_dump(),
|
||||
}
|
||||
for option in credential.options or []
|
||||
],
|
||||
}
|
||||
for credential in datasource.declaration.oauth_schema.credentials_schema or []
|
||||
],
|
||||
}
|
||||
if datasource.declaration.oauth_schema
|
||||
else None,
|
||||
}
|
||||
)
|
||||
return datasource_credentials
|
||||
|
|
|
|||
|
|
@ -54,7 +54,7 @@ from core.workflow.workflow_entry import WorkflowEntry
|
|||
from extensions.ext_database import db
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
from models.account import Account
|
||||
from models.dataset import Dataset, Document, Pipeline, PipelineCustomizedTemplate # type: ignore
|
||||
from models.dataset import Document, Pipeline, PipelineCustomizedTemplate # type: ignore
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
from models.model import EndUser
|
||||
from models.workflow import (
|
||||
|
|
|
|||
|
|
@ -15,8 +15,6 @@ from services.entities.knowledge_entities.rag_pipeline_entities import Knowledge
|
|||
|
||||
|
||||
class RagPipelineTransformService:
|
||||
|
||||
|
||||
def transform_dataset(self, dataset_id: str):
|
||||
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
|
||||
if not dataset:
|
||||
|
|
@ -42,7 +40,10 @@ class RagPipelineTransformService:
|
|||
new_nodes = []
|
||||
|
||||
for node in nodes:
|
||||
if node.get("data", {}).get("type") == "datasource" and node.get("data", {}).get("provider_type") == "local_file":
|
||||
if (
|
||||
node.get("data", {}).get("type") == "datasource"
|
||||
and node.get("data", {}).get("provider_type") == "local_file"
|
||||
):
|
||||
node = self._deal_file_extensions(node)
|
||||
if node.get("data", {}).get("type") == "knowledge-index":
|
||||
node = self._deal_knowledge_index(dataset, doc_form, indexing_technique, retrieval_model, node)
|
||||
|
|
@ -66,6 +67,11 @@ class RagPipelineTransformService:
|
|||
dataset.pipeline_id = pipeline.id
|
||||
|
||||
db.session.commit()
|
||||
return {
|
||||
"pipeline_id": pipeline.id,
|
||||
"dataset_id": dataset_id,
|
||||
"status": "success",
|
||||
}
|
||||
|
||||
def _get_transform_yaml(self, doc_form: str, datasource_type: str, indexing_technique: str):
|
||||
if doc_form == "text_model":
|
||||
|
|
@ -73,29 +79,29 @@ class RagPipelineTransformService:
|
|||
case "upload_file":
|
||||
if indexing_technique == "high_quality":
|
||||
# get graph from transform.file-general-high-quality.yml
|
||||
with open(f"{Path(__file__).parent}/transform/file-general-high-quality.yml", "r") as f:
|
||||
with open(f"{Path(__file__).parent}/transform/file-general-high-quality.yml") as f:
|
||||
pipeline_yaml = yaml.safe_load(f)
|
||||
if indexing_technique == "economy":
|
||||
# get graph from transform.file-general-economy.yml
|
||||
with open(f"{Path(__file__).parent}/transform/file-general-economy.yml", "r") as f:
|
||||
with open(f"{Path(__file__).parent}/transform/file-general-economy.yml") as f:
|
||||
pipeline_yaml = yaml.safe_load(f)
|
||||
case "notion_import":
|
||||
if indexing_technique == "high_quality":
|
||||
# get graph from transform.notion-general-high-quality.yml
|
||||
with open(f"{Path(__file__).parent}/transform/notion-general-high-quality.yml", "r") as f:
|
||||
with open(f"{Path(__file__).parent}/transform/notion-general-high-quality.yml") as f:
|
||||
pipeline_yaml = yaml.safe_load(f)
|
||||
if indexing_technique == "economy":
|
||||
# get graph from transform.notion-general-economy.yml
|
||||
with open(f"{Path(__file__).parent}/transform/notion-general-economy.yml", "r") as f:
|
||||
with open(f"{Path(__file__).parent}/transform/notion-general-economy.yml") as f:
|
||||
pipeline_yaml = yaml.safe_load(f)
|
||||
case "website_crawl":
|
||||
if indexing_technique == "high_quality":
|
||||
# get graph from transform.website-crawl-general-high-quality.yml
|
||||
with open(f"{Path(__file__).parent}/transform/website-crawl-general-high-quality.yml", "r") as f:
|
||||
with open(f"{Path(__file__).parent}/transform/website-crawl-general-high-quality.yml") as f:
|
||||
pipeline_yaml = yaml.safe_load(f)
|
||||
if indexing_technique == "economy":
|
||||
# get graph from transform.website-crawl-general-economy.yml
|
||||
with open(f"{Path(__file__).parent}/transform/website-crawl-general-economy.yml", "r") as f:
|
||||
with open(f"{Path(__file__).parent}/transform/website-crawl-general-economy.yml") as f:
|
||||
pipeline_yaml = yaml.safe_load(f)
|
||||
case _:
|
||||
raise ValueError("Unsupported datasource type")
|
||||
|
|
@ -103,15 +109,15 @@ class RagPipelineTransformService:
|
|||
match datasource_type:
|
||||
case "upload_file":
|
||||
# get graph from transform.file-parent-child.yml
|
||||
with open(f"{Path(__file__).parent}/transform/file-parent-child.yml", "r") as f:
|
||||
with open(f"{Path(__file__).parent}/transform/file-parent-child.yml") as f:
|
||||
pipeline_yaml = yaml.safe_load(f)
|
||||
case "notion_import":
|
||||
# get graph from transform.notion-parent-child.yml
|
||||
with open(f"{Path(__file__).parent}/transform/notion-parent-child.yml", "r") as f:
|
||||
with open(f"{Path(__file__).parent}/transform/notion-parent-child.yml") as f:
|
||||
pipeline_yaml = yaml.safe_load(f)
|
||||
case "website_crawl":
|
||||
# get graph from transform.website-crawl-parent-child.yml
|
||||
with open(f"{Path(__file__).parent}/transform/website-crawl-parent-child.yml", "r") as f:
|
||||
with open(f"{Path(__file__).parent}/transform/website-crawl-parent-child.yml") as f:
|
||||
pipeline_yaml = yaml.safe_load(f)
|
||||
case _:
|
||||
raise ValueError("Unsupported datasource type")
|
||||
|
|
@ -127,7 +133,9 @@ class RagPipelineTransformService:
|
|||
node["data"]["fileExtensions"] = DOCUMENT_EXTENSIONS
|
||||
return node
|
||||
|
||||
def _deal_knowledge_index(self, dataset: Dataset, doc_form: str, indexing_technique: str, retrieval_model: dict, node: dict):
|
||||
def _deal_knowledge_index(
|
||||
self, dataset: Dataset, doc_form: str, indexing_technique: str, retrieval_model: dict, node: dict
|
||||
):
|
||||
knowledge_configuration = node.get("data", {})
|
||||
knowledge_configuration = KnowledgeConfiguration(**knowledge_configuration)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue