mirror of
https://github.com/langgenius/dify.git
synced 2026-04-26 10:16:40 +08:00
r2
This commit is contained in:
parent
a49942b949
commit
64d997fdb0
@ -8,6 +8,7 @@ from flask_restful.inputs import int_range # type: ignore
|
|||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||||
|
|
||||||
|
from models.model import EndUser
|
||||||
import services
|
import services
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from controllers.console import api
|
from controllers.console import api
|
||||||
@ -44,7 +45,6 @@ from services.errors.llm import InvokeRateLimitError
|
|||||||
from services.rag_pipeline.pipeline_generate_service import PipelineGenerateService
|
from services.rag_pipeline.pipeline_generate_service import PipelineGenerateService
|
||||||
from services.rag_pipeline.rag_pipeline import RagPipelineService
|
from services.rag_pipeline.rag_pipeline import RagPipelineService
|
||||||
from services.rag_pipeline.rag_pipeline_manage_service import RagPipelineManageService
|
from services.rag_pipeline.rag_pipeline_manage_service import RagPipelineManageService
|
||||||
from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -243,6 +243,7 @@ class DraftRagPipelineRunApi(Resource):
|
|||||||
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||||
parser.add_argument("datasource_type", type=str, required=True, location="json")
|
parser.add_argument("datasource_type", type=str, required=True, location="json")
|
||||||
parser.add_argument("datasource_info", type=list, required=True, location="json")
|
parser.add_argument("datasource_info", type=list, required=True, location="json")
|
||||||
|
parser.add_argument("start_node_id", type=str, required=True, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -313,13 +314,20 @@ class RagPipelineDatasourceNodeRunApi(Resource):
|
|||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||||
|
parser.add_argument("datasource_type", type=str, required=True, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
inputs = args.get("inputs")
|
inputs = args.get("inputs")
|
||||||
|
if inputs == None:
|
||||||
|
raise ValueError("missing inputs")
|
||||||
|
|
||||||
rag_pipeline_service = RagPipelineService()
|
rag_pipeline_service = RagPipelineService()
|
||||||
result = rag_pipeline_service.run_datasource_workflow_node(
|
result = rag_pipeline_service.run_datasource_workflow_node(
|
||||||
pipeline=pipeline, node_id=node_id, user_inputs=inputs, account=current_user
|
pipeline=pipeline,
|
||||||
|
node_id=node_id,
|
||||||
|
user_inputs=inputs,
|
||||||
|
account=current_user,
|
||||||
|
datasource_type=args.get("datasource_type"),
|
||||||
)
|
)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
@ -648,40 +656,6 @@ class RagPipelineByIdApi(Resource):
|
|||||||
|
|
||||||
return workflow
|
return workflow
|
||||||
|
|
||||||
@setup_required
|
|
||||||
@login_required
|
|
||||||
@account_initialization_required
|
|
||||||
@get_rag_pipeline
|
|
||||||
def delete(self, pipeline: Pipeline, workflow_id: str):
|
|
||||||
"""
|
|
||||||
Delete workflow
|
|
||||||
"""
|
|
||||||
# Check permission
|
|
||||||
if not current_user.is_editor:
|
|
||||||
raise Forbidden()
|
|
||||||
|
|
||||||
if not isinstance(current_user, Account):
|
|
||||||
raise Forbidden()
|
|
||||||
|
|
||||||
rag_pipeline_service = RagPipelineService()
|
|
||||||
|
|
||||||
# Create a session and manage the transaction
|
|
||||||
with Session(db.engine) as session:
|
|
||||||
try:
|
|
||||||
rag_pipeline_service.delete_workflow(
|
|
||||||
session=session, workflow_id=workflow_id, tenant_id=pipeline.tenant_id
|
|
||||||
)
|
|
||||||
# Commit the transaction in the controller
|
|
||||||
session.commit()
|
|
||||||
except WorkflowInUseError as e:
|
|
||||||
abort(400, description=str(e))
|
|
||||||
except DraftWorkflowDeletionError as e:
|
|
||||||
abort(400, description=str(e))
|
|
||||||
except ValueError as e:
|
|
||||||
raise NotFound(str(e))
|
|
||||||
|
|
||||||
return None, 204
|
|
||||||
|
|
||||||
|
|
||||||
class PublishedRagPipelineSecondStepApi(Resource):
|
class PublishedRagPipelineSecondStepApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@ -695,8 +669,12 @@ class PublishedRagPipelineSecondStepApi(Resource):
|
|||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
# The role of the current user in the ta table must be admin, owner, or editor
|
||||||
if not current_user.is_editor:
|
if not current_user.is_editor:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
node_id = request.args.get("node_id", required=True, type=str)
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument("node_id", type=str, required=True, location="args")
|
||||||
|
args = parser.parse_args()
|
||||||
|
node_id = args.get("node_id")
|
||||||
|
if not node_id:
|
||||||
|
raise ValueError("Node ID is required")
|
||||||
rag_pipeline_service = RagPipelineService()
|
rag_pipeline_service = RagPipelineService()
|
||||||
variables = rag_pipeline_service.get_published_second_step_parameters(pipeline=pipeline, node_id=node_id)
|
variables = rag_pipeline_service.get_published_second_step_parameters(pipeline=pipeline, node_id=node_id)
|
||||||
return {
|
return {
|
||||||
@ -716,7 +694,12 @@ class DraftRagPipelineSecondStepApi(Resource):
|
|||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
# The role of the current user in the ta table must be admin, owner, or editor
|
||||||
if not current_user.is_editor:
|
if not current_user.is_editor:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
node_id = request.args.get("node_id", required=True, type=str)
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument("node_id", type=str, required=True, location="args")
|
||||||
|
args = parser.parse_args()
|
||||||
|
node_id = args.get("node_id")
|
||||||
|
if not node_id:
|
||||||
|
raise ValueError("Node ID is required")
|
||||||
|
|
||||||
rag_pipeline_service = RagPipelineService()
|
rag_pipeline_service = RagPipelineService()
|
||||||
variables = rag_pipeline_service.get_draft_second_step_parameters(pipeline=pipeline, node_id=node_id)
|
variables = rag_pipeline_service.get_draft_second_step_parameters(pipeline=pipeline, node_id=node_id)
|
||||||
@ -777,9 +760,11 @@ class RagPipelineWorkflowRunNodeExecutionListApi(Resource):
|
|||||||
run_id = str(run_id)
|
run_id = str(run_id)
|
||||||
|
|
||||||
rag_pipeline_service = RagPipelineService()
|
rag_pipeline_service = RagPipelineService()
|
||||||
|
user = cast("Account | EndUser", current_user)
|
||||||
node_executions = rag_pipeline_service.get_rag_pipeline_workflow_run_node_executions(
|
node_executions = rag_pipeline_service.get_rag_pipeline_workflow_run_node_executions(
|
||||||
pipeline=pipeline,
|
pipeline=pipeline,
|
||||||
run_id=run_id,
|
run_id=run_id,
|
||||||
|
user=user,
|
||||||
)
|
)
|
||||||
|
|
||||||
return {"data": node_executions}
|
return {"data": node_executions}
|
||||||
@ -875,9 +860,9 @@ api.add_resource(
|
|||||||
)
|
)
|
||||||
api.add_resource(
|
api.add_resource(
|
||||||
PublishedRagPipelineSecondStepApi,
|
PublishedRagPipelineSecondStepApi,
|
||||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/published/processing/paramters",
|
"/rag/pipelines/<uuid:pipeline_id>/workflows/published/processing/parameters",
|
||||||
)
|
)
|
||||||
api.add_resource(
|
api.add_resource(
|
||||||
DraftRagPipelineSecondStepApi,
|
DraftRagPipelineSecondStepApi,
|
||||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/processing/paramters",
|
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/processing/parameters",
|
||||||
)
|
)
|
||||||
|
|||||||
@ -99,6 +99,7 @@ class PipelineGenerator(BaseAppGenerator):
|
|||||||
)
|
)
|
||||||
|
|
||||||
inputs: Mapping[str, Any] = args["inputs"]
|
inputs: Mapping[str, Any] = args["inputs"]
|
||||||
|
start_node_id: str = args["start_node_id"]
|
||||||
datasource_type: str = args["datasource_type"]
|
datasource_type: str = args["datasource_type"]
|
||||||
datasource_info_list: list[Mapping[str, Any]] = args["datasource_info_list"]
|
datasource_info_list: list[Mapping[str, Any]] = args["datasource_info_list"]
|
||||||
batch = time.strftime("%Y%m%d%H%M%S") + str(random.randint(100000, 999999))
|
batch = time.strftime("%Y%m%d%H%M%S") + str(random.randint(100000, 999999))
|
||||||
@ -118,7 +119,7 @@ class PipelineGenerator(BaseAppGenerator):
|
|||||||
position=position,
|
position=position,
|
||||||
account=user,
|
account=user,
|
||||||
batch=batch,
|
batch=batch,
|
||||||
document_form=pipeline.dataset.doc_form,
|
document_form=pipeline.dataset.chunk_structure,
|
||||||
)
|
)
|
||||||
db.session.add(document)
|
db.session.add(document)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
@ -231,7 +232,7 @@ class PipelineGenerator(BaseAppGenerator):
|
|||||||
|
|
||||||
def single_iteration_generate(
|
def single_iteration_generate(
|
||||||
self,
|
self,
|
||||||
app_model: App,
|
pipeline: Pipeline,
|
||||||
workflow: Workflow,
|
workflow: Workflow,
|
||||||
node_id: str,
|
node_id: str,
|
||||||
user: Account | EndUser,
|
user: Account | EndUser,
|
||||||
@ -255,7 +256,7 @@ class PipelineGenerator(BaseAppGenerator):
|
|||||||
raise ValueError("inputs is required")
|
raise ValueError("inputs is required")
|
||||||
|
|
||||||
# convert to app config
|
# convert to app config
|
||||||
app_config = WorkflowAppConfigManager.get_app_config(app_model=app_model, workflow=workflow)
|
app_config = PipelineConfigManager.get_pipeline_config(pipeline=pipeline, workflow=workflow)
|
||||||
|
|
||||||
# init application generate entity
|
# init application generate entity
|
||||||
application_generate_entity = WorkflowAppGenerateEntity(
|
application_generate_entity = WorkflowAppGenerateEntity(
|
||||||
|
|||||||
@ -2,7 +2,6 @@ from abc import ABC, abstractmethod
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from core.datasource.__base.datasource_plugin import DatasourcePlugin
|
from core.datasource.__base.datasource_plugin import DatasourcePlugin
|
||||||
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
|
||||||
from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType
|
from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType
|
||||||
from core.entities.provider_entities import ProviderConfig
|
from core.entities.provider_entities import ProviderConfig
|
||||||
from core.plugin.impl.tool import PluginToolManager
|
from core.plugin.impl.tool import PluginToolManager
|
||||||
@ -11,9 +10,11 @@ from core.tools.errors import ToolProviderCredentialValidationError
|
|||||||
|
|
||||||
class DatasourcePluginProviderController(ABC):
|
class DatasourcePluginProviderController(ABC):
|
||||||
entity: DatasourceProviderEntityWithPlugin
|
entity: DatasourceProviderEntityWithPlugin
|
||||||
|
tenant_id: str
|
||||||
|
|
||||||
def __init__(self, entity: DatasourceProviderEntityWithPlugin) -> None:
|
def __init__(self, entity: DatasourceProviderEntityWithPlugin, tenant_id: str) -> None:
|
||||||
self.entity = entity
|
self.entity = entity
|
||||||
|
self.tenant_id = tenant_id
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def need_credentials(self) -> bool:
|
def need_credentials(self) -> bool:
|
||||||
@ -51,21 +52,6 @@ class DatasourcePluginProviderController(ABC):
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def get_datasources(self) -> list[DatasourcePlugin]: # type: ignore
|
|
||||||
"""
|
|
||||||
get all datasources
|
|
||||||
"""
|
|
||||||
return [
|
|
||||||
DatasourcePlugin(
|
|
||||||
entity=datasource_entity,
|
|
||||||
runtime=DatasourceRuntime(tenant_id=self.tenant_id),
|
|
||||||
tenant_id=self.tenant_id,
|
|
||||||
icon=self.entity.identity.icon,
|
|
||||||
plugin_unique_identifier=self.plugin_unique_identifier,
|
|
||||||
)
|
|
||||||
for datasource_entity in self.entity.datasources
|
|
||||||
]
|
|
||||||
|
|
||||||
def validate_credentials_format(self, credentials: dict[str, Any]) -> None:
|
def validate_credentials_format(self, credentials: dict[str, Any]) -> None:
|
||||||
"""
|
"""
|
||||||
validate the format of the credentials of the provider and set the default value if needed
|
validate the format of the credentials of the provider and set the default value if needed
|
||||||
|
|||||||
@ -6,7 +6,11 @@ import contexts
|
|||||||
from core.datasource.__base.datasource_plugin import DatasourcePlugin
|
from core.datasource.__base.datasource_plugin import DatasourcePlugin
|
||||||
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
|
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
|
||||||
from core.datasource.entities.common_entities import I18nObject
|
from core.datasource.entities.common_entities import I18nObject
|
||||||
|
from core.datasource.entities.datasource_entities import DatasourceProviderType
|
||||||
from core.datasource.errors import DatasourceProviderNotFoundError
|
from core.datasource.errors import DatasourceProviderNotFoundError
|
||||||
|
from core.datasource.local_file.local_file_provider import LocalFileDatasourcePluginProviderController
|
||||||
|
from core.datasource.online_document.online_document_provider import OnlineDocumentDatasourcePluginProviderController
|
||||||
|
from core.datasource.website_crawl.website_crawl_provider import WebsiteCrawlDatasourcePluginProviderController
|
||||||
from core.plugin.impl.datasource import PluginDatasourceManager
|
from core.plugin.impl.datasource import PluginDatasourceManager
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -19,7 +23,9 @@ class DatasourceManager:
|
|||||||
_builtin_tools_labels: dict[str, Union[I18nObject, None]] = {}
|
_builtin_tools_labels: dict[str, Union[I18nObject, None]] = {}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_datasource_plugin_provider(cls, provider: str, tenant_id: str) -> DatasourcePluginProviderController:
|
def get_datasource_plugin_provider(
|
||||||
|
cls, provider: str, tenant_id: str, datasource_type: DatasourceProviderType
|
||||||
|
) -> DatasourcePluginProviderController:
|
||||||
"""
|
"""
|
||||||
get the datasource plugin provider
|
get the datasource plugin provider
|
||||||
"""
|
"""
|
||||||
@ -40,12 +46,30 @@ class DatasourceManager:
|
|||||||
if not provider_entity:
|
if not provider_entity:
|
||||||
raise DatasourceProviderNotFoundError(f"plugin provider {provider} not found")
|
raise DatasourceProviderNotFoundError(f"plugin provider {provider} not found")
|
||||||
|
|
||||||
controller = DatasourcePluginProviderController(
|
match (datasource_type):
|
||||||
entity=provider_entity.declaration,
|
case DatasourceProviderType.ONLINE_DOCUMENT:
|
||||||
plugin_id=provider_entity.plugin_id,
|
controller = OnlineDocumentDatasourcePluginProviderController(
|
||||||
plugin_unique_identifier=provider_entity.plugin_unique_identifier,
|
entity=provider_entity.declaration,
|
||||||
tenant_id=tenant_id,
|
plugin_id=provider_entity.plugin_id,
|
||||||
)
|
plugin_unique_identifier=provider_entity.plugin_unique_identifier,
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
)
|
||||||
|
case DatasourceProviderType.WEBSITE_CRAWL:
|
||||||
|
controller = WebsiteCrawlDatasourcePluginProviderController(
|
||||||
|
entity=provider_entity.declaration,
|
||||||
|
plugin_id=provider_entity.plugin_id,
|
||||||
|
plugin_unique_identifier=provider_entity.plugin_unique_identifier,
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
)
|
||||||
|
case DatasourceProviderType.LOCAL_FILE:
|
||||||
|
controller = LocalFileDatasourcePluginProviderController(
|
||||||
|
entity=provider_entity.declaration,
|
||||||
|
plugin_id=provider_entity.plugin_id,
|
||||||
|
plugin_unique_identifier=provider_entity.plugin_unique_identifier,
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
)
|
||||||
|
case _:
|
||||||
|
raise ValueError(f"Unsupported datasource type: {datasource_type}")
|
||||||
|
|
||||||
datasource_plugin_providers[provider] = controller
|
datasource_plugin_providers[provider] = controller
|
||||||
|
|
||||||
@ -57,6 +81,7 @@ class DatasourceManager:
|
|||||||
provider_id: str,
|
provider_id: str,
|
||||||
datasource_name: str,
|
datasource_name: str,
|
||||||
tenant_id: str,
|
tenant_id: str,
|
||||||
|
datasource_type: DatasourceProviderType,
|
||||||
) -> DatasourcePlugin:
|
) -> DatasourcePlugin:
|
||||||
"""
|
"""
|
||||||
get the datasource runtime
|
get the datasource runtime
|
||||||
@ -68,21 +93,10 @@ class DatasourceManager:
|
|||||||
|
|
||||||
:return: the datasource plugin
|
:return: the datasource plugin
|
||||||
"""
|
"""
|
||||||
return cls.get_datasource_plugin_provider(provider_id, tenant_id).get_datasource(datasource_name)
|
return cls.get_datasource_plugin_provider(
|
||||||
|
provider_id,
|
||||||
|
tenant_id,
|
||||||
|
datasource_type,
|
||||||
|
).get_datasource(datasource_name)
|
||||||
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def list_datasource_providers(cls, tenant_id: str) -> list[DatasourcePluginProviderController]:
|
|
||||||
"""
|
|
||||||
list all the datasource providers
|
|
||||||
"""
|
|
||||||
manager = PluginDatasourceManager()
|
|
||||||
provider_entities = manager.fetch_datasource_providers(tenant_id)
|
|
||||||
return [
|
|
||||||
DatasourcePluginProviderController(
|
|
||||||
entity=provider.declaration,
|
|
||||||
plugin_id=provider.plugin_id,
|
|
||||||
plugin_unique_identifier=provider.plugin_unique_identifier,
|
|
||||||
tenant_id=tenant_id,
|
|
||||||
)
|
|
||||||
for provider in provider_entities
|
|
||||||
]
|
|
||||||
|
|||||||
@ -251,7 +251,7 @@ class GetOnlineDocumentPageContentRequest(BaseModel):
|
|||||||
Get online document page content request
|
Get online document page content request
|
||||||
"""
|
"""
|
||||||
|
|
||||||
online_document_info_list: list[OnlineDocumentInfo]
|
online_document_info: OnlineDocumentInfo
|
||||||
|
|
||||||
|
|
||||||
class OnlineDocumentPageContent(BaseModel):
|
class OnlineDocumentPageContent(BaseModel):
|
||||||
@ -259,6 +259,7 @@ class OnlineDocumentPageContent(BaseModel):
|
|||||||
Online document page content
|
Online document page content
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
workspace_id: str = Field(..., description="The workspace id")
|
||||||
page_id: str = Field(..., description="The page id")
|
page_id: str = Field(..., description="The page id")
|
||||||
content: str = Field(..., description="The content of the page")
|
content: str = Field(..., description="The content of the page")
|
||||||
|
|
||||||
@ -268,7 +269,7 @@ class GetOnlineDocumentPageContentResponse(BaseModel):
|
|||||||
Get online document page content response
|
Get online document page content response
|
||||||
"""
|
"""
|
||||||
|
|
||||||
result: list[OnlineDocumentPageContent]
|
result: OnlineDocumentPageContent
|
||||||
|
|
||||||
|
|
||||||
class GetWebsiteCrawlRequest(BaseModel):
|
class GetWebsiteCrawlRequest(BaseModel):
|
||||||
@ -286,7 +287,7 @@ class WebSiteInfo(BaseModel):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
source_url: str = Field(..., description="The url of the website")
|
source_url: str = Field(..., description="The url of the website")
|
||||||
markdown: str = Field(..., description="The markdown of the website")
|
content: str = Field(..., description="The content of the website")
|
||||||
title: str = Field(..., description="The title of the website")
|
title: str = Field(..., description="The title of the website")
|
||||||
description: str = Field(..., description="The description of the website")
|
description: str = Field(..., description="The description of the website")
|
||||||
|
|
||||||
@ -296,4 +297,4 @@ class GetWebsiteCrawlResponse(BaseModel):
|
|||||||
Get website crawl response
|
Get website crawl response
|
||||||
"""
|
"""
|
||||||
|
|
||||||
result: list[WebSiteInfo]
|
result: WebSiteInfo
|
||||||
|
|||||||
@ -26,12 +26,3 @@ class LocalFileDatasourcePlugin(DatasourcePlugin):
|
|||||||
|
|
||||||
def datasource_provider_type(self) -> DatasourceProviderType:
|
def datasource_provider_type(self) -> DatasourceProviderType:
|
||||||
return DatasourceProviderType.LOCAL_FILE
|
return DatasourceProviderType.LOCAL_FILE
|
||||||
|
|
||||||
def fork_datasource_runtime(self, runtime: DatasourceRuntime) -> "DatasourcePlugin":
|
|
||||||
return DatasourcePlugin(
|
|
||||||
entity=self.entity,
|
|
||||||
runtime=runtime,
|
|
||||||
tenant_id=self.tenant_id,
|
|
||||||
icon=self.icon,
|
|
||||||
plugin_unique_identifier=self.plugin_unique_identifier,
|
|
||||||
)
|
|
||||||
|
|||||||
@ -8,15 +8,13 @@ from core.datasource.local_file.local_file_plugin import LocalFileDatasourcePlug
|
|||||||
|
|
||||||
class LocalFileDatasourcePluginProviderController(DatasourcePluginProviderController):
|
class LocalFileDatasourcePluginProviderController(DatasourcePluginProviderController):
|
||||||
entity: DatasourceProviderEntityWithPlugin
|
entity: DatasourceProviderEntityWithPlugin
|
||||||
tenant_id: str
|
|
||||||
plugin_id: str
|
plugin_id: str
|
||||||
plugin_unique_identifier: str
|
plugin_unique_identifier: str
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, entity: DatasourceProviderEntityWithPlugin, plugin_id: str, plugin_unique_identifier: str, tenant_id: str
|
self, entity: DatasourceProviderEntityWithPlugin, plugin_id: str, plugin_unique_identifier: str, tenant_id: str
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(entity)
|
super().__init__(entity, tenant_id)
|
||||||
self.tenant_id = tenant_id
|
|
||||||
self.plugin_id = plugin_id
|
self.plugin_id = plugin_id
|
||||||
self.plugin_unique_identifier = plugin_unique_identifier
|
self.plugin_unique_identifier = plugin_unique_identifier
|
||||||
|
|
||||||
|
|||||||
@ -69,12 +69,3 @@ class OnlineDocumentDatasourcePlugin(DatasourcePlugin):
|
|||||||
|
|
||||||
def datasource_provider_type(self) -> DatasourceProviderType:
|
def datasource_provider_type(self) -> DatasourceProviderType:
|
||||||
return DatasourceProviderType.ONLINE_DOCUMENT
|
return DatasourceProviderType.ONLINE_DOCUMENT
|
||||||
|
|
||||||
def fork_datasource_runtime(self, runtime: DatasourceRuntime) -> "DatasourcePlugin":
|
|
||||||
return DatasourcePlugin(
|
|
||||||
entity=self.entity,
|
|
||||||
runtime=runtime,
|
|
||||||
tenant_id=self.tenant_id,
|
|
||||||
icon=self.icon,
|
|
||||||
plugin_unique_identifier=self.plugin_unique_identifier,
|
|
||||||
)
|
|
||||||
|
|||||||
@ -1,20 +1,18 @@
|
|||||||
from core.datasource.__base.datasource_plugin import DatasourcePlugin
|
|
||||||
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
|
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
|
||||||
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||||
from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType
|
from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType
|
||||||
|
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
|
||||||
|
|
||||||
|
|
||||||
class OnlineDocumentDatasourcePluginProviderController(DatasourcePluginProviderController):
|
class OnlineDocumentDatasourcePluginProviderController(DatasourcePluginProviderController):
|
||||||
entity: DatasourceProviderEntityWithPlugin
|
entity: DatasourceProviderEntityWithPlugin
|
||||||
tenant_id: str
|
|
||||||
plugin_id: str
|
plugin_id: str
|
||||||
plugin_unique_identifier: str
|
plugin_unique_identifier: str
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, entity: DatasourceProviderEntityWithPlugin, plugin_id: str, plugin_unique_identifier: str, tenant_id: str
|
self, entity: DatasourceProviderEntityWithPlugin, plugin_id: str, plugin_unique_identifier: str, tenant_id: str
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(entity)
|
super().__init__(entity, tenant_id)
|
||||||
self.tenant_id = tenant_id
|
|
||||||
self.plugin_id = plugin_id
|
self.plugin_id = plugin_id
|
||||||
self.plugin_unique_identifier = plugin_unique_identifier
|
self.plugin_unique_identifier = plugin_unique_identifier
|
||||||
|
|
||||||
@ -25,7 +23,7 @@ class OnlineDocumentDatasourcePluginProviderController(DatasourcePluginProviderC
|
|||||||
"""
|
"""
|
||||||
return DatasourceProviderType.ONLINE_DOCUMENT
|
return DatasourceProviderType.ONLINE_DOCUMENT
|
||||||
|
|
||||||
def get_datasource(self, datasource_name: str) -> DatasourcePlugin: # type: ignore
|
def get_datasource(self, datasource_name: str) -> OnlineDocumentDatasourcePlugin: # type: ignore
|
||||||
"""
|
"""
|
||||||
return datasource with given name
|
return datasource with given name
|
||||||
"""
|
"""
|
||||||
@ -41,7 +39,7 @@ class OnlineDocumentDatasourcePluginProviderController(DatasourcePluginProviderC
|
|||||||
if not datasource_entity:
|
if not datasource_entity:
|
||||||
raise ValueError(f"Datasource with name {datasource_name} not found")
|
raise ValueError(f"Datasource with name {datasource_name} not found")
|
||||||
|
|
||||||
return DatasourcePlugin(
|
return OnlineDocumentDatasourcePlugin(
|
||||||
entity=datasource_entity,
|
entity=datasource_entity,
|
||||||
runtime=DatasourceRuntime(tenant_id=self.tenant_id),
|
runtime=DatasourceRuntime(tenant_id=self.tenant_id),
|
||||||
tenant_id=self.tenant_id,
|
tenant_id=self.tenant_id,
|
||||||
|
|||||||
@ -7,7 +7,6 @@ from core.datasource.entities.datasource_entities import (
|
|||||||
GetWebsiteCrawlResponse,
|
GetWebsiteCrawlResponse,
|
||||||
)
|
)
|
||||||
from core.plugin.impl.datasource import PluginDatasourceManager
|
from core.plugin.impl.datasource import PluginDatasourceManager
|
||||||
from core.plugin.utils.converter import convert_parameters_to_plugin_format
|
|
||||||
|
|
||||||
|
|
||||||
class WebsiteCrawlDatasourcePlugin(DatasourcePlugin):
|
class WebsiteCrawlDatasourcePlugin(DatasourcePlugin):
|
||||||
@ -38,9 +37,7 @@ class WebsiteCrawlDatasourcePlugin(DatasourcePlugin):
|
|||||||
) -> GetWebsiteCrawlResponse:
|
) -> GetWebsiteCrawlResponse:
|
||||||
manager = PluginDatasourceManager()
|
manager = PluginDatasourceManager()
|
||||||
|
|
||||||
datasource_parameters = convert_parameters_to_plugin_format(datasource_parameters)
|
return manager.get_website_crawl(
|
||||||
|
|
||||||
return manager.invoke_first_step(
|
|
||||||
tenant_id=self.tenant_id,
|
tenant_id=self.tenant_id,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
datasource_provider=self.entity.identity.provider,
|
datasource_provider=self.entity.identity.provider,
|
||||||
@ -52,12 +49,3 @@ class WebsiteCrawlDatasourcePlugin(DatasourcePlugin):
|
|||||||
|
|
||||||
def datasource_provider_type(self) -> DatasourceProviderType:
|
def datasource_provider_type(self) -> DatasourceProviderType:
|
||||||
return DatasourceProviderType.WEBSITE_CRAWL
|
return DatasourceProviderType.WEBSITE_CRAWL
|
||||||
|
|
||||||
def fork_datasource_runtime(self, runtime: DatasourceRuntime) -> "DatasourcePlugin":
|
|
||||||
return DatasourcePlugin(
|
|
||||||
entity=self.entity,
|
|
||||||
runtime=runtime,
|
|
||||||
tenant_id=self.tenant_id,
|
|
||||||
icon=self.icon,
|
|
||||||
plugin_unique_identifier=self.plugin_unique_identifier,
|
|
||||||
)
|
|
||||||
|
|||||||
@ -1,20 +1,18 @@
|
|||||||
from core.datasource.__base.datasource_plugin import DatasourcePlugin
|
|
||||||
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
|
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
|
||||||
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||||
from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType
|
from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType
|
||||||
|
from core.datasource.website_crawl.website_crawl_plugin import WebsiteCrawlDatasourcePlugin
|
||||||
|
|
||||||
|
|
||||||
class WebsiteCrawlDatasourcePluginProviderController(DatasourcePluginProviderController):
|
class WebsiteCrawlDatasourcePluginProviderController(DatasourcePluginProviderController):
|
||||||
entity: DatasourceProviderEntityWithPlugin
|
entity: DatasourceProviderEntityWithPlugin
|
||||||
tenant_id: str
|
|
||||||
plugin_id: str
|
plugin_id: str
|
||||||
plugin_unique_identifier: str
|
plugin_unique_identifier: str
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, entity: DatasourceProviderEntityWithPlugin, plugin_id: str, plugin_unique_identifier: str, tenant_id: str
|
self, entity: DatasourceProviderEntityWithPlugin, plugin_id: str, plugin_unique_identifier: str, tenant_id: str
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(entity)
|
super().__init__(entity, tenant_id)
|
||||||
self.tenant_id = tenant_id
|
|
||||||
self.plugin_id = plugin_id
|
self.plugin_id = plugin_id
|
||||||
self.plugin_unique_identifier = plugin_unique_identifier
|
self.plugin_unique_identifier = plugin_unique_identifier
|
||||||
|
|
||||||
@ -25,7 +23,7 @@ class WebsiteCrawlDatasourcePluginProviderController(DatasourcePluginProviderCon
|
|||||||
"""
|
"""
|
||||||
return DatasourceProviderType.WEBSITE_CRAWL
|
return DatasourceProviderType.WEBSITE_CRAWL
|
||||||
|
|
||||||
def get_datasource(self, datasource_name: str) -> DatasourcePlugin: # type: ignore
|
def get_datasource(self, datasource_name: str) -> WebsiteCrawlDatasourcePlugin: # type: ignore
|
||||||
"""
|
"""
|
||||||
return datasource with given name
|
return datasource with given name
|
||||||
"""
|
"""
|
||||||
@ -41,7 +39,7 @@ class WebsiteCrawlDatasourcePluginProviderController(DatasourcePluginProviderCon
|
|||||||
if not datasource_entity:
|
if not datasource_entity:
|
||||||
raise ValueError(f"Datasource with name {datasource_name} not found")
|
raise ValueError(f"Datasource with name {datasource_name} not found")
|
||||||
|
|
||||||
return DatasourcePlugin(
|
return WebsiteCrawlDatasourcePlugin(
|
||||||
entity=datasource_entity,
|
entity=datasource_entity,
|
||||||
runtime=DatasourceRuntime(tenant_id=self.tenant_id),
|
runtime=DatasourceRuntime(tenant_id=self.tenant_id),
|
||||||
tenant_id=self.tenant_id,
|
tenant_id=self.tenant_id,
|
||||||
|
|||||||
@ -7,7 +7,6 @@ from typing import Any, Optional, Union
|
|||||||
from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_serializer, field_validator, model_validator
|
from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_serializer, field_validator, model_validator
|
||||||
|
|
||||||
from core.entities.provider_entities import ProviderConfig
|
from core.entities.provider_entities import ProviderConfig
|
||||||
from core.plugin.entities.oauth import OAuthSchema
|
|
||||||
from core.plugin.entities.parameters import (
|
from core.plugin.entities.parameters import (
|
||||||
PluginParameter,
|
PluginParameter,
|
||||||
PluginParameterOption,
|
PluginParameterOption,
|
||||||
@ -350,7 +349,6 @@ class ToolProviderEntity(BaseModel):
|
|||||||
identity: ToolProviderIdentity
|
identity: ToolProviderIdentity
|
||||||
plugin_id: Optional[str] = None
|
plugin_id: Optional[str] = None
|
||||||
credentials_schema: list[ProviderConfig] = Field(default_factory=list)
|
credentials_schema: list[ProviderConfig] = Field(default_factory=list)
|
||||||
oauth_schema: Optional[OAuthSchema] = Field(default=None, description="The oauth schema of the tool provider")
|
|
||||||
|
|
||||||
|
|
||||||
class ToolProviderEntityWithPlugin(ToolProviderEntity):
|
class ToolProviderEntityWithPlugin(ToolProviderEntity):
|
||||||
|
|||||||
@ -4,6 +4,9 @@ from typing import Any, cast
|
|||||||
from core.datasource.entities.datasource_entities import (
|
from core.datasource.entities.datasource_entities import (
|
||||||
DatasourceParameter,
|
DatasourceParameter,
|
||||||
DatasourceProviderType,
|
DatasourceProviderType,
|
||||||
|
GetOnlineDocumentPageContentRequest,
|
||||||
|
GetOnlineDocumentPageContentResponse,
|
||||||
|
GetWebsiteCrawlRequest,
|
||||||
GetWebsiteCrawlResponse,
|
GetWebsiteCrawlResponse,
|
||||||
)
|
)
|
||||||
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
|
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
|
||||||
@ -54,6 +57,7 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
|
|||||||
provider_id=node_data.provider_id,
|
provider_id=node_data.provider_id,
|
||||||
datasource_name=node_data.datasource_name,
|
datasource_name=node_data.datasource_name,
|
||||||
tenant_id=self.tenant_id,
|
tenant_id=self.tenant_id,
|
||||||
|
datasource_type=DatasourceProviderType(node_data.provider_type),
|
||||||
)
|
)
|
||||||
except DatasourceNodeError as e:
|
except DatasourceNodeError as e:
|
||||||
yield RunCompletedEvent(
|
yield RunCompletedEvent(
|
||||||
@ -82,38 +86,43 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# TODO: handle result
|
|
||||||
if datasource_runtime.datasource_provider_type() == DatasourceProviderType.ONLINE_DOCUMENT:
|
if datasource_runtime.datasource_provider_type() == DatasourceProviderType.ONLINE_DOCUMENT:
|
||||||
datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime)
|
datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime)
|
||||||
result = datasource_runtime._get_online_document_page_content(
|
online_document_result: GetOnlineDocumentPageContentResponse = (
|
||||||
user_id=self.user_id,
|
datasource_runtime._get_online_document_page_content(
|
||||||
datasource_parameters=parameters,
|
user_id=self.user_id,
|
||||||
provider_type=node_data.provider_type,
|
datasource_parameters=GetOnlineDocumentPageContentRequest(**parameters),
|
||||||
|
provider_type=datasource_runtime.datasource_provider_type(),
|
||||||
|
)
|
||||||
)
|
)
|
||||||
return NodeRunResult(
|
yield RunCompletedEvent(
|
||||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
run_result=NodeRunResult(
|
||||||
inputs=parameters_for_log,
|
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||||
metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info},
|
inputs=parameters_for_log,
|
||||||
outputs={
|
metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info},
|
||||||
"result": result.result.model_dump(),
|
outputs={
|
||||||
"datasource_type": datasource_runtime.datasource_provider_type,
|
"online_document": online_document_result.result.model_dump(),
|
||||||
},
|
"datasource_type": datasource_runtime.datasource_provider_type,
|
||||||
|
},
|
||||||
|
)
|
||||||
)
|
)
|
||||||
elif datasource_runtime.datasource_provider_type == DatasourceProviderType.WEBSITE_CRAWL:
|
elif datasource_runtime.datasource_provider_type == DatasourceProviderType.WEBSITE_CRAWL:
|
||||||
datasource_runtime = cast(WebsiteCrawlDatasourcePlugin, datasource_runtime)
|
datasource_runtime = cast(WebsiteCrawlDatasourcePlugin, datasource_runtime)
|
||||||
result: GetWebsiteCrawlResponse = datasource_runtime._get_website_crawl(
|
website_crawl_result: GetWebsiteCrawlResponse = datasource_runtime._get_website_crawl(
|
||||||
user_id=self.user_id,
|
user_id=self.user_id,
|
||||||
datasource_parameters=parameters,
|
datasource_parameters=GetWebsiteCrawlRequest(**parameters),
|
||||||
provider_type=node_data.provider_type,
|
provider_type=datasource_runtime.datasource_provider_type(),
|
||||||
)
|
)
|
||||||
return NodeRunResult(
|
yield RunCompletedEvent(
|
||||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
run_result=NodeRunResult(
|
||||||
inputs=parameters_for_log,
|
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||||
metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info},
|
inputs=parameters_for_log,
|
||||||
outputs={
|
metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info},
|
||||||
"result": result.result.model_dump(),
|
outputs={
|
||||||
"datasource_type": datasource_runtime.datasource_provider_type,
|
"website": website_crawl_result.result.model_dump(),
|
||||||
},
|
"datasource_type": datasource_runtime.datasource_provider_type,
|
||||||
|
},
|
||||||
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise DatasourceNodeError(
|
raise DatasourceNodeError(
|
||||||
|
|||||||
@ -360,7 +360,7 @@ class Workflow(Base):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def rag_pipeline_variables(self) -> Sequence[Variable]:
|
def rag_pipeline_variables(self) -> list[dict]:
|
||||||
# TODO: find some way to init `self._conversation_variables` when instance created.
|
# TODO: find some way to init `self._conversation_variables` when instance created.
|
||||||
if self._rag_pipeline_variables is None:
|
if self._rag_pipeline_variables is None:
|
||||||
self._rag_pipeline_variables = "{}"
|
self._rag_pipeline_variables = "{}"
|
||||||
|
|||||||
@ -2,12 +2,11 @@ from collections.abc import Mapping
|
|||||||
from typing import Any, Union
|
from typing import Any, Union
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
|
|
||||||
from core.app.apps.pipeline.pipeline_generator import PipelineGenerator
|
from core.app.apps.pipeline.pipeline_generator import PipelineGenerator
|
||||||
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
|
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
|
||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
from models.dataset import Pipeline
|
from models.dataset import Pipeline
|
||||||
from models.model import Account, App, AppMode, EndUser
|
from models.model import Account, App, EndUser
|
||||||
from models.workflow import Workflow
|
from models.workflow import Workflow
|
||||||
from services.rag_pipeline.rag_pipeline import RagPipelineService
|
from services.rag_pipeline.rag_pipeline import RagPipelineService
|
||||||
|
|
||||||
@ -57,23 +56,15 @@ class PipelineGenerateService:
|
|||||||
return max_active_requests
|
return max_active_requests
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def generate_single_iteration(cls, app_model: App, user: Account, node_id: str, args: Any, streaming: bool = True):
|
def generate_single_iteration(
|
||||||
if app_model.mode == AppMode.ADVANCED_CHAT.value:
|
cls, pipeline: Pipeline, user: Account, node_id: str, args: Any, streaming: bool = True
|
||||||
workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER)
|
):
|
||||||
return AdvancedChatAppGenerator.convert_to_event_stream(
|
workflow = cls._get_workflow(pipeline, InvokeFrom.DEBUGGER)
|
||||||
AdvancedChatAppGenerator().single_iteration_generate(
|
return PipelineGenerator.convert_to_event_stream(
|
||||||
app_model=app_model, workflow=workflow, node_id=node_id, user=user, args=args, streaming=streaming
|
PipelineGenerator().single_iteration_generate(
|
||||||
)
|
pipeline=pipeline, workflow=workflow, node_id=node_id, user=user, args=args, streaming=streaming
|
||||||
)
|
)
|
||||||
elif app_model.mode == AppMode.WORKFLOW.value:
|
)
|
||||||
workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER)
|
|
||||||
return AdvancedChatAppGenerator.convert_to_event_stream(
|
|
||||||
WorkflowAppGenerator().single_iteration_generate(
|
|
||||||
app_model=app_model, workflow=workflow, node_id=node_id, user=user, args=args, streaming=streaming
|
|
||||||
)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Invalid app mode {app_model.mode}")
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def generate_single_loop(cls, pipeline: Pipeline, user: Account, node_id: str, args: Any, streaming: bool = True):
|
def generate_single_loop(cls, pipeline: Pipeline, user: Account, node_id: str, args: Any, streaming: bool = True):
|
||||||
|
|||||||
@ -3,7 +3,7 @@ import threading
|
|||||||
import time
|
import time
|
||||||
from collections.abc import Callable, Generator, Sequence
|
from collections.abc import Callable, Generator, Sequence
|
||||||
from datetime import UTC, datetime
|
from datetime import UTC, datetime
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional, cast
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
from flask_login import current_user
|
from flask_login import current_user
|
||||||
@ -12,6 +12,9 @@ from sqlalchemy.orm import Session
|
|||||||
|
|
||||||
import contexts
|
import contexts
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
|
from core.datasource.entities.datasource_entities import DatasourceProviderType, GetOnlineDocumentPagesRequest, GetOnlineDocumentPagesResponse, GetWebsiteCrawlRequest, GetWebsiteCrawlResponse
|
||||||
|
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
|
||||||
|
from core.datasource.website_crawl.website_crawl_plugin import WebsiteCrawlDatasourcePlugin
|
||||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||||
from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository
|
from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository
|
||||||
from core.variables.variables import Variable
|
from core.variables.variables import Variable
|
||||||
@ -30,6 +33,7 @@ from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
|||||||
from models.account import Account
|
from models.account import Account
|
||||||
from models.dataset import Pipeline, PipelineBuiltInTemplate, PipelineCustomizedTemplate # type: ignore
|
from models.dataset import Pipeline, PipelineBuiltInTemplate, PipelineCustomizedTemplate # type: ignore
|
||||||
from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom
|
from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom
|
||||||
|
from models.model import EndUser
|
||||||
from models.workflow import (
|
from models.workflow import (
|
||||||
Workflow,
|
Workflow,
|
||||||
WorkflowNodeExecution,
|
WorkflowNodeExecution,
|
||||||
@ -394,8 +398,8 @@ class RagPipelineService:
|
|||||||
return workflow_node_execution
|
return workflow_node_execution
|
||||||
|
|
||||||
def run_datasource_workflow_node(
|
def run_datasource_workflow_node(
|
||||||
self, pipeline: Pipeline, node_id: str, user_inputs: dict, account: Account
|
self, pipeline: Pipeline, node_id: str, user_inputs: dict, account: Account, datasource_type: str
|
||||||
) -> WorkflowNodeExecution:
|
) -> dict:
|
||||||
"""
|
"""
|
||||||
Run published workflow datasource
|
Run published workflow datasource
|
||||||
"""
|
"""
|
||||||
@ -416,17 +420,36 @@ class RagPipelineService:
|
|||||||
provider_id=datasource_node_data.get("provider_id"),
|
provider_id=datasource_node_data.get("provider_id"),
|
||||||
datasource_name=datasource_node_data.get("datasource_name"),
|
datasource_name=datasource_node_data.get("datasource_name"),
|
||||||
tenant_id=pipeline.tenant_id,
|
tenant_id=pipeline.tenant_id,
|
||||||
|
datasource_type=DatasourceProviderType(datasource_type),
|
||||||
)
|
)
|
||||||
result = datasource_runtime._invoke_first_step(
|
if datasource_runtime.datasource_provider_type() == DatasourceProviderType.ONLINE_DOCUMENT:
|
||||||
inputs=user_inputs,
|
datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime)
|
||||||
provider_type=datasource_node_data.get("provider_type"),
|
online_document_result: GetOnlineDocumentPagesResponse = (
|
||||||
user_id=account.id,
|
datasource_runtime._get_online_document_pages(
|
||||||
)
|
user_id=account.id,
|
||||||
|
datasource_parameters=GetOnlineDocumentPagesRequest(tenant_id=pipeline.tenant_id),
|
||||||
|
provider_type=datasource_runtime.datasource_provider_type(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"result": [page.model_dump() for page in online_document_result.result],
|
||||||
|
"provider_type": datasource_node_data.get("provider_type"),
|
||||||
|
}
|
||||||
|
|
||||||
|
elif datasource_runtime.datasource_provider_type == DatasourceProviderType.WEBSITE_CRAWL:
|
||||||
|
datasource_runtime = cast(WebsiteCrawlDatasourcePlugin, datasource_runtime)
|
||||||
|
website_crawl_result: GetWebsiteCrawlResponse = datasource_runtime._get_website_crawl(
|
||||||
|
user_id=account.id,
|
||||||
|
datasource_parameters=GetWebsiteCrawlRequest(**user_inputs),
|
||||||
|
provider_type=datasource_runtime.datasource_provider_type(),
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"result": website_crawl_result.result.model_dump(),
|
||||||
|
"provider_type": datasource_node_data.get("provider_type"),
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}")
|
||||||
|
|
||||||
return {
|
|
||||||
"result": result,
|
|
||||||
"provider_type": datasource_node_data.get("provider_type"),
|
|
||||||
}
|
|
||||||
|
|
||||||
def run_free_workflow_node(
|
def run_free_workflow_node(
|
||||||
self, node_data: dict, tenant_id: str, user_id: str, node_id: str, user_inputs: dict[str, Any]
|
self, node_data: dict, tenant_id: str, user_id: str, node_id: str, user_inputs: dict[str, Any]
|
||||||
@ -587,7 +610,7 @@ class RagPipelineService:
|
|||||||
|
|
||||||
return workflow
|
return workflow
|
||||||
|
|
||||||
def get_published_second_step_parameters(self, pipeline: Pipeline, node_id: str) -> dict:
|
def get_published_second_step_parameters(self, pipeline: Pipeline, node_id: str) -> list[dict]:
|
||||||
"""
|
"""
|
||||||
Get second step parameters of rag pipeline
|
Get second step parameters of rag pipeline
|
||||||
"""
|
"""
|
||||||
@ -599,7 +622,7 @@ class RagPipelineService:
|
|||||||
# get second step node
|
# get second step node
|
||||||
rag_pipeline_variables = workflow.rag_pipeline_variables
|
rag_pipeline_variables = workflow.rag_pipeline_variables
|
||||||
if not rag_pipeline_variables:
|
if not rag_pipeline_variables:
|
||||||
return {}
|
return []
|
||||||
|
|
||||||
# get datasource provider
|
# get datasource provider
|
||||||
datasource_provider_variables = [
|
datasource_provider_variables = [
|
||||||
@ -609,7 +632,7 @@ class RagPipelineService:
|
|||||||
]
|
]
|
||||||
return datasource_provider_variables
|
return datasource_provider_variables
|
||||||
|
|
||||||
def get_draft_second_step_parameters(self, pipeline: Pipeline, node_id: str) -> dict:
|
def get_draft_second_step_parameters(self, pipeline: Pipeline, node_id: str) -> list[dict]:
|
||||||
"""
|
"""
|
||||||
Get second step parameters of rag pipeline
|
Get second step parameters of rag pipeline
|
||||||
"""
|
"""
|
||||||
@ -621,7 +644,7 @@ class RagPipelineService:
|
|||||||
# get second step node
|
# get second step node
|
||||||
rag_pipeline_variables = workflow.rag_pipeline_variables
|
rag_pipeline_variables = workflow.rag_pipeline_variables
|
||||||
if not rag_pipeline_variables:
|
if not rag_pipeline_variables:
|
||||||
return {}
|
return []
|
||||||
|
|
||||||
# get datasource provider
|
# get datasource provider
|
||||||
datasource_provider_variables = [
|
datasource_provider_variables = [
|
||||||
@ -702,6 +725,7 @@ class RagPipelineService:
|
|||||||
self,
|
self,
|
||||||
pipeline: Pipeline,
|
pipeline: Pipeline,
|
||||||
run_id: str,
|
run_id: str,
|
||||||
|
user: Account | EndUser,
|
||||||
) -> list[WorkflowNodeExecution]:
|
) -> list[WorkflowNodeExecution]:
|
||||||
"""
|
"""
|
||||||
Get workflow run node execution list
|
Get workflow run node execution list
|
||||||
@ -716,11 +740,16 @@ class RagPipelineService:
|
|||||||
|
|
||||||
# Use the repository to get the node execution
|
# Use the repository to get the node execution
|
||||||
repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||||
session_factory=db.engine, tenant_id=pipeline.tenant_id, app_id=pipeline.id
|
session_factory=db.engine,
|
||||||
|
app_id=pipeline.id,
|
||||||
|
user=user,
|
||||||
|
triggered_from=None
|
||||||
)
|
)
|
||||||
|
|
||||||
# Use the repository to get the node executions with ordering
|
# Use the repository to get the node executions with ordering
|
||||||
order_config = OrderConfig(order_by=["index"], order_direction="desc")
|
order_config = OrderConfig(order_by=["index"], order_direction="desc")
|
||||||
node_executions = repository.get_by_workflow_run(workflow_run_id=run_id, order_config=order_config)
|
node_executions = repository.get_by_workflow_run(workflow_run_id=run_id, order_config=order_config)
|
||||||
|
# Convert domain models to database models
|
||||||
|
workflow_node_executions = [repository.to_db_model(node_execution) for node_execution in node_executions]
|
||||||
|
|
||||||
return list(node_executions)
|
return workflow_node_executions
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user