diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py index a0dc692c4e..6b86846e3e 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py @@ -1004,6 +1004,11 @@ class RagPipelineRecommendedPluginApi(Resource): @login_required @account_initialization_required def get(self): + parser = reqparse.RequestParser() + parser.add_argument('type', type=str, location='args', required=False, default='all') + args = parser.parse_args() + plugin_type = args["type"] + rag_pipeline_service = RagPipelineService() - recommended_plugins = rag_pipeline_service.get_recommended_plugins() + recommended_plugins = rag_pipeline_service.get_recommended_plugins(plugin_type) return recommended_plugins diff --git a/api/migrations/versions/2025_12_15_1614-6bb0832495f0_alter_table_pipeline_recommended_.py b/api/migrations/versions/2025_12_15_1614-6bb0832495f0_alter_table_pipeline_recommended_.py new file mode 100644 index 0000000000..ca3172665b --- /dev/null +++ b/api/migrations/versions/2025_12_15_1614-6bb0832495f0_alter_table_pipeline_recommended_.py @@ -0,0 +1,65 @@ +"""Alter table pipeline_recommended_plugins add column type + +Revision ID: 6bb0832495f0 +Revises: 7bb281b7a422 +Create Date: 2025-12-15 16:14:38.482072 + +""" +from alembic import op +import models as models +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = '6bb0832495f0' +down_revision = '7bb281b7a422' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('app_triggers', schema=None) as batch_op: + batch_op.alter_column('provider_name', + existing_type=sa.VARCHAR(length=255), + nullable=False, + existing_server_default=sa.text("''::character varying")) + + with op.batch_alter_table('operation_logs', schema=None) as batch_op: + batch_op.alter_column('content', + existing_type=postgresql.JSON(astext_type=sa.Text()), + nullable=False) + + with op.batch_alter_table('pipeline_recommended_plugins', schema=None) as batch_op: + batch_op.add_column(sa.Column('type', sa.String(length=50), nullable=True)) + + with op.batch_alter_table('providers', schema=None) as batch_op: + batch_op.alter_column('quota_used', + existing_type=sa.BIGINT(), + nullable=False) + + # ### end Alembic commands ### + +# “推荐插件”model添加type字段;查询推荐列表接口添加type参数 +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('providers', schema=None) as batch_op: + batch_op.alter_column('quota_used', + existing_type=sa.BIGINT(), + nullable=True) + + with op.batch_alter_table('pipeline_recommended_plugins', schema=None) as batch_op: + batch_op.drop_column('type') + + with op.batch_alter_table('operation_logs', schema=None) as batch_op: + batch_op.alter_column('content', + existing_type=postgresql.JSON(astext_type=sa.Text()), + nullable=True) + + with op.batch_alter_table('app_triggers', schema=None) as batch_op: + batch_op.alter_column('provider_name', + existing_type=sa.VARCHAR(length=255), + nullable=True, + existing_server_default=sa.text("''::character varying")) + + # ### end Alembic commands ### diff --git a/api/models/dataset.py b/api/models/dataset.py index e072711b82..c4b4c6d985 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -1458,6 +1458,7 @@ class PipelineRecommendedPlugin(TypeBase): ) plugin_id: Mapped[str] = mapped_column(LongText, nullable=False) provider_name: Mapped[str] = mapped_column(LongText, nullable=False) + type: Mapped[str] = mapped_column(sa.String(50), nullable=True) position: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0) active: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, default=True) created_at: Mapped[datetime] = mapped_column( diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index 097d16e2a7..e438eae295 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -1248,12 +1248,14 @@ class RagPipelineService: session.commit() return workflow_node_execution_db_model - def get_recommended_plugins(self) -> dict: + def get_recommended_plugins(self, type) -> dict: # Query active recommended plugins + query = db.session.query(PipelineRecommendedPlugin).where(PipelineRecommendedPlugin.active == True) + if type and type != "all": + query = query.where(PipelineRecommendedPlugin.type == type) + pipeline_recommended_plugins = ( - db.session.query(PipelineRecommendedPlugin) - .where(PipelineRecommendedPlugin.active == True) - .order_by(PipelineRecommendedPlugin.position.asc()) + query.order_by(PipelineRecommendedPlugin.position.asc()) .all() )