mirror of https://github.com/langgenius/dify.git
This commit is contained in:
parent
1aa13bd20d
commit
21a3509bef
|
|
@ -300,6 +300,86 @@ class PublishedRagPipelineRunApi(Resource):
|
|||
except InvokeRateLimitError as ex:
|
||||
raise InvokeRateLimitHttpError(ex.description)
|
||||
|
||||
class RagPipelinePublishedDatasourceNodeRunStatusApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_rag_pipeline
|
||||
def post(self, pipeline: Pipeline, node_id: str):
|
||||
"""
|
||||
Run rag pipeline datasource
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("job_id", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("datasource_type", type=str, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
job_id = args.get("job_id")
|
||||
if job_id == None:
|
||||
raise ValueError("missing job_id")
|
||||
datasource_type = args.get("datasource_type")
|
||||
if datasource_type == None:
|
||||
raise ValueError("missing datasource_type")
|
||||
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
result = rag_pipeline_service.run_datasource_workflow_node_status(
|
||||
pipeline=pipeline,
|
||||
node_id=node_id,
|
||||
job_id=job_id,
|
||||
account=current_user,
|
||||
datasource_type=datasource_type,
|
||||
is_published=True
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
class RagPipelineDraftDatasourceNodeRunStatusApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_rag_pipeline
|
||||
def post(self, pipeline: Pipeline, node_id: str):
|
||||
"""
|
||||
Run rag pipeline datasource
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("job_id", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("datasource_type", type=str, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
job_id = args.get("job_id")
|
||||
if job_id == None:
|
||||
raise ValueError("missing job_id")
|
||||
datasource_type = args.get("datasource_type")
|
||||
if datasource_type == None:
|
||||
raise ValueError("missing datasource_type")
|
||||
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
result = rag_pipeline_service.run_datasource_workflow_node_status(
|
||||
pipeline=pipeline,
|
||||
node_id=node_id,
|
||||
job_id=job_id,
|
||||
account=current_user,
|
||||
datasource_type=datasource_type,
|
||||
is_published=False
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class RagPipelinePublishedDatasourceNodeRunApi(Resource):
|
||||
@setup_required
|
||||
|
|
@ -894,6 +974,14 @@ api.add_resource(
|
|||
RagPipelinePublishedDatasourceNodeRunApi,
|
||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/published/datasource/nodes/<string:node_id>/run",
|
||||
)
|
||||
api.add_resource(
|
||||
RagPipelinePublishedDatasourceNodeRunStatusApi,
|
||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/published/datasource/nodes/<string:node_id>/run-status",
|
||||
)
|
||||
api.add_resource(
|
||||
RagPipelineDraftDatasourceNodeRunStatusApi,
|
||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/datasource/nodes/<string:node_id>/run-status",
|
||||
)
|
||||
|
||||
api.add_resource(
|
||||
RagPipelineDrafDatasourceNodeRunApi,
|
||||
|
|
|
|||
|
|
@ -304,4 +304,6 @@ class GetWebsiteCrawlResponse(BaseModel):
|
|||
Get website crawl response
|
||||
"""
|
||||
|
||||
result: list[WebSiteInfo]
|
||||
result: Optional[list[WebSiteInfo]] = []
|
||||
job_id: str = Field(..., description="The job id")
|
||||
status: str = Field(..., description="The status of the job")
|
||||
|
|
|
|||
|
|
@ -415,6 +415,67 @@ class RagPipelineService:
|
|||
|
||||
return workflow_node_execution
|
||||
|
||||
def run_datasource_workflow_node_status(
|
||||
self, pipeline: Pipeline, node_id: str, job_id: str, account: Account, datasource_type: str, is_published: bool
|
||||
) -> dict:
|
||||
"""
|
||||
Run published workflow datasource
|
||||
"""
|
||||
if is_published:
|
||||
# fetch published workflow by app_model
|
||||
workflow = self.get_published_workflow(pipeline=pipeline)
|
||||
else:
|
||||
workflow = self.get_draft_workflow(pipeline=pipeline)
|
||||
if not workflow:
|
||||
raise ValueError("Workflow not initialized")
|
||||
|
||||
# run draft workflow node
|
||||
datasource_node_data = None
|
||||
start_at = time.perf_counter()
|
||||
datasource_nodes = workflow.graph_dict.get("nodes", [])
|
||||
for datasource_node in datasource_nodes:
|
||||
if datasource_node.get("id") == node_id:
|
||||
datasource_node_data = datasource_node.get("data", {})
|
||||
break
|
||||
if not datasource_node_data:
|
||||
raise ValueError("Datasource node data not found")
|
||||
|
||||
|
||||
from core.datasource.datasource_manager import DatasourceManager
|
||||
|
||||
datasource_runtime = DatasourceManager.get_datasource_runtime(
|
||||
provider_id=f"{datasource_node_data.get('plugin_id')}/{datasource_node_data.get('provider_name')}",
|
||||
datasource_name=datasource_node_data.get("datasource_name"),
|
||||
tenant_id=pipeline.tenant_id,
|
||||
datasource_type=DatasourceProviderType(datasource_type),
|
||||
)
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
credentials = datasource_provider_service.get_real_datasource_credentials(
|
||||
tenant_id=pipeline.tenant_id,
|
||||
provider=datasource_node_data.get('provider_name'),
|
||||
plugin_id=datasource_node_data.get('plugin_id'),
|
||||
)
|
||||
if credentials:
|
||||
datasource_runtime.runtime.credentials = credentials[0].get("credentials")
|
||||
match datasource_type:
|
||||
|
||||
case DatasourceProviderType.WEBSITE_CRAWL:
|
||||
datasource_runtime = cast(WebsiteCrawlDatasourcePlugin, datasource_runtime)
|
||||
website_crawl_result: GetWebsiteCrawlResponse = datasource_runtime._get_website_crawl(
|
||||
user_id=account.id,
|
||||
datasource_parameters={"job_id": job_id},
|
||||
provider_type=datasource_runtime.datasource_provider_type(),
|
||||
)
|
||||
return {
|
||||
"result": [result.model_dump() for result in website_crawl_result.result],
|
||||
"job_id": website_crawl_result.job_id,
|
||||
"status": website_crawl_result.status,
|
||||
"provider_type": datasource_node_data.get("provider_type"),
|
||||
}
|
||||
case _:
|
||||
raise ValueError(f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}")
|
||||
|
||||
|
||||
def run_datasource_workflow_node(
|
||||
self, pipeline: Pipeline, node_id: str, user_inputs: dict, account: Account, datasource_type: str, is_published: bool
|
||||
) -> dict:
|
||||
|
|
|
|||
|
|
@ -94,7 +94,6 @@ services:
|
|||
PLUGIN_REMOTE_INSTALLING_HOST: ${PLUGIN_DEBUGGING_HOST:-0.0.0.0}
|
||||
PLUGIN_REMOTE_INSTALLING_PORT: ${PLUGIN_DEBUGGING_PORT:-5003}
|
||||
PLUGIN_WORKING_PATH: ${PLUGIN_WORKING_PATH:-/app/storage/cwd}
|
||||
FORCE_VERIFYING_SIGNATURE: ${FORCE_VERIFYING_SIGNATURE:-true}
|
||||
PYTHON_ENV_INIT_TIMEOUT: ${PLUGIN_PYTHON_ENV_INIT_TIMEOUT:-120}
|
||||
PLUGIN_MAX_EXECUTION_TIMEOUT: ${PLUGIN_MAX_EXECUTION_TIMEOUT:-600}
|
||||
PIP_MIRROR_URL: ${PIP_MIRROR_URL:-}
|
||||
|
|
@ -127,6 +126,7 @@ services:
|
|||
VOLCENGINE_TOS_REGION: ${PLUGIN_VOLCENGINE_TOS_REGION:-}
|
||||
THIRD_PARTY_SIGNATURE_VERIFICATION_ENABLED: true
|
||||
THIRD_PARTY_SIGNATURE_VERIFICATION_PUBLIC_KEYS: /app/keys/publickey.pem
|
||||
FORCE_VERIFYING_SIGNATURE: false
|
||||
ports:
|
||||
- "${EXPOSE_PLUGIN_DAEMON_PORT:-5002}:${PLUGIN_DAEMON_PORT:-5002}"
|
||||
- "${EXPOSE_PLUGIN_DEBUGGING_PORT:-5003}:${PLUGIN_DEBUGGING_PORT:-5003}"
|
||||
|
|
|
|||
Loading…
Reference in New Issue