more typed orm (#28507)

This commit is contained in:
Asuka Minato 2025-11-21 22:45:51 +09:00 committed by GitHub
parent 63b8bbbab3
commit a6c6bcf95c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
20 changed files with 196 additions and 134 deletions

View File

@ -163,7 +163,7 @@ class PipelineGenerator(BaseAppGenerator):
datasource_type=datasource_type, datasource_type=datasource_type,
datasource_info=json.dumps(datasource_info), datasource_info=json.dumps(datasource_info),
datasource_node_id=start_node_id, datasource_node_id=start_node_id,
input_data=inputs, input_data=dict(inputs),
pipeline_id=pipeline.id, pipeline_id=pipeline.id,
created_by=user.id, created_by=user.id,
) )

View File

@ -274,6 +274,8 @@ class OpsTraceManager:
raise ValueError("App not found") raise ValueError("App not found")
tenant_id = app.tenant_id tenant_id = app.tenant_id
if trace_config_data.tracing_config is None:
raise ValueError("Tracing config cannot be None.")
decrypt_tracing_config = cls.decrypt_tracing_config( decrypt_tracing_config = cls.decrypt_tracing_config(
tenant_id, tracing_provider, trace_config_data.tracing_config tenant_id, tracing_provider, trace_config_data.tracing_config
) )

View File

@ -309,11 +309,12 @@ class ProviderManager:
(model for model in available_models if model.model == "gpt-4"), available_models[0] (model for model in available_models if model.model == "gpt-4"), available_models[0]
) )
default_model = TenantDefaultModel() default_model = TenantDefaultModel(
default_model.tenant_id = tenant_id tenant_id=tenant_id,
default_model.model_type = model_type.to_origin_model_type() model_type=model_type.to_origin_model_type(),
default_model.provider_name = available_model.provider.provider provider_name=available_model.provider.provider,
default_model.model_name = available_model.model model_name=available_model.model,
)
db.session.add(default_model) db.session.add(default_model)
db.session.commit() db.session.commit()

View File

@ -11,8 +11,7 @@ from sqlalchemy import DateTime, String, func, select
from sqlalchemy.orm import Mapped, Session, mapped_column from sqlalchemy.orm import Mapped, Session, mapped_column
from typing_extensions import deprecated from typing_extensions import deprecated
from models.base import TypeBase from .base import TypeBase
from .engine import db from .engine import db
from .types import LongText, StringUUID from .types import LongText, StringUUID

View File

@ -5,8 +5,9 @@ from sqlalchemy.orm import DeclarativeBase, Mapped, MappedAsDataclass, mapped_co
from libs.datetime_utils import naive_utc_now from libs.datetime_utils import naive_utc_now
from libs.uuid_utils import uuidv7 from libs.uuid_utils import uuidv7
from models.engine import metadata
from models.types import StringUUID from .engine import metadata
from .types import StringUUID
class Base(DeclarativeBase): class Base(DeclarativeBase):

View File

@ -22,11 +22,10 @@ from core.rag.index_processor.constant.built_in_field import BuiltInField, Metad
from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.rag.retrieval.retrieval_methods import RetrievalMethod
from extensions.ext_storage import storage from extensions.ext_storage import storage
from libs.uuid_utils import uuidv7 from libs.uuid_utils import uuidv7
from models.base import TypeBase
from services.entities.knowledge_entities.knowledge_entities import ParentMode, Rule from services.entities.knowledge_entities.knowledge_entities import ParentMode, Rule
from .account import Account from .account import Account
from .base import Base from .base import Base, TypeBase
from .engine import db from .engine import db
from .model import App, Tag, TagBinding, UploadFile from .model import App, Tag, TagBinding, UploadFile
from .types import AdjustedJSON, BinaryData, LongText, StringUUID, adjusted_json_index from .types import AdjustedJSON, BinaryData, LongText, StringUUID, adjusted_json_index
@ -934,21 +933,25 @@ class AppDatasetJoin(TypeBase):
return db.session.get(App, self.app_id) return db.session.get(App, self.app_id)
class DatasetQuery(Base): class DatasetQuery(TypeBase):
__tablename__ = "dataset_queries" __tablename__ = "dataset_queries"
__table_args__ = ( __table_args__ = (
sa.PrimaryKeyConstraint("id", name="dataset_query_pkey"), sa.PrimaryKeyConstraint("id", name="dataset_query_pkey"),
sa.Index("dataset_query_dataset_id_idx", "dataset_id"), sa.Index("dataset_query_dataset_id_idx", "dataset_id"),
) )
id = mapped_column(StringUUID, primary_key=True, nullable=False, default=lambda: str(uuid4())) id: Mapped[str] = mapped_column(
dataset_id = mapped_column(StringUUID, nullable=False) StringUUID, primary_key=True, nullable=False, default=lambda: str(uuid4()), init=False
content = mapped_column(LongText, nullable=False) )
dataset_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
content: Mapped[str] = mapped_column(LongText, nullable=False)
source: Mapped[str] = mapped_column(String(255), nullable=False) source: Mapped[str] = mapped_column(String(255), nullable=False)
source_app_id = mapped_column(StringUUID, nullable=True) source_app_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
created_by_role = mapped_column(String(255), nullable=False) created_by_role: Mapped[str] = mapped_column(String(255), nullable=False)
created_by = mapped_column(StringUUID, nullable=False) created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=sa.func.current_timestamp()) created_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=sa.func.current_timestamp(), init=False
)
class DatasetKeywordTable(TypeBase): class DatasetKeywordTable(TypeBase):
@ -1047,12 +1050,12 @@ class TidbAuthBinding(Base):
sa.Index("tidb_auth_bindings_created_at_idx", "created_at"), sa.Index("tidb_auth_bindings_created_at_idx", "created_at"),
sa.Index("tidb_auth_bindings_status_idx", "status"), sa.Index("tidb_auth_bindings_status_idx", "status"),
) )
id = mapped_column(StringUUID, primary_key=True, default=lambda: str(uuid4())) id: Mapped[str] = mapped_column(StringUUID, primary_key=True, default=lambda: str(uuid4()))
tenant_id = mapped_column(StringUUID, nullable=True) tenant_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
cluster_id: Mapped[str] = mapped_column(String(255), nullable=False) cluster_id: Mapped[str] = mapped_column(String(255), nullable=False)
cluster_name: Mapped[str] = mapped_column(String(255), nullable=False) cluster_name: Mapped[str] = mapped_column(String(255), nullable=False)
active: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false")) active: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
status = mapped_column(sa.String(255), nullable=False, server_default=sa.text("'CREATING'")) status: Mapped[str] = mapped_column(sa.String(255), nullable=False, server_default=sa.text("'CREATING'"))
account: Mapped[str] = mapped_column(String(255), nullable=False) account: Mapped[str] = mapped_column(String(255), nullable=False)
password: Mapped[str] = mapped_column(String(255), nullable=False) password: Mapped[str] = mapped_column(String(255), nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
@ -1148,7 +1151,7 @@ class ExternalKnowledgeApis(TypeBase):
return dataset_bindings return dataset_bindings
class ExternalKnowledgeBindings(Base): class ExternalKnowledgeBindings(TypeBase):
__tablename__ = "external_knowledge_bindings" __tablename__ = "external_knowledge_bindings"
__table_args__ = ( __table_args__ = (
sa.PrimaryKeyConstraint("id", name="external_knowledge_bindings_pkey"), sa.PrimaryKeyConstraint("id", name="external_knowledge_bindings_pkey"),
@ -1158,16 +1161,18 @@ class ExternalKnowledgeBindings(Base):
sa.Index("external_knowledge_bindings_external_knowledge_api_idx", "external_knowledge_api_id"), sa.Index("external_knowledge_bindings_external_knowledge_api_idx", "external_knowledge_api_id"),
) )
id = mapped_column(StringUUID, nullable=False, default=lambda: str(uuid4())) id: Mapped[str] = mapped_column(StringUUID, nullable=False, default=lambda: str(uuid4()), init=False)
tenant_id = mapped_column(StringUUID, nullable=False) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
external_knowledge_api_id = mapped_column(StringUUID, nullable=False) external_knowledge_api_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
dataset_id = mapped_column(StringUUID, nullable=False) dataset_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
external_knowledge_id = mapped_column(String(512), nullable=False) external_knowledge_id: Mapped[str] = mapped_column(String(512), nullable=False)
created_by = mapped_column(StringUUID, nullable=False) created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) created_at: Mapped[datetime] = mapped_column(
updated_by = mapped_column(StringUUID, nullable=True) DateTime, nullable=False, server_default=func.current_timestamp(), init=False
)
updated_by: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None, init=False)
updated_at: Mapped[datetime] = mapped_column( updated_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp() DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp(), init=False
) )
@ -1245,49 +1250,61 @@ class DatasetMetadataBinding(Base):
created_by = mapped_column(StringUUID, nullable=False) created_by = mapped_column(StringUUID, nullable=False)
class PipelineBuiltInTemplate(Base): # type: ignore[name-defined] class PipelineBuiltInTemplate(TypeBase):
__tablename__ = "pipeline_built_in_templates" __tablename__ = "pipeline_built_in_templates"
__table_args__ = (sa.PrimaryKeyConstraint("id", name="pipeline_built_in_template_pkey"),) __table_args__ = (sa.PrimaryKeyConstraint("id", name="pipeline_built_in_template_pkey"),)
id = mapped_column(StringUUID, default=lambda: str(uuidv7())) id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7()), init=False)
name = mapped_column(sa.String(255), nullable=False) name: Mapped[str] = mapped_column(sa.String(255), nullable=False)
description = mapped_column(LongText, nullable=False) description: Mapped[str] = mapped_column(LongText, nullable=False)
chunk_structure = mapped_column(sa.String(255), nullable=False) chunk_structure: Mapped[str] = mapped_column(sa.String(255), nullable=False)
icon = mapped_column(sa.JSON, nullable=False) icon: Mapped[dict] = mapped_column(sa.JSON, nullable=False)
yaml_content = mapped_column(LongText, nullable=False) yaml_content: Mapped[str] = mapped_column(LongText, nullable=False)
copyright = mapped_column(sa.String(255), nullable=False) copyright: Mapped[str] = mapped_column(sa.String(255), nullable=False)
privacy_policy = mapped_column(sa.String(255), nullable=False) privacy_policy: Mapped[str] = mapped_column(sa.String(255), nullable=False)
position = mapped_column(sa.Integer, nullable=False) position: Mapped[int] = mapped_column(sa.Integer, nullable=False)
install_count = mapped_column(sa.Integer, nullable=False, default=0) install_count: Mapped[int] = mapped_column(sa.Integer, nullable=False)
language = mapped_column(sa.String(255), nullable=False) language: Mapped[str] = mapped_column(sa.String(255), nullable=False)
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) created_at: Mapped[datetime] = mapped_column(
updated_at = mapped_column( sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp() )
updated_at: Mapped[datetime] = mapped_column(
sa.DateTime,
nullable=False,
server_default=func.current_timestamp(),
onupdate=func.current_timestamp(),
init=False,
) )
class PipelineCustomizedTemplate(Base): # type: ignore[name-defined] class PipelineCustomizedTemplate(TypeBase):
__tablename__ = "pipeline_customized_templates" __tablename__ = "pipeline_customized_templates"
__table_args__ = ( __table_args__ = (
sa.PrimaryKeyConstraint("id", name="pipeline_customized_template_pkey"), sa.PrimaryKeyConstraint("id", name="pipeline_customized_template_pkey"),
sa.Index("pipeline_customized_template_tenant_idx", "tenant_id"), sa.Index("pipeline_customized_template_tenant_idx", "tenant_id"),
) )
id = mapped_column(StringUUID, default=lambda: str(uuidv7())) id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7()), init=False)
tenant_id = mapped_column(StringUUID, nullable=False) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
name = mapped_column(sa.String(255), nullable=False) name: Mapped[str] = mapped_column(sa.String(255), nullable=False)
description = mapped_column(LongText, nullable=False) description: Mapped[str] = mapped_column(LongText, nullable=False)
chunk_structure = mapped_column(sa.String(255), nullable=False) chunk_structure: Mapped[str] = mapped_column(sa.String(255), nullable=False)
icon = mapped_column(sa.JSON, nullable=False) icon: Mapped[dict] = mapped_column(sa.JSON, nullable=False)
position = mapped_column(sa.Integer, nullable=False) position: Mapped[int] = mapped_column(sa.Integer, nullable=False)
yaml_content = mapped_column(LongText, nullable=False) yaml_content: Mapped[str] = mapped_column(LongText, nullable=False)
install_count = mapped_column(sa.Integer, nullable=False, default=0) install_count: Mapped[int] = mapped_column(sa.Integer, nullable=False)
language = mapped_column(sa.String(255), nullable=False) language: Mapped[str] = mapped_column(sa.String(255), nullable=False)
created_by = mapped_column(StringUUID, nullable=False) created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
updated_by = mapped_column(StringUUID, nullable=True) updated_by: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None, init=False)
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) created_at: Mapped[datetime] = mapped_column(
updated_at = mapped_column( sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp() )
updated_at: Mapped[datetime] = mapped_column(
sa.DateTime,
nullable=False,
server_default=func.current_timestamp(),
onupdate=func.current_timestamp(),
init=False,
) )
@property @property
@ -1320,34 +1337,42 @@ class Pipeline(Base): # type: ignore[name-defined]
return session.query(Dataset).where(Dataset.pipeline_id == self.id).first() return session.query(Dataset).where(Dataset.pipeline_id == self.id).first()
class DocumentPipelineExecutionLog(Base): class DocumentPipelineExecutionLog(TypeBase):
__tablename__ = "document_pipeline_execution_logs" __tablename__ = "document_pipeline_execution_logs"
__table_args__ = ( __table_args__ = (
sa.PrimaryKeyConstraint("id", name="document_pipeline_execution_log_pkey"), sa.PrimaryKeyConstraint("id", name="document_pipeline_execution_log_pkey"),
sa.Index("document_pipeline_execution_logs_document_id_idx", "document_id"), sa.Index("document_pipeline_execution_logs_document_id_idx", "document_id"),
) )
id = mapped_column(StringUUID, default=lambda: str(uuidv7())) id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7()), init=False)
pipeline_id = mapped_column(StringUUID, nullable=False) pipeline_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
document_id = mapped_column(StringUUID, nullable=False) document_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
datasource_type = mapped_column(sa.String(255), nullable=False) datasource_type: Mapped[str] = mapped_column(sa.String(255), nullable=False)
datasource_info = mapped_column(LongText, nullable=False) datasource_info: Mapped[str] = mapped_column(LongText, nullable=False)
datasource_node_id = mapped_column(sa.String(255), nullable=False) datasource_node_id: Mapped[str] = mapped_column(sa.String(255), nullable=False)
input_data = mapped_column(sa.JSON, nullable=False) input_data: Mapped[dict] = mapped_column(sa.JSON, nullable=False)
created_by = mapped_column(StringUUID, nullable=True) created_by: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) created_at: Mapped[datetime] = mapped_column(
sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False
)
class PipelineRecommendedPlugin(Base): class PipelineRecommendedPlugin(TypeBase):
__tablename__ = "pipeline_recommended_plugins" __tablename__ = "pipeline_recommended_plugins"
__table_args__ = (sa.PrimaryKeyConstraint("id", name="pipeline_recommended_plugin_pkey"),) __table_args__ = (sa.PrimaryKeyConstraint("id", name="pipeline_recommended_plugin_pkey"),)
id = mapped_column(StringUUID, default=lambda: str(uuidv7())) id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7()), init=False)
plugin_id = mapped_column(LongText, nullable=False) plugin_id: Mapped[str] = mapped_column(LongText, nullable=False)
provider_name = mapped_column(LongText, nullable=False) provider_name: Mapped[str] = mapped_column(LongText, nullable=False)
position = mapped_column(sa.Integer, nullable=False, default=0) position: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0)
active = mapped_column(sa.Boolean, nullable=False, default=True) active: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, default=True)
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) created_at: Mapped[datetime] = mapped_column(
updated_at = mapped_column( sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp() )
updated_at: Mapped[datetime] = mapped_column(
sa.DateTime,
nullable=False,
server_default=func.current_timestamp(),
onupdate=func.current_timestamp(),
init=False,
) )

View File

@ -31,7 +31,7 @@ from .provider_ids import GenericProviderID
from .types import LongText, StringUUID from .types import LongText, StringUUID
if TYPE_CHECKING: if TYPE_CHECKING:
from models.workflow import Workflow from .workflow import Workflow
class DifySetup(TypeBase): class DifySetup(TypeBase):
@ -1747,36 +1747,40 @@ class UploadFile(Base):
self.source_url = source_url self.source_url = source_url
class ApiRequest(Base): class ApiRequest(TypeBase):
__tablename__ = "api_requests" __tablename__ = "api_requests"
__table_args__ = ( __table_args__ = (
sa.PrimaryKeyConstraint("id", name="api_request_pkey"), sa.PrimaryKeyConstraint("id", name="api_request_pkey"),
sa.Index("api_request_token_idx", "tenant_id", "api_token_id"), sa.Index("api_request_token_idx", "tenant_id", "api_token_id"),
) )
id = mapped_column(StringUUID, default=lambda: str(uuid4())) id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
tenant_id = mapped_column(StringUUID, nullable=False) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
api_token_id = mapped_column(StringUUID, nullable=False) api_token_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
path: Mapped[str] = mapped_column(String(255), nullable=False) path: Mapped[str] = mapped_column(String(255), nullable=False)
request = mapped_column(LongText, nullable=True) request: Mapped[str | None] = mapped_column(LongText, nullable=True)
response = mapped_column(LongText, nullable=True) response: Mapped[str | None] = mapped_column(LongText, nullable=True)
ip: Mapped[str] = mapped_column(String(255), nullable=False) ip: Mapped[str] = mapped_column(String(255), nullable=False)
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) created_at: Mapped[datetime] = mapped_column(
sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False
)
class MessageChain(Base): class MessageChain(TypeBase):
__tablename__ = "message_chains" __tablename__ = "message_chains"
__table_args__ = ( __table_args__ = (
sa.PrimaryKeyConstraint("id", name="message_chain_pkey"), sa.PrimaryKeyConstraint("id", name="message_chain_pkey"),
sa.Index("message_chain_message_id_idx", "message_id"), sa.Index("message_chain_message_id_idx", "message_id"),
) )
id = mapped_column(StringUUID, default=lambda: str(uuid4())) id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
message_id = mapped_column(StringUUID, nullable=False) message_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
type: Mapped[str] = mapped_column(String(255), nullable=False) type: Mapped[str] = mapped_column(String(255), nullable=False)
input = mapped_column(LongText, nullable=True) input: Mapped[str | None] = mapped_column(LongText, nullable=True)
output = mapped_column(LongText, nullable=True) output: Mapped[str | None] = mapped_column(LongText, nullable=True)
created_at = mapped_column(sa.DateTime, nullable=False, server_default=sa.func.current_timestamp()) created_at: Mapped[datetime] = mapped_column(
sa.DateTime, nullable=False, server_default=sa.func.current_timestamp(), init=False
)
class MessageAgentThought(Base): class MessageAgentThought(Base):
@ -1956,22 +1960,28 @@ class TagBinding(TypeBase):
) )
class TraceAppConfig(Base): class TraceAppConfig(TypeBase):
__tablename__ = "trace_app_config" __tablename__ = "trace_app_config"
__table_args__ = ( __table_args__ = (
sa.PrimaryKeyConstraint("id", name="tracing_app_config_pkey"), sa.PrimaryKeyConstraint("id", name="tracing_app_config_pkey"),
sa.Index("trace_app_config_app_id_idx", "app_id"), sa.Index("trace_app_config_app_id_idx", "app_id"),
) )
id = mapped_column(StringUUID, default=lambda: str(uuid4())) id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
app_id = mapped_column(StringUUID, nullable=False) app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
tracing_provider = mapped_column(String(255), nullable=True) tracing_provider: Mapped[str | None] = mapped_column(String(255), nullable=True)
tracing_config = mapped_column(sa.JSON, nullable=True) tracing_config: Mapped[dict | None] = mapped_column(sa.JSON, nullable=True)
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) created_at: Mapped[datetime] = mapped_column(
updated_at = mapped_column( sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
) )
is_active: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true")) updated_at: Mapped[datetime] = mapped_column(
sa.DateTime,
nullable=False,
server_default=func.current_timestamp(),
onupdate=func.current_timestamp(),
init=False,
)
is_active: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"), default=True)
@property @property
def tracing_config_dict(self) -> dict[str, Any]: def tracing_config_dict(self) -> dict[str, Any]:

View File

@ -166,21 +166,23 @@ class ProviderModel(TypeBase):
return credential.encrypted_config if credential else None return credential.encrypted_config if credential else None
class TenantDefaultModel(Base): class TenantDefaultModel(TypeBase):
__tablename__ = "tenant_default_models" __tablename__ = "tenant_default_models"
__table_args__ = ( __table_args__ = (
sa.PrimaryKeyConstraint("id", name="tenant_default_model_pkey"), sa.PrimaryKeyConstraint("id", name="tenant_default_model_pkey"),
sa.Index("tenant_default_model_tenant_id_provider_type_idx", "tenant_id", "provider_name", "model_type"), sa.Index("tenant_default_model_tenant_id_provider_type_idx", "tenant_id", "provider_name", "model_type"),
) )
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4())) id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
provider_name: Mapped[str] = mapped_column(String(255), nullable=False) provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
model_name: Mapped[str] = mapped_column(String(255), nullable=False) model_name: Mapped[str] = mapped_column(String(255), nullable=False)
model_type: Mapped[str] = mapped_column(String(40), nullable=False) model_type: Mapped[str] = mapped_column(String(40), nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) created_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=func.current_timestamp(), init=False
)
updated_at: Mapped[datetime] = mapped_column( updated_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp() DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp(), init=False
) )

View File

@ -6,8 +6,7 @@ import sqlalchemy as sa
from sqlalchemy import DateTime, String, func from sqlalchemy import DateTime, String, func
from sqlalchemy.orm import Mapped, mapped_column from sqlalchemy.orm import Mapped, mapped_column
from models.base import TypeBase from .base import TypeBase
from .types import AdjustedJSON, LongText, StringUUID, adjusted_json_index from .types import AdjustedJSON, LongText, StringUUID, adjusted_json_index

View File

@ -6,8 +6,8 @@ from sqlalchemy import DateTime, String
from sqlalchemy.orm import Mapped, mapped_column from sqlalchemy.orm import Mapped, mapped_column
from libs.datetime_utils import naive_utc_now from libs.datetime_utils import naive_utc_now
from models.base import TypeBase
from .base import TypeBase
from .types import BinaryData, LongText from .types import BinaryData, LongText

View File

@ -12,8 +12,8 @@ from sqlalchemy.orm import Mapped, mapped_column
from core.tools.entities.common_entities import I18nObject from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_bundle import ApiToolBundle from core.tools.entities.tool_bundle import ApiToolBundle
from core.tools.entities.tool_entities import ApiProviderSchemaType, WorkflowToolParameterConfiguration from core.tools.entities.tool_entities import ApiProviderSchemaType, WorkflowToolParameterConfiguration
from models.base import TypeBase
from .base import TypeBase
from .engine import db from .engine import db
from .model import Account, App, Tenant from .model import Account, App, Tenant
from .types import LongText, StringUUID from .types import LongText, StringUUID

View File

@ -237,7 +237,7 @@ class WorkflowTriggerLog(Base):
@property @property
def created_by_end_user(self): def created_by_end_user(self):
from models.model import EndUser from .model import EndUser
created_by_role = CreatorUserRole(self.created_by_role) created_by_role = CreatorUserRole(self.created_by_role)
return db.session.get(EndUser, self.created_by) if created_by_role == CreatorUserRole.END_USER else None return db.session.get(EndUser, self.created_by) if created_by_role == CreatorUserRole.END_USER else None

View File

@ -5,8 +5,7 @@ import sqlalchemy as sa
from sqlalchemy import DateTime, String, func from sqlalchemy import DateTime, String, func
from sqlalchemy.orm import Mapped, mapped_column from sqlalchemy.orm import Mapped, mapped_column
from models.base import TypeBase from .base import TypeBase
from .engine import db from .engine import db
from .model import Message from .model import Message
from .types import StringUUID from .types import StringUUID

View File

@ -86,7 +86,7 @@ class WorkflowType(StrEnum):
:param app_mode: app mode :param app_mode: app mode
:return: workflow type :return: workflow type
""" """
from models.model import AppMode from .model import AppMode
app_mode = app_mode if isinstance(app_mode, AppMode) else AppMode.value_of(app_mode) app_mode = app_mode if isinstance(app_mode, AppMode) else AppMode.value_of(app_mode)
return cls.WORKFLOW if app_mode == AppMode.WORKFLOW else cls.CHAT return cls.WORKFLOW if app_mode == AppMode.WORKFLOW else cls.CHAT
@ -413,7 +413,7 @@ class Workflow(Base):
For accurate checking, use a direct query with tenant_id, app_id, and version. For accurate checking, use a direct query with tenant_id, app_id, and version.
""" """
from models.tools import WorkflowToolProvider from .tools import WorkflowToolProvider
stmt = select( stmt = select(
exists().where( exists().where(
@ -634,7 +634,7 @@ class WorkflowRun(Base):
@property @property
def created_by_end_user(self): def created_by_end_user(self):
from models.model import EndUser from .model import EndUser
created_by_role = CreatorUserRole(self.created_by_role) created_by_role = CreatorUserRole(self.created_by_role)
return db.session.get(EndUser, self.created_by) if created_by_role == CreatorUserRole.END_USER else None return db.session.get(EndUser, self.created_by) if created_by_role == CreatorUserRole.END_USER else None
@ -653,7 +653,7 @@ class WorkflowRun(Base):
@property @property
def message(self): def message(self):
from models.model import Message from .model import Message
return ( return (
db.session.query(Message).where(Message.app_id == self.app_id, Message.workflow_run_id == self.id).first() db.session.query(Message).where(Message.app_id == self.app_id, Message.workflow_run_id == self.id).first()
@ -874,7 +874,7 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo
@property @property
def created_by_end_user(self): def created_by_end_user(self):
from models.model import EndUser from .model import EndUser
created_by_role = CreatorUserRole(self.created_by_role) created_by_role = CreatorUserRole(self.created_by_role)
# TODO(-LAN-): Avoid using db.session.get() here. # TODO(-LAN-): Avoid using db.session.get() here.
@ -1130,7 +1130,7 @@ class WorkflowAppLog(TypeBase):
@property @property
def created_by_end_user(self): def created_by_end_user(self):
from models.model import EndUser from .model import EndUser
created_by_role = CreatorUserRole(self.created_by_role) created_by_role = CreatorUserRole(self.created_by_role)
return db.session.get(EndUser, self.created_by) if created_by_role == CreatorUserRole.END_USER else None return db.session.get(EndUser, self.created_by) if created_by_role == CreatorUserRole.END_USER else None

View File

@ -254,6 +254,8 @@ class DatasetService:
external_knowledge_api = ExternalDatasetService.get_external_knowledge_api(external_knowledge_api_id) external_knowledge_api = ExternalDatasetService.get_external_knowledge_api(external_knowledge_api_id)
if not external_knowledge_api: if not external_knowledge_api:
raise ValueError("External API template not found.") raise ValueError("External API template not found.")
if external_knowledge_id is None:
raise ValueError("external_knowledge_id is required")
external_knowledge_binding = ExternalKnowledgeBindings( external_knowledge_binding = ExternalKnowledgeBindings(
tenant_id=tenant_id, tenant_id=tenant_id,
dataset_id=dataset.id, dataset_id=dataset.id,

View File

@ -257,12 +257,16 @@ class ExternalDatasetService:
db.session.add(dataset) db.session.add(dataset)
db.session.flush() db.session.flush()
if args.get("external_knowledge_id") is None:
raise ValueError("external_knowledge_id is required")
if args.get("external_knowledge_api_id") is None:
raise ValueError("external_knowledge_api_id is required")
external_knowledge_binding = ExternalKnowledgeBindings( external_knowledge_binding = ExternalKnowledgeBindings(
tenant_id=tenant_id, tenant_id=tenant_id,
dataset_id=dataset.id, dataset_id=dataset.id,
external_knowledge_api_id=args.get("external_knowledge_api_id"), external_knowledge_api_id=args.get("external_knowledge_api_id") or "",
external_knowledge_id=args.get("external_knowledge_id"), external_knowledge_id=args.get("external_knowledge_id") or "",
created_by=user_id, created_by=user_id,
) )
db.session.add(external_knowledge_binding) db.session.add(external_knowledge_binding)

View File

@ -82,7 +82,12 @@ class HitTestingService:
logger.debug("Hit testing retrieve in %s seconds", end - start) logger.debug("Hit testing retrieve in %s seconds", end - start)
dataset_query = DatasetQuery( dataset_query = DatasetQuery(
dataset_id=dataset.id, content=query, source="hit_testing", created_by_role="account", created_by=account.id dataset_id=dataset.id,
content=query,
source="hit_testing",
source_app_id=None,
created_by_role="account",
created_by=account.id,
) )
db.session.add(dataset_query) db.session.add(dataset_query)
@ -118,7 +123,12 @@ class HitTestingService:
logger.debug("External knowledge hit testing retrieve in %s seconds", end - start) logger.debug("External knowledge hit testing retrieve in %s seconds", end - start)
dataset_query = DatasetQuery( dataset_query = DatasetQuery(
dataset_id=dataset.id, content=query, source="hit_testing", created_by_role="account", created_by=account.id dataset_id=dataset.id,
content=query,
source="hit_testing",
source_app_id=None,
created_by_role="account",
created_by=account.id,
) )
db.session.add(dataset_query) db.session.add(dataset_query)

View File

@ -29,6 +29,8 @@ class OpsService:
if not app: if not app:
return None return None
tenant_id = app.tenant_id tenant_id = app.tenant_id
if trace_config_data.tracing_config is None:
raise ValueError("Tracing config cannot be None.")
decrypt_tracing_config = OpsTraceManager.decrypt_tracing_config( decrypt_tracing_config = OpsTraceManager.decrypt_tracing_config(
tenant_id, tracing_provider, trace_config_data.tracing_config tenant_id, tracing_provider, trace_config_data.tracing_config
) )

View File

@ -1119,13 +1119,19 @@ class RagPipelineService:
with Session(db.engine) as session: with Session(db.engine) as session:
rag_pipeline_dsl_service = RagPipelineDslService(session) rag_pipeline_dsl_service = RagPipelineDslService(session)
dsl = rag_pipeline_dsl_service.export_rag_pipeline_dsl(pipeline=pipeline, include_secret=True) dsl = rag_pipeline_dsl_service.export_rag_pipeline_dsl(pipeline=pipeline, include_secret=True)
if args.get("icon_info") is None:
args["icon_info"] = {}
if args.get("description") is None:
raise ValueError("Description is required")
if args.get("name") is None:
raise ValueError("Name is required")
pipeline_customized_template = PipelineCustomizedTemplate( pipeline_customized_template = PipelineCustomizedTemplate(
name=args.get("name"), name=args.get("name") or "",
description=args.get("description"), description=args.get("description") or "",
icon=args.get("icon_info"), icon=args.get("icon_info") or {},
tenant_id=pipeline.tenant_id, tenant_id=pipeline.tenant_id,
yaml_content=dsl, yaml_content=dsl,
install_count=0,
position=max_position + 1 if max_position else 1, position=max_position + 1 if max_position else 1,
chunk_structure=dataset.chunk_structure, chunk_structure=dataset.chunk_structure,
language="en-US", language="en-US",

View File

@ -322,9 +322,9 @@ class RagPipelineTransformService:
datasource_info=data_source_info, datasource_info=data_source_info,
input_data={}, input_data={},
created_by=document.created_by, created_by=document.created_by,
created_at=document.created_at,
datasource_node_id=file_node_id, datasource_node_id=file_node_id,
) )
document_pipeline_execution_log.created_at = document.created_at
db.session.add(document) db.session.add(document)
db.session.add(document_pipeline_execution_log) db.session.add(document_pipeline_execution_log)
elif document.data_source_type == "notion_import": elif document.data_source_type == "notion_import":
@ -350,9 +350,9 @@ class RagPipelineTransformService:
datasource_info=data_source_info, datasource_info=data_source_info,
input_data={}, input_data={},
created_by=document.created_by, created_by=document.created_by,
created_at=document.created_at,
datasource_node_id=notion_node_id, datasource_node_id=notion_node_id,
) )
document_pipeline_execution_log.created_at = document.created_at
db.session.add(document) db.session.add(document)
db.session.add(document_pipeline_execution_log) db.session.add(document_pipeline_execution_log)
elif document.data_source_type == "website_crawl": elif document.data_source_type == "website_crawl":
@ -379,8 +379,8 @@ class RagPipelineTransformService:
datasource_info=data_source_info, datasource_info=data_source_info,
input_data={}, input_data={},
created_by=document.created_by, created_by=document.created_by,
created_at=document.created_at,
datasource_node_id=datasource_node_id, datasource_node_id=datasource_node_id,
) )
document_pipeline_execution_log.created_at = document.created_at
db.session.add(document) db.session.add(document)
db.session.add(document_pipeline_execution_log) db.session.add(document_pipeline_execution_log)