From e710a8402c02dd83b72c7d990334f79fadde0745 Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Thu, 15 May 2025 16:07:17 +0800 Subject: [PATCH] r2 --- .../datasource/__base/datasource_plugin.py | 6 +- api/core/datasource/datasource_manager.py | 12 +- .../entities/datasource_entities.py | 3 - api/core/plugin/entities/plugin.py | 3 + api/core/plugin/impl/tool.py | 32 ++++- ..._15_1558-b35c3db83d09_add_pipeline_info.py | 113 ++++++++++++++++++ api/models/dataset.py | 6 +- api/models/workflow.py | 12 +- 8 files changed, 165 insertions(+), 22 deletions(-) create mode 100644 api/migrations/versions/2025_05_15_1558-b35c3db83d09_add_pipeline_info.py diff --git a/api/core/datasource/__base/datasource_plugin.py b/api/core/datasource/__base/datasource_plugin.py index 86bd66a3f9..8fb89e1172 100644 --- a/api/core/datasource/__base/datasource_plugin.py +++ b/api/core/datasource/__base/datasource_plugin.py @@ -2,13 +2,13 @@ from collections.abc import Generator from typing import Any, Optional from core.datasource.__base.datasource_runtime import DatasourceRuntime -from core.datasource.datasource_manager import DatasourceManager from core.datasource.entities.datasource_entities import ( DatasourceEntity, DatasourceInvokeMessage, DatasourceParameter, DatasourceProviderType, ) +from core.plugin.impl.datasource import PluginDatasourceManager from core.plugin.utils.converter import convert_parameters_to_plugin_format @@ -44,7 +44,7 @@ class DatasourcePlugin: datasource_parameters: dict[str, Any], rag_pipeline_id: Optional[str] = None, ) -> Generator[DatasourceInvokeMessage, None, None]: - manager = DatasourceManager() + manager = PluginDatasourceManager() datasource_parameters = convert_parameters_to_plugin_format(datasource_parameters) @@ -64,7 +64,7 @@ class DatasourcePlugin: datasource_parameters: dict[str, Any], rag_pipeline_id: Optional[str] = None, ) -> Generator[DatasourceInvokeMessage, None, None]: - manager = DatasourceManager() + manager = PluginDatasourceManager() datasource_parameters = convert_parameters_to_plugin_format(datasource_parameters) diff --git a/api/core/datasource/datasource_manager.py b/api/core/datasource/datasource_manager.py index 195d430015..fa141a679a 100644 --- a/api/core/datasource/datasource_manager.py +++ b/api/core/datasource/datasource_manager.py @@ -7,8 +7,8 @@ from core.datasource.__base.datasource_plugin import DatasourcePlugin from core.datasource.__base.datasource_provider import DatasourcePluginProviderController from core.datasource.entities.common_entities import I18nObject from core.datasource.entities.datasource_entities import DatasourceProviderType -from core.datasource.errors import ToolProviderNotFoundError -from core.plugin.manager.tool import PluginToolManager +from core.datasource.errors import DatasourceProviderNotFoundError +from core.plugin.impl.tool import PluginToolManager logger = logging.getLogger(__name__) @@ -37,9 +37,9 @@ class DatasourceManager: return datasource_plugin_providers[provider] manager = PluginToolManager() - provider_entity = manager.fetch_tool_provider(tenant_id, provider) + provider_entity = manager.fetch_datasource_provider(tenant_id, provider) if not provider_entity: - raise ToolProviderNotFoundError(f"plugin provider {provider} not found") + raise DatasourceProviderNotFoundError(f"plugin provider {provider} not found") controller = DatasourcePluginProviderController( entity=provider_entity.declaration, @@ -73,7 +73,7 @@ class DatasourceManager: if provider_type == DatasourceProviderType.RAG_PIPELINE: return cls.get_datasource_plugin_provider(provider_id, tenant_id).get_datasource(datasource_name) else: - raise ToolProviderNotFoundError(f"provider type {provider_type.value} not found") + raise DatasourceProviderNotFoundError(f"provider type {provider_type.value} not found") @classmethod def list_datasource_providers(cls, tenant_id: str) -> list[DatasourcePluginProviderController]: @@ -81,7 +81,7 @@ class DatasourceManager: list all the datasource providers """ manager = PluginToolManager() - provider_entities = manager.fetch_tool_providers(tenant_id) + provider_entities = manager.fetch_datasources(tenant_id) return [ DatasourcePluginProviderController( entity=provider.declaration, diff --git a/api/core/datasource/entities/datasource_entities.py b/api/core/datasource/entities/datasource_entities.py index 80e89ef1a9..6fc23e88cc 100644 --- a/api/core/datasource/entities/datasource_entities.py +++ b/api/core/datasource/entities/datasource_entities.py @@ -321,9 +321,6 @@ class DatasourceEntity(BaseModel): output_schema: Optional[dict] = None has_runtime_parameters: bool = Field(default=False, description="Whether the tool has runtime parameters") - # pydantic configs - model_config = ConfigDict(protected_namespaces=()) - @field_validator("parameters", mode="before") @classmethod def set_parameters(cls, v, validation_info: ValidationInfo) -> list[DatasourceParameter]: diff --git a/api/core/plugin/entities/plugin.py b/api/core/plugin/entities/plugin.py index bdf7d5ce1f..85d4d130ba 100644 --- a/api/core/plugin/entities/plugin.py +++ b/api/core/plugin/entities/plugin.py @@ -192,6 +192,9 @@ class ToolProviderID(GenericProviderID): if self.provider_name in ["jina", "siliconflow", "stepfun", "gitee_ai"]: self.plugin_name = f"{self.provider_name}_tool" +class DatasourceProviderID(GenericProviderID): + def __init__(self, value: str, is_hardcoded: bool = False) -> None: + super().__init__(value, is_hardcoded) class PluginDependency(BaseModel): class Type(enum.StrEnum): diff --git a/api/core/plugin/impl/tool.py b/api/core/plugin/impl/tool.py index f4360a70de..54f5418bb4 100644 --- a/api/core/plugin/impl/tool.py +++ b/api/core/plugin/impl/tool.py @@ -3,7 +3,7 @@ from typing import Any, Optional from pydantic import BaseModel -from core.plugin.entities.plugin import GenericProviderID, ToolProviderID +from core.plugin.entities.plugin import DatasourceProviderID, GenericProviderID, ToolProviderID from core.plugin.entities.plugin_daemon import ( PluginBasicBooleanResponse, PluginDatasourceProviderEntity, @@ -76,6 +76,36 @@ class PluginToolManager(BasePluginClient): return response + def fetch_datasource_provider(self, tenant_id: str, provider: str) -> PluginDatasourceProviderEntity: + """ + Fetch datasource provider for the given tenant and plugin. + """ + datasource_provider_id = DatasourceProviderID(provider) + + def transformer(json_response: dict[str, Any]) -> dict: + data = json_response.get("data") + if data: + for tool in data.get("declaration", {}).get("tools", []): + tool["identity"]["provider"] = datasource_provider_id.provider_name + + return json_response + + response = self._request_with_plugin_daemon_response( + "GET", + f"plugin/{tenant_id}/management/datasource", + PluginDatasourceProviderEntity, + params={"provider": datasource_provider_id.provider_name, "plugin_id": datasource_provider_id.plugin_id}, + transformer=transformer, + ) + + response.declaration.identity.name = f"{response.plugin_id}/{response.declaration.identity.name}" + + # override the provider name for each tool to plugin_id/provider_name + for tool in response.declaration.tools: + tool.identity.provider = response.declaration.identity.name + + return response + def fetch_tool_provider(self, tenant_id: str, provider: str) -> PluginToolProviderEntity: """ Fetch tool provider for the given tenant and plugin. diff --git a/api/migrations/versions/2025_05_15_1558-b35c3db83d09_add_pipeline_info.py b/api/migrations/versions/2025_05_15_1558-b35c3db83d09_add_pipeline_info.py new file mode 100644 index 0000000000..89fcc6aa29 --- /dev/null +++ b/api/migrations/versions/2025_05_15_1558-b35c3db83d09_add_pipeline_info.py @@ -0,0 +1,113 @@ +"""add_pipeline_info + +Revision ID: b35c3db83d09 +Revises: d28f2004b072 +Create Date: 2025-05-15 15:58:05.179877 + +""" +from alembic import op +import models as models +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = 'b35c3db83d09' +down_revision = 'd28f2004b072' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('pipeline_built_in_templates', + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('pipeline_id', models.types.StringUUID(), nullable=False), + sa.Column('name', sa.String(length=255), nullable=False), + sa.Column('description', sa.Text(), nullable=False), + sa.Column('icon', sa.JSON(), nullable=False), + sa.Column('copyright', sa.String(length=255), nullable=False), + sa.Column('privacy_policy', sa.String(length=255), nullable=False), + sa.Column('position', sa.Integer(), nullable=False), + sa.Column('install_count', sa.Integer(), nullable=False), + sa.Column('language', sa.String(length=255), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), + sa.PrimaryKeyConstraint('id', name='pipeline_built_in_template_pkey') + ) + op.create_table('pipeline_customized_templates', + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tenant_id', models.types.StringUUID(), nullable=False), + sa.Column('pipeline_id', models.types.StringUUID(), nullable=False), + sa.Column('name', sa.String(length=255), nullable=False), + sa.Column('description', sa.Text(), nullable=False), + sa.Column('icon', sa.JSON(), nullable=False), + sa.Column('position', sa.Integer(), nullable=False), + sa.Column('install_count', sa.Integer(), nullable=False), + sa.Column('language', sa.String(length=255), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), + sa.PrimaryKeyConstraint('id', name='pipeline_customized_template_pkey') + ) + with op.batch_alter_table('pipeline_customized_templates', schema=None) as batch_op: + batch_op.create_index('pipeline_customized_template_tenant_idx', ['tenant_id'], unique=False) + + op.create_table('pipelines', + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tenant_id', models.types.StringUUID(), nullable=False), + sa.Column('name', sa.String(length=255), nullable=False), + sa.Column('description', sa.Text(), server_default=sa.text("''::character varying"), nullable=False), + sa.Column('mode', sa.String(length=255), nullable=False), + sa.Column('workflow_id', models.types.StringUUID(), nullable=True), + sa.Column('is_public', sa.Boolean(), server_default=sa.text('false'), nullable=False), + sa.Column('is_published', sa.Boolean(), server_default=sa.text('false'), nullable=False), + sa.Column('created_by', models.types.StringUUID(), nullable=True), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), + sa.Column('updated_by', models.types.StringUUID(), nullable=True), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), + sa.PrimaryKeyConstraint('id', name='pipeline_pkey') + ) + op.create_table('tool_builtin_datasource_providers', + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tenant_id', models.types.StringUUID(), nullable=True), + sa.Column('user_id', models.types.StringUUID(), nullable=False), + sa.Column('provider', sa.String(length=256), nullable=False), + sa.Column('encrypted_credentials', sa.Text(), nullable=True), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.PrimaryKeyConstraint('id', name='tool_builtin_datasource_provider_pkey'), + sa.UniqueConstraint('tenant_id', 'provider', name='unique_builtin_datasource_provider') + ) + + with op.batch_alter_table('datasets', schema=None) as batch_op: + batch_op.add_column(sa.Column('keyword_number', sa.Integer(), server_default=sa.text('10'), nullable=True)) + batch_op.add_column(sa.Column('icon_info', postgresql.JSONB(astext_type=sa.Text()), nullable=True)) + batch_op.add_column(sa.Column('runtime_mode', sa.String(length=255), server_default=sa.text("'general'::character varying"), nullable=True)) + batch_op.add_column(sa.Column('pipeline_id', models.types.StringUUID(), nullable=True)) + batch_op.add_column(sa.Column('chunk_structure', sa.String(length=255), nullable=True)) + + with op.batch_alter_table('workflows', schema=None) as batch_op: + batch_op.add_column(sa.Column('rag_pipeline_variables', sa.Text(), server_default='{}', nullable=False)) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('workflows', schema=None) as batch_op: + batch_op.drop_column('rag_pipeline_variables') + + with op.batch_alter_table('datasets', schema=None) as batch_op: + batch_op.drop_column('chunk_structure') + batch_op.drop_column('pipeline_id') + batch_op.drop_column('runtime_mode') + batch_op.drop_column('icon_info') + batch_op.drop_column('keyword_number') + + op.drop_table('tool_builtin_datasource_providers') + op.drop_table('pipelines') + with op.batch_alter_table('pipeline_customized_templates', schema=None) as batch_op: + batch_op.drop_index('pipeline_customized_template_tenant_idx') + + op.drop_table('pipeline_customized_templates') + op.drop_table('pipeline_built_in_templates') + # ### end Alembic commands ### diff --git a/api/models/dataset.py b/api/models/dataset.py index 292f4aacfd..e60f110aef 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -1149,7 +1149,7 @@ class DatasetMetadataBinding(Base): created_by = db.Column(StringUUID, nullable=False) -class PipelineBuiltInTemplate(db.Model): # type: ignore[name-defined] +class PipelineBuiltInTemplate(Base): # type: ignore[name-defined] __tablename__ = "pipeline_built_in_templates" __table_args__ = (db.PrimaryKeyConstraint("id", name="pipeline_built_in_template_pkey"),) @@ -1167,7 +1167,7 @@ class PipelineBuiltInTemplate(db.Model): # type: ignore[name-defined] updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) -class PipelineCustomizedTemplate(db.Model): # type: ignore[name-defined] +class PipelineCustomizedTemplate(Base): # type: ignore[name-defined] __tablename__ = "pipeline_customized_templates" __table_args__ = ( db.PrimaryKeyConstraint("id", name="pipeline_customized_template_pkey"), @@ -1187,7 +1187,7 @@ class PipelineCustomizedTemplate(db.Model): # type: ignore[name-defined] updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) -class Pipeline(db.Model): # type: ignore[name-defined] +class Pipeline(Base): # type: ignore[name-defined] __tablename__ = "pipelines" __table_args__ = (db.PrimaryKeyConstraint("id", name="pipeline_pkey"),) diff --git a/api/models/workflow.py b/api/models/workflow.py index 2fda5431c3..b6b56ad520 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -128,8 +128,8 @@ class Workflow(Base): _conversation_variables: Mapped[str] = mapped_column( "conversation_variables", db.Text, nullable=False, server_default="{}" ) - _pipeline_variables: Mapped[str] = mapped_column( - "conversation_variables", db.Text, nullable=False, server_default="{}" + _rag_pipeline_variables: Mapped[str] = mapped_column( + "rag_pipeline_variables", db.Text, nullable=False, server_default="{}" ) @classmethod @@ -354,10 +354,10 @@ class Workflow(Base): @property def pipeline_variables(self) -> dict[str, Sequence[Variable]]: # TODO: find some way to init `self._conversation_variables` when instance created. - if self._pipeline_variables is None: - self._pipeline_variables = "{}" + if self._rag_pipeline_variables is None: + self._rag_pipeline_variables = "{}" - variables_dict: dict[str, Any] = json.loads(self._pipeline_variables) + variables_dict: dict[str, Any] = json.loads(self._rag_pipeline_variables) results = {} for k, v in variables_dict.items(): results[k] = [variable_factory.build_pipeline_variable_from_mapping(item) for item in v.values()] @@ -365,7 +365,7 @@ class Workflow(Base): @pipeline_variables.setter def pipeline_variables(self, values: dict[str, Sequence[Variable]]) -> None: - self._pipeline_variables = json.dumps( + self._rag_pipeline_variables = json.dumps( {k: {item.name: item.model_dump() for item in v} for k, v in values.items()}, ensure_ascii=False, )