part of add type to orm (#26262)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
Asuka Minato 2025-10-11 00:40:54 +09:00 committed by GitHub
parent fdb53fdeb1
commit 3922ad876f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 94 additions and 93 deletions

View File

@ -30,6 +30,8 @@ jobs:
run: | run: |
uvx --from ast-grep-cli sg --pattern 'db.session.query($WHATEVER).filter($HERE)' --rewrite 'db.session.query($WHATEVER).where($HERE)' -l py --update-all uvx --from ast-grep-cli sg --pattern 'db.session.query($WHATEVER).filter($HERE)' --rewrite 'db.session.query($WHATEVER).where($HERE)' -l py --update-all
uvx --from ast-grep-cli sg --pattern 'session.query($WHATEVER).filter($HERE)' --rewrite 'session.query($WHATEVER).where($HERE)' -l py --update-all uvx --from ast-grep-cli sg --pattern 'session.query($WHATEVER).filter($HERE)' --rewrite 'session.query($WHATEVER).where($HERE)' -l py --update-all
uvx --from ast-grep-cli sg -p '$A = db.Column($$$B)' -r '$A = mapped_column($$$B)' -l py --update-all
uvx --from ast-grep-cli sg -p '$A : $T = db.Column($$$B)' -r '$A : $T = mapped_column($$$B)' -l py --update-all
# Convert Optional[T] to T | None (ignoring quoted types) # Convert Optional[T] to T | None (ignoring quoted types)
cat > /tmp/optional-rule.yml << 'EOF' cat > /tmp/optional-rule.yml << 'EOF'
id: convert-optional-to-union id: convert-optional-to-union

View File

@ -61,18 +61,18 @@ class Dataset(Base):
created_by = mapped_column(StringUUID, nullable=False) created_by = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
updated_by = mapped_column(StringUUID, nullable=True) updated_by = mapped_column(StringUUID, nullable=True)
updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
embedding_model = mapped_column(db.String(255), nullable=True) embedding_model = mapped_column(db.String(255), nullable=True)
embedding_model_provider = mapped_column(db.String(255), nullable=True) embedding_model_provider = mapped_column(db.String(255), nullable=True)
keyword_number = db.Column(db.Integer, nullable=True, server_default=db.text("10")) keyword_number = mapped_column(sa.Integer, nullable=True, server_default=db.text("10"))
collection_binding_id = mapped_column(StringUUID, nullable=True) collection_binding_id = mapped_column(StringUUID, nullable=True)
retrieval_model = mapped_column(JSONB, nullable=True) retrieval_model = mapped_column(JSONB, nullable=True)
built_in_field_enabled = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) built_in_field_enabled = mapped_column(sa.Boolean, nullable=False, server_default=db.text("false"))
icon_info = db.Column(JSONB, nullable=True) icon_info = mapped_column(JSONB, nullable=True)
runtime_mode = db.Column(db.String(255), nullable=True, server_default=db.text("'general'::character varying")) runtime_mode = mapped_column(db.String(255), nullable=True, server_default=db.text("'general'::character varying"))
pipeline_id = db.Column(StringUUID, nullable=True) pipeline_id = mapped_column(StringUUID, nullable=True)
chunk_structure = db.Column(db.String(255), nullable=True) chunk_structure = mapped_column(db.String(255), nullable=True)
enable_api = db.Column(db.Boolean, nullable=False, server_default=db.text("true")) enable_api = mapped_column(sa.Boolean, nullable=False, server_default=db.text("true"))
@property @property
def total_documents(self): def total_documents(self):
@ -1226,21 +1226,21 @@ class PipelineBuiltInTemplate(Base): # type: ignore[name-defined]
__tablename__ = "pipeline_built_in_templates" __tablename__ = "pipeline_built_in_templates"
__table_args__ = (db.PrimaryKeyConstraint("id", name="pipeline_built_in_template_pkey"),) __table_args__ = (db.PrimaryKeyConstraint("id", name="pipeline_built_in_template_pkey"),)
id = db.Column(StringUUID, server_default=db.text("uuidv7()")) id = mapped_column(StringUUID, server_default=db.text("uuidv7()"))
name = db.Column(db.String(255), nullable=False) name = mapped_column(db.String(255), nullable=False)
description = db.Column(db.Text, nullable=False) description = mapped_column(sa.Text, nullable=False)
chunk_structure = db.Column(db.String(255), nullable=False) chunk_structure = mapped_column(db.String(255), nullable=False)
icon = db.Column(db.JSON, nullable=False) icon = mapped_column(sa.JSON, nullable=False)
yaml_content = db.Column(db.Text, nullable=False) yaml_content = mapped_column(sa.Text, nullable=False)
copyright = db.Column(db.String(255), nullable=False) copyright = mapped_column(db.String(255), nullable=False)
privacy_policy = db.Column(db.String(255), nullable=False) privacy_policy = mapped_column(db.String(255), nullable=False)
position = db.Column(db.Integer, nullable=False) position = mapped_column(sa.Integer, nullable=False)
install_count = db.Column(db.Integer, nullable=False, default=0) install_count = mapped_column(sa.Integer, nullable=False, default=0)
language = db.Column(db.String(255), nullable=False) language = mapped_column(db.String(255), nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
created_by = db.Column(StringUUID, nullable=False) created_by = mapped_column(StringUUID, nullable=False)
updated_by = db.Column(StringUUID, nullable=True) updated_by = mapped_column(StringUUID, nullable=True)
@property @property
def created_user_name(self): def created_user_name(self):
@ -1257,20 +1257,20 @@ class PipelineCustomizedTemplate(Base): # type: ignore[name-defined]
db.Index("pipeline_customized_template_tenant_idx", "tenant_id"), db.Index("pipeline_customized_template_tenant_idx", "tenant_id"),
) )
id = db.Column(StringUUID, server_default=db.text("uuidv7()")) id = mapped_column(StringUUID, server_default=db.text("uuidv7()"))
tenant_id = db.Column(StringUUID, nullable=False) tenant_id = mapped_column(StringUUID, nullable=False)
name = db.Column(db.String(255), nullable=False) name = mapped_column(db.String(255), nullable=False)
description = db.Column(db.Text, nullable=False) description = mapped_column(sa.Text, nullable=False)
chunk_structure = db.Column(db.String(255), nullable=False) chunk_structure = mapped_column(db.String(255), nullable=False)
icon = db.Column(db.JSON, nullable=False) icon = mapped_column(sa.JSON, nullable=False)
position = db.Column(db.Integer, nullable=False) position = mapped_column(sa.Integer, nullable=False)
yaml_content = db.Column(db.Text, nullable=False) yaml_content = mapped_column(sa.Text, nullable=False)
install_count = db.Column(db.Integer, nullable=False, default=0) install_count = mapped_column(sa.Integer, nullable=False, default=0)
language = db.Column(db.String(255), nullable=False) language = mapped_column(db.String(255), nullable=False)
created_by = db.Column(StringUUID, nullable=False) created_by = mapped_column(StringUUID, nullable=False)
updated_by = db.Column(StringUUID, nullable=True) updated_by = mapped_column(StringUUID, nullable=True)
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
@property @property
def created_user_name(self): def created_user_name(self):
@ -1284,17 +1284,17 @@ class Pipeline(Base): # type: ignore[name-defined]
__tablename__ = "pipelines" __tablename__ = "pipelines"
__table_args__ = (db.PrimaryKeyConstraint("id", name="pipeline_pkey"),) __table_args__ = (db.PrimaryKeyConstraint("id", name="pipeline_pkey"),)
id = db.Column(StringUUID, server_default=db.text("uuidv7()")) id = mapped_column(StringUUID, server_default=db.text("uuidv7()"))
tenant_id: Mapped[str] = db.Column(StringUUID, nullable=False) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
name = db.Column(db.String(255), nullable=False) name = mapped_column(db.String(255), nullable=False)
description = db.Column(db.Text, nullable=False, server_default=db.text("''::character varying")) description = mapped_column(sa.Text, nullable=False, server_default=db.text("''::character varying"))
workflow_id = db.Column(StringUUID, nullable=True) workflow_id = mapped_column(StringUUID, nullable=True)
is_public = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) is_public = mapped_column(sa.Boolean, nullable=False, server_default=db.text("false"))
is_published = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) is_published = mapped_column(sa.Boolean, nullable=False, server_default=db.text("false"))
created_by = db.Column(StringUUID, nullable=True) created_by = mapped_column(StringUUID, nullable=True)
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
updated_by = db.Column(StringUUID, nullable=True) updated_by = mapped_column(StringUUID, nullable=True)
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
def retrieve_dataset(self, session: Session): def retrieve_dataset(self, session: Session):
return session.query(Dataset).where(Dataset.pipeline_id == self.id).first() return session.query(Dataset).where(Dataset.pipeline_id == self.id).first()
@ -1307,25 +1307,25 @@ class DocumentPipelineExecutionLog(Base):
db.Index("document_pipeline_execution_logs_document_id_idx", "document_id"), db.Index("document_pipeline_execution_logs_document_id_idx", "document_id"),
) )
id = db.Column(StringUUID, server_default=db.text("uuidv7()")) id = mapped_column(StringUUID, server_default=db.text("uuidv7()"))
pipeline_id = db.Column(StringUUID, nullable=False) pipeline_id = mapped_column(StringUUID, nullable=False)
document_id = db.Column(StringUUID, nullable=False) document_id = mapped_column(StringUUID, nullable=False)
datasource_type = db.Column(db.String(255), nullable=False) datasource_type = mapped_column(db.String(255), nullable=False)
datasource_info = db.Column(db.Text, nullable=False) datasource_info = mapped_column(sa.Text, nullable=False)
datasource_node_id = db.Column(db.String(255), nullable=False) datasource_node_id = mapped_column(db.String(255), nullable=False)
input_data = db.Column(db.JSON, nullable=False) input_data = mapped_column(sa.JSON, nullable=False)
created_by = db.Column(StringUUID, nullable=True) created_by = mapped_column(StringUUID, nullable=True)
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
class PipelineRecommendedPlugin(Base): class PipelineRecommendedPlugin(Base):
__tablename__ = "pipeline_recommended_plugins" __tablename__ = "pipeline_recommended_plugins"
__table_args__ = (db.PrimaryKeyConstraint("id", name="pipeline_recommended_plugin_pkey"),) __table_args__ = (db.PrimaryKeyConstraint("id", name="pipeline_recommended_plugin_pkey"),)
id = db.Column(StringUUID, server_default=db.text("uuidv7()")) id = mapped_column(StringUUID, server_default=db.text("uuidv7()"))
plugin_id = db.Column(db.Text, nullable=False) plugin_id = mapped_column(sa.Text, nullable=False)
provider_name = db.Column(db.Text, nullable=False) provider_name = mapped_column(sa.Text, nullable=False)
position = db.Column(db.Integer, nullable=False, default=0) position = mapped_column(sa.Integer, nullable=False, default=0)
active = db.Column(db.Boolean, nullable=False, default=True) active = mapped_column(sa.Boolean, nullable=False, default=True)
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())

View File

@ -1,7 +1,8 @@
from datetime import datetime from datetime import datetime
import sqlalchemy as sa
from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.orm import Mapped from sqlalchemy.orm import Mapped, mapped_column
from .base import Base from .base import Base
from .engine import db from .engine import db
@ -15,10 +16,10 @@ class DatasourceOauthParamConfig(Base): # type: ignore[name-defined]
db.UniqueConstraint("plugin_id", "provider", name="datasource_oauth_config_datasource_id_provider_idx"), db.UniqueConstraint("plugin_id", "provider", name="datasource_oauth_config_datasource_id_provider_idx"),
) )
id = db.Column(StringUUID, server_default=db.text("uuidv7()")) id = mapped_column(StringUUID, server_default=db.text("uuidv7()"))
plugin_id: Mapped[str] = db.Column(db.String(255), nullable=False) plugin_id: Mapped[str] = mapped_column(db.String(255), nullable=False)
provider: Mapped[str] = db.Column(db.String(255), nullable=False) provider: Mapped[str] = mapped_column(db.String(255), nullable=False)
system_credentials: Mapped[dict] = db.Column(JSONB, nullable=False) system_credentials: Mapped[dict] = mapped_column(JSONB, nullable=False)
class DatasourceProvider(Base): class DatasourceProvider(Base):
@ -28,19 +29,19 @@ class DatasourceProvider(Base):
db.UniqueConstraint("tenant_id", "plugin_id", "provider", "name", name="datasource_provider_unique_name"), db.UniqueConstraint("tenant_id", "plugin_id", "provider", "name", name="datasource_provider_unique_name"),
db.Index("datasource_provider_auth_type_provider_idx", "tenant_id", "plugin_id", "provider"), db.Index("datasource_provider_auth_type_provider_idx", "tenant_id", "plugin_id", "provider"),
) )
id = db.Column(StringUUID, server_default=db.text("uuidv7()")) id = mapped_column(StringUUID, server_default=db.text("uuidv7()"))
tenant_id = db.Column(StringUUID, nullable=False) tenant_id = mapped_column(StringUUID, nullable=False)
name: Mapped[str] = db.Column(db.String(255), nullable=False) name: Mapped[str] = mapped_column(db.String(255), nullable=False)
provider: Mapped[str] = db.Column(db.String(255), nullable=False) provider: Mapped[str] = mapped_column(db.String(255), nullable=False)
plugin_id: Mapped[str] = db.Column(db.String(255), nullable=False) plugin_id: Mapped[str] = mapped_column(db.String(255), nullable=False)
auth_type: Mapped[str] = db.Column(db.String(255), nullable=False) auth_type: Mapped[str] = mapped_column(db.String(255), nullable=False)
encrypted_credentials: Mapped[dict] = db.Column(JSONB, nullable=False) encrypted_credentials: Mapped[dict] = mapped_column(JSONB, nullable=False)
avatar_url: Mapped[str] = db.Column(db.Text, nullable=True, default="default") avatar_url: Mapped[str] = mapped_column(sa.Text, nullable=True, default="default")
is_default: Mapped[bool] = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) is_default: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=db.text("false"))
expires_at: Mapped[int] = db.Column(db.Integer, nullable=False, server_default="-1") expires_at: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default="-1")
created_at: Mapped[datetime] = db.Column(db.DateTime, nullable=False, default=datetime.now) created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, default=datetime.now)
updated_at: Mapped[datetime] = db.Column(db.DateTime, nullable=False, default=datetime.now) updated_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, default=datetime.now)
class DatasourceOauthTenantParamConfig(Base): class DatasourceOauthTenantParamConfig(Base):
@ -50,12 +51,12 @@ class DatasourceOauthTenantParamConfig(Base):
db.UniqueConstraint("tenant_id", "plugin_id", "provider", name="datasource_oauth_tenant_config_unique"), db.UniqueConstraint("tenant_id", "plugin_id", "provider", name="datasource_oauth_tenant_config_unique"),
) )
id = db.Column(StringUUID, server_default=db.text("uuidv7()")) id = mapped_column(StringUUID, server_default=db.text("uuidv7()"))
tenant_id = db.Column(StringUUID, nullable=False) tenant_id = mapped_column(StringUUID, nullable=False)
provider: Mapped[str] = db.Column(db.String(255), nullable=False) provider: Mapped[str] = mapped_column(db.String(255), nullable=False)
plugin_id: Mapped[str] = db.Column(db.String(255), nullable=False) plugin_id: Mapped[str] = mapped_column(db.String(255), nullable=False)
client_params: Mapped[dict] = db.Column(JSONB, nullable=False, default={}) client_params: Mapped[dict] = mapped_column(JSONB, nullable=False, default={})
enabled: Mapped[bool] = db.Column(db.Boolean, nullable=False, default=False) enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, default=False)
created_at: Mapped[datetime] = db.Column(db.DateTime, nullable=False, default=datetime.now) created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, default=datetime.now)
updated_at: Mapped[datetime] = db.Column(db.DateTime, nullable=False, default=datetime.now) updated_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, default=datetime.now)

View File

@ -8,8 +8,6 @@ from sqlalchemy.orm import Mapped, mapped_column
from libs.datetime_utils import naive_utc_now from libs.datetime_utils import naive_utc_now
from models.base import Base from models.base import Base
from .engine import db
class CeleryTask(Base): class CeleryTask(Base):
"""Task result/status.""" """Task result/status."""
@ -19,7 +17,7 @@ class CeleryTask(Base):
id = mapped_column(sa.Integer, sa.Sequence("task_id_sequence"), primary_key=True, autoincrement=True) id = mapped_column(sa.Integer, sa.Sequence("task_id_sequence"), primary_key=True, autoincrement=True)
task_id = mapped_column(String(155), unique=True) task_id = mapped_column(String(155), unique=True)
status = mapped_column(String(50), default=states.PENDING) status = mapped_column(String(50), default=states.PENDING)
result = mapped_column(db.PickleType, nullable=True) result = mapped_column(sa.PickleType, nullable=True)
date_done = mapped_column( date_done = mapped_column(
DateTime, DateTime,
default=lambda: naive_utc_now(), default=lambda: naive_utc_now(),
@ -44,5 +42,5 @@ class CeleryTaskSet(Base):
sa.Integer, sa.Sequence("taskset_id_sequence"), autoincrement=True, primary_key=True sa.Integer, sa.Sequence("taskset_id_sequence"), autoincrement=True, primary_key=True
) )
taskset_id = mapped_column(String(155), unique=True) taskset_id = mapped_column(String(155), unique=True)
result = mapped_column(db.PickleType, nullable=True) result = mapped_column(sa.PickleType, nullable=True)
date_done: Mapped[datetime | None] = mapped_column(DateTime, default=lambda: naive_utc_now(), nullable=True) date_done: Mapped[datetime | None] = mapped_column(DateTime, default=lambda: naive_utc_now(), nullable=True)