This commit is contained in:
jyong 2025-05-26 14:49:59 +08:00
parent ec1c4efca9
commit 665ffbdc10
11 changed files with 143 additions and 94 deletions

View File

@ -8,6 +8,7 @@ from werkzeug.exceptions import Forbidden
from configs import dify_config
from controllers.console import api
from core.plugin.impl.datasource import PluginDatasourceManager
from libs.login import login_required
from libs.oauth_data_source import NotionOAuth
@ -109,7 +110,30 @@ class OAuthDataSourceSync(Resource):
return {"result": "success"}, 200
class DatasourcePluginOauthApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, datasource_type, datasource_name):
# Check user role first
if not current_user.is_editor:
raise Forbidden()
# get all builtin providers
manager = PluginDatasourceManager()
# Fix: use correct method name or implement the missing method
try:
providers = manager.get_providers() # or whatever the correct method is
# Filter by datasource_type and datasource_name if needed
oauth_config = {} # Build appropriate OAuth URL response
return oauth_config
except AttributeError:
# Method doesn't exist, return empty response or implement
return {"oauth_url": None, "supported": False}
api.add_resource(OAuthDataSource, "/oauth/data-source/<string:provider>")
api.add_resource(OAuthDataSourceCallback, "/oauth/data-source/callback/<string:provider>")
api.add_resource(OAuthDataSourceBinding, "/oauth/data-source/binding/<string:provider>")
api.add_resource(OAuthDataSourceSync, "/oauth/data-source/<string:provider>/<uuid:binding_id>/sync")
api.add_resource(DatasourcePluginOauthApi, "/oauth/plugin/datasource/<string:datasoruce_type>/<string:datasource_name>")

View File

@ -280,6 +280,8 @@ class PublishedRagPipelineRunApi(Resource):
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
parser.add_argument("datasource_type", type=str, required=True, location="json")
parser.add_argument("datasource_info", type=list, required=True, location="json")
parser.add_argument("start_node_id", type=str, required=True, location="json")
parser.add_argument("is_preview", type=bool, required=True, location="json", default=False)
args = parser.parse_args()
try:
@ -287,7 +289,7 @@ class PublishedRagPipelineRunApi(Resource):
pipeline=pipeline,
user=current_user,
args=args,
invoke_from=InvokeFrom.PUBLISHED,
invoke_from=InvokeFrom.DEBUGGER if args.get("is_preview") else InvokeFrom.PUBLISHED,
streaming=True,
)
@ -469,6 +471,7 @@ class PublishedRagPipelineApi(Resource):
rag_pipeline_service = RagPipelineService()
with Session(db.engine) as session:
pipeline = session.merge(pipeline)
workflow = rag_pipeline_service.publish_workflow(
session=session,
pipeline=pipeline,
@ -478,6 +481,7 @@ class PublishedRagPipelineApi(Resource):
)
pipeline.is_published = True
pipeline.workflow_id = workflow.id
session.add(pipeline)
workflow_created_at = TimestampField().format(workflow.created_at)
session.commit()
@ -797,6 +801,10 @@ api.add_resource(
DraftRagPipelineRunApi,
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/run",
)
api.add_resource(
PublishedRagPipelineRunApi,
"/rag/pipelines/<uuid:pipeline_id>/workflows/published/run",
)
api.add_resource(
RagPipelineTaskStopApi,
"/rag/pipelines/<uuid:pipeline_id>/workflow-runs/tasks/<string:task_id>/stop",

View File

@ -92,7 +92,7 @@ class PipelineGenerator(BaseAppGenerator):
streaming: bool = True,
call_depth: int = 0,
workflow_thread_pool_id: Optional[str] = None,
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None], None]:
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]:
# convert to app config
pipeline_config = PipelineConfigManager.get_pipeline_config(
pipeline=pipeline,
@ -108,23 +108,24 @@ class PipelineGenerator(BaseAppGenerator):
for datasource_info in datasource_info_list:
workflow_run_id = str(uuid.uuid4())
document_id = None
dataset = pipeline.dataset
if not dataset:
raise ValueError("Dataset not found")
# Add null check for dataset
if not pipeline.dataset:
raise ValueError("Pipeline dataset is required")
if invoke_from == InvokeFrom.PUBLISHED:
position = DocumentService.get_documents_position(pipeline.dataset_id)
position = DocumentService.get_documents_position(pipeline.dataset_id)
document = self._build_document(
tenant_id=pipeline.tenant_id,
dataset_id=pipeline.dataset_id,
built_in_field_enabled=dataset.built_in_field_enabled,
built_in_field_enabled=pipeline.dataset.built_in_field_enabled,
datasource_type=datasource_type,
datasource_info=datasource_info,
created_from="rag-pipeline",
position=position,
account=user,
batch=batch,
document_form=dataset.chunk_structure,
document_form=pipeline.dataset.chunk_structure,
)
db.session.add(document)
db.session.commit()
@ -136,7 +137,7 @@ class PipelineGenerator(BaseAppGenerator):
pipeline_config=pipeline_config,
datasource_type=datasource_type,
datasource_info=datasource_info,
dataset_id=dataset.id,
dataset_id=pipeline.dataset.id,
start_node_id=start_node_id,
batch=batch,
document_id=document_id,
@ -274,27 +275,24 @@ class PipelineGenerator(BaseAppGenerator):
raise ValueError("inputs is required")
# convert to app config
pipeline_config = PipelineConfigManager.get_pipeline_config(pipeline=pipeline, workflow=workflow)
app_config = PipelineConfigManager.get_pipeline_config(pipeline=pipeline, workflow=workflow)
# init application generate entity
# init application generate entity - use RagPipelineGenerateEntity instead
application_generate_entity = RagPipelineGenerateEntity(
task_id=str(uuid.uuid4()),
app_config=pipeline_config,
pipeline_config=pipeline_config,
datasource_type=args["datasource_type"],
datasource_info=args["datasource_info"],
app_config=app_config,
pipeline_config=app_config,
datasource_type=args.get("datasource_type", ""),
datasource_info=args.get("datasource_info", {}),
dataset_id=pipeline.dataset_id,
batch=args["batch"],
document_id=args["document_id"],
batch=args.get("batch", ""),
document_id=args.get("document_id"),
inputs={},
files=[],
user_id=user.id,
stream=streaming,
invoke_from=InvokeFrom.DEBUGGER,
extras={"auto_generate_conversation_name": False},
single_iteration_run=WorkflowAppGenerateEntity.SingleIterationRunEntity(
node_id=node_id, inputs=args["inputs"]
),
call_depth=0,
workflow_run_id=str(uuid.uuid4()),
)
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)

View File

@ -104,6 +104,7 @@ class PipelineRunner(WorkflowBasedAppRunner):
SystemVariableKey.DATASET_ID: self.application_generate_entity.dataset_id,
SystemVariableKey.DATASOURCE_TYPE: self.application_generate_entity.datasource_type,
SystemVariableKey.DATASOURCE_INFO: self.application_generate_entity.datasource_info,
SystemVariableKey.INVOKE_FROM: self.application_generate_entity.invoke_from,
}
variable_pool = VariablePool(

View File

@ -1,12 +1,11 @@
from collections.abc import Mapping
from typing import Any
from core.datasource.entities.api_entities import DatasourceProviderApiEntity
from core.datasource.entities.datasource_entities import (
GetOnlineDocumentPageContentRequest,
GetOnlineDocumentPageContentResponse,
GetOnlineDocumentPagesResponse,
GetWebsiteCrawlResponse, DatasourceProviderEntity,
GetWebsiteCrawlResponse,
)
from core.plugin.entities.plugin import GenericProviderID, ToolProviderID
from core.plugin.entities.plugin_daemon import (
@ -228,7 +227,30 @@ class PluginDatasourceManager(BasePluginClient):
return resp.result
return False
def get_provider_oauth_url(self, datasource_type: str, datasource_name: str, provider: str) -> str:
"""
get the oauth url of the provider
"""
tool_provider_id = GenericProviderID(provider)
response = self._request_with_plugin_daemon_response_stream(
"GET",
f"plugin/datasource/oauth",
PluginBasicBooleanResponse,
params={"page": 1, "page_size": 256},
headers={
"X-Plugin-ID": tool_provider_id.plugin_id,
"Content-Type": "application/json",
},
)
for resp in response:
return resp.result
return False
def _get_local_file_datasource_provider(self) -> dict[str, Any]:
return {
"id": "langgenius/file/file",

View File

@ -20,3 +20,4 @@ class SystemVariableKey(StrEnum):
DATASET_ID = "dataset_id"
DATASOURCE_TYPE = "datasource_type"
DATASOURCE_INFO = "datasource_info"
INVOKE_FROM = "invoke_from"

View File

@ -17,7 +17,6 @@ from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.event import RunCompletedEvent
from core.workflow.utils.variable_template_parser import VariableTemplateParser
from models.workflow import WorkflowNodeExecutionStatus
@ -33,7 +32,7 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
_node_data_cls = DatasourceNodeData
_node_type = NodeType.DATASOURCE
def _run(self) -> Generator:
def _run(self) -> NodeRunResult:
"""
Run the datasource node
"""
@ -58,21 +57,19 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
datasource_runtime = DatasourceManager.get_datasource_runtime(
provider_id=node_data.provider_id,
datasource_name=node_data.datasource_name,
datasource_name=node_data.datasource_name or "",
tenant_id=self.tenant_id,
datasource_type=DatasourceProviderType(datasource_type),
)
except DatasourceNodeError as e:
yield RunCompletedEvent(
run_result=NodeRunResult(
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs={},
metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info},
error=f"Failed to get datasource runtime: {str(e)}",
error_type=type(e).__name__,
)
)
return
# get parameters
datasource_parameters = datasource_runtime.entity.parameters
@ -99,66 +96,55 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
provider_type=datasource_type,
)
)
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=parameters_for_log,
metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info},
outputs={
"online_document": online_document_result.result.model_dump(),
"datasource_type": datasource_type,
},
)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=parameters_for_log,
metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info},
outputs={
"online_document": online_document_result.result.model_dump(),
"datasource_type": datasource_type,
},
)
case DatasourceProviderType.WEBSITE_CRAWL | DatasourceProviderType.LOCAL_FILE:
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=parameters_for_log,
metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info},
outputs={
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=parameters_for_log,
metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info},
outputs={
"website": datasource_info,
"datasource_type": datasource_type,
},
)
},
)
case DatasourceProviderType.LOCAL_FILE:
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=parameters_for_log,
metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info},
outputs={
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=parameters_for_log,
metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info},
outputs={
"file": datasource_info,
"datasource_type": datasource_runtime.datasource_provider_type,
},
)
)
case _:
raise DatasourceNodeError(
f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}"
)
except PluginDaemonClientSideError as e:
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters_for_log,
metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info},
error=f"Failed to transform datasource message: {str(e)}",
error_type=type(e).__name__,
)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters_for_log,
metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info},
error=f"Failed to transform datasource message: {str(e)}",
error_type=type(e).__name__,
)
except DatasourceNodeError as e:
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters_for_log,
metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info},
error=f"Failed to invoke datasource: {str(e)}",
error_type=type(e).__name__,
)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters_for_log,
metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info},
error=f"Failed to invoke datasource: {str(e)}",
error_type=type(e).__name__,
)
return
def _generate_parameters(
self,
@ -225,18 +211,19 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
:return:
"""
result = {}
for parameter_name in node_data.datasource_parameters:
input = node_data.datasource_parameters[parameter_name]
if input.type == "mixed":
assert isinstance(input.value, str)
selectors = VariableTemplateParser(input.value).extract_variable_selectors()
for selector in selectors:
result[selector.variable] = selector.value_selector
elif input.type == "variable":
result[parameter_name] = input.value
elif input.type == "constant":
pass
if node_data.datasource_parameters:
for parameter_name in node_data.datasource_parameters:
input = node_data.datasource_parameters[parameter_name]
if input.type == "mixed":
assert isinstance(input.value, str)
selectors = VariableTemplateParser(input.value).extract_variable_selectors()
for selector in selectors:
result[selector.variable] = selector.value_selector
elif input.type == "variable":
result[parameter_name] = input.value
elif input.type == "constant":
pass
result = {node_id + "." + key: value for key, value in result.items()}
result = {node_id + "." + key: value for key, value in result.items()}
return result

View File

@ -1,4 +1,4 @@
from typing import Any, Literal, Union, Optional
from typing import Any, Literal, Optional, Union
from pydantic import BaseModel, field_validator
from pydantic_core.core_schema import ValidationInfo

View File

@ -3,6 +3,7 @@ import logging
from collections.abc import Mapping
from typing import Any, cast
from core.app.entities.app_invoke_entities import InvokeFrom
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.variables.segments import ObjectSegment
@ -10,16 +11,15 @@ from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.llm.node import LLMNode
from extensions.ext_database import db
from models.dataset import Dataset, Document
from models.workflow import WorkflowNodeExecutionStatus
from ..base import BaseNode
from .entities import KnowledgeIndexNodeData
from .exc import (
KnowledgeIndexNodeError,
)
from ..base import BaseNode
logger = logging.getLogger(__name__)
@ -41,6 +41,7 @@ class KnowledgeIndexNode(BaseNode[KnowledgeIndexNodeData]):
variable_pool = self.graph_runtime_state.variable_pool
# extract variables
variable = variable_pool.get(node_data.index_chunk_variable_selector)
is_preview = variable_pool.get(["sys", SystemVariableKey.INVOKE_FROM]) == InvokeFrom.DEBUGGER
if not isinstance(variable, ObjectSegment):
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
@ -55,6 +56,13 @@ class KnowledgeIndexNode(BaseNode[KnowledgeIndexNodeData]):
)
# retrieve knowledge
try:
if is_preview:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=variables,
process_data=None,
outputs={"result": "success"},
)
results = self._invoke_knowledge_index(node_data=node_data, chunks=chunks, variable_pool=variable_pool)
outputs = {"result": results}
return NodeRunResult(
@ -90,15 +98,15 @@ class KnowledgeIndexNode(BaseNode[KnowledgeIndexNodeData]):
batch = variable_pool.get(["sys", SystemVariableKey.BATCH])
if not batch:
raise KnowledgeIndexNodeError("Batch is required.")
dataset = Dataset.query.filter_by(id=dataset_id).first()
dataset = db.session.query(Dataset).filter_by(id=dataset_id).first()
if not dataset:
raise KnowledgeIndexNodeError(f"Dataset {dataset_id} not found.")
document = Document.query.filter_by(id=document_id).first()
document = db.session.query(Document).filter_by(id=document_id).first()
if not document:
raise KnowledgeIndexNodeError(f"Document {document_id} not found.")
index_processor = IndexProcessorFactory(node_data.chunk_structure).init_index_processor()
index_processor = IndexProcessorFactory(dataset.chunk_structure).init_index_processor()
index_processor.index(dataset, document, chunks)
# update document status

View File

@ -270,7 +270,7 @@ class DatasetService:
permission=rag_pipeline_dataset_create_entity.permission,
provider="vendor",
runtime_mode="rag_pipeline",
icon_info=rag_pipeline_dataset_create_entity.icon_info,
icon_info=rag_pipeline_dataset_create_entity.icon_info.model_dump(),
created_by=current_user.id,
pipeline_id=pipeline.id,
)
@ -299,7 +299,7 @@ class DatasetService:
permission=rag_pipeline_dataset_create_entity.permission,
provider="vendor",
runtime_mode="rag-pipeline",
icon_info=rag_pipeline_dataset_create_entity.icon_info,
icon_info=rag_pipeline_dataset_create_entity.icon_info.model_dump(),
)
with Session(db.engine) as session:
rag_pipeline_dsl_service = RagPipelineDslService(session)

View File

@ -21,8 +21,8 @@ class RagPipelineDatasetCreateEntity(BaseModel):
description: str
icon_info: IconInfo
permission: str
partial_member_list: list[str]
yaml_content: str
partial_member_list: Optional[list[str]] = None
yaml_content: Optional[str] = None
class RerankingModelConfig(BaseModel):