diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py index 93976bd6f5..334c11bddb 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py @@ -8,6 +8,7 @@ from controllers.console import api from controllers.console.wraps import ( account_initialization_required, enterprise_license_required, + knowledge_pipeline_publish_enabled, setup_required, ) from extensions.ext_database import db @@ -116,6 +117,7 @@ class PublishCustomizedPipelineTemplateApi(Resource): @login_required @account_initialization_required @enterprise_license_required + @knowledge_pipeline_publish_enabled def post(self, pipeline_id: str): parser = reqparse.RequestParser() parser.add_argument( diff --git a/api/controllers/console/wraps.py b/api/controllers/console/wraps.py index d862dac373..3e1237615a 100644 --- a/api/controllers/console/wraps.py +++ b/api/controllers/console/wraps.py @@ -261,3 +261,14 @@ def is_allow_transfer_owner(view): abort(403) return decorated + + +def knowledge_pipeline_publish_enabled(view): + @wraps(view) + def decorated(*args, **kwargs): + features = FeatureService.get_features(current_user.current_tenant_id) + if features.knowledge_pipeline.publish_enabled: + return view(*args, **kwargs) + abort(403) + + return decorated diff --git a/api/core/app/apps/pipeline/pipeline_runner.py b/api/core/app/apps/pipeline/pipeline_runner.py index 89ada53831..9c97b8109f 100644 --- a/api/core/app/apps/pipeline/pipeline_runner.py +++ b/api/core/app/apps/pipeline/pipeline_runner.py @@ -13,7 +13,6 @@ from core.app.entities.app_invoke_entities import ( from core.variables.variables import RAGPipelineVariable, RAGPipelineVariableInput from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback from core.workflow.entities.variable_pool import VariablePool -from core.workflow.enums import SystemVariableKey from core.workflow.graph_engine.entities.event import GraphEngineEvent, GraphRunFailedEvent from core.workflow.graph_engine.entities.graph import Graph from core.workflow.system_variable import SystemVariable diff --git a/api/core/workflow/entities/variable_pool.py b/api/core/workflow/entities/variable_pool.py index d92134db58..f19128b445 100644 --- a/api/core/workflow/entities/variable_pool.py +++ b/api/core/workflow/entities/variable_pool.py @@ -18,6 +18,7 @@ from core.workflow.constants import ( ) from core.workflow.system_variable import SystemVariable from factories import variable_factory + VariableValue = Union[str, int, float, dict, list, File] VARIABLE_PATTERN = re.compile(r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}") @@ -75,7 +76,6 @@ class VariablePool(BaseModel): for key, value in rag_pipeline_variables_map.items(): self.add((RAG_PIPELINE_VARIABLE_NODE_ID, key), value) - def add(self, selector: Sequence[str], value: Any, /) -> None: """ Add a variable to the variable pool. diff --git a/api/core/workflow/nodes/datasource/datasource_node.py b/api/core/workflow/nodes/datasource/datasource_node.py index 823514de44..fbab9c631d 100644 --- a/api/core/workflow/nodes/datasource/datasource_node.py +++ b/api/core/workflow/nodes/datasource/datasource_node.py @@ -138,7 +138,7 @@ class DatasourceNode(BaseNode): datasource_runtime.get_online_document_page_content( user_id=self.user_id, datasource_parameters=GetOnlineDocumentPageContentRequest( - workspace_id=datasource_info.get("workspace_id"), + workspace_id=datasource_info.get("workspace_id"), page_id=datasource_info.get("page").get("page_id"), type=datasource_info.get("page").get("type"), ), @@ -205,7 +205,7 @@ class DatasourceNode(BaseNode): storage_key=upload_file.key, ) variable_pool.add([self.node_id, "file"], file_info) - #variable_pool.add([self.node_id, "file"], file_info.to_dict()) + # variable_pool.add([self.node_id, "file"], file_info.to_dict()) yield RunCompletedEvent( run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, diff --git a/api/services/feature_service.py b/api/services/feature_service.py index 1441e6ce16..3bb0fff0a8 100644 --- a/api/services/feature_service.py +++ b/api/services/feature_service.py @@ -88,6 +88,10 @@ class WebAppAuthModel(BaseModel): allow_email_password_login: bool = False +class KnowledgePipeline(BaseModel): + publish_enabled: bool = False + + class PluginInstallationScope(StrEnum): NONE = "none" OFFICIAL_ONLY = "official_only" @@ -126,6 +130,7 @@ class FeatureModel(BaseModel): is_allow_transfer_workspace: bool = True # pydantic configs model_config = ConfigDict(protected_namespaces=()) + knowledge_pipeline: KnowledgePipeline = KnowledgePipeline() class KnowledgeRateLimitModel(BaseModel): @@ -265,6 +270,9 @@ class FeatureService: if "knowledge_rate_limit" in billing_info: features.knowledge_rate_limit = billing_info["knowledge_rate_limit"]["limit"] + if "knowledge_pipeline_publish_enabled" in billing_info: + features.knowledge_pipeline.publish_enabled = billing_info["knowledge_pipeline_publish_enabled"] + @classmethod def _fulfill_params_from_enterprise(cls, features: SystemFeatureModel): enterprise_info = EnterpriseService.get_info()