This commit is contained in:
jyong 2025-07-02 18:15:23 +08:00
parent e23d7e39ec
commit 81b07dc3be
14 changed files with 102 additions and 32 deletions

View File

@ -1051,11 +1051,12 @@ class DocumentPipelineExecutionLogApi(DocumentResource):
.first()
)
if not log:
return {"datasource_info": None,
"datasource_type": None,
"input_data": None,
"datasource_node_id": None,
}, 200
return {
"datasource_info": None,
"datasource_type": None,
"input_data": None,
"datasource_node_id": None,
}, 200
return {
"datasource_info": json.loads(log.datasource_info),
"datasource_type": log.datasource_type,
@ -1086,5 +1087,6 @@ api.add_resource(DocumentRetryApi, "/datasets/<uuid:dataset_id>/retry")
api.add_resource(DocumentRenameApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/rename")
api.add_resource(WebsiteDocumentSyncApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/website-sync")
api.add_resource(DocumentPipelineExecutionLogApi,
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/pipeline-execution-log")
api.add_resource(
DocumentPipelineExecutionLogApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/pipeline-execution-log"
)

View File

@ -96,7 +96,7 @@ class DatasourceAuth(Resource):
parser = reqparse.RequestParser()
parser.add_argument("provider", type=str, required=True, nullable=False, location="json")
parser.add_argument("name", type=str, required=True, nullable=False, location="json")
parser.add_argument("name", type=str, required=False, nullable=False, location="json", default="test")
parser.add_argument("plugin_id", type=str, required=True, nullable=False, location="json")
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
args = parser.parse_args()

View File

@ -48,7 +48,8 @@ class DataSourceContentPreviewApi(Resource):
)
return preview_content, 200
api.add_resource(
DataSourceContentPreviewApi,
"/rag/pipelines/<uuid:pipeline_id>/workflows/published/datasource/nodes/<string:node_id>/preview"
"/rag/pipelines/<uuid:pipeline_id>/workflows/published/datasource/nodes/<string:node_id>/preview",
)

View File

@ -1,4 +1,3 @@
from ast import Str
from collections.abc import Sequence
from enum import Enum, StrEnum
from typing import Any, Literal, Optional
@ -128,14 +127,17 @@ class VariableEntity(BaseModel):
def convert_none_options(cls, v: Any) -> Sequence[str]:
return v or []
class RagPipelineVariableEntity(VariableEntity):
"""
Rag Pipeline Variable Entity.
"""
tooltips: Optional[str] = None
placeholder: Optional[str] = None
belong_to_node_id: str
class ExternalDataVariableEntity(BaseModel):
"""
External Data Variable Entity.

View File

@ -1,5 +1,3 @@
from typing import Any
from core.app.app_config.entities import RagPipelineVariableEntity, VariableEntity
from models.workflow import Workflow

View File

@ -13,6 +13,7 @@ class PipelineConfig(WorkflowUIBasedAppConfig):
"""
Pipeline Config Entity.
"""
rag_pipeline_variables: list[RagPipelineVariableEntity] = []
pass

View File

@ -47,6 +47,7 @@ class PipelineRunner(WorkflowBasedAppRunner):
def _get_app_id(self) -> str:
return self.application_generate_entity.app_config.app_id
def run(self) -> None:
"""
Run application
@ -114,9 +115,9 @@ class PipelineRunner(WorkflowBasedAppRunner):
for v in workflow.rag_pipeline_variables:
rag_pipeline_variable = RAGPipelineVariable(**v)
if (
(rag_pipeline_variable.belong_to_node_id == self.application_generate_entity.start_node_id or rag_pipeline_variable.belong_to_node_id == "shared")
and rag_pipeline_variable.variable in inputs
):
rag_pipeline_variable.belong_to_node_id
in (self.application_generate_entity.start_node_id, "shared")
) and rag_pipeline_variable.variable in inputs:
rag_pipeline_variables.append(
RAGPipelineVariableInput(
variable=rag_pipeline_variable,

View File

@ -10,8 +10,12 @@ from core.variables import Segment, SegmentGroup, Variable
from core.variables.consts import MIN_SELECTORS_LENGTH
from core.variables.segments import FileSegment, NoneSegment
from core.variables.variables import RAGPipelineVariableInput
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, \
SYSTEM_VARIABLE_NODE_ID, RAG_PIPELINE_VARIABLE_NODE_ID
from core.workflow.constants import (
CONVERSATION_VARIABLE_NODE_ID,
ENVIRONMENT_VARIABLE_NODE_ID,
RAG_PIPELINE_VARIABLE_NODE_ID,
SYSTEM_VARIABLE_NODE_ID,
)
from core.workflow.enums import SystemVariableKey
from factories import variable_factory

View File

@ -462,6 +462,7 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
inputs=parameters_for_log,
)
)
@classmethod
def version(cls) -> str:
return "1"

View File

@ -323,13 +323,11 @@ class Workflow(Base):
return variables
def rag_pipeline_user_input_form(self) -> list:
# get user_input_form from start node
variables: list[Any] = self.rag_pipeline_variables
return variables
@property
def unique_hash(self) -> str:
"""

View File

@ -344,10 +344,10 @@ class DatasetService:
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise ValueError("Dataset not found")
# check if dataset name is exists
# check if dataset name is exists
if (
db.session.query(Dataset)
.filter(
.filter(
Dataset.id != dataset_id,
Dataset.name == data.get("name", dataset.name),
Dataset.tenant_id == dataset.tenant_id,
@ -470,7 +470,7 @@ class DatasetService:
filtered_data["updated_at"] = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
# update Retrieval model
filtered_data["retrieval_model"] = data["retrieval_model"]
# update icon info
# update icon info
if data.get("icon_info"):
filtered_data["icon_info"] = data.get("icon_info")

View File

@ -32,14 +32,10 @@ class DatasourceProviderService:
:param credentials:
"""
# check name is exist
datasource_provider = (
db.session.query(DatasourceProvider)
.filter_by(tenant_id=tenant_id, name=name)
.first()
)
datasource_provider = db.session.query(DatasourceProvider).filter_by(tenant_id=tenant_id, name=name).first()
if datasource_provider:
raise ValueError("Authorization name is already exists")
credential_valid = self.provider_manager.validate_provider_credentials(
tenant_id=tenant_id,
user_id=current_user.id,

View File

@ -20,9 +20,12 @@ from core.datasource.entities.datasource_entities import (
DatasourceProviderType,
GetOnlineDocumentPageContentRequest,
OnlineDocumentPagesMessage,
OnlineDriveBrowseFilesRequest,
OnlineDriveBrowseFilesResponse,
WebsiteCrawlMessage,
)
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
from core.datasource.online_drive.online_drive_plugin import OnlineDriveDatasourcePlugin
from core.datasource.website_crawl.website_crawl_plugin import WebsiteCrawlDatasourcePlugin
from core.rag.entities.event import (
BaseDatasourceEvent,
@ -31,8 +34,9 @@ from core.rag.entities.event import (
DatasourceProcessingEvent,
)
from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository
from core.variables.variables import Variable
from core.variables.variables import RAGPipelineVariable, RAGPipelineVariableInput, Variable
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import (
WorkflowNodeExecution,
WorkflowNodeExecutionStatus,
@ -381,6 +385,17 @@ class RagPipelineService:
# run draft workflow node
start_at = time.perf_counter()
rag_pipeline_variables = []
if draft_workflow.rag_pipeline_variables:
for v in draft_workflow.rag_pipeline_variables:
rag_pipeline_variable = RAGPipelineVariable(**v)
if rag_pipeline_variable.variable in user_inputs:
rag_pipeline_variables.append(
RAGPipelineVariableInput(
variable=rag_pipeline_variable,
value=user_inputs[rag_pipeline_variable.variable],
)
)
workflow_node_execution = self._handle_node_run_result(
getter=lambda: WorkflowEntry.single_step_run(
@ -388,6 +403,12 @@ class RagPipelineService:
node_id=node_id,
user_inputs=user_inputs,
user_id=account.id,
variable_pool=VariablePool(
user_inputs=user_inputs,
environment_variables=draft_workflow.environment_variables,
conversation_variables=draft_workflow.conversation_variables,
rag_pipeline_variables=rag_pipeline_variables,
),
),
start_at=start_at,
tenant_id=pipeline.tenant_id,
@ -413,6 +434,17 @@ class RagPipelineService:
# run draft workflow node
start_at = time.perf_counter()
rag_pipeline_variables = []
if published_workflow.rag_pipeline_variables:
for v in published_workflow.rag_pipeline_variables:
rag_pipeline_variable = RAGPipelineVariable(**v)
if rag_pipeline_variable.variable in user_inputs:
rag_pipeline_variables.append(
RAGPipelineVariableInput(
variable=rag_pipeline_variable,
value=user_inputs[rag_pipeline_variable.variable],
)
)
workflow_node_execution = self._handle_node_run_result(
getter=lambda: WorkflowEntry.single_step_run(
@ -420,6 +452,12 @@ class RagPipelineService:
node_id=node_id,
user_inputs=user_inputs,
user_id=account.id,
variable_pool=VariablePool(
user_inputs=user_inputs,
environment_variables=published_workflow.environment_variables,
conversation_variables=published_workflow.conversation_variables,
rag_pipeline_variables=rag_pipeline_variables,
),
),
start_at=start_at,
tenant_id=pipeline.tenant_id,
@ -511,6 +549,33 @@ class RagPipelineService:
except Exception as e:
logger.exception("Error during online document.")
yield DatasourceErrorEvent(error=str(e)).model_dump()
case DatasourceProviderType.ONLINE_DRIVE:
datasource_runtime = cast(OnlineDriveDatasourcePlugin, datasource_runtime)
online_drive_result: Generator[OnlineDriveBrowseFilesResponse, None, None] = datasource_runtime.online_drive_browse_files(
user_id=account.id,
request=OnlineDriveBrowseFilesRequest(
bucket=user_inputs.get("bucket"),
prefix=user_inputs.get("prefix"),
max_keys=user_inputs.get("max_keys", 20),
start_after=user_inputs.get("start_after"),
),
provider_type=datasource_runtime.datasource_provider_type(),
)
start_time = time.time()
start_event = DatasourceProcessingEvent(
total=0,
completed=0,
)
yield start_event.model_dump()
for message in online_drive_result:
end_time = time.time()
online_drive_event = DatasourceCompletedEvent(
data=message.result,
time_consuming=round(end_time - start_time, 2),
total=None,
completed=None,
)
yield online_drive_event.model_dump()
case DatasourceProviderType.WEBSITE_CRAWL:
datasource_runtime = cast(WebsiteCrawlDatasourcePlugin, datasource_runtime)
website_crawl_result: Generator[WebsiteCrawlMessage, None, None] = (
@ -631,7 +696,7 @@ class RagPipelineService:
except Exception as e:
logger.exception("Error during get online document content.")
raise RuntimeError(str(e))
#TODO Online Drive
# TODO Online Drive
case _:
raise ValueError(f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}")
except Exception as e:

View File

@ -86,8 +86,9 @@ class ToolTransformService:
)
else:
provider.declaration.identity.icon = ToolTransformService.get_tool_provider_icon_url(
provider_type=provider.type.value, provider_name=provider.name,
icon=provider.declaration.identity.icon
provider_type=provider.type.value,
provider_name=provider.name,
icon=provider.declaration.identity.icon,
)
@classmethod